File size: 3,445 Bytes
3e6a4eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
816b68d
3e6a4eb
 
 
 
 
 
816b68d
 
 
 
 
 
 
 
 
 
3e6a4eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c981ebe
 
 
 
 
3e6a4eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import base64
import json
import os

import gradio as gr
from dotenv import load_dotenv
from openai import OpenAI
from transformers import pipeline

load_dotenv()

OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4.1-mini")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
openai_client = OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None

# Load models
vit_classifier = pipeline("image-classification", model="adisaljusi/cifar10-vit")
clip_detector = pipeline(
    model="openai/clip-vit-large-patch14",
    task="zero-shot-image-classification",
)

labels_cifar10 = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]


def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")


def classify_with_openai(image_path):
    if openai_client is None:
        return {
            "error": "Missing OPENAI_API_KEY. Add it to your environment or .env file to enable OpenAI classification."
        }

    prompt = (
        "Classify the object in this image. Choose the best matching label from this list: "
        f"{', '.join(labels_cifar10)}. "
        "Return valid JSON with exactly these keys: "
        "label, confidence, reasoning. "
        "The confidence must be a number between 0 and 1."
    )

    base64_image = encode_image(image_path)
    response = openai_client.responses.create(
        model=OPENAI_MODEL,
        input=[
            {
                "role": "user",
                "content": [
                    {"type": "input_text", "text": prompt},
                    {
                        "type": "input_image",
                        "image_url": f"data:image/jpeg;base64,{base64_image}",
                    },
                ],
            }
        ],
    )

    try:
        text = response.output_text.strip()
        if text.startswith("```"):
            text = text.split("\n", 1)[1].rsplit("```", 1)[0].strip()
        parsed_response = json.loads(text)
    except (json.JSONDecodeError, IndexError):
        parsed_response = {
            "raw_response": response.output_text,
            "warning": "OpenAI response was not valid JSON.",
        }

    return parsed_response


def classify_image(image):
    vit_results = vit_classifier(image)
    vit_output = {result["label"]: result["score"] for result in vit_results}

    clip_results = clip_detector(image, candidate_labels=labels_cifar10)
    clip_output = {result["label"]: result["score"] for result in clip_results}

    openai_output = classify_with_openai(image)

    return {
        "ViT Classification": vit_output,
        "CLIP Zero-Shot Classification": clip_output,
        "OpenAI Vision Classification": openai_output,
    }


example_images = [
    ["example_images/airplane.jpg"],
    ["example_images/automobile.jpg"],
    ["example_images/cat.jpg"],
    ["example_images/dog.jpg"],
    ["example_images/horse.jpg"],
    ["example_images/ship.jpg"],
]

iface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="filepath"),
    outputs=gr.JSON(),
    title="CIFAR-10 Classification Comparison",
    description=(
        "Upload an image and compare classification results from three models: "
        "a fine-tuned ViT model, a zero-shot CLIP model, and OpenAI GPT-4.1-mini vision."
    ),
    examples=example_images,
)

iface.launch()