nenzilea's picture
Upload 2 files
2d4062d verified
import gradio as gr
from transformers import pipeline
import openai
import base64
import os
import json
# ---------------------------------------------------------------------------
# Car brands β€” must match the classes the ViT model was trained on
# ---------------------------------------------------------------------------
CAR_BRANDS = ['BMW', 'Dodge', 'Ferrari', 'Ford', 'Jeep', 'Lamborghini', 'Porsche', 'Rolls-Royce', 'Toyota']
# ---------------------------------------------------------------------------
# Load models (loaded once at startup)
# ---------------------------------------------------------------------------
# Custom fine-tuned ViT model (trained on Stanford Cars, 8 brand classes)
# Replace with your actual Hugging Face model ID after training and pushing
vit_classifier = pipeline(
"image-classification",
model="nenzilea/car-classification"
)
# CLIP zero-shot classifier
clip_classifier = pipeline(
model="openai/clip-vit-large-patch14",
task="zero-shot-image-classification"
)
# ---------------------------------------------------------------------------
# OpenAI helper
# ---------------------------------------------------------------------------
def encode_image_to_base64(image_path: str) -> str:
with open(image_path, "rb") as f:
return base64.b64encode(f.read()).decode("utf-8")
def classify_with_openai(image_path: str) -> dict:
"""Send image to GPT-4o and ask it to return confidence scores per brand."""
try:
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
return {"Error: OPENAI_API_KEY not set": 1.0}
client = openai.OpenAI(api_key=api_key)
ext = os.path.splitext(image_path)[1].lower().lstrip(".")
mime_type = "image/jpeg" if ext in ("jpg", "jpeg") else f"image/{ext}"
base64_image = encode_image_to_base64(image_path)
prompt = (
f"You are a car classification expert. Classify the car brand shown in this image. "
f"The possible classes are: {', '.join(CAR_BRANDS)}. "
"Respond ONLY with a valid JSON object where each key is a brand name from the list and "
"each value is a confidence score between 0.0 and 1.0. All scores must sum to 1.0. "
'Example format: {"BMW": 0.05, "Ferrari": 0.85, "Ford": 0.02, ...}'
)
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {"url": f"data:{mime_type};base64,{base64_image}"},
},
],
}
],
max_tokens=300,
)
text = response.choices[0].message.content
start = text.find("{")
end = text.rfind("}") + 1
scores = json.loads(text[start:end])
return {brand: float(scores.get(brand, 0.0)) for brand in CAR_BRANDS}
except openai.AuthenticationError:
return {"Error: Invalid OpenAI API key": 1.0}
except Exception as e:
return {f"Error: {str(e)[:60]}": 1.0}
# ---------------------------------------------------------------------------
# Main classification function
# ---------------------------------------------------------------------------
def classify_car(image):
if image is None:
return {}, {}, {}
# Custom ViT β€” fine-tuned on Stanford Cars
vit_results = vit_classifier(image, top_k=len(CAR_BRANDS))
vit_output = {r["label"]: round(r["score"], 4) for r in vit_results}
# CLIP β€” zero-shot with brand names as candidate labels
clip_results = clip_classifier(image, candidate_labels=CAR_BRANDS)
clip_output = {r["label"]: round(r["score"], 4) for r in clip_results}
# OpenAI GPT-4o Vision
openai_output = classify_with_openai(image)
return vit_output, clip_output, openai_output
# ---------------------------------------------------------------------------
# Example images (add representative car images to example_images/)
# ---------------------------------------------------------------------------
_img_dir = os.path.join(os.path.dirname(__file__), "example_images")
example_images = [
[os.path.join(_img_dir, f)]
for f in sorted(os.listdir(_img_dir))
if f.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))
]
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
css = """
.title { text-align: center; margin-bottom: 0.25rem; }
.subtitle { text-align: center; color: #6b7280; margin-bottom: 1.5rem; font-size: 0.95rem; }
.model-header { font-weight: 600; font-size: 1rem; margin-bottom: 0.25rem; padding: 0.4rem 0.75rem;
border-radius: 6px; background: #f3f4f6; }
.classify-btn { max-width: 200px; margin: 0 auto; }
footer { display: none !important; }
"""
with gr.Blocks(title="Car Brand Classification", css=css, theme=gr.themes.Soft()) as demo:
# Header
gr.Markdown("# πŸš— Car Brand Classification", elem_classes="title")
gr.Markdown(
"Upload a car image and compare predictions from three models side by side.",
elem_classes="subtitle"
)
# Input + button
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(
type="filepath",
label="Car Image",
height=280,
)
classify_btn = gr.Button("Classify", variant="primary", elem_classes="classify-btn")
# Results
with gr.Column(scale=2):
with gr.Row():
with gr.Column():
gr.Markdown("**Custom ViT** β€” fine-tuned", elem_classes="model-header")
vit_output = gr.Label(num_top_classes=5, label="")
with gr.Column():
gr.Markdown("**CLIP** β€” zero-shot", elem_classes="model-header")
clip_output = gr.Label(num_top_classes=5, label="")
with gr.Column():
gr.Markdown("**GPT-4o** β€” vision LLM", elem_classes="model-header")
openai_output = gr.Label(num_top_classes=5, label="")
# Examples
gr.Examples(
examples=example_images,
inputs=input_image,
label="Example Images",
)
classify_btn.click(
fn=classify_car,
inputs=[input_image],
outputs=[vit_output, clip_output, openai_output],
)
demo.launch()