BGLab commited on
Commit
c231d6a
·
verified ·
1 Parent(s): 4a68f07

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +248 -0
app.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import streamlit as st
6
+ import torch
7
+ import timm
8
+ import pandas as pd
9
+ from PIL import Image
10
+ from torchvision import transforms
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ # =========================
14
+ # Config
15
+ # =========================
16
+ MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "Shivani98/ViT-L_Insect_Classifier")
17
+ MODEL_FILE = os.getenv("MODEL_FILE", "vit_l_518.pth")
18
+ NUM_CLASSES = int(os.getenv("NUM_CLASSES", "3747"))
19
+ IMG_SIZE = int(os.getenv("IMG_SIZE", "518"))
20
+ CPU_THREADS = int(os.getenv("CPU_THREADS", "2"))
21
+ HF_TOKEN = os.getenv("HF_TOKEN")
22
+
23
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
24
+ IMAGENET_STD = (0.229, 0.224, 0.225)
25
+
26
+ MAPPING_XLSX = Path("class_mapping_4k.xlsx") # expects: class_idx, Scientific Name, Common Name, Order, Family
27
+
28
+ # =========================
29
+ # Streamlit basics
30
+ # =========================
31
+ st.set_page_config(page_title="ViT-L InsectNet Classifier", layout="centered")
32
+ st.title("🪲 InsectNet v2 — ViT-L Classifier")
33
+
34
+ torch.set_num_threads(CPU_THREADS)
35
+ torch.set_grad_enabled(False)
36
+
37
+ # =========================
38
+ # Cached: Load model + preprocess
39
+ # =========================
40
+ @st.cache_resource
41
+ def load_model_and_preprocess():
42
+ st.caption("✨ App loaded from `app.py` (Streamlit)")
43
+
44
+ # Download checkpoint (cached by HF)
45
+ ckpt_path = hf_hub_download(
46
+ repo_id=MODEL_REPO_ID,
47
+ filename=MODEL_FILE,
48
+ token=HF_TOKEN,
49
+ cache_dir=str(Path.home() / ".cache" / "huggingface"),
50
+ )
51
+
52
+ # Build model
53
+ model = timm.create_model(
54
+ "vit_large_patch14_reg4_dinov2.lvd142m",
55
+ pretrained=True,
56
+ num_classes=NUM_CLASSES,
57
+ )
58
+
59
+ # Load checkpoint
60
+ ckpt = torch.load(ckpt_path, map_location="cpu")
61
+ state = ckpt.get("model", ckpt.get("state_dict", ckpt)) if isinstance(ckpt, dict) else ckpt
62
+ model.load_state_dict(state, strict=False)
63
+
64
+ # CPU speedup: dynamic quantization
65
+ try:
66
+ model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
67
+ except Exception:
68
+ pass
69
+
70
+ model.eval()
71
+
72
+ preprocess = transforms.Compose([
73
+ transforms.Resize(IMG_SIZE),
74
+ transforms.CenterCrop(IMG_SIZE),
75
+ transforms.ToTensor(),
76
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
77
+ ])
78
+
79
+ # Warmup
80
+ with torch.inference_mode():
81
+ _ = model(torch.zeros(1, 3, IMG_SIZE, IMG_SIZE))
82
+
83
+ return model, preprocess
84
+
85
+ model, preprocess = load_model_and_preprocess()
86
+
87
+ # =========================
88
+ # Cached: Load mapping (xlsx)
89
+ # =========================
90
+ @st.cache_resource
91
+ def load_mapping_table(mapping_path: Path):
92
+ """
93
+ Expects columns:
94
+ - class_idx
95
+ - Scientific Name
96
+ - Common Name
97
+ - Order
98
+ - Family
99
+ """
100
+ if not mapping_path.exists():
101
+ return None
102
+
103
+ df = pd.read_excel(mapping_path)
104
+ # Normalize column names just in case
105
+ # (support a few common variants)
106
+ col_map = {c.lower().strip(): c for c in df.columns}
107
+ required = {
108
+ "class_idx": None,
109
+ "scientific name": None,
110
+ "common name": None,
111
+ "order": None,
112
+ "family": None,
113
+ }
114
+ # Find matching original columns
115
+ for key in list(required.keys()):
116
+ for col in df.columns:
117
+ if col.lower().strip() == key:
118
+ required[key] = col
119
+ break
120
+
121
+ missing = [k for k, v in required.items() if v is None]
122
+ if missing:
123
+ st.warning(f"Mapping file found but missing columns: {missing}. Will fall back to raw indices.")
124
+ return None
125
+
126
+ # Set index to class_idx for O(1) lookup
127
+ df = df.set_index(required["class_idx"])
128
+ return {
129
+ "df": df,
130
+ "cols": {
131
+ "scientific": required["scientific name"],
132
+ "common": required["common name"],
133
+ "order": required["order"],
134
+ "family": required["family"],
135
+ },
136
+ }
137
+
138
+ mapping_store = load_mapping_table(MAPPING_XLSX)
139
+
140
+ # =========================
141
+ # Prediction util
142
+ # =========================
143
+ @torch.inference_mode()
144
+ def predict_indices(img: Image.Image, topk: int = 5):
145
+ x = preprocess(img).unsqueeze(0)
146
+ logits = model(x)
147
+ probs = torch.softmax(logits, dim=1).squeeze(0)
148
+
149
+ topk = min(topk, NUM_CLASSES)
150
+ topk_probs, topk_idx = torch.topk(probs, k=topk)
151
+
152
+ top1_idx = int(topk_idx[0].item())
153
+ top1_prob = float(topk_probs[0].item())
154
+
155
+ top5_idx = [int(i) for i in topk_idx.tolist()]
156
+ top5_prob = [float(p) for p in topk_probs.tolist()]
157
+
158
+ return top1_idx, top1_prob, top5_idx, top5_prob
159
+
160
+ # =========================
161
+ # Helpers to format rows
162
+ # =========================
163
+ def fmt_top1(idx: int, p: float):
164
+ if mapping_store is None:
165
+ st.info(f"Top-1 index: **{idx}** — p={p:.3f}\n\n(Upload a `class_mapping.xlsx` to show names/taxonomy.)")
166
+ return
167
+
168
+ df = mapping_store["df"]
169
+ cols = mapping_store["cols"]
170
+
171
+ if idx not in df.index:
172
+ st.warning(f"Top-1 index {idx} not found in mapping; showing raw index only.")
173
+ st.write(f"Confidence: `{p:.3f}`")
174
+ return
175
+
176
+ row = df.loc[idx]
177
+ sci = row[cols["scientific"]]
178
+ com = row[cols["common"]]
179
+ odr = row[cols["order"]]
180
+ fam = row[cols["family"]]
181
+
182
+ # No index displayed here by design
183
+ st.subheader("🦋 Top-1 Prediction")
184
+ st.markdown(
185
+ f"""
186
+ **Scientific Name:** *{sci}*
187
+ **Common Name:** {com}
188
+ **Order:** {odr}
189
+ **Family:** {fam}
190
+ **Confidence:** `{p:.3f}`
191
+ """.strip()
192
+ )
193
+
194
+ def fmt_top5(idxs, ps):
195
+ st.markdown("### 🌿 Top-5 Predictions")
196
+ if mapping_store is None:
197
+ for i, p in zip(idxs, ps):
198
+ st.write(f"- Index **{i}** — p={p:.3f}")
199
+ return
200
+
201
+ df = mapping_store["df"]
202
+ cols = mapping_store["cols"]
203
+
204
+ for i, p in zip(idxs, ps):
205
+ if i in df.index:
206
+ row = df.loc[i]
207
+ sci = row[cols["scientific"]]
208
+ com = row[cols["common"]]
209
+ # Only scientific + common for top-5
210
+ st.markdown(f"- **{sci}** (*{com}*) — `{p:.3f}`")
211
+ else:
212
+ st.markdown(f"- Index **{i}** — `{p:.3f}`")
213
+
214
+ # =========================
215
+ # UI
216
+ # =========================
217
+ with st.sidebar:
218
+ st.header("Settings")
219
+ fps_note = st.caption("Model: ViT-L DINOv2 head · Image size: {}".format(IMG_SIZE))
220
+ if mapping_store is None:
221
+ st.warning("No `class_mapping.xlsx` found. Top-1/Top-5 will show indices only.")
222
+
223
+ uploaded = st.file_uploader("Upload a JPG/PNG", type=["jpg", "jpeg", "png"])
224
+ if uploaded:
225
+ try:
226
+ img = Image.open(uploaded).convert("RGB")
227
+ except Exception as e:
228
+ st.error(f"Failed to read image: {e}")
229
+ st.stop()
230
+
231
+ st.image(img, caption="Input", use_container_width=True)
232
+
233
+ with st.spinner("Predicting…"):
234
+ top1_idx, top1_prob, top5_idx, top5_prob = predict_indices(img, topk=5)
235
+
236
+ # Render: Top-1 (all attributes, no index), then Top-5 (name + common only)
237
+ fmt_top1(top1_idx, top1_prob)
238
+ fmt_top5(top5_idx, top5_prob)
239
+
240
+ with st.expander("Advanced • Raw indices & probabilities"):
241
+ st.write(f"Top-1 index: **{top1_idx}** — p={top1_prob:.4f}")
242
+ for i, p in zip(top5_idx, top5_prob):
243
+ st.write(f"- {i} : {p:.4f}")
244
+ else:
245
+ st.info("Upload an image to see predictions.")
246
+
247
+ st.caption("Tip: place `class_mapping.xlsx` next to this script with columns: "
248
+ "`class_idx, Scientific Name, Common Name, Order, Family`.")