dubattim's picture
expand to 20 classes (incl. neo + more airliners)
0a7efb5
"""
Aircraft classifier — comparison of three models:
1. Custom fine-tuned ViT (transfer learning, 6 aircraft classes)
2. CLIP zero-shot (open-source: openai/clip-vit-large-patch14)
3. OpenAI GPT-4o vision (closed-source)
"""
import base64
import io
import os
import gradio as gr
from PIL import Image
from transformers import pipeline
CUSTOM_MODEL_ID = "dubattim/aircraft-vit-fs26"
AIRCRAFT_LABELS = [
"ah64_apache",
"airbus_a220",
"airbus_a320",
"airbus_a321neo",
"airbus_a330",
"airbus_a350",
"airbus_a380",
"atr_72",
"b2_spirit",
"boeing_737",
"boeing_747",
"boeing_777",
"boeing_787_dreamliner",
"cessna_172",
"concorde",
"embraer_e190",
"f16_fighting_falcon",
"mig21",
"sr71_blackbird",
"v22_osprey",
]
PRETTY = {
"ah64_apache": "AH-64 Apache",
"airbus_a220": "Airbus A220",
"airbus_a320": "Airbus A320",
"airbus_a321neo": "Airbus A321neo",
"airbus_a330": "Airbus A330",
"airbus_a350": "Airbus A350",
"airbus_a380": "Airbus A380",
"atr_72": "ATR 72",
"b2_spirit": "Northrop B-2 Spirit",
"boeing_737": "Boeing 737",
"boeing_747": "Boeing 747",
"boeing_777": "Boeing 777",
"boeing_787_dreamliner": "Boeing 787 Dreamliner",
"cessna_172": "Cessna 172",
"concorde": "Concorde",
"embraer_e190": "Embraer E190",
"f16_fighting_falcon": "F-16 Fighting Falcon",
"mig21": "MiG-21",
"sr71_blackbird": "SR-71 Blackbird",
"v22_osprey": "V-22 Osprey",
}
custom_classifier = pipeline("image-classification", model=CUSTOM_MODEL_ID)
clip_classifier = pipeline(
task="zero-shot-image-classification",
model="openai/clip-vit-large-patch14",
)
try:
from openai import OpenAI
openai_client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
except Exception:
openai_client = None
def _encode_image(path: str) -> str:
with Image.open(path) as img:
img = img.convert("RGB")
buf = io.BytesIO()
img.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode("utf-8")
def classify_with_openai(image_path: str) -> dict:
if openai_client is None or not os.environ.get("OPENAI_API_KEY"):
return {"error": "OPENAI_API_KEY not configured"}
b64 = _encode_image(image_path)
options = ", ".join(PRETTY.values())
prompt = (
f"Identify the aircraft in this image. Respond with EXACTLY one of "
f"these labels and nothing else: {options}."
)
resp = openai_client.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64}"}},
],
}
],
max_tokens=20,
)
answer = resp.choices[0].message.content.strip().lower()
return {
PRETTY[k]: (1.0 if PRETTY[k].lower() in answer else 0.0)
for k in AIRCRAFT_LABELS
}
def classify(image_path):
if image_path is None:
return {}, {}, {}
custom = {
PRETTY.get(r["label"], r["label"]): float(r["score"])
for r in custom_classifier(image_path)
}
clip = {
r["label"]: float(r["score"])
for r in clip_classifier(image_path, candidate_labels=list(PRETTY.values()))
}
openai_out = classify_with_openai(image_path)
return custom, clip, openai_out
examples = [[f"example_images/{c}.jpg"] for c in AIRCRAFT_LABELS]
with gr.Blocks(title="Aircraft Classifier — Model Comparison") as demo:
gr.Markdown(
"# Aircraft Classifier — Model Comparison\n"
"Upload an aircraft image and compare three approaches: a fine-tuned "
"ViT (transfer learning), CLIP zero-shot, and OpenAI GPT-4o vision.\n\n"
f"**Classes:** {', '.join(PRETTY.values())}"
)
with gr.Row():
inp = gr.Image(type="filepath", label="Input image")
with gr.Column():
out_custom = gr.Label(label="Custom ViT (fine-tuned)", num_top_classes=5)
out_clip = gr.Label(label="CLIP zero-shot", num_top_classes=5)
out_openai = gr.Label(label="OpenAI GPT-4o", num_top_classes=5)
btn = gr.Button("Classify", variant="primary")
btn.click(classify, inputs=inp, outputs=[out_custom, out_clip, out_openai])
gr.Examples(examples=examples, inputs=inp)
if __name__ == "__main__":
demo.launch()