AndrewKof commited on
Commit
c34dda4
Β·
1 Parent(s): 361c20d

πŸš€ Update: new UI design, classification plot, and attention improvements

Browse files
app/Inference.py CHANGED
@@ -1,29 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
 
2
  import torch
3
- from transformers import AutoProcessor, Dinov2ForImageClassification
 
4
  from PIL import Image
5
- from torch.nn.functional import softmax
6
-
7
- # --- Load mapping ---
8
- with open("id2name.json", "r") as f:
9
- id2name = json.load(f)
10
-
11
- # --- Load model ---
12
- model_name = "Arew99/dinov2-costum"
13
- processor = AutoProcessor.from_pretrained(model_name)
14
- model = Dinov2ForImageClassification.from_pretrained(model_name)
15
- model.eval()
16
-
17
- # --- Load image (example) ---
18
- image = Image.open("sample_fish.jpg").convert("RGB")
19
- inputs = processor(images=image, return_tensors="pt")
20
-
21
- # --- Inference ---
22
- with torch.no_grad():
23
- logits = model(**inputs).logits.squeeze(0)
24
- probs, idxs = softmax(logits, dim=0).topk(5)
25
-
26
- print("\nTop-5 predictions:")
27
- for p, i in zip(probs.tolist(), idxs.tolist()):
28
- label = id2name[str(i)]
29
- print(f"{label:30s} {p*100:.2f}%")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # # app/inference.py
2
+ # import os
3
+ # import io
4
+ # import json
5
+ # import torch
6
+ # import torch.nn.functional as F
7
+ # from PIL import Image
8
+ # from transformers import AutoImageProcessor, AutoModelForImageClassification
9
+
10
+ # # ─────────────────────────────────────────────
11
+ # # CONFIG
12
+ # # ─────────────────────────────────────────────
13
+ # # Hugging Face repo for the trained checkpoint
14
+ # MODEL_REPO = "Arew99/dinov2-costum"
15
+
16
+ # # optional: path to local label map (bundled in your repo)
17
+ # ID2NAME_PATH = os.path.join(os.path.dirname(__file__), "id2name.json")
18
+
19
+ # # detect device automatically
20
+ # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ # print(f"🧠 Using device: {DEVICE}")
22
+
23
+ # # global cache (so model loads only once)
24
+ # _model = None
25
+ # _processor = None
26
+ # _id2name = None
27
+
28
+
29
+ # # ─────────────────────────────────────────────
30
+ # # HELPER: load label map
31
+ # # ─────────────────────────────────────────────
32
+ # def _load_id2name():
33
+ # if os.path.exists(ID2NAME_PATH):
34
+ # with open(ID2NAME_PATH, "r") as f:
35
+ # data = json.load(f)
36
+ # # ensure integer keys
37
+ # return {int(k): v for k, v in data.items()}
38
+ # print("⚠️ id2name.json not found β€” using placeholder labels.")
39
+ # return {i: f"Class {i}" for i in range(101)} # fallback
40
+
41
+
42
+ # # ─────────────────────────────────────────────
43
+ # # INIT: load model & processor once
44
+ # # ─────────────────────────────────────────────
45
+ # def load_classification_model():
46
+ # global _model, _processor, _id2name
47
+
48
+ # if _model is not None:
49
+ # return _model, _processor, _id2name
50
+
51
+ # print(f"πŸ” Loading model from Hugging Face repo: {MODEL_REPO}")
52
+ # _processor = AutoImageProcessor.from_pretrained(MODEL_REPO)
53
+ # _model = AutoModelForImageClassification.from_pretrained(
54
+ # MODEL_REPO,
55
+ # ignore_mismatched_sizes=True,
56
+ # ).to(DEVICE)
57
+ # _model.eval()
58
+ # _id2name = _load_id2name()
59
+
60
+ # print(f"βœ… Model loaded and ready on {DEVICE}")
61
+ # return _model, _processor, _id2name
62
+
63
+
64
+ # # ─────────────────────────────────────────────
65
+ # # INFERENCE: classify raw image bytes
66
+ # # ─────────────────────────────────────────────
67
+ # def classify_bytes(image_bytes: bytes):
68
+ # model, processor, id2name = load_classification_model()
69
+
70
+ # image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
71
+ # inputs = processor(images=image, return_tensors="pt").to(DEVICE)
72
+
73
+ # with torch.no_grad():
74
+ # outputs = model(**inputs)
75
+ # probs = F.softmax(outputs.logits, dim=-1)
76
+
77
+ # topk = torch.topk(probs, k=5)
78
+ # indices = topk.indices[0].tolist()
79
+ # values = topk.values[0].tolist()
80
+
81
+ # results = []
82
+ # for rank, (idx, prob) in enumerate(zip(indices, values), 1):
83
+ # label = id2name.get(int(idx), f"Class {idx}")
84
+ # results.append({
85
+ # "rank": rank,
86
+ # "id": int(idx),
87
+ # "label": label,
88
+ # "score": float(prob),
89
+ # })
90
+
91
+ # # concise summary for API
92
+ # return {
93
+ # "top1": results[0],
94
+ # "top5": results,
95
+ # }
96
+
97
+ # app/inference.py
98
+ import os
99
+ import io
100
  import json
101
+ import base64
102
  import torch
103
+ import torch.nn.functional as F
104
+ import matplotlib.pyplot as plt
105
  from PIL import Image
