ronalhung commited on
Commit
50709d2
Β·
verified Β·
1 Parent(s): a5f75a3

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +141 -0
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import numpy as np
4
+ from PIL import Image
5
+ import tensorflow as tf
6
+ import gradio as gr
7
+
8
+ # ─── Constants ────────────────────────────────────────────────────────────────
9
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
10
+ MODELS_DIR = os.path.join(BASE_DIR, "models")
11
+ IMG_SIZE = (224, 224)
12
+
13
+ MODEL_CONFIGS = [
14
+ ("MobileNetV3 Small", "mobilenetv3_small.tflite"),
15
+ ("MobileNetV3 Large", "mobilenetv3_large.tflite"),
16
+ ("EfficientNetV2-B0", "efficientnetv2b0.tflite"),
17
+ ]
18
+
19
+ MEDAL = ["πŸ₯‡", "πŸ₯ˆ", "πŸ₯‰"]
20
+
21
+ # ─── Load labels ──────────────────────────────────────────────────────────────
22
+ with open(os.path.join(MODELS_DIR, "labels.txt")) as f:
23
+ CLASS_NAMES = [line.strip() for line in f if line.strip()]
24
+
25
+ # ─── Load TFLite interpreters at startup ──────────────────────────────────────
26
+ INTERPRETERS: list[tuple[str, tf.lite.Interpreter]] = []
27
+ for display_name, filename in MODEL_CONFIGS:
28
+ path = os.path.join(MODELS_DIR, filename)
29
+ interp = tf.lite.Interpreter(model_path=path)
30
+ interp.allocate_tensors()
31
+ INTERPRETERS.append((display_name, interp))
32
+
33
+ print(f"Loaded {len(INTERPRETERS)} TFLite models.")
34
+
35
+
36
+ # ─── Inference helpers ────────────────────────────────────────────────────────
37
+ def preprocess(pil_image: Image.Image) -> np.ndarray:
38
+ """Resize, convert to float32 array, add batch dim. Values in [0, 255]."""
39
+ img = pil_image.convert("RGB").resize(IMG_SIZE, Image.BILINEAR)
40
+ arr = np.array(img, dtype=np.float32) # (224, 224, 3)
41
+ return np.expand_dims(arr, axis=0) # (1, 224, 224, 3)
42
+
43
+
44
+ def run_model(interp: tf.lite.Interpreter, img_array: np.ndarray) -> tuple[np.ndarray, float]:
45
+ """Run a single TFLite inference. Returns (probs, latency_ms)."""
46
+ input_idx = interp.get_input_details()[0]["index"]
47
+ output_idx = interp.get_output_details()[0]["index"]
48
+ interp.set_tensor(input_idx, img_array)
49
+ t0 = time.perf_counter()
50
+ interp.invoke()
51
+ latency_ms = (time.perf_counter() - t0) * 1000
52
+ probs = interp.get_tensor(output_idx)[0] # shape (num_classes,)
53
+ return probs, latency_ms
54
+
55
+
56
+ def fmt_model_card(display_name: str, probs: np.ndarray, latency_ms: float) -> str:
57
+ """Build a Markdown string for one model's result card."""
58
+ top3_idx = np.argsort(probs)[::-1][:3]
59
+ lines = [f"### {display_name}", f"⏱ `{latency_ms:.1f} ms`", ""]
60
+ for rank, idx in enumerate(top3_idx):
61
+ pct = probs[idx] * 100
62
+ bar = "β–ˆ" * int(pct / 5) # simple bar, max 20 chars
63
+ lines.append(f"{MEDAL[rank]} **{CLASS_NAMES[idx]}** β€” {pct:.1f}%")
64
+ lines.append(f"`{bar:<20}` ")
65
+ lines.append("")
66
+ return "\n".join(lines)
67
+
68
+
69
+ # ─── Main prediction function ─────────────────────────────────────────────────
70
+ def predict(pil_image):
71
+ if pil_image is None:
72
+ empty = "_(upload an image and press Run Inference)_"
73
+ return empty, empty, empty
74
+
75
+ img_array = preprocess(pil_image)
76
+ cards: list[str] = []
77
+
78
+ for display_name, interp in INTERPRETERS:
79
+ probs, latency_ms = run_model(interp, img_array)
80
+ cards.append(fmt_model_card(display_name, probs, latency_ms))
81
+
82
+ return cards[0], cards[1], cards[2]
83
+
84
+
85
+ # ─── Gradio UI ────────────────────────────────────────────────────────────────
86
+ CSS = """
87
+ .model-card { background: var(--block-background-fill); border-radius: 12px; padding: 16px; }
88
+ footer { display: none !important; }
89
+ """
90
+
91
+ with gr.Blocks(title="Animal Toy Classifier") as demo:
92
+ gr.Markdown("# 🦁 Animal Toy Classifier\nUpload a photo of an animal toy β€” three models will predict the class and vote on the best answer.")
93
+
94
+ with gr.Row():
95
+ # ── Left: upload + button ──────────────────────────────────────────────
96
+ with gr.Column(scale=1, min_width=260):
97
+ img_in = gr.Image(
98
+ type="pil",
99
+ sources=["upload"],
100
+ label="Input Image",
101
+ height=280,
102
+ )
103
+ with gr.Row():
104
+ run_btn = gr.Button("β–Ά Run Inference", variant="primary", scale=3)
105
+ clear_btn = gr.ClearButton(
106
+ components=[img_in],
107
+ value="βœ• Clear",
108
+ scale=1,
109
+ )
110
+
111
+ # ── Right: per-model result cards ──────────────────────────────────────
112
+ with gr.Column(scale=2):
113
+ with gr.Row():
114
+ with gr.Column(elem_classes="model-card"):
115
+ out_small = gr.Markdown("### MobileNetV3 Small\n\n_(waiting)_")
116
+ with gr.Column(elem_classes="model-card"):
117
+ out_large = gr.Markdown("### MobileNetV3 Large\n\n_(waiting)_")
118
+ with gr.Column(elem_classes="model-card"):
119
+ out_eff = gr.Markdown("### EfficientNetV2-B0\n\n_(waiting)_")
120
+
121
+ # ── Wire up ────────────────────────────────────────────────────────────────
122
+ run_btn.click(
123
+ fn=predict,
124
+ inputs=[img_in],
125
+ outputs=[out_small, out_large, out_eff],
126
+ )
127
+
128
+ # Reset outputs when Clear is pressed
129
+ clear_btn.click(
130
+ fn=lambda: (
131
+ "### MobileNetV3 Small\n\n_(waiting)_",
132
+ "### MobileNetV3 Large\n\n_(waiting)_",
133
+ "### EfficientNetV2-B0\n\n_(waiting)_",
134
+ ),
135
+ inputs=[],
136
+ outputs=[out_small, out_large, out_eff],
137
+ )
138
+
139
+
140
+ if __name__ == "__main__":
141
+ demo.launch(theme=gr.themes.Soft(), css=CSS)