import os import time import numpy as np from PIL import Image import tensorflow as tf import gradio as gr # ─── Constants ──────────────────────────────────────────────────────────────── BASE_DIR = os.path.dirname(os.path.abspath(__file__)) MODELS_DIR = os.path.join(BASE_DIR, "models") IMG_SIZE = (224, 224) MODEL_CONFIGS = [ ("MobileNetV3 Small", "mobilenetv3_small.tflite"), ("MobileNetV3 Large", "mobilenetv3_large.tflite"), ("EfficientNetV2-B0", "efficientnetv2b0.tflite"), ] MEDAL = ["🥇", "🥈", "🥉"] # ─── Load labels ────────────────────────────────────────────────────────────── with open(os.path.join(MODELS_DIR, "labels.txt")) as f: CLASS_NAMES = [line.strip() for line in f if line.strip()] # ─── Load TFLite interpreters at startup ────────────────────────────────────── INTERPRETERS: list[tuple[str, tf.lite.Interpreter]] = [] for display_name, filename in MODEL_CONFIGS: path = os.path.join(MODELS_DIR, filename) interp = tf.lite.Interpreter(model_path=path) interp.allocate_tensors() INTERPRETERS.append((display_name, interp)) print(f"Loaded {len(INTERPRETERS)} TFLite models.") # ─── Inference helpers ──────────────────────────────────────────────────────── def preprocess(pil_image: Image.Image) -> np.ndarray: """Resize, convert to float32 array, add batch dim. Values in [0, 255].""" img = pil_image.convert("RGB").resize(IMG_SIZE, Image.BILINEAR) arr = np.array(img, dtype=np.float32) # (224, 224, 3) return np.expand_dims(arr, axis=0) # (1, 224, 224, 3) def run_model(interp: tf.lite.Interpreter, img_array: np.ndarray) -> tuple[np.ndarray, float]: """Run a single TFLite inference. Returns (probs, latency_ms).""" input_idx = interp.get_input_details()[0]["index"] output_idx = interp.get_output_details()[0]["index"] interp.set_tensor(input_idx, img_array) t0 = time.perf_counter() interp.invoke() latency_ms = (time.perf_counter() - t0) * 1000 probs = interp.get_tensor(output_idx)[0] # shape (num_classes,) return probs, latency_ms def fmt_model_card(display_name: str, probs: np.ndarray, latency_ms: float) -> str: """Build a Markdown string for one model's result card.""" top3_idx = np.argsort(probs)[::-1][:3] lines = [f"### {display_name}", f"⏱ `{latency_ms:.1f} ms`", ""] for rank, idx in enumerate(top3_idx): pct = probs[idx] * 100 bar = "█" * int(pct / 5) # simple bar, max 20 chars lines.append(f"{MEDAL[rank]} **{CLASS_NAMES[idx]}** — {pct:.1f}%") lines.append(f"`{bar:<20}` ") lines.append("") return "\n".join(lines) # ─── Main prediction function ───────────────────────────────────────────────── def predict(pil_image): if pil_image is None: empty = "_(upload an image and press Run Inference)_" return empty, empty, empty img_array = preprocess(pil_image) cards: list[str] = [] for display_name, interp in INTERPRETERS: probs, latency_ms = run_model(interp, img_array) cards.append(fmt_model_card(display_name, probs, latency_ms)) return cards[0], cards[1], cards[2] # ─── Gradio UI ──────────────────────────────────────────────────────────────── CSS = """ .model-card { background: var(--block-background-fill); border-radius: 12px; padding: 16px; } footer { display: none !important; } """ with gr.Blocks(title="Animal Toy Classifier") as demo: gr.Markdown("# 🦁 Animal Toy Classifier\nUpload a photo of an animal toy — three models will predict the class and vote on the best answer.") with gr.Row(): # ── Left: upload + button ────────────────────────────────────────────── with gr.Column(scale=1, min_width=260): img_in = gr.Image( type="pil", sources=["upload"], label="Input Image", height=280, ) with gr.Row(): run_btn = gr.Button("▶ Run Inference", variant="primary", scale=3) clear_btn = gr.ClearButton( components=[img_in], value="✕ Clear", scale=1, ) # ── Right: per-model result cards ────────────────────────────────────── with gr.Column(scale=2): with gr.Row(): with gr.Column(elem_classes="model-card"): out_small = gr.Markdown("### MobileNetV3 Small\n\n_(waiting)_") with gr.Column(elem_classes="model-card"): out_large = gr.Markdown("### MobileNetV3 Large\n\n_(waiting)_") with gr.Column(elem_classes="model-card"): out_eff = gr.Markdown("### EfficientNetV2-B0\n\n_(waiting)_") # ── Wire up ──────────────────────────────────────────────────────────────── run_btn.click( fn=predict, inputs=[img_in], outputs=[out_small, out_large, out_eff], ) # Reset outputs when Clear is pressed clear_btn.click( fn=lambda: ( "### MobileNetV3 Small\n\n_(waiting)_", "### MobileNetV3 Large\n\n_(waiting)_", "### EfficientNetV2-B0\n\n_(waiting)_", ), inputs=[], outputs=[out_small, out_large, out_eff], ) if __name__ == "__main__": demo.launch(theme=gr.themes.Soft(), css=CSS)