Spaces:
Sleeping
Sleeping
| import os | |
| import base64 | |
| from io import BytesIO | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification, CLIPModel, CLIPProcessor | |
| try: | |
| from openai import OpenAI | |
| except Exception: | |
| OpenAI = None | |
| LABELS = [ | |
| "eren", | |
| "naruto", | |
| "totoro", | |
| "sakura", | |
| "mikasa", | |
| "luffy", | |
| "cherry", | |
| "kirito", | |
| "doraemon", | |
| "asuna", | |
| "chihiro", | |
| ] | |
| EXAMPLES = [[f"example_images/{name}"] for name in [ | |
| "eren.JPG", | |
| "naruto.webp", | |
| "totoro.webp", | |
| "sakura.webp", | |
| "mikasa.JPG", | |
| "luffy.webp", | |
| "cherry.webp", | |
| "kirito.webp", | |
| "doraemon.webp", | |
| "asuna.webp", | |
| "chihiro.webp", | |
| ]] | |
| PROMPTS = [ | |
| "Eren Yeager from Attack on Titan", | |
| "Naruto Uzumaki from Naruto", | |
| "Totoro from My Neighbor Totoro", | |
| "Sakura Haruno from Naruto", | |
| "Mikasa Ackermann from Attack on Titan", | |
| "Luffy from One Piece", | |
| "Cherry Magic :3", | |
| "Kirito from Sword Art Online", | |
| "Doraemon", | |
| "Asuna Yuuki from Sword Art Online", | |
| "Chihiro Ogino from Spirited Away", | |
| ] | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| CUSTOM_MODEL_ID = os.getenv("CUSTOM_MODEL_ID", "your-username/your-model-name") | |
| OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4.1-mini") | |
| def rgb(image): | |
| return image.convert("RGB") if image.mode != "RGB" else image | |
| def format_top3(labels, scores): | |
| pairs = sorted(zip(labels, scores), key=lambda x: x[1], reverse=True)[:3] | |
| text = "\n".join(f"{i+1}. {label} ({score:.4f})" for i, (label, score) in enumerate(pairs)) | |
| return text, dict(pairs) | |
| try: | |
| custom_processor = AutoImageProcessor.from_pretrained(CUSTOM_MODEL_ID) | |
| custom_model = AutoModelForImageClassification.from_pretrained(CUSTOM_MODEL_ID).to(DEVICE) | |
| custom_model.eval() | |
| custom_error = None | |
| except Exception as e: | |
| custom_processor, custom_model, custom_error = None, None, str(e) | |
| try: | |
| clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(DEVICE) | |
| clip_model.eval() | |
| clip_error = None | |
| except Exception as e: | |
| clip_processor, clip_model, clip_error = None, None, str(e) | |
| def predict_custom(image): | |
| if custom_model is None: | |
| return f"Custom model could not be loaded.\n\n{custom_error}", {} | |
| inputs = custom_processor(images=rgb(image), return_tensors="pt") | |
| inputs = {k: v.to(DEVICE) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| probs = torch.softmax(custom_model(**inputs).logits, dim=-1)[0].tolist() | |
| id2label = custom_model.config.id2label | |
| labels = [str(id2label[i]).replace("_", " ").lower() for i in range(len(probs))] | |
| return format_top3(labels, probs) | |
| def predict_clip(image): | |
| if clip_model is None: | |
| return f"CLIP model could not be loaded.\n\n{clip_error}", {} | |
| inputs = clip_processor( | |
| text=[f"anime character {p}" for p in PROMPTS], | |
| images=rgb(image), | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| inputs = {k: v.to(DEVICE) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| probs = torch.softmax(clip_model(**inputs).logits_per_image[0], dim=-1).tolist() | |
| return format_top3(LABELS, probs) | |
| def predict_openai(image): | |
| if OpenAI is None: | |
| return "OpenAI package is not installed." | |
| if not os.getenv("OPENAI_API_KEY"): | |
| return "OPENAI_API_KEY is not set." | |
| try: | |
| buffer = BytesIO() | |
| rgb(image).save(buffer, format="JPEG") | |
| img_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8") | |
| client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| response = client.responses.create( | |
| model=OPENAI_MODEL, | |
| input=[{ | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "input_text", | |
| "text": ( | |
| "Choose exactly one anime character label from this list: " | |
| f"{', '.join(LABELS)}. " | |
| "Return exactly:\nlabel: <label>\nreason: <short reason>" | |
| ), | |
| }, | |
| {"type": "input_image", "image_url": f"data:image/jpeg;base64,{img_b64}"}, | |
| ], | |
| }], | |
| ) | |
| return response.output_text.strip() | |
| except Exception as e: | |
| return f"OpenAI prediction failed: {e}" | |
| def compare(image): | |
| if image is None: | |
| msg = "Please upload an image." | |
| return msg, {}, msg, {}, msg | |
| a, a_scores = predict_custom(image) | |
| b, b_scores = predict_clip(image) | |
| c = predict_openai(image) | |
| return a, a_scores, b, b_scores, c | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Anime Character Classifier") | |
| gr.Markdown("Compare a fine-tuned model, CLIP zero-shot, and OpenAI vision on 9 anime character labels.") | |
| image = gr.Image(type="pil", label="Upload image") | |
| button = gr.Button("Run comparison") | |
| with gr.Row(): | |
| out_custom = gr.Textbox(label="Fine-tuned model", lines=6) | |
| out_clip = gr.Textbox(label="CLIP zero-shot", lines=6) | |
| out_openai = gr.Textbox(label="OpenAI vision", lines=6) | |
| with gr.Row(): | |
| scores_custom = gr.Label(label="Fine-tuned scores") | |
| scores_clip = gr.Label(label="CLIP scores") | |
| button.click(compare, image, [out_custom, scores_custom, out_clip, scores_clip, out_openai]) | |
| gr.Examples(EXAMPLES, inputs=image) | |
| demo.launch() | |