File size: 5,529 Bytes
4194462
75bcf2a
 
4194462
 
 
 
75bcf2a
4194462
 
 
 
 
 
 
75bcf2a
4011ff2
75bcf2a
 
 
 
 
 
4011ff2
 
 
 
4194462
 
75bcf2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4194462
 
75bcf2a
 
 
4011ff2
 
75bcf2a
 
4011ff2
 
75bcf2a
 
 
 
4194462
 
75bcf2a
 
 
 
 
 
 
4194462
75bcf2a
 
 
 
 
 
 
4011ff2
4194462
75bcf2a
 
 
4194462
75bcf2a
 
4194462
75bcf2a
4194462
75bcf2a
 
 
4194462
 
75bcf2a
 
 
4194462
75bcf2a
 
 
 
 
 
 
4194462
75bcf2a
4194462
75bcf2a
4194462
 
75bcf2a
4194462
75bcf2a
 
 
4194462
 
75bcf2a
 
 
4194462
75bcf2a
4194462
4011ff2
75bcf2a
 
 
 
 
 
 
 
 
 
 
 
 
 
4194462
 
75bcf2a
 
4194462
 
75bcf2a
4194462
75bcf2a
4011ff2
75bcf2a
 
 
 
4194462
 
 
4011ff2
75bcf2a
4194462
75bcf2a
 
4194462
75bcf2a
 
 
 
4194462
 
75bcf2a
 
 
 
 
4194462
 
75bcf2a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import os
import base64
from io import BytesIO

import gradio as gr
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification, CLIPModel, CLIPProcessor

try:
    from openai import OpenAI
except Exception:
    OpenAI = None


LABELS = [
    "eren",
    "naruto",
    "totoro",
    "sakura",
    "mikasa",
    "luffy",
    "cherry",
    "kirito",
    "doraemon",
    "asuna",
    "chihiro",
]

EXAMPLES = [[f"example_images/{name}"] for name in [
    "eren.JPG",
    "naruto.webp",
    "totoro.webp",
    "sakura.webp",
    "mikasa.JPG",
    "luffy.webp",
    "cherry.webp",
    "kirito.webp",
    "doraemon.webp",
    "asuna.webp",
    "chihiro.webp",
]]

PROMPTS = [
    "Eren Yeager from Attack on Titan",
    "Naruto Uzumaki from Naruto",
    "Totoro from My Neighbor Totoro",
    "Sakura Haruno from Naruto",
    "Mikasa Ackermann from Attack on Titan",
    "Luffy from One Piece",
    "Cherry Magic :3",
    "Kirito from Sword Art Online",
    "Doraemon",
    "Asuna Yuuki from Sword Art Online",
    "Chihiro Ogino from Spirited Away",
]

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CUSTOM_MODEL_ID = os.getenv("CUSTOM_MODEL_ID", "your-username/your-model-name")
OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4.1-mini")


def rgb(image):
    return image.convert("RGB") if image.mode != "RGB" else image


def format_top3(labels, scores):
    pairs = sorted(zip(labels, scores), key=lambda x: x[1], reverse=True)[:3]
    text = "\n".join(f"{i+1}. {label} ({score:.4f})" for i, (label, score) in enumerate(pairs))
    return text, dict(pairs)


try:
    custom_processor = AutoImageProcessor.from_pretrained(CUSTOM_MODEL_ID)
    custom_model = AutoModelForImageClassification.from_pretrained(CUSTOM_MODEL_ID).to(DEVICE)
    custom_model.eval()
    custom_error = None
except Exception as e:
    custom_processor, custom_model, custom_error = None, None, str(e)

try:
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(DEVICE)
    clip_model.eval()
    clip_error = None
except Exception as e:
    clip_processor, clip_model, clip_error = None, None, str(e)


def predict_custom(image):
    if custom_model is None:
        return f"Custom model could not be loaded.\n\n{custom_error}", {}

    inputs = custom_processor(images=rgb(image), return_tensors="pt")
    inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
    with torch.no_grad():
        probs = torch.softmax(custom_model(**inputs).logits, dim=-1)[0].tolist()

    id2label = custom_model.config.id2label
    labels = [str(id2label[i]).replace("_", " ").lower() for i in range(len(probs))]
    return format_top3(labels, probs)


def predict_clip(image):
    if clip_model is None:
        return f"CLIP model could not be loaded.\n\n{clip_error}", {}

    inputs = clip_processor(
        text=[f"anime character {p}" for p in PROMPTS],
        images=rgb(image),
        return_tensors="pt",
        padding=True,
    )
    inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
    with torch.no_grad():
        probs = torch.softmax(clip_model(**inputs).logits_per_image[0], dim=-1).tolist()

    return format_top3(LABELS, probs)


def predict_openai(image):
    if OpenAI is None:
        return "OpenAI package is not installed."
    if not os.getenv("OPENAI_API_KEY"):
        return "OPENAI_API_KEY is not set."

    try:
        buffer = BytesIO()
        rgb(image).save(buffer, format="JPEG")
        img_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")

        client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
        response = client.responses.create(
            model=OPENAI_MODEL,
            input=[{
                "role": "user",
                "content": [
                    {
                        "type": "input_text",
                        "text": (
                            "Choose exactly one anime character label from this list: "
                            f"{', '.join(LABELS)}. "
                            "Return exactly:\nlabel: <label>\nreason: <short reason>"
                        ),
                    },
                    {"type": "input_image", "image_url": f"data:image/jpeg;base64,{img_b64}"},
                ],
            }],
        )
        return response.output_text.strip()
    except Exception as e:
        return f"OpenAI prediction failed: {e}"


def compare(image):
    if image is None:
        msg = "Please upload an image."
        return msg, {}, msg, {}, msg
    a, a_scores = predict_custom(image)
    b, b_scores = predict_clip(image)
    c = predict_openai(image)
    return a, a_scores, b, b_scores, c


with gr.Blocks() as demo:
    gr.Markdown("# Anime Character Classifier")
    gr.Markdown("Compare a fine-tuned model, CLIP zero-shot, and OpenAI vision on 9 anime character labels.")

    image = gr.Image(type="pil", label="Upload image")
    button = gr.Button("Run comparison")

    with gr.Row():
        out_custom = gr.Textbox(label="Fine-tuned model", lines=6)
        out_clip = gr.Textbox(label="CLIP zero-shot", lines=6)
        out_openai = gr.Textbox(label="OpenAI vision", lines=6)

    with gr.Row():
        scores_custom = gr.Label(label="Fine-tuned scores")
        scores_clip = gr.Label(label="CLIP scores")

    button.click(compare, image, [out_custom, scores_custom, out_clip, scores_clip, out_openai])
    gr.Examples(EXAMPLES, inputs=image)


demo.launch()