106
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
107
+
108
+ # ─────────────────────────────────────────────
109
+ # CONFIG
110
+ # ─────────────────────────────────────────────
111
+ MODEL_REPO = "Arew99/dinov2-costum" # your Hugging Face repo
112
+ ID2NAME_PATH = os.path.join(os.path.dirname(__file__), "id2name.json")
113
+
114
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
115
+ print(f"🧠 Using device: {DEVICE}")
116
+
117
+ _model = None
118
+ _processor = None
119
+ _id2name = None
120
+
121
+
122
+ # ─────────────────────────────────────────────
123
+ # HELPER β€” load id2name mapping
124
+ # ─────────────────────────────────────────────
125
+ def _load_id2name():
126
+ if os.path.exists(ID2NAME_PATH):
127
+ with open(ID2NAME_PATH, "r") as f:
128
+ data = json.load(f)
129
+ return {int(k): v for k, v in data.items()}
130
+ print("⚠️ id2name.json not found β€” using placeholder labels.")
131
+ return {i: f"Class {i}" for i in range(101)}
132
+
133
+
134
+ # ─────────────────────────────────────────────
135
+ # LOAD MODEL (cached globally)
136
+ # ─────────────────────────────────────────────
137
+ def load_classification_model():
138
+ global _model, _processor, _id2name
139
+
140
+ if _model is not None:
141
+ return _model, _processor, _id2name
142
+
143
+ print(f"πŸ” Loading model from Hugging Face repo: {MODEL_REPO}")
144
+ _processor = AutoImageProcessor.from_pretrained(MODEL_REPO)
145
+ _model = AutoModelForImageClassification.from_pretrained(
146
+ MODEL_REPO,
147
+ ignore_mismatched_sizes=True,
148
+ ).to(DEVICE)
149
+ _model.eval()
150
+ _id2name = _load_id2name()
151
+
152
+ print(f"βœ… Model loaded and ready on {DEVICE}")
153
+ return _model, _processor, _id2name
154
+
155
+
156
+ # ─────────────────────────────────────────────
157
+ # CLASSIFY IMAGE BYTES
158
+ # ─────────────────────────────────────────────
159
+ def classify_bytes(image_bytes: bytes):
160
+ model, processor, id2name = load_classification_model()
161
+
162
+ # Load and preprocess image
163
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
164
+ inputs = processor(images=image, return_tensors="pt").to(DEVICE)
165
+
166
+ # Forward pass
167
+ with torch.no_grad():
168
+ outputs = model(**inputs)
169
+ probs = F.softmax(outputs.logits, dim=-1)
170
+
171
+ # Top-5 predictions
172
+ topk = torch.topk(probs, k=5)
173
+ indices = topk.indices[0].tolist()
174
+ values = topk.values[0].tolist()
175
+
176
+ results = []
177
+ for rank, (idx, prob) in enumerate(zip(indices, values), 1):
178
+ label = id2name.get(int(idx), f"Class {idx}")
179
+ results.append({
180
+ "rank": rank,
181
+ "id": int(idx),
182
+ "label": label,
183
+ "score": float(prob),
184
+ })
185
+
186
+ # ───────────────────────────────
187
+ # MATPLOTLIB TOP-3 PLOT
188
+ # ───────────────────────────────
189
+ top3 = results[:3]
190
+ labels = [p["label"] for p in top3]
191
+ probs_top3 = [p["score"] * 100 for p in top3]
192
+
193
+ plt.style.use("seaborn-v0_8-whitegrid")
194
+ fig, ax = plt.subplots(1, 2, figsize=(9, 4))
195
+
196
+ # Left: input image
197
+ ax[0].imshow(image)
198
+ ax[0].axis("off")
199
+ ax[0].set_title("Input Image", fontsize=12, weight="bold")
200
+
201
+ # Right: horizontal bar chart
202
+ bars = ax[1].barh(labels[::-1], probs_top3[::-1],
203
+ color=["#C44E52", "#55A868", "#4C72B0"],
204
+ edgecolor="none", height=0.6)
205
+ ax[1].set_xlim(0, 100)
206
+ ax[1].set_xlabel("Probability (%)", fontsize=11)
207
+ ax[1].set_title("Top-3 Predicted Species", fontsize=12, weight="bold")
208
+
209
+ for bar, prob in zip(bars, probs_top3[::-1]):
210
+ ax[1].text(prob + 1, bar.get_y() + bar.get_height()/2,
211
+ f"{prob:.1f}%", va="center", fontsize=10, weight="bold")
212
+
213
+ plt.tight_layout()
214
+
215
+ # Encode plot as base64
216
+ buf = io.BytesIO()
217
+ plt.savefig(buf, format="png", bbox_inches="tight")
218
+ plt.close(fig)
219
+ buf.seek(0)
220
+ plot_b64 = base64.b64encode(buf.read()).decode("utf-8")
221
+ buf.close()
222
+
223
+ # ───────────────────────────────
224
+ # FINAL OUTPUT
225
+ # ───────────────────────────────
226
+ return {
227
+ "top1": results[0],
228
+ "top5": results,
229
+ "plot": f"data:image/png;base64,{plot_b64}"
230
+ }
231
+
232
+
233
+ # ─────────────────────────────────────────────
234
+ # LOCAL TEST
235
+ # ──────────────────────────────���──────────────
236
+ if __name__ == "__main__":
237
+ test_img = "sample3.jpg"
238
+ with open(test_img, "rb") as f:
239
+ img_bytes = f.read()
240
+ out = classify_bytes(img_bytes)
241
+ print(json.dumps(out["top5"], indent=2))
242
+ print("\nPlot base64 length:", len(out["plot"]))
app/MyInference.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
4
+ import torch.nn.functional as F
5
+ import json
6
+ import matplotlib.pyplot as plt
7
+
8
+ # ===== Load label names =====
9
+ with open("id2name.json", "r") as f:
10
+ id2name = json.load(f)
11
+
12
+ # ===== Paths =====
13
+ model_dir = "ckpt_merged_large"
14
+ image_path = "sample3.jpg"
15
+
16
+ # ===== Auto-detect device =====
17
+ device = "cpu"
18
+ print(f"Using device: {device}")
19
+
20
+ # ===== Load processor & model =====
21
+ processor = AutoImageProcessor.from_pretrained(model_dir)
22
+ model = AutoModelForImageClassification.from_pretrained(model_dir)
23
+ model.to(device)
24
+ model.eval()
25
+
26
+ # ===== Preprocess image =====
27
+ image = Image.open(image_path).convert("RGB")
28
+ inputs = processor(images=image, return_tensors="pt").to(device)
29
+
30
+ # ===== Inference =====
31
+ with torch.no_grad():
32
+ outputs = model(**inputs)
33
+ probs = F.softmax(outputs.logits, dim=-1)
34
+
35
+ # ===== Top-5 predictions =====
36
+ topk = torch.topk(probs, k=5)
37
+ indices = topk.indices[0].tolist()
38
+ values = topk.values[0].tolist()
39
+
40
+ print("\nTop-5 Predictions:")
41
+ for rank, (idx, prob) in enumerate(zip(indices, values), 1):
42
+ label = id2name[str(idx)]
43
+ print(f"{rank}. {label:<30} ({prob*100:.2f}%)")
44
+
45
+ # ===== Prepare Top-3 for plotting =====
46
+ top3_labels = [id2name[str(indices[i])] for i in range(3)]
47
+ top3_probs = [values[i] * 100 for i in range(3)]
48
+
49
+ # ===== Styled plot =====
50
+ plt.style.use("seaborn-v0_8-whitegrid")
51
+ fig, ax = plt.subplots(1, 2, figsize=(10, 4))
52
+
53
+ # -- Left: input image --
54
+ ax[0].imshow(image)
55
+ ax[0].axis("off")
56
+ ax[0].set_title("Input Image", fontsize=13, weight="bold")
57
+
58
+ # -- Right: bar chart --
59
+ bars = ax[1].barh(top3_labels[::-1], top3_probs[::-1],
60
+ color=["#4C72B0", "#55A868", "#C44E52"],
61
+ edgecolor="none", height=0.6)
62
+
63
+ ax[1].set_xlim(0, 100)
64
+ ax[1].set_xlabel("Probability (%)", fontsize=12)
65
+ ax[1].set_title("Top-3 Predicted Species", fontsize=13, weight="bold")
66
+
67
+ # Add percentage labels next to bars
68
+ for bar, prob in zip(bars, top3_probs[::-1]):
69
+ ax[1].text(prob + 1, bar.get_y() + bar.get_height() / 2,
70
+ f"{prob:.1f}%", va="center", fontsize=11, weight="bold", color="#333")
71
+
72
+ plt.tight_layout()
73
+ plt.show()
app/__pycache__/inference.cpython-310.pyc ADDED
Binary file (4.16 kB). View file
 
app/main.py CHANGED
@@ -1,10 +1,11 @@
1
  # app/main.py
2
- import os
3
- from fastapi import FastAPI, File, UploadFile
4
- from fastapi.middleware.cors import CORSMiddleware
5
- from fastapi.responses import HTMLResponse
6
  from fastapi.staticfiles import StaticFiles
 
7
  from app.model import load_model, predict_from_bytes
 
 
 
8
 
9
 
10
  # ──────────────────────────────────────────────
@@ -12,13 +13,13 @@ from app.model import load_model, predict_from_bytes
12
  # ──────────────────────────────────────────────
13
  app = FastAPI(title="NEMO Tools")
14
 
15
- app.add_middleware(
16
- CORSMiddleware,
17
- allow_origins=["*"],
18
- allow_credentials=True,
19
- allow_methods=["*"],
20
- allow_headers=["*"],
21
- )
22
 
23
  # ──────────────────────────────────────────────
24
  # Static Frontend
@@ -42,6 +43,17 @@ print("πŸš€ Loading DINOv2 custom model...")
42
  model_device_tuple = load_model()
43
  print("βœ… Model loaded and ready for inference!")
44
 
 
 
 
 
 
 
 
 
 
 
 
45
  # ──────────────────────────────────────────────
46
  # API Endpoints
47
  # ──────────────────────────────────────────────
@@ -52,6 +64,11 @@ async def generate_attention(file: UploadFile = File(...)):
52
  result = predict_from_bytes(model_device_tuple, image_bytes)
53
  return result
54
 
 
 
 
 
 
55
  @app.get("/api")
56
  def api_root():
57
  return {"message": "NEMO Tools backend running."}
 
1
  # app/main.py
2
+ from fastapi import FastAPI, UploadFile, File
 
 
 
3
  from fastapi.staticfiles import StaticFiles
4
+ from fastapi.responses import HTMLResponse, JSONResponse
5
  from app.model import load_model, predict_from_bytes
6
+ from app.inference import load_classification_model, classify_bytes
7
+ import json, os
8
+
9
 
10
 
11
  # ──────────────────────────────────────────────
 
13
  # ──────────────────────────────────────────────
14
  app = FastAPI(title="NEMO Tools")
15
 
16
+ # app.add_middleware(
17
+ # CORSMiddleware,
18
+ # allow_origins=["*"],
19
+ # allow_credentials=True,
20
+ # allow_methods=["*"],
21
+ # allow_headers=["*"],
22
+ # )
23
 
24
  # ──────────────────────────────────────────────
25
  # Static Frontend
 
43
  model_device_tuple = load_model()
44
  print("βœ… Model loaded and ready for inference!")
45
 
