Zhen Ye commited on
Commit
b8fe2b6
·
1 Parent(s): 9803004

added apple depth pro

Browse files
app.py CHANGED
@@ -19,6 +19,8 @@ from inference import process_first_frame, run_inference, run_segmentation
19
  from jobs.background import process_video_async
20
  from jobs.models import JobInfo, JobStatus
21
  from jobs.storage import (
 
 
22
  get_first_frame_path,
23
  get_input_video_path,
24
  get_job_directory,
@@ -272,6 +274,8 @@ async def detect_async_endpoint(
272
  input_path = get_input_video_path(job_id)
273
  output_path = get_output_video_path(job_id)
274
  first_frame_path = get_first_frame_path(job_id)
 
 
275
 
276
  try:
277
  _save_upload_to_path(video, input_path)
@@ -314,6 +318,9 @@ async def detect_async_endpoint(
314
  output_video_path=str(output_path),
315
  first_frame_path=str(first_frame_path),
316
  first_frame_detections=detections,
 
 
 
317
  )
318
  get_job_storage().create(job)
319
  asyncio.create_task(process_video_async(job_id))
@@ -321,8 +328,10 @@ async def detect_async_endpoint(
321
  return {
322
  "job_id": job_id,
323
  "first_frame_url": f"/detect/first-frame/{job_id}",
 
324
  "status_url": f"/detect/status/{job_id}",
325
  "video_url": f"/detect/video/{job_id}",
 
326
  "status": job.status.value,
327
  "first_frame_detections": detections,
328
  }
@@ -396,5 +405,54 @@ async def detect_video(job_id: str):
396
  )
397
 
398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  if __name__ == "__main__":
400
  uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)
 
19
  from jobs.background import process_video_async
20
  from jobs.models import JobInfo, JobStatus
21
  from jobs.storage import (
22
+ get_depth_output_path,
23
+ get_first_frame_depth_path,
24
  get_first_frame_path,
25
  get_input_video_path,
26
  get_job_directory,
 
274
  input_path = get_input_video_path(job_id)
275
  output_path = get_output_video_path(job_id)
276
  first_frame_path = get_first_frame_path(job_id)
277
+ depth_output_path = get_depth_output_path(job_id)
278
+ first_frame_depth_path = get_first_frame_depth_path(job_id)
279
 
280
  try:
281
  _save_upload_to_path(video, input_path)
 
318
  output_video_path=str(output_path),
319
  first_frame_path=str(first_frame_path),
320
  first_frame_detections=detections,
321
+ depth_estimator_name="depth_pro",
322
+ depth_output_path=str(depth_output_path),
323
+ first_frame_depth_path=str(first_frame_depth_path),
324
  )
325
  get_job_storage().create(job)
326
  asyncio.create_task(process_video_async(job_id))
 
328
  return {
329
  "job_id": job_id,
330
  "first_frame_url": f"/detect/first-frame/{job_id}",
331
+ "first_frame_depth_url": f"/detect/first-frame-depth/{job_id}",
332
  "status_url": f"/detect/status/{job_id}",
333
  "video_url": f"/detect/video/{job_id}",
334
+ "depth_video_url": f"/detect/depth-video/{job_id}",
335
  "status": job.status.value,
336
  "first_frame_detections": detections,
337
  }
 
405
  )
406
 
407
 
