Zhen Ye commited on
Commit
52536ca
·
1 Parent(s): 537aca9

added async first frame/video detection

Browse files
Files changed (7) hide show
  1. app.py +173 -2
  2. demo.html +64 -14
  3. inference.py +61 -6
  4. jobs/__init__.py +1 -0
  5. jobs/background.py +48 -0
  6. jobs/models.py +27 -0
  7. jobs/storage.py +72 -0
app.py CHANGED
@@ -1,18 +1,49 @@
 
1
  import logging
2
  import os
 
3
  import tempfile
 
 
 
4
  from pathlib import Path
5
 
 
6
  from fastapi import BackgroundTasks, FastAPI, File, Form, HTTPException, UploadFile
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
9
  import uvicorn
10
 
11
- from inference import run_inference, run_segmentation
 
 
 
 
 
 
 
 
 
12
 
13
  logging.basicConfig(level=logging.INFO)
14
 
15
- app = FastAPI(title="Video Object Detection")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  app.add_middleware(
17
  CORSMiddleware,
18
  allow_origins=["*"],
@@ -36,6 +67,13 @@ def _save_upload_to_tmp(upload: UploadFile) -> str:
36
  return path
37
 
38
 
 
 
 
 
 
 
 
39
  def _safe_delete(path: str) -> None:
40
  """Safely delete a file, ignoring errors."""
41
  try:
@@ -54,6 +92,14 @@ def _schedule_cleanup(background_tasks: BackgroundTasks, path: str) -> None:
54
  background_tasks.add_task(_cleanup)
55
 
56
 
 
 
 
 
 
 
 
 
57
  @app.get("/", response_class=HTMLResponse)
58
  async def demo_page() -> str:
59
  """Serve the demo page."""
@@ -198,5 +244,130 @@ async def detect_endpoint(
198
  return response
199
 
200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  if __name__ == "__main__":
202
  uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)
 
1
+ import asyncio
2
  import logging
3
  import os
4
+ import shutil
5
  import tempfile
6
+ import uuid
7
+ from contextlib import asynccontextmanager
8
+ from datetime import timedelta
9
  from pathlib import Path
10
 
11
+ import cv2
12
  from fastapi import BackgroundTasks, FastAPI, File, Form, HTTPException, UploadFile
13
  from fastapi.middleware.cors import CORSMiddleware
14
  from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
15
  import uvicorn
16
 
17
+ from inference import process_first_frame, run_inference, run_segmentation
18
+ from jobs.background import process_video_async
19
+ from jobs.models import JobInfo, JobStatus
20
+ from jobs.storage import (
21
+ get_first_frame_path,
22
+ get_input_video_path,
23
+ get_job_directory,
24
+ get_job_storage,
25
+ get_output_video_path,
26
+ )
27
 
28
  logging.basicConfig(level=logging.INFO)
29
 
30
+
31
+ async def _periodic_cleanup() -> None:
32
+ while True:
33
+ await asyncio.sleep(600)
34
+ get_job_storage().cleanup_expired(timedelta(hours=1))
35
+
36
+
37
+ @asynccontextmanager
38
+ async def lifespan(_: FastAPI):
39
+ cleanup_task = asyncio.create_task(_periodic_cleanup())
40
+ try:
41
+ yield
42
+ finally:
43
+ cleanup_task.cancel()
44
+
45
+
46
+ app = FastAPI(title="Video Object Detection", lifespan=lifespan)
47
  app.add_middleware(
48
  CORSMiddleware,
49
  allow_origins=["*"],
 
67
  return path
68
 
69
 
70
+ def _save_upload_to_path(upload: UploadFile, path: Path) -> None:
71
+ path.parent.mkdir(parents=True, exist_ok=True)
72
+ with open(path, "wb") as buffer:
73
+ data = upload.file.read()
74
+ buffer.write(data)
75
+
76
+
77
  def _safe_delete(path: str) -> None:
78
  """Safely delete a file, ignoring errors."""
79
  try:
 
92
  background_tasks.add_task(_cleanup)
93
 
94
 
95
+ def _default_queries_for_mode(mode: str) -> list[str]:
96
+ if mode == "segmentation":
97
+ return ["object"]
98
+ if mode == "drone_detection":
99
+ return ["drone"]
100
+ return ["person", "car", "truck", "motorcycle", "bicycle", "bus", "train", "airplane"]
101
+
102
+
103
  @app.get("/", response_class=HTMLResponse)
104
  async def demo_page() -> str:
105
  """Serve the demo page."""
 
244
  return response
245
 
246
 
247
+ @app.post("/detect/async")
248
+ async def detect_async_endpoint(
249
+ video: UploadFile = File(...),
250
+ mode: str = Form(...),
251
+ queries: str = Form(""),
252
+ detector: str = Form("hf_yolov8"),
253
+ segmenter: str = Form("sam3"),
254
+ ):
255
+ if mode not in VALID_MODES:
256
+ raise HTTPException(
257
+ status_code=400,
258
+ detail=f"Invalid mode '{mode}'. Must be one of: {', '.join(VALID_MODES)}",
259
+ )
260
+
261
+ if video is None:
262
+ raise HTTPException(status_code=400, detail="Video file is required.")
263
+
264
+ job_id = uuid.uuid4().hex
265
+ job_dir = get_job_directory(job_id)
266
+ input_path = get_input_video_path(job_id)
267
+ output_path = get_output_video_path(job_id)
268
+ first_frame_path = get_first_frame_path(job_id)
269
+
270
+ try:
271
+ _save_upload_to_path(video, input_path)
272
+ except Exception:
273
+ logging.exception("Failed to save uploaded file.")
274
+ raise HTTPException(status_code=500, detail="Failed to save uploaded video.")
275
+ finally:
276
+ await video.close()
277
+
278
+ query_list = [q.strip() for q in queries.split(",") if q.strip()]
279
+ if not query_list:
280
+ query_list = _default_queries_for_mode(mode)
281
+
282
+ detector_name = detector
283
+ if mode == "drone_detection":
284
+ detector_name = "drone_yolo"
285
+
286
+ try:
287
+ processed_frame, detections = process_first_frame(
288
+ str(input_path),
289
+ query_list,
290
+ mode=mode,
291
+ detector_name=detector_name,
292
+ segmenter_name=segmenter,
293
+ )
294
+ cv2.imwrite(str(first_frame_path), processed_frame)
295
+ except Exception:
296
+ logging.exception("First-frame processing failed.")
297
+ shutil.rmtree(job_dir, ignore_errors=True)
298
+ raise HTTPException(status_code=500, detail="Failed to process first frame.")
299
+
300
+ job = JobInfo(
301
+ job_id=job_id,
302
+ status=JobStatus.PROCESSING,
303
+ mode=mode,
304
+ queries=query_list,
305
+ detector_name=detector_name,
306
+ segmenter_name=segmenter,
307
+ input_video_path=str(input_path),
308
+ output_video_path=str(output_path),
309
+ first_frame_path=str(first_frame_path),
310
+ first_frame_detections=detections,
311
+ )
312
+ get_job_storage().create(job)
313
+ asyncio.create_task(process_video_async(job_id))
314
+
315
+ return {
316
+ "job_id": job_id,
317
+ "first_frame_url": f"/detect/first-frame/{job_id}",
318
+ "status_url": f"/detect/status/{job_id}",
319
+ "video_url": f"/detect/video/{job_id}",
320
+ "status": job.status.value,
321
+ "first_frame_detections": detections,
322
+ }
323
+
324
+
325
+ @app.get("/detect/status/{job_id}")
326
+ async def detect_status(job_id: str):
327
+ job = get_job_storage().get(job_id)
328
+ if not job:
329
+ raise HTTPException(status_code=404, detail="Job not found or expired.")
330
+ return {
331
+ "job_id": job.job_id,
332
+ "status": job.status.value,
333
+ "created_at": job.created_at.isoformat(),
334
+ "completed_at": job.completed_at.isoformat() if job.completed_at else None,
335
+ "error": job.error,
336
+ }
337
+
338
+
339
+ @app.get("/detect/first-frame/{job_id}")
340
+ async def detect_first_frame(job_id: str):
341
+ job = get_job_storage().get(job_id)
342
+ if not job or not Path(job.first_frame_path).exists():
343
+ raise HTTPException(status_code=404, detail="First frame not found.")
344
+ return FileResponse(
345
+ path=job.first_frame_path,
346
+ media_type="image/jpeg",
347
+ filename="first_frame.jpg",
348
+ )
349
+
350
+
351
+ @app.get("/detect/video/{job_id}")
352
+ async def detect_video(job_id: str):
353
+ job = get_job_storage().get(job_id)
354
+ if not job:
355
+ raise HTTPException(status_code=404, detail="Job not found or expired.")
356
+ if job.status == JobStatus.FAILED:
357
+ raise HTTPException(status_code=500, detail=f"Job failed: {job.error}")
358
+ if job.status == JobStatus.PROCESSING:
359
+ return JSONResponse(
360
+ status_code=202,
361
+ content={"detail": "Video still processing", "status": "processing"},
362
+ )
363
+ if not job.output_video_path or not Path(job.output_video_path).exists():
364
+ raise HTTPException(status_code=404, detail="Video file not found.")
365
+ return FileResponse(
366
+ path=job.output_video_path,
367
+ media_type="video/mp4",
368
+ filename="processed.mp4",
369
+ )
370
+
371
+
372
  if __name__ == "__main__":
373
  uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)
demo.html CHANGED
@@ -231,6 +231,13 @@
231
  background: #000;
232
  }
