Spaces:
Sleeping
Sleeping
Zhen Ye commited on
Commit ·
9a35d39
1
Parent(s): 8c16918
Optimize model loading to prevent reload loop
Browse files- inference.py +17 -5
inference.py
CHANGED
|
@@ -11,7 +11,7 @@ from threading import RLock
|
|
| 11 |
from models.detectors.base import ObjectDetector
|
| 12 |
from models.model_loader import load_detector, load_detector_on_device
|
| 13 |
from models.segmenters.model_loader import load_segmenter, load_segmenter_on_device
|
| 14 |
-
from models.depth_estimators.model_loader import load_depth_estimator_on_device
|
| 15 |
from utils.video import extract_frames, write_video
|
| 16 |
|
| 17 |
|
|
@@ -529,6 +529,15 @@ def run_inference(
|
|
| 529 |
|
| 530 |
else:
|
| 531 |
# Standard Single-Threaded Loop
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
processed_frames = []
|
| 533 |
all_detections = []
|
| 534 |
for idx, frame in enumerate(frames):
|
|
@@ -540,14 +549,17 @@ def run_inference(
|
|
| 540 |
logging.debug("Processing frame %d", idx)
|
| 541 |
|
| 542 |
# Run depth estimation every 3 frames if configured
|
| 543 |
-
|
|
|
|
| 544 |
|
| 545 |
processed_frame, frame_dets = infer_frame(
|
| 546 |
frame,
|
| 547 |
queries,
|
| 548 |
-
detector_name=
|
| 549 |
-
depth_estimator_name=
|
| 550 |
-
depth_scale=depth_scale
|
|
|
|
|
|
|
| 551 |
)
|
| 552 |
processed_frames.append(processed_frame)
|
| 553 |
all_detections.append(frame_dets)
|
|
|
|
| 11 |
from models.detectors.base import ObjectDetector
|
| 12 |
from models.model_loader import load_detector, load_detector_on_device
|
| 13 |
from models.segmenters.model_loader import load_segmenter, load_segmenter_on_device
|
| 14 |
+
from models.depth_estimators.model_loader import load_depth_estimator, load_depth_estimator_on_device
|
| 15 |
from utils.video import extract_frames, write_video
|
| 16 |
|
| 17 |
|
|
|
|
| 529 |
|
| 530 |
else:
|
| 531 |
# Standard Single-Threaded Loop
|
| 532 |
+
# Pre-load models to ensure they are loaded once
|
| 533 |
+
detector_instance = load_detector(active_detector)
|
| 534 |
+
detector_instance.lock = _get_model_lock("detector", detector_instance.name)
|
| 535 |
+
|
| 536 |
+
depth_estimator_instance = None
|
| 537 |
+
if depth_estimator_name:
|
| 538 |
+
depth_estimator_instance = load_depth_estimator(depth_estimator_name)
|
| 539 |
+
depth_estimator_instance.lock = _get_model_lock("depth", depth_estimator_instance.name)
|
| 540 |
+
|
| 541 |
processed_frames = []
|
| 542 |
all_detections = []
|
| 543 |
for idx, frame in enumerate(frames):
|
|
|
|
| 549 |
logging.debug("Processing frame %d", idx)
|
| 550 |
|
| 551 |
# Run depth estimation every 3 frames if configured
|
| 552 |
+
active_depth_name = depth_estimator_name if (idx % 3 == 0) else None
|
| 553 |
+
active_depth_instance = depth_estimator_instance if (idx % 3 == 0) else None
|
| 554 |
|
| 555 |
processed_frame, frame_dets = infer_frame(
|
| 556 |
frame,
|
| 557 |
queries,
|
| 558 |
+
detector_name=None,
|
| 559 |
+
depth_estimator_name=active_depth_name,
|
| 560 |
+
depth_scale=depth_scale,
|
| 561 |
+
detector_instance=detector_instance,
|
| 562 |
+
depth_estimator_instance=active_depth_instance
|
| 563 |
)
|
| 564 |
processed_frames.append(processed_frame)
|
| 565 |
all_detections.append(frame_dets)
|