wueesnin commited on
Commit
4011ff2
·
verified ·
1 Parent(s): 80987aa

Updated app to anime images

Browse files
Files changed (1) hide show
  1. app.py +119 -70
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from typing import List, Dict, Tuple
3
 
4
  import gradio as gr
5
  import torch
@@ -11,7 +11,6 @@ from transformers import (
11
  CLIPProcessor,
12
  )
13
 
14
- # Optional OpenAI client. The app still works without it.
15
  try:
16
  from openai import OpenAI
17
  except Exception:
@@ -21,35 +20,53 @@ except Exception:
21
  # =========================================================
22
  # Configuration
23
  # =========================================================
24
- # Replace these labels with your final dataset classes.
25
  CLASS_LABELS: List[str] = [
26
- "sphynx",
27
- "russian blue",
28
- "maine coon",
29
- "ragdoll",
30
- "bengal",
31
- "singapura",
32
- "calico cat"
 
 
33
  ]
34
 
35
  # Your fine-tuned Hugging Face image classification model.
36
- # Example: "your-username/cat-vs-wild-animal-vit"
37
- CUSTOM_MODEL_ID = os.getenv("CUSTOM_MODEL_ID", "your-username/your-model-name")
38
 
39
- # Open-source comparison model.
40
  CLIP_MODEL_ID = os.getenv("CLIP_MODEL_ID", "openai/clip-vit-base-patch32")
 
41
 
42
- # Example images shown in Gradio. Add real files before deployment.
43
  EXAMPLE_IMAGES = [
44
- ["example_images/sphynx.jpg"],
45
- ["example_images/russian-blue.jpg"],
46
- ["example_images/maine-coon.jpg"],
47
- ["example_images/ragdoll.jpg"],
48
- ["example_images/bengal.jpg"],
49
- ["example_images/singapura.jpg"],
50
- ["example_images/calico.jpg"],
 
 
 
 
51
  ]
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  # =========================================================
55
  # Model loading
@@ -97,27 +114,44 @@ load_clip_model()
97
  # =========================================================
98
  def ensure_rgb(image: Image.Image) -> Image.Image:
99
  if image.mode != "RGB":
100
- image = image.convert("RGB")
101
  return image
102
 
103
 
104
 
105
  def format_topk(predictions: List[Tuple[str, float]]) -> str:
106
- lines = []
107
- for rank, (label, score) in enumerate(predictions, start=1):
108
- lines.append(f"{rank}. {label} ({score:.4f})")
109
- return "\n".join(lines)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
 
112
 
113
  def predict_custom_model(image: Image.Image, top_k: int = 3) -> Tuple[str, Dict[str, float]]:
114
  if custom_model is None or custom_processor is None:
115
- message = (
116
- "Custom model could not be loaded.\n\n"
117
- f"Model ID: {CUSTOM_MODEL_ID}\n"
118
- f"Error: {custom_model_error}"
 
 
 
 
119
  )
120
- return message, {}
121
 
122
  image = ensure_rgb(image)
123
  inputs = custom_processor(images=image, return_tensors="pt")