233
 
 
 
 
 
 
 
 
234
  .download-btn {
235
  margin-top: 12px;
236
  padding: 10px 16px;
@@ -381,6 +388,12 @@
381
  <div class="section hidden" id="resultsSection">
382
  <div class="section-title">Results</div>
383
  <div class="results-grid">
 
 
 
 
 
 
384
  <div class="video-card">
385
  <div class="video-card-header">Original Video</div>
386
  <div class="video-card-body">
@@ -421,7 +434,9 @@
421
  const resultsSection = document.getElementById('resultsSection');
422
  const originalVideo = document.getElementById('originalVideo');
423
  const processedVideo = document.getElementById('processedVideo');
 
424
  const downloadBtn = document.getElementById('downloadBtn');
 
425
  // Mode selection handler
426
  modeCards.forEach(card => {
427
  card.addEventListener('click', (e) => {
@@ -483,6 +498,13 @@
483
  processBtn.disabled = true;
484
  loading.classList.add('show');
485
  resultsSection.classList.add('hidden');
 
 
 
 
 
 
 
486
 
487
  // Prepare form data
488
  const formData = new FormData();
@@ -493,27 +515,55 @@
493
  formData.append('segmenter', document.getElementById('segmenter').value);
494
 
495
  try {
496
- const response = await fetch('/detect', {
497
  method: 'POST',
498
  body: formData
499
  });
500
 
501
- if (response.ok) {
502
- const contentType = response.headers.get('content-type') || '';
503
- if (contentType.includes('application/json')) {
504
- const data = await response.json();
505
- alert(data.message || 'Request completed.');
506
- return;
507
- }
508
- const blob = await response.blob();
509
- const videoUrl = URL.createObjectURL(blob);
510
- processedVideo.src = videoUrl;
511
- downloadBtn.href = videoUrl;
512
- resultsSection.classList.remove('hidden');
513
- } else {
514
  const error = await response.json();
515
  alert(`Error: ${error.detail || error.error || 'Processing failed'}`);
 
516
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
  } catch (error) {
518
  console.error('Error:', error);
519
  alert('Network error: ' + error.message);
 
231
  background: #000;
232
  }
233
 
234
+ .frame-preview {
235
+ width: 100%;
236
+ border-radius: 8px;
237
+ background: #f3f4f6;
238
+ display: block;
239
+ }
240
+
241
  .download-btn {
242
  margin-top: 12px;
243
  padding: 10px 16px;
 
388
  <div class="section hidden" id="resultsSection">
389
  <div class="section-title">Results</div>
390
  <div class="results-grid">
391
+ <div class="video-card">
392
+ <div class="video-card-header">First Frame</div>
393
+ <div class="video-card-body">
394
+ <img id="firstFrameImage" class="frame-preview" alt="First frame preview">
395
+ </div>
396
+ </div>
397
  <div class="video-card">
398
  <div class="video-card-header">Original Video</div>
399
  <div class="video-card-body">
 
434
  const resultsSection = document.getElementById('resultsSection');
435
  const originalVideo = document.getElementById('originalVideo');
436
  const processedVideo = document.getElementById('processedVideo');
437
+ const firstFrameImage = document.getElementById('firstFrameImage');
438
  const downloadBtn = document.getElementById('downloadBtn');
439
+ let statusPoller = null;
440
  // Mode selection handler
441
  modeCards.forEach(card => {
442
  card.addEventListener('click', (e) => {
 
498
  processBtn.disabled = true;
499
  loading.classList.add('show');
500
  resultsSection.classList.add('hidden');
501
+ if (statusPoller) {
502
+ clearInterval(statusPoller);
503
+ statusPoller = null;
504
+ }
505
+ firstFrameImage.removeAttribute('src');
506
+ processedVideo.removeAttribute('src');
507
+ downloadBtn.removeAttribute('href');
508
 
509
  // Prepare form data
510
  const formData = new FormData();
 
515
  formData.append('segmenter', document.getElementById('segmenter').value);
516
 
517
  try {
518
+ const response = await fetch('/detect/async', {
519
  method: 'POST',
520
  body: formData
521
  });
522
 
523
+ if (!response.ok) {
 
 
 
 
 
 
 
 
 
 
 
 
524
  const error = await response.json();
525
  alert(`Error: ${error.detail || error.error || 'Processing failed'}`);
526
+ return;
527
  }
528
+
529
+ const data = await response.json();
530
+ firstFrameImage.src = `${data.first_frame_url}?t=${Date.now()}`;
531
+ resultsSection.classList.remove('hidden');
532
+
533
+ statusPoller = setInterval(async () => {
534
+ try {
535
+ const statusResponse = await fetch(data.status_url);
536
+ if (!statusResponse.ok) {
537
+ clearInterval(statusPoller);
538
+ statusPoller = null;
539
+ alert('Job expired. Please re-upload the video.');
540
+ return;
541
+ }
542
+ const statusData = await statusResponse.json();
543
+ if (statusData.status === 'completed') {
544
+ clearInterval(statusPoller);
545
+ statusPoller = null;
546
+ const videoResponse = await fetch(data.video_url);
547
+ if (!videoResponse.ok) {
548
+ alert('Failed to fetch processed video.');
549
+ return;
550
+ }
551
+ const blob = await videoResponse.blob();
552
+ const videoUrl = URL.createObjectURL(blob);
553
+ processedVideo.src = videoUrl;
554
+ downloadBtn.href = videoUrl;
555
+ } else if (statusData.status === 'failed') {
556
+ clearInterval(statusPoller);
557
+ statusPoller = null;
558
+ alert(statusData.error || 'Processing failed.');
559
+ }
560
+ } catch (pollError) {
561
+ clearInterval(statusPoller);
562
+ statusPoller = null;
563
+ console.error('Polling error:', pollError);
564
+ alert('Polling error: ' + pollError.message);
565
+ }
566
+ }, 2000);
567
  } catch (error) {
568
  console.error('Error:', error);
569
  alert('Network error: ' + error.message);
inference.py CHANGED
@@ -1,5 +1,6 @@
1
  import logging
2
- from typing import Any, Dict, List, Optional, Sequence
 
3
 
4
  import cv2
5
  import numpy as np
@@ -71,6 +72,20 @@ def _build_detection_records(
71
  return detections
72
 
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  def infer_frame(
75
  frame: np.ndarray,
76
  queries: Sequence[str],
@@ -79,10 +94,12 @@ def infer_frame(
79
  detector = load_detector(detector_name)
80
  text_queries = list(queries) or ["object"]
81
  try:
82
- result = detector.predict(frame, text_queries)
83
- detections = _build_detection_records(
84
- result.boxes, result.scores, result.labels, text_queries, result.label_names
85
- )
 
 
86
  except Exception:
87
  logging.exception("Inference failed for queries %s", text_queries)
88
  raise
@@ -95,10 +112,48 @@ def infer_segmentation_frame(
95
  segmenter_name: Optional[str] = None,
96
  ) -> tuple[np.ndarray, Any]:
97
  segmenter = load_segmenter(segmenter_name)
98
- result = segmenter.predict(frame, text_prompts=text_queries)
 
 
99
  return draw_masks(frame, result.masks), result
100
 
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def run_inference(
103
  input_video_path: str,
104
  output_video_path: str,
 
1
  import logging
2
+ from threading import RLock
3
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
4
 
5
  import cv2
6
  import numpy as np
 
72
  return detections
73
 
74
 
75
+ _MODEL_LOCKS: Dict[str, RLock] = {}
76
+ _MODEL_LOCKS_GUARD = RLock()
77
+
78
+
79
+ def _get_model_lock(kind: str, name: str) -> RLock:
80
+ key = f"{kind}:{name}"
81
+ with _MODEL_LOCKS_GUARD:
82
+ lock = _MODEL_LOCKS.get(key)
83
+ if lock is None:
84
+ lock = RLock()
85
+ _MODEL_LOCKS[key] = lock
86
+ return lock
87
+
88
+
89
  def infer_frame(
90
  frame: np.ndarray,
91
  queries: Sequence[str],
 
94
  detector = load_detector(detector_name)
95
  text_queries = list(queries) or ["object"]
96
  try:
97
+ lock = _get_model_lock("detector", detector.name)
98
+ with lock:
99
+ result = detector.predict(frame, text_queries)
100
+ detections = _build_detection_records(
101
+ result.boxes, result.scores, result.labels, text_queries, result.label_names
102
+ )
103
  except Exception:
104
  logging.exception("Inference failed for queries %s", text_queries)
105
  raise
 
112
  segmenter_name: Optional[str] = None,
113
  ) -> tuple[np.ndarray, Any]:
114
  segmenter = load_segmenter(segmenter_name)
115
+ lock = _get_model_lock("segmenter", segmenter.name)
116
+ with lock:
117
+ result = segmenter.predict(frame, text_prompts=text_queries)
118
  return draw_masks(frame, result.masks), result
119
 
120
 
121
+ def extract_first_frame(video_path: str) -> Tuple[np.ndarray, float, int, int]:
122
+ cap = cv2.VideoCapture(video_path)
123
+ if not cap.isOpened():
124
+ raise ValueError("Unable to open video.")
125
+
126
+ fps = cap.get(cv2.CAP_PROP_FPS) or 0.0
127
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
128
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
129
+ success, frame = cap.read()
130
+ cap.release()
131
+
132
+ if not success or frame is None:
133
+ raise ValueError("Video decode produced zero frames.")
134
+
135
+ return frame, fps, width, height
136
+
137
+
138
+ def process_first_frame(
139
+ video_path: str,
140
+ queries: List[str],
141
+ mode: str,
142
+ detector_name: Optional[str] = None,
143
+ segmenter_name: Optional[str] = None,
144
+ ) -> Tuple[np.ndarray, List[Dict[str, Any]]]:
145
+ frame, _, _, _ = extract_first_frame(video_path)
146
+ if mode == "segmentation":
147
+ processed, _ = infer_segmentation_frame(
148
+ frame, text_queries=queries, segmenter_name=segmenter_name
149
+ )
150
+ return processed, []
151
+ processed, detections = infer_frame(
152
+ frame, queries, detector_name=detector_name
153
+ )
154
+ return processed, detections
155
+
156
+
157
  def run_inference(
158
  input_video_path: str,
159
  output_video_path: str,
jobs/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Job management package for async detection."""
jobs/background.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ from datetime import datetime
4
+
5
+ from jobs.models import JobStatus
6
+ from jobs.storage import get_job_storage
7
+ from inference import run_inference, run_segmentation
8
+
9
+
10
+ async def process_video_async(job_id: str) -> None:
11
+ storage = get_job_storage()
12
+ job = storage.get(job_id)
13
+ if not job:
14
+ return
15
+
16
+ try:
17
+ if job.mode == "segmentation":
18
+ output_path = await asyncio.to_thread(
19
+ run_segmentation,
20
+ job.input_video_path,
21
+ job.output_video_path,
22
+ job.queries,
23
+ None,
24
+ job.segmenter_name,
25
+ )
26
+ else:
27
+ output_path = await asyncio.to_thread(
28
+ run_inference,
29
+ job.input_video_path,
30
+ job.output_video_path,
31
+ job.queries,
32
+ None,
33
+ job.detector_name,
34
+ )
35
+ storage.update(
36
+ job_id,
37
+ status=JobStatus.COMPLETED,
38
+ completed_at=datetime.utcnow(),
39
+ output_video_path=output_path,
40
+ )
41
+ except Exception as exc:
42
+ logging.exception("Background processing failed for job %s", job_id)
43
+ storage.update(
44
+ job_id,
45
+ status=JobStatus.FAILED,
46
+ completed_at=datetime.utcnow(),
47
+ error=str(exc),
48
+ )
jobs/models.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from datetime import datetime
3
+ from enum import Enum
4
+ from typing import Any, Dict, List, Optional
5
+
6
+
7
+ class JobStatus(str, Enum):
8
+ PROCESSING = "processing"
9
+ COMPLETED = "completed"
10
+ FAILED = "failed"
11
+
12
+
13
+ @dataclass
14
+ class JobInfo:
15
+ job_id: str
16
+ status: JobStatus
17
+ mode: str
18
+ queries: List[str]
19
+ detector_name: Optional[str]
20
+ segmenter_name: Optional[str]
21
+ input_video_path: str
22
+ output_video_path: Optional[str]
23
+ first_frame_path: str
24
+ created_at: datetime = field(default_factory=datetime.utcnow)
25
+ completed_at: Optional[datetime] = None
26
+ error: Optional[str] = None
27
+ first_frame_detections: List[Dict[str, Any]] = field(default_factory=list)
jobs/storage.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ from datetime import datetime, timedelta
3
+ from pathlib import Path
4
+ from threading import RLock
5
+ from typing import Dict, Optional
6
+
7
+ from jobs.models import JobInfo, JobStatus
8
+
9
+ _BASE_DIR = Path("/tmp/detection_jobs")
10
+
11
+
12
+ def get_job_directory(job_id: str) -> Path:
13
+ return _BASE_DIR / job_id
14
+
15
+
16
+ def get_input_video_path(job_id: str) -> Path:
17
+ return get_job_directory(job_id) / "input.mp4"
18
+
19
+
20
+ def get_output_video_path(job_id: str) -> Path:
21
+ return get_job_directory(job_id) / "output.mp4"
22
+
23
+
24
+ def get_first_frame_path(job_id: str) -> Path:
25
+ return get_job_directory(job_id) / "first_frame.jpg"
26
+
27
+
28
+ class JobStorage:
29
+ def __init__(self) -> None:
30
+ self._jobs: Dict[str, JobInfo] = {}
31
+ self._lock = RLock()
32
+
33
+ def create(self, job: JobInfo) -> None:
34
+ with self._lock:
35
+ self._jobs[job.job_id] = job
36
+
37
+ def get(self, job_id: str) -> Optional[JobInfo]:
38
+ with self._lock:
39
+ return self._jobs.get(job_id)
40
+
41
+ def update(self, job_id: str, **updates) -> None:
42
+ with self._lock:
43
+ job = self._jobs.get(job_id)
44
+ if not job:
45
+ return
46
+ for key, value in updates.items():
47
+ setattr(job, key, value)
48
+
49
+ def delete(self, job_id: str) -> None:
50
+ with self._lock:
51
+ self._jobs.pop(job_id, None)
52
+ shutil.rmtree(get_job_directory(job_id), ignore_errors=True)
53
+
54
+ def cleanup_expired(self, max_age: timedelta) -> None:
55
+ cutoff = datetime.utcnow() - max_age
56
+ to_delete = []
57
+ with self._lock:
58
+ for job_id, job in self._jobs.items():
59
+ if job.status in {JobStatus.COMPLETED, JobStatus.FAILED} and job.created_at < cutoff:
60
+ to_delete.append(job_id)
61
+ for job_id in to_delete:
62
+ self.delete(job_id)
63
+
64
+
65
+ _STORAGE: Optional[JobStorage] = None
66
+
67
+
68
+ def get_job_storage() -> JobStorage:
69
+ global _STORAGE
70
+ if _STORAGE is None:
71
+ _STORAGE = JobStorage()
72
+ return _STORAGE