408
+ @app.get("/detect/depth-video/{job_id}")
409
+ async def detect_depth_video(job_id: str):
410
+ """Return depth estimation video."""
411
+ job = get_job_storage().get(job_id)
412
+ if not job:
413
+ raise HTTPException(status_code=404, detail="Job not found or expired.")
414
+ if not job.depth_output_path:
415
+ # Check if depth failed (partial success)
416
+ if job.partial_success and job.depth_error:
417
+ raise HTTPException(status_code=404, detail=f"Depth unavailable: {job.depth_error}")
418
+ raise HTTPException(status_code=404, detail="No depth video for this job.")
419
+ if job.status == JobStatus.FAILED:
420
+ raise HTTPException(status_code=500, detail=f"Job failed: {job.error}")
421
+ if job.status == JobStatus.CANCELLED:
422
+ raise HTTPException(status_code=410, detail="Job was cancelled")
423
+ if job.status == JobStatus.PROCESSING:
424
+ return JSONResponse(
425
+ status_code=202,
426
+ content={"detail": "Video still processing", "status": "processing"},
427
+ )
428
+ if not Path(job.depth_output_path).exists():
429
+ raise HTTPException(status_code=404, detail="Depth video file not found.")
430
+ return FileResponse(
431
+ path=job.depth_output_path,
432
+ media_type="video/mp4",
433
+ filename="depth.mp4",
434
+ )
435
+
436
+
437
+ @app.get("/detect/first-frame-depth/{job_id}")
438
+ async def detect_first_frame_depth(job_id: str):
439
+ """Return first frame depth visualization."""
440
+ job = get_job_storage().get(job_id)
441
+ if not job:
442
+ raise HTTPException(status_code=404, detail="Job not found or expired.")
443
+ if not job.first_frame_depth_path:
444
+ # Return placeholder or error if depth not available
445
+ if job.partial_success and job.depth_error:
446
+ raise HTTPException(status_code=404, detail=f"Depth unavailable: {job.depth_error}")
447
+ raise HTTPException(status_code=404, detail="First frame depth not found.")
448
+ if not Path(job.first_frame_depth_path).exists():
449
+ raise HTTPException(status_code=404, detail="First frame depth file not found.")
450
+ return FileResponse(
451
+ path=job.first_frame_depth_path,
452
+ media_type="image/jpeg",
453
+ filename="first_frame_depth.jpg",
454
+ )
455
+
456
+
457
  if __name__ == "__main__":
458
  uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)
demo.html CHANGED
@@ -238,6 +238,20 @@
238
  display: block;
239
  }
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  .download-btn {
242
  margin-top: 12px;
243
  padding: 10px 16px;
@@ -271,6 +285,12 @@
271
  text-align: center;
272
  }
273
 
 
 
 
 
 
 
274
  .spinner {
275
  border: 4px solid #e5e7eb;
276
  border-top: 4px solid #1f2933;
@@ -402,6 +422,16 @@
402
  <img id="firstFrameImage" class="frame-preview" alt="First frame preview">
403
  </div>
404
  </div>
 
 
 
 
 
 
 
 
 
 
405
  <div class="video-card">
406
  <div class="video-card-header">Original Video</div>
407
  <div class="video-card-body">
@@ -417,6 +447,16 @@
417
  </a>
418
  </div>
419
  </div>
 
 
 
 
 
 
 
 
 
 
420
  </div>
421
  </div>
422
  </div>
@@ -444,6 +484,12 @@
444
  const processedVideo = document.getElementById('processedVideo');
445
  const firstFrameImage = document.getElementById('firstFrameImage');
446
  const downloadBtn = document.getElementById('downloadBtn');
 
 
 
 
 
 
447
  let statusPoller = null;
448
  const statusLine = document.getElementById('statusLine');
449
  // Mode selection handler
@@ -512,9 +558,19 @@
512
  statusPoller = null;
513
  }
514
  firstFrameImage.removeAttribute('src');
 
 
 
 
515
  processedVideo.removeAttribute('src');
516
  processedVideo.load();
517
  downloadBtn.removeAttribute('href');
 
 
 
 
 
 
518
  statusLine.classList.add('hidden');
519
  statusLine.textContent = '';
520
 
@@ -568,6 +624,8 @@
568
  const videoUrl = URL.createObjectURL(blob);
569
  processedVideo.src = videoUrl;
570
  downloadBtn.href = videoUrl;
 
 
571
  } else if (statusData.status === 'failed') {
572
  clearInterval(statusPoller);
573
  statusPoller = null;
@@ -593,6 +651,47 @@
593
  }
594
  });
595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
596
  </script>
597
  </body>
598
  </html>
 
238
  display: block;
239
  }
240
 
