wueesnin's picture
Update app.py
75bcf2a verified
import os
import base64
from io import BytesIO
import gradio as gr
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification, CLIPModel, CLIPProcessor
try:
from openai import OpenAI
except Exception:
OpenAI = None
LABELS = [
"eren",
"naruto",
"totoro",
"sakura",
"mikasa",
"luffy",
"cherry",
"kirito",
"doraemon",
"asuna",
"chihiro",
]
EXAMPLES = [[f"example_images/{name}"] for name in [
"eren.JPG",
"naruto.webp",
"totoro.webp",
"sakura.webp",
"mikasa.JPG",
"luffy.webp",
"cherry.webp",
"kirito.webp",
"doraemon.webp",
"asuna.webp",
"chihiro.webp",
]]
PROMPTS = [
"Eren Yeager from Attack on Titan",
"Naruto Uzumaki from Naruto",
"Totoro from My Neighbor Totoro",
"Sakura Haruno from Naruto",
"Mikasa Ackermann from Attack on Titan",
"Luffy from One Piece",
"Cherry Magic :3",
"Kirito from Sword Art Online",
"Doraemon",
"Asuna Yuuki from Sword Art Online",
"Chihiro Ogino from Spirited Away",
]
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CUSTOM_MODEL_ID = os.getenv("CUSTOM_MODEL_ID", "your-username/your-model-name")
OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4.1-mini")
def rgb(image):
return image.convert("RGB") if image.mode != "RGB" else image
def format_top3(labels, scores):
pairs = sorted(zip(labels, scores), key=lambda x: x[1], reverse=True)[:3]
text = "\n".join(f"{i+1}. {label} ({score:.4f})" for i, (label, score) in enumerate(pairs))
return text, dict(pairs)
try:
custom_processor = AutoImageProcessor.from_pretrained(CUSTOM_MODEL_ID)
custom_model = AutoModelForImageClassification.from_pretrained(CUSTOM_MODEL_ID).to(DEVICE)
custom_model.eval()
custom_error = None
except Exception as e:
custom_processor, custom_model, custom_error = None, None, str(e)
try:
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(DEVICE)
clip_model.eval()
clip_error = None
except Exception as e:
clip_processor, clip_model, clip_error = None, None, str(e)
def predict_custom(image):
if custom_model is None:
return f"Custom model could not be loaded.\n\n{custom_error}", {}
inputs = custom_processor(images=rgb(image), return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
probs = torch.softmax(custom_model(**inputs).logits, dim=-1)[0].tolist()
id2label = custom_model.config.id2label
labels = [str(id2label[i]).replace("_", " ").lower() for i in range(len(probs))]
return format_top3(labels, probs)
def predict_clip(image):
if clip_model is None:
return f"CLIP model could not be loaded.\n\n{clip_error}", {}
inputs = clip_processor(
text=[f"anime character {p}" for p in PROMPTS],
images=rgb(image),
return_tensors="pt",
padding=True,
)
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
probs = torch.softmax(clip_model(**inputs).logits_per_image[0], dim=-1).tolist()
return format_top3(LABELS, probs)
def predict_openai(image):
if OpenAI is None:
return "OpenAI package is not installed."
if not os.getenv("OPENAI_API_KEY"):
return "OPENAI_API_KEY is not set."
try:
buffer = BytesIO()
rgb(image).save(buffer, format="JPEG")
img_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
response = client.responses.create(
model=OPENAI_MODEL,
input=[{
"role": "user",
"content": [
{
"type": "input_text",
"text": (
"Choose exactly one anime character label from this list: "
f"{', '.join(LABELS)}. "
"Return exactly:\nlabel: <label>\nreason: <short reason>"
),
},
{"type": "input_image", "image_url": f"data:image/jpeg;base64,{img_b64}"},
],
}],
)
return response.output_text.strip()
except Exception as e:
return f"OpenAI prediction failed: {e}"
def compare(image):
if image is None:
msg = "Please upload an image."
return msg, {}, msg, {}, msg
a, a_scores = predict_custom(image)
b, b_scores = predict_clip(image)
c = predict_openai(image)
return a, a_scores, b, b_scores, c
with gr.Blocks() as demo:
gr.Markdown("# Anime Character Classifier")
gr.Markdown("Compare a fine-tuned model, CLIP zero-shot, and OpenAI vision on 9 anime character labels.")
image = gr.Image(type="pil", label="Upload image")
button = gr.Button("Run comparison")
with gr.Row():
out_custom = gr.Textbox(label="Fine-tuned model", lines=6)
out_clip = gr.Textbox(label="CLIP zero-shot", lines=6)
out_openai = gr.Textbox(label="OpenAI vision", lines=6)
with gr.Row():
scores_custom = gr.Label(label="Fine-tuned scores")
scores_clip = gr.Label(label="CLIP scores")
button.click(compare, image, [out_custom, scores_custom, out_clip, scores_clip, out_openai])
gr.Examples(EXAMPLES, inputs=image)
demo.launch()