Sbwg commited on
Commit
4d15a8e
·
verified ·
1 Parent(s): 90c7d1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -15
app.py CHANGED
@@ -1,25 +1,30 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
3
 
4
- # Load a simple image classification model
5
- classifier = pipeline("image-classification", model="google/vit-large-patch16-224-in21k")
 
 
 
6
 
7
- def classify_image(image):
8
- # Run prediction
9
- results = classifier(image)
10
- # Get top label and confidence
11
- top_result = results[0]
12
- label = top_result["label"]
13
- score = round(top_result["score"] * 100, 2)
14
- return f"{label} ({score}%)"
 
 
15
 
16
- # Simple Gradio UI
17
  demo = gr.Interface(
18
- fn=classify_image,
19
  inputs=gr.Image(type="pil"),
20
  outputs="text",
21
- title="🖼️ Simple Image Classifier",
22
- description="what is this food"
23
  )
24
 
25
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
3
+ from PIL import Image
4
 
5
+ # Load the stronger model
6
+ model_id = "nlpconnect/vit-gpt2-image-captioning"
7
+ model = VisionEncoderDecoderModel.from_pretrained(model_id)
8
+ feature_extractor = ViTImageProcessor.from_pretrained(model_id)
9
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
10
 
11
+ def classify_better(image):
12
+ # preprocess
13
+ if image.mode != "RGB":
14
+ image = image.convert("RGB")
15
+ pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
16
+
17
+ # Generate caption
18
+ output_ids = model.generate(pixel_values, max_length=20, num_beams=5)
19
+ caption = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
20
+ return caption
21
 
 
22
  demo = gr.Interface(
23
+ fn=classify_better,
24
  inputs=gr.Image(type="pil"),
25
  outputs="text",
26
+ title="Better Image Captioning",
27
+ description="Upload an image and the model will try to describe it (better)."
28
  )
29
 
30
  if __name__ == "__main__":