3d_model / ylff /utils /pipeline_parallel.py
Azan
Clean deployment build (Squashed)
7a87926
"""
GPU/CPU pipeline parallelism utilities.
Allows overlapping GPU inference with CPU-bound operations (like BA validation)
for better resource utilization.
"""
import logging
from queue import Empty, Queue
from threading import Thread
from typing import Any, Callable, Dict, List, Optional
import numpy as np
import torch
logger = logging.getLogger(__name__)
class PipelineProcessor:
"""
Pipeline processor that overlaps GPU and CPU work.
GPU worker: Runs model inference
CPU worker: Runs CPU-bound operations (BA validation, etc.)
"""
def __init__(
self,
gpu_worker_fn: Callable,
cpu_worker_fn: Callable,
gpu_queue_size: int = 10,
cpu_queue_size: int = 10,
):
"""
Args:
gpu_worker_fn: Function to run on GPU (takes images, returns output)
cpu_worker_fn: Function to run on CPU (takes GPU output, returns result)
gpu_queue_size: Size of GPU work queue
cpu_queue_size: Size of CPU work queue
"""
self.gpu_worker_fn = gpu_worker_fn
self.cpu_worker_fn = cpu_worker_fn
self.gpu_queue = Queue(maxsize=gpu_queue_size)
self.cpu_queue = Queue(maxsize=cpu_queue_size)
self.gpu_thread = None
self.cpu_thread = None
self.running = False
self.results = {}
def start(self):
"""Start GPU and CPU worker threads."""
if self.running:
logger.warning("Pipeline processor already running")
return
self.running = True
self.gpu_thread = Thread(target=self._gpu_worker, daemon=True)
self.cpu_thread = Thread(target=self._cpu_worker, daemon=True)
self.gpu_thread.start()
self.cpu_thread.start()
logger.info("Pipeline processor started (GPU + CPU workers)")
def stop(self):
"""Stop worker threads."""
self.running = False
# Send sentinels
self.gpu_queue.put(None)
self.cpu_queue.put(None)
if self.gpu_thread:
self.gpu_thread.join(timeout=5.0)
if self.cpu_thread:
self.cpu_thread.join(timeout=5.0)
logger.info("Pipeline processor stopped")
def submit(
self, item_id: str, images: List[np.ndarray], metadata: Optional[Dict] = None
) -> str:
"""
Submit work item to pipeline.
Args:
item_id: Unique identifier for this item
images: Input images
metadata: Optional metadata
Returns:
Item ID
"""
self.gpu_queue.put((item_id, images, metadata))
return item_id
def get_result(self, item_id: str, timeout: Optional[float] = None) -> Optional[Any]:
"""
Get result for submitted item.
Args:
item_id: Item ID
timeout: Timeout in seconds (None = wait indefinitely)
Returns:
Result or None if timeout
"""
import time
start_time = time.time()
while True:
if item_id in self.results:
result = self.results.pop(item_id)
return result
if timeout and (time.time() - start_time) > timeout:
return None
import time
time.sleep(0.01) # Small sleep to avoid busy waiting
def _gpu_worker(self):
"""GPU worker thread: runs inference."""
while self.running:
try:
item = self.gpu_queue.get(timeout=1.0)
if item is None: # Sentinel
break
item_id, images, metadata = item
# Run GPU inference
with torch.no_grad():
try:
output = self.gpu_worker_fn(images)
self.cpu_queue.put((item_id, output, images, metadata))
except Exception as e:
logger.error(f"GPU worker error for {item_id}: {e}")
self.results[item_id] = {"error": str(e)}
self.gpu_queue.task_done()
except Empty:
continue
except Exception as e:
logger.error(f"GPU worker thread error: {e}")
def _cpu_worker(self):
"""CPU worker thread: processes GPU outputs."""
while self.running:
try:
item = self.cpu_queue.get(timeout=1.0)
if item is None: # Sentinel
break
item_id, gpu_output, images, metadata = item
# Run CPU processing
try:
result = self.cpu_worker_fn(gpu_output, images, metadata)
self.results[item_id] = result
except Exception as e:
logger.error(f"CPU worker error for {item_id}: {e}")
self.results[item_id] = {"error": str(e)}
self.cpu_queue.task_done()
except Empty:
continue
except Exception as e:
logger.error(f"CPU worker thread error: {e}")
class AsyncBAValidator:
"""
Async BA validator that uses pipeline parallelism.
Overlaps DA3 inference (GPU) with BA validation (CPU).
"""
def __init__(
self,
model,
ba_validator,
queue_size: int = 10,
):
"""
Args:
model: DA3 model for inference
ba_validator: BAValidator instance
queue_size: Queue size for pipeline
"""
self.model = model
self.ba_validator = ba_validator
# GPU worker: model inference
def gpu_worker(images):
return self.model.inference(images)
# CPU worker: BA validation
def cpu_worker(gpu_output, images, metadata):
return self.ba_validator.validate(
images=images,
poses_model=gpu_output.extrinsics,
intrinsics=gpu_output.intrinsics if hasattr(gpu_output, "intrinsics") else None,
)
self.pipeline = PipelineProcessor(
gpu_worker_fn=gpu_worker,
cpu_worker_fn=cpu_worker,
gpu_queue_size=queue_size,
cpu_queue_size=queue_size,
)
self.pipeline.start()
def validate_async(
self,
images: List[np.ndarray],
sequence_id: Optional[str] = None,
) -> str:
"""
Submit validation request asynchronously.
Args:
images: Input images
sequence_id: Sequence identifier
Returns:
Item ID for retrieving result
"""
return self.pipeline.submit(
item_id=sequence_id or f"seq_{id(images)}",
images=images,
metadata={"sequence_id": sequence_id},
)
def get_result(self, item_id: str, timeout: Optional[float] = None) -> Optional[Dict]:
"""Get validation result."""
return self.pipeline.get_result(item_id, timeout=timeout)
def validate_sync(
self,
images: List[np.ndarray],
sequence_id: Optional[str] = None,
timeout: float = 300.0,
) -> Dict:
"""
Validate synchronously (submits and waits for result).
Args:
images: Input images
sequence_id: Sequence identifier
timeout: Timeout in seconds
Returns:
Validation result
"""
item_id = self.validate_async(images, sequence_id)
result = self.get_result(item_id, timeout=timeout)
if result is None:
raise TimeoutError(f"Validation timeout for {sequence_id}")
if "error" in result:
raise RuntimeError(f"Validation error: {result['error']}")
return result
def shutdown(self):
"""Shutdown pipeline processor."""
self.pipeline.stop()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.shutdown()