Zhen Ye commited on
Commit
45eb65b
·
1 Parent(s): b2e7d79

feat(inference): enable full multi-GPU support for all models

Browse files

- Update inference.py to parallelize detection, segmentation, and depth estimation across all available GPUs
- Update detectors (YOLOv8, DETR, GroundingDINO, DroneYolo) to accept device argument
- Update SAM3 and DepthAnythingV2 to accept device argument
- Add device-specific model loading to all model loaders
- Remove OwlV2 support

app.py CHANGED
@@ -228,6 +228,8 @@ async def detect_endpoint(
228
  output_path,
229
  query_list,
230
  detector_name=detector_name,
 
 
231
  )
232
  except ValueError as exc:
233
  logging.exception("Video processing failed.")
@@ -261,6 +263,7 @@ async def detect_async_endpoint(
261
  detector: str = Form("hf_yolov8"),
262
  segmenter: str = Form("sam3"),
263
  depth_estimator: str = Form("depth"),
 
264
  ):
265
  if mode not in VALID_MODES:
266
  raise HTTPException(
@@ -313,6 +316,7 @@ async def detect_async_endpoint(
313
  detector_name=detector_name,
314
  segmenter_name=segmenter,
315
  depth_estimator_name=depth_estimator,
 
316
  )
317
  cv2.imwrite(str(first_frame_path), processed_frame)
318
  except Exception:
@@ -332,6 +336,7 @@ async def detect_async_endpoint(
332
  first_frame_path=str(first_frame_path),
333
  first_frame_detections=detections,
334
  depth_estimator_name=depth_estimator,
 
335
  depth_output_path=str(depth_output_path),
336
  first_frame_depth_path=str(first_frame_depth_path),
337
  )
 
228
  output_path,
229
  query_list,
230
  detector_name=detector_name,
231
+ depth_estimator_name="depth", # Synch endpoint default
232
+ depth_scale=1.0,
233
  )
234
  except ValueError as exc:
235
  logging.exception("Video processing failed.")
 
263
  detector: str = Form("hf_yolov8"),
264
  segmenter: str = Form("sam3"),
265
  depth_estimator: str = Form("depth"),
266
+ depth_scale: float = Form(1.0),
267
  ):
268
  if mode not in VALID_MODES:
269
  raise HTTPException(
 
316
  detector_name=detector_name,
317
  segmenter_name=segmenter,
318
  depth_estimator_name=depth_estimator,
319
+ depth_scale=depth_scale,
320
  )
321
  cv2.imwrite(str(first_frame_path), processed_frame)
322
  except Exception:
 
336
  first_frame_path=str(first_frame_path),
337
  first_frame_detections=detections,
338
  depth_estimator_name=depth_estimator,
339
+ depth_scale=float(depth_scale),
340
  depth_output_path=str(depth_output_path),
341
  first_frame_depth_path=str(first_frame_depth_path),
342
  )
inference.py CHANGED
@@ -5,8 +5,13 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple
5
 
6
  import cv2
7
  import numpy as np
8
- from models.model_loader import load_detector
9
- from models.segmenters.model_loader import load_segmenter
 
 
 
 
 
10
  from utils.video import extract_frames, write_video
11
 
12
 
@@ -186,14 +191,25 @@ def _attach_depth_metrics(
186
  detections: List[Dict[str, Any]],
187
  depth_estimator_name: Optional[str],
188
  depth_scale: float,
 
189
  ) -> None:
190
- if not detections or not depth_estimator_name:
191
  return
192
 
193
  from models.depth_estimators.model_loader import load_depth_estimator
194
 
195
- estimator = load_depth_estimator(depth_estimator_name)
196
- lock = _get_model_lock("depth", estimator.name)
 
 
 
 
 
 
 
 
 
 
197
  with lock:
198
  depth_result = estimator.predict(frame)
199
 
@@ -246,25 +262,56 @@ def infer_frame(
246
  frame: np.ndarray,
247
  queries: Sequence[str],
248
  detector_name: Optional[str] = None,
 
 
 
 
249
  ) -> tuple[np.ndarray, List[Dict[str, Any]]]:
250
- detector = load_detector(detector_name)
 
 
 
 
251
  text_queries = list(queries) or ["object"]
252
  try:
253
- lock = _get_model_lock("detector", detector.name)
 
 
 
 
254
  with lock:
255
  result = detector.predict(frame, text_queries)
