Spaces:
Sleeping
Sleeping
| """ | |
| Pokemon classifier β comparison of three models: | |
| 1. Custom fine-tuned ViT (transfer learning, your HF model) | |
| 2. CLIP zero-shot (open-source: openai/clip-vit-large-patch14) | |
| 3. OpenAI GPT-4o vision (closed-source) | |
| """ | |
| import base64 | |
| import io | |
| import os | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers import pipeline | |
| # ----- Config ----- | |
| CUSTOM_MODEL_ID = "dubattim/pokemon-vit-fs26" | |
| POKEMON_LABELS = ["charizard", "charmander", "charmeleon", "ditto", "eevee", "ekans"] | |
| # ----- Models ----- | |
| custom_classifier = pipeline("image-classification", model=CUSTOM_MODEL_ID) | |
| clip_classifier = pipeline( | |
| task="zero-shot-image-classification", | |
| model="openai/clip-vit-large-patch14", | |
| ) | |
| # OpenAI client (lazy import so app still loads if package missing locally) | |
| try: | |
| from openai import OpenAI | |
| openai_client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) | |
| except Exception: | |
| openai_client = None | |
| def _encode_image(path: str) -> str: | |
| with Image.open(path) as img: | |
| img = img.convert("RGB") | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| return base64.b64encode(buf.getvalue()).decode("utf-8") | |
| def classify_with_openai(image_path: str) -> dict: | |
| if openai_client is None or not os.environ.get("OPENAI_API_KEY"): | |
| return {"error": "OPENAI_API_KEY not configured"} | |
| b64 = _encode_image(image_path) | |
| prompt = ( | |
| "Classify the Pokemon in this image. Respond with EXACTLY one of these " | |
| f"labels and nothing else: {', '.join(POKEMON_LABELS)}." | |
| ) | |
| resp = openai_client.chat.completions.create( | |
| model="gpt-4o", | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": prompt}, | |
| { | |
| "type": "image_url", | |
| "image_url": {"url": f"data:image/png;base64,{b64}"}, | |
| }, | |
| ], | |
| } | |
| ], | |
| max_tokens=20, | |
| ) | |
| answer = resp.choices[0].message.content.strip().lower() | |
| return {label: (1.0 if label in answer else 0.0) for label in POKEMON_LABELS} | |
| def classify(image_path): | |
| if image_path is None: | |
| return {}, {}, {} | |
| # 1. Custom ViT | |
| custom = {r["label"]: float(r["score"]) for r in custom_classifier(image_path)} | |
| # 2. CLIP zero-shot | |
| clip = { | |
| r["label"]: float(r["score"]) | |
| for r in clip_classifier(image_path, candidate_labels=POKEMON_LABELS) | |
| } | |
| # 3. OpenAI GPT-4o | |
| openai_out = classify_with_openai(image_path) | |
| return custom, clip, openai_out | |
| examples = [ | |
| ["example_images/charizard.png"], | |
| ["example_images/charmander.png"], | |
| ["example_images/charmeleon.png"], | |
| ["example_images/ditto.png"], | |
| ["example_images/eevee.png"], | |
| ["example_images/ekans.png"], | |
| ] | |
| with gr.Blocks(title="Pokemon Classifier β Model Comparison") as demo: | |
| gr.Markdown( | |
| "# Pokemon Classifier β Model Comparison\n" | |
| "Upload a Pokemon image and compare three approaches: a fine-tuned ViT " | |
| "(transfer learning), CLIP zero-shot, and OpenAI GPT-4o vision.\n\n" | |
| f"**Classes:** {', '.join(POKEMON_LABELS)}" | |
| ) | |
| with gr.Row(): | |
| inp = gr.Image(type="filepath", label="Input image") | |
| with gr.Column(): | |
| out_custom = gr.Label(label="Custom ViT (fine-tuned)", num_top_classes=6) | |
| out_clip = gr.Label(label="CLIP zero-shot", num_top_classes=6) | |
| out_openai = gr.Label(label="OpenAI GPT-4o", num_top_classes=6) | |
| btn = gr.Button("Classify", variant="primary") | |
| btn.click(classify, inputs=inp, outputs=[out_custom, out_clip, out_openai]) | |
| gr.Examples(examples=examples, inputs=inp) | |
| if __name__ == "__main__": | |
| demo.launch() | |