""" Aircraft classifier — comparison of three models: 1. Custom fine-tuned ViT (transfer learning, 6 aircraft classes) 2. CLIP zero-shot (open-source: openai/clip-vit-large-patch14) 3. OpenAI GPT-4o vision (closed-source) """ import base64 import io import os import gradio as gr from PIL import Image from transformers import pipeline CUSTOM_MODEL_ID = "dubattim/aircraft-vit-fs26" AIRCRAFT_LABELS = [ "ah64_apache", "airbus_a220", "airbus_a320", "airbus_a321neo", "airbus_a330", "airbus_a350", "airbus_a380", "atr_72", "b2_spirit", "boeing_737", "boeing_747", "boeing_777", "boeing_787_dreamliner", "cessna_172", "concorde", "embraer_e190", "f16_fighting_falcon", "mig21", "sr71_blackbird", "v22_osprey", ] PRETTY = { "ah64_apache": "AH-64 Apache", "airbus_a220": "Airbus A220", "airbus_a320": "Airbus A320", "airbus_a321neo": "Airbus A321neo", "airbus_a330": "Airbus A330", "airbus_a350": "Airbus A350", "airbus_a380": "Airbus A380", "atr_72": "ATR 72", "b2_spirit": "Northrop B-2 Spirit", "boeing_737": "Boeing 737", "boeing_747": "Boeing 747", "boeing_777": "Boeing 777", "boeing_787_dreamliner": "Boeing 787 Dreamliner", "cessna_172": "Cessna 172", "concorde": "Concorde", "embraer_e190": "Embraer E190", "f16_fighting_falcon": "F-16 Fighting Falcon", "mig21": "MiG-21", "sr71_blackbird": "SR-71 Blackbird", "v22_osprey": "V-22 Osprey", } custom_classifier = pipeline("image-classification", model=CUSTOM_MODEL_ID) clip_classifier = pipeline( task="zero-shot-image-classification", model="openai/clip-vit-large-patch14", ) try: from openai import OpenAI openai_client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) except Exception: openai_client = None def _encode_image(path: str) -> str: with Image.open(path) as img: img = img.convert("RGB") buf = io.BytesIO() img.save(buf, format="PNG") return base64.b64encode(buf.getvalue()).decode("utf-8") def classify_with_openai(image_path: str) -> dict: if openai_client is None or not os.environ.get("OPENAI_API_KEY"): return {"error": "OPENAI_API_KEY not configured"} b64 = _encode_image(image_path) options = ", ".join(PRETTY.values()) prompt = ( f"Identify the aircraft in this image. Respond with EXACTLY one of " f"these labels and nothing else: {options}." ) resp = openai_client.chat.completions.create( model="gpt-4o", messages=[ { "role": "user", "content": [ {"type": "text", "text": prompt}, {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64}"}}, ], } ], max_tokens=20, ) answer = resp.choices[0].message.content.strip().lower() return { PRETTY[k]: (1.0 if PRETTY[k].lower() in answer else 0.0) for k in AIRCRAFT_LABELS } def classify(image_path): if image_path is None: return {}, {}, {} custom = { PRETTY.get(r["label"], r["label"]): float(r["score"]) for r in custom_classifier(image_path) } clip = { r["label"]: float(r["score"]) for r in clip_classifier(image_path, candidate_labels=list(PRETTY.values())) } openai_out = classify_with_openai(image_path) return custom, clip, openai_out examples = [[f"example_images/{c}.jpg"] for c in AIRCRAFT_LABELS] with gr.Blocks(title="Aircraft Classifier — Model Comparison") as demo: gr.Markdown( "# Aircraft Classifier — Model Comparison\n" "Upload an aircraft image and compare three approaches: a fine-tuned " "ViT (transfer learning), CLIP zero-shot, and OpenAI GPT-4o vision.\n\n" f"**Classes:** {', '.join(PRETTY.values())}" ) with gr.Row(): inp = gr.Image(type="filepath", label="Input image") with gr.Column(): out_custom = gr.Label(label="Custom ViT (fine-tuned)", num_top_classes=5) out_clip = gr.Label(label="CLIP zero-shot", num_top_classes=5) out_openai = gr.Label(label="OpenAI GPT-4o", num_top_classes=5) btn = gr.Button("Classify", variant="primary") btn.click(classify, inputs=inp, outputs=[out_custom, out_clip, out_openai]) gr.Examples(examples=examples, inputs=inp) if __name__ == "__main__": demo.launch()