File size: 6,554 Bytes
50709d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import time
import numpy as np
from PIL import Image
import tensorflow as tf
import gradio as gr

# ─── Constants ────────────────────────────────────────────────────────────────
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODELS_DIR = os.path.join(BASE_DIR, "models")
IMG_SIZE = (224, 224)

MODEL_CONFIGS = [
    ("MobileNetV3 Small",  "mobilenetv3_small.tflite"),
    ("MobileNetV3 Large",  "mobilenetv3_large.tflite"),
    ("EfficientNetV2-B0",  "efficientnetv2b0.tflite"),
]

MEDAL = ["πŸ₯‡", "πŸ₯ˆ", "πŸ₯‰"]

# ─── Load labels ──────────────────────────────────────────────────────────────
with open(os.path.join(MODELS_DIR, "labels.txt")) as f:
    CLASS_NAMES = [line.strip() for line in f if line.strip()]

# ─── Load TFLite interpreters at startup ──────────────────────────────────────
INTERPRETERS: list[tuple[str, tf.lite.Interpreter]] = []
for display_name, filename in MODEL_CONFIGS:
    path = os.path.join(MODELS_DIR, filename)
    interp = tf.lite.Interpreter(model_path=path)
    interp.allocate_tensors()
    INTERPRETERS.append((display_name, interp))

print(f"Loaded {len(INTERPRETERS)} TFLite models.")


# ─── Inference helpers ────────────────────────────────────────────────────────
def preprocess(pil_image: Image.Image) -> np.ndarray:
    """Resize, convert to float32 array, add batch dim. Values in [0, 255]."""
    img = pil_image.convert("RGB").resize(IMG_SIZE, Image.BILINEAR)
    arr = np.array(img, dtype=np.float32)          # (224, 224, 3)
    return np.expand_dims(arr, axis=0)              # (1, 224, 224, 3)


def run_model(interp: tf.lite.Interpreter, img_array: np.ndarray) -> tuple[np.ndarray, float]:
    """Run a single TFLite inference. Returns (probs, latency_ms)."""
    input_idx  = interp.get_input_details()[0]["index"]
    output_idx = interp.get_output_details()[0]["index"]
    interp.set_tensor(input_idx, img_array)
    t0 = time.perf_counter()
    interp.invoke()
    latency_ms = (time.perf_counter() - t0) * 1000
    probs = interp.get_tensor(output_idx)[0]        # shape (num_classes,)
    return probs, latency_ms


def fmt_model_card(display_name: str, probs: np.ndarray, latency_ms: float) -> str:
    """Build a Markdown string for one model's result card."""
    top3_idx = np.argsort(probs)[::-1][:3]
    lines = [f"### {display_name}", f"⏱ `{latency_ms:.1f} ms`", ""]
    for rank, idx in enumerate(top3_idx):
        pct = probs[idx] * 100
        bar = "β–ˆ" * int(pct / 5)            # simple bar, max 20 chars
        lines.append(f"{MEDAL[rank]} **{CLASS_NAMES[idx]}** β€” {pct:.1f}%")
        lines.append(f"`{bar:<20}` ")
        lines.append("")
    return "\n".join(lines)


# ─── Main prediction function ─────────────────────────────────────────────────
def predict(pil_image):
    if pil_image is None:
        empty = "_(upload an image and press Run Inference)_"
        return empty, empty, empty

    img_array = preprocess(pil_image)
    cards: list[str] = []

    for display_name, interp in INTERPRETERS:
        probs, latency_ms = run_model(interp, img_array)
        cards.append(fmt_model_card(display_name, probs, latency_ms))

    return cards[0], cards[1], cards[2]


# ─── Gradio UI ────────────────────────────────────────────────────────────────
CSS = """

.model-card { background: var(--block-background-fill); border-radius: 12px; padding: 16px; }

footer { display: none !important; }

"""

with gr.Blocks(title="Animal Toy Classifier") as demo:
    gr.Markdown("# 🦁 Animal Toy Classifier\nUpload a photo of an animal toy β€” three models will predict the class and vote on the best answer.")

    with gr.Row():
        # ── Left: upload + button ──────────────────────────────────────────────
        with gr.Column(scale=1, min_width=260):
            img_in = gr.Image(
                type="pil",
                sources=["upload"],
                label="Input Image",
                height=280,
            )
            with gr.Row():
                run_btn   = gr.Button("β–Ά  Run Inference", variant="primary", scale=3)
                clear_btn = gr.ClearButton(
                    components=[img_in],
                    value="βœ• Clear",
                    scale=1,
                )

        # ── Right: per-model result cards ──────────────────────────────────────
        with gr.Column(scale=2):
            with gr.Row():
                with gr.Column(elem_classes="model-card"):
                    out_small = gr.Markdown("### MobileNetV3 Small\n\n_(waiting)_")
                with gr.Column(elem_classes="model-card"):
                    out_large = gr.Markdown("### MobileNetV3 Large\n\n_(waiting)_")
                with gr.Column(elem_classes="model-card"):
                    out_eff   = gr.Markdown("### EfficientNetV2-B0\n\n_(waiting)_")

    # ── Wire up ────────────────────────────────────────────────────────────────
    run_btn.click(
        fn=predict,
        inputs=[img_in],
        outputs=[out_small, out_large, out_eff],
    )

    # Reset outputs when Clear is pressed
    clear_btn.click(
        fn=lambda: (
            "### MobileNetV3 Small\n\n_(waiting)_",
            "### MobileNetV3 Large\n\n_(waiting)_",
            "### EfficientNetV2-B0\n\n_(waiting)_",
        ),
        inputs=[],
        outputs=[out_small, out_large, out_eff],
    )


if __name__ == "__main__":
    demo.launch(theme=gr.themes.Soft(), css=CSS)