""" Inference optimization utilities: batching, caching, and async processing. """ import hashlib import logging from pathlib import Path from typing import Dict, List, Optional, Tuple import numpy as np import torch logger = logging.getLogger(__name__) class BatchedInference: """ Batch multiple inference requests together for better GPU utilization. Instead of processing sequences one-by-one, collects multiple sequences and processes them in a single batch for 2-5x speedup. """ def __init__(self, model, batch_size: int = 4): """ Args: model: DA3 model for inference batch_size: Number of sequences to batch together """ self.model = model self.batch_size = batch_size self.queue: List[Tuple[List[np.ndarray], str]] = [] def add(self, images: List[np.ndarray], sequence_id: str) -> Optional[Dict]: """ Add a sequence to the batch queue. Args: images: List of images for the sequence sequence_id: Identifier for the sequence Returns: Results dict if batch is full and processed, None otherwise """ self.queue.append((images, sequence_id)) if len(self.queue) >= self.batch_size: return self.process_batch() return None def process_batch(self) -> List[Dict]: """ Process all queued sequences in a single batch. Returns: List of result dicts, one per sequence """ if not self.queue: return [] # Combine all images from all sequences all_images = [] sequence_boundaries = [] idx = 0 for images, seq_id in self.queue: all_images.extend(images) sequence_boundaries.append((idx, idx + len(images), seq_id)) idx += len(images) # Run batched inference logger.debug( f"Processing batch of {len(self.queue)} sequences ({len(all_images)} total images)" ) with torch.no_grad(): try: outputs = self.model.inference(all_images) except Exception as e: logger.error(f"Batch inference failed: {e}") # Return None for all sequences in batch results = [] for _, _, seq_id in sequence_boundaries: results.append( { "sequence_id": seq_id, "error": str(e), "extrinsics": None, "intrinsics": None, } ) self.queue = [] return results # Split results back to individual sequences results = [] for start, end, seq_id in sequence_boundaries: result = { "sequence_id": seq_id, "extrinsics": ( outputs.extrinsics[start:end] if hasattr(outputs, "extrinsics") else None ), "intrinsics": ( outputs.intrinsics[start:end] if hasattr(outputs, "intrinsics") else None ), "depth": outputs.depth[start:end] if hasattr(outputs, "depth") else None, } results.append(result) self.queue = [] return results def flush(self) -> List[Dict]: """Process any remaining queued sequences.""" if self.queue: return self.process_batch() return [] class CachedInference: """ Cache inference results to avoid recomputing for identical inputs. Uses content-based hashing to detect duplicate sequences. """ def __init__(self, model, cache_dir: Optional[Path] = None, max_cache_size: int = 1000): """ Args: model: DA3 model for inference cache_dir: Directory to persist cache (None = in-memory only) max_cache_size: Maximum number of cached entries """ self.model = model self.cache_dir = cache_dir self.max_cache_size = max_cache_size self.cache: Dict[str, Dict] = {} if cache_dir: cache_dir.mkdir(parents=True, exist_ok=True) self._load_cache() def _hash_images(self, images: List[np.ndarray]) -> str: """Create hash from image content.""" # Sample pixels from each image for faster hashing combined = [] for img in images: # Sample every 100th pixel to create signature sampled = img[::100, ::100].flatten()[:1000] combined.append(sampled) combined_array = np.concatenate(combined) return hashlib.md5(combined_array.tobytes()).hexdigest() def _load_cache(self): """Load cache from disk if available.""" if not self.cache_dir: return cache_file = self.cache_dir / "inference_cache.pkl" if cache_file.exists(): try: import pickle with open(cache_file, "rb") as f: self.cache = pickle.load(f) logger.info(f"Loaded {len(self.cache)} cached inference results") except Exception as e: logger.warning(f"Failed to load cache: {e}") def _save_cache(self): """Save cache to disk.""" if not self.cache_dir: return cache_file = self.cache_dir / "inference_cache.pkl" try: import pickle with open(cache_file, "wb") as f: pickle.dump(self.cache, f) except Exception as e: logger.warning(f"Failed to save cache: {e}") def inference(self, images: List[np.ndarray], sequence_id: Optional[str] = None) -> Dict: """ Run inference with caching. Args: images: List of input images sequence_id: Optional sequence identifier for logging Returns: Inference result dict """ cache_key = self._hash_images(images) # Check cache if cache_key in self.cache: logger.debug(f"Cache hit for sequence {sequence_id}") return self.cache[cache_key] # Run inference logger.debug(f"Cache miss for sequence {sequence_id}, running inference...") with torch.no_grad(): output = self.model.inference(images) # Store result result = { "extrinsics": output.extrinsics if hasattr(output, "extrinsics") else None, "intrinsics": output.intrinsics if hasattr(output, "intrinsics") else None, "depth": output.depth if hasattr(output, "depth") else None, } # Manage cache size if len(self.cache) >= self.max_cache_size: # Remove oldest entry (simple FIFO) oldest_key = next(iter(self.cache)) del self.cache[oldest_key] self.cache[cache_key] = result # Periodically save cache if len(self.cache) % 100 == 0: self._save_cache() return result def clear_cache(self): """Clear the cache.""" self.cache = {} if self.cache_dir: cache_file = self.cache_dir / "inference_cache.pkl" if cache_file.exists(): cache_file.unlink() logger.info("Cache cleared") class OptimizedInference: """ Combined batched and cached inference for maximum efficiency. """ def __init__( self, model, batch_size: int = 4, use_cache: bool = True, cache_dir: Optional[Path] = None, max_cache_size: int = 1000, ): """ Args: model: DA3 model for inference batch_size: Batch size for batching use_cache: Enable caching cache_dir: Cache directory max_cache_size: Maximum cache size """ self.model = model self.batcher = BatchedInference(model, batch_size=batch_size) self.cache = ( CachedInference(model, cache_dir=cache_dir, max_cache_size=max_cache_size) if use_cache else None ) def inference( self, images: List[np.ndarray], sequence_id: Optional[str] = None, force_batch: bool = False, ) -> Dict: """ Run optimized inference (cached + batched). Args: images: List of input images sequence_id: Optional sequence identifier force_batch: Force immediate batch processing Returns: Inference result dict """ # Check cache first if self.cache: cache_key = self.cache._hash_images(images) if cache_key in self.cache.cache: return self.cache.cache[cache_key] # Add to batch queue if force_batch: # Process immediately results = self.batcher.add(images, sequence_id or "unknown") if results: return results[0] # If batch not full, flush it results = self.batcher.flush() if results: return results[0] else: result = self.batcher.add(images, sequence_id or "unknown") if result: return result[0] # If we get here, item is queued but not processed yet # For immediate results, flush the batch results = self.batcher.flush() if results: return results[0] # Fallback to direct inference logger.warning("Falling back to direct inference") with torch.no_grad(): output = self.model.inference(images) result = { "extrinsics": output.extrinsics if hasattr(output, "extrinsics") else None, "intrinsics": output.intrinsics if hasattr(output, "intrinsics") else None, "depth": output.depth if hasattr(output, "depth") else None, } # Cache result if self.cache: cache_key = self.cache._hash_images(images) self.cache.cache[cache_key] = result return result def flush(self) -> List[Dict]: """Process any queued batches.""" return self.batcher.flush()