|
|
""" |
|
|
Inference optimization utilities: batching, caching, and async processing. |
|
|
""" |
|
|
|
|
|
import hashlib |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Optional, Tuple |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class BatchedInference: |
|
|
""" |
|
|
Batch multiple inference requests together for better GPU utilization. |
|
|
|
|
|
Instead of processing sequences one-by-one, collects multiple sequences |
|
|
and processes them in a single batch for 2-5x speedup. |
|
|
""" |
|
|
|
|
|
def __init__(self, model, batch_size: int = 4): |
|
|
""" |
|
|
Args: |
|
|
model: DA3 model for inference |
|
|
batch_size: Number of sequences to batch together |
|
|
""" |
|
|
self.model = model |
|
|
self.batch_size = batch_size |
|
|
self.queue: List[Tuple[List[np.ndarray], str]] = [] |
|
|
|
|
|
def add(self, images: List[np.ndarray], sequence_id: str) -> Optional[Dict]: |
|
|
""" |
|
|
Add a sequence to the batch queue. |
|
|
|
|
|
Args: |
|
|
images: List of images for the sequence |
|
|
sequence_id: Identifier for the sequence |
|
|
|
|
|
Returns: |
|
|
Results dict if batch is full and processed, None otherwise |
|
|
""" |
|
|
self.queue.append((images, sequence_id)) |
|
|
|
|
|
if len(self.queue) >= self.batch_size: |
|
|
return self.process_batch() |
|
|
return None |
|
|
|
|
|
def process_batch(self) -> List[Dict]: |
|
|
""" |
|
|
Process all queued sequences in a single batch. |
|
|
|
|
|
Returns: |
|
|
List of result dicts, one per sequence |
|
|
""" |
|
|
if not self.queue: |
|
|
return [] |
|
|
|
|
|
|
|
|
all_images = [] |
|
|
sequence_boundaries = [] |
|
|
idx = 0 |
|
|
|
|
|
for images, seq_id in self.queue: |
|
|
all_images.extend(images) |
|
|
sequence_boundaries.append((idx, idx + len(images), seq_id)) |
|
|
idx += len(images) |
|
|
|
|
|
|
|
|
logger.debug( |
|
|
f"Processing batch of {len(self.queue)} sequences ({len(all_images)} total images)" |
|
|
) |
|
|
with torch.no_grad(): |
|
|
try: |
|
|
outputs = self.model.inference(all_images) |
|
|
except Exception as e: |
|
|
logger.error(f"Batch inference failed: {e}") |
|
|
|
|
|
results = [] |
|
|
for _, _, seq_id in sequence_boundaries: |
|
|
results.append( |
|
|
{ |
|
|
"sequence_id": seq_id, |
|
|
"error": str(e), |
|
|
"extrinsics": None, |
|
|
"intrinsics": None, |
|
|
} |
|
|
) |
|
|
self.queue = [] |
|
|
return results |
|
|
|
|
|
|
|
|
results = [] |
|
|
for start, end, seq_id in sequence_boundaries: |
|
|
result = { |
|
|
"sequence_id": seq_id, |
|
|
"extrinsics": ( |
|
|
outputs.extrinsics[start:end] if hasattr(outputs, "extrinsics") else None |
|
|
), |
|
|
"intrinsics": ( |
|
|
outputs.intrinsics[start:end] if hasattr(outputs, "intrinsics") else None |
|
|
), |
|
|
"depth": outputs.depth[start:end] if hasattr(outputs, "depth") else None, |
|
|
} |
|
|
results.append(result) |
|
|
|
|
|
self.queue = [] |
|
|
return results |
|
|
|
|
|
def flush(self) -> List[Dict]: |
|
|
"""Process any remaining queued sequences.""" |
|
|
if self.queue: |
|
|
return self.process_batch() |
|
|
return [] |
|
|
|
|
|
|
|
|
class CachedInference: |
|
|
""" |
|
|
Cache inference results to avoid recomputing for identical inputs. |
|
|
|
|
|
Uses content-based hashing to detect duplicate sequences. |
|
|
""" |
|
|
|
|
|
def __init__(self, model, cache_dir: Optional[Path] = None, max_cache_size: int = 1000): |
|
|
""" |
|
|
Args: |
|
|
model: DA3 model for inference |
|
|
cache_dir: Directory to persist cache (None = in-memory only) |
|
|
max_cache_size: Maximum number of cached entries |
|
|
""" |
|
|
self.model = model |
|
|
self.cache_dir = cache_dir |
|
|
self.max_cache_size = max_cache_size |
|
|
self.cache: Dict[str, Dict] = {} |
|
|
|
|
|
if cache_dir: |
|
|
cache_dir.mkdir(parents=True, exist_ok=True) |
|
|
self._load_cache() |
|
|
|
|
|
def _hash_images(self, images: List[np.ndarray]) -> str: |
|
|
"""Create hash from image content.""" |
|
|
|
|
|
combined = [] |
|
|
for img in images: |
|
|
|
|
|
sampled = img[::100, ::100].flatten()[:1000] |
|
|
combined.append(sampled) |
|
|
|
|
|
combined_array = np.concatenate(combined) |
|
|
return hashlib.md5(combined_array.tobytes()).hexdigest() |
|
|
|
|
|
def _load_cache(self): |
|
|
"""Load cache from disk if available.""" |
|
|
if not self.cache_dir: |
|
|
return |
|
|
|
|
|
cache_file = self.cache_dir / "inference_cache.pkl" |
|
|
if cache_file.exists(): |
|
|
try: |
|
|
import pickle |
|
|
|
|
|
with open(cache_file, "rb") as f: |
|
|
self.cache = pickle.load(f) |
|
|
logger.info(f"Loaded {len(self.cache)} cached inference results") |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to load cache: {e}") |
|
|
|
|
|
def _save_cache(self): |
|
|
"""Save cache to disk.""" |
|
|
if not self.cache_dir: |
|
|
return |
|
|
|
|
|
cache_file = self.cache_dir / "inference_cache.pkl" |
|
|
try: |
|
|
import pickle |
|
|
|
|
|
with open(cache_file, "wb") as f: |
|
|
pickle.dump(self.cache, f) |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to save cache: {e}") |
|
|
|
|
|
def inference(self, images: List[np.ndarray], sequence_id: Optional[str] = None) -> Dict: |
|
|
""" |
|
|
Run inference with caching. |
|
|
|
|
|
Args: |
|
|
images: List of input images |
|
|
sequence_id: Optional sequence identifier for logging |
|
|
|
|
|
Returns: |
|
|
Inference result dict |
|
|
""" |
|
|
cache_key = self._hash_images(images) |
|
|
|
|
|
|
|
|
if cache_key in self.cache: |
|
|
logger.debug(f"Cache hit for sequence {sequence_id}") |
|
|
return self.cache[cache_key] |
|
|
|
|
|
|
|
|
logger.debug(f"Cache miss for sequence {sequence_id}, running inference...") |
|
|
with torch.no_grad(): |
|
|
output = self.model.inference(images) |
|
|
|
|
|
|
|
|
result = { |
|
|
"extrinsics": output.extrinsics if hasattr(output, "extrinsics") else None, |
|
|
"intrinsics": output.intrinsics if hasattr(output, "intrinsics") else None, |
|
|
"depth": output.depth if hasattr(output, "depth") else None, |
|
|
} |
|
|
|
|
|
|
|
|
if len(self.cache) >= self.max_cache_size: |
|
|
|
|
|
oldest_key = next(iter(self.cache)) |
|
|
del self.cache[oldest_key] |
|
|
|
|
|
self.cache[cache_key] = result |
|
|
|
|
|
|
|
|
if len(self.cache) % 100 == 0: |
|
|
self._save_cache() |
|
|
|
|
|
return result |
|
|
|
|
|
def clear_cache(self): |
|
|
"""Clear the cache.""" |
|
|
self.cache = {} |
|
|
if self.cache_dir: |
|
|
cache_file = self.cache_dir / "inference_cache.pkl" |
|
|
if cache_file.exists(): |
|
|
cache_file.unlink() |
|
|
logger.info("Cache cleared") |
|
|
|
|
|
|
|
|
class OptimizedInference: |
|
|
""" |
|
|
Combined batched and cached inference for maximum efficiency. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model, |
|
|
batch_size: int = 4, |
|
|
use_cache: bool = True, |
|
|
cache_dir: Optional[Path] = None, |
|
|
max_cache_size: int = 1000, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
model: DA3 model for inference |
|
|
batch_size: Batch size for batching |
|
|
use_cache: Enable caching |
|
|
cache_dir: Cache directory |
|
|
max_cache_size: Maximum cache size |
|
|
""" |
|
|
self.model = model |
|
|
self.batcher = BatchedInference(model, batch_size=batch_size) |
|
|
self.cache = ( |
|
|
CachedInference(model, cache_dir=cache_dir, max_cache_size=max_cache_size) |
|
|
if use_cache |
|
|
else None |
|
|
) |
|
|
|
|
|
def inference( |
|
|
self, |
|
|
images: List[np.ndarray], |
|
|
sequence_id: Optional[str] = None, |
|
|
force_batch: bool = False, |
|
|
) -> Dict: |
|
|
""" |
|
|
Run optimized inference (cached + batched). |
|
|
|
|
|
Args: |
|
|
images: List of input images |
|
|
sequence_id: Optional sequence identifier |
|
|
force_batch: Force immediate batch processing |
|
|
|
|
|
Returns: |
|
|
Inference result dict |
|
|
""" |
|
|
|
|
|
if self.cache: |
|
|
cache_key = self.cache._hash_images(images) |
|
|
if cache_key in self.cache.cache: |
|
|
return self.cache.cache[cache_key] |
|
|
|
|
|
|
|
|
if force_batch: |
|
|
|
|
|
results = self.batcher.add(images, sequence_id or "unknown") |
|
|
if results: |
|
|
return results[0] |
|
|
|
|
|
results = self.batcher.flush() |
|
|
if results: |
|
|
return results[0] |
|
|
else: |
|
|
result = self.batcher.add(images, sequence_id or "unknown") |
|
|
if result: |
|
|
return result[0] |
|
|
|
|
|
|
|
|
|
|
|
results = self.batcher.flush() |
|
|
if results: |
|
|
return results[0] |
|
|
|
|
|
|
|
|
logger.warning("Falling back to direct inference") |
|
|
with torch.no_grad(): |
|
|
output = self.model.inference(images) |
|
|
|
|
|
result = { |
|
|
"extrinsics": output.extrinsics if hasattr(output, "extrinsics") else None, |
|
|
"intrinsics": output.intrinsics if hasattr(output, "intrinsics") else None, |
|
|
"depth": output.depth if hasattr(output, "depth") else None, |
|
|
} |
|
|
|
|
|
|
|
|
if self.cache: |
|
|
cache_key = self.cache._hash_images(images) |
|
|
self.cache.cache[cache_key] = result |
|
|
|
|
|
return result |
|
|
|
|
|
def flush(self) -> List[Dict]: |
|
|
"""Process any queued batches.""" |
|
|
return self.batcher.flush() |
|
|
|