File size: 4,616 Bytes
8247df2
eaebace
228da07
 
8247df2
 
eaebace
228da07
eaebace
8247df2
 
 
bb2ea71
8247df2
 
 
 
 
 
0a1174b
8247df2
 
 
 
 
 
 
228da07
 
602e64e
228da07
 
8247df2
 
 
 
228da07
 
 
 
 
8247df2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce2bb40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228da07
ce2bb40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228da07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8247df2
 
 
 
 
 
228da07
72d8624
e431a80
 
 
8247df2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228da07
 
8247df2
 
228da07
8247df2
 
 
 
bbff3e1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import os
import json
import base64
from pathlib import Path

import gradio as gr
from openai import OpenAI
from transformers import pipeline

BASE_DIR = Path(__file__).resolve().parent
EXAMPLE_DIR = BASE_DIR / "example_images"

MODEL_PATH = "DKatheesrupan/cat-vit"

CAT_LABELS = ["cheetah", "leopard", "lion", "puma", "tiger"]

print("Loading custom model...")
vit_classifier = pipeline(
    "image-classification",
    model=MODEL_PATH
)

print("Loading CLIP model...")
clip_classifier = pipeline(
    task="zero-shot-image-classification",
    model="openai/clip-vit-base-patch32"
)

# OpenAI key comes from Hugging Face Space secret: OPENAI_API_KEY
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))


# ----------------------------
# Helper functions
# ----------------------------

def encode_image(image_path):
    with open(image_path, "rb") as f:
        return base64.b64encode(f.read()).decode("utf-8")


def normalize_custom_labels(results):
    id2label = {
        "LABEL_0": "cheetah",
        "LABEL_1": "leopard",
        "LABEL_2": "lion",
        "LABEL_3": "puma",
        "LABEL_4": "tiger",
    }

    output = {}

    for r in results:
        label = r["label"]
        score = float(r["score"])

        if label in id2label:
            label = id2label[label]
        else:
            label = label.lower()

        output[label] = score

    return output


def classify_with_openai(image_path):
    base64_image = encode_image(image_path)

    prompt = f"""
You are a big cat classifier.

Classify the image into exactly one of these labels:

{CAT_LABELS}

Return ONLY valid JSON.
Do not use markdown.
Do not use code fences.
Do not add explanations.

Required format:
{{"label":"one_of_{CAT_LABELS}","confidence":0.0}}
"""

    try:
        response = client.responses.create(
            model="gpt-4.1-mini",
            input=[
                {
                    "role": "user",
                    "content": [
                        {"type": "input_text", "text": prompt},
                        {
                            "type": "input_image",
                            "image_url": f"data:image/jpeg;base64,{base64_image}"
                        }
                    ]
                }
            ]
        )

        text = response.output_text.strip()
        text = text.replace("```json", "").replace("```", "").strip()

        start = text.find("{")
        end = text.rfind("}")
        if start != -1 and end != -1 and end > start:
            text = text[start:end + 1]

        result = json.loads(text)

        label = str(result["label"]).strip().lower()
        confidence = float(result["confidence"])

        if label not in CAT_LABELS:
            raise ValueError(f"Invalid label: {label}")

        confidence = max(0.0, min(1.0, confidence))
        remaining = 1.0 - confidence
        num_other = len(CAT_LABELS) - 1

        distribution = {}

        for l in CAT_LABELS:
            if l == label:
                distribution[l] = confidence
            else:
                distribution[l] = remaining / num_other

        return distribution

    except Exception:
        return {"unknown": 1.0}


# ----------------------------
# Main function
# ----------------------------

def classify_cat(image):
    # Custom Model
    vit_results = vit_classifier(image)
    vit_output = normalize_custom_labels(vit_results)

    # CLIP
    clip_labels = [f"a photo of a {label}" for label in CAT_LABELS]
    clip_results = clip_classifier(image, candidate_labels=clip_labels)

    clip_output = {}
    for r in clip_results:
        label = r["label"].replace("a photo of a ", "").lower()
        score = float(r["score"])
        clip_output[label] = score

    # OpenAI
    openai_output = classify_with_openai(image)

    return vit_output, clip_output, openai_output


# ----------------------------
# Example images
# ----------------------------

example_images = [
    [str(EXAMPLE_DIR / "Cheetah_032.jpg")],
    [str(EXAMPLE_DIR / "Leopard_001.jpg")],
    [str(EXAMPLE_DIR / "Lion_003.jpg")],
    [str(EXAMPLE_DIR / "Puma_001.jpg")],
    [str(EXAMPLE_DIR / "Tiger_001.jpg")]
]


# ----------------------------
# Interface
# ----------------------------

iface = gr.Interface(
    fn=classify_cat,
    inputs=gr.Image(type="filepath"),
    outputs=[
        gr.Label(label="Custom Model"),
        gr.Label(label="CLIP"),
        gr.Label(label="OpenAI")
    ],
    title="Big Cat Classification",
    description="Compare Custom Model vs CLIP vs OpenAI",
    examples=example_images
)

if __name__ == "__main__":
    iface.launch()