import base64 import json import os import gradio as gr from dotenv import load_dotenv from openai import OpenAI from transformers import pipeline load_dotenv() OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4.1-mini") OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") openai_client = OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None # Load models vit_classifier = pipeline("image-classification", model="adisaljusi/cifar10-vit") clip_detector = pipeline( model="openai/clip-vit-large-patch14", task="zero-shot-image-classification", ) labels_cifar10 = [ "airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck", ] def encode_image(image_path): with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode("utf-8") def classify_with_openai(image_path): if openai_client is None: return { "error": "Missing OPENAI_API_KEY. Add it to your environment or .env file to enable OpenAI classification." } prompt = ( "Classify the object in this image. Choose the best matching label from this list: " f"{', '.join(labels_cifar10)}. " "Return valid JSON with exactly these keys: " "label, confidence, reasoning. " "The confidence must be a number between 0 and 1." ) base64_image = encode_image(image_path) response = openai_client.responses.create( model=OPENAI_MODEL, input=[ { "role": "user", "content": [ {"type": "input_text", "text": prompt}, { "type": "input_image", "image_url": f"data:image/jpeg;base64,{base64_image}", }, ], } ], ) try: text = response.output_text.strip() if text.startswith("```"): text = text.split("\n", 1)[1].rsplit("```", 1)[0].strip() parsed_response = json.loads(text) except (json.JSONDecodeError, IndexError): parsed_response = { "raw_response": response.output_text, "warning": "OpenAI response was not valid JSON.", } return parsed_response def classify_image(image): vit_results = vit_classifier(image) vit_output = {result["label"]: result["score"] for result in vit_results} clip_results = clip_detector(image, candidate_labels=labels_cifar10) clip_output = {result["label"]: result["score"] for result in clip_results} openai_output = classify_with_openai(image) return { "ViT Classification": vit_output, "CLIP Zero-Shot Classification": clip_output, "OpenAI Vision Classification": openai_output, } example_images = [ ["example_images/airplane.jpg"], ["example_images/automobile.jpg"], ["example_images/cat.jpg"], ["example_images/dog.jpg"], ["example_images/horse.jpg"], ["example_images/ship.jpg"], ] iface = gr.Interface( fn=classify_image, inputs=gr.Image(type="filepath"), outputs=gr.JSON(), title="CIFAR-10 Classification Comparison", description=( "Upload an image and compare classification results from three models: " "a fine-tuned ViT model, a zero-shot CLIP model, and OpenAI GPT-4.1-mini vision." ), examples=example_images, ) iface.launch()