Skorm commited on
Commit
2636156
·
verified ·
1 Parent(s): 25224a3

Update App to compare to Zero Shot model

Browse files
Files changed (2) hide show
  1. app.py +34 -7
  2. requirements.txt +3 -1
app.py CHANGED
@@ -1,11 +1,35 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
 
3
 
4
  classifier = pipeline("image-classification", model="Skorm/food11-vit")
5
 
6
- def classify_food(image):
7
- results = classifier(image)
8
- return {result["label"]: round(result["score"], 4) for result in results}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Example image paths
11
  examples = [
@@ -21,9 +45,12 @@ examples = [
21
  iface = gr.Interface(
22
  fn=classify_food,
23
  inputs=gr.Image(type="filepath"),
24
- outputs=gr.Label(num_top_classes=3),
25
- title="🍽️ Food Classification with ViT",
26
- description="Upload a food image to classify it into 1 of 11 food categories.",
 
 
 
27
  examples=examples
28
  )
29
 
 
1
  import gradio as gr
2
+ from transformers import pipeline, CLIPProcessor, CLIPModel
3
+ from PIL import Image
4
+ import torch
5
 
6
  classifier = pipeline("image-classification", model="Skorm/food11-vit")
7
 
8
+ # Load CLIP model
9
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
10
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
11
+
12
+ # Define CLIP labels
13
+ clip_labels = [
14
+ "bread", "dairy product", "dessert", "egg", "fried food",
15
+ "meat", "noodles or pasta", "rice", "seafood", "soup", "vegetables or fruits"
16
+ ]
17
+
18
+ def classify_food(image_path):
19
+ image = Image.open(image_path)
20
+
21
+ # ----- ViT prediction -----
22
+ vit_results = classifier(image_path)
23
+ vit_output = {result["label"]: round(result["score"], 4) for result in vit_results}
24
+
25
+ # ----- CLIP zero-shot prediction -----
26
+ inputs = clip_processor(text=clip_labels, images=image, return_tensors="pt", padding=True)
27
+ outputs = clip_model(**inputs)
28
+ probs = outputs.logits_per_image.softmax(dim=1)[0]
29
+
30
+ clip_output = {label: round(float(score), 4) for label, score in zip(clip_labels, probs)}
31
+
32
+ return vit_output, clip_output
33
 
34
  # Example image paths
35
  examples = [
 
45
  iface = gr.Interface(
46
  fn=classify_food,
47
  inputs=gr.Image(type="filepath"),
48
+ outputs=[
49
+ gr.Label(num_top_classes=3, label="ViT (Fine-tuned) Prediction"),
50
+ gr.Label(num_top_classes=3, label="CLIP Zero-Shot Prediction")
51
+ ],
52
+ title="🍽️ Food Classification with ViT and Zero-Shot CLIP",
53
+ description="Upload a food image. The app compares predictions between your fine-tuned ViT model and zero-shot CLIP.",
54
  examples=examples
55
  )
56
 
requirements.txt CHANGED
@@ -1,2 +1,4 @@
1
  transformers
2
- torch
 
 
 
1
  transformers
2
+ torch
3
+ gradio
4
+ Pillow