3d_model / ylff /utils /inference_optimizer.py
Azan
Clean deployment build (Squashed)
7a87926
"""
Inference optimization utilities: batching, caching, and async processing.
"""
import hashlib
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
logger = logging.getLogger(__name__)
class BatchedInference:
"""
Batch multiple inference requests together for better GPU utilization.
Instead of processing sequences one-by-one, collects multiple sequences
and processes them in a single batch for 2-5x speedup.
"""
def __init__(self, model, batch_size: int = 4):
"""
Args:
model: DA3 model for inference
batch_size: Number of sequences to batch together
"""
self.model = model
self.batch_size = batch_size
self.queue: List[Tuple[List[np.ndarray], str]] = []
def add(self, images: List[np.ndarray], sequence_id: str) -> Optional[Dict]:
"""
Add a sequence to the batch queue.
Args:
images: List of images for the sequence
sequence_id: Identifier for the sequence
Returns:
Results dict if batch is full and processed, None otherwise
"""
self.queue.append((images, sequence_id))
if len(self.queue) >= self.batch_size:
return self.process_batch()
return None
def process_batch(self) -> List[Dict]:
"""
Process all queued sequences in a single batch.
Returns:
List of result dicts, one per sequence
"""
if not self.queue:
return []
# Combine all images from all sequences
all_images = []
sequence_boundaries = []
idx = 0
for images, seq_id in self.queue:
all_images.extend(images)
sequence_boundaries.append((idx, idx + len(images), seq_id))
idx += len(images)
# Run batched inference
logger.debug(
f"Processing batch of {len(self.queue)} sequences ({len(all_images)} total images)"
)
with torch.no_grad():
try:
outputs = self.model.inference(all_images)
except Exception as e:
logger.error(f"Batch inference failed: {e}")
# Return None for all sequences in batch
results = []
for _, _, seq_id in sequence_boundaries:
results.append(
{
"sequence_id": seq_id,
"error": str(e),
"extrinsics": None,
"intrinsics": None,
}
)
self.queue = []
return results
# Split results back to individual sequences
results = []
for start, end, seq_id in sequence_boundaries:
result = {
"sequence_id": seq_id,
"extrinsics": (
outputs.extrinsics[start:end] if hasattr(outputs, "extrinsics") else None
),
"intrinsics": (
outputs.intrinsics[start:end] if hasattr(outputs, "intrinsics") else None
),
"depth": outputs.depth[start:end] if hasattr(outputs, "depth") else None,
}
results.append(result)
self.queue = []
return results
def flush(self) -> List[Dict]:
"""Process any remaining queued sequences."""
if self.queue:
return self.process_batch()
return []
class CachedInference:
"""
Cache inference results to avoid recomputing for identical inputs.
Uses content-based hashing to detect duplicate sequences.
"""
def __init__(self, model, cache_dir: Optional[Path] = None, max_cache_size: int = 1000):
"""
Args:
model: DA3 model for inference
cache_dir: Directory to persist cache (None = in-memory only)
max_cache_size: Maximum number of cached entries
"""
self.model = model
self.cache_dir = cache_dir
self.max_cache_size = max_cache_size
self.cache: Dict[str, Dict] = {}
if cache_dir:
cache_dir.mkdir(parents=True, exist_ok=True)
self._load_cache()
def _hash_images(self, images: List[np.ndarray]) -> str:
"""Create hash from image content."""
# Sample pixels from each image for faster hashing
combined = []
for img in images:
# Sample every 100th pixel to create signature
sampled = img[::100, ::100].flatten()[:1000]
combined.append(sampled)
combined_array = np.concatenate(combined)
return hashlib.md5(combined_array.tobytes()).hexdigest()
def _load_cache(self):
"""Load cache from disk if available."""
if not self.cache_dir:
return
cache_file = self.cache_dir / "inference_cache.pkl"
if cache_file.exists():
try:
import pickle
with open(cache_file, "rb") as f:
self.cache = pickle.load(f)
logger.info(f"Loaded {len(self.cache)} cached inference results")
except Exception as e:
logger.warning(f"Failed to load cache: {e}")
def _save_cache(self):
"""Save cache to disk."""
if not self.cache_dir:
return
cache_file = self.cache_dir / "inference_cache.pkl"
try:
import pickle
with open(cache_file, "wb") as f:
pickle.dump(self.cache, f)
except Exception as e:
logger.warning(f"Failed to save cache: {e}")
def inference(self, images: List[np.ndarray], sequence_id: Optional[str] = None) -> Dict:
"""
Run inference with caching.
Args:
images: List of input images
sequence_id: Optional sequence identifier for logging
Returns:
Inference result dict
"""
cache_key = self._hash_images(images)
# Check cache
if cache_key in self.cache:
logger.debug(f"Cache hit for sequence {sequence_id}")
return self.cache[cache_key]
# Run inference
logger.debug(f"Cache miss for sequence {sequence_id}, running inference...")
with torch.no_grad():
output = self.model.inference(images)
# Store result
result = {
"extrinsics": output.extrinsics if hasattr(output, "extrinsics") else None,
"intrinsics": output.intrinsics if hasattr(output, "intrinsics") else None,
"depth": output.depth if hasattr(output, "depth") else None,
}
# Manage cache size
if len(self.cache) >= self.max_cache_size:
# Remove oldest entry (simple FIFO)
oldest_key = next(iter(self.cache))
del self.cache[oldest_key]
self.cache[cache_key] = result
# Periodically save cache
if len(self.cache) % 100 == 0:
self._save_cache()
return result
def clear_cache(self):
"""Clear the cache."""
self.cache = {}
if self.cache_dir:
cache_file = self.cache_dir / "inference_cache.pkl"
if cache_file.exists():
cache_file.unlink()
logger.info("Cache cleared")
class OptimizedInference:
"""
Combined batched and cached inference for maximum efficiency.
"""
def __init__(
self,
model,
batch_size: int = 4,
use_cache: bool = True,
cache_dir: Optional[Path] = None,
max_cache_size: int = 1000,
):
"""
Args:
model: DA3 model for inference
batch_size: Batch size for batching
use_cache: Enable caching
cache_dir: Cache directory
max_cache_size: Maximum cache size
"""
self.model = model
self.batcher = BatchedInference(model, batch_size=batch_size)
self.cache = (
CachedInference(model, cache_dir=cache_dir, max_cache_size=max_cache_size)
if use_cache
else None
)
def inference(
self,
images: List[np.ndarray],
sequence_id: Optional[str] = None,
force_batch: bool = False,
) -> Dict:
"""
Run optimized inference (cached + batched).
Args:
images: List of input images
sequence_id: Optional sequence identifier
force_batch: Force immediate batch processing
Returns:
Inference result dict
"""
# Check cache first
if self.cache:
cache_key = self.cache._hash_images(images)
if cache_key in self.cache.cache:
return self.cache.cache[cache_key]
# Add to batch queue
if force_batch:
# Process immediately
results = self.batcher.add(images, sequence_id or "unknown")
if results:
return results[0]
# If batch not full, flush it
results = self.batcher.flush()
if results:
return results[0]
else:
result = self.batcher.add(images, sequence_id or "unknown")
if result:
return result[0]
# If we get here, item is queued but not processed yet
# For immediate results, flush the batch
results = self.batcher.flush()
if results:
return results[0]
# Fallback to direct inference
logger.warning("Falling back to direct inference")
with torch.no_grad():
output = self.model.inference(images)
result = {
"extrinsics": output.extrinsics if hasattr(output, "extrinsics") else None,
"intrinsics": output.intrinsics if hasattr(output, "intrinsics") else None,
"depth": output.depth if hasattr(output, "depth") else None,
}
# Cache result
if self.cache:
cache_key = self.cache._hash_images(images)
self.cache.cache[cache_key] = result
return result
def flush(self) -> List[Dict]:
"""Process any queued batches."""
return self.batcher.flush()