File size: 10,402 Bytes
7a87926
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
"""
Inference optimization utilities: batching, caching, and async processing.
"""

import hashlib
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch

logger = logging.getLogger(__name__)


class BatchedInference:
    """
    Batch multiple inference requests together for better GPU utilization.

    Instead of processing sequences one-by-one, collects multiple sequences
    and processes them in a single batch for 2-5x speedup.
    """

    def __init__(self, model, batch_size: int = 4):
        """
        Args:
            model: DA3 model for inference
            batch_size: Number of sequences to batch together
        """
        self.model = model
        self.batch_size = batch_size
        self.queue: List[Tuple[List[np.ndarray], str]] = []

    def add(self, images: List[np.ndarray], sequence_id: str) -> Optional[Dict]:
        """
        Add a sequence to the batch queue.

        Args:
            images: List of images for the sequence
            sequence_id: Identifier for the sequence

        Returns:
            Results dict if batch is full and processed, None otherwise
        """
        self.queue.append((images, sequence_id))

        if len(self.queue) >= self.batch_size:
            return self.process_batch()
        return None

    def process_batch(self) -> List[Dict]:
        """
        Process all queued sequences in a single batch.

        Returns:
            List of result dicts, one per sequence
        """
        if not self.queue:
            return []

        # Combine all images from all sequences
        all_images = []
        sequence_boundaries = []
        idx = 0

        for images, seq_id in self.queue:
            all_images.extend(images)
            sequence_boundaries.append((idx, idx + len(images), seq_id))
            idx += len(images)

        # Run batched inference
        logger.debug(
            f"Processing batch of {len(self.queue)} sequences ({len(all_images)} total images)"
        )
        with torch.no_grad():
            try:
                outputs = self.model.inference(all_images)
            except Exception as e:
                logger.error(f"Batch inference failed: {e}")
                # Return None for all sequences in batch
                results = []
                for _, _, seq_id in sequence_boundaries:
                    results.append(
                        {
                            "sequence_id": seq_id,
                            "error": str(e),
                            "extrinsics": None,
                            "intrinsics": None,
                        }
                    )
                self.queue = []
                return results

        # Split results back to individual sequences
        results = []
        for start, end, seq_id in sequence_boundaries:
            result = {
                "sequence_id": seq_id,
                "extrinsics": (
                    outputs.extrinsics[start:end] if hasattr(outputs, "extrinsics") else None
                ),
                "intrinsics": (
                    outputs.intrinsics[start:end] if hasattr(outputs, "intrinsics") else None
                ),
                "depth": outputs.depth[start:end] if hasattr(outputs, "depth") else None,
            }
            results.append(result)

        self.queue = []
        return results

    def flush(self) -> List[Dict]:
        """Process any remaining queued sequences."""
        if self.queue:
            return self.process_batch()
        return []


