Zhen Ye Claude Opus 4.6 (1M context) commited on
Commit
3d7acee
Β·
1 Parent(s): 1bdac0d

feat(inspection): add frame extraction and mask retrieval API endpoints

Browse files

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Files changed (3) hide show
  1. app.py +2 -0
  2. inspection/router.py +228 -0
  3. jobs/storage.py +2 -2
app.py CHANGED
@@ -58,6 +58,7 @@ from jobs.storage import (
58
  )
59
  from models.segmenters.model_loader import get_segmenter_detector
60
  from pydantic import BaseModel
 
61
 
62
  logging.basicConfig(level=logging.INFO)
63
 
@@ -82,6 +83,7 @@ async def lifespan(_: FastAPI):
82
 
83
 
84
  app = FastAPI(title="Video Object Detection", lifespan=lifespan)
 
85
  app.add_middleware(
86
  CORSMiddleware,
87
  allow_origins=["*"],
 
58
  )
59
  from models.segmenters.model_loader import get_segmenter_detector
60
  from pydantic import BaseModel
61
+ from inspection.router import router as inspection_router
62
 
63
  logging.basicConfig(level=logging.INFO)
64
 
 
83
 
84
 
85
  app = FastAPI(title="Video Object Detection", lifespan=lifespan)
86
+ app.include_router(inspection_router)
87
  app.add_middleware(
88
  CORSMiddleware,
89
  allow_origins=["*"],
inspection/router.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI router for Object Deep-Inspection endpoints.
2
+
3
+ All endpoints are on-demand β€” they do not affect the main inference pipeline.
4
+ Endpoints are mounted at /inspect in app.py.
5
+ """
6
+
7
+ import logging
8
+ from pathlib import Path
9
+ from typing import Optional
10
+
11
+ from fastapi import APIRouter, HTTPException, Query
12
+ from fastapi.responses import JSONResponse, Response
13
+
14
+ from jobs.storage import get_job_storage
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ router = APIRouter(prefix="/inspect", tags=["inspection"])
19
+
20
+
21
+ def _get_job_or_404(job_id: str):
22
+ """Retrieve a job from storage or raise 404."""
23
+ job = get_job_storage().get(job_id)
24
+ if not job:
25
+ raise HTTPException(status_code=404, detail="Job not found or expired.")
26
+ return job
27
+
28
+
29
+ def _validate_frame_idx(video_path: str, frame_idx: int) -> None:
30
+ """Raise 400 if frame_idx is out of range for the video."""
31
+ import cv2
32
+
33
+ cap = cv2.VideoCapture(video_path)
34
+ if not cap.isOpened():
35
+ raise HTTPException(status_code=404, detail="Input video not found.")
36
+ total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
37
+ cap.release()
38
+ if frame_idx < 0 or frame_idx >= total:
39
+ raise HTTPException(
40
+ status_code=400,
41
+ detail=f"frame_idx {frame_idx} out of range [0, {total}).",
42
+ )
43
+
44
+
45
+ # ── Frame extraction ──────────────────────────────────────────────
46
+
47
+ @router.get("/frame/{job_id}/{frame_idx}")
48
+ async def get_frame(
49
+ job_id: str,
50
+ frame_idx: int,
51
+ track_id: Optional[str] = Query(None, description="Track ID to crop to, e.g. 'T01'"),
52
+ padding: float = Query(0.15, ge=0.0, le=2.0, description="Padding ratio around bbox"),
53
+ max_size: int = Query(1920, ge=64, le=4096, description="Max dimension for output"),
54
+ ):
55
+ """Extract a raw frame from the input video, optionally cropped to a track.
56
+
57
+ Returns a JPEG image. If track_id is provided and found in the frame's
58
+ track data, the image is cropped to that track's bounding box with
59
+ the specified padding ratio.
60
+ """
61
+ import asyncio
62
+ import cv2
63
+
64
+ from inspection.frames import extract_frame, crop_frame, frame_to_jpeg
65
+
66
+ job = _get_job_or_404(job_id)
67
+ input_path = job.input_video_path
68
+ if not input_path or not Path(input_path).exists():
69
+ raise HTTPException(status_code=404, detail="Input video not found on disk.")
70
+
71
+ _validate_frame_idx(input_path, frame_idx)
72
+
73
+ # Extract frame in thread pool (cv2 seek can block)
74
+ frame = await asyncio.to_thread(extract_frame, input_path, frame_idx)
75
+
76
+ # Optionally crop to track bbox
77
+ if track_id is not None:
78
+ from jobs.storage import get_track_data
79
+
80
+ tracks = get_track_data(job_id, frame_idx)
81
+ target = None
82
+ # Parse "T01" -> 1 for instance_id matching
83
+ instance_id = int(track_id.replace("T", "")) if track_id.startswith("T") else int(track_id)
84
+ for t in tracks:
85
+ tid = t.get("instance_id") or t.get("track_id")
86
+ if tid == instance_id or tid == track_id:
87
+ target = t
88
+ break
89
+ if target and "bbox" in target:
90
+ frame = crop_frame(frame, target["bbox"], padding=padding)
91
+ else:
92
+ raise HTTPException(
93
+ status_code=404,
94
+ detail=f"Track {track_id} not found in frame {frame_idx}.",
95
+ )
96
+
97
+ # Resize if larger than max_size
98
+ h, w = frame.shape[:2]
99
+ if max(h, w) > max_size:
100
+ scale = max_size / max(h, w)
101
+ new_w = int(w * scale)
102
+ new_h = int(h * scale)
103
+ frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_AREA)
104
+
105
+ jpeg_bytes = frame_to_jpeg(frame, quality=90)
106
+ return Response(content=jpeg_bytes, media_type="image/jpeg")
107
+
108
+
109
+ # ── Mask retrieval ────────────────────────────────────────────────
110
+
111
+ @router.get("/mask/{job_id}/{frame_idx}/{track_id}")
112
+ async def get_mask(
113
+ job_id: str,
114
+ frame_idx: int,
115
+ track_id: str,
116
+ format: str = Query("json", description="Response format: 'json' or 'png'"),
117
+ ):
118
+ """Get the segmentation mask for a specific object at a specific frame.
119
+
120
+ Only available for jobs run in segmentation mode.
121
+ track_id is a string like "T01".
122
+
123
+ Returns either:
124
+ - JSON with RLE-encoded mask, bbox, area, label, width, height, color, mask_format (default)
125
+ - PNG image of the mask (white on black) if format=png
126
+ """
127
+ from jobs.storage import get_mask_data, get_track_data
128
+ from inspection.masks import mask_area, rle_decode, mask_to_png_bytes
129
+
130
+ job = _get_job_or_404(job_id)
131
+ if job.mode != "segmentation":
132
+ raise HTTPException(
133
+ status_code=400,
134
+ detail="Mask data is only available for segmentation mode jobs.",
135
+ )
136
+
137
+ # Parse track_id: accept "T01" or "1", store as int internally
138
+ instance_id = int(track_id.replace("T", "")) if isinstance(track_id, str) and track_id.startswith("T") else int(track_id)
139
+
140
+ rle = get_mask_data(job_id, frame_idx, instance_id)
141
+ if rle is None:
142
+ raise HTTPException(
143
+ status_code=404,
144
+ detail=f"No mask found for track {track_id} at frame {frame_idx}.",
145
+ )
146
+
147
+ if format == "png":
148
+ mask = rle_decode(rle)
149
+ png_bytes = mask_to_png_bytes(mask)
150
+ return Response(content=png_bytes, media_type="image/png")
151
+
152
+ # JSON response: include track metadata per unified contract
153
+ label = ""
154
+ bbox = None
155
+ tracks = get_track_data(job_id, frame_idx)
156
+ for t in tracks:
157
+ tid = t.get("instance_id")
158
+ if tid == instance_id:
159
+ label = t.get("label", "")
160
+ bbox = t.get("bbox")
161
+ break
162
+
163
+ h, w = rle["size"]
164
+
165
+ # Deterministic color per track ID
166
+ TRACK_COLORS = [
167
+ [255, 0, 128], [0, 255, 128], [128, 0, 255], [255, 128, 0],
168
+ [0, 128, 255], [128, 255, 0], [255, 0, 0], [0, 255, 0],
169
+ [0, 0, 255], [255, 255, 0], [255, 0, 255], [0, 255, 255],
170
+ ]
171
+ color = TRACK_COLORS[instance_id % len(TRACK_COLORS)]
172
+
173
+ return JSONResponse({
174
+ "track_id": track_id,
175
+ "frame_idx": frame_idx,
176
+ "label": label,
177
+ "width": w,
178
+ "height": h,
179
+ "mask_format": "rle",
180
+ "rle": rle,
181
+ "bbox": bbox,
182
+ "area": mask_area(rle),
183
+ "color": color,
184
+ })
185
+
186
+
187
+ @router.get("/masks/{job_id}/{frame_idx}")
188
+ async def get_all_masks(job_id: str, frame_idx: int):
189
+ """Get all segmentation masks for a frame.
190
+
191
+ Returns a list of {track_id, label, rle, bbox, area} for every
192
+ object detected in the given frame.
193
+ """
194
+ from jobs.storage import get_all_masks_for_frame, get_track_data
195
+ from inspection.masks import mask_area
196
+
197
+ job = _get_job_or_404(job_id)
198
+ if job.mode != "segmentation":
199
+ raise HTTPException(
200
+ status_code=400,
201
+ detail="Mask data is only available for segmentation mode jobs.",
202
+ )
203
+
204
+ masks = get_all_masks_for_frame(job_id, frame_idx)
205
+ if not masks:
206
+ return JSONResponse([])
207
+
208
+ # Enrich with track metadata
209
+ tracks = get_track_data(job_id, frame_idx)
210
+ track_lookup = {}
211
+ for t in tracks:
212
+ tid = t.get("instance_id")
213
+ if tid is not None:
214
+ track_lookup[tid] = t
215
+
216
+ results = []
217
+ for tid, rle in masks.items():
218
+ t = track_lookup.get(tid, {})
219
+ results.append({
220
+ "track_id": tid,
221
+ "frame_idx": frame_idx,
222
+ "label": t.get("label", ""),
223
+ "rle": rle,
224
+ "bbox": t.get("bbox"),
225
+ "area": mask_area(rle),
226
+ })
227
+
228
+ return JSONResponse(results)
jobs/storage.py CHANGED
@@ -87,7 +87,7 @@ class JobStorage:
87
  key = f"{frame_idx}:{track_id}"
88
  self._mask_data[job_id][key] = rle
89
 
90
- def get_mask_data(self, job_id: str, frame_idx: int, track_id: int) -> dict | None:
91
  """Retrieve RLE mask for a specific object at a specific frame."""
92
  with self._lock:
93
  key = f"{frame_idx}:{track_id}"
@@ -162,7 +162,7 @@ def get_latest_frame(job_id: str):
162
  def set_mask_data(job_id: str, frame_idx: int, track_id: int, rle: dict) -> None:
163
  get_job_storage().set_mask_data(job_id, frame_idx, track_id, rle)
164
 
165
- def get_mask_data(job_id: str, frame_idx: int, track_id: int) -> dict | None:
166
  return get_job_storage().get_mask_data(job_id, frame_idx, track_id)
167
 
168
  def get_all_masks_for_frame(job_id: str, frame_idx: int) -> dict:
 
87
  key = f"{frame_idx}:{track_id}"
88
  self._mask_data[job_id][key] = rle
89
 
90
+ def get_mask_data(self, job_id: str, frame_idx: int, track_id: int) -> Optional[dict]:
91
  """Retrieve RLE mask for a specific object at a specific frame."""
92
  with self._lock:
93
  key = f"{frame_idx}:{track_id}"
 
162
  def set_mask_data(job_id: str, frame_idx: int, track_id: int, rle: dict) -> None:
163
  get_job_storage().set_mask_data(job_id, frame_idx, track_id, rle)
164
 
165
+ def get_mask_data(job_id: str, frame_idx: int, track_id: int) -> Optional[dict]:
166
  return get_job_storage().get_mask_data(job_id, frame_idx, track_id)
167
 
168
  def get_all_masks_for_frame(job_id: str, frame_idx: int) -> dict: