import os import base64 from io import BytesIO import gradio as gr import torch from PIL import Image from transformers import AutoImageProcessor, AutoModelForImageClassification, CLIPModel, CLIPProcessor try: from openai import OpenAI except Exception: OpenAI = None LABELS = [ "eren", "naruto", "totoro", "sakura", "mikasa", "luffy", "cherry", "kirito", "doraemon", "asuna", "chihiro", ] EXAMPLES = [[f"example_images/{name}"] for name in [ "eren.JPG", "naruto.webp", "totoro.webp", "sakura.webp", "mikasa.JPG", "luffy.webp", "cherry.webp", "kirito.webp", "doraemon.webp", "asuna.webp", "chihiro.webp", ]] PROMPTS = [ "Eren Yeager from Attack on Titan", "Naruto Uzumaki from Naruto", "Totoro from My Neighbor Totoro", "Sakura Haruno from Naruto", "Mikasa Ackermann from Attack on Titan", "Luffy from One Piece", "Cherry Magic :3", "Kirito from Sword Art Online", "Doraemon", "Asuna Yuuki from Sword Art Online", "Chihiro Ogino from Spirited Away", ] DEVICE = "cuda" if torch.cuda.is_available() else "cpu" CUSTOM_MODEL_ID = os.getenv("CUSTOM_MODEL_ID", "your-username/your-model-name") OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4.1-mini") def rgb(image): return image.convert("RGB") if image.mode != "RGB" else image def format_top3(labels, scores): pairs = sorted(zip(labels, scores), key=lambda x: x[1], reverse=True)[:3] text = "\n".join(f"{i+1}. {label} ({score:.4f})" for i, (label, score) in enumerate(pairs)) return text, dict(pairs) try: custom_processor = AutoImageProcessor.from_pretrained(CUSTOM_MODEL_ID) custom_model = AutoModelForImageClassification.from_pretrained(CUSTOM_MODEL_ID).to(DEVICE) custom_model.eval() custom_error = None except Exception as e: custom_processor, custom_model, custom_error = None, None, str(e) try: clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(DEVICE) clip_model.eval() clip_error = None except Exception as e: clip_processor, clip_model, clip_error = None, None, str(e) def predict_custom(image): if custom_model is None: return f"Custom model could not be loaded.\n\n{custom_error}", {} inputs = custom_processor(images=rgb(image), return_tensors="pt") inputs = {k: v.to(DEVICE) for k, v in inputs.items()} with torch.no_grad(): probs = torch.softmax(custom_model(**inputs).logits, dim=-1)[0].tolist() id2label = custom_model.config.id2label labels = [str(id2label[i]).replace("_", " ").lower() for i in range(len(probs))] return format_top3(labels, probs) def predict_clip(image): if clip_model is None: return f"CLIP model could not be loaded.\n\n{clip_error}", {} inputs = clip_processor( text=[f"anime character {p}" for p in PROMPTS], images=rgb(image), return_tensors="pt", padding=True, ) inputs = {k: v.to(DEVICE) for k, v in inputs.items()} with torch.no_grad(): probs = torch.softmax(clip_model(**inputs).logits_per_image[0], dim=-1).tolist() return format_top3(LABELS, probs) def predict_openai(image): if OpenAI is None: return "OpenAI package is not installed." if not os.getenv("OPENAI_API_KEY"): return "OPENAI_API_KEY is not set." try: buffer = BytesIO() rgb(image).save(buffer, format="JPEG") img_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8") client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) response = client.responses.create( model=OPENAI_MODEL, input=[{ "role": "user", "content": [ { "type": "input_text", "text": ( "Choose exactly one anime character label from this list: " f"{', '.join(LABELS)}. " "Return exactly:\nlabel: