Upload app.py with huggingface_hub
Browse files
app.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import tensorflow as tf
|
| 6 |
+
import gradio as gr
|
| 7 |
+
|
| 8 |
+
# βββ Constants ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 9 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 10 |
+
MODELS_DIR = os.path.join(BASE_DIR, "models")
|
| 11 |
+
IMG_SIZE = (224, 224)
|
| 12 |
+
|
| 13 |
+
MODEL_CONFIGS = [
|
| 14 |
+
("MobileNetV3 Small", "mobilenetv3_small.tflite"),
|
| 15 |
+
("MobileNetV3 Large", "mobilenetv3_large.tflite"),
|
| 16 |
+
("EfficientNetV2-B0", "efficientnetv2b0.tflite"),
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
MEDAL = ["π₯", "π₯", "π₯"]
|
| 20 |
+
|
| 21 |
+
# βββ Load labels ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 22 |
+
with open(os.path.join(MODELS_DIR, "labels.txt")) as f:
|
| 23 |
+
CLASS_NAMES = [line.strip() for line in f if line.strip()]
|
| 24 |
+
|
| 25 |
+
# βββ Load TFLite interpreters at startup ββββββββββββββββββββββββββββββββββββββ
|
| 26 |
+
INTERPRETERS: list[tuple[str, tf.lite.Interpreter]] = []
|
| 27 |
+
for display_name, filename in MODEL_CONFIGS:
|
| 28 |
+
path = os.path.join(MODELS_DIR, filename)
|
| 29 |
+
interp = tf.lite.Interpreter(model_path=path)
|
| 30 |
+
interp.allocate_tensors()
|
| 31 |
+
INTERPRETERS.append((display_name, interp))
|
| 32 |
+
|
| 33 |
+
print(f"Loaded {len(INTERPRETERS)} TFLite models.")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# βββ Inference helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 37 |
+
def preprocess(pil_image: Image.Image) -> np.ndarray:
|
| 38 |
+
"""Resize, convert to float32 array, add batch dim. Values in [0, 255]."""
|
| 39 |
+
img = pil_image.convert("RGB").resize(IMG_SIZE, Image.BILINEAR)
|
| 40 |
+
arr = np.array(img, dtype=np.float32) # (224, 224, 3)
|
| 41 |
+
return np.expand_dims(arr, axis=0) # (1, 224, 224, 3)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def run_model(interp: tf.lite.Interpreter, img_array: np.ndarray) -> tuple[np.ndarray, float]:
|
| 45 |
+
"""Run a single TFLite inference. Returns (probs, latency_ms)."""
|
| 46 |
+
input_idx = interp.get_input_details()[0]["index"]
|
| 47 |
+
output_idx = interp.get_output_details()[0]["index"]
|
| 48 |
+
interp.set_tensor(input_idx, img_array)
|
| 49 |
+
t0 = time.perf_counter()
|
| 50 |
+
interp.invoke()
|
| 51 |
+
latency_ms = (time.perf_counter() - t0) * 1000
|
| 52 |
+
probs = interp.get_tensor(output_idx)[0] # shape (num_classes,)
|
| 53 |
+
return probs, latency_ms
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def fmt_model_card(display_name: str, probs: np.ndarray, latency_ms: float) -> str:
|
| 57 |
+
"""Build a Markdown string for one model's result card."""
|
| 58 |
+
top3_idx = np.argsort(probs)[::-1][:3]
|
| 59 |
+
lines = [f"### {display_name}", f"β± `{latency_ms:.1f} ms`", ""]
|
| 60 |
+
for rank, idx in enumerate(top3_idx):
|
| 61 |
+
pct = probs[idx] * 100
|
| 62 |
+
bar = "β" * int(pct / 5) # simple bar, max 20 chars
|
| 63 |
+
lines.append(f"{MEDAL[rank]} **{CLASS_NAMES[idx]}** β {pct:.1f}%")
|
| 64 |
+
lines.append(f"`{bar:<20}` ")
|
| 65 |
+
lines.append("")
|
| 66 |
+
return "\n".join(lines)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# βββ Main prediction function βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 70 |
+
def predict(pil_image):
|
| 71 |
+
if pil_image is None:
|
| 72 |
+
empty = "_(upload an image and press Run Inference)_"
|
| 73 |
+
return empty, empty, empty
|
| 74 |
+
|
| 75 |
+
img_array = preprocess(pil_image)
|
| 76 |
+
cards: list[str] = []
|
| 77 |
+
|
| 78 |
+
for display_name, interp in INTERPRETERS:
|
| 79 |
+
probs, latency_ms = run_model(interp, img_array)
|
| 80 |
+
cards.append(fmt_model_card(display_name, probs, latency_ms))
|
| 81 |
+
|
| 82 |
+
return cards[0], cards[1], cards[2]
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# βββ Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 86 |
+
CSS = """
|
| 87 |
+
.model-card { background: var(--block-background-fill); border-radius: 12px; padding: 16px; }
|
| 88 |
+
footer { display: none !important; }
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
with gr.Blocks(title="Animal Toy Classifier") as demo:
|
| 92 |
+
gr.Markdown("# π¦ Animal Toy Classifier\nUpload a photo of an animal toy β three models will predict the class and vote on the best answer.")
|
| 93 |
+
|
| 94 |
+
with gr.Row():
|
| 95 |
+
# ββ Left: upload + button ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 96 |
+
with gr.Column(scale=1, min_width=260):
|
| 97 |
+
img_in = gr.Image(
|
| 98 |
+
type="pil",
|
| 99 |
+
sources=["upload"],
|
| 100 |
+
label="Input Image",
|
| 101 |
+
height=280,
|
| 102 |
+
)
|
| 103 |
+
with gr.Row():
|
| 104 |
+
run_btn = gr.Button("βΆ Run Inference", variant="primary", scale=3)
|
| 105 |
+
clear_btn = gr.ClearButton(
|
| 106 |
+
components=[img_in],
|
| 107 |
+
value="β Clear",
|
| 108 |
+
scale=1,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# ββ Right: per-model result cards ββββββββββββββββββββββββββββββββββββββ
|
| 112 |
+
with gr.Column(scale=2):
|
| 113 |
+
with gr.Row():
|
| 114 |
+
with gr.Column(elem_classes="model-card"):
|
| 115 |
+
out_small = gr.Markdown("### MobileNetV3 Small\n\n_(waiting)_")
|
| 116 |
+
with gr.Column(elem_classes="model-card"):
|
| 117 |
+
out_large = gr.Markdown("### MobileNetV3 Large\n\n_(waiting)_")
|
| 118 |
+
with gr.Column(elem_classes="model-card"):
|
| 119 |
+
out_eff = gr.Markdown("### EfficientNetV2-B0\n\n_(waiting)_")
|
| 120 |
+
|
| 121 |
+
# ββ Wire up ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 122 |
+
run_btn.click(
|
| 123 |
+
fn=predict,
|
| 124 |
+
inputs=[img_in],
|
| 125 |
+
outputs=[out_small, out_large, out_eff],
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Reset outputs when Clear is pressed
|
| 129 |
+
clear_btn.click(
|
| 130 |
+
fn=lambda: (
|
| 131 |
+
"### MobileNetV3 Small\n\n_(waiting)_",
|
| 132 |
+
"### MobileNetV3 Large\n\n_(waiting)_",
|
| 133 |
+
"### EfficientNetV2-B0\n\n_(waiting)_",
|
| 134 |
+
),
|
| 135 |
+
inputs=[],
|
| 136 |
+
outputs=[out_small, out_large, out_eff],
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
if __name__ == "__main__":
|
| 141 |
+
demo.launch(theme=gr.themes.Soft(), css=CSS)
|