""" Pokemon classifier — comparison of three models: 1. Custom fine-tuned ViT (transfer learning, your HF model) 2. CLIP zero-shot (open-source: openai/clip-vit-large-patch14) 3. OpenAI GPT-4o vision (closed-source) """ import base64 import io import os import gradio as gr from PIL import Image from transformers import pipeline # ----- Config ----- CUSTOM_MODEL_ID = "dubattim/pokemon-vit-fs26" POKEMON_LABELS = ["charizard", "charmander", "charmeleon", "ditto", "eevee", "ekans"] # ----- Models ----- custom_classifier = pipeline("image-classification", model=CUSTOM_MODEL_ID) clip_classifier = pipeline( task="zero-shot-image-classification", model="openai/clip-vit-large-patch14", ) # OpenAI client (lazy import so app still loads if package missing locally) try: from openai import OpenAI openai_client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) except Exception: openai_client = None def _encode_image(path: str) -> str: with Image.open(path) as img: img = img.convert("RGB") buf = io.BytesIO() img.save(buf, format="PNG") return base64.b64encode(buf.getvalue()).decode("utf-8") def classify_with_openai(image_path: str) -> dict: if openai_client is None or not os.environ.get("OPENAI_API_KEY"): return {"error": "OPENAI_API_KEY not configured"} b64 = _encode_image(image_path) prompt = ( "Classify the Pokemon in this image. Respond with EXACTLY one of these " f"labels and nothing else: {', '.join(POKEMON_LABELS)}." ) resp = openai_client.chat.completions.create( model="gpt-4o", messages=[ { "role": "user", "content": [ {"type": "text", "text": prompt}, { "type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64}"}, }, ], } ], max_tokens=20, ) answer = resp.choices[0].message.content.strip().lower() return {label: (1.0 if label in answer else 0.0) for label in POKEMON_LABELS} def classify(image_path): if image_path is None: return {}, {}, {} # 1. Custom ViT custom = {r["label"]: float(r["score"]) for r in custom_classifier(image_path)} # 2. CLIP zero-shot clip = { r["label"]: float(r["score"]) for r in clip_classifier(image_path, candidate_labels=POKEMON_LABELS) } # 3. OpenAI GPT-4o openai_out = classify_with_openai(image_path) return custom, clip, openai_out examples = [ ["example_images/charizard.png"], ["example_images/charmander.png"], ["example_images/charmeleon.png"], ["example_images/ditto.png"], ["example_images/eevee.png"], ["example_images/ekans.png"], ] with gr.Blocks(title="Pokemon Classifier — Model Comparison") as demo: gr.Markdown( "# Pokemon Classifier — Model Comparison\n" "Upload a Pokemon image and compare three approaches: a fine-tuned ViT " "(transfer learning), CLIP zero-shot, and OpenAI GPT-4o vision.\n\n" f"**Classes:** {', '.join(POKEMON_LABELS)}" ) with gr.Row(): inp = gr.Image(type="filepath", label="Input image") with gr.Column(): out_custom = gr.Label(label="Custom ViT (fine-tuned)", num_top_classes=6) out_clip = gr.Label(label="CLIP zero-shot", num_top_classes=6) out_openai = gr.Label(label="OpenAI GPT-4o", num_top_classes=6) btn = gr.Button("Classify", variant="primary") btn.click(classify, inputs=inp, outputs=[out_custom, out_clip, out_openai]) gr.Examples(examples=examples, inputs=inp) if __name__ == "__main__": demo.launch()