256
  detections = _build_detection_records(
257
  result.boxes, result.scores, result.labels, text_queries, result.label_names
258
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  except Exception:
260
  logging.exception("Inference failed for queries %s", text_queries)
261
  raise
262
  return draw_boxes(
263
  frame,
264
  result.boxes,
265
- labels=result.labels,
266
- queries=text_queries,
267
- label_names=result.label_names,
268
  ), detections
269
 
270
 
@@ -272,9 +319,19 @@ def infer_segmentation_frame(
272
  frame: np.ndarray,
273
  text_queries: Optional[List[str]] = None,
274
  segmenter_name: Optional[str] = None,
 
275
  ) -> tuple[np.ndarray, Any]:
276
- segmenter = load_segmenter(segmenter_name)
277
- lock = _get_model_lock("segmenter", segmenter.name)
 
 
 
 
 
 
 
 
 
278
  with lock:
279
  result = segmenter.predict(frame, text_prompts=text_queries)
280
  labels = text_queries or []
@@ -335,6 +392,8 @@ def run_inference(
335
  max_frames: Optional[int] = None,
336
  detector_name: Optional[str] = None,
337
  job_id: Optional[str] = None,
 
 
338
  ) -> str:
339
  """
340
  Run object detection inference on a video.
@@ -346,9 +405,8 @@ def run_inference(
346
  max_frames: Optional frame limit for testing
347
  detector_name: Detector to use (default: hf_yolov8)
348
  job_id: Optional job ID for cancellation support
349
-
350
- Returns:
351
- Path to processed output video
352
  """
353
  try:
354
  frames, fps, width, height = extract_frames(input_video_path)
@@ -367,17 +425,102 @@ def run_inference(
367
  active_detector = detector_name or "hf_yolov8"
368
  logging.info("Using detector: %s", active_detector)
369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  # Process frames
371
- processed_frames: List[np.ndarray] = []
372
- for idx, frame in enumerate(frames):
373
- # Check for cancellation every frame
374
- _check_cancellation(job_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
 
376
- if max_frames is not None and idx >= max_frames:
377
- break
378
- logging.debug("Processing frame %d", idx)
379
- processed_frame, _ = infer_frame(frame, queries, detector_name=active_detector)
380
- processed_frames.append(processed_frame)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
 
382
  # Write output video
383
  write_video(processed_frames, output_video_path, fps=fps, width=width, height=height)
@@ -403,16 +546,64 @@ def run_segmentation(
403
  active_segmenter = segmenter_name or "sam3"
404
  logging.info("Using segmenter: %s with queries: %s", active_segmenter, queries)
405
 
406
- processed_frames: List[np.ndarray] = []
407
- for idx, frame in enumerate(frames):
408
- # Check for cancellation every frame
409
- _check_cancellation(job_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
 
411
- if max_frames is not None and idx >= max_frames:
412
- break
413
- logging.debug("Processing frame %d", idx)
414
- processed_frame, _ = infer_segmentation_frame(frame, text_queries=queries, segmenter_name=active_segmenter)
415
- processed_frames.append(processed_frame)
416
 
417
  write_video(processed_frames, output_video_path, fps=fps, width=width, height=height)
418
  logging.info("Segmented video written to: %s", output_video_path)
@@ -490,25 +681,80 @@ def process_frames_depth(
490
  Returns:
491
  List of depth visualization frames (HxWx3 RGB uint8)
492
  """
493
- from models.depth_estimators.model_loader import load_depth_estimator
494
-
495
- estimator = load_depth_estimator(depth_estimator_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
 
497
  # First pass: Compute all depth maps and find global range
498
- depth_maps = []
499
  all_values = []
500
- for idx, frame in enumerate(frames):
501
- _check_cancellation(job_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
 
503
- lock = _get_model_lock("depth", estimator.name)
504
- with lock:
505
- depth_result = estimator.predict(frame)
 
 
 
506
 
507
- depth_maps.append(depth_result.depth_map)
508
- all_values.append(depth_result.depth_map.ravel())
 
509
 
510
- if idx % 10 == 0:
511
- logging.debug("Computed depth for frame %d/%d", idx + 1, len(frames))
 
 
 
512
 
513
  # Compute global min/max (using percentiles to handle outliers)
514
  all_depths = np.concatenate(all_values).astype(np.float32, copy=False)
 
5
 
6
  import cv2
7
  import numpy as np
8
+ import torch
9
+ from concurrent.futures import ThreadPoolExecutor
10
+ from threading import RLock
11
+ from models.detectors.base import ObjectDetector
12
+ from models.model_loader import load_detector, load_detector_on_device
13
+ from models.segmenters.model_loader import load_segmenter, load_segmenter_on_device
14
+ from models.depth_estimators.model_loader import load_depth_estimator_on_device
15
  from utils.video import extract_frames, write_video
16
 
17
 
 
191
  detections: List[Dict[str, Any]],
192
  depth_estimator_name: Optional[str],
193
  depth_scale: float,
194
+ estimator_instance: Optional[Any] = None,
195
  ) -> None:
196
+ if not detections or (not depth_estimator_name and not estimator_instance):
197
  return
198
 
199
  from models.depth_estimators.model_loader import load_depth_estimator
200
 
201
+ if estimator_instance:
202
+ estimator = estimator_instance
203
+ # Use instance lock if available, or create one
204
+ if hasattr(estimator, "lock"):
205
+ lock = estimator.lock
206
+ else:
207
+ # Fallback (shouldn't happen with our new setup but safe)
208
+ lock = _get_model_lock("depth", estimator.name)
209
+ else:
210
+ estimator = load_depth_estimator(depth_estimator_name)
211
+ lock = _get_model_lock("depth", estimator.name)
212
+
213
  with lock:
214
  depth_result = estimator.predict(frame)
215
 
 
262
  frame: np.ndarray,
263
  queries: Sequence[str],
264
  detector_name: Optional[str] = None,
265
+ depth_estimator_name: Optional[str] = None,
266
+ depth_scale: float = 1.0,
267
+ detector_instance: Optional[ObjectDetector] = None,
268
+ depth_estimator_instance: Optional[Any] = None,
269
  ) -> tuple[np.ndarray, List[Dict[str, Any]]]:
270
+ if detector_instance:
271
+ detector = detector_instance
272
+ else:
273
+ detector = load_detector(detector_name)
274
+
275
  text_queries = list(queries) or ["object"]
276
  try:
277
+ if hasattr(detector, "lock"):
278
+ lock = detector.lock
279
+ else:
280
+ lock = _get_model_lock("detector", detector.name)
281
+
282
  with lock:
283
  result = detector.predict(frame, text_queries)
284
  detections = _build_detection_records(
285
  result.boxes, result.scores, result.labels, text_queries, result.label_names
286
  )
287
+
288
+ if depth_estimator_name or depth_estimator_instance:
289
+ try:
290
+ _attach_depth_metrics(
291
+ frame, detections, depth_estimator_name, depth_scale, estimator_instance=depth_estimator_instance
292
+ )
293
+ except Exception:
294
+ logging.exception("Depth estimation failed for frame")
295
+
296
+ # Re-build display labels to incude depth if available
297
+ display_labels = []
298
+ for i, det in enumerate(detections):
299
+ label = det["label"]
300
+ if det.get("depth_valid") and det.get("depth_est_m") is not None:
301
+ # Add depth to label, e.g. "car 12m"
302
+ depth_str = f"{int(det['depth_est_m'])}m"
303
+ label = f"{label} {depth_str}"
304
+ display_labels.append(label)
305
+
306
  except Exception:
307
  logging.exception("Inference failed for queries %s", text_queries)
308
  raise
309
  return draw_boxes(
310
  frame,
311
  result.boxes,
312
+ labels=None, # Use custom labels
313
+ queries=None,
314
+ label_names=display_labels,
315
  ), detections
316
 
317
 
 
319
  frame: np.ndarray,
320
  text_queries: Optional[List[str]] = None,
321
  segmenter_name: Optional[str] = None,
322
+ segmenter_instance: Optional[Any] = None,
323
  ) -> tuple[np.ndarray, Any]:
324
+ if segmenter_instance:
325
+ segmenter = segmenter_instance
326
+ # Use instance lock if available
327
+ if hasattr(segmenter, "lock"):
328
+ lock = segmenter.lock
329
+ else:
330
+ lock = _get_model_lock("segmenter", segmenter.name)
331
+ else:
332
+ segmenter = load_segmenter(segmenter_name)
333
+ lock = _get_model_lock("segmenter", segmenter.name)
334
+
335
  with lock:
336
  result = segmenter.predict(frame, text_prompts=text_queries)
337
  labels = text_queries or []
 
392
  max_frames: Optional[int] = None,
393
  detector_name: Optional[str] = None,
394
  job_id: Optional[str] = None,
395
+ depth_estimator_name: Optional[str] = None,
396
+ depth_scale: float = 1.0,
397
  ) -> str:
398
  """
399
  Run object detection inference on a video.
 
405
  max_frames: Optional frame limit for testing
406
  detector_name: Detector to use (default: hf_yolov8)
407
  job_id: Optional job ID for cancellation support
408
+ depth_estimator_name: Optional depth estimator name
409
+ depth_scale: Scale factor for depth estimation
 
410
  """
411
  try:
412
  frames, fps, width, height = extract_frames(input_video_path)
 
425
  active_detector = detector_name or "hf_yolov8"
426
  logging.info("Using detector: %s", active_detector)
427
 
428
+ # Detect GPUs
429
+ num_gpus = torch.cuda.device_count()
430
+ detectors = None
431
+ depth_estimators = None
432
+
433
+ if num_gpus > 1:
434
+ logging.info("Detected %d GPUs. Enabling Multi-GPU inference.", num_gpus)
435
+ # Initialize one detector per GPU
436
+ detectors = []
437
+ depth_estimators = []
438
+ for i in range(num_gpus):
439
+ device_str = f"cuda:{i}"
440
+ logging.info("Loading detector/depth on %s", device_str)
441
+
442
+ # Detector
443
+ det = load_detector_on_device(active_detector, device_str)
444
+ det.lock = RLock()
445
+ detectors.append(det)
446
+
447
+ # Depth (if requested)
448
+ if depth_estimator_name:
449
+ depth = load_depth_estimator_on_device(depth_estimator_name, device_str)
450
+ depth.lock = RLock()
451
+ depth_estimators.append(depth)
452
+ else:
453
+ depth_estimators.append(None)
454
+
455
+ else:
456
+ logging.info("Single device detected. Using standard inference.")
457
+ detectors = None
458
+
459
+ processed_frames_map = {}
460
+
461
  # Process frames
462
+ if detectors:
463
+ # Multi-GPU Parallel Processing
464
+ def process_frame_task(frame_idx: int, frame_data: np.ndarray) -> tuple[int, np.ndarray]:
465
+ # Determine which GPU to use based on frame index (round-robin)
466
+ gpu_idx = frame_idx % len(detectors)
467
+ detector_instance = detectors[gpu_idx]
468
+ depth_instance = depth_estimators[gpu_idx] if depth_estimators else None
469
+
470
+ # Run depth estimation every 3 frames if configured
471
+ active_depth_name = depth_estimator_name if (frame_idx % 3 == 0) else None
472
+ active_depth_instance = depth_instance if (frame_idx % 3 == 0) else None
473
+
474
+ processed, _ = infer_frame(
475
+ frame_data,
476
+ queries,
477
+ detector_name=None, # Use instance
478
+ depth_estimator_name=active_depth_name,
479
+ depth_scale=depth_scale,
480
+ detector_instance=detector_instance,
481
+ depth_estimator_instance=active_depth_instance
482
+ )
483
+ return frame_idx, processed
484
+
485
+ # Thread pool with more workers than GPUs to keep them fed
486
+ max_workers = min(len(detectors) * 2, 8)
487
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
488
+ futures = []
489
+ for idx, frame in enumerate(frames):
490
+ _check_cancellation(job_id)
491
+ if max_frames is not None and idx >= max_frames:
492
+ break
493
+ futures.append(executor.submit(process_frame_task, idx, frame))
494
+
495
+ for future in futures:
496
+ idx, result_frame = future.result() # Wait for completion (in order or not, but we verify order)
497
+ processed_frames_map[idx] = result_frame
498
+
499
+ # Reasemble in order
500
+ processed_frames = [processed_frames_map[i] for i in range(len(processed_frames_map))]
501
 
502
+ else:
503
+ # Standard Single-Threaded Loop
504
+ processed_frames = []
505
+ for idx, frame in enumerate(frames):
506
+ # Check for cancellation every frame
507
+ _check_cancellation(job_id)
508
+
509
+ if max_frames is not None and idx >= max_frames:
510
+ break
511
+ logging.debug("Processing frame %d", idx)
512
+
513
+ # Run depth estimation every 3 frames if configured
514
+ active_depth = depth_estimator_name if (idx % 3 == 0) else None
515
+
516
+ processed_frame, _ = infer_frame(
517
+ frame,
518
+ queries,
519
+ detector_name=active_detector,
520
+ depth_estimator_name=active_depth,
521
+ depth_scale=depth_scale
522
+ )
523
+ processed_frames.append(processed_frame)
524
 
525
  # Write output video
526
  write_video(processed_frames, output_video_path, fps=fps, width=width, height=height)
 
546
  active_segmenter = segmenter_name or "sam3"
547
  logging.info("Using segmenter: %s with queries: %s", active_segmenter, queries)
548
 
549
+ # Detect GPUs
550
+ num_gpus = torch.cuda.device_count()
551
+ segmenters = None
552
+ if num_gpus > 1:
553
+ logging.info("Detected %d GPUs. Enabling Multi-GPU segmentation.", num_gpus)
554
+ segmenters = []
555
+ for i in range(num_gpus):
556
+ device_str = f"cuda:{i}"
557
+ logging.info("Loading segmenter on %s", device_str)
558
+ seg = load_segmenter_on_device(active_segmenter, device_str)
559
+ seg.lock = RLock()
560
+ segmenters.append(seg)
561
+ else:
562
+ logging.info("Single device detected. Using standard segmentation.")
563
+ segmenters = None
564
+
565
+ processed_frames_map = {}
566
+
567
+ if segmenters:
568
+ # Multi-GPU Parallel Processing
569
+ def process_segmentation_task(frame_idx: int, frame_data: np.ndarray) -> tuple[int, np.ndarray]:
570
+ gpu_idx = frame_idx % len(segmenters)
571
+ segmenter_instance = segmenters[gpu_idx]
572
+
573
+ processed, _ = infer_segmentation_frame(
574
+ frame_data,
575
+ text_queries=queries,
576
+ segmenter_name=None,
577
+ segmenter_instance=segmenter_instance
578
+ )
579
+ return frame_idx, processed
580
+
581
+ max_workers = min(len(segmenters) * 2, 8)
582
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
583
+ futures = []
584
+ for idx, frame in enumerate(frames):
585
+ _check_cancellation(job_id)
586
+ if max_frames is not None and idx >= max_frames:
587
+ break
588
+ futures.append(executor.submit(process_segmentation_task, idx, frame))
589
+
590
+ for future in futures:
591
+ idx, result_frame = future.result()
592
+ processed_frames_map[idx] = result_frame
593
+
594
+ processed_frames = [processed_frames_map[i] for i in range(len(processed_frames_map))]
595
+
596
+ else:
597
+ processed_frames: List[np.ndarray] = []
598
+ for idx, frame in enumerate(frames):
599
+ # Check for cancellation every frame
600
+ _check_cancellation(job_id)
601
 
602
+ if max_frames is not None and idx >= max_frames:
603
+ break
604
+ logging.debug("Processing frame %d", idx)
605
+ processed_frame, _ = infer_segmentation_frame(frame, text_queries=queries, segmenter_name=active_segmenter)
606
+ processed_frames.append(processed_frame)
607
 
608
  write_video(processed_frames, output_video_path, fps=fps, width=width, height=height)
609
  logging.info("Segmented video written to: %s", output_video_path)
 
681
  Returns:
682
  List of depth visualization frames (HxWx3 RGB uint8)
683
  """
684
+ from models.depth_estimators.model_loader import load_depth_estimator, load_depth_estimator_on_device
685
+
686
+ # Detect GPUs
687
+ num_gpus = torch.cuda.device_count()
688
+ estimators = None
689
+ if num_gpus > 1:
690
+ logging.info("Detected %d GPUs. Enabling Multi-GPU depth estimation.", num_gpus)
691
+ estimators = []
692
+ for i in range(num_gpus):
693
+ device_str = f"cuda:{i}"
694
+ logging.info("Loading depth estimator on %s", device_str)
695
+ est = load_depth_estimator_on_device(depth_estimator_name, device_str)
696
+ est.lock = RLock()
697
+ estimators.append(est)
698
+ else:
699
+ logging.info("Single device detected. Using standard depth estimation.")
700
+ estimators = None
701
+ # Fallback to single estimator
702
+ single_estimator = load_depth_estimator(depth_estimator_name)
703
 
704
  # First pass: Compute all depth maps and find global range
705
+ depth_maps_map = {}
706
  all_values = []
707
+
708
+ if estimators:
709
+ # Multi-GPU Parallel Processing
710
+ def compute_depth_task(frame_idx: int, frame_data: np.ndarray) -> tuple[int, Any]:
711
+ gpu_idx = frame_idx % len(estimators)
712
+ estimator_instance = estimators[gpu_idx]
713
+
714
+ # Use instance lock
715
+ if hasattr(estimator_instance, "lock"):
716
+ lock = estimator_instance.lock
717
+ else:
718
+ # Should have been assigned above
719
+ lock = RLock()
720
+
721
+ with lock:
722
+ result = estimator_instance.predict(frame_data)
723
+ return frame_idx, result
724
+
725
+ max_workers = min(len(estimators) * 2, 8)
726
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
727
+ futures = []
728
+ for idx, frame in enumerate(frames):
729
+ _check_cancellation(job_id)
730
+ futures.append(executor.submit(compute_depth_task, idx, frame))
731
+
732
+ for future in futures:
733
+ idx, res = future.result()
734
+ depth_maps_map[idx] = res.depth_map
735
+ # We need to collect values for global min/max.
736
+ # Doing this here or later? doing it later to keep thread clean
737
+
738
+ # Reassemble
739
+ depth_maps = [depth_maps_map[i] for i in range(len(depth_maps_map))]
740
+ all_values = [dm.ravel() for dm in depth_maps]
741
 
742
+ else:
743
+ # Single threaded
744
+ estimator = single_estimator
745
+ depth_maps = []
746
+ for idx, frame in enumerate(frames):
747
+ _check_cancellation(job_id)
748
 
749
+ lock = _get_model_lock("depth", estimator.name)
750
+ with lock:
751
+ depth_result = estimator.predict(frame)
752
 
753
+ depth_maps.append(depth_result.depth_map)
754
+ all_values.append(depth_result.depth_map.ravel())
755
+
756
+ if idx % 10 == 0:
757
+ logging.debug("Computed depth for frame %d/%d", idx + 1, len(frames))
758
 
759
  # Compute global min/max (using percentiles to handle outliers)
760
  all_depths = np.concatenate(all_values).astype(np.float32, copy=False)
jobs/background.py CHANGED
@@ -41,6 +41,8 @@ async def process_video_async(job_id: str) -> None:
41
  None,
42
  job.detector_name,
43
  job_id,
 
 
44
  )
45
 
46
  # Try to run depth estimation
 
41
  None,
42
  job.detector_name,
43
  job_id,
44
+ job.depth_estimator_name,
45
+ job.depth_scale,
46
  )
47
 
48
  # Try to run depth estimation
jobs/models.py CHANGED
@@ -28,6 +28,7 @@ class JobInfo:
28
  first_frame_detections: List[Dict[str, Any]] = field(default_factory=list)
29
  # Depth estimation fields
30
  depth_estimator_name: str = "depth"
 
31
  depth_output_path: Optional[str] = None
32
  first_frame_depth_path: Optional[str] = None
33
  partial_success: bool = False # True if one component failed but job completed
 
28
  first_frame_detections: List[Dict[str, Any]] = field(default_factory=list)
29
  # Depth estimation fields
30
  depth_estimator_name: str = "depth"
31
+ depth_scale: float = 1.0
32
  depth_output_path: Optional[str] = None
33
  first_frame_depth_path: Optional[str] = None
34
  partial_success: bool = False # True if one component failed but job completed
models/depth_estimators/depth_anything_v2.py CHANGED
@@ -13,10 +13,13 @@ class DepthAnythingV2Estimator(DepthEstimator):
13
 
14
  name = "depth"
15
 
16
- def __init__(self) -> None:
17
  logging.info("Loading Depth-Anything model from Hugging Face (transformers)...")
18
 
19
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
20
 
21
  model_id = "LiheYoung/depth-anything-large-hf"
22
  self.image_processor = AutoImageProcessor.from_pretrained(model_id)
 
13
 
14
  name = "depth"
15
 
16
+ def __init__(self, device: str = None) -> None:
17
  logging.info("Loading Depth-Anything model from Hugging Face (transformers)...")
18
 
19
+ if device:
20
+ self.device = torch.device(device)
21
+ else:
22
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
 
24
  model_id = "LiheYoung/depth-anything-large-hf"
25
  self.image_processor = AutoImageProcessor.from_pretrained(model_id)
models/depth_estimators/model_loader.py CHANGED
@@ -27,7 +27,7 @@ def _get_cached_depth_estimator(name: str) -> DepthEstimator:
27
  return _create_depth_estimator(name)
28
 
29
 
30
- def _create_depth_estimator(name: str) -> DepthEstimator:
31
  """
32
  Create depth estimator instance.
33
 
@@ -46,7 +46,7 @@ def _create_depth_estimator(name: str) -> DepthEstimator:
46
  )
47
 
48
  estimator_class = _REGISTRY[name]
49
- return estimator_class()
50
 
51
 
52
  def load_depth_estimator(name: str = "depth") -> DepthEstimator:
@@ -62,6 +62,11 @@ def load_depth_estimator(name: str = "depth") -> DepthEstimator:
62
  return _get_cached_depth_estimator(name)
63
 
64
 
 
 
 
 
 
65
  def list_depth_estimators() -> list[str]:
66
  """Return list of available depth estimator names."""
67
  return list(_REGISTRY.keys())
 
27
  return _create_depth_estimator(name)
28
 
29
 
30
+ def _create_depth_estimator(name: str, **kwargs) -> DepthEstimator:
31
  """
32
  Create depth estimator instance.
33
 
 
46
  )
47
 
48
  estimator_class = _REGISTRY[name]
49
+ return estimator_class(**kwargs)
50
 
51
 
52
  def load_depth_estimator(name: str = "depth") -> DepthEstimator:
 
62
  return _get_cached_depth_estimator(name)
63
 
64
 
65
+ def load_depth_estimator_on_device(name: str, device: str) -> DepthEstimator:
66
+ """Create a new depth estimator instance on the specified device (no caching)."""
67
+ return _create_depth_estimator(name, device=device)
68
+
69
+
70
  def list_depth_estimators() -> list[str]:
71
  """Return list of available depth estimator names."""
72
  return list(_REGISTRY.keys())
models/detectors/detr.py CHANGED
@@ -13,10 +13,13 @@ class DetrDetector(ObjectDetector):
13
 
14
  MODEL_NAME = "facebook/detr-resnet-50"
15
 
16
- def __init__(self, score_threshold: float = 0.3) -> None:
17
  self.name = "detr_resnet50"
18
  self.score_threshold = score_threshold
19
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
20
  logging.info("Loading %s onto %s", self.MODEL_NAME, self.device)
21
  self.processor = DetrImageProcessor.from_pretrained(self.MODEL_NAME)
22
  self.model = DetrForObjectDetection.from_pretrained(self.MODEL_NAME)
 
13
 
14
  MODEL_NAME = "facebook/detr-resnet-50"
15
 
16
+ def __init__(self, score_threshold: float = 0.3, device: str = None) -> None:
17
  self.name = "detr_resnet50"
18
  self.score_threshold = score_threshold
19
+ if device:
20
+ self.device = torch.device(device)
21
+ else:
22
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
  logging.info("Loading %s onto %s", self.MODEL_NAME, self.device)
24
  self.processor = DetrImageProcessor.from_pretrained(self.MODEL_NAME)
25
  self.model = DetrForObjectDetection.from_pretrained(self.MODEL_NAME)
models/detectors/drone_yolo.py CHANGED
@@ -16,10 +16,13 @@ class DroneYoloDetector(ObjectDetector):
16
  REPO_ID = "rujutashashikanjoshi/yolo12-drone-detection-0205-100m"
17
  DEFAULT_WEIGHT = "best.pt"
18
 
19
- def __init__(self, score_threshold: float = 0.3) -> None:
20
  self.name = "drone_yolo"
21
  self.score_threshold = score_threshold
22
- self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
 
 
 
23
  weight_file = os.getenv("DRONE_YOLO_WEIGHT", self.DEFAULT_WEIGHT)
24
  logging.info(
25
  "Loading drone YOLO weights %s/%s onto %s",
 
16
  REPO_ID = "rujutashashikanjoshi/yolo12-drone-detection-0205-100m"
17
  DEFAULT_WEIGHT = "best.pt"
18
 
19
+ def __init__(self, score_threshold: float = 0.3, device: str = None) -> None:
20
  self.name = "drone_yolo"
21
  self.score_threshold = score_threshold
22
+ if device:
23
+ self.device = device
24
+ else:
25
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
26
  weight_file = os.getenv("DRONE_YOLO_WEIGHT", self.DEFAULT_WEIGHT)
27
  logging.info(
28
  "Loading drone YOLO weights %s/%s onto %s",
models/detectors/grounding_dino.py CHANGED
@@ -13,11 +13,14 @@ class GroundingDinoDetector(ObjectDetector):
13
 
14
  MODEL_NAME = "IDEA-Research/grounding-dino-base"
15
 
16
- def __init__(self, box_threshold: float = 0.35, text_threshold: float = 0.25) -> None:
17
  self.name = "grounding_dino"
18
  self.box_threshold = box_threshold
19
  self.text_threshold = text_threshold
20
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
21
  logging.info("Loading %s onto %s", self.MODEL_NAME, self.device)
22
  self.processor = GroundingDinoProcessor.from_pretrained(self.MODEL_NAME)
23
  self.model = GroundingDinoForObjectDetection.from_pretrained(self.MODEL_NAME)
 
13
 
14
  MODEL_NAME = "IDEA-Research/grounding-dino-base"
15
 
16
+ def __init__(self, box_threshold: float = 0.35, text_threshold: float = 0.25, device: str = None) -> None:
17
  self.name = "grounding_dino"
18
  self.box_threshold = box_threshold
19
  self.text_threshold = text_threshold
20
+ if device:
21
+ self.device = torch.device(device)
22
+ else:
23
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  logging.info("Loading %s onto %s", self.MODEL_NAME, self.device)
25
  self.processor = GroundingDinoProcessor.from_pretrained(self.MODEL_NAME)
26
  self.model = GroundingDinoForObjectDetection.from_pretrained(self.MODEL_NAME)
models/detectors/yolov8.py CHANGED
@@ -15,10 +15,13 @@ class HuggingFaceYoloV8Detector(ObjectDetector):
15
  REPO_ID = "Ultralytics/YOLOv8"
16
  WEIGHT_FILE = "yolov8s.pt"
17
 
18
- def __init__(self, score_threshold: float = 0.3) -> None:
19
  self.name = "hf_yolov8"
20
  self.score_threshold = score_threshold
21
- self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
 
 
 
22
  logging.info(
23
  "Loading Hugging Face YOLOv8 weights %s/%s onto %s",
24
  self.REPO_ID,
 
15
  REPO_ID = "Ultralytics/YOLOv8"
16
  WEIGHT_FILE = "yolov8s.pt"
17
 
18
+ def __init__(self, score_threshold: float = 0.3, device: str = None) -> None:
19
  self.name = "hf_yolov8"
20
  self.score_threshold = score_threshold
21
+ if device:
22
+ self.device = device
23
+ else:
24
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
25
  logging.info(
26
  "Loading Hugging Face YOLOv8 weights %s/%s onto %s",
27
  self.REPO_ID,
models/model_loader.py CHANGED
@@ -18,13 +18,13 @@ _REGISTRY: Dict[str, Callable[[], ObjectDetector]] = {
18
  }
19
 
20
 
21
- def _create_detector(name: str) -> ObjectDetector:
22
  try:
23
  factory = _REGISTRY[name]
24
  except KeyError as exc:
25
  available = ", ".join(sorted(_REGISTRY))
26
  raise ValueError(f"Unknown detector '{name}'. Available: {available}") from exc
27
- return factory()
28
 
29
 
30
  @lru_cache(maxsize=None)
@@ -38,6 +38,11 @@ def load_detector(name: Optional[str] = None) -> ObjectDetector:
38
  return _get_cached_detector(detector_name)
39
 
40
 
 
 
 
 
 
41
  # Backwards compatibility for existing callers.
42
  def load_model():
43
  return load_detector()
 
18
  }
19
 
20
 
21
+ def _create_detector(name: str, **kwargs) -> ObjectDetector:
22
  try:
23
  factory = _REGISTRY[name]
24
  except KeyError as exc:
25
  available = ", ".join(sorted(_REGISTRY))
26
  raise ValueError(f"Unknown detector '{name}'. Available: {available}") from exc
27
+ return factory(**kwargs)
28
 
29
 
30
  @lru_cache(maxsize=None)
 
38
  return _get_cached_detector(detector_name)
39
 
40
 
41
+ def load_detector_on_device(name: str, device: str) -> ObjectDetector:
42
+ """Create a new detector instance on the specified device (no caching)."""
43
+ return _create_detector(name, device=device)
44
+
45
+
46
  # Backwards compatibility for existing callers.
47
  def load_model():
48
  return load_detector()
models/segmenters/model_loader.py CHANGED
@@ -12,7 +12,7 @@ _REGISTRY: Dict[str, Callable[[], Segmenter]] = {
12
  }
13
 
14
 
15
- def _create_segmenter(name: str) -> Segmenter:
16
  """Create a new segmenter instance."""
17
  try:
18
  factory = _REGISTRY[name]
@@ -21,7 +21,7 @@ def _create_segmenter(name: str) -> Segmenter:
21
  raise ValueError(
22
  f"Unknown segmenter '{name}'. Available: {available}"
23
  ) from exc
24
- return factory()
25
 
26
 
27
  @lru_cache(maxsize=None)
@@ -42,3 +42,11 @@ def load_segmenter(name: Optional[str] = None) -> Segmenter:
42
  """
43
  segmenter_name = name or os.getenv("SEGMENTER", DEFAULT_SEGMENTER)
44
  return _get_cached_segmenter(segmenter_name)
 
 
 
 
 
 
 
 
 
12
  }
13
 
14
 
15
+ def _create_segmenter(name: str, **kwargs) -> Segmenter:
16
  """Create a new segmenter instance."""
17
  try:
18
  factory = _REGISTRY[name]
 
21
  raise ValueError(
22
  f"Unknown segmenter '{name}'. Available: {available}"
23
  ) from exc
24
+ return factory(**kwargs)
25
 
26
 
27
  @lru_cache(maxsize=None)
 
42
  """
43
  segmenter_name = name or os.getenv("SEGMENTER", DEFAULT_SEGMENTER)
44
  return _get_cached_segmenter(segmenter_name)
45
+
46
+
47
+ def load_segmenter_on_device(name: str, device: str) -> Segmenter:
48
+ """Create a new segmenter instance on the specified device (no caching)."""
49
+ # bypass cache by calling private creator directly
50
+ # Note: _create_segmenter calls factory() which needs to accept device now.
51
+ # We need to update _create_segmenter to pass kwargs too.
52
+ return _create_segmenter(name, device=device)