241
+ .frame-placeholder {
242
+ width: 100%;
243
+ border-radius: 8px;
244
+ background: #f3f4f6;
245
+ color: #6b7280;
246
+ display: flex;
247
+ align-items: center;
248
+ justify-content: center;
249
+ min-height: 200px;
250
+ font-size: 0.95rem;
251
+ text-align: center;
252
+ padding: 16px;
253
+ }
254
+
255
  .download-btn {
256
  margin-top: 12px;
257
  padding: 10px 16px;
 
285
  text-align: center;
286
  }
287
 
288
+ .depth-status {
289
+ margin-top: 8px;
290
+ font-size: 0.85rem;
291
+ color: #6b7280;
292
+ }
293
+
294
  .spinner {
295
  border: 4px solid #e5e7eb;
296
  border-top: 4px solid #1f2933;
 
422
  <img id="firstFrameImage" class="frame-preview" alt="First frame preview">
423
  </div>
424
  </div>
425
+ <div class="video-card">
426
+ <div class="video-card-header">First Frame (Depth)</div>
427
+ <div class="video-card-body">
428
+ <div id="depthFramePlaceholder" class="frame-placeholder">
429
+ Depth preview will appear after processing.
430
+ </div>
431
+ <img id="depthFrameImage" class="frame-preview hidden" alt="First frame depth preview">
432
+ <div id="depthFrameStatus" class="depth-status"></div>
433
+ </div>
434
+ </div>
435
  <div class="video-card">
436
  <div class="video-card-header">Original Video</div>
437
  <div class="video-card-body">
 
447
  </a>
448
  </div>
449
  </div>
450
+ <div class="video-card">
451
+ <div class="video-card-header">Depth Video</div>
452
+ <div class="video-card-body">
453
+ <video id="depthVideo" controls autoplay loop class="hidden"></video>
454
+ <a id="depthDownloadBtn" class="download-btn hidden" download="depth.mp4">
455
+ Download Depth Video
456
+ </a>
457
+ <div id="depthVideoStatus" class="depth-status"></div>
458
+ </div>
459
+ </div>
460
  </div>
461
  </div>
462
  </div>
 
484
  const processedVideo = document.getElementById('processedVideo');
485
  const firstFrameImage = document.getElementById('firstFrameImage');
486
  const downloadBtn = document.getElementById('downloadBtn');
487
+ const depthFrameImage = document.getElementById('depthFrameImage');
488
+ const depthFramePlaceholder = document.getElementById('depthFramePlaceholder');
489
+ const depthFrameStatus = document.getElementById('depthFrameStatus');
490
+ const depthVideo = document.getElementById('depthVideo');
491
+ const depthDownloadBtn = document.getElementById('depthDownloadBtn');
492
+ const depthVideoStatus = document.getElementById('depthVideoStatus');
493
  let statusPoller = null;
494
  const statusLine = document.getElementById('statusLine');
495
  // Mode selection handler
 
558
  statusPoller = null;
559
  }
560
  firstFrameImage.removeAttribute('src');
561
+ depthFrameImage.removeAttribute('src');
562
+ depthFrameImage.classList.add('hidden');
563
+ depthFramePlaceholder.classList.remove('hidden');
564
+ depthFrameStatus.textContent = '';
565
  processedVideo.removeAttribute('src');
566
  processedVideo.load();
567
  downloadBtn.removeAttribute('href');
568
+ depthVideo.removeAttribute('src');
569
+ depthVideo.load();
570
+ depthVideo.classList.add('hidden');
571
+ depthDownloadBtn.removeAttribute('href');
572
+ depthDownloadBtn.classList.add('hidden');
573
+ depthVideoStatus.textContent = '';
574
  statusLine.classList.add('hidden');
575
  statusLine.textContent = '';
576
 
 
624
  const videoUrl = URL.createObjectURL(blob);
625
  processedVideo.src = videoUrl;
626
  downloadBtn.href = videoUrl;
627
+
628
+ await loadDepthAssets(data);
629
  } else if (statusData.status === 'failed') {
630
  clearInterval(statusPoller);
631
  statusPoller = null;
 
651
  }
652
  });
653
 
