wueesnin commited on
Commit
4194462
·
verified ·
1 Parent(s): 2fa315d

Update app.py

Browse files

Updated app.py function

Files changed (1) hide show
  1. app.py +280 -0
app.py CHANGED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Dict, Tuple
3
+
4
+ import gradio as gr
5
+ import torch
6
+ from PIL import Image
7
+ from transformers import (
8
+ AutoImageProcessor,
9
+ AutoModelForImageClassification,
10
+ CLIPModel,
11
+ CLIPProcessor,
12
+ )
13
+
14
+ # Optional OpenAI client. The app still works without it.
15
+ try:
16
+ from openai import OpenAI
17
+ except Exception:
18
+ OpenAI = None
19
+
20
+
21
+ # =========================================================
22
+ # Configuration
23
+ # =========================================================
24
+ # Replace these labels with your final dataset classes.
25
+ CLASS_LABELS: List[str] = [
26
+ "sphynx",
27
+ "russian blue",
28
+ "maine coon",
29
+ "ragdoll",
30
+ "bengal",
31
+ "singapura",
32
+ "calico cat"
33
+ ]
34
+
35
+ # Your fine-tuned Hugging Face image classification model.
36
+ # Example: "your-username/cat-vs-wild-animal-vit"
37
+ CUSTOM_MODEL_ID = os.getenv("CUSTOM_MODEL_ID", "your-username/your-model-name")
38
+
39
+ # Open-source comparison model.
40
+ CLIP_MODEL_ID = os.getenv("CLIP_MODEL_ID", "openai/clip-vit-base-patch32")
41
+
42
+ # Example images shown in Gradio. Add real files before deployment.
43
+ EXAMPLE_IMAGES = [
44
+ ["example_images/sphynx.jpg"],
45
+ ["example_images/russian-blue.jpg"],
46
+ ["example_images/maine-coon.jpg"],
47
+ ["example_images/ragdoll.jpg"],
48
+ ["example_images/bengal.jpg"],
49
+ ["example_images/singapura.jpg"],
50
+ ["example_images/calico.jpg"],
51
+ ]
52
+
53
+
54
+ # =========================================================
55
+ # Model loading
56
+ # =========================================================
57
+ device = "cuda" if torch.cuda.is_available() else "cpu"
58
+
59
+ custom_processor = None
60
+ custom_model = None
61
+ custom_model_error = None
62
+
63
+ clip_processor = None
64
+ clip_model = None
65
+ clip_model_error = None
66
+
67
+
68
+ def load_custom_model() -> None:
69
+ global custom_processor, custom_model, custom_model_error
70
+ try:
71
+ custom_processor = AutoImageProcessor.from_pretrained(CUSTOM_MODEL_ID)
72
+ custom_model = AutoModelForImageClassification.from_pretrained(CUSTOM_MODEL_ID)
73
+ custom_model.to(device)
74
+ custom_model.eval()
75
+ except Exception as exc:
76
+ custom_model_error = str(exc)
77
+
78
+
79
+
80
+ def load_clip_model() -> None:
81
+ global clip_processor, clip_model, clip_model_error
82
+ try:
83
+ clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_ID)
84
+ clip_model = CLIPModel.from_pretrained(CLIP_MODEL_ID)
85
+ clip_model.to(device)
86
+ clip_model.eval()
87
+ except Exception as exc:
88
+ clip_model_error = str(exc)
89
+
90
+
91
+ load_custom_model()
92
+ load_clip_model()
93
+
94
+
95
+ # =========================================================
96
+ # Helpers
97
+ # =========================================================
98
+ def ensure_rgb(image: Image.Image) -> Image.Image:
99
+ if image.mode != "RGB":
100
+ image = image.convert("RGB")
101
+ return image
102
+
103
+
104
+
105
+ def format_topk(predictions: List[Tuple[str, float]]) -> str:
106
+ lines = []
107
+ for rank, (label, score) in enumerate(predictions, start=1):
108
+ lines.append(f"{rank}. {label} ({score:.4f})")
109
+ return "\n".join(lines)
110
+
111
+
112
+
113
+ def predict_custom_model(image: Image.Image, top_k: int = 3) -> Tuple[str, Dict[str, float]]:
114
+ if custom_model is None or custom_processor is None:
115
+ message = (
116
+ "Custom model could not be loaded.\n\n"
117
+ f"Model ID: {CUSTOM_MODEL_ID}\n"
118
+ f"Error: {custom_model_error}"
119
+ )
120
+ return message, {}
121
+
122
+ image = ensure_rgb(image)
123
+ inputs = custom_processor(images=image, return_tensors="pt")
124
+ inputs = {k: v.to(device) for k, v in inputs.items()}
125
+
126
+ with torch.no_grad():
127
+ outputs = custom_model(**inputs)
128
+ probs = torch.softmax(outputs.logits, dim=-1)[0]
129
+
130
+ id2label = custom_model.config.id2label
131
+ top_indices = torch.topk(probs, k=min(top_k, probs.shape[0])).indices.tolist()
132
+
133
+ top_preds = []
134
+ label_scores = {}
135
+ for idx in top_indices:
136
+ label = id2label.get(idx, str(idx))
137
+ score = probs[idx].item()
138
+ top_preds.append((label, score))
139
+ label_scores[label] = score
140
+
141
+ return format_topk(top_preds), label_scores
142
+
143
+
144
+
145
+ def predict_clip(image: Image.Image, class_labels: List[str], top_k: int = 3) -> Tuple[str, Dict[str, float]]:
146
+ if clip_model is None or clip_processor is None:
147
+ message = (
148
+ "CLIP model could not be loaded.\n\n"
149
+ f"Model ID: {CLIP_MODEL_ID}\n"
150
+ f"Error: {clip_model_error}"
151
+ )
152
+ return message, {}
153
+
154
+ image = ensure_rgb(image)
155
+ prompts = [f"a photo of a {label}" for label in class_labels]
156
+
157
+ inputs = clip_processor(text=prompts, images=image, return_tensors="pt", padding=True)
158
+ inputs = {k: v.to(device) for k, v in inputs.items()}
159
+
160
+ with torch.no_grad():
161
+ outputs = clip_model(**inputs)
162
+ logits = outputs.logits_per_image[0]
163
+ probs = torch.softmax(logits, dim=-1)
164
+
165
+ pairs = [(label, probs[i].item()) for i, label in enumerate(class_labels)]
166
+ pairs.sort(key=lambda x: x[1], reverse=True)
167
+ top_preds = pairs[:top_k]
168
+ label_scores = {label: score for label, score in pairs}
169
+
170
+ return format_topk(top_preds), label_scores
171
+
172
+
173
+
174
+ def predict_openai(image: Image.Image, class_labels: List[str]) -> str:
175
+ if OpenAI is None:
176
+ return "OpenAI package is not installed. Add `openai` to requirements.txt."
177
+
178
+ api_key = os.getenv("OPENAI_API_KEY")
179
+ if not api_key:
180
+ return "OPENAI_API_KEY is not set. The app can still run without the OpenAI comparison."
181
+
182
+ try:
183
+ client = OpenAI(api_key=api_key)
184
+
185
+ # Convert image to bytes for upload.
186
+ import io
187
+
188
+ buffer = io.BytesIO()
189
+ ensure_rgb(image).save(buffer, format="JPEG")
190
+ buffer.seek(0)
191
+
192
+ uploaded = client.files.create(file=("image.jpg", buffer.getvalue(), "image/jpeg"), purpose="vision")
193
+
194
+ prompt = (
195
+ "You are an image classifier. "
196
+ "Choose exactly one label from this label set: "
197
+ f"{', '.join(class_labels)}. "
198
+ "Return a short answer with this structure only: "
199
+ "label: <chosen label>\\nreason: <very short reason>."
200
+ )
201
+
202
+ response = client.responses.create(
203
+ model="gpt-4.1-mini",
204
+ input=[
205
+ {
206
+ "role": "user",
207
+ "content": [
208
+ {"type": "input_text", "text": prompt},
209
+ {"type": "input_image", "file_id": uploaded.id},
210
+ ],
211
+ }
212
+ ],
213
+ )
214
+ return response.output_text.strip()
215
+ except Exception as exc:
216
+ return f"OpenAI prediction failed: {exc}"
217
+
218
+
219
+
220
+ def compare_models(image: Image.Image) -> Tuple[str, Dict[str, float], str, Dict[str, float], str]:
221
+ if image is None:
222
+ return "Please upload an image.", {}, "Please upload an image.", {}, "Please upload an image."
223
+
224
+ custom_text, custom_scores = predict_custom_model(image)
225
+ clip_text, clip_scores = predict_clip(image, CLASS_LABELS)
226
+ openai_text = predict_openai(image, CLASS_LABELS)
227
+
228
+ return custom_text, custom_scores, clip_text, clip_scores, openai_text
229
+
230
+
231
+ # =========================================================
232
+ # UI
233
+ # =========================================================
234
+ DESCRIPTION = """
235
+ Upload an image and compare three approaches:
236
+ 1. Fine-tuned transfer learning model
237
+ 2. Zero-shot CLIP
238
+ 3. OpenAI vision model
239
+
240
+ This version focuses only on cat breed classification.
241
+ """
242
+
243
+ with gr.Blocks() as demo:
244
+ gr.Markdown("# Cat Breed Classifier")
245
+ gr.Markdown(DESCRIPTION)
246
+
247
+ with gr.Row():
248
+ image_input = gr.Image(type="pil", label="Upload image")
249
+
250
+ run_btn = gr.Button("Run comparison")
251
+
252
+ with gr.Row():
253
+ with gr.Column():
254
+ gr.Markdown("## Fine-tuned model")
255
+ custom_text = gr.Textbox(label="Top predictions", lines=6)
256
+ custom_plot = gr.Label(label="Scores")
257
+
258
+ with gr.Column():
259
+ gr.Markdown("## CLIP zero-shot")
260
+ clip_text = gr.Textbox(label="Top predictions", lines=6)
261
+ clip_plot = gr.Label(label="Scores")
262
+
263
+ with gr.Column():
264
+ gr.Markdown("## OpenAI vision")
265
+ openai_text = gr.Textbox(label="Prediction", lines=6)
266
+
267
+ run_btn.click(
268
+ fn=compare_models,
269
+ inputs=image_input,
270
+ outputs=[custom_text, custom_plot, clip_text, clip_plot, openai_text],
271
+ )
272
+
273
+ gr.Examples(
274
+ examples=EXAMPLE_IMAGES,
275
+ inputs=image_input,
276
+ label="Example images",
277
+ )
278
+
279
+ if __name__ == "__main__":
280
+ demo.launch()