Zhen Ye commited on
Commit
9a35d39
·
1 Parent(s): 8c16918

Optimize model loading to prevent reload loop

Browse files
Files changed (1) hide show
  1. inference.py +17 -5
inference.py CHANGED
@@ -11,7 +11,7 @@ from threading import RLock
11
  from models.detectors.base import ObjectDetector
12
  from models.model_loader import load_detector, load_detector_on_device
13
  from models.segmenters.model_loader import load_segmenter, load_segmenter_on_device
14
- from models.depth_estimators.model_loader import load_depth_estimator_on_device
15
  from utils.video import extract_frames, write_video
16
 
17
 
@@ -529,6 +529,15 @@ def run_inference(
529
 
530
  else:
531
  # Standard Single-Threaded Loop
 
 
 
 
 
 
 
 
 
532
  processed_frames = []
533
  all_detections = []
534
  for idx, frame in enumerate(frames):
@@ -540,14 +549,17 @@ def run_inference(
540
  logging.debug("Processing frame %d", idx)
541
 
542
  # Run depth estimation every 3 frames if configured
543
- active_depth = depth_estimator_name if (idx % 3 == 0) else None
 
544
 
545
  processed_frame, frame_dets = infer_frame(
546
  frame,
547
  queries,
548
- detector_name=active_detector,
549
- depth_estimator_name=active_depth,
550
- depth_scale=depth_scale
 
 
551
  )
552
  processed_frames.append(processed_frame)
553
  all_detections.append(frame_dets)
 
11
  from models.detectors.base import ObjectDetector
12
  from models.model_loader import load_detector, load_detector_on_device
13
  from models.segmenters.model_loader import load_segmenter, load_segmenter_on_device
14
+ from models.depth_estimators.model_loader import load_depth_estimator, load_depth_estimator_on_device
15
  from utils.video import extract_frames, write_video
16
 
17
 
 
529
 
530
  else:
531
  # Standard Single-Threaded Loop
532
+ # Pre-load models to ensure they are loaded once
533
+ detector_instance = load_detector(active_detector)
534
+ detector_instance.lock = _get_model_lock("detector", detector_instance.name)
535
+
536
+ depth_estimator_instance = None
537
+ if depth_estimator_name:
538
+ depth_estimator_instance = load_depth_estimator(depth_estimator_name)
539
+ depth_estimator_instance.lock = _get_model_lock("depth", depth_estimator_instance.name)
540
+
541
  processed_frames = []
542
  all_detections = []
543
  for idx, frame in enumerate(frames):
 
549
  logging.debug("Processing frame %d", idx)
550
 
551
  # Run depth estimation every 3 frames if configured
552
+ active_depth_name = depth_estimator_name if (idx % 3 == 0) else None
553
+ active_depth_instance = depth_estimator_instance if (idx % 3 == 0) else None
554
 
555
  processed_frame, frame_dets = infer_frame(
556
  frame,
557
  queries,
558
+ detector_name=None,
559
+ depth_estimator_name=active_depth_name,
560
+ depth_scale=depth_scale,
561
+ detector_instance=detector_instance,
562
+ depth_estimator_instance=active_depth_instance
563
  )
564
  processed_frames.append(processed_frame)
565
  all_detections.append(frame_dets)