raheebhassan commited on
Commit
4fec4e4
·
1 Parent(s): fa7e75c

Initial Commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes CHANGED
@@ -1,35 +1,8 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ checkpoints/tic_lambda_0.0932.pth.tar filter=lfs diff=lfs merge=lfs -text
2
+ checkpoints/tic_lambda_0.0035.pth.tar filter=lfs diff=lfs merge=lfs -text
3
+ checkpoints/tic_lambda_0.013.pth.tar filter=lfs diff=lfs merge=lfs -text
4
+ checkpoints/tic_lambda_0.025.pth.tar filter=lfs diff=lfs merge=lfs -text
5
+ checkpoints/tic_lambda_0.0483.pth.tar filter=lfs diff=lfs merge=lfs -text
6
+ images/*/*.jpg filter=lfs diff=lfs merge=lfs -text
7
+ images/*/*.jpeg filter=lfs diff=lfs merge=lfs -text
8
+ images/*/*.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.github/copilot-instructions.md ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ROI-VAE Image Compression - Copilot Instructions
2
+
3
+ ## Project Overview
4
+ ROI-based VAE image compression using TIC (Transformer-based Image Compression). The system preserves quality in Regions of Interest (ROI) while aggressively compressing backgrounds using configurable quality factors.
5
+
6
+ ## Architecture
7
+
8
+ ### Core Pipeline
9
+ 1. **Segmentation** (`segmentation/` module) → 2. **Compression** (`vae/` module) → 3. **Output**
10
+ - Segmentation creates binary masks (1=ROI, 0=background)
11
+ - Compression applies variable quality based on mask using `sigma` parameter
12
+
13
+ ### Key Components
14
+
15
+ **Segmentation Module** (`segmentation/`):
16
+ - Abstract base class `BaseSegmenter` defines common interface
17
+ - Implementations:
18
+ - `SegFormerSegmenter` - Cityscapes semantic segmentation (19 classes: road, car, building, person, etc.)
19
+ - `YOLOSegmenter` - COCO instance segmentation (80 classes)
20
+ - `Mask2FormerSegmenter` - Swin Transformer-based panoptic/semantic segmentation (COCO: 133 classes, ADE20K: 150 classes)
21
+ - `MaskRCNNSegmenter` - ResNet50-FPN instance segmentation (COCO: 80 classes)
22
+ - `SAM3Segmenter` - Prompt-based segmentation (natural language prompt → mask via text-conditioned detector + SAM)
23
+ - `FakeSegmenter` - Detection + tracking → bbox masks (fast, non-pixel-perfect)
24
+ - **Fake Segmentation** (NEW): Detection-based segmentation for speed
25
+ - Creates rectangular masks from detection bounding boxes
26
+ - Uses object tracking for temporal consistency (ByteTrack, BoTSORT, SimpleTracker)
27
+ - Available methods: `fake_yolo` (default, ByteTrack), `fake_yolo_botsort`, `fake_detr`, `fake_fasterrcnn`, `fake_retinanet`, `fake_fcos`, `fake_deformable_detr`, `fake_grounding_dino`
28
+ - Much faster than pixel-perfect segmentation (~60-100 fps vs 10-30 fps)
29
+ - Memory estimates in `gpu_memory.py`: 120-200 MB per frame (vs 180-500 MB for full segmentation)
30
+ - Factory pattern: `create_segmenter('yolo', device='cuda')` or `create_segmenter('fake_yolo', device='cuda')`
31
+ - Extensible for future models
32
+ - Utils: `visualize_mask()`, `save_mask()`, `calculate_roi_stats()`
33
+
34
+ **Compression Module** (`vae/`):
35
+ - `tic_model.py`: Base `TIC` class - Transformer-based VAE with encoder, decoder, hyperprior
36
+ - `RSTB.py`: Residual Swin Transformer Blocks and attention modules
37
+ - `transformer_layers.py`: Generic transformer components (MLP, attention, drop path)
38
+ - `roi_tic.py`: `ModifiedTIC` class extending base TIC with ROI-aware forward pass
39
+ - `utils.py`: `compress_image()`, `compute_padding()` for image processing
40
+ - `visualization.py`: `highlight_roi()`, `create_comparison_grid()` for results
41
+ - Handles checkpoint loading with compressai version compatibility fixes
42
+
43
+ **Detection Module** (`detection/`):
44
+ - Abstract base class `BaseDetector` defines common interface
45
+ - Factory pattern: `create_detector('yolo', device='cuda')`
46
+ - Implementations:
47
+ - `YOLODetector` - Ultralytics YOLO (closed-vocabulary COCO weights)
48
+ - Torchvision: Faster R-CNN, RetinaNet, SSD, FCOS
49
+ - Transformers: DETR, Deformable DETR
50
+ - `EfficientDetDetector` - optional via `effdet`
51
+ - `YOLOWorldDetector` - open-vocabulary detection (Ultralytics YOLO-World; requires prompts)
52
+ - `GroundingDINODetector` - open-vocabulary detection (Transformers; requires prompts)
53
+ - CLI: `roi_detection_eval.py` evaluates detection retention before vs after ROI compression
54
+
55
+ **TIC Model** (`vae/tic_model.py`):
56
+ - Transformer-based VAE with encoder (`g_a`), decoder (`g_s`), and hyperprior (`h_a`, `h_s`)
57
+ - Uses RSTB (Residual Swin Transformer Blocks) for feature extraction
58
+ - Channels: N=192, M=192 (expansion layer)
59
+ - Critical: Images must be padded to multiples of 256 (use `compute_padding()`)
60
+
61
+ **ModifiedTIC** (`vae/roi_tic.py`):
62
+ - Extends base TIC with ROI-aware forward pass
63
+ - Takes mask + sigma parameter to create quality factors
64
+ - Applies `similarity_loss` tensor: 1.0 for ROI pixels, sigma for background
65
+ - Integrates mask through `simi_net` and `sub_impor_net` branches
66
+
67
+ ## Critical Conventions
68
+
69
+ ### Model Cache Locations
70
+ - By default, auto-downloaded model artifacts are kept inside `checkpoints/`:
71
+ - Hugging Face cache: `checkpoints/hf/`
72
+ - Torch/torchvision cache: `checkpoints/torch/`
73
+
74
+ ### Checkpoint Loading Pattern
75
+ ```python
76
+ from vae import load_checkpoint
77
+
78
+ # Automatically handles compressai version mismatch
79
+ model = load_checkpoint('checkpoints/tic_lambda_0.0483.pth.tar', N=192, M=192, device='cuda')
80
+ # Note: model.update(force=True) is called automatically
81
+ ```
82
+
83
+ Manual loading:
84
+ ```python
85
+ # Fix compressai version mismatch - required for all checkpoint loading
86
+ state_dict = checkpoint["state_dict"]
87
+ new_state_dict = {}
88
+ for k, v in state_dict.items():
89
+ if "entropy_bottleneck._matrix" in k:
90
+ new_key = k.replace("entropy_bottleneck._matrix", "entropy_bottleneck.matrices.")
91
+ # ... similar replacements for _bias, _factor
92
+ ```
93
+ Always call `model.update(force=True)` after loading checkpoints.
94
+
95
+ ### Image Preprocessing
96
+ 1. Convert PIL to torch tensor: `x = torch.from_numpy(np.array(img)).float() / 255.0`
97
+ 2. Permute to [B, C, H, W]: `x = x.permute(2, 0, 1).unsqueeze(0)`
98
+ 3. Pad to 256 multiples using `compute_padding(h, w, min_div=256)`
99
+ 4. Apply mask at same resolution as input image
100
+
101
+ ### Sigma Parameter
102
+ - Range: 0.01 - 1.0 (lower = more background compression)
103
+ - Default: 0.3
104
+ - ROI pixels always get quality factor 1.0
105
+ - Applied via `torch.where(mask > 0.5, 1.0, sigma)`
106
+
107
+ ### Available Checkpoints
108
+ Located in `checkpoints/` directory with different lambda (rate-distortion) values:
109
+ - `tic_lambda_0.0035.pth.tar` - Lowest bitrate (highest compression)
110
+ - `tic_lambda_0.013.pth.tar` - Low bitrate (N=128, M=192)
111
+ - `tic_lambda_0.025.pth.tar` - Medium-low bitrate
112
+ - `tic_lambda_0.0483.pth.tar` - **Default** - Medium bitrate
113
+ - `tic_lambda_0.0932.pth.tar` - High bitrate (better quality)
114
+ - `yolo26x-seg.pt` - YOLO segmentation model
115
+
116
+ ## Development Workflows
117
+
118
+ ### Using Segmentation Module (New)
119
+ ```python
120
+ from segmentation import create_segmenter
121
+
122
+ # Available methods: segformer, yolo, mask2former, maskrcnn, sam3
123
+ # Fake methods: fake_yolo, fake_yolo_botsort, fake_detr, fake_fasterrcnn, etc.
124
+ segmenter = create_segmenter('mask2former', device='cuda', model_type='coco')
125
+
126
+ # Segment image
127
+ mask = segmenter(image, target_classes=['car', 'person'])
128
+
129
+ # Fast segmentation with detection + tracking (non-pixel-perfect)
130
+ fake_seg = create_segmenter('fake_yolo', device='cuda')
131
+ mask = fake_seg(image, target_classes=['person']) # Uses ByteTrack tracking
132
+ # Much faster: ~60-100 fps vs 10-30 fps for pixel-perfect segmentation
133
+
134
+ # Add new segmentation method
135
+ from segmentation import register_segmenter, BaseSegmenter
136
+
137
+ class MySegmenter(BaseSegmenter):
138
+ def load_model(self): ...
139
+ def segment(self, image, target_classes, **kwargs): ...
140
+ def get_available_classes(self): ...
141
+
142
+ register_segmenter('my_method', MySegmenter)
143
+ ```
144
+
145
+ ### Using Compression Module (New)
146
+ ```python
147
+ from vae import load_checkpoint, compress_image
148
+ from PIL import Image
149
+ import numpy as np
150
+
151
+ # Load model
152
+ model = load_checkpoint('checkpoints/tic_lambda_0.0483.pth.tar', device='cuda')
153
+
154
+ # Compress image with mask
155
+ image = Image.open('input.jpg')
156
+ mask = np.zeros((image.height, image.width)) # Your mask here
157
+
158
+ result = compress_image(image, mask, model, sigma=0.3, device='cuda')
159
+ compressed = result['compressed'] # PIL Image
160
+ bpp = result['bpp'] # Bits per pixel
161
+
162
+ # Visualize results
163
+ from vae import create_comparison_grid
164
+ grid = create_comparison_grid(image, compressed, mask, bpp, sigma=0.3, lambda_val=0.0483)
165
+ grid.save('comparison.jpg')
166
+ ```
167
+
168
+ ### Using Detection Module (New)
169
+ ```python
170
+ from detection import create_detector
171
+
172
+ # Closed-vocabulary
173
+ det = create_detector('yolo', device='cuda', model_path='checkpoints/yolo26x.pt')
174
+ dets = det(image, conf_threshold=0.25)
175
+
176
+ # Open-vocabulary (must pass prompts/classes)
177
+ det_ov = create_detector('yolo_world', device='cuda')
178
+ dets_ov = det_ov(image, conf_threshold=0.25, classes='person,car')
179
+ ```
180
+
181
+ ### Detection Eval (CLI)
182
+ ```bash
183
+ # Compare before vs after (already-compressed)
184
+ python roi_detection_eval.py \
185
+ --before images/car/0016cf15fa4d4e16.jpg \
186
+ --after results/compressed.jpg \
187
+ --detectors yolo detr \
188
+ --viz-dir results/det_viz
189
+
190
+ # Open-vocabulary eval (YOLO-World requires prompts)
191
+ python roi_detection_eval.py \
192
+ --before images/person/kodim04.png \
193
+ --checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
194
+ --sigma 0.3 \
195
+ --seg-method yolo --seg-classes person \
196
+ --detectors yolo_world \
197
+ --open-vocab-classes "person,car" \
198
+ --viz-dir results/det_viz
199
+ ```
200
+
201
+ ### Running Compression (CLI)
202
+ ```bash
203
+ # Basic compression with segmentation
204
+ python roi_compressor.py \
205
+ --input images/car/0016cf15fa4d4e16.jpg \
206
+ --output results/compressed.jpg \
207
+ --checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
208
+ --sigma 0.3 \
209
+ --seg-classes car \
210
+ --seg-method yolo
211
+
212
+ # Fast compression with detection-based fake segmentation (~3x faster)
213
+ python roi_compressor.py \
214
+ --input images/car/0016cf15fa4d4e16.jpg \
215
+ --output results/compressed.jpg \
216
+ --checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
217
+ --sigma 0.3 \
218
+ --seg-classes car \
219
+ --seg-method fake_yolo
220
+
221
+ # With comparison grid (original, compressed, ROI highlighted)
222
+ python roi_compressor.py ... --highlight
223
+ ```
224
+
225
+ ### Standalone Segmentation (CLI)
226
+ ```bash
227
+ # Using Mask2Former with COCO panoptic
228
+ python roi_segmenter.py \
229
+ --input images/car/0016cf15fa4d4e16.jpg \
230
+ --output results/mask.png \
231
+ --method mask2former \
232
+ --classes car building person \
233
+ --visualize
234
+
235
+ # Fast segmentation with detection + ByteTrack tracking
236
+ python roi_segmenter.py \
237
+ --input data/videos/Person_doing_handstand.mp4 \
238
+ --output results/masks.mp4 \
239
+ --method fake_yolo \
240
+ --classes person \
241
+ --resize-height 480 \
242
+ --smooth-patience 10 \
243
+ --visualize
244
+
245
+ # Other fake methods (detection + tracking)
246
+ # fake_yolo_botsort (YOLO + BoTSORT)
247
+ # fake_detr (DETR + SimpleTracker)
248
+ # fake_fasterrcnn, fake_retinanet, fake_fcos, etc.
249
+ ```
250
+
251
+ ### Adding New Segmentation Models
252
+ 1. Create new file in `segmentation/` (e.g., `sam.py`)
253
+ 2. Extend `BaseSegmenter` and implement abstract methods:
254
+ - `load_model()`: Load model weights
255
+ - `segment()`: Generate mask from image
256
+ - `get_available_classes()`: Return supported classes/capabilities
257
+ 3. Register in `segmentation/__init__.py` or use `register_segmenter()`
258
+ 4. Use via `create_segmenter('your_method', ...)`
259
+
260
+ ### Testing Examples
261
+ - `roi_segmenter.py`: CLI tool for standalone segmentation
262
+ - `roi_compressor.py`: CLI tool for ROI-based image compression
263
+ - `roi_segmenter.py`: CLI tool for standalone segmentation
264
+ - `roi_compressor.py`: CLI tool for ROI-based image compression
265
+
266
+ - `segmentation/`: Modular segmentation with abstract base class
267
+ - `base.py`: `BaseSegmenter` abstract class
268
+ - `segformer.py`: Cityscapes semantic segmentation
269
+ - `yolo.py`: COCO instance segmentation
270
+ - `factory.py`: Factory pattern for creating segmenters
271
+ - `utils.py`: Visualization and I/O utilities
272
+ - `vae/`: Modular compression with ROI support
273
+ - `tic_model.py`: Base `TIC` class (Transformer-based VAE)
274
+ - `RSTB.py`: Residual Swin Transformer Blocks
275
+ - `transformer_layers.py`: Generic transformer components
276
+ - `roi_tic.py`: `ModifiedTIC` class and checkpoint loading
277
+ - `utils.py`: `compress_image()`, `compute_padding()`
278
+ - `visualization.py`: `highlight_roi()`, `create_comparison_grid()`
279
+ - `roi_segmenter.py`: CLI tool for standalone segmentation
280
+ - `roi_compressor.py`: CLI tool for ROI-based compression
281
+ - `vae_compress.py`: Legacy ROI compression script (updated to use modules)
282
+ - `*.bak`: Backup files from pre-modularization (tic_model, RSTB, etc.)
283
+
284
+ ## Dependencies
285
+ - PyTorch + torchvision for model
286
+ - compressai for entropy models (version sensitive - see checkpoint loading)
287
+ - transformers for SegFormer + DETR/Deformable DETR + Grounding DINO
288
+ - ultralytics for YOLO + YOLO-World
289
+ - effdet (optional) for EfficientDet detector
290
+ - timm for model layers
291
+
292
+ ## Common Pitfalls
293
+ 1. **Padding**: Forgetting to pad images to 256 multiples causes dimension mismatches
294
+ 2. **Checkpoint keys**: Old checkpoints use `_matrix/_bias/_factor` naming that must be converted
295
+ 3. **Mask resolution**: Mask must match input image size; it's automatically downsampled in forward pass
296
+ 4. **Mask downsampling**: In ModifiedTIC, mask is downsampled to 1/2 resolution before simi_net (which further downsamples 8x to match 16x16 latent)
297
+ 5. **Device mismatch**: Ensure mask, sigma tensor, and model are on same device
298
+ 6. **Model update**: Must call `model.update(force=True)` after loading for entropy models
299
+
300
+ ## Project Structure
301
+
302
+ - `.github/copilot-instructions.md`: This file - comprehensive development guide
303
+ - `examples.sh`: Example commands for running compression and segmentation
304
+ - `README.md`: Project overview and quick start guide
305
+ - `requirements.txt`: Python dependencies
306
+
307
+ **CLI Tools:**
308
+ - `roi_segmenter.py`: CLI tool for standalone segmentation
309
+ - `roi_compressor.py`: CLI tool for ROI-based image compression
310
+ - `app.py`: Gradio demo with Image and Video tabs
311
+
312
+ **Core Modules:**
313
+ - `segmentation/`: Modular segmentation with abstract base class
314
+ - `base.py`: `BaseSegmenter` abstract class
315
+ - `segformer.py`: Cityscapes semantic segmentation (19 classes)
316
+ - `yolo.py`: COCO instance segmentation (80 classes)
317
+ - `mask2former.py`: Swin-based panoptic/semantic (COCO: 133, ADE20K: 150 classes)
318
+ - `maskrcnn.py`: ResNet50-FPN instance segmentation (COCO: 80 classes)
319
+ - `sam3.py`: Prompt-based segmentation
320
+ - `factory.py`: Factory pattern for creating segmenters
321
+ - `utils.py`: Visualization and I/O utilities
322
+ - `vae/`: Modular compression with ROI support
323
+ - `tic_model.py`: Base `TIC` class (Transformer-based VAE)
324
+ - `RSTB.py`: Residual Swin Transformer Blocks
325
+ - `transformer_layers.py`: Generic transformer components
326
+ - `roi_tic.py`: `ModifiedTIC` class and checkpoint loading
327
+ - `utils.py`: `compress_image()`, `compute_padding()`
328
+ - `visualization.py`: `highlight_roi()`, `create_comparison_grid()`
329
+ - `video/`: Video compression with streaming support
330
+ - `video_processor.py`: `VideoProcessor` class for video compression
331
+ - `motion_analyzer.py`: `MotionAnalyzer` for scene complexity estimation
332
+ - `chunk_compressor.py`: `ChunkCompressor` and `BandwidthController`
333
+ - `detection/`: Object detection and tracking
334
+ - `tracker.py`: `SimpleTracker` IoU-based multi-object tracker
335
+ - `utils.py`: `draw_detections()`, `draw_tracks()`
336
+
337
+ ## Video Processing
338
+
339
+ ### Video Module Usage
340
+ ```python
341
+ from video import VideoProcessor, CompressionSettings
342
+
343
+ # Create processor
344
+ processor = VideoProcessor(device='cuda')
345
+ processor.load_models(
346
+ quality_level=4,
347
+ segmentation_method='sam3',
348
+ detection_method='yolo',
349
+ enable_tracking=True,
350
+ )
351
+
352
+ # Static mode (fixed settings)
353
+ settings = CompressionSettings(
354
+ mode='static',
355
+ quality_level=4,
356
+ sigma=0.3,
357
+ output_fps=15.0,
358
+ target_classes=['person', 'car'],
359
+ )
360
+
361
+ for chunk in processor.process_static('input.mp4', settings):
362
+ # Stream chunks in real-time
363
+ print(f"Chunk {chunk.chunk_index}: {len(chunk.frames)} frames at {chunk.fps} FPS")
364
+
365
+ # Dynamic mode (bandwidth-adaptive)
366
+ settings = CompressionSettings(
367
+ mode='dynamic',
368
+ target_bandwidth_kbps=500,
369
+ min_fps=5,
370
+ max_fps=30,
371
+ chunk_duration_sec=1.0,
372
+ target_classes=['person', 'car'],
373
+ )
374
+
375
+ for chunk in processor.process_dynamic('input.mp4', settings):
376
+ # Adaptive FPS and quality per chunk based on motion
377
+ print(f"Chunk {chunk.chunk_index}: fps={chunk.fps:.1f}, quality={chunk.quality_level}")
378
+ ```
379
+
380
+ ### Motion-Adaptive Compression
381
+ The dynamic mode analyzes each chunk for:
382
+ - **Motion magnitude**: Mean pixel change between frames
383
+ - **Motion coverage**: Fraction of pixels with significant motion
384
+ - **Scene complexity**: Edge density and texture variance
385
+ - **Scene changes**: Large global differences
386
+
387
+ High-motion scenes get:
388
+ - More frames (higher FPS)
389
+ - Higher spatial compression (lower quality/sigma) to stay within bandwidth
390
+
391
+ Low-motion scenes get:
392
+ - Fewer frames (lower FPS)
393
+ - Better spatial quality (higher quality/sigma)
394
+
395
+ ### Object Tracking
396
+ ```python
397
+ from detection import SimpleTracker, draw_tracks
398
+
399
+ tracker = SimpleTracker(iou_threshold=0.3, max_age=30)
400
+
401
+ for frame_detections in frame_by_frame_detections:
402
+ tracks = tracker.update(frame_detections)
403
+ # tracks contains track_id, label, bbox, history
404
+
405
+ # Draw tracks with trails
406
+ img = draw_tracks(frame, tracks, show_id=True, show_trail=True)
407
+ ```
408
+
409
+ ## Coding Guidelines
410
+ - Don't create unnecessary files—focus on core functionality.
411
+ - Ensure all scripts have clear argument parsing and help messages.
412
+ - Maintain consistent coding style and comments for clarity.
413
+ - Validate inputs (image paths, checkpoint paths, segmentation classes).
414
+ - Include error handling for common issues (file not found, dimension mismatches).
415
+ - Document all functions and classes with docstrings.
416
+ - Write modular code to facilitate testing and future extensions.
417
+ - Use ipynb files for prototyping but keep main logic in .py files.
.gitignore ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- Python ---
2
+ __pycache__/
3
+ *.py[cod]
4
+ *.pyd
5
+ *.so
6
+ *.egg-info/
7
+ dist/
8
+ build/
9
+ .eggs/
10
+ .pytest_cache/
11
+ .mypy_cache/
12
+ .ruff_cache/
13
+ .pytype/
14
+ .coverage
15
+ coverage.xml
16
+ htmlcov/
17
+
18
+ # --- Virtual environments ---
19
+ venv/
20
+ .venv/
21
+ env/
22
+ ENV/
23
+
24
+ # --- Jupyter ---
25
+ .ipynb_checkpoints/
26
+
27
+ # --- OS / editor ---
28
+ .DS_Store
29
+ Thumbs.db
30
+ .vscode/
31
+ .idea/
32
+
33
+ # --- Secrets / local config ---
34
+ .env
35
+ .env.*
36
+ *.key
37
+ *.pem
38
+
39
+ # --- Logs / temp ---
40
+ *.log
41
+ logs/
42
+ tmp/
43
+ .cache/
44
+
45
+ # --- Gradio / HF Spaces artifacts ---
46
+ flagged/
47
+ gradio_cached_examples/
48
+ .gradio/
49
+
50
+ # --- Data / media ---
51
+ data/
52
+
53
+ # --- Model + framework caches (keep your curated checkpoints, ignore auto-download caches) ---
54
+ checkpoints/hf/
55
+ checkpoints/torch/
56
+ checkpoints/yolo*.pt
57
+
58
+ # --- Common ML experiment outputs ---
59
+ runs/
60
+ wandb/
61
+ outputs/
62
+ results/
63
+
64
+ # --- Large artifacts (uncomment if you don't want binaries tracked) ---
65
+ # *.pth
66
+ # *.pt
67
+ # *.pth.tar
68
+ # *.onnx
69
+ # *.ckpt
API.md ADDED
@@ -0,0 +1,1029 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # API Documentation
2
+
3
+ This document describes the Gradio API endpoints exposed by the ROI-VAE image and video compression application. The API allows programmatic access to segmentation, compression, detection, and full pipeline processing for both images and videos.
4
+
5
+ **Live Demo:** https://biaslab2025-contextual-communication-demo.hf.space
6
+
7
+ ## Table of Contents
8
+
9
+ - [Quick Start](#quick-start)
10
+ - [Important Notes](#important-notes)
11
+ - [Image API Endpoints](#image-api-endpoints)
12
+ - [/segment](#1-segment---generate-roi-mask)
13
+ - [/compress](#2-compress---compress-image)
14
+ - [/detect](#3-detect---object-detection)
15
+ - [/detect_overlay](#31-detect_overlay---detection-with-visualization)
16
+ - [/process](#4-process---full-image-pipeline)
17
+ - [Video API Endpoints](#video-api-endpoints)
18
+ - [/segment_video](#1-segment_video---segment-video)
19
+ - [/compress_video](#2-compress_video---compress-video)
20
+ - [/detect_video](#3-detect_video---video-detection)
21
+ - [/process_video](#4-process_video---full-video-pipeline)
22
+ - [Streaming Video API Endpoints](#streaming-video-api-endpoints)
23
+ - [/stream_process_video](#1-stream_process_video---full-streaming-pipeline)
24
+ - [/stream_compress_video](#2-stream_compress_video---simplified-streaming-compression)
25
+ - [Class Reference](#class-reference)
26
+ - [Error Handling](#error-handling)
27
+ - [GPU Quota Handling](#handling-gpu-quota-on-hf-spaces)
28
+ - [cURL Examples](#using-with-curl)
29
+ - [Example Scripts](#example-scripts)
30
+
31
+ ---
32
+
33
+ ## Quick Start
34
+
35
+ ### Installation
36
+
37
+ ```bash
38
+ pip install gradio_client
39
+ ```
40
+
41
+ ### Image Processing
42
+
43
+ ```python
44
+ from gradio_client import Client, handle_file
45
+
46
+ # Connect to the API
47
+ client = Client("https://biaslab2025-contextual-communication-demo.hf.space")
48
+ # Or local: client = Client("http://localhost:7860")
49
+
50
+ # Full pipeline: segment → compress → detect
51
+ compressed, mask, bpp, ratio, coverage, detections_json = client.predict(
52
+ handle_file("path/to/image.jpg"),
53
+ "car, person", # segmentation prompt
54
+ "sam3", # segmentation method
55
+ 4, # quality level (1-5)
56
+ 0.3, # sigma (background compression)
57
+ True, # run detection
58
+ "yolo", # detection method
59
+ "", # detection classes (empty for closed-vocab)
60
+ api_name="/process"
61
+ )
62
+
63
+ print(f"Compression: {bpp:.4f} bpp ({ratio:.2f}x)")
64
+ ```
65
+
66
+ ### Video Processing
67
+
68
+ ```python
69
+ from gradio_client import Client, handle_file
70
+ import json
71
+
72
+ client = Client("http://localhost:7860")
73
+
74
+ # Full pipeline with static settings
75
+ output_video, stats_json = client.predict(
76
+ handle_file("path/to/video.mp4"),
77
+ "person, car", # segmentation classes
78
+ "sam3", # segmentation method
79
+ "static", # mode: "static" or "dynamic"
80
+ 4, # quality level (1-5)
81
+ 0.3, # sigma
82
+ 15.0, # output FPS
83
+ 500, # bandwidth (dynamic mode)
84
+ 5, # min_fps (dynamic mode)
85
+ 30, # max_fps (dynamic mode)
86
+ False, # run detection
87
+ "yolo", # detection method
88
+ None, # mask_file_path (optional)
89
+ api_name="/process_video"
90
+ )
91
+
92
+ stats = json.loads(stats_json)
93
+ print(f"Compressed video: {output_video}")
94
+ print(f"Total frames: {stats['total_frames']}")
95
+ ```
96
+
97
+ ---
98
+
99
+ ## Important Notes
100
+
101
+ ### File Handling
102
+
103
+ Always wrap file paths with `handle_file()` when using `gradio_client`:
104
+
105
+ ```python
106
+ from gradio_client import handle_file
107
+
108
+ # ✅ Correct
109
+ client.predict(handle_file("image.jpg"), ...)
110
+
111
+ # ❌ Incorrect - will fail with validation error
112
+ client.predict("image.jpg", ...)
113
+ ```
114
+
115
+ ### Detection Output Format
116
+
117
+ All detection endpoints return JSON strings with this structure:
118
+
119
+ ```python
120
+ import json
121
+
122
+ detections = json.loads(detections_json)
123
+ # Each detection has:
124
+ # - label: str (class name)
125
+ # - score: float (confidence 0-1)
126
+ # - bbox_xyxy: list[float] (bounding box [x1, y1, x2, y2])
127
+ ```
128
+
129
+ ### Open-Vocabulary Detectors
130
+
131
+ The following detectors require a `classes` parameter:
132
+ - `yolo_world` - YOLO-World
133
+ - `grounding_dino` - Grounding DINO
134
+
135
+ Closed-vocabulary detectors (`yolo`, `detr`, `faster_rcnn`, etc.) use pretrained COCO classes and ignore the `classes` parameter.
136
+
137
+ ---
138
+
139
+ ## Image API Endpoints
140
+
141
+ ### 1. `/segment` - Generate ROI Mask
142
+
143
+ Segments an image to create a Region of Interest (ROI) mask.
144
+
145
+ **Parameters:**
146
+
147
+ | Parameter | Type | Default | Description |
148
+ |-----------|------|---------|-------------|
149
+ | `image` | Image | required | Input image file |
150
+ | `prompt` | str | `"object"` | Comma-separated classes or natural language prompt |
151
+ | `method` | str | `"sam3"` | Segmentation method (see [methods](#segmentation-methods)) |
152
+ | `return_overlay` | bool | `False` | If `True`, returns image with ROI highlighted instead of mask |
153
+
154
+ **Returns:**
155
+
156
+ | Output | Type | Description |
157
+ |--------|------|-------------|
158
+ | `result_image` | Image | Grayscale mask OR image with ROI overlay (if `return_overlay=True`) |
159
+ | `roi_coverage` | float | Fraction of image covered by ROI (0.0-1.0) |
160
+ | `classes_used` | str | JSON list of classes/prompts used |
161
+
162
+ **Example:**
163
+
164
+ ```python
165
+ # Get binary mask (default)
166
+ mask, coverage, classes = client.predict(
167
+ handle_file("car_scene.jpg"),
168
+ "car, road",
169
+ "sam3",
170
+ False, # return_overlay
171
+ api_name="/segment"
172
+ )
173
+ print(f"ROI covers {coverage*100:.2f}% of image")
174
+
175
+ # Get image with ROI highlighted
176
+ highlighted, coverage, classes = client.predict(
177
+ handle_file("car_scene.jpg"),
178
+ "car, road",
179
+ "sam3",
180
+ True, # return_overlay=True
181
+ api_name="/segment"
182
+ )
183
+ ```
184
+
185
+ ---
186
+
187
+ ### 2. `/compress` - Compress Image
188
+
189
+ Compresses an image using TIC VAE, optionally with an ROI mask for variable quality.
190
+
191
+ **Parameters:**
192
+
193
+ | Parameter | Type | Default | Description |
194
+ |-----------|------|---------|-------------|
195
+ | `image` | Image | required | Input image file |
196
+ | `mask_image` | Image | `None` | ROI mask (white=ROI, black=background) |
197
+ | `quality` | int | `4` | Quality level 1-5 |
198
+ | `sigma` | float | `0.3` | Background preservation (0.01-1.0) |
199
+
200
+ **Quality Levels:**
201
+
202
+ | Level | Lambda | Description |
203
+ |-------|--------|-------------|
204
+ | 1 | 0.0035 | Smallest file |
205
+ | 2 | 0.013 | Smaller file |
206
+ | 3 | 0.025 | Balanced |
207
+ | 4 | 0.0483 | Higher quality (default) |
208
+ | 5 | 0.0932 | Best quality |
209
+
210
+ **Returns:**
211
+
212
+ | Output | Type | Description |
213
+ |--------|------|-------------|
214
+ | `compressed_image` | Image | Compressed output image |
215
+ | `bpp` | float | Bits per pixel |
216
+ | `compression_ratio` | float | Compression ratio (24/bpp) |
217
+
218
+ **Example:**
219
+
220
+ ```python
221
+ # Compress without mask (uniform quality)
222
+ compressed, bpp, ratio = client.predict(
223
+ handle_file("image.jpg"),
224
+ None, # no mask
225
+ 4, # quality
226
+ 0.3, # sigma (ignored without mask)
227
+ api_name="/compress"
228
+ )
229
+
230
+ # Compress with ROI mask
231
+ mask, _, _ = client.predict(handle_file("image.jpg"), "person", "yolo", False, api_name="/segment")
232
+
233
+ compressed, bpp, ratio = client.predict(
234
+ handle_file("image.jpg"),
235
+ handle_file(mask),
236
+ 4,
237
+ 0.2, # aggressive background compression
238
+ api_name="/compress"
239
+ )
240
+ ```
241
+
242
+ ---
243
+
244
+ ### 3. `/detect` - Object Detection
245
+
246
+ Runs object detection on an image and returns detection results as JSON.
247
+
248
+ **Parameters:**
249
+
250
+ | Parameter | Type | Default | Description |
251
+ |-----------|------|---------|-------------|
252
+ | `image` | Image | required | Input image file |
253
+ | `method` | str | `"yolo"` | Detection method (see [methods](#detection-methods)) |
254
+ | `classes` | str | `""` | Comma-separated classes (required for open-vocab detectors) |
255
+ | `confidence` | float | `0.25` | Confidence threshold (0.0-1.0) |
256
+
257
+ **Returns:**
258
+
259
+ | Output | Type | Description |
260
+ |--------|------|-------------|
261
+ | `detections_json` | str | JSON string of detection results |
262
+
263
+ **Example - Closed-Vocabulary:**
264
+
265
+ ```python
266
+ import json
267
+
268
+ # YOLO detection (COCO classes)
269
+ dets_json = client.predict(
270
+ handle_file("street_scene.jpg"),
271
+ "yolo",
272
+ "", # no classes needed
273
+ 0.25,
274
+ api_name="/detect"
275
+ )
276
+
277
+ detections = json.loads(dets_json)
278
+ for det in detections:
279
+ print(f"{det['label']}: {det['score']:.2f}")
280
+ ```
281
+
282
+ **Example - Open-Vocabulary:**
283
+
284
+ ```python
285
+ # YOLO-World with custom classes
286
+ dets_json = client.predict(
287
+ handle_file("image.jpg"),
288
+ "yolo_world",
289
+ "hat, backpack, umbrella", # custom classes required
290
+ 0.25,
291
+ api_name="/detect"
292
+ )
293
+ ```
294
+
295
+ ---
296
+
297
+ ### 3.1. `/detect_overlay` - Detection with Visualization
298
+
299
+ Runs object detection and returns the image with bounding boxes drawn.
300
+
301
+ **Parameters:**
302
+
303
+ | Parameter | Type | Default | Description |
304
+ |-----------|------|---------|-------------|
305
+ | `image` | Image | required | Input image file |
306
+ | `method` | str | `"yolo"` | Detection method (see [methods](#detection-methods)) |
307
+ | `classes` | str | `""` | Comma-separated classes (required for open-vocab detectors) |
308
+ | `confidence` | float | `0.25` | Confidence threshold (0.0-1.0) |
309
+
310
+ **Returns:**
311
+
312
+ | Output | Type | Description |
313
+ |--------|------|-------------|
314
+ | `result_image` | Image | Image with detection bounding boxes |
315
+ | `detections_json` | str | JSON string of detection results |
316
+
317
+ **Example:**
318
+
319
+ ```python
320
+ import json
321
+
322
+ # Get image with detection boxes
323
+ result_img, dets_json = client.predict(
324
+ handle_file("street_scene.jpg"),
325
+ "yolo",
326
+ "",
327
+ 0.25,
328
+ api_name="/detect_overlay"
329
+ )
330
+
331
+ # result_img is a file path to the image with boxes drawn
332
+ print(f"Image with boxes: {result_img}")
333
+ detections = json.loads(dets_json)
334
+ ```
335
+
336
+ ---
337
+
338
+ ### 4. `/process` - Full Image Pipeline
339
+
340
+ Runs the complete pipeline: segmentation → compression → optional detection.
341
+
342
+ **Parameters:**
343
+
344
+ | Parameter | Type | Default | Description |
345
+ |-----------|------|---------|-------------|
346
+ | `image` | Image | required | Input image file |
347
+ | `prompt` | str | `"object"` | Segmentation prompt/classes |
348
+ | `segmentation_method` | str | `"sam3"` | ROI segmentation method |
349
+ | `quality` | int | `4` | Compression quality (1-5) |
350
+ | `sigma` | float | `0.3` | Background preservation (0.01-1.0) |
351
+ | `run_detection` | bool | `False` | Whether to run detection on output |
352
+ | `detection_method` | str | `"yolo"` | Detector to use |
353
+ | `detection_classes` | str | `""` | Classes for open-vocab detectors |
354
+
355
+ **Returns:**
356
+
357
+ | Output | Type | Description |
358
+ |--------|------|-------------|
359
+ | `compressed_image` | Image | Compressed output image |
360
+ | `mask_image` | Image | Generated ROI mask |
361
+ | `bpp` | float | Bits per pixel |
362
+ | `compression_ratio` | float | Compression ratio |
363
+ | `roi_coverage` | float | ROI coverage percentage (0-1) |
364
+ | `detections_json` | str | JSON detections (empty list if `run_detection=False`) |
365
+
366
+ **Example:**
367
+
368
+ ```python
369
+ import json
370
+
371
+ compressed, mask, bpp, ratio, coverage, dets_json = client.predict(
372
+ handle_file("street.jpg"),
373
+ "car, person, road",
374
+ "sam3",
375
+ 4,
376
+ 0.3,
377
+ True, # run detection
378
+ "yolo",
379
+ "",
380
+ api_name="/process"
381
+ )
382
+
383
+ print(f"ROI Coverage: {coverage*100:.2f}%")
384
+ print(f"Compression: {bpp:.4f} bpp ({ratio:.2f}x)")
385
+ print(f"Detections: {len(json.loads(dets_json))}")
386
+ ```
387
+
388
+ ---
389
+
390
+ ## Video API Endpoints
391
+
392
+ ### 1. `/segment_video` - Segment Video
393
+
394
+ Segments a video to find ROI regions, returning either a mask file or overlay video.
395
+
396
+ **Parameters:**
397
+
398
+ | Parameter | Type | Default | Description |
399
+ |-----------|------|---------|-------------|
400
+ | `video_path` | Video | required | Input video file |
401
+ | `prompt` | str | `"object"` | Comma-separated classes or natural language prompt |
402
+ | `method` | str | `"sam3"` | Segmentation method |
403
+ | `return_overlay` | bool | `False` | If `True`, returns video with ROI highlighted |
404
+ | `output_fps` | float | `15.0` | Output framerate (max 30) |
405
+
406
+ **Returns:**
407
+
408
+ | Output | Type | Description |
409
+ |--------|------|-------------|
410
+ | `result_path` | File/Video | Mask file (NPZ) OR video with ROI overlay |
411
+ | `stats_json` | str | JSON with frame count, coverage, and classes |
412
+
413
+ **Example:**
414
+
415
+ ```python
416
+ import json
417
+
418
+ # Get mask file for reuse in compression
419
+ mask_file, stats_json = client.predict(
420
+ handle_file("video.mp4"),
421
+ "person, car",
422
+ "sam3",
423
+ False, # return masks file
424
+ 15.0, # fps
425
+ api_name="/segment_video"
426
+ )
427
+
428
+ stats = json.loads(stats_json)
429
+ print(f"Processed {stats['total_frames']} frames")
430
+ print(f"Avg ROI coverage: {stats['avg_roi_coverage']*100:.2f}%")
431
+
432
+ # Get video with ROI overlay for visualization
433
+ overlay_video, _ = client.predict(
434
+ handle_file("video.mp4"),
435
+ "person, car",
436
+ "sam3",
437
+ True, # return overlay video
438
+ 15.0,
439
+ api_name="/segment_video"
440
+ )
441
+ ```
442
+
443
+ ---
444
+
445
+ ### 2. `/compress_video` - Compress Video
446
+
447
+ Compresses a video with optional ROI mask preservation.
448
+
449
+ **Parameters:**
450
+
451
+ | Parameter | Type | Default | Description |
452
+ |-----------|------|---------|-------------|
453
+ | `video_path` | Video | required | Input video file |
454
+ | `mask_file_path` | str | `None` | Path to pre-computed masks (from `/segment_video`) |
455
+ | `quality` | int | `4` | Quality level (1-5) |
456
+ | `sigma` | float | `0.3` | Background preservation (0.01-1.0) |
457
+ | `output_fps` | float | `15.0` | Target output framerate |
458
+
459
+ **Returns:**
460
+
461
+ | Output | Type | Description |
462
+ |--------|------|-------------|
463
+ | `compressed_video` | Video | Compressed output video |
464
+ | `stats_json` | str | JSON with compression statistics |
465
+
466
+ **Example:**
467
+
468
+ ```python
469
+ import json
470
+
471
+ # First, segment to get masks
472
+ mask_file, _ = client.predict(
473
+ handle_file("video.mp4"), "person", "sam3", False, 15.0,
474
+ api_name="/segment_video"
475
+ )
476
+
477
+ # Then compress with cached masks (3-5x faster!)
478
+ compressed, stats_json = client.predict(
479
+ handle_file("video.mp4"),
480
+ mask_file, # reuse masks
481
+ 4, # quality
482
+ 0.3, # sigma
483
+ 15.0, # fps
484
+ api_name="/compress_video"
485
+ )
486
+
487
+ stats = json.loads(stats_json)
488
+ print(f"Compression ratio: {stats['compression_ratio']}x")
489
+ print(f"Total size: {stats['total_size_kb']} KB")
490
+ ```
491
+
492
+ ---
493
+
494
+ ### 3. `/detect_video` - Video Detection
495
+
496
+ Runs object detection on each frame of a video.
497
+
498
+ **Parameters:**
499
+
500
+ | Parameter | Type | Default | Description |
501
+ |-----------|------|---------|-------------|
502
+ | `video_path` | Video | required | Input video file |
503
+ | `method` | str | `"yolo"` | Detection method |
504
+ | `classes` | str | `""` | Comma-separated classes (required for open-vocab) |
505
+ | `confidence` | float | `0.25` | Confidence threshold (0.0-1.0) |
506
+ | `return_overlay` | bool | `False` | If `True`, returns video with detection boxes |
507
+ | `output_fps` | float | `15.0` | Output framerate (max 30) |
508
+
509
+ **Returns:**
510
+
511
+ | Output | Type | Description |
512
+ |--------|------|-------------|
513
+ | `result_video` | Video | Video with detection boxes (if `return_overlay=True`), None otherwise |
514
+ | `detections_json` | str | JSON with per-frame detections |
515
+
516
+ **Example:**
517
+
518
+ ```python
519
+ import json
520
+
521
+ # Get per-frame detections JSON
522
+ _, dets_json = client.predict(
523
+ handle_file("video.mp4"),
524
+ "yolo",
525
+ "",
526
+ 0.25,
527
+ False, # return JSON only
528
+ 15.0,
529
+ api_name="/detect_video"
530
+ )
531
+
532
+ data = json.loads(dets_json)
533
+ print(f"Total detections: {data['total_detections']}")
534
+ print(f"Avg per frame: {data['avg_detections_per_frame']}")
535
+
536
+ # Get video with detection overlays
537
+ det_video, _ = client.predict(
538
+ handle_file("video.mp4"),
539
+ "yolo",
540
+ "",
541
+ 0.25,
542
+ True, # return overlay video
543
+ 15.0,
544
+ api_name="/detect_video"
545
+ )
546
+ ```
547
+
548
+ ---
549
+
550
+ ### 4. `/process_video` - Full Video Pipeline
551
+
552
+ Processes a video with ROI-based compression (segment → compress), with optional detection.
553
+
554
+ **Parameters:**
555
+
556
+ | Parameter | Type | Default | Description |
557
+ |-----------|------|---------|-------------|
558
+ | `video_path` | Video | required | Input video file |
559
+ | `prompt` | str | `"object"` | Segmentation prompt/classes |
560
+ | `segmentation_method` | str | `"sam3"` | ROI segmentation method |
561
+ | `mode` | str | `"static"` | `"static"` or `"dynamic"` |
562
+ | `quality` | int | `4` | Quality level 1-5 (static mode) |
563
+ | `sigma` | float | `0.3` | Background preservation (static mode) |
564
+ | `output_fps` | float | `15.0` | Target framerate (static mode) |
565
+ | `bandwidth_kbps` | float | `500.0` | Target bandwidth (dynamic mode) |
566
+ | `min_fps` | float | `5.0` | Minimum framerate (dynamic mode) |
567
+ | `max_fps` | float | `30.0` | Maximum framerate (dynamic mode) |
568
+ | `aggressiveness` | float | `0.5` | Bandwidth savings strategy (dynamic mode): `0.0` = use full bandwidth (high FPS always), `0.5` = moderate savings, `1.0` = maximum savings (aggressive FPS reduction for low motion) |
569
+ | `run_detection` | bool | `False` | Whether to run detection/tracking |
570
+ | `detection_method` | str | `"yolo"` | Detector to use |
571
+ | `mask_file_path` | str | `None` | Path to pre-computed masks (skips segmentation) |
572
+
573
+ **Returns:**
574
+
575
+ | Output | Type | Description |
576
+ |--------|------|-------------|
577
+ | `output_video` | Video | Compressed video |
578
+ | `stats_json` | str | JSON with detailed statistics |
579
+
580
+ **Example - Static Mode:**
581
+
582
+ ```python
583
+ import json
584
+
585
+ output, stats_json = client.predict(
586
+ handle_file("video.mp4"),
587
+ "person, car",
588
+ "sam3",
589
+ "static",
590
+ 4, 0.3, 15.0, # static: quality, sigma, fps
591
+ 500, 5, 30, # dynamic: bandwidth, min_fps, max_fps (ignored)
592
+ False, "yolo", None,
593
+ api_name="/process_video"
594
+ )
595
+
596
+ stats = json.loads(stats_json)
597
+ print(f"Processed {stats['total_frames']} frames")
598
+ ```
599
+
600
+ **Example - Dynamic Mode:**
601
+
602
+ ```python
603
+ output, stats_json = client.predict(
604
+ handle_file("video.mp4"),
605
+ "person",
606
+ "yolo",
607
+ "dynamic",
608
+ 4, 0.3, 15.0, # static settings (ignored)
609
+ 750, # target bandwidth 750 kbps
610
+ 8, # min FPS
611
+ 30, # max FPS
612
+ True, "yolo", None,
613
+ api_name="/process_video"
614
+ )
615
+ ```
616
+
617
+ ---
618
+
619
+ ## Class Reference
620
+
621
+ ### Segmentation Methods
622
+
623
+ **Pixel-Perfect Segmentation:**
624
+
625
+ | Method | Description | Classes |
626
+ |--------|-------------|---------|
627
+ | `sam3` | Prompt-based (natural language) | Any text prompt |
628
+ ### Segmentation Methods
629
+
630
+ | Method | Description | Classes |
631
+ |--------|-------------|---------|
632
+ | `sam3` | Prompt-based (natural language) | Any text prompt |
633
+ | `yolo` | YOLO instance segmentation | 80 COCO classes |
634
+ | `segformer` | Cityscapes semantic segmentation | 19 classes |
635
+ | `mask2former` | Swin-based panoptic/semantic | 133 COCO / 150 ADE20K |
636
+ | `maskrcnn` | ResNet50-FPN instance segmentation | 80 COCO classes |
637
+ | `fake_yolo` | Fast bbox-based (YOLO + ByteTrack) | 80 COCO classes |
638
+ | `fake_yolo_botsort` | Fast bbox-based (YOLO + BoTSORT) | 80 COCO classes |
639
+ | `fake_detr` | Fast bbox-based (DETR + ByteTrack) | 80 COCO classes |
640
+ | `fake_fasterrcnn` | Fast bbox-based (Faster R-CNN + ByteTrack) | 80 COCO classes |
641
+ | `fake_retinanet` | Fast bbox-based (RetinaNet + ByteTrack) | 80 COCO classes |
642
+ | `fake_fcos` | Fast bbox-based (FCOS + ByteTrack) | 80 COCO classes |
643
+ | `fake_deformable_detr` | Fast bbox-based (Deformable DETR + ByteTrack) | 80 COCO classes |
644
+ | `fake_grounding_dino` | Fast bbox-based (Grounding DINO + ByteTrack) | Requires prompt |
645
+
646
+ **Note:** `fake_*` methods create rectangular masks from detection bounding boxes with object tracking. Faster than pixel-perfect segmentation, suitable for video when precise boundaries aren't critical.
647
+
648
+ ### Detection Methods
649
+
650
+ **Closed-Vocabulary (COCO pretrained):**
651
+
652
+ | Method | Description |
653
+ |--------|-------------|
654
+ | `yolo` | Ultralytics YOLO |
655
+ | `detr` | Facebook DETR |
656
+ | `faster_rcnn` | Faster R-CNN |
657
+ | `retinanet` | RetinaNet |
658
+ | `fcos` | FCOS |
659
+ | `ssd` | SSD300 |
660
+ | `deformable_detr` | Deformable DETR |
661
+
662
+ **Open-Vocabulary (requires `classes` parameter):**
663
+
664
+ | Method | Description |
665
+ |--------|-------------|
666
+ | `yolo_world` | YOLO-World |
667
+ | `grounding_dino` | Grounding DINO |
668
+
669
+ ### COCO Classes (80)
670
+
671
+ ```
672
+ person, bicycle, car, motorcycle, airplane, bus, train, truck, boat,
673
+ traffic light, fire hydrant, stop sign, parking meter, bench, bird, cat,
674
+ dog, horse, sheep, cow, elephant, bear, zebra, giraffe, backpack, umbrella,
675
+ handbag, tie, suitcase, frisbee, skis, snowboard, sports ball, kite,
676
+ baseball bat, baseball glove, skateboard, surfboard, tennis racket, bottle,
677
+ wine glass, cup, fork, knife, spoon, bowl, banana, apple, sandwich, orange,
678
+ broccoli, carrot, hot dog, pizza, donut, cake, chair, couch, potted plant,
679
+ bed, dining table, toilet, tv, laptop, mouse, remote, keyboard, cell phone,
680
+ microwave, oven, toaster, sink, refrigerator, book, clock, vase, scissors,
681
+ teddy bear, hair drier, toothbrush
682
+ ```
683
+
684
+ ### Cityscapes Classes (19)
685
+
686
+ ```
687
+ road, sidewalk, building, wall, fence, pole, traffic light, traffic sign,
688
+ vegetation, terrain, sky, person, rider, car, truck, bus, train, motorcycle,
689
+ bicycle
690
+ ```
691
+
692
+ ---
693
+
694
+ ## Error Handling
695
+
696
+ ```python
697
+ try:
698
+ result = client.predict(
699
+ handle_file("image.jpg"),
700
+ ...,
701
+ api_name="/endpoint"
702
+ )
703
+ except Exception as e:
704
+ print(f"API Error: {e}")
705
+ ```
706
+
707
+ **Common Errors:**
708
+
709
+ | Error | Cause | Solution |
710
+ |-------|-------|----------|
711
+ | Validation error for ImageData | Missing `handle_file()` | Wrap file paths with `handle_file()` |
712
+ | File does not exist | Invalid path | Check file path is correct |
713
+ | Empty detection classes | Open-vocab detector without classes | Provide classes for `yolo_world`, `grounding_dino` |
714
+ | GPU quota exceeded | HF Spaces limit | Wait and retry (see below) |
715
+
716
+ ---
717
+
718
+ ## Handling GPU Quota on HF Spaces
719
+
720
+ When using Hugging Face Spaces with ZeroGPU, you may encounter quota limits:
721
+
722
+ ```
723
+ You have exceeded your GPU quota (60s requested vs. 0s left). Try again in 0:05:30
724
+ ```
725
+
726
+ ### Automatic Retry with Backoff
727
+
728
+ ```python
729
+ import time
730
+ import re
731
+
732
+ def extract_wait_time(error_msg):
733
+ """Extract wait time from GPU quota error message."""
734
+ match = re.search(r'Try again in (\d+):(\d+)(?::(\d+))?', error_msg)
735
+ if match:
736
+ if match.group(3): # HH:MM:SS
737
+ return int(match.group(1)) * 3600 + int(match.group(2)) * 60 + int(match.group(3))
738
+ else: # MM:SS
739
+ return int(match.group(1)) * 60 + int(match.group(2))
740
+ return 60
741
+
742
+ def call_with_retry(client, *args, api_name, max_retries=5):
743
+ """Call API with exponential backoff retry."""
744
+ delay = 10
745
+
746
+ for attempt in range(max_retries):
747
+ try:
748
+ return client.predict(*args, api_name=api_name)
749
+ except Exception as e:
750
+ error_msg = str(e)
751
+ if "exceeded your GPU quota" in error_msg:
752
+ wait_time = extract_wait_time(error_msg)
753
+ actual_delay = max(delay, wait_time + 5)
754
+ print(f"⏳ GPU quota exhausted. Waiting {actual_delay}s... (attempt {attempt + 1})")
755
+ time.sleep(actual_delay)
756
+ delay *= 2
757
+ else:
758
+ raise
759
+ raise Exception("Max retries reached")
760
+
761
+ # Usage
762
+ result = call_with_retry(
763
+ client,
764
+ handle_file("image.jpg"),
765
+ "car", "sam3", False, 4, 0.3, False, "yolo", "",
766
+ api_name="/process"
767
+ )
768
+ ```
769
+
770
+ ---
771
+
772
+ ## Using with cURL
773
+
774
+ ### Upload File First
775
+
776
+ ```bash
777
+ # Upload image
778
+ FILE_URL=$(curl -s -X POST http://localhost:7860/upload \
779
+ -F "files=@image.jpg" | \
780
+ python3 -c "import sys, json; print(json.load(sys.stdin)[0])")
781
+ ```
782
+
783
+ ### Call Endpoints
784
+
785
+ ```bash
786
+ # Segment
787
+ curl -X POST http://localhost:7860/api/segment \
788
+ -H "Content-Type: application/json" \
789
+ -d "{\"data\": [\"$FILE_URL\", \"car, person\", \"sam3\", false]}"
790
+
791
+ # Compress (no mask)
792
+ curl -X POST http://localhost:7860/api/compress \
793
+ -H "Content-Type: application/json" \
794
+ -d "{\"data\": [\"$FILE_URL\", null, 4, 0.3]}"
795
+
796
+ # Detect
797
+ curl -X POST http://localhost:7860/api/detect \
798
+ -H "Content-Type: application/json" \
799
+ -d "{\"data\": [\"$FILE_URL\", \"yolo\", \"\", 0.25, false]}"
800
+
801
+ # Full pipeline
802
+ curl -X POST http://localhost:7860/api/process \
803
+ -H "Content-Type: application/json" \
804
+ -d "{\"data\": [\"$FILE_URL\", \"car, person\", \"sam3\", 4, 0.3, true, \"yolo\", \"\"]}"
805
+ ```
806
+
807
+ ---
808
+
809
+ ## Performance Guide
810
+
811
+ ### Choosing Segmentation Methods
812
+
813
+ **Use Pixel-Perfect Segmentation when:**
814
+ - You need precise object boundaries
815
+ - Working with single images or small videos
816
+ - Quality is more important than speed
817
+ - Computing time/power is not constrained
818
+
819
+ **Use Fast Segmentation (fake_*) when:**
820
+ - Processing large videos or real-time streams
821
+ - Speed is critical (2-3x faster)
822
+ - Rectangular masks are acceptable
823
+ - Need temporal consistency (tracking maintains object IDs)
824
+
825
+ ### Performance Benchmarks
826
+
827
+ **Video Processing (480p, 30 frames):**
828
+
829
+ | Method | Speed | Use Case |
830
+ |--------|-------|----------|
831
+ | `fake_yolo` | ~70 fps | Real-time video, fastest |
832
+ | `fake_yolo_botsort` | ~65 fps | Real-time with robust tracking |
833
+ | `fake_detr` | ~40 fps | Good speed + accuracy balance |
834
+ | `fake_fasterrcnn` | ~30 fps | Accurate detection |
835
+ | `yolo` (pixel-perfect) | ~30 fps | Instance segmentation |
836
+ | `sam3` | ~15 fps | Prompt-based, highest flexibility |
837
+ | `mask2former` | ~20 fps | Panoptic segmentation |
838
+
839
+ **Detection Performance (with batch support):**
840
+
841
+ | Detector | Single-Frame | Batch (30 frames) | Speedup |
842
+ |----------|--------------|-------------------|---------|
843
+ | YOLO26x | ~40 fps | ~70 fps | 1.75x |
844
+ | DETR | ~15 fps | ~40 fps | 2.67x |
845
+ | Faster R-CNN | ~12 fps | ~30 fps | 2.50x |
846
+
847
+ ### Example: Fast Video Processing
848
+
849
+ ```python
850
+ from gradio_client import Client, handle_file
851
+ import json
852
+ import time
853
+
854
+ client = Client("http://localhost:7860")
855
+
856
+ # Method 1: Fast fake segmentation (recommended for video)
857
+ start = time.time()
858
+ output1, stats1 = client.predict(
859
+ handle_file("long_video.mp4"),
860
+ "person, car",
861
+ "fake_yolo", # Fast detection + tracking
862
+ "static",
863
+ 4,
864
+ 0.3,
865
+ 15.0,
866
+ 500, 5, 30, False, "yolo", None,
867
+ api_name="/process_video"
868
+ )
869
+ fast_time = time.time() - start
870
+
871
+ # Method 2: Pixel-perfect segmentation
872
+ start = time.time()
873
+ output2, stats2 = client.predict(
874
+ handle_file("long_video.mp4"),
875
+ "person, car",
876
+ "yolo", # Pixel-perfect YOLO26x-seg
877
+ "static",
878
+ 4,
879
+ 0.3,
880
+ 15.0,
881
+ 500, 5, 30, False, "yolo", None,
882
+ api_name="/process_video"
883
+ )
884
+ perfect_time = time.time() - start
885
+
886
+ stats1_data = json.loads(stats1)
887
+ stats2_data = json.loads(stats2)
888
+
889
+ print(f"Fast segmentation: {fast_time:.2f}s")
890
+ print(f"Pixel-perfect: {perfect_time:.2f}s")
891
+ print(f"Speedup: {perfect_time/fast_time:.2f}x faster")
892
+ print(f"Compression ratio (fast): {stats1_data['compression_ratio']:.2f}x")
893
+ print(f"Compression ratio (perfect): {stats2_data['compression_ratio']:.2f}x")
894
+ ```
895
+
896
+ ### Example: Tracker Comparison
897
+
898
+ ```python
899
+ # Test different trackers with same detector
900
+ trackers = {
901
+ "ByteTrack (default)": "fake_yolo",
902
+ "BoTSORT": "fake_yolo_botsort",
903
+ }
904
+
905
+ for name, method in trackers.items():
906
+ output, stats = client.predict(
907
+ handle_file("test_video.mp4"),
908
+ "person",
909
+ method,
910
+ "static",
911
+ 4, 0.3, 15.0,
912
+ 500, 5, 30, False, "yolo", None,
913
+ api_name="/process_video"
914
+ )
915
+
916
+ stats_data = json.loads(stats)
917
+ print(f"{name}: {stats_data['avg_roi_coverage']:.2f}% avg coverage")
918
+ ```
919
+
920
+ ---
921
+
922
+ ## Example Scripts
923
+
924
+ ### Batch Image Processing
925
+
926
+ ```python
927
+ from gradio_client import Client, handle_file
928
+ from pathlib import Path
929
+
930
+ client = Client("http://localhost:7860")
931
+ output_dir = Path("compressed_output")
932
+ output_dir.mkdir(exist_ok=True)
933
+
934
+ for img_path in Path("images").glob("*.jpg"):
935
+ print(f"Processing {img_path.name}...")
936
+
937
+ compressed, mask, bpp, ratio, coverage, _ = client.predict(
938
+ handle_file(str(img_path)),
939
+ "car, person",
940
+ "sam3",
941
+ 4, 0.3,
942
+ False, "", "",
943
+ api_name="/process"
944
+ )
945
+
946
+ # Save compressed image
947
+ output_path = output_dir / f"compressed_{img_path.name}"
948
+ with open(output_path, "wb") as f:
949
+ f.write(open(compressed, "rb").read())
950
+
951
+ print(f" BPP: {bpp:.4f}, Ratio: {ratio:.2f}x, ROI: {coverage*100:.2f}%")
952
+ ```
953
+
954
+ ### Video Processing with Mask Caching
955
+
956
+ ```python
957
+ from gradio_client import Client, handle_file
958
+ import json
959
+
960
+ client = Client("http://localhost:7860")
961
+ video_path = "input_video.mp4"
962
+
963
+ # Step 1: Segment video (one-time cost)
964
+ mask_file, seg_stats = client.predict(
965
+ handle_file(video_path),
966
+ "person, car",
967
+ "sam3",
968
+ False, # return mask file
969
+ 15.0,
970
+ api_name="/segment_video"
971
+ )
972
+ print(f"Segmented video, masks saved to: {mask_file}")
973
+
974
+ # Step 2: Compress with different settings, reusing masks
975
+ for quality in [3, 4, 5]:
976
+ compressed, comp_stats = client.predict(
977
+ handle_file(video_path),
978
+ mask_file, # reuse cached masks
979
+ quality,
980
+ 0.3,
981
+ 15.0,
982
+ api_name="/compress_video"
983
+ )
984
+ stats = json.loads(comp_stats)
985
+ print(f"Quality {quality}: {stats['compression_ratio']}x compression")
986
+ ```
987
+
988
+ ### Detection Comparison (Original vs Compressed)
989
+
990
+ ```python
991
+ from gradio_client import Client, handle_file
992
+ import json
993
+
994
+ client = Client("http://localhost:7860")
995
+ image = "street_scene.jpg"
996
+
997
+ # Detect on original
998
+ _, dets_orig = client.predict(
999
+ handle_file(image), "yolo", "", 0.25, False,
1000
+ api_name="/detect"
1001
+ )
1002
+ orig_count = len(json.loads(dets_orig))
1003
+ print(f"Original: {orig_count} detections")
1004
+
1005
+ # Compress and detect
1006
+ compressed, _, bpp, ratio, _, dets_comp = client.predict(
1007
+ handle_file(image),
1008
+ "car, person, road",
1009
+ "sam3",
1010
+ 4, 0.3,
1011
+ True, "yolo", "",
1012
+ api_name="/process"
1013
+ )
1014
+ comp_count = len(json.loads(dets_comp))
1015
+
1016
+ retention = comp_count / orig_count * 100 if orig_count else 0
1017
+ print(f"Compressed ({ratio:.2f}x): {comp_count} detections")
1018
+ print(f"Detection retention: {retention:.1f}%")
1019
+ ```
1020
+
1021
+ ---
1022
+
1023
+ ## Additional Resources
1024
+
1025
+ - **Web UI**: Visit `http://localhost:7860` for interactive interface
1026
+ - **GitHub**: See repository for source code and examples
1027
+ - **Model Checkpoints**: Available in `checkpoints/` directory
1028
+ - **Test Images**: Sample images in `data/images/` directory
1029
+
README.md CHANGED
@@ -1,13 +1,513 @@
1
  ---
2
- title: Contextual Communication Demo 2
3
- emoji: 🦀
4
- colorFrom: purple
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 6.5.1
8
- python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Contextual Communication Demo
3
+ emoji: "📡"
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: "6.2.0"
 
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ # Contextual Communication Demo
13
+
14
+ An interactive demo for **contextual communication in bandwidth-degraded environments** (e.g., ISR collection from drones). The core idea is **context-aware compression**: transmit an extremely compact latent representation while ensuring the **decoded output remains useful for downstream decision-making** (e.g., object detection).
15
+
16
+ This repository implements **contextual spatial compression** for EO/IR-style imagery using an ROI-aware learned image compression model (TIC-style VAE) guided by segmentation masks.
17
+
18
+ ## Features
19
+
20
+ - **Contextual (ROI) compression**: preserves fidelity in mission-relevant regions while aggressively compressing non-relevant background.
21
+ - **Mission-driven context extraction**: map a mission prompt to ROI masks via multiple segmentation strategies:
22
+ - Class-based segmentation (SegFormer / YOLO / Mask2Former / Mask R-CNN)
23
+ - Prompt/referring segmentation (SAM3)
24
+ - Optional object detection overlays to visualize task retention on the decoded image
25
+ - **Two operator knobs** for bandwidth adaptation:
26
+ - **Background preservation** ($\sigma$, 0.01–1.0): lower = more background degradation
27
+ - **Transmission quality** (checkpoint/lambda selection): higher = larger payload / better reconstruction
28
+ - **CLI tools** for segmentation, ROI compression, and before/after detection retention.
29
+
30
+ ## Setup
31
+
32
+ ```bash
33
+ pip install -r requirements.txt
34
+ ```
35
+
36
+ Checkpoints are expected under `checkpoints/` (e.g., `checkpoints/tic_lambda_0.0483.pth.tar`).
37
+
38
+ By default, model weights/caches downloaded by detection/segmentation backends are also stored under `checkpoints/`:
39
+
40
+ - Hugging Face models under `checkpoints/hf/`
41
+ - Torch/torchvision weights under `checkpoints/torch/`
42
+
43
+ ## Usage
44
+
45
+ ### Interactive Demo (Hugging Face Spaces / Local)
46
+
47
+ This repo includes a Gradio app intended for Hugging Face Spaces (`app_file: app.py`). To run locally:
48
+
49
+ ```bash
50
+ python app.py
51
+ ```
52
+
53
+ In the UI:
54
+
55
+ - Enter a **Mission** and choose a **Context Extraction Method (ROI)**.
56
+ - Tune the two knobs to match bandwidth constraints:
57
+ - **Transmission quality** (checkpoint selection)
58
+ - **Background preservation** ($\sigma$)
59
+ - Optionally enable **object detection overlays**.
60
+
61
+ Note: the app includes a **Video** tab placeholder (inactive).
62
+
63
+ ### CLI: Contextual Spatial Compression (Images)
64
+
65
+ ```bash
66
+ python roi_compressor.py \
67
+ --input data/images/car/0016cf15fa4d4e16.jpg \
68
+ --output results/compressed.jpg \
69
+ --checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
70
+ --sigma 0.3 \
71
+ --seg-method yolo \
72
+ --seg-classes car \
73
+ --highlight
74
+ ```
75
+
76
+ Key arguments:
77
+
78
+ - `--sigma`: background quality (lower = more compression)
79
+ - `--seg-method`: `segformer`, `yolo`, `mask2former`, `maskrcnn`
80
+ - `--load-mask`: bypass segmentation using a precomputed mask
81
+
82
+ ### CLI: Segmentation Only
83
+
84
+ ```bash
85
+ python roi_segmenter.py \
86
+ --input data/images/car/0016cf15fa4d4e16.jpg \
87
+ --output results/mask.png \
88
+ --method segformer \
89
+ --classes car \
90
+ --visualize
91
+ ```
92
+
93
+ Prompt-based segmentation (SAM3):
94
+
95
+ ```bash
96
+ python roi_segmenter.py \
97
+ --input data/images/car/0016cf15fa4d4e16.jpg \
98
+ --output results/mask.png \
99
+ --method sam3 \
100
+ --prompt "a car" \
101
+ --visualize
102
+ ```
103
+
104
+ ### CLI: Detection Retention (Before vs After)
105
+
106
+ Compare original vs already-compressed:
107
+
108
+ ```bash
109
+ python roi_detection_eval.py \
110
+ --before data/images/car/0016cf15fa4d4e16.jpg \
111
+ --after results/compressed.jpg \
112
+ --detectors yolo fasterrcnn detr \
113
+ --viz-dir results/det_viz
114
+ ```
115
+
116
+ Or generate the "after" image via ROI compression and then evaluate:
117
+
118
+ ```bash
119
+ python roi_detection_eval.py \
120
+ --before data/images/car/0016cf15fa4d4e16.jpg \
121
+ --checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
122
+ --sigma 0.3 \
123
+ --seg-method yolo --seg-classes car \
124
+ --detectors yolo fasterrcnn \
125
+ --save-after results/after.jpg \
126
+ --viz-dir results/det_viz
127
+ ```
128
+
129
+ Open-vocabulary example (YOLO-World):
130
+
131
+ ```bash
132
+ python roi_detection_eval.py \
133
+ --before data/images/person/kodim04.png \
134
+ --checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
135
+ --sigma 0.3 \
136
+ --seg-method yolo --seg-classes person \
137
+ --detectors yolo_world \
138
+ --open-vocab-classes "person,car" \
139
+ --viz-dir results/det_viz
140
+ ```
141
+
142
+ ## Project Structure
143
+
144
+ ```
145
+ .
146
+ ├── app.py # Gradio demo (Hugging Face Spaces)
147
+ ├── model_cache.py # Cache routing to `checkpoints/`
148
+ ├── roi_compressor.py # CLI: contextual (ROI) image compression
149
+ ├── roi_segmenter.py # CLI: ROI mask generation
150
+ ├── roi_detection_eval.py # CLI: before/after detection retention
151
+ ├── segmentation/ # Segmenters + factory
152
+ ├── detection/ # Detectors + factory
153
+ ├── vae/ # ROI-aware TIC model + compression utils
154
+ ├── checkpoints/ # Compression checkpoints + model caches
155
+ ├── data/images/ # Sample images
156
+ ├── examples.sh
157
+ └── _segmentation_comparison.ipynb
158
+ ```
159
+
160
+ ## Modular API
161
+
162
+ Segmentation:
163
+
164
+ ```python
165
+ from segmentation import create_segmenter
166
+
167
+ segmenter = create_segmenter("yolo", device="cuda", conf_threshold=0.3)
168
+ mask = segmenter(image, target_classes=["car", "person"])
169
+ ```
170
+
171
+ Compression:
172
+
173
+ ```python
174
+ from vae import load_checkpoint, compress_image
175
+
176
+ model = load_checkpoint("checkpoints/tic_lambda_0.0483.pth.tar", device="cuda")
177
+ out = compress_image(image, mask, model, sigma=0.3, device="cuda")
178
+ compressed = out["compressed"]
179
+ bpp = out["bpp"]
180
+ ```
181
+
182
+ ## Notes
183
+
184
+ - OpenCV is included via `opencv-python-headless` (recommended for server/Spaces environments).
185
+ - Some backends download weights on first use; caches are routed under `checkpoints/`.
186
+ - Output directories like `results/` are created at runtime by the CLIs.
187
+ ---
188
+ title: Contextual Communication Demo
189
+ emoji: "📡"
190
+ colorFrom: blue
191
+ colorTo: purple
192
+ sdk: gradio
193
+ sdk_version: "6.2.0"
194
+ app_file: app.py
195
+ pinned: false
196
+ ---
197
+
198
+ # Contextual Communication Demo
199
+
200
+ An interactive demo for **contextual communication in bandwidth-degraded environments** (e.g., ISR collection from drones). The core idea is **context-aware compression**: transmit an extremely compact latent representation while ensuring the **decoded output remains useful for downstream decision-making** (e.g., object detection).
201
+
202
+ This repository implements **contextual spatial compression** for EO/IR-style imagery using an ROI-aware learned image compression model (TIC-style VAE) guided by segmentation masks.
203
+
204
+ ## Features
205
+
206
+ - **Contextual (ROI) compression**: Preserves fidelity in mission-relevant regions while aggressively compressing non-relevant background.
207
+ - **Mission-driven context extraction**: A mission prompt can be mapped to ROI masks via multiple segmentation strategies:
208
+ - **Class-based segmentation** (e.g., SegFormer / YOLO / Mask2Former / Mask R-CNN)
209
+ - **Prompt/referring segmentation** (SAM3)
210
+ - Optional **object detection overlays** to evaluate task retention on decoded outputs
211
+ - **Two operator knobs** for bandwidth adaptation:
212
+ - **Background preservation** (`sigma`, 0.01–1.0): lower = more background degradation
213
+ - **Overall quality level** (checkpoint/lambda selection): higher = larger file / better reconstruction
214
+ - **Visualization**: Compare input vs decoded output and optionally highlight context regions.
215
+ - **CLI tools**: Scripts for segmentation, ROI compression, and before/after detection eval.
216
+
217
+ ## Setup
218
+
219
+ 1. **Install Dependencies**:
220
+ ```bash
221
+ pip install -r requirements.txt
222
+ ```
223
+
224
+ 2. **Model Checkpoints**:
225
+ Checkpoints are located in `checkpoints/` directory. Main checkpoint: `checkpoints/tic_lambda_0.0483.pth.tar`
226
+
227
+ By default, model weights/caches downloaded by detection/segmentation backends are also stored under `checkpoints/`
228
+ (Hugging Face models under `checkpoints/hf/`, torchvision weights under `checkpoints/torch/`).
229
+
230
+ ## Usage
231
+
232
+ ### Interactive Demo (Hugging Face Spaces / Local)
233
+
234
+ This repo includes a Gradio app intended for Hugging Face Spaces (`app_file: app.py`). To run locally:
235
+
236
+ ```bash
237
+ python app.py
238
+ ```
239
+
240
+ In the UI:
241
+
242
+ - Enter a **Mission** and choose a **Context Extraction Method (ROI)**.
243
+ - Tune the two knobs to match bandwidth constraints:
244
+ - **Transmission quality** (checkpoint selection)
245
+ - **Background preservation** ($\sigma$)
246
+ - Optionally enable **object detection overlays** to visualize task retention on the decoded image.
247
+
248
+ Note: the app includes a **Video** tab placeholder (inactive).
249
+
250
+ ### Contextual Spatial Compression (Images)
251
+
252
+ Run the compression script with an input image:
253
+
254
+ ```bash
255
+ python roi_compressor.py \
256
+ --input data/images/car/0016cf15fa4d4e16.jpg \
257
+ --output results/compressed.jpg \
258
+ --checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
259
+ --sigma 0.3 \
260
+ --seg-classes car \
261
+ --highlight
262
+ ```
263
+
264
+ **Arguments:**
265
+ - `--input`: Path to input image.
266
+ - `--output`: Path to save compressed image.
267
+ - `--checkpoint`: Path to model checkpoint.
268
+ - `--sigma`: Background quality factor (lower = more compression). Default: 0.3.
269
+ - `--lambda`: Rate-distortion tradeoff parameter (default: 0.0483).
270
+ - `--seg-method`: Segmentation method (`segformer`, `yolo`, `mask2former`, `maskrcnn`). Default: `segformer`.
271
+ - `--seg-classes`: List of classes to treat as ROI (e.g., `car`, `person`).
272
+ - `--highlight`: Save a comparison grid with ROI highlighted.
273
+
274
+ Tip: you can bypass segmentation by providing `--load-mask`.
275
+
276
+ ### Segmentation Only
277
+
278
+ Generate segmentation masks without compression:
279
+
280
+ ```bash
281
+ python roi_segmenter.py \
282
+ --input data/images/car/0016cf15fa4d4e16.jpg \
283
+ --output results/mask.png \
284
+ --method segformer \
285
+ --classes car \
286
+ --visualize
287
+ ```
288
+
289
+ Prompt-based segmentation (SAM3):
290
+
291
+ ```bash
292
+ python roi_segmenter.py \
293
+ --input data/images/car/0016cf15fa4d4e16.jpg \
294
+ --output results/mask.png \
295
+ --method sam3 \
296
+ --prompt "a car" \
297
+ --visualize
298
+ ```
299
+
300
+ ## Project Structure
301
+
302
+ ```
303
+ .
304
+ ├── app.py # Gradio demo (Hugging Face Spaces)
305
+ ├── README.md
306
+ ├── requirements.txt
307
+ ├── model_cache.py # Cache routing to `checkpoints/`
308
+ ├── examples.sh # Example CLI commands
309
+ ├── _segmentation_comparison.ipynb
310
+ ├── roi_compressor.py # CLI: contextual (ROI) image compression
311
+ ├── roi_segmenter.py # CLI: ROI mask generation
312
+ ├── roi_detection_eval.py # CLI: before/after detection retention
313
+ ├── checkpoints/ # Compression checkpoints + model caches
314
+ ├── data/images/ # Sample images
315
+ ├── segmentation/ # Segmenters + factory
316
+ ├── detection/ # Detectors + factory
317
+ └── vae/ # ROI-aware TIC model + compression utils
318
+ ```
319
+
320
+ ## Modular API
321
+
322
+ ### Using Segmentation Module
323
+
324
+ ```python
325
+ from segmentation import create_segmenter
326
+
327
+ # Create a segmenter
328
+ segmenter = create_segmenter('yolo', device='cuda', conf_threshold=0.3)
329
+
330
+ # Segment image
331
+ mask = segmenter(image, target_classes=['car', 'person'])
332
+ ```
333
+
334
+ ### Using Compression Module
335
+
336
+ ```python
337
+ from vae import load_checkpoint, compress_image
338
+ from PIL import Image
339
+
340
+ # Load model
341
+ model = load_checkpoint('checkpoints/tic_lambda_0.0483.pth.tar', device='cuda')
342
+
343
+ # Compress with ROI mask
344
+ result = compress_image(image, mask, model, sigma=0.3, device='cuda')
345
+ compressed_img = result['compressed']
346
+ bpp = result['bpp']
347
+ ```
348
+ ## Object Detection (New)
349
+
350
+ An extendable object detection module is available in `detection/` with multiple implemented backends:
351
+
352
+ - YOLO (Ultralytics)
353
+ - YOLO-World (Ultralytics, open-vocabulary)
354
+ - Faster R-CNN (torchvision)
355
+ - RetinaNet (torchvision)
356
+ - SSD (torchvision)
357
+ - FCOS (torchvision)
358
+ - DETR (transformers)
359
+ - Deformable DETR (transformers, if supported by your installed version)
360
+ - EfficientDet (optional, requires `effdet`)
361
+ - Grounding DINO (transformers, open-vocabulary)
362
+
363
+ Open-vocabulary detectors (YOLO-World / Grounding DINO) require text prompts/classes at runtime.
364
+
365
+ ### Evaluate Detection Before/After ROI Compression
366
+
367
+ Compare an original image vs an already-compressed image:
368
+
369
+ ```bash
370
+ python roi_detection_eval.py \
371
+ --before data/images/car/0016cf15fa4d4e16.jpg \
372
+ --after results/compressed.jpg \
373
+ --detectors yolo fasterrcnn detr \
374
+ --viz-dir results/det_viz
375
+ ```
376
+
377
+ Or generate the "after" image via ROI compression and then evaluate:
378
+
379
+ ```bash
380
+ python roi_detection_eval.py \
381
+ --before data/images/car/0016cf15fa4d4e16.jpg \
382
+ --checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
383
+ --sigma 0.3 \
384
+ --seg-method yolo --seg-classes car \
385
+ --detectors yolo fasterrcnn \
386
+ --save-after results/after.jpg \
387
+ --viz-dir results/det_viz
388
+
389
+ ```
390
+
391
+ Open-vocabulary example (YOLO-World):
392
+
393
+ ```bash
394
+ python roi_detection_eval.py \
395
+ --before data/images/person/kodim04.png \
396
+ --checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
397
+ --sigma 0.3 \
398
+ --seg-method yolo --seg-classes person \
399
+ --detectors yolo_world \
400
+ --open-vocab-classes "person,car" \
401
+ --viz-dir results/det_viz
402
+ ```
403
+
404
+ Open-vocabulary example (Grounding DINO):
405
+
406
+ ```bash
407
+ python roi_detection_eval.py \
408
+ --before data/images/car/0016cf15fa4d4e16.jpg \
409
+ --checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
410
+ --sigma 0.3 \
411
+ --seg-method yolo --seg-classes car \
412
+ --detectors grounding_dino \
413
+ --open-vocab-classes "car,person" \
414
+ --viz-dir results/det_viz
415
+ ```
416
+
417
+ ## Programmatic API
418
+
419
+ The application exposes a Gradio API for programmatic access to all features:
420
+
421
+ ### Image API
422
+ - `/segment` - Segment image → mask or overlay
423
+ - `/compress` - Compress image with optional ROI mask
424
+ - `/detect` - Run object detection → JSON or overlay
425
+ - `/process` - Full pipeline: segment → compress → detect
426
+
427
+ ### Video API (Buffered)
428
+ - `/segment_video` - Segment video → mask file or overlay video
429
+ - `/compress_video` - Compress video with optional cached masks
430
+ - `/detect_video` - Run detection on video → JSON or overlay video
431
+ - `/process_video` - Full pipeline with static/dynamic modes
432
+
433
+ ### Video API (Streaming - NEW!)
434
+ - `/stream_process_video` - Stream compressed chunks progressively (HLS-style)
435
+ - `/stream_compress_video` - Stream chunks with pre-computed masks
436
+
437
+ **Key difference**: Streaming endpoints yield chunks as they're produced (low latency, ~1 second for first chunk) instead of buffering the entire video. Perfect for real-time streaming applications.
438
+
439
+ See [API.md](API.md) for complete documentation with examples.
440
+ See [STREAMING_API.md](STREAMING_API.md) for streaming API guide and comparison.
441
+
442
+ ### Quick Example
443
+
444
+ ```python
445
+ from gradio_client import Client, handle_file
446
+
447
+ client = Client("http://localhost:7860")
448
+
449
+ # Image: segment → compress → detect
450
+ compressed, mask, bpp, ratio, coverage, detections = client.predict(
451
+ handle_file("image.jpg"),
452
+ "car, person", # mission prompt
453
+ "sam3", # ROI method
454
+ 4, # quality level (1-5)
455
+ 0.3, # sigma (background preservation)
456
+ True, # run detection
457
+ "yolo", # detection method
458
+ "", # detection classes
459
+ api_name="/process"
460
+ )
461
+
462
+ # Video: streaming compression (chunk-by-chunk)
463
+ chunk_stream = client.submit(
464
+ handle_file("video.mp4"),
465
+ "person, car",
466
+ "sam3", "static",
467
+ 4, 0.3, 15.0,
468
+ api_name="/stream_process_video"
469
+ )
470
+
471
+ for chunk_json in chunk_stream:
472
+ chunk = json.loads(chunk_json)
473
+ if chunk.get("status") == "complete":
474
+ break
475
+ print(f"Chunk {chunk['chunk_index']}: {len(chunk['frames'])} frames")
476
+ ```
477
+
478
+ ### JavaScript/Frontend Integration
479
+
480
+ **Yes, streaming works great with JavaScript!** The `@gradio/client` package fully supports async iterators for streaming:
481
+
482
+ ```javascript
483
+ import { Client } from "@gradio/client";
484
+
485
+ const client = await Client.connect("http://localhost:7860");
486
+ const stream = client.submit("/stream_process_video", {
487
+ video_path: videoFile,
488
+ prompt: "person, car",
489
+ segmentation_method: "sam3",
490
+ mode: "static",
491
+ quality: 4,
492
+ sigma: 0.3,
493
+ output_fps: 15.0,
494
+ frame_format: "jpeg",
495
+ frame_quality: 85
496
+ });
497
+
498
+ for await (const msg of stream) {
499
+ const chunk = JSON.parse(msg.data);
500
+ if (chunk.status === "complete") break;
501
+
502
+ // Display frames immediately
503
+ displayFrame(`data:image/jpeg;base64,${chunk.frames[0]}`);
504
+ }
505
+ ```
506
+
507
+ **Complete examples available:**
508
+ - [examples/streaming_demo.html](examples/streaming_demo.html) - Standalone HTML demo
509
+ - [examples/streaming_client.ts](examples/streaming_client.ts) - React/TypeScript/Vanilla JS examples
510
+
511
+ See [STREAMING_API.md](STREAMING_API.md) for detailed streaming guide.```
512
+
513
+ ````
_segmentation_comparison.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
The diff for this file is too large to render. See raw diff
 
checkpoints/tic_lambda_0.0035.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03f2627ba49ddc117c2235868e179e48cd7fbb343b0ab637f6bac575f447b44f
3
+ size 93931400
checkpoints/tic_lambda_0.013.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:821df1b78110d8bb6d01809a55b64ea1da8f17ed5e893d33ed258ee180054385
3
+ size 93922614
checkpoints/tic_lambda_0.025.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:867f34e9949cc02800d1c8eff48c4f67ff7436d0e8c8208aec572fd78d83affb
3
+ size 168141319
checkpoints/tic_lambda_0.0483.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9686371eed833dc2a6fdd7c07cf3695290c482f53b75db5be662892a45a71430
3
+ size 168229460
checkpoints/tic_lambda_0.0932.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b7ddfa2f5f1b5082d087135a806f11ae65e8f8da5f724c5e510db264d3816c0
3
+ size 168141383
detection/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Object detection module.
2
+
3
+ Provides an extendable interface + factory similar to `segmentation/`.
4
+
5
+ Detectors:
6
+ - yolo (Ultralytics)
7
+ - yolo_world (Ultralytics YOLO-World; open-vocabulary)
8
+ - fasterrcnn, retinanet, ssd, fcos (torchvision)
9
+ - efficientdet (optional: effdet)
10
+ - detr, deformable_detr (transformers)
11
+ - grounding_dino (transformers; open-vocabulary)
12
+
13
+ Tracking:
14
+ - SimpleTracker (IoU-based multi-object tracking)
15
+ """
16
+
17
+ from .base import BaseDetector, Detection
18
+ from .factory import create_detector, register_detector, get_available_detectors
19
+ from .tracker import SimpleTracker, Track, draw_tracks
20
+
21
+ __all__ = [
22
+ "BaseDetector",
23
+ "Detection",
24
+ "create_detector",
25
+ "register_detector",
26
+ "get_available_detectors",
27
+ "SimpleTracker",
28
+ "Track",
29
+ "draw_tracks",
30
+ ]
detection/base.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Abstract base class for object detection models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from abc import ABC, abstractmethod
6
+ from dataclasses import dataclass
7
+ from typing import Any, Dict, List, Optional, Sequence, Union
8
+
9
+ from PIL import Image
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class Detection:
14
+ """Single detection result."""
15
+
16
+ label: str
17
+ score: float
18
+ # [x1, y1, x2, y2] in pixel coordinates
19
+ bbox_xyxy: List[float]
20
+
21
+
22
+ class BaseDetector(ABC):
23
+ """Common interface for all object detectors."""
24
+
25
+ def __init__(self, device: str = "cuda", **kwargs: Any):
26
+ self.device = device
27
+ self.model = None
28
+ self._is_loaded = False
29
+
30
+ @abstractmethod
31
+ def load_model(self) -> None:
32
+ """Load weights and prepare for inference."""
33
+
34
+ @abstractmethod
35
+ def detect(
36
+ self,
37
+ image: Image.Image,
38
+ conf_threshold: float = 0.25,
39
+ **kwargs: Any,
40
+ ) -> List[Detection]:
41
+ """Run detection and return a list of detections."""
42
+
43
+ @abstractmethod
44
+ def get_available_classes(self) -> Union[List[str], Dict[str, int], None]:
45
+ """Return the class list (or mapping) supported by this detector.
46
+
47
+ For open-vocabulary / prompt-based detectors, return None.
48
+ """
49
+
50
+ def ensure_loaded(self) -> None:
51
+ if not self._is_loaded:
52
+ self.load_model()
53
+ self._is_loaded = True
54
+
55
+ def __call__(
56
+ self,
57
+ image: Image.Image,
58
+ conf_threshold: float = 0.25,
59
+ **kwargs: Any,
60
+ ) -> List[Detection]:
61
+ self.ensure_loaded()
62
+ return self.detect(image=image, conf_threshold=conf_threshold, **kwargs)
63
+
64
+ def _classes_to_list(self) -> List[str]:
65
+ avail = self.get_available_classes()
66
+ if avail is None:
67
+ return []
68
+ if isinstance(avail, dict):
69
+ return list(avail.keys())
70
+ if isinstance(avail, (list, tuple, set)):
71
+ return list(avail)
72
+ try:
73
+ return list(avail) # type: ignore[arg-type]
74
+ except Exception:
75
+ return []
76
+
77
+ def supports_label(self, label: str) -> bool:
78
+ """Best-effort check whether the detector has a given label."""
79
+
80
+ classes = [c.lower() for c in self._classes_to_list()]
81
+ if not classes:
82
+ return True
83
+ return label.lower() in classes
detection/bytetrack.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ByteTrack: Multi-object tracker using high/low confidence detection matching.
2
+
3
+ Based on: https://github.com/ifzhang/ByteTrack
4
+ Simple, fast, and strong multi-object tracker without ReID features.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass, field
10
+ from typing import Dict, List, Optional, Tuple
11
+
12
+ import numpy as np
13
+
14
+
15
+ def linear_assignment(cost_matrix: np.ndarray) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
16
+ """Solve linear assignment problem using LAP (Hungarian algorithm).
17
+
18
+ Args:
19
+ cost_matrix: Cost matrix (N x M)
20
+
21
+ Returns:
22
+ matches: Array of (row, col) pairs
23
+ unmatched: Tuple of (unmatched_rows, unmatched_cols)
24
+ """
25
+ try:
26
+ import lap
27
+ _, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=1e6)
28
+ matches = np.array([[idx, x[idx]] for idx in range(len(x)) if x[idx] >= 0])
29
+ unmatched_a = np.where(x < 0)[0]
30
+ unmatched_b = np.where(y < 0)[0]
31
+ return matches, (unmatched_a, unmatched_b)
32
+ except ImportError:
33
+ # Fallback to greedy matching
34
+ return _greedy_assignment(cost_matrix)
35
+
36
+
37
+ def _greedy_assignment(cost_matrix: np.ndarray) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
38
+ """Greedy assignment fallback."""
39
+ matches = []
40
+ matched_rows = set()
41
+ matched_cols = set()
42
+
43
+ n_rows, n_cols = cost_matrix.shape
44
+ pairs = []
45
+ for i in range(n_rows):
46
+ for j in range(n_cols):
47
+ pairs.append((cost_matrix[i, j], i, j))
48
+
49
+ pairs.sort(key=lambda x: x[0])
50
+
51
+ for cost, i, j in pairs:
52
+ if i not in matched_rows and j not in matched_cols:
53
+ matches.append([i, j])
54
+ matched_rows.add(i)
55
+ matched_cols.add(j)
56
+
57
+ unmatched_rows = np.array([i for i in range(n_rows) if i not in matched_rows])
58
+ unmatched_cols = np.array([j for j in range(n_cols) if j not in matched_cols])
59
+
60
+ return np.array(matches) if matches else np.empty((0, 2)), (unmatched_rows, unmatched_cols)
61
+
62
+
63
+ @dataclass
64
+ class STrack:
65
+ """Single track for ByteTrack."""
66
+
67
+ track_id: int
68
+ label: str
69
+ tlbr: np.ndarray # [x1, y1, x2, y2]
70
+ score: float
71
+
72
+ # State
73
+ state: str = "tracked" # tracked, lost, removed
74
+ frame_id: int = 0
75
+ tracklet_len: int = 0
76
+
77
+ # History
78
+ history: List[Tuple[int, np.ndarray, float]] = field(default_factory=list)
79
+
80
+ @property
81
+ def tlwh(self) -> np.ndarray:
82
+ """Top-left-width-height format."""
83
+ x1, y1, x2, y2 = self.tlbr
84
+ return np.array([x1, y1, x2 - x1, y2 - y1])
85
+
86
+ def activate(self, frame_id: int):
87
+ """Activate new track."""
88
+ self.track_id = self.next_id()
89
+ self.tracklet_len = 0
90
+ self.state = "tracked"
91
+ self.frame_id = frame_id
92
+ self.history = [(frame_id, self.tlbr.copy(), self.score)]
93
+
94
+ def re_activate(self, new_track: 'STrack', frame_id: int):
95
+ """Reactivate lost track."""
96
+ self.tlbr = new_track.tlbr
97
+ self.score = new_track.score
98
+ self.tracklet_len = 0
99
+ self.state = "tracked"
100
+ self.frame_id = frame_id
101
+ self.history.append((frame_id, self.tlbr.copy(), self.score))
102
+
103
+ def update(self, new_track: 'STrack', frame_id: int):
104
+ """Update with new detection."""
105
+ self.tlbr = new_track.tlbr
106
+ self.score = new_track.score
107
+ self.tracklet_len += 1
108
+ self.state = "tracked"
109
+ self.frame_id = frame_id
110
+ self.history.append((frame_id, self.tlbr.copy(), self.score))
111
+
112
+ def mark_lost(self):
113
+ """Mark as lost."""
114
+ self.state = "lost"
115
+
116
+ def mark_removed(self):
117
+ """Mark as removed."""
118
+ self.state = "removed"
119
+
120
+ _count = 0
121
+
122
+ @staticmethod
123
+ def next_id() -> int:
124
+ STrack._count += 1
125
+ return STrack._count
126
+
127
+ @staticmethod
128
+ def reset_id():
129
+ STrack._count = 0
130
+
131
+
132
+ class ByteTracker:
133
+ """ByteTrack: Multi-object tracker using high/low confidence matching.
134
+
135
+ Key features:
136
+ - Two-stage matching with high/low confidence detections
137
+ - Hungarian algorithm for optimal assignment
138
+ - No ReID features needed (fast and simple)
139
+ """
140
+
141
+ def __init__(
142
+ self,
143
+ track_thresh: float = 0.5,
144
+ match_thresh: float = 0.8,
145
+ track_buffer: int = 30,
146
+ frame_rate: int = 30,
147
+ ):
148
+ """
149
+ Args:
150
+ track_thresh: High confidence threshold for first matching
151
+ match_thresh: IoU threshold for matching
152
+ track_buffer: Frames to keep lost tracks
153
+ frame_rate: Video frame rate
154
+ """
155
+ self.track_thresh = track_thresh
156
+ self.match_thresh = match_thresh
157
+ self.track_buffer = track_buffer
158
+ self.frame_rate = frame_rate
159
+
160
+ self.tracked_stracks: List[STrack] = []
161
+ self.lost_stracks: List[STrack] = []
162
+ self.removed_stracks: List[STrack] = []
163
+
164
+ self.frame_id = 0
165
+ self.max_time_lost = int(frame_rate / 30.0 * track_buffer)
166
+
167
+ def reset(self):
168
+ """Reset tracker state."""
169
+ self.tracked_stracks = []
170
+ self.lost_stracks = []
171
+ self.removed_stracks = []
172
+ self.frame_id = 0
173
+ STrack.reset_id()
174
+
175
+ def update(self, detections: List[Dict]) -> List[Dict]:
176
+ """Update with new detections.
177
+
178
+ Args:
179
+ detections: List of dicts with label, score, bbox_xyxy
180
+
181
+ Returns:
182
+ List of track dicts with track_id, label, bbox_xyxy, score
183
+ """
184
+ self.frame_id += 1
185
+ activated_stracks = []
186
+ refind_stracks = []
187
+ lost_stracks = []
188
+ removed_stracks = []
189
+
190
+ # Separate high and low confidence detections
191
+ remain_high_inds = [i for i, d in enumerate(detections) if d["score"] >= self.track_thresh]
192
+ remain_low_inds = [i for i, d in enumerate(detections) if d["score"] < self.track_thresh]
193
+
194
+ dets_high = [detections[i] for i in remain_high_inds]
195
+ dets_low = [detections[i] for i in remain_low_inds]
196
+
197
+ # Convert to STrack format
198
+ detections_high = [self._det_to_strack(d) for d in dets_high]
199
+ detections_low = [self._det_to_strack(d) for d in dets_low]
200
+
201
+ # ---- Step 1: Match high-confidence detections with tracked tracks ----
202
+ strack_pool = self.tracked_stracks
203
+
204
+ # Compute IoU distance
205
+ dists = self._iou_distance(strack_pool, detections_high)
206
+
207
+ # Hungarian matching
208
+ matches, u_track, u_detection = self._matching(dists, strack_pool, detections_high)
209
+
210
+ # Update matched tracks
211
+ for itracked, idet in matches:
212
+ track = strack_pool[itracked]
213
+ det = detections_high[idet]
214
+ track.update(det, self.frame_id)
215
+ activated_stracks.append(track)
216
+
217
+ # ---- Step 2: Match unmatched tracks with low-confidence detections ----
218
+ detections_low = [detections_low[i] for i in u_detection if i < len(detections_low)]
219
+ r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == "tracked"]
220
+
221
+ dists = self._iou_distance(r_tracked_stracks, detections_low)
222
+ matches, u_track, u_detection_low = self._matching(dists, r_tracked_stracks, detections_low)
223
+
224
+ for itracked, idet in matches:
225
+ track = r_tracked_stracks[itracked]
226
+ det = detections_low[idet]
227
+ track.update(det, self.frame_id)
228
+ activated_stracks.append(track)
229
+
230
+ # Unmatched tracks become lost
231
+ for it in u_track:
232
+ track = r_tracked_stracks[it]
233
+ track.mark_lost()
234
+ lost_stracks.append(track)
235
+
236
+ # ---- Step 3: Match unmatched high-confidence detections with lost tracks ----
237
+ detections_high_second = [detections_high[i] for i in u_detection]
238
+ dists = self._iou_distance(self.lost_stracks, detections_high_second)
239
+ matches, u_lost, u_detection = self._matching(dists, self.lost_stracks, detections_high_second)
240
+
241
+ for ilost, idet in matches:
242
+ track = self.lost_stracks[ilost]
243
+ det = detections_high_second[idet]
244
+ track.re_activate(det, self.frame_id)
245
+ refind_stracks.append(track)
246
+
247
+ # ---- Step 4: Initialize new tracks ----
248
+ for inew in u_detection:
249
+ track = detections_high_second[inew]
250
+ if track.score >= self.track_thresh:
251
+ track.activate(self.frame_id)
252
+ activated_stracks.append(track)
253
+
254
+ # ---- Step 5: Remove long-lost tracks ----
255
+ for track in self.lost_stracks:
256
+ if self.frame_id - track.frame_id > self.max_time_lost:
257
+ track.mark_removed()
258
+ removed_stracks.append(track)
259
+
260
+ # Update state lists
261
+ self.tracked_stracks = [t for t in self.tracked_stracks if t.state == "tracked"]
262
+ self.tracked_stracks = activated_stracks + refind_stracks
263
+ self.lost_stracks = [t for t in self.lost_stracks if t.state == "lost"]
264
+ self.lost_stracks.extend(lost_stracks)
265
+ self.lost_stracks = [t for t in self.lost_stracks if t not in removed_stracks]
266
+ self.removed_stracks.extend(removed_stracks)
267
+
268
+ # Convert to output format
269
+ return self._stracks_to_output(self.tracked_stracks)
270
+
271
+ def _det_to_strack(self, det: Dict) -> STrack:
272
+ """Convert detection dict to STrack."""
273
+ return STrack(
274
+ track_id=-1,
275
+ label=det["label"],
276
+ tlbr=np.array(det["bbox_xyxy"]),
277
+ score=det["score"],
278
+ )
279
+
280
+ def _iou_distance(self, atracks: List[STrack], btracks: List[STrack]) -> np.ndarray:
281
+ """Compute IoU distance matrix."""
282
+ if not atracks or not btracks:
283
+ return np.zeros((len(atracks), len(btracks)))
284
+
285
+ atlbrs = np.array([track.tlbr for track in atracks])
286
+ btlbrs = np.array([track.tlbr for track in btracks])
287
+
288
+ ious = self._batch_iou(atlbrs, btlbrs)
289
+ cost_matrix = 1 - ious
290
+
291
+ return cost_matrix
292
+
293
+ def _batch_iou(self, boxes_a: np.ndarray, boxes_b: np.ndarray) -> np.ndarray:
294
+ """Batch IoU computation."""
295
+ area_a = (boxes_a[:, 2] - boxes_a[:, 0]) * (boxes_a[:, 3] - boxes_a[:, 1])
296
+ area_b = (boxes_b[:, 2] - boxes_b[:, 0]) * (boxes_b[:, 3] - boxes_b[:, 1])
297
+
298
+ iw = np.minimum(boxes_a[:, None, 2], boxes_b[:, 2]) - np.maximum(boxes_a[:, None, 0], boxes_b[:, 0])
299
+ ih = np.minimum(boxes_a[:, None, 3], boxes_b[:, 3]) - np.maximum(boxes_a[:, None, 1], boxes_b[:, 1])
300
+
301
+ iw = np.maximum(iw, 0)
302
+ ih = np.maximum(ih, 0)
303
+
304
+ inter = iw * ih
305
+ union = area_a[:, None] + area_b - inter
306
+
307
+ ious = inter / np.maximum(union, 1e-6)
308
+ return ious
309
+
310
+ def _matching(
311
+ self,
312
+ dists: np.ndarray,
313
+ atracks: List[STrack],
314
+ btracks: List[STrack],
315
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
316
+ """Match tracks using Hungarian algorithm."""
317
+ if not atracks or not btracks:
318
+ return np.empty((0, 2), dtype=int), np.arange(len(atracks)), np.arange(len(btracks))
319
+
320
+ # Filter by threshold
321
+ dists[dists > 1 - self.match_thresh] = 1e6
322
+
323
+ matches, (u_track, u_detection) = linear_assignment(dists)
324
+
325
+ return matches, u_track, u_detection
326
+
327
+ def _stracks_to_output(self, stracks: List[STrack]) -> List[Dict]:
328
+ """Convert STracks to output dict format."""
329
+ result = []
330
+ for track in stracks:
331
+ result.append({
332
+ "track_id": track.track_id,
333
+ "label": track.label,
334
+ "bbox_xyxy": track.tlbr.tolist(),
335
+ "score": track.score,
336
+ "frame_id": track.frame_id,
337
+ "tracklet_len": track.tracklet_len,
338
+ })
339
+ return result
340
+
341
+
342
+ class BoTSORT(ByteTracker):
343
+ """BoTSORT: ByteTrack with camera motion compensation and ReID.
344
+
345
+ For simplicity, this is a lightweight version without ReID features.
346
+ Adds Kalman filter for state prediction over ByteTrack.
347
+ """
348
+
349
+ def __init__(
350
+ self,
351
+ track_thresh: float = 0.5,
352
+ match_thresh: float = 0.8,
353
+ track_buffer: int = 30,
354
+ frame_rate: int = 30,
355
+ ):
356
+ super().__init__(track_thresh, match_thresh, track_buffer, frame_rate)
357
+ # BoTSORT would add Kalman filter here, but for detection-based tracking
358
+ # we keep it simple and inherit ByteTrack behavior
detection/detr.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DETR and Deformable DETR via Hugging Face Transformers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Dict, List, Optional, Union
6
+
7
+ import torch
8
+ from PIL import Image
9
+
10
+ from .base import BaseDetector, Detection
11
+ from model_cache import hf_cache_dir, ensure_default_checkpoint_dirs
12
+
13
+
14
+ class DETRDetector(BaseDetector):
15
+ def __init__(self, device: str = "cuda", model_name: str = "facebook/detr-resnet-50", **kwargs):
16
+ super().__init__(device=device, **kwargs)
17
+ self.model_name = model_name
18
+ self._device = torch.device(device if torch.cuda.is_available() or device == "cpu" else "cpu")
19
+ self.processor = None
20
+ self._id2label: Dict[int, str] = {}
21
+
22
+ def load_model(self) -> None:
23
+ from transformers import DetrForObjectDetection, DetrImageProcessor
24
+
25
+ ensure_default_checkpoint_dirs()
26
+ cache_dir = str(hf_cache_dir())
27
+ self.processor = DetrImageProcessor.from_pretrained(self.model_name, cache_dir=cache_dir)
28
+ self.model = DetrForObjectDetection.from_pretrained(self.model_name, cache_dir=cache_dir).to(self._device).eval()
29
+ self._id2label = dict(getattr(self.model.config, "id2label", {}) or {})
30
+
31
+ def get_available_classes(self) -> Union[List[str], Dict[str, int], None]:
32
+ if not self._id2label:
33
+ return None
34
+ return {name: int(i) for i, name in self._id2label.items()}
35
+
36
+ def detect(self, image: Image.Image, conf_threshold: float = 0.25, **kwargs) -> List[Detection]:
37
+ assert self.model is not None and self.processor is not None
38
+
39
+ img = image.convert("RGB")
40
+ inputs = self.processor(images=img, return_tensors="pt").to(self._device)
41
+
42
+ with torch.no_grad():
43
+ outputs = self.model(**inputs)
44
+
45
+ target_sizes = torch.tensor([img.size[::-1]], device=self._device) # (h, w)
46
+ results = self.processor.post_process_object_detection(
47
+ outputs,
48
+ threshold=float(conf_threshold),
49
+ target_sizes=target_sizes,
50
+ )[0]
51
+
52
+ dets: List[Detection] = []
53
+ for score, label_id, box in zip(results["scores"], results["labels"], results["boxes"]):
54
+ lid = int(label_id)
55
+ label = self._id2label.get(lid, str(lid))
56
+ b = [float(v) for v in box.tolist()]
57
+ dets.append(Detection(label=label, score=float(score), bbox_xyxy=b))
58
+ return dets
59
+
60
+ def detect_batch(self, images: List[Image.Image], conf_threshold: float = 0.25, **kwargs) -> List[List[Detection]]:
61
+ """Batch detection for video processing - faster than frame-by-frame.
62
+
63
+ Args:
64
+ images: List of PIL Images
65
+ conf_threshold: Confidence threshold
66
+
67
+ Returns:
68
+ List of detection lists, one per image
69
+ """
70
+ assert self.model is not None and self.processor is not None
71
+
72
+ if not images:
73
+ return []
74
+
75
+ # Convert all images to RGB
76
+ imgs = [img.convert("RGB") for img in images]
77
+
78
+ # Batch process images
79
+ inputs = self.processor(images=imgs, return_tensors="pt").to(self._device)
80
+
81
+ with torch.no_grad():
82
+ outputs = self.model(**inputs)
83
+
84
+ # Target sizes for each image (h, w)
85
+ target_sizes = torch.tensor([img.size[::-1] for img in imgs], device=self._device)
86
+ results = self.processor.post_process_object_detection(
87
+ outputs,
88
+ threshold=float(conf_threshold),
89
+ target_sizes=target_sizes,
90
+ )
91
+
92
+ # Parse results for each image
93
+ all_detections = []
94
+ for result in results:
95
+ frame_dets = []
96
+ for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]):
97
+ lid = int(label_id)
98
+ label = self._id2label.get(lid, str(lid))
99
+ b = [float(v) for v in box.tolist()]
100
+ frame_dets.append(Detection(label=label, score=float(score), bbox_xyxy=b))
101
+ all_detections.append(frame_dets)
102
+
103
+ return all_detections
104
+
105
+
106
+ class DeformableDETRDetector(BaseDetector):
107
+ """Deformable DETR wrapper.
108
+
109
+ This relies on transformers' Deformable DETR implementations and checkpoints.
110
+ If your transformers build doesn't include Deformable DETR, this will raise.
111
+
112
+ Default checkpoint is `SenseTime/deformable-detr`.
113
+ """
114
+
115
+ def __init__(self, device: str = "cuda", model_name: str = "SenseTime/deformable-detr", **kwargs):
116
+ super().__init__(device=device, **kwargs)
117
+ self.model_name = model_name
118
+ self._device = torch.device(device if torch.cuda.is_available() or device == "cpu" else "cpu")
119
+ self.processor = None
120
+ self._id2label: Dict[int, str] = {}
121
+
122
+ def load_model(self) -> None:
123
+ try:
124
+ from transformers import AutoImageProcessor, AutoModelForObjectDetection
125
+ except Exception as e:
126
+ raise ImportError("transformers is required for Deformable DETR") from e
127
+
128
+ ensure_default_checkpoint_dirs()
129
+ cache_dir = str(hf_cache_dir())
130
+ self.processor = AutoImageProcessor.from_pretrained(self.model_name, cache_dir=cache_dir)
131
+ self.model = AutoModelForObjectDetection.from_pretrained(self.model_name, cache_dir=cache_dir).to(self._device).eval()
132
+ self._id2label = dict(getattr(self.model.config, "id2label", {}) or {})
133
+
134
+ def get_available_classes(self) -> Union[List[str], Dict[str, int], None]:
135
+ if not self._id2label:
136
+ return None
137
+ return {name: int(i) for i, name in self._id2label.items()}
138
+
139
+ def detect(self, image: Image.Image, conf_threshold: float = 0.25, **kwargs) -> List[Detection]:
140
+ assert self.model is not None and self.processor is not None
141
+
142
+ img = image.convert("RGB")
143
+ inputs = self.processor(images=img, return_tensors="pt").to(self._device)
144
+
145
+ with torch.no_grad():
146
+ outputs = self.model(**inputs)
147
+
148
+ # Prefer standard post_process if provided
149
+ if hasattr(self.processor, "post_process_object_detection"):
150
+ target_sizes = torch.tensor([img.size[::-1]], device=self._device)
151
+ results = self.processor.post_process_object_detection(
152
+ outputs,
153
+ threshold=float(conf_threshold),
154
+ target_sizes=target_sizes,
155
+ )[0]
156
+
157
+ dets: List[Detection] = []
158
+ for score, label_id, box in zip(results["scores"], results["labels"], results["boxes"]):
159
+ lid = int(label_id)
160
+ label = self._id2label.get(lid, str(lid))
161
+ b = [float(v) for v in box.tolist()]
162
+ dets.append(Detection(label=label, score=float(score), bbox_xyxy=b))
163
+ return dets
164
+
165
+ # Fallback: no post_process available
166
+ raise RuntimeError("This Deformable DETR processor does not support post_process_object_detection")
167
+
168
+ def detect_batch(self, images: List[Image.Image], conf_threshold: float = 0.25, **kwargs) -> List[List[Detection]]:
169
+ """Batch detection for video processing - faster than frame-by-frame.
170
+
171
+ Args:
172
+ images: List of PIL Images
173
+ conf_threshold: Confidence threshold
174
+
175
+ Returns:
176
+ List of detection lists, one per image
177
+ """
178
+ assert self.model is not None and self.processor is not None
179
+
180
+ if not images:
181
+ return []
182
+
183
+ # Convert all images to RGB
184
+ imgs = [img.convert("RGB") for img in images]
185
+
186
+ # Batch process images
187
+ inputs = self.processor(images=imgs, return_tensors="pt").to(self._device)
188
+
189
+ with torch.no_grad():
190
+ outputs = self.model(**inputs)
191
+
192
+ # Post-process requires method support
193
+ if not hasattr(self.processor, "post_process_object_detection"):
194
+ raise RuntimeError("This Deformable DETR processor does not support post_process_object_detection")
195
+
196
+ # Target sizes for each image (h, w)
197
+ target_sizes = torch.tensor([img.size[::-1] for img in imgs], device=self._device)
198
+ results = self.processor.post_process_object_detection(
199
+ outputs,
200
+ threshold=float(conf_threshold),
201
+ target_sizes=target_sizes,
202
+ )
203
+
204
+ # Parse results for each image
205
+ all_detections = []
206
+ for result in results:
207
+ frame_dets = []
208
+ for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]):
209
+ lid = int(label_id)
210
+ label = self._id2label.get(lid, str(lid))
211
+ b = [float(v) for v in box.tolist()]
212
+ frame_dets.append(Detection(label=label, score=float(score), bbox_xyxy=b))
213
+ all_detections.append(frame_dets)
214
+
215
+ return all_detections
detection/factory.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Factory + registry for object detectors."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Dict, List, Type
6
+
7
+ from .base import BaseDetector
8
+ from .yolo import YOLODetector
9
+ from .torchvision_detectors import (
10
+ FasterRCNNDetector,
11
+ RetinaNetDetector,
12
+ SSDDetector,
13
+ EfficientDetDetector,
14
+ FCOSDetector,
15
+ )
16
+ from .detr import DETRDetector, DeformableDETRDetector
17
+ from .grounding_dino import GroundingDINODetector
18
+ from .yolo_world import YOLOWorldDetector
19
+
20
+
21
+ DETECTOR_REGISTRY: Dict[str, Type[BaseDetector]] = {
22
+ "yolo": YOLODetector,
23
+ "yolo_world": YOLOWorldDetector,
24
+ "fasterrcnn": FasterRCNNDetector,
25
+ "retinanet": RetinaNetDetector,
26
+ "ssd": SSDDetector,
27
+ "efficientdet": EfficientDetDetector,
28
+ "fcos": FCOSDetector,
29
+ "detr": DETRDetector,
30
+ "deformable_detr": DeformableDETRDetector,
31
+ "grounding_dino": GroundingDINODetector,
32
+ }
33
+
34
+
35
+ def register_detector(name: str, detector_class: Type[BaseDetector]) -> None:
36
+ if not issubclass(detector_class, BaseDetector):
37
+ raise ValueError(f"{detector_class} must extend BaseDetector")
38
+ DETECTOR_REGISTRY[name.lower()] = detector_class
39
+
40
+
41
+ def create_detector(method: str, device: str = "cuda", **kwargs) -> BaseDetector:
42
+ method_lower = method.lower()
43
+ if method_lower not in DETECTOR_REGISTRY:
44
+ available = ", ".join(sorted(DETECTOR_REGISTRY.keys()))
45
+ raise ValueError(f"Unknown detector: '{method}'. Available: {available}")
46
+ return DETECTOR_REGISTRY[method_lower](device=device, **kwargs)
47
+
48
+
49
+ def get_available_detectors() -> List[str]:
50
+ return sorted(DETECTOR_REGISTRY.keys())
detection/grounding_dino.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Grounding DINO open-vocabulary object detection via Hugging Face Transformers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, List, Optional, Sequence, Union
6
+
7
+ import torch
8
+ from PIL import Image
9
+
10
+ from .base import BaseDetector, Detection
11
+ from model_cache import hf_cache_dir, ensure_default_checkpoint_dirs
12
+
13
+
14
+ def _normalize_prompts(prompts: Optional[Union[str, Sequence[str]]]) -> Optional[List[str]]:
15
+ if prompts is None:
16
+ return None
17
+ if isinstance(prompts, str):
18
+ # allow comma-separated convenience
19
+ parts = [p.strip() for p in prompts.split(",")]
20
+ parts = [p for p in parts if p]
21
+ return parts or None
22
+ out = [str(p).strip() for p in prompts]
23
+ out = [p for p in out if p]
24
+ return out or None
25
+
26
+
27
+ class GroundingDINODetector(BaseDetector):
28
+ """Grounding DINO wrapper.
29
+
30
+ Usage:
31
+ detector = create_detector('grounding_dino', device='cuda')
32
+ dets = detector(image, prompts=['person', 'car'])
33
+
34
+ Notes:
35
+ - This is an open-vocabulary model; `get_available_classes()` returns None.
36
+ - Pass prompts via `prompts=` or `classes=`; you may also pass raw `text=`.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ device: str = "cuda",
42
+ model_name: str = "IDEA-Research/grounding-dino-base",
43
+ **kwargs: Any,
44
+ ):
45
+ super().__init__(device=device, **kwargs)
46
+ self.model_name = model_name
47
+ self._device = torch.device(device if (torch.cuda.is_available() or device == "cpu") else "cpu")
48
+ self.processor = None
49
+
50
+ def load_model(self) -> None:
51
+ try:
52
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
53
+ except Exception as e:
54
+ raise ImportError(
55
+ "transformers>=4.36 is required for Grounding DINO (AutoProcessor + AutoModelForZeroShotObjectDetection)"
56
+ ) from e
57
+
58
+ ensure_default_checkpoint_dirs()
59
+ cache_dir = str(hf_cache_dir())
60
+ self.processor = AutoProcessor.from_pretrained(self.model_name, cache_dir=cache_dir)
61
+ self.model = AutoModelForZeroShotObjectDetection.from_pretrained(self.model_name, cache_dir=cache_dir).to(self._device).eval()
62
+
63
+ def get_available_classes(self):
64
+ return None
65
+
66
+ def detect(self, image: Image.Image, conf_threshold: float = 0.25, **kwargs: Any) -> List[Detection]:
67
+ assert self.model is not None and self.processor is not None
68
+
69
+ img = image.convert("RGB")
70
+ prompts = _normalize_prompts(kwargs.get("prompts") or kwargs.get("classes") or kwargs.get("labels"))
71
+ text: Optional[str] = kwargs.get("text")
72
+
73
+ if text is None:
74
+ if not prompts:
75
+ raise ValueError("GroundingDINO requires `prompts`/`classes` (or raw `text`) for open-vocabulary detection")
76
+ # GroundingDINO expects a single string; period-delimited works well
77
+ text = " . ".join(prompts)
78
+ if not text.endswith("."):
79
+ text = text + " ."
80
+
81
+ threshold = float(kwargs.get("threshold", conf_threshold))
82
+
83
+ inputs = self.processor(images=img, text=text, return_tensors="pt")
84
+ inputs = {k: v.to(self._device) for k, v in inputs.items()}
85
+
86
+ with torch.no_grad():
87
+ outputs = self.model(**inputs)
88
+
89
+ target_sizes = torch.tensor([img.size[::-1]], device=self._device) # (h, w)
90
+
91
+ # Prefer the processor's helper if available.
92
+ if hasattr(self.processor, "post_process_grounded_object_detection"):
93
+ # Move input_ids to CPU for post-processing (some versions have device mismatch bugs)
94
+ input_ids_for_postprocess = inputs["input_ids"].cpu()
95
+ results = self.processor.post_process_grounded_object_detection(
96
+ outputs,
97
+ input_ids_for_postprocess,
98
+ threshold=threshold,
99
+ target_sizes=target_sizes.cpu(),
100
+ )[0]
101
+ else:
102
+ raise RuntimeError("This processor does not support post_process_grounded_object_detection")
103
+
104
+ boxes = results.get("boxes")
105
+ scores = results.get("scores")
106
+ text_labels = results.get("text_labels") or results.get("labels")
107
+
108
+ if boxes is None or scores is None:
109
+ return []
110
+
111
+ dets: List[Detection] = []
112
+ for i in range(int(scores.shape[0])):
113
+ score = float(scores[i])
114
+ box = boxes[i]
115
+ b = [float(v) for v in box.tolist()]
116
+
117
+ label: str
118
+ if isinstance(text_labels, (list, tuple)) and i < len(text_labels):
119
+ label = str(text_labels[i])
120
+ elif torch.is_tensor(text_labels) and i < int(text_labels.shape[0]):
121
+ lid = int(text_labels[i])
122
+ if prompts and 0 <= lid < len(prompts):
123
+ label = prompts[lid]
124
+ else:
125
+ label = str(lid)
126
+ else:
127
+ label = "object"
128
+
129
+ dets.append(Detection(label=label, score=score, bbox_xyxy=b))
130
+
131
+ return dets
132
+
133
+ def detect_batch(self, images: List[Image.Image], conf_threshold: float = 0.25, **kwargs: Any) -> List[List[Detection]]:
134
+ """Batch detection for video processing.
135
+
136
+ Note: Grounding DINO requires the same text prompt for all images in batch.
137
+
138
+ Args:
139
+ images: List of PIL Images
140
+ conf_threshold: Confidence threshold
141
+ **kwargs: Should include 'prompts'/'classes' or 'text'
142
+
143
+ Returns:
144
+ List of detection lists, one per image
145
+ """
146
+ assert self.model is not None and self.processor is not None
147
+
148
+ if not images:
149
+ return []
150
+
151
+ imgs = [img.convert("RGB") for img in images]
152
+ prompts = _normalize_prompts(kwargs.get("prompts") or kwargs.get("classes") or kwargs.get("labels"))
153
+ text: Optional[str] = kwargs.get("text")
154
+
155
+ if text is None:
156
+ if not prompts:
157
+ raise ValueError("GroundingDINO requires `prompts`/`classes` (or raw `text`) for open-vocabulary detection")
158
+ text = " . ".join(prompts)
159
+ if not text.endswith("."):
160
+ text = text + " ."
161
+
162
+ threshold = float(kwargs.get("threshold", conf_threshold))
163
+
164
+ # Batch process images with same text prompt
165
+ inputs = self.processor(images=imgs, text=[text] * len(imgs), return_tensors="pt")
166
+ inputs = {k: v.to(self._device) for k, v in inputs.items()}
167
+
168
+ with torch.no_grad():
169
+ outputs = self.model(**inputs)
170
+
171
+ target_sizes = torch.tensor([img.size[::-1] for img in imgs], device=self._device)
172
+
173
+ if not hasattr(self.processor, "post_process_grounded_object_detection"):
174
+ raise RuntimeError("This processor does not support post_process_grounded_object_detection")
175
+
176
+ # Post-process
177
+ input_ids_for_postprocess = inputs["input_ids"].cpu()
178
+ results = self.processor.post_process_grounded_object_detection(
179
+ outputs,
180
+ input_ids_for_postprocess,
181
+ threshold=threshold,
182
+ target_sizes=target_sizes.cpu(),
183
+ )
184
+
185
+ # Parse results for each image
186
+ all_detections = []
187
+ for result in results:
188
+ boxes = result.get("boxes")
189
+ scores = result.get("scores")
190
+ text_labels = result.get("text_labels") or result.get("labels")
191
+
192
+ frame_dets = []
193
+ if boxes is not None and scores is not None:
194
+ for i in range(int(scores.shape[0])):
195
+ score = float(scores[i])
196
+ box = boxes[i]
197
+ b = [float(v) for v in box.tolist()]
198
+
199
+ label: str
200
+ if isinstance(text_labels, (list, tuple)) and i < len(text_labels):
201
+ label = str(text_labels[i])
202
+ elif torch.is_tensor(text_labels) and i < int(text_labels.shape[0]):
203
+ lid = int(text_labels[i])
204
+ if prompts and 0 <= lid < len(prompts):
205
+ label = prompts[lid]
206
+ else:
207
+ label = str(lid)
208
+ else:
209
+ label = "object"
210
+
211
+ frame_dets.append(Detection(label=label, score=score, bbox_xyxy=b))
212
+
213
+ all_detections.append(frame_dets)
214
+
215
+ return all_detections
detection/torchvision_detectors.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Torchvision-based object detectors: Faster R-CNN, RetinaNet, SSD, FCOS.
2
+
3
+ Also includes an EfficientDet wrapper (optional dependency: effdet).
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from model_cache import ensure_default_checkpoint_dirs
9
+
10
+ from typing import Dict, List, Optional, Union
11
+
12
+ import numpy as np
13
+ import torch
14
+ from PIL import Image
15
+ from torchvision.transforms.functional import to_tensor
16
+
17
+ from .base import BaseDetector, Detection
18
+
19
+
20
+ # Ensure torchvision/torch hub downloads land under `checkpoints/` by default.
21
+ ensure_default_checkpoint_dirs()
22
+
23
+
24
+ class _TorchvisionCOCODetector(BaseDetector):
25
+ _categories: Optional[List[str]] = None
26
+
27
+ def __init__(self, device: str = "cuda", **kwargs):
28
+ super().__init__(device=device, **kwargs)
29
+ self._device = torch.device(device if torch.cuda.is_available() or device == "cpu" else "cpu")
30
+
31
+ def get_available_classes(self) -> Union[List[str], Dict[str, int], None]:
32
+ if not self._categories:
33
+ return None
34
+ return {name: int(i) for i, name in enumerate(self._categories)}
35
+
36
+ def detect(self, image: Image.Image, conf_threshold: float = 0.25, **kwargs) -> List[Detection]:
37
+ assert self.model is not None
38
+
39
+ img = image.convert("RGB")
40
+ x = to_tensor(img).to(self._device)
41
+
42
+ with torch.no_grad():
43
+ out = self.model([x])[0]
44
+
45
+ boxes = out.get("boxes")
46
+ labels = out.get("labels")
47
+ scores = out.get("scores")
48
+
49
+ if boxes is None or labels is None or scores is None:
50
+ return []
51
+
52
+ boxes = boxes.detach().cpu().numpy()
53
+ labels = labels.detach().cpu().numpy().astype(int)
54
+ scores = scores.detach().cpu().numpy()
55
+
56
+ dets: List[Detection] = []
57
+ for b, l, s in zip(boxes, labels, scores):
58
+ if float(s) < float(conf_threshold):
59
+ continue
60
+ label_name = str(int(l))
61
+ if self._categories and 0 <= int(l) < len(self._categories):
62
+ label_name = self._categories[int(l)]
63
+ dets.append(
64
+ Detection(
65
+ label=label_name,
66
+ score=float(s),
67
+ bbox_xyxy=[float(v) for v in b.tolist()],
68
+ )
69
+ )
70
+ return dets
71
+
72
+ def detect_batch(self, images: List[Image.Image], conf_threshold: float = 0.25, **kwargs) -> List[List[Detection]]:
73
+ """Batch detection for video processing - much faster than frame-by-frame.
74
+
75
+ Args:
76
+ images: List of PIL Images
77
+ conf_threshold: Confidence threshold
78
+
79
+ Returns:
80
+ List of detection lists, one per image
81
+ """
82
+ assert self.model is not None
83
+
84
+ if not images:
85
+ return []
86
+
87
+ # Convert all images to tensors
88
+ imgs = [img.convert("RGB") for img in images]
89
+ tensors = [to_tensor(img).to(self._device) for img in imgs]
90
+
91
+ # Batch inference - torchvision models accept list of tensors
92
+ with torch.no_grad():
93
+ outputs = self.model(tensors)
94
+
95
+ # Parse results for each image
96
+ all_detections = []
97
+ for out in outputs:
98
+ boxes = out.get("boxes")
99
+ labels = out.get("labels")
100
+ scores = out.get("scores")
101
+
102
+ frame_dets = []
103
+ if boxes is not None and labels is not None and scores is not None:
104
+ boxes = boxes.detach().cpu().numpy()
105
+ labels = labels.detach().cpu().numpy().astype(int)
106
+ scores = scores.detach().cpu().numpy()
107
+
108
+ for b, l, s in zip(boxes, labels, scores):
109
+ if float(s) < float(conf_threshold):
110
+ continue
111
+ label_name = str(int(l))
112
+ if self._categories and 0 <= int(l) < len(self._categories):
113
+ label_name = self._categories[int(l)]
114
+ frame_dets.append(
115
+ Detection(
116
+ label=label_name,
117
+ score=float(s),
118
+ bbox_xyxy=[float(v) for v in b.tolist()],
119
+ )
120
+ )
121
+
122
+ all_detections.append(frame_dets)
123
+
124
+ return all_detections
125
+
126
+
127
+ class FasterRCNNDetector(_TorchvisionCOCODetector):
128
+ def load_model(self) -> None:
129
+ from torchvision.models.detection import fasterrcnn_resnet50_fpn
130
+ from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
131
+
132
+ weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
133
+ self._categories = list(weights.meta.get("categories", []))
134
+ self.model = fasterrcnn_resnet50_fpn(weights=weights).to(self._device).eval()
135
+
136
+
137
+ class RetinaNetDetector(_TorchvisionCOCODetector):
138
+ def load_model(self) -> None:
139
+ from torchvision.models.detection import retinanet_resnet50_fpn
140
+ from torchvision.models.detection import RetinaNet_ResNet50_FPN_Weights
141
+
142
+ weights = RetinaNet_ResNet50_FPN_Weights.DEFAULT
143
+ self._categories = list(weights.meta.get("categories", []))
144
+ self.model = retinanet_resnet50_fpn(weights=weights).to(self._device).eval()
145
+
146
+
147
+ class SSDDetector(_TorchvisionCOCODetector):
148
+ def load_model(self) -> None:
149
+ from torchvision.models.detection import ssd300_vgg16
150
+ from torchvision.models.detection import SSD300_VGG16_Weights
151
+
152
+ weights = SSD300_VGG16_Weights.DEFAULT
153
+ self._categories = list(weights.meta.get("categories", []))
154
+ self.model = ssd300_vgg16(weights=weights).to(self._device).eval()
155
+
156
+
157
+ class FCOSDetector(_TorchvisionCOCODetector):
158
+ def load_model(self) -> None:
159
+ from torchvision.models.detection import fcos_resnet50_fpn
160
+ from torchvision.models.detection import FCOS_ResNet50_FPN_Weights
161
+
162
+ weights = FCOS_ResNet50_FPN_Weights.DEFAULT
163
+ self._categories = list(weights.meta.get("categories", []))
164
+ self.model = fcos_resnet50_fpn(weights=weights).to(self._device).eval()
165
+
166
+
167
+ class EfficientDetDetector(BaseDetector):
168
+ """EfficientDet via `effdet` (optional dependency).
169
+
170
+ This is implemented so the module is extendable, but requires:
171
+ - `pip install effdet`
172
+
173
+ We default to `tf_efficientdet_d0`.
174
+ """
175
+
176
+ def __init__(self, device: str = "cuda", model_name: str = "tf_efficientdet_d0", **kwargs):
177
+ super().__init__(device=device, **kwargs)
178
+ self.model_name = model_name
179
+ self._device = torch.device(device if torch.cuda.is_available() or device == "cpu" else "cpu")
180
+ self._categories: Optional[List[str]] = None
181
+ self._input_size_hw: Optional[tuple[int, int]] = None
182
+
183
+ def load_model(self) -> None:
184
+ try:
185
+ from effdet import create_model
186
+ from effdet.config import get_efficientdet_config
187
+ except Exception as e:
188
+ raise ImportError(
189
+ "EfficientDet requires the optional package 'effdet'. "
190
+ "Install it with: pip install effdet"
191
+ ) from e
192
+
193
+ self.model = create_model(
194
+ self.model_name,
195
+ pretrained=True,
196
+ bench_task="predict",
197
+ ).to(self._device).eval()
198
+
199
+ # effdet prediction bench expects a fixed input resolution.
200
+ cfg = get_efficientdet_config(self.model_name)
201
+ # cfg.image_size is [H, W]
202
+ self._input_size_hw = (int(cfg.image_size[0]), int(cfg.image_size[1]))
203
+
204
+ # effdet uses COCO labels by default for pretrained models
205
+ # Keep class list unknown here to avoid hard-coding.
206
+ self._categories = None
207
+
208
+ def get_available_classes(self) -> Union[List[str], Dict[str, int], None]:
209
+ return self._categories
210
+
211
+ def detect(self, image: Image.Image, conf_threshold: float = 0.25, **kwargs) -> List[Detection]:
212
+ assert self.model is not None
213
+
214
+ img = image.convert("RGB")
215
+ orig_w, orig_h = img.size
216
+
217
+ # Resize to the model's configured input size. This avoids internal shape mismatch
218
+ # issues in effdet when given arbitrary resolutions.
219
+ target_h, target_w = self._input_size_hw or (512, 512)
220
+ resized = img.resize((target_w, target_h))
221
+ x0 = to_tensor(resized)
222
+
223
+ in_h, in_w = target_h, target_w
224
+ x = x0.unsqueeze(0).to(self._device)
225
+
226
+ with torch.no_grad():
227
+ pred = self.model(x)
228
+
229
+ # pred is a dict-like with 'detections': [B, max_det, 6]
230
+ det = pred[0].detach().cpu().numpy()
231
+ out: List[Detection] = []
232
+ sx = orig_w / float(in_w) if in_w > 0 else 1.0
233
+ sy = orig_h / float(in_h) if in_h > 0 else 1.0
234
+ for row in det:
235
+ x1, y1, x2, y2, score, cls = row.tolist()
236
+ if float(score) < float(conf_threshold):
237
+ continue
238
+ out.append(
239
+ Detection(
240
+ label=str(int(cls)),
241
+ score=float(score),
242
+ bbox_xyxy=[float(x1) * sx, float(y1) * sy, float(x2) * sx, float(y2) * sy],
243
+ )
244
+ )
245
+ return out
246
+
247
+ def detect_batch(self, images: List[Image.Image], conf_threshold: float = 0.25, **kwargs) -> List[List[Detection]]:
248
+ """Batch detection for video processing.
249
+
250
+ Args:
251
+ images: List of PIL Images
252
+ conf_threshold: Confidence threshold
253
+
254
+ Returns:
255
+ List of detection lists, one per image
256
+ """
257
+ assert self.model is not None
258
+
259
+ if not images:
260
+ return []
261
+
262
+ # Convert and resize all images
263
+ target_h, target_w = self._input_size_hw or (512, 512)
264
+ tensors = []
265
+ scales = []
266
+
267
+ for img in images:
268
+ img_rgb = img.convert("RGB")
269
+ orig_w, orig_h = img_rgb.size
270
+ resized = img_rgb.resize((target_w, target_h))
271
+ tensors.append(to_tensor(resized))
272
+
273
+ sx = orig_w / float(target_w) if target_w > 0 else 1.0
274
+ sy = orig_h / float(target_h) if target_h > 0 else 1.0
275
+ scales.append((sx, sy))
276
+
277
+ # Batch inference
278
+ batch = torch.stack(tensors).to(self._device)
279
+ with torch.no_grad():
280
+ preds = self.model(batch)
281
+
282
+ # Parse results for each image
283
+ all_detections = []
284
+ for pred, (sx, sy) in zip(preds, scales):
285
+ det = pred.detach().cpu().numpy()
286
+ frame_dets = []
287
+ for row in det:
288
+ x1, y1, x2, y2, score, cls = row.tolist()
289
+ if float(score) < float(conf_threshold):
290
+ continue
291
+ frame_dets.append(
292
+ Detection(
293
+ label=str(int(cls)),
294
+ score=float(score),
295
+ bbox_xyxy=[float(x1) * sx, float(y1) * sy, float(x2) * sx, float(y2) * sy],
296
+ )
297
+ )
298
+ all_detections.append(frame_dets)
299
+
300
+ return all_detections
detection/tracker.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Simple IoU-based object tracker for video processing.
2
+
3
+ Provides basic multi-object tracking using detection-to-track association
4
+ based on IoU overlap and optional appearance features.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass, field
10
+ from typing import Dict, List, Optional, Sequence, Tuple
11
+
12
+ import numpy as np
13
+
14
+
15
+ @dataclass
16
+ class Track:
17
+ """A tracked object across frames."""
18
+
19
+ track_id: int
20
+ label: str
21
+
22
+ # History of bounding boxes [(frame_idx, bbox_xyxy, score)]
23
+ history: List[Tuple[int, List[float], float]] = field(default_factory=list)
24
+
25
+ # Current state
26
+ last_bbox: List[float] = field(default_factory=list)
27
+ last_score: float = 0.0
28
+ last_frame: int = -1
29
+
30
+ # Track status
31
+ age: int = 0 # Frames since first detection
32
+ hits: int = 0 # Total detection matches
33
+ time_since_update: int = 0 # Frames since last match
34
+
35
+ # Color for visualization (RGB)
36
+ color: Tuple[int, int, int] = (0, 255, 0)
37
+
38
+
39
+ class SimpleTracker:
40
+ """IoU-based multi-object tracker.
41
+
42
+ Associates detections to existing tracks using IoU overlap.
43
+ Creates new tracks for unmatched detections.
44
+ Removes tracks that haven't been updated for too long.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ iou_threshold: float = 0.3,
50
+ max_age: int = 30,
51
+ min_hits: int = 3,
52
+ label_match: bool = True,
53
+ ):
54
+ """
55
+ Args:
56
+ iou_threshold: Minimum IoU for detection-track association
57
+ max_age: Maximum frames a track survives without update
58
+ min_hits: Minimum detections before track is confirmed
59
+ label_match: Require label match for association
60
+ """
61
+ self.iou_threshold = iou_threshold
62
+ self.max_age = max_age
63
+ self.min_hits = min_hits
64
+ self.label_match = label_match
65
+
66
+ self._tracks: Dict[int, Track] = {}
67
+ self._next_id: int = 1
68
+ self._frame_idx: int = 0
69
+
70
+ # Color palette for tracks
71
+ self._colors = [
72
+ (255, 0, 0), (0, 255, 0), (0, 0, 255),
73
+ (255, 255, 0), (255, 0, 255), (0, 255, 255),
74
+ (255, 128, 0), (128, 0, 255), (0, 255, 128),
75
+ (255, 0, 128), (128, 255, 0), (0, 128, 255),
76
+ ]
77
+
78
+ def reset(self):
79
+ """Reset tracker state."""
80
+ self._tracks = {}
81
+ self._next_id = 1
82
+ self._frame_idx = 0
83
+
84
+ def update(
85
+ self,
86
+ detections: List[Dict],
87
+ frame_idx: Optional[int] = None,
88
+ ) -> List[Dict]:
89
+ """Update tracks with new detections.
90
+
91
+ Args:
92
+ detections: List of detection dicts with label, score, bbox_xyxy
93
+ frame_idx: Optional frame index (auto-incremented if None)
94
+
95
+ Returns:
96
+ List of track dicts with track_id, label, bbox_xyxy, score
97
+ """
98
+ if frame_idx is None:
99
+ self._frame_idx += 1
100
+ else:
101
+ self._frame_idx = frame_idx
102
+
103
+ # Increment age for all tracks
104
+ for track in self._tracks.values():
105
+ track.time_since_update += 1
106
+
107
+ if not detections:
108
+ # No detections - just age tracks
109
+ self._remove_old_tracks()
110
+ return self._get_active_tracks()
111
+
112
+ # Build cost matrix (negative IoU for Hungarian matching)
113
+ track_ids = list(self._tracks.keys())
114
+ det_indices = list(range(len(detections)))
115
+
116
+ if track_ids and det_indices:
117
+ # Compute IoU matrix
118
+ iou_matrix = self._compute_iou_matrix(
119
+ [self._tracks[tid].last_bbox for tid in track_ids],
120
+ [d["bbox_xyxy"] for d in detections],
121
+ )
122
+
123
+ # Greedy matching (could use Hungarian for optimal)
124
+ matches, unmatched_tracks, unmatched_dets = self._greedy_match(
125
+ iou_matrix,
126
+ track_ids,
127
+ det_indices,
128
+ detections,
129
+ )
130
+ else:
131
+ matches = []
132
+ unmatched_tracks = track_ids
133
+ unmatched_dets = det_indices
134
+
135
+ # Update matched tracks
136
+ for track_id, det_idx in matches:
137
+ det = detections[det_idx]
138
+ track = self._tracks[track_id]
139
+ track.last_bbox = det["bbox_xyxy"]
140
+ track.last_score = det["score"]
141
+ track.last_frame = self._frame_idx
142
+ track.hits += 1
143
+ track.time_since_update = 0
144
+ track.history.append((self._frame_idx, det["bbox_xyxy"], det["score"]))
145
+
146
+ # Create new tracks for unmatched detections
147
+ for det_idx in unmatched_dets:
148
+ det = detections[det_idx]
149
+ color = self._colors[self._next_id % len(self._colors)]
150
+
151
+ track = Track(
152
+ track_id=self._next_id,
153
+ label=det["label"],
154
+ last_bbox=det["bbox_xyxy"],
155
+ last_score=det["score"],
156
+ last_frame=self._frame_idx,
157
+ age=1,
158
+ hits=1,
159
+ time_since_update=0,
160
+ color=color,
161
+ )
162
+ track.history.append((self._frame_idx, det["bbox_xyxy"], det["score"]))
163
+
164
+ self._tracks[self._next_id] = track
165
+ self._next_id += 1
166
+
167
+ # Remove old tracks
168
+ self._remove_old_tracks()
169
+
170
+ return self._get_active_tracks()
171
+
172
+ def update_batch(
173
+ self,
174
+ detections_list: List[List[Dict]],
175
+ ) -> List[Dict]:
176
+ """Update tracks with a batch of frames.
177
+
178
+ Args:
179
+ detections_list: List of detection lists, one per frame
180
+
181
+ Returns:
182
+ List of all track dicts with full history
183
+ """
184
+ for dets in detections_list:
185
+ self.update(dets)
186
+
187
+ return self._get_all_tracks_with_history()
188
+
189
+ def _compute_iou_matrix(
190
+ self,
191
+ boxes_a: List[List[float]],
192
+ boxes_b: List[List[float]],
193
+ ) -> np.ndarray:
194
+ """Compute IoU matrix between two sets of boxes."""
195
+ n_a = len(boxes_a)
196
+ n_b = len(boxes_b)
197
+
198
+ if n_a == 0 or n_b == 0:
199
+ return np.zeros((n_a, n_b))
200
+
201
+ iou_matrix = np.zeros((n_a, n_b))
202
+
203
+ for i, box_a in enumerate(boxes_a):
204
+ for j, box_b in enumerate(boxes_b):
205
+ iou_matrix[i, j] = self._box_iou(box_a, box_b)
206
+
207
+ return iou_matrix
208
+
209
+ def _box_iou(self, a: List[float], b: List[float]) -> float:
210
+ """Compute IoU between two boxes."""
211
+ ax1, ay1, ax2, ay2 = a
212
+ bx1, by1, bx2, by2 = b
213
+
214
+ inter_x1 = max(ax1, bx1)
215
+ inter_y1 = max(ay1, by1)
216
+ inter_x2 = min(ax2, bx2)
217
+ inter_y2 = min(ay2, by2)
218
+
219
+ inter_w = max(0.0, inter_x2 - inter_x1)
220
+ inter_h = max(0.0, inter_y2 - inter_y1)
221
+ inter = inter_w * inter_h
222
+
223
+ area_a = max(0.0, ax2 - ax1) * max(0.0, ay2 - ay1)
224
+ area_b = max(0.0, bx2 - bx1) * max(0.0, by2 - by1)
225
+ union = area_a + area_b - inter
226
+
227
+ return float(inter / union) if union > 0 else 0.0
228
+
229
+ def _greedy_match(
230
+ self,
231
+ iou_matrix: np.ndarray,
232
+ track_ids: List[int],
233
+ det_indices: List[int],
234
+ detections: List[Dict],
235
+ ) -> Tuple[List[Tuple[int, int]], List[int], List[int]]:
236
+ """Greedy matching based on IoU."""
237
+ matches: List[Tuple[int, int]] = []
238
+ matched_tracks = set()
239
+ matched_dets = set()
240
+
241
+ # Sort by IoU (descending)
242
+ n_tracks, n_dets = iou_matrix.shape
243
+ pairs = []
244
+ for i in range(n_tracks):
245
+ for j in range(n_dets):
246
+ if iou_matrix[i, j] >= self.iou_threshold:
247
+ pairs.append((iou_matrix[i, j], i, j))
248
+
249
+ pairs.sort(reverse=True, key=lambda x: x[0])
250
+
251
+ for iou_val, track_idx, det_idx in pairs:
252
+ if track_idx in matched_tracks or det_idx in matched_dets:
253
+ continue
254
+
255
+ track_id = track_ids[track_idx]
256
+ det = detections[det_idx]
257
+
258
+ # Check label match if required
259
+ if self.label_match:
260
+ if self._tracks[track_id].label != det["label"]:
261
+ continue
262
+
263
+ matches.append((track_id, det_idx))
264
+ matched_tracks.add(track_idx)
265
+ matched_dets.add(det_idx)
266
+
267
+ unmatched_tracks = [track_ids[i] for i in range(n_tracks) if i not in matched_tracks]
268
+ unmatched_dets = [j for j in range(n_dets) if j not in matched_dets]
269
+
270
+ return matches, unmatched_tracks, unmatched_dets
271
+
272
+ def _remove_old_tracks(self):
273
+ """Remove tracks that haven't been updated recently."""
274
+ to_remove = []
275
+ for track_id, track in self._tracks.items():
276
+ if track.time_since_update > self.max_age:
277
+ to_remove.append(track_id)
278
+
279
+ for track_id in to_remove:
280
+ del self._tracks[track_id]
281
+
282
+ def _get_active_tracks(self) -> List[Dict]:
283
+ """Get currently active (confirmed) tracks."""
284
+ result = []
285
+ for track in self._tracks.values():
286
+ # Only return confirmed tracks
287
+ if track.hits >= self.min_hits:
288
+ result.append({
289
+ "track_id": track.track_id,
290
+ "label": track.label,
291
+ "bbox_xyxy": track.last_bbox,
292
+ "score": track.last_score,
293
+ "age": track.age,
294
+ "color": track.color,
295
+ })
296
+ return result
297
+
298
+ def _get_all_tracks_with_history(self) -> List[Dict]:
299
+ """Get all tracks with full history."""
300
+ result = []
301
+ for track in self._tracks.values():
302
+ if track.hits >= self.min_hits:
303
+ result.append({
304
+ "track_id": track.track_id,
305
+ "label": track.label,
306
+ "bbox_xyxy": track.last_bbox,
307
+ "score": track.last_score,
308
+ "age": track.age,
309
+ "hits": track.hits,
310
+ "color": track.color,
311
+ "history": [
312
+ {"frame": h[0], "bbox_xyxy": h[1], "score": h[2]}
313
+ for h in track.history
314
+ ],
315
+ })
316
+ return result
317
+
318
+
319
+ def draw_tracks(
320
+ image,
321
+ tracks: List[Dict],
322
+ show_id: bool = True,
323
+ show_trail: bool = False,
324
+ trail_length: int = 10,
325
+ ):
326
+ """Draw tracks on image.
327
+
328
+ Args:
329
+ image: PIL Image
330
+ tracks: List of track dicts
331
+ show_id: Show track ID
332
+ show_trail: Show movement trail
333
+ trail_length: Number of trail points
334
+
335
+ Returns:
336
+ Image with tracks drawn (PIL Image)
337
+ """
338
+ from PIL import Image, ImageDraw, ImageFont
339
+
340
+ img = image.copy()
341
+ draw = ImageDraw.Draw(img)
342
+
343
+ try:
344
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14)
345
+ except Exception:
346
+ font = ImageFont.load_default()
347
+
348
+ for track in tracks:
349
+ bbox = track["bbox_xyxy"]
350
+ color = track.get("color", (0, 255, 0))
351
+ track_id = track.get("track_id", 0)
352
+ label = track.get("label", "")
353
+ score = track.get("score", 0.0)
354
+
355
+ # Draw bounding box
356
+ x1, y1, x2, y2 = [int(c) for c in bbox]
357
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
358
+
359
+ # Draw label with track ID
360
+ if show_id:
361
+ text = f"#{track_id} {label}"
362
+ else:
363
+ text = f"{label} {score:.2f}"
364
+
365
+ # Text background
366
+ text_bbox = draw.textbbox((x1, y1 - 20), text, font=font)
367
+ draw.rectangle(text_bbox, fill=color)
368
+ draw.text((x1, y1 - 20), text, fill=(255, 255, 255), font=font)
369
+
370
+ # Draw trail if history available
371
+ if show_trail and "history" in track:
372
+ history = track["history"][-trail_length:]
373
+ if len(history) > 1:
374
+ centers = []
375
+ for h in history:
376
+ hbbox = h["bbox_xyxy"]
377
+ cx = (hbbox[0] + hbbox[2]) / 2
378
+ cy = (hbbox[1] + hbbox[3]) / 2
379
+ centers.append((int(cx), int(cy)))
380
+
381
+ for i in range(len(centers) - 1):
382
+ # Fade trail
383
+ alpha = (i + 1) / len(centers)
384
+ trail_color = tuple(int(c * alpha) for c in color)
385
+ draw.line([centers[i], centers[i + 1]], fill=trail_color, width=2)
386
+
387
+ return img
detection/utils.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Detection utilities: IoU, matching, drawing, metrics."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import asdict
6
+ from typing import Dict, Iterable, List, Optional, Sequence, Tuple
7
+
8
+ import numpy as np
9
+ from PIL import Image, ImageDraw, ImageFont
10
+
11
+ from .base import Detection
12
+
13
+
14
+ def box_iou_xyxy(a: Sequence[float], b: Sequence[float]) -> float:
15
+ ax1, ay1, ax2, ay2 = a
16
+ bx1, by1, bx2, by2 = b
17
+
18
+ inter_x1 = max(ax1, bx1)
19
+ inter_y1 = max(ay1, by1)
20
+ inter_x2 = min(ax2, bx2)
21
+ inter_y2 = min(ay2, by2)
22
+
23
+ inter_w = max(0.0, inter_x2 - inter_x1)
24
+ inter_h = max(0.0, inter_y2 - inter_y1)
25
+ inter = inter_w * inter_h
26
+
27
+ area_a = max(0.0, ax2 - ax1) * max(0.0, ay2 - ay1)
28
+ area_b = max(0.0, bx2 - bx1) * max(0.0, by2 - by1)
29
+ union = area_a + area_b - inter
30
+
31
+ return float(inter / union) if union > 0 else 0.0
32
+
33
+
34
+ def greedy_match_detections(
35
+ before: List[Detection],
36
+ after: List[Detection],
37
+ iou_threshold: float = 0.5,
38
+ require_same_label: bool = True,
39
+ ) -> List[Tuple[int, int, float]]:
40
+ """Greedy match: for each `before` detection (sorted by score), pick best IoU `after`.
41
+
42
+ Returns list of (before_idx, after_idx, iou).
43
+ """
44
+
45
+ before_order = sorted(range(len(before)), key=lambda i: before[i].score, reverse=True)
46
+ used_after = set()
47
+ matches: List[Tuple[int, int, float]] = []
48
+
49
+ for bi in before_order:
50
+ best = None
51
+ best_iou = 0.0
52
+ for ai in range(len(after)):
53
+ if ai in used_after:
54
+ continue
55
+ if require_same_label and before[bi].label != after[ai].label:
56
+ continue
57
+ iou = box_iou_xyxy(before[bi].bbox_xyxy, after[ai].bbox_xyxy)
58
+ if iou >= iou_threshold and iou > best_iou:
59
+ best = ai
60
+ best_iou = iou
61
+ if best is not None:
62
+ used_after.add(best)
63
+ matches.append((bi, best, float(best_iou)))
64
+
65
+ return matches
66
+
67
+
68
+ def summarize_before_after(
69
+ before: List[Detection],
70
+ after: List[Detection],
71
+ iou_threshold: float = 0.5,
72
+ ) -> Dict:
73
+ matches = greedy_match_detections(before, after, iou_threshold=iou_threshold, require_same_label=True)
74
+
75
+ matched_before = {bi for (bi, _, _) in matches}
76
+ matched_after = {ai for (_, ai, _) in matches}
77
+
78
+ retention = (len(matches) / len(before)) if before else 1.0
79
+ emergence = (len(after) - len(matches))
80
+
81
+ ious = [iou for (_, _, iou) in matches]
82
+ score_before = [before[bi].score for (bi, _, _) in matches]
83
+ score_after = [after[ai].score for (_, ai, _) in matches]
84
+
85
+ return {
86
+ "num_before": len(before),
87
+ "num_after": len(after),
88
+ "matched": len(matches),
89
+ "retention": float(retention),
90
+ "new_after": int(emergence),
91
+ "mean_iou_matched": float(np.mean(ious)) if ious else None,
92
+ "mean_score_before_matched": float(np.mean(score_before)) if score_before else None,
93
+ "mean_score_after_matched": float(np.mean(score_after)) if score_after else None,
94
+ }
95
+
96
+
97
+ def detections_to_dict(dets: List[Detection]) -> List[Dict]:
98
+ return [asdict(d) for d in dets]
99
+
100
+
101
+ def draw_detections(
102
+ image: Image.Image,
103
+ detections: List[Detection],
104
+ max_dets: Optional[int] = 50,
105
+ color: Tuple[int, int, int] = (0, 255, 0),
106
+ ) -> Image.Image:
107
+ img = image.copy()
108
+ draw = ImageDraw.Draw(img)
109
+
110
+ try:
111
+ font = ImageFont.load_default()
112
+ except Exception:
113
+ font = None
114
+
115
+ dets = detections[: max_dets or len(detections)]
116
+ for d in dets:
117
+ x1, y1, x2, y2 = d.bbox_xyxy
118
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
119
+ label = f"{d.label} {d.score:.2f}"
120
+ if font is not None:
121
+ draw.text((x1 + 2, y1 + 2), label, fill=color, font=font)
122
+ else:
123
+ draw.text((x1 + 2, y1 + 2), label, fill=color)
124
+ return img
detection/yolo.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """YOLO detector via Ultralytics."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Dict, List, Optional, Union
6
+
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ from .base import BaseDetector, Detection
11
+ from model_cache import default_checkpoint_path, ensure_default_checkpoint_dirs
12
+
13
+
14
+ class YOLODetector(BaseDetector):
15
+ def __init__(self, device: str = "cuda", model_path: str = default_checkpoint_path("yolo26x.pt"), **kwargs):
16
+ super().__init__(device=device, **kwargs)
17
+ self.model_path = model_path
18
+ self._names: Dict[int, str] = {}
19
+
20
+ def load_model(self) -> None:
21
+ ensure_default_checkpoint_dirs()
22
+ from ultralytics import YOLO
23
+
24
+ self.model = YOLO(self.model_path)
25
+ # ultralytics uses 'cuda:0' style
26
+ self._device_arg = 0 if self.device.startswith("cuda") else "cpu"
27
+ self._names = dict(getattr(self.model, "names", {}) or {})
28
+
29
+ def get_available_classes(self) -> Union[List[str], Dict[str, int], None]:
30
+ if not self._names:
31
+ return None
32
+ # map name->id
33
+ return {name: int(idx) for idx, name in self._names.items()}
34
+
35
+ def detect(self, image: Image.Image, conf_threshold: float = 0.25, **kwargs) -> List[Detection]:
36
+ assert self.model is not None
37
+
38
+ img = np.asarray(image.convert("RGB"))
39
+ res = self.model.predict(source=img, conf=float(conf_threshold), device=self._device_arg, verbose=False)
40
+ if not res:
41
+ return []
42
+
43
+ r0 = res[0]
44
+ if getattr(r0, "boxes", None) is None:
45
+ return []
46
+
47
+ boxes = r0.boxes
48
+ xyxy = boxes.xyxy.detach().cpu().numpy()
49
+ conf = boxes.conf.detach().cpu().numpy()
50
+ cls = boxes.cls.detach().cpu().numpy().astype(int)
51
+
52
+ out: List[Detection] = []
53
+ for (b, s, c) in zip(xyxy, conf, cls):
54
+ label = self._names.get(int(c), str(int(c)))
55
+ out.append(Detection(label=label, score=float(s), bbox_xyxy=[float(x) for x in b.tolist()]))
56
+ return out
57
+
58
+ def detect_batch(self, images: List[Image.Image], conf_threshold: float = 0.25, **kwargs) -> List[List[Detection]]:
59
+ """Batch detection for video processing - much faster than frame-by-frame.
60
+
61
+ Args:
62
+ images: List of PIL Images
63
+ conf_threshold: Confidence threshold
64
+
65
+ Returns:
66
+ List of detection lists, one per image
67
+ """
68
+ assert self.model is not None
69
+
70
+ if not images:
71
+ return []
72
+
73
+ # Convert all images to numpy arrays
74
+ imgs = [np.asarray(img.convert("RGB")) for img in images]
75
+
76
+ # TRUE batch inference - YOLO processes all frames together
77
+ results = self.model.predict(source=imgs, conf=float(conf_threshold), device=self._device_arg, verbose=False)
78
+
79
+ # Parse results for each frame
80
+ all_detections = []
81
+ for result in results:
82
+ if not result or getattr(result, "boxes", None) is None:
83
+ all_detections.append([])
84
+ continue
85
+
86
+ boxes = result.boxes
87
+ xyxy = boxes.xyxy.detach().cpu().numpy()
88
+ conf = boxes.conf.detach().cpu().numpy()
89
+ cls = boxes.cls.detach().cpu().numpy().astype(int)
90
+
91
+ frame_dets = []
92
+ for (b, s, c) in zip(xyxy, conf, cls):
93
+ label = self._names.get(int(c), str(int(c)))
94
+ frame_dets.append(Detection(label=label, score=float(s), bbox_xyxy=[float(x) for x in b.tolist()]))
95
+
96
+ all_detections.append(frame_dets)
97
+
98
+ return all_detections
detection/yolo_world.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """YOLO-World open-vocabulary object detection via Ultralytics."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, List, Optional, Sequence, Union
6
+
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ from .base import BaseDetector, Detection
11
+ from model_cache import default_checkpoint_path, ensure_default_checkpoint_dirs
12
+
13
+
14
+ def _normalize_classes(classes: Optional[Union[str, Sequence[str]]]) -> Optional[List[str]]:
15
+ if classes is None:
16
+ return None
17
+ if isinstance(classes, str):
18
+ parts = [c.strip() for c in classes.split(",")]
19
+ parts = [c for c in parts if c]
20
+ return parts or None
21
+ out = [str(c).strip() for c in classes]
22
+ out = [c for c in out if c]
23
+ return out or None
24
+
25
+
26
+ class YOLOWorldDetector(BaseDetector):
27
+ """YOLO-World wrapper.
28
+
29
+ Usage:
30
+ detector = create_detector('yolo_world', device='cuda')
31
+ dets = detector(image, classes=['person', 'car'])
32
+
33
+ Notes:
34
+ - This is open-vocabulary; `get_available_classes()` returns None.
35
+ - Requires the Ultralytics `YOLOWorld` class (available in ultralytics>=8).
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ device: str = "cuda",
41
+ model_path: str = default_checkpoint_path("yolo26s-world.pt"),
42
+ **kwargs: Any,
43
+ ):
44
+ super().__init__(device=device, **kwargs)
45
+ self.model_path = model_path
46
+ self._device_arg = "cpu"
47
+ self._names: Dict[int, str] = {}
48
+
49
+ def load_model(self) -> None:
50
+ ensure_default_checkpoint_dirs()
51
+ try:
52
+ from ultralytics import YOLOWorld
53
+ except Exception as e:
54
+ raise ImportError("ultralytics is required for YOLO-World detection") from e
55
+
56
+ self.model = YOLOWorld(self.model_path)
57
+ self._device_arg = 0 if self.device.startswith("cuda") else "cpu"
58
+
59
+ def get_available_classes(self):
60
+ return None
61
+
62
+ def detect(self, image: Image.Image, conf_threshold: float = 0.25, **kwargs: Any) -> List[Detection]:
63
+ assert self.model is not None
64
+
65
+ classes = _normalize_classes(kwargs.get("classes") or kwargs.get("prompts") or kwargs.get("labels"))
66
+ if not classes:
67
+ raise ValueError("YOLOWorld requires `classes`/`prompts` (comma-separated string or list) for open-vocabulary detection")
68
+
69
+ # Tell the model which classes to look for.
70
+ # Note: set_classes triggers CLIP text encoding which can have device mismatch issues.
71
+ # We work around this by temporarily moving to CPU for set_classes, then back to GPU for inference.
72
+ if hasattr(self.model, "set_classes"):
73
+ import torch
74
+ try:
75
+ # Try setting classes directly first
76
+ self.model.set_classes(classes)
77
+ except RuntimeError as e:
78
+ if "device" in str(e).lower():
79
+ # Device mismatch - try moving model to CPU temporarily
80
+ if hasattr(self.model, 'model'):
81
+ original_device = next(self.model.model.parameters()).device
82
+ self.model.model.to('cpu')
83
+ self.model.set_classes(classes)
84
+ self.model.model.to(original_device)
85
+ else:
86
+ raise
87
+
88
+ img = np.asarray(image.convert("RGB"))
89
+ res = self.model.predict(source=img, conf=float(conf_threshold), device=self._device_arg, verbose=False)
90
+ if not res:
91
+ return []
92
+
93
+ r0 = res[0]
94
+ # Handle names attribute - can be dict, list, or other types
95
+ raw_names = getattr(r0, "names", {})
96
+ if isinstance(raw_names, dict):
97
+ names = raw_names
98
+ elif isinstance(raw_names, (list, tuple)):
99
+ names = {i: str(n) for i, n in enumerate(raw_names)}
100
+ else:
101
+ names = {}
102
+
103
+ if getattr(r0, "boxes", None) is None:
104
+ return []
105
+
106
+ boxes = r0.boxes
107
+ xyxy = boxes.xyxy.detach().cpu().numpy()
108
+ conf = boxes.conf.detach().cpu().numpy()
109
+ cls = boxes.cls.detach().cpu().numpy().astype(int)
110
+
111
+ out: List[Detection] = []
112
+ for (b, s, c) in zip(xyxy, conf, cls):
113
+ label = names.get(int(c))
114
+ if not label:
115
+ # fallback: index into the user-provided classes if the mapping is missing
116
+ label = classes[int(c)] if 0 <= int(c) < len(classes) else str(int(c))
117
+ out.append(Detection(label=str(label), score=float(s), bbox_xyxy=[float(x) for x in b.tolist()]))
118
+ return out
119
+
120
+ def detect_batch(self, images: List[Image.Image], conf_threshold: float = 0.25, **kwargs: Any) -> List[List[Detection]]:
121
+ """Batch detection for video processing - much faster than frame-by-frame.
122
+
123
+ Args:
124
+ images: List of PIL Images
125
+ conf_threshold: Confidence threshold
126
+ **kwargs: Should include 'classes'/'prompts' for open-vocabulary detection
127
+
128
+ Returns:
129
+ List of detection lists, one per image
130
+ """
131
+ assert self.model is not None
132
+
133
+ if not images:
134
+ return []
135
+
136
+ classes = _normalize_classes(kwargs.get("classes") or kwargs.get("prompts") or kwargs.get("labels"))
137
+ if not classes:
138
+ raise ValueError("YOLOWorld requires `classes`/`prompts` (comma-separated string or list) for open-vocabulary detection")
139
+
140
+ # Set classes for the model
141
+ if hasattr(self.model, "set_classes"):
142
+ import torch
143
+ try:
144
+ self.model.set_classes(classes)
145
+ except RuntimeError as e:
146
+ if "device" in str(e).lower():
147
+ if hasattr(self.model, 'model'):
148
+ original_device = next(self.model.model.parameters()).device
149
+ self.model.model.to('cpu')
150
+ self.model.set_classes(classes)
151
+ self.model.model.to(original_device)
152
+ else:
153
+ raise
154
+
155
+ # Convert all images to numpy arrays
156
+ imgs = [np.asarray(img.convert("RGB")) for img in images]
157
+
158
+ # TRUE batch inference
159
+ results = self.model.predict(source=imgs, conf=float(conf_threshold), device=self._device_arg, verbose=False)
160
+
161
+ # Parse results for each frame
162
+ all_detections = []
163
+ for result in results:
164
+ # Handle names attribute
165
+ raw_names = getattr(result, "names", {})
166
+ if isinstance(raw_names, dict):
167
+ names = raw_names
168
+ elif isinstance(raw_names, (list, tuple)):
169
+ names = {i: str(n) for i, n in enumerate(raw_names)}
170
+ else:
171
+ names = {}
172
+
173
+ frame_dets = []
174
+ if getattr(result, "boxes", None) is not None:
175
+ boxes = result.boxes
176
+ xyxy = boxes.xyxy.detach().cpu().numpy()
177
+ conf = boxes.conf.detach().cpu().numpy()
178
+ cls = boxes.cls.detach().cpu().numpy().astype(int)
179
+
180
+ for (b, s, c) in zip(xyxy, conf, cls):
181
+ label = names.get(int(c))
182
+ if not label:
183
+ label = classes[int(c)] if 0 <= int(c) < len(classes) else str(int(c))
184
+ frame_dets.append(Detection(label=str(label), score=float(s), bbox_xyxy=[float(x) for x in b.tolist()]))
185
+
186
+ all_detections.append(frame_dets)
187
+
188
+ return all_detections
examples.sh ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Example commands for ROI-VAE compression
3
+ # Covers: Car, Building, Person, Boat
4
+
5
+ # Make sure you're in the roi-vae directory
6
+ cd "$(dirname "$0")"
7
+
8
+ # Create results directory if it doesn't exist
9
+ mkdir -p results
10
+
11
+ echo "ROI-VAE Compression Examples"
12
+ echo "============================="
13
+ echo
14
+
15
+ # ---------------------------------------------------------
16
+ # 1. CAR
17
+ # ---------------------------------------------------------
18
+ echo "1. CAR (YOLO)"
19
+ echo "-------------"
20
+
21
+ # 1a. Only Image
22
+ echo "a) Compressing (Image Only)..."
23
+ python roi_compressor.py \
24
+ --input data/images/car/0016cf15fa4d4e16.jpg \
25
+ --output results/car_compressed.jpg \
26
+ --checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
27
+ --sigma 0.3 \
28
+ --seg-method yolo \
29
+ --seg-classes car
30
+
31
+ # 1b. With Comparison
32
+ echo "b) Compressing (With Comparison)..."
33
+ python roi_compressor.py \
34
+ --input data/images/car/0016cf15fa4d4e16.jpg \
35
+ --output results/car_comparison.jpg \
36
+ --checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
37
+ --sigma 0.3 \
38
+ --seg-method yolo \
39
+ --seg-classes car \
40
+ --highlight
41
+
42
+ echo
43
+ echo "---------------------------------------------------------"
44
+ echo
45
+
46
+ # ---------------------------------------------------------
47
+ # 2. BUILDING
48
+ # ---------------------------------------------------------
49
+ echo "2. BUILDING (SegFormer)"
50
+ echo "-----------------------"
51
+
52
+ # 2a. Only Image
53
+ echo "a) Compressing (Image Only)..."
54
+ python roi_compressor.py \
55
+ --input data/images/building/000571767ec7a593.jpg \
56
+ --output results/building_compressed.jpg \
57
+ --checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
58
+ --sigma 0.3 \
59
+ --seg-classes building
60
+
61
+ # 2b. With Comparison
62
+ echo "b) Compressing (With Comparison)..."
63
+ python roi_compressor.py \
64
+ --input data/images/building/000571767ec7a593.jpg \
65
+ --output results/building_comparison.jpg \
66
+ --checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
67
+ --sigma 0.3 \
68
+ --seg-classes building \
69
+ --highlight
70
+
71
+ echo
72
+ echo "---------------------------------------------------------"
73
+ echo
74
+
75
+ # ---------------------------------------------------------
76
+ # 3. PERSON
77
+ # ---------------------------------------------------------
78
+ echo "3. PERSON (YOLO)"
79
+ echo "----------------"
80
+
81
+ # 3a. Only Image
82
+ echo "a) Compressing (Image Only)..."
83
+ python roi_compressor.py \
84
+ --input data/images/person/kodim04.png \
85
+ --output results/person_compressed.jpg \
86
+ --checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
87
+ --sigma 0.3 \
88
+ --seg-method yolo \
89
+ --seg-classes person
90
+
91
+ # 3b. With Comparison
92
+ echo "b) Compressing (With Comparison)..."
93
+ python roi_compressor.py \
94
+ --input data/images/person/kodim04.png \
95
+ --output results/person_comparison.jpg \
96
+ --checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
97
+ --sigma 0.3 \
98
+ --seg-method yolo \
99
+ --seg-classes person \
100
+ --highlight
101
+
102
+ echo
103
+ echo "---------------------------------------------------------"
104
+ echo
105
+
106
+ # ---------------------------------------------------------
107
+ # 4. BOAT
108
+ # ---------------------------------------------------------
109
+ echo "4. BOAT (YOLO)"
110
+ echo "--------------"
111
+
112
+ # 4a. Only Image
113
+ echo "a) Compressing (Image Only)..."
114
+ python roi_compressor.py \
115
+ --input data/images/boat/kodim06.png \
116
+ --output results/boat_compressed.jpg \
117
+ --checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
118
+ --sigma 0.3 \
119
+ --seg-method yolo \
120
+ --seg-classes boat
121
+
122
+ # 4b. With Comparison
123
+ echo "b) Compressing (With Comparison)..."
124
+ python roi_compressor.py \
125
+ --input data/images/boat/kodim06.png \
126
+ --output results/boat_comparison.jpg \
127
+ --checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
128
+ --sigma 0.3 \
129
+ --seg-method yolo \
130
+ --seg-classes boat \
131
+ --highlight
132
+
133
+ echo
134
+ echo "---------------------------------------------------------"
135
+ echo
136
+
137
+ # ---------------------------------------------------------
138
+ # 5. PARAMETER COMPARISON
139
+ # ---------------------------------------------------------
140
+ echo "5. PARAMETER COMPARISON"
141
+ echo "-----------------------"
142
+
143
+ # 5a. Sigma Comparison (Background Quality)
144
+ echo "a) Sigma Comparison (Background Quality)..."
145
+ # Low Sigma (High Compression)
146
+ echo " - Low Sigma (0.1)"
147
+ python roi_compressor.py \
148
+ --input data/images/car/0016cf15fa4d4e16.jpg \
149
+ --output results/sigma_low.jpg \
150
+ --checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
151
+ --sigma 0.1 \
152
+ --seg-method yolo \
153
+ --seg-classes car \
154
+ --highlight
155
+
156
+ # High Sigma (Low Compression)
157
+ echo " - High Sigma (0.9)"
158
+ python roi_compressor.py \
159
+ --input data/images/car/0016cf15fa4d4e16.jpg \
160
+ --output results/sigma_high.jpg \
161
+ --checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
162
+ --sigma 0.9 \
163
+ --seg-method yolo \
164
+ --seg-classes car \
165
+ --highlight
166
+
167
+ # 5b. Lambda Comparison (Rate-Distortion)
168
+ echo "b) Lambda Comparison (Rate-Distortion)..."
169
+ # Low Lambda
170
+ echo " - Low Lambda (0.013)"
171
+ python roi_compressor.py \
172
+ --input data/images/car/0016cf15fa4d4e16.jpg \
173
+ --output results/lambda_low.jpg \
174
+ --checkpoint checkpoints/tic_lambda_0.013.pth.tar \
175
+ --sigma 0.3 \
176
+ --lambda 0.013 \
177
+ --N 128 \
178
+ --M 192 \
179
+ --seg-method yolo \
180
+ --seg-classes car \
181
+ --highlight
182
+
183
+ # High Lambda
184
+ echo " - High Lambda (0.0932)"
185
+ python roi_compressor.py \
186
+ --input data/images/car/0016cf15fa4d4e16.jpg \
187
+ --output results/lambda_high.jpg \
188
+ --checkpoint checkpoints/tic_lambda_0.0932.pth.tar \
189
+ --sigma 0.3 \
190
+ --lambda 0.0932 \
191
+ --seg-method yolo \
192
+ --seg-classes car \
193
+ --highlight
194
+
195
+ echo
196
+ echo "---------------------------------------------------------"
197
+ echo
198
+
199
+ # -------------------------------------------------
200
+ # 5. Standalone Segmentation (Mask R-CNN)
201
+ # -------------------------------------------------
202
+ echo "--- Running Standalone Segmentation (Mask R-CNN) ---"
203
+ python roi_segmenter.py \
204
+ --input images/car/0016cf15fa4d4e16.jpg \
205
+ --output results/mask_rcnn.png \
206
+ --method maskrcnn \
207
+ --classes car \
208
+ --visualize
209
+
210
+ # -------------------------------------------------
211
+ # 6. Video Compression (Static Mode)
212
+ # -------------------------------------------------
213
+ echo "--- Running Video Compression (Static Mode) ---"
214
+ python roi_compressor.py \
215
+ --input data/videos/traffic.mp4 \
216
+ --output results/video_static.mp4 \
217
+ --quality-level 4 \
218
+ --sigma 0.4 \
219
+ --seg-method yolo \
220
+ --seg-classes car person \
221
+ --video-mode static \
222
+ --output-fps 10
223
+
224
+ # -------------------------------------------------
225
+ # 7. Video Compression (Dynamic Mode)
226
+ # -------------------------------------------------
227
+ echo "--- Running Video Compression (Dynamic Mode) ---"
228
+ python roi_compressor.py \
229
+ --input data/videos/traffic.mp4 \
230
+ --output results/video_dynamic.mp4 \
231
+ --quality-level 4 \
232
+ --seg-method yolo \
233
+ --seg-classes car person \
234
+ --video-mode dynamic \
235
+ --target-bandwidth-kbps 800 \
236
+ --min-fps 5 \
237
+ --max-fps 20
238
+
239
+ # -------------------------------------------------
240
+ # 8. Video Detection Evaluation (Static Mode)
241
+ # -------------------------------------------------
242
+ echo "--- Running Video Detection Evaluation (Static Mode) ---"
243
+ python roi_detection_eval.py \
244
+ --before data/videos/traffic.mp4 \
245
+ --video-mode static \
246
+ --sigma 0.3 \
247
+ --output-fps 10 \
248
+ --quality-level 4 \
249
+ --seg-method yolo \
250
+ --seg-classes car person \
251
+ --detectors yolo \
252
+ --max-video-frames 30 \
253
+ --video-sample-interval 3 \
254
+ --save-after results/video_eval_compressed.mp4 \
255
+ --out results/video_detection_eval.json \
256
+ --viz-dir results/video_detection_viz
257
+
258
+ # -------------------------------------------------
259
+ # 9. Video Detection Evaluation (Comparing Two Videos)
260
+ # -------------------------------------------------
261
+ echo "--- Running Video Detection Evaluation (Compare Two Videos) ---"
262
+ python roi_detection_eval.py \
263
+ --before data/videos/traffic.mp4 \
264
+ --after results/video_static.mp4 \
265
+ --detectors yolo \
266
+ --max-video-frames 30 \
267
+ --video-sample-interval 3 \
268
+ --out results/video_compare_eval.json
269
+
270
+ echo
271
+ echo "============================="
272
+ echo "All examples complete! Check results/ directory"
model_cache.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Centralized defaults for model checkpoint/cache locations.
2
+
3
+ Goal: keep all auto-downloaded model artifacts inside this repo's `checkpoints/`
4
+ directory by default (instead of user-wide cache dirs or repo root).
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import os
10
+ from pathlib import Path
11
+
12
+
13
+ PROJECT_ROOT = Path(__file__).resolve().parent
14
+ CHECKPOINTS_DIR = PROJECT_ROOT / "checkpoints"
15
+
16
+ # Hugging Face will create subfolders like `hub/`, `datasets/`, etc under HF_HOME.
17
+ HF_HOME_DIR = CHECKPOINTS_DIR / "hf"
18
+
19
+ # Torchvision uses torch.hub.load_state_dict_from_url which respects TORCH_HOME.
20
+ TORCH_HOME_DIR = CHECKPOINTS_DIR / "torch"
21
+
22
+
23
+ def ensure_default_checkpoint_dirs() -> None:
24
+ """Ensure checkpoint dirs exist and set cache-related env vars.
25
+
26
+ This is intentionally a best-effort helper. If the user has explicitly set
27
+ env vars already, we do not override them.
28
+ """
29
+
30
+ CHECKPOINTS_DIR.mkdir(parents=True, exist_ok=True)
31
+ HF_HOME_DIR.mkdir(parents=True, exist_ok=True)
32
+ TORCH_HOME_DIR.mkdir(parents=True, exist_ok=True)
33
+
34
+ # Hugging Face
35
+ os.environ.setdefault("HF_HOME", str(HF_HOME_DIR))
36
+ # Compatibility env vars across transformers/huggingface-hub versions.
37
+ os.environ.setdefault("TRANSFORMERS_CACHE", str(HF_HOME_DIR / "hub"))
38
+ os.environ.setdefault("HUGGINGFACE_HUB_CACHE", str(HF_HOME_DIR / "hub"))
39
+
40
+ # Torch / torchvision
41
+ os.environ.setdefault("TORCH_HOME", str(TORCH_HOME_DIR))
42
+
43
+
44
+ def hf_cache_dir() -> Path:
45
+ ensure_default_checkpoint_dirs()
46
+ return HF_HOME_DIR
47
+
48
+
49
+ def torch_home_dir() -> Path:
50
+ ensure_default_checkpoint_dirs()
51
+ return TORCH_HOME_DIR
52
+
53
+
54
+ def checkpoints_dir() -> Path:
55
+ ensure_default_checkpoint_dirs()
56
+ return CHECKPOINTS_DIR
57
+
58
+
59
+ def default_checkpoint_path(filename: str) -> str:
60
+ """Return an absolute path under `checkpoints/` for a given filename."""
61
+
62
+ return str(checkpoints_dir() / filename)
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ numpy>=1.24.0
5
+ pillow>=9.5.0
6
+ matplotlib>=3.7.0
7
+ # Prefer headless OpenCV for servers / Hugging Face Spaces.
8
+ opencv-python-headless>=4.7.0
9
+
10
+ # Model dependencies
11
+ compressai>=1.2.4
12
+ timm>=0.9.0
13
+
14
+ # Segmentation
15
+ transformers>=4.36.0
16
+ ultralytics>=8.0.0
17
+
18
+ # Detection (optional)
19
+ effdet>=0.4.1
20
+
21
+ # Demo app (Hugging Face Spaces)
22
+ gradio>=6.2.0,<7
23
+ openai>=1.0.0
24
+ gradio_client
25
+ scipy
26
+ tqdm
roi_compressor.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CLI for ROI-based image/video compression using modular compression framework.
3
+ """
4
+
5
+ import argparse
6
+ import os
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ import numpy as np
11
+ from PIL import Image
12
+ import torch
13
+
14
+ from segmentation import create_segmenter
15
+ from vae import load_checkpoint, compress_image
16
+ from vae.visualization import create_comparison_grid
17
+ from video import VideoProcessor, CompressionSettings
18
+ from video.video_processor import frames_to_video_bytes
19
+
20
+
21
+ # Command-line interface
22
+ def main():
23
+ parser = argparse.ArgumentParser(description="ROI-based Image/Video Compressor.")
24
+ # I/O
25
+ parser.add_argument("--input", required=True, help="Path to input image or video file.")
26
+ parser.add_argument("--output", required=True, help="Path to save compressed output file.")
27
+ parser.add_argument("--checkpoint", default="checkpoints/tic_lambda_0.0483.pth.tar", help="Path to VAE model checkpoint.")
28
+ parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run on.")
29
+
30
+ # General Compression Settings
31
+ parser.add_argument("--quality-level", type=int, default=4, choices=range(7), help="Base quality level (0-6, higher is better). Affects VAE model selection.")
32
+ parser.add_argument("--sigma", type=float, default=0.3, help="Background quality factor (0.01-1.0). Lower means more background compression.")
33
+
34
+ # Segmentation Settings
35
+ parser.add_argument("--seg-method", default="yolo", help="Segmentation method to use.")
36
+ parser.add_argument("--seg-classes", nargs="+", required=True, help="Classes to segment as ROI.")
37
+ parser.add_argument("--seg-model", help="Path to a specific segmentation model checkpoint (optional).")
38
+
39
+ # Video-Specific Settings
40
+ parser.add_argument("--video-mode", choices=['static', 'dynamic'], default='static', help="Video compression mode: 'static' for fixed settings, 'dynamic' for adaptive.")
41
+ parser.add_argument("--target-bandwidth-kbps", type=int, default=1000, help="[Dynamic Mode] Target bandwidth in kbps.")
42
+ parser.add_argument("--output-fps", type=float, default=15.0, help="[Static Mode] Output framerate.")
43
+ parser.add_argument("--min-fps", type=float, default=5.0, help="[Dynamic Mode] Minimum framerate.")
44
+ parser.add_argument("--max-fps", type=float, default=30.0, help="[Dynamic Mode] Maximum framerate.")
45
+ parser.add_argument("--chunk-duration-sec", type=float, default=1.0, help="[Dynamic Mode] Duration of video chunks for analysis.")
46
+
47
+ # Detection/Tracking Settings (for video)
48
+ parser.add_argument("--detection-method", default="yolo", help="Object detection method for video.")
49
+ parser.add_argument("--enable-tracking", action="store_true", help="Enable object tracking in video.")
50
+
51
+ # Visualization
52
+ parser.add_argument("--highlight", action="store_true", help="Create a comparison grid image (for image input only).")
53
+ parser.add_argument("--viz-dir", help="Directory to save visualization artifacts (e.g., masks).")
54
+
55
+ args = parser.parse_args()
56
+
57
+ # --- Input Type Check ---
58
+ input_path = args.input.lower()
59
+ is_video = any(input_path.endswith(ext) for ext in ['.mp4', '.avi', '.mov', '.mkv'])
60
+
61
+ if is_video:
62
+ print("Processing video input...")
63
+ process_video(args)
64
+ else:
65
+ print("Processing image input...")
66
+ process_image(args)
67
+
68
+ def process_image(args):
69
+ """Compresses a single image."""
70
+ print(f"Loading VAE model from {args.checkpoint}...")
71
+ model = load_checkpoint(args.checkpoint, device=args.device)
72
+ model.eval()
73
+
74
+ print(f"Loading segmenter '{args.seg_method}'...")
75
+ seg_kwargs = dict(device=args.device)
76
+ if args.seg_model:
77
+ seg_kwargs['model_path'] = args.seg_model
78
+ segmenter = create_segmenter(args.seg_method, **seg_kwargs)
79
+
80
+ print(f"Loading image from {args.input}...")
81
+ image = Image.open(args.input).convert("RGB")
82
+
83
+ print(f"Segmenting image for classes: {args.seg_classes}...")
84
+ mask = segmenter(image, target_classes=args.seg_classes)
85
+
86
+ if args.viz_dir:
87
+ if not os.path.exists(args.viz_dir):
88
+ os.makedirs(args.viz_dir)
89
+ mask_path = os.path.join(args.viz_dir, "mask.png")
90
+ Image.fromarray((mask * 255).astype(np.uint8)).save(mask_path)
91
+ print(f"Saved segmentation mask to {mask_path}")
92
+
93
+ print(f"Compressing image with sigma={args.sigma}...")
94
+ result = compress_image(
95
+ image,
96
+ mask,
97
+ model,
98
+ sigma=args.sigma,
99
+ device=args.device
100
+ )
101
+ compressed_img = result['compressed']
102
+ bpp = result['bpp']
103
+
104
+ print(f"Saving compressed image to {args.output} (BPP: {bpp:.4f})")
105
+ compressed_img.save(args.output)
106
+
107
+ if args.highlight:
108
+ print("Creating comparison grid...")
109
+ lambda_val = float(os.path.basename(args.checkpoint).split('_')[-1].replace('.pth.tar', ''))
110
+ grid = create_comparison_grid(image, compressed_img, mask, bpp, args.sigma, lambda_val)
111
+ grid_path = args.output.replace(os.path.splitext(args.output)[1], "_comparison.jpg")
112
+ grid.save(grid_path)
113
+ print(f"Saved comparison grid to {grid_path}")
114
+
115
+ def process_video(args):
116
+ """Compresses a video using static or dynamic settings."""
117
+ processor = VideoProcessor(device=args.device)
118
+ print("Loading models for video processing...")
119
+ processor.load_models(
120
+ quality_level=args.quality_level,
121
+ segmentation_method=args.seg_method,
122
+ detection_method=args.detection_method,
123
+ enable_tracking=args.enable_tracking,
124
+ )
125
+
126
+ # Simple progress callback
127
+ def progress_callback(current, total, message):
128
+ percent = int(100 * current / max(1, total))
129
+ print(f"[{percent:3d}%] {message}")
130
+
131
+ if args.video_mode == 'static':
132
+ settings = CompressionSettings(
133
+ mode='static',
134
+ quality_level=args.quality_level,
135
+ sigma=args.sigma,
136
+ output_fps=args.output_fps,
137
+ target_classes=args.seg_classes,
138
+ )
139
+ print(f"Starting STATIC video compression with FPS={settings.output_fps}, Sigma={settings.sigma}...")
140
+ print("Using offline batch processing (GPU memory optimized)...")
141
+ chunks = processor.process_static_offline(args.input, settings, progress_callback=progress_callback)
142
+ else: # dynamic
143
+ settings = CompressionSettings(
144
+ mode='dynamic',
145
+ target_bandwidth_kbps=args.target_bandwidth_kbps,
146
+ min_fps=args.min_fps,
147
+ max_fps=args.max_fps,
148
+ chunk_duration_sec=args.chunk_duration_sec,
149
+ target_classes=args.seg_classes,
150
+ quality_level=args.quality_level,
151
+ )
152
+ print(f"Starting DYNAMIC video compression with Target Bandwidth={settings.target_bandwidth_kbps} kbps...")
153
+ print("Using offline batch processing (GPU memory optimized)...")
154
+ chunks = processor.process_dynamic_offline(args.input, settings, progress_callback=progress_callback)
155
+
156
+ if not chunks:
157
+ print("No frames were processed. Exiting.")
158
+ return
159
+
160
+ # Collect all frames from chunks
161
+ all_frames = []
162
+ for chunk in chunks:
163
+ all_frames.extend(chunk.frames)
164
+
165
+ # Determine the final video's FPS
166
+ if args.video_mode == 'static':
167
+ final_fps = args.output_fps
168
+ else:
169
+ # For dynamic, use weighted average FPS from chunks
170
+ total_frames = sum(len(c.frames) for c in chunks)
171
+ total_duration = sum(len(c.frames) / c.fps for c in chunks)
172
+ final_fps = total_frames / total_duration if total_duration > 0 else args.max_fps
173
+
174
+ print(f"\nRe-encoding {len(all_frames)} frames into final video at ~{final_fps:.2f} FPS...")
175
+ video_bytes = frames_to_video_bytes(all_frames, fps=final_fps)
176
+
177
+ print(f"Saving compressed video to {args.output}...")
178
+ with open(args.output, "wb") as f:
179
+ f.write(video_bytes)
180
+ print("Done.")
181
+
182
+ if __name__ == "__main__":
183
+ main()
roi_detection_eval.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CLI: evaluate object detection before vs after ROI compression.
2
+
3
+ Supports both images and videos.
4
+
5
+ Image modes:
6
+ 1) Compare two images:
7
+ python roi_detection_eval.py --before img.jpg --after img_compressed.jpg --detectors yolo fasterrcnn
8
+
9
+ 2) Create the "after" image via ROI compression, then evaluate:
10
+ python roi_detection_eval.py --before img.jpg \
11
+ --checkpoint checkpoints/tic_lambda_0.0483.pth.tar --sigma 0.3 \
12
+ --seg-method yolo --seg-classes car person \
13
+ --detectors yolo fasterrcnn
14
+
15
+ Video modes:
16
+ 3) Compare two videos:
17
+ python roi_detection_eval.py --before video.mp4 --after video_compressed.mp4 --detectors yolo
18
+
19
+ 4) Create the "after" video via ROI compression, then evaluate:
20
+ python roi_detection_eval.py --before video.mp4 \
21
+ --video-mode static --sigma 0.3 --output-fps 10 \
22
+ --seg-method yolo --seg-classes car person \
23
+ --detectors yolo
24
+
25
+ Outputs JSON summary + optional visualization images/videos.
26
+ """
27
+
28
+ from __future__ import annotations
29
+
30
+ import json
31
+ import os
32
+ import tempfile
33
+ from pathlib import Path
34
+ from typing import Dict, List, Optional, Tuple
35
+
36
+ import numpy as np
37
+ from PIL import Image
38
+
39
+ from detection import create_detector, get_available_detectors
40
+ from detection.utils import (
41
+ detections_to_dict,
42
+ draw_detections,
43
+ summarize_before_after,
44
+ )
45
+
46
+
47
+ VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.mkv', '.webm'}
48
+
49
+
50
+ def _is_video(path: str) -> bool:
51
+ """Check if file is a video based on extension."""
52
+ return Path(path).suffix.lower() in VIDEO_EXTENSIONS
53
+
54
+
55
+ def _load_image(path: str) -> Image.Image:
56
+ return Image.open(path).convert("RGB")
57
+
58
+
59
+ def _normalize_open_vocab_classes(raw: Optional[str]) -> Optional[str]:
60
+ if raw is None:
61
+ return None
62
+ s = str(raw).strip()
63
+ return s or None
64
+
65
+
66
+ def _maybe_make_after_image(
67
+ before_img: Image.Image,
68
+ args,
69
+ ) -> Image.Image:
70
+ """Create after image via ROI compression if --after not provided."""
71
+ if args.after:
72
+ return _load_image(args.after)
73
+
74
+ if not args.checkpoint:
75
+ raise SystemExit("Provide either --after or --checkpoint (to generate after via ROI compression).")
76
+
77
+ from vae import load_checkpoint, compress_image
78
+ from segmentation import create_segmenter
79
+ from segmentation.utils import load_mask
80
+
81
+ if args.mask:
82
+ mask = load_mask(args.mask)
83
+ else:
84
+ if not args.seg_method:
85
+ raise SystemExit("When not using --mask, provide --seg-method.")
86
+
87
+ seg_kwargs = {}
88
+ if args.seg_method == "yolo":
89
+ seg_kwargs["conf_threshold"] = args.seg_conf
90
+ if args.seg_method == "mask2former":
91
+ seg_kwargs["model_type"] = args.seg_model_type
92
+
93
+ segmenter = create_segmenter(args.seg_method, device=args.device, **seg_kwargs)
94
+
95
+ if args.seg_method == "sam3":
96
+ if not args.seg_prompt:
97
+ raise SystemExit("For --seg-method sam3, provide --seg-prompt.")
98
+ mask = segmenter(before_img, target_classes=[args.seg_prompt])
99
+ else:
100
+ if not args.seg_classes:
101
+ raise SystemExit("Provide --seg-classes (or use --seg-method sam3 + --seg-prompt).")
102
+ mask = segmenter(before_img, target_classes=args.seg_classes)
103
+
104
+ model = load_checkpoint(args.checkpoint, N=args.N, M=args.M, device=args.device)
105
+ out = compress_image(before_img, mask, model, sigma=float(args.sigma), device=args.device)
106
+ after_img = out["compressed"]
107
+
108
+ if args.save_after:
109
+ Path(args.save_after).parent.mkdir(parents=True, exist_ok=True)
110
+ after_img.save(args.save_after)
111
+
112
+ return after_img
113
+
114
+
115
+ def _extract_video_frames(
116
+ video_path: str,
117
+ max_frames: Optional[int] = None,
118
+ sample_interval: int = 1,
119
+ ) -> Tuple[List[Image.Image], float]:
120
+ """Extract frames from a video file.
121
+
122
+ Args:
123
+ video_path: Path to video file
124
+ max_frames: Max frames to extract (None = all)
125
+ sample_interval: Extract every Nth frame
126
+
127
+ Returns:
128
+ (frames, fps)
129
+ """
130
+ try:
131
+ import cv2
132
+ except ImportError:
133
+ raise ImportError("opencv-python required for video processing")
134
+
135
+ cap = cv2.VideoCapture(video_path)
136
+ if not cap.isOpened():
137
+ raise ValueError(f"Cannot open video: {video_path}")
138
+
139
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
140
+ frames = []
141
+ frame_idx = 0
142
+
143
+ while True:
144
+ ret, frame = cap.read()
145
+ if not ret:
146
+ break
147
+
148
+ if frame_idx % sample_interval == 0:
149
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
150
+ frames.append(Image.fromarray(frame_rgb))
151
+
152
+ frame_idx += 1
153
+
154
+ if max_frames is not None and len(frames) >= max_frames:
155
+ break
156
+
157
+ cap.release()
158
+ return frames, fps
159
+
160
+
161
+ def _maybe_make_after_video(
162
+ before_path: str,
163
+ args,
164
+ ) -> Tuple[str, List[Image.Image], List[Image.Image]]:
165
+ """Create after video via ROI compression if --after not provided.
166
+
167
+ Returns:
168
+ (after_video_path, before_frames, after_frames)
169
+ """
170
+ from video import VideoProcessor, CompressionSettings
171
+ from video.video_processor import frames_to_video_bytes
172
+
173
+ # Extract frames from before video
174
+ sample_interval = max(1, int(args.video_sample_interval))
175
+ before_frames, fps = _extract_video_frames(
176
+ before_path,
177
+ max_frames=args.max_video_frames,
178
+ sample_interval=sample_interval,
179
+ )
180
+ print(f"Extracted {len(before_frames)} frames from before video (sampled every {sample_interval} frames)")
181
+
182
+ if args.after:
183
+ # Load after video frames
184
+ after_frames, _ = _extract_video_frames(
185
+ args.after,
186
+ max_frames=args.max_video_frames,
187
+ sample_interval=sample_interval,
188
+ )
189
+ print(f"Extracted {len(after_frames)} frames from after video")
190
+ return args.after, before_frames, after_frames
191
+
192
+ # Create compressed video
193
+ print("Creating compressed video via ROI compression...")
194
+
195
+ processor = VideoProcessor(device=args.device)
196
+ processor.load_models(
197
+ quality_level=args.quality_level,
198
+ segmentation_method=args.seg_method or "yolo",
199
+ detection_method="yolo",
200
+ enable_tracking=False,
201
+ )
202
+
203
+ if args.video_mode == "static":
204
+ settings = CompressionSettings(
205
+ mode="static",
206
+ quality_level=args.quality_level,
207
+ sigma=args.sigma,
208
+ output_fps=args.output_fps,
209
+ target_classes=args.seg_classes or [],
210
+ enable_tracking=False,
211
+ )
212
+ chunks = processor.process_static(before_path, settings)
213
+ else: # dynamic
214
+ settings = CompressionSettings(
215
+ mode="dynamic",
216
+ target_bandwidth_kbps=args.target_bandwidth_kbps,
217
+ min_fps=args.min_fps,
218
+ max_fps=args.max_fps,
219
+ chunk_duration_sec=args.chunk_duration_sec,
220
+ target_classes=args.seg_classes or [],
221
+ quality_level=args.quality_level,
222
+ enable_tracking=False,
223
+ )
224
+ chunks = processor.process_dynamic(before_path, settings)
225
+
226
+ # Collect compressed frames
227
+ all_compressed_frames = []
228
+ for chunk in chunks:
229
+ all_compressed_frames.extend(chunk.frames)
230
+ if args.max_video_frames and len(all_compressed_frames) >= args.max_video_frames:
231
+ all_compressed_frames = all_compressed_frames[:args.max_video_frames]
232
+ break
233
+
234
+ print(f"Compressed {len(all_compressed_frames)} frames")
235
+
236
+ # Sample the after frames at the same interval as before frames
237
+ # Note: The compression may change frame count, so we align by time proportion
238
+ after_frames = []
239
+ if len(all_compressed_frames) >= len(before_frames):
240
+ # Sample at regular intervals
241
+ step = len(all_compressed_frames) / len(before_frames)
242
+ for i in range(len(before_frames)):
243
+ idx = min(int(i * step), len(all_compressed_frames) - 1)
244
+ after_frames.append(all_compressed_frames[idx])
245
+ else:
246
+ # Use all available frames
247
+ after_frames = all_compressed_frames
248
+ # Extend before_frames to match
249
+ before_frames = before_frames[:len(after_frames)]
250
+
251
+ # Save compressed video if requested
252
+ after_path = args.save_after
253
+ if after_path:
254
+ video_bytes = frames_to_video_bytes(all_compressed_frames, fps=args.output_fps)
255
+ Path(after_path).parent.mkdir(parents=True, exist_ok=True)
256
+ with open(after_path, "wb") as f:
257
+ f.write(video_bytes)
258
+ print(f"Saved compressed video to {after_path}")
259
+ else:
260
+ # Create temp file
261
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
262
+ after_path = tmp.name
263
+ video_bytes = frames_to_video_bytes(all_compressed_frames, fps=args.output_fps)
264
+ with open(after_path, "wb") as f:
265
+ f.write(video_bytes)
266
+
267
+ return after_path, before_frames, after_frames
268
+
269
+
270
+ def _evaluate_video_detections(
271
+ before_frames: List[Image.Image],
272
+ after_frames: List[Image.Image],
273
+ detector,
274
+ det_kwargs: Dict,
275
+ iou_threshold: float,
276
+ ) -> Dict:
277
+ """Evaluate detection across video frames.
278
+
279
+ Returns:
280
+ Summary dict with per-frame and aggregate stats
281
+ """
282
+ frame_results = []
283
+ total_before_dets = 0
284
+ total_after_dets = 0
285
+ total_matched = 0
286
+ total_lost = 0
287
+ total_new = 0
288
+
289
+ for i, (before_frame, after_frame) in enumerate(zip(before_frames, after_frames)):
290
+ before_dets = detector(before_frame, **det_kwargs)
291
+ after_dets = detector(after_frame, **det_kwargs)
292
+
293
+ summary = summarize_before_after(before_dets, after_dets, iou_threshold=iou_threshold)
294
+
295
+ # Extract values using correct keys from summarize_before_after
296
+ num_before = summary["num_before"]
297
+ num_after = summary["num_after"]
298
+ matched = summary["matched"]
299
+ lost = num_before - matched
300
+ new_dets = summary["new_after"]
301
+
302
+ frame_results.append({
303
+ "frame_index": i,
304
+ "before_count": num_before,
305
+ "after_count": num_after,
306
+ "matched": matched,
307
+ "lost": lost,
308
+ "new_detections": new_dets,
309
+ })
310
+
311
+ total_before_dets += num_before
312
+ total_after_dets += num_after
313
+ total_matched += matched
314
+ total_lost += lost
315
+ total_new += new_dets
316
+
317
+ retention_rate = total_matched / max(total_before_dets, 1)
318
+
319
+ return {
320
+ "total_frames": len(before_frames),
321
+ "total_before_detections": total_before_dets,
322
+ "total_after_detections": total_after_dets,
323
+ "total_matched": total_matched,
324
+ "total_lost": total_lost,
325
+ "total_new": total_new,
326
+ "retention_rate": retention_rate,
327
+ "avg_before_per_frame": total_before_dets / max(len(before_frames), 1),
328
+ "avg_after_per_frame": total_after_dets / max(len(after_frames), 1),
329
+ "frame_results": frame_results,
330
+ }
331
+
332
+
333
+ def _create_video_visualization(
334
+ before_frames: List[Image.Image],
335
+ after_frames: List[Image.Image],
336
+ before_dets_list: List[List],
337
+ after_dets_list: List[List],
338
+ output_dir: Path,
339
+ method: str,
340
+ fps: float = 10.0,
341
+ ) -> None:
342
+ """Create visualization videos with detections drawn."""
343
+ from video.video_processor import frames_to_video_bytes
344
+
345
+ # Draw detections on frames
346
+ before_viz = []
347
+ after_viz = []
348
+
349
+ for i, (bf, af) in enumerate(zip(before_frames, after_frames)):
350
+ b_dets = before_dets_list[i] if i < len(before_dets_list) else []
351
+ a_dets = after_dets_list[i] if i < len(after_dets_list) else []
352
+
353
+ before_viz.append(draw_detections(bf, b_dets, color=(0, 255, 0)))
354
+ after_viz.append(draw_detections(af, a_dets, color=(255, 0, 0)))
355
+
356
+ # Save visualization videos
357
+ before_bytes = frames_to_video_bytes(before_viz, fps=fps)
358
+ after_bytes = frames_to_video_bytes(after_viz, fps=fps)
359
+
360
+ with open(output_dir / f"{method}_before.mp4", "wb") as f:
361
+ f.write(before_bytes)
362
+ with open(output_dir / f"{method}_after.mp4", "wb") as f:
363
+ f.write(after_bytes)
364
+
365
+
366
+ def main_image(args) -> None:
367
+ """Run detection evaluation on images."""
368
+ from model_cache import default_checkpoint_path
369
+
370
+ before_img = _load_image(args.before)
371
+ after_img = _maybe_make_after_image(before_img, args)
372
+
373
+ det_methods = args.detectors
374
+ if len(det_methods) == 1 and det_methods[0].lower() == "all":
375
+ det_methods = get_available_detectors()
376
+
377
+ results: Dict = {
378
+ "type": "image",
379
+ "before": str(Path(args.before)),
380
+ "after": str(Path(args.after)) if args.after else None,
381
+ "det_conf": float(args.det_conf),
382
+ "iou_threshold": float(args.iou),
383
+ "open_vocab_classes": _normalize_open_vocab_classes(args.open_vocab_classes),
384
+ "detectors": {},
385
+ }
386
+
387
+ viz_dir = Path(args.viz_dir) if args.viz_dir else None
388
+ if viz_dir:
389
+ viz_dir.mkdir(parents=True, exist_ok=True)
390
+
391
+ for method in det_methods:
392
+ det_kwargs = _get_detector_kwargs(method, args)
393
+ detector = create_detector(method, device=args.device, **det_kwargs)
394
+
395
+ call_kwargs = {"conf_threshold": args.det_conf}
396
+ if method in {"yolo_world", "grounding_dino"}:
397
+ ov = _normalize_open_vocab_classes(args.open_vocab_classes)
398
+ if not ov:
399
+ raise SystemExit(
400
+ f"Detector '{method}' is open-vocabulary; provide --open-vocab-classes (e.g. 'person,car')."
401
+ )
402
+ call_kwargs["classes"] = ov
403
+
404
+ before_dets = detector(before_img, **call_kwargs)
405
+ after_dets = detector(after_img, **call_kwargs)
406
+
407
+ summary = summarize_before_after(before_dets, after_dets, iou_threshold=args.iou)
408
+
409
+ results["detectors"][method] = {
410
+ "summary": summary,
411
+ "before_detections": detections_to_dict(before_dets),
412
+ "after_detections": detections_to_dict(after_dets),
413
+ }
414
+
415
+ if viz_dir:
416
+ b = draw_detections(before_img, before_dets, color=(0, 255, 0))
417
+ a = draw_detections(after_img, after_dets, color=(255, 0, 0))
418
+ b.save(viz_dir / f"{method}_before.png")
419
+ a.save(viz_dir / f"{method}_after.png")
420
+
421
+ out_path = Path(args.out)
422
+ out_path.parent.mkdir(parents=True, exist_ok=True)
423
+ out_path.write_text(json.dumps(results, indent=2))
424
+
425
+ print(f"Wrote: {out_path}")
426
+ if viz_dir:
427
+ print(f"Wrote visualizations to: {viz_dir}")
428
+
429
+
430
+ def main_video(args) -> None:
431
+ """Run detection evaluation on videos."""
432
+ print(f"Video detection evaluation: {args.before}")
433
+
434
+ after_path, before_frames, after_frames = _maybe_make_after_video(args.before, args)
435
+
436
+ det_methods = args.detectors
437
+ if len(det_methods) == 1 and det_methods[0].lower() == "all":
438
+ det_methods = get_available_detectors()
439
+
440
+ results: Dict = {
441
+ "type": "video",
442
+ "before": str(Path(args.before)),
443
+ "after": str(Path(after_path)),
444
+ "det_conf": float(args.det_conf),
445
+ "iou_threshold": float(args.iou),
446
+ "open_vocab_classes": _normalize_open_vocab_classes(args.open_vocab_classes),
447
+ "video_settings": {
448
+ "mode": args.video_mode,
449
+ "sigma": args.sigma,
450
+ "output_fps": args.output_fps,
451
+ "quality_level": args.quality_level,
452
+ "frames_evaluated": len(before_frames),
453
+ },
454
+ "detectors": {},
455
+ }
456
+
457
+ viz_dir = Path(args.viz_dir) if args.viz_dir else None
458
+ if viz_dir:
459
+ viz_dir.mkdir(parents=True, exist_ok=True)
460
+
461
+ for method in det_methods:
462
+ print(f"Evaluating with detector: {method}")
463
+ det_kwargs = _get_detector_kwargs(method, args)
464
+ detector = create_detector(method, device=args.device, **det_kwargs)
465
+
466
+ call_kwargs = {"conf_threshold": args.det_conf}
467
+ if method in {"yolo_world", "grounding_dino"}:
468
+ ov = _normalize_open_vocab_classes(args.open_vocab_classes)
469
+ if not ov:
470
+ raise SystemExit(
471
+ f"Detector '{method}' is open-vocabulary; provide --open-vocab-classes (e.g. 'person,car')."
472
+ )
473
+ call_kwargs["classes"] = ov
474
+
475
+ # Evaluate across all frames
476
+ video_summary = _evaluate_video_detections(
477
+ before_frames,
478
+ after_frames,
479
+ detector,
480
+ call_kwargs,
481
+ args.iou,
482
+ )
483
+
484
+ results["detectors"][method] = {
485
+ "summary": {
486
+ "total_frames": video_summary["total_frames"],
487
+ "retention_rate": video_summary["retention_rate"],
488
+ "total_before_detections": video_summary["total_before_detections"],
489
+ "total_after_detections": video_summary["total_after_detections"],
490
+ "total_matched": video_summary["total_matched"],
491
+ "total_lost": video_summary["total_lost"],
492
+ "avg_before_per_frame": video_summary["avg_before_per_frame"],
493
+ "avg_after_per_frame": video_summary["avg_after_per_frame"],
494
+ },
495
+ "per_frame_results": video_summary["frame_results"],
496
+ }
497
+
498
+ print(f" Retention rate: {video_summary['retention_rate']:.2%}")
499
+ print(f" Avg detections: {video_summary['avg_before_per_frame']:.1f} before, {video_summary['avg_after_per_frame']:.1f} after")
500
+
501
+ if viz_dir:
502
+ # Create visualization videos
503
+ print(f" Creating visualization videos...")
504
+ before_dets_list = []
505
+ after_dets_list = []
506
+ for bf, af in zip(before_frames, after_frames):
507
+ before_dets_list.append(detector(bf, **call_kwargs))
508
+ after_dets_list.append(detector(af, **call_kwargs))
509
+
510
+ _create_video_visualization(
511
+ before_frames,
512
+ after_frames,
513
+ before_dets_list,
514
+ after_dets_list,
515
+ viz_dir,
516
+ method,
517
+ fps=args.output_fps,
518
+ )
519
+
520
+ out_path = Path(args.out)
521
+ out_path.parent.mkdir(parents=True, exist_ok=True)
522
+ out_path.write_text(json.dumps(results, indent=2))
523
+
524
+ print(f"\nWrote: {out_path}")
525
+ if viz_dir:
526
+ print(f"Wrote visualizations to: {viz_dir}")
527
+
528
+
529
+ def _get_detector_kwargs(method: str, args) -> Dict:
530
+ """Get detector-specific kwargs."""
531
+ det_kwargs = {}
532
+ if method == "yolo":
533
+ det_kwargs["model_path"] = args.yolo_weights
534
+ if method == "yolo_world":
535
+ det_kwargs["model_path"] = args.yolo_world_weights
536
+ if method == "efficientdet":
537
+ det_kwargs["model_name"] = args.efficientdet_name
538
+ if method == "detr":
539
+ det_kwargs["model_name"] = args.detr_model
540
+ if method == "deformable_detr":
541
+ det_kwargs["model_name"] = args.deformable_detr_model
542
+ if method == "grounding_dino":
543
+ det_kwargs["model_name"] = args.grounding_dino_model
544
+ return det_kwargs
545
+
546
+
547
+ def main() -> None:
548
+ import argparse
549
+
550
+ from model_cache import default_checkpoint_path
551
+
552
+ parser = argparse.ArgumentParser(
553
+ description="Evaluate object detection before vs after ROI compression (images or videos)"
554
+ )
555
+
556
+ # Input/Output
557
+ parser.add_argument("--before", required=True, help="Path to original (before) image or video")
558
+ parser.add_argument("--after", help="Path to already-compressed (after) image or video")
559
+ parser.add_argument("--out", default="results/detection_eval.json", help="Where to write JSON results")
560
+ parser.add_argument("--viz-dir", default=None, help="If set, write visualization images/videos here")
561
+ parser.add_argument("--save-after", help="Save generated after image/video here")
562
+
563
+ # ROI compression (if --after is not provided)
564
+ parser.add_argument("--checkpoint", help="TIC checkpoint to generate after image (images only)")
565
+ parser.add_argument("--sigma", type=float, default=0.3, help="Background quality for ROI compression")
566
+ parser.add_argument("--mask", help="Optional mask path to use for ROI compression (images only)")
567
+ parser.add_argument("--seg-method", default=None, help="Segmentation method to build mask")
568
+ parser.add_argument("--seg-classes", nargs="+", default=None, help="Segmentation classes")
569
+ parser.add_argument("--seg-prompt", default=None, help="Segmentation prompt (sam3)")
570
+ parser.add_argument("--seg-conf", type=float, default=0.25, help="Segmentation conf for yolo")
571
+ parser.add_argument("--seg-model-type", default="coco", help="mask2former model_type")
572
+
573
+ # Video-specific settings
574
+ parser.add_argument("--video-mode", choices=["static", "dynamic"], default="static",
575
+ help="Video compression mode")
576
+ parser.add_argument("--quality-level", type=int, default=4, help="Quality level (1-5)")
577
+ parser.add_argument("--output-fps", type=float, default=10.0, help="Output FPS for compressed video")
578
+ parser.add_argument("--target-bandwidth-kbps", type=float, default=500.0,
579
+ help="Target bandwidth for dynamic mode")
580
+ parser.add_argument("--min-fps", type=float, default=5.0, help="Min FPS for dynamic mode")
581
+ parser.add_argument("--max-fps", type=float, default=30.0, help="Max FPS for dynamic mode")
582
+ parser.add_argument("--chunk-duration-sec", type=float, default=1.0, help="Chunk duration for dynamic mode")
583
+ parser.add_argument("--max-video-frames", type=int, default=100,
584
+ help="Max frames to evaluate from video (for efficiency)")
585
+ parser.add_argument("--video-sample-interval", type=int, default=5,
586
+ help="Sample every Nth frame from videos")
587
+
588
+ # Detection settings
589
+ parser.add_argument(
590
+ "--detectors",
591
+ nargs="+",
592
+ default=["yolo"],
593
+ help=f"Detectors to run (or 'all'). Available: {', '.join(get_available_detectors())}",
594
+ )
595
+ parser.add_argument("--det-conf", type=float, default=0.25, help="Detection confidence threshold")
596
+ parser.add_argument("--iou", type=float, default=0.5, help="IoU threshold for matching before↔after detections")
597
+
598
+ # Open-vocabulary detection
599
+ parser.add_argument(
600
+ "--open-vocab-classes",
601
+ default=None,
602
+ help="Comma-separated class prompts for open-vocabulary detectors",
603
+ )
604
+
605
+ parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"], help="Device")
606
+ parser.add_argument("--N", type=int, default=192)
607
+ parser.add_argument("--M", type=int, default=192)
608
+
609
+ # Per-detector config
610
+ parser.add_argument(
611
+ "--yolo-weights",
612
+ default=default_checkpoint_path("yolo26x.pt"),
613
+ help="Ultralytics YOLO weights path/name",
614
+ )
615
+ parser.add_argument(
616
+ "--yolo-world-weights",
617
+ default=default_checkpoint_path("yolo26s-world.pt"),
618
+ help="Ultralytics YOLO-World weights path/name",
619
+ )
620
+ parser.add_argument("--efficientdet-name", default="tf_efficientdet_d0", help="EfficientDet model name")
621
+ parser.add_argument("--detr-model", default="facebook/detr-resnet-50", help="DETR model name")
622
+ parser.add_argument("--deformable-detr-model", default="SenseTime/deformable-detr", help="Deformable DETR model")
623
+ parser.add_argument(
624
+ "--grounding-dino-model",
625
+ default="IDEA-Research/grounding-dino-base",
626
+ help="GroundingDINO model name",
627
+ )
628
+
629
+ args = parser.parse_args()
630
+
631
+ # Determine if input is image or video
632
+ if _is_video(args.before):
633
+ main_video(args)
634
+ else:
635
+ main_image(args)
636
+
637
+
638
+ if __name__ == "__main__":
639
+ main()
roi_segmenter.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CLI for ROI segmentation using modular segmentation framework.
3
+ Supports both image and video input with batched processing.
4
+ """
5
+
6
+ import sys
7
+ import time
8
+ from pathlib import Path
9
+ import numpy as np
10
+ from PIL import Image
11
+ import torch
12
+
13
+ from segmentation import create_segmenter
14
+ from segmentation.utils import visualize_mask, save_mask, calculate_roi_stats
15
+ from video import estimate_batch_sizes, smooth_masks_sdf
16
+
17
+
18
+ VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.mkv', '.webm'}
19
+
20
+
21
+ def is_video(path: str) -> bool:
22
+ """Check if file is a video based on extension."""
23
+ return Path(path).suffix.lower() in VIDEO_EXTENSIONS
24
+
25
+
26
+ def extract_video_frames(video_path: str, target_height: int = None) -> tuple[list[Image.Image], float]:
27
+ """Extract frames from video file.
28
+
29
+ Args:
30
+ video_path: Path to video file
31
+ target_height: Optional height to resize frames to (maintains aspect ratio)
32
+
33
+ Returns:
34
+ (frames, fps)
35
+ """
36
+ try:
37
+ import cv2
38
+ except ImportError:
39
+ raise ImportError("opencv-python required for video processing. Install with: pip install opencv-python")
40
+
41
+ cap = cv2.VideoCapture(video_path)
42
+ if not cap.isOpened():
43
+ raise ValueError(f"Cannot open video: {video_path}")
44
+
45
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
46
+ frames = []
47
+
48
+ while True:
49
+ ret, frame = cap.read()
50
+ if not ret:
51
+ break
52
+
53
+ if target_height:
54
+ h, w = frame.shape[:2]
55
+ scale = target_height / h
56
+ new_w = max(1, int(w * scale))
57
+ frame = cv2.resize(frame, (new_w, target_height))
58
+
59
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
60
+ frames.append(Image.fromarray(frame_rgb))
61
+
62
+ cap.release()
63
+ return frames, fps
64
+
65
+
66
+ def save_masks_as_video(masks: list[np.ndarray], output_path: str, fps: float = 30.0) -> None:
67
+ """Save mask sequence as video file."""
68
+ try:
69
+ import cv2
70
+ except ImportError:
71
+ raise ImportError("opencv-python required. Install with: pip install opencv-python")
72
+
73
+ if not masks:
74
+ raise ValueError("No masks to save")
75
+
76
+ h, w = masks[0].shape
77
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
78
+ writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h), isColor=False)
79
+
80
+ for mask in masks:
81
+ mask_uint8 = (mask * 255).astype(np.uint8)
82
+ writer.write(mask_uint8)
83
+
84
+ writer.release()
85
+
86
+
87
+ def process_image(args):
88
+ """Process single image input."""
89
+ print(f"Segmenting {args.input}...")
90
+ print(f" Method: {args.method}")
91
+ if args.method == 'sam3':
92
+ print(f" Prompt: {args.prompt or ' '.join(args.classes)}")
93
+ else:
94
+ print(f" Classes: {args.classes}")
95
+
96
+ # Load image
97
+ image = Image.open(args.input).convert('RGB')
98
+ print(f" Image size: {image.size}")
99
+
100
+ # Create segmenter and mask
101
+ seg_kwargs = {}
102
+ if args.method == 'yolo':
103
+ seg_kwargs['conf_threshold'] = args.conf_threshold
104
+
105
+ segmenter = create_segmenter(args.method, device=args.device, **seg_kwargs)
106
+
107
+ targets = args.classes
108
+ if args.method == 'sam3' and args.prompt:
109
+ targets = [args.prompt]
110
+
111
+ mask = segmenter(image, target_classes=targets)
112
+
113
+ # Calculate statistics
114
+ stats = calculate_roi_stats(mask)
115
+ print(f" ROI coverage: {stats['roi_percentage']:.2f}% "
116
+ f"({stats['roi_pixels']}/{stats['total_pixels']} pixels)")
117
+
118
+ # Save mask
119
+ save_mask(mask, args.output)
120
+ print(f" Mask saved: {args.output}")
121
+
122
+ # Save visualization if requested
123
+ if args.visualize:
124
+ viz_path = args.output.replace('.', '_overlay.')
125
+ viz_img = visualize_mask(image, mask, alpha=0.5, color=(255, 0, 0))
126
+ viz_img.save(viz_path)
127
+ print(f" Visualization saved: {viz_path}")
128
+
129
+ print("Done!")
130
+
131
+
132
+ def process_video(args):
133
+ """Process video input with batched segmentation."""
134
+ print(f"Processing video: {args.input}")
135
+ print(f" Method: {args.method}")
136
+ print(f" Classes: {args.classes}")
137
+
138
+ # Extract frames
139
+ print("Extracting frames...")
140
+ frames, fps = extract_video_frames(args.input, target_height=args.resize_height)
141
+ print(f" Extracted {len(frames)} frames at {fps:.2f} FPS")
142
+
143
+ # Create segmenter
144
+ seg_kwargs = {}
145
+ if args.method == 'yolo':
146
+ seg_kwargs['conf_threshold'] = args.conf_threshold
147
+
148
+ segmenter = create_segmenter(args.method, device=args.device, **seg_kwargs)
149
+
150
+ targets = args.classes
151
+ if args.method == 'sam3' and args.prompt:
152
+ targets = [args.prompt]
153
+
154
+ # Check if batched segmentation is supported
155
+ supports_batch = getattr(segmenter, 'supports_batch', False)
156
+
157
+ print(f" Segmenter supports batching: {supports_batch}")
158
+
159
+ # Estimate optimal batch size if GPU is being used
160
+ if args.device == 'cuda' and supports_batch:
161
+ try:
162
+ # Get first frame dimensions for estimation
163
+ sample_w, sample_h = frames[0].size
164
+ batch_info = estimate_batch_sizes(
165
+ device=args.device,
166
+ seg_method=args.method,
167
+ frame_width=sample_w,
168
+ frame_height=sample_h,
169
+ total_frames=len(frames),
170
+ )
171
+ recommended_batch = batch_info.seg_batch_size
172
+ print(f" GPU Memory Estimation:")
173
+ print(f" Free VRAM: {batch_info.free_vram_bytes / (1024**3):.2f} GB")
174
+ print(f" Recommended batch size: {recommended_batch}")
175
+ if batch_info.notes:
176
+ print(f" {batch_info.notes}")
177
+ except Exception as e:
178
+ print(f" Warning: Could not estimate GPU batch size: {e}")
179
+ recommended_batch = None
180
+ else:
181
+ recommended_batch = None
182
+
183
+ # Segment frames with OOM retry logic
184
+ t0 = time.perf_counter()
185
+ if supports_batch and hasattr(segmenter, 'segment_batch'):
186
+ print(f" Segmenting {len(frames)} frames in batches...")
187
+
188
+ # OOM retry logic
189
+ max_retries = 7
190
+ retry_count = 0
191
+ masks = None
192
+ current_batch_size = recommended_batch # None means use segmenter's default
193
+
194
+ while retry_count <= max_retries:
195
+ try:
196
+ if current_batch_size is not None:
197
+ # Segment in manual batches
198
+ masks = []
199
+ for i in range(0, len(frames), current_batch_size):
200
+ batch = frames[i:i + current_batch_size]
201
+ print(f" Batch {i//current_batch_size + 1}: frames {i}-{i+len(batch)-1} (batch_size={len(batch)})")
202
+ batch_masks = segmenter.segment_batch(batch, target_classes=targets)
203
+ masks.extend(batch_masks)
204
+ else:
205
+ # Let segmenter handle batching internally
206
+ masks = segmenter.segment_batch(frames, target_classes=targets)
207
+
208
+ # Success - break retry loop
209
+ break
210
+
211
+ except torch.cuda.OutOfMemoryError as e:
212
+ retry_count += 1
213
+ if retry_count > max_retries:
214
+ print(f" ERROR: Out of memory after {max_retries} retries. Try reducing --resize-height.")
215
+ raise
216
+
217
+ # Halve the batch size
218
+ if current_batch_size is None:
219
+ # Initial OOM with default batching - start with a reasonable size
220
+ current_batch_size = max(1, len(frames) // 4)
221
+ else:
222
+ current_batch_size = max(1, current_batch_size // 2)
223
+
224
+ print(f" Out of memory! Retry {retry_count}/{max_retries} with batch_size={current_batch_size}")
225
+ torch.cuda.empty_cache()
226
+ masks = None # Reset
227
+
228
+ if masks is None:
229
+ raise RuntimeError("Segmentation failed after all retries")
230
+ else:
231
+ print(f" Segmenting {len(frames)} frames sequentially...")
232
+ masks = []
233
+ for i, frame in enumerate(frames):
234
+ if (i + 1) % 10 == 0 or i == 0:
235
+ print(f" Frame {i+1}/{len(frames)}")
236
+ masks.append(segmenter(frame, target_classes=targets))
237
+
238
+ t1 = time.perf_counter()
239
+ total_time = t1 - t0
240
+
241
+ print(f" Total segmentation time: {total_time:.3f} s")
242
+ print(f" Average per-frame: {total_time/len(frames):.4f} s ({len(frames)/total_time:.2f} fps)")
243
+
244
+ # Apply SDF temporal smoothing to reduce jitter
245
+ if not args.no_smooth and len(masks) > 2:
246
+ print(f" Applying SDF temporal smoothing (alpha={args.smooth_alpha}, patience={args.smooth_patience})...")
247
+ t_smooth_start = time.perf_counter()
248
+ masks = smooth_masks_sdf(
249
+ masks,
250
+ alpha=args.smooth_alpha,
251
+ empty_thresh=10,
252
+ patience=args.smooth_patience,
253
+ )
254
+ t_smooth_end = time.perf_counter()
255
+ print(f" Smoothing time: {t_smooth_end - t_smooth_start:.3f} s")
256
+
257
+ # Calculate aggregate statistics
258
+ total_roi_pixels = sum(m.sum() for m in masks)
259
+ total_pixels = sum(m.size for m in masks)
260
+ avg_coverage = (total_roi_pixels / total_pixels * 100) if total_pixels > 0 else 0.0
261
+ print(f" Average ROI coverage: {avg_coverage:.2f}%")
262
+
263
+ # Save output
264
+ output_path = Path(args.output)
265
+ if args.save_frames:
266
+ # Save individual mask frames
267
+ output_dir = output_path.parent / f"{output_path.stem}_frames"
268
+ output_dir.mkdir(parents=True, exist_ok=True)
269
+ print(f" Saving {len(masks)} mask frames to {output_dir}/")
270
+ for i, mask in enumerate(masks):
271
+ frame_path = output_dir / f"mask_{i:06d}.png"
272
+ save_mask(mask, str(frame_path))
273
+ print(f" Saved {len(masks)} frames")
274
+ else:
275
+ # Save as video
276
+ print(f" Saving mask video to {args.output}")
277
+ save_masks_as_video(masks, str(output_path), fps=fps)
278
+ print(f" Saved mask video")
279
+
280
+ # Save visualization if requested
281
+ if args.visualize:
282
+ viz_dir = output_path.parent / f"{output_path.stem}_viz"
283
+ viz_dir.mkdir(parents=True, exist_ok=True)
284
+ print(f" Creating visualization video...")
285
+
286
+ viz_frames = [visualize_mask(frame, mask, alpha=0.5, color=(255, 0, 0))
287
+ for frame, mask in zip(frames, masks)]
288
+
289
+ # Save as video
290
+ try:
291
+ import cv2
292
+ viz_path = viz_dir / "overlay.mp4"
293
+ h, w = viz_frames[0].size[1], viz_frames[0].size[0]
294
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
295
+ writer = cv2.VideoWriter(str(viz_path), fourcc, fps, (w, h), isColor=True)
296
+
297
+ for vf in viz_frames:
298
+ frame_np = np.array(vf)
299
+ frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR)
300
+ writer.write(frame_bgr)
301
+
302
+ writer.release()
303
+ print(f" Visualization saved: {viz_path}")
304
+ except Exception as e:
305
+ print(f" Warning: Could not save visualization video: {e}")
306
+
307
+ print("Done!")
308
+
309
+
310
+ # Command-line interface
311
+ if __name__ == '__main__':
312
+ import argparse
313
+ from segmentation import get_available_methods
314
+
315
+ # Get available segmentation methods dynamically
316
+ available_methods = get_available_methods()
317
+
318
+ parser = argparse.ArgumentParser(description='Segment objects in images or videos')
319
+ parser.add_argument('--input', required=True, help='Input image or video path')
320
+ parser.add_argument('--output', required=True, help='Output mask path (image/video) or directory')
321
+ parser.add_argument('--method', default='yolo',
322
+ choices=available_methods,
323
+ help='Segmentation method (including fake_* methods for detection+tracking)')
324
+ parser.add_argument('--classes', nargs='+', default=['car'],
325
+ help='Target classes to segment')
326
+ parser.add_argument('--prompt', type=str, default=None,
327
+ help='Natural language prompt (use with --method sam3)')
328
+ parser.add_argument('--conf-threshold', type=float, default=0.25,
329
+ help='Confidence threshold')
330
+ parser.add_argument('--visualize', action='store_true',
331
+ help='Save visualization with overlay')
332
+ parser.add_argument('--device', default='cuda', choices=['cuda', 'cpu'],
333
+ help='Device to run on')
334
+
335
+ # Video-specific options
336
+ parser.add_argument('--resize-height', type=int, default=None,
337
+ help='Resize video frames to this height (maintains aspect ratio)')
338
+ parser.add_argument('--save-frames', action='store_true',
339
+ help='For videos: save masks as individual frames instead of video')
340
+
341
+ # Temporal smoothing options
342
+ parser.add_argument('--no-smooth', action='store_true',
343
+ help='Disable temporal smoothing (may cause jitter)')
344
+ parser.add_argument('--smooth-alpha', type=float, default=0.5,
345
+ help='SDF smoothing factor (0.1=slow/viscous, 0.9=fast/reactive)')
346
+ parser.add_argument('--smooth-patience', type=int, default=5,
347
+ help='Frames to tolerate dropouts before decay (0=immediate, 5=conservative, 15=aggressive)')
348
+
349
+ args = parser.parse_args()
350
+
351
+ # Determine if input is image or video
352
+ if is_video(args.input):
353
+ process_video(args)
354
+ else:
355
+ process_image(args)
segmentation/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ROI Segmentation Module
3
+
4
+ Provides abstract base class and concrete implementations for various
5
+ segmentation models (YOLO, SegFormer, Mask2Former, Mask R-CNN, SAM3, etc.) used in ROI-based compression.
6
+ """
7
+
8
+ from .base import BaseSegmenter
9
+ from .segformer import SegFormerSegmenter
10
+ from .yolo import YOLOSegmenter
11
+ from .mask2former import Mask2FormerSegmenter
12
+ from .maskrcnn import MaskRCNNSegmenter
13
+ from .sam3 import SAM3Segmenter
14
+ from .fake import FakeSegmenter
15
+ from .factory import create_segmenter, get_available_methods, register_segmenter
16
+ from .utils import visualize_mask, save_mask, load_mask, calculate_roi_stats
17
+
18
+ __all__ = [
19
+ 'BaseSegmenter',
20
+ 'SegFormerSegmenter',
21
+ 'YOLOSegmenter',
22
+ 'Mask2FormerSegmenter',
23
+ 'MaskRCNNSegmenter',
24
+ 'SAM3Segmenter',
25
+ 'FakeSegmenter',
26
+ 'create_segmenter',
27
+ 'get_available_methods',
28
+ 'register_segmenter',
29
+ 'visualize_mask',
30
+ 'save_mask',
31
+ 'load_mask',
32
+ 'calculate_roi_stats',
33
+ ]
segmentation/base.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Abstract base class for segmentation models.
3
+ """
4
+
5
+ from abc import ABC, abstractmethod
6
+ from typing import List, Optional, Union, Dict, Any
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+
11
+ class BaseSegmenter(ABC):
12
+ """
13
+ Abstract base class for all segmentation models.
14
+
15
+ This class defines the common interface that all segmentation models
16
+ must implement. Subclasses can handle different types of inputs:
17
+ - Class-based segmentation (YOLO, SegFormer): List of class names
18
+ - Natural language segmentation (SAM, CLIP-based): Text prompts
19
+ - Point/box-based segmentation (SAM): Coordinates
20
+ """
21
+
22
+ def __init__(self, device: str = 'cuda', **kwargs):
23
+ """
24
+ Initialize the segmenter.
25
+
26
+ Args:
27
+ device: Device to run inference on ('cuda' or 'cpu')
28
+ **kwargs: Model-specific parameters
29
+ """
30
+ self.device = device
31
+ self.model = None
32
+ self._is_loaded = False
33
+
34
+ @abstractmethod
35
+ def load_model(self):
36
+ """
37
+ Load the segmentation model.
38
+
39
+ This method should load the model weights and prepare the model
40
+ for inference. Called automatically before first use.
41
+ """
42
+ pass
43
+
44
+ @abstractmethod
45
+ def segment(
46
+ self,
47
+ image: Image.Image,
48
+ target_classes: Optional[List[str]] = None,
49
+ **kwargs
50
+ ) -> np.ndarray:
51
+ """
52
+ Create binary segmentation mask for ROI.
53
+
54
+ Args:
55
+ image: PIL Image to segment
56
+ target_classes: List of target classes or text prompts
57
+ **kwargs: Model-specific parameters (e.g., confidence threshold)
58
+
59
+ Returns:
60
+ Binary mask as numpy array (H, W) with values 0 or 1
61
+ - 1: Region of Interest (ROI)
62
+ - 0: Background
63
+ """
64
+ pass
65
+
66
+ @abstractmethod
67
+ def get_available_classes(self) -> Union[List[str], Dict[str, int]]:
68
+ """
69
+ Get list or mapping of classes this model can segment.
70
+
71
+ Returns:
72
+ List of class names or dict mapping class names to IDs
73
+ """
74
+ pass
75
+
76
+ def validate_classes(self, target_classes: Optional[List[str]]) -> List[str]:
77
+ """
78
+ Validate and filter target classes against available classes.
79
+
80
+ Args:
81
+ target_classes: List of requested class names
82
+
83
+ Returns:
84
+ List of valid class names
85
+ """
86
+ if target_classes is None:
87
+ return self.get_default_classes()
88
+
89
+ available_classes = self.get_available_classes()
90
+ if isinstance(available_classes, dict):
91
+ available_classes = list(available_classes.keys())
92
+
93
+ valid_classes = []
94
+ for cls in target_classes:
95
+ cls_lower = cls.lower()
96
+ if cls_lower in [c.lower() for c in available_classes]:
97
+ valid_classes.append(cls)
98
+ else:
99
+ print(f"Warning: '{cls}' not in {self.__class__.__name__} classes.")
100
+
101
+ if not valid_classes:
102
+ print(f"Warning: No valid classes found. Using defaults.")
103
+ valid_classes = self.get_default_classes()
104
+
105
+ return valid_classes
106
+
107
+ def segment_batch(
108
+ self,
109
+ images: List[Image.Image],
110
+ target_classes: Optional[List[str]] = None,
111
+ **kwargs
112
+ ) -> List[np.ndarray]:
113
+ """
114
+ Segment a batch of images.
115
+
116
+ Default implementation processes images sequentially. Subclasses
117
+ should override this with a true batched forward pass when the
118
+ underlying model supports it.
119
+
120
+ Args:
121
+ images: List of PIL Images (ideally same resolution)
122
+ target_classes: List of target classes or text prompts
123
+ **kwargs: Model-specific parameters
124
+
125
+ Returns:
126
+ List of binary masks (H, W) – one per input image
127
+ """
128
+ self.ensure_loaded()
129
+ return [self.segment(img, target_classes, **kwargs) for img in images]
130
+
131
+ # Whether the model supports true GPU-batched inference.
132
+ # Subclasses should set this to True if segment_batch uses a
133
+ # single forward pass rather than a loop.
134
+ supports_batch: bool = False
135
+
136
+ def get_default_classes(self) -> List[str]:
137
+ """
138
+ Get default classes to segment if none specified.
139
+
140
+ Returns:
141
+ List of default class names
142
+ """
143
+ return ['car'] # Default fallback
144
+
145
+ def ensure_loaded(self):
146
+ """Ensure model is loaded before use."""
147
+ if not self._is_loaded:
148
+ self.load_model()
149
+ self._is_loaded = True
150
+
151
+ def __call__(
152
+ self,
153
+ image: Image.Image,
154
+ target_classes: Optional[List[str]] = None,
155
+ **kwargs
156
+ ) -> np.ndarray:
157
+ """
158
+ Convenience method to call segment().
159
+
160
+ Args:
161
+ image: PIL Image to segment
162
+ target_classes: List of target classes or text prompts
163
+ **kwargs: Model-specific parameters
164
+
165
+ Returns:
166
+ Binary mask as numpy array (H, W)
167
+ """
168
+ self.ensure_loaded()
169
+ return self.segment(image, target_classes, **kwargs)
segmentation/factory.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Factory for creating segmentation models.
3
+ """
4
+
5
+ from typing import Dict, Type, Optional, List, Callable, Union
6
+ from .base import BaseSegmenter
7
+ from .segformer import SegFormerSegmenter
8
+ from .yolo import YOLOSegmenter
9
+ from .mask2former import Mask2FormerSegmenter
10
+ from .maskrcnn import MaskRCNNSegmenter
11
+ from .sam3 import SAM3Segmenter
12
+ from .fake import FakeSegmenter
13
+
14
+
15
+ # Registry of available segmentation methods
16
+ # Can be either a class or a factory function
17
+ SEGMENTER_REGISTRY: Dict[str, Union[Type[BaseSegmenter], Callable]] = {
18
+ 'segformer': SegFormerSegmenter,
19
+ 'yolo': YOLOSegmenter,
20
+ 'mask2former': Mask2FormerSegmenter,
21
+ 'maskrcnn': MaskRCNNSegmenter,
22
+ 'sam3': SAM3Segmenter,
23
+ }
24
+
25
+
26
+ def _register_fake_methods():
27
+ """Register fake segmentation methods (detection + tracking → bbox masks)."""
28
+ fake_configs = [
29
+ ('fake_yolo', 'yolo', 'bytetrack'),
30
+ ('fake_yolo_botsort', 'yolo', 'botsort'),
31
+ ('fake_detr', 'detr', 'bytetrack'),
32
+ ('fake_deformable_detr', 'deformable_detr', 'bytetrack'),
33
+ ('fake_fasterrcnn', 'fasterrcnn', 'bytetrack'),
34
+ ('fake_retinanet', 'retinanet', 'bytetrack'),
35
+ ('fake_fcos', 'fcos', 'bytetrack'),
36
+ ('fake_grounding_dino', 'grounding_dino', 'bytetrack'),
37
+ ]
38
+
39
+ for name, detector, tracker in fake_configs:
40
+ # Create a factory function for each config
41
+ # Use default arguments to capture values properly in closure
42
+ def make_factory(det=detector, default_tracker=tracker):
43
+ def factory(**kwargs):
44
+ # Allow overriding tracker_type, otherwise use default
45
+ if 'tracker_type' not in kwargs:
46
+ kwargs['tracker_type'] = default_tracker
47
+ return FakeSegmenter(detector_name=det, **kwargs)
48
+ return factory
49
+
50
+ SEGMENTER_REGISTRY[name] = make_factory()
51
+
52
+
53
+ # Register fake methods on import
54
+ _register_fake_methods()
55
+
56
+
57
+ def register_segmenter(name: str, segmenter_class: Type[BaseSegmenter]):
58
+ """
59
+ Register a new segmentation method.
60
+
61
+ Args:
62
+ name: Method name (e.g., 'sam', 'drone_detector')
63
+ segmenter_class: Segmenter class that extends BaseSegmenter
64
+ """
65
+ if not issubclass(segmenter_class, BaseSegmenter):
66
+ raise ValueError(f"{segmenter_class} must extend BaseSegmenter")
67
+ SEGMENTER_REGISTRY[name.lower()] = segmenter_class
68
+
69
+
70
+ def create_segmenter(
71
+ method: str,
72
+ device: str = 'cuda',
73
+ **kwargs
74
+ ) -> BaseSegmenter:
75
+ """
76
+ Factory function to create a segmentation model.
77
+
78
+ Args:
79
+ method: Segmentation method name ('segformer', 'yolo', 'fake_yolo', etc.)
80
+ device: Device to run on ('cuda' or 'cpu')
81
+ **kwargs: Method-specific parameters
82
+
83
+ Returns:
84
+ Instance of the requested segmenter
85
+
86
+ Raises:
87
+ ValueError: If method is not recognized
88
+
89
+ Example:
90
+ >>> segmenter = create_segmenter('yolo', device='cuda', conf_threshold=0.3)
91
+ >>> mask = segmenter(image, target_classes=['car', 'person'])
92
+
93
+ >>> # Use detection-based fake segmentation with tracking
94
+ >>> fake_seg = create_segmenter('fake_yolo', device='cuda')
95
+ >>> mask = fake_seg(image, target_classes=['person'])
96
+ """
97
+ method_lower = method.lower()
98
+
99
+ if method_lower not in SEGMENTER_REGISTRY:
100
+ available = ', '.join(sorted(SEGMENTER_REGISTRY.keys()))
101
+ raise ValueError(
102
+ f"Unknown segmentation method: '{method}'. "
103
+ f"Available methods: {available}"
104
+ )
105
+
106
+ factory = SEGMENTER_REGISTRY[method_lower]
107
+
108
+ # Handle both class constructors and factory functions
109
+ if callable(factory) and not isinstance(factory, type):
110
+ # It's a factory function (for fake segmenters)
111
+ return factory(device=device, **kwargs)
112
+ else:
113
+ # It's a class constructor
114
+ return factory(device=device, **kwargs)
115
+
116
+
117
+ def get_available_methods() -> List[str]:
118
+ """
119
+ Get list of available segmentation methods.
120
+
121
+ Returns:
122
+ List of method names
123
+ """
124
+ return list(SEGMENTER_REGISTRY.keys())
segmentation/fake.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Detection-based 'fake' segmentation using bounding boxes + object tracking.
2
+
3
+ Creates rectangular masks from detection bounding boxes and maintains object
4
+ identity across frames using tracking (ByteTrack, BoTSORT, or SimpleTracker).
5
+
6
+ Now supports batch detection for efficiency: detects all frames in batches,
7
+ then runs tracking sequentially on the results.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import numpy as np
13
+ from PIL import Image
14
+ from typing import List, Optional, Dict, Any
15
+
16
+ from .base import BaseSegmenter
17
+
18
+
19
+ class FakeSegmenter(BaseSegmenter):
20
+ """
21
+ Detection-based segmentation that creates rectangular masks from bboxes.
22
+
23
+ Uses object tracking to maintain consistent masks across video frames:
24
+ - ByteTrack (default for YOLO)
25
+ - BoTSORT (available for YOLO)
26
+ - SimpleTracker (fallback for non-YOLO detectors)
27
+
28
+ This is useful for fast ROI extraction when pixel-perfect masks aren't needed.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ device: str = 'cuda',
34
+ detector_name: str = 'yolo',
35
+ tracker_type: str = 'bytetrack',
36
+ conf_threshold: float = 0.25,
37
+ model_path: Optional[str] = None,
38
+ **kwargs
39
+ ):
40
+ """
41
+ Initialize fake segmenter.
42
+
43
+ Args:
44
+ device: Device to run on ('cuda' or 'cpu')
45
+ detector_name: Detection method ('yolo', 'detr', 'faster_rcnn', etc.)
46
+ tracker_type: Tracker type ('bytetrack', 'botsort', 'simple')
47
+ conf_threshold: Confidence threshold for detections
48
+ model_path: Optional path to detector weights
49
+ **kwargs: Additional parameters
50
+ """
51
+ super().__init__(device=device, **kwargs)
52
+ self.detector_name = detector_name.lower()
53
+ self.tracker_type = tracker_type.lower()
54
+ self.conf_threshold = conf_threshold
55
+ self.model_path = model_path
56
+
57
+ # Detector (set in load_model())
58
+ self.detector = None
59
+
60
+ # State for tracking
61
+ self._tracker = None
62
+ self._frame_count = 0
63
+
64
+ # Batch support
65
+ self.supports_batch = True
66
+
67
+ def load_model(self):
68
+ """Load detector model and initialize tracker."""
69
+ from detection import create_detector
70
+
71
+ # Create detector
72
+ if self.model_path:
73
+ self.detector = create_detector(
74
+ self.detector_name,
75
+ device=self.device,
76
+ model_path=self.model_path,
77
+ )
78
+ else:
79
+ self.detector = create_detector(
80
+ self.detector_name,
81
+ device=self.device,
82
+ )
83
+
84
+ # Ensure detector model is loaded
85
+ if not self.detector._is_loaded:
86
+ self.detector.load_model()
87
+
88
+ # Initialize tracker (now all trackers work with any detector!)
89
+ if self.tracker_type == 'bytetrack':
90
+ from detection.bytetrack import ByteTracker
91
+ self._tracker = ByteTracker(
92
+ track_thresh=0.5,
93
+ match_thresh=0.8,
94
+ track_buffer=30,
95
+ frame_rate=30,
96
+ )
97
+ elif self.tracker_type == 'botsort':
98
+ from detection.bytetrack import BoTSORT
99
+ self._tracker = BoTSORT(
100
+ track_thresh=0.5,
101
+ match_thresh=0.8,
102
+ track_buffer=30,
103
+ frame_rate=30,
104
+ )
105
+ else: # simple
106
+ from detection.tracker import SimpleTracker
107
+ self._tracker = SimpleTracker(
108
+ iou_threshold=0.3,
109
+ max_age=30,
110
+ min_hits=1,
111
+ label_match=True,
112
+ )
113
+
114
+ print(f"Loaded FakeSegmenter: {self.detector_name} + {self.tracker_type} tracking")
115
+
116
+ def reset_tracking(self):
117
+ """Reset tracker state (call between videos)."""
118
+ self._frame_count = 0
119
+ if self._tracker is not None:
120
+ self._tracker.reset()
121
+
122
+ def _create_bbox_mask(
123
+ self,
124
+ width: int,
125
+ height: int,
126
+ detections: List[Dict],
127
+ ) -> np.ndarray:
128
+ """Create binary mask from bounding boxes.
129
+
130
+ Args:
131
+ width: Image width
132
+ height: Image height
133
+ detections: List of detections with 'bbox_xyxy' key
134
+
135
+ Returns:
136
+ Binary mask (H, W) with 1.0 where bboxes are
137
+ """
138
+ mask = np.zeros((height, width), dtype=np.float32)
139
+
140
+ for det in detections:
141
+ bbox = det.get('bbox_xyxy', det.get('bbox'))
142
+ if bbox is None:
143
+ continue
144
+
145
+ x1, y1, x2, y2 = bbox
146
+ x1 = int(max(0, min(x1, width - 1)))
147
+ y1 = int(max(0, min(y1, height - 1)))
148
+ x2 = int(max(0, min(x2, width - 1)))
149
+ y2 = int(max(0, min(y2, height - 1)))
150
+
151
+ if x2 > x1 and y2 > y1:
152
+ mask[y1:y2, x1:x2] = 1.0
153
+
154
+ return mask
155
+
156
+ def segment(
157
+ self,
158
+ image: Image.Image,
159
+ target_classes: Optional[List[str]] = None,
160
+ conf_threshold: Optional[float] = None,
161
+ **kwargs
162
+ ) -> np.ndarray:
163
+ """
164
+ Create segmentation mask from detections.
165
+
166
+ Args:
167
+ image: Input PIL Image
168
+ target_classes: Classes to detect (None = all classes)
169
+ conf_threshold: Override default confidence threshold
170
+ **kwargs: Additional parameters
171
+
172
+ Returns:
173
+ Binary mask (H, W) as float32
174
+ """
175
+ if conf_threshold is None:
176
+ conf_threshold = self.conf_threshold
177
+
178
+ width, height = image.size
179
+
180
+ # Run detection
181
+ # Pass classes for open-vocabulary detectors (Grounding DINO, YOLO-World)
182
+ detect_kwargs = {"conf_threshold": conf_threshold}
183
+ if target_classes:
184
+ detect_kwargs["classes"] = target_classes
185
+ detections = self.detector.detect(image, **detect_kwargs)
186
+
187
+ # Convert to dict format
188
+ det_dicts = [
189
+ {
190
+ 'label': d.label,
191
+ 'score': d.score,
192
+ 'bbox_xyxy': d.bbox_xyxy,
193
+ }
194
+ for d in detections
195
+ ]
196
+
197
+ # Update tracker
198
+ if self._tracker is not None:
199
+ tracks = self._tracker.update(det_dicts)
200
+ # Convert tracks to detection format
201
+ detections = track_dicts_to_detections(tracks)
202
+ else:
203
+ detections = det_dicts
204
+
205
+ # Filter by target classes if specified
206
+ if target_classes:
207
+ target_lower = [tc.lower() for tc in target_classes]
208
+ detections = [
209
+ d for d in detections
210
+ if any(tc in d['label'].lower() for tc in target_lower)
211
+ ]
212
+
213
+ # Create mask from bboxes
214
+ mask = self._create_bbox_mask(width, height, detections)
215
+ self._frame_count += 1
216
+
217
+ return mask
218
+
219
+ def _detect_with_yolo_tracking(
220
+ self,
221
+ image: Image.Image,
222
+ conf_threshold: float,
223
+ ) -> List[Dict]:
224
+ """Run YOLO detection with built-in tracking.
225
+
226
+ Returns:
227
+ List of detections with track IDs
228
+ """
229
+ img = np.asarray(image.convert('RGB'))
230
+
231
+ # Determine device argument
232
+ device_arg = 0 if self.detector.device.startswith('cuda') else 'cpu'
233
+
234
+ # Use YOLO's .track() method instead of .predict()
235
+ results = self.detector.model.track(
236
+ source=img,
237
+ conf=conf_threshold,
238
+ device=device_arg,
239
+ verbose=False,
240
+ tracker=f'{self.tracker_type}.yaml', # bytetrack.yaml or botsort.yaml
241
+ persist=True, # Persist tracks between frames
242
+ )
243
+
244
+ if not results:
245
+ return []
246
+
247
+ r0 = results[0]
248
+ if not hasattr(r0, 'boxes') or r0.boxes is None:
249
+ return []
250
+
251
+ boxes = r0.boxes
252
+ xyxy = boxes.xyxy.detach().cpu().numpy()
253
+ conf = boxes.conf.detach().cpu().numpy()
254
+ cls = boxes.cls.detach().cpu().numpy().astype(int)
255
+
256
+ # Get track IDs if available
257
+ track_ids = None
258
+ if hasattr(boxes, 'id') and boxes.id is not None:
259
+ track_ids = boxes.id.detach().cpu().numpy().astype(int)
260
+
261
+ detections = []
262
+ for i, (bbox, score, class_id) in enumerate(zip(xyxy, conf, cls)):
263
+ label = self.detector._names.get(int(class_id), str(int(class_id)))
264
+ det = {
265
+ 'label': label,
266
+ 'score': float(score),
267
+ 'bbox_xyxy': [float(x) for x in bbox.tolist()],
268
+ }
269
+ if track_ids is not None:
270
+ det['track_id'] = int(track_ids[i])
271
+ detections.append(det)
272
+
273
+ return detections
274
+
275
+ def segment_batch(
276
+ self,
277
+ images: List[Image.Image],
278
+ target_classes: Optional[List[str]] = None,
279
+ conf_threshold: Optional[float] = None,
280
+ **kwargs
281
+ ) -> List[np.ndarray]:
282
+ """
283
+ Batch segmentation for video processing using offline batch detection + sequential tracking.
284
+
285
+ This is much more efficient:
286
+ 1. Batch detect all frames at once (or in batches if memory limited)
287
+ 2. Run tracker sequentially on detection results
288
+ 3. Create masks from tracked detections
289
+
290
+ Args:
291
+ images: List of PIL Images
292
+ target_classes: Classes to detect
293
+ conf_threshold: Confidence threshold
294
+ **kwargs: Additional parameters
295
+
296
+ Returns:
297
+ List of binary masks
298
+ """
299
+ # Ensure model is loaded
300
+ self.ensure_loaded()
301
+
302
+ if conf_threshold is None:
303
+ conf_threshold = self.conf_threshold
304
+
305
+ if not images:
306
+ return []
307
+
308
+ # Step 1: Batch detect all frames (TRUE batch inference for speed)
309
+ all_detections = []
310
+ # Prepare detection kwargs (classes for open-vocabulary detectors)
311
+ detect_kwargs = {"conf_threshold": conf_threshold}
312
+ if target_classes:
313
+ detect_kwargs["classes"] = target_classes
314
+
315
+ if hasattr(self.detector, 'detect_batch'):
316
+ # Use batch detection for efficiency (GPU parallelization)
317
+ batch_dets = self.detector.detect_batch(images, **detect_kwargs)
318
+ for dets in batch_dets:
319
+ det_dicts = [
320
+ {
321
+ 'label': d.label,
322
+ 'score': d.score,
323
+ 'bbox_xyxy': d.bbox_xyxy,
324
+ }
325
+ for d in dets
326
+ ]
327
+ all_detections.append(det_dicts)
328
+ else:
329
+ # Fallback to frame-by-frame for detectors without batch support
330
+ for image in images:
331
+ dets = self.detector.detect(image, **detect_kwargs)
332
+ det_dicts = [
333
+ {
334
+ 'label': d.label,
335
+ 'score': d.score,
336
+ 'bbox_xyxy': d.bbox_xyxy,
337
+ }
338
+ for d in dets
339
+ ]
340
+ all_detections.append(det_dicts)
341
+
342
+ # Step 2: Run tracker sequentially on all detections
343
+ tracked_detections = []
344
+ if self._tracker is not None:
345
+ for frame_dets in all_detections:
346
+ tracks = self._tracker.update(frame_dets)
347
+ # Convert tracks to detection format
348
+ frame_tracked = track_dicts_to_detections(tracks)
349
+ tracked_detections.append(frame_tracked)
350
+ else:
351
+ tracked_detections = all_detections
352
+
353
+ # Step 3: Filter by target classes and create masks
354
+ masks = []
355
+ for i, (image, detections) in enumerate(zip(images, tracked_detections)):
356
+ width, height = image.size
357
+
358
+ # Filter by target classes if specified
359
+ if target_classes:
360
+ target_lower = [tc.lower() for tc in target_classes]
361
+ detections = [
362
+ d for d in detections
363
+ if any(tc in d['label'].lower() for tc in target_lower)
364
+ ]
365
+
366
+ # Create mask from bboxes
367
+ mask = self._create_bbox_mask(width, height, detections)
368
+ masks.append(mask)
369
+ self._frame_count += 1
370
+
371
+ return masks
372
+
373
+ def get_available_classes(self) -> List[str]:
374
+ """Get list of classes the detector can detect."""
375
+ classes = self.detector.get_available_classes()
376
+
377
+ if isinstance(classes, dict):
378
+ return sorted(classes.keys())
379
+ elif isinstance(classes, list):
380
+ return classes
381
+ else:
382
+ return []
383
+
384
+ def get_default_classes(self) -> List[str]:
385
+ """Get default classes for common use cases."""
386
+ # Common COCO classes
387
+ return ['person', 'car', 'truck', 'bus', 'bicycle', 'motorcycle']
388
+
389
+
390
+ def track_dicts_to_detections(tracks: List[Dict]) -> List[Dict]:
391
+ """Convert tracker output to detection format.
392
+
393
+ Args:
394
+ tracks: List of track dicts from tracker
395
+
396
+ Returns:
397
+ List of detection dicts
398
+ """
399
+ detections = []
400
+ for track in tracks:
401
+ det = {
402
+ 'label': track.get('label', ''),
403
+ 'score': track.get('score', track.get('last_score', 0.0)),
404
+ 'bbox_xyxy': track.get('bbox_xyxy', track.get('last_bbox', [])),
405
+ }
406
+ # Add track_id if available
407
+ if 'track_id' in track:
408
+ det['track_id'] = track['track_id']
409
+
410
+ detections.append(det)
411
+
412
+ return detections
segmentation/mask2former.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mask2Former segmentation using Swin Transformer backbone.
3
+ Supports both COCO and ADE20K pre-trained models.
4
+ """
5
+
6
+ import torch
7
+ import numpy as np
8
+ import cv2
9
+ from PIL import Image
10
+ from typing import List, Optional, Dict
11
+ from .base import BaseSegmenter
12
+ from model_cache import hf_cache_dir, ensure_default_checkpoint_dirs
13
+
14
+
15
+ class Mask2FormerSegmenter(BaseSegmenter):
16
+ """
17
+ Mask2Former segmentation with Swin Transformer backbone.
18
+
19
+ Supports both COCO (133 classes) and ADE20K (150 classes) datasets.
20
+ """
21
+
22
+ # COCO panoptic categories (133 classes including stuff)
23
+ COCO_CLASSES = {
24
+ 'person': 0, 'bicycle': 1, 'car': 2, 'motorcycle': 3, 'airplane': 4,
25
+ 'bus': 5, 'train': 6, 'truck': 7, 'boat': 8, 'traffic light': 9,
26
+ 'fire hydrant': 10, 'stop sign': 11, 'parking meter': 12, 'bench': 13,
27
+ 'bird': 14, 'cat': 15, 'dog': 16, 'horse': 17, 'sheep': 18, 'cow': 19,
28
+ 'elephant': 20, 'bear': 21, 'zebra': 22, 'giraffe': 23, 'backpack': 24,
29
+ 'umbrella': 25, 'handbag': 26, 'tie': 27, 'suitcase': 28, 'frisbee': 29,
30
+ 'skis': 30, 'snowboard': 31, 'sports ball': 32, 'kite': 33, 'baseball bat': 34,
31
+ 'baseball glove': 35, 'skateboard': 36, 'surfboard': 37, 'tennis racket': 38,
32
+ 'bottle': 39, 'wine glass': 40, 'cup': 41, 'fork': 42, 'knife': 43,
33
+ 'spoon': 44, 'bowl': 45, 'banana': 46, 'apple': 47, 'sandwich': 48,
34
+ 'orange': 49, 'broccoli': 50, 'carrot': 51, 'hot dog': 52, 'pizza': 53,
35
+ 'donut': 54, 'cake': 55, 'chair': 56, 'couch': 57, 'potted plant': 58,
36
+ 'bed': 59, 'dining table': 60, 'toilet': 61, 'tv': 62, 'laptop': 63,
37
+ 'mouse': 64, 'remote': 65, 'keyboard': 66, 'cell phone': 67, 'microwave': 68,
38
+ 'oven': 69, 'toaster': 70, 'sink': 71, 'refrigerator': 72, 'book': 73,
39
+ 'clock': 74, 'vase': 75, 'scissors': 76, 'teddy bear': 77, 'hair drier': 78,
40
+ 'toothbrush': 79, 'banner': 80, 'blanket': 81, 'bridge': 82, 'cardboard': 83,
41
+ 'counter': 84, 'curtain': 85, 'door-stuff': 86, 'floor-wood': 87,
42
+ 'flower': 88, 'fruit': 89, 'gravel': 90, 'house': 91, 'light': 92,
43
+ 'mirror-stuff': 93, 'net': 94, 'pillow': 95, 'platform': 96, 'playingfield': 97,
44
+ 'railroad': 98, 'river': 99, 'road': 100, 'roof': 101, 'sand': 102,
45
+ 'sea': 103, 'shelf': 104, 'snow': 105, 'stairs': 106, 'tent': 107,
46
+ 'towel': 108, 'wall-brick': 109, 'wall-stone': 110, 'wall-tile': 111,
47
+ 'wall-wood': 112, 'water': 113, 'window-blind': 114, 'window': 115,
48
+ 'tree': 116, 'fence': 117, 'ceiling': 118, 'sky': 119, 'cabinet': 120,
49
+ 'table': 121, 'floor': 122, 'pavement': 123, 'mountain': 124, 'grass': 125,
50
+ 'dirt': 126, 'paper': 127, 'food': 128, 'building': 129, 'rock': 130,
51
+ 'wall': 131, 'rug': 132
52
+ }
53
+
54
+ # Common ADE20K classes (subset of 150)
55
+ ADE20K_CLASSES = {
56
+ 'wall': 0, 'building': 1, 'sky': 2, 'floor': 3, 'tree': 4,
57
+ 'ceiling': 5, 'road': 6, 'bed': 7, 'windowpane': 8, 'grass': 9,
58
+ 'cabinet': 10, 'sidewalk': 11, 'person': 12, 'earth': 13, 'door': 14,
59
+ 'table': 15, 'mountain': 16, 'plant': 17, 'curtain': 18, 'chair': 19,
60
+ 'car': 20, 'water': 21, 'painting': 22, 'sofa': 23, 'shelf': 24,
61
+ 'house': 25, 'sea': 26, 'mirror': 27, 'rug': 28, 'field': 29,
62
+ 'armchair': 30, 'seat': 31, 'fence': 32, 'desk': 33, 'rock': 34,
63
+ 'wardrobe': 35, 'lamp': 36, 'bathtub': 37, 'railing': 38, 'cushion': 39,
64
+ 'base': 40, 'box': 41, 'column': 42, 'signboard': 43, 'chest of drawers': 44,
65
+ 'counter': 45, 'sand': 46, 'sink': 47, 'skyscraper': 48, 'fireplace': 49,
66
+ }
67
+
68
+ def __init__(
69
+ self,
70
+ device: str = 'cuda',
71
+ conf_threshold: float = 0.5,
72
+ model_type: str = 'coco', # 'coco' or 'ade20k'
73
+ **kwargs
74
+ ):
75
+ """
76
+ Initialize Mask2Former segmenter.
77
+
78
+ Args:
79
+ device: Device to run model on
80
+ conf_threshold: Confidence threshold for predictions
81
+ model_type: Which pre-trained model to use ('coco' or 'ade20k')
82
+ **kwargs: Additional arguments
83
+ """
84
+ super().__init__(device, **kwargs)
85
+ self.conf_threshold = conf_threshold
86
+ self.model_type = model_type.lower()
87
+
88
+ if self.model_type not in ['coco', 'ade20k']:
89
+ raise ValueError(f"model_type must be 'coco' or 'ade20k', got {self.model_type}")
90
+
91
+ self.class_map = self.COCO_CLASSES if self.model_type == 'coco' else self.ADE20K_CLASSES
92
+ self.model = None
93
+ self.processor = None
94
+
95
+ def load_model(self):
96
+ """Load Mask2Former model and processor from HuggingFace."""
97
+ try:
98
+ from transformers import Mask2FormerForUniversalSegmentation, AutoImageProcessor
99
+ except ImportError:
100
+ raise ImportError(
101
+ "Mask2Former requires transformers. Install with: pip install transformers"
102
+ )
103
+
104
+ if self.model_type == 'coco':
105
+ model_name = "facebook/mask2former-swin-large-coco-panoptic"
106
+ else: # ade20k
107
+ model_name = "facebook/mask2former-swin-large-ade-semantic"
108
+
109
+ print(f"Loading Mask2Former ({self.model_type}) model...")
110
+ ensure_default_checkpoint_dirs()
111
+ cache_dir = str(hf_cache_dir())
112
+ self.processor = AutoImageProcessor.from_pretrained(model_name, cache_dir=cache_dir)
113
+ self.model = Mask2FormerForUniversalSegmentation.from_pretrained(model_name, cache_dir=cache_dir)
114
+ self.model = self.model.to(self.device)
115
+ self.model.eval()
116
+ print(f"✓ Mask2Former loaded: {model_name}")
117
+
118
+ def segment(
119
+ self,
120
+ image: Image.Image,
121
+ target_classes: Optional[List[str]] = None,
122
+ **kwargs
123
+ ) -> np.ndarray:
124
+ """
125
+ Segment image using Mask2Former.
126
+
127
+ Args:
128
+ image: PIL Image
129
+ target_classes: List of class names to segment (None for all)
130
+ **kwargs: Additional arguments (unused)
131
+
132
+ Returns:
133
+ Binary mask as numpy array [H, W] with 1 for ROI, 0 for background
134
+ """
135
+ if self.model is None:
136
+ self.load_model()
137
+
138
+ # Prepare image
139
+ inputs = self.processor(images=image, return_tensors="pt")
140
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
141
+
142
+ # Run inference
143
+ with torch.no_grad():
144
+ outputs = self.model(**inputs)
145
+
146
+ # Post-process
147
+ if self.model_type == 'coco':
148
+ # Panoptic segmentation
149
+ result = self.processor.post_process_panoptic_segmentation(
150
+ outputs,
151
+ target_sizes=[image.size[::-1]]
152
+ )[0]
153
+ segmentation = result['segmentation'].cpu().numpy()
154
+ segments_info = result['segments_info']
155
+ else:
156
+ # Semantic segmentation
157
+ result = self.processor.post_process_semantic_segmentation(
158
+ outputs,
159
+ target_sizes=[image.size[::-1]]
160
+ )[0]
161
+ segmentation = result.cpu().numpy()
162
+ segments_info = None
163
+
164
+ # Create binary mask
165
+ mask = np.zeros(segmentation.shape, dtype=np.uint8)
166
+
167
+ if target_classes is None:
168
+ # Return all detected objects
169
+ mask = (segmentation > 0).astype(np.uint8)
170
+ else:
171
+ # Filter by target classes
172
+ target_ids = []
173
+ for class_name in target_classes:
174
+ class_lower = class_name.lower()
175
+ if class_lower in self.class_map:
176
+ target_ids.append(self.class_map[class_lower])
177
+ else:
178
+ # Try fuzzy matching
179
+ for key, val in self.class_map.items():
180
+ if class_lower in key or key in class_lower:
181
+ target_ids.append(val)
182
+ break
183
+
184
+ if self.model_type == 'coco' and segments_info:
185
+ # Use segment info for panoptic
186
+ for segment in segments_info:
187
+ if segment['label_id'] in target_ids:
188
+ mask[segmentation == segment['id']] = 1
189
+ else:
190
+ # Use class IDs directly for semantic
191
+ for target_id in target_ids:
192
+ mask[segmentation == target_id] = 1
193
+
194
+ return mask
195
+
196
+ def get_available_classes(self) -> List[str]:
197
+ """
198
+ Get list of available class names.
199
+
200
+ Returns:
201
+ List of class names supported by the model
202
+ """
203
+ return sorted(self.class_map.keys())
204
+
205
+ # Mask2Former can be batched through the HF processor for semantic mode.
206
+ # Panoptic post-processing is per-image, but the forward pass is batched.
207
+ supports_batch: bool = True
208
+
209
+ def segment_batch(
210
+ self,
211
+ images: List[Image.Image],
212
+ target_classes: Optional[List[str]] = None,
213
+ **kwargs,
214
+ ) -> List[np.ndarray]:
215
+ """Segment a batch of images via Mask2Former.
216
+
217
+ The HuggingFace processor accepts a list of images. The forward
218
+ pass runs on the full batch; post-processing is per-image.
219
+
220
+ Args:
221
+ images: List of PIL Images
222
+ target_classes: Class names to include in ROI
223
+ **kwargs: unused
224
+
225
+ Returns:
226
+ List of binary masks (H, W) float32
227
+ """
228
+ if not images:
229
+ return []
230
+
231
+ if self.model is None:
232
+ self.load_model()
233
+
234
+ # Resolve target class IDs
235
+ target_ids = []
236
+ if target_classes:
237
+ for cn in target_classes:
238
+ cl = cn.lower()
239
+ if cl in self.class_map:
240
+ target_ids.append(self.class_map[cl])
241
+ else:
242
+ for key, val in self.class_map.items():
243
+ if cl in key or key in cl:
244
+ target_ids.append(val)
245
+ break
246
+
247
+ # Batch preprocess
248
+ inputs = self.processor(images=images, return_tensors="pt", padding=True)
249
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
250
+
251
+ with torch.no_grad():
252
+ outputs = self.model(**inputs)
253
+
254
+ # Collect all target sizes for batch post-processing
255
+ target_sizes = [img.size[::-1] for img in images] # [(H1, W1), (H2, W2), ...]
256
+
257
+ masks: List[np.ndarray] = []
258
+
259
+ if self.model_type == 'coco':
260
+ # Post-process entire batch at once for panoptic segmentation
261
+ results = self.processor.post_process_panoptic_segmentation(
262
+ outputs,
263
+ target_sizes=target_sizes,
264
+ )
265
+
266
+ for i, (result, img) in enumerate(zip(results, images)):
267
+ segmentation = result['segmentation'].cpu().numpy()
268
+ segments_info = result['segments_info']
269
+
270
+ mask = np.zeros(segmentation.shape, dtype=np.float32)
271
+ if target_classes is None:
272
+ mask = (segmentation > 0).astype(np.float32)
273
+ else:
274
+ for seg in segments_info:
275
+ if seg['label_id'] in target_ids:
276
+ mask[segmentation == seg['id']] = 1.0
277
+
278
+ masks.append(mask)
279
+ else:
280
+ # Post-process entire batch at once for semantic segmentation
281
+ results = self.processor.post_process_semantic_segmentation(
282
+ outputs,
283
+ target_sizes=target_sizes,
284
+ )
285
+
286
+ for i, (result, img) in enumerate(zip(results, images)):
287
+ segmentation = result.cpu().numpy()
288
+ mask = np.zeros(segmentation.shape, dtype=np.float32)
289
+ if target_classes is None:
290
+ mask = (segmentation > 0).astype(np.float32)
291
+ else:
292
+ for tid in target_ids:
293
+ mask[segmentation == tid] = 1.0
294
+
295
+ masks.append(mask)
296
+
297
+ return masks
298
+
299
+ def get_class_info(self) -> Dict[str, int]:
300
+ """
301
+ Get detailed class information.
302
+
303
+ Returns:
304
+ Dictionary mapping class names to IDs
305
+ """
306
+ return {
307
+ 'model_type': self.model_type,
308
+ 'num_classes': len(self.class_map),
309
+ 'classes': self.class_map.copy()
310
+ }
segmentation/maskrcnn.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mask R-CNN segmentation using torchvision pre-trained models.
3
+ Supports COCO-trained instance segmentation.
4
+ """
5
+
6
+ from model_cache import ensure_default_checkpoint_dirs
7
+
8
+ import torch
9
+ import numpy as np
10
+ import cv2
11
+ from PIL import Image
12
+ from typing import List, Optional, Dict
13
+ from .base import BaseSegmenter
14
+
15
+
16
+ # Ensure torchvision/torch hub downloads land under `checkpoints/` by default.
17
+ ensure_default_checkpoint_dirs()
18
+
19
+
20
+ class MaskRCNNSegmenter(BaseSegmenter):
21
+ """
22
+ Mask R-CNN instance segmentation from torchvision.
23
+
24
+ Uses pre-trained ResNet50-FPN backbone on COCO dataset (80 classes).
25
+ """
26
+
27
+ # COCO class names (80 classes)
28
+ COCO_CLASSES = [
29
+ '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
30
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
31
+ 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
32
+ 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
33
+ 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
34
+ 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
35
+ 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
36
+ 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
37
+ 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
38
+ 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
39
+ 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
40
+ 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
41
+ ]
42
+
43
+ def __init__(
44
+ self,
45
+ device: str = 'cuda',
46
+ conf_threshold: float = 0.5,
47
+ backbone: str = 'resnet50', # 'resnet50' or 'mobilenet'
48
+ **kwargs
49
+ ):
50
+ """
51
+ Initialize Mask R-CNN segmenter.
52
+
53
+ Args:
54
+ device: Device to run model on
55
+ conf_threshold: Confidence threshold for detections
56
+ backbone: Model backbone ('resnet50' or 'mobilenet')
57
+ **kwargs: Additional arguments
58
+ """
59
+ super().__init__(device, **kwargs)
60
+ self.conf_threshold = conf_threshold
61
+ self.backbone = backbone.lower()
62
+
63
+ if self.backbone not in ['resnet50', 'mobilenet']:
64
+ raise ValueError(f"backbone must be 'resnet50' or 'mobilenet', got {self.backbone}")
65
+
66
+ self.model = None
67
+
68
+ def load_model(self):
69
+ """Load Mask R-CNN model from torchvision."""
70
+ try:
71
+ import torchvision
72
+ from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights
73
+ from torchvision.models.detection import maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights
74
+ except ImportError:
75
+ raise ImportError(
76
+ "Mask R-CNN requires torchvision. Install with: pip install torchvision"
77
+ )
78
+
79
+ print(f"Loading Mask R-CNN ({self.backbone}) model...")
80
+
81
+ if self.backbone == 'resnet50':
82
+ # Use newer V2 weights for better performance
83
+ self.model = maskrcnn_resnet50_fpn_v2(weights=MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT)
84
+ else:
85
+ # MobileNet version (lighter but less accurate)
86
+ from torchvision.models.detection import maskrcnn_mobilenet_v3_large_fpn
87
+ from torchvision.models.detection import MaskRCNN_MobileNet_V3_Large_FPN_Weights
88
+ self.model = maskrcnn_mobilenet_v3_large_fpn(weights=MaskRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT)
89
+
90
+ self.model = self.model.to(self.device)
91
+ self.model.eval()
92
+ print(f"✓ Mask R-CNN loaded: {self.backbone}")
93
+
94
+ def segment(
95
+ self,
96
+ image: Image.Image,
97
+ target_classes: Optional[List[str]] = None,
98
+ **kwargs
99
+ ) -> np.ndarray:
100
+ """
101
+ Segment image using Mask R-CNN.
102
+
103
+ Args:
104
+ image: PIL Image
105
+ target_classes: List of class names to segment (None for all)
106
+ **kwargs: Additional arguments (can override conf_threshold)
107
+
108
+ Returns:
109
+ Binary mask as numpy array [H, W] with 1 for ROI, 0 for background
110
+ """
111
+ if self.model is None:
112
+ self.load_model()
113
+
114
+ # Get confidence threshold from kwargs or use default
115
+ conf_threshold = kwargs.get('conf_threshold', self.conf_threshold)
116
+
117
+ # Prepare image
118
+ img_array = np.array(image)
119
+ if len(img_array.shape) == 2: # Grayscale
120
+ img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
121
+ elif img_array.shape[2] == 4: # RGBA
122
+ img_array = cv2.cvtColor(img_array, cv2.COLOR_RGBA2RGB)
123
+
124
+ # Convert to tensor [3, H, W] and normalize
125
+ img_tensor = torch.from_numpy(img_array).float().permute(2, 0, 1) / 255.0
126
+ img_tensor = img_tensor.to(self.device)
127
+
128
+ # Run inference
129
+ with torch.no_grad():
130
+ predictions = self.model([img_tensor])[0]
131
+
132
+ # Get predictions
133
+ boxes = predictions['boxes'].cpu().numpy()
134
+ labels = predictions['labels'].cpu().numpy()
135
+ scores = predictions['scores'].cpu().numpy()
136
+ masks = predictions['masks'].cpu().numpy()
137
+
138
+ # Filter by confidence
139
+ keep_indices = scores >= conf_threshold
140
+ boxes = boxes[keep_indices]
141
+ labels = labels[keep_indices]
142
+ scores = scores[keep_indices]
143
+ masks = masks[keep_indices]
144
+
145
+ # Create combined mask
146
+ h, w = img_array.shape[:2]
147
+ combined_mask = np.zeros((h, w), dtype=np.uint8)
148
+
149
+ if target_classes is None:
150
+ # Combine all high-confidence masks
151
+ for mask in masks:
152
+ binary_mask = (mask[0] > 0.5).astype(np.uint8)
153
+ combined_mask = np.maximum(combined_mask, binary_mask)
154
+ else:
155
+ # Filter by target classes
156
+ target_indices = []
157
+ for class_name in target_classes:
158
+ class_lower = class_name.lower()
159
+ for idx, coco_class in enumerate(self.COCO_CLASSES):
160
+ if coco_class.lower() == class_lower or class_lower in coco_class.lower():
161
+ target_indices.append(idx)
162
+
163
+ # Combine masks for target classes
164
+ for i, label in enumerate(labels):
165
+ if label in target_indices:
166
+ binary_mask = (masks[i][0] > 0.5).astype(np.uint8)
167
+ combined_mask = np.maximum(combined_mask, binary_mask)
168
+
169
+ return combined_mask
170
+
171
+ def get_available_classes(self) -> List[str]:
172
+ """
173
+ Get list of available class names.
174
+
175
+ Returns:
176
+ List of COCO class names (excluding 'N/A' and '__background__')
177
+ """
178
+ return [cls for cls in self.COCO_CLASSES if cls not in ['N/A', '__background__']]
179
+
180
+ # torchvision Mask R-CNN natively accepts a list of tensors.
181
+ supports_batch: bool = True
182
+
183
+ def segment_batch(
184
+ self,
185
+ images: List[Image.Image],
186
+ target_classes: Optional[List[str]] = None,
187
+ **kwargs,
188
+ ) -> List[np.ndarray]:
189
+ """Segment a batch of images via Mask R-CNN.
190
+
191
+ torchvision's Mask R-CNN forward accepts ``List[Tensor]`` so we
192
+ can pass all images in one call. Post-processing is per-image.
193
+
194
+ Args:
195
+ images: List of PIL Images
196
+ target_classes: COCO class names to include in mask
197
+ **kwargs: May include ``conf_threshold``
198
+
199
+ Returns:
200
+ List of binary masks (H, W) float32
201
+ """
202
+ if not images:
203
+ return []
204
+
205
+ if self.model is None:
206
+ self.load_model()
207
+
208
+ conf_threshold = kwargs.get('conf_threshold', self.conf_threshold)
209
+
210
+ # Resolve target class indices
211
+ target_indices: Optional[List[int]] = None
212
+ if target_classes is not None:
213
+ target_indices = []
214
+ for cn in target_classes:
215
+ cl = cn.lower()
216
+ for idx, cc in enumerate(self.COCO_CLASSES):
217
+ if cc.lower() == cl or cl in cc.lower():
218
+ target_indices.append(idx)
219
+
220
+ # Build list of tensors (varying sizes are ok for torchvision)
221
+ tensors = []
222
+ for img in images:
223
+ arr = np.array(img)
224
+ if len(arr.shape) == 2:
225
+ arr = cv2.cvtColor(arr, cv2.COLOR_GRAY2RGB)
226
+ elif arr.shape[2] == 4:
227
+ arr = cv2.cvtColor(arr, cv2.COLOR_RGBA2RGB)
228
+ t = torch.from_numpy(arr).float().permute(2, 0, 1) / 255.0
229
+ tensors.append(t.to(self.device))
230
+
231
+ with torch.no_grad():
232
+ predictions_list = self.model(tensors)
233
+
234
+ masks_out: List[np.ndarray] = []
235
+ for i, preds in enumerate(predictions_list):
236
+ h, w = images[i].height, images[i].width
237
+ combined = np.zeros((h, w), dtype=np.float32)
238
+
239
+ labels = preds['labels'].cpu().numpy()
240
+ scores = preds['scores'].cpu().numpy()
241
+ pred_masks = preds['masks'].cpu().numpy()
242
+
243
+ keep = scores >= conf_threshold
244
+ labels = labels[keep]
245
+ pred_masks = pred_masks[keep]
246
+
247
+ for j, lbl in enumerate(labels):
248
+ if target_indices is not None and int(lbl) not in target_indices:
249
+ continue
250
+ binary = (pred_masks[j][0] > 0.5).astype(np.float32)
251
+ combined = np.maximum(combined, binary)
252
+
253
+ masks_out.append(combined)
254
+
255
+ return masks_out
256
+
257
+ def get_detection_info(
258
+ self,
259
+ image: Image.Image,
260
+ conf_threshold: Optional[float] = None
261
+ ) -> List[Dict]:
262
+ """
263
+ Get detailed detection information for all objects.
264
+
265
+ Args:
266
+ image: PIL Image
267
+ conf_threshold: Confidence threshold (uses default if None)
268
+
269
+ Returns:
270
+ List of dictionaries with detection info (class, score, bbox, mask)
271
+ """
272
+ if self.model is None:
273
+ self.load_model()
274
+
275
+ if conf_threshold is None:
276
+ conf_threshold = self.conf_threshold
277
+
278
+ # Prepare image
279
+ img_array = np.array(image)
280
+ if len(img_array.shape) == 2:
281
+ img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
282
+ elif img_array.shape[2] == 4:
283
+ img_array = cv2.cvtColor(img_array, cv2.COLOR_RGBA2RGB)
284
+
285
+ img_tensor = torch.from_numpy(img_array).float().permute(2, 0, 1) / 255.0
286
+ img_tensor = img_tensor.to(self.device)
287
+
288
+ # Run inference
289
+ with torch.no_grad():
290
+ predictions = self.model([img_tensor])[0]
291
+
292
+ # Get predictions
293
+ boxes = predictions['boxes'].cpu().numpy()
294
+ labels = predictions['labels'].cpu().numpy()
295
+ scores = predictions['scores'].cpu().numpy()
296
+ masks = predictions['masks'].cpu().numpy()
297
+
298
+ # Filter and format results
299
+ detections = []
300
+ for i in range(len(scores)):
301
+ if scores[i] >= conf_threshold:
302
+ detections.append({
303
+ 'class': self.COCO_CLASSES[labels[i]],
304
+ 'class_id': int(labels[i]),
305
+ 'score': float(scores[i]),
306
+ 'bbox': boxes[i].tolist(),
307
+ 'mask': (masks[i][0] > 0.5).astype(np.uint8)
308
+ })
309
+
310
+ return detections
segmentation/sam3.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAM3-style promptable segmentation.
2
+
3
+ This integrates a prompt-driven segmentation method into the existing
4
+ class-based segmentation interface. Instead of relying on a fixed class
5
+ vocabulary, it accepts natural-language prompts (e.g., "a red car", "the person").
6
+
7
+ Implementation approach (lightweight, no custom training):
8
+ - Use a text-conditioned detector (OWL-ViT) to propose bounding boxes from text.
9
+ - Use SAM (Segment Anything) to convert boxes into masks.
10
+
11
+ Notes:
12
+ - This is not "SAM 3" in the sense of an official model release; it is a
13
+ prompt-to-mask pipeline exposed as a single segmenter named "sam3".
14
+ - If required dependencies/models are missing, this segmenter raises a clear
15
+ error message.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ from dataclasses import dataclass
21
+ from typing import List, Optional, Union
22
+
23
+ import numpy as np
24
+ import torch
25
+ from PIL import Image
26
+
27
+ from .base import BaseSegmenter
28
+ from model_cache import hf_cache_dir, ensure_default_checkpoint_dirs
29
+
30
+
31
+ @dataclass
32
+ class _SAM3Config:
33
+ detector_model: str = "google/owlvit-base-patch32"
34
+ sam_model: str = "facebook/sam-vit-base"
35
+ box_threshold: float = 0.02
36
+ max_boxes: int = 5
37
+
38
+
39
+ class SAM3Segmenter(BaseSegmenter):
40
+ """Prompt-driven segmentation via (text detector → SAM).
41
+
42
+ Use `target_classes` to pass natural language prompts:
43
+ - `['car']`, `['a car']`, `['the person']`, etc.
44
+
45
+ Returns a binary mask (H, W) with 1 for predicted ROI.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ device: str = "cuda",
51
+ detector_model: str = _SAM3Config.detector_model,
52
+ sam_model: str = _SAM3Config.sam_model,
53
+ box_threshold: float = _SAM3Config.box_threshold,
54
+ max_boxes: int = _SAM3Config.max_boxes,
55
+ **kwargs,
56
+ ):
57
+ super().__init__(device=device, **kwargs)
58
+ self.detector_model_name = detector_model
59
+ self.sam_model_name = sam_model
60
+ self.box_threshold = float(box_threshold)
61
+ self.max_boxes = int(max_boxes)
62
+
63
+ self._detector = None
64
+ self._sam_model = None
65
+ self._sam_processor = None
66
+
67
+ def load_model(self):
68
+ try:
69
+ from transformers import pipeline, SamModel, SamProcessor
70
+ except Exception as e: # pragma: no cover
71
+ raise ImportError(
72
+ "SAM3Segmenter requires `transformers` with SAM support. "
73
+ "Try: pip install -U transformers"
74
+ ) from e
75
+
76
+ # Make sure any HF downloads (including pipeline internals) land under `checkpoints/`.
77
+ ensure_default_checkpoint_dirs()
78
+
79
+ # Configure device for HF pipeline
80
+ if self.device.startswith("cuda") and torch.cuda.is_available():
81
+ pipeline_device = 0
82
+ else:
83
+ pipeline_device = -1
84
+
85
+ self._detector = pipeline(
86
+ task="zero-shot-object-detection",
87
+ model=self.detector_model_name,
88
+ device=pipeline_device,
89
+ )
90
+
91
+ cache_dir = str(hf_cache_dir())
92
+
93
+ self._sam_processor = SamProcessor.from_pretrained(self.sam_model_name, cache_dir=cache_dir)
94
+ self._sam_model = SamModel.from_pretrained(self.sam_model_name, cache_dir=cache_dir)
95
+ self._sam_model = self._sam_model.to(self.device)
96
+ self._sam_model.eval()
97
+
98
+ # Keep BaseSegmenter.model set for consistency
99
+ self.model = self._sam_model
100
+
101
+ def segment(
102
+ self,
103
+ image: Image.Image,
104
+ target_classes: Optional[List[str]] = None,
105
+ **kwargs,
106
+ ) -> np.ndarray:
107
+ self.ensure_loaded()
108
+
109
+ prompts: List[str]
110
+ if target_classes is None or len(target_classes) == 0:
111
+ prompts = ["object"]
112
+ else:
113
+ # Treat provided "classes" as free-form text prompts.
114
+ prompts = [str(p).strip() for p in target_classes if str(p).strip()]
115
+ if not prompts:
116
+ prompts = ["object"]
117
+
118
+ box_threshold = float(kwargs.get("box_threshold", self.box_threshold))
119
+ max_boxes = int(kwargs.get("max_boxes", self.max_boxes))
120
+
121
+ detections = self._detector(image, candidate_labels=prompts)
122
+
123
+ # HF pipeline may return dict (single) or list
124
+ if isinstance(detections, dict):
125
+ detections = [detections]
126
+
127
+ boxes: List[List[float]] = []
128
+ for det in detections:
129
+ score = float(det.get("score", 0.0))
130
+ if score < box_threshold:
131
+ continue
132
+ b = det.get("box") or {}
133
+ xmin = float(b.get("xmin", 0.0))
134
+ ymin = float(b.get("ymin", 0.0))
135
+ xmax = float(b.get("xmax", 0.0))
136
+ ymax = float(b.get("ymax", 0.0))
137
+ # Sanity clamp
138
+ xmin, ymin = max(0.0, xmin), max(0.0, ymin)
139
+ xmax, ymax = max(xmin + 1.0, xmax), max(ymin + 1.0, ymax)
140
+ boxes.append([xmin, ymin, xmax, ymax])
141
+
142
+ if not boxes:
143
+ return np.zeros((image.height, image.width), dtype=np.float32)
144
+
145
+ boxes = boxes[:max_boxes]
146
+
147
+ # SAM expects a batch; provide one image with N boxes
148
+ inputs = self._sam_processor(
149
+ image,
150
+ input_boxes=[boxes],
151
+ return_tensors="pt",
152
+ )
153
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
154
+
155
+ with torch.no_grad():
156
+ outputs = self._sam_model(**inputs)
157
+
158
+ # Post-process masks back to original image size
159
+ # Returns list (batch) of tensors: [num_boxes, H, W]
160
+ post = self._sam_processor.image_processor.post_process_masks(
161
+ outputs.pred_masks.detach().cpu(),
162
+ inputs["original_sizes"].detach().cpu(),
163
+ inputs["reshaped_input_sizes"].detach().cpu(),
164
+ )
165
+
166
+ masks0 = post[0]
167
+ if isinstance(masks0, (list, tuple)):
168
+ # Defensive: some versions may nest
169
+ masks0 = torch.stack([m.squeeze(0) if m.ndim == 3 else m for m in masks0], dim=0)
170
+
171
+ # masks0: [num_boxes, H, W] or [num_boxes, 1, H, W]
172
+ if masks0.ndim == 4:
173
+ masks0 = masks0[:, 0]
174
+
175
+ combined = (masks0 > 0.5).any(dim=0).to(torch.float32)
176
+ return combined.numpy()
177
+
178
+ def get_available_classes(self) -> Union[List[str], dict]:
179
+ # Prompt-based model: not a fixed class list.
180
+ return []
181
+
182
+ def get_default_classes(self) -> List[str]:
183
+ return ["object"]
184
+
185
+ # SAM3 (OWL-ViT detector → SAM masks) is inherently sequential;
186
+ # the two-stage pipeline does not support batched inference.
187
+ supports_batch: bool = False
segmentation/segformer.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SegFormer-based segmentation implementation.
3
+ Uses Cityscapes-trained model for semantic segmentation.
4
+ """
5
+
6
+ import torch
7
+ import numpy as np
8
+ import cv2
9
+ from PIL import Image
10
+ from typing import List, Optional, Dict
11
+ from .base import BaseSegmenter
12
+ from model_cache import hf_cache_dir, ensure_default_checkpoint_dirs
13
+
14
+
15
+ class SegFormerSegmenter(BaseSegmenter):
16
+ """
17
+ SegFormer segmentation using Cityscapes classes.
18
+
19
+ Supports semantic segmentation of 19 urban scene classes including
20
+ vehicles, pedestrians, buildings, and road infrastructure.
21
+ """
22
+
23
+ # Cityscapes class mapping
24
+ CITYSCAPES_CLASSES = {
25
+ 'road': 0, 'sidewalk': 1, 'building': 2, 'wall': 3, 'fence': 4,
26
+ 'pole': 5, 'traffic light': 6, 'traffic sign': 7, 'vegetation': 8,
27
+ 'terrain': 9, 'sky': 10, 'person': 11, 'rider': 12, 'car': 13,
28
+ 'truck': 14, 'bus': 15, 'train': 16, 'motorcycle': 17, 'bicycle': 18
29
+ }
30
+
31
+ def __init__(
32
+ self,
33
+ device: str = 'cuda',
34
+ model_name: str = "nvidia/segformer-b4-finetuned-cityscapes-1024-1024",
35
+ **kwargs
36
+ ):
37
+ """
38
+ Initialize SegFormer segmenter.
39
+
40
+ Args:
41
+ device: Device to run on ('cuda' or 'cpu')
42
+ model_name: HuggingFace model identifier
43
+ **kwargs: Additional parameters
44
+ """
45
+ super().__init__(device=device, **kwargs)
46
+ self.model_name = model_name
47
+ self.processor = None
48
+
49
+ def load_model(self):
50
+ """Load SegFormer model from HuggingFace."""
51
+ from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
52
+
53
+ print(f"Loading SegFormer model: {self.model_name}")
54
+ ensure_default_checkpoint_dirs()
55
+ cache_dir = str(hf_cache_dir())
56
+ self.processor = SegformerImageProcessor.from_pretrained(self.model_name, cache_dir=cache_dir)
57
+ self.model = SegformerForSemanticSegmentation.from_pretrained(
58
+ self.model_name,
59
+ cache_dir=cache_dir,
60
+ ).to(self.device)
61
+ self.model.eval()
62
+ print("SegFormer model loaded successfully")
63
+
64
+ def segment(
65
+ self,
66
+ image: Image.Image,
67
+ target_classes: Optional[List[str]] = None,
68
+ **kwargs
69
+ ) -> np.ndarray:
70
+ """
71
+ Create segmentation mask using SegFormer.
72
+
73
+ Args:
74
+ image: PIL Image
75
+ target_classes: List of class names (e.g., ['car', 'building', 'person'])
76
+ **kwargs: Additional parameters (unused)
77
+
78
+ Returns:
79
+ Binary mask (H, W) with 1 for target classes, 0 for background
80
+ """
81
+ # Validate classes
82
+ if target_classes is None:
83
+ target_classes = self.get_default_classes()
84
+
85
+ # Get target class IDs
86
+ target_ids = []
87
+ for cls in target_classes:
88
+ cls_lower = cls.lower()
89
+ if cls_lower in self.CITYSCAPES_CLASSES:
90
+ target_ids.append(self.CITYSCAPES_CLASSES[cls_lower])
91
+ else:
92
+ print(f"Warning: '{cls}' not in Cityscapes classes. "
93
+ f"Available: {list(self.CITYSCAPES_CLASSES.keys())}")
94
+
95
+ if not target_ids:
96
+ print(f"Warning: No valid classes found. Using 'car' as default.")
97
+ target_ids = [self.CITYSCAPES_CLASSES['car']]
98
+
99
+ # Process image
100
+ orig_size = image.size
101
+ inputs = self.processor(images=image, return_tensors="pt")
102
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
103
+
104
+ with torch.no_grad():
105
+ outputs = self.model(**inputs)
106
+ logits = outputs.logits
107
+
108
+ # Get segmentation map
109
+ seg_map = torch.argmax(logits, dim=1)[0].cpu().numpy()
110
+
111
+ # Resize to original size
112
+ seg_map_resized = cv2.resize(
113
+ seg_map.astype(np.uint8),
114
+ orig_size,
115
+ interpolation=cv2.INTER_NEAREST
116
+ )
117
+
118
+ # Create binary mask for target classes
119
+ mask = np.zeros_like(seg_map_resized, dtype=np.float32)
120
+ for class_id in target_ids:
121
+ mask[seg_map_resized == class_id] = 1.0
122
+
123
+ return mask
124
+
125
+ def get_available_classes(self) -> Dict[str, int]:
126
+ """Get Cityscapes class mapping."""
127
+ return self.CITYSCAPES_CLASSES
128
+
129
+ def get_default_classes(self) -> List[str]:
130
+ """Default to car segmentation."""
131
+ return ['car']
132
+
133
+ # SegFormer supports batched inference via the HF processor.
134
+ supports_batch: bool = True
135
+
136
+ def segment_batch(
137
+ self,
138
+ images: List[Image.Image],
139
+ target_classes: Optional[List[str]] = None,
140
+ **kwargs,
141
+ ) -> List[np.ndarray]:
142
+ """Segment a batch of images in a single forward pass.
143
+
144
+ The HuggingFace SegFormer preprocessor natively accepts a list of
145
+ PIL images and returns a batched tensor.
146
+
147
+ Args:
148
+ images: List of PIL Images (should be same resolution for padding)
149
+ target_classes: Cityscapes class names to include in ROI mask
150
+ **kwargs: unused
151
+
152
+ Returns:
153
+ List of binary masks (H, W) float32
154
+ """
155
+ if not images:
156
+ return []
157
+
158
+ self.ensure_loaded()
159
+
160
+ if target_classes is None:
161
+ target_classes = self.get_default_classes()
162
+
163
+ target_ids = []
164
+ for cls in target_classes:
165
+ cls_lower = cls.lower()
166
+ if cls_lower in self.CITYSCAPES_CLASSES:
167
+ target_ids.append(self.CITYSCAPES_CLASSES[cls_lower])
168
+
169
+ if not target_ids:
170
+ target_ids = [self.CITYSCAPES_CLASSES['car']]
171
+
172
+ # Batch preprocess
173
+ inputs = self.processor(images=images, return_tensors="pt", padding=True)
174
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
175
+
176
+ with torch.no_grad():
177
+ outputs = self.model(**inputs)
178
+ logits = outputs.logits # (B, num_classes, h', w')
179
+
180
+ masks: List[np.ndarray] = []
181
+ for i, img in enumerate(images):
182
+ seg_map = torch.argmax(logits[i], dim=0).cpu().numpy()
183
+ seg_resized = cv2.resize(
184
+ seg_map.astype(np.uint8),
185
+ img.size,
186
+ interpolation=cv2.INTER_NEAREST,
187
+ )
188
+ mask = np.zeros_like(seg_resized, dtype=np.float32)
189
+ for cid in target_ids:
190
+ mask[seg_resized == cid] = 1.0
191
+ masks.append(mask)
192
+
193
+ return masks
segmentation/utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for segmentation visualization and I/O.
3
+ """
4
+
5
+ import numpy as np
6
+ import cv2
7
+ from PIL import Image
8
+ from typing import Tuple
9
+
10
+
11
+ def visualize_mask(
12
+ image: Image.Image,
13
+ mask: np.ndarray,
14
+ alpha: float = 0.5,
15
+ color: Tuple[int, int, int] = (255, 0, 0)
16
+ ) -> Image.Image:
17
+ """
18
+ Overlay segmentation mask on image.
19
+
20
+ Args:
21
+ image: PIL Image
22
+ mask: Binary mask (H, W)
23
+ alpha: Transparency (0-1)
24
+ color: RGB color tuple for mask
25
+
26
+ Returns:
27
+ Image with mask overlay
28
+ """
29
+ # Convert to numpy
30
+ img_array = np.array(image)
31
+
32
+ # Create colored mask
33
+ colored_mask = np.zeros_like(img_array)
34
+ colored_mask[mask > 0.5] = color
35
+
36
+ # Blend
37
+ result = cv2.addWeighted(img_array, 1.0, colored_mask, alpha, 0)
38
+
39
+ return Image.fromarray(result)
40
+
41
+
42
+ def save_mask(mask: np.ndarray, output_path: str):
43
+ """
44
+ Save mask as image (white for ROI, black for background).
45
+
46
+ Args:
47
+ mask: Binary mask array (H, W)
48
+ output_path: Path to save mask image
49
+ """
50
+ mask_img = (mask * 255).astype(np.uint8)
51
+ Image.fromarray(mask_img).save(output_path)
52
+
53
+
54
+ def load_mask(mask_path: str) -> np.ndarray:
55
+ """
56
+ Load mask from image file.
57
+
58
+ Args:
59
+ mask_path: Path to mask image
60
+
61
+ Returns:
62
+ Binary mask as numpy array (H, W) with values 0 or 1
63
+ """
64
+ mask_img = Image.open(mask_path).convert('L')
65
+ mask = np.array(mask_img).astype(np.float32) / 255.0
66
+ return mask
67
+
68
+
69
+ def calculate_roi_stats(mask: np.ndarray) -> dict:
70
+ """
71
+ Calculate statistics about ROI coverage.
72
+
73
+ Args:
74
+ mask: Binary mask (H, W)
75
+
76
+ Returns:
77
+ Dictionary with statistics:
78
+ - roi_pixels: Number of ROI pixels
79
+ - total_pixels: Total number of pixels
80
+ - roi_percentage: Percentage of image covered by ROI
81
+ """
82
+ roi_pixels = int(np.sum(mask > 0.5))
83
+ total_pixels = int(mask.size)
84
+ roi_percentage = (roi_pixels / total_pixels) * 100
85
+
86
+ return {
87
+ 'roi_pixels': roi_pixels,
88
+ 'total_pixels': total_pixels,
89
+ 'roi_percentage': roi_percentage
90
+ }
segmentation/yolo.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ YOLO-based instance segmentation implementation.
3
+ Uses YOLO26 (or YOLOv8 fallback) for COCO object detection and segmentation.
4
+ Supports true batch inference for video processing.
5
+ """
6
+
7
+ import numpy as np
8
+ import cv2
9
+ from PIL import Image
10
+ from typing import List, Optional
11
+ from .base import BaseSegmenter
12
+
13
+
14
+ class YOLOSegmenter(BaseSegmenter):
15
+ """
16
+ YOLO instance segmentation using COCO classes.
17
+
18
+ Defaults to YOLO26x-seg for best accuracy.
19
+ Supports batch inference via ``segment_batch()`` for video pipelines.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ device: str = 'cuda',
25
+ model_path: str = 'checkpoints/yolo26x-seg.pt',
26
+ conf_threshold: float = 0.25,
27
+ **kwargs
28
+ ):
29
+ """
30
+ Initialize YOLO segmenter.
31
+
32
+ Args:
33
+ device: Device to run on ('cuda' or 'cpu')
34
+ model_path: Path to YOLO weights file (yolo26x-seg.pt default)
35
+ conf_threshold: Confidence threshold for detections
36
+ **kwargs: Additional parameters
37
+ """
38
+ super().__init__(device=device, **kwargs)
39
+ self.model_path = model_path
40
+ self.conf_threshold = conf_threshold
41
+
42
+ def load_model(self):
43
+ """Load YOLO model."""
44
+ from ultralytics import YOLO
45
+
46
+ print(f"Loading YOLO model: {self.model_path}")
47
+ self.model = YOLO(self.model_path)
48
+ print("YOLO model loaded successfully")
49
+
50
+ def _extract_mask(
51
+ self,
52
+ result,
53
+ width: int,
54
+ height: int,
55
+ target_classes: List[str],
56
+ ) -> np.ndarray:
57
+ """Extract combined binary mask from a single YOLO result object.
58
+
59
+ Args:
60
+ result: Single ultralytics Results object
61
+ width: Target mask width
62
+ height: Target mask height
63
+ target_classes: Classes to include
64
+
65
+ Returns:
66
+ Binary mask (H, W) float32
67
+ """
68
+ mask = np.zeros((height, width), dtype=np.float32)
69
+
70
+ if result.masks is None:
71
+ return mask
72
+
73
+ masks_data = result.masks.data
74
+ boxes = result.boxes
75
+
76
+ for idx, box in enumerate(boxes):
77
+ class_id = int(box.cls[0])
78
+ class_name = self.model.names[class_id].lower()
79
+
80
+ if any(target.lower() in class_name for target in target_classes):
81
+ instance_mask = masks_data[idx].cpu().numpy()
82
+ instance_mask_resized = cv2.resize(
83
+ instance_mask,
84
+ (width, height),
85
+ interpolation=cv2.INTER_LINEAR,
86
+ )
87
+ mask = np.maximum(mask, instance_mask_resized)
88
+
89
+ return (mask > 0.5).astype(np.float32)
90
+
91
+ def segment(
92
+ self,
93
+ image: Image.Image,
94
+ target_classes: Optional[List[str]] = None,
95
+ conf_threshold: Optional[float] = None,
96
+ **kwargs
97
+ ) -> np.ndarray:
98
+ """
99
+ Create segmentation mask using YOLO.
100
+
101
+ Args:
102
+ image: PIL Image
103
+ target_classes: List of class names (e.g., ['car', 'person'])
104
+ conf_threshold: Override default confidence threshold
105
+ **kwargs: Additional parameters
106
+
107
+ Returns:
108
+ Binary mask (H, W) with 1 for target instances, 0 for background
109
+ """
110
+ threshold = conf_threshold if conf_threshold is not None else self.conf_threshold
111
+ if target_classes is None:
112
+ target_classes = self.get_default_classes()
113
+
114
+ results = self.model(image, verbose=False, conf=threshold, device=self.device)
115
+
116
+ if not results:
117
+ return np.zeros((image.height, image.width), dtype=np.float32)
118
+
119
+ return self._extract_mask(results[0], image.width, image.height, target_classes)
120
+
121
+ def segment_batch(
122
+ self,
123
+ images: List[Image.Image],
124
+ target_classes: Optional[List[str]] = None,
125
+ conf_threshold: Optional[float] = None,
126
+ **kwargs,
127
+ ) -> List[np.ndarray]:
128
+ """Segment a batch of images in a single YOLO forward pass.
129
+
130
+ Args:
131
+ images: List of PIL Images (should be the same resolution)
132
+ target_classes: Class names to include in ROI mask
133
+ conf_threshold: Override default confidence threshold
134
+
135
+ Returns:
136
+ List of binary masks (H, W), one per input image
137
+ """
138
+ if not images:
139
+ return []
140
+
141
+ self.ensure_loaded()
142
+
143
+ threshold = conf_threshold if conf_threshold is not None else self.conf_threshold
144
+ if target_classes is None:
145
+ target_classes = self.get_default_classes()
146
+
147
+ # Ultralytics accepts a list of PIL images for batch inference
148
+ results = self.model(images, verbose=False, conf=threshold, device=self.device)
149
+
150
+ masks = []
151
+ for i, result in enumerate(results):
152
+ img = images[i]
153
+ masks.append(self._extract_mask(result, img.width, img.height, target_classes))
154
+
155
+ return masks
156
+
157
+ def get_available_classes(self) -> List[str]:
158
+ """
159
+ Get COCO class names.
160
+
161
+ Returns:
162
+ List of COCO class names (80 classes)
163
+ """
164
+ if self.model is None:
165
+ # Return common COCO classes if model not loaded
166
+ return [
167
+ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
168
+ 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
169
+ 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
170
+ 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
171
+ 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
172
+ 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
173
+ 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
174
+ 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
175
+ 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
176
+ 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
177
+ 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
178
+ 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
179
+ 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
180
+ ]
181
+ return list(self.model.names.values())
182
+
183
+ def get_default_classes(self) -> List[str]:
184
+ """Default to car segmentation."""
185
+ return ['car']
186
+
187
+ # YOLO natively supports batched inference via ultralytics.
188
+ supports_batch: bool = True
vae/RSTB.py ADDED
@@ -0,0 +1,813 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021-2022, InterDigital Communications, Inc
2
+ # All rights reserved.
3
+
4
+ # Redistribution and use in source and binary forms, with or without
5
+ # modification, are permitted (subject to the limitations in the disclaimer
6
+ # below) provided that the following conditions are met:
7
+
8
+ # * Redistributions of source code must retain the above copyright notice,
9
+ # this list of conditions and the following disclaimer.
10
+ # * Redistributions in binary form must reproduce the above copyright notice,
11
+ # this list of conditions and the following disclaimer in the documentation
12
+ # and/or other materials provided with the distribution.
13
+ # * Neither the name of InterDigital Communications, Inc nor the names of its
14
+ # contributors may be used to endorse or promote products derived from this
15
+ # software without specific prior written permission.
16
+
17
+ # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
18
+ # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
19
+ # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
20
+ # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
21
+ # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
22
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
23
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
24
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
25
+ # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
26
+ # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
27
+ # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
28
+ # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+
30
+ from typing import Any
31
+
32
+ import torch
33
+ import torch.nn as nn
34
+ import torch.nn.functional as F
35
+ from torch import Tensor
36
+ from torch.autograd import Function
37
+ import torch.utils.checkpoint as checkpoint
38
+
39
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
40
+
41
+ from compressai.layers import GDN
42
+
43
+ __all__ = [
44
+ "AttentionBlock",
45
+ "MaskedConv2d",
46
+ "ResidualBlock",
47
+ "ResidualBlockUpsample",
48
+ "ResidualBlockWithStride",
49
+ "conv3x3",
50
+ "subpel_conv3x3",
51
+ "QReLU",
52
+ "RSTB",
53
+ "CausalAttentionModule",
54
+ ]
55
+
56
+
57
+ class MaskedConv2d(nn.Conv2d):
58
+ r"""Masked 2D convolution implementation, mask future "unseen" pixels.
59
+ Useful for building auto-regressive network components.
60
+
61
+ Introduced in `"Conditional Image Generation with PixelCNN Decoders"
62
+ <https://arxiv.org/abs/1606.05328>`_.
63
+
64
+ Inherits the same arguments as a `nn.Conv2d`. Use `mask_type='A'` for the
65
+ first layer (which also masks the "current pixel"), `mask_type='B'` for the
66
+ following layers.
67
+ """
68
+
69
+ def __init__(self, *args: Any, mask_type: str = "A", **kwargs: Any):
70
+ super().__init__(*args, **kwargs)
71
+
72
+ if mask_type not in ("A", "B"):
73
+ raise ValueError(f'Invalid "mask_type" value "{mask_type}"')
74
+
75
+ self.register_buffer("mask", torch.ones_like(self.weight.data))
76
+ _, _, h, w = self.mask.size()
77
+ self.mask[:, :, h // 2, w // 2 + (mask_type == "B"):] = 0
78
+ self.mask[:, :, h // 2 + 1:] = 0
79
+
80
+ def forward(self, x: Tensor) -> Tensor:
81
+ # TODO(begaintj): weight assigment is not supported by torchscript
82
+ self.weight.data *= self.mask
83
+ return super().forward(x)
84
+
85
+
86
+ def conv3x3(in_ch: int, out_ch: int, stride: int = 1) -> nn.Module:
87
+ """3x3 convolution with padding."""
88
+ return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1)
89
+
90
+
91
+ def subpel_conv3x3(in_ch: int, out_ch: int, r: int = 1) -> nn.Sequential:
92
+ """3x3 sub-pixel convolution for up-sampling."""
93
+ return nn.Sequential(
94
+ nn.Conv2d(in_ch, out_ch * r ** 2, kernel_size=3, padding=1), nn.PixelShuffle(r)
95
+ )
96
+
97
+
98
+ def conv1x1(in_ch: int, out_ch: int, stride: int = 1) -> nn.Module:
99
+ """1x1 convolution."""
100
+ return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride)
101
+
102
+
103
+ class ResidualBlockWithStride(nn.Module):
104
+ """Residual block with a stride on the first convolution.
105
+
106
+ Args:
107
+ in_ch (int): number of input channels
108
+ out_ch (int): number of output channels
109
+ stride (int): stride value (default: 2)
110
+ """
111
+
112
+ def __init__(self, in_ch: int, out_ch: int, stride: int = 2):
113
+ super().__init__()
114
+ self.conv1 = conv3x3(in_ch, out_ch, stride=stride)
115
+ self.leaky_relu = nn.LeakyReLU(inplace=True)
116
+ self.conv2 = conv3x3(out_ch, out_ch)
117
+ self.gdn = GDN(out_ch)
118
+ if stride != 1 or in_ch != out_ch:
119
+ self.skip = conv1x1(in_ch, out_ch, stride=stride)
120
+ else:
121
+ self.skip = None
122
+
123
+ def forward(self, x: Tensor) -> Tensor:
124
+ identity = x
125
+ out = self.conv1(x)
126
+ out = self.leaky_relu(out)
127
+ out = self.conv2(out)
128
+ out = self.gdn(out)
129
+
130
+ if self.skip is not None:
131
+ identity = self.skip(x)
132
+
133
+ out += identity
134
+ return out
135
+
136
+
137
+ class ResidualBlockUpsample(nn.Module):
138
+ """Residual block with sub-pixel upsampling on the last convolution.
139
+
140
+ Args:
141
+ in_ch (int): number of input channels
142
+ out_ch (int): number of output channels
143
+ upsample (int): upsampling factor (default: 2)
144
+ """
145
+
146
+ def __init__(self, in_ch: int, out_ch: int, upsample: int = 2):
147
+ super().__init__()
148
+ self.subpel_conv = subpel_conv3x3(in_ch, out_ch, upsample)
149
+ self.leaky_relu = nn.LeakyReLU(inplace=True)
150
+ self.conv = conv3x3(out_ch, out_ch)
151
+ self.igdn = GDN(out_ch, inverse=True)
152
+ self.upsample = subpel_conv3x3(in_ch, out_ch, upsample)
153
+
154
+ def forward(self, x: Tensor) -> Tensor:
155
+ identity = x
156
+ out = self.subpel_conv(x)
157
+ out = self.leaky_relu(out)
158
+ out = self.conv(out)
159
+ out = self.igdn(out)
160
+ identity = self.upsample(x)
161
+ out += identity
162
+ return out
163
+
164
+
165
+ class ResidualBlock(nn.Module):
166
+ """Simple residual block with two 3x3 convolutions.
167
+
168
+ Args:
169
+ in_ch (int): number of input channels
170
+ out_ch (int): number of output channels
171
+ """
172
+
173
+ def __init__(self, in_ch: int, out_ch: int):
174
+ super().__init__()
175
+ self.conv1 = conv3x3(in_ch, out_ch)
176
+ self.leaky_relu = nn.LeakyReLU(inplace=True)
177
+ self.conv2 = conv3x3(out_ch, out_ch)
178
+ if in_ch != out_ch:
179
+ self.skip = conv1x1(in_ch, out_ch)
180
+ else:
181
+ self.skip = None
182
+
183
+ def forward(self, x: Tensor) -> Tensor:
184
+ identity = x
185
+
186
+ out = self.conv1(x)
187
+ out = self.leaky_relu(out)
188
+ out = self.conv2(out)
189
+ out = self.leaky_relu(out)
190
+
191
+ if self.skip is not None:
192
+ identity = self.skip(x)
193
+
194
+ out = out + identity
195
+ return out
196
+
197
+
198
+ class AttentionBlock(nn.Module):
199
+ """Self attention block.
200
+
201
+ Simplified variant from `"Learned Image Compression with
202
+ Discretized Gaussian Mixture Likelihoods and Attention Modules"
203
+ <https://arxiv.org/abs/2001.01568>`_, by Zhengxue Cheng, Heming Sun, Masaru
204
+ Takeuchi, Jiro Katto.
205
+
206
+ Args:
207
+ N (int): Number of channels)
208
+ """
209
+
210
+ def __init__(self, N: int):
211
+ super().__init__()
212
+
213
+ class ResidualUnit(nn.Module):
214
+ """Simple residual unit."""
215
+
216
+ def __init__(self):
217
+ super().__init__()
218
+ self.conv = nn.Sequential(
219
+ conv1x1(N, N // 2),
220
+ nn.ReLU(inplace=True),
221
+ conv3x3(N // 2, N // 2),
222
+ nn.ReLU(inplace=True),
223
+ conv1x1(N // 2, N),
224
+ )
225
+ self.relu = nn.ReLU(inplace=True)
226
+
227
+ def forward(self, x: Tensor) -> Tensor:
228
+ identity = x
229
+ out = self.conv(x)
230
+ out += identity
231
+ out = self.relu(out)
232
+ return out
233
+
234
+ self.conv_a = nn.Sequential(ResidualUnit(), ResidualUnit(), ResidualUnit())
235
+
236
+ self.conv_b = nn.Sequential(
237
+ ResidualUnit(),
238
+ ResidualUnit(),
239
+ ResidualUnit(),
240
+ conv1x1(N, N),
241
+ )
242
+
243
+ def forward(self, x: Tensor) -> Tensor:
244
+ identity = x
245
+ a = self.conv_a(x)
246
+ b = self.conv_b(x)
247
+ out = a * torch.sigmoid(b)
248
+ out += identity
249
+ return out
250
+
251
+
252
+ class QReLU(Function):
253
+ """QReLU
254
+
255
+ Clamping input with given bit-depth range.
256
+ Suppose that input data presents integer through an integer network
257
+ otherwise any precision of input will simply clamp without rounding
258
+ operation.
259
+
260
+ Pre-computed scale with gamma function is used for backward computation.
261
+
262
+ More details can be found in
263
+ `"Integer networks for data compression with latent-variable models"
264
+ <https://openreview.net/pdf?id=S1zz2i0cY7>`_,
265
+ by Johannes Ballé, Nick Johnston and David Minnen, ICLR in 2019
266
+
267
+ Args:
268
+ input: a tensor data
269
+ bit_depth: source bit-depth (used for clamping)
270
+ beta: a parameter for modeling the gradient during backward computation
271
+ """
272
+
273
+ @staticmethod
274
+ def forward(ctx, input, bit_depth, beta):
275
+ # TODO(choih): allow to use adaptive scale instead of
276
+ # pre-computed scale with gamma function
277
+ ctx.alpha = 0.9943258522851727
278
+ ctx.beta = beta
279
+ ctx.max_value = 2 ** bit_depth - 1
280
+ ctx.save_for_backward(input)
281
+
282
+ return input.clamp(min=0, max=ctx.max_value)
283
+
284
+ @staticmethod
285
+ def backward(ctx, grad_output):
286
+ grad_input = None
287
+ (input,) = ctx.saved_tensors
288
+
289
+ grad_input = grad_output.clone()
290
+ grad_sub = (
291
+ torch.exp(
292
+ (-ctx.alpha ** ctx.beta)
293
+ * torch.abs(2.0 * input / ctx.max_value - 1) ** ctx.beta
294
+ )
295
+ * grad_output.clone()
296
+ )
297
+
298
+ grad_input[input < 0] = grad_sub[input < 0]
299
+ grad_input[input > ctx.max_value] = grad_sub[input > ctx.max_value]
300
+
301
+ return grad_input, None, None
302
+
303
+
304
+ class PatchEmbed(nn.Module):
305
+ def __init__(self):
306
+ super().__init__()
307
+
308
+ def forward(self, x):
309
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
310
+ return x
311
+
312
+ def flops(self):
313
+ flops = 0
314
+ return flops
315
+
316
+
317
+ class PatchUnEmbed(nn.Module):
318
+ def __init__(self):
319
+ super().__init__()
320
+
321
+ def forward(self, x, x_size):
322
+ B, HW, C = x.shape
323
+ x = x.transpose(1, 2).view(B, -1, x_size[0], x_size[1])
324
+ return x
325
+
326
+ def flops(self):
327
+ flops = 0
328
+ return flops
329
+
330
+
331
+ class Mlp(nn.Module):
332
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
333
+ super().__init__()
334
+ out_features = out_features or in_features
335
+ hidden_features = hidden_features or in_features
336
+ self.fc1 = nn.Linear(in_features, hidden_features)
337
+ self.act = act_layer()
338
+ self.fc2 = nn.Linear(hidden_features, out_features)
339
+ self.drop = nn.Dropout(drop)
340
+
341
+ def forward(self, x):
342
+ x = self.fc1(x)
343
+ x = self.act(x)
344
+ x = self.drop(x)
345
+ x = self.fc2(x)
346
+ x = self.drop(x)
347
+ return x
348
+
349
+
350
+ def window_partition(x, window_size):
351
+ """
352
+ Args:
353
+ x: (B, H, W, C)
354
+ window_size (int): window size
355
+ Returns:
356
+ windows: (num_windows*B, window_size, window_size, C)
357
+ """
358
+ B, H, W, C = x.shape
359
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
360
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
361
+ return windows
362
+
363
+
364
+ def window_reverse(windows, window_size, H, W):
365
+ """
366
+ Args:
367
+ windows: (num_windows*B, window_size, window_size, C)
368
+ window_size (int): Window size
369
+ H (int): Height of image
370
+ W (int): Width of image
371
+ Returns:
372
+ x: (B, H, W, C)
373
+ """
374
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
375
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
376
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
377
+ return x
378
+
379
+
380
+ class WindowAttention(nn.Module):
381
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
382
+ It supports both of shifted and non-shifted window.
383
+ Args:
384
+ dim (int): Number of input channels.
385
+ window_size (tuple[int]): The height and width of the window.
386
+ num_heads (int): Number of attention heads.
387
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
388
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
389
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
390
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
391
+ """
392
+
393
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
394
+
395
+ super().__init__()
396
+ self.dim = dim
397
+ self.window_size = window_size # Wh, Ww
398
+ self.num_heads = num_heads
399
+ head_dim = dim // num_heads
400
+ self.scale = qk_scale or head_dim ** -0.5
401
+
402
+ # define a parameter table of relative position bias
403
+ self.relative_position_bias_table = nn.Parameter(
404
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
405
+
406
+ # get pair-wise relative position index for each token inside the window
407
+ coords_h = torch.arange(self.window_size[0])
408
+ coords_w = torch.arange(self.window_size[1])
409
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
410
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
411
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
412
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
413
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
414
+ relative_coords[:, :, 1] += self.window_size[1] - 1
415
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
416
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
417
+ self.register_buffer("relative_position_index", relative_position_index)
418
+
419
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
420
+ self.attn_drop = nn.Dropout(attn_drop)
421
+ self.proj = nn.Linear(dim, dim)
422
+
423
+ self.proj_drop = nn.Dropout(proj_drop)
424
+
425
+ trunc_normal_(self.relative_position_bias_table, std=.02)
426
+ self.softmax = nn.Softmax(dim=-1)
427
+
428
+ def forward(self, x, mask=None):
429
+ """
430
+ Args:
431
+ x: input features with shape of (num_windows*B, N, C)
432
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
433
+ """
434
+ B_, N, C = x.shape
435
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
436
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
437
+
438
+ q = q * self.scale
439
+ attn = (q @ k.transpose(-2, -1))
440
+
441
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
442
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
443
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
444
+ attn = attn + relative_position_bias.unsqueeze(0)
445
+
446
+ if mask is not None:
447
+ nW = mask.shape[0]
448
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
449
+ attn = attn.view(-1, self.num_heads, N, N)
450
+ attn = self.softmax(attn)
451
+ else:
452
+ attn = self.softmax(attn)
453
+
454
+ attn = self.attn_drop(attn)
455
+
456
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
457
+ x = self.proj(x)
458
+ x = self.proj_drop(x)
459
+ return x
460
+
461
+ def extra_repr(self) -> str:
462
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
463
+
464
+ def flops(self, N):
465
+ # calculate flops for 1 window with token length of N
466
+ flops = 0
467
+ # qkv = self.qkv(x)
468
+ flops += N * self.dim * 3 * self.dim
469
+ # attn = (q @ k.transpose(-2, -1))
470
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
471
+ # x = (attn @ v)
472
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
473
+ # x = self.proj(x)
474
+ flops += N * self.dim * self.dim
475
+ return flops
476
+
477
+
478
+ class SwinTransformerBlock(nn.Module):
479
+ r""" Swin Transformer Block.
480
+ Args:
481
+ dim (int): Number of input channels.
482
+ input_resolution (tuple[int]): Input resulotion.
483
+ num_heads (int): Number of attention heads.
484
+ window_size (int): Window size.
485
+ shift_size (int): Shift size for SW-MSA.
486
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
487
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
488
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
489
+ drop (float, optional): Dropout rate. Default: 0.0
490
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
491
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
492
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
493
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
494
+ """
495
+
496
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
497
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
498
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
499
+ super().__init__()
500
+ self.dim = dim
501
+ self.input_resolution = input_resolution
502
+ self.num_heads = num_heads
503
+ self.window_size = window_size
504
+ self.shift_size = shift_size
505
+ self.mlp_ratio = mlp_ratio
506
+ if min(self.input_resolution) <= self.window_size:
507
+ # if window size is larger than input resolution, we don't partition windows
508
+ self.shift_size = 0
509
+ self.window_size = min(self.input_resolution)
510
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
511
+
512
+ self.norm1 = norm_layer(dim)
513
+ self.attn = WindowAttention(
514
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
515
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
516
+
517
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
518
+ self.norm2 = norm_layer(dim)
519
+ mlp_hidden_dim = int(dim * mlp_ratio)
520
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
521
+
522
+ if self.shift_size > 0:
523
+ attn_mask = self.calculate_mask(self.input_resolution)
524
+ else:
525
+ attn_mask = None
526
+
527
+ self.register_buffer("attn_mask", attn_mask)
528
+
529
+ def calculate_mask(self, x_size):
530
+ # calculate attention mask for SW-MSA
531
+ H, W = x_size
532
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
533
+ h_slices = (slice(0, -self.window_size),
534
+ slice(-self.window_size, -self.shift_size),
535
+ slice(-self.shift_size, None))
536
+ w_slices = (slice(0, -self.window_size),
537
+ slice(-self.window_size, -self.shift_size),
538
+ slice(-self.shift_size, None))
539
+ cnt = 0
540
+ for h in h_slices:
541
+ for w in w_slices:
542
+ img_mask[:, h, w, :] = cnt
543
+ cnt += 1
544
+
545
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
546
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
547
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
548
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
549
+
550
+ return attn_mask
551
+
552
+ def forward(self, x, x_size):
553
+ H, W = x_size
554
+ B, L, C = x.shape
555
+ # assert L == H * W, "input feature has wrong size"
556
+
557
+ shortcut = x
558
+ x = self.norm1(x)
559
+ x = x.view(B, H, W, C)
560
+
561
+ # cyclic shift
562
+ if self.shift_size > 0:
563
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
564
+ else:
565
+ shifted_x = x
566
+
567
+ # partition windows
568
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
569
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
570
+
571
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
572
+ if self.input_resolution == x_size:
573
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
574
+ else:
575
+ attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
576
+
577
+ # merge windows
578
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
579
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
580
+
581
+ # reverse cyclic shift
582
+ if self.shift_size > 0:
583
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
584
+ else:
585
+ x = shifted_x
586
+ x = x.view(B, H * W, C)
587
+
588
+ # FFN
589
+ x = shortcut + self.drop_path(x)
590
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
591
+
592
+ return x
593
+
594
+ def extra_repr(self) -> str:
595
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
596
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
597
+
598
+ def flops(self):
599
+ flops = 0
600
+ H, W = self.input_resolution
601
+ # norm1
602
+ flops += self.dim * H * W
603
+ # W-MSA/SW-MSA
604
+ nW = H * W / self.window_size / self.window_size
605
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
606
+ # mlp
607
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
608
+ # norm2
609
+ flops += self.dim * H * W
610
+ return flops
611
+
612
+
613
+ class BasicLayer(nn.Module):
614
+ """ A basic Swin Transformer layer for one stage.
615
+ Args:
616
+ dim (int): Number of input channels.
617
+ input_resolution (tuple[int]): Input resolution.
618
+ depth (int): Number of blocks.
619
+ num_heads (int): Number of attention heads.
620
+ window_size (int): Local window size.
621
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
622
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
623
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
624
+ drop (float, optional): Dropout rate. Default: 0.0
625
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
626
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
627
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
628
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
629
+ """
630
+
631
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
632
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
633
+ drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False):
634
+
635
+ super().__init__()
636
+ self.dim = dim
637
+ self.input_resolution = input_resolution
638
+ self.depth = depth
639
+ self.use_checkpoint = use_checkpoint
640
+
641
+ # build blocks
642
+ self.blocks = nn.ModuleList([
643
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
644
+ num_heads=num_heads, window_size=window_size,
645
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
646
+ mlp_ratio=mlp_ratio,
647
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
648
+ drop=drop, attn_drop=attn_drop,
649
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
650
+ norm_layer=norm_layer)
651
+ for i in range(depth)])
652
+
653
+ def forward(self, x, x_size):
654
+ for blk in self.blocks:
655
+ if self.use_checkpoint:
656
+ x = checkpoint.checkpoint(blk, x)
657
+ else:
658
+ x = blk(x, x_size)
659
+ return x
660
+
661
+ def extra_repr(self) -> str:
662
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
663
+
664
+ def flops(self):
665
+ flops = 0
666
+ for blk in self.blocks:
667
+ flops += blk.flops()
668
+ return flops
669
+
670
+
671
+ class RSTB(nn.Module):
672
+ """Residual Swin Transformer Block (RSTB).
673
+ Args:
674
+ dim (int): Number of input channels.
675
+ input_resolution (tuple[int]): Input resolution.
676
+ depth (int): Number of blocks.
677
+ num_heads (int): Number of attention heads.
678
+ window_size (int): Local window size.
679
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
680
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
681
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
682
+ drop (float, optional): Dropout rate. Default: 0.0
683
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
684
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
685
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
686
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
687
+ """
688
+
689
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
690
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
691
+ drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False):
692
+ super(RSTB, self).__init__()
693
+
694
+ self.dim = dim
695
+ self.input_resolution = input_resolution
696
+
697
+ self.residual_group = BasicLayer(dim=dim,
698
+ input_resolution=input_resolution,
699
+ depth=depth,
700
+ num_heads=num_heads,
701
+ window_size=window_size,
702
+ mlp_ratio=mlp_ratio,
703
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
704
+ drop=drop, attn_drop=attn_drop,
705
+ drop_path=drop_path,
706
+ norm_layer=norm_layer,
707
+ use_checkpoint=use_checkpoint
708
+ )
709
+
710
+ self.patch_embed = PatchEmbed()
711
+ self.patch_unembed = PatchUnEmbed()
712
+
713
+ def forward(self, x, x_size):
714
+ return self.patch_unembed(self.residual_group(self.patch_embed(x), x_size), x_size) + x
715
+
716
+ def flops(self):
717
+ flops = 0
718
+ flops += self.residual_group.flops()
719
+ flops += self.patch_embed.flops()
720
+ flops += self.patch_unembed.flops()
721
+
722
+ return flops
723
+
724
+
725
+ class CausalAttentionModule(nn.Module):
726
+ r""" Causal multi-head self attention module.
727
+
728
+ Args:
729
+ dim (int): Number of input channels.
730
+ window_size (tuple[int]): The height and width of the window.
731
+ num_heads (int): Number of attention heads.
732
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
733
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
734
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
735
+ """
736
+
737
+ def __init__(self, dim, out_dim, block_len=5, num_heads=16, mlp_ratio=4., qkv_bias=True, qk_scale=None,
738
+ attn_drop=0.):
739
+ super().__init__()
740
+ assert dim % num_heads == 0
741
+ self.dim = dim
742
+ self.num_heads = num_heads
743
+ head_dim = dim // num_heads
744
+ self.block_size = block_len * block_len
745
+ self.scale = qk_scale or head_dim ** -0.5
746
+ self.attn_drop = nn.Dropout(attn_drop)
747
+
748
+ self.norm1 = nn.LayerNorm(dim)
749
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
750
+ self.mask = torch.Tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]).view(1,
751
+ self.block_size,
752
+ 1)
753
+
754
+ # define a parameter table of relative position bias
755
+ self.relative_position_bias_table = nn.Parameter(
756
+ torch.zeros((2 * block_len - 1) * (2 * block_len - 1), num_heads)) # 2*P-1 * 2*P-1, num_heads
757
+
758
+ # get pair-wise relative position index for each token inside the window
759
+ coords_h = torch.arange(block_len)
760
+ coords_w = torch.arange(block_len)
761
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, P, P
762
+ coords_flatten = torch.flatten(coords, 1) # 2, P*P
763
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, PP, PP
764
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # PP, PP, 2
765
+ relative_coords[:, :, 0] += block_len - 1 # shift to start from 0
766
+ relative_coords[:, :, 1] += block_len - 1
767
+ relative_coords[:, :, 0] *= 2 * block_len - 1
768
+ relative_position_index = relative_coords.sum(-1) # PP, PP
769
+ self.register_buffer("relative_position_index", relative_position_index)
770
+
771
+ self.softmax = nn.Softmax(dim=-1)
772
+
773
+ self.norm2 = nn.LayerNorm(dim)
774
+ mlp_hidden_dim = int(dim * mlp_ratio)
775
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=nn.GELU, drop=attn_drop)
776
+ self.proj = nn.Linear(dim, out_dim)
777
+
778
+ def forward(self, x):
779
+ B, C, H, W = x.shape
780
+ x_unfold = F.unfold(x, kernel_size=(5, 5), padding=2) # B, CPP, HW
781
+ x_unfold = x_unfold.reshape(B, C, self.block_size, H * W).permute(0, 3, 2, 1).contiguous().view(-1,
782
+ self.block_size,
783
+ C) # BHW, PP, C
784
+
785
+ x_masked = x_unfold * self.mask.to(x_unfold.device)
786
+ out = self.norm1(x_masked)
787
+ qkv = self.qkv(out).reshape(B * H * W, self.block_size, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3,
788
+ 1,
789
+ 4) # 3, BHW, num_heads, PP, C
790
+ q, k, v = qkv[0], qkv[1], qkv[
791
+ 2] # make torchscript happy (cannot use tensor as tuple) # BHW, num_heads, PP, C//num_heads
792
+ q = q * self.scale
793
+ attn = (q @ k.transpose(-2, -1)) # BHW, num_heads, PP, PP
794
+
795
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
796
+ self.block_size, self.block_size, -1) # PP, PP, num_heads
797
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # num_heads, PP, PP
798
+ attn = attn + relative_position_bias.unsqueeze(0)
799
+
800
+ attn = self.softmax(attn)
801
+ attn = self.attn_drop(attn)
802
+ out = (attn @ v).transpose(1, 2).reshape(B * H * W, self.block_size,
803
+ C) # [BHW, num_heads, PP, PP] [BHW, num_heads, PP, C//num_heads]
804
+ out += x_masked
805
+ out_sumed = torch.sum(out, dim=1).reshape(B, H * W, C)
806
+ out = self.norm2(out_sumed)
807
+ out = self.mlp(out)
808
+ out += out_sumed
809
+
810
+ out = self.proj(out)
811
+ out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2) # B, C_out, H, W
812
+
813
+ return out
vae/__init__.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Local ROI-VAE compression package.
2
+
3
+ This package was previously named `compression`, but was renamed to `vae` to
4
+ avoid colliding with optional third-party imports (e.g. via Gradio/fsspec).
5
+
6
+ It intentionally uses lazy attribute loading to keep import-time overhead low.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from importlib import import_module
12
+ from typing import TYPE_CHECKING
13
+
14
+
15
+ __all__ = [
16
+ "TIC",
17
+ "ModifiedTIC",
18
+ "load_checkpoint",
19
+ "compute_padding",
20
+ "compress_image",
21
+ "highlight_roi",
22
+ "create_comparison_grid",
23
+ "RSTB",
24
+ "CausalAttentionModule",
25
+ ]
26
+
27
+
28
+ _EXPORTS: dict[str, tuple[str, str]] = {
29
+ "TIC": (".tic_model", "TIC"),
30
+ "ModifiedTIC": (".roi_tic", "ModifiedTIC"),
31
+ "load_checkpoint": (".roi_tic", "load_checkpoint"),
32
+ "compute_padding": (".utils", "compute_padding"),
33
+ "compress_image": (".utils", "compress_image"),
34
+ "highlight_roi": (".visualization", "highlight_roi"),
35
+ "create_comparison_grid": (".visualization", "create_comparison_grid"),
36
+ "RSTB": (".RSTB", "RSTB"),
37
+ "CausalAttentionModule": (".RSTB", "CausalAttentionModule"),
38
+ }
39
+
40
+
41
+ def __getattr__(name: str):
42
+ if name not in _EXPORTS:
43
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
44
+
45
+ module_name, attr_name = _EXPORTS[name]
46
+ module = import_module(module_name, package=__name__)
47
+ value = getattr(module, attr_name)
48
+ globals()[name] = value
49
+ return value
50
+
51
+
52
+ if TYPE_CHECKING:
53
+ from .RSTB import CausalAttentionModule, RSTB
54
+ from .roi_tic import ModifiedTIC, load_checkpoint
55
+ from .tic_model import TIC
56
+ from .utils import compress_image, compute_padding
57
+ from .visualization import create_comparison_grid, highlight_roi
vae/roi_tic.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ROI-aware TIC model for region-based compression.
3
+ """
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from .tic_model import TIC
8
+
9
+
10
+ class ModifiedTIC(TIC):
11
+ """Modified TIC that uses pre-computed binary mask with sigma parameter for ROI-based compression"""
12
+
13
+ def __init__(self, N=192, M=192):
14
+ super().__init__(N=N, M=M)
15
+ # Cache for pre-allocated constant tensors (device -> tensor)
16
+ self._ones_cache = {}
17
+ self._sigma_cache = {}
18
+
19
+ def forward(self, x, mask, sigma=0.3):
20
+ """
21
+ Forward pass with ROI mask and quality factor.
22
+
23
+ Args:
24
+ x: input image tensor [B, C, H, W]
25
+ mask: ROI mask (1 for ROI, 0 for background) [B, 1, H, W]
26
+ sigma: quality factor for background (0.01-1.0, lower = more compression)
27
+ Can be a scalar (same for all frames) or tensor [B] for per-frame values
28
+
29
+ Returns:
30
+ dict with compression outputs (y_hat, y, similarity, x_hat, likelihoods)
31
+ """
32
+ x_size = (x.shape[2], x.shape[3])
33
+ batch_size = x.shape[0]
34
+ device = mask.device
35
+
36
+ # Get or create cached ones tensor for this device
37
+ if device not in self._ones_cache:
38
+ self._ones_cache[device] = torch.tensor(1.0, device=device)
39
+ ones_tensor = self._ones_cache[device]
40
+
41
+ # Convert sigma to tensor [B, 1, 1, 1] for broadcasting
42
+ if isinstance(sigma, (int, float)):
43
+ # Scalar sigma - use cache
44
+ cache_key = (device, sigma)
45
+ if cache_key not in self._sigma_cache:
46
+ self._sigma_cache[cache_key] = torch.tensor(sigma, device=device)
47
+ sigma_tensor = self._sigma_cache[cache_key]
48
+ else:
49
+ # Batched sigma - convert to [B, 1, 1, 1]
50
+ if sigma.dim() == 1:
51
+ sigma_tensor = sigma.view(batch_size, 1, 1, 1)
52
+ else:
53
+ sigma_tensor = sigma
54
+
55
+ # Convert binary mask to quality factors (broadcasting handles per-frame sigma)
56
+ similarity_loss = torch.where(mask > 0.5, ones_tensor, sigma_tensor)
57
+ similarity_imp = torch.where(mask > 0.5, ones_tensor, sigma_tensor)
58
+
59
+ # Downsample mask to 1/2 resolution for simi_net
60
+ # simi_net has 3 stride-2 convolutions (8x downsampling), so input at 1/2 gives output at 1/16
61
+ # which matches y_codec_a6 dimensions (after g_a's 3x downsampling + g_a6's 1x downsampling = 16x)
62
+ # Use nearest-neighbor for binary masks (faster, no quality loss for binary data)
63
+ similarity_down = F.interpolate(similarity_imp, scale_factor=0.5, mode='nearest')
64
+
65
+ similarity_up = F.interpolate(similarity_loss, scale_factor=2, mode='nearest')
66
+ similarity_up_repeated = similarity_up.repeat(1, 3, 1, 1)
67
+
68
+ # simi_net downsamples by 8x: 128x128 -> 16x16 to match y_codec_a6
69
+ similarities_channel = self.simi_net(similarity_down)
70
+ similarities_sigmoid = torch.sigmoid(similarities_channel)
71
+
72
+ y_codec = self.g_a(x, x_size)
73
+ y_codec_a6 = self.g_a6(y_codec)
74
+
75
+ y_import = self.sub_impor_net(y_codec)
76
+ y_tanh = self.tanh(y_import)
77
+ y_soft = self.softsign(y_tanh)
78
+
79
+ y_imp = y_soft + similarities_sigmoid
80
+ y = y_codec_a6 * y_imp
81
+
82
+ z = self.h_a(y, x_size)
83
+ z_hat, z_likelihoods = self.entropy_bottleneck(z)
84
+ params = self.h_s(z_hat, x_size)
85
+
86
+ y_hat = self.gaussian_conditional.quantize(
87
+ y, "noise" if self.training else "dequantize"
88
+ )
89
+ ctx_params = self.context_prediction(y_hat)
90
+ gaussian_params = self.entropy_parameters(
91
+ torch.cat((params, ctx_params), dim=1)
92
+ )
93
+ scales_hat, means_hat = gaussian_params.chunk(2, 1)
94
+ _, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat)
95
+ x_hat = self.g_s(y_hat, x_size)
96
+
97
+ return {
98
+ "y_hat": y_hat,
99
+ "y": y,
100
+ "similarity": similarity_up_repeated,
101
+ "x_hat": x_hat,
102
+ "likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
103
+ }
104
+
105
+
106
+ def load_checkpoint(checkpoint_path: str, N: int = 192, M: int = 192, device: str = 'cuda') -> ModifiedTIC:
107
+ """
108
+ Load TIC model from checkpoint.
109
+
110
+ Args:
111
+ checkpoint_path: Path to .pth.tar checkpoint
112
+ N: Number of channels (default 192)
113
+ M: Number of channels in expansion layers (default 192)
114
+ device: 'cuda' or 'cpu'
115
+
116
+ Returns:
117
+ Loaded ModifiedTIC model in eval mode
118
+ """
119
+ model = ModifiedTIC(N=N, M=M).to(device)
120
+ checkpoint = torch.load(checkpoint_path, map_location=device)
121
+
122
+ # Fix for compressai version mismatch (handle both old→new and new→old)
123
+ state_dict = checkpoint["state_dict"]
124
+
125
+ # Don't convert - just use as-is
126
+ # The checkpoint is already in the old format that the model expects
127
+
128
+
129
+ model.load_state_dict(state_dict)
130
+ model.eval()
131
+ model.update(force=True)
132
+ return model
vae/tic_model.py ADDED
@@ -0,0 +1,989 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import time
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from compressai.entropy_models import EntropyBottleneck, GaussianConditional
9
+ from .RSTB import RSTB, CausalAttentionModule
10
+ from compressai.ans import BufferedRansEncoder, RansDecoder
11
+ from timm.models.layers import trunc_normal_
12
+ from compressai.models.utils import conv, deconv, update_registered_buffers
13
+ from compressai.layers import AttentionBlock
14
+ from PIL import Image
15
+ import numpy as np
16
+ import matplotlib
17
+ import matplotlib.pyplot as plt
18
+ from matplotlib.colors import ListedColormap
19
+
20
+ # from lseg.lseg_net import LSegNet
21
+ # import cv2
22
+ # import random
23
+ import itertools
24
+
25
+
26
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
27
+
28
+ # From Balle's tensorflow compression examples
29
+ SCALES_MIN = 0.11
30
+ SCALES_MAX = 256
31
+ SCALES_LEVELS = 64
32
+
33
+ device = "cuda"
34
+
35
+ def get_scale_table(min=SCALES_MIN, max=SCALES_MAX, levels=SCALES_LEVELS):
36
+ return torch.exp(torch.linspace(math.log(min), math.log(max), levels))
37
+
38
+
39
+ class Binarizer(torch.autograd.Function):
40
+ """
41
+ An elementwise function that bins values
42
+ to 0 or 1 depending on a threshold of
43
+ 0.5
44
+
45
+ Input: a tensor with values in range(0,1)
46
+
47
+ Returns: a tensor with binary values: 0 or 1
48
+ based on a threshold of 0.5
49
+
50
+ Equation(1) in paper
51
+ """
52
+ @staticmethod
53
+ def forward(ctx, i):
54
+ result = torch.where(i > 0.9, torch.tensor(1.0), torch.tensor(0.2))
55
+ return result
56
+
57
+ @staticmethod
58
+ def backward(ctx, grad_output):
59
+ return grad_output
60
+
61
+ def bin_values(x):
62
+ return Binarizer.apply(x)
63
+
64
+
65
+
66
+
67
+ class TIC(nn.Module):
68
+ """Neural image compression framework from
69
+ Lu Ming and Guo, Peiyao and Shi, Huiqing and Cao, Chuntong and Ma, Zhan:
70
+ `"Transformer-based Image Compression" <https://arxiv.org/abs/2111.06707>`, (DCC 2022).
71
+
72
+ Args:
73
+ N (int): Number of channels
74
+ M (int): Number of channels in the expansion layers (last layer of the
75
+ encoder and last layer of the hyperprior decoder)
76
+ input_resolution (int): Just used for window partition decision
77
+ """
78
+
79
+ def __init__(self, N=192, M=192):
80
+ super().__init__()
81
+
82
+ depths = [1, 2, 3, 1, 1]
83
+ num_heads = [4, 8, 16, 16, 16]
84
+ window_size = 8
85
+ mlp_ratio = 4.
86
+ qkv_bias = True
87
+ qk_scale = None
88
+ drop_rate = 0.
89
+ attn_drop_rate = 0.
90
+ drop_path_rate = 0.2
91
+ norm_layer = nn.LayerNorm
92
+ use_checkpoint = False
93
+
94
+ # stochastic depth
95
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
96
+ self.align_corners = True
97
+
98
+
99
+ self.g_a0 = conv(3, N, kernel_size=5, stride=2)
100
+ self.g_a1 = RSTB(dim=N,
101
+ input_resolution=(128, 128),
102
+ depth=depths[0],
103
+ num_heads=num_heads[0],
104
+ window_size=window_size,
105
+ mlp_ratio=mlp_ratio,
106
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
107
+ drop=drop_rate, attn_drop=attn_drop_rate,
108
+ drop_path=dpr[sum(depths[:0]):sum(depths[:1])],
109
+ norm_layer=norm_layer,
110
+ use_checkpoint=use_checkpoint,
111
+ )
112
+ self.g_a2 = conv(N, N, kernel_size=3, stride=2)
113
+ self.g_a3 = RSTB(dim=N,
114
+ input_resolution=(64, 64),
115
+ depth=depths[1],
116
+ num_heads=num_heads[1],
117
+ window_size=window_size,
118
+ mlp_ratio=mlp_ratio,
119
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
120
+ drop=drop_rate, attn_drop=attn_drop_rate,
121
+ drop_path=dpr[sum(depths[:1]):sum(depths[:2])],
122
+ norm_layer=norm_layer,
123
+ use_checkpoint=use_checkpoint,
124
+ )
125
+ self.g_a4 = conv(N, N, kernel_size=3, stride=2)
126
+ self.g_a5 = RSTB(dim=N,
127
+ input_resolution=(32, 32),
128
+ depth=depths[2],
129
+ num_heads=num_heads[2],
130
+ window_size=window_size,
131
+ mlp_ratio=mlp_ratio,
132
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
133
+ drop=drop_rate, attn_drop=attn_drop_rate,
134
+ drop_path=dpr[sum(depths[:2]):sum(depths[:3])],
135
+ norm_layer=norm_layer,
136
+ use_checkpoint=use_checkpoint,
137
+ )
138
+
139
+ self.g_a6 = conv(N, M, kernel_size=3, stride=2)
140
+
141
+ self.h_a0 = conv(M, N, kernel_size=3, stride=1)
142
+ self.h_a1 = RSTB(dim=N,
143
+ input_resolution=(16, 16),
144
+ depth=depths[3],
145
+ num_heads=num_heads[3],
146
+ window_size=window_size // 2,
147
+ mlp_ratio=mlp_ratio,
148
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
149
+ drop=drop_rate, attn_drop=attn_drop_rate,
150
+ drop_path=dpr[sum(depths[:3]):sum(depths[:4])],
151
+ norm_layer=norm_layer,
152
+ use_checkpoint=use_checkpoint,
153
+ )
154
+ self.h_a2 = conv(N, N, kernel_size=3, stride=2)
155
+ self.h_a3 = RSTB(dim=N,
156
+ input_resolution=(8, 8),
157
+ depth=depths[4],
158
+ num_heads=num_heads[4],
159
+ window_size=window_size // 2,
160
+ mlp_ratio=mlp_ratio,
161
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
162
+ drop=drop_rate, attn_drop=attn_drop_rate,
163
+ drop_path=dpr[sum(depths[:4]):sum(depths[:5])],
164
+ norm_layer=norm_layer,
165
+ use_checkpoint=use_checkpoint,
166
+ )
167
+ self.h_a4 = conv(N, N, kernel_size=3, stride=2)
168
+
169
+ depths = depths[::-1]
170
+ num_heads = num_heads[::-1]
171
+ self.h_s0 = deconv(N, N, kernel_size=3, stride=2)
172
+ self.h_s1 = RSTB(dim=N,
173
+ input_resolution=(8, 8),
174
+ depth=depths[0],
175
+ num_heads=num_heads[0],
176
+ window_size=window_size // 2,
177
+ mlp_ratio=mlp_ratio,
178
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
179
+ drop=drop_rate, attn_drop=attn_drop_rate,
180
+ drop_path=dpr[sum(depths[:0]):sum(depths[:1])],
181
+ norm_layer=norm_layer,
182
+ use_checkpoint=use_checkpoint,
183
+ )
184
+ self.h_s2 = deconv(N, N, kernel_size=3, stride=2)
185
+ self.h_s3 = RSTB(dim=N,
186
+ input_resolution=(16, 16),
187
+ depth=depths[1],
188
+ num_heads=num_heads[1],
189
+ window_size=window_size // 2,
190
+ mlp_ratio=mlp_ratio,
191
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
192
+ drop=drop_rate, attn_drop=attn_drop_rate,
193
+ drop_path=dpr[sum(depths[:1]):sum(depths[:2])],
194
+ norm_layer=norm_layer,
195
+ use_checkpoint=use_checkpoint,
196
+ )
197
+ self.h_s4 = conv(N, M * 2, kernel_size=3, stride=1)
198
+
199
+ self.g_s0 = deconv(M, N, kernel_size=3, stride=2)
200
+ self.g_s1 = RSTB(dim=N,
201
+ input_resolution=(32, 32),
202
+ depth=depths[2],
203
+ num_heads=num_heads[2],
204
+ window_size=window_size,
205
+ mlp_ratio=mlp_ratio,
206
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
207
+ drop=drop_rate, attn_drop=attn_drop_rate,
208
+ drop_path=dpr[sum(depths[:2]):sum(depths[:3])],
209
+ norm_layer=norm_layer,
210
+ use_checkpoint=use_checkpoint,
211
+ )
212
+ self.g_s2 = deconv(N, N, kernel_size=3, stride=2)
213
+ self.g_s3 = RSTB(dim=N,
214
+ input_resolution=(64, 64),
215
+ depth=depths[3],
216
+ num_heads=num_heads[3],
217
+ window_size=window_size,
218
+ mlp_ratio=mlp_ratio,
219
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
220
+ drop=drop_rate, attn_drop=attn_drop_rate,
221
+ drop_path=dpr[sum(depths[:3]):sum(depths[:4])],
222
+ norm_layer=norm_layer,
223
+ use_checkpoint=use_checkpoint,
224
+ )
225
+ self.g_s4 = deconv(N, N, kernel_size=3, stride=2)
226
+ self.g_s5 = RSTB(dim=N,
227
+ input_resolution=(128, 128),
228
+ depth=depths[4],
229
+ num_heads=num_heads[4],
230
+ window_size=window_size,
231
+ mlp_ratio=mlp_ratio,
232
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
233
+ drop=drop_rate, attn_drop=attn_drop_rate,
234
+ drop_path=dpr[sum(depths[:4]):sum(depths[:5])],
235
+ norm_layer=norm_layer,
236
+ use_checkpoint=use_checkpoint,
237
+ )
238
+ self.g_s6 = deconv(N, 3, kernel_size=5, stride=2)
239
+
240
+ self.entropy_bottleneck = EntropyBottleneck(N)
241
+ self.gaussian_conditional = GaussianConditional(None)
242
+ self.context_prediction = CausalAttentionModule(M, M * 2)
243
+ # self.attetionmap = AttentionBlock(M)
244
+
245
+ self.entropy_parameters = nn.Sequential(
246
+ nn.Conv2d(M * 12 // 3, M * 10 // 3, 1),
247
+ nn.GELU(),
248
+ nn.Conv2d(M * 10 // 3, M * 8 // 3, 1),
249
+ nn.GELU(),
250
+ nn.Conv2d(M * 8 // 3, M * 6 // 3, 1),
251
+ )
252
+
253
+ self.sub_net_leaky = nn.Sequential(
254
+ conv(N,N,kernel_size=3,stride=2),
255
+ nn.LeakyReLU()
256
+ )
257
+
258
+ self.sub_net0 = nn.Sequential(
259
+ conv(N,64,kernel_size=1,stride=1),
260
+ nn.ReLU()
261
+ )
262
+ self.sub_net1 = nn.Sequential(
263
+ conv(64,64,kernel_size=3,stride=1),
264
+ nn.ReLU()
265
+ )
266
+ self.sub_net2 = conv(64,N,kernel_size=1,stride=1)
267
+
268
+ self.sub_net_channel = conv(N,M,kernel_size=1,stride=1)
269
+
270
+ self.simi_net = nn.Sequential(
271
+ conv(1,64,kernel_size=3,stride=2),
272
+ nn.ReLU(),
273
+ conv(64,128,kernel_size=3,stride=2),
274
+ nn.ReLU(),
275
+ conv(128, M, kernel_size=3, stride=2),
276
+ )
277
+
278
+ # self.net_lseg = LSegNet(
279
+ # backbone="clip_vitl16_384",
280
+ # features=256,
281
+ # crop_size=256,
282
+ # arch_option=0,
283
+ # block_depth=0,
284
+ # activation="lrelu",
285
+ # )
286
+
287
+ self.cosine_similarity = torch.nn.CosineSimilarity(dim=1)
288
+
289
+
290
+
291
+
292
+ # self.con3_3 = conv(192,192,kernel_size=3,stride=1)
293
+
294
+ self.tanh = nn.Tanh()
295
+ self.softsign = nn.Softsign()
296
+ self.relu = nn.ReLU()
297
+ self.sigmoid = nn.Sigmoid()
298
+
299
+
300
+ self.apply(self._init_weights)
301
+
302
+
303
+
304
+ def g_a(self, x, x_size=None):
305
+ if x_size is None:
306
+ x_size = x.shape[2:4]
307
+ x = self.g_a0(x)
308
+ x = self.g_a1(x, (x_size[0] // 2, x_size[1] // 2))
309
+ x = self.g_a2(x)
310
+ x = self.g_a3(x, (x_size[0] // 4, x_size[1] // 4))
311
+ x = self.g_a4(x)
312
+ x = self.g_a5(x, (x_size[0] // 8, x_size[1] // 8))
313
+ # x = self.g_a6(x)
314
+ return x
315
+
316
+ def g_s(self, x, x_size=None):
317
+ if x_size is None:
318
+ x_size = (x.shape[2] * 16, x.shape[3] * 16)
319
+ x = self.g_s0(x)
320
+ x = self.g_s1(x, (x_size[0] // 8, x_size[1] // 8))
321
+ x = self.g_s2(x)
322
+ x = self.g_s3(x, (x_size[0] // 4, x_size[1] // 4))
323
+ x = self.g_s4(x)
324
+ x = self.g_s5(x, (x_size[0] // 2, x_size[1] // 2))
325
+ x = self.g_s6(x)
326
+ return x
327
+
328
+ def h_a(self, x, x_size=None):
329
+ if x_size is None:
330
+ x_size = (x.shape[2] * 16, x.shape[3] * 16)
331
+ x = self.h_a0(x)
332
+ x = self.h_a1(x, (x_size[0] // 16, x_size[1] // 16))
333
+ x = self.h_a2(x)
334
+ x = self.h_a3(x, (x_size[0] // 32, x_size[1] // 32))
335
+ x = self.h_a4(x)
336
+ return x
337
+
338
+ def h_s(self, x, x_size=None):
339
+ if x_size is None:
340
+ x_size = (x.shape[2] * 64, x.shape[3] * 64)
341
+ x = self.h_s0(x)
342
+ x = self.h_s1(x, (x_size[0] // 32, x_size[1] // 32))
343
+ x = self.h_s2(x)
344
+ x = self.h_s3(x, (x_size[0] // 16, x_size[1] // 16))
345
+ x = self.h_s4(x)
346
+ return x
347
+
348
+
349
+ def sub_impor_net(self,x): # important map
350
+ x1 = self.sub_net_leaky(x)
351
+
352
+ x2 = self.sub_net0(x1)
353
+ x2 = self.sub_net1(x2)
354
+ x2 = self.sub_net2(x2)
355
+
356
+ x2 = x1 + x2
357
+ x3 = self.sub_net0(x2)
358
+ x3 = self.sub_net1(x3)
359
+ x3 = self.sub_net2(x3)
360
+
361
+ x3 = x2 + x3
362
+ x4 = self.sub_net0(x3)
363
+ x4 = self.sub_net1(x4)
364
+ x4 = self.sub_net2(x4)
365
+
366
+ x_out = x4 + x3
367
+ x_out = self.sub_net_channel(x_out)
368
+
369
+ return x_out
370
+
371
+
372
+
373
+ def aux_loss(self):
374
+ """Return the aggregated loss over the auxiliary entropy bottleneck
375
+ module(s).
376
+ """
377
+ aux_loss = sum(
378
+ m.loss() for m in self.modules() if isinstance(m, EntropyBottleneck)
379
+ )
380
+ return aux_loss
381
+
382
+ def _init_weights(self, m):
383
+ if isinstance(m, nn.Linear):
384
+ trunc_normal_(m.weight, std=.02)
385
+ if isinstance(m, nn.Linear) and m.bias is not None:
386
+ nn.init.constant_(m.bias, 0)
387
+ elif isinstance(m, nn.LayerNorm):
388
+ nn.init.constant_(m.bias, 0)
389
+ nn.init.constant_(m.weight, 1.0)
390
+
391
+
392
+ @torch.jit.ignore
393
+ def no_weight_decay_keywords(self):
394
+ return {'relative_position_bias_table'}
395
+
396
+ def forward(self, x, similarity):
397
+
398
+ x_size = (x.shape[2], x.shape[3])
399
+
400
+ h, w = x.size(2), x.size(3)
401
+
402
+ similarity_loss = torch.where(similarity > 0.85, torch.tensor(1.0), torch.tensor(0.01))
403
+ similarity_imp = torch.where(similarity > 0.85, torch.tensor(1.0), torch.tensor(0.01))
404
+
405
+ similarity_up = F.interpolate(similarity_loss, scale_factor=2, mode='bilinear')
406
+ similarity_up_repeated = similarity_up.repeat(1, 3, 1, 1)
407
+
408
+ similarities_channel = self.simi_net(similarity_imp)
409
+ similarities_sigmoid = torch.sigmoid(similarities_channel)
410
+
411
+
412
+ y_codec = self.g_a(x, x_size) # y
413
+ y_codec_a6 = self.g_a6(y_codec)
414
+
415
+
416
+ y_import = self.sub_impor_net(y_codec)
417
+ y_tanh = self.tanh(y_import)
418
+ y_soft = self.softsign(y_tanh)
419
+
420
+
421
+
422
+ y_imp = y_soft + similarities_sigmoid
423
+ y = y_codec_a6 * y_imp
424
+
425
+
426
+ z = self.h_a(y, x_size)
427
+ z_hat, z_likelihoods = self.entropy_bottleneck(z)
428
+ params = self.h_s(z_hat, x_size)
429
+
430
+ y_hat = self.gaussian_conditional.quantize(
431
+ y, "noise" if self.training else "dequantize"
432
+ )
433
+ ctx_params = self.context_prediction(y_hat)
434
+ gaussian_params = self.entropy_parameters(
435
+ torch.cat((params, ctx_params), dim=1)
436
+ )
437
+ scales_hat, means_hat = gaussian_params.chunk(2, 1)
438
+ _, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat)
439
+ x_hat = self.g_s(y_hat, x_size)
440
+
441
+ return {
442
+ "y_hat": y_hat,
443
+ "y": y,
444
+ "similarity":similarity_up_repeated,
445
+ "x_hat": x_hat,
446
+ "likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
447
+ }
448
+
449
+
450
+ def update(self, scale_table=None, force=False):
451
+ """Updates the entropy bottleneck(s) CDF values.
452
+
453
+ Needs to be called once after training to be able to later perform the
454
+ evaluation with an actual entropy coder.
455
+
456
+ Args:
457
+ scale_table (bool): (default: None)
458
+ force (bool): overwrite previous values (default: False)
459
+
460
+ Returns:
461
+ updated (bool): True if one of the EntropyBottlenecks was updated.
462
+
463
+ """
464
+ if scale_table is None:
465
+ scale_table = get_scale_table()
466
+ self.gaussian_conditional.update_scale_table(scale_table, force=force)
467
+
468
+ updated = False
469
+ for m in self.children():
470
+ if not isinstance(m, EntropyBottleneck):
471
+ continue
472
+ rv = m.update(force=force)
473
+ updated |= rv
474
+ return updated
475
+
476
+ def load_state_dict(self, state_dict, strict=True):
477
+ # Dynamically update the entropy bottleneck buffers related to the CDFs
478
+ update_registered_buffers(
479
+ self.entropy_bottleneck,
480
+ "entropy_bottleneck",
481
+ ["_quantized_cdf", "_offset", "_cdf_length"],
482
+ state_dict,
483
+ )
484
+ update_registered_buffers(
485
+ self.gaussian_conditional,
486
+ "gaussian_conditional",
487
+ ["_quantized_cdf", "_offset", "_cdf_length", "scale_table"],
488
+ state_dict,
489
+ )
490
+ super().load_state_dict(state_dict, strict=strict)
491
+
492
+ @classmethod
493
+ def from_state_dict(cls, state_dict):
494
+ """Return a new model instance from `state_dict`."""
495
+ N = state_dict["g_a0.weight"].size(0)
496
+ M = state_dict["g_a6.weight"].size(0)
497
+ net = cls(N, M)
498
+ net.load_state_dict(state_dict)
499
+ return net
500
+
501
+ # def compress(self, x,similarity):
502
+ def compress(self, x):
503
+ x = x.cuda()
504
+ # similarity = similarity.to(device)
505
+ x_size = (x.shape[2], x.shape[3])
506
+
507
+ # start_1 = time.time()
508
+ #
509
+ # img_feat = self.net_lseg.forward(x)
510
+ # img_feat_norm = torch.nn.functional.normalize(img_feat, dim=1)
511
+ # #
512
+ # prompt = clip.tokenize(similarity).cuda()
513
+ # text_feat = self.net_lseg.clip_pretrained.encode_text(prompt) # 1, 512
514
+ # text_feat_norm = torch.nn.functional.normalize(text_feat, dim=1)
515
+ # #
516
+ # similarity = self.cosine_similarity(
517
+ # img_feat_norm, text_feat_norm.unsqueeze(-1).unsqueeze(-1)
518
+ # )
519
+ # similarity = similarity.unsqueeze(0)
520
+ #
521
+ # torch.cuda.synchronize()
522
+ #
523
+ # inf_time = time.time() - start_1
524
+ #
525
+ # print(inf_time)
526
+
527
+ # #####在这里
528
+
529
+ start = time.time()
530
+ # similarity_down_1 = torch.where(similarity > 0.9, torch.tensor(1.0), torch.tensor(1.0))
531
+ # similarities_repeated = self.simi_net(similarity_down_1)
532
+ # similarities_repeated = torch.sigmoid(similarities_repeated)
533
+
534
+
535
+
536
+
537
+
538
+ y_codec = self.g_a(x, x_size) # y
539
+
540
+ # y_import = self.sub_impor_net(y_codec)
541
+ # y_tanh = self.tanh(y_import)
542
+ #
543
+ # y_soft = self.softsign(y_tanh)
544
+
545
+
546
+ y_codec_a6 = self.g_a6(y_codec)
547
+
548
+
549
+ # y_imp = y_soft + similarities_repeated # 相似度* important map
550
+ # y = y_codec_a6 * y_imp
551
+ y = y_codec_a6
552
+ # y = y_imp * y_codec_a6
553
+ # y = self.sub_net_channel(y)
554
+
555
+ # y = y_codec_a6 * similarities_repeated
556
+
557
+ z = self.h_a(y)
558
+
559
+ z_strings = self.entropy_bottleneck.compress(z)
560
+ z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:])
561
+
562
+ params = self.h_s(z_hat)
563
+
564
+ s = 4 # scaling factor between z and y
565
+ kernel_size = 5 # context prediction kernel size
566
+ padding = (kernel_size - 1) // 2
567
+
568
+ y_height = z_hat.size(2) * s
569
+ y_width = z_hat.size(3) * s
570
+
571
+ y_hat = F.pad(y, (padding, padding, padding, padding))
572
+
573
+ # pylint: disable=protected-access
574
+ cdf = self.gaussian_conditional._quantized_cdf.tolist()
575
+ cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist()
576
+ offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist()
577
+ # pylint: enable=protected-access
578
+ # print(cdf, cdf_lengths, offsets)
579
+ y_strings = []
580
+ for i in range(y.size(0)):
581
+ encoder = BufferedRansEncoder()
582
+ # Warning, this is slow...
583
+ # TODO: profile the calls to the bindings...
584
+ symbols_list = []
585
+ indexes_list = []
586
+ y_q_ = torch.zeros_like(y)
587
+ indexes_ = torch.zeros_like(y)
588
+ for h in range(y_height):
589
+ for w in range(y_width):
590
+ y_crop = y_hat[
591
+ i: i + 1, :, h: h + kernel_size, w: w + kernel_size
592
+ ]
593
+ ctx_p = self.context_prediction(y_crop)
594
+ # 1x1 conv for the entropy parameters prediction network, so
595
+ # we only keep the elements in the "center"
596
+ p = params[i: i + 1, :, h: h + 1, w: w + 1]
597
+ gaussian_params = self.entropy_parameters(
598
+ torch.cat((p, ctx_p[i: i + 1, :, 2: 3, 2: 3]), dim=1)
599
+ )
600
+ scales_hat, means_hat = gaussian_params.chunk(2, 1)
601
+
602
+ indexes = self.gaussian_conditional.build_indexes(scales_hat)
603
+ y_q = torch.round(y_crop - means_hat)
604
+ y_hat[i, :, h + padding, w + padding] = (y_q + means_hat)[
605
+ i, :, padding, padding
606
+ ]
607
+ y_q_[i,:, h, w] = y_q[i, :, padding, padding]
608
+ indexes_[i,:, h, w] = indexes[i, :,0,0]
609
+
610
+ flag = np.array(np.zeros(y_q_.shape[1]))
611
+ for idx in range(y_q_.shape[1]):
612
+ if torch.sum(torch.abs(y_q_[:, idx, :, :])) > 0: # 全部大于0就设置标志位是1
613
+ flag[idx] = 1
614
+ y_q_ = y_q_[:,np.nonzero(flag),...].squeeze()
615
+ indexes_ = indexes_[:,np.nonzero(flag),...].squeeze()
616
+ for h in range(y_height):
617
+ for w in range(y_width):
618
+ # encoder.encode_with_indexes(
619
+ # y_q_[:,np.nonzero(flag),h,w].squeeze().int().tolist(),
620
+ # indexes_[:,np.nonzero(flag),h,w].squeeze().int().tolist(), cdf, cdf_lengths, offsets
621
+ # )
622
+ symbols_list.extend(y_q_[:,h,w].int().tolist())
623
+ indexes_list.extend(indexes_[:,h,w].squeeze().int().tolist())
624
+ encoder.encode_with_indexes(
625
+ symbols_list, indexes_list, cdf, cdf_lengths, offsets
626
+ )
627
+ string = encoder.flush()
628
+ y_strings.append(string)
629
+ print(flag.sum())
630
+
631
+ torch.cuda.synchronize() # 确保 model2 真正跑完
632
+ t2 = time.time() - start
633
+ # print(t2)
634
+
635
+ return {"strings": [y_strings, z_strings], "shape": z.size()[-2:],"flag":flag}
636
+
637
+ # return {"test":similarity}
638
+
639
+ def compress_1(self, x,similarity):
640
+ # def compress_1(self, x):
641
+ x = x.cuda()
642
+ x_size = (x.shape[2], x.shape[3])
643
+
644
+ similarity = similarity.cuda()
645
+ # # #
646
+ similarity_down_1 = torch.where(similarity == 0, torch.tensor(1e-4), torch.tensor(1.0))
647
+ #
648
+ #
649
+
650
+
651
+ # similarity_down_1 = torch.where(similarity > 0.9, torch.tensor(1.0), torch.tensor(1e-4))
652
+ #
653
+ similarity_down_1 = F.interpolate(similarity_down_1, scale_factor=0.5, mode='bilinear')
654
+ similarities_repeated = self.simi_net(similarity_down_1)
655
+ similarities_repeated = torch.sigmoid(similarities_repeated)
656
+
657
+
658
+ y_codec = self.g_a(x, x_size) # y
659
+
660
+ y_import = self.sub_impor_net(y_codec)
661
+ y_tanh = self.tanh(y_import)
662
+ y_soft = self.softsign(y_tanh) # important2
663
+ # y_soft = self.sigmoid(y_soft)
664
+
665
+ y_codec_a6 = self.g_a6(y_codec)
666
+ # y_codec_a6 = self.attetionmap(y_codec_a6)
667
+
668
+
669
+ y_imp = similarities_repeated + y_soft # 相似度* important map
670
+
671
+
672
+ y = y_codec_a6 * y_imp
673
+ #
674
+ # y= y_codec_a6 * y_tanh
675
+
676
+ # cmap = ListedColormap(['yellow'])
677
+ #
678
+ # similarity_image = torch.where(similarity > 0.9, torch.tensor(1.0), torch.tensor(0.1))
679
+ # similarity_image = F.interpolate(similarity_image, scale_factor=2, mode='bilinear')
680
+ # abs = torch.abs(similarity_image)
681
+ # mean = torch.mean(abs, axis=1, keepdims=True)
682
+ # viz = mean.detach().cpu().numpy()
683
+ # viz = viz[0]
684
+ # viz = viz.squeeze()
685
+ # plt.imshow(viz)
686
+ # # # 保存图像
687
+ # plt.imsave('/mnt/disk10T/xfx/CLIP/bird.png', viz)
688
+
689
+ z = self.h_a(y)
690
+
691
+ z_strings = self.entropy_bottleneck.compress(z)
692
+ z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:])
693
+
694
+ params = self.h_s(z_hat)
695
+
696
+ s = 4 # scaling factor between z and y
697
+ kernel_size = 5 # context prediction kernel size
698
+ padding = (kernel_size - 1) // 2
699
+
700
+ y_height = z_hat.size(2) * s
701
+ y_width = z_hat.size(3) * s
702
+
703
+ y_hat = F.pad(y, (padding, padding, padding, padding))
704
+
705
+ # pylint: disable=protected-access
706
+ cdf = self.gaussian_conditional._quantized_cdf.tolist()
707
+ cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist()
708
+ offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist()
709
+ # pylint: enable=protected-access
710
+ # print(cdf, cdf_lengths, offsets)
711
+
712
+ y_strings = []
713
+ for i in range(y.size(0)):
714
+ encoder = BufferedRansEncoder()
715
+ # Warning, this is slow...
716
+ # TODO: profile the calls to the bindings...
717
+ symbols_list = []
718
+ indexes_list = []
719
+ for h in range(y_height):
720
+ for w in range(y_width):
721
+ y_crop = y_hat[
722
+ i: i + 1, :, h: h + kernel_size, w: w + kernel_size
723
+ ]
724
+ ctx_p = self.context_prediction(y_crop)
725
+ # 1x1 conv for the entropy parameters prediction network, so
726
+ # we only keep the elements in the "center"
727
+ p = params[i: i + 1, :, h: h + 1, w: w + 1]
728
+ gaussian_params = self.entropy_parameters(
729
+ torch.cat((p, ctx_p[i: i + 1, :, 2: 3, 2: 3]), dim=1)
730
+ )
731
+ scales_hat, means_hat = gaussian_params.chunk(2, 1)
732
+
733
+ indexes = self.gaussian_conditional.build_indexes(scales_hat)
734
+ y_q = torch.round(y_crop - means_hat)
735
+ y_hat[i, :, h + padding, w + padding] = (y_q + means_hat)[
736
+ i, :, padding, padding
737
+ ]
738
+
739
+ symbols_list.extend(y_q[i, :, padding, padding].int().tolist())
740
+ indexes_list.extend(indexes[i, :].squeeze().int().tolist())
741
+
742
+ encoder.encode_with_indexes(
743
+ symbols_list, indexes_list, cdf, cdf_lengths, offsets
744
+ )
745
+
746
+ string = encoder.flush()
747
+ y_strings.append(string)
748
+
749
+ return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]}
750
+
751
+ def compress_2(self, x,similarity):
752
+ # def compress_1(self, x):
753
+ x = x.cuda()
754
+ x_size = (x.shape[2], x.shape[3])
755
+
756
+ #####在这里
757
+ similarity_down_1 = torch.where(similarity > 0.9, torch.tensor(1.0), torch.tensor(0.1))
758
+ similarities_repeated = self.simi_net(similarity_down_1)
759
+ similarities_repeated = torch.sigmoid(similarities_repeated)
760
+
761
+ y_codec = self.g_a(x, x_size) # y
762
+
763
+ y_import = self.sub_impor_net(y_codec)
764
+ y_tanh = self.tanh(y_import)
765
+
766
+ y_codec_a6 = self.g_a6(y_codec)
767
+
768
+
769
+
770
+ y_imp = similarities_repeated + y_tanh # 相似度* important map
771
+
772
+ y = y_codec_a6 * y_imp
773
+
774
+ z = self.h_a(y)
775
+
776
+ z_strings = self.entropy_bottleneck.compress(z)
777
+ z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:])
778
+
779
+ params = self.h_s(z_hat)
780
+
781
+ s = 4 # scaling factor between z and y
782
+ kernel_size = 5 # context prediction kernel size
783
+ padding = (kernel_size - 1) // 2
784
+
785
+ y_height = z_hat.size(2) * s
786
+ y_width = z_hat.size(3) * s
787
+
788
+ y_hat = F.pad(y, (padding, padding, padding, padding))
789
+
790
+ # pylint: disable=protected-access
791
+ cdf = self.gaussian_conditional._quantized_cdf.tolist()
792
+ cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist()
793
+ offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist()
794
+ # pylint: enable=protected-access
795
+ # print(cdf, cdf_lengths, offsets)
796
+
797
+ y_strings = []
798
+ for i in range(y.size(0)):
799
+ encoder = BufferedRansEncoder()
800
+ # Warning, this is slow...
801
+ # TODO: profile the calls to the bindings...
802
+ symbols_list = []
803
+ indexes_list = []
804
+ for h in range(y_height):
805
+ for w in range(y_width):
806
+ y_crop = y_hat[
807
+ i: i + 1, :, h: h + kernel_size, w: w + kernel_size
808
+ ]
809
+ ctx_p = self.context_prediction(y_crop)
810
+ # 1x1 conv for the entropy parameters prediction network, so
811
+ # we only keep the elements in the "center"
812
+ p = params[i: i + 1, :, h: h + 1, w: w + 1]
813
+ gaussian_params = self.entropy_parameters(
814
+ torch.cat((p, ctx_p[i: i + 1, :, 2: 3, 2: 3]), dim=1)
815
+ )
816
+ scales_hat, means_hat = gaussian_params.chunk(2, 1)
817
+
818
+ indexes = self.gaussian_conditional.build_indexes(scales_hat)
819
+ y_q = torch.round(y_crop - means_hat)
820
+ y_hat[i, :, h + padding, w + padding] = (y_q + means_hat)[
821
+ i, :, padding, padding
822
+ ]
823
+
824
+ symbols_list.extend(y_q[i, :, padding, padding].int().tolist())
825
+ indexes_list.extend(indexes[i, :].squeeze().int().tolist())
826
+
827
+ encoder.encode_with_indexes(
828
+ symbols_list, indexes_list, cdf, cdf_lengths, offsets
829
+ )
830
+
831
+ string = encoder.flush()
832
+ y_strings.append(string)
833
+
834
+ return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]}
835
+
836
+ def decompress(self, strings, shape, flag):
837
+ # def decompress(self, strings, shape):
838
+ flag = np.nonzero(flag)
839
+ assert isinstance(strings, list) and len(strings) == 2
840
+ # FIXME: we don't respect the default entropy coder and directly call the
841
+ # range ANS decoder
842
+
843
+ z_hat = self.entropy_bottleneck.decompress(strings[1], shape)
844
+ params = self.h_s(z_hat)
845
+
846
+ s = 4 # scaling factor between z and y
847
+ kernel_size = 5 # context prediction kernel size
848
+ padding = (kernel_size - 1) // 2
849
+
850
+ y_height = z_hat.size(2) * s
851
+ y_width = z_hat.size(3) * s
852
+
853
+ # initialize y_hat to zeros, and pad it so we can directly work with
854
+ # sub-tensors of size (N, C, kernel size, kernel_size)
855
+ y_hat = torch.zeros(
856
+ (z_hat.size(0), 192, y_height + 2 * padding, y_width + 2 * padding),
857
+ device=z_hat.device,
858
+ )
859
+ decoder = RansDecoder()
860
+
861
+ # pylint: disable=protected-access
862
+ cdf = self.gaussian_conditional._quantized_cdf.tolist()
863
+ cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist()
864
+ offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist()
865
+
866
+ # Warning: this is slow due to the auto-regressive nature of the
867
+ # decoding... See more recent publication where they use an
868
+ # auto-regressive module on chunks of channels for faster decoding...
869
+ for i, y_string in enumerate(strings[0]):
870
+ decoder.set_stream(y_string)
871
+
872
+ for h in range(y_height):
873
+ for w in range(y_width):
874
+ # only perform the 5x5 convolution on a cropped tensor
875
+ # centered in (h, w)
876
+ y_crop = y_hat[
877
+ i: i + 1, :, h: h + kernel_size, w: w + kernel_size
878
+ ]
879
+ ctx_p = self.context_prediction(y_crop)
880
+ # 1x1 conv for the entropy parameters prediction network, so
881
+ # we only keep the elements in the "center"
882
+ p = params[i: i + 1, :, h: h + 1, w: w + 1]
883
+ gaussian_params = self.entropy_parameters(
884
+ torch.cat((p, ctx_p[i: i + 1, :, 2: 3, 2: 3]), dim=1)
885
+ )
886
+ scales_hat, means_hat = gaussian_params.chunk(2, 1)
887
+
888
+ indexes = self.gaussian_conditional.build_indexes(scales_hat)
889
+ rv = decoder.decode_stream(
890
+ indexes[i, flag].squeeze().int().tolist(),
891
+ # indexes[i, :].squeeze().int().tolist(),
892
+ cdf,
893
+ cdf_lengths,
894
+ offsets,
895
+ )
896
+ # rv = torch.Tensor(rv).reshape(1, -1, 1, 1)
897
+ rv = torch.Tensor(rv).reshape(1, -1, 1, 1)
898
+ tmp = torch.zeros((1, 192, 1, 1))
899
+ tmp[:, flag, ...] = rv
900
+ rv = self.gaussian_conditional._dequantize(tmp, means_hat)
901
+ # rv = self.gaussian_conditional._dequantize(rv, means_hat)
902
+
903
+ y_hat[
904
+ i,
905
+ :,
906
+ h + padding: h + padding + 1,
907
+ w + padding: w + padding + 1,
908
+ ] = rv
909
+
910
+
911
+ y_hat = y_hat[:, :, padding:-padding, padding:-padding]
912
+ # pylint: enable=protected-access
913
+
914
+ x_hat = self.g_s(y_hat).clamp_(0, 1)
915
+ return {"x_hat": x_hat,}
916
+
917
+ def decompress_1(self, strings, shape):
918
+ assert isinstance(strings, list) and len(strings) == 2
919
+ # FIXME: we don't respect the default entropy coder and directly call the
920
+ # range ANS decoder
921
+
922
+ z_hat = self.entropy_bottleneck.decompress(strings[1], shape)
923
+ params = self.h_s(z_hat)
924
+
925
+ s = 4 # scaling factor between z and y
926
+ kernel_size = 5 # context prediction kernel size
927
+ padding = (kernel_size - 1) // 2
928
+
929
+ y_height = z_hat.size(2) * s
930
+ y_width = z_hat.size(3) * s
931
+
932
+ # initialize y_hat to zeros, and pad it so we can directly work with
933
+ # sub-tensors of size (N, C, kernel size, kernel_size)
934
+ y_hat = torch.zeros(
935
+ (z_hat.size(0), 192, y_height + 2 * padding, y_width + 2 * padding),
936
+ device=z_hat.device,
937
+ )
938
+ decoder = RansDecoder()
939
+
940
+ # pylint: disable=protected-access
941
+ cdf = self.gaussian_conditional._quantized_cdf.tolist()
942
+ cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist()
943
+ offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist()
944
+
945
+ # Warning: this is slow due to the auto-regressive nature of the
946
+ # decoding... See more recent publication where they use an
947
+ # auto-regressive module on chunks of channels for faster decoding...
948
+ for i, y_string in enumerate(strings[0]):
949
+ decoder.set_stream(y_string)
950
+
951
+ for h in range(y_height):
952
+ for w in range(y_width):
953
+ # only perform the 5x5 convolution on a cropped tensor
954
+ # centered in (h, w)
955
+ y_crop = y_hat[
956
+ i: i + 1, :, h: h + kernel_size, w: w + kernel_size
957
+ ]
958
+ ctx_p = self.context_prediction(y_crop)
959
+ # 1x1 conv for the entropy parameters prediction network, so
960
+ # we only keep the elements in the "center"
961
+ p = params[i: i + 1, :, h: h + 1, w: w + 1]
962
+ gaussian_params = self.entropy_parameters(
963
+ torch.cat((p, ctx_p[i: i + 1, :, 2: 3, 2: 3]), dim=1)
964
+ )
965
+ scales_hat, means_hat = gaussian_params.chunk(2, 1)
966
+
967
+ indexes = self.gaussian_conditional.build_indexes(scales_hat)
968
+
969
+ rv = decoder.decode_stream(
970
+ indexes[i, :].squeeze().int().tolist(),
971
+ cdf,
972
+ cdf_lengths,
973
+ offsets,
974
+ )
975
+ rv = torch.Tensor(rv).reshape(1, -1, 1, 1)
976
+
977
+ rv = self.gaussian_conditional._dequantize(rv, means_hat)
978
+
979
+ y_hat[
980
+ i,
981
+ :,
982
+ h + padding: h + padding + 1,
983
+ w + padding: w + padding + 1,
984
+ ] = rv
985
+ y_hat = y_hat[:, :, padding:-padding, padding:-padding]
986
+ # pylint: enable=protected-access
987
+
988
+ x_hat = self.g_s(y_hat).clamp_(0, 1)
989
+ return {"x_hat": x_hat}
vae/transformer_layers.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import numpy as np
4
+
5
+
6
+ def create_look_ahead_mask(size):
7
+ """Creates a lookahead mask for autoregressive masking."""
8
+ mask = np.triu(np.ones((size, size), np.float32), 1)
9
+ return torch.Tensor(mask)
10
+
11
+
12
+ class StochasticDepth(nn.Module):
13
+ """Creates a stochastic depth layer."""
14
+
15
+ def __init__(self, stochastic_depth_drop_rate):
16
+ """Initializes a stochastic depth layer.
17
+
18
+ Args:
19
+ stochastic_depth_drop_rate: A `float` of drop rate.
20
+ name: Name of the layer.
21
+
22
+ Returns:
23
+ A output `tf.Tensor` of which should have the same shape as input.
24
+ """
25
+ super().__init__()
26
+ self._drop_rate = stochastic_depth_drop_rate
27
+
28
+ def forward(self, inputs):
29
+ if not self.training or self._drop_rate == 0.:
30
+ return inputs
31
+ keep_prob = 1.0 - self._drop_rate
32
+ batch_size = inputs.shape[0]
33
+ random_tensor = keep_prob
34
+ random_tensor += torch.rand_like(
35
+ [batch_size] + [1] * (inputs.shape.rank - 1), dtype=inputs.dtype)
36
+ binary_tensor = torch.floor(random_tensor)
37
+ output = torch.div(inputs, keep_prob) * binary_tensor
38
+ return output
39
+
40
+
41
+ class MLP(nn.Module):
42
+ """MLP head for transformer."""
43
+
44
+ def __init__(self, n_channel,expansion_rate, act, dropout_rate):
45
+ super().__init__()
46
+ self._expansion_rate = expansion_rate
47
+ self._act = act
48
+ self._dropout_rate = dropout_rate
49
+ self._fc1 = nn.Linear(
50
+ n_channel,
51
+ self._expansion_rate * n_channel)
52
+ self.act1 = self._act()
53
+ self._fc2 = nn.Linear(
54
+ self._expansion_rate * n_channel,
55
+ n_channel)
56
+ self.act2 = self._act()
57
+ self._drop = nn.Dropout(self._dropout_rate)
58
+
59
+ def forward(self, features):
60
+ """Forward pass."""
61
+ features = self.act1(self._fc1(features))
62
+ features = self._drop(features)
63
+ features = self.act2(self._fc2(features))
64
+ features = self._drop(features)
65
+ return features
66
+
67
+
68
+ class TransformerBlock(nn.Module):
69
+ """Transformer block that is similar to the Swin encoder block.
70
+
71
+ However, an important difference is that we _do not_ shift the windows
72
+ for the second Attention layer. Instead, we _feed the encoder outputs_
73
+ as Keys and Values. This allows for autoregressive applications.
74
+
75
+ If `style == "encoder"`, no autoregression is happening.
76
+
77
+ Also, this class operates on windowed tensor, see `call` docstring.
78
+ """
79
+
80
+ def __init__(
81
+ self,
82
+ *,
83
+ d_model,
84
+ seq_len,
85
+ num_head = 4,
86
+ mlp_expansion = 4,
87
+ mlp_act = nn.GELU,
88
+ drop_out_rate = 0.1,
89
+ drop_path_rate = 0.1,
90
+ style = "decoder",
91
+ ):
92
+ super().__init__()
93
+ self._style = style
94
+ if style == "decoder":
95
+ # Register as a buffer so moving the module moves the mask too.
96
+ self.register_buffer(
97
+ "look_ahead_mask",
98
+ create_look_ahead_mask(seq_len),
99
+ persistent=False,
100
+ )
101
+ elif style == "encoder":
102
+ self.look_ahead_mask = None
103
+ else:
104
+ raise ValueError(f"Invalid style: {style}")
105
+
106
+ # self._norm1a = nn.LayerNorm(
107
+ # axis=-1, epsilon=1e-5, name="mhsa_normalization1")
108
+ self._norm1a = nn.LayerNorm(d_model)
109
+ # self._norm1b = tf.keras.layers.LayerNormalization(
110
+ # axis=-1, epsilon=1e-5, name="ffn_normalization1")
111
+ self._norm1b = nn.LayerNorm(d_model,eps=1e-5)
112
+
113
+ # self._norm2a = tf.keras.layers.LayerNormalization(
114
+ # axis=-1, epsilon=1e-5, name="mhsa_normalization2")
115
+ self._norm2a = nn.LayerNorm(d_model, eps=1e-5)
116
+ # self._norm2b = tf.keras.layers.LayerNormalization(
117
+ # axis=-1, epsilon=1e-5, name="ffn_normalization2")
118
+ self._norm2b = nn.LayerNorm(d_model, eps=1e-5)
119
+ self._attn1 = nn.MultiheadAttention(
120
+ d_model,
121
+ num_head,
122
+ dropout=drop_out_rate
123
+ )
124
+
125
+ self._attn2 = nn.MultiheadAttention(
126
+ d_model,
127
+ num_head,
128
+ dropout=drop_out_rate
129
+ )
130
+
131
+ self._mlp1 = MLP(
132
+ d_model,
133
+ expansion_rate=mlp_expansion,
134
+ act=mlp_act,
135
+ dropout_rate=drop_out_rate)
136
+ self._mlp2 = MLP(
137
+ d_model,
138
+ expansion_rate=mlp_expansion,
139
+ act=mlp_act,
140
+ dropout_rate=drop_out_rate)
141
+
142
+ # No weights, so we share for both blocks.
143
+ self._drop_path = StochasticDepth(drop_path_rate)
144
+
145
+ def forward(self, features, enc_output):
146
+ if enc_output is None:
147
+ if self._style == "decoder":
148
+ raise ValueError("Need `enc_output` when running decoder.")
149
+ else:
150
+ assert enc_output.shape[0] == features.shape[0] and enc_output.shape[2] == features.shape[2]
151
+
152
+ # First Block ---
153
+ shortcut = features
154
+ features = self._norm1a(features)
155
+ # Masked self-attention.
156
+ features = features.permute(1, 0, 2) # NLD -> LND
157
+ features, _ = self._attn1(
158
+ value=features,
159
+ key=features,
160
+ query=features,
161
+ attn_mask=self.look_ahead_mask)
162
+ features = features.permute(1, 0, 2) # LND -> NLD
163
+
164
+ assert features.shape == shortcut.shape
165
+ features = shortcut + self._drop_path(features)
166
+
167
+ features = features + self._drop_path(
168
+ self._mlp1(self._norm1b(features)))
169
+
170
+ # Second Block ---
171
+ shortcut = features
172
+ features = self._norm2a(features)
173
+ # Unmasked "lookup" into enc_output, no need for mask.
174
+
175
+ features = features.permute(1, 0, 2) # NLD -> LND
176
+ if enc_output is not None:
177
+ enc_output = enc_output.permute(1, 0, 2) # NLD -> LND
178
+ features, _ = self._attn2( # pytype: disable=wrong-arg-types # dynamic-method-lookup
179
+ value=enc_output if enc_output is not None else features,
180
+ key=enc_output if enc_output is not None else features,
181
+ query=features,
182
+ attn_mask=None)
183
+ features = features.permute(1, 0, 2) # LND -> NLD
184
+
185
+ features = shortcut + self._drop_path(features)
186
+ output = features + self._drop_path(
187
+ self._mlp2(self._norm2b(features)))
188
+
189
+ return output
190
+
191
+
192
+ class Transformer(nn.Module):
193
+ """A stack of transformer blocks, useable for encoding or decoding."""
194
+
195
+ def __init__(
196
+ self,
197
+ is_decoder,
198
+ num_layers = 4,
199
+ d_model = 192,
200
+ seq_len = 16,
201
+ num_head = 4,
202
+ mlp_expansion = 4,
203
+ drop_out = 0.1
204
+ ):
205
+ super().__init__()
206
+ self.is_decoder = is_decoder
207
+
208
+ # IMPORTANT: use ModuleList so parameters/buffers are registered and moved
209
+ # correctly with `.to(device)`.
210
+ self.layers = nn.ModuleList(
211
+ [
212
+ TransformerBlock(
213
+ d_model=d_model,
214
+ seq_len=seq_len,
215
+ num_head=num_head,
216
+ mlp_expansion=mlp_expansion,
217
+ drop_out_rate=drop_out,
218
+ drop_path_rate=drop_out,
219
+ style="decoder" if is_decoder else "encoder",
220
+ )
221
+ for _ in range(num_layers)
222
+ ]
223
+ )
224
+
225
+ def forward(
226
+ self, latent, enc_output
227
+ ):
228
+ """Forward pass.
229
+
230
+ For decoder, this predicts distribution of `latent` given `enc_output`.
231
+
232
+ We assume that `latent` has already been embedded in a d_model-dimensional
233
+ space.
234
+
235
+ Args:
236
+ latent: (B', seq_len, C) latent.
237
+ enc_output: (B', seq_len_enc, C) result of concatenated encode output.
238
+ training: Whether we are training.
239
+
240
+ Returns:
241
+ Decoder output of shape (B', seq_len, C).
242
+ """
243
+ assert len(latent.shape) == 3, latent.shape
244
+ if enc_output is not None:
245
+ assert latent.shape[-1] == enc_output.shape[-1], (latent.shape,
246
+ enc_output.shape)
247
+ for layer in self.layers:
248
+ latent = layer(features=latent, enc_output=enc_output)
249
+ return latent
250
+
vae/utils.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for image compression.
3
+ """
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ from PIL import Image
9
+ from typing import Tuple, Dict
10
+ from .roi_tic import ModifiedTIC
11
+
12
+
13
+ def compute_padding(in_h: int, in_w: int, min_div: int = 256) -> Tuple[Tuple[int, int, int, int], Tuple[int, int, int, int]]:
14
+ """
15
+ Compute padding to make dimensions divisible by min_div.
16
+
17
+ Args:
18
+ in_h: input height
19
+ in_w: input width
20
+ min_div: minimum divisor (default 256 for TIC)
21
+
22
+ Returns:
23
+ pad: (left, right, top, bottom) padding
24
+ unpad: negative padding for cropping back
25
+ """
26
+ out_h = (in_h + min_div - 1) // min_div * min_div
27
+ out_w = (in_w + min_div - 1) // min_div * min_div
28
+
29
+ left = (out_w - in_w) // 2
30
+ right = out_w - in_w - left
31
+ top = (out_h - in_h) // 2
32
+ bottom = out_h - in_h - top
33
+
34
+ pad = (left, right, top, bottom)
35
+ unpad = (-left, -right, -top, -bottom)
36
+
37
+ return pad, unpad
38
+
39
+
40
+ def compress_image(
41
+ image: Image.Image,
42
+ mask: np.ndarray,
43
+ model: ModifiedTIC,
44
+ sigma: float = 0.3,
45
+ device: str = 'cuda'
46
+ ) -> Dict:
47
+ """
48
+ Compress image with ROI-based quality control.
49
+
50
+ Args:
51
+ image: PIL Image (RGB)
52
+ mask: Binary mask (H, W) with 1 for ROI, 0 for background
53
+ model: Loaded ModifiedTIC model
54
+ sigma: Background quality (0.01-1.0, lower = more compression)
55
+ device: 'cuda' or 'cpu'
56
+
57
+ Returns:
58
+ dict with:
59
+ - compressed: PIL Image of compressed result
60
+ - bpp: Bits per pixel
61
+ - original_size: Original image dimensions
62
+ - mask_used: The mask that was used
63
+ """
64
+ # Convert image to tensor
65
+ img_array = np.array(image).astype(np.float32) / 255.0
66
+ img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).to(device)
67
+
68
+ # Pad image
69
+ _, _, h, w = img_tensor.shape
70
+ pad, unpad = compute_padding(h, w, min_div=256)
71
+ img_padded = F.pad(img_tensor, pad, mode='constant', value=0)
72
+
73
+ # Prepare mask
74
+ mask_tensor = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).to(device)
75
+ mask_padded = F.pad(mask_tensor, pad, mode='constant', value=0)
76
+
77
+ # Compress
78
+ with torch.no_grad():
79
+ # NOTE: `ModifiedTIC.forward()` handles mask downsampling internally.
80
+ out = model(img_padded, mask_padded, sigma=sigma)
81
+
82
+ # Unpad result
83
+ x_hat = F.pad(out['x_hat'], unpad)
84
+
85
+ # Convert back to image
86
+ x_hat_np = x_hat.squeeze(0).permute(1, 2, 0).cpu().numpy()
87
+ x_hat_np = np.clip(x_hat_np * 255, 0, 255).astype(np.uint8)
88
+ compressed_img = Image.fromarray(x_hat_np)
89
+
90
+ # Calculate BPP
91
+ num_pixels = h * w
92
+ likelihoods = out['likelihoods']
93
+ bpp_y = torch.log(likelihoods['y']).sum() / (-np.log(2) * num_pixels)
94
+ bpp_z = torch.log(likelihoods['z']).sum() / (-np.log(2) * num_pixels)
95
+ bpp = (bpp_y + bpp_z).item()
96
+
97
+ return {
98
+ 'compressed': compressed_img,
99
+ 'bpp': bpp,
100
+ 'original_size': (w, h),
101
+ 'mask_used': mask
102
+ }
vae/visualization.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visualization utilities for compression results.
3
+ """
4
+
5
+ import numpy as np
6
+ import cv2
7
+ from PIL import Image
8
+ from typing import Tuple
9
+ import matplotlib.pyplot as plt
10
+
11
+
12
+ def highlight_roi(
13
+ image: Image.Image,
14
+ mask: np.ndarray,
15
+ alpha: float = 0.3,
16
+ color: Tuple[int, int, int] = (0, 255, 0)
17
+ ) -> Image.Image:
18
+ """
19
+ Highlight ROI regions in image with colored overlay.
20
+
21
+ Args:
22
+ image: PIL Image
23
+ mask: Binary mask (H, W)
24
+ alpha: Overlay transparency (0-1)
25
+ color: RGB color tuple for ROI highlight
26
+
27
+ Returns:
28
+ Image with ROI highlighted
29
+ """
30
+ img_array = np.array(image)
31
+
32
+ # Create colored overlay
33
+ overlay = img_array.copy()
34
+ overlay[mask > 0.5] = color
35
+
36
+ # Blend
37
+ result = cv2.addWeighted(img_array, 1 - alpha, overlay, alpha, 0)
38
+
39
+ return Image.fromarray(result)
40
+
41
+
42
+ def create_comparison_grid(
43
+ original: Image.Image,
44
+ compressed: Image.Image,
45
+ mask: np.ndarray,
46
+ bpp: float,
47
+ sigma: float,
48
+ lambda_val: float,
49
+ highlight: bool = True
50
+ ) -> Image.Image:
51
+ """
52
+ Create side-by-side comparison of original and compressed images.
53
+
54
+ Args:
55
+ original: Original PIL Image
56
+ compressed: Compressed PIL Image
57
+ mask: Binary mask used
58
+ bpp: Bits per pixel
59
+ sigma: Sigma value used
60
+ lambda_val: Lambda value used
61
+ highlight: Whether to show ROI overlay
62
+
63
+ Returns:
64
+ Combined comparison image
65
+ """
66
+ fig, axes = plt.subplots(1, 3 if highlight else 2, figsize=(15 if highlight else 10, 5))
67
+
68
+ # Original
69
+ axes[0].imshow(original)
70
+ axes[0].set_title('Original', fontsize=14)
71
+ axes[0].axis('off')
72
+
73
+ # Compressed
74
+ axes[1].imshow(compressed)
75
+ axes[1].set_title(f'Compressed (σ={sigma:.2f}, λ={lambda_val}, BPP={bpp:.3f})', fontsize=14)
76
+ axes[1].axis('off')
77
+
78
+ # ROI overlay
79
+ if highlight:
80
+ highlighted = highlight_roi(original, mask, alpha=0.4, color=(0, 255, 0))
81
+ axes[2].imshow(highlighted)
82
+ axes[2].set_title('ROI Mask (green)', fontsize=14)
83
+ axes[2].axis('off')
84
+
85
+ plt.tight_layout()
86
+
87
+ # Convert to PIL Image
88
+ fig.canvas.draw()
89
+ img_array = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
90
+ img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (4,))
91
+ img_array = img_array[:, :, :3] # Remove alpha channel
92
+ plt.close(fig)
93
+
94
+ return Image.fromarray(img_array)
video/__init__.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Video processing module for ROI-based video compression.
2
+
3
+ Provides:
4
+ - VideoProcessor: Frame extraction, motion analysis, adaptive compression
5
+ - MotionAnalyzer: Optical flow and scene complexity estimation
6
+ - ChunkCompressor: Chunk-by-chunk compression with bandwidth targeting
7
+ - Temporal smoothing utilities for mask stabilization
8
+ - Mask caching for video segmentation reuse
9
+ """
10
+
11
+ from .video_processor import (
12
+ VideoProcessor,
13
+ VideoFrame,
14
+ CompressedChunk,
15
+ CompressionSettings,
16
+ ChunkPlan,
17
+ )
18
+ from .motion_analyzer import MotionAnalyzer
19
+ from .chunk_compressor import (
20
+ ChunkCompressor,
21
+ BandwidthController,
22
+ smooth_masks_temporal,
23
+ smooth_masks_temporal_fast,
24
+ smooth_masks_sdf,
25
+ )
26
+ from .mask_cache import (
27
+ save_video_masks,
28
+ load_video_masks,
29
+ get_mask_info,
30
+ )
31
+ from .gpu_memory import (
32
+ estimate_batch_sizes,
33
+ BatchSizeEstimate,
34
+ )
35
+ from .sdf_smoother import SDFSmoother
36
+
37
+ __all__ = [
38
+ "VideoProcessor",
39
+ "VideoFrame",
40
+ "CompressedChunk",
41
+ "CompressionSettings",
42
+ "ChunkPlan",
43
+ "MotionAnalyzer",
44
+ "ChunkCompressor",
45
+ "BandwidthController",
46
+ "smooth_masks_temporal",
47
+ "smooth_masks_temporal_fast",
48
+ "smooth_masks_sdf",
49
+ "save_video_masks",
50
+ "load_video_masks",
51
+ "get_mask_info",
52
+ "estimate_batch_sizes",
53
+ "BatchSizeEstimate",
54
+ "SDFSmoother",
55
+ ]
video/chunk_compressor.py ADDED
@@ -0,0 +1,877 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Chunk-based video compression with bandwidth targeting.
2
+
3
+ Implements dynamic compression that balances framerate and spatial quality
4
+ to meet bandwidth constraints while prioritizing motion-heavy scenes.
5
+
6
+ Includes temporal smoothing for segmentation masks to reduce flickering.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from dataclasses import dataclass, field
12
+ from typing import List, Optional, Tuple, Generator, Dict, Any
13
+ import math
14
+
15
+ import numpy as np
16
+ from PIL import Image
17
+ from scipy import ndimage
18
+
19
+ from .motion_analyzer import MotionAnalyzer, MotionMetrics
20
+ from .sdf_smoother import SDFSmoother
21
+
22
+
23
+ def smooth_masks_temporal(
24
+ masks: List[np.ndarray],
25
+ window_size: int = 5,
26
+ threshold_appear: float = 0.4,
27
+ threshold_disappear: float = 0.2,
28
+ spatial_smooth: bool = True,
29
+ spatial_kernel_size: int = 5,
30
+ ) -> List[np.ndarray]:
31
+ """Apply temporal smoothing to a sequence of segmentation masks.
32
+
33
+ Reduces flickering by:
34
+ 1. Applying temporal median/mean filtering across frames
35
+ 2. Using hysteresis thresholding (different thresholds for appearing vs disappearing)
36
+ 3. Optional spatial smoothing to clean up edges
37
+
38
+ Args:
39
+ masks: List of binary/float masks (H, W), values in [0, 1]
40
+ window_size: Number of frames to consider for temporal filtering (odd number)
41
+ threshold_appear: Confidence threshold for a pixel to become ROI (higher = stricter)
42
+ threshold_disappear: Confidence threshold for a pixel to stop being ROI (lower = stickier)
43
+ spatial_smooth: Whether to apply spatial Gaussian smoothing
44
+ spatial_kernel_size: Size of spatial smoothing kernel
45
+
46
+ Returns:
47
+ List of smoothed masks
48
+ """
49
+ if not masks or len(masks) < 2:
50
+ return masks
51
+
52
+ # Convert to numpy array for efficient processing: (T, H, W)
53
+ h, w = masks[0].shape
54
+ mask_stack = np.stack([m.astype(np.float32) for m in masks], axis=0)
55
+ num_frames = mask_stack.shape[0]
56
+
57
+ # Pad temporally for filtering
58
+ half_window = window_size // 2
59
+ padded = np.pad(mask_stack, ((half_window, half_window), (0, 0), (0, 0)), mode='edge')
60
+
61
+ # Apply temporal filtering using a sliding window
62
+ smoothed = np.zeros_like(mask_stack)
63
+ for t in range(num_frames):
64
+ # Get window of frames
65
+ window = padded[t:t + window_size] # (window_size, H, W)
66
+ # Use weighted mean - center frame gets more weight
67
+ weights = np.array([1, 2, 3, 2, 1][:window_size])
68
+ weights = weights / weights.sum()
69
+ weighted_mean = np.average(window, axis=0, weights=weights)
70
+ smoothed[t] = weighted_mean
71
+
72
+ # Apply hysteresis thresholding
73
+ # A pixel becomes ROI if confidence > threshold_appear
74
+ # A pixel stays ROI until confidence < threshold_disappear
75
+ result = np.zeros_like(smoothed)
76
+
77
+ # Initialize with first frame using appear threshold
78
+ result[0] = (smoothed[0] > threshold_appear).astype(np.float32)
79
+
80
+ for t in range(1, num_frames):
81
+ # Pixels that were ROI in previous frame
82
+ was_roi = result[t - 1] > 0.5
83
+
84
+ # New ROI: high confidence (above appear threshold)
85
+ new_roi = smoothed[t] > threshold_appear
86
+
87
+ # Continuing ROI: was ROI and still above disappear threshold
88
+ continuing_roi = was_roi & (smoothed[t] > threshold_disappear)
89
+
90
+ # Combine: either new or continuing
91
+ result[t] = (new_roi | continuing_roi).astype(np.float32)
92
+
93
+ # Optional: apply spatial smoothing to clean up edges
94
+ if spatial_smooth:
95
+ for t in range(num_frames):
96
+ # Gaussian blur then threshold
97
+ blurred = ndimage.gaussian_filter(result[t], sigma=spatial_kernel_size / 4)
98
+ result[t] = (blurred > 0.5).astype(np.float32)
99
+
100
+ # Convert back to list
101
+ return [result[t] for t in range(num_frames)]
102
+
103
+
104
+ def smooth_masks_temporal_fast(
105
+ masks: List[np.ndarray],
106
+ alpha: float = 0.3,
107
+ threshold: float = 0.5,
108
+ ) -> List[np.ndarray]:
109
+ """Fast temporal smoothing using exponential moving average.
110
+
111
+ Simpler and faster than full temporal smoothing, good for real-time.
112
+
113
+ Args:
114
+ masks: List of binary/float masks (H, W)
115
+ alpha: Smoothing factor (0-1). Lower = more smoothing, more lag.
116
+ threshold: Threshold for final binary mask
117
+
118
+ Returns:
119
+ List of smoothed masks
120
+ """
121
+ if not masks or len(masks) < 2:
122
+ return masks
123
+
124
+ result = []
125
+ ema = masks[0].astype(np.float32).copy()
126
+ result.append((ema > threshold).astype(np.float32))
127
+
128
+ for i in range(1, len(masks)):
129
+ current = masks[i].astype(np.float32)
130
+ # Exponential moving average
131
+ ema = alpha * current + (1 - alpha) * ema
132
+ result.append((ema > threshold).astype(np.float32))
133
+
134
+ return result
135
+
136
+
137
+ def smooth_masks_sdf(
138
+ masks: List[np.ndarray],
139
+ alpha: float = 0.5,
140
+ empty_thresh: int = 10,
141
+ patience: int = 5,
142
+ ) -> List[np.ndarray]:
143
+ """Smooth masks using Signed Distance Field temporal filtering.
144
+
145
+ Uses SDF representation for fluid, jitter-free transitions while
146
+ preserving sharp boundaries. More sophisticated than simple EMA.
147
+
148
+ Args:
149
+ masks: List of binary/float masks (H, W)
150
+ alpha: Smoothing factor (0.1 = slow/viscous, 0.9 = fast/reactive)
151
+ empty_thresh: Min pixel count to consider mask "valid"
152
+ patience: Frames to tolerate empty masks before decay
153
+ (0 = immediate, 5 = conservative, 15 = aggressive)
154
+
155
+ Returns:
156
+ List of smoothed masks
157
+ """
158
+ if not masks or len(masks) < 2:
159
+ return masks
160
+
161
+ smoother = SDFSmoother(alpha=alpha, empty_thresh=empty_thresh, patience=patience)
162
+ result = []
163
+
164
+ for mask in masks:
165
+ smoothed = smoother.update(mask.astype(np.float32))
166
+ result.append(smoothed)
167
+
168
+ return result
169
+
170
+
171
+ @dataclass
172
+ class CompressionResult:
173
+ """Result of compressing a single frame."""
174
+
175
+ compressed_image: Image.Image
176
+ bpp: float
177
+ original_size: Tuple[int, int]
178
+ roi_coverage: float
179
+
180
+
181
+ @dataclass
182
+ class ChunkResult:
183
+ """Result of compressing a video chunk."""
184
+
185
+ # Compressed frames for this chunk
186
+ frames: List[Image.Image]
187
+
188
+ # Frame indices from original video that were kept
189
+ frame_indices: List[int]
190
+
191
+ # Effective framerate for this chunk
192
+ effective_fps: float
193
+
194
+ # Quality level used (1-5)
195
+ quality_level: int
196
+
197
+ # Sigma (background preservation) used
198
+ sigma: float
199
+
200
+ # Average bits per pixel
201
+ avg_bpp: float
202
+
203
+ # Estimated chunk size in bytes
204
+ estimated_bytes: int
205
+
206
+ # Motion metrics for this chunk
207
+ motion_metrics: Optional[MotionMetrics] = None
208
+
209
+ # Chunk index
210
+ chunk_index: int = 0
211
+
212
+ # Total number of original frames in chunk
213
+ original_frame_count: int = 0
214
+
215
+
216
+ class BandwidthController:
217
+ """Controls compression parameters to meet bandwidth targets.
218
+
219
+ Dynamically adjusts:
220
+ - Frame sampling rate (effective FPS)
221
+ - Spatial compression quality
222
+ - Background preservation (sigma)
223
+
224
+ Based on:
225
+ - Target bandwidth constraint
226
+ - Motion complexity of current chunk
227
+ - Smooth transitions between settings
228
+ """
229
+
230
+ # Quality level presets (lambda, expected_bpp_range)
231
+ QUALITY_PRESETS = [
232
+ (1, 0.05, 0.15), # Lowest quality: ~0.05-0.15 bpp
233
+ (2, 0.10, 0.25), # Low quality: ~0.10-0.25 bpp
234
+ (3, 0.15, 0.40), # Medium quality: ~0.15-0.40 bpp
235
+ (4, 0.25, 0.60), # High quality: ~0.25-0.60 bpp
236
+ (5, 0.40, 1.00), # Best quality: ~0.40-1.00 bpp
237
+ ]
238
+
239
+ def __init__(
240
+ self,
241
+ target_bandwidth_kbps: float = 500.0,
242
+ base_fps: float = 30.0,
243
+ min_fps: float = 5.0,
244
+ max_fps: float = 60.0,
245
+ chunk_duration_sec: float = 1.0,
246
+ smoothing_factor: float = 0.3,
247
+ aggressiveness: float = 0.5,
248
+ ):
249
+ """
250
+ Args:
251
+ target_bandwidth_kbps: Target bandwidth in kilobits per second
252
+ base_fps: Original video framerate
253
+ min_fps: Minimum allowed effective framerate
254
+ max_fps: Maximum allowed effective framerate
255
+ chunk_duration_sec: Duration of each chunk in seconds
256
+ smoothing_factor: How much to smooth parameter transitions (0-1)
257
+ aggressiveness: Bandwidth savings strategy (0.0=use full bandwidth, 1.0=maximum savings)
258
+ """
259
+ self.target_bandwidth_kbps = target_bandwidth_kbps
260
+ self.base_fps = base_fps
261
+ self.min_fps = min_fps
262
+ self.max_fps = max_fps
263
+ self.chunk_duration_sec = chunk_duration_sec
264
+ self.smoothing_factor = smoothing_factor
265
+ self.aggressiveness = max(0.0, min(1.0, aggressiveness))
266
+
267
+ # State for smooth transitions
268
+ self._prev_fps: Optional[float] = None
269
+ self._prev_quality: Optional[int] = None
270
+ self._prev_sigma: Optional[float] = None
271
+
272
+ def reset(self):
273
+ """Reset state for new video."""
274
+ self._prev_fps = None
275
+ self._prev_quality = None
276
+ self._prev_sigma = None
277
+
278
+ def compute_settings(
279
+ self,
280
+ frame_width: int,
281
+ frame_height: int,
282
+ motion_metrics: MotionMetrics,
283
+ roi_coverage: float = 0.3,
284
+ ) -> Tuple[float, int, float]:
285
+ """Compute compression settings for a chunk.
286
+
287
+ Args:
288
+ frame_width: Frame width in pixels
289
+ frame_height: Frame height in pixels
290
+ motion_metrics: Motion analysis results
291
+ roi_coverage: Fraction of frame covered by ROI
292
+
293
+ Returns:
294
+ Tuple of (effective_fps, quality_level, sigma)
295
+ """
296
+ num_pixels = frame_width * frame_height
297
+
298
+ # Target bits per chunk
299
+ target_bits_per_chunk = self.target_bandwidth_kbps * 1000 * self.chunk_duration_sec
300
+
301
+ # Framerate adjustment based on motion
302
+ # High motion -> more frames, lower quality per frame
303
+ # Low motion -> fewer frames, higher quality per frame
304
+ fr_factor = motion_metrics.framerate_factor
305
+
306
+ # For dynamic mode with high aggressiveness, use a higher baseline
307
+ # to ensure motion-based variation actually produces meaningful FPS range
308
+ if self.aggressiveness > 0.5:
309
+ # Use average of min/max as effective base for motion scaling
310
+ effective_base_fps = (self.min_fps + self.max_fps) / 2
311
+ else:
312
+ # Use video's native FPS
313
+ effective_base_fps = self.base_fps
314
+
315
+ # Base framerate adjusted by motion
316
+ motion_fps = effective_base_fps * fr_factor
317
+ motion_fps = np.clip(motion_fps, self.min_fps, self.max_fps)
318
+
319
+ # Estimate frames in chunk at this FPS
320
+ frames_in_chunk = int(motion_fps * self.chunk_duration_sec)
321
+ frames_in_chunk = max(1, frames_in_chunk)
322
+
323
+ # Target bits per frame
324
+ target_bits_per_frame = target_bits_per_chunk / frames_in_chunk
325
+ target_bpp = target_bits_per_frame / num_pixels
326
+
327
+ # Find quality level that matches target BPP
328
+ quality_level = self._find_quality_for_bpp(target_bpp, roi_coverage)
329
+
330
+ # Compute sigma based on motion and quality
331
+ # Higher motion -> lower sigma (more background compression to save bits for motion)
332
+ # Lower quality -> lower sigma (aggressive background compression)
333
+ base_sigma = 0.3
334
+ motion_adjustment = (1.0 - motion_metrics.motion_magnitude) * 0.4
335
+ quality_adjustment = (quality_level - 3) * 0.1
336
+ sigma = np.clip(base_sigma + motion_adjustment + quality_adjustment, 0.05, 0.8)
337
+
338
+ # Final FPS adjustment based on actual quality/bpp achieved
339
+ expected_bpp = self._estimate_bpp(quality_level, sigma, roi_coverage)
340
+ expected_bits_per_frame = expected_bpp * num_pixels
341
+
342
+ current_bandwidth = (motion_fps * expected_bits_per_frame) / 1000 # kbps
343
+ bandwidth_ratio = current_bandwidth / self.target_bandwidth_kbps if self.target_bandwidth_kbps > 0 else 1.0
344
+
345
+ # At high aggressiveness, prioritize motion over bandwidth constraints
346
+ # At low aggressiveness, prioritize bandwidth target
347
+ if self.aggressiveness > 0.7:
348
+ # Very aggressive: trust motion analysis, ignore bandwidth (may exceed target)
349
+ # Only enforce absolute max_fps constraint, not bandwidth
350
+ final_fps = motion_fps
351
+ elif self.aggressiveness > 0.5:
352
+ # Moderately aggressive: allow bandwidth excursions for high motion
353
+ # Only reduce FPS if significantly over budget (>1.5x)
354
+ if bandwidth_ratio > 1.5:
355
+ reduction_factor = 1.5 / bandwidth_ratio
356
+ final_fps = motion_fps * reduction_factor
357
+ else:
358
+ final_fps = motion_fps
359
+ else:
360
+ # Conservative mode: enforce bandwidth target strictly
361
+ override_threshold = 1.2 + self.aggressiveness * 0.5 # 1.2 at agg=0, 1.35 at agg=0.5
362
+
363
+ if bandwidth_ratio > override_threshold:
364
+ # Bandwidth too high, reduce FPS
365
+ fps_reduction = (bandwidth_ratio - 1.0) * 0.5
366
+ final_fps = max(self.min_fps, motion_fps / (1 + fps_reduction))
367
+ else:
368
+ # Keep motion-based FPS
369
+ final_fps = motion_fps
370
+
371
+ final_fps = np.clip(final_fps, self.min_fps, self.max_fps)
372
+
373
+ # Smooth transitions (disable smoothing at high aggressiveness for dramatic variation)
374
+ # At agg>0.7, no FPS smoothing - follow motion exactly
375
+ # At agg≤0.7, apply smoothing
376
+ if self.aggressiveness > 0.7:
377
+ # No FPS smoothing - allow dramatic jumps
378
+ smoothing_fps = 0.0
379
+ smoothing_other = 0.1 # Minimal smoothing for quality/sigma
380
+ else:
381
+ smoothing_fps = 0.3 - self.aggressiveness * 0.3 # 0.3 at agg=0, 0.09 at agg=0.7
382
+ smoothing_other = smoothing_fps
383
+
384
+ quality_change_limit = 1 + int(self.aggressiveness * 2) # 1 at agg=0, 3 at agg=1.0
385
+
386
+ if self._prev_fps is not None and smoothing_fps > 0:
387
+ final_fps = self._smooth(final_fps, self._prev_fps, smoothing_fps)
388
+ if self._prev_quality is not None:
389
+ # Quality can change by quality_change_limit steps at a time
390
+ quality_level = int(np.clip(
391
+ quality_level,
392
+ self._prev_quality - quality_change_limit,
393
+ self._prev_quality + quality_change_limit
394
+ ))
395
+ if self._prev_sigma is not None and smoothing_other > 0:
396
+ sigma = self._smooth(sigma, self._prev_sigma, smoothing_other)
397
+
398
+ # Update state
399
+ self._prev_fps = final_fps
400
+ self._prev_quality = quality_level
401
+ self._prev_sigma = sigma
402
+
403
+ return float(final_fps), int(quality_level), float(sigma)
404
+
405
+ def _find_quality_for_bpp(self, target_bpp: float, roi_coverage: float) -> int:
406
+ """Find quality level that approximately matches target BPP."""
407
+ # ROI coverage affects effective BPP (more ROI = higher BPP for same quality)
408
+ adjusted_target = target_bpp / (0.5 + roi_coverage * 0.5)
409
+
410
+ for level, min_bpp, max_bpp in self.QUALITY_PRESETS:
411
+ mid_bpp = (min_bpp + max_bpp) / 2
412
+ if adjusted_target <= mid_bpp:
413
+ return level
414
+
415
+ return 5 # Max quality if target is very high
416
+
417
+ def _estimate_bpp(self, quality_level: int, sigma: float, roi_coverage: float) -> float:
418
+ """Estimate BPP for given settings."""
419
+ _, min_bpp, max_bpp = self.QUALITY_PRESETS[quality_level - 1]
420
+ base_bpp = (min_bpp + max_bpp) / 2
421
+
422
+ # Sigma affects background compression
423
+ # Lower sigma = more compression = lower BPP
424
+ sigma_factor = 0.5 + sigma * 0.5
425
+
426
+ # ROI coverage affects total BPP
427
+ # More ROI = more high-quality pixels = higher BPP
428
+ roi_factor = 0.7 + roi_coverage * 0.3
429
+
430
+ return base_bpp * sigma_factor * roi_factor
431
+
432
+ def _smooth(self, new_value: float, prev_value: float, factor: Optional[float] = None) -> float:
433
+ """Apply exponential smoothing.
434
+
435
+ Args:
436
+ new_value: Target value
437
+ prev_value: Previous value
438
+ factor: Smoothing factor (uses self.smoothing_factor if None)
439
+ """
440
+ if factor is None:
441
+ factor = self.smoothing_factor
442
+ return prev_value + factor * (new_value - prev_value)
443
+
444
+
445
+ class ChunkCompressor:
446
+ """Compresses video chunks with adaptive settings.
447
+
448
+ Each chunk can have different:
449
+ - Frame sampling rate
450
+ - Compression quality
451
+ - Background preservation
452
+
453
+ Based on motion analysis and bandwidth constraints.
454
+ """
455
+
456
+ def __init__(
457
+ self,
458
+ compression_model,
459
+ segmenter=None,
460
+ target_classes: Optional[List[str]] = None,
461
+ device: str = "cuda",
462
+ ):
463
+ """
464
+ Args:
465
+ compression_model: Loaded TIC compression model
466
+ segmenter: Optional segmenter for ROI extraction
467
+ target_classes: Classes to segment as ROI
468
+ device: Compute device
469
+ """
470
+ self.compression_model = compression_model
471
+ self.segmenter = segmenter
472
+ self.target_classes = target_classes or []
473
+ self.device = device
474
+
475
+ self.motion_analyzer = MotionAnalyzer()
476
+
477
+ def reset(self):
478
+ """Reset state for new video."""
479
+ self.motion_analyzer.reset()
480
+
481
+ def compress_frame(
482
+ self,
483
+ frame: Image.Image,
484
+ mask: Optional[np.ndarray],
485
+ sigma: float,
486
+ ) -> CompressionResult:
487
+ """Compress a single frame.
488
+
489
+ Args:
490
+ frame: PIL Image frame
491
+ mask: ROI mask (or None for no ROI)
492
+ sigma: Background preservation factor
493
+
494
+ Returns:
495
+ CompressionResult with compressed image and stats
496
+ """
497
+ import vae
498
+
499
+ if mask is None:
500
+ mask = np.zeros((frame.height, frame.width), dtype=np.float32)
501
+
502
+ result = vae.compress_image(
503
+ image=frame,
504
+ mask=mask,
505
+ model=self.compression_model,
506
+ sigma=sigma,
507
+ device=self.device,
508
+ )
509
+
510
+ roi_coverage = float(mask.mean()) if mask is not None else 0.0
511
+
512
+ return CompressionResult(
513
+ compressed_image=result["compressed"],
514
+ bpp=result["bpp"],
515
+ original_size=(frame.width, frame.height),
516
+ roi_coverage=roi_coverage,
517
+ )
518
+
519
+ def compress_chunk(
520
+ self,
521
+ frames: List[Image.Image],
522
+ chunk_index: int,
523
+ effective_fps: float,
524
+ base_fps: float,
525
+ quality_level: int,
526
+ sigma: float,
527
+ roi_masks: Optional[List[np.ndarray]] = None,
528
+ ) -> ChunkResult:
529
+ """Compress a chunk of frames with given settings.
530
+
531
+ Args:
532
+ frames: List of PIL Image frames
533
+ chunk_index: Index of this chunk
534
+ effective_fps: Target effective framerate
535
+ base_fps: Original video framerate
536
+ quality_level: Compression quality 1-5
537
+ sigma: Background preservation factor
538
+ roi_masks: Optional list of ROI masks
539
+
540
+ Returns:
541
+ ChunkResult with compressed frames and stats
542
+ """
543
+ if not frames:
544
+ return ChunkResult(
545
+ frames=[],
546
+ frame_indices=[],
547
+ effective_fps=0.0,
548
+ quality_level=quality_level,
549
+ sigma=sigma,
550
+ avg_bpp=0.0,
551
+ estimated_bytes=0,
552
+ chunk_index=chunk_index,
553
+ original_frame_count=0,
554
+ )
555
+
556
+ # Compute frame sampling to achieve target FPS
557
+ frame_step = max(1, int(base_fps / effective_fps))
558
+ sampled_indices = list(range(0, len(frames), frame_step))
559
+
560
+ # Ensure at least one frame
561
+ if not sampled_indices:
562
+ sampled_indices = [0]
563
+
564
+ # Compress sampled frames
565
+ compressed_frames: List[Image.Image] = []
566
+ bpps: List[float] = []
567
+
568
+ for idx in sampled_indices:
569
+ frame = frames[idx]
570
+ mask = roi_masks[idx] if roi_masks and idx < len(roi_masks) else None
571
+
572
+ result = self.compress_frame(frame, mask, sigma)
573
+ compressed_frames.append(result.compressed_image)
574
+ bpps.append(result.bpp)
575
+
576
+ # Compute stats
577
+ avg_bpp = float(np.mean(bpps)) if bpps else 0.0
578
+ frame_pixels = frames[0].width * frames[0].height
579
+ total_bits = avg_bpp * frame_pixels * len(compressed_frames)
580
+ estimated_bytes = int(total_bits / 8)
581
+
582
+ actual_fps = len(sampled_indices) / (len(frames) / base_fps) if len(frames) > 0 else 0
583
+
584
+ return ChunkResult(
585
+ frames=compressed_frames,
586
+ frame_indices=sampled_indices,
587
+ effective_fps=float(actual_fps),
588
+ quality_level=quality_level,
589
+ sigma=sigma,
590
+ avg_bpp=avg_bpp,
591
+ estimated_bytes=estimated_bytes,
592
+ chunk_index=chunk_index,
593
+ original_frame_count=len(frames),
594
+ )
595
+
596
+ def segment_frames(
597
+ self,
598
+ frames: List[Image.Image],
599
+ temporal_smoothing: bool = True,
600
+ smoothing_alpha: float = 0.3,
601
+ ) -> List[np.ndarray]:
602
+ """Segment multiple frames to get ROI masks (sequential).
603
+
604
+ Args:
605
+ frames: List of PIL Images
606
+ temporal_smoothing: Whether to apply temporal smoothing
607
+ smoothing_alpha: Alpha for fast EMA smoothing (0-1, lower=smoother)
608
+
609
+ Returns:
610
+ List of ROI masks (temporally smoothed if enabled)
611
+ """
612
+ if self.segmenter is None:
613
+ return [np.zeros((f.height, f.width), dtype=np.float32) for f in frames]
614
+
615
+ masks = []
616
+ for frame in frames:
617
+ mask = self.segmenter(frame, target_classes=self.target_classes)
618
+ masks.append(mask.astype(np.float32))
619
+
620
+ # Apply temporal smoothing to reduce flickering
621
+ if temporal_smoothing and len(masks) > 2:
622
+ masks = smooth_masks_sdf(
623
+ masks,
624
+ alpha=smoothing_alpha,
625
+ empty_thresh=10,
626
+ patience=5,
627
+ )
628
+
629
+ return masks
630
+
631
+ def segment_frames_batch(
632
+ self,
633
+ frames: List[Image.Image],
634
+ batch_size: int = 4,
635
+ temporal_smoothing: bool = True,
636
+ smoothing_window: int = 5,
637
+ smoothing_alpha: float = 0.3,
638
+ ) -> List[np.ndarray]:
639
+ """Segment multiple frames with batch processing and temporal smoothing.
640
+
641
+ Processes frames in batches for better GPU utilization.
642
+ Falls back to sequential processing if batch not supported.
643
+
644
+ Optionally applies temporal smoothing to reduce mask flickering.
645
+
646
+ Args:
647
+ frames: List of PIL Images
648
+ batch_size: Number of frames to process at once
649
+ temporal_smoothing: Whether to apply temporal smoothing
650
+ smoothing_window: Window size for temporal smoothing (odd number)
651
+ smoothing_alpha: Alpha for fast EMA smoothing (0-1, lower=smoother)
652
+
653
+ Returns:
654
+ List of ROI masks (temporally smoothed if enabled)
655
+ """
656
+ if self.segmenter is None:
657
+ return [np.zeros((f.height, f.width), dtype=np.float32) for f in frames]
658
+
659
+ # Check if segmenter supports batch processing
660
+ import torch
661
+ max_retries = 7
662
+ bs = batch_size
663
+ masks = None
664
+ for attempt in range(max_retries + 1):
665
+ try:
666
+ masks = []
667
+ if hasattr(self.segmenter, 'segment_batch') and getattr(self.segmenter, 'supports_batch', False):
668
+ for i in range(0, len(frames), bs):
669
+ batch = frames[i:i + bs]
670
+ batch_masks = self.segmenter.segment_batch(batch, target_classes=self.target_classes)
671
+ masks.extend([m.astype(np.float32) for m in batch_masks])
672
+ else:
673
+ for frame in frames:
674
+ mask = self.segmenter(frame, target_classes=self.target_classes)
675
+ masks.append(mask.astype(np.float32))
676
+ break # success
677
+ except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
678
+ if 'out of memory' in str(e).lower() and attempt < max_retries:
679
+ bs = max(1, bs // 2)
680
+ # Aggressive memory cleanup
681
+ masks = None
682
+ import gc
683
+ gc.collect()
684
+ torch.cuda.empty_cache()
685
+ torch.cuda.synchronize()
686
+ print(f"segment_frames_batch: OOM, retrying batch_size={bs} (attempt {attempt+1}/{max_retries})")
687
+ continue
688
+ raise
689
+ if masks is None:
690
+ raise RuntimeError("Segmentation failed after OOM retries")
691
+
692
+ # Apply temporal smoothing to reduce flickering
693
+ if temporal_smoothing and len(masks) > 2:
694
+ # Use SDF smoothing for jitter-free, fluid transitions
695
+ masks = smooth_masks_sdf(
696
+ masks,
697
+ alpha=smoothing_alpha,
698
+ empty_thresh=10,
699
+ patience=5,
700
+ )
701
+
702
+ return masks
703
+
704
+ def compress_frames_batch(
705
+ self,
706
+ frames: List[Image.Image],
707
+ masks: List[np.ndarray],
708
+ sigma: "float | List[float]",
709
+ batch_size: int = 4,
710
+ ) -> List[CompressionResult]:
711
+ """Compress multiple frames with true batch processing.
712
+
713
+ Processes frames in batches for better GPU utilization.
714
+ Batches padding operations and CPU transfers for efficiency.
715
+
716
+ Args:
717
+ frames: List of PIL Images
718
+ masks: List of ROI masks
719
+ sigma: Background preservation factor – a single float (uniform)
720
+ or a per-frame list of floats.
721
+ batch_size: Number of frames to process at once
722
+
723
+ Returns:
724
+ List of CompressionResult
725
+ """
726
+ import torch
727
+ import torch.nn.functional as F
728
+ import vae
729
+ from vae.utils import compute_padding
730
+
731
+ if not frames:
732
+ return []
733
+
734
+ # Normalise sigma to a per-frame list
735
+ if isinstance(sigma, (int, float)):
736
+ sigmas = [float(sigma)] * len(frames)
737
+ else:
738
+ sigmas = [float(s) for s in sigma]
739
+ if len(sigmas) != len(frames):
740
+ raise ValueError(f"sigma list length {len(sigmas)} != frame count {len(frames)}")
741
+
742
+ results = []
743
+ max_retries = 7
744
+ bs = batch_size
745
+
746
+ # Process in batches with OOM retry
747
+ pos = 0
748
+ while pos < len(frames):
749
+ batch_end = min(pos + bs, len(frames))
750
+ batch_frames = frames[pos:batch_end]
751
+ batch_masks = masks[pos:batch_end]
752
+
753
+ try:
754
+ batch_results = self._compress_batch_inner(
755
+ batch_frames, batch_masks, sigmas, pos, pad_cache=None,
756
+ )
757
+ results.extend(batch_results)
758
+ pos = batch_end
759
+ except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
760
+ if 'out of memory' in str(e).lower() and bs > 1:
761
+ bs = max(1, bs // 2)
762
+ torch.cuda.empty_cache()
763
+ print(f"compress_frames_batch: OOM, retrying with batch_size={bs}")
764
+ continue
765
+ raise
766
+
767
+ # Clear GPU memory after all batches
768
+ if torch.cuda.is_available():
769
+ torch.cuda.empty_cache()
770
+
771
+ return results
772
+
773
+ def _compress_batch_inner(
774
+ self,
775
+ batch_frames: List[Image.Image],
776
+ batch_masks: List[np.ndarray],
777
+ sigmas: List[float],
778
+ global_offset: int,
779
+ pad_cache: Optional[Any] = None,
780
+ ) -> List[CompressionResult]:
781
+ """Compress a single batch of frames (inner helper).
782
+
783
+ Args:
784
+ batch_frames: Frames in this batch.
785
+ batch_masks: Masks in this batch.
786
+ sigmas: Per-frame sigma values (full list, indexed by global_offset + i).
787
+ global_offset: Start index into the full sigmas list.
788
+ pad_cache: Unused, reserved for future padding reuse.
789
+
790
+ Returns:
791
+ List of CompressionResult for this batch.
792
+ """
793
+ import torch
794
+ import torch.nn.functional as F
795
+ from vae.utils import compute_padding
796
+
797
+ batch_results = []
798
+ batch_tensors = []
799
+ batch_mask_tensors = []
800
+ original_sizes = []
801
+
802
+ for frame, mask in zip(batch_frames, batch_masks):
803
+ img_array = np.array(frame).astype(np.float32) / 255.0
804
+ img_tensor = torch.from_numpy(img_array).permute(2, 0, 1)
805
+
806
+ if mask is None:
807
+ mask = np.zeros((frame.height, frame.width), dtype=np.float32)
808
+ mask_tensor = torch.from_numpy(mask).unsqueeze(0)
809
+
810
+ batch_tensors.append(img_tensor)
811
+ batch_mask_tensors.append(mask_tensor)
812
+ original_sizes.append((frame.width, frame.height))
813
+
814
+ _, h, w = batch_tensors[0].shape
815
+ pad, unpad = compute_padding(h, w, min_div=256)
816
+
817
+ padded_batch = torch.stack([
818
+ F.pad(img_t, pad, mode='constant', value=0) for img_t in batch_tensors
819
+ ]).to(self.device)
820
+
821
+ padded_masks = torch.stack([
822
+ F.pad(mask_t, pad, mode='constant', value=0) for mask_t in batch_mask_tensors
823
+ ]).to(self.device)
824
+
825
+ with torch.no_grad():
826
+ # Extract per-frame sigma values for this batch
827
+ batch_sigmas = [sigmas[global_offset + i] for i in range(len(batch_frames))]
828
+
829
+ # Call model once with entire batch (TRUE BATCHING)
830
+ if all(s == batch_sigmas[0] for s in batch_sigmas):
831
+ # All frames have same sigma - use scalar
832
+ out = self.compression_model(
833
+ padded_batch,
834
+ padded_masks,
835
+ sigma=batch_sigmas[0],
836
+ )
837
+ else:
838
+ # Different sigma values - pass as tensor
839
+ sigma_tensor = torch.tensor(batch_sigmas, device=self.device, dtype=torch.float32)
840
+ out = self.compression_model(
841
+ padded_batch,
842
+ padded_masks,
843
+ sigma=sigma_tensor,
844
+ )
845
+
846
+ # Unpad all frames at once
847
+ x_hat_padded = out['x_hat']
848
+ x_hat_batch = F.pad(x_hat_padded, unpad)
849
+
850
+ # Move to CPU and convert to numpy once
851
+ x_hat_cpu = x_hat_batch.cpu().numpy()
852
+
853
+ # Extract per-frame results
854
+ for i in range(len(batch_frames)):
855
+ x_hat_np = x_hat_cpu[i].transpose(1, 2, 0)
856
+ x_hat_np = np.clip(x_hat_np * 255, 0, 255).astype(np.uint8)
857
+ compressed_img = Image.fromarray(x_hat_np)
858
+
859
+ # Calculate BPP for this frame
860
+ num_pixels = h * w
861
+ likelihoods = out['likelihoods']
862
+ bpp_y = torch.log(likelihoods['y'][i:i+1]).sum() / (-np.log(2) * num_pixels)
863
+ bpp_z = torch.log(likelihoods['z'][i:i+1]).sum() / (-np.log(2) * num_pixels)
864
+ bpp = (bpp_y + bpp_z).item()
865
+
866
+ mask_for_coverage = batch_masks[i] if i < len(batch_masks) else None
867
+ roi_coverage = float(mask_for_coverage.mean()) if mask_for_coverage is not None else 0.0
868
+
869
+ batch_results.append(CompressionResult(
870
+ compressed_image=compressed_img,
871
+ bpp=bpp,
872
+ original_size=original_sizes[i],
873
+ roi_coverage=roi_coverage,
874
+ ))
875
+
876
+ del padded_batch, padded_masks, x_hat_batch, x_hat_cpu
877
+ return batch_results
video/gpu_memory.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GPU memory estimation and automatic batch-size selection.
2
+
3
+ Queries free VRAM and uses per-model heuristics to pick the largest safe
4
+ batch size for both segmentation and compression stages. Falls back to
5
+ batch=1 on CPU or when CUDA info is unavailable.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import math
11
+ from dataclasses import dataclass
12
+ from typing import Optional, Tuple
13
+
14
+
15
+ # ---------------------------------------------------------------------------
16
+ # Per-frame memory cost heuristics (bytes, float32, conservative)
17
+ # These are empirical estimates measured on a mix of 480p frames.
18
+ # ---------------------------------------------------------------------------
19
+
20
+ # Segmentation models: (model_key, approx_param_bytes, per_frame_activation_bytes_480p)
21
+ _SEG_MEMORY_PER_FRAME: dict[str, float] = {
22
+ # YOLO-X seg – large model (71M params) with conv layers
23
+ "yolo": 180 * 1024**2, # ~180 MB per frame at 480p (X model)
24
+ # SegFormer – transformer encoder, moderate
25
+ "segformer": 200 * 1024**2, # ~200 MB per frame
26
+ # Mask2Former – Swin-Large + mask decoder, heavier
27
+ "mask2former": 350 * 1024**2, # ~350 MB per frame
28
+ # Mask R-CNN – ResNet50-FPN, moderate
29
+ "maskrcnn": 250 * 1024**2, # ~250 MB per frame
30
+ # SAM3 (OWL-ViT + SAM) – not truly batchable, treat as sequential
31
+ "sam3": 500 * 1024**2, # ~500 MB (single image pipeline)
32
+ # Fake segmentation (detection + tracking → bbox masks) - much lighter
33
+ "fake_yolo": 120 * 1024**2, # ~120 MB per frame (YOLO detection)
34
+ "fake_yolo_botsort": 120 * 1024**2, # Same as fake_yolo
35
+ "fake_detr": 150 * 1024**2, # ~150 MB per frame (DETR transformer)
36
+ "fake_deformable_detr": 170 * 1024**2, # ~170 MB per frame
37
+ "fake_fasterrcnn": 140 * 1024**2, # ~140 MB per frame (ResNet50 backbone)
38
+ "fake_retinanet": 140 * 1024**2, # ~140 MB per frame
39
+ "fake_fcos": 130 * 1024**2, # ~130 MB per frame
40
+ "fake_grounding_dino": 800 * 1024**2, # ~800 MB per frame (VERY large: BERT text + Swin-T vision + cross-attention)
41
+ }
42
+
43
+ _SEG_MODEL_OVERHEAD: dict[str, float] = {
44
+ "yolo": 450 * 1024**2, # YOLO-X model overhead (~71M params)
45
+ "segformer": 400 * 1024**2,
46
+ "mask2former": 800 * 1024**2,
47
+ "maskrcnn": 300 * 1024**2,
48
+ "sam3": 600 * 1024**2,
49
+ # Fake segmentation overhead (detector weights + tracker state)
50
+ "fake_yolo": 350 * 1024**2, # YOLO-X detector weights
51
+ "fake_yolo_botsort": 350 * 1024**2,
52
+ "fake_detr": 200 * 1024**2, # DETR weights
53
+ "fake_deformable_detr": 250 * 1024**2,
54
+ "fake_fasterrcnn": 160 * 1024**2, # Faster R-CNN weights
55
+ "fake_retinanet": 160 * 1024**2,
56
+ "fake_fcos": 150 * 1024**2,
57
+ "fake_grounding_dino": 1200 * 1024**2, # Grounding DINO weights (VERY large: ~700M params)
58
+ }
59
+
60
+ # TIC transformer compression: very heavy activations (attention is O(N²))
61
+ # Measured empirically on 854×480 frames:
62
+ # - Model params + entropy structures ≈ 1–3 GB
63
+ # - Per-frame activations ≈ 2–3 GB (float32)
64
+ _TIC_MODEL_OVERHEAD: float = 2.0 * 1024**3 # ~2 GB for params + entropy buffers
65
+ _TIC_PER_FRAME_480P: float = 2.5 * 1024**3 # ~2.5 GB per frame
66
+
67
+ # Reference resolution for the estimates above
68
+ _REF_HEIGHT = 480
69
+ _REF_WIDTH = 854
70
+ _REF_PIXELS = _REF_HEIGHT * _REF_WIDTH
71
+
72
+ # Safety margin – leave at least this fraction of free memory unused
73
+ _SAFETY_MARGIN = 0.15 # 15 %
74
+
75
+
76
+ @dataclass
77
+ class BatchSizeEstimate:
78
+ """Result of automatic batch-size estimation."""
79
+
80
+ seg_batch_size: int
81
+ compress_batch_size: int
82
+ free_vram_bytes: int
83
+ device: str
84
+ notes: str = ""
85
+
86
+
87
+ def _get_free_vram(device: str = "cuda") -> Optional[int]:
88
+ """Return free VRAM in bytes, or None if unavailable."""
89
+ try:
90
+ import torch
91
+
92
+ if not torch.cuda.is_available() or device == "cpu":
93
+ return None
94
+ dev_idx = 0
95
+ if ":" in device:
96
+ dev_idx = int(device.split(":")[1])
97
+ free, _total = torch.cuda.mem_get_info(dev_idx)
98
+ return int(free)
99
+ except Exception:
100
+ return None
101
+
102
+
103
+ def _scale_memory(base_bytes: float, frame_h: int, frame_w: int) -> float:
104
+ """Scale a 480p memory estimate to an arbitrary resolution."""
105
+ pixels = frame_h * frame_w
106
+ # Attention is roughly O(pixels) for feature-map memory and O(tokens²)
107
+ # for attention matrices. Use a conservative linear-ish scaling.
108
+ ratio = pixels / _REF_PIXELS
109
+ # Slightly super-linear to account for quadratic attention term
110
+ return base_bytes * (ratio ** 1.2)
111
+
112
+
113
+ def estimate_batch_sizes(
114
+ frame_height: int = 480,
115
+ frame_width: int = 854,
116
+ seg_method: str = "yolo",
117
+ device: str = "cuda",
118
+ total_frames: int = 300,
119
+ ) -> BatchSizeEstimate:
120
+ """Estimate optimal batch sizes for segmentation and compression.
121
+
122
+ The function queries free VRAM and computes how many frames can be
123
+ processed in a single batch for each stage, independently.
124
+
125
+ Args:
126
+ frame_height: Height of (pre-processed) frames.
127
+ frame_width: Width of (pre-processed) frames.
128
+ seg_method: Segmentation method key (yolo, segformer, mask2former, etc.)
129
+ device: Torch device string.
130
+ total_frames: Total number of frames in the video.
131
+
132
+ Returns:
133
+ BatchSizeEstimate with recommended batch sizes.
134
+ """
135
+ free = _get_free_vram(device)
136
+ notes_parts: list[str] = []
137
+
138
+ if free is None:
139
+ # CPU fallback
140
+ return BatchSizeEstimate(
141
+ seg_batch_size=min(16, total_frames),
142
+ compress_batch_size=1,
143
+ free_vram_bytes=0,
144
+ device=device,
145
+ notes="CPU mode – using sequential compression, modest seg batches.",
146
+ )
147
+
148
+ usable = int(free * (1.0 - _SAFETY_MARGIN))
149
+
150
+ # --- Segmentation batch size ---
151
+ seg_key = seg_method.lower()
152
+ per_frame_seg = _scale_memory(
153
+ _SEG_MEMORY_PER_FRAME.get(seg_key, 200 * 1024**2),
154
+ frame_height,
155
+ frame_width,
156
+ )
157
+ model_overhead_seg = _SEG_MODEL_OVERHEAD.get(seg_key, 400 * 1024**2)
158
+ available_for_seg_frames = usable - model_overhead_seg
159
+ if available_for_seg_frames < per_frame_seg:
160
+ seg_batch = 1
161
+ else:
162
+ seg_batch = int(available_for_seg_frames / per_frame_seg)
163
+
164
+ # SAM3 is not truly batchable – cap to 1
165
+ if seg_key == "sam3":
166
+ seg_batch = 1
167
+ notes_parts.append("SAM3 is sequential (OWL-ViT+SAM pipeline).")
168
+
169
+ seg_batch = max(1, min(seg_batch, total_frames))
170
+
171
+ # --- Compression batch size ---
172
+ per_frame_compress = _scale_memory(_TIC_PER_FRAME_480P, frame_height, frame_width)
173
+ available_for_compress = usable - _TIC_MODEL_OVERHEAD
174
+ if available_for_compress < per_frame_compress:
175
+ compress_batch = 1
176
+ else:
177
+ compress_batch = int(available_for_compress / per_frame_compress)
178
+
179
+ compress_batch = max(1, min(compress_batch, total_frames))
180
+
181
+ notes_parts.append(
182
+ f"Free VRAM: {free / 1024**3:.1f} GB, "
183
+ f"usable: {usable / 1024**3:.1f} GB. "
184
+ f"Seg per-frame est: {per_frame_seg / 1024**2:.0f} MB, "
185
+ f"compress per-frame est: {per_frame_compress / 1024**2:.0f} MB."
186
+ )
187
+
188
+ return BatchSizeEstimate(
189
+ seg_batch_size=seg_batch,
190
+ compress_batch_size=compress_batch,
191
+ free_vram_bytes=free,
192
+ device=device,
193
+ notes=" ".join(notes_parts),
194
+ )
video/mask_cache.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utilities for saving and loading video segmentation masks."""
2
+
3
+ import numpy as np
4
+ import tempfile
5
+ from pathlib import Path
6
+ from typing import List, Optional
7
+ import pickle
8
+
9
+
10
+ def save_video_masks(masks: List[np.ndarray], output_path: Optional[str] = None) -> str:
11
+ """Save video segmentation masks to a file.
12
+
13
+ Args:
14
+ masks: List of mask arrays (H, W) for each frame
15
+ output_path: Optional output path. If None, creates temp file.
16
+
17
+ Returns:
18
+ Path to saved mask file
19
+ """
20
+ if output_path is None:
21
+ # Create temporary file
22
+ fd, output_path = tempfile.mkstemp(suffix='.masks.npz', prefix='video_masks_')
23
+ import os
24
+ os.close(fd)
25
+
26
+ # Stack masks and save
27
+ mask_array = np.stack(masks, axis=0) # (T, H, W)
28
+ np.savez_compressed(output_path, masks=mask_array)
29
+
30
+ return output_path
31
+
32
+
33
+ def load_video_masks(mask_path: str) -> List[np.ndarray]:
34
+ """Load video segmentation masks from a file.
35
+
36
+ Args:
37
+ mask_path: Path to saved mask file
38
+
39
+ Returns:
40
+ List of mask arrays (H, W) for each frame
41
+ """
42
+ data = np.load(mask_path)
43
+ mask_array = data['masks'] # (T, H, W)
44
+
45
+ # Convert back to list
46
+ masks = [mask_array[i] for i in range(mask_array.shape[0])]
47
+
48
+ return masks
49
+
50
+
51
+ def get_mask_info(mask_path: str) -> dict:
52
+ """Get information about saved masks without loading them.
53
+
54
+ Args:
55
+ mask_path: Path to saved mask file
56
+
57
+ Returns:
58
+ Dictionary with mask metadata
59
+ """
60
+ data = np.load(mask_path)
61
+ mask_array = data['masks']
62
+
63
+ return {
64
+ 'num_frames': mask_array.shape[0],
65
+ 'height': mask_array.shape[1],
66
+ 'width': mask_array.shape[2],
67
+ 'dtype': str(mask_array.dtype),
68
+ 'size_mb': mask_array.nbytes / (1024 * 1024),
69
+ }