| import os
|
| import time
|
| import numpy as np
|
| from PIL import Image
|
| import tensorflow as tf
|
| import gradio as gr
|
|
|
|
|
| BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| MODELS_DIR = os.path.join(BASE_DIR, "models")
|
| IMG_SIZE = (224, 224)
|
|
|
| MODEL_CONFIGS = [
|
| ("MobileNetV3 Small", "mobilenetv3_small.tflite"),
|
| ("MobileNetV3 Large", "mobilenetv3_large.tflite"),
|
| ("EfficientNetV2-B0", "efficientnetv2b0.tflite"),
|
| ]
|
|
|
| MEDAL = ["π₯", "π₯", "π₯"]
|
|
|
|
|
| with open(os.path.join(MODELS_DIR, "labels.txt")) as f:
|
| CLASS_NAMES = [line.strip() for line in f if line.strip()]
|
|
|
|
|
| INTERPRETERS: list[tuple[str, tf.lite.Interpreter]] = []
|
| for display_name, filename in MODEL_CONFIGS:
|
| path = os.path.join(MODELS_DIR, filename)
|
| interp = tf.lite.Interpreter(model_path=path)
|
| interp.allocate_tensors()
|
| INTERPRETERS.append((display_name, interp))
|
|
|
| print(f"Loaded {len(INTERPRETERS)} TFLite models.")
|
|
|
|
|
|
|
| def preprocess(pil_image: Image.Image) -> np.ndarray:
|
| """Resize, convert to float32 array, add batch dim. Values in [0, 255]."""
|
| img = pil_image.convert("RGB").resize(IMG_SIZE, Image.BILINEAR)
|
| arr = np.array(img, dtype=np.float32)
|
| return np.expand_dims(arr, axis=0)
|
|
|
|
|
| def run_model(interp: tf.lite.Interpreter, img_array: np.ndarray) -> tuple[np.ndarray, float]:
|
| """Run a single TFLite inference. Returns (probs, latency_ms)."""
|
| input_idx = interp.get_input_details()[0]["index"]
|
| output_idx = interp.get_output_details()[0]["index"]
|
| interp.set_tensor(input_idx, img_array)
|
| t0 = time.perf_counter()
|
| interp.invoke()
|
| latency_ms = (time.perf_counter() - t0) * 1000
|
| probs = interp.get_tensor(output_idx)[0]
|
| return probs, latency_ms
|
|
|
|
|
| def fmt_model_card(display_name: str, probs: np.ndarray, latency_ms: float) -> str:
|
| """Build a Markdown string for one model's result card."""
|
| top3_idx = np.argsort(probs)[::-1][:3]
|
| lines = [f"### {display_name}", f"β± `{latency_ms:.1f} ms`", ""]
|
| for rank, idx in enumerate(top3_idx):
|
| pct = probs[idx] * 100
|
| bar = "β" * int(pct / 5)
|
| lines.append(f"{MEDAL[rank]} **{CLASS_NAMES[idx]}** β {pct:.1f}%")
|
| lines.append(f"`{bar:<20}` ")
|
| lines.append("")
|
| return "\n".join(lines)
|
|
|
|
|
|
|
| def predict(pil_image):
|
| if pil_image is None:
|
| empty = "_(upload an image and press Run Inference)_"
|
| return empty, empty, empty
|
|
|
| img_array = preprocess(pil_image)
|
| cards: list[str] = []
|
|
|
| for display_name, interp in INTERPRETERS:
|
| probs, latency_ms = run_model(interp, img_array)
|
| cards.append(fmt_model_card(display_name, probs, latency_ms))
|
|
|
| return cards[0], cards[1], cards[2]
|
|
|
|
|
|
|
| CSS = """
|
| .model-card { background: var(--block-background-fill); border-radius: 12px; padding: 16px; }
|
| footer { display: none !important; }
|
| """
|
|
|
| with gr.Blocks(title="Animal Toy Classifier") as demo:
|
| gr.Markdown("# π¦ Animal Toy Classifier\nUpload a photo of an animal toy β three models will predict the class and vote on the best answer.")
|
|
|
| with gr.Row():
|
|
|
| with gr.Column(scale=1, min_width=260):
|
| img_in = gr.Image(
|
| type="pil",
|
| sources=["upload"],
|
| label="Input Image",
|
| height=280,
|
| )
|
| with gr.Row():
|
| run_btn = gr.Button("βΆ Run Inference", variant="primary", scale=3)
|
| clear_btn = gr.ClearButton(
|
| components=[img_in],
|
| value="β Clear",
|
| scale=1,
|
| )
|
|
|
|
|
| with gr.Column(scale=2):
|
| with gr.Row():
|
| with gr.Column(elem_classes="model-card"):
|
| out_small = gr.Markdown("### MobileNetV3 Small\n\n_(waiting)_")
|
| with gr.Column(elem_classes="model-card"):
|
| out_large = gr.Markdown("### MobileNetV3 Large\n\n_(waiting)_")
|
| with gr.Column(elem_classes="model-card"):
|
| out_eff = gr.Markdown("### EfficientNetV2-B0\n\n_(waiting)_")
|
|
|
|
|
| run_btn.click(
|
| fn=predict,
|
| inputs=[img_in],
|
| outputs=[out_small, out_large, out_eff],
|
| )
|
|
|
|
|
| clear_btn.click(
|
| fn=lambda: (
|
| "### MobileNetV3 Small\n\n_(waiting)_",
|
| "### MobileNetV3 Large\n\n_(waiting)_",
|
| "### EfficientNetV2-B0\n\n_(waiting)_",
|
| ),
|
| inputs=[],
|
| outputs=[out_small, out_large, out_eff],
|
| )
|
|
|
|
|
| if __name__ == "__main__":
|
| demo.launch(theme=gr.themes.Soft(), css=CSS)
|
|
|