import gradio as gr from transformers import pipeline import openai import base64 import os import json # --------------------------------------------------------------------------- # Car brands — must match the classes the ViT model was trained on # --------------------------------------------------------------------------- CAR_BRANDS = ['BMW', 'Dodge', 'Ferrari', 'Ford', 'Jeep', 'Lamborghini', 'Porsche', 'Rolls-Royce', 'Toyota'] # --------------------------------------------------------------------------- # Load models (loaded once at startup) # --------------------------------------------------------------------------- # Custom fine-tuned ViT model (trained on Stanford Cars, 8 brand classes) # Replace with your actual Hugging Face model ID after training and pushing vit_classifier = pipeline( "image-classification", model="nenzilea/car-classification" ) # CLIP zero-shot classifier clip_classifier = pipeline( model="openai/clip-vit-large-patch14", task="zero-shot-image-classification" ) # --------------------------------------------------------------------------- # OpenAI helper # --------------------------------------------------------------------------- def encode_image_to_base64(image_path: str) -> str: with open(image_path, "rb") as f: return base64.b64encode(f.read()).decode("utf-8") def classify_with_openai(image_path: str) -> dict: """Send image to GPT-4o and ask it to return confidence scores per brand.""" try: api_key = os.environ.get("OPENAI_API_KEY") if not api_key: return {"Error: OPENAI_API_KEY not set": 1.0} client = openai.OpenAI(api_key=api_key) ext = os.path.splitext(image_path)[1].lower().lstrip(".") mime_type = "image/jpeg" if ext in ("jpg", "jpeg") else f"image/{ext}" base64_image = encode_image_to_base64(image_path) prompt = ( f"You are a car classification expert. Classify the car brand shown in this image. " f"The possible classes are: {', '.join(CAR_BRANDS)}. " "Respond ONLY with a valid JSON object where each key is a brand name from the list and " "each value is a confidence score between 0.0 and 1.0. All scores must sum to 1.0. " 'Example format: {"BMW": 0.05, "Ferrari": 0.85, "Ford": 0.02, ...}' ) response = client.chat.completions.create( model="gpt-4o", messages=[ { "role": "user", "content": [ {"type": "text", "text": prompt}, { "type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{base64_image}"}, }, ], } ], max_tokens=300, ) text = response.choices[0].message.content start = text.find("{") end = text.rfind("}") + 1 scores = json.loads(text[start:end]) return {brand: float(scores.get(brand, 0.0)) for brand in CAR_BRANDS} except openai.AuthenticationError: return {"Error: Invalid OpenAI API key": 1.0} except Exception as e: return {f"Error: {str(e)[:60]}": 1.0} # --------------------------------------------------------------------------- # Main classification function # --------------------------------------------------------------------------- def classify_car(image): if image is None: return {}, {}, {} # Custom ViT — fine-tuned on Stanford Cars vit_results = vit_classifier(image, top_k=len(CAR_BRANDS)) vit_output = {r["label"]: round(r["score"], 4) for r in vit_results} # CLIP — zero-shot with brand names as candidate labels clip_results = clip_classifier(image, candidate_labels=CAR_BRANDS) clip_output = {r["label"]: round(r["score"], 4) for r in clip_results} # OpenAI GPT-4o Vision openai_output = classify_with_openai(image) return vit_output, clip_output, openai_output # --------------------------------------------------------------------------- # Example images (add representative car images to example_images/) # --------------------------------------------------------------------------- _img_dir = os.path.join(os.path.dirname(__file__), "example_images") example_images = [ [os.path.join(_img_dir, f)] for f in sorted(os.listdir(_img_dir)) if f.lower().endswith((".jpg", ".jpeg", ".png", ".webp")) ] # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- css = """ .title { text-align: center; margin-bottom: 0.25rem; } .subtitle { text-align: center; color: #6b7280; margin-bottom: 1.5rem; font-size: 0.95rem; } .model-header { font-weight: 600; font-size: 1rem; margin-bottom: 0.25rem; padding: 0.4rem 0.75rem; border-radius: 6px; background: #f3f4f6; } .classify-btn { max-width: 200px; margin: 0 auto; } footer { display: none !important; } """ with gr.Blocks(title="Car Brand Classification", css=css, theme=gr.themes.Soft()) as demo: # Header gr.Markdown("# 🚗 Car Brand Classification", elem_classes="title") gr.Markdown( "Upload a car image and compare predictions from three models side by side.", elem_classes="subtitle" ) # Input + button with gr.Row(): with gr.Column(scale=1): input_image = gr.Image( type="filepath", label="Car Image", height=280, ) classify_btn = gr.Button("Classify", variant="primary", elem_classes="classify-btn") # Results with gr.Column(scale=2): with gr.Row(): with gr.Column(): gr.Markdown("**Custom ViT** — fine-tuned", elem_classes="model-header") vit_output = gr.Label(num_top_classes=5, label="") with gr.Column(): gr.Markdown("**CLIP** — zero-shot", elem_classes="model-header") clip_output = gr.Label(num_top_classes=5, label="") with gr.Column(): gr.Markdown("**GPT-4o** — vision LLM", elem_classes="model-header") openai_output = gr.Label(num_top_classes=5, label="") # Examples gr.Examples( examples=example_images, inputs=input_image, label="Example Images", ) classify_btn.click( fn=classify_car, inputs=[input_image], outputs=[vit_output, clip_output, openai_output], ) demo.launch()