@@ -127,33 +161,36 @@ def predict_custom_model(image: Image.Image, top_k: int = 3) -> Tuple[str, Dict[
127
  outputs = custom_model(**inputs)
128
  probs = torch.softmax(outputs.logits, dim=-1)[0]
129
 
130
- id2label = custom_model.config.id2label
131
  top_indices = torch.topk(probs, k=min(top_k, probs.shape[0])).indices.tolist()
132
 
133
- top_preds = []
134
- label_scores = {}
135
  for idx in top_indices:
136
- label = id2label.get(idx, str(idx))
137
- score = probs[idx].item()
 
138
  top_preds.append((label, score))
139
- label_scores[label] = score
140
 
141
- return format_topk(top_preds), label_scores
142
 
143
 
144
 
145
  def predict_clip(image: Image.Image, class_labels: List[str], top_k: int = 3) -> Tuple[str, Dict[str, float]]:
146
  if clip_model is None or clip_processor is None:
147
- message = (
148
- "CLIP model could not be loaded.\n\n"
149
- f"Model ID: {CLIP_MODEL_ID}\n"
150
- f"Error: {clip_model_error}"
 
 
 
 
151
  )
152
- return message, {}
153
 
154
  image = ensure_rgb(image)
155
- prompts = [f"a photo of a {label}" for label in class_labels]
156
-
157
  inputs = clip_processor(text=prompts, images=image, return_tensors="pt", padding=True)
158
  inputs = {k: v.to(device) for k, v in inputs.items()}
159
 
@@ -162,12 +199,12 @@ def predict_clip(image: Image.Image, class_labels: List[str], top_k: int = 3) ->
162
  logits = outputs.logits_per_image[0]
163
  probs = torch.softmax(logits, dim=-1)
164
 
165
- pairs = [(label, probs[i].item()) for i, label in enumerate(class_labels)]
166
  pairs.sort(key=lambda x: x[1], reverse=True)
167
  top_preds = pairs[:top_k]
168
- label_scores = {label: score for label, score in pairs}
169
 
170
- return format_topk(top_preds), label_scores
171
 
172
 
173
 
@@ -177,36 +214,51 @@ def predict_openai(image: Image.Image, class_labels: List[str]) -> str:
177
 
178
  api_key = os.getenv("OPENAI_API_KEY")
179
  if not api_key:
180
- return "OPENAI_API_KEY is not set. The app can still run without the OpenAI comparison."
181
 
182
  try:
183
- client = OpenAI(api_key=api_key)
184
-
185
- # Convert image to bytes for upload.
186
  import io
187
 
188
  buffer = io.BytesIO()
189
  ensure_rgb(image).save(buffer, format="JPEG")
190
- buffer.seek(0)
191
 
192
- uploaded = client.files.create(file=("image.jpg", buffer.getvalue(), "image/jpeg"), purpose="vision")
 
 
 
 
 
193
 
194
  prompt = (
195
- "You are an image classifier. "
196
- "Choose exactly one label from this label set: "
197
- f"{', '.join(class_labels)}. "
198
- "Return a short answer with this structure only: "
199
- "label: <chosen label>\\nreason: <very short reason>."
 
 
 
 
 
 
 
 
 
200
  )
201
 
202
  response = client.responses.create(
203
- model="gpt-4.1-mini",
204
  input=[
205
  {
206
  "role": "user",
207
  "content": [
208
  {"type": "input_text", "text": prompt},
209
- {"type": "input_image", "file_id": uploaded.id},
 
 
 
210
  ],
211
  }
212
  ],
@@ -217,31 +269,28 @@ def predict_openai(image: Image.Image, class_labels: List[str]) -> str:
217
 
218
 
219
 
220
- def compare_models(image: Image.Image) -> Tuple[str, Dict[str, float], str, Dict[str, float], str]:
221
  if image is None:
222
- return "Please upload an image.", {}, "Please upload an image.", {}, "Please upload an image."
 
223
 
224
  custom_text, custom_scores = predict_custom_model(image)
225
  clip_text, clip_scores = predict_clip(image, CLASS_LABELS)
226
  openai_text = predict_openai(image, CLASS_LABELS)
227
-
228
  return custom_text, custom_scores, clip_text, clip_scores, openai_text
229
 
230
 
231
- # =========================================================
232
- # UI
233
- # =========================================================
234
  DESCRIPTION = """
235
- Upload an image and compare three approaches:
236
  1. Fine-tuned transfer learning model
237
  2. Zero-shot CLIP
238
  3. OpenAI vision model
239
 
240
- This version focuses only on cat breed classification.
241
  """
242
 
243
  with gr.Blocks() as demo:
244
- gr.Markdown("# Cat Breed Classifier")
245
  gr.Markdown(DESCRIPTION)
246
 
247
  with gr.Row():
 
1
  import os
2
+ from typing import Dict, List, Tuple
3
 
4
  import gradio as gr
5
  import torch
 
11
  CLIPProcessor,
12
  )
13
 
 
14
  try:
15
  from openai import OpenAI
16
  except Exception:
 
20
  # =========================================================
21
  # Configuration
22
  # =========================================================
 
23
  CLASS_LABELS: List[str] = [
24
+ "cherry",
25
+ "sakura",
26
+ "naruto",
27
+ "eren",
28
+ "kirito",
29
+ "doraemon",
30
+ "asuna",
31
+ "totoro",
32
+ "chihiro",
33
  ]
34
 
35
  # Your fine-tuned Hugging Face image classification model.
36
+ CUSTOM_MODEL_ID = os.getenv("CUSTOM_MODEL_ID", "wueesnin/image_comparison")
 
37
 
38
+ # Open-source comparison model (openai)
39
  CLIP_MODEL_ID = os.getenv("CLIP_MODEL_ID", "openai/clip-vit-base-patch32")
40
+ OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4.1-mini")
41
 
42
+ # Example anime images :3
43
  EXAMPLE_IMAGES = [
44
+ ["example_images/eren.JPG"],
45
+ ["example_images/mikasa.JPG"],
46
+ ["example_images/naruto.webp"],
47
+ ["example_images/sakura.webp"],
48
+ ["example_images/cherry.webp"],
49
+ ["example_images/kirito.webp"],
50
+ ["example_images/doraemon.webp"],
51
+ ["example_images/luffy.webp"],
52
+ ["example_images/asuna.webp"],
53
+ ["example_images/totoro.webp"],
54
+ ["example_images/chihiro.webp"],
55
  ]
56
 
57
+ # Better prompt wording for CLIP / OpenAI.
58
+ LABEL_DESCRIPTIONS: Dict[str, str] = {
59
+ "eren": "Eren Yeager from Attack on Titan",
60
+ "mikasa": "Mikasa Akermann from Attack on Titan",
61
+ "totoro": "Totoro from My Neighbor Totoro",
62
+ "sakura": "Sakura Haruno from Naruto",
63
+ "naruto": "Naruto Uzumaki from Naruto",
64
+ "cherry": "Cherry Magic",
65
+ "kirito": "Kirito from Sword Art Online",
66
+ "doraemon": "Doraemon",
67
+ "asuna": "Asuna Yuuki from Sword Art Online",
68
+ "chihiro": "Chihiro Ogino from Spirited Away",
69
+ }
70
 
71
  # =========================================================
72
  # Model loading
 
114
  # =========================================================
115
  def ensure_rgb(image: Image.Image) -> Image.Image:
116
  if image.mode != "RGB":
117
+ return image.convert("RGB")
118
  return image
119
 
120
 
121
 
122
  def format_topk(predictions: List[Tuple[str, float]]) -> str:
123
+ return "
124
+ ".join(
125
+ f"{rank}. {label} ({score:.4f})"
126
+ for rank, (label, score) in enumerate(predictions, start=1)
127
+ )
128
+
129
+
130
+
131
+ def normalize_model_label(label: str) -> str:
132
+ return str(label).strip().lower().replace("_", " ")
133
+
134
+
135
+
136
+ def build_clip_prompts(class_labels: List[str]) -> List[str]:
137
+ return [
138
+ f"anime character, {LABEL_DESCRIPTIONS.get(label, label)}"
139
+ for label in class_labels
140
+ ]
141
 
142
 
143
 
144
  def predict_custom_model(image: Image.Image, top_k: int = 3) -> Tuple[str, Dict[str, float]]:
145
  if custom_model is None or custom_processor is None:
146
+ return (
147
+ "Custom model could not be loaded.
148
+
149
+ "
150
+ f"Model ID: {CUSTOM_MODEL_ID}
151
+ "
152
+ f"Error: {custom_model_error}",
153
+ {},
154
  )
 
155
 
156
  image = ensure_rgb(image)
157
  inputs = custom_processor(images=image, return_tensors="pt")
 
161
  outputs = custom_model(**inputs)
162
  probs = torch.softmax(outputs.logits, dim=-1)[0]
163
 
164
+ id2label = getattr(custom_model.config, "id2label", {})
165
  top_indices = torch.topk(probs, k=min(top_k, probs.shape[0])).indices.tolist()
166
 
167
+ top_preds: List[Tuple[str, float]] = []
168
+ score_map: Dict[str, float] = {}
169
  for idx in top_indices:
170
+ raw_label = id2label.get(idx, str(idx))
171
+ label = normalize_model_label(raw_label)
172
+ score = float(probs[idx].item())
173
  top_preds.append((label, score))
174
+ score_map[label] = score
175
 
176
+ return format_topk(top_preds), score_map
177
 
178
 
179
 
180
  def predict_clip(image: Image.Image, class_labels: List[str], top_k: int = 3) -> Tuple[str, Dict[str, float]]:
181
  if clip_model is None or clip_processor is None:
182
+ return (
183
+ "CLIP model could not be loaded.
184
+
185
+ "
186
+ f"Model ID: {CLIP_MODEL_ID}
187
+ "
188
+ f"Error: {clip_model_error}",
189
+ {},
190
  )
 
191
 
192
  image = ensure_rgb(image)
193
+ prompts = build_clip_prompts(class_labels)
 
194
  inputs = clip_processor(text=prompts, images=image, return_tensors="pt", padding=True)
195
  inputs = {k: v.to(device) for k, v in inputs.items()}
196
 
 
199
  logits = outputs.logits_per_image[0]
200
  probs = torch.softmax(logits, dim=-1)
201
 
202
+ pairs = [(class_labels[i], float(probs[i].item())) for i in range(len(class_labels))]
203
  pairs.sort(key=lambda x: x[1], reverse=True)
204
  top_preds = pairs[:top_k]
205
+ score_map = {label: score for label, score in pairs}
206
 
207
+ return format_topk(top_preds), score_map
208
 
209
 
210
 
 
214
 
215
  api_key = os.getenv("OPENAI_API_KEY")
216
  if not api_key:
217
+ return "OPENAI_API_KEY is not set. The app still works for the custom model and CLIP."
218
 
219
  try:
220
+ import base64
 
 
221
  import io
222
 
223
  buffer = io.BytesIO()
224
  ensure_rgb(image).save(buffer, format="JPEG")
225
+ encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")
226
 
227
+ client = OpenAI(api_key=api_key)
228
+ allowed_labels = ", ".join(class_labels)
229
+ descriptions = "
230
+ ".join(
231
+ f"- {label}: {LABEL_DESCRIPTIONS.get(label, label)}" for label in class_labels
232
+ )
233
 
