dubattim's picture
initial app: pokemon classifier with 3-model comparison
6bea7d9
"""
Pokemon classifier β€” comparison of three models:
1. Custom fine-tuned ViT (transfer learning, your HF model)
2. CLIP zero-shot (open-source: openai/clip-vit-large-patch14)
3. OpenAI GPT-4o vision (closed-source)
"""
import base64
import io
import os
import gradio as gr
from PIL import Image
from transformers import pipeline
# ----- Config -----
CUSTOM_MODEL_ID = "dubattim/pokemon-vit-fs26"
POKEMON_LABELS = ["charizard", "charmander", "charmeleon", "ditto", "eevee", "ekans"]
# ----- Models -----
custom_classifier = pipeline("image-classification", model=CUSTOM_MODEL_ID)
clip_classifier = pipeline(
task="zero-shot-image-classification",
model="openai/clip-vit-large-patch14",
)
# OpenAI client (lazy import so app still loads if package missing locally)
try:
from openai import OpenAI
openai_client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
except Exception:
openai_client = None
def _encode_image(path: str) -> str:
with Image.open(path) as img:
img = img.convert("RGB")
buf = io.BytesIO()
img.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode("utf-8")
def classify_with_openai(image_path: str) -> dict:
if openai_client is None or not os.environ.get("OPENAI_API_KEY"):
return {"error": "OPENAI_API_KEY not configured"}
b64 = _encode_image(image_path)
prompt = (
"Classify the Pokemon in this image. Respond with EXACTLY one of these "
f"labels and nothing else: {', '.join(POKEMON_LABELS)}."
)
resp = openai_client.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{b64}"},
},
],
}
],
max_tokens=20,
)
answer = resp.choices[0].message.content.strip().lower()
return {label: (1.0 if label in answer else 0.0) for label in POKEMON_LABELS}
def classify(image_path):
if image_path is None:
return {}, {}, {}
# 1. Custom ViT
custom = {r["label"]: float(r["score"]) for r in custom_classifier(image_path)}
# 2. CLIP zero-shot
clip = {
r["label"]: float(r["score"])
for r in clip_classifier(image_path, candidate_labels=POKEMON_LABELS)
}
# 3. OpenAI GPT-4o
openai_out = classify_with_openai(image_path)
return custom, clip, openai_out
examples = [
["example_images/charizard.png"],
["example_images/charmander.png"],
["example_images/charmeleon.png"],
["example_images/ditto.png"],
["example_images/eevee.png"],
["example_images/ekans.png"],
]
with gr.Blocks(title="Pokemon Classifier β€” Model Comparison") as demo:
gr.Markdown(
"# Pokemon Classifier β€” Model Comparison\n"
"Upload a Pokemon image and compare three approaches: a fine-tuned ViT "
"(transfer learning), CLIP zero-shot, and OpenAI GPT-4o vision.\n\n"
f"**Classes:** {', '.join(POKEMON_LABELS)}"
)
with gr.Row():
inp = gr.Image(type="filepath", label="Input image")
with gr.Column():
out_custom = gr.Label(label="Custom ViT (fine-tuned)", num_top_classes=6)
out_clip = gr.Label(label="CLIP zero-shot", num_top_classes=6)
out_openai = gr.Label(label="OpenAI GPT-4o", num_top_classes=6)
btn = gr.Button("Classify", variant="primary")
btn.click(classify, inputs=inp, outputs=[out_custom, out_clip, out_openai])
gr.Examples(examples=examples, inputs=inp)
if __name__ == "__main__":
demo.launch()