654
+ async function loadDepthAssets(jobData) {
655
+ if (!jobData.first_frame_depth_url || !jobData.depth_video_url) {
656
+ depthFrameStatus.textContent = 'Depth endpoints not available for this job.';
657
+ depthVideoStatus.textContent = 'Depth endpoints not available for this job.';
658
+ return;
659
+ }
660
+
661
+ try {
662
+ const frameResponse = await fetch(jobData.first_frame_depth_url);
663
+ if (frameResponse.ok) {
664
+ const frameBlob = await frameResponse.blob();
665
+ const frameUrl = URL.createObjectURL(frameBlob);
666
+ depthFrameImage.src = frameUrl;
667
+ depthFrameImage.classList.remove('hidden');
668
+ depthFramePlaceholder.classList.add('hidden');
669
+ } else {
670
+ const error = await frameResponse.json();
671
+ depthFrameStatus.textContent = error.detail || 'Depth preview unavailable.';
672
+ }
673
+ } catch (error) {
674
+ depthFrameStatus.textContent = 'Depth preview failed to load.';
675
+ }
676
+
677
+ try {
678
+ const depthResponse = await fetch(jobData.depth_video_url);
679
+ if (depthResponse.ok) {
680
+ const depthBlob = await depthResponse.blob();
681
+ const depthUrl = URL.createObjectURL(depthBlob);
682
+ depthVideo.src = depthUrl;
683
+ depthVideo.classList.remove('hidden');
684
+ depthDownloadBtn.href = depthUrl;
685
+ depthDownloadBtn.classList.remove('hidden');
686
+ } else {
687
+ const error = await depthResponse.json();
688
+ depthVideoStatus.textContent = error.detail || 'Depth video unavailable.';
689
+ }
690
+ } catch (error) {
691
+ depthVideoStatus.textContent = 'Depth video failed to load.';
692
+ }
693
+ }
694
+
695
  </script>
696
  </body>
697
  </html>