class CachedInference:
    """
    Cache inference results to avoid recomputing for identical inputs.

    Uses content-based hashing to detect duplicate sequences.
    """

    def __init__(self, model, cache_dir: Optional[Path] = None, max_cache_size: int = 1000):
        """
        Args:
            model: DA3 model for inference
            cache_dir: Directory to persist cache (None = in-memory only)
            max_cache_size: Maximum number of cached entries
        """
        self.model = model
        self.cache_dir = cache_dir
        self.max_cache_size = max_cache_size
        self.cache: Dict[str, Dict] = {}

        if cache_dir:
            cache_dir.mkdir(parents=True, exist_ok=True)
            self._load_cache()

    def _hash_images(self, images: List[np.ndarray]) -> str:
        """Create hash from image content."""
        # Sample pixels from each image for faster hashing
        combined = []
        for img in images:
            # Sample every 100th pixel to create signature
            sampled = img[::100, ::100].flatten()[:1000]
            combined.append(sampled)

        combined_array = np.concatenate(combined)
        return hashlib.md5(combined_array.tobytes()).hexdigest()

    def _load_cache(self):
        """Load cache from disk if available."""
        if not self.cache_dir:
            return

        cache_file = self.cache_dir / "inference_cache.pkl"
        if cache_file.exists():
            try:
                import pickle

                with open(cache_file, "rb") as f:
                    self.cache = pickle.load(f)
                logger.info(f"Loaded {len(self.cache)} cached inference results")
            except Exception as e:
                logger.warning(f"Failed to load cache: {e}")

    def _save_cache(self):
        """Save cache to disk."""
        if not self.cache_dir:
            return

        cache_file = self.cache_dir / "inference_cache.pkl"
        try:
            import pickle

            with open(cache_file, "wb") as f:
                pickle.dump(self.cache, f)
        except Exception as e:
            logger.warning(f"Failed to save cache: {e}")

    def inference(self, images: List[np.ndarray], sequence_id: Optional[str] = None) -> Dict:
        """
        Run inference with caching.

        Args:
            images: List of input images
            sequence_id: Optional sequence identifier for logging

        Returns:
            Inference result dict
        """
        cache_key = self._hash_images(images)

        # Check cache
        if cache_key in self.cache:
            logger.debug(f"Cache hit for sequence {sequence_id}")
            return self.cache[cache_key]

        # Run inference
        logger.debug(f"Cache miss for sequence {sequence_id}, running inference...")
        with torch.no_grad():
            output = self.model.inference(images)

        # Store result
        result = {
            "extrinsics": output.extrinsics if hasattr(output, "extrinsics") else None,
            "intrinsics": output.intrinsics if hasattr(output, "intrinsics") else None,
            "depth": output.depth if hasattr(output, "depth") else None,
        }

        # Manage cache size
        if len(self.cache) >= self.max_cache_size:
            # Remove oldest entry (simple FIFO)
            oldest_key = next(iter(self.cache))
            del self.cache[oldest_key]

        self.cache[cache_key] = result

        # Periodically save cache
        if len(self.cache) % 100 == 0:
            self._save_cache()

        return result

    def clear_cache(self):
        """Clear the cache."""
        self.cache = {}
        if self.cache_dir:
            cache_file = self.cache_dir / "inference_cache.pkl"
            if cache_file.exists():
                cache_file.unlink()
        logger.info("Cache cleared")


class OptimizedInference:
    """
    Combined batched and cached inference for maximum efficiency.
    """

    def __init__(
        self,
        model,
        batch_size: int = 4,
        use_cache: bool = True,
        cache_dir: Optional[Path] = None,
        max_cache_size: int = 1000,
    ):
        """
        Args:
            model: DA3 model for inference
            batch_size: Batch size for batching
            use_cache: Enable caching
            cache_dir: Cache directory
            max_cache_size: Maximum cache size
        """
        self.model = model
        self.batcher = BatchedInference(model, batch_size=batch_size)
        self.cache = (
            CachedInference(model, cache_dir=cache_dir, max_cache_size=max_cache_size)
            if use_cache
            else None
        )

    def inference(
        self,
        images: List[np.ndarray],
        sequence_id: Optional[str] = None,
        force_batch: bool = False,
    ) -> Dict:
        """
        Run optimized inference (cached + batched).

        Args:
            images: List of input images
            sequence_id: Optional sequence identifier
            force_batch: Force immediate batch processing

        Returns:
            Inference result dict
        """
        # Check cache first
        if self.cache:
            cache_key = self.cache._hash_images(images)
            if cache_key in self.cache.cache:
                return self.cache.cache[cache_key]

        # Add to batch queue
        if force_batch:
            # Process immediately
            results = self.batcher.add(images, sequence_id or "unknown")
            if results:
                return results[0]
            # If batch not full, flush it
            results = self.batcher.flush()
            if results:
                return results[0]
        else:
            result = self.batcher.add(images, sequence_id or "unknown")
            if result:
                return result[0]

        # If we get here, item is queued but not processed yet
        # For immediate results, flush the batch
        results = self.batcher.flush()
        if results:
            return results[0]

        # Fallback to direct inference
        logger.warning("Falling back to direct inference")
        with torch.no_grad():
            output = self.model.inference(images)

        result = {
            "extrinsics": output.extrinsics if hasattr(output, "extrinsics") else None,
            "intrinsics": output.intrinsics if hasattr(output, "intrinsics") else None,
            "depth": output.depth if hasattr(output, "depth") else None,
        }

        # Cache result
        if self.cache:
            cache_key = self.cache._hash_images(images)
            self.cache.cache[cache_key] = result

        return result

    def flush(self) -> List[Dict]:
        """Process any queued batches."""
        return self.batcher.flush()