Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import pipeline | |
| import openai | |
| import base64 | |
| import os | |
| import json | |
| # --------------------------------------------------------------------------- | |
| # Car brands β must match the classes the ViT model was trained on | |
| # --------------------------------------------------------------------------- | |
| CAR_BRANDS = ['BMW', 'Dodge', 'Ferrari', 'Ford', 'Jeep', 'Lamborghini', 'Porsche', 'Rolls-Royce', 'Toyota'] | |
| # --------------------------------------------------------------------------- | |
| # Load models (loaded once at startup) | |
| # --------------------------------------------------------------------------- | |
| # Custom fine-tuned ViT model (trained on Stanford Cars, 8 brand classes) | |
| # Replace with your actual Hugging Face model ID after training and pushing | |
| vit_classifier = pipeline( | |
| "image-classification", | |
| model="nenzilea/car-classification" | |
| ) | |
| # CLIP zero-shot classifier | |
| clip_classifier = pipeline( | |
| model="openai/clip-vit-large-patch14", | |
| task="zero-shot-image-classification" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # OpenAI helper | |
| # --------------------------------------------------------------------------- | |
| def encode_image_to_base64(image_path: str) -> str: | |
| with open(image_path, "rb") as f: | |
| return base64.b64encode(f.read()).decode("utf-8") | |
| def classify_with_openai(image_path: str) -> dict: | |
| """Send image to GPT-4o and ask it to return confidence scores per brand.""" | |
| try: | |
| api_key = os.environ.get("OPENAI_API_KEY") | |
| if not api_key: | |
| return {"Error: OPENAI_API_KEY not set": 1.0} | |
| client = openai.OpenAI(api_key=api_key) | |
| ext = os.path.splitext(image_path)[1].lower().lstrip(".") | |
| mime_type = "image/jpeg" if ext in ("jpg", "jpeg") else f"image/{ext}" | |
| base64_image = encode_image_to_base64(image_path) | |
| prompt = ( | |
| f"You are a car classification expert. Classify the car brand shown in this image. " | |
| f"The possible classes are: {', '.join(CAR_BRANDS)}. " | |
| "Respond ONLY with a valid JSON object where each key is a brand name from the list and " | |
| "each value is a confidence score between 0.0 and 1.0. All scores must sum to 1.0. " | |
| 'Example format: {"BMW": 0.05, "Ferrari": 0.85, "Ford": 0.02, ...}' | |
| ) | |
| response = client.chat.completions.create( | |
| model="gpt-4o", | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": prompt}, | |
| { | |
| "type": "image_url", | |
| "image_url": {"url": f"data:{mime_type};base64,{base64_image}"}, | |
| }, | |
| ], | |
| } | |
| ], | |
| max_tokens=300, | |
| ) | |
| text = response.choices[0].message.content | |
| start = text.find("{") | |
| end = text.rfind("}") + 1 | |
| scores = json.loads(text[start:end]) | |
| return {brand: float(scores.get(brand, 0.0)) for brand in CAR_BRANDS} | |
| except openai.AuthenticationError: | |
| return {"Error: Invalid OpenAI API key": 1.0} | |
| except Exception as e: | |
| return {f"Error: {str(e)[:60]}": 1.0} | |
| # --------------------------------------------------------------------------- | |
| # Main classification function | |
| # --------------------------------------------------------------------------- | |
| def classify_car(image): | |
| if image is None: | |
| return {}, {}, {} | |
| # Custom ViT β fine-tuned on Stanford Cars | |
| vit_results = vit_classifier(image, top_k=len(CAR_BRANDS)) | |
| vit_output = {r["label"]: round(r["score"], 4) for r in vit_results} | |
| # CLIP β zero-shot with brand names as candidate labels | |
| clip_results = clip_classifier(image, candidate_labels=CAR_BRANDS) | |
| clip_output = {r["label"]: round(r["score"], 4) for r in clip_results} | |
| # OpenAI GPT-4o Vision | |
| openai_output = classify_with_openai(image) | |
| return vit_output, clip_output, openai_output | |
| # --------------------------------------------------------------------------- | |
| # Example images (add representative car images to example_images/) | |
| # --------------------------------------------------------------------------- | |
| _img_dir = os.path.join(os.path.dirname(__file__), "example_images") | |
| example_images = [ | |
| [os.path.join(_img_dir, f)] | |
| for f in sorted(os.listdir(_img_dir)) | |
| if f.lower().endswith((".jpg", ".jpeg", ".png", ".webp")) | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| css = """ | |
| .title { text-align: center; margin-bottom: 0.25rem; } | |
| .subtitle { text-align: center; color: #6b7280; margin-bottom: 1.5rem; font-size: 0.95rem; } | |
| .model-header { font-weight: 600; font-size: 1rem; margin-bottom: 0.25rem; padding: 0.4rem 0.75rem; | |
| border-radius: 6px; background: #f3f4f6; } | |
| .classify-btn { max-width: 200px; margin: 0 auto; } | |
| footer { display: none !important; } | |
| """ | |
| with gr.Blocks(title="Car Brand Classification", css=css, theme=gr.themes.Soft()) as demo: | |
| # Header | |
| gr.Markdown("# π Car Brand Classification", elem_classes="title") | |
| gr.Markdown( | |
| "Upload a car image and compare predictions from three models side by side.", | |
| elem_classes="subtitle" | |
| ) | |
| # Input + button | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image( | |
| type="filepath", | |
| label="Car Image", | |
| height=280, | |
| ) | |
| classify_btn = gr.Button("Classify", variant="primary", elem_classes="classify-btn") | |
| # Results | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("**Custom ViT** β fine-tuned", elem_classes="model-header") | |
| vit_output = gr.Label(num_top_classes=5, label="") | |
| with gr.Column(): | |
| gr.Markdown("**CLIP** β zero-shot", elem_classes="model-header") | |
| clip_output = gr.Label(num_top_classes=5, label="") | |
| with gr.Column(): | |
| gr.Markdown("**GPT-4o** β vision LLM", elem_classes="model-header") | |
| openai_output = gr.Label(num_top_classes=5, label="") | |
| # Examples | |
| gr.Examples( | |
| examples=example_images, | |
| inputs=input_image, | |
| label="Example Images", | |
| ) | |
| classify_btn.click( | |
| fn=classify_car, | |
| inputs=[input_image], | |
| outputs=[vit_output, clip_output, openai_output], | |
| ) | |
| demo.launch() | |