inference.py CHANGED
@@ -347,3 +347,139 @@ def run_segmentation(
347
  logging.info("Segmented video written to: %s", output_video_path)
348
 
349
  return output_video_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  logging.info("Segmented video written to: %s", output_video_path)
348
 
349
  return output_video_path
350
+
351
+
352
+ def run_depth_inference(
353
+ input_video_path: str,
354
+ output_video_path: str,
355
+ max_frames: Optional[int] = None,
356
+ depth_estimator_name: str = "depth_pro",
357
+ job_id: Optional[str] = None,
358
+ ) -> str:
359
+ """
360
+ Run depth estimation on a video.
361
+
362
+ Args:
363
+ input_video_path: Path to input video
364
+ output_video_path: Path to write depth visualization video
365
+ max_frames: Optional frame limit for testing
366
+ depth_estimator_name: Depth estimator to use (default: depth_pro)
367
+ job_id: Optional job ID for cancellation support
368
+
369
+ Returns:
370
+ Path to depth visualization video
371
+ """
372
+ try:
373
+ frames, fps, width, height = extract_frames(input_video_path)
374
+ except ValueError as exc:
375
+ logging.exception("Failed to decode video at %s", input_video_path)
376
+ raise
377
+
378
+ logging.info("Using depth estimator: %s", depth_estimator_name)
379
+
380
+ # Limit frames if requested
381
+ if max_frames is not None:
382
+ frames = frames[:max_frames]
383
+
384
+ # Process depth with stable normalization
385
+ processed_frames = process_frames_depth(frames, depth_estimator_name, job_id)
386
+
387
+ # Write output video
388
+ write_video(processed_frames, output_video_path, fps=fps, width=width, height=height)
389
+ logging.info("Depth video written to: %s", output_video_path)
390
+
391
+ return output_video_path
392
+
393
+
394
+ def process_frames_depth(
395
+ frames: List[np.ndarray],
396
+ depth_estimator_name: str,
397
+ job_id: Optional[str] = None,
398
+ ) -> List[np.ndarray]:
399
+ """
400
+ Process all frames through depth estimator with stable normalization.
401
+
402
+ Two-pass approach:
403
+ 1. Compute depth for all frames and find global min/max
404
+ 2. Colorize using global range to avoid flicker
405
+
406
+ Args:
407
+ frames: List of frames (HxWx3 BGR uint8)
408
+ depth_estimator_name: Name of depth estimator to use
409
+ job_id: Optional job ID for cancellation
410
+
411
+ Returns:
412
+ List of depth visualization frames (HxWx3 RGB uint8)
413
+ """
414
+ from models.depth_estimators.model_loader import load_depth_estimator
415
+
416
+ estimator = load_depth_estimator(depth_estimator_name)
417
+
418
+ # First pass: Compute all depth maps and find global range
419
+ depth_maps = []
420
+ all_values = []
421
+ for idx, frame in enumerate(frames):
422
+ _check_cancellation(job_id)
423
+
424
+ lock = _get_model_lock("depth", estimator.name)
425
+ with lock:
426
+ depth_result = estimator.predict(frame)
427
+
428
+ depth_maps.append(depth_result.depth_map)
429
+ all_values.append(depth_result.depth_map.ravel())
430
+
431
+ if idx % 10 == 0:
432
+ logging.debug("Computed depth for frame %d/%d", idx + 1, len(frames))
433
+
434
+ # Compute global min/max (using percentiles to handle outliers)
435
+ all_depths = np.concatenate(all_values)
436
+ global_min = np.percentile(all_depths, 1) # 1st percentile to clip outliers
437
+ global_max = np.percentile(all_depths, 99) # 99th percentile
438
+
439
+ logging.info(
440
+ "Depth range: %.2f - %.2f meters (1st-99th percentile)",
441
+ global_min,
442
+ global_max,
443
+ )
444
+
445
+ # Second pass: Colorize with stable normalization
446
+ processed = []
447
+ for idx, depth_map in enumerate(depth_maps):
448
+ depth_vis = colorize_depth_map(depth_map, global_min, global_max)
449
+ processed.append(depth_vis)
450
+
451
+ if idx % 10 == 0:
452
+ logging.debug("Colorized frame %d/%d", idx + 1, len(depth_maps))
453
+
454
+ return processed
455
+
456
+
457
+ def colorize_depth_map(
458
+ depth_map: np.ndarray,
459
+ global_min: float,
460
+ global_max: float,
461
+ ) -> np.ndarray:
462
+ """
463
+ Convert depth map to RGB visualization using TURBO colormap.
464
+
465
+ Args:
466
+ depth_map: HxW float32 depth in meters
467
+ global_min: Minimum depth across entire video (for stable normalization)
468
+ global_max: Maximum depth across entire video (for stable normalization)
469
+
470
+ Returns:
471
+ HxWx3 uint8 RGB image
472
+ """
473
+ import cv2
474
+
475
+ if global_max - global_min < 1e-6: # Handle uniform depth
476
+ depth_norm = np.zeros_like(depth_map, dtype=np.uint8)
477
+ else:
478
+ # Clip to global range to handle outliers
479
+ depth_clipped = np.clip(depth_map, global_min, global_max)
480
+ depth_norm = ((depth_clipped - global_min) / (global_max - global_min) * 255).astype(np.uint8)
481
+
482
+ # Apply TURBO colormap for vibrant, perceptually uniform visualization
483
+ colored = cv2.applyColorMap(depth_norm, cv2.COLORMAP_TURBO)
484
+
485
+ return colored
jobs/background.py CHANGED
@@ -2,9 +2,11 @@ import asyncio
2
  import logging
3
  from datetime import datetime
4
 
 
 
5
  from jobs.models import JobStatus
6
- from jobs.storage import get_job_storage
7
- from inference import run_inference, run_segmentation
8
 
9
 
10
  async def process_video_async(job_id: str) -> None:
@@ -13,9 +15,15 @@ async def process_video_async(job_id: str) -> None:
13
  if not job:
14
  return
15
 
 
 
 
 
 
16
  try:
 
17
  if job.mode == "segmentation":
18
- output_path = await asyncio.to_thread(
19
  run_segmentation,
20
  job.input_video_path,
21
  job.output_video_path,
@@ -25,7 +33,7 @@ async def process_video_async(job_id: str) -> None:
25
  job_id,
26
  )
27
  else:
28
- output_path = await asyncio.to_thread(
29
  run_inference,
30
  job.input_video_path,
31
  job.output_video_path,
@@ -34,12 +42,52 @@ async def process_video_async(job_id: str) -> None:
34
  job.detector_name,
35
  job_id,
36
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  storage.update(
38
  job_id,
39
  status=JobStatus.COMPLETED,
40
  completed_at=datetime.utcnow(),
41
- output_video_path=output_path,
 
 
 
42
  )
 
43
  except RuntimeError as exc:
44
  # Handle cancellation specifically
45
  if "cancelled" in str(exc).lower():
 
2
  import logging
3
  from datetime import datetime
4
 
5
+ import torch
6
+
7
  from jobs.models import JobStatus
8
+ from jobs.storage import get_job_storage, get_depth_output_path
9
+ from inference import run_inference, run_segmentation, run_depth_inference
10
 
11
 
12
  async def process_video_async(job_id: str) -> None:
 
15
  if not job:
16
  return
17
 
18
+ detection_path = None
19
+ depth_path = None
20
+ depth_error = None
21
+ partial_success = False
22
+
23
  try:
24
+ # Run detection or segmentation first
25
  if job.mode == "segmentation":
26
+ detection_path = await asyncio.to_thread(
27
  run_segmentation,
28
  job.input_video_path,
29
  job.output_video_path,
 
33
  job_id,
34
  )
35
  else:
36
+ detection_path = await asyncio.to_thread(
37
  run_inference,
38
  job.input_video_path,
39
  job.output_video_path,
 
42
  job.detector_name,
43
  job_id,
44
  )
45
+
46
+ # Try to run depth estimation
47
+ try:
48
+ depth_path = await asyncio.to_thread(
49
+ run_depth_inference,
50
+ job.input_video_path,
51
+ str(get_depth_output_path(job_id)),
52
+ None, # max_frames
53
+ job.depth_estimator_name,
54
+ job_id,
55
+ )
56
+ logging.info("Depth estimation completed for job %s", job_id)
57
+ except (ImportError, ModuleNotFoundError) as exc:
58
+ logging.exception("Depth model not available for job %s", job_id)
59
+ depth_error = f"Depth model import failed: {exc}"
60
+ partial_success = True
61
+ except torch.cuda.OutOfMemoryError:
62
+ logging.exception("Depth estimation failed due to GPU OOM for job %s", job_id)
63
+ depth_error = "Depth estimation failed due to GPU memory limits"
64
+ partial_success = True
65
+ except RuntimeError as exc:
66
+ # Handle cancellation specifically for depth
67
+ if "cancelled" in str(exc).lower():
68
+ logging.info("Depth processing cancelled for job %s", job_id)
69
+ depth_error = "Depth processing cancelled"
70
+ partial_success = True
71
+ else:
72
+ logging.exception("Depth estimation failed for job %s", job_id)
73
+ depth_error = f"Depth processing error: {str(exc)}"
74
+ partial_success = True
75
+ except Exception as exc:
76
+ logging.exception("Depth estimation failed for job %s", job_id)
77
+ depth_error = f"Depth processing error: {str(exc)}"
78
+ partial_success = True
79
+
80
+ # Mark as completed (with or without depth)
81
  storage.update(
82
  job_id,
83
  status=JobStatus.COMPLETED,
84
  completed_at=datetime.utcnow(),
85
+ output_video_path=detection_path,
86
+ depth_output_path=depth_path,
87
+ partial_success=partial_success,
88
+ depth_error=depth_error,
89
  )
90
+
91
  except RuntimeError as exc:
92
  # Handle cancellation specifically
93
  if "cancelled" in str(exc).lower():
jobs/models.py CHANGED
@@ -26,3 +26,9 @@ class JobInfo:
26
  completed_at: Optional[datetime] = None
27
  error: Optional[str] = None
28
  first_frame_detections: List[Dict[str, Any]] = field(default_factory=list)
 
 
 
 
 
 
 
26
  completed_at: Optional[datetime] = None
27
  error: Optional[str] = None
28
  first_frame_detections: List[Dict[str, Any]] = field(default_factory=list)
29
+ # Depth estimation fields
30
+ depth_estimator_name: str = "depth_pro" # Always depth_pro for now
31
+ depth_output_path: Optional[str] = None
32
+ first_frame_depth_path: Optional[str] = None
33
+ partial_success: bool = False # True if one component failed but job completed
34
+ depth_error: Optional[str] = None # Error message if depth failed
jobs/storage.py CHANGED
@@ -25,6 +25,16 @@ def get_first_frame_path(job_id: str) -> Path:
25
  return get_job_directory(job_id) / "first_frame.jpg"
26
 
27
 
 
 
 
 
 
 
 
 
 
 
28
  class JobStorage:
29
  def __init__(self) -> None:
30
  self._jobs: Dict[str, JobInfo] = {}
 
25
  return get_job_directory(job_id) / "first_frame.jpg"
26
 
27
 
28
+ def get_depth_output_path(job_id: str) -> Path:
29
+ """Get path for depth estimation video output."""
30
+ return get_job_directory(job_id) / "depth.mp4"
31
+
32
+
33
+ def get_first_frame_depth_path(job_id: str) -> Path:
34
+ """Get path for first frame depth visualization."""
35
+ return get_job_directory(job_id) / "first_frame_depth.jpg"
36
+
37
+
38
  class JobStorage:
39
  def __init__(self) -> None:
40
  self._jobs: Dict[str, JobInfo] = {}
models/depth_estimators/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Depth estimation models for video processing."""
2
+
3
+ from .base import DepthEstimator, DepthResult
4
+ from .depth_pro import DepthProEstimator
5
+ from .model_loader import list_depth_estimators, load_depth_estimator
6
+
7
+ __all__ = [
8
+ "DepthEstimator",
9
+ "DepthResult",
10
+ "DepthProEstimator",
11
+ "load_depth_estimator",
12
+ "list_depth_estimators",
13
+ ]
models/depth_estimators/base.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import NamedTuple
2
+
3
+ import numpy as np
4
+
5
+
6
+ class DepthResult(NamedTuple):
7
+ """Result from depth estimation inference."""
8
+ depth_map: np.ndarray # HxW float32 depth in meters
9
+ focal_length: float # Estimated focal length in pixels
10
+
11
+
12
+ class DepthEstimator:
13
+ """Base interface for depth estimation models."""
14
+
15
+ name: str
16
+
17
+ def predict(self, frame: np.ndarray) -> DepthResult:
18
+ """
19
+ Run depth estimation on a single frame.
20
+
21
+ Args:
22
+ frame: Input image as numpy array (HxWxC, BGR format from OpenCV)
23
+
24
+ Returns:
25
+ DepthResult with depth_map and focal_length
26
+ """
27
+ raise NotImplementedError
models/depth_estimators/depth_pro.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+
7
+ from .base import DepthEstimator, DepthResult
8
+
9
+
10
+ class DepthProEstimator(DepthEstimator):
11
+ """Apple Depth Pro depth estimator."""
12
+
13
+ name = "depth_pro"
14
+
15
+ def __init__(self):
16
+ """Initialize Depth Pro model."""
17
+ try:
18
+ import depth_pro
19
+ except ImportError as exc:
20
+ raise ImportError(
21
+ "depth_pro package not installed. "
22
+ "Install with: pip install git+https://github.com/apple/ml-depth-pro.git"
23
+ ) from exc
24
+
25
+ logging.info("Loading Depth Pro model...")
26
+ self.model, self.transform = depth_pro.create_model_and_transforms()
27
+ self.model.eval()
28
+
29
+ # Move model to GPU if available
30
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ if torch.cuda.is_available():
32
+ self.model = self.model.cuda()
33
+ logging.info("Depth Pro model loaded on GPU")
34
+ else:
35
+ logging.warning("Depth Pro model loaded on CPU (no CUDA available)")
36
+
37
+ def predict(self, frame: np.ndarray) -> DepthResult:
38
+ """
39
+ Run depth estimation on a single frame.
40
+
41
+ Args:
42
+ frame: HxWx3 BGR uint8 numpy array (OpenCV format)
43
+
44
+ Returns:
45
+ DepthResult with depth_map (HxW float32 in meters) and focal_length
46
+ """
47
+ # Convert BGR to RGB
48
+ rgb_frame = frame[:, :, ::-1] # BGR → RGB
49
+
50
+ # Convert to PIL Image for transform
51
+ pil_image = Image.fromarray(rgb_frame)
52
+
53
+ # Apply transform and move to device
54
+ image_tensor = self.transform(pil_image)
55
+ image_tensor = image_tensor.to(self.device)
56
+
57
+ # Run inference (no gradient needed)
58
+ with torch.no_grad():
59
+ prediction = self.model.infer(image_tensor, f_px=None)
60
+
61
+ # Extract depth map and move to CPU/numpy
62
+ # prediction is a dict: {"depth": tensor, "focallength_px": tensor}
63
+ depth_tensor = prediction["depth"]
64
+ focal_length_tensor = prediction.get("focallength_px")
65
+
66
+ # Convert to numpy, remove batch dimension if present
67
+ depth_map = depth_tensor.cpu().numpy().squeeze()
68
+
69
+ # Extract focal length
70
+ if focal_length_tensor is not None:
71
+ focal_length = float(focal_length_tensor.cpu().item())
72
+ else:
73
+ focal_length = 1.0
74
+
75
+ return DepthResult(depth_map=depth_map, focal_length=focal_length)
models/depth_estimators/model_loader.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Registry and loader for depth estimators."""
2
+
3
+ from functools import lru_cache
4
+ from typing import Callable, Dict
5
+
6
+ from .base import DepthEstimator
7
+ from .depth_pro import DepthProEstimator
8
+
9
+
10
+ # Registry of depth estimators
11
+ _REGISTRY: Dict[str, Callable[[], DepthEstimator]] = {
12
+ "depth_pro": DepthProEstimator,
13
+ }
14
+
15
+
16
+ @lru_cache(maxsize=None)
17
+ def _get_cached_depth_estimator(name: str) -> DepthEstimator:
18
+ """
19
+ Create and cache depth estimator instance.
20
+
21
+ Args:
22
+ name: Depth estimator name (e.g., "depth_pro")
23
+
24
+ Returns:
25
+ Depth estimator instance
26
+ """
27
+ return _create_depth_estimator(name)
28
+
29
+
30
+ def _create_depth_estimator(name: str) -> DepthEstimator:
31
+ """
32
+ Create depth estimator instance.
33
+
34
+ Args:
35
+ name: Depth estimator name
36
+
37
+ Returns:
38
+ Depth estimator instance
39
+
40
+ Raises:
41
+ KeyError: If estimator not found in registry
42
+ """
43
+ if name not in _REGISTRY:
44
+ raise KeyError(
45
+ f"Depth estimator '{name}' not found. Available: {list(_REGISTRY.keys())}"
46
+ )
47
+
48
+ estimator_class = _REGISTRY[name]
49
+ return estimator_class()
50
+
51
+
52
+ def load_depth_estimator(name: str = "depth_pro") -> DepthEstimator:
53
+ """
54
+ Load depth estimator by name (with caching).
55
+
56
+ Args:
57
+ name: Depth estimator name (default: "depth_pro")
58
+
59
+ Returns:
60
+ Cached depth estimator instance
61
+ """
62
+ return _get_cached_depth_estimator(name)
63
+
64
+
65
+ def list_depth_estimators() -> list[str]:
66
+ """Return list of available depth estimator names."""
67
+ return list(_REGISTRY.keys())
requirements.txt CHANGED
@@ -11,3 +11,4 @@ huggingface-hub
11
  ultralytics
12
  timm
13
  ffmpeg-python
 
 
11
  ultralytics
12
  timm
13
  ffmpeg-python
14
+ depth-pro @ git+https://github.com/apple/ml-depth-pro.git