234
  prompt = (
235
+ "Classify this anime image. Choose exactly one label from this list: "
236
+ f"{allowed_labels}.
237
+
238
+ "
239
+ "Label meanings:
240
+ "
241
+ f"{descriptions}
242
+
243
+ "
244
+ "Return exactly this format:
245
+ "
246
+ "label: <one label from the list>
247
+ "
248
+ "reason: <short reason>"
249
  )
250
 
251
  response = client.responses.create(
252
+ model=OPENAI_MODEL,
253
  input=[
254
  {
255
  "role": "user",
256
  "content": [
257
  {"type": "input_text", "text": prompt},
258
+ {
259
+ "type": "input_image",
260
+ "image_url": f"data:image/jpeg;base64,{encoded}",
261
+ },
262
  ],
263
  }
264
  ],
 
269
 
270
 
271
 
272
+ def compare_models(image: Image.Image):
273
  if image is None:
274
+ msg = "Please upload or select an example image."
275
+ return msg, {}, msg, {}, msg
276
 
277
  custom_text, custom_scores = predict_custom_model(image)
278
  clip_text, clip_scores = predict_clip(image, CLASS_LABELS)
279
  openai_text = predict_openai(image, CLASS_LABELS)
 
280
  return custom_text, custom_scores, clip_text, clip_scores, openai_text
281
 
282
 
 
 
 
283
  DESCRIPTION = """
284
+ Upload an anime image and compare three approaches:
285
  1. Fine-tuned transfer learning model
286
  2. Zero-shot CLIP
287
  3. OpenAI vision model
288
 
289
+ This version uses 9 fixed character labels.
290
  """
291
 
292
  with gr.Blocks() as demo:
293
+ gr.Markdown("# Anime Character Classifier")
294
  gr.Markdown(DESCRIPTION)
295
 
296
  with gr.Row():