CarolinaSMarques commited on
Commit
8e67d41
·
verified ·
1 Parent(s): 3f42174

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +434 -0
app.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Vertebrates Track Classifier (EfficientNet-B0)
4
+
5
+ - Input: one or more photographs
6
+ - Output: top-3 most probable classes + probabilities
7
+ - Classes:
8
+ Bear, Coyote, Deer, Fox, Turkey, Otter,
9
+ Squirrel, Raccoon, Sauropod, Theropod
10
+ """
11
+
12
+ import os
13
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # safe no-op in most environments
14
+
15
+ import tempfile
16
+ import html
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torchvision import models, transforms
21
+
22
+ from PIL import Image
23
+ import numpy as np
24
+ import pandas as pd
25
+ import gradio as gr
26
+
27
+ # =========================
28
+ # Config
29
+ # =========================
30
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+
32
+ # IMAGE_SIZE from training script (typically 224 for EfficientNet-B0)
33
+ IMAGE_SIZE = 224
34
+ criterion = nn.CrossEntropyLoss()
35
+ # ----- Class names -----
36
+ # IMPORTANT: the order MUST match the class order used during training.
37
+ # If you used torchvision.datasets.ImageFolder, this is the alphabetical
38
+ # order of your training subfolders.
39
+ CLASS_NAMES = [
40
+ "Bear",
41
+ "Coyote",
42
+ "Deer",
43
+ "Fox",
44
+ "Otter",
45
+ "Raccoon",
46
+ "Sauropod",
47
+ "Squirrel",
48
+ "Theropod",
49
+ "Turkey",
50
+ ]
51
+
52
+ NUM_CLASSES = len(CLASS_NAMES)
53
+
54
+ # ---- Checkpoint path (Hugging Face: relative 'checkpoints' folder) ----
55
+ THIS_DIR = os.path.dirname(os.path.abspath(__file__))
56
+ CHECKPOINT_PATH = os.path.join(THIS_DIR, "checkpoints", "model_checkpoint_8.pth")
57
+ # Put your .pth file in: ./checkpoints/wild_dino_tracks_efficientnet_b0.pth
58
+ # or change the filename above to match your checkpoint.
59
+
60
+ # =========================
61
+ # Preprocessing (matches training)
62
+ # =========================
63
+ INFER_TRANSFORM = transforms.Compose([
64
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
65
+ transforms.ToTensor(),
66
+ transforms.Normalize(
67
+ mean=[0.485, 0.456, 0.406], # ImageNet mean
68
+ std=[0.229, 0.224, 0.225], # ImageNet std
69
+ ),
70
+ ])
71
+
72
+
73
+ # =========================
74
+ # Model definitions
75
+ # =========================
76
+ def create_efficientnet_b0(num_classes: int) -> nn.Module:
77
+ """
78
+ EfficientNet-B0 head adapted for num_classes.
79
+ Matches typical transfer-learning setup:
80
+ model = models.efficientnet_b0(pretrained=True)
81
+ in_features = model.classifier[1].in_features
82
+ model.classifier = nn.Sequential(
83
+ nn.Dropout(p=0.2),
84
+ nn.Linear(in_features, num_classes)
85
+ )
86
+ """
87
+ model = models.efficientnet_b0(pretrained=True)
88
+ in_features = model.classifier[1].in_features
89
+ model.classifier = nn.Sequential(
90
+ nn.Dropout(p=0.2),
91
+ nn.Linear(in_features, num_classes),
92
+ )
93
+ return model
94
+
95
+
96
+ def _safe_torch_load(path: str):
97
+ """
98
+ Helper to handle PyTorch 2.6+ (weights_only=True by default) and older versions.
99
+ """
100
+ try:
101
+ # Newer PyTorch versions
102
+ return torch.load(path, map_location="cpu", weights_only=False)
103
+ except TypeError:
104
+ # Older PyTorch versions (no weights_only argument)
105
+ return torch.load(path, map_location="cpu")
106
+
107
+
108
+ def load_model(checkpoint_path: str) -> nn.Module:
109
+ """
110
+ Load EfficientNet-B0 model and checkpoint.
111
+ """
112
+ if not os.path.exists(checkpoint_path):
113
+ raise FileNotFoundError(
114
+ f"Checkpoint not found: {checkpoint_path}\n"
115
+ "Make sure the .pth is in the 'checkpoints' folder."
116
+ )
117
+
118
+ model = create_efficientnet_b0(NUM_CLASSES)
119
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
120
+ ckpt = _safe_torch_load(checkpoint_path)
121
+
122
+ if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
123
+ state_dict = ckpt["model_state_dict"]
124
+
125
+ optimizer.load_state_dict(ckpt['optimizer_state_dict'])
126
+
127
+
128
+ else:
129
+ state_dict = ckpt
130
+
131
+ model.load_state_dict(state_dict)
132
+ model.to(DEVICE)
133
+ model.eval()
134
+ return model
135
+
136
+
137
+ # Lazy global model
138
+ _MODEL = None
139
+
140
+
141
+ def get_model() -> nn.Module:
142
+ global _MODEL
143
+ if _MODEL is None:
144
+ _MODEL = load_model(CHECKPOINT_PATH)
145
+ return _MODEL
146
+
147
+
148
+ # =========================
149
+ # Prediction helpers
150
+ # =========================
151
+ @torch.no_grad()
152
+ def predict_top3_from_pil(pil_img: Image.Image):
153
+ """
154
+ Input: PIL image
155
+ Output: (top3_class_names, top3_probs) where probs are floats in [0,1]
156
+ """
157
+ model = get_model()
158
+
159
+ img = pil_img.convert("RGB")
160
+ x = INFER_TRANSFORM(img).unsqueeze(0).to(DEVICE) # [1,3,H,W]
161
+
162
+ logits = model(x) # [1, num_classes]
163
+ probs = torch.softmax(logits, dim=1)[0].cpu().numpy() # (num_classes,)
164
+
165
+ top_idx = np.argsort(-probs)[:3]
166
+ top_classes = [CLASS_NAMES[i] for i in top_idx]
167
+ top_probs = [float(probs[i]) for i in top_idx]
168
+
169
+ return top_classes, top_probs
170
+
171
+
172
+ def df_to_html(df: pd.DataFrame) -> str:
173
+ """
174
+ Convert the predictions DataFrame into a styled HTML table.
175
+ """
176
+ if df.empty:
177
+ return "<p>No predictions to display yet.</p>"
178
+
179
+ headers = df.columns.tolist()
180
+ header_cells = "".join(f"<th>{html.escape(str(h))}</th>" for h in headers)
181
+ rows_html = []
182
+
183
+ for _, row in df.iterrows():
184
+ cells = []
185
+ for col in headers:
186
+ val = row[col]
187
+
188
+ if val is None or (isinstance(val, float) and np.isnan(val)):
189
+ disp = ""
190
+ elif isinstance(val, float):
191
+ # Round decimals for readability
192
+ disp = f"{val:.3f}"
193
+ else:
194
+ disp = str(val)
195
+
196
+ cells.append(f"<td>{html.escape(disp)}</td>")
197
+
198
+ rows_html.append("<tr>" + "".join(cells) + "</tr>")
199
+
200
+ table_html = (
201
+ "<div class='pred-table'>"
202
+ "<table>"
203
+ "<thead><tr>"
204
+ f"{header_cells}"
205
+ "</tr></thead>"
206
+ "<tbody>"
207
+ f"{''.join(rows_html)}"
208
+ "</tbody></table>"
209
+ "</div>"
210
+ )
211
+ return table_html
212
+
213
+
214
+ def classify_batch(filepaths):
215
+ """
216
+ Gradio callback.
217
+ """
218
+ cols = [
219
+ "image_name",
220
+ "top1_class", "top1_prob",
221
+ "top2_class", "top2_prob",
222
+ "top3_class", "top3_prob",
223
+ ]
224
+
225
+ if not filepaths:
226
+ empty_df = pd.DataFrame(columns=cols)
227
+ html_table = df_to_html(empty_df)
228
+ return html_table, "Please upload at least one image.", None
229
+
230
+ rows = []
231
+
232
+ for path in filepaths:
233
+ try:
234
+ pil = Image.open(path).convert("RGB")
235
+ top_classes, top_probs = predict_top3_from_pil(pil)
236
+
237
+ rows.append({
238
+ "image_name": os.path.basename(str(path)),
239
+ "top1_class": top_classes[0],
240
+ "top1_prob": top_probs[0],
241
+ "top2_class": top_classes[1],
242
+ "top2_prob": top_probs[1],
243
+ "top3_class": top_classes[2],
244
+ "top3_prob": top_probs[2],
245
+ })
246
+ except Exception as e:
247
+ rows.append({
248
+ "image_name": os.path.basename(str(path)),
249
+ "top1_class": f"Error: {e}",
250
+ "top1_prob": None,
251
+ "top2_class": None,
252
+ "top2_prob": None,
253
+ "top3_class": None,
254
+ "top3_prob": None,
255
+ })
256
+
257
+ df = pd.DataFrame(rows)
258
+ status = f"Processed {len(rows)} photograph(s)."
259
+
260
+ tmpdir = tempfile.mkdtemp()
261
+ csv_path = os.path.join(tmpdir, "predictions_vert_tracks.csv")
262
+ df.to_csv(csv_path, index=False)
263
+
264
+ html_table = df_to_html(df)
265
+ return html_table, status, csv_path
266
+
267
+
268
+ # =========================
269
+ # Gradio UI (paleo + wildlife aesthetics)
270
+ # =========================
271
+ theme = gr.themes.Soft(
272
+ primary_hue="orange",
273
+ secondary_hue="amber",
274
+ neutral_hue="gray",
275
+ )
276
+
277
+ with gr.Blocks(theme=theme, css="""
278
+ .gradio-container {
279
+ font-family: 'Georgia', 'Times New Roman', serif;
280
+ }
281
+
282
+ .app-wrapper {
283
+ max-width: 1100px;
284
+ margin: 0 auto;
285
+ padding: 1.5rem 1rem 2rem 1rem;
286
+ }
287
+
288
+ .app-header {
289
+ text-align: center;
290
+ margin-bottom: 1.2rem;
291
+ }
292
+
293
+ .app-header h1 {
294
+ font-size: 2.1rem;
295
+ margin-bottom: 0.3rem;
296
+ }
297
+
298
+ .app-header h2 {
299
+ font-size: 1.1rem;
300
+ font-weight: normal;
301
+ opacity: 0.9;
302
+ }
303
+
304
+ .app-panel {
305
+ background: rgba(255, 255, 255, 0.85);
306
+ border-radius: 14px;
307
+ padding: 1.2rem 1.5rem;
308
+ margin-bottom: 1rem;
309
+ border: 1px solid rgba(120, 82, 45, 0.18);
310
+ }
311
+
312
+ /* === predictions table wrapper === */
313
+ .pred-table {
314
+ width: 100%;
315
+ overflow-x: auto; /* horizontal scrollbar if needed */
316
+ }
317
+
318
+ /* Styled table for predictions */
319
+ .pred-table table {
320
+ width: 100%;
321
+ min-width: 650px;
322
+ border-collapse: collapse;
323
+ margin-top: 0.5rem;
324
+ font-size: 0.9rem;
325
+ }
326
+
327
+ .pred-table thead {
328
+ background: #e0cfb3;
329
+ }
330
+
331
+ .pred-table th, .pred-table td {
332
+ border: 1px solid #d0b897;
333
+ padding: 0.4rem 0.6rem;
334
+ text-align: center;
335
+ color: #000000;
336
+ white-space: nowrap;
337
+ }
338
+
339
+ .pred-table th {
340
+ font-weight: 600;
341
+ }
342
+
343
+ .pred-table tbody tr:nth-child(even) {
344
+ background: #f7eee2;
345
+ }
346
+
347
+ .pred-table tbody tr:nth-child(odd) {
348
+ background: #fbf4ea;
349
+ }
350
+
351
+ /* first column (image name) left-aligned */
352
+ .pred-table td:first-child {
353
+ text-align: left;
354
+ }
355
+ """) as demo:
356
+
357
+ gr.HTML("<div class='app-wrapper'>")
358
+
359
+ # ----- Header -----
360
+ gr.HTML("""
361
+ <div class="app-header">
362
+ <h1>🐾 Vertebrate Tracks Classifier</h1>
363
+ <h2>Deep-learning assisted ichnological identifications with EfficientNet-B0</h2>
364
+ Model finetuned from a model trained on data obtained by the
365
+ <a href="https://zenodo.org/records/15092442" target="_blank">Deep Tracks</a>
366
+ App.<br>
367
+ Developed by <b>Carolina S. Marques</b>
368
+ (<a href="https://orcid.org/0000-0002-5936-9342" target="_blank">ORCID</a>)
369
+ as part of her PhD research, funded by CEAUL through FCT - Fundação para a Ciência e Tecnologia
370
+ (<a href="https://doi.org/10.54499/UI/BD/154258/2022" target="_blank">DOI</a>).
371
+ </div>
372
+
373
+ """)
374
+
375
+
376
+
377
+ with gr.Row():
378
+ with gr.Column(scale=1):
379
+ gr.HTML("<div class='app-panel'>")
380
+ gr.Markdown(
381
+ " This model distinguishes between footprints of <b>Bear, Coyote, Deer, Fox, Turkey, Otter, Squirrel, Raccoon<b> as well as dinosaur tracks attributed to <b>Sauropod</b> and <b>Theropod</b> trackmakers.\n"
382
+ "#### 1. Upload track photographs\n"
383
+ "You can upload one or more photos of footprints from different vertebrates. "
384
+ "The network will estimate, for each image, the probability of belonging to each of the ten classes:\n\n"
385
+ "- Bear, Coyote, Deer, Fox, Turkey, Otter, Squirrel, Raccoon\n"
386
+ "- Sauropod, Theropod (dinosaur tracks)\n"
387
+ )
388
+
389
+ img_files = gr.Files(
390
+ label="Track photographs (you can select multiple files)",
391
+ file_types=["image"],
392
+ file_count="multiple",
393
+ type="filepath",
394
+ )
395
+ classify_btn = gr.Button("Run classification", variant="primary")
396
+ gr.HTML("</div>")
397
+
398
+ with gr.Column(scale=1.4):
399
+ gr.HTML("<div class='app-panel'>")
400
+ gr.Markdown("#### Predicted classes and probabilities")
401
+ results_html = gr.HTML(label="Top-3 predictions per image")
402
+ gr.Markdown(
403
+ "_How to read the table:_\n"
404
+ "- **top1_class** / **top1_prob**: class with the highest predicted probability for that image, and the corresponding probability.\n"
405
+ "- **top2_class** / **top2_prob**: second most probable class and the corresponding probability.\n"
406
+ "- **top3_class** / **top3_prob**: third most probable class and the corresponding probability.\n"
407
+ "- Probabilities are between 0 and 1 and, for each image, they sum to 1 across all ten classes."
408
+ )
409
+ gr.HTML("</div>")
410
+
411
+ gr.HTML("<div class='app-panel'>")
412
+ status_md = gr.Markdown()
413
+ df_file = gr.File(
414
+ label="Download full predictions as CSV",
415
+ file_types=[".csv"],
416
+ )
417
+ gr.Markdown(
418
+ "_Note_: The CSV export is plain text, ready to be used in R, Python, or Excel "
419
+ "for further analysis (e.g., confusion matrices, ROC curves, etc.)."
420
+ )
421
+ gr.HTML("</div>")
422
+
423
+ gr.HTML("</div>") # close app-wrapper
424
+
425
+ classify_btn.click(
426
+ fn=classify_batch,
427
+ inputs=[img_files],
428
+ outputs=[results_html, status_md, df_file],
429
+ )
430
+
431
+ # For local dev / Hugging Face Spaces:
432
+ if __name__ == "__main__":
433
+ demo.queue()
434
+ demo.launch()