46
+ # warm-up on startup
47
+ load_classification_model()
48
+
49
+ # --- Load classification model & labels once at startup ---
50
+ MAP_PATH = os.path.join(os.path.dirname(__file__), "id2name.json")
51
+ with open(MAP_PATH, "r") as f:
52
+ ID2NAME = json.load(f)
53
+
54
+ cls_model = load_model()
55
+ print("βœ… Classification model loaded and ready for inference!")
56
+
57
  # ──────────────────────────────────────────────
58
  # API Endpoints
59
  # ──────────────────────────────────────────────
 
64
  result = predict_from_bytes(model_device_tuple, image_bytes)
65
  return result
66
 
67
+ @app.post("/classify")
68
+ async def classify(file: UploadFile = File(...)):
69
+ image_bytes = await file.read()
70
+ return classify_bytes(image_bytes)
71
+
72
  @app.get("/api")
73
  def api_root():
74
  return {"message": "NEMO Tools backend running."}
app/model.py CHANGED
@@ -25,7 +25,8 @@ CKPT_PATH = hf_hub_download(
25
  )
26
 
27
  PATCH_SIZE = 14
28
- IMAGE_SIZE = (1000,1000)
 
29
 
30
 
31
  # -------------------------------------------------------
@@ -46,7 +47,6 @@ def load_model():
46
  # Load weights
47
  state_dict = load_file(CKPT_PATH)
48
  keys_list = list(state_dict.keys())
49
- print(f"Loaded {len(state_dict.keys())} weights from {CKPT_PATH}")
50
 
51
  # Handle "model." prefix if present
52
  if keys_list and "model." in keys_list[0]:
@@ -81,10 +81,6 @@ def preprocess_image(image_bytes):
81
  img = img[:, :w, :h].unsqueeze(0)
82
  return img, (w, h)
83
 
84
-
85
- # -------------------------------------------------------
86
- # Prediction logic (generate attention map)
87
- # -------------------------------------------------------
88
  def predict_from_bytes(model_device_tuple, image_bytes):
89
  model, device = model_device_tuple
90
  img, (w, h) = preprocess_image(image_bytes)
@@ -96,6 +92,7 @@ def predict_from_bytes(model_device_tuple, image_bytes):
96
  attentions = model.get_last_self_attention(x)
97
  nh = attentions.shape[1] # number of heads
98
 
 
99
  attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
100
  attentions = attentions.reshape(nh, w_featmap, h_featmap)
101
  attentions = nn.functional.interpolate(
@@ -104,22 +101,34 @@ def predict_from_bytes(model_device_tuple, image_bytes):
104
  mode="nearest"
105
  )[0].cpu().numpy()
106
 
107
- # Mean attention map
108
- mean_attention = np.mean(attentions, axis=0)
 
109
 
110
- # Normalize to [0,1] for visualization
111
- mean_attention_norm = (mean_attention - mean_attention.min()) / (
112
- mean_attention.max() - mean_attention.min() + 1e-8
113
- )
 
114
 
115
- # Apply colormap (viridis) and resize to match input image size
116
- heatmap = (cm.viridis(mean_attention_norm)[:, :, :3] * 255).astype(np.uint8)
117
- heatmap_img = Image.fromarray(heatmap).resize(IMAGE_SIZE)
 
 
 
 
 
 
 
 
118
 
119
- # Convert to base64
120
  buf = BytesIO()
121
- heatmap_img.save(buf, format="PNG")
122
  buf.seek(0)
123
- img_base64 = base64.b64encode(buf.getvalue()).decode("utf-8")
124
 
125
- return {"attention_map": img_base64}
 
 
 
 
25
  )
26
 
27
  PATCH_SIZE = 14
28
+ IMAGE_SIZE = (800,800)
29
+
30
 
31
 
32
  # -------------------------------------------------------
 
47
  # Load weights
48
  state_dict = load_file(CKPT_PATH)
49
  keys_list = list(state_dict.keys())
 
50
 
51
  # Handle "model." prefix if present
52
  if keys_list and "model." in keys_list[0]:
 
81
  img = img[:, :w, :h].unsqueeze(0)
82
  return img, (w, h)
83
 
 
 
 
 
84
  def predict_from_bytes(model_device_tuple, image_bytes):
85
  model, device = model_device_tuple
86
  img, (w, h) = preprocess_image(image_bytes)
 
92
  attentions = model.get_last_self_attention(x)
93
  nh = attentions.shape[1] # number of heads
94
 
95
+ # Reshape attention maps
96
  attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
97
  attentions = attentions.reshape(nh, w_featmap, h_featmap)
98
  attentions = nn.functional.interpolate(
 
101
  mode="nearest"
102
  )[0].cpu().numpy()
103
 
104
+ # --- Normalize and visualize ---
105
+ all_heads_base64 = []
106
+ original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
107
 
108
+ for i in range(nh):
109
+ head_attn = attentions[i]
110
+ head_norm = (head_attn - head_attn.min()) / (head_attn.max() - head_attn.min() + 1e-8)
111
+ heatmap = (cm.viridis(head_norm)[:, :, :3] * 255).astype(np.uint8)
112
+ heatmap_img = Image.fromarray(heatmap).resize(original_image.size, Image.BILINEAR)
113
 
114
+ buf = BytesIO()
115
+ heatmap_img.save(buf, format="PNG")
116
+ buf.seek(0)
117
+ head_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
118
+ all_heads_base64.append(head_b64)
119
+
120
+ # --- Mean attention map ---
121
+ mean_attention = np.mean(attentions, axis=0)
122
+ mean_norm = (mean_attention - mean_attention.min()) / (mean_attention.max() - mean_attention.min() + 1e-8)
123
+ heatmap = (cm.viridis(mean_norm)[:, :, :3] * 255).astype(np.uint8)
124
+ mean_img = Image.fromarray(heatmap).resize(original_image.size, Image.BILINEAR)
125
 
 
126
  buf = BytesIO()
127
+ mean_img.save(buf, format="PNG")
128
  buf.seek(0)
129
+ mean_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
130
 
