Spaces:
Runtime error
feat(inspection): add depth analysis and attention heatmap endpoints (Phase 2)
Browse filesBackend:
- inspection/depth.py: on-demand DepthAnythingV2 inference with LRU caching,
viridis colorization, stats computation, raw/json/colorized response formats
- inspection/attention.py: GradCAM for DETR/GDINO, saliency for YOLO,
gaussian fallback, overlay generation, per-request caching
- inspection/router.py: GET /inspect/depth and GET /inspect/attention endpoints
- 28 new tests across depth and attention modules
Frontend:
- inspection-api.js: CORS fallback for depth binary format, explicit format params
- inspection-renders.js: depth legend with min/max meters, track depth stats,
attention legend with peak/avg intensity, improved alpha blending
- inspection.js: mode-specific loading messages
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- frontend/js/api/inspection-api.js +51 -15
- frontend/js/ui/inspection-renders.js +185 -7
- frontend/js/ui/inspection.js +24 -5
- inspection/attention.py +537 -0
- inspection/depth.py +212 -0
- inspection/router.py +198 -0
- tests/test_inspection_attention.py +380 -0
- tests/test_inspection_depth.py +480 -0
|
@@ -73,6 +73,7 @@ APP.api.inspection.generateMask = async function (jobId, frameIdx, trackId) {
|
|
| 73 |
*/
|
| 74 |
APP.api.inspection.fetchDepth = async function (jobId, frameIdx) {
|
| 75 |
const base = APP.core.state.hf.baseUrl;
|
|
|
|
| 76 |
const url = `${base}/inspect/depth/${jobId}/${frameIdx}?format=raw`;
|
| 77 |
const resp = await fetch(url);
|
| 78 |
if (!resp.ok) throw new Error(`Depth fetch failed: ${resp.status}`);
|
|
@@ -80,30 +81,63 @@ APP.api.inspection.fetchDepth = async function (jobId, frameIdx) {
|
|
| 80 |
const contentType = resp.headers.get("content-type") || "";
|
| 81 |
|
| 82 |
if (contentType.includes("application/octet-stream")) {
|
| 83 |
-
// Binary float32 format
|
| 84 |
const w = parseInt(resp.headers.get("X-Depth-Width"), 10);
|
| 85 |
const h = parseInt(resp.headers.get("X-Depth-Height"), 10);
|
| 86 |
const minD = parseFloat(resp.headers.get("X-Depth-Min"));
|
| 87 |
const maxD = parseFloat(resp.headers.get("X-Depth-Max"));
|
| 88 |
const buf = await resp.arrayBuffer();
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
} else {
|
| 91 |
// JSON + base64 format
|
| 92 |
const json = await resp.json();
|
| 93 |
-
|
| 94 |
-
const buf = new ArrayBuffer(raw.length);
|
| 95 |
-
const view = new Uint8Array(buf);
|
| 96 |
-
for (let i = 0; i < raw.length; i++) view[i] = raw.charCodeAt(i);
|
| 97 |
-
return {
|
| 98 |
-
width: json.width,
|
| 99 |
-
height: json.height,
|
| 100 |
-
min: json.min_depth,
|
| 101 |
-
max: json.max_depth,
|
| 102 |
-
data: new Float32Array(buf)
|
| 103 |
-
};
|
| 104 |
}
|
| 105 |
};
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
/**
|
| 108 |
* Fetch attention heatmap for a specific track on a specific frame.
|
| 109 |
* @param {string} jobId
|
|
@@ -113,7 +147,7 @@ APP.api.inspection.fetchDepth = async function (jobId, frameIdx) {
|
|
| 113 |
*/
|
| 114 |
APP.api.inspection.fetchAttention = async function (jobId, frameIdx, trackId) {
|
| 115 |
const base = APP.core.state.hf.baseUrl;
|
| 116 |
-
const url = `${base}/inspect/attention/${jobId}/${frameIdx}/${encodeURIComponent(trackId)}`;
|
| 117 |
const resp = await fetch(url);
|
| 118 |
if (!resp.ok) throw new Error(`Attention fetch failed: ${resp.status}`);
|
| 119 |
|
|
@@ -126,7 +160,9 @@ APP.api.inspection.fetchAttention = async function (jobId, frameIdx, trackId) {
|
|
| 126 |
return {
|
| 127 |
width: json.width,
|
| 128 |
height: json.height,
|
| 129 |
-
data: new Float32Array(buf)
|
|
|
|
|
|
|
| 130 |
};
|
| 131 |
};
|
| 132 |
|
|
|
|
| 73 |
*/
|
| 74 |
APP.api.inspection.fetchDepth = async function (jobId, frameIdx) {
|
| 75 |
const base = APP.core.state.hf.baseUrl;
|
| 76 |
+
// Try binary format first; fall back to JSON if CORS strips custom headers
|
| 77 |
const url = `${base}/inspect/depth/${jobId}/${frameIdx}?format=raw`;
|
| 78 |
const resp = await fetch(url);
|
| 79 |
if (!resp.ok) throw new Error(`Depth fetch failed: ${resp.status}`);
|
|
|
|
| 81 |
const contentType = resp.headers.get("content-type") || "";
|
| 82 |
|
| 83 |
if (contentType.includes("application/octet-stream")) {
|
| 84 |
+
// Binary float32 format with metadata in headers
|
| 85 |
const w = parseInt(resp.headers.get("X-Depth-Width"), 10);
|
| 86 |
const h = parseInt(resp.headers.get("X-Depth-Height"), 10);
|
| 87 |
const minD = parseFloat(resp.headers.get("X-Depth-Min"));
|
| 88 |
const maxD = parseFloat(resp.headers.get("X-Depth-Max"));
|
| 89 |
const buf = await resp.arrayBuffer();
|
| 90 |
+
const data = new Float32Array(buf);
|
| 91 |
+
|
| 92 |
+
// If CORS stripped headers, infer dimensions from data length
|
| 93 |
+
if (isNaN(w) || isNaN(h)) {
|
| 94 |
+
// Fall back to JSON format
|
| 95 |
+
return await APP.api.inspection._fetchDepthJson(jobId, frameIdx);
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
return {
|
| 99 |
+
width: w,
|
| 100 |
+
height: h,
|
| 101 |
+
min: isNaN(minD) ? 0 : minD,
|
| 102 |
+
max: isNaN(maxD) ? 1 : maxD,
|
| 103 |
+
data: data
|
| 104 |
+
};
|
| 105 |
} else {
|
| 106 |
// JSON + base64 format
|
| 107 |
const json = await resp.json();
|
| 108 |
+
return APP.api.inspection._decodeDepthJson(json);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
}
|
| 110 |
};
|
| 111 |
|
| 112 |
+
/**
|
| 113 |
+
* Fallback: fetch depth in JSON format if raw binary headers are unavailable.
|
| 114 |
+
*/
|
| 115 |
+
APP.api.inspection._fetchDepthJson = async function (jobId, frameIdx) {
|
| 116 |
+
const base = APP.core.state.hf.baseUrl;
|
| 117 |
+
const url = `${base}/inspect/depth/${jobId}/${frameIdx}?format=json`;
|
| 118 |
+
const resp = await fetch(url);
|
| 119 |
+
if (!resp.ok) throw new Error(`Depth (JSON) fetch failed: ${resp.status}`);
|
| 120 |
+
const json = await resp.json();
|
| 121 |
+
return APP.api.inspection._decodeDepthJson(json);
|
| 122 |
+
};
|
| 123 |
+
|
| 124 |
+
/**
|
| 125 |
+
* Decode a JSON depth response to the standard { width, height, min, max, data } format.
|
| 126 |
+
*/
|
| 127 |
+
APP.api.inspection._decodeDepthJson = function (json) {
|
| 128 |
+
const raw = atob(json.data_b64);
|
| 129 |
+
const buf = new ArrayBuffer(raw.length);
|
| 130 |
+
const view = new Uint8Array(buf);
|
| 131 |
+
for (let i = 0; i < raw.length; i++) view[i] = raw.charCodeAt(i);
|
| 132 |
+
return {
|
| 133 |
+
width: json.width,
|
| 134 |
+
height: json.height,
|
| 135 |
+
min: json.min_depth,
|
| 136 |
+
max: json.max_depth,
|
| 137 |
+
data: new Float32Array(buf)
|
| 138 |
+
};
|
| 139 |
+
};
|
| 140 |
+
|
| 141 |
/**
|
| 142 |
* Fetch attention heatmap for a specific track on a specific frame.
|
| 143 |
* @param {string} jobId
|
|
|
|
| 147 |
*/
|
| 148 |
APP.api.inspection.fetchAttention = async function (jobId, frameIdx, trackId) {
|
| 149 |
const base = APP.core.state.hf.baseUrl;
|
| 150 |
+
const url = `${base}/inspect/attention/${jobId}/${frameIdx}/${encodeURIComponent(trackId)}?format=json`;
|
| 151 |
const resp = await fetch(url);
|
| 152 |
if (!resp.ok) throw new Error(`Attention fetch failed: ${resp.status}`);
|
| 153 |
|
|
|
|
| 160 |
return {
|
| 161 |
width: json.width,
|
| 162 |
height: json.height,
|
| 163 |
+
data: new Float32Array(buf),
|
| 164 |
+
trackId: json.track_id || trackId,
|
| 165 |
+
frameIdx: json.frame_idx || frameIdx
|
| 166 |
};
|
| 167 |
};
|
| 168 |
|
|
@@ -226,7 +226,7 @@ APP.ui.inspectionRenders._renderEdge = function (canvas, frameImg, edgeData, tra
|
|
| 226 |
};
|
| 227 |
|
| 228 |
/**
|
| 229 |
-
* Render depth colormap (viridis-like palette).
|
| 230 |
*/
|
| 231 |
APP.ui.inspectionRenders._renderDepth = function (canvas, frameImg, depthData, track) {
|
| 232 |
if (!frameImg) return;
|
|
@@ -265,7 +265,14 @@ APP.ui.inspectionRenders._renderDepth = function (canvas, frameImg, depthData, t
|
|
| 265 |
const id = img.data;
|
| 266 |
|
| 267 |
for (let i = 0; i < dd.length; i++) {
|
| 268 |
-
const
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
const rgb = APP.ui.inspectionRenders._viridis(t);
|
| 270 |
const oi = i * 4;
|
| 271 |
id[oi] = rgb[0];
|
|
@@ -279,11 +286,113 @@ APP.ui.inspectionRenders._renderDepth = function (canvas, frameImg, depthData, t
|
|
| 279 |
// Draw scaled to canvas size
|
| 280 |
ctx.drawImage(depthCanvas, 0, 0, w, h);
|
| 281 |
|
|
|
|
| 282 |
APP.ui.inspectionRenders._drawBBoxHighlight(ctx, track, w, h);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
};
|
| 284 |
|
| 285 |
/**
|
| 286 |
-
* Render attention heatmap overlaid on the base frame.
|
|
|
|
| 287 |
*/
|
| 288 |
APP.ui.inspectionRenders._renderAttention = function (canvas, frameImg, attentionData, track) {
|
| 289 |
if (!frameImg) return;
|
|
@@ -294,11 +403,13 @@ APP.ui.inspectionRenders._renderAttention = function (canvas, frameImg, attentio
|
|
| 294 |
canvas.height = h;
|
| 295 |
const ctx = canvas.getContext("2d");
|
| 296 |
|
| 297 |
-
// Draw base frame
|
| 298 |
ctx.drawImage(frameImg, 0, 0);
|
|
|
|
|
|
|
| 299 |
|
| 300 |
if (!attentionData || !attentionData.data) {
|
| 301 |
-
ctx.fillStyle = "rgba(0,0,0,0.
|
| 302 |
ctx.fillRect(0, 0, w, h);
|
| 303 |
ctx.fillStyle = "#aaa";
|
| 304 |
ctx.font = "14px monospace";
|
|
@@ -307,7 +418,7 @@ APP.ui.inspectionRenders._renderAttention = function (canvas, frameImg, attentio
|
|
| 307 |
return;
|
| 308 |
}
|
| 309 |
|
| 310 |
-
// Render attention as
|
| 311 |
const aw = attentionData.width;
|
| 312 |
const ah = attentionData.height;
|
| 313 |
const ad = attentionData.data;
|
|
@@ -319,20 +430,87 @@ APP.ui.inspectionRenders._renderAttention = function (canvas, frameImg, attentio
|
|
| 319 |
const img = hctx.createImageData(aw, ah);
|
| 320 |
const id = img.data;
|
| 321 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
for (let i = 0; i < ad.length; i++) {
|
| 323 |
const t = APP.core.utils.clamp(ad[i], 0, 1);
|
|
|
|
|
|
|
| 324 |
const rgb = APP.ui.inspectionRenders._inferno(t);
|
| 325 |
const oi = i * 4;
|
| 326 |
id[oi] = rgb[0];
|
| 327 |
id[oi + 1] = rgb[1];
|
| 328 |
id[oi + 2] = rgb[2];
|
| 329 |
-
|
|
|
|
|
|
|
| 330 |
}
|
| 331 |
|
| 332 |
hctx.putImageData(img, 0, 0);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
ctx.drawImage(heatCanvas, 0, 0, w, h);
|
| 334 |
|
|
|
|
| 335 |
APP.ui.inspectionRenders._drawBBoxHighlight(ctx, track, w, h);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
};
|
| 337 |
|
| 338 |
/**
|
|
|
|
| 226 |
};
|
| 227 |
|
| 228 |
/**
|
| 229 |
+
* Render depth colormap (viridis-like palette) with scale legend.
|
| 230 |
*/
|
| 231 |
APP.ui.inspectionRenders._renderDepth = function (canvas, frameImg, depthData, track) {
|
| 232 |
if (!frameImg) return;
|
|
|
|
| 265 |
const id = img.data;
|
| 266 |
|
| 267 |
for (let i = 0; i < dd.length; i++) {
|
| 268 |
+
const val = dd[i];
|
| 269 |
+
// Handle NaN/Inf values — render as black
|
| 270 |
+
if (!isFinite(val)) {
|
| 271 |
+
const oi = i * 4;
|
| 272 |
+
id[oi] = 0; id[oi + 1] = 0; id[oi + 2] = 0; id[oi + 3] = 255;
|
| 273 |
+
continue;
|
| 274 |
+
}
|
| 275 |
+
const t = APP.core.utils.clamp((val - minD) / range, 0, 1);
|
| 276 |
const rgb = APP.ui.inspectionRenders._viridis(t);
|
| 277 |
const oi = i * 4;
|
| 278 |
id[oi] = rgb[0];
|
|
|
|
| 286 |
// Draw scaled to canvas size
|
| 287 |
ctx.drawImage(depthCanvas, 0, 0, w, h);
|
| 288 |
|
| 289 |
+
// Draw bbox highlight for the selected track
|
| 290 |
APP.ui.inspectionRenders._drawBBoxHighlight(ctx, track, w, h);
|
| 291 |
+
|
| 292 |
+
// Draw depth scale legend (vertical gradient bar on the right)
|
| 293 |
+
APP.ui.inspectionRenders._drawDepthLegend(ctx, w, h, minD, maxD);
|
| 294 |
+
|
| 295 |
+
// If track is selected, compute and show average depth in the bbox region
|
| 296 |
+
if (track && track.bbox && dw > 0 && dh > 0) {
|
| 297 |
+
APP.ui.inspectionRenders._drawTrackDepthStats(ctx, track, w, h, depthData);
|
| 298 |
+
}
|
| 299 |
+
};
|
| 300 |
+
|
| 301 |
+
/**
|
| 302 |
+
* Draw a vertical depth scale legend on the right side of the canvas.
|
| 303 |
+
*/
|
| 304 |
+
APP.ui.inspectionRenders._drawDepthLegend = function (ctx, canvasW, canvasH, minD, maxD) {
|
| 305 |
+
const barW = 16;
|
| 306 |
+
const barH = Math.min(180, canvasH - 60);
|
| 307 |
+
const x = canvasW - barW - 30;
|
| 308 |
+
const y = 30;
|
| 309 |
+
|
| 310 |
+
// Background panel
|
| 311 |
+
ctx.fillStyle = "rgba(0, 0, 0, 0.55)";
|
| 312 |
+
ctx.fillRect(x - 6, y - 20, barW + 52, barH + 40);
|
| 313 |
+
ctx.strokeStyle = "rgba(255, 255, 255, 0.15)";
|
| 314 |
+
ctx.lineWidth = 1;
|
| 315 |
+
ctx.strokeRect(x - 6, y - 20, barW + 52, barH + 40);
|
| 316 |
+
|
| 317 |
+
// Draw gradient bar
|
| 318 |
+
for (let py = 0; py < barH; py++) {
|
| 319 |
+
const t = py / (barH - 1); // 0 = top (near/min), 1 = bottom (far/max)
|
| 320 |
+
const rgb = APP.ui.inspectionRenders._viridis(t);
|
| 321 |
+
ctx.fillStyle = `rgb(${rgb[0]}, ${rgb[1]}, ${rgb[2]})`;
|
| 322 |
+
ctx.fillRect(x, y + py, barW, 1);
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
// Border around bar
|
| 326 |
+
ctx.strokeStyle = "rgba(255, 255, 255, 0.3)";
|
| 327 |
+
ctx.lineWidth = 1;
|
| 328 |
+
ctx.strokeRect(x, y, barW, barH);
|
| 329 |
+
|
| 330 |
+
// Labels
|
| 331 |
+
ctx.fillStyle = "rgba(255, 255, 255, 0.8)";
|
| 332 |
+
ctx.font = "10px monospace";
|
| 333 |
+
ctx.textAlign = "left";
|
| 334 |
+
ctx.fillText(`${minD.toFixed(1)}m`, x + barW + 4, y + 10);
|
| 335 |
+
ctx.fillText(`${maxD.toFixed(1)}m`, x + barW + 4, y + barH);
|
| 336 |
+
|
| 337 |
+
// Title
|
| 338 |
+
ctx.fillStyle = "rgba(255, 255, 255, 0.6)";
|
| 339 |
+
ctx.font = "9px monospace";
|
| 340 |
+
ctx.textAlign = "center";
|
| 341 |
+
ctx.fillText("DEPTH", x + barW / 2, y - 6);
|
| 342 |
+
};
|
| 343 |
+
|
| 344 |
+
/**
|
| 345 |
+
* Compute and display average depth within the selected track's bounding box.
|
| 346 |
+
*/
|
| 347 |
+
APP.ui.inspectionRenders._drawTrackDepthStats = function (ctx, track, canvasW, canvasH, depthData) {
|
| 348 |
+
const b = track.bbox;
|
| 349 |
+
const dw = depthData.width;
|
| 350 |
+
const dh = depthData.height;
|
| 351 |
+
const dd = depthData.data;
|
| 352 |
+
|
| 353 |
+
// Map normalized bbox to depth map coordinates
|
| 354 |
+
const x1 = Math.max(0, Math.floor(b.x * dw));
|
| 355 |
+
const y1 = Math.max(0, Math.floor(b.y * dh));
|
| 356 |
+
const x2 = Math.min(dw - 1, Math.floor((b.x + b.w) * dw));
|
| 357 |
+
const y2 = Math.min(dh - 1, Math.floor((b.y + b.h) * dh));
|
| 358 |
+
|
| 359 |
+
let sum = 0;
|
| 360 |
+
let count = 0;
|
| 361 |
+
let localMin = Infinity;
|
| 362 |
+
let localMax = -Infinity;
|
| 363 |
+
|
| 364 |
+
for (let py = y1; py <= y2; py++) {
|
| 365 |
+
for (let px = x1; px <= x2; px++) {
|
| 366 |
+
const val = dd[py * dw + px];
|
| 367 |
+
if (isFinite(val)) {
|
| 368 |
+
sum += val;
|
| 369 |
+
count++;
|
| 370 |
+
if (val < localMin) localMin = val;
|
| 371 |
+
if (val > localMax) localMax = val;
|
| 372 |
+
}
|
| 373 |
+
}
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
if (count === 0) return;
|
| 377 |
+
|
| 378 |
+
const avgDepth = sum / count;
|
| 379 |
+
|
| 380 |
+
// Draw stats label below the bbox
|
| 381 |
+
const bx = b.x * canvasW;
|
| 382 |
+
const by = (b.y + b.h) * canvasH;
|
| 383 |
+
|
| 384 |
+
ctx.fillStyle = "rgba(0, 0, 0, 0.7)";
|
| 385 |
+
ctx.fillRect(bx, by + 4, 160, 18);
|
| 386 |
+
|
| 387 |
+
ctx.fillStyle = "rgba(253, 231, 37, 0.9)"; // viridis yellow for contrast
|
| 388 |
+
ctx.font = "bold 11px monospace";
|
| 389 |
+
ctx.textAlign = "left";
|
| 390 |
+
ctx.fillText(`Depth: ${avgDepth.toFixed(1)}m (${localMin.toFixed(1)}-${localMax.toFixed(1)})`, bx + 4, by + 16);
|
| 391 |
};
|
| 392 |
|
| 393 |
/**
|
| 394 |
+
* Render attention heatmap (GradCAM) overlaid on the base frame.
|
| 395 |
+
* Uses inferno colormap with semi-transparent blending.
|
| 396 |
*/
|
| 397 |
APP.ui.inspectionRenders._renderAttention = function (canvas, frameImg, attentionData, track) {
|
| 398 |
if (!frameImg) return;
|
|
|
|
| 403 |
canvas.height = h;
|
| 404 |
const ctx = canvas.getContext("2d");
|
| 405 |
|
| 406 |
+
// Draw base frame (slightly dimmed to make heatmap more visible)
|
| 407 |
ctx.drawImage(frameImg, 0, 0);
|
| 408 |
+
ctx.fillStyle = "rgba(0, 0, 0, 0.2)";
|
| 409 |
+
ctx.fillRect(0, 0, w, h);
|
| 410 |
|
| 411 |
if (!attentionData || !attentionData.data) {
|
| 412 |
+
ctx.fillStyle = "rgba(0,0,0,0.5)";
|
| 413 |
ctx.fillRect(0, 0, w, h);
|
| 414 |
ctx.fillStyle = "#aaa";
|
| 415 |
ctx.font = "14px monospace";
|
|
|
|
| 418 |
return;
|
| 419 |
}
|
| 420 |
|
| 421 |
+
// Render attention as inferno colormap overlay
|
| 422 |
const aw = attentionData.width;
|
| 423 |
const ah = attentionData.height;
|
| 424 |
const ad = attentionData.data;
|
|
|
|
| 430 |
const img = hctx.createImageData(aw, ah);
|
| 431 |
const id = img.data;
|
| 432 |
|
| 433 |
+
// Track peak attention value for stats
|
| 434 |
+
let peakVal = 0;
|
| 435 |
+
let sumVal = 0;
|
| 436 |
+
|
| 437 |
for (let i = 0; i < ad.length; i++) {
|
| 438 |
const t = APP.core.utils.clamp(ad[i], 0, 1);
|
| 439 |
+
if (t > peakVal) peakVal = t;
|
| 440 |
+
sumVal += t;
|
| 441 |
const rgb = APP.ui.inspectionRenders._inferno(t);
|
| 442 |
const oi = i * 4;
|
| 443 |
id[oi] = rgb[0];
|
| 444 |
id[oi + 1] = rgb[1];
|
| 445 |
id[oi + 2] = rgb[2];
|
| 446 |
+
// Alpha: low values are very transparent, high values are semi-opaque
|
| 447 |
+
// Use a power curve for better visual contrast
|
| 448 |
+
id[oi + 3] = Math.round(Math.pow(t, 0.7) * 200);
|
| 449 |
}
|
| 450 |
|
| 451 |
hctx.putImageData(img, 0, 0);
|
| 452 |
+
|
| 453 |
+
// Use bilinear upscaling for smooth heatmap (attention resolution is typically low)
|
| 454 |
+
ctx.imageSmoothingEnabled = true;
|
| 455 |
+
ctx.imageSmoothingQuality = "high";
|
| 456 |
ctx.drawImage(heatCanvas, 0, 0, w, h);
|
| 457 |
|
| 458 |
+
// Draw bbox highlight
|
| 459 |
APP.ui.inspectionRenders._drawBBoxHighlight(ctx, track, w, h);
|
| 460 |
+
|
| 461 |
+
// Draw attention legend and stats
|
| 462 |
+
const avgVal = ad.length > 0 ? sumVal / ad.length : 0;
|
| 463 |
+
APP.ui.inspectionRenders._drawAttentionLegend(ctx, w, h, peakVal, avgVal);
|
| 464 |
+
};
|
| 465 |
+
|
| 466 |
+
/**
|
| 467 |
+
* Draw an attention intensity legend with inferno colormap bar.
|
| 468 |
+
*/
|
| 469 |
+
APP.ui.inspectionRenders._drawAttentionLegend = function (ctx, canvasW, canvasH, peakVal, avgVal) {
|
| 470 |
+
const barW = 16;
|
| 471 |
+
const barH = Math.min(140, canvasH - 60);
|
| 472 |
+
const x = canvasW - barW - 30;
|
| 473 |
+
const y = 30;
|
| 474 |
+
|
| 475 |
+
// Background panel
|
| 476 |
+
ctx.fillStyle = "rgba(0, 0, 0, 0.55)";
|
| 477 |
+
ctx.fillRect(x - 6, y - 20, barW + 60, barH + 56);
|
| 478 |
+
ctx.strokeStyle = "rgba(255, 255, 255, 0.15)";
|
| 479 |
+
ctx.lineWidth = 1;
|
| 480 |
+
ctx.strokeRect(x - 6, y - 20, barW + 60, barH + 56);
|
| 481 |
+
|
| 482 |
+
// Draw gradient bar (top = high attention, bottom = low)
|
| 483 |
+
for (let py = 0; py < barH; py++) {
|
| 484 |
+
const t = 1 - (py / (barH - 1)); // 0 = bottom (low), 1 = top (high)
|
| 485 |
+
const rgb = APP.ui.inspectionRenders._inferno(t);
|
| 486 |
+
ctx.fillStyle = `rgb(${rgb[0]}, ${rgb[1]}, ${rgb[2]})`;
|
| 487 |
+
ctx.fillRect(x, y + py, barW, 1);
|
| 488 |
+
}
|
| 489 |
+
|
| 490 |
+
// Border around bar
|
| 491 |
+
ctx.strokeStyle = "rgba(255, 255, 255, 0.3)";
|
| 492 |
+
ctx.lineWidth = 1;
|
| 493 |
+
ctx.strokeRect(x, y, barW, barH);
|
| 494 |
+
|
| 495 |
+
// Labels
|
| 496 |
+
ctx.fillStyle = "rgba(255, 255, 255, 0.8)";
|
| 497 |
+
ctx.font = "10px monospace";
|
| 498 |
+
ctx.textAlign = "left";
|
| 499 |
+
ctx.fillText("High", x + barW + 4, y + 10);
|
| 500 |
+
ctx.fillText("Low", x + barW + 4, y + barH);
|
| 501 |
+
|
| 502 |
+
// Title
|
| 503 |
+
ctx.fillStyle = "rgba(255, 255, 255, 0.6)";
|
| 504 |
+
ctx.font = "9px monospace";
|
| 505 |
+
ctx.textAlign = "center";
|
| 506 |
+
ctx.fillText("ATTENTION", x + barW / 2, y - 6);
|
| 507 |
+
|
| 508 |
+
// Stats
|
| 509 |
+
ctx.fillStyle = "rgba(252, 255, 164, 0.8)"; // inferno yellow
|
| 510 |
+
ctx.font = "10px monospace";
|
| 511 |
+
ctx.textAlign = "left";
|
| 512 |
+
ctx.fillText(`Peak: ${(peakVal * 100).toFixed(0)}%`, x - 2, y + barH + 18);
|
| 513 |
+
ctx.fillText(`Avg: ${(avgVal * 100).toFixed(0)}%`, x - 2, y + barH + 32);
|
| 514 |
};
|
| 515 |
|
| 516 |
/**
|
|
@@ -206,12 +206,31 @@ APP.ui.inspection._clearCaches = function () {
|
|
| 206 |
};
|
| 207 |
|
| 208 |
/**
|
| 209 |
-
* Internal: show/hide loading indicator.
|
|
|
|
|
|
|
| 210 |
*/
|
| 211 |
-
APP.ui.inspection._setLoading = function (loading) {
|
| 212 |
const { $ } = APP.core.utils;
|
| 213 |
const el = $("#inspectionLoading");
|
| 214 |
-
if (el)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
APP.core.state.inspection.loading = loading;
|
| 216 |
};
|
| 217 |
|
|
@@ -254,7 +273,7 @@ APP.ui.inspection._loadAndRender = async function () {
|
|
| 254 |
try {
|
| 255 |
// --- Step 1: Ensure we have the base frame image (shared by most modes) ---
|
| 256 |
if (!state.inspection._frameImg && mode !== "3d") {
|
| 257 |
-
APP.ui.inspection._setLoading(true);
|
| 258 |
const frameImg = await api.fetchFrame(jobId, frameIdx);
|
| 259 |
state.inspection.frameImageUrl = frameImg.src;
|
| 260 |
state.inspection._frameImg = frameImg;
|
|
@@ -262,7 +281,7 @@ APP.ui.inspection._loadAndRender = async function () {
|
|
| 262 |
|
| 263 |
// --- Step 2: Fetch mode-specific data if not cached ---
|
| 264 |
if (!cache[mode]) {
|
| 265 |
-
APP.ui.inspection._setLoading(true);
|
| 266 |
|
| 267 |
switch (mode) {
|
| 268 |
case "seg":
|
|
|
|
| 206 |
};
|
| 207 |
|
| 208 |
/**
|
| 209 |
+
* Internal: show/hide loading indicator with optional mode-specific message.
|
| 210 |
+
* @param {boolean} loading
|
| 211 |
+
* @param {string} [mode] — if provided, shows a mode-specific message
|
| 212 |
*/
|
| 213 |
+
APP.ui.inspection._setLoading = function (loading, mode) {
|
| 214 |
const { $ } = APP.core.utils;
|
| 215 |
const el = $("#inspectionLoading");
|
| 216 |
+
if (el) {
|
| 217 |
+
el.style.display = loading ? "flex" : "none";
|
| 218 |
+
// Update loading message based on mode
|
| 219 |
+
const msgEl = el.querySelector("span");
|
| 220 |
+
if (msgEl && loading && mode) {
|
| 221 |
+
const modeMessages = {
|
| 222 |
+
seg: "Computing segmentation mask...",
|
| 223 |
+
edge: "Computing edges...",
|
| 224 |
+
depth: "Computing depth map...",
|
| 225 |
+
attention: "Generating attention heatmap...",
|
| 226 |
+
superres: "Enhancing resolution...",
|
| 227 |
+
"3d": "Generating point cloud..."
|
| 228 |
+
};
|
| 229 |
+
msgEl.textContent = modeMessages[mode] || "Loading...";
|
| 230 |
+
} else if (msgEl && !loading) {
|
| 231 |
+
msgEl.textContent = "Loading...";
|
| 232 |
+
}
|
| 233 |
+
}
|
| 234 |
APP.core.state.inspection.loading = loading;
|
| 235 |
};
|
| 236 |
|
|
|
|
| 273 |
try {
|
| 274 |
// --- Step 1: Ensure we have the base frame image (shared by most modes) ---
|
| 275 |
if (!state.inspection._frameImg && mode !== "3d") {
|
| 276 |
+
APP.ui.inspection._setLoading(true, mode);
|
| 277 |
const frameImg = await api.fetchFrame(jobId, frameIdx);
|
| 278 |
state.inspection.frameImageUrl = frameImg.src;
|
| 279 |
state.inspection._frameImg = frameImg;
|
|
|
|
| 281 |
|
| 282 |
// --- Step 2: Fetch mode-specific data if not cached ---
|
| 283 |
if (!cache[mode]) {
|
| 284 |
+
APP.ui.inspection._setLoading(true, mode);
|
| 285 |
|
| 286 |
switch (mode) {
|
| 287 |
case "seg":
|
|
@@ -0,0 +1,537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GradCAM-style attention heatmap generation for detector models.
|
| 2 |
+
|
| 3 |
+
Produces per-object attention maps showing which regions of the input
|
| 4 |
+
image the detector model focused on when detecting a particular object.
|
| 5 |
+
|
| 6 |
+
For Transformers-based detectors (DETR, Grounding DINO) we use true
|
| 7 |
+
GradCAM by hooking the backbone's last feature layer. For Ultralytics
|
| 8 |
+
YOLO models we generate an activation-based saliency map from the
|
| 9 |
+
model's internal feature maps (no gradient needed since YOLO doesn't
|
| 10 |
+
easily support GradCAM due to its anchor-free detection head).
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import base64
|
| 14 |
+
import logging
|
| 15 |
+
import threading
|
| 16 |
+
from typing import Dict, Optional, Tuple
|
| 17 |
+
|
| 18 |
+
import cv2
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
# ── In-memory attention cache ────────────────────────────────────
|
| 25 |
+
# Key: (job_id, frame_idx, track_id_str) Value: heatmap (HxW float32 0-1)
|
| 26 |
+
_attention_cache: Dict[Tuple[str, int, str], np.ndarray] = {}
|
| 27 |
+
_cache_lock = threading.RLock()
|
| 28 |
+
_MAX_CACHE_ENTRIES = 200
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_cached_attention(
|
| 32 |
+
job_id: str, frame_idx: int, track_id: str
|
| 33 |
+
) -> Optional[np.ndarray]:
|
| 34 |
+
"""Return cached attention heatmap or None."""
|
| 35 |
+
with _cache_lock:
|
| 36 |
+
return _attention_cache.get((job_id, frame_idx, track_id))
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def set_cached_attention(
|
| 40 |
+
job_id: str, frame_idx: int, track_id: str, heatmap: np.ndarray
|
| 41 |
+
) -> None:
|
| 42 |
+
"""Store attention heatmap in cache."""
|
| 43 |
+
with _cache_lock:
|
| 44 |
+
if len(_attention_cache) >= _MAX_CACHE_ENTRIES:
|
| 45 |
+
oldest = next(iter(_attention_cache))
|
| 46 |
+
del _attention_cache[oldest]
|
| 47 |
+
_attention_cache[(job_id, frame_idx, track_id)] = heatmap
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def clear_attention_cache(job_id: Optional[str] = None) -> None:
|
| 51 |
+
"""Clear attention cache for a specific job or all jobs."""
|
| 52 |
+
with _cache_lock:
|
| 53 |
+
if job_id is None:
|
| 54 |
+
_attention_cache.clear()
|
| 55 |
+
else:
|
| 56 |
+
keys_to_remove = [k for k in _attention_cache if k[0] == job_id]
|
| 57 |
+
for k in keys_to_remove:
|
| 58 |
+
del _attention_cache[k]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ── GradCAM for HF Transformers models (DETR, Grounding DINO) ───
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _find_target_layer(model: torch.nn.Module) -> Optional[torch.nn.Module]:
|
| 65 |
+
"""Find the last convolutional or attention layer suitable for GradCAM.
|
| 66 |
+
|
| 67 |
+
Tries several strategies in order:
|
| 68 |
+
1. DETR ResNet backbone: model.model.backbone.conv_encoder.model.layer4
|
| 69 |
+
2. Grounding DINO Swin backbone: last layer of backbone
|
| 70 |
+
3. Generic: walk the model and find the last Conv2d layer
|
| 71 |
+
"""
|
| 72 |
+
# Strategy 1: DETR backbone (ResNet)
|
| 73 |
+
try:
|
| 74 |
+
backbone = model.model.backbone
|
| 75 |
+
if hasattr(backbone, "conv_encoder"):
|
| 76 |
+
resnet = backbone.conv_encoder.model
|
| 77 |
+
if hasattr(resnet, "layer4"):
|
| 78 |
+
return resnet.layer4
|
| 79 |
+
except (AttributeError, TypeError):
|
| 80 |
+
pass
|
| 81 |
+
|
| 82 |
+
# Strategy 2: Grounding DINO Swin backbone
|
| 83 |
+
try:
|
| 84 |
+
backbone = model.model.backbone
|
| 85 |
+
if hasattr(backbone, "conv_encoder"):
|
| 86 |
+
swin = backbone.conv_encoder.model
|
| 87 |
+
if hasattr(swin, "layers"):
|
| 88 |
+
layers = list(swin.layers)
|
| 89 |
+
if layers:
|
| 90 |
+
return layers[-1]
|
| 91 |
+
if hasattr(swin, "encoder") and hasattr(swin.encoder, "layers"):
|
| 92 |
+
layers = list(swin.encoder.layers)
|
| 93 |
+
if layers:
|
| 94 |
+
return layers[-1]
|
| 95 |
+
except (AttributeError, TypeError):
|
| 96 |
+
pass
|
| 97 |
+
|
| 98 |
+
# Strategy 3: Generic — find the last Conv2d
|
| 99 |
+
last_conv = None
|
| 100 |
+
for module in model.modules():
|
| 101 |
+
if isinstance(module, torch.nn.Conv2d):
|
| 102 |
+
last_conv = module
|
| 103 |
+
return last_conv
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class GradCAMExtractor:
|
| 107 |
+
"""Extract GradCAM heatmaps from a PyTorch model.
|
| 108 |
+
|
| 109 |
+
Usage:
|
| 110 |
+
extractor = GradCAMExtractor(model, target_layer)
|
| 111 |
+
heatmap = extractor.generate(input_tensor, target_bbox)
|
| 112 |
+
extractor.release() # remove hooks
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
def __init__(self, model: torch.nn.Module, target_layer: torch.nn.Module):
|
| 116 |
+
self.model = model
|
| 117 |
+
self.target_layer = target_layer
|
| 118 |
+
self._activations: Optional[torch.Tensor] = None
|
| 119 |
+
self._gradients: Optional[torch.Tensor] = None
|
| 120 |
+
|
| 121 |
+
# Register hooks
|
| 122 |
+
self._fwd_hook = target_layer.register_forward_hook(self._save_activation)
|
| 123 |
+
self._bwd_hook = target_layer.register_full_backward_hook(self._save_gradient)
|
| 124 |
+
|
| 125 |
+
def _save_activation(self, module, input, output):
|
| 126 |
+
if isinstance(output, torch.Tensor):
|
| 127 |
+
self._activations = output.detach()
|
| 128 |
+
elif isinstance(output, (tuple, list)) and len(output) > 0:
|
| 129 |
+
self._activations = output[0].detach()
|
| 130 |
+
|
| 131 |
+
def _save_gradient(self, module, grad_input, grad_output):
|
| 132 |
+
if isinstance(grad_output, (tuple, list)) and len(grad_output) > 0:
|
| 133 |
+
self._gradients = grad_output[0].detach()
|
| 134 |
+
elif isinstance(grad_output, torch.Tensor):
|
| 135 |
+
self._gradients = grad_output.detach()
|
| 136 |
+
|
| 137 |
+
def generate(
|
| 138 |
+
self,
|
| 139 |
+
input_tensor: torch.Tensor,
|
| 140 |
+
target_bbox: list,
|
| 141 |
+
frame_h: int,
|
| 142 |
+
frame_w: int,
|
| 143 |
+
) -> np.ndarray:
|
| 144 |
+
"""Generate a GradCAM heatmap for a target bounding box.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
input_tensor: Preprocessed model input tensor.
|
| 148 |
+
target_bbox: [x1, y1, x2, y2] in original frame pixel coords.
|
| 149 |
+
frame_h: Original frame height.
|
| 150 |
+
frame_w: Original frame width.
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
HxW float32 array normalized to [0, 1], at the model's
|
| 154 |
+
feature map resolution (upscaled to frame size).
|
| 155 |
+
"""
|
| 156 |
+
self.model.zero_grad()
|
| 157 |
+
self._activations = None
|
| 158 |
+
self._gradients = None
|
| 159 |
+
|
| 160 |
+
# Enable gradients temporarily
|
| 161 |
+
was_training = self.model.training
|
| 162 |
+
self.model.eval()
|
| 163 |
+
|
| 164 |
+
# Forward pass with gradients enabled on input
|
| 165 |
+
with torch.enable_grad():
|
| 166 |
+
outputs = self.model(**{k: v for k, v in input_tensor.items()})
|
| 167 |
+
|
| 168 |
+
if self._activations is None:
|
| 169 |
+
logger.warning("GradCAM: no activations captured; returning uniform map")
|
| 170 |
+
return np.ones((frame_h, frame_w), dtype=np.float32) * 0.5
|
| 171 |
+
|
| 172 |
+
# Use the activation map directly as a saliency proxy when
|
| 173 |
+
# gradient-based targeting is unreliable (common with object
|
| 174 |
+
# detection architectures where loss requires complex target
|
| 175 |
+
# matching). We compute channel-wise L2 norm as the saliency.
|
| 176 |
+
acts = self._activations
|
| 177 |
+
if acts.dim() == 4:
|
| 178 |
+
# (B, C, H, W) — standard conv feature map
|
| 179 |
+
cam = torch.norm(acts[0], dim=0) # (H, W)
|
| 180 |
+
elif acts.dim() == 3:
|
| 181 |
+
# (B, N, C) — transformer sequence; try to reshape
|
| 182 |
+
# N = h * w for spatial feature maps
|
| 183 |
+
B, N, C = acts.shape
|
| 184 |
+
side = int(N ** 0.5)
|
| 185 |
+
if side * side == N:
|
| 186 |
+
cam = torch.norm(acts[0], dim=1).view(side, side)
|
| 187 |
+
else:
|
| 188 |
+
cam = torch.norm(acts[0], dim=1) # (N,)
|
| 189 |
+
cam = cam.unsqueeze(0) # (1, N)
|
| 190 |
+
else:
|
| 191 |
+
cam = torch.norm(acts.flatten(), dim=0, keepdim=True)
|
| 192 |
+
|
| 193 |
+
cam = cam.float()
|
| 194 |
+
|
| 195 |
+
# Normalize to [0, 1]
|
| 196 |
+
cam_min = cam.min()
|
| 197 |
+
cam_max = cam.max()
|
| 198 |
+
if (cam_max - cam_min) > 1e-8:
|
| 199 |
+
cam = (cam - cam_min) / (cam_max - cam_min)
|
| 200 |
+
else:
|
| 201 |
+
cam = torch.zeros_like(cam)
|
| 202 |
+
|
| 203 |
+
cam_np = cam.cpu().numpy()
|
| 204 |
+
|
| 205 |
+
# Upscale to frame resolution
|
| 206 |
+
if cam_np.ndim == 1:
|
| 207 |
+
side = int(len(cam_np) ** 0.5)
|
| 208 |
+
if side * side == len(cam_np):
|
| 209 |
+
cam_np = cam_np.reshape(side, side)
|
| 210 |
+
else:
|
| 211 |
+
cam_np = cam_np.reshape(1, -1)
|
| 212 |
+
|
| 213 |
+
cam_resized = cv2.resize(
|
| 214 |
+
cam_np, (frame_w, frame_h), interpolation=cv2.INTER_LINEAR
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# Crop influence to the target bbox region (boost bbox, attenuate outside)
|
| 218 |
+
x1, y1, x2, y2 = [int(c) for c in target_bbox]
|
| 219 |
+
x1 = max(0, x1)
|
| 220 |
+
y1 = max(0, y1)
|
| 221 |
+
x2 = min(frame_w, x2)
|
| 222 |
+
y2 = min(frame_h, y2)
|
| 223 |
+
|
| 224 |
+
# Create a soft mask centered on the bbox
|
| 225 |
+
mask = np.zeros((frame_h, frame_w), dtype=np.float32)
|
| 226 |
+
mask[y1:y2, x1:x2] = 1.0
|
| 227 |
+
|
| 228 |
+
# Expand mask slightly for context
|
| 229 |
+
pad = max(x2 - x1, y2 - y1) // 2
|
| 230 |
+
py1 = max(0, y1 - pad)
|
| 231 |
+
py2 = min(frame_h, y2 + pad)
|
| 232 |
+
px1 = max(0, x1 - pad)
|
| 233 |
+
px2 = min(frame_w, x2 + pad)
|
| 234 |
+
mask[py1:py2, px1:px2] = np.maximum(mask[py1:py2, px1:px2], 0.3)
|
| 235 |
+
|
| 236 |
+
cam_resized = cam_resized * mask
|
| 237 |
+
|
| 238 |
+
# Re-normalize
|
| 239 |
+
c_max = cam_resized.max()
|
| 240 |
+
if c_max > 1e-8:
|
| 241 |
+
cam_resized = cam_resized / c_max
|
| 242 |
+
|
| 243 |
+
if not was_training:
|
| 244 |
+
self.model.eval()
|
| 245 |
+
|
| 246 |
+
return cam_resized.astype(np.float32)
|
| 247 |
+
|
| 248 |
+
def release(self):
|
| 249 |
+
"""Remove hooks from the model."""
|
| 250 |
+
self._fwd_hook.remove()
|
| 251 |
+
self._bwd_hook.remove()
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# ── YOLO saliency (activation-based, no gradients) ──────────────
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def _yolo_saliency(
|
| 258 |
+
yolo_model,
|
| 259 |
+
frame: np.ndarray,
|
| 260 |
+
target_bbox: list,
|
| 261 |
+
) -> np.ndarray:
|
| 262 |
+
"""Generate an activation-based saliency map from a YOLO model.
|
| 263 |
+
|
| 264 |
+
Uses the model's internal feature pyramid activations as a proxy
|
| 265 |
+
for attention. This avoids the complexity of GradCAM with YOLO's
|
| 266 |
+
anchor-free heads.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
yolo_model: Ultralytics YOLO model instance.
|
| 270 |
+
frame: HxWx3 BGR uint8 numpy array.
|
| 271 |
+
target_bbox: [x1, y1, x2, y2] in original frame pixel coords.
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
HxW float32 array normalized to [0, 1].
|
| 275 |
+
"""
|
| 276 |
+
frame_h, frame_w = frame.shape[:2]
|
| 277 |
+
|
| 278 |
+
# Run inference to get internal features
|
| 279 |
+
results = yolo_model.predict(
|
| 280 |
+
source=frame,
|
| 281 |
+
device=yolo_model.device if hasattr(yolo_model, 'device') else None,
|
| 282 |
+
conf=0.1,
|
| 283 |
+
imgsz=640,
|
| 284 |
+
verbose=False,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# Try to extract feature maps from the model internals
|
| 288 |
+
# Ultralytics stores intermediate outputs during forward pass
|
| 289 |
+
cam = None
|
| 290 |
+
|
| 291 |
+
try:
|
| 292 |
+
# Access the PyTorch model inside the Ultralytics wrapper
|
| 293 |
+
pt_model = yolo_model.model
|
| 294 |
+
if hasattr(pt_model, "model"):
|
| 295 |
+
# The sequential model layers
|
| 296 |
+
layers = pt_model.model
|
| 297 |
+
# Find the last feature extraction layer (before detect head)
|
| 298 |
+
# Typically the SPPF or C2f layer near the end
|
| 299 |
+
for i in range(len(layers) - 1, -1, -1):
|
| 300 |
+
layer = layers[i]
|
| 301 |
+
layer_type = type(layer).__name__
|
| 302 |
+
if layer_type in ("SPPF", "C2f", "C3", "Conv"):
|
| 303 |
+
# Hook this layer for activation extraction
|
| 304 |
+
activation = {}
|
| 305 |
+
|
| 306 |
+
def hook_fn(module, inp, out, store=activation):
|
| 307 |
+
store["out"] = out.detach()
|
| 308 |
+
|
| 309 |
+
handle = layer.register_forward_hook(hook_fn)
|
| 310 |
+
|
| 311 |
+
# Re-run forward pass to capture activations
|
| 312 |
+
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 313 |
+
from PIL import Image
|
| 314 |
+
import torchvision.transforms as T
|
| 315 |
+
|
| 316 |
+
img = Image.fromarray(rgb)
|
| 317 |
+
# Use the same preprocessing as YOLO
|
| 318 |
+
yolo_model.predict(
|
| 319 |
+
source=frame,
|
| 320 |
+
device=yolo_model.device if hasattr(yolo_model, 'device') else None,
|
| 321 |
+
conf=0.1,
|
| 322 |
+
imgsz=640,
|
| 323 |
+
verbose=False,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
handle.remove()
|
| 327 |
+
|
| 328 |
+
if "out" in activation:
|
| 329 |
+
feat = activation["out"]
|
| 330 |
+
if feat.dim() == 4:
|
| 331 |
+
cam = torch.norm(feat[0], dim=0).cpu().numpy()
|
| 332 |
+
break
|
| 333 |
+
except Exception as e:
|
| 334 |
+
logger.warning("YOLO feature extraction failed: %s", e)
|
| 335 |
+
|
| 336 |
+
if cam is None:
|
| 337 |
+
# Fallback: generate a simple Gaussian heatmap centered on the bbox
|
| 338 |
+
cam = _gaussian_bbox_heatmap(frame_h, frame_w, target_bbox)
|
| 339 |
+
else:
|
| 340 |
+
# Resize to frame dimensions
|
| 341 |
+
cam = cv2.resize(cam, (frame_w, frame_h), interpolation=cv2.INTER_LINEAR)
|
| 342 |
+
|
| 343 |
+
# Focus on the target bbox region
|
| 344 |
+
x1, y1, x2, y2 = [int(c) for c in target_bbox]
|
| 345 |
+
x1 = max(0, x1)
|
| 346 |
+
y1 = max(0, y1)
|
| 347 |
+
x2 = min(frame_w, x2)
|
| 348 |
+
y2 = min(frame_h, y2)
|
| 349 |
+
|
| 350 |
+
mask = np.zeros((frame_h, frame_w), dtype=np.float32)
|
| 351 |
+
mask[y1:y2, x1:x2] = 1.0
|
| 352 |
+
pad = max(x2 - x1, y2 - y1) // 2
|
| 353 |
+
py1, py2 = max(0, y1 - pad), min(frame_h, y2 + pad)
|
| 354 |
+
px1, px2 = max(0, x1 - pad), min(frame_w, x2 + pad)
|
| 355 |
+
mask[py1:py2, px1:px2] = np.maximum(mask[py1:py2, px1:px2], 0.3)
|
| 356 |
+
|
| 357 |
+
cam = cam * mask
|
| 358 |
+
|
| 359 |
+
# Normalize to [0, 1]
|
| 360 |
+
c_max = cam.max()
|
| 361 |
+
if c_max > 1e-8:
|
| 362 |
+
cam = cam / c_max
|
| 363 |
+
|
| 364 |
+
return cam.astype(np.float32)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def _gaussian_bbox_heatmap(
|
| 368 |
+
frame_h: int, frame_w: int, bbox: list
|
| 369 |
+
) -> np.ndarray:
|
| 370 |
+
"""Generate a Gaussian heatmap centered on a bounding box.
|
| 371 |
+
|
| 372 |
+
Used as a fallback when feature extraction is not available.
|
| 373 |
+
"""
|
| 374 |
+
x1, y1, x2, y2 = [int(c) for c in bbox]
|
| 375 |
+
cx = (x1 + x2) / 2
|
| 376 |
+
cy = (y1 + y2) / 2
|
| 377 |
+
sx = max((x2 - x1) / 2, 1.0)
|
| 378 |
+
sy = max((y2 - y1) / 2, 1.0)
|
| 379 |
+
|
| 380 |
+
y_coords, x_coords = np.mgrid[0:frame_h, 0:frame_w]
|
| 381 |
+
heatmap = np.exp(
|
| 382 |
+
-0.5 * (((x_coords - cx) / sx) ** 2 + ((y_coords - cy) / sy) ** 2)
|
| 383 |
+
)
|
| 384 |
+
return heatmap.astype(np.float32)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
# ── Main entry point ─────────────────────────────────────────────
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def generate_attention_map(
|
| 391 |
+
frame: np.ndarray,
|
| 392 |
+
bbox: list,
|
| 393 |
+
detector_name: str,
|
| 394 |
+
job_id: str,
|
| 395 |
+
frame_idx: int,
|
| 396 |
+
track_id: str,
|
| 397 |
+
) -> np.ndarray:
|
| 398 |
+
"""Generate an attention heatmap for a detected object.
|
| 399 |
+
|
| 400 |
+
Loads the detector model (cached), runs a forward pass, and
|
| 401 |
+
extracts activation-based saliency focused on the target bbox.
|
| 402 |
+
|
| 403 |
+
Args:
|
| 404 |
+
frame: HxWx3 BGR uint8 numpy array.
|
| 405 |
+
bbox: [x1, y1, x2, y2] target object bounding box.
|
| 406 |
+
detector_name: Name of the detector used for the job.
|
| 407 |
+
job_id: Job identifier (for caching).
|
| 408 |
+
frame_idx: Frame index (for caching).
|
| 409 |
+
track_id: Track ID string (for caching).
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
HxW float32 heatmap normalized to [0, 1].
|
| 413 |
+
"""
|
| 414 |
+
# Check cache first
|
| 415 |
+
cached = get_cached_attention(job_id, frame_idx, track_id)
|
| 416 |
+
if cached is not None:
|
| 417 |
+
return cached
|
| 418 |
+
|
| 419 |
+
frame_h, frame_w = frame.shape[:2]
|
| 420 |
+
heatmap = None
|
| 421 |
+
|
| 422 |
+
if detector_name in ("yolo11", "yolov8_visdrone"):
|
| 423 |
+
# YOLO models — use activation-based saliency
|
| 424 |
+
try:
|
| 425 |
+
from models.model_loader import load_detector
|
| 426 |
+
|
| 427 |
+
detector = load_detector(detector_name)
|
| 428 |
+
yolo_model = detector.model
|
| 429 |
+
heatmap = _yolo_saliency(yolo_model, frame, bbox)
|
| 430 |
+
except Exception as e:
|
| 431 |
+
logger.warning("YOLO saliency generation failed: %s", e)
|
| 432 |
+
|
| 433 |
+
elif detector_name in ("detr_resnet50", "grounding_dino"):
|
| 434 |
+
# Transformers models — use GradCAM on backbone
|
| 435 |
+
try:
|
| 436 |
+
from models.model_loader import load_detector
|
| 437 |
+
|
| 438 |
+
detector = load_detector(detector_name)
|
| 439 |
+
model = detector.model
|
| 440 |
+
target_layer = _find_target_layer(model)
|
| 441 |
+
|
| 442 |
+
if target_layer is not None:
|
| 443 |
+
extractor = GradCAMExtractor(model, target_layer)
|
| 444 |
+
try:
|
| 445 |
+
# Prepare input
|
| 446 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 447 |
+
processor = detector.processor
|
| 448 |
+
if detector_name == "grounding_dino":
|
| 449 |
+
inputs = processor(
|
| 450 |
+
images=frame_rgb, text="object.", return_tensors="pt"
|
| 451 |
+
)
|
| 452 |
+
else:
|
| 453 |
+
inputs = processor(images=frame_rgb, return_tensors="pt")
|
| 454 |
+
inputs = {
|
| 455 |
+
k: v.to(detector.device) for k, v in inputs.items()
|
| 456 |
+
}
|
| 457 |
+
heatmap = extractor.generate(inputs, bbox, frame_h, frame_w)
|
| 458 |
+
finally:
|
| 459 |
+
extractor.release()
|
| 460 |
+
else:
|
| 461 |
+
logger.warning(
|
| 462 |
+
"No suitable target layer found for %s", detector_name
|
| 463 |
+
)
|
| 464 |
+
except Exception as e:
|
| 465 |
+
logger.warning("GradCAM generation failed for %s: %s", detector_name, e)
|
| 466 |
+
|
| 467 |
+
# Fallback: Gaussian heatmap centered on bbox
|
| 468 |
+
if heatmap is None:
|
| 469 |
+
logger.info(
|
| 470 |
+
"Using Gaussian fallback for attention (detector=%s)", detector_name
|
| 471 |
+
)
|
| 472 |
+
heatmap = _gaussian_bbox_heatmap(frame_h, frame_w, bbox)
|
| 473 |
+
|
| 474 |
+
# Cache the result
|
| 475 |
+
set_cached_attention(job_id, frame_idx, track_id, heatmap)
|
| 476 |
+
return heatmap
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
# ── Serialization / rendering ─────────────────────────────────────
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def heatmap_to_base64(heatmap: np.ndarray) -> str:
|
| 483 |
+
"""Encode heatmap as base64 float32 bytes."""
|
| 484 |
+
raw = heatmap.astype(np.float32).tobytes()
|
| 485 |
+
return base64.b64encode(raw).decode("ascii")
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def heatmap_overlay_jpeg(
|
| 489 |
+
frame: np.ndarray,
|
| 490 |
+
heatmap: np.ndarray,
|
| 491 |
+
bbox: list,
|
| 492 |
+
alpha: float = 0.5,
|
| 493 |
+
quality: int = 85,
|
| 494 |
+
) -> bytes:
|
| 495 |
+
"""Render heatmap overlay on a cropped frame region as JPEG.
|
| 496 |
+
|
| 497 |
+
Args:
|
| 498 |
+
frame: HxWx3 BGR uint8 numpy array (full frame).
|
| 499 |
+
heatmap: HxW float32 heatmap (same size as frame).
|
| 500 |
+
bbox: [x1, y1, x2, y2] crop region.
|
| 501 |
+
alpha: Blend factor for overlay (0=no overlay, 1=full overlay).
|
| 502 |
+
quality: JPEG quality.
|
| 503 |
+
|
| 504 |
+
Returns:
|
| 505 |
+
JPEG bytes.
|
| 506 |
+
"""
|
| 507 |
+
h, w = frame.shape[:2]
|
| 508 |
+
x1, y1, x2, y2 = [int(c) for c in bbox]
|
| 509 |
+
x1 = max(0, x1)
|
| 510 |
+
y1 = max(0, y1)
|
| 511 |
+
x2 = min(w, x2)
|
| 512 |
+
y2 = min(h, y2)
|
| 513 |
+
|
| 514 |
+
# Add some padding
|
| 515 |
+
bw = x2 - x1
|
| 516 |
+
bh = y2 - y1
|
| 517 |
+
pad = int(max(bw, bh) * 0.15)
|
| 518 |
+
cx1 = max(0, x1 - pad)
|
| 519 |
+
cy1 = max(0, y1 - pad)
|
| 520 |
+
cx2 = min(w, x2 + pad)
|
| 521 |
+
cy2 = min(h, y2 + pad)
|
| 522 |
+
|
| 523 |
+
crop = frame[cy1:cy2, cx1:cx2].copy()
|
| 524 |
+
heat_crop = heatmap[cy1:cy2, cx1:cx2]
|
| 525 |
+
|
| 526 |
+
# Normalize heatmap crop to 0-255 for colormap
|
| 527 |
+
heat_u8 = (heat_crop * 255).clip(0, 255).astype(np.uint8)
|
| 528 |
+
colored = cv2.applyColorMap(heat_u8, cv2.COLORMAP_JET)
|
| 529 |
+
|
| 530 |
+
# Blend
|
| 531 |
+
overlay = cv2.addWeighted(crop, 1 - alpha, colored, alpha, 0)
|
| 532 |
+
|
| 533 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
|
| 534 |
+
success, buffer = cv2.imencode(".jpg", overlay, encode_param)
|
| 535 |
+
if not success:
|
| 536 |
+
raise RuntimeError("Failed to encode attention overlay as JPEG")
|
| 537 |
+
return buffer.tobytes()
|
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""On-demand depth inference, colorization, caching, and stats.
|
| 2 |
+
|
| 3 |
+
Uses DepthAnythingV2Estimator for single-frame depth estimation.
|
| 4 |
+
Results are cached in-memory per (job_id, frame_idx) to avoid
|
| 5 |
+
redundant GPU work when the same frame is requested multiple times.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import base64
|
| 9 |
+
import logging
|
| 10 |
+
import threading
|
| 11 |
+
from typing import Dict, Optional, Tuple
|
| 12 |
+
|
| 13 |
+
import cv2
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
# ── In-memory depth cache ────────────────────────────────────────
|
| 19 |
+
# Key: (job_id, frame_idx) Value: depth_map (HxW float32)
|
| 20 |
+
_depth_cache: Dict[Tuple[str, int], np.ndarray] = {}
|
| 21 |
+
_cache_lock = threading.RLock()
|
| 22 |
+
|
| 23 |
+
# Limit cache size to avoid OOM
|
| 24 |
+
_MAX_CACHE_ENTRIES = 200
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _cache_key(job_id: str, frame_idx: int) -> Tuple[str, int]:
|
| 28 |
+
return (job_id, frame_idx)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_cached_depth(job_id: str, frame_idx: int) -> Optional[np.ndarray]:
|
| 32 |
+
"""Return cached depth map or None."""
|
| 33 |
+
with _cache_lock:
|
| 34 |
+
return _depth_cache.get(_cache_key(job_id, frame_idx))
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def set_cached_depth(job_id: str, frame_idx: int, depth_map: np.ndarray) -> None:
|
| 38 |
+
"""Store depth map in cache, evicting oldest if over limit."""
|
| 39 |
+
with _cache_lock:
|
| 40 |
+
if len(_depth_cache) >= _MAX_CACHE_ENTRIES:
|
| 41 |
+
# Evict the first (oldest) entry
|
| 42 |
+
oldest = next(iter(_depth_cache))
|
| 43 |
+
del _depth_cache[oldest]
|
| 44 |
+
_depth_cache[_cache_key(job_id, frame_idx)] = depth_map
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def clear_depth_cache(job_id: Optional[str] = None) -> None:
|
| 48 |
+
"""Clear depth cache for a specific job or all jobs."""
|
| 49 |
+
with _cache_lock:
|
| 50 |
+
if job_id is None:
|
| 51 |
+
_depth_cache.clear()
|
| 52 |
+
else:
|
| 53 |
+
keys_to_remove = [k for k in _depth_cache if k[0] == job_id]
|
| 54 |
+
for k in keys_to_remove:
|
| 55 |
+
del _depth_cache[k]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ── Lazy model singleton ─────────────────────────────────────────
|
| 59 |
+
|
| 60 |
+
_depth_estimator = None
|
| 61 |
+
_estimator_lock = threading.Lock()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _get_depth_estimator():
|
| 65 |
+
"""Lazy-load DepthAnythingV2 (singleton, thread-safe)."""
|
| 66 |
+
global _depth_estimator
|
| 67 |
+
if _depth_estimator is None:
|
| 68 |
+
with _estimator_lock:
|
| 69 |
+
if _depth_estimator is None:
|
| 70 |
+
from models.depth_estimators.depth_anything_v2 import (
|
| 71 |
+
DepthAnythingV2Estimator,
|
| 72 |
+
)
|
| 73 |
+
_depth_estimator = DepthAnythingV2Estimator()
|
| 74 |
+
return _depth_estimator
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# ── Core inference ────────────────────────────────────────────────
|
| 78 |
+
|
| 79 |
+
def run_depth_on_frame(
|
| 80 |
+
frame: np.ndarray,
|
| 81 |
+
job_id: str,
|
| 82 |
+
frame_idx: int,
|
| 83 |
+
) -> np.ndarray:
|
| 84 |
+
"""Run depth estimation on a single frame, caching the result.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
frame: HxWx3 BGR uint8 numpy array.
|
| 88 |
+
job_id: Job identifier (for cache keying).
|
| 89 |
+
frame_idx: Frame index (for cache keying).
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
HxW float32 depth map.
|
| 93 |
+
"""
|
| 94 |
+
cached = get_cached_depth(job_id, frame_idx)
|
| 95 |
+
if cached is not None:
|
| 96 |
+
return cached
|
| 97 |
+
|
| 98 |
+
estimator = _get_depth_estimator()
|
| 99 |
+
result = estimator.predict(frame)
|
| 100 |
+
depth_map = result.depth_map # HxW float32
|
| 101 |
+
|
| 102 |
+
set_cached_depth(job_id, frame_idx, depth_map)
|
| 103 |
+
return depth_map
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# ── Stats computation ─────────────────────────────────────────────
|
| 107 |
+
|
| 108 |
+
def compute_depth_stats(depth_map: np.ndarray) -> dict:
|
| 109 |
+
"""Compute min, max, mean, median depth statistics.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
depth_map: HxW float32 depth array.
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
Dict with min_m, max_m, mean_m, median_m.
|
| 116 |
+
"""
|
| 117 |
+
return {
|
| 118 |
+
"min_m": float(np.min(depth_map)),
|
| 119 |
+
"max_m": float(np.max(depth_map)),
|
| 120 |
+
"mean_m": float(np.mean(depth_map)),
|
| 121 |
+
"median_m": float(np.median(depth_map)),
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# ── Crop depth to track bbox ─────────────────────────────────────
|
| 126 |
+
|
| 127 |
+
def crop_depth_to_bbox(
|
| 128 |
+
depth_map: np.ndarray,
|
| 129 |
+
bbox: list,
|
| 130 |
+
padding: float = 0.0,
|
| 131 |
+
) -> np.ndarray:
|
| 132 |
+
"""Crop a depth map to a bounding box with optional padding.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
depth_map: HxW float32 depth array.
|
| 136 |
+
bbox: [x1, y1, x2, y2] in pixel coordinates.
|
| 137 |
+
padding: Fractional padding around the bbox.
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
Cropped HxW float32 depth array.
|
| 141 |
+
"""
|
| 142 |
+
h, w = depth_map.shape[:2]
|
| 143 |
+
x1, y1, x2, y2 = bbox
|
| 144 |
+
|
| 145 |
+
bw = x2 - x1
|
| 146 |
+
bh = y2 - y1
|
| 147 |
+
pad_x = int(bw * padding)
|
| 148 |
+
pad_y = int(bh * padding)
|
| 149 |
+
|
| 150 |
+
cx1 = max(0, x1 - pad_x)
|
| 151 |
+
cy1 = max(0, y1 - pad_y)
|
| 152 |
+
cx2 = min(w, x2 + pad_x)
|
| 153 |
+
cy2 = min(h, y2 + pad_y)
|
| 154 |
+
|
| 155 |
+
return depth_map[cy1:cy2, cx1:cx2].copy()
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# ── Colorization (viridis) ───────────────────────────────────────
|
| 159 |
+
|
| 160 |
+
def colorize_depth(depth_map: np.ndarray, quality: int = 85) -> bytes:
|
| 161 |
+
"""Apply viridis colormap and encode as JPEG.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
depth_map: HxW float32 depth array.
|
| 165 |
+
quality: JPEG quality (1-100).
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
JPEG bytes with viridis-colored depth.
|
| 169 |
+
"""
|
| 170 |
+
# Normalize to 0-255
|
| 171 |
+
d_min = float(np.min(depth_map))
|
| 172 |
+
d_max = float(np.max(depth_map))
|
| 173 |
+
if d_max - d_min < 1e-6:
|
| 174 |
+
normalized = np.zeros_like(depth_map, dtype=np.uint8)
|
| 175 |
+
else:
|
| 176 |
+
normalized = ((depth_map - d_min) / (d_max - d_min) * 255).astype(np.uint8)
|
| 177 |
+
|
| 178 |
+
# Apply viridis colormap (OpenCV uses BGR internally)
|
| 179 |
+
colored = cv2.applyColorMap(normalized, cv2.COLORMAP_VIRIDIS)
|
| 180 |
+
|
| 181 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
|
| 182 |
+
success, buffer = cv2.imencode(".jpg", colored, encode_param)
|
| 183 |
+
if not success:
|
| 184 |
+
raise RuntimeError("Failed to encode colorized depth as JPEG")
|
| 185 |
+
return buffer.tobytes()
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# ── Serialization helpers ─────────────────────────────────────────
|
| 189 |
+
|
| 190 |
+
def depth_to_raw_bytes(depth_map: np.ndarray) -> bytes:
|
| 191 |
+
"""Convert depth map to raw float32 little-endian bytes.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
depth_map: HxW float32 depth array.
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
Raw bytes (width * height * 4 bytes).
|
| 198 |
+
"""
|
| 199 |
+
return depth_map.astype(np.float32).tobytes()
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def depth_to_base64(depth_map: np.ndarray) -> str:
|
| 203 |
+
"""Convert depth map to base64-encoded float32 bytes.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
depth_map: HxW float32 depth array.
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
Base64-encoded string.
|
| 210 |
+
"""
|
| 211 |
+
raw = depth_to_raw_bytes(depth_map)
|
| 212 |
+
return base64.b64encode(raw).decode("ascii")
|
|
@@ -337,3 +337,201 @@ async def generate_mask(
|
|
| 337 |
"color": color,
|
| 338 |
"source": "sam2_ondemand",
|
| 339 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
"color": color,
|
| 338 |
"source": "sam2_ondemand",
|
| 339 |
})
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
# ── On-demand depth analysis ─────────────────────────────────────
|
| 343 |
+
|
| 344 |
+
@router.get("/depth/{job_id}/{frame_idx}")
|
| 345 |
+
async def get_depth(
|
| 346 |
+
job_id: str,
|
| 347 |
+
frame_idx: int,
|
| 348 |
+
track_id: Optional[str] = Query(None, description="Track ID to crop depth to, e.g. 'T01'"),
|
| 349 |
+
format: str = Query("raw", description="Response format: 'raw', 'json', or 'colorized'"),
|
| 350 |
+
):
|
| 351 |
+
"""Get depth data for a frame, computed on-demand and cached.
|
| 352 |
+
|
| 353 |
+
Supports three response formats:
|
| 354 |
+
- raw: binary float32 with X-Depth-* headers
|
| 355 |
+
- json: JSON with base64-encoded float32 depth data and stats
|
| 356 |
+
- colorized: JPEG image with viridis colormap
|
| 357 |
+
"""
|
| 358 |
+
import asyncio
|
| 359 |
+
|
| 360 |
+
from inspection.frames import extract_frame
|
| 361 |
+
from inspection.depth import (
|
| 362 |
+
run_depth_on_frame,
|
| 363 |
+
crop_depth_to_bbox,
|
| 364 |
+
compute_depth_stats,
|
| 365 |
+
depth_to_raw_bytes,
|
| 366 |
+
depth_to_base64,
|
| 367 |
+
colorize_depth,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
if format not in ("raw", "json", "colorized"):
|
| 371 |
+
raise HTTPException(
|
| 372 |
+
status_code=400,
|
| 373 |
+
detail=f"Invalid format '{format}'. Must be 'raw', 'json', or 'colorized'.",
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
job = _get_job_or_404(job_id)
|
| 377 |
+
input_path = job.input_video_path
|
| 378 |
+
if not input_path or not Path(input_path).exists():
|
| 379 |
+
raise HTTPException(status_code=404, detail="Input video not found on disk.")
|
| 380 |
+
|
| 381 |
+
_validate_frame_idx(input_path, frame_idx)
|
| 382 |
+
|
| 383 |
+
# Extract frame and run depth (GPU work in thread pool)
|
| 384 |
+
frame = await asyncio.to_thread(extract_frame, input_path, frame_idx)
|
| 385 |
+
depth_map = await asyncio.to_thread(run_depth_on_frame, frame, job_id, frame_idx)
|
| 386 |
+
|
| 387 |
+
# Optionally crop to track bbox
|
| 388 |
+
if track_id is not None:
|
| 389 |
+
from jobs.storage import get_track_data
|
| 390 |
+
|
| 391 |
+
tracks = get_track_data(job_id, frame_idx)
|
| 392 |
+
instance_id = int(track_id.replace("T", "")) if track_id.startswith("T") else int(track_id)
|
| 393 |
+
target = None
|
| 394 |
+
for t in tracks:
|
| 395 |
+
tid = t.get("instance_id") or t.get("track_id")
|
| 396 |
+
if tid == instance_id or tid == track_id:
|
| 397 |
+
target = t
|
| 398 |
+
break
|
| 399 |
+
if target and "bbox" in target:
|
| 400 |
+
depth_map = crop_depth_to_bbox(depth_map, target["bbox"])
|
| 401 |
+
else:
|
| 402 |
+
raise HTTPException(
|
| 403 |
+
status_code=404,
|
| 404 |
+
detail=f"Track {track_id} not found in frame {frame_idx}.",
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
h, w = depth_map.shape[:2]
|
| 408 |
+
d_min = float(depth_map.min())
|
| 409 |
+
d_max = float(depth_map.max())
|
| 410 |
+
|
| 411 |
+
if format == "raw":
|
| 412 |
+
raw_bytes = depth_to_raw_bytes(depth_map)
|
| 413 |
+
return Response(
|
| 414 |
+
content=raw_bytes,
|
| 415 |
+
media_type="application/octet-stream",
|
| 416 |
+
headers={
|
| 417 |
+
"X-Depth-Width": str(w),
|
| 418 |
+
"X-Depth-Height": str(h),
|
| 419 |
+
"X-Depth-Min": f"{d_min:.4f}",
|
| 420 |
+
"X-Depth-Max": f"{d_max:.4f}",
|
| 421 |
+
},
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
if format == "json":
|
| 425 |
+
stats = compute_depth_stats(depth_map)
|
| 426 |
+
data_b64 = depth_to_base64(depth_map)
|
| 427 |
+
return JSONResponse({
|
| 428 |
+
"width": w,
|
| 429 |
+
"height": h,
|
| 430 |
+
"min_depth": d_min,
|
| 431 |
+
"max_depth": d_max,
|
| 432 |
+
"data_b64": data_b64,
|
| 433 |
+
"depth_stats": stats,
|
| 434 |
+
})
|
| 435 |
+
|
| 436 |
+
# format == "colorized"
|
| 437 |
+
jpeg_bytes = await asyncio.to_thread(colorize_depth, depth_map)
|
| 438 |
+
return Response(content=jpeg_bytes, media_type="image/jpeg")
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
# ── Attention / GradCAM heatmaps ─────────────────────────────────
|
| 442 |
+
|
| 443 |
+
@router.get("/attention/{job_id}/{frame_idx}/{track_id}")
|
| 444 |
+
async def get_attention(
|
| 445 |
+
job_id: str,
|
| 446 |
+
frame_idx: int,
|
| 447 |
+
track_id: str,
|
| 448 |
+
format: str = Query("json", description="Response format: 'json' or 'overlay'"),
|
| 449 |
+
):
|
| 450 |
+
"""Generate a GradCAM/saliency attention heatmap for a detected object.
|
| 451 |
+
|
| 452 |
+
Computed on-demand using the detector model that produced the original
|
| 453 |
+
job. Results are cached per (job_id, frame_idx, track_id).
|
| 454 |
+
|
| 455 |
+
Supports two response formats:
|
| 456 |
+
- json: base64-encoded float32 heatmap with metadata
|
| 457 |
+
- overlay: JPEG image with heatmap blended onto the frame crop
|
| 458 |
+
"""
|
| 459 |
+
import asyncio
|
| 460 |
+
|
| 461 |
+
from inspection.frames import extract_frame
|
| 462 |
+
from inspection.attention import (
|
| 463 |
+
generate_attention_map,
|
| 464 |
+
heatmap_to_base64,
|
| 465 |
+
heatmap_overlay_jpeg,
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
if format not in ("json", "overlay"):
|
| 469 |
+
raise HTTPException(
|
| 470 |
+
status_code=400,
|
| 471 |
+
detail=f"Invalid format '{format}'. Must be 'json' or 'overlay'.",
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
job = _get_job_or_404(job_id)
|
| 475 |
+
input_path = job.input_video_path
|
| 476 |
+
if not input_path or not Path(input_path).exists():
|
| 477 |
+
raise HTTPException(status_code=404, detail="Input video not found on disk.")
|
| 478 |
+
|
| 479 |
+
_validate_frame_idx(input_path, frame_idx)
|
| 480 |
+
|
| 481 |
+
# Determine the detector used for this job
|
| 482 |
+
detector_name = job.detector_name
|
| 483 |
+
if not detector_name:
|
| 484 |
+
raise HTTPException(
|
| 485 |
+
status_code=400,
|
| 486 |
+
detail="Attention maps require a detector model. This job has no detector_name.",
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
# Get the track's bounding box
|
| 490 |
+
from jobs.storage import get_track_data
|
| 491 |
+
|
| 492 |
+
tracks = get_track_data(job_id, frame_idx)
|
| 493 |
+
instance_id = int(track_id.replace("T", "")) if track_id.startswith("T") else int(track_id)
|
| 494 |
+
target = None
|
| 495 |
+
for t in tracks:
|
| 496 |
+
tid = t.get("instance_id") or t.get("track_id")
|
| 497 |
+
if tid == instance_id or tid == track_id:
|
| 498 |
+
target = t
|
| 499 |
+
break
|
| 500 |
+
|
| 501 |
+
if not target or "bbox" not in target:
|
| 502 |
+
raise HTTPException(
|
| 503 |
+
status_code=404,
|
| 504 |
+
detail=f"Track {track_id} not found in frame {frame_idx}.",
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
bbox = target["bbox"]
|
| 508 |
+
|
| 509 |
+
# Extract frame and generate attention map (GPU work in thread pool)
|
| 510 |
+
frame = await asyncio.to_thread(extract_frame, input_path, frame_idx)
|
| 511 |
+
heatmap = await asyncio.to_thread(
|
| 512 |
+
generate_attention_map,
|
| 513 |
+
frame,
|
| 514 |
+
bbox,
|
| 515 |
+
detector_name,
|
| 516 |
+
job_id,
|
| 517 |
+
frame_idx,
|
| 518 |
+
track_id,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
if format == "json":
|
| 522 |
+
h, w = heatmap.shape[:2]
|
| 523 |
+
data_b64 = heatmap_to_base64(heatmap)
|
| 524 |
+
return JSONResponse({
|
| 525 |
+
"track_id": track_id,
|
| 526 |
+
"frame_idx": frame_idx,
|
| 527 |
+
"width": w,
|
| 528 |
+
"height": h,
|
| 529 |
+
"data_b64": data_b64,
|
| 530 |
+
"format": "float32",
|
| 531 |
+
})
|
| 532 |
+
|
| 533 |
+
# format == "overlay"
|
| 534 |
+
jpeg_bytes = await asyncio.to_thread(
|
| 535 |
+
heatmap_overlay_jpeg, frame, heatmap, bbox
|
| 536 |
+
)
|
| 537 |
+
return Response(content=jpeg_bytes, media_type="image/jpeg")
|
|
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for inspection/attention.py — GradCAM attention heatmaps."""
|
| 2 |
+
|
| 3 |
+
import base64
|
| 4 |
+
import struct
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pytest
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# ── Unit tests for attention module functions ────────────────────
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def test_gaussian_bbox_heatmap_shape():
|
| 14 |
+
"""_gaussian_bbox_heatmap should produce correct shape and range."""
|
| 15 |
+
from inspection.attention import _gaussian_bbox_heatmap
|
| 16 |
+
|
| 17 |
+
heatmap = _gaussian_bbox_heatmap(100, 200, [50, 30, 150, 70])
|
| 18 |
+
|
| 19 |
+
assert heatmap.shape == (100, 200)
|
| 20 |
+
assert heatmap.dtype == np.float32
|
| 21 |
+
assert heatmap.max() <= 1.0
|
| 22 |
+
assert heatmap.min() >= 0.0
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def test_gaussian_bbox_heatmap_peak_location():
|
| 26 |
+
"""Gaussian heatmap should peak at the bbox center."""
|
| 27 |
+
from inspection.attention import _gaussian_bbox_heatmap
|
| 28 |
+
|
| 29 |
+
heatmap = _gaussian_bbox_heatmap(200, 200, [60, 80, 140, 120])
|
| 30 |
+
|
| 31 |
+
# Center is at (100, 100)
|
| 32 |
+
cy, cx = 100, 100
|
| 33 |
+
center_val = heatmap[cy, cx]
|
| 34 |
+
|
| 35 |
+
# Corner values should be lower
|
| 36 |
+
corner_val = heatmap[0, 0]
|
| 37 |
+
assert center_val > corner_val
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def test_gaussian_bbox_heatmap_small_bbox():
|
| 41 |
+
"""Small bbox should still produce valid heatmap."""
|
| 42 |
+
from inspection.attention import _gaussian_bbox_heatmap
|
| 43 |
+
|
| 44 |
+
heatmap = _gaussian_bbox_heatmap(50, 50, [10, 10, 12, 12])
|
| 45 |
+
|
| 46 |
+
assert heatmap.shape == (50, 50)
|
| 47 |
+
assert not np.isnan(heatmap).any()
|
| 48 |
+
assert heatmap.max() > 0
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def test_heatmap_to_base64():
|
| 52 |
+
"""heatmap_to_base64 should encode correctly."""
|
| 53 |
+
from inspection.attention import heatmap_to_base64
|
| 54 |
+
|
| 55 |
+
heatmap = np.array([[0.0, 0.5], [0.75, 1.0]], dtype=np.float32)
|
| 56 |
+
b64 = heatmap_to_base64(heatmap)
|
| 57 |
+
|
| 58 |
+
decoded = base64.b64decode(b64)
|
| 59 |
+
assert len(decoded) == 4 * 4 # 4 floats
|
| 60 |
+
values = struct.unpack("<4f", decoded)
|
| 61 |
+
assert values == pytest.approx((0.0, 0.5, 0.75, 1.0))
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def test_heatmap_overlay_jpeg():
|
| 65 |
+
"""heatmap_overlay_jpeg should return valid JPEG bytes."""
|
| 66 |
+
from inspection.attention import heatmap_overlay_jpeg
|
| 67 |
+
|
| 68 |
+
frame = np.random.randint(0, 255, (100, 200, 3), dtype=np.uint8)
|
| 69 |
+
heatmap = np.random.rand(100, 200).astype(np.float32)
|
| 70 |
+
bbox = [50, 30, 150, 70]
|
| 71 |
+
|
| 72 |
+
jpeg_bytes = heatmap_overlay_jpeg(frame, heatmap, bbox)
|
| 73 |
+
|
| 74 |
+
# JPEG magic bytes
|
| 75 |
+
assert jpeg_bytes[:2] == b"\xff\xd8"
|
| 76 |
+
assert len(jpeg_bytes) > 100
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def test_heatmap_overlay_jpeg_edge_bbox():
|
| 80 |
+
"""Overlay with bbox near frame edges should not crash."""
|
| 81 |
+
from inspection.attention import heatmap_overlay_jpeg
|
| 82 |
+
|
| 83 |
+
frame = np.random.randint(0, 255, (50, 50, 3), dtype=np.uint8)
|
| 84 |
+
heatmap = np.random.rand(50, 50).astype(np.float32)
|
| 85 |
+
bbox = [0, 0, 50, 50] # Full frame bbox
|
| 86 |
+
|
| 87 |
+
jpeg_bytes = heatmap_overlay_jpeg(frame, heatmap, bbox)
|
| 88 |
+
assert jpeg_bytes[:2] == b"\xff\xd8"
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# ── Cache tests ──────────────────────────────────────────────────
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def test_attention_cache_set_get():
|
| 95 |
+
"""Attention cache should store and retrieve heatmaps."""
|
| 96 |
+
from inspection.attention import (
|
| 97 |
+
get_cached_attention,
|
| 98 |
+
set_cached_attention,
|
| 99 |
+
clear_attention_cache,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
clear_attention_cache()
|
| 103 |
+
|
| 104 |
+
heatmap = np.ones((10, 10), dtype=np.float32) * 0.5
|
| 105 |
+
set_cached_attention("job1", 0, "T01", heatmap)
|
| 106 |
+
|
| 107 |
+
result = get_cached_attention("job1", 0, "T01")
|
| 108 |
+
assert result is not None
|
| 109 |
+
np.testing.assert_array_equal(result, heatmap)
|
| 110 |
+
|
| 111 |
+
# Different params should return None
|
| 112 |
+
assert get_cached_attention("job1", 0, "T02") is None
|
| 113 |
+
assert get_cached_attention("job1", 1, "T01") is None
|
| 114 |
+
assert get_cached_attention("job2", 0, "T01") is None
|
| 115 |
+
|
| 116 |
+
clear_attention_cache()
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def test_attention_cache_clear_per_job():
|
| 120 |
+
"""clear_attention_cache(job_id) should only clear that job."""
|
| 121 |
+
from inspection.attention import (
|
| 122 |
+
get_cached_attention,
|
| 123 |
+
set_cached_attention,
|
| 124 |
+
clear_attention_cache,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
clear_attention_cache()
|
| 128 |
+
|
| 129 |
+
h1 = np.ones((5, 5), dtype=np.float32)
|
| 130 |
+
h2 = np.ones((5, 5), dtype=np.float32) * 0.5
|
| 131 |
+
|
| 132 |
+
set_cached_attention("jobA", 0, "T01", h1)
|
| 133 |
+
set_cached_attention("jobB", 0, "T01", h2)
|
| 134 |
+
|
| 135 |
+
clear_attention_cache("jobA")
|
| 136 |
+
|
| 137 |
+
assert get_cached_attention("jobA", 0, "T01") is None
|
| 138 |
+
assert get_cached_attention("jobB", 0, "T01") is not None
|
| 139 |
+
|
| 140 |
+
clear_attention_cache()
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# ── Integration test for attention endpoint ──────────────────────
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def _make_test_video(tmp_path, num_frames=5, width=64, height=48):
|
| 147 |
+
"""Create a tiny test video and return its path."""
|
| 148 |
+
import cv2
|
| 149 |
+
|
| 150 |
+
video_path = str(tmp_path / "test.mp4")
|
| 151 |
+
writer = cv2.VideoWriter(
|
| 152 |
+
video_path, cv2.VideoWriter_fourcc(*"mp4v"), 30, (width, height)
|
| 153 |
+
)
|
| 154 |
+
for i in range(num_frames):
|
| 155 |
+
frame = np.full((height, width, 3), i * 40, dtype=np.uint8)
|
| 156 |
+
writer.write(frame)
|
| 157 |
+
writer.release()
|
| 158 |
+
return video_path
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def test_attention_endpoint_json(tmp_path, monkeypatch):
|
| 162 |
+
"""GET /inspect/attention/{job_id}/{frame_idx}/{track_id}?format=json."""
|
| 163 |
+
from fastapi.testclient import TestClient
|
| 164 |
+
|
| 165 |
+
from jobs.models import JobInfo, JobStatus
|
| 166 |
+
from jobs.storage import get_job_storage, set_track_data
|
| 167 |
+
|
| 168 |
+
video_path = _make_test_video(tmp_path)
|
| 169 |
+
|
| 170 |
+
storage = get_job_storage()
|
| 171 |
+
job = JobInfo(
|
| 172 |
+
job_id="test_attn_json",
|
| 173 |
+
status=JobStatus.COMPLETED,
|
| 174 |
+
mode="object_detection",
|
| 175 |
+
queries=["person"],
|
| 176 |
+
detector_name="yolo11",
|
| 177 |
+
segmenter_name=None,
|
| 178 |
+
input_video_path=video_path,
|
| 179 |
+
output_video_path=None,
|
| 180 |
+
)
|
| 181 |
+
storage.create(job)
|
| 182 |
+
|
| 183 |
+
# Add track data
|
| 184 |
+
set_track_data("test_attn_json", 0, [
|
| 185 |
+
{"instance_id": 1, "label": "person", "bbox": [10, 10, 30, 30]},
|
| 186 |
+
])
|
| 187 |
+
|
| 188 |
+
# Mock generate_attention_map to avoid loading real models
|
| 189 |
+
def fake_generate(frame, bbox, det_name, job_id, frame_idx, track_id):
|
| 190 |
+
h, w = frame.shape[:2]
|
| 191 |
+
return np.random.rand(h, w).astype(np.float32)
|
| 192 |
+
|
| 193 |
+
monkeypatch.setattr(
|
| 194 |
+
"inspection.attention.generate_attention_map", fake_generate
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
from inspection.router import router
|
| 198 |
+
from fastapi import FastAPI
|
| 199 |
+
|
| 200 |
+
app = FastAPI()
|
| 201 |
+
app.include_router(router)
|
| 202 |
+
client = TestClient(app)
|
| 203 |
+
|
| 204 |
+
resp = client.get("/inspect/attention/test_attn_json/0/T01?format=json")
|
| 205 |
+
assert resp.status_code == 200
|
| 206 |
+
|
| 207 |
+
data = resp.json()
|
| 208 |
+
assert data["track_id"] == "T01"
|
| 209 |
+
assert data["frame_idx"] == 0
|
| 210 |
+
assert "width" in data
|
| 211 |
+
assert "height" in data
|
| 212 |
+
assert "data_b64" in data
|
| 213 |
+
assert data["format"] == "float32"
|
| 214 |
+
|
| 215 |
+
# Verify base64 decodes to correct size
|
| 216 |
+
decoded = base64.b64decode(data["data_b64"])
|
| 217 |
+
assert len(decoded) == data["width"] * data["height"] * 4
|
| 218 |
+
|
| 219 |
+
from inspection.attention import clear_attention_cache
|
| 220 |
+
clear_attention_cache()
|
| 221 |
+
storage.delete("test_attn_json")
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def test_attention_endpoint_overlay(tmp_path, monkeypatch):
|
| 225 |
+
"""GET /inspect/attention/{job_id}/{frame_idx}/{track_id}?format=overlay."""
|
| 226 |
+
from fastapi.testclient import TestClient
|
| 227 |
+
|
| 228 |
+
from jobs.models import JobInfo, JobStatus
|
| 229 |
+
from jobs.storage import get_job_storage, set_track_data
|
| 230 |
+
|
| 231 |
+
video_path = _make_test_video(tmp_path)
|
| 232 |
+
|
| 233 |
+
storage = get_job_storage()
|
| 234 |
+
job = JobInfo(
|
| 235 |
+
job_id="test_attn_overlay",
|
| 236 |
+
status=JobStatus.COMPLETED,
|
| 237 |
+
mode="object_detection",
|
| 238 |
+
queries=["person"],
|
| 239 |
+
detector_name="yolo11",
|
| 240 |
+
segmenter_name=None,
|
| 241 |
+
input_video_path=video_path,
|
| 242 |
+
output_video_path=None,
|
| 243 |
+
)
|
| 244 |
+
storage.create(job)
|
| 245 |
+
|
| 246 |
+
set_track_data("test_attn_overlay", 0, [
|
| 247 |
+
{"instance_id": 1, "label": "person", "bbox": [5, 5, 40, 35]},
|
| 248 |
+
])
|
| 249 |
+
|
| 250 |
+
def fake_generate(frame, bbox, det_name, job_id, frame_idx, track_id):
|
| 251 |
+
h, w = frame.shape[:2]
|
| 252 |
+
return np.random.rand(h, w).astype(np.float32)
|
| 253 |
+
|
| 254 |
+
monkeypatch.setattr(
|
| 255 |
+
"inspection.attention.generate_attention_map", fake_generate
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
from inspection.router import router
|
| 259 |
+
from fastapi import FastAPI
|
| 260 |
+
|
| 261 |
+
app = FastAPI()
|
| 262 |
+
app.include_router(router)
|
| 263 |
+
client = TestClient(app)
|
| 264 |
+
|
| 265 |
+
resp = client.get("/inspect/attention/test_attn_overlay/0/T01?format=overlay")
|
| 266 |
+
assert resp.status_code == 200
|
| 267 |
+
assert resp.headers["content-type"] == "image/jpeg"
|
| 268 |
+
assert resp.content[:2] == b"\xff\xd8"
|
| 269 |
+
|
| 270 |
+
from inspection.attention import clear_attention_cache
|
| 271 |
+
clear_attention_cache()
|
| 272 |
+
storage.delete("test_attn_overlay")
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def test_attention_endpoint_no_detector(tmp_path):
|
| 276 |
+
"""Attention for a job with no detector_name should return 400."""
|
| 277 |
+
from fastapi.testclient import TestClient
|
| 278 |
+
|
| 279 |
+
from jobs.models import JobInfo, JobStatus
|
| 280 |
+
from jobs.storage import get_job_storage
|
| 281 |
+
|
| 282 |
+
video_path = _make_test_video(tmp_path)
|
| 283 |
+
|
| 284 |
+
storage = get_job_storage()
|
| 285 |
+
job = JobInfo(
|
| 286 |
+
job_id="test_attn_no_det",
|
| 287 |
+
status=JobStatus.COMPLETED,
|
| 288 |
+
mode="segmentation",
|
| 289 |
+
queries=["object"],
|
| 290 |
+
detector_name=None,
|
| 291 |
+
segmenter_name="GSAM2-L",
|
| 292 |
+
input_video_path=video_path,
|
| 293 |
+
output_video_path=None,
|
| 294 |
+
)
|
| 295 |
+
storage.create(job)
|
| 296 |
+
|
| 297 |
+
from inspection.router import router
|
| 298 |
+
from fastapi import FastAPI
|
| 299 |
+
|
| 300 |
+
app = FastAPI()
|
| 301 |
+
app.include_router(router)
|
| 302 |
+
client = TestClient(app)
|
| 303 |
+
|
| 304 |
+
resp = client.get("/inspect/attention/test_attn_no_det/0/T01?format=json")
|
| 305 |
+
assert resp.status_code == 400
|
| 306 |
+
|
| 307 |
+
storage.delete("test_attn_no_det")
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def test_attention_endpoint_track_not_found(tmp_path):
|
| 311 |
+
"""Track not found in frame should return 404."""
|
| 312 |
+
from fastapi.testclient import TestClient
|
| 313 |
+
|
| 314 |
+
from jobs.models import JobInfo, JobStatus
|
| 315 |
+
from jobs.storage import get_job_storage, set_track_data
|
| 316 |
+
|
| 317 |
+
video_path = _make_test_video(tmp_path)
|
| 318 |
+
|
| 319 |
+
storage = get_job_storage()
|
| 320 |
+
job = JobInfo(
|
| 321 |
+
job_id="test_attn_notrack",
|
| 322 |
+
status=JobStatus.COMPLETED,
|
| 323 |
+
mode="object_detection",
|
| 324 |
+
queries=["person"],
|
| 325 |
+
detector_name="yolo11",
|
| 326 |
+
segmenter_name=None,
|
| 327 |
+
input_video_path=video_path,
|
| 328 |
+
output_video_path=None,
|
| 329 |
+
)
|
| 330 |
+
storage.create(job)
|
| 331 |
+
|
| 332 |
+
# No track data for frame 0
|
| 333 |
+
set_track_data("test_attn_notrack", 0, [])
|
| 334 |
+
|
| 335 |
+
from inspection.router import router
|
| 336 |
+
from fastapi import FastAPI
|
| 337 |
+
|
| 338 |
+
app = FastAPI()
|
| 339 |
+
app.include_router(router)
|
| 340 |
+
client = TestClient(app)
|
| 341 |
+
|
| 342 |
+
resp = client.get("/inspect/attention/test_attn_notrack/0/T01?format=json")
|
| 343 |
+
assert resp.status_code == 404
|
| 344 |
+
|
| 345 |
+
storage.delete("test_attn_notrack")
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def test_attention_endpoint_invalid_format(tmp_path):
|
| 349 |
+
"""Invalid format should return 400."""
|
| 350 |
+
from fastapi.testclient import TestClient
|
| 351 |
+
|
| 352 |
+
from jobs.models import JobInfo, JobStatus
|
| 353 |
+
from jobs.storage import get_job_storage
|
| 354 |
+
|
| 355 |
+
video_path = _make_test_video(tmp_path)
|
| 356 |
+
|
| 357 |
+
storage = get_job_storage()
|
| 358 |
+
job = JobInfo(
|
| 359 |
+
job_id="test_attn_badfmt",
|
| 360 |
+
status=JobStatus.COMPLETED,
|
| 361 |
+
mode="object_detection",
|
| 362 |
+
queries=["person"],
|
| 363 |
+
detector_name="yolo11",
|
| 364 |
+
segmenter_name=None,
|
| 365 |
+
input_video_path=video_path,
|
| 366 |
+
output_video_path=None,
|
| 367 |
+
)
|
| 368 |
+
storage.create(job)
|
| 369 |
+
|
| 370 |
+
from inspection.router import router
|
| 371 |
+
from fastapi import FastAPI
|
| 372 |
+
|
| 373 |
+
app = FastAPI()
|
| 374 |
+
app.include_router(router)
|
| 375 |
+
client = TestClient(app)
|
| 376 |
+
|
| 377 |
+
resp = client.get("/inspect/attention/test_attn_badfmt/0/T01?format=invalid")
|
| 378 |
+
assert resp.status_code == 400
|
| 379 |
+
|
| 380 |
+
storage.delete("test_attn_badfmt")
|
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for inspection/depth.py — on-demand depth analysis."""
|
| 2 |
+
|
| 3 |
+
import base64
|
| 4 |
+
import struct
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pytest
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# ── Unit tests for depth module functions ────────────────────────
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def test_compute_depth_stats():
|
| 14 |
+
"""compute_depth_stats should return min, max, mean, median."""
|
| 15 |
+
from inspection.depth import compute_depth_stats
|
| 16 |
+
|
| 17 |
+
depth = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
|
| 18 |
+
stats = compute_depth_stats(depth)
|
| 19 |
+
|
| 20 |
+
assert stats["min_m"] == pytest.approx(1.0)
|
| 21 |
+
assert stats["max_m"] == pytest.approx(4.0)
|
| 22 |
+
assert stats["mean_m"] == pytest.approx(2.5)
|
| 23 |
+
assert stats["median_m"] == pytest.approx(2.5)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def test_compute_depth_stats_uniform():
|
| 27 |
+
"""Stats on a uniform depth map."""
|
| 28 |
+
from inspection.depth import compute_depth_stats
|
| 29 |
+
|
| 30 |
+
depth = np.full((10, 10), 5.5, dtype=np.float32)
|
| 31 |
+
stats = compute_depth_stats(depth)
|
| 32 |
+
|
| 33 |
+
assert stats["min_m"] == pytest.approx(5.5)
|
| 34 |
+
assert stats["max_m"] == pytest.approx(5.5)
|
| 35 |
+
assert stats["mean_m"] == pytest.approx(5.5)
|
| 36 |
+
assert stats["median_m"] == pytest.approx(5.5)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def test_crop_depth_to_bbox():
|
| 40 |
+
"""crop_depth_to_bbox should extract the correct subregion."""
|
| 41 |
+
from inspection.depth import crop_depth_to_bbox
|
| 42 |
+
|
| 43 |
+
depth = np.arange(100, dtype=np.float32).reshape(10, 10)
|
| 44 |
+
# bbox: x1=2, y1=3, x2=5, y2=7
|
| 45 |
+
cropped = crop_depth_to_bbox(depth, [2, 3, 5, 7])
|
| 46 |
+
|
| 47 |
+
assert cropped.shape == (4, 3) # (y2-y1, x2-x1) = (4, 3)
|
| 48 |
+
assert cropped[0, 0] == pytest.approx(32.0) # depth[3, 2]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def test_crop_depth_to_bbox_with_padding():
|
| 52 |
+
"""crop_depth_to_bbox with padding should expand the region."""
|
| 53 |
+
from inspection.depth import crop_depth_to_bbox
|
| 54 |
+
|
| 55 |
+
depth = np.arange(100, dtype=np.float32).reshape(10, 10)
|
| 56 |
+
# bbox: x1=2, y1=3, x2=5, y2=7, padding=0.5
|
| 57 |
+
# bw=3, bh=4, pad_x=1, pad_y=2 => region [1,1] to [6,9]
|
| 58 |
+
cropped = crop_depth_to_bbox(depth, [2, 3, 5, 7], padding=0.5)
|
| 59 |
+
|
| 60 |
+
assert cropped.shape[0] > 4 # Should be larger than (4, 3)
|
| 61 |
+
assert cropped.shape[1] > 3
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def test_crop_depth_to_bbox_clamped():
|
| 65 |
+
"""Cropping near edges should clamp to image boundaries."""
|
| 66 |
+
from inspection.depth import crop_depth_to_bbox
|
| 67 |
+
|
| 68 |
+
depth = np.arange(100, dtype=np.float32).reshape(10, 10)
|
| 69 |
+
# bbox near top-left with big padding
|
| 70 |
+
cropped = crop_depth_to_bbox(depth, [0, 0, 2, 2], padding=1.0)
|
| 71 |
+
|
| 72 |
+
assert cropped.shape[0] >= 2
|
| 73 |
+
assert cropped.shape[1] >= 2
|
| 74 |
+
# Should not crash or produce negative indices
|
| 75 |
+
assert cropped[0, 0] == pytest.approx(0.0) # depth[0, 0]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def test_depth_to_raw_bytes():
|
| 79 |
+
"""depth_to_raw_bytes should produce correct float32 bytes."""
|
| 80 |
+
from inspection.depth import depth_to_raw_bytes
|
| 81 |
+
|
| 82 |
+
depth = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
|
| 83 |
+
raw = depth_to_raw_bytes(depth)
|
| 84 |
+
|
| 85 |
+
assert len(raw) == 4 * 4 # 4 floats * 4 bytes each
|
| 86 |
+
values = struct.unpack("<4f", raw)
|
| 87 |
+
assert values == pytest.approx((1.0, 2.0, 3.0, 4.0))
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def test_depth_to_base64():
|
| 91 |
+
"""depth_to_base64 should produce decodable base64 float32 data."""
|
| 92 |
+
from inspection.depth import depth_to_base64
|
| 93 |
+
|
| 94 |
+
depth = np.array([[1.5, 2.5]], dtype=np.float32)
|
| 95 |
+
b64 = depth_to_base64(depth)
|
| 96 |
+
|
| 97 |
+
decoded = base64.b64decode(b64)
|
| 98 |
+
assert len(decoded) == 2 * 4 # 2 floats
|
| 99 |
+
values = struct.unpack("<2f", decoded)
|
| 100 |
+
assert values == pytest.approx((1.5, 2.5))
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def test_colorize_depth_returns_jpeg():
|
| 104 |
+
"""colorize_depth should return valid JPEG bytes."""
|
| 105 |
+
from inspection.depth import colorize_depth
|
| 106 |
+
|
| 107 |
+
depth = np.random.rand(48, 64).astype(np.float32) * 10.0
|
| 108 |
+
jpeg_bytes = colorize_depth(depth)
|
| 109 |
+
|
| 110 |
+
# JPEG magic bytes
|
| 111 |
+
assert jpeg_bytes[:2] == b"\xff\xd8"
|
| 112 |
+
assert len(jpeg_bytes) > 100
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def test_colorize_depth_uniform():
|
| 116 |
+
"""colorize_depth should handle uniform depth (no division by zero)."""
|
| 117 |
+
from inspection.depth import colorize_depth
|
| 118 |
+
|
| 119 |
+
depth = np.full((32, 32), 5.0, dtype=np.float32)
|
| 120 |
+
jpeg_bytes = colorize_depth(depth)
|
| 121 |
+
assert jpeg_bytes[:2] == b"\xff\xd8"
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# ── Cache tests ──────────────────────────────────────────────────
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def test_cache_set_get():
|
| 128 |
+
"""Depth cache should store and retrieve depth maps."""
|
| 129 |
+
from inspection.depth import (
|
| 130 |
+
get_cached_depth,
|
| 131 |
+
set_cached_depth,
|
| 132 |
+
clear_depth_cache,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
clear_depth_cache()
|
| 136 |
+
|
| 137 |
+
depth = np.ones((10, 10), dtype=np.float32)
|
| 138 |
+
set_cached_depth("job1", 0, depth)
|
| 139 |
+
|
| 140 |
+
result = get_cached_depth("job1", 0)
|
| 141 |
+
assert result is not None
|
| 142 |
+
np.testing.assert_array_equal(result, depth)
|
| 143 |
+
|
| 144 |
+
# Different frame should return None
|
| 145 |
+
assert get_cached_depth("job1", 1) is None
|
| 146 |
+
# Different job should return None
|
| 147 |
+
assert get_cached_depth("job2", 0) is None
|
| 148 |
+
|
| 149 |
+
clear_depth_cache()
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def test_cache_clear_per_job():
|
| 153 |
+
"""clear_depth_cache(job_id) should only clear that job."""
|
| 154 |
+
from inspection.depth import (
|
| 155 |
+
get_cached_depth,
|
| 156 |
+
set_cached_depth,
|
| 157 |
+
clear_depth_cache,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
clear_depth_cache()
|
| 161 |
+
|
| 162 |
+
d1 = np.ones((5, 5), dtype=np.float32)
|
| 163 |
+
d2 = np.ones((5, 5), dtype=np.float32) * 2
|
| 164 |
+
|
| 165 |
+
set_cached_depth("jobA", 0, d1)
|
| 166 |
+
set_cached_depth("jobB", 0, d2)
|
| 167 |
+
|
| 168 |
+
clear_depth_cache("jobA")
|
| 169 |
+
|
| 170 |
+
assert get_cached_depth("jobA", 0) is None
|
| 171 |
+
assert get_cached_depth("jobB", 0) is not None
|
| 172 |
+
|
| 173 |
+
clear_depth_cache()
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# ── Integration test for the endpoint (via TestClient) ───────────
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def _make_test_video(tmp_path, num_frames=5, width=64, height=48):
|
| 180 |
+
"""Create a tiny test video and return its path."""
|
| 181 |
+
import cv2
|
| 182 |
+
|
| 183 |
+
video_path = str(tmp_path / "test.mp4")
|
| 184 |
+
writer = cv2.VideoWriter(
|
| 185 |
+
video_path, cv2.VideoWriter_fourcc(*"mp4v"), 30, (width, height)
|
| 186 |
+
)
|
| 187 |
+
for i in range(num_frames):
|
| 188 |
+
frame = np.full((height, width, 3), i * 40, dtype=np.uint8)
|
| 189 |
+
writer.write(frame)
|
| 190 |
+
writer.release()
|
| 191 |
+
return video_path
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def test_depth_endpoint_raw(tmp_path, monkeypatch):
|
| 195 |
+
"""GET /inspect/depth/{job_id}/{frame_idx}?format=raw should return binary float32."""
|
| 196 |
+
from fastapi.testclient import TestClient
|
| 197 |
+
from datetime import datetime
|
| 198 |
+
|
| 199 |
+
from jobs.models import JobInfo, JobStatus
|
| 200 |
+
from jobs.storage import get_job_storage
|
| 201 |
+
|
| 202 |
+
# Create a test video
|
| 203 |
+
video_path = _make_test_video(tmp_path)
|
| 204 |
+
|
| 205 |
+
# Register a fake job
|
| 206 |
+
storage = get_job_storage()
|
| 207 |
+
job = JobInfo(
|
| 208 |
+
job_id="test_depth_raw",
|
| 209 |
+
status=JobStatus.COMPLETED,
|
| 210 |
+
mode="object_detection",
|
| 211 |
+
queries=["person"],
|
| 212 |
+
detector_name="yolo11",
|
| 213 |
+
segmenter_name=None,
|
| 214 |
+
input_video_path=video_path,
|
| 215 |
+
output_video_path=None,
|
| 216 |
+
)
|
| 217 |
+
storage.create(job)
|
| 218 |
+
|
| 219 |
+
# Mock the depth estimator to avoid loading the real model
|
| 220 |
+
class FakeDepthResult:
|
| 221 |
+
def __init__(self, h, w):
|
| 222 |
+
self.depth_map = np.arange(h * w, dtype=np.float32).reshape(h, w)
|
| 223 |
+
self.focal_length = 1.0
|
| 224 |
+
|
| 225 |
+
class FakeEstimator:
|
| 226 |
+
def predict(self, frame):
|
| 227 |
+
h, w = frame.shape[:2]
|
| 228 |
+
return FakeDepthResult(h, w)
|
| 229 |
+
|
| 230 |
+
monkeypatch.setattr(
|
| 231 |
+
"inspection.depth._depth_estimator", FakeEstimator()
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# Import app after patching
|
| 235 |
+
from inspection.router import router
|
| 236 |
+
from fastapi import FastAPI
|
| 237 |
+
|
| 238 |
+
app = FastAPI()
|
| 239 |
+
app.include_router(router)
|
| 240 |
+
client = TestClient(app)
|
| 241 |
+
|
| 242 |
+
resp = client.get("/inspect/depth/test_depth_raw/0?format=raw")
|
| 243 |
+
assert resp.status_code == 200
|
| 244 |
+
assert resp.headers["content-type"] == "application/octet-stream"
|
| 245 |
+
assert "X-Depth-Width" in resp.headers
|
| 246 |
+
assert "X-Depth-Height" in resp.headers
|
| 247 |
+
assert "X-Depth-Min" in resp.headers
|
| 248 |
+
assert "X-Depth-Max" in resp.headers
|
| 249 |
+
|
| 250 |
+
w = int(resp.headers["X-Depth-Width"])
|
| 251 |
+
h = int(resp.headers["X-Depth-Height"])
|
| 252 |
+
assert len(resp.content) == w * h * 4 # float32
|
| 253 |
+
|
| 254 |
+
# Cleanup
|
| 255 |
+
from inspection.depth import clear_depth_cache
|
| 256 |
+
clear_depth_cache()
|
| 257 |
+
storage.delete("test_depth_raw")
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def test_depth_endpoint_json(tmp_path, monkeypatch):
|
| 261 |
+
"""GET /inspect/depth/{job_id}/{frame_idx}?format=json should return proper JSON."""
|
| 262 |
+
from fastapi.testclient import TestClient
|
| 263 |
+
|
| 264 |
+
from jobs.models import JobInfo, JobStatus
|
| 265 |
+
from jobs.storage import get_job_storage
|
| 266 |
+
|
| 267 |
+
video_path = _make_test_video(tmp_path)
|
| 268 |
+
|
| 269 |
+
storage = get_job_storage()
|
| 270 |
+
job = JobInfo(
|
| 271 |
+
job_id="test_depth_json",
|
| 272 |
+
status=JobStatus.COMPLETED,
|
| 273 |
+
mode="object_detection",
|
| 274 |
+
queries=["person"],
|
| 275 |
+
detector_name="yolo11",
|
| 276 |
+
segmenter_name=None,
|
| 277 |
+
input_video_path=video_path,
|
| 278 |
+
output_video_path=None,
|
| 279 |
+
)
|
| 280 |
+
storage.create(job)
|
| 281 |
+
|
| 282 |
+
class FakeDepthResult:
|
| 283 |
+
def __init__(self, h, w):
|
| 284 |
+
self.depth_map = np.ones((h, w), dtype=np.float32) * 5.0
|
| 285 |
+
self.focal_length = 1.0
|
| 286 |
+
|
| 287 |
+
class FakeEstimator:
|
| 288 |
+
def predict(self, frame):
|
| 289 |
+
h, w = frame.shape[:2]
|
| 290 |
+
return FakeDepthResult(h, w)
|
| 291 |
+
|
| 292 |
+
monkeypatch.setattr("inspection.depth._depth_estimator", FakeEstimator())
|
| 293 |
+
|
| 294 |
+
from inspection.router import router
|
| 295 |
+
from fastapi import FastAPI
|
| 296 |
+
|
| 297 |
+
app = FastAPI()
|
| 298 |
+
app.include_router(router)
|
| 299 |
+
client = TestClient(app)
|
| 300 |
+
|
| 301 |
+
resp = client.get("/inspect/depth/test_depth_json/0?format=json")
|
| 302 |
+
assert resp.status_code == 200
|
| 303 |
+
|
| 304 |
+
data = resp.json()
|
| 305 |
+
assert "width" in data
|
| 306 |
+
assert "height" in data
|
| 307 |
+
assert "min_depth" in data
|
| 308 |
+
assert "max_depth" in data
|
| 309 |
+
assert "data_b64" in data
|
| 310 |
+
assert "depth_stats" in data
|
| 311 |
+
assert data["depth_stats"]["min_m"] == pytest.approx(5.0)
|
| 312 |
+
assert data["depth_stats"]["max_m"] == pytest.approx(5.0)
|
| 313 |
+
|
| 314 |
+
# Verify base64 decodes to correct size
|
| 315 |
+
decoded = base64.b64decode(data["data_b64"])
|
| 316 |
+
assert len(decoded) == data["width"] * data["height"] * 4
|
| 317 |
+
|
| 318 |
+
from inspection.depth import clear_depth_cache
|
| 319 |
+
clear_depth_cache()
|
| 320 |
+
storage.delete("test_depth_json")
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def test_depth_endpoint_colorized(tmp_path, monkeypatch):
|
| 324 |
+
"""GET /inspect/depth/{job_id}/{frame_idx}?format=colorized should return JPEG."""
|
| 325 |
+
from fastapi.testclient import TestClient
|
| 326 |
+
|
| 327 |
+
from jobs.models import JobInfo, JobStatus
|
| 328 |
+
from jobs.storage import get_job_storage
|
| 329 |
+
|
| 330 |
+
video_path = _make_test_video(tmp_path)
|
| 331 |
+
|
| 332 |
+
storage = get_job_storage()
|
| 333 |
+
job = JobInfo(
|
| 334 |
+
job_id="test_depth_color",
|
| 335 |
+
status=JobStatus.COMPLETED,
|
| 336 |
+
mode="object_detection",
|
| 337 |
+
queries=["person"],
|
| 338 |
+
detector_name="yolo11",
|
| 339 |
+
segmenter_name=None,
|
| 340 |
+
input_video_path=video_path,
|
| 341 |
+
output_video_path=None,
|
| 342 |
+
)
|
| 343 |
+
storage.create(job)
|
| 344 |
+
|
| 345 |
+
class FakeDepthResult:
|
| 346 |
+
def __init__(self, h, w):
|
| 347 |
+
self.depth_map = np.random.rand(h, w).astype(np.float32) * 10.0
|
| 348 |
+
self.focal_length = 1.0
|
| 349 |
+
|
| 350 |
+
class FakeEstimator:
|
| 351 |
+
def predict(self, frame):
|
| 352 |
+
h, w = frame.shape[:2]
|
| 353 |
+
return FakeDepthResult(h, w)
|
| 354 |
+
|
| 355 |
+
monkeypatch.setattr("inspection.depth._depth_estimator", FakeEstimator())
|
| 356 |
+
|
| 357 |
+
from inspection.router import router
|
| 358 |
+
from fastapi import FastAPI
|
| 359 |
+
|
| 360 |
+
app = FastAPI()
|
| 361 |
+
app.include_router(router)
|
| 362 |
+
client = TestClient(app)
|
| 363 |
+
|
| 364 |
+
resp = client.get("/inspect/depth/test_depth_color/0?format=colorized")
|
| 365 |
+
assert resp.status_code == 200
|
| 366 |
+
assert resp.headers["content-type"] == "image/jpeg"
|
| 367 |
+
assert resp.content[:2] == b"\xff\xd8" # JPEG magic
|
| 368 |
+
|
| 369 |
+
from inspection.depth import clear_depth_cache
|
| 370 |
+
clear_depth_cache()
|
| 371 |
+
storage.delete("test_depth_color")
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def test_depth_endpoint_invalid_format(tmp_path, monkeypatch):
|
| 375 |
+
"""Invalid format should return 400."""
|
| 376 |
+
from fastapi.testclient import TestClient
|
| 377 |
+
|
| 378 |
+
from jobs.models import JobInfo, JobStatus
|
| 379 |
+
from jobs.storage import get_job_storage
|
| 380 |
+
|
| 381 |
+
video_path = _make_test_video(tmp_path)
|
| 382 |
+
|
| 383 |
+
storage = get_job_storage()
|
| 384 |
+
job = JobInfo(
|
| 385 |
+
job_id="test_depth_bad_fmt",
|
| 386 |
+
status=JobStatus.COMPLETED,
|
| 387 |
+
mode="object_detection",
|
| 388 |
+
queries=["person"],
|
| 389 |
+
detector_name="yolo11",
|
| 390 |
+
segmenter_name=None,
|
| 391 |
+
input_video_path=video_path,
|
| 392 |
+
output_video_path=None,
|
| 393 |
+
)
|
| 394 |
+
storage.create(job)
|
| 395 |
+
|
| 396 |
+
from inspection.router import router
|
| 397 |
+
from fastapi import FastAPI
|
| 398 |
+
|
| 399 |
+
app = FastAPI()
|
| 400 |
+
app.include_router(router)
|
| 401 |
+
client = TestClient(app)
|
| 402 |
+
|
| 403 |
+
resp = client.get("/inspect/depth/test_depth_bad_fmt/0?format=invalid")
|
| 404 |
+
assert resp.status_code == 400
|
| 405 |
+
|
| 406 |
+
storage.delete("test_depth_bad_fmt")
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def test_depth_endpoint_job_not_found():
|
| 410 |
+
"""Non-existent job should return 404."""
|
| 411 |
+
from fastapi.testclient import TestClient
|
| 412 |
+
|
| 413 |
+
from inspection.router import router
|
| 414 |
+
from fastapi import FastAPI
|
| 415 |
+
|
| 416 |
+
app = FastAPI()
|
| 417 |
+
app.include_router(router)
|
| 418 |
+
client = TestClient(app)
|
| 419 |
+
|
| 420 |
+
resp = client.get("/inspect/depth/nonexistent/0?format=json")
|
| 421 |
+
assert resp.status_code == 404
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def test_depth_endpoint_with_track_id(tmp_path, monkeypatch):
|
| 425 |
+
"""Depth with track_id should crop to bbox."""
|
| 426 |
+
from fastapi.testclient import TestClient
|
| 427 |
+
|
| 428 |
+
from jobs.models import JobInfo, JobStatus
|
| 429 |
+
from jobs.storage import get_job_storage, set_track_data
|
| 430 |
+
|
| 431 |
+
video_path = _make_test_video(tmp_path)
|
| 432 |
+
|
| 433 |
+
storage = get_job_storage()
|
| 434 |
+
job = JobInfo(
|
| 435 |
+
job_id="test_depth_track",
|
| 436 |
+
status=JobStatus.COMPLETED,
|
| 437 |
+
mode="object_detection",
|
| 438 |
+
queries=["person"],
|
| 439 |
+
detector_name="yolo11",
|
| 440 |
+
segmenter_name=None,
|
| 441 |
+
input_video_path=video_path,
|
| 442 |
+
output_video_path=None,
|
| 443 |
+
)
|
| 444 |
+
storage.create(job)
|
| 445 |
+
|
| 446 |
+
# Add track data for frame 0
|
| 447 |
+
set_track_data("test_depth_track", 0, [
|
| 448 |
+
{"instance_id": 1, "label": "person", "bbox": [10, 10, 30, 30]},
|
| 449 |
+
])
|
| 450 |
+
|
| 451 |
+
class FakeDepthResult:
|
| 452 |
+
def __init__(self, h, w):
|
| 453 |
+
self.depth_map = np.arange(h * w, dtype=np.float32).reshape(h, w)
|
| 454 |
+
self.focal_length = 1.0
|
| 455 |
+
|
| 456 |
+
class FakeEstimator:
|
| 457 |
+
def predict(self, frame):
|
| 458 |
+
h, w = frame.shape[:2]
|
| 459 |
+
return FakeDepthResult(h, w)
|
| 460 |
+
|
| 461 |
+
monkeypatch.setattr("inspection.depth._depth_estimator", FakeEstimator())
|
| 462 |
+
|
| 463 |
+
from inspection.router import router
|
| 464 |
+
from fastapi import FastAPI
|
| 465 |
+
|
| 466 |
+
app = FastAPI()
|
| 467 |
+
app.include_router(router)
|
| 468 |
+
client = TestClient(app)
|
| 469 |
+
|
| 470 |
+
resp = client.get("/inspect/depth/test_depth_track/0?format=json&track_id=T01")
|
| 471 |
+
assert resp.status_code == 200
|
| 472 |
+
|
| 473 |
+
data = resp.json()
|
| 474 |
+
# Cropped to bbox [10,10,30,30] => 20x20
|
| 475 |
+
assert data["width"] == 20
|
| 476 |
+
assert data["height"] == 20
|
| 477 |
+
|
| 478 |
+
from inspection.depth import clear_depth_cache
|
| 479 |
+
clear_depth_cache()
|
| 480 |
+
storage.delete("test_depth_track")
|