Spaces:
Paused
Paused
Zhen Ye
commited on
Commit
·
45eb65b
1
Parent(s):
b2e7d79
feat(inference): enable full multi-GPU support for all models
Browse files- Update inference.py to parallelize detection, segmentation, and depth estimation across all available GPUs
- Update detectors (YOLOv8, DETR, GroundingDINO, DroneYolo) to accept device argument
- Update SAM3 and DepthAnythingV2 to accept device argument
- Add device-specific model loading to all model loaders
- Remove OwlV2 support
- app.py +5 -0
- inference.py +292 -46
- jobs/background.py +2 -0
- jobs/models.py +1 -0
- models/depth_estimators/depth_anything_v2.py +5 -2
- models/depth_estimators/model_loader.py +7 -2
- models/detectors/detr.py +5 -2
- models/detectors/drone_yolo.py +5 -2
- models/detectors/grounding_dino.py +5 -2
- models/detectors/yolov8.py +5 -2
- models/model_loader.py +7 -2
- models/segmenters/model_loader.py +10 -2
app.py
CHANGED
|
@@ -228,6 +228,8 @@ async def detect_endpoint(
|
|
| 228 |
output_path,
|
| 229 |
query_list,
|
| 230 |
detector_name=detector_name,
|
|
|
|
|
|
|
| 231 |
)
|
| 232 |
except ValueError as exc:
|
| 233 |
logging.exception("Video processing failed.")
|
|
@@ -261,6 +263,7 @@ async def detect_async_endpoint(
|
|
| 261 |
detector: str = Form("hf_yolov8"),
|
| 262 |
segmenter: str = Form("sam3"),
|
| 263 |
depth_estimator: str = Form("depth"),
|
|
|
|
| 264 |
):
|
| 265 |
if mode not in VALID_MODES:
|
| 266 |
raise HTTPException(
|
|
@@ -313,6 +316,7 @@ async def detect_async_endpoint(
|
|
| 313 |
detector_name=detector_name,
|
| 314 |
segmenter_name=segmenter,
|
| 315 |
depth_estimator_name=depth_estimator,
|
|
|
|
| 316 |
)
|
| 317 |
cv2.imwrite(str(first_frame_path), processed_frame)
|
| 318 |
except Exception:
|
|
@@ -332,6 +336,7 @@ async def detect_async_endpoint(
|
|
| 332 |
first_frame_path=str(first_frame_path),
|
| 333 |
first_frame_detections=detections,
|
| 334 |
depth_estimator_name=depth_estimator,
|
|
|
|
| 335 |
depth_output_path=str(depth_output_path),
|
| 336 |
first_frame_depth_path=str(first_frame_depth_path),
|
| 337 |
)
|
|
|
|
| 228 |
output_path,
|
| 229 |
query_list,
|
| 230 |
detector_name=detector_name,
|
| 231 |
+
depth_estimator_name="depth", # Synch endpoint default
|
| 232 |
+
depth_scale=1.0,
|
| 233 |
)
|
| 234 |
except ValueError as exc:
|
| 235 |
logging.exception("Video processing failed.")
|
|
|
|
| 263 |
detector: str = Form("hf_yolov8"),
|
| 264 |
segmenter: str = Form("sam3"),
|
| 265 |
depth_estimator: str = Form("depth"),
|
| 266 |
+
depth_scale: float = Form(1.0),
|
| 267 |
):
|
| 268 |
if mode not in VALID_MODES:
|
| 269 |
raise HTTPException(
|
|
|
|
| 316 |
detector_name=detector_name,
|
| 317 |
segmenter_name=segmenter,
|
| 318 |
depth_estimator_name=depth_estimator,
|
| 319 |
+
depth_scale=depth_scale,
|
| 320 |
)
|
| 321 |
cv2.imwrite(str(first_frame_path), processed_frame)
|
| 322 |
except Exception:
|
|
|
|
| 336 |
first_frame_path=str(first_frame_path),
|
| 337 |
first_frame_detections=detections,
|
| 338 |
depth_estimator_name=depth_estimator,
|
| 339 |
+
depth_scale=float(depth_scale),
|
| 340 |
depth_output_path=str(depth_output_path),
|
| 341 |
first_frame_depth_path=str(first_frame_depth_path),
|
| 342 |
)
|
inference.py
CHANGED
|
@@ -5,8 +5,13 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple
|
|
| 5 |
|
| 6 |
import cv2
|
| 7 |
import numpy as np
|
| 8 |
-
|
| 9 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
from utils.video import extract_frames, write_video
|
| 11 |
|
| 12 |
|
|
@@ -186,14 +191,25 @@ def _attach_depth_metrics(
|
|
| 186 |
detections: List[Dict[str, Any]],
|
| 187 |
depth_estimator_name: Optional[str],
|
| 188 |
depth_scale: float,
|
|
|
|
| 189 |
) -> None:
|
| 190 |
-
if not detections or not depth_estimator_name:
|
| 191 |
return
|
| 192 |
|
| 193 |
from models.depth_estimators.model_loader import load_depth_estimator
|
| 194 |
|
| 195 |
-
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
with lock:
|
| 198 |
depth_result = estimator.predict(frame)
|
| 199 |
|
|
@@ -246,25 +262,56 @@ def infer_frame(
|
|
| 246 |
frame: np.ndarray,
|
| 247 |
queries: Sequence[str],
|
| 248 |
detector_name: Optional[str] = None,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
) -> tuple[np.ndarray, List[Dict[str, Any]]]:
|
| 250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
text_queries = list(queries) or ["object"]
|
| 252 |
try:
|
| 253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
with lock:
|
| 255 |
result = detector.predict(frame, text_queries)
|
| 256 |
detections = _build_detection_records(
|
| 257 |
result.boxes, result.scores, result.labels, text_queries, result.label_names
|
| 258 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
except Exception:
|
| 260 |
logging.exception("Inference failed for queries %s", text_queries)
|
| 261 |
raise
|
| 262 |
return draw_boxes(
|
| 263 |
frame,
|
| 264 |
result.boxes,
|
| 265 |
-
labels=
|
| 266 |
-
queries=
|
| 267 |
-
label_names=
|
| 268 |
), detections
|
| 269 |
|
| 270 |
|
|
@@ -272,9 +319,19 @@ def infer_segmentation_frame(
|
|
| 272 |
frame: np.ndarray,
|
| 273 |
text_queries: Optional[List[str]] = None,
|
| 274 |
segmenter_name: Optional[str] = None,
|
|
|
|
| 275 |
) -> tuple[np.ndarray, Any]:
|
| 276 |
-
|
| 277 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
with lock:
|
| 279 |
result = segmenter.predict(frame, text_prompts=text_queries)
|
| 280 |
labels = text_queries or []
|
|
@@ -335,6 +392,8 @@ def run_inference(
|
|
| 335 |
max_frames: Optional[int] = None,
|
| 336 |
detector_name: Optional[str] = None,
|
| 337 |
job_id: Optional[str] = None,
|
|
|
|
|
|
|
| 338 |
) -> str:
|
| 339 |
"""
|
| 340 |
Run object detection inference on a video.
|
|
@@ -346,9 +405,8 @@ def run_inference(
|
|
| 346 |
max_frames: Optional frame limit for testing
|
| 347 |
detector_name: Detector to use (default: hf_yolov8)
|
| 348 |
job_id: Optional job ID for cancellation support
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
Path to processed output video
|
| 352 |
"""
|
| 353 |
try:
|
| 354 |
frames, fps, width, height = extract_frames(input_video_path)
|
|
@@ -367,17 +425,102 @@ def run_inference(
|
|
| 367 |
active_detector = detector_name or "hf_yolov8"
|
| 368 |
logging.info("Using detector: %s", active_detector)
|
| 369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
# Process frames
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
|
| 382 |
# Write output video
|
| 383 |
write_video(processed_frames, output_video_path, fps=fps, width=width, height=height)
|
|
@@ -403,16 +546,64 @@ def run_segmentation(
|
|
| 403 |
active_segmenter = segmenter_name or "sam3"
|
| 404 |
logging.info("Using segmenter: %s with queries: %s", active_segmenter, queries)
|
| 405 |
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
|
| 417 |
write_video(processed_frames, output_video_path, fps=fps, width=width, height=height)
|
| 418 |
logging.info("Segmented video written to: %s", output_video_path)
|
|
@@ -490,25 +681,80 @@ def process_frames_depth(
|
|
| 490 |
Returns:
|
| 491 |
List of depth visualization frames (HxWx3 RGB uint8)
|
| 492 |
"""
|
| 493 |
-
from models.depth_estimators.model_loader import load_depth_estimator
|
| 494 |
-
|
| 495 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
|
| 497 |
# First pass: Compute all depth maps and find global range
|
| 498 |
-
|
| 499 |
all_values = []
|
| 500 |
-
|
| 501 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 502 |
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
|
|
|
|
|
|
|
|
|
| 506 |
|
| 507 |
-
|
| 508 |
-
|
|
|
|
| 509 |
|
| 510 |
-
|
| 511 |
-
|
|
|
|
|
|
|
|
|
|
| 512 |
|
| 513 |
# Compute global min/max (using percentiles to handle outliers)
|
| 514 |
all_depths = np.concatenate(all_values).astype(np.float32, copy=False)
|
|
|
|
| 5 |
|
| 6 |
import cv2
|
| 7 |
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 10 |
+
from threading import RLock
|
| 11 |
+
from models.detectors.base import ObjectDetector
|
| 12 |
+
from models.model_loader import load_detector, load_detector_on_device
|
| 13 |
+
from models.segmenters.model_loader import load_segmenter, load_segmenter_on_device
|
| 14 |
+
from models.depth_estimators.model_loader import load_depth_estimator_on_device
|
| 15 |
from utils.video import extract_frames, write_video
|
| 16 |
|
| 17 |
|
|
|
|
| 191 |
detections: List[Dict[str, Any]],
|
| 192 |
depth_estimator_name: Optional[str],
|
| 193 |
depth_scale: float,
|
| 194 |
+
estimator_instance: Optional[Any] = None,
|
| 195 |
) -> None:
|
| 196 |
+
if not detections or (not depth_estimator_name and not estimator_instance):
|
| 197 |
return
|
| 198 |
|
| 199 |
from models.depth_estimators.model_loader import load_depth_estimator
|
| 200 |
|
| 201 |
+
if estimator_instance:
|
| 202 |
+
estimator = estimator_instance
|
| 203 |
+
# Use instance lock if available, or create one
|
| 204 |
+
if hasattr(estimator, "lock"):
|
| 205 |
+
lock = estimator.lock
|
| 206 |
+
else:
|
| 207 |
+
# Fallback (shouldn't happen with our new setup but safe)
|
| 208 |
+
lock = _get_model_lock("depth", estimator.name)
|
| 209 |
+
else:
|
| 210 |
+
estimator = load_depth_estimator(depth_estimator_name)
|
| 211 |
+
lock = _get_model_lock("depth", estimator.name)
|
| 212 |
+
|
| 213 |
with lock:
|
| 214 |
depth_result = estimator.predict(frame)
|
| 215 |
|
|
|
|
| 262 |
frame: np.ndarray,
|
| 263 |
queries: Sequence[str],
|
| 264 |
detector_name: Optional[str] = None,
|
| 265 |
+
depth_estimator_name: Optional[str] = None,
|
| 266 |
+
depth_scale: float = 1.0,
|
| 267 |
+
detector_instance: Optional[ObjectDetector] = None,
|
| 268 |
+
depth_estimator_instance: Optional[Any] = None,
|
| 269 |
) -> tuple[np.ndarray, List[Dict[str, Any]]]:
|
| 270 |
+
if detector_instance:
|
| 271 |
+
detector = detector_instance
|
| 272 |
+
else:
|
| 273 |
+
detector = load_detector(detector_name)
|
| 274 |
+
|
| 275 |
text_queries = list(queries) or ["object"]
|
| 276 |
try:
|
| 277 |
+
if hasattr(detector, "lock"):
|
| 278 |
+
lock = detector.lock
|
| 279 |
+
else:
|
| 280 |
+
lock = _get_model_lock("detector", detector.name)
|
| 281 |
+
|
| 282 |
with lock:
|
| 283 |
result = detector.predict(frame, text_queries)
|
| 284 |
detections = _build_detection_records(
|
| 285 |
result.boxes, result.scores, result.labels, text_queries, result.label_names
|
| 286 |
)
|
| 287 |
+
|
| 288 |
+
if depth_estimator_name or depth_estimator_instance:
|
| 289 |
+
try:
|
| 290 |
+
_attach_depth_metrics(
|
| 291 |
+
frame, detections, depth_estimator_name, depth_scale, estimator_instance=depth_estimator_instance
|
| 292 |
+
)
|
| 293 |
+
except Exception:
|
| 294 |
+
logging.exception("Depth estimation failed for frame")
|
| 295 |
+
|
| 296 |
+
# Re-build display labels to incude depth if available
|
| 297 |
+
display_labels = []
|
| 298 |
+
for i, det in enumerate(detections):
|
| 299 |
+
label = det["label"]
|
| 300 |
+
if det.get("depth_valid") and det.get("depth_est_m") is not None:
|
| 301 |
+
# Add depth to label, e.g. "car 12m"
|
| 302 |
+
depth_str = f"{int(det['depth_est_m'])}m"
|
| 303 |
+
label = f"{label} {depth_str}"
|
| 304 |
+
display_labels.append(label)
|
| 305 |
+
|
| 306 |
except Exception:
|
| 307 |
logging.exception("Inference failed for queries %s", text_queries)
|
| 308 |
raise
|
| 309 |
return draw_boxes(
|
| 310 |
frame,
|
| 311 |
result.boxes,
|
| 312 |
+
labels=None, # Use custom labels
|
| 313 |
+
queries=None,
|
| 314 |
+
label_names=display_labels,
|
| 315 |
), detections
|
| 316 |
|
| 317 |
|
|
|
|
| 319 |
frame: np.ndarray,
|
| 320 |
text_queries: Optional[List[str]] = None,
|
| 321 |
segmenter_name: Optional[str] = None,
|
| 322 |
+
segmenter_instance: Optional[Any] = None,
|
| 323 |
) -> tuple[np.ndarray, Any]:
|
| 324 |
+
if segmenter_instance:
|
| 325 |
+
segmenter = segmenter_instance
|
| 326 |
+
# Use instance lock if available
|
| 327 |
+
if hasattr(segmenter, "lock"):
|
| 328 |
+
lock = segmenter.lock
|
| 329 |
+
else:
|
| 330 |
+
lock = _get_model_lock("segmenter", segmenter.name)
|
| 331 |
+
else:
|
| 332 |
+
segmenter = load_segmenter(segmenter_name)
|
| 333 |
+
lock = _get_model_lock("segmenter", segmenter.name)
|
| 334 |
+
|
| 335 |
with lock:
|
| 336 |
result = segmenter.predict(frame, text_prompts=text_queries)
|
| 337 |
labels = text_queries or []
|
|
|
|
| 392 |
max_frames: Optional[int] = None,
|
| 393 |
detector_name: Optional[str] = None,
|
| 394 |
job_id: Optional[str] = None,
|
| 395 |
+
depth_estimator_name: Optional[str] = None,
|
| 396 |
+
depth_scale: float = 1.0,
|
| 397 |
) -> str:
|
| 398 |
"""
|
| 399 |
Run object detection inference on a video.
|
|
|
|
| 405 |
max_frames: Optional frame limit for testing
|
| 406 |
detector_name: Detector to use (default: hf_yolov8)
|
| 407 |
job_id: Optional job ID for cancellation support
|
| 408 |
+
depth_estimator_name: Optional depth estimator name
|
| 409 |
+
depth_scale: Scale factor for depth estimation
|
|
|
|
| 410 |
"""
|
| 411 |
try:
|
| 412 |
frames, fps, width, height = extract_frames(input_video_path)
|
|
|
|
| 425 |
active_detector = detector_name or "hf_yolov8"
|
| 426 |
logging.info("Using detector: %s", active_detector)
|
| 427 |
|
| 428 |
+
# Detect GPUs
|
| 429 |
+
num_gpus = torch.cuda.device_count()
|
| 430 |
+
detectors = None
|
| 431 |
+
depth_estimators = None
|
| 432 |
+
|
| 433 |
+
if num_gpus > 1:
|
| 434 |
+
logging.info("Detected %d GPUs. Enabling Multi-GPU inference.", num_gpus)
|
| 435 |
+
# Initialize one detector per GPU
|
| 436 |
+
detectors = []
|
| 437 |
+
depth_estimators = []
|
| 438 |
+
for i in range(num_gpus):
|
| 439 |
+
device_str = f"cuda:{i}"
|
| 440 |
+
logging.info("Loading detector/depth on %s", device_str)
|
| 441 |
+
|
| 442 |
+
# Detector
|
| 443 |
+
det = load_detector_on_device(active_detector, device_str)
|
| 444 |
+
det.lock = RLock()
|
| 445 |
+
detectors.append(det)
|
| 446 |
+
|
| 447 |
+
# Depth (if requested)
|
| 448 |
+
if depth_estimator_name:
|
| 449 |
+
depth = load_depth_estimator_on_device(depth_estimator_name, device_str)
|
| 450 |
+
depth.lock = RLock()
|
| 451 |
+
depth_estimators.append(depth)
|
| 452 |
+
else:
|
| 453 |
+
depth_estimators.append(None)
|
| 454 |
+
|
| 455 |
+
else:
|
| 456 |
+
logging.info("Single device detected. Using standard inference.")
|
| 457 |
+
detectors = None
|
| 458 |
+
|
| 459 |
+
processed_frames_map = {}
|
| 460 |
+
|
| 461 |
# Process frames
|
| 462 |
+
if detectors:
|
| 463 |
+
# Multi-GPU Parallel Processing
|
| 464 |
+
def process_frame_task(frame_idx: int, frame_data: np.ndarray) -> tuple[int, np.ndarray]:
|
| 465 |
+
# Determine which GPU to use based on frame index (round-robin)
|
| 466 |
+
gpu_idx = frame_idx % len(detectors)
|
| 467 |
+
detector_instance = detectors[gpu_idx]
|
| 468 |
+
depth_instance = depth_estimators[gpu_idx] if depth_estimators else None
|
| 469 |
+
|
| 470 |
+
# Run depth estimation every 3 frames if configured
|
| 471 |
+
active_depth_name = depth_estimator_name if (frame_idx % 3 == 0) else None
|
| 472 |
+
active_depth_instance = depth_instance if (frame_idx % 3 == 0) else None
|
| 473 |
+
|
| 474 |
+
processed, _ = infer_frame(
|
| 475 |
+
frame_data,
|
| 476 |
+
queries,
|
| 477 |
+
detector_name=None, # Use instance
|
| 478 |
+
depth_estimator_name=active_depth_name,
|
| 479 |
+
depth_scale=depth_scale,
|
| 480 |
+
detector_instance=detector_instance,
|
| 481 |
+
depth_estimator_instance=active_depth_instance
|
| 482 |
+
)
|
| 483 |
+
return frame_idx, processed
|
| 484 |
+
|
| 485 |
+
# Thread pool with more workers than GPUs to keep them fed
|
| 486 |
+
max_workers = min(len(detectors) * 2, 8)
|
| 487 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 488 |
+
futures = []
|
| 489 |
+
for idx, frame in enumerate(frames):
|
| 490 |
+
_check_cancellation(job_id)
|
| 491 |
+
if max_frames is not None and idx >= max_frames:
|
| 492 |
+
break
|
| 493 |
+
futures.append(executor.submit(process_frame_task, idx, frame))
|
| 494 |
+
|
| 495 |
+
for future in futures:
|
| 496 |
+
idx, result_frame = future.result() # Wait for completion (in order or not, but we verify order)
|
| 497 |
+
processed_frames_map[idx] = result_frame
|
| 498 |
+
|
| 499 |
+
# Reasemble in order
|
| 500 |
+
processed_frames = [processed_frames_map[i] for i in range(len(processed_frames_map))]
|
| 501 |
|
| 502 |
+
else:
|
| 503 |
+
# Standard Single-Threaded Loop
|
| 504 |
+
processed_frames = []
|
| 505 |
+
for idx, frame in enumerate(frames):
|
| 506 |
+
# Check for cancellation every frame
|
| 507 |
+
_check_cancellation(job_id)
|
| 508 |
+
|
| 509 |
+
if max_frames is not None and idx >= max_frames:
|
| 510 |
+
break
|
| 511 |
+
logging.debug("Processing frame %d", idx)
|
| 512 |
+
|
| 513 |
+
# Run depth estimation every 3 frames if configured
|
| 514 |
+
active_depth = depth_estimator_name if (idx % 3 == 0) else None
|
| 515 |
+
|
| 516 |
+
processed_frame, _ = infer_frame(
|
| 517 |
+
frame,
|
| 518 |
+
queries,
|
| 519 |
+
detector_name=active_detector,
|
| 520 |
+
depth_estimator_name=active_depth,
|
| 521 |
+
depth_scale=depth_scale
|
| 522 |
+
)
|
| 523 |
+
processed_frames.append(processed_frame)
|
| 524 |
|
| 525 |
# Write output video
|
| 526 |
write_video(processed_frames, output_video_path, fps=fps, width=width, height=height)
|
|
|
|
| 546 |
active_segmenter = segmenter_name or "sam3"
|
| 547 |
logging.info("Using segmenter: %s with queries: %s", active_segmenter, queries)
|
| 548 |
|
| 549 |
+
# Detect GPUs
|
| 550 |
+
num_gpus = torch.cuda.device_count()
|
| 551 |
+
segmenters = None
|
| 552 |
+
if num_gpus > 1:
|
| 553 |
+
logging.info("Detected %d GPUs. Enabling Multi-GPU segmentation.", num_gpus)
|
| 554 |
+
segmenters = []
|
| 555 |
+
for i in range(num_gpus):
|
| 556 |
+
device_str = f"cuda:{i}"
|
| 557 |
+
logging.info("Loading segmenter on %s", device_str)
|
| 558 |
+
seg = load_segmenter_on_device(active_segmenter, device_str)
|
| 559 |
+
seg.lock = RLock()
|
| 560 |
+
segmenters.append(seg)
|
| 561 |
+
else:
|
| 562 |
+
logging.info("Single device detected. Using standard segmentation.")
|
| 563 |
+
segmenters = None
|
| 564 |
+
|
| 565 |
+
processed_frames_map = {}
|
| 566 |
+
|
| 567 |
+
if segmenters:
|
| 568 |
+
# Multi-GPU Parallel Processing
|
| 569 |
+
def process_segmentation_task(frame_idx: int, frame_data: np.ndarray) -> tuple[int, np.ndarray]:
|
| 570 |
+
gpu_idx = frame_idx % len(segmenters)
|
| 571 |
+
segmenter_instance = segmenters[gpu_idx]
|
| 572 |
+
|
| 573 |
+
processed, _ = infer_segmentation_frame(
|
| 574 |
+
frame_data,
|
| 575 |
+
text_queries=queries,
|
| 576 |
+
segmenter_name=None,
|
| 577 |
+
segmenter_instance=segmenter_instance
|
| 578 |
+
)
|
| 579 |
+
return frame_idx, processed
|
| 580 |
+
|
| 581 |
+
max_workers = min(len(segmenters) * 2, 8)
|
| 582 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 583 |
+
futures = []
|
| 584 |
+
for idx, frame in enumerate(frames):
|
| 585 |
+
_check_cancellation(job_id)
|
| 586 |
+
if max_frames is not None and idx >= max_frames:
|
| 587 |
+
break
|
| 588 |
+
futures.append(executor.submit(process_segmentation_task, idx, frame))
|
| 589 |
+
|
| 590 |
+
for future in futures:
|
| 591 |
+
idx, result_frame = future.result()
|
| 592 |
+
processed_frames_map[idx] = result_frame
|
| 593 |
+
|
| 594 |
+
processed_frames = [processed_frames_map[i] for i in range(len(processed_frames_map))]
|
| 595 |
+
|
| 596 |
+
else:
|
| 597 |
+
processed_frames: List[np.ndarray] = []
|
| 598 |
+
for idx, frame in enumerate(frames):
|
| 599 |
+
# Check for cancellation every frame
|
| 600 |
+
_check_cancellation(job_id)
|
| 601 |
|
| 602 |
+
if max_frames is not None and idx >= max_frames:
|
| 603 |
+
break
|
| 604 |
+
logging.debug("Processing frame %d", idx)
|
| 605 |
+
processed_frame, _ = infer_segmentation_frame(frame, text_queries=queries, segmenter_name=active_segmenter)
|
| 606 |
+
processed_frames.append(processed_frame)
|
| 607 |
|
| 608 |
write_video(processed_frames, output_video_path, fps=fps, width=width, height=height)
|
| 609 |
logging.info("Segmented video written to: %s", output_video_path)
|
|
|
|
| 681 |
Returns:
|
| 682 |
List of depth visualization frames (HxWx3 RGB uint8)
|
| 683 |
"""
|
| 684 |
+
from models.depth_estimators.model_loader import load_depth_estimator, load_depth_estimator_on_device
|
| 685 |
+
|
| 686 |
+
# Detect GPUs
|
| 687 |
+
num_gpus = torch.cuda.device_count()
|
| 688 |
+
estimators = None
|
| 689 |
+
if num_gpus > 1:
|
| 690 |
+
logging.info("Detected %d GPUs. Enabling Multi-GPU depth estimation.", num_gpus)
|
| 691 |
+
estimators = []
|
| 692 |
+
for i in range(num_gpus):
|
| 693 |
+
device_str = f"cuda:{i}"
|
| 694 |
+
logging.info("Loading depth estimator on %s", device_str)
|
| 695 |
+
est = load_depth_estimator_on_device(depth_estimator_name, device_str)
|
| 696 |
+
est.lock = RLock()
|
| 697 |
+
estimators.append(est)
|
| 698 |
+
else:
|
| 699 |
+
logging.info("Single device detected. Using standard depth estimation.")
|
| 700 |
+
estimators = None
|
| 701 |
+
# Fallback to single estimator
|
| 702 |
+
single_estimator = load_depth_estimator(depth_estimator_name)
|
| 703 |
|
| 704 |
# First pass: Compute all depth maps and find global range
|
| 705 |
+
depth_maps_map = {}
|
| 706 |
all_values = []
|
| 707 |
+
|
| 708 |
+
if estimators:
|
| 709 |
+
# Multi-GPU Parallel Processing
|
| 710 |
+
def compute_depth_task(frame_idx: int, frame_data: np.ndarray) -> tuple[int, Any]:
|
| 711 |
+
gpu_idx = frame_idx % len(estimators)
|
| 712 |
+
estimator_instance = estimators[gpu_idx]
|
| 713 |
+
|
| 714 |
+
# Use instance lock
|
| 715 |
+
if hasattr(estimator_instance, "lock"):
|
| 716 |
+
lock = estimator_instance.lock
|
| 717 |
+
else:
|
| 718 |
+
# Should have been assigned above
|
| 719 |
+
lock = RLock()
|
| 720 |
+
|
| 721 |
+
with lock:
|
| 722 |
+
result = estimator_instance.predict(frame_data)
|
| 723 |
+
return frame_idx, result
|
| 724 |
+
|
| 725 |
+
max_workers = min(len(estimators) * 2, 8)
|
| 726 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 727 |
+
futures = []
|
| 728 |
+
for idx, frame in enumerate(frames):
|
| 729 |
+
_check_cancellation(job_id)
|
| 730 |
+
futures.append(executor.submit(compute_depth_task, idx, frame))
|
| 731 |
+
|
| 732 |
+
for future in futures:
|
| 733 |
+
idx, res = future.result()
|
| 734 |
+
depth_maps_map[idx] = res.depth_map
|
| 735 |
+
# We need to collect values for global min/max.
|
| 736 |
+
# Doing this here or later? doing it later to keep thread clean
|
| 737 |
+
|
| 738 |
+
# Reassemble
|
| 739 |
+
depth_maps = [depth_maps_map[i] for i in range(len(depth_maps_map))]
|
| 740 |
+
all_values = [dm.ravel() for dm in depth_maps]
|
| 741 |
|
| 742 |
+
else:
|
| 743 |
+
# Single threaded
|
| 744 |
+
estimator = single_estimator
|
| 745 |
+
depth_maps = []
|
| 746 |
+
for idx, frame in enumerate(frames):
|
| 747 |
+
_check_cancellation(job_id)
|
| 748 |
|
| 749 |
+
lock = _get_model_lock("depth", estimator.name)
|
| 750 |
+
with lock:
|
| 751 |
+
depth_result = estimator.predict(frame)
|
| 752 |
|
| 753 |
+
depth_maps.append(depth_result.depth_map)
|
| 754 |
+
all_values.append(depth_result.depth_map.ravel())
|
| 755 |
+
|
| 756 |
+
if idx % 10 == 0:
|
| 757 |
+
logging.debug("Computed depth for frame %d/%d", idx + 1, len(frames))
|
| 758 |
|
| 759 |
# Compute global min/max (using percentiles to handle outliers)
|
| 760 |
all_depths = np.concatenate(all_values).astype(np.float32, copy=False)
|
jobs/background.py
CHANGED
|
@@ -41,6 +41,8 @@ async def process_video_async(job_id: str) -> None:
|
|
| 41 |
None,
|
| 42 |
job.detector_name,
|
| 43 |
job_id,
|
|
|
|
|
|
|
| 44 |
)
|
| 45 |
|
| 46 |
# Try to run depth estimation
|
|
|
|
| 41 |
None,
|
| 42 |
job.detector_name,
|
| 43 |
job_id,
|
| 44 |
+
job.depth_estimator_name,
|
| 45 |
+
job.depth_scale,
|
| 46 |
)
|
| 47 |
|
| 48 |
# Try to run depth estimation
|
jobs/models.py
CHANGED
|
@@ -28,6 +28,7 @@ class JobInfo:
|
|
| 28 |
first_frame_detections: List[Dict[str, Any]] = field(default_factory=list)
|
| 29 |
# Depth estimation fields
|
| 30 |
depth_estimator_name: str = "depth"
|
|
|
|
| 31 |
depth_output_path: Optional[str] = None
|
| 32 |
first_frame_depth_path: Optional[str] = None
|
| 33 |
partial_success: bool = False # True if one component failed but job completed
|
|
|
|
| 28 |
first_frame_detections: List[Dict[str, Any]] = field(default_factory=list)
|
| 29 |
# Depth estimation fields
|
| 30 |
depth_estimator_name: str = "depth"
|
| 31 |
+
depth_scale: float = 1.0
|
| 32 |
depth_output_path: Optional[str] = None
|
| 33 |
first_frame_depth_path: Optional[str] = None
|
| 34 |
partial_success: bool = False # True if one component failed but job completed
|
models/depth_estimators/depth_anything_v2.py
CHANGED
|
@@ -13,10 +13,13 @@ class DepthAnythingV2Estimator(DepthEstimator):
|
|
| 13 |
|
| 14 |
name = "depth"
|
| 15 |
|
| 16 |
-
def __init__(self) -> None:
|
| 17 |
logging.info("Loading Depth-Anything model from Hugging Face (transformers)...")
|
| 18 |
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
model_id = "LiheYoung/depth-anything-large-hf"
|
| 22 |
self.image_processor = AutoImageProcessor.from_pretrained(model_id)
|
|
|
|
| 13 |
|
| 14 |
name = "depth"
|
| 15 |
|
| 16 |
+
def __init__(self, device: str = None) -> None:
|
| 17 |
logging.info("Loading Depth-Anything model from Hugging Face (transformers)...")
|
| 18 |
|
| 19 |
+
if device:
|
| 20 |
+
self.device = torch.device(device)
|
| 21 |
+
else:
|
| 22 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 23 |
|
| 24 |
model_id = "LiheYoung/depth-anything-large-hf"
|
| 25 |
self.image_processor = AutoImageProcessor.from_pretrained(model_id)
|
models/depth_estimators/model_loader.py
CHANGED
|
@@ -27,7 +27,7 @@ def _get_cached_depth_estimator(name: str) -> DepthEstimator:
|
|
| 27 |
return _create_depth_estimator(name)
|
| 28 |
|
| 29 |
|
| 30 |
-
def _create_depth_estimator(name: str) -> DepthEstimator:
|
| 31 |
"""
|
| 32 |
Create depth estimator instance.
|
| 33 |
|
|
@@ -46,7 +46,7 @@ def _create_depth_estimator(name: str) -> DepthEstimator:
|
|
| 46 |
)
|
| 47 |
|
| 48 |
estimator_class = _REGISTRY[name]
|
| 49 |
-
return estimator_class()
|
| 50 |
|
| 51 |
|
| 52 |
def load_depth_estimator(name: str = "depth") -> DepthEstimator:
|
|
@@ -62,6 +62,11 @@ def load_depth_estimator(name: str = "depth") -> DepthEstimator:
|
|
| 62 |
return _get_cached_depth_estimator(name)
|
| 63 |
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
def list_depth_estimators() -> list[str]:
|
| 66 |
"""Return list of available depth estimator names."""
|
| 67 |
return list(_REGISTRY.keys())
|
|
|
|
| 27 |
return _create_depth_estimator(name)
|
| 28 |
|
| 29 |
|
| 30 |
+
def _create_depth_estimator(name: str, **kwargs) -> DepthEstimator:
|
| 31 |
"""
|
| 32 |
Create depth estimator instance.
|
| 33 |
|
|
|
|
| 46 |
)
|
| 47 |
|
| 48 |
estimator_class = _REGISTRY[name]
|
| 49 |
+
return estimator_class(**kwargs)
|
| 50 |
|
| 51 |
|
| 52 |
def load_depth_estimator(name: str = "depth") -> DepthEstimator:
|
|
|
|
| 62 |
return _get_cached_depth_estimator(name)
|
| 63 |
|
| 64 |
|
| 65 |
+
def load_depth_estimator_on_device(name: str, device: str) -> DepthEstimator:
|
| 66 |
+
"""Create a new depth estimator instance on the specified device (no caching)."""
|
| 67 |
+
return _create_depth_estimator(name, device=device)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
def list_depth_estimators() -> list[str]:
|
| 71 |
"""Return list of available depth estimator names."""
|
| 72 |
return list(_REGISTRY.keys())
|
models/detectors/detr.py
CHANGED
|
@@ -13,10 +13,13 @@ class DetrDetector(ObjectDetector):
|
|
| 13 |
|
| 14 |
MODEL_NAME = "facebook/detr-resnet-50"
|
| 15 |
|
| 16 |
-
def __init__(self, score_threshold: float = 0.3) -> None:
|
| 17 |
self.name = "detr_resnet50"
|
| 18 |
self.score_threshold = score_threshold
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
| 20 |
logging.info("Loading %s onto %s", self.MODEL_NAME, self.device)
|
| 21 |
self.processor = DetrImageProcessor.from_pretrained(self.MODEL_NAME)
|
| 22 |
self.model = DetrForObjectDetection.from_pretrained(self.MODEL_NAME)
|
|
|
|
| 13 |
|
| 14 |
MODEL_NAME = "facebook/detr-resnet-50"
|
| 15 |
|
| 16 |
+
def __init__(self, score_threshold: float = 0.3, device: str = None) -> None:
|
| 17 |
self.name = "detr_resnet50"
|
| 18 |
self.score_threshold = score_threshold
|
| 19 |
+
if device:
|
| 20 |
+
self.device = torch.device(device)
|
| 21 |
+
else:
|
| 22 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 23 |
logging.info("Loading %s onto %s", self.MODEL_NAME, self.device)
|
| 24 |
self.processor = DetrImageProcessor.from_pretrained(self.MODEL_NAME)
|
| 25 |
self.model = DetrForObjectDetection.from_pretrained(self.MODEL_NAME)
|
models/detectors/drone_yolo.py
CHANGED
|
@@ -16,10 +16,13 @@ class DroneYoloDetector(ObjectDetector):
|
|
| 16 |
REPO_ID = "rujutashashikanjoshi/yolo12-drone-detection-0205-100m"
|
| 17 |
DEFAULT_WEIGHT = "best.pt"
|
| 18 |
|
| 19 |
-
def __init__(self, score_threshold: float = 0.3) -> None:
|
| 20 |
self.name = "drone_yolo"
|
| 21 |
self.score_threshold = score_threshold
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
| 23 |
weight_file = os.getenv("DRONE_YOLO_WEIGHT", self.DEFAULT_WEIGHT)
|
| 24 |
logging.info(
|
| 25 |
"Loading drone YOLO weights %s/%s onto %s",
|
|
|
|
| 16 |
REPO_ID = "rujutashashikanjoshi/yolo12-drone-detection-0205-100m"
|
| 17 |
DEFAULT_WEIGHT = "best.pt"
|
| 18 |
|
| 19 |
+
def __init__(self, score_threshold: float = 0.3, device: str = None) -> None:
|
| 20 |
self.name = "drone_yolo"
|
| 21 |
self.score_threshold = score_threshold
|
| 22 |
+
if device:
|
| 23 |
+
self.device = device
|
| 24 |
+
else:
|
| 25 |
+
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 26 |
weight_file = os.getenv("DRONE_YOLO_WEIGHT", self.DEFAULT_WEIGHT)
|
| 27 |
logging.info(
|
| 28 |
"Loading drone YOLO weights %s/%s onto %s",
|
models/detectors/grounding_dino.py
CHANGED
|
@@ -13,11 +13,14 @@ class GroundingDinoDetector(ObjectDetector):
|
|
| 13 |
|
| 14 |
MODEL_NAME = "IDEA-Research/grounding-dino-base"
|
| 15 |
|
| 16 |
-
def __init__(self, box_threshold: float = 0.35, text_threshold: float = 0.25) -> None:
|
| 17 |
self.name = "grounding_dino"
|
| 18 |
self.box_threshold = box_threshold
|
| 19 |
self.text_threshold = text_threshold
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
| 21 |
logging.info("Loading %s onto %s", self.MODEL_NAME, self.device)
|
| 22 |
self.processor = GroundingDinoProcessor.from_pretrained(self.MODEL_NAME)
|
| 23 |
self.model = GroundingDinoForObjectDetection.from_pretrained(self.MODEL_NAME)
|
|
|
|
| 13 |
|
| 14 |
MODEL_NAME = "IDEA-Research/grounding-dino-base"
|
| 15 |
|
| 16 |
+
def __init__(self, box_threshold: float = 0.35, text_threshold: float = 0.25, device: str = None) -> None:
|
| 17 |
self.name = "grounding_dino"
|
| 18 |
self.box_threshold = box_threshold
|
| 19 |
self.text_threshold = text_threshold
|
| 20 |
+
if device:
|
| 21 |
+
self.device = torch.device(device)
|
| 22 |
+
else:
|
| 23 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 24 |
logging.info("Loading %s onto %s", self.MODEL_NAME, self.device)
|
| 25 |
self.processor = GroundingDinoProcessor.from_pretrained(self.MODEL_NAME)
|
| 26 |
self.model = GroundingDinoForObjectDetection.from_pretrained(self.MODEL_NAME)
|
models/detectors/yolov8.py
CHANGED
|
@@ -15,10 +15,13 @@ class HuggingFaceYoloV8Detector(ObjectDetector):
|
|
| 15 |
REPO_ID = "Ultralytics/YOLOv8"
|
| 16 |
WEIGHT_FILE = "yolov8s.pt"
|
| 17 |
|
| 18 |
-
def __init__(self, score_threshold: float = 0.3) -> None:
|
| 19 |
self.name = "hf_yolov8"
|
| 20 |
self.score_threshold = score_threshold
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
| 22 |
logging.info(
|
| 23 |
"Loading Hugging Face YOLOv8 weights %s/%s onto %s",
|
| 24 |
self.REPO_ID,
|
|
|
|
| 15 |
REPO_ID = "Ultralytics/YOLOv8"
|
| 16 |
WEIGHT_FILE = "yolov8s.pt"
|
| 17 |
|
| 18 |
+
def __init__(self, score_threshold: float = 0.3, device: str = None) -> None:
|
| 19 |
self.name = "hf_yolov8"
|
| 20 |
self.score_threshold = score_threshold
|
| 21 |
+
if device:
|
| 22 |
+
self.device = device
|
| 23 |
+
else:
|
| 24 |
+
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 25 |
logging.info(
|
| 26 |
"Loading Hugging Face YOLOv8 weights %s/%s onto %s",
|
| 27 |
self.REPO_ID,
|
models/model_loader.py
CHANGED
|
@@ -18,13 +18,13 @@ _REGISTRY: Dict[str, Callable[[], ObjectDetector]] = {
|
|
| 18 |
}
|
| 19 |
|
| 20 |
|
| 21 |
-
def _create_detector(name: str) -> ObjectDetector:
|
| 22 |
try:
|
| 23 |
factory = _REGISTRY[name]
|
| 24 |
except KeyError as exc:
|
| 25 |
available = ", ".join(sorted(_REGISTRY))
|
| 26 |
raise ValueError(f"Unknown detector '{name}'. Available: {available}") from exc
|
| 27 |
-
return factory()
|
| 28 |
|
| 29 |
|
| 30 |
@lru_cache(maxsize=None)
|
|
@@ -38,6 +38,11 @@ def load_detector(name: Optional[str] = None) -> ObjectDetector:
|
|
| 38 |
return _get_cached_detector(detector_name)
|
| 39 |
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
# Backwards compatibility for existing callers.
|
| 42 |
def load_model():
|
| 43 |
return load_detector()
|
|
|
|
| 18 |
}
|
| 19 |
|
| 20 |
|
| 21 |
+
def _create_detector(name: str, **kwargs) -> ObjectDetector:
|
| 22 |
try:
|
| 23 |
factory = _REGISTRY[name]
|
| 24 |
except KeyError as exc:
|
| 25 |
available = ", ".join(sorted(_REGISTRY))
|
| 26 |
raise ValueError(f"Unknown detector '{name}'. Available: {available}") from exc
|
| 27 |
+
return factory(**kwargs)
|
| 28 |
|
| 29 |
|
| 30 |
@lru_cache(maxsize=None)
|
|
|
|
| 38 |
return _get_cached_detector(detector_name)
|
| 39 |
|
| 40 |
|
| 41 |
+
def load_detector_on_device(name: str, device: str) -> ObjectDetector:
|
| 42 |
+
"""Create a new detector instance on the specified device (no caching)."""
|
| 43 |
+
return _create_detector(name, device=device)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
# Backwards compatibility for existing callers.
|
| 47 |
def load_model():
|
| 48 |
return load_detector()
|
models/segmenters/model_loader.py
CHANGED
|
@@ -12,7 +12,7 @@ _REGISTRY: Dict[str, Callable[[], Segmenter]] = {
|
|
| 12 |
}
|
| 13 |
|
| 14 |
|
| 15 |
-
def _create_segmenter(name: str) -> Segmenter:
|
| 16 |
"""Create a new segmenter instance."""
|
| 17 |
try:
|
| 18 |
factory = _REGISTRY[name]
|
|
@@ -21,7 +21,7 @@ def _create_segmenter(name: str) -> Segmenter:
|
|
| 21 |
raise ValueError(
|
| 22 |
f"Unknown segmenter '{name}'. Available: {available}"
|
| 23 |
) from exc
|
| 24 |
-
return factory()
|
| 25 |
|
| 26 |
|
| 27 |
@lru_cache(maxsize=None)
|
|
@@ -42,3 +42,11 @@ def load_segmenter(name: Optional[str] = None) -> Segmenter:
|
|
| 42 |
"""
|
| 43 |
segmenter_name = name or os.getenv("SEGMENTER", DEFAULT_SEGMENTER)
|
| 44 |
return _get_cached_segmenter(segmenter_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
}
|
| 13 |
|
| 14 |
|
| 15 |
+
def _create_segmenter(name: str, **kwargs) -> Segmenter:
|
| 16 |
"""Create a new segmenter instance."""
|
| 17 |
try:
|
| 18 |
factory = _REGISTRY[name]
|
|
|
|
| 21 |
raise ValueError(
|
| 22 |
f"Unknown segmenter '{name}'. Available: {available}"
|
| 23 |
) from exc
|
| 24 |
+
return factory(**kwargs)
|
| 25 |
|
| 26 |
|
| 27 |
@lru_cache(maxsize=None)
|
|
|
|
| 42 |
"""
|
| 43 |
segmenter_name = name or os.getenv("SEGMENTER", DEFAULT_SEGMENTER)
|
| 44 |
return _get_cached_segmenter(segmenter_name)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def load_segmenter_on_device(name: str, device: str) -> Segmenter:
|
| 48 |
+
"""Create a new segmenter instance on the specified device (no caching)."""
|
| 49 |
+
# bypass cache by calling private creator directly
|
| 50 |
+
# Note: _create_segmenter calls factory() which needs to accept device now.
|
| 51 |
+
# We need to update _create_segmenter to pass kwargs too.
|
| 52 |
+
return _create_segmenter(name, device=device)
|