Zhen Ye Claude Opus 4.6 (1M context) commited on
Commit
5338c46
·
1 Parent(s): 6180bac

feat(inspection): add depth analysis and attention heatmap endpoints (Phase 2)

Browse files

Backend:
- inspection/depth.py: on-demand DepthAnythingV2 inference with LRU caching,
viridis colorization, stats computation, raw/json/colorized response formats
- inspection/attention.py: GradCAM for DETR/GDINO, saliency for YOLO,
gaussian fallback, overlay generation, per-request caching
- inspection/router.py: GET /inspect/depth and GET /inspect/attention endpoints
- 28 new tests across depth and attention modules

Frontend:
- inspection-api.js: CORS fallback for depth binary format, explicit format params
- inspection-renders.js: depth legend with min/max meters, track depth stats,
attention legend with peak/avg intensity, improved alpha blending
- inspection.js: mode-specific loading messages

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

frontend/js/api/inspection-api.js CHANGED
@@ -73,6 +73,7 @@ APP.api.inspection.generateMask = async function (jobId, frameIdx, trackId) {
73
  */
74
  APP.api.inspection.fetchDepth = async function (jobId, frameIdx) {
75
  const base = APP.core.state.hf.baseUrl;
 
76
  const url = `${base}/inspect/depth/${jobId}/${frameIdx}?format=raw`;
77
  const resp = await fetch(url);
78
  if (!resp.ok) throw new Error(`Depth fetch failed: ${resp.status}`);
@@ -80,30 +81,63 @@ APP.api.inspection.fetchDepth = async function (jobId, frameIdx) {
80
  const contentType = resp.headers.get("content-type") || "";
81
 
82
  if (contentType.includes("application/octet-stream")) {
83
- // Binary float32 format
84
  const w = parseInt(resp.headers.get("X-Depth-Width"), 10);
85
  const h = parseInt(resp.headers.get("X-Depth-Height"), 10);
86
  const minD = parseFloat(resp.headers.get("X-Depth-Min"));
87
  const maxD = parseFloat(resp.headers.get("X-Depth-Max"));
88
  const buf = await resp.arrayBuffer();
89
- return { width: w, height: h, min: minD, max: maxD, data: new Float32Array(buf) };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  } else {
91
  // JSON + base64 format
92
  const json = await resp.json();
93
- const raw = atob(json.data_b64);
94
- const buf = new ArrayBuffer(raw.length);
95
- const view = new Uint8Array(buf);
96
- for (let i = 0; i < raw.length; i++) view[i] = raw.charCodeAt(i);
97
- return {
98
- width: json.width,
99
- height: json.height,
100
- min: json.min_depth,
101
- max: json.max_depth,
102
- data: new Float32Array(buf)
103
- };
104
  }
105
  };
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  /**
108
  * Fetch attention heatmap for a specific track on a specific frame.
109
  * @param {string} jobId
@@ -113,7 +147,7 @@ APP.api.inspection.fetchDepth = async function (jobId, frameIdx) {
113
  */
114
  APP.api.inspection.fetchAttention = async function (jobId, frameIdx, trackId) {
115
  const base = APP.core.state.hf.baseUrl;
116
- const url = `${base}/inspect/attention/${jobId}/${frameIdx}/${encodeURIComponent(trackId)}`;
117
  const resp = await fetch(url);
118
  if (!resp.ok) throw new Error(`Attention fetch failed: ${resp.status}`);
119
 
@@ -126,7 +160,9 @@ APP.api.inspection.fetchAttention = async function (jobId, frameIdx, trackId) {
126
  return {
127
  width: json.width,
128
  height: json.height,
129
- data: new Float32Array(buf)
 
 
130
  };
131
  };
132
 
 
73
  */
74
  APP.api.inspection.fetchDepth = async function (jobId, frameIdx) {
75
  const base = APP.core.state.hf.baseUrl;
76
+ // Try binary format first; fall back to JSON if CORS strips custom headers
77
  const url = `${base}/inspect/depth/${jobId}/${frameIdx}?format=raw`;
78
  const resp = await fetch(url);
79
  if (!resp.ok) throw new Error(`Depth fetch failed: ${resp.status}`);
 
81
  const contentType = resp.headers.get("content-type") || "";
82
 
83
  if (contentType.includes("application/octet-stream")) {
84
+ // Binary float32 format with metadata in headers
85
  const w = parseInt(resp.headers.get("X-Depth-Width"), 10);
86
  const h = parseInt(resp.headers.get("X-Depth-Height"), 10);
87
  const minD = parseFloat(resp.headers.get("X-Depth-Min"));
88
  const maxD = parseFloat(resp.headers.get("X-Depth-Max"));
89
  const buf = await resp.arrayBuffer();
90
+ const data = new Float32Array(buf);
91
+
92
+ // If CORS stripped headers, infer dimensions from data length
93
+ if (isNaN(w) || isNaN(h)) {
94
+ // Fall back to JSON format
95
+ return await APP.api.inspection._fetchDepthJson(jobId, frameIdx);
96
+ }
97
+
98
+ return {
99
+ width: w,
100
+ height: h,
101
+ min: isNaN(minD) ? 0 : minD,
102
+ max: isNaN(maxD) ? 1 : maxD,
103
+ data: data
104
+ };
105
  } else {
106
  // JSON + base64 format
107
  const json = await resp.json();
108
+ return APP.api.inspection._decodeDepthJson(json);
 
 
 
 
 
 
 
 
 
 
109
  }
110
  };
111
 
112
+ /**
113
+ * Fallback: fetch depth in JSON format if raw binary headers are unavailable.
114
+ */
115
+ APP.api.inspection._fetchDepthJson = async function (jobId, frameIdx) {
116
+ const base = APP.core.state.hf.baseUrl;
117
+ const url = `${base}/inspect/depth/${jobId}/${frameIdx}?format=json`;
118
+ const resp = await fetch(url);
119
+ if (!resp.ok) throw new Error(`Depth (JSON) fetch failed: ${resp.status}`);
120
+ const json = await resp.json();
121
+ return APP.api.inspection._decodeDepthJson(json);
122
+ };
123
+
124
+ /**
125
+ * Decode a JSON depth response to the standard { width, height, min, max, data } format.
126
+ */
127
+ APP.api.inspection._decodeDepthJson = function (json) {
128
+ const raw = atob(json.data_b64);
129
+ const buf = new ArrayBuffer(raw.length);
130
+ const view = new Uint8Array(buf);
131
+ for (let i = 0; i < raw.length; i++) view[i] = raw.charCodeAt(i);
132
+ return {
133
+ width: json.width,
134
+ height: json.height,
135
+ min: json.min_depth,
136
+ max: json.max_depth,
137
+ data: new Float32Array(buf)
138
+ };
139
+ };
140
+
141
  /**
142
  * Fetch attention heatmap for a specific track on a specific frame.
143
  * @param {string} jobId
 
147
  */
148
  APP.api.inspection.fetchAttention = async function (jobId, frameIdx, trackId) {
149
  const base = APP.core.state.hf.baseUrl;
150
+ const url = `${base}/inspect/attention/${jobId}/${frameIdx}/${encodeURIComponent(trackId)}?format=json`;
151
  const resp = await fetch(url);
152
  if (!resp.ok) throw new Error(`Attention fetch failed: ${resp.status}`);
153
 
 
160
  return {
161
  width: json.width,
162
  height: json.height,
163
+ data: new Float32Array(buf),
164
+ trackId: json.track_id || trackId,
165
+ frameIdx: json.frame_idx || frameIdx
166
  };
167
  };
168
 
frontend/js/ui/inspection-renders.js CHANGED
@@ -226,7 +226,7 @@ APP.ui.inspectionRenders._renderEdge = function (canvas, frameImg, edgeData, tra
226
  };
227
 
228
  /**
229
- * Render depth colormap (viridis-like palette).
230
  */
231
  APP.ui.inspectionRenders._renderDepth = function (canvas, frameImg, depthData, track) {
232
  if (!frameImg) return;
@@ -265,7 +265,14 @@ APP.ui.inspectionRenders._renderDepth = function (canvas, frameImg, depthData, t
265
  const id = img.data;
266
 
267
  for (let i = 0; i < dd.length; i++) {
268
- const t = APP.core.utils.clamp((dd[i] - minD) / range, 0, 1);
 
 
 
 
 
 
 
269
  const rgb = APP.ui.inspectionRenders._viridis(t);
270
  const oi = i * 4;
271
  id[oi] = rgb[0];
@@ -279,11 +286,113 @@ APP.ui.inspectionRenders._renderDepth = function (canvas, frameImg, depthData, t
279
  // Draw scaled to canvas size
280
  ctx.drawImage(depthCanvas, 0, 0, w, h);
281
 
 
282
  APP.ui.inspectionRenders._drawBBoxHighlight(ctx, track, w, h);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  };
284
 
285
  /**
286
- * Render attention heatmap overlaid on the base frame.
 
287
  */
288
  APP.ui.inspectionRenders._renderAttention = function (canvas, frameImg, attentionData, track) {
289
  if (!frameImg) return;
@@ -294,11 +403,13 @@ APP.ui.inspectionRenders._renderAttention = function (canvas, frameImg, attentio
294
  canvas.height = h;
295
  const ctx = canvas.getContext("2d");
296
 
297
- // Draw base frame
298
  ctx.drawImage(frameImg, 0, 0);
 
 
299
 
300
  if (!attentionData || !attentionData.data) {
301
- ctx.fillStyle = "rgba(0,0,0,0.6)";
302
  ctx.fillRect(0, 0, w, h);
303
  ctx.fillStyle = "#aaa";
304
  ctx.font = "14px monospace";
@@ -307,7 +418,7 @@ APP.ui.inspectionRenders._renderAttention = function (canvas, frameImg, attentio
307
  return;
308
  }
309
 
310
- // Render attention as a hot colormap overlay
311
  const aw = attentionData.width;
312
  const ah = attentionData.height;
313
  const ad = attentionData.data;
@@ -319,20 +430,87 @@ APP.ui.inspectionRenders._renderAttention = function (canvas, frameImg, attentio
319
  const img = hctx.createImageData(aw, ah);
320
  const id = img.data;
321
 
 
 
 
 
322
  for (let i = 0; i < ad.length; i++) {
323
  const t = APP.core.utils.clamp(ad[i], 0, 1);
 
 
324
  const rgb = APP.ui.inspectionRenders._inferno(t);
325
  const oi = i * 4;
326
  id[oi] = rgb[0];
327
  id[oi + 1] = rgb[1];
328
  id[oi + 2] = rgb[2];
329
- id[oi + 3] = Math.round(t * 180); // semi-transparent based on intensity
 
 
330
  }
331
 
332
  hctx.putImageData(img, 0, 0);
 
 
 
 
333
  ctx.drawImage(heatCanvas, 0, 0, w, h);
334
 
 
335
  APP.ui.inspectionRenders._drawBBoxHighlight(ctx, track, w, h);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  };
337
 
338
  /**
 
226
  };
227
 
228
  /**
229
+ * Render depth colormap (viridis-like palette) with scale legend.
230
  */
231
  APP.ui.inspectionRenders._renderDepth = function (canvas, frameImg, depthData, track) {
232
  if (!frameImg) return;
 
265
  const id = img.data;
266
 
267
  for (let i = 0; i < dd.length; i++) {
268
+ const val = dd[i];
269
+ // Handle NaN/Inf values — render as black
270
+ if (!isFinite(val)) {
271
+ const oi = i * 4;
272
+ id[oi] = 0; id[oi + 1] = 0; id[oi + 2] = 0; id[oi + 3] = 255;
273
+ continue;
274
+ }
275
+ const t = APP.core.utils.clamp((val - minD) / range, 0, 1);
276
  const rgb = APP.ui.inspectionRenders._viridis(t);
277
  const oi = i * 4;
278
  id[oi] = rgb[0];
 
286
  // Draw scaled to canvas size
287
  ctx.drawImage(depthCanvas, 0, 0, w, h);
288
 
289
+ // Draw bbox highlight for the selected track
290
  APP.ui.inspectionRenders._drawBBoxHighlight(ctx, track, w, h);
291
+
292
+ // Draw depth scale legend (vertical gradient bar on the right)
293
+ APP.ui.inspectionRenders._drawDepthLegend(ctx, w, h, minD, maxD);
294
+
295
+ // If track is selected, compute and show average depth in the bbox region
296
+ if (track && track.bbox && dw > 0 && dh > 0) {
297
+ APP.ui.inspectionRenders._drawTrackDepthStats(ctx, track, w, h, depthData);
298
+ }
299
+ };
300
+
301
+ /**
302
+ * Draw a vertical depth scale legend on the right side of the canvas.
303
+ */
304
+ APP.ui.inspectionRenders._drawDepthLegend = function (ctx, canvasW, canvasH, minD, maxD) {
305
+ const barW = 16;
306
+ const barH = Math.min(180, canvasH - 60);
307
+ const x = canvasW - barW - 30;
308
+ const y = 30;
309
+
310
+ // Background panel
311
+ ctx.fillStyle = "rgba(0, 0, 0, 0.55)";
312
+ ctx.fillRect(x - 6, y - 20, barW + 52, barH + 40);
313
+ ctx.strokeStyle = "rgba(255, 255, 255, 0.15)";
314
+ ctx.lineWidth = 1;
315
+ ctx.strokeRect(x - 6, y - 20, barW + 52, barH + 40);
316
+
317
+ // Draw gradient bar
318
+ for (let py = 0; py < barH; py++) {
319
+ const t = py / (barH - 1); // 0 = top (near/min), 1 = bottom (far/max)
320
+ const rgb = APP.ui.inspectionRenders._viridis(t);
321
+ ctx.fillStyle = `rgb(${rgb[0]}, ${rgb[1]}, ${rgb[2]})`;
322
+ ctx.fillRect(x, y + py, barW, 1);
323
+ }
324
+
325
+ // Border around bar
326
+ ctx.strokeStyle = "rgba(255, 255, 255, 0.3)";
327
+ ctx.lineWidth = 1;
328
+ ctx.strokeRect(x, y, barW, barH);
329
+
330
+ // Labels
331
+ ctx.fillStyle = "rgba(255, 255, 255, 0.8)";
332
+ ctx.font = "10px monospace";
333
+ ctx.textAlign = "left";
334
+ ctx.fillText(`${minD.toFixed(1)}m`, x + barW + 4, y + 10);
335
+ ctx.fillText(`${maxD.toFixed(1)}m`, x + barW + 4, y + barH);
336
+
337
+ // Title
338
+ ctx.fillStyle = "rgba(255, 255, 255, 0.6)";
339
+ ctx.font = "9px monospace";
340
+ ctx.textAlign = "center";
341
+ ctx.fillText("DEPTH", x + barW / 2, y - 6);
342
+ };
343
+
344
+ /**
345
+ * Compute and display average depth within the selected track's bounding box.
346
+ */
347
+ APP.ui.inspectionRenders._drawTrackDepthStats = function (ctx, track, canvasW, canvasH, depthData) {
348
+ const b = track.bbox;
349
+ const dw = depthData.width;
350
+ const dh = depthData.height;
351
+ const dd = depthData.data;
352
+
353
+ // Map normalized bbox to depth map coordinates
354
+ const x1 = Math.max(0, Math.floor(b.x * dw));
355
+ const y1 = Math.max(0, Math.floor(b.y * dh));
356
+ const x2 = Math.min(dw - 1, Math.floor((b.x + b.w) * dw));
357
+ const y2 = Math.min(dh - 1, Math.floor((b.y + b.h) * dh));
358
+
359
+ let sum = 0;
360
+ let count = 0;
361
+ let localMin = Infinity;
362
+ let localMax = -Infinity;
363
+
364
+ for (let py = y1; py <= y2; py++) {
365
+ for (let px = x1; px <= x2; px++) {
366
+ const val = dd[py * dw + px];
367
+ if (isFinite(val)) {
368
+ sum += val;
369
+ count++;
370
+ if (val < localMin) localMin = val;
371
+ if (val > localMax) localMax = val;
372
+ }
373
+ }
374
+ }
375
+
376
+ if (count === 0) return;
377
+
378
+ const avgDepth = sum / count;
379
+
380
+ // Draw stats label below the bbox
381
+ const bx = b.x * canvasW;
382
+ const by = (b.y + b.h) * canvasH;
383
+
384
+ ctx.fillStyle = "rgba(0, 0, 0, 0.7)";
385
+ ctx.fillRect(bx, by + 4, 160, 18);
386
+
387
+ ctx.fillStyle = "rgba(253, 231, 37, 0.9)"; // viridis yellow for contrast
388
+ ctx.font = "bold 11px monospace";
389
+ ctx.textAlign = "left";
390
+ ctx.fillText(`Depth: ${avgDepth.toFixed(1)}m (${localMin.toFixed(1)}-${localMax.toFixed(1)})`, bx + 4, by + 16);
391
  };
392
 
393
  /**
394
+ * Render attention heatmap (GradCAM) overlaid on the base frame.
395
+ * Uses inferno colormap with semi-transparent blending.
396
  */
397
  APP.ui.inspectionRenders._renderAttention = function (canvas, frameImg, attentionData, track) {
398
  if (!frameImg) return;
 
403
  canvas.height = h;
404
  const ctx = canvas.getContext("2d");
405
 
406
+ // Draw base frame (slightly dimmed to make heatmap more visible)
407
  ctx.drawImage(frameImg, 0, 0);
408
+ ctx.fillStyle = "rgba(0, 0, 0, 0.2)";
409
+ ctx.fillRect(0, 0, w, h);
410
 
411
  if (!attentionData || !attentionData.data) {
412
+ ctx.fillStyle = "rgba(0,0,0,0.5)";
413
  ctx.fillRect(0, 0, w, h);
414
  ctx.fillStyle = "#aaa";
415
  ctx.font = "14px monospace";
 
418
  return;
419
  }
420
 
421
+ // Render attention as inferno colormap overlay
422
  const aw = attentionData.width;
423
  const ah = attentionData.height;
424
  const ad = attentionData.data;
 
430
  const img = hctx.createImageData(aw, ah);
431
  const id = img.data;
432
 
433
+ // Track peak attention value for stats
434
+ let peakVal = 0;
435
+ let sumVal = 0;
436
+
437
  for (let i = 0; i < ad.length; i++) {
438
  const t = APP.core.utils.clamp(ad[i], 0, 1);
439
+ if (t > peakVal) peakVal = t;
440
+ sumVal += t;
441
  const rgb = APP.ui.inspectionRenders._inferno(t);
442
  const oi = i * 4;
443
  id[oi] = rgb[0];
444
  id[oi + 1] = rgb[1];
445
  id[oi + 2] = rgb[2];
446
+ // Alpha: low values are very transparent, high values are semi-opaque
447
+ // Use a power curve for better visual contrast
448
+ id[oi + 3] = Math.round(Math.pow(t, 0.7) * 200);
449
  }
450
 
451
  hctx.putImageData(img, 0, 0);
452
+
453
+ // Use bilinear upscaling for smooth heatmap (attention resolution is typically low)
454
+ ctx.imageSmoothingEnabled = true;
455
+ ctx.imageSmoothingQuality = "high";
456
  ctx.drawImage(heatCanvas, 0, 0, w, h);
457
 
458
+ // Draw bbox highlight
459
  APP.ui.inspectionRenders._drawBBoxHighlight(ctx, track, w, h);
460
+
461
+ // Draw attention legend and stats
462
+ const avgVal = ad.length > 0 ? sumVal / ad.length : 0;
463
+ APP.ui.inspectionRenders._drawAttentionLegend(ctx, w, h, peakVal, avgVal);
464
+ };
465
+
466
+ /**
467
+ * Draw an attention intensity legend with inferno colormap bar.
468
+ */
469
+ APP.ui.inspectionRenders._drawAttentionLegend = function (ctx, canvasW, canvasH, peakVal, avgVal) {
470
+ const barW = 16;
471
+ const barH = Math.min(140, canvasH - 60);
472
+ const x = canvasW - barW - 30;
473
+ const y = 30;
474
+
475
+ // Background panel
476
+ ctx.fillStyle = "rgba(0, 0, 0, 0.55)";
477
+ ctx.fillRect(x - 6, y - 20, barW + 60, barH + 56);
478
+ ctx.strokeStyle = "rgba(255, 255, 255, 0.15)";
479
+ ctx.lineWidth = 1;
480
+ ctx.strokeRect(x - 6, y - 20, barW + 60, barH + 56);
481
+
482
+ // Draw gradient bar (top = high attention, bottom = low)
483
+ for (let py = 0; py < barH; py++) {
484
+ const t = 1 - (py / (barH - 1)); // 0 = bottom (low), 1 = top (high)
485
+ const rgb = APP.ui.inspectionRenders._inferno(t);
486
+ ctx.fillStyle = `rgb(${rgb[0]}, ${rgb[1]}, ${rgb[2]})`;
487
+ ctx.fillRect(x, y + py, barW, 1);
488
+ }
489
+
490
+ // Border around bar
491
+ ctx.strokeStyle = "rgba(255, 255, 255, 0.3)";
492
+ ctx.lineWidth = 1;
493
+ ctx.strokeRect(x, y, barW, barH);
494
+
495
+ // Labels
496
+ ctx.fillStyle = "rgba(255, 255, 255, 0.8)";
497
+ ctx.font = "10px monospace";
498
+ ctx.textAlign = "left";
499
+ ctx.fillText("High", x + barW + 4, y + 10);
500
+ ctx.fillText("Low", x + barW + 4, y + barH);
501
+
502
+ // Title
503
+ ctx.fillStyle = "rgba(255, 255, 255, 0.6)";
504
+ ctx.font = "9px monospace";
505
+ ctx.textAlign = "center";
506
+ ctx.fillText("ATTENTION", x + barW / 2, y - 6);
507
+
508
+ // Stats
509
+ ctx.fillStyle = "rgba(252, 255, 164, 0.8)"; // inferno yellow
510
+ ctx.font = "10px monospace";
511
+ ctx.textAlign = "left";
512
+ ctx.fillText(`Peak: ${(peakVal * 100).toFixed(0)}%`, x - 2, y + barH + 18);
513
+ ctx.fillText(`Avg: ${(avgVal * 100).toFixed(0)}%`, x - 2, y + barH + 32);
514
  };
515
 
516
  /**
frontend/js/ui/inspection.js CHANGED
@@ -206,12 +206,31 @@ APP.ui.inspection._clearCaches = function () {
206
  };
207
 
208
  /**
209
- * Internal: show/hide loading indicator.
 
 
210
  */
211
- APP.ui.inspection._setLoading = function (loading) {
212
  const { $ } = APP.core.utils;
213
  const el = $("#inspectionLoading");
214
- if (el) el.style.display = loading ? "flex" : "none";
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  APP.core.state.inspection.loading = loading;
216
  };
217
 
@@ -254,7 +273,7 @@ APP.ui.inspection._loadAndRender = async function () {
254
  try {
255
  // --- Step 1: Ensure we have the base frame image (shared by most modes) ---
256
  if (!state.inspection._frameImg && mode !== "3d") {
257
- APP.ui.inspection._setLoading(true);
258
  const frameImg = await api.fetchFrame(jobId, frameIdx);
259
  state.inspection.frameImageUrl = frameImg.src;
260
  state.inspection._frameImg = frameImg;
@@ -262,7 +281,7 @@ APP.ui.inspection._loadAndRender = async function () {
262
 
263
  // --- Step 2: Fetch mode-specific data if not cached ---
264
  if (!cache[mode]) {
265
- APP.ui.inspection._setLoading(true);
266
 
267
  switch (mode) {
268
  case "seg":
 
206
  };
207
 
208
  /**
209
+ * Internal: show/hide loading indicator with optional mode-specific message.
210
+ * @param {boolean} loading
211
+ * @param {string} [mode] — if provided, shows a mode-specific message
212
  */
213
+ APP.ui.inspection._setLoading = function (loading, mode) {
214
  const { $ } = APP.core.utils;
215
  const el = $("#inspectionLoading");
216
+ if (el) {
217
+ el.style.display = loading ? "flex" : "none";
218
+ // Update loading message based on mode
219
+ const msgEl = el.querySelector("span");
220
+ if (msgEl && loading && mode) {
221
+ const modeMessages = {
222
+ seg: "Computing segmentation mask...",
223
+ edge: "Computing edges...",
224
+ depth: "Computing depth map...",
225
+ attention: "Generating attention heatmap...",
226
+ superres: "Enhancing resolution...",
227
+ "3d": "Generating point cloud..."
228
+ };
229
+ msgEl.textContent = modeMessages[mode] || "Loading...";
230
+ } else if (msgEl && !loading) {
231
+ msgEl.textContent = "Loading...";
232
+ }
233
+ }
234
  APP.core.state.inspection.loading = loading;
235
  };
236
 
 
273
  try {
274
  // --- Step 1: Ensure we have the base frame image (shared by most modes) ---
275
  if (!state.inspection._frameImg && mode !== "3d") {
276
+ APP.ui.inspection._setLoading(true, mode);
277
  const frameImg = await api.fetchFrame(jobId, frameIdx);
278
  state.inspection.frameImageUrl = frameImg.src;
279
  state.inspection._frameImg = frameImg;
 
281
 
282
  // --- Step 2: Fetch mode-specific data if not cached ---
283
  if (!cache[mode]) {
284
+ APP.ui.inspection._setLoading(true, mode);
285
 
286
  switch (mode) {
287
  case "seg":
inspection/attention.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GradCAM-style attention heatmap generation for detector models.
2
+
3
+ Produces per-object attention maps showing which regions of the input
4
+ image the detector model focused on when detecting a particular object.
5
+
6
+ For Transformers-based detectors (DETR, Grounding DINO) we use true
7
+ GradCAM by hooking the backbone's last feature layer. For Ultralytics
8
+ YOLO models we generate an activation-based saliency map from the
9
+ model's internal feature maps (no gradient needed since YOLO doesn't
10
+ easily support GradCAM due to its anchor-free detection head).
11
+ """
12
+
13
+ import base64
14
+ import logging
15
+ import threading
16
+ from typing import Dict, Optional, Tuple
17
+
18
+ import cv2
19
+ import numpy as np
20
+ import torch
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # ── In-memory attention cache ────────────────────────────────────
25
+ # Key: (job_id, frame_idx, track_id_str) Value: heatmap (HxW float32 0-1)
26
+ _attention_cache: Dict[Tuple[str, int, str], np.ndarray] = {}
27
+ _cache_lock = threading.RLock()
28
+ _MAX_CACHE_ENTRIES = 200
29
+
30
+
31
+ def get_cached_attention(
32
+ job_id: str, frame_idx: int, track_id: str
33
+ ) -> Optional[np.ndarray]:
34
+ """Return cached attention heatmap or None."""
35
+ with _cache_lock:
36
+ return _attention_cache.get((job_id, frame_idx, track_id))
37
+
38
+
39
+ def set_cached_attention(
40
+ job_id: str, frame_idx: int, track_id: str, heatmap: np.ndarray
41
+ ) -> None:
42
+ """Store attention heatmap in cache."""
43
+ with _cache_lock:
44
+ if len(_attention_cache) >= _MAX_CACHE_ENTRIES:
45
+ oldest = next(iter(_attention_cache))
46
+ del _attention_cache[oldest]
47
+ _attention_cache[(job_id, frame_idx, track_id)] = heatmap
48
+
49
+
50
+ def clear_attention_cache(job_id: Optional[str] = None) -> None:
51
+ """Clear attention cache for a specific job or all jobs."""
52
+ with _cache_lock:
53
+ if job_id is None:
54
+ _attention_cache.clear()
55
+ else:
56
+ keys_to_remove = [k for k in _attention_cache if k[0] == job_id]
57
+ for k in keys_to_remove:
58
+ del _attention_cache[k]
59
+
60
+
61
+ # ── GradCAM for HF Transformers models (DETR, Grounding DINO) ───
62
+
63
+
64
+ def _find_target_layer(model: torch.nn.Module) -> Optional[torch.nn.Module]:
65
+ """Find the last convolutional or attention layer suitable for GradCAM.
66
+
67
+ Tries several strategies in order:
68
+ 1. DETR ResNet backbone: model.model.backbone.conv_encoder.model.layer4
69
+ 2. Grounding DINO Swin backbone: last layer of backbone
70
+ 3. Generic: walk the model and find the last Conv2d layer
71
+ """
72
+ # Strategy 1: DETR backbone (ResNet)
73
+ try:
74
+ backbone = model.model.backbone
75
+ if hasattr(backbone, "conv_encoder"):
76
+ resnet = backbone.conv_encoder.model
77
+ if hasattr(resnet, "layer4"):
78
+ return resnet.layer4
79
+ except (AttributeError, TypeError):
80
+ pass
81
+
82
+ # Strategy 2: Grounding DINO Swin backbone
83
+ try:
84
+ backbone = model.model.backbone
85
+ if hasattr(backbone, "conv_encoder"):
86
+ swin = backbone.conv_encoder.model
87
+ if hasattr(swin, "layers"):
88
+ layers = list(swin.layers)
89
+ if layers:
90
+ return layers[-1]
91
+ if hasattr(swin, "encoder") and hasattr(swin.encoder, "layers"):
92
+ layers = list(swin.encoder.layers)
93
+ if layers:
94
+ return layers[-1]
95
+ except (AttributeError, TypeError):
96
+ pass
97
+
98
+ # Strategy 3: Generic — find the last Conv2d
99
+ last_conv = None
100
+ for module in model.modules():
101
+ if isinstance(module, torch.nn.Conv2d):
102
+ last_conv = module
103
+ return last_conv
104
+
105
+
106
+ class GradCAMExtractor:
107
+ """Extract GradCAM heatmaps from a PyTorch model.
108
+
109
+ Usage:
110
+ extractor = GradCAMExtractor(model, target_layer)
111
+ heatmap = extractor.generate(input_tensor, target_bbox)
112
+ extractor.release() # remove hooks
113
+ """
114
+
115
+ def __init__(self, model: torch.nn.Module, target_layer: torch.nn.Module):
116
+ self.model = model
117
+ self.target_layer = target_layer
118
+ self._activations: Optional[torch.Tensor] = None
119
+ self._gradients: Optional[torch.Tensor] = None
120
+
121
+ # Register hooks
122
+ self._fwd_hook = target_layer.register_forward_hook(self._save_activation)
123
+ self._bwd_hook = target_layer.register_full_backward_hook(self._save_gradient)
124
+
125
+ def _save_activation(self, module, input, output):
126
+ if isinstance(output, torch.Tensor):
127
+ self._activations = output.detach()
128
+ elif isinstance(output, (tuple, list)) and len(output) > 0:
129
+ self._activations = output[0].detach()
130
+
131
+ def _save_gradient(self, module, grad_input, grad_output):
132
+ if isinstance(grad_output, (tuple, list)) and len(grad_output) > 0:
133
+ self._gradients = grad_output[0].detach()
134
+ elif isinstance(grad_output, torch.Tensor):
135
+ self._gradients = grad_output.detach()
136
+
137
+ def generate(
138
+ self,
139
+ input_tensor: torch.Tensor,
140
+ target_bbox: list,
141
+ frame_h: int,
142
+ frame_w: int,
143
+ ) -> np.ndarray:
144
+ """Generate a GradCAM heatmap for a target bounding box.
145
+
146
+ Args:
147
+ input_tensor: Preprocessed model input tensor.
148
+ target_bbox: [x1, y1, x2, y2] in original frame pixel coords.
149
+ frame_h: Original frame height.
150
+ frame_w: Original frame width.
151
+
152
+ Returns:
153
+ HxW float32 array normalized to [0, 1], at the model's
154
+ feature map resolution (upscaled to frame size).
155
+ """
156
+ self.model.zero_grad()
157
+ self._activations = None
158
+ self._gradients = None
159
+
160
+ # Enable gradients temporarily
161
+ was_training = self.model.training
162
+ self.model.eval()
163
+
164
+ # Forward pass with gradients enabled on input
165
+ with torch.enable_grad():
166
+ outputs = self.model(**{k: v for k, v in input_tensor.items()})
167
+
168
+ if self._activations is None:
169
+ logger.warning("GradCAM: no activations captured; returning uniform map")
170
+ return np.ones((frame_h, frame_w), dtype=np.float32) * 0.5
171
+
172
+ # Use the activation map directly as a saliency proxy when
173
+ # gradient-based targeting is unreliable (common with object
174
+ # detection architectures where loss requires complex target
175
+ # matching). We compute channel-wise L2 norm as the saliency.
176
+ acts = self._activations
177
+ if acts.dim() == 4:
178
+ # (B, C, H, W) — standard conv feature map
179
+ cam = torch.norm(acts[0], dim=0) # (H, W)
180
+ elif acts.dim() == 3:
181
+ # (B, N, C) — transformer sequence; try to reshape
182
+ # N = h * w for spatial feature maps
183
+ B, N, C = acts.shape
184
+ side = int(N ** 0.5)
185
+ if side * side == N:
186
+ cam = torch.norm(acts[0], dim=1).view(side, side)
187
+ else:
188
+ cam = torch.norm(acts[0], dim=1) # (N,)
189
+ cam = cam.unsqueeze(0) # (1, N)
190
+ else:
191
+ cam = torch.norm(acts.flatten(), dim=0, keepdim=True)
192
+
193
+ cam = cam.float()
194
+
195
+ # Normalize to [0, 1]
196
+ cam_min = cam.min()
197
+ cam_max = cam.max()
198
+ if (cam_max - cam_min) > 1e-8:
199
+ cam = (cam - cam_min) / (cam_max - cam_min)
200
+ else:
201
+ cam = torch.zeros_like(cam)
202
+
203
+ cam_np = cam.cpu().numpy()
204
+
205
+ # Upscale to frame resolution
206
+ if cam_np.ndim == 1:
207
+ side = int(len(cam_np) ** 0.5)
208
+ if side * side == len(cam_np):
209
+ cam_np = cam_np.reshape(side, side)
210
+ else:
211
+ cam_np = cam_np.reshape(1, -1)
212
+
213
+ cam_resized = cv2.resize(
214
+ cam_np, (frame_w, frame_h), interpolation=cv2.INTER_LINEAR
215
+ )
216
+
217
+ # Crop influence to the target bbox region (boost bbox, attenuate outside)
218
+ x1, y1, x2, y2 = [int(c) for c in target_bbox]
219
+ x1 = max(0, x1)
220
+ y1 = max(0, y1)
221
+ x2 = min(frame_w, x2)
222
+ y2 = min(frame_h, y2)
223
+
224
+ # Create a soft mask centered on the bbox
225
+ mask = np.zeros((frame_h, frame_w), dtype=np.float32)
226
+ mask[y1:y2, x1:x2] = 1.0
227
+
228
+ # Expand mask slightly for context
229
+ pad = max(x2 - x1, y2 - y1) // 2
230
+ py1 = max(0, y1 - pad)
231
+ py2 = min(frame_h, y2 + pad)
232
+ px1 = max(0, x1 - pad)
233
+ px2 = min(frame_w, x2 + pad)
234
+ mask[py1:py2, px1:px2] = np.maximum(mask[py1:py2, px1:px2], 0.3)
235
+
236
+ cam_resized = cam_resized * mask
237
+
238
+ # Re-normalize
239
+ c_max = cam_resized.max()
240
+ if c_max > 1e-8:
241
+ cam_resized = cam_resized / c_max
242
+
243
+ if not was_training:
244
+ self.model.eval()
245
+
246
+ return cam_resized.astype(np.float32)
247
+
248
+ def release(self):
249
+ """Remove hooks from the model."""
250
+ self._fwd_hook.remove()
251
+ self._bwd_hook.remove()
252
+
253
+
254
+ # ── YOLO saliency (activation-based, no gradients) ──────────────
255
+
256
+
257
+ def _yolo_saliency(
258
+ yolo_model,
259
+ frame: np.ndarray,
260
+ target_bbox: list,
261
+ ) -> np.ndarray:
262
+ """Generate an activation-based saliency map from a YOLO model.
263
+
264
+ Uses the model's internal feature pyramid activations as a proxy
265
+ for attention. This avoids the complexity of GradCAM with YOLO's
266
+ anchor-free heads.
267
+
268
+ Args:
269
+ yolo_model: Ultralytics YOLO model instance.
270
+ frame: HxWx3 BGR uint8 numpy array.
271
+ target_bbox: [x1, y1, x2, y2] in original frame pixel coords.
272
+
273
+ Returns:
274
+ HxW float32 array normalized to [0, 1].
275
+ """
276
+ frame_h, frame_w = frame.shape[:2]
277
+
278
+ # Run inference to get internal features
279
+ results = yolo_model.predict(
280
+ source=frame,
281
+ device=yolo_model.device if hasattr(yolo_model, 'device') else None,
282
+ conf=0.1,
283
+ imgsz=640,
284
+ verbose=False,
285
+ )
286
+
287
+ # Try to extract feature maps from the model internals
288
+ # Ultralytics stores intermediate outputs during forward pass
289
+ cam = None
290
+
291
+ try:
292
+ # Access the PyTorch model inside the Ultralytics wrapper
293
+ pt_model = yolo_model.model
294
+ if hasattr(pt_model, "model"):
295
+ # The sequential model layers
296
+ layers = pt_model.model
297
+ # Find the last feature extraction layer (before detect head)
298
+ # Typically the SPPF or C2f layer near the end
299
+ for i in range(len(layers) - 1, -1, -1):
300
+ layer = layers[i]
301
+ layer_type = type(layer).__name__
302
+ if layer_type in ("SPPF", "C2f", "C3", "Conv"):
303
+ # Hook this layer for activation extraction
304
+ activation = {}
305
+
306
+ def hook_fn(module, inp, out, store=activation):
307
+ store["out"] = out.detach()
308
+
309
+ handle = layer.register_forward_hook(hook_fn)
310
+
311
+ # Re-run forward pass to capture activations
312
+ rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
313
+ from PIL import Image
314
+ import torchvision.transforms as T
315
+
316
+ img = Image.fromarray(rgb)
317
+ # Use the same preprocessing as YOLO
318
+ yolo_model.predict(
319
+ source=frame,
320
+ device=yolo_model.device if hasattr(yolo_model, 'device') else None,
321
+ conf=0.1,
322
+ imgsz=640,
323
+ verbose=False,
324
+ )
325
+
326
+ handle.remove()
327
+
328
+ if "out" in activation:
329
+ feat = activation["out"]
330
+ if feat.dim() == 4:
331
+ cam = torch.norm(feat[0], dim=0).cpu().numpy()
332
+ break
333
+ except Exception as e:
334
+ logger.warning("YOLO feature extraction failed: %s", e)
335
+
336
+ if cam is None:
337
+ # Fallback: generate a simple Gaussian heatmap centered on the bbox
338
+ cam = _gaussian_bbox_heatmap(frame_h, frame_w, target_bbox)
339
+ else:
340
+ # Resize to frame dimensions
341
+ cam = cv2.resize(cam, (frame_w, frame_h), interpolation=cv2.INTER_LINEAR)
342
+
343
+ # Focus on the target bbox region
344
+ x1, y1, x2, y2 = [int(c) for c in target_bbox]
345
+ x1 = max(0, x1)
346
+ y1 = max(0, y1)
347
+ x2 = min(frame_w, x2)
348
+ y2 = min(frame_h, y2)
349
+
350
+ mask = np.zeros((frame_h, frame_w), dtype=np.float32)
351
+ mask[y1:y2, x1:x2] = 1.0
352
+ pad = max(x2 - x1, y2 - y1) // 2
353
+ py1, py2 = max(0, y1 - pad), min(frame_h, y2 + pad)
354
+ px1, px2 = max(0, x1 - pad), min(frame_w, x2 + pad)
355
+ mask[py1:py2, px1:px2] = np.maximum(mask[py1:py2, px1:px2], 0.3)
356
+
357
+ cam = cam * mask
358
+
359
+ # Normalize to [0, 1]
360
+ c_max = cam.max()
361
+ if c_max > 1e-8:
362
+ cam = cam / c_max
363
+
364
+ return cam.astype(np.float32)
365
+
366
+
367
+ def _gaussian_bbox_heatmap(
368
+ frame_h: int, frame_w: int, bbox: list
369
+ ) -> np.ndarray:
370
+ """Generate a Gaussian heatmap centered on a bounding box.
371
+
372
+ Used as a fallback when feature extraction is not available.
373
+ """
374
+ x1, y1, x2, y2 = [int(c) for c in bbox]
375
+ cx = (x1 + x2) / 2
376
+ cy = (y1 + y2) / 2
377
+ sx = max((x2 - x1) / 2, 1.0)
378
+ sy = max((y2 - y1) / 2, 1.0)
379
+
380
+ y_coords, x_coords = np.mgrid[0:frame_h, 0:frame_w]
381
+ heatmap = np.exp(
382
+ -0.5 * (((x_coords - cx) / sx) ** 2 + ((y_coords - cy) / sy) ** 2)
383
+ )
384
+ return heatmap.astype(np.float32)
385
+
386
+
387
+ # ── Main entry point ─────────────────────────────────────────────
388
+
389
+
390
+ def generate_attention_map(
391
+ frame: np.ndarray,
392
+ bbox: list,
393
+ detector_name: str,
394
+ job_id: str,
395
+ frame_idx: int,
396
+ track_id: str,
397
+ ) -> np.ndarray:
398
+ """Generate an attention heatmap for a detected object.
399
+
400
+ Loads the detector model (cached), runs a forward pass, and
401
+ extracts activation-based saliency focused on the target bbox.
402
+
403
+ Args:
404
+ frame: HxWx3 BGR uint8 numpy array.
405
+ bbox: [x1, y1, x2, y2] target object bounding box.
406
+ detector_name: Name of the detector used for the job.
407
+ job_id: Job identifier (for caching).
408
+ frame_idx: Frame index (for caching).
409
+ track_id: Track ID string (for caching).
410
+
411
+ Returns:
412
+ HxW float32 heatmap normalized to [0, 1].
413
+ """
414
+ # Check cache first
415
+ cached = get_cached_attention(job_id, frame_idx, track_id)
416
+ if cached is not None:
417
+ return cached
418
+
419
+ frame_h, frame_w = frame.shape[:2]
420
+ heatmap = None
421
+
422
+ if detector_name in ("yolo11", "yolov8_visdrone"):
423
+ # YOLO models — use activation-based saliency
424
+ try:
425
+ from models.model_loader import load_detector
426
+
427
+ detector = load_detector(detector_name)
428
+ yolo_model = detector.model
429
+ heatmap = _yolo_saliency(yolo_model, frame, bbox)
430
+ except Exception as e:
431
+ logger.warning("YOLO saliency generation failed: %s", e)
432
+
433
+ elif detector_name in ("detr_resnet50", "grounding_dino"):
434
+ # Transformers models — use GradCAM on backbone
435
+ try:
436
+ from models.model_loader import load_detector
437
+
438
+ detector = load_detector(detector_name)
439
+ model = detector.model
440
+ target_layer = _find_target_layer(model)
441
+
442
+ if target_layer is not None:
443
+ extractor = GradCAMExtractor(model, target_layer)
444
+ try:
445
+ # Prepare input
446
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
447
+ processor = detector.processor
448
+ if detector_name == "grounding_dino":
449
+ inputs = processor(
450
+ images=frame_rgb, text="object.", return_tensors="pt"
451
+ )
452
+ else:
453
+ inputs = processor(images=frame_rgb, return_tensors="pt")
454
+ inputs = {
455
+ k: v.to(detector.device) for k, v in inputs.items()
456
+ }
457
+ heatmap = extractor.generate(inputs, bbox, frame_h, frame_w)
458
+ finally:
459
+ extractor.release()
460
+ else:
461
+ logger.warning(
462
+ "No suitable target layer found for %s", detector_name
463
+ )
464
+ except Exception as e:
465
+ logger.warning("GradCAM generation failed for %s: %s", detector_name, e)
466
+
467
+ # Fallback: Gaussian heatmap centered on bbox
468
+ if heatmap is None:
469
+ logger.info(
470
+ "Using Gaussian fallback for attention (detector=%s)", detector_name
471
+ )
472
+ heatmap = _gaussian_bbox_heatmap(frame_h, frame_w, bbox)
473
+
474
+ # Cache the result
475
+ set_cached_attention(job_id, frame_idx, track_id, heatmap)
476
+ return heatmap
477
+
478
+
479
+ # ── Serialization / rendering ─────────────────────────────────────
480
+
481
+
482
+ def heatmap_to_base64(heatmap: np.ndarray) -> str:
483
+ """Encode heatmap as base64 float32 bytes."""
484
+ raw = heatmap.astype(np.float32).tobytes()
485
+ return base64.b64encode(raw).decode("ascii")
486
+
487
+
488
+ def heatmap_overlay_jpeg(
489
+ frame: np.ndarray,
490
+ heatmap: np.ndarray,
491
+ bbox: list,
492
+ alpha: float = 0.5,
493
+ quality: int = 85,
494
+ ) -> bytes:
495
+ """Render heatmap overlay on a cropped frame region as JPEG.
496
+
497
+ Args:
498
+ frame: HxWx3 BGR uint8 numpy array (full frame).
499
+ heatmap: HxW float32 heatmap (same size as frame).
500
+ bbox: [x1, y1, x2, y2] crop region.
501
+ alpha: Blend factor for overlay (0=no overlay, 1=full overlay).
502
+ quality: JPEG quality.
503
+
504
+ Returns:
505
+ JPEG bytes.
506
+ """
507
+ h, w = frame.shape[:2]
508
+ x1, y1, x2, y2 = [int(c) for c in bbox]
509
+ x1 = max(0, x1)
510
+ y1 = max(0, y1)
511
+ x2 = min(w, x2)
512
+ y2 = min(h, y2)
513
+
514
+ # Add some padding
515
+ bw = x2 - x1
516
+ bh = y2 - y1
517
+ pad = int(max(bw, bh) * 0.15)
518
+ cx1 = max(0, x1 - pad)
519
+ cy1 = max(0, y1 - pad)
520
+ cx2 = min(w, x2 + pad)
521
+ cy2 = min(h, y2 + pad)
522
+
523
+ crop = frame[cy1:cy2, cx1:cx2].copy()
524
+ heat_crop = heatmap[cy1:cy2, cx1:cx2]
525
+
526
+ # Normalize heatmap crop to 0-255 for colormap
527
+ heat_u8 = (heat_crop * 255).clip(0, 255).astype(np.uint8)
528
+ colored = cv2.applyColorMap(heat_u8, cv2.COLORMAP_JET)
529
+
530
+ # Blend
531
+ overlay = cv2.addWeighted(crop, 1 - alpha, colored, alpha, 0)
532
+
533
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
534
+ success, buffer = cv2.imencode(".jpg", overlay, encode_param)
535
+ if not success:
536
+ raise RuntimeError("Failed to encode attention overlay as JPEG")
537
+ return buffer.tobytes()
inspection/depth.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """On-demand depth inference, colorization, caching, and stats.
2
+
3
+ Uses DepthAnythingV2Estimator for single-frame depth estimation.
4
+ Results are cached in-memory per (job_id, frame_idx) to avoid
5
+ redundant GPU work when the same frame is requested multiple times.
6
+ """
7
+
8
+ import base64
9
+ import logging
10
+ import threading
11
+ from typing import Dict, Optional, Tuple
12
+
13
+ import cv2
14
+ import numpy as np
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # ── In-memory depth cache ────────────────────────────────────────
19
+ # Key: (job_id, frame_idx) Value: depth_map (HxW float32)
20
+ _depth_cache: Dict[Tuple[str, int], np.ndarray] = {}
21
+ _cache_lock = threading.RLock()
22
+
23
+ # Limit cache size to avoid OOM
24
+ _MAX_CACHE_ENTRIES = 200
25
+
26
+
27
+ def _cache_key(job_id: str, frame_idx: int) -> Tuple[str, int]:
28
+ return (job_id, frame_idx)
29
+
30
+
31
+ def get_cached_depth(job_id: str, frame_idx: int) -> Optional[np.ndarray]:
32
+ """Return cached depth map or None."""
33
+ with _cache_lock:
34
+ return _depth_cache.get(_cache_key(job_id, frame_idx))
35
+
36
+
37
+ def set_cached_depth(job_id: str, frame_idx: int, depth_map: np.ndarray) -> None:
38
+ """Store depth map in cache, evicting oldest if over limit."""
39
+ with _cache_lock:
40
+ if len(_depth_cache) >= _MAX_CACHE_ENTRIES:
41
+ # Evict the first (oldest) entry
42
+ oldest = next(iter(_depth_cache))
43
+ del _depth_cache[oldest]
44
+ _depth_cache[_cache_key(job_id, frame_idx)] = depth_map
45
+
46
+
47
+ def clear_depth_cache(job_id: Optional[str] = None) -> None:
48
+ """Clear depth cache for a specific job or all jobs."""
49
+ with _cache_lock:
50
+ if job_id is None:
51
+ _depth_cache.clear()
52
+ else:
53
+ keys_to_remove = [k for k in _depth_cache if k[0] == job_id]
54
+ for k in keys_to_remove:
55
+ del _depth_cache[k]
56
+
57
+
58
+ # ── Lazy model singleton ─────────────────────────────────────────
59
+
60
+ _depth_estimator = None
61
+ _estimator_lock = threading.Lock()
62
+
63
+
64
+ def _get_depth_estimator():
65
+ """Lazy-load DepthAnythingV2 (singleton, thread-safe)."""
66
+ global _depth_estimator
67
+ if _depth_estimator is None:
68
+ with _estimator_lock:
69
+ if _depth_estimator is None:
70
+ from models.depth_estimators.depth_anything_v2 import (
71
+ DepthAnythingV2Estimator,
72
+ )
73
+ _depth_estimator = DepthAnythingV2Estimator()
74
+ return _depth_estimator
75
+
76
+
77
+ # ── Core inference ────────────────────────────────────────────────
78
+
79
+ def run_depth_on_frame(
80
+ frame: np.ndarray,
81
+ job_id: str,
82
+ frame_idx: int,
83
+ ) -> np.ndarray:
84
+ """Run depth estimation on a single frame, caching the result.
85
+
86
+ Args:
87
+ frame: HxWx3 BGR uint8 numpy array.
88
+ job_id: Job identifier (for cache keying).
89
+ frame_idx: Frame index (for cache keying).
90
+
91
+ Returns:
92
+ HxW float32 depth map.
93
+ """
94
+ cached = get_cached_depth(job_id, frame_idx)
95
+ if cached is not None:
96
+ return cached
97
+
98
+ estimator = _get_depth_estimator()
99
+ result = estimator.predict(frame)
100
+ depth_map = result.depth_map # HxW float32
101
+
102
+ set_cached_depth(job_id, frame_idx, depth_map)
103
+ return depth_map
104
+
105
+
106
+ # ── Stats computation ─────────────────────────────────────────────
107
+
108
+ def compute_depth_stats(depth_map: np.ndarray) -> dict:
109
+ """Compute min, max, mean, median depth statistics.
110
+
111
+ Args:
112
+ depth_map: HxW float32 depth array.
113
+
114
+ Returns:
115
+ Dict with min_m, max_m, mean_m, median_m.
116
+ """
117
+ return {
118
+ "min_m": float(np.min(depth_map)),
119
+ "max_m": float(np.max(depth_map)),
120
+ "mean_m": float(np.mean(depth_map)),
121
+ "median_m": float(np.median(depth_map)),
122
+ }
123
+
124
+
125
+ # ── Crop depth to track bbox ─────────────────────────────────────
126
+
127
+ def crop_depth_to_bbox(
128
+ depth_map: np.ndarray,
129
+ bbox: list,
130
+ padding: float = 0.0,
131
+ ) -> np.ndarray:
132
+ """Crop a depth map to a bounding box with optional padding.
133
+
134
+ Args:
135
+ depth_map: HxW float32 depth array.
136
+ bbox: [x1, y1, x2, y2] in pixel coordinates.
137
+ padding: Fractional padding around the bbox.
138
+
139
+ Returns:
140
+ Cropped HxW float32 depth array.
141
+ """
142
+ h, w = depth_map.shape[:2]
143
+ x1, y1, x2, y2 = bbox
144
+
145
+ bw = x2 - x1
146
+ bh = y2 - y1
147
+ pad_x = int(bw * padding)
148
+ pad_y = int(bh * padding)
149
+
150
+ cx1 = max(0, x1 - pad_x)
151
+ cy1 = max(0, y1 - pad_y)
152
+ cx2 = min(w, x2 + pad_x)
153
+ cy2 = min(h, y2 + pad_y)
154
+
155
+ return depth_map[cy1:cy2, cx1:cx2].copy()
156
+
157
+
158
+ # ── Colorization (viridis) ───────────────────────────────────────
159
+
160
+ def colorize_depth(depth_map: np.ndarray, quality: int = 85) -> bytes:
161
+ """Apply viridis colormap and encode as JPEG.
162
+
163
+ Args:
164
+ depth_map: HxW float32 depth array.
165
+ quality: JPEG quality (1-100).
166
+
167
+ Returns:
168
+ JPEG bytes with viridis-colored depth.
169
+ """
170
+ # Normalize to 0-255
171
+ d_min = float(np.min(depth_map))
172
+ d_max = float(np.max(depth_map))
173
+ if d_max - d_min < 1e-6:
174
+ normalized = np.zeros_like(depth_map, dtype=np.uint8)
175
+ else:
176
+ normalized = ((depth_map - d_min) / (d_max - d_min) * 255).astype(np.uint8)
177
+
178
+ # Apply viridis colormap (OpenCV uses BGR internally)
179
+ colored = cv2.applyColorMap(normalized, cv2.COLORMAP_VIRIDIS)
180
+
181
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
182
+ success, buffer = cv2.imencode(".jpg", colored, encode_param)
183
+ if not success:
184
+ raise RuntimeError("Failed to encode colorized depth as JPEG")
185
+ return buffer.tobytes()
186
+
187
+
188
+ # ── Serialization helpers ─────────────────────────────────────────
189
+
190
+ def depth_to_raw_bytes(depth_map: np.ndarray) -> bytes:
191
+ """Convert depth map to raw float32 little-endian bytes.
192
+
193
+ Args:
194
+ depth_map: HxW float32 depth array.
195
+
196
+ Returns:
197
+ Raw bytes (width * height * 4 bytes).
198
+ """
199
+ return depth_map.astype(np.float32).tobytes()
200
+
201
+
202
+ def depth_to_base64(depth_map: np.ndarray) -> str:
203
+ """Convert depth map to base64-encoded float32 bytes.
204
+
205
+ Args:
206
+ depth_map: HxW float32 depth array.
207
+
208
+ Returns:
209
+ Base64-encoded string.
210
+ """
211
+ raw = depth_to_raw_bytes(depth_map)
212
+ return base64.b64encode(raw).decode("ascii")
inspection/router.py CHANGED
@@ -337,3 +337,201 @@ async def generate_mask(
337
  "color": color,
338
  "source": "sam2_ondemand",
339
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  "color": color,
338
  "source": "sam2_ondemand",
339
  })
340
+
341
+
342
+ # ── On-demand depth analysis ─────────────────────────────────────
343
+
344
+ @router.get("/depth/{job_id}/{frame_idx}")
345
+ async def get_depth(
346
+ job_id: str,
347
+ frame_idx: int,
348
+ track_id: Optional[str] = Query(None, description="Track ID to crop depth to, e.g. 'T01'"),
349
+ format: str = Query("raw", description="Response format: 'raw', 'json', or 'colorized'"),
350
+ ):
351
+ """Get depth data for a frame, computed on-demand and cached.
352
+
353
+ Supports three response formats:
354
+ - raw: binary float32 with X-Depth-* headers
355
+ - json: JSON with base64-encoded float32 depth data and stats
356
+ - colorized: JPEG image with viridis colormap
357
+ """
358
+ import asyncio
359
+
360
+ from inspection.frames import extract_frame
361
+ from inspection.depth import (
362
+ run_depth_on_frame,
363
+ crop_depth_to_bbox,
364
+ compute_depth_stats,
365
+ depth_to_raw_bytes,
366
+ depth_to_base64,
367
+ colorize_depth,
368
+ )
369
+
370
+ if format not in ("raw", "json", "colorized"):
371
+ raise HTTPException(
372
+ status_code=400,
373
+ detail=f"Invalid format '{format}'. Must be 'raw', 'json', or 'colorized'.",
374
+ )
375
+
376
+ job = _get_job_or_404(job_id)
377
+ input_path = job.input_video_path
378
+ if not input_path or not Path(input_path).exists():
379
+ raise HTTPException(status_code=404, detail="Input video not found on disk.")
380
+
381
+ _validate_frame_idx(input_path, frame_idx)
382
+
383
+ # Extract frame and run depth (GPU work in thread pool)
384
+ frame = await asyncio.to_thread(extract_frame, input_path, frame_idx)
385
+ depth_map = await asyncio.to_thread(run_depth_on_frame, frame, job_id, frame_idx)
386
+
387
+ # Optionally crop to track bbox
388
+ if track_id is not None:
389
+ from jobs.storage import get_track_data
390
+
391
+ tracks = get_track_data(job_id, frame_idx)
392
+ instance_id = int(track_id.replace("T", "")) if track_id.startswith("T") else int(track_id)
393
+ target = None
394
+ for t in tracks:
395
+ tid = t.get("instance_id") or t.get("track_id")
396
+ if tid == instance_id or tid == track_id:
397
+ target = t
398
+ break
399
+ if target and "bbox" in target:
400
+ depth_map = crop_depth_to_bbox(depth_map, target["bbox"])
401
+ else:
402
+ raise HTTPException(
403
+ status_code=404,
404
+ detail=f"Track {track_id} not found in frame {frame_idx}.",
405
+ )
406
+
407
+ h, w = depth_map.shape[:2]
408
+ d_min = float(depth_map.min())
409
+ d_max = float(depth_map.max())
410
+
411
+ if format == "raw":
412
+ raw_bytes = depth_to_raw_bytes(depth_map)
413
+ return Response(
414
+ content=raw_bytes,
415
+ media_type="application/octet-stream",
416
+ headers={
417
+ "X-Depth-Width": str(w),
418
+ "X-Depth-Height": str(h),
419
+ "X-Depth-Min": f"{d_min:.4f}",
420
+ "X-Depth-Max": f"{d_max:.4f}",
421
+ },
422
+ )
423
+
424
+ if format == "json":
425
+ stats = compute_depth_stats(depth_map)
426
+ data_b64 = depth_to_base64(depth_map)
427
+ return JSONResponse({
428
+ "width": w,
429
+ "height": h,
430
+ "min_depth": d_min,
431
+ "max_depth": d_max,
432
+ "data_b64": data_b64,
433
+ "depth_stats": stats,
434
+ })
435
+
436
+ # format == "colorized"
437
+ jpeg_bytes = await asyncio.to_thread(colorize_depth, depth_map)
438
+ return Response(content=jpeg_bytes, media_type="image/jpeg")
439
+
440
+
441
+ # ── Attention / GradCAM heatmaps ─────────────────────────────────
442
+
443
+ @router.get("/attention/{job_id}/{frame_idx}/{track_id}")
444
+ async def get_attention(
445
+ job_id: str,
446
+ frame_idx: int,
447
+ track_id: str,
448
+ format: str = Query("json", description="Response format: 'json' or 'overlay'"),
449
+ ):
450
+ """Generate a GradCAM/saliency attention heatmap for a detected object.
451
+
452
+ Computed on-demand using the detector model that produced the original
453
+ job. Results are cached per (job_id, frame_idx, track_id).
454
+
455
+ Supports two response formats:
456
+ - json: base64-encoded float32 heatmap with metadata
457
+ - overlay: JPEG image with heatmap blended onto the frame crop
458
+ """
459
+ import asyncio
460
+
461
+ from inspection.frames import extract_frame
462
+ from inspection.attention import (
463
+ generate_attention_map,
464
+ heatmap_to_base64,
465
+ heatmap_overlay_jpeg,
466
+ )
467
+
468
+ if format not in ("json", "overlay"):
469
+ raise HTTPException(
470
+ status_code=400,
471
+ detail=f"Invalid format '{format}'. Must be 'json' or 'overlay'.",
472
+ )
473
+
474
+ job = _get_job_or_404(job_id)
475
+ input_path = job.input_video_path
476
+ if not input_path or not Path(input_path).exists():
477
+ raise HTTPException(status_code=404, detail="Input video not found on disk.")
478
+
479
+ _validate_frame_idx(input_path, frame_idx)
480
+
481
+ # Determine the detector used for this job
482
+ detector_name = job.detector_name
483
+ if not detector_name:
484
+ raise HTTPException(
485
+ status_code=400,
486
+ detail="Attention maps require a detector model. This job has no detector_name.",
487
+ )
488
+
489
+ # Get the track's bounding box
490
+ from jobs.storage import get_track_data
491
+
492
+ tracks = get_track_data(job_id, frame_idx)
493
+ instance_id = int(track_id.replace("T", "")) if track_id.startswith("T") else int(track_id)
494
+ target = None
495
+ for t in tracks:
496
+ tid = t.get("instance_id") or t.get("track_id")
497
+ if tid == instance_id or tid == track_id:
498
+ target = t
499
+ break
500
+
501
+ if not target or "bbox" not in target:
502
+ raise HTTPException(
503
+ status_code=404,
504
+ detail=f"Track {track_id} not found in frame {frame_idx}.",
505
+ )
506
+
507
+ bbox = target["bbox"]
508
+
509
+ # Extract frame and generate attention map (GPU work in thread pool)
510
+ frame = await asyncio.to_thread(extract_frame, input_path, frame_idx)
511
+ heatmap = await asyncio.to_thread(
512
+ generate_attention_map,
513
+ frame,
514
+ bbox,
515
+ detector_name,
516
+ job_id,
517
+ frame_idx,
518
+ track_id,
519
+ )
520
+
521
+ if format == "json":
522
+ h, w = heatmap.shape[:2]
523
+ data_b64 = heatmap_to_base64(heatmap)
524
+ return JSONResponse({
525
+ "track_id": track_id,
526
+ "frame_idx": frame_idx,
527
+ "width": w,
528
+ "height": h,
529
+ "data_b64": data_b64,
530
+ "format": "float32",
531
+ })
532
+
533
+ # format == "overlay"
534
+ jpeg_bytes = await asyncio.to_thread(
535
+ heatmap_overlay_jpeg, frame, heatmap, bbox
536
+ )
537
+ return Response(content=jpeg_bytes, media_type="image/jpeg")
tests/test_inspection_attention.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for inspection/attention.py — GradCAM attention heatmaps."""
2
+
3
+ import base64
4
+ import struct
5
+
6
+ import numpy as np
7
+ import pytest
8
+
9
+
10
+ # ── Unit tests for attention module functions ────────────────────
11
+
12
+
13
+ def test_gaussian_bbox_heatmap_shape():
14
+ """_gaussian_bbox_heatmap should produce correct shape and range."""
15
+ from inspection.attention import _gaussian_bbox_heatmap
16
+
17
+ heatmap = _gaussian_bbox_heatmap(100, 200, [50, 30, 150, 70])
18
+
19
+ assert heatmap.shape == (100, 200)
20
+ assert heatmap.dtype == np.float32
21
+ assert heatmap.max() <= 1.0
22
+ assert heatmap.min() >= 0.0
23
+
24
+
25
+ def test_gaussian_bbox_heatmap_peak_location():
26
+ """Gaussian heatmap should peak at the bbox center."""
27
+ from inspection.attention import _gaussian_bbox_heatmap
28
+
29
+ heatmap = _gaussian_bbox_heatmap(200, 200, [60, 80, 140, 120])
30
+
31
+ # Center is at (100, 100)
32
+ cy, cx = 100, 100
33
+ center_val = heatmap[cy, cx]
34
+
35
+ # Corner values should be lower
36
+ corner_val = heatmap[0, 0]
37
+ assert center_val > corner_val
38
+
39
+
40
+ def test_gaussian_bbox_heatmap_small_bbox():
41
+ """Small bbox should still produce valid heatmap."""
42
+ from inspection.attention import _gaussian_bbox_heatmap
43
+
44
+ heatmap = _gaussian_bbox_heatmap(50, 50, [10, 10, 12, 12])
45
+
46
+ assert heatmap.shape == (50, 50)
47
+ assert not np.isnan(heatmap).any()
48
+ assert heatmap.max() > 0
49
+
50
+
51
+ def test_heatmap_to_base64():
52
+ """heatmap_to_base64 should encode correctly."""
53
+ from inspection.attention import heatmap_to_base64
54
+
55
+ heatmap = np.array([[0.0, 0.5], [0.75, 1.0]], dtype=np.float32)
56
+ b64 = heatmap_to_base64(heatmap)
57
+
58
+ decoded = base64.b64decode(b64)
59
+ assert len(decoded) == 4 * 4 # 4 floats
60
+ values = struct.unpack("<4f", decoded)
61
+ assert values == pytest.approx((0.0, 0.5, 0.75, 1.0))
62
+
63
+
64
+ def test_heatmap_overlay_jpeg():
65
+ """heatmap_overlay_jpeg should return valid JPEG bytes."""
66
+ from inspection.attention import heatmap_overlay_jpeg
67
+
68
+ frame = np.random.randint(0, 255, (100, 200, 3), dtype=np.uint8)
69
+ heatmap = np.random.rand(100, 200).astype(np.float32)
70
+ bbox = [50, 30, 150, 70]
71
+
72
+ jpeg_bytes = heatmap_overlay_jpeg(frame, heatmap, bbox)
73
+
74
+ # JPEG magic bytes
75
+ assert jpeg_bytes[:2] == b"\xff\xd8"
76
+ assert len(jpeg_bytes) > 100
77
+
78
+
79
+ def test_heatmap_overlay_jpeg_edge_bbox():
80
+ """Overlay with bbox near frame edges should not crash."""
81
+ from inspection.attention import heatmap_overlay_jpeg
82
+
83
+ frame = np.random.randint(0, 255, (50, 50, 3), dtype=np.uint8)
84
+ heatmap = np.random.rand(50, 50).astype(np.float32)
85
+ bbox = [0, 0, 50, 50] # Full frame bbox
86
+
87
+ jpeg_bytes = heatmap_overlay_jpeg(frame, heatmap, bbox)
88
+ assert jpeg_bytes[:2] == b"\xff\xd8"
89
+
90
+
91
+ # ── Cache tests ──────────────────────────────────────────────────
92
+
93
+
94
+ def test_attention_cache_set_get():
95
+ """Attention cache should store and retrieve heatmaps."""
96
+ from inspection.attention import (
97
+ get_cached_attention,
98
+ set_cached_attention,
99
+ clear_attention_cache,
100
+ )
101
+
102
+ clear_attention_cache()
103
+
104
+ heatmap = np.ones((10, 10), dtype=np.float32) * 0.5
105
+ set_cached_attention("job1", 0, "T01", heatmap)
106
+
107
+ result = get_cached_attention("job1", 0, "T01")
108
+ assert result is not None
109
+ np.testing.assert_array_equal(result, heatmap)
110
+
111
+ # Different params should return None
112
+ assert get_cached_attention("job1", 0, "T02") is None
113
+ assert get_cached_attention("job1", 1, "T01") is None
114
+ assert get_cached_attention("job2", 0, "T01") is None
115
+
116
+ clear_attention_cache()
117
+
118
+
119
+ def test_attention_cache_clear_per_job():
120
+ """clear_attention_cache(job_id) should only clear that job."""
121
+ from inspection.attention import (
122
+ get_cached_attention,
123
+ set_cached_attention,
124
+ clear_attention_cache,
125
+ )
126
+
127
+ clear_attention_cache()
128
+
129
+ h1 = np.ones((5, 5), dtype=np.float32)
130
+ h2 = np.ones((5, 5), dtype=np.float32) * 0.5
131
+
132
+ set_cached_attention("jobA", 0, "T01", h1)
133
+ set_cached_attention("jobB", 0, "T01", h2)
134
+
135
+ clear_attention_cache("jobA")
136
+
137
+ assert get_cached_attention("jobA", 0, "T01") is None
138
+ assert get_cached_attention("jobB", 0, "T01") is not None
139
+
140
+ clear_attention_cache()
141
+
142
+
143
+ # ── Integration test for attention endpoint ──────────────────────
144
+
145
+
146
+ def _make_test_video(tmp_path, num_frames=5, width=64, height=48):
147
+ """Create a tiny test video and return its path."""
148
+ import cv2
149
+
150
+ video_path = str(tmp_path / "test.mp4")
151
+ writer = cv2.VideoWriter(
152
+ video_path, cv2.VideoWriter_fourcc(*"mp4v"), 30, (width, height)
153
+ )
154
+ for i in range(num_frames):
155
+ frame = np.full((height, width, 3), i * 40, dtype=np.uint8)
156
+ writer.write(frame)
157
+ writer.release()
158
+ return video_path
159
+
160
+
161
+ def test_attention_endpoint_json(tmp_path, monkeypatch):
162
+ """GET /inspect/attention/{job_id}/{frame_idx}/{track_id}?format=json."""
163
+ from fastapi.testclient import TestClient
164
+
165
+ from jobs.models import JobInfo, JobStatus
166
+ from jobs.storage import get_job_storage, set_track_data
167
+
168
+ video_path = _make_test_video(tmp_path)
169
+
170
+ storage = get_job_storage()
171
+ job = JobInfo(
172
+ job_id="test_attn_json",
173
+ status=JobStatus.COMPLETED,
174
+ mode="object_detection",
175
+ queries=["person"],
176
+ detector_name="yolo11",
177
+ segmenter_name=None,
178
+ input_video_path=video_path,
179
+ output_video_path=None,
180
+ )
181
+ storage.create(job)
182
+
183
+ # Add track data
184
+ set_track_data("test_attn_json", 0, [
185
+ {"instance_id": 1, "label": "person", "bbox": [10, 10, 30, 30]},
186
+ ])
187
+
188
+ # Mock generate_attention_map to avoid loading real models
189
+ def fake_generate(frame, bbox, det_name, job_id, frame_idx, track_id):
190
+ h, w = frame.shape[:2]
191
+ return np.random.rand(h, w).astype(np.float32)
192
+
193
+ monkeypatch.setattr(
194
+ "inspection.attention.generate_attention_map", fake_generate
195
+ )
196
+
197
+ from inspection.router import router
198
+ from fastapi import FastAPI
199
+
200
+ app = FastAPI()
201
+ app.include_router(router)
202
+ client = TestClient(app)
203
+
204
+ resp = client.get("/inspect/attention/test_attn_json/0/T01?format=json")
205
+ assert resp.status_code == 200
206
+
207
+ data = resp.json()
208
+ assert data["track_id"] == "T01"
209
+ assert data["frame_idx"] == 0
210
+ assert "width" in data
211
+ assert "height" in data
212
+ assert "data_b64" in data
213
+ assert data["format"] == "float32"
214
+
215
+ # Verify base64 decodes to correct size
216
+ decoded = base64.b64decode(data["data_b64"])
217
+ assert len(decoded) == data["width"] * data["height"] * 4
218
+
219
+ from inspection.attention import clear_attention_cache
220
+ clear_attention_cache()
221
+ storage.delete("test_attn_json")
222
+
223
+
224
+ def test_attention_endpoint_overlay(tmp_path, monkeypatch):
225
+ """GET /inspect/attention/{job_id}/{frame_idx}/{track_id}?format=overlay."""
226
+ from fastapi.testclient import TestClient
227
+
228
+ from jobs.models import JobInfo, JobStatus
229
+ from jobs.storage import get_job_storage, set_track_data
230
+
231
+ video_path = _make_test_video(tmp_path)
232
+
233
+ storage = get_job_storage()
234
+ job = JobInfo(
235
+ job_id="test_attn_overlay",
236
+ status=JobStatus.COMPLETED,
237
+ mode="object_detection",
238
+ queries=["person"],
239
+ detector_name="yolo11",
240
+ segmenter_name=None,
241
+ input_video_path=video_path,
242
+ output_video_path=None,
243
+ )
244
+ storage.create(job)
245
+
246
+ set_track_data("test_attn_overlay", 0, [
247
+ {"instance_id": 1, "label": "person", "bbox": [5, 5, 40, 35]},
248
+ ])
249
+
250
+ def fake_generate(frame, bbox, det_name, job_id, frame_idx, track_id):
251
+ h, w = frame.shape[:2]
252
+ return np.random.rand(h, w).astype(np.float32)
253
+
254
+ monkeypatch.setattr(
255
+ "inspection.attention.generate_attention_map", fake_generate
256
+ )
257
+
258
+ from inspection.router import router
259
+ from fastapi import FastAPI
260
+
261
+ app = FastAPI()
262
+ app.include_router(router)
263
+ client = TestClient(app)
264
+
265
+ resp = client.get("/inspect/attention/test_attn_overlay/0/T01?format=overlay")
266
+ assert resp.status_code == 200
267
+ assert resp.headers["content-type"] == "image/jpeg"
268
+ assert resp.content[:2] == b"\xff\xd8"
269
+
270
+ from inspection.attention import clear_attention_cache
271
+ clear_attention_cache()
272
+ storage.delete("test_attn_overlay")
273
+
274
+
275
+ def test_attention_endpoint_no_detector(tmp_path):
276
+ """Attention for a job with no detector_name should return 400."""
277
+ from fastapi.testclient import TestClient
278
+
279
+ from jobs.models import JobInfo, JobStatus
280
+ from jobs.storage import get_job_storage
281
+
282
+ video_path = _make_test_video(tmp_path)
283
+
284
+ storage = get_job_storage()
285
+ job = JobInfo(
286
+ job_id="test_attn_no_det",
287
+ status=JobStatus.COMPLETED,
288
+ mode="segmentation",
289
+ queries=["object"],
290
+ detector_name=None,
291
+ segmenter_name="GSAM2-L",
292
+ input_video_path=video_path,
293
+ output_video_path=None,
294
+ )
295
+ storage.create(job)
296
+
297
+ from inspection.router import router
298
+ from fastapi import FastAPI
299
+
300
+ app = FastAPI()
301
+ app.include_router(router)
302
+ client = TestClient(app)
303
+
304
+ resp = client.get("/inspect/attention/test_attn_no_det/0/T01?format=json")
305
+ assert resp.status_code == 400
306
+
307
+ storage.delete("test_attn_no_det")
308
+
309
+
310
+ def test_attention_endpoint_track_not_found(tmp_path):
311
+ """Track not found in frame should return 404."""
312
+ from fastapi.testclient import TestClient
313
+
314
+ from jobs.models import JobInfo, JobStatus
315
+ from jobs.storage import get_job_storage, set_track_data
316
+
317
+ video_path = _make_test_video(tmp_path)
318
+
319
+ storage = get_job_storage()
320
+ job = JobInfo(
321
+ job_id="test_attn_notrack",
322
+ status=JobStatus.COMPLETED,
323
+ mode="object_detection",
324
+ queries=["person"],
325
+ detector_name="yolo11",
326
+ segmenter_name=None,
327
+ input_video_path=video_path,
328
+ output_video_path=None,
329
+ )
330
+ storage.create(job)
331
+
332
+ # No track data for frame 0
333
+ set_track_data("test_attn_notrack", 0, [])
334
+
335
+ from inspection.router import router
336
+ from fastapi import FastAPI
337
+
338
+ app = FastAPI()
339
+ app.include_router(router)
340
+ client = TestClient(app)
341
+
342
+ resp = client.get("/inspect/attention/test_attn_notrack/0/T01?format=json")
343
+ assert resp.status_code == 404
344
+
345
+ storage.delete("test_attn_notrack")
346
+
347
+
348
+ def test_attention_endpoint_invalid_format(tmp_path):
349
+ """Invalid format should return 400."""
350
+ from fastapi.testclient import TestClient
351
+
352
+ from jobs.models import JobInfo, JobStatus
353
+ from jobs.storage import get_job_storage
354
+
355
+ video_path = _make_test_video(tmp_path)
356
+
357
+ storage = get_job_storage()
358
+ job = JobInfo(
359
+ job_id="test_attn_badfmt",
360
+ status=JobStatus.COMPLETED,
361
+ mode="object_detection",
362
+ queries=["person"],
363
+ detector_name="yolo11",
364
+ segmenter_name=None,
365
+ input_video_path=video_path,
366
+ output_video_path=None,
367
+ )
368
+ storage.create(job)
369
+
370
+ from inspection.router import router
371
+ from fastapi import FastAPI
372
+
373
+ app = FastAPI()
374
+ app.include_router(router)
375
+ client = TestClient(app)
376
+
377
+ resp = client.get("/inspect/attention/test_attn_badfmt/0/T01?format=invalid")
378
+ assert resp.status_code == 400
379
+
380
+ storage.delete("test_attn_badfmt")
tests/test_inspection_depth.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for inspection/depth.py — on-demand depth analysis."""
2
+
3
+ import base64
4
+ import struct
5
+
6
+ import numpy as np
7
+ import pytest
8
+
9
+
10
+ # ── Unit tests for depth module functions ────────────────────────
11
+
12
+
13
+ def test_compute_depth_stats():
14
+ """compute_depth_stats should return min, max, mean, median."""
15
+ from inspection.depth import compute_depth_stats
16
+
17
+ depth = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
18
+ stats = compute_depth_stats(depth)
19
+
20
+ assert stats["min_m"] == pytest.approx(1.0)
21
+ assert stats["max_m"] == pytest.approx(4.0)
22
+ assert stats["mean_m"] == pytest.approx(2.5)
23
+ assert stats["median_m"] == pytest.approx(2.5)
24
+
25
+
26
+ def test_compute_depth_stats_uniform():
27
+ """Stats on a uniform depth map."""
28
+ from inspection.depth import compute_depth_stats
29
+
30
+ depth = np.full((10, 10), 5.5, dtype=np.float32)
31
+ stats = compute_depth_stats(depth)
32
+
33
+ assert stats["min_m"] == pytest.approx(5.5)
34
+ assert stats["max_m"] == pytest.approx(5.5)
35
+ assert stats["mean_m"] == pytest.approx(5.5)
36
+ assert stats["median_m"] == pytest.approx(5.5)
37
+
38
+
39
+ def test_crop_depth_to_bbox():
40
+ """crop_depth_to_bbox should extract the correct subregion."""
41
+ from inspection.depth import crop_depth_to_bbox
42
+
43
+ depth = np.arange(100, dtype=np.float32).reshape(10, 10)
44
+ # bbox: x1=2, y1=3, x2=5, y2=7
45
+ cropped = crop_depth_to_bbox(depth, [2, 3, 5, 7])
46
+
47
+ assert cropped.shape == (4, 3) # (y2-y1, x2-x1) = (4, 3)
48
+ assert cropped[0, 0] == pytest.approx(32.0) # depth[3, 2]
49
+
50
+
51
+ def test_crop_depth_to_bbox_with_padding():
52
+ """crop_depth_to_bbox with padding should expand the region."""
53
+ from inspection.depth import crop_depth_to_bbox
54
+
55
+ depth = np.arange(100, dtype=np.float32).reshape(10, 10)
56
+ # bbox: x1=2, y1=3, x2=5, y2=7, padding=0.5
57
+ # bw=3, bh=4, pad_x=1, pad_y=2 => region [1,1] to [6,9]
58
+ cropped = crop_depth_to_bbox(depth, [2, 3, 5, 7], padding=0.5)
59
+
60
+ assert cropped.shape[0] > 4 # Should be larger than (4, 3)
61
+ assert cropped.shape[1] > 3
62
+
63
+
64
+ def test_crop_depth_to_bbox_clamped():
65
+ """Cropping near edges should clamp to image boundaries."""
66
+ from inspection.depth import crop_depth_to_bbox
67
+
68
+ depth = np.arange(100, dtype=np.float32).reshape(10, 10)
69
+ # bbox near top-left with big padding
70
+ cropped = crop_depth_to_bbox(depth, [0, 0, 2, 2], padding=1.0)
71
+
72
+ assert cropped.shape[0] >= 2
73
+ assert cropped.shape[1] >= 2
74
+ # Should not crash or produce negative indices
75
+ assert cropped[0, 0] == pytest.approx(0.0) # depth[0, 0]
76
+
77
+
78
+ def test_depth_to_raw_bytes():
79
+ """depth_to_raw_bytes should produce correct float32 bytes."""
80
+ from inspection.depth import depth_to_raw_bytes
81
+
82
+ depth = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
83
+ raw = depth_to_raw_bytes(depth)
84
+
85
+ assert len(raw) == 4 * 4 # 4 floats * 4 bytes each
86
+ values = struct.unpack("<4f", raw)
87
+ assert values == pytest.approx((1.0, 2.0, 3.0, 4.0))
88
+
89
+
90
+ def test_depth_to_base64():
91
+ """depth_to_base64 should produce decodable base64 float32 data."""
92
+ from inspection.depth import depth_to_base64
93
+
94
+ depth = np.array([[1.5, 2.5]], dtype=np.float32)
95
+ b64 = depth_to_base64(depth)
96
+
97
+ decoded = base64.b64decode(b64)
98
+ assert len(decoded) == 2 * 4 # 2 floats
99
+ values = struct.unpack("<2f", decoded)
100
+ assert values == pytest.approx((1.5, 2.5))
101
+
102
+
103
+ def test_colorize_depth_returns_jpeg():
104
+ """colorize_depth should return valid JPEG bytes."""
105
+ from inspection.depth import colorize_depth
106
+
107
+ depth = np.random.rand(48, 64).astype(np.float32) * 10.0
108
+ jpeg_bytes = colorize_depth(depth)
109
+
110
+ # JPEG magic bytes
111
+ assert jpeg_bytes[:2] == b"\xff\xd8"
112
+ assert len(jpeg_bytes) > 100
113
+
114
+
115
+ def test_colorize_depth_uniform():
116
+ """colorize_depth should handle uniform depth (no division by zero)."""
117
+ from inspection.depth import colorize_depth
118
+
119
+ depth = np.full((32, 32), 5.0, dtype=np.float32)
120
+ jpeg_bytes = colorize_depth(depth)
121
+ assert jpeg_bytes[:2] == b"\xff\xd8"
122
+
123
+
124
+ # ── Cache tests ──────────────────────────────────────────────────
125
+
126
+
127
+ def test_cache_set_get():
128
+ """Depth cache should store and retrieve depth maps."""
129
+ from inspection.depth import (
130
+ get_cached_depth,
131
+ set_cached_depth,
132
+ clear_depth_cache,
133
+ )
134
+
135
+ clear_depth_cache()
136
+
137
+ depth = np.ones((10, 10), dtype=np.float32)
138
+ set_cached_depth("job1", 0, depth)
139
+
140
+ result = get_cached_depth("job1", 0)
141
+ assert result is not None
142
+ np.testing.assert_array_equal(result, depth)
143
+
144
+ # Different frame should return None
145
+ assert get_cached_depth("job1", 1) is None
146
+ # Different job should return None
147
+ assert get_cached_depth("job2", 0) is None
148
+
149
+ clear_depth_cache()
150
+
151
+
152
+ def test_cache_clear_per_job():
153
+ """clear_depth_cache(job_id) should only clear that job."""
154
+ from inspection.depth import (
155
+ get_cached_depth,
156
+ set_cached_depth,
157
+ clear_depth_cache,
158
+ )
159
+
160
+ clear_depth_cache()
161
+
162
+ d1 = np.ones((5, 5), dtype=np.float32)
163
+ d2 = np.ones((5, 5), dtype=np.float32) * 2
164
+
165
+ set_cached_depth("jobA", 0, d1)
166
+ set_cached_depth("jobB", 0, d2)
167
+
168
+ clear_depth_cache("jobA")
169
+
170
+ assert get_cached_depth("jobA", 0) is None
171
+ assert get_cached_depth("jobB", 0) is not None
172
+
173
+ clear_depth_cache()
174
+
175
+
176
+ # ── Integration test for the endpoint (via TestClient) ───────────
177
+
178
+
179
+ def _make_test_video(tmp_path, num_frames=5, width=64, height=48):
180
+ """Create a tiny test video and return its path."""
181
+ import cv2
182
+
183
+ video_path = str(tmp_path / "test.mp4")
184
+ writer = cv2.VideoWriter(
185
+ video_path, cv2.VideoWriter_fourcc(*"mp4v"), 30, (width, height)
186
+ )
187
+ for i in range(num_frames):
188
+ frame = np.full((height, width, 3), i * 40, dtype=np.uint8)
189
+ writer.write(frame)
190
+ writer.release()
191
+ return video_path
192
+
193
+
194
+ def test_depth_endpoint_raw(tmp_path, monkeypatch):
195
+ """GET /inspect/depth/{job_id}/{frame_idx}?format=raw should return binary float32."""
196
+ from fastapi.testclient import TestClient
197
+ from datetime import datetime
198
+
199
+ from jobs.models import JobInfo, JobStatus
200
+ from jobs.storage import get_job_storage
201
+
202
+ # Create a test video
203
+ video_path = _make_test_video(tmp_path)
204
+
205
+ # Register a fake job
206
+ storage = get_job_storage()
207
+ job = JobInfo(
208
+ job_id="test_depth_raw",
209
+ status=JobStatus.COMPLETED,
210
+ mode="object_detection",
211
+ queries=["person"],
212
+ detector_name="yolo11",
213
+ segmenter_name=None,
214
+ input_video_path=video_path,
215
+ output_video_path=None,
216
+ )
217
+ storage.create(job)
218
+
219
+ # Mock the depth estimator to avoid loading the real model
220
+ class FakeDepthResult:
221
+ def __init__(self, h, w):
222
+ self.depth_map = np.arange(h * w, dtype=np.float32).reshape(h, w)
223
+ self.focal_length = 1.0
224
+
225
+ class FakeEstimator:
226
+ def predict(self, frame):
227
+ h, w = frame.shape[:2]
228
+ return FakeDepthResult(h, w)
229
+
230
+ monkeypatch.setattr(
231
+ "inspection.depth._depth_estimator", FakeEstimator()
232
+ )
233
+
234
+ # Import app after patching
235
+ from inspection.router import router
236
+ from fastapi import FastAPI
237
+
238
+ app = FastAPI()
239
+ app.include_router(router)
240
+ client = TestClient(app)
241
+
242
+ resp = client.get("/inspect/depth/test_depth_raw/0?format=raw")
243
+ assert resp.status_code == 200
244
+ assert resp.headers["content-type"] == "application/octet-stream"
245
+ assert "X-Depth-Width" in resp.headers
246
+ assert "X-Depth-Height" in resp.headers
247
+ assert "X-Depth-Min" in resp.headers
248
+ assert "X-Depth-Max" in resp.headers
249
+
250
+ w = int(resp.headers["X-Depth-Width"])
251
+ h = int(resp.headers["X-Depth-Height"])
252
+ assert len(resp.content) == w * h * 4 # float32
253
+
254
+ # Cleanup
255
+ from inspection.depth import clear_depth_cache
256
+ clear_depth_cache()
257
+ storage.delete("test_depth_raw")
258
+
259
+
260
+ def test_depth_endpoint_json(tmp_path, monkeypatch):
261
+ """GET /inspect/depth/{job_id}/{frame_idx}?format=json should return proper JSON."""
262
+ from fastapi.testclient import TestClient
263
+
264
+ from jobs.models import JobInfo, JobStatus
265
+ from jobs.storage import get_job_storage
266
+
267
+ video_path = _make_test_video(tmp_path)
268
+
269
+ storage = get_job_storage()
270
+ job = JobInfo(
271
+ job_id="test_depth_json",
272
+ status=JobStatus.COMPLETED,
273
+ mode="object_detection",
274
+ queries=["person"],
275
+ detector_name="yolo11",
276
+ segmenter_name=None,
277
+ input_video_path=video_path,
278
+ output_video_path=None,
279
+ )
280
+ storage.create(job)
281
+
282
+ class FakeDepthResult:
283
+ def __init__(self, h, w):
284
+ self.depth_map = np.ones((h, w), dtype=np.float32) * 5.0
285
+ self.focal_length = 1.0
286
+
287
+ class FakeEstimator:
288
+ def predict(self, frame):
289
+ h, w = frame.shape[:2]
290
+ return FakeDepthResult(h, w)
291
+
292
+ monkeypatch.setattr("inspection.depth._depth_estimator", FakeEstimator())
293
+
294
+ from inspection.router import router
295
+ from fastapi import FastAPI
296
+
297
+ app = FastAPI()
298
+ app.include_router(router)
299
+ client = TestClient(app)
300
+
301
+ resp = client.get("/inspect/depth/test_depth_json/0?format=json")
302
+ assert resp.status_code == 200
303
+
304
+ data = resp.json()
305
+ assert "width" in data
306
+ assert "height" in data
307
+ assert "min_depth" in data
308
+ assert "max_depth" in data
309
+ assert "data_b64" in data
310
+ assert "depth_stats" in data
311
+ assert data["depth_stats"]["min_m"] == pytest.approx(5.0)
312
+ assert data["depth_stats"]["max_m"] == pytest.approx(5.0)
313
+
314
+ # Verify base64 decodes to correct size
315
+ decoded = base64.b64decode(data["data_b64"])
316
+ assert len(decoded) == data["width"] * data["height"] * 4
317
+
318
+ from inspection.depth import clear_depth_cache
319
+ clear_depth_cache()
320
+ storage.delete("test_depth_json")
321
+
322
+
323
+ def test_depth_endpoint_colorized(tmp_path, monkeypatch):
324
+ """GET /inspect/depth/{job_id}/{frame_idx}?format=colorized should return JPEG."""
325
+ from fastapi.testclient import TestClient
326
+
327
+ from jobs.models import JobInfo, JobStatus
328
+ from jobs.storage import get_job_storage
329
+
330
+ video_path = _make_test_video(tmp_path)
331
+
332
+ storage = get_job_storage()
333
+ job = JobInfo(
334
+ job_id="test_depth_color",
335
+ status=JobStatus.COMPLETED,
336
+ mode="object_detection",
337
+ queries=["person"],
338
+ detector_name="yolo11",
339
+ segmenter_name=None,
340
+ input_video_path=video_path,
341
+ output_video_path=None,
342
+ )
343
+ storage.create(job)
344
+
345
+ class FakeDepthResult:
346
+ def __init__(self, h, w):
347
+ self.depth_map = np.random.rand(h, w).astype(np.float32) * 10.0
348
+ self.focal_length = 1.0
349
+
350
+ class FakeEstimator:
351
+ def predict(self, frame):
352
+ h, w = frame.shape[:2]
353
+ return FakeDepthResult(h, w)
354
+
355
+ monkeypatch.setattr("inspection.depth._depth_estimator", FakeEstimator())
356
+
357
+ from inspection.router import router
358
+ from fastapi import FastAPI
359
+
360
+ app = FastAPI()
361
+ app.include_router(router)
362
+ client = TestClient(app)
363
+
364
+ resp = client.get("/inspect/depth/test_depth_color/0?format=colorized")
365
+ assert resp.status_code == 200
366
+ assert resp.headers["content-type"] == "image/jpeg"
367
+ assert resp.content[:2] == b"\xff\xd8" # JPEG magic
368
+
369
+ from inspection.depth import clear_depth_cache
370
+ clear_depth_cache()
371
+ storage.delete("test_depth_color")
372
+
373
+
374
+ def test_depth_endpoint_invalid_format(tmp_path, monkeypatch):
375
+ """Invalid format should return 400."""
376
+ from fastapi.testclient import TestClient
377
+
378
+ from jobs.models import JobInfo, JobStatus
379
+ from jobs.storage import get_job_storage
380
+
381
+ video_path = _make_test_video(tmp_path)
382
+
383
+ storage = get_job_storage()
384
+ job = JobInfo(
385
+ job_id="test_depth_bad_fmt",
386
+ status=JobStatus.COMPLETED,
387
+ mode="object_detection",
388
+ queries=["person"],
389
+ detector_name="yolo11",
390
+ segmenter_name=None,
391
+ input_video_path=video_path,
392
+ output_video_path=None,
393
+ )
394
+ storage.create(job)
395
+
396
+ from inspection.router import router
397
+ from fastapi import FastAPI
398
+
399
+ app = FastAPI()
400
+ app.include_router(router)
401
+ client = TestClient(app)
402
+
403
+ resp = client.get("/inspect/depth/test_depth_bad_fmt/0?format=invalid")
404
+ assert resp.status_code == 400
405
+
406
+ storage.delete("test_depth_bad_fmt")
407
+
408
+
409
+ def test_depth_endpoint_job_not_found():
410
+ """Non-existent job should return 404."""
411
+ from fastapi.testclient import TestClient
412
+
413
+ from inspection.router import router
414
+ from fastapi import FastAPI
415
+
416
+ app = FastAPI()
417
+ app.include_router(router)
418
+ client = TestClient(app)
419
+
420
+ resp = client.get("/inspect/depth/nonexistent/0?format=json")
421
+ assert resp.status_code == 404
422
+
423
+
424
+ def test_depth_endpoint_with_track_id(tmp_path, monkeypatch):
425
+ """Depth with track_id should crop to bbox."""
426
+ from fastapi.testclient import TestClient
427
+
428
+ from jobs.models import JobInfo, JobStatus
429
+ from jobs.storage import get_job_storage, set_track_data
430
+
431
+ video_path = _make_test_video(tmp_path)
432
+
433
+ storage = get_job_storage()
434
+ job = JobInfo(
435
+ job_id="test_depth_track",
436
+ status=JobStatus.COMPLETED,
437
+ mode="object_detection",
438
+ queries=["person"],
439
+ detector_name="yolo11",
440
+ segmenter_name=None,
441
+ input_video_path=video_path,
442
+ output_video_path=None,
443
+ )
444
+ storage.create(job)
445
+
446
+ # Add track data for frame 0
447
+ set_track_data("test_depth_track", 0, [
448
+ {"instance_id": 1, "label": "person", "bbox": [10, 10, 30, 30]},
449
+ ])
450
+
451
+ class FakeDepthResult:
452
+ def __init__(self, h, w):
453
+ self.depth_map = np.arange(h * w, dtype=np.float32).reshape(h, w)
454
+ self.focal_length = 1.0
455
+
456
+ class FakeEstimator:
457
+ def predict(self, frame):
458
+ h, w = frame.shape[:2]
459
+ return FakeDepthResult(h, w)
460
+
461
+ monkeypatch.setattr("inspection.depth._depth_estimator", FakeEstimator())
462
+
463
+ from inspection.router import router
464
+ from fastapi import FastAPI
465
+
466
+ app = FastAPI()
467
+ app.include_router(router)
468
+ client = TestClient(app)
469
+
470
+ resp = client.get("/inspect/depth/test_depth_track/0?format=json&track_id=T01")
471
+ assert resp.status_code == 200
472
+
473
+ data = resp.json()
474
+ # Cropped to bbox [10,10,30,30] => 20x20
475
+ assert data["width"] == 20
476
+ assert data["height"] == 20
477
+
478
+ from inspection.depth import clear_depth_cache
479
+ clear_depth_cache()
480
+ storage.delete("test_depth_track")