khhamid commited on
Commit
3e0639f
·
verified ·
1 Parent(s): 4d490e5

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ from PIL import Image
4
+ from ai_edge_litert.interpreter import Interpreter
5
+ import time
6
+ import os
7
+
8
+ os.environ["STREAMLIT_CACHE_DIR"] = "/tmp/streamlit-cache"
9
+
10
+ st.title("🌿 MobileNet TFLite Image Classifier")
11
+ st.write("Upload an image to test your quantized MobileNet model.")
12
+
13
+
14
+ def load_labels(path):
15
+ try:
16
+ with open(path, "r") as f:
17
+ return [line.strip() for line in f.readlines()]
18
+ except FileNotFoundError:
19
+ return None
20
+
21
+ labels = load_labels("class_names.txt")
22
+
23
+ def preprocess(image_array: np.ndarray) -> np.ndarray:
24
+ """Replicate keras.applications.mobilenet_v3.preprocess_input"""
25
+ image_array = image_array.astype(np.float32)
26
+ image_array = image_array / 127.5 - 1.0 # scale to [-1, 1]
27
+ return image_array
28
+
29
+ @st.cache_resource
30
+ def load_tflite_model():
31
+ interpreter = Interpreter(model_path="models/mobilenet_int8.tflite")
32
+ interpreter.allocate_tensors()
33
+ return interpreter
34
+ interpreter = load_tflite_model()
35
+ input_details = interpreter.get_input_details()
36
+ output_details = interpreter.get_output_details()
37
+
38
+
39
+ uploaded_file = st.file_uploader("📸 Choose an image...", type=["jpg", "jpeg", "png"])
40
+
41
+ if uploaded_file is not None:
42
+ image = Image.open(uploaded_file).convert("RGB")
43
+ st.image(image, caption="Uploaded Image", width="stretch")
44
+
45
+
46
+
47
+ if st.button("🔍 Predict"):
48
+ with st.spinner("Analyzing image..."):
49
+ img = image.resize((224, 224))
50
+ img = np.array(img)
51
+ input_data = preprocess(img)
52
+ input_data = np.expand_dims(img, axis=0).astype(np.float32)
53
+
54
+ start = time.time()
55
+ interpreter.set_tensor(input_details[0]['index'], input_data)
56
+ interpreter.invoke()
57
+ preds = interpreter.get_tensor(output_details[0]['index'])[0]
58
+ inference_time = (time.time() - start) * 1000
59
+
60
+ top_k = preds.argsort()[-3:][::-1]
61
+ st.markdown("### 🌱 Predictions:")
62
+ for i in top_k:
63
+ label = labels[i] if labels else f"Class {i}"
64
+ st.write(f"**{label}** — {preds[i] * 100:.2f}%")
65
+ if preds[i]==1:
66
+ break
67
+ st.info(f"⚡ Inference Time: {inference_time:.2f} ms")