131
+ return {
132
+ "mean_attention_map": mean_b64,
133
+ "head_attention_maps": all_heads_base64,
134
+ }
app/static/Correctindex.html ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="utf-8" />
5
+ <meta name="viewport" content="width=device-width, initial-scale=1" />
6
+ <title>NEMO Tools</title>
7
+
8
+ <!-- TailwindCSS CDN -->
9
+ <script src="https://cdn.tailwindcss.com"></script>
10
+ <style>
11
+ body {
12
+ background-image: url('static/background.jpg');
13
+ background-size: cover;
14
+ background-position: center;
15
+ background-attachment: fixed;
16
+ background-repeat: no-repeat;
17
+ color: #f9fafb;
18
+ }
19
+ body::before {
20
+ content: "";
21
+ position: fixed;
22
+ top: 0;
23
+ left: 0;
24
+ width: 100%;
25
+ height: 100%;
26
+ background: rgba(0, 10, 20, 0.3);
27
+ z-index: -1;
28
+ }
29
+ .card {
30
+ background: rgba(255, 255, 255, 0.12);
31
+ backdrop-filter: blur(10px);
32
+ border-radius: 20px;
33
+ box-shadow: 0 4px 20px rgba(0, 0, 0, 0.2);
34
+ color: #f1f1f1;
35
+ }
36
+ .nav-link {
37
+ color: #e0e0e0;
38
+ }
39
+ .nav-link.active {
40
+ color: #60a5fa;
41
+ border-bottom: 2px solid #60a5fa;
42
+ }
43
+ </style>
44
+ </head>
45
+
46
+ <body class="bg-gray-100 min-h-screen">
47
+ <header class="bg-white shadow-sm">
48
+ <div class="max-w-6xl mx-auto px-4 py-4 flex items-center justify-between">
49
+ <div class="flex items-center gap-3">
50
+ <img src="/static/assets/logo.png" alt="NEMO logo" class="h-10 w-10 rounded-full shadow-sm" />
51
+ <div>
52
+ <h1 class="text-lg font-bold text-indigo-600">NEMO tools</h1>
53
+ <p class="text-xs text-gray-400">DINOv2 visualisation sandbox</p>
54
+ </div>
55
+ </div>
56
+
57
+ <nav class="flex gap-3">
58
+ <button id="tab-research" class="tab-btn text-gray-500 hover:text-indigo-600 px-3 py-1 rounded-md text-sm font-medium" onclick="showTab('research')">
59
+ Research
60
+ </button>
61
+ <button id="tab-people" class="tab-btn text-gray-500 hover:text-indigo-600 px-3 py-1 rounded-md text-sm font-medium" onclick="showTab('people')">
62
+ People
63
+ </button>
64
+ <button id="tab-tools" class="tab-btn text-indigo-600 bg-indigo-50 px-3 py-1 rounded-md text-sm font-medium" onclick="showTab('tools')">
65
+ Tools
66
+ </button>
67
+ </nav>
68
+ </div>
69
+ </header>
70
+
71
+ <main class="max-w-6xl mx-auto px-4 py-8">
72
+ <section id="page-research" class="hidden">
73
+ <h2 class="text-2xl font-semibold text-gray-800 mb-4">Research</h2>
74
+ <p class="text-gray-500 mb-4">We can list publications, datasets, and experiment notes here later.</p>
75
+ <div class="bg-white rounded-xl shadow p-6 text-gray-400 text-sm">(placeholder) Add your papers, abstracts, or GitHub repos here.</div>
76
+ </section>
77
+
78
+ <section id="page-people" class="hidden">
79
+ <h2 class="text-2xl font-semibold text-gray-800 mb-4">People</h2>
80
+ <p class="text-gray-500 mb-4">We can add your name, collaborators, and links to profiles here.</p>
81
+ <div class="grid gap-4 md:grid-cols-3">
82
+ <div class="bg-white rounded-xl shadow p-5">
83
+ <h3 class="font-semibold text-gray-700">You</h3>
84
+ <p class="text-gray-400 text-sm">Project owner</p>
85
+ </div>
86
+ <div class="bg-white rounded-xl shadow p-5">
87
+ <h3 class="font-semibold text-gray-700">To add</h3>
88
+ <p class="text-gray-400 text-sm">Collaborators / advisors</p>
89
+ </div>
90
+ <div class="bg-white rounded-xl shadow p-5">
91
+ <h3 class="font-semibold text-gray-700">Contact</h3>
92
+ <p class="text-gray-400 text-sm">Add email / GitHub here</p>
93
+ </div>
94
+ </div>
95
+ </section>
96
+
97
+ <section id="page-tools">
98
+ <div class="bg-white shadow-lg rounded-2xl p-8 w-full">
99
+ <h2 class="text-2xl font-bold text-indigo-600 mb-6 flex items-center gap-2"><span>🧰 Tools</span></h2>
100
+ <div class="flex gap-3 mb-6 border-b pb-2">
101
+ <button id="sub-attention" class="subtab-btn text-indigo-600 font-medium border-b-2 border-indigo-600 pb-1" onclick="showSubTool('attention')">🧠 Attention Maps</button>
102
+ <button id="sub-classification" class="subtab-btn text-gray-500 hover:text-indigo-600 pb-1" onclick="showSubTool('classification')">πŸ” Run Classification</button>
103
+ </div>
104
+
105
+ <!-- Attention Tool -->
106
+ <div id="tool-attention">
107
+ <div class="flex flex-col items-center gap-4 mb-6 justify-center">
108
+ <input id="file" type="file" accept="image/*"
109
+ class="block w-full md:w-auto text-sm text-gray-600
110
+ file:mr-4 file:py-2 file:px-4
111
+ file:rounded-full file:border-0
112
+ file:text-sm file:font-semibold
113
+ file:bg-indigo-50 file:text-indigo-600
114
+ hover:file:bg-indigo-100"
115
+ onchange="onAttentionImageSelected()" />
116
+ </div>
117
+
118
+ <div id="attention-extra" class="hidden flex flex-col items-center gap-6">
119
+ <button id="runButton" onclick="runActiveTool()" class="px-8 py-3 bg-indigo-600 text-white text-lg font-semibold rounded-full shadow-md hover:bg-indigo-700 transition">
120
+ ▢️ Run Attention
121
+ </button>
122
+ <div class="flex flex-col md:flex-row justify-center items-start gap-6">
123
+ <div class="flex flex-col items-center">
124
+ <h4 class="text-gray-600 mb-2 font-medium">Original Image</h4>
125
+ <img id="original" class="rounded-lg shadow-md max-w-xs hidden" />
126
+ </div>
127
+ <div class="flex flex-col items-center">
128
+ <h4 class="text-gray-600 mb-2 font-medium">Mean Attention Map</h4>
129
+ <img id="output" class="rounded-lg shadow-md max-w-xs hidden" />
130
+ </div>
131
+ </div>
132
+ <div id="headsContainer" class="hidden mt-8">
133
+ <h4 class="text-gray-600 mb-3 font-medium text-center">All Attention Heads</h4>
134
+ <div id="headsGrid" class="flex flex-wrap justify-center gap-3"></div>
135
+ </div>
136
+ <p id="status" class="text-center text-gray-500 mt-2 text-sm"></p>
137
+ </div>
138
+ </div>
139
+
140
+ <!-- Classification Tool -->
141
+ <div id="tool-classification" class="hidden">
142
+ <div class="flex flex-col items-center gap-4 mb-6 justify-center">
143
+
144
+ <input id="cls-file" type="file" accept="image/*"
145
+ class="block w-full md:w-auto text-sm text-gray-600
146
+ file:mr-4 file:py-2 file:px-4
147
+ file:rounded-full file:border-0
148
+ file:text-sm file:font-semibold
149
+ file:bg-indigo-50 file:text-indigo-600
150
+ hover:file:bg-indigo-100" />
151
+ <button id="cls-run" onclick="runClassification()" style="display:none;"
152
+ class="px-8 py-3 bg-green-600 text-white text-lg font-semibold rounded-full shadow-md hover:bg-green-700 transition">
153
+ ▢️ Run Classification
154
+ </button>
155
+ </div>
156
+ <div id="cls-result" class="text-center text-gray-700 mt-4 text-lg font-medium"></div>
157
+ </div>
158
+ </div>
159
+ </section>
160
+ </main>
161
+
162
+ <script>
163
+ function showTab(name) {
164
+ const tabs = ["research", "people", "tools"];
165
+ tabs.forEach(t => {
166
+ document.getElementById("page-" + t).classList.add("hidden");
167
+ document.getElementById("tab-" + t).classList.remove("bg-indigo-50", "text-indigo-600");
168
+ document.getElementById("tab-" + t).classList.add("text-gray-500");
169
+ });
170
+ document.getElementById("page-" + name).classList.remove("hidden");
171
+ document.getElementById("tab-" + name).classList.add("bg-indigo-50", "text-indigo-600");
172
+ document.getElementById("tab-" + name).classList.remove("text-gray-500");
173
+ }
174
+
175
+ let activeTool = "attention";
176
+ const runButton = () => document.getElementById("runButton");
177
+
178
+ function showSubTool(name) {
179
+ const subs = ["attention", "classification"];
180
+ subs.forEach(s => {
181
+ document.getElementById("tool-" + s).classList.add("hidden");
182
+ document.getElementById("sub-" + s).classList.remove("text-indigo-600", "font-medium", "border-b-2", "border-indigo-600");
183
+ document.getElementById("sub-" + s).classList.add("text-gray-500");
184
+ });
185
+ document.getElementById("tool-" + name).classList.remove("hidden");
186
+ document.getElementById("sub-" + name).classList.add("text-indigo-600", "font-medium", "border-b-2", "border-indigo-600");
187
+ document.getElementById("sub-" + name).classList.remove("text-gray-500");
188
+ activeTool = name;
189
+ }
190
+
191
+ async function runActiveTool() {
192
+ const btn = runButton();
193
+ btn.disabled = true;
194
+ btn.textContent = "βŒ›οΈ Running...";
195
+ btn.classList.add("opacity-70", "cursor-not-allowed");
196
+ try {
197
+ if (activeTool === "attention") await sendAttention();
198
+ else await runClassification();
199
+ } catch (err) {
200
+ console.error(err);
201
+ alert("❌ Error while running model: " + err.message);
202
+ }
203
+ btn.disabled = false;
204
+ btn.classList.remove("opacity-70", "cursor-not-allowed");
205
+ btn.textContent = activeTool === "attention" ? "▢️ Run Attention" : "▢️ Run Classification";
206
+ }
207
+
208
+ function onAttentionImageSelected() {
209
+ const fileInput = document.getElementById("file");
210
+ const extra = document.getElementById("attention-extra");
211
+ const original = document.getElementById("original");
212
+ if (fileInput.files.length > 0) {
213
+ extra.classList.remove("hidden");
214
+ const reader = new FileReader();
215
+ reader.onload = e => {
216
+ original.src = e.target.result;
217
+ original.classList.remove("hidden");
218
+ };
219
+ reader.readAsDataURL(fileInput.files[0]);
220
+ } else {
221
+ extra.classList.add("hidden");
222
+ original.classList.add("hidden");
223
+ }
224
+ }
225
+
226
+ async function sendAttention() {
227
+ const fileInput = document.getElementById("file");
228
+ const file = fileInput.files[0];
229
+ const output = document.getElementById("output");
230
+ const status = document.getElementById("status");
231
+ const headsContainer = document.getElementById("headsContainer");
232
+ const headsGrid = document.getElementById("headsGrid");
233
+ if (!file) {
234
+ alert("Please choose an image first!");
235
+ return;
236
+ }
237
+ output.classList.add("hidden");
238
+ headsContainer.classList.add("hidden");
239
+ headsGrid.innerHTML = "";
240
+ status.textContent = "βŒ›οΈ Model is running...";
241
+ const fd = new FormData();
242
+ fd.append("file", file);
243
+ try {
244
+ const res = await fetch("/attention", { method: "POST", body: fd });
245
+ if (!res.ok) throw new Error(`Server error: ${res.status}`);
246
+ const json = await res.json();
247
+ output.src = "data:image/png;base64," + json.mean_attention_map;
248
+ output.classList.remove("hidden");
249
+ if (json.head_attention_maps && json.head_attention_maps.length > 0) {
250
+ json.head_attention_maps.forEach((headB64, i) => {
251
+ const img = document.createElement("img");
252
+ img.src = "data:image/png;base64," + headB64;
253
+ img.alt = `Head ${i + 1}`;
254
+ img.className = "rounded-md shadow-sm";
255
+ img.style.width = "120px";
256
+ img.style.transition = "transform 0.2s";
257
+ img.onmouseenter = () => (img.style.transform = "scale(1.1)");
258
+ img.onmouseleave = () => (img.style.transform = "scale(1)");
259
+ headsGrid.appendChild(img);
260
+ });
261
+ headsContainer.classList.remove("hidden");
262
+ }
263
+ status.textContent = "βœ… Done!";
264
+ } catch (err) {
265
+ console.error(err);
266
+ status.textContent = "❌ Error: " + err.message;
267
+ }
268
+ }
269
+
270
+ async function runClassification() {
271
+ const fileInput = document.getElementById("cls-file");
272
+ const file = fileInput.files[0];
273
+ const result = document.getElementById("cls-result");
274
+ const btn = document.getElementById("cls-run");
275
+
276
+ if (!file) return alert("Please choose an image to classify!");
277
+
278
+ // Disable button + show "Running..."
279
+ btn.disabled = true;
280
+ btn.textContent = "⏳ Running...";
281
+ btn.classList.add("opacity-70", "cursor-not-allowed");
282
+
283
+ result.textContent = "βŒ›οΈ Model is running...";
284
+
285
+ const fd = new FormData();
286
+ fd.append("file", file);
287
+
288
+ try {
289
+ const res = await fetch("/classify", { method: "POST", body: fd });
290
+ if (!res.ok) throw new Error(`Server error: ${res.status}`);
291
+ const json = await res.json();
292
+
293
+ if (json.top5) {
294
+ result.innerHTML = `
295
+ <h3 class="font-semibold text-indigo-600 mb-2">Top-5 Predictions</h3>
296
+ ${json.top5.map(p => `
297
+ <div class="flex justify-between border-b py-1">
298
+ <span>${p.rank}. ${p.label}</span>
299
+ <span class="text-gray-500">${(p.score * 100).toFixed(2)}%</span>
300
+ </div>
301
+ `).join("")}
302
+ `;
303
+ } else {
304
+ result.textContent = "No predictions returned.";
305
+ }
306
+
307
+ if (json.plot) {
308
+ const plotImg = document.createElement("img");
309
+ plotImg.src = json.plot;
310
+ plotImg.alt = "Top-3 Predicted Species";
311
+ plotImg.style.display = "block";
312
+ plotImg.style.margin = "20px auto";
313
+ plotImg.style.maxWidth = "800px";
314
+ result.appendChild(plotImg);
315
+ }
316
+ } catch (err) {
317
+ console.error(err);
318
+ result.textContent = "❌ Error: " + err.message;
319
+ } finally {
320
+ // Restore button state
321
+ btn.disabled = false;
322
+ btn.classList.remove("opacity-70", "cursor-not-allowed");
323
+ btn.textContent = "▢️ Run Classification";
324
+ }
325
+ }
326
+
327
+
328
+ // Show classification button only after image is selected
329
+ document.addEventListener("DOMContentLoaded", () => {
330
+ const fileInput = document.getElementById("cls-file");
331
+ const runBtn = document.getElementById("cls-run");
332
+ fileInput.addEventListener("change", () => {
333
+ if (fileInput.files && fileInput.files.length > 0) {
334
+ runBtn.style.display = "inline-block";
335
+ } else {
336
+ runBtn.style.display = "none";
337
+ }
338
+ });
339
+ });
340
+
341
+ showTab("tools");
342
+ showSubTool("attention");
343
+ </script>
344
+ </body>
345
+ </html>
app/static/assets/logo.png ADDED

Git LFS Details

  • SHA256: 41b842d01c04524042e79e0ff5d0861b4902b910f53c4b7165bc019837cccfe1
  • Pointer size: 131 Bytes
  • Size of remote file: 159 kB
app/static/background.jpg ADDED
app/static/index.html CHANGED
@@ -1,364 +1,266 @@
1
  <!doctype html>
2
  <html lang="en">
3
- <head>
4
- <meta charset="utf-8" />
5
- <meta name="viewport" content="width=device-width, initial-scale=1" />
6
- <title>NEMO Tools</title>
7
-
8
- <!-- TailwindCSS CDN -->
9
- <script src="https://cdn.tailwindcss.com"></script>
10
- <style>
11
- body {
12
- background-image: url('/background.jpg');
13
- background-size: cover;
14
- background-position: center;
15
- background-attachment: fixed;
16
- background-repeat: no-repeat;
17
- color: #f9fafb; /* soft white text for better contrast */
18
- }
19
-
20
- /* Add slight overlay to make text pop */
21
- body::before {
22
- content: "";
23
- position: fixed;
24
- top: 0;
25
- left: 0;
26
- width: 100%;
27
- height: 100%;
28
- background: rgba(0, 10, 20, 0.3); /* deep ocean tint */
29
- z-index: -1;
30
- }
31
-
32
- .card {
33
- background: rgba(255, 255, 255, 0.12);
34
- backdrop-filter: blur(10px);
35
- border-radius: 20px;
36
- box-shadow: 0 4px 20px rgba(0, 0, 0, 0.2);
37
- color: #f1f1f1;
38
- }
39
-
40
- .nav-link {
41
- color: #e0e0e0;
42
- }
43
-
44
- .nav-link.active {
45
- color: #60a5fa;
46
- border-bottom: 2px solid #60a5fa;
47
- }
48
- </style>
49
-
50
- </head>
51
-
52
- <body class="bg-gray-100 min-h-screen">
53
- <!-- Header -->
54
- <header class="bg-white shadow-sm">
55
- <div class="max-w-6xl mx-auto px-4 py-4 flex items-center justify-between">
56
- <!-- Logo and title -->
57
- <div class="flex items-center gap-3">
58
- <img src="/static/assets/logo.png" alt="NEMO logo" class="h-10 w-10 rounded-full shadow-sm" />
59
- <div>
60
- <h1 class="text-lg font-bold text-indigo-600">NEMO tools</h1>
61
- <p class="text-xs text-gray-400">DINOv2 visualisation sandbox</p>
62
- </div>
63
  </div>
64
-
65
- <!-- Top navigation -->
66
- <nav class="flex gap-3">
67
- <button id="tab-research" class="tab-btn text-gray-500 hover:text-indigo-600 px-3 py-1 rounded-md text-sm font-medium" onclick="showTab('research')">
68
- Research
69
- </button>
70
- <button id="tab-people" class="tab-btn text-gray-500 hover:text-indigo-600 px-3 py-1 rounded-md text-sm font-medium" onclick="showTab('people')">
71
- People
72
- </button>
73
- <button id="tab-tools" class="tab-btn text-indigo-600 bg-indigo-50 px-3 py-1 rounded-md text-sm font-medium" onclick="showTab('tools')">
74
- Tools
75
- </button>
76
- </nav>
77
  </div>
78
- </header>
79
-
80
- <!-- Main content -->
81
- <main class="max-w-6xl mx-auto px-4 py-8">
82
- <!-- Research tab -->
83
- <section id="page-research" class="hidden">
84
- <h2 class="text-2xl font-semibold text-gray-800 mb-4">Research</h2>
85
- <p class="text-gray-500 mb-4">
86
- We can list publications, datasets, and experiment notes here later.
87
- </p>
88
- <div class="bg-white rounded-xl shadow p-6 text-gray-400 text-sm">
89
- (placeholder) Add your papers, abstracts, or GitHub repos here.
90
- </div>
91
- </section>
92
-
93
- <!-- People tab -->
94
- <section id="page-people" class="hidden">
95
- <h2 class="text-2xl font-semibold text-gray-800 mb-4">People</h2>
96
- <p class="text-gray-500 mb-4">
97
- We can add your name, collaborators, and links to profiles here.
98
- </p>
99
- <div class="grid gap-4 md:grid-cols-3">
100
- <div class="bg-white rounded-xl shadow p-5">
101
- <h3 class="font-semibold text-gray-700">You</h3>
102
- <p class="text-gray-400 text-sm">Project owner</p>
103
- </div>
104
- <div class="bg-white rounded-xl shadow p-5">
105
- <h3 class="font-semibold text-gray-700">To add</h3>
106
- <p class="text-gray-400 text-sm">Collaborators / advisors</p>
107
- </div>
108
- <div class="bg-white rounded-xl shadow p-5">
109
- <h3 class="font-semibold text-gray-700">Contact</h3>
110
- <p class="text-gray-400 text-sm">Add email / GitHub here</p>
111
- </div>
112
  </div>
113
- </section>
114
 
115
- <!-- Tools tab -->
116
- <section id="page-tools">
117
- <div class="bg-white shadow-lg rounded-2xl p-8 w-full">
118
- <h2 class="text-2xl font-bold text-indigo-600 mb-6 flex items-center gap-2">
119
- <span>🧰 Tools</span>
120
- </h2>
121
-
122
- <!-- Sub-tab buttons -->
123
- <div class="flex gap-3 mb-6 border-b pb-2">
124
- <button id="sub-attention" class="subtab-btn text-indigo-600 font-medium border-b-2 border-indigo-600 pb-1"
125
- onclick="showSubTool('attention')">
126
- 🧠 Mean Attention Map
127
- </button>
128
- <button id="sub-classification" class="subtab-btn text-gray-500 hover:text-indigo-600 pb-1"
129
- onclick="showSubTool('classification')">
130
- πŸ” Run Classification
131
- </button>
132
- </div>
133
-
134
-
135
- <!-- Attention Map tool -->
136
- <div id="tool-attention">
137
- <!-- 1) always visible file input -->
138
- <div class="flex flex-col items-center gap-4 mb-6 justify-center">
139
- <input id="file" type="file" accept="image/*"
140
- class="block w-full md:w-auto text-sm text-gray-600
141
- file:mr-4 file:py-2 file:px-4
142
- file:rounded-full file:border-0
143
- file:text-sm file:font-semibold
144
- file:bg-indigo-50 file:text-indigo-600
145
- hover:file:bg-indigo-100"
146
- onchange="onAttentionImageSelected()" />
147
  </div>
 
148
 
149
- <!-- 2) hidden until user selects an image -->
150
- <div id="attention-extra" class="hidden flex flex-col items-center gap-6">
151
- <!-- run button -->
152
- <button id="runButton"
153
- onclick="runActiveTool()"
154
- class="px-8 py-3 bg-indigo-600 text-white text-lg font-semibold rounded-full shadow-md hover:bg-indigo-700 transition">
155
- ▢️ Run Attention
156
- </button>
157
 
158
- <!-- image area -->
159
- <div class="flex flex-col md:flex-row justify-center items-start gap-6">
160
- <div class="flex flex-col items-center">
161
- <h4 class="text-gray-600 mb-2 font-medium">Original Image</h4>
162
- <img id="original" class="rounded-lg shadow-md max-w-xs hidden" />
163
- </div>
164
- <div class="flex flex-col items-center">
165
- <h4 class="text-gray-600 mb-2 font-medium">Mean Attention Map</h4>
166
- <!-- start empty, we'll fill after run -->
167
- <img id="output" class="rounded-lg shadow-md max-w-xs hidden" />
168
- </div>
169
  </div>
170
-
171
- <p id="status" class="text-center text-gray-500 mt-2 text-sm"></p>
172
  </div>
173
- </div>
174
 
175
- <!-- Classification tool -->
176
- <!-- Classification tool -->
177
- <div id="tool-classification" class="hidden">
178
- <div class="flex flex-col md:flex-row md:items-center gap-4 mb-6 justify-center">
179
- <input id="cls-file" type="file" accept="image/*"
180
- class="block w-full md:w-auto text-sm text-gray-600
181
- file:mr-4 file:py-2 file:px-4
182
- file:rounded-full file:border-0
183
- file:text-sm file:font-semibold
184
- file:bg-indigo-50 file:text-indigo-600
185
- hover:file:bg-indigo-100" />
186
-
187
- <!-- βœ… Add this button -->
188
- <button id="cls-run-btn"
189
- onclick="runClassification()"
190
- class="px-8 py-3 bg-green-600 text-white text-lg font-semibold rounded-full shadow-md hover:bg-green-700 transition">
191
- ▢️ Run Classification
192
- </button>
193
  </div>
194
-
195
- <div id="cls-result" class="text-center text-gray-700 mt-4 text-lg font-medium"></div>
196
  </div>
197
-
198
-
199
-
200
- <!-- <div id="cls-result" class="text-center text-gray-700 mt-4 text-lg font-medium"></div> -->
201
- </div>
202
  </div>
203
- </section>
204
- </main>
205
-
206
- <!-- JS -->
207
- <script>
208
- // --- top navigation ---
209
- function showTab(name) {
210
- const tabs = ["research", "people", "tools"];
211
- tabs.forEach(t => {
212
- document.getElementById("page-" + t).classList.add("hidden");
213
- document.getElementById("tab-" + t).classList.remove("bg-indigo-50", "text-indigo-600");
214
- document.getElementById("tab-" + t).classList.add("text-gray-500");
215
- });
216
- document.getElementById("page-" + name).classList.remove("hidden");
217
- document.getElementById("tab-" + name).classList.add("bg-indigo-50", "text-indigo-600");
218
- document.getElementById("tab-" + name).classList.remove("text-gray-500");
219
- }
220
 
221
- // --- tools logic ---
222
- let activeTool = "attention";
223
- const runButton = () => document.getElementById("runButton");
224
-
225
- function showSubTool(name) {
226
- const subs = ["attention", "classification"];
227
- subs.forEach(s => {
228
- document.getElementById("tool-" + s).classList.add("hidden");
229
- document.getElementById("sub-" + s).classList.remove("text-indigo-600", "font-medium", "border-b-2", "border-indigo-600");
230
- document.getElementById("sub-" + s).classList.add("text-gray-500");
231
- });
232
- document.getElementById("tool-" + name).classList.remove("hidden");
233
- document.getElementById("sub-" + name).classList.add("text-indigo-600", "font-medium", "border-b-2", "border-indigo-600");
234
- document.getElementById("sub-" + name).classList.remove("text-gray-500");
235
 
236
- activeTool = name;
237
 
238
- const btn = runButton();
239
- if (name === "attention") {
240
- btn.textContent = "▢️ Run Attention";
241
- btn.className = "px-8 py-3 bg-indigo-600 text-white text-lg font-semibold rounded-full shadow-md hover:bg-indigo-700 transition";
242
- } else {
243
- btn.textContent = "▢️ Run Classification";
244
- btn.className = "px-8 py-3 bg-green-600 text-white text-lg font-semibold rounded-full shadow-md hover:bg-green-700 transition";
245
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  }
247
-
248
- async function runActiveTool() {
249
- const btn = runButton();
250
- btn.disabled = true;
251
- btn.textContent = "βŒ›οΈ Running...";
252
- btn.classList.add("opacity-70", "cursor-not-allowed");
253
-
254
- try {
255
- if (activeTool === "attention") await sendAttention();
256
- else await runClassification();
257
- } catch (err) {
258
- console.error(err);
259
- alert("❌ Error while running model: " + err.message);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  }
261
-
 
 
 
262
  btn.disabled = false;
 
263
  btn.classList.remove("opacity-70", "cursor-not-allowed");
264
- btn.textContent = activeTool === "attention" ? "▢️ Run Attention" : "▢️ Run Classification";
265
  }
266
-
267
- // --- Attention Map Tool ---
268
- // called when the user picks a file in the Attention tool
269
- function onAttentionImageSelected() {
270
- const fileInput = document.getElementById("file");
271
- const extra = document.getElementById("attention-extra");
272
- const original = document.getElementById("original");
273
-
274
- if (fileInput.files.length > 0) {
275
- // show the extra area (run button + image boxes)
276
- extra.classList.remove("hidden");
277
-
278
- // preview original immediately
279
- const reader = new FileReader();
280
- reader.onload = e => {
281
- original.src = e.target.result;
282
- original.classList.remove("hidden");
283
- };
284
- reader.readAsDataURL(fileInput.files[0]);
285
- } else {
286
- extra.classList.add("hidden");
287
- original.classList.add("hidden");
288
- }
289
  }
290
-
291
- // unified run button will call this when attention is active
292
- async function sendAttention() {
293
- const fileInput = document.getElementById("file");
294
- const file = fileInput.files[0];
295
- const output = document.getElementById("output");
296
- const status = document.getElementById("status");
297
-
298
- if (!file) {
299
- alert("Please choose an image first!");
300
- return;
301
- }
302
-
303
- // show loading
304
- status.textContent = "βŒ›οΈ Model is running...";
305
- output.classList.add("hidden");
306
-
307
- const fd = new FormData();
308
- fd.append("file", file);
309
-
310
- try {
311
- const res = await fetch("/attention", { method: "POST", body: fd });
312
- if (!res.ok) throw new Error(`Server error: ${res.status}`);
313
- const json = await res.json();
314
-
315
- // show result
316
- output.src = "data:image/png;base64," + json.attention_map;
317
- output.classList.remove("hidden");
318
- status.textContent = "βœ… Done!";
319
- } catch (err) {
320
- console.error(err);
321
- status.textContent = "❌ Error: " + err.message;
 
 
 
322
  }
 
 
 
 
 
 
323
  }
 
324
 
 
 
 
 
325
 
326
- // --- Classification Tool ---
327
- async function runClassification() {
328
- const fileInput = document.getElementById("cls-file");
329
- const file = fileInput.files[0];
330
- const result = document.getElementById("cls-result");
331
-
332
- if (!file) return alert("Please choose an image to classify!");
333
-
334
- result.textContent = "βŒ›οΈ Model is running...";
335
-
336
- const fd = new FormData();
337
- fd.append("file", file);
338
-
339
- try {
340
- const res = await fetch("/attention", { method: "POST", body: fd }); // βœ… must match FastAPI route
341
- if (!res.ok) throw new Error(`Server error: ${res.status}`);
342
- const json = await res.json();
343
-
344
- // βœ… display top-5 predictions if available
345
- if (json.predictions) {
346
- result.innerHTML = "<h3 class='font-semibold text-indigo-600 mb-2'>Top-5 Predictions:</h3>" +
347
- json.predictions.map(p =>
348
- `<div>${p.label} β€” ${(p.confidence * 100).toFixed(2)}%</div>`
349
- ).join("");
350
- } else {
351
- result.textContent = "βœ… Predicted class: " + json.label;
352
- }
353
- } catch (err) {
354
- console.error(err);
355
- result.textContent = "❌ Error: " + err.message;
356
- }
357
- }
358
-
359
- // Initialize default tab
360
- showTab("tools");
361
- showSubTool("attention");
362
- </script>
363
- </body>
364
  </html>
 
1
  <!doctype html>
2
  <html lang="en">
3
+ <head>
4
+ <meta charset="utf-8" />
5
+ <meta name="viewport" content="width=device-width, initial-scale=1" />
6
+ <title>NEMO Tools</title>
7
+
8
+ <!-- TailwindCSS -->
9
+ <script src="https://cdn.tailwindcss.com"></script>
10
+ <style>
11
+ body {
12
+ background-image: url('/static/background.jpg');
13
+ background-size: cover;
14
+ background-position: center;
15
+ background-attachment: fixed;
16
+ background-repeat: no-repeat;
17
+ color: #f9fafb;
18
+ }
19
+ body::before {
20
+ content: "";
21
+ position: fixed;
22
+ top: 0; left: 0;
23
+ width: 100%; height: 100%;
24
+ background: rgba(0, 10, 20, 0.3);
25
+ z-index: -1;
26
+ }
27
+ </style>
28
+ </head>
29
+
30
+ <body class="bg-gray-100 min-h-screen">
31
+
32
+ <!-- Header -->
33
+ <header class="bg-white shadow-sm">
34
+ <div class="max-w-6xl mx-auto px-4 py-4 flex items-center justify-between">
35
+ <div class="flex items-center gap-3">
36
+ <img src="/static/assets/logo.png" alt="NEMO logo" class="h-10 w-10 rounded-full shadow-sm" />
37
+ <div>
38
+ <h1 class="text-lg font-bold text-indigo-600">NEMO tools</h1>
39
+ <p class="text-xs text-gray-400">DINOv2 visualisation sandbox</p>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  </div>
42
+ <nav class="flex gap-3">
43
+ <button id="tab-research" class="tab-btn text-gray-500 hover:text-indigo-600 px-3 py-1 rounded-md text-sm font-medium" onclick="showTab('research')">Research</button>
44
+ <button id="tab-people" class="tab-btn text-gray-500 hover:text-indigo-600 px-3 py-1 rounded-md text-sm font-medium" onclick="showTab('people')">People</button>
45
+ <button id="tab-tools" class="tab-btn text-indigo-600 bg-indigo-50 px-3 py-1 rounded-md text-sm font-medium" onclick="showTab('tools')">Tools</button>
46
+ </nav>
47
+ </div>
48
+ </header>
49
+
50
+ <!-- Main Content -->
51
+ <main class="max-w-6xl mx-auto px-4 py-8">
52
+ <section id="page-tools">
53
+ <div class="bg-white shadow-lg rounded-2xl p-8 w-full">
54
+ <h2 class="text-2xl font-bold text-indigo-600 mb-6 flex items-center gap-2">🧰 Tools</h2>
55
+
56
+ <!-- Subtabs -->
57
+ <div class="flex gap-3 mb-6 border-b pb-2">
58
+ <button id="sub-attention" class="subtab-btn text-indigo-600 font-medium border-b-2 border-indigo-600 pb-1" onclick="showSubTool('attention')">🧠 Mean Attention Map</button>
59
+ <button id="sub-classification" class="subtab-btn text-gray-500 hover:text-indigo-600 pb-1" onclick="showSubTool('classification')">πŸ” Run Classification</button>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  </div>
 
61
 
62
+ <!-- 🧠 Attention Tool -->
63
+ <div id="tool-attention">
64
+ <div class="flex flex-col items-center gap-4 mb-6 justify-center">
65
+ <div class="w-full max-w-md mx-auto text-center border-2 border-dashed border-gray-300 rounded-xl p-6 bg-gray-50 hover:bg-gray-100 transition">
66
+ <p class="text-gray-700 font-semibold mb-2">Upload Image</p>
67
+ <p class="text-gray-400 text-sm mb-4">Supported formats: JPG, PNG</p>
68
+ <label for="file" class="inline-block bg-indigo-600 hover:bg-indigo-700 text-white font-semibold py-2 px-6 rounded-full cursor-pointer shadow-md transition">Choose File</label>
69
+ <input id="file" type="file" accept="image/*" class="hidden" onchange="onAttentionFileSelected()" />
70
+ <p id="attention-filename" class="text-gray-500 mt-3 text-sm"></p>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  </div>
72
+ </div>
73
 
74
+ <div id="attention-extra" class="hidden flex flex-col items-center gap-6">
75
+ <button id="runButton" onclick="runActiveTool()" class="px-8 py-3 bg-indigo-600 text-white text-lg font-semibold rounded-full shadow-md hover:bg-indigo-700 transition">▢️ Run Attention</button>
 
 
 
 
 
 
76
 
77
+ <div class="flex flex-col md:flex-row justify-center items-start gap-6">
78
+ <div class="flex flex-col items-center">
79
+ <h4 class="text-gray-600 mb-2 font-medium">Original Image</h4>
80
+ <img id="original" class="rounded-lg shadow-md max-w-xs hidden" />
81
+ </div>
82
+ <div class="flex flex-col items-center">
83
+ <h4 class="text-gray-600 mb-2 font-medium">Mean Attention Map</h4>
84
+ <img id="output" class="rounded-lg shadow-md max-w-xs hidden" />
 
 
 
85
  </div>
 
 
86
  </div>
 
87
 
88
+ <div id="headsContainer" class="hidden mt-8">
89
+ <h4 class="text-gray-600 mb-3 font-medium text-center">All Attention Heads</h4>
90
+ <div id="headsGrid" class="flex flex-wrap justify-center gap-3"></div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  </div>
92
+ <p id="status" class="text-center text-gray-500 mt-2 text-sm"></p>
 
93
  </div>
 
 
 
 
 
94
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ <!-- πŸ” Classification Tool -->
97
+ <div id="tool-classification" class="hidden flex flex-col items-center">
98
+ <div class="w-full max-w-md mx-auto text-center border-2 border-dashed border-gray-300 rounded-xl p-6 bg-gray-50 hover:bg-gray-100 transition">
99
+ <p class="text-gray-700 font-semibold mb-2">Upload Image</p>
100
+ <p class="text-gray-400 text-sm mb-4">Supported formats: JPG, PNG</p>
101
+ <label for="cls-file" class="inline-block bg-green-600 hover:bg-green-700 text-white font-semibold py-2 px-6 rounded-full cursor-pointer shadow-md transition">Choose File</label>
102
+ <input id="cls-file" type="file" accept="image/*" class="hidden" onchange="onClsFileSelected()" />
103
+ <p id="cls-filename" class="text-gray-500 mt-3 text-sm"></p>
104
+ </div>
 
 
 
 
 
105
 
106
+ <button id="cls-run" onclick="runClassification()" style="display:none;" class="mt-5 px-8 py-3 bg-green-600 text-white text-lg font-semibold rounded-full shadow-md hover:bg-green-700 transition">▢️ Run Classification</button>
107
 
108
+ <div id="cls-result" class="text-center text-gray-700 mt-6 text-lg font-medium"></div>
109
+ </div>
110
+ </div>
111
+ </section>
112
+ </main>
113
+
114
+ <!-- Scripts -->
115
+ <script>
116
+ let activeTool = "attention";
117
+
118
+ function showSubTool(name) {
119
+ const subs = ["attention", "classification"];
120
+ subs.forEach(s => {
121
+ document.getElementById("tool-" + s).classList.add("hidden");
122
+ document.getElementById("sub-" + s).classList.remove("text-indigo-600", "font-medium", "border-b-2", "border-indigo-600");
123
+ document.getElementById("sub-" + s).classList.add("text-gray-500");
124
+ });
125
+ document.getElementById("tool-" + name).classList.remove("hidden");
126
+ document.getElementById("sub-" + name).classList.add("text-indigo-600", "font-medium", "border-b-2", "border-indigo-600");
127
+ document.getElementById("sub-" + name).classList.remove("text-gray-500");
128
+ activeTool = name;
129
+ }
130
+
131
+ // 🧠 Attention Tool
132
+ function onAttentionFileSelected() {
133
+ const fileInput = document.getElementById("file");
134
+ const extra = document.getElementById("attention-extra");
135
+ const original = document.getElementById("original");
136
+ const nameEl = document.getElementById("attention-filename");
137
+
138
+ if (fileInput.files.length > 0) {
139
+ extra.classList.remove("hidden");
140
+ const file = fileInput.files[0];
141
+ nameEl.textContent = file.name;
142
+
143
+ const reader = new FileReader();
144
+ reader.onload = e => {
145
+ original.src = e.target.result;
146
+ original.classList.remove("hidden");
147
+ };
148
+ reader.readAsDataURL(file);
149
+ } else {
150
+ extra.classList.add("hidden");
151
+ original.classList.add("hidden");
152
+ nameEl.textContent = "";
153
  }
154
+ }
155
+
156
+ async function sendAttention() {
157
+ const file = document.getElementById("file").files[0];
158
+ const output = document.getElementById("output");
159
+ const status = document.getElementById("status");
160
+ const headsContainer = document.getElementById("headsContainer");
161
+ const headsGrid = document.getElementById("headsGrid");
162
+ const btn = document.getElementById("runButton");
163
+
164
+ if (!file) return alert("Please choose an image first!");
165
+ output.classList.add("hidden");
166
+ headsContainer.classList.add("hidden");
167
+ headsGrid.innerHTML = "";
168
+ // status.textContent = "βŒ›οΈ Model is running...";
169
+
170
+ btn.disabled = true;
171
+ btn.textContent = "⏳ Running...";
172
+ btn.classList.add("opacity-70", "cursor-not-allowed");
173
+
174
+ const fd = new FormData();
175
+ fd.append("file", file);
176
+
177
+ try {
178
+ const res = await fetch("/attention", { method: "POST", body: fd });
179
+ const json = await res.json();
180
+ output.src = "data:image/png;base64," + json.mean_attention_map;
181
+ output.classList.remove("hidden");
182
+ if (json.head_attention_maps) {
183
+ json.head_attention_maps.forEach((b64, i) => {
184
+ const img = document.createElement("img");
185
+ img.src = "data:image/png;base64," + b64;
186
+ img.className = "rounded-md shadow-sm w-[120px]";
187
+ headsGrid.appendChild(img);
188
+ });
189
+ headsContainer.classList.remove("hidden");
190
  }
191
+ status.textContent = "βœ… Done!";
192
+ } catch (err) {
193
+ status.textContent = "❌ Error: " + err.message;
194
+ } finally {
195
  btn.disabled = false;
196
+ btn.textContent = "▢️ Run Attention";
197
  btn.classList.remove("opacity-70", "cursor-not-allowed");
 
198
  }
199
+ }
200
+
201
+ // πŸ” Classification Tool
202
+ function onClsFileSelected() {
203
+ const fileInput = document.getElementById("cls-file");
204
+ const fileName = document.getElementById("cls-filename");
205
+ const runBtn = document.getElementById("cls-run");
206
+ if (fileInput.files.length > 0) {
207
+ fileName.textContent = fileInput.files[0].name;
208
+ runBtn.style.display = "block";
209
+ } else {
210
+ fileName.textContent = "";
211
+ runBtn.style.display = "none";
 
 
 
 
 
 
 
 
 
 
212
  }
213
+ }
214
+
215
+ async function runClassification() {
216
+ const file = document.getElementById("cls-file").files[0];
217
+ const result = document.getElementById("cls-result");
218
+ const btn = document.getElementById("cls-run");
219
+
220
+ if (!file) return alert("Please choose an image to classify!");
221
+
222
+ btn.disabled = true;
223
+ btn.textContent = "⏳ Running...";
224
+ btn.classList.add("opacity-70", "cursor-not-allowed");
225
+ // result.textContent = "βŒ›οΈ Model is running...";
226
+
227
+ const fd = new FormData();
228
+ fd.append("file", file);
229
+
230
+ try {
231
+ const res = await fetch("/classify", { method: "POST", body: fd });
232
+ const json = await res.json();
233
+ if (json.top5) {
234
+ result.innerHTML = `
235
+ <h3 class="font-semibold text-indigo-600 mb-2">Top-5 Predictions</h3>
236
+ ${json.top5.map(p => `
237
+ <div class="flex justify-between border-b py-1">
238
+ <span>${p.rank}. ${p.label}</span>
239
+ <span class="text-gray-500">${(p.score * 100).toFixed(2)}%</span>
240
+ </div>`).join("")}`;
241
+ } else result.textContent = "No predictions returned.";
242
+
243
+ if (json.plot) {
244
+ const plotImg = document.createElement("img");
245
+ plotImg.src = json.plot;
246
+ plotImg.className = "block mx-auto mt-6 max-w-2xl";
247
+ result.appendChild(plotImg);
248
  }
249
+ } catch (err) {
250
+ result.textContent = "❌ Error: " + err.message;
251
+ } finally {
252
+ btn.disabled = false;
253
+ btn.textContent = "▢️ Run Classification";
254
+ btn.classList.remove("opacity-70", "cursor-not-allowed");
255
  }
256
+ }
257
 
258
+ async function runActiveTool() {
259
+ if (activeTool === "attention") await sendAttention();
260
+ else await runClassification();
261
+ }
262
 
263
+ showSubTool("attention");
264
+ </script>
265
+ </body>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  </html>
app/test.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from safetensors.torch import load_file
2
+ from pprint import pprint
3
+ import os
4
+
5
+ path = os.path.expanduser("~/.cache/huggingface/hub/models--Arew99--dinov2-costum/snapshots/055a10af249d426a5b9a6ac07550f011e5739bbf/model.safetensors")
6
+ print(f"Loading checkpoint: {path}")
7
+ sd = load_file(path)
8
+ print(f"βœ… Loaded {len(sd)} tensors")
9
+ print("Sample keys:")
10
+ pprint(list(sd.keys())[:30])