Spaces:
Sleeping
Sleeping
| """ | |
| Aircraft classifier — comparison of three models: | |
| 1. Custom fine-tuned ViT (transfer learning, 6 aircraft classes) | |
| 2. CLIP zero-shot (open-source: openai/clip-vit-large-patch14) | |
| 3. OpenAI GPT-4o vision (closed-source) | |
| """ | |
| import base64 | |
| import io | |
| import os | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers import pipeline | |
| CUSTOM_MODEL_ID = "dubattim/aircraft-vit-fs26" | |
| AIRCRAFT_LABELS = [ | |
| "ah64_apache", | |
| "airbus_a220", | |
| "airbus_a320", | |
| "airbus_a321neo", | |
| "airbus_a330", | |
| "airbus_a350", | |
| "airbus_a380", | |
| "atr_72", | |
| "b2_spirit", | |
| "boeing_737", | |
| "boeing_747", | |
| "boeing_777", | |
| "boeing_787_dreamliner", | |
| "cessna_172", | |
| "concorde", | |
| "embraer_e190", | |
| "f16_fighting_falcon", | |
| "mig21", | |
| "sr71_blackbird", | |
| "v22_osprey", | |
| ] | |
| PRETTY = { | |
| "ah64_apache": "AH-64 Apache", | |
| "airbus_a220": "Airbus A220", | |
| "airbus_a320": "Airbus A320", | |
| "airbus_a321neo": "Airbus A321neo", | |
| "airbus_a330": "Airbus A330", | |
| "airbus_a350": "Airbus A350", | |
| "airbus_a380": "Airbus A380", | |
| "atr_72": "ATR 72", | |
| "b2_spirit": "Northrop B-2 Spirit", | |
| "boeing_737": "Boeing 737", | |
| "boeing_747": "Boeing 747", | |
| "boeing_777": "Boeing 777", | |
| "boeing_787_dreamliner": "Boeing 787 Dreamliner", | |
| "cessna_172": "Cessna 172", | |
| "concorde": "Concorde", | |
| "embraer_e190": "Embraer E190", | |
| "f16_fighting_falcon": "F-16 Fighting Falcon", | |
| "mig21": "MiG-21", | |
| "sr71_blackbird": "SR-71 Blackbird", | |
| "v22_osprey": "V-22 Osprey", | |
| } | |
| custom_classifier = pipeline("image-classification", model=CUSTOM_MODEL_ID) | |
| clip_classifier = pipeline( | |
| task="zero-shot-image-classification", | |
| model="openai/clip-vit-large-patch14", | |
| ) | |
| try: | |
| from openai import OpenAI | |
| openai_client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) | |
| except Exception: | |
| openai_client = None | |
| def _encode_image(path: str) -> str: | |
| with Image.open(path) as img: | |
| img = img.convert("RGB") | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| return base64.b64encode(buf.getvalue()).decode("utf-8") | |
| def classify_with_openai(image_path: str) -> dict: | |
| if openai_client is None or not os.environ.get("OPENAI_API_KEY"): | |
| return {"error": "OPENAI_API_KEY not configured"} | |
| b64 = _encode_image(image_path) | |
| options = ", ".join(PRETTY.values()) | |
| prompt = ( | |
| f"Identify the aircraft in this image. Respond with EXACTLY one of " | |
| f"these labels and nothing else: {options}." | |
| ) | |
| resp = openai_client.chat.completions.create( | |
| model="gpt-4o", | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": prompt}, | |
| {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64}"}}, | |
| ], | |
| } | |
| ], | |
| max_tokens=20, | |
| ) | |
| answer = resp.choices[0].message.content.strip().lower() | |
| return { | |
| PRETTY[k]: (1.0 if PRETTY[k].lower() in answer else 0.0) | |
| for k in AIRCRAFT_LABELS | |
| } | |
| def classify(image_path): | |
| if image_path is None: | |
| return {}, {}, {} | |
| custom = { | |
| PRETTY.get(r["label"], r["label"]): float(r["score"]) | |
| for r in custom_classifier(image_path) | |
| } | |
| clip = { | |
| r["label"]: float(r["score"]) | |
| for r in clip_classifier(image_path, candidate_labels=list(PRETTY.values())) | |
| } | |
| openai_out = classify_with_openai(image_path) | |
| return custom, clip, openai_out | |
| examples = [[f"example_images/{c}.jpg"] for c in AIRCRAFT_LABELS] | |
| with gr.Blocks(title="Aircraft Classifier — Model Comparison") as demo: | |
| gr.Markdown( | |
| "# Aircraft Classifier — Model Comparison\n" | |
| "Upload an aircraft image and compare three approaches: a fine-tuned " | |
| "ViT (transfer learning), CLIP zero-shot, and OpenAI GPT-4o vision.\n\n" | |
| f"**Classes:** {', '.join(PRETTY.values())}" | |
| ) | |
| with gr.Row(): | |
| inp = gr.Image(type="filepath", label="Input image") | |
| with gr.Column(): | |
| out_custom = gr.Label(label="Custom ViT (fine-tuned)", num_top_classes=5) | |
| out_clip = gr.Label(label="CLIP zero-shot", num_top_classes=5) | |
| out_openai = gr.Label(label="OpenAI GPT-4o", num_top_classes=5) | |
| btn = gr.Button("Classify", variant="primary") | |
| btn.click(classify, inputs=inp, outputs=[out_custom, out_clip, out_openai]) | |
| gr.Examples(examples=examples, inputs=inp) | |
| if __name__ == "__main__": | |
| demo.launch() | |