ToyGame / app.py
ronalhung's picture
Upload app.py with huggingface_hub
50709d2 verified
import os
import time
import numpy as np
from PIL import Image
import tensorflow as tf
import gradio as gr
# ─── Constants ────────────────────────────────────────────────────────────────
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODELS_DIR = os.path.join(BASE_DIR, "models")
IMG_SIZE = (224, 224)
MODEL_CONFIGS = [
("MobileNetV3 Small", "mobilenetv3_small.tflite"),
("MobileNetV3 Large", "mobilenetv3_large.tflite"),
("EfficientNetV2-B0", "efficientnetv2b0.tflite"),
]
MEDAL = ["πŸ₯‡", "πŸ₯ˆ", "πŸ₯‰"]
# ─── Load labels ──────────────────────────────────────────────────────────────
with open(os.path.join(MODELS_DIR, "labels.txt")) as f:
CLASS_NAMES = [line.strip() for line in f if line.strip()]
# ─── Load TFLite interpreters at startup ──────────────────────────────────────
INTERPRETERS: list[tuple[str, tf.lite.Interpreter]] = []
for display_name, filename in MODEL_CONFIGS:
path = os.path.join(MODELS_DIR, filename)
interp = tf.lite.Interpreter(model_path=path)
interp.allocate_tensors()
INTERPRETERS.append((display_name, interp))
print(f"Loaded {len(INTERPRETERS)} TFLite models.")
# ─── Inference helpers ────────────────────────────────────────────────────────
def preprocess(pil_image: Image.Image) -> np.ndarray:
"""Resize, convert to float32 array, add batch dim. Values in [0, 255]."""
img = pil_image.convert("RGB").resize(IMG_SIZE, Image.BILINEAR)
arr = np.array(img, dtype=np.float32) # (224, 224, 3)
return np.expand_dims(arr, axis=0) # (1, 224, 224, 3)
def run_model(interp: tf.lite.Interpreter, img_array: np.ndarray) -> tuple[np.ndarray, float]:
"""Run a single TFLite inference. Returns (probs, latency_ms)."""
input_idx = interp.get_input_details()[0]["index"]
output_idx = interp.get_output_details()[0]["index"]
interp.set_tensor(input_idx, img_array)
t0 = time.perf_counter()
interp.invoke()
latency_ms = (time.perf_counter() - t0) * 1000
probs = interp.get_tensor(output_idx)[0] # shape (num_classes,)
return probs, latency_ms
def fmt_model_card(display_name: str, probs: np.ndarray, latency_ms: float) -> str:
"""Build a Markdown string for one model's result card."""
top3_idx = np.argsort(probs)[::-1][:3]
lines = [f"### {display_name}", f"⏱ `{latency_ms:.1f} ms`", ""]
for rank, idx in enumerate(top3_idx):
pct = probs[idx] * 100
bar = "β–ˆ" * int(pct / 5) # simple bar, max 20 chars
lines.append(f"{MEDAL[rank]} **{CLASS_NAMES[idx]}** β€” {pct:.1f}%")
lines.append(f"`{bar:<20}` ")
lines.append("")
return "\n".join(lines)
# ─── Main prediction function ─────────────────────────────────────────────────
def predict(pil_image):
if pil_image is None:
empty = "_(upload an image and press Run Inference)_"
return empty, empty, empty
img_array = preprocess(pil_image)
cards: list[str] = []
for display_name, interp in INTERPRETERS:
probs, latency_ms = run_model(interp, img_array)
cards.append(fmt_model_card(display_name, probs, latency_ms))
return cards[0], cards[1], cards[2]
# ─── Gradio UI ────────────────────────────────────────────────────────────────
CSS = """
.model-card { background: var(--block-background-fill); border-radius: 12px; padding: 16px; }
footer { display: none !important; }
"""
with gr.Blocks(title="Animal Toy Classifier") as demo:
gr.Markdown("# 🦁 Animal Toy Classifier\nUpload a photo of an animal toy β€” three models will predict the class and vote on the best answer.")
with gr.Row():
# ── Left: upload + button ──────────────────────────────────────────────
with gr.Column(scale=1, min_width=260):
img_in = gr.Image(
type="pil",
sources=["upload"],
label="Input Image",
height=280,
)
with gr.Row():
run_btn = gr.Button("β–Ά Run Inference", variant="primary", scale=3)
clear_btn = gr.ClearButton(
components=[img_in],
value="βœ• Clear",
scale=1,
)
# ── Right: per-model result cards ──────────────────────────────────────
with gr.Column(scale=2):
with gr.Row():
with gr.Column(elem_classes="model-card"):
out_small = gr.Markdown("### MobileNetV3 Small\n\n_(waiting)_")
with gr.Column(elem_classes="model-card"):
out_large = gr.Markdown("### MobileNetV3 Large\n\n_(waiting)_")
with gr.Column(elem_classes="model-card"):
out_eff = gr.Markdown("### EfficientNetV2-B0\n\n_(waiting)_")
# ── Wire up ────────────────────────────────────────────────────────────────
run_btn.click(
fn=predict,
inputs=[img_in],
outputs=[out_small, out_large, out_eff],
)
# Reset outputs when Clear is pressed
clear_btn.click(
fn=lambda: (
"### MobileNetV3 Small\n\n_(waiting)_",
"### MobileNetV3 Large\n\n_(waiting)_",
"### EfficientNetV2-B0\n\n_(waiting)_",
),
inputs=[],
outputs=[out_small, out_large, out_eff],
)
if __name__ == "__main__":
demo.launch(theme=gr.themes.Soft(), css=CSS)