File size: 6,554 Bytes
50709d2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | import os
import time
import numpy as np
from PIL import Image
import tensorflow as tf
import gradio as gr
# βββ Constants ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODELS_DIR = os.path.join(BASE_DIR, "models")
IMG_SIZE = (224, 224)
MODEL_CONFIGS = [
("MobileNetV3 Small", "mobilenetv3_small.tflite"),
("MobileNetV3 Large", "mobilenetv3_large.tflite"),
("EfficientNetV2-B0", "efficientnetv2b0.tflite"),
]
MEDAL = ["π₯", "π₯", "π₯"]
# βββ Load labels ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
with open(os.path.join(MODELS_DIR, "labels.txt")) as f:
CLASS_NAMES = [line.strip() for line in f if line.strip()]
# βββ Load TFLite interpreters at startup ββββββββββββββββββββββββββββββββββββββ
INTERPRETERS: list[tuple[str, tf.lite.Interpreter]] = []
for display_name, filename in MODEL_CONFIGS:
path = os.path.join(MODELS_DIR, filename)
interp = tf.lite.Interpreter(model_path=path)
interp.allocate_tensors()
INTERPRETERS.append((display_name, interp))
print(f"Loaded {len(INTERPRETERS)} TFLite models.")
# βββ Inference helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def preprocess(pil_image: Image.Image) -> np.ndarray:
"""Resize, convert to float32 array, add batch dim. Values in [0, 255]."""
img = pil_image.convert("RGB").resize(IMG_SIZE, Image.BILINEAR)
arr = np.array(img, dtype=np.float32) # (224, 224, 3)
return np.expand_dims(arr, axis=0) # (1, 224, 224, 3)
def run_model(interp: tf.lite.Interpreter, img_array: np.ndarray) -> tuple[np.ndarray, float]:
"""Run a single TFLite inference. Returns (probs, latency_ms)."""
input_idx = interp.get_input_details()[0]["index"]
output_idx = interp.get_output_details()[0]["index"]
interp.set_tensor(input_idx, img_array)
t0 = time.perf_counter()
interp.invoke()
latency_ms = (time.perf_counter() - t0) * 1000
probs = interp.get_tensor(output_idx)[0] # shape (num_classes,)
return probs, latency_ms
def fmt_model_card(display_name: str, probs: np.ndarray, latency_ms: float) -> str:
"""Build a Markdown string for one model's result card."""
top3_idx = np.argsort(probs)[::-1][:3]
lines = [f"### {display_name}", f"β± `{latency_ms:.1f} ms`", ""]
for rank, idx in enumerate(top3_idx):
pct = probs[idx] * 100
bar = "β" * int(pct / 5) # simple bar, max 20 chars
lines.append(f"{MEDAL[rank]} **{CLASS_NAMES[idx]}** β {pct:.1f}%")
lines.append(f"`{bar:<20}` ")
lines.append("")
return "\n".join(lines)
# βββ Main prediction function βββββββββββββββββββββββββββββββββββββββββββββββββ
def predict(pil_image):
if pil_image is None:
empty = "_(upload an image and press Run Inference)_"
return empty, empty, empty
img_array = preprocess(pil_image)
cards: list[str] = []
for display_name, interp in INTERPRETERS:
probs, latency_ms = run_model(interp, img_array)
cards.append(fmt_model_card(display_name, probs, latency_ms))
return cards[0], cards[1], cards[2]
# βββ Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
CSS = """
.model-card { background: var(--block-background-fill); border-radius: 12px; padding: 16px; }
footer { display: none !important; }
"""
with gr.Blocks(title="Animal Toy Classifier") as demo:
gr.Markdown("# π¦ Animal Toy Classifier\nUpload a photo of an animal toy β three models will predict the class and vote on the best answer.")
with gr.Row():
# ββ Left: upload + button ββββββββββββββββββββββββββββββββββββββββββββββ
with gr.Column(scale=1, min_width=260):
img_in = gr.Image(
type="pil",
sources=["upload"],
label="Input Image",
height=280,
)
with gr.Row():
run_btn = gr.Button("βΆ Run Inference", variant="primary", scale=3)
clear_btn = gr.ClearButton(
components=[img_in],
value="β Clear",
scale=1,
)
# ββ Right: per-model result cards ββββββββββββββββββββββββββββββββββββββ
with gr.Column(scale=2):
with gr.Row():
with gr.Column(elem_classes="model-card"):
out_small = gr.Markdown("### MobileNetV3 Small\n\n_(waiting)_")
with gr.Column(elem_classes="model-card"):
out_large = gr.Markdown("### MobileNetV3 Large\n\n_(waiting)_")
with gr.Column(elem_classes="model-card"):
out_eff = gr.Markdown("### EfficientNetV2-B0\n\n_(waiting)_")
# ββ Wire up ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
run_btn.click(
fn=predict,
inputs=[img_in],
outputs=[out_small, out_large, out_eff],
)
# Reset outputs when Clear is pressed
clear_btn.click(
fn=lambda: (
"### MobileNetV3 Small\n\n_(waiting)_",
"### MobileNetV3 Large\n\n_(waiting)_",
"### EfficientNetV2-B0\n\n_(waiting)_",
),
inputs=[],
outputs=[out_small, out_large, out_eff],
)
if __name__ == "__main__":
demo.launch(theme=gr.themes.Soft(), css=CSS)
|