Spaces:
Sleeping
Sleeping
File size: 5,529 Bytes
4194462 75bcf2a 4194462 75bcf2a 4194462 75bcf2a 4011ff2 75bcf2a 4011ff2 4194462 75bcf2a 4194462 75bcf2a 4011ff2 75bcf2a 4011ff2 75bcf2a 4194462 75bcf2a 4194462 75bcf2a 4011ff2 4194462 75bcf2a 4194462 75bcf2a 4194462 75bcf2a 4194462 75bcf2a 4194462 75bcf2a 4194462 75bcf2a 4194462 75bcf2a 4194462 75bcf2a 4194462 75bcf2a 4194462 75bcf2a 4194462 75bcf2a 4194462 75bcf2a 4194462 4011ff2 75bcf2a 4194462 75bcf2a 4194462 75bcf2a 4194462 75bcf2a 4011ff2 75bcf2a 4194462 4011ff2 75bcf2a 4194462 75bcf2a 4194462 75bcf2a 4194462 75bcf2a 4194462 75bcf2a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 | import os
import base64
from io import BytesIO
import gradio as gr
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification, CLIPModel, CLIPProcessor
try:
from openai import OpenAI
except Exception:
OpenAI = None
LABELS = [
"eren",
"naruto",
"totoro",
"sakura",
"mikasa",
"luffy",
"cherry",
"kirito",
"doraemon",
"asuna",
"chihiro",
]
EXAMPLES = [[f"example_images/{name}"] for name in [
"eren.JPG",
"naruto.webp",
"totoro.webp",
"sakura.webp",
"mikasa.JPG",
"luffy.webp",
"cherry.webp",
"kirito.webp",
"doraemon.webp",
"asuna.webp",
"chihiro.webp",
]]
PROMPTS = [
"Eren Yeager from Attack on Titan",
"Naruto Uzumaki from Naruto",
"Totoro from My Neighbor Totoro",
"Sakura Haruno from Naruto",
"Mikasa Ackermann from Attack on Titan",
"Luffy from One Piece",
"Cherry Magic :3",
"Kirito from Sword Art Online",
"Doraemon",
"Asuna Yuuki from Sword Art Online",
"Chihiro Ogino from Spirited Away",
]
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CUSTOM_MODEL_ID = os.getenv("CUSTOM_MODEL_ID", "your-username/your-model-name")
OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4.1-mini")
def rgb(image):
return image.convert("RGB") if image.mode != "RGB" else image
def format_top3(labels, scores):
pairs = sorted(zip(labels, scores), key=lambda x: x[1], reverse=True)[:3]
text = "\n".join(f"{i+1}. {label} ({score:.4f})" for i, (label, score) in enumerate(pairs))
return text, dict(pairs)
try:
custom_processor = AutoImageProcessor.from_pretrained(CUSTOM_MODEL_ID)
custom_model = AutoModelForImageClassification.from_pretrained(CUSTOM_MODEL_ID).to(DEVICE)
custom_model.eval()
custom_error = None
except Exception as e:
custom_processor, custom_model, custom_error = None, None, str(e)
try:
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(DEVICE)
clip_model.eval()
clip_error = None
except Exception as e:
clip_processor, clip_model, clip_error = None, None, str(e)
def predict_custom(image):
if custom_model is None:
return f"Custom model could not be loaded.\n\n{custom_error}", {}
inputs = custom_processor(images=rgb(image), return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
probs = torch.softmax(custom_model(**inputs).logits, dim=-1)[0].tolist()
id2label = custom_model.config.id2label
labels = [str(id2label[i]).replace("_", " ").lower() for i in range(len(probs))]
return format_top3(labels, probs)
def predict_clip(image):
if clip_model is None:
return f"CLIP model could not be loaded.\n\n{clip_error}", {}
inputs = clip_processor(
text=[f"anime character {p}" for p in PROMPTS],
images=rgb(image),
return_tensors="pt",
padding=True,
)
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
probs = torch.softmax(clip_model(**inputs).logits_per_image[0], dim=-1).tolist()
return format_top3(LABELS, probs)
def predict_openai(image):
if OpenAI is None:
return "OpenAI package is not installed."
if not os.getenv("OPENAI_API_KEY"):
return "OPENAI_API_KEY is not set."
try:
buffer = BytesIO()
rgb(image).save(buffer, format="JPEG")
img_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
response = client.responses.create(
model=OPENAI_MODEL,
input=[{
"role": "user",
"content": [
{
"type": "input_text",
"text": (
"Choose exactly one anime character label from this list: "
f"{', '.join(LABELS)}. "
"Return exactly:\nlabel: <label>\nreason: <short reason>"
),
},
{"type": "input_image", "image_url": f"data:image/jpeg;base64,{img_b64}"},
],
}],
)
return response.output_text.strip()
except Exception as e:
return f"OpenAI prediction failed: {e}"
def compare(image):
if image is None:
msg = "Please upload an image."
return msg, {}, msg, {}, msg
a, a_scores = predict_custom(image)
b, b_scores = predict_clip(image)
c = predict_openai(image)
return a, a_scores, b, b_scores, c
with gr.Blocks() as demo:
gr.Markdown("# Anime Character Classifier")
gr.Markdown("Compare a fine-tuned model, CLIP zero-shot, and OpenAI vision on 9 anime character labels.")
image = gr.Image(type="pil", label="Upload image")
button = gr.Button("Run comparison")
with gr.Row():
out_custom = gr.Textbox(label="Fine-tuned model", lines=6)
out_clip = gr.Textbox(label="CLIP zero-shot", lines=6)
out_openai = gr.Textbox(label="OpenAI vision", lines=6)
with gr.Row():
scores_custom = gr.Label(label="Fine-tuned scores")
scores_clip = gr.Label(label="CLIP scores")
button.click(compare, image, [out_custom, scores_custom, out_clip, scores_clip, out_openai])
gr.Examples(EXAMPLES, inputs=image)
demo.launch()
|