Commit ·
4fec4e4
1
Parent(s): fa7e75c
Initial Commit
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +8 -35
- .github/copilot-instructions.md +417 -0
- .gitignore +69 -0
- API.md +1029 -0
- README.md +507 -7
- _segmentation_comparison.ipynb +0 -0
- app.py +0 -0
- checkpoints/tic_lambda_0.0035.pth.tar +3 -0
- checkpoints/tic_lambda_0.013.pth.tar +3 -0
- checkpoints/tic_lambda_0.025.pth.tar +3 -0
- checkpoints/tic_lambda_0.0483.pth.tar +3 -0
- checkpoints/tic_lambda_0.0932.pth.tar +3 -0
- detection/__init__.py +30 -0
- detection/base.py +83 -0
- detection/bytetrack.py +358 -0
- detection/detr.py +215 -0
- detection/factory.py +50 -0
- detection/grounding_dino.py +215 -0
- detection/torchvision_detectors.py +300 -0
- detection/tracker.py +387 -0
- detection/utils.py +124 -0
- detection/yolo.py +98 -0
- detection/yolo_world.py +188 -0
- examples.sh +272 -0
- model_cache.py +62 -0
- requirements.txt +26 -0
- roi_compressor.py +183 -0
- roi_detection_eval.py +639 -0
- roi_segmenter.py +355 -0
- segmentation/__init__.py +33 -0
- segmentation/base.py +169 -0
- segmentation/factory.py +124 -0
- segmentation/fake.py +412 -0
- segmentation/mask2former.py +310 -0
- segmentation/maskrcnn.py +310 -0
- segmentation/sam3.py +187 -0
- segmentation/segformer.py +193 -0
- segmentation/utils.py +90 -0
- segmentation/yolo.py +188 -0
- vae/RSTB.py +813 -0
- vae/__init__.py +57 -0
- vae/roi_tic.py +132 -0
- vae/tic_model.py +989 -0
- vae/transformer_layers.py +250 -0
- vae/utils.py +102 -0
- vae/visualization.py +94 -0
- video/__init__.py +55 -0
- video/chunk_compressor.py +877 -0
- video/gpu_memory.py +194 -0
- video/mask_cache.py +69 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,8 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
*.
|
| 7 |
-
*.
|
| 8 |
-
*.
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
checkpoints/tic_lambda_0.0932.pth.tar filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
checkpoints/tic_lambda_0.0035.pth.tar filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
checkpoints/tic_lambda_0.013.pth.tar filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
checkpoints/tic_lambda_0.025.pth.tar filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
checkpoints/tic_lambda_0.0483.pth.tar filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
images/*/*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
images/*/*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
images/*/*.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.github/copilot-instructions.md
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ROI-VAE Image Compression - Copilot Instructions
|
| 2 |
+
|
| 3 |
+
## Project Overview
|
| 4 |
+
ROI-based VAE image compression using TIC (Transformer-based Image Compression). The system preserves quality in Regions of Interest (ROI) while aggressively compressing backgrounds using configurable quality factors.
|
| 5 |
+
|
| 6 |
+
## Architecture
|
| 7 |
+
|
| 8 |
+
### Core Pipeline
|
| 9 |
+
1. **Segmentation** (`segmentation/` module) → 2. **Compression** (`vae/` module) → 3. **Output**
|
| 10 |
+
- Segmentation creates binary masks (1=ROI, 0=background)
|
| 11 |
+
- Compression applies variable quality based on mask using `sigma` parameter
|
| 12 |
+
|
| 13 |
+
### Key Components
|
| 14 |
+
|
| 15 |
+
**Segmentation Module** (`segmentation/`):
|
| 16 |
+
- Abstract base class `BaseSegmenter` defines common interface
|
| 17 |
+
- Implementations:
|
| 18 |
+
- `SegFormerSegmenter` - Cityscapes semantic segmentation (19 classes: road, car, building, person, etc.)
|
| 19 |
+
- `YOLOSegmenter` - COCO instance segmentation (80 classes)
|
| 20 |
+
- `Mask2FormerSegmenter` - Swin Transformer-based panoptic/semantic segmentation (COCO: 133 classes, ADE20K: 150 classes)
|
| 21 |
+
- `MaskRCNNSegmenter` - ResNet50-FPN instance segmentation (COCO: 80 classes)
|
| 22 |
+
- `SAM3Segmenter` - Prompt-based segmentation (natural language prompt → mask via text-conditioned detector + SAM)
|
| 23 |
+
- `FakeSegmenter` - Detection + tracking → bbox masks (fast, non-pixel-perfect)
|
| 24 |
+
- **Fake Segmentation** (NEW): Detection-based segmentation for speed
|
| 25 |
+
- Creates rectangular masks from detection bounding boxes
|
| 26 |
+
- Uses object tracking for temporal consistency (ByteTrack, BoTSORT, SimpleTracker)
|
| 27 |
+
- Available methods: `fake_yolo` (default, ByteTrack), `fake_yolo_botsort`, `fake_detr`, `fake_fasterrcnn`, `fake_retinanet`, `fake_fcos`, `fake_deformable_detr`, `fake_grounding_dino`
|
| 28 |
+
- Much faster than pixel-perfect segmentation (~60-100 fps vs 10-30 fps)
|
| 29 |
+
- Memory estimates in `gpu_memory.py`: 120-200 MB per frame (vs 180-500 MB for full segmentation)
|
| 30 |
+
- Factory pattern: `create_segmenter('yolo', device='cuda')` or `create_segmenter('fake_yolo', device='cuda')`
|
| 31 |
+
- Extensible for future models
|
| 32 |
+
- Utils: `visualize_mask()`, `save_mask()`, `calculate_roi_stats()`
|
| 33 |
+
|
| 34 |
+
**Compression Module** (`vae/`):
|
| 35 |
+
- `tic_model.py`: Base `TIC` class - Transformer-based VAE with encoder, decoder, hyperprior
|
| 36 |
+
- `RSTB.py`: Residual Swin Transformer Blocks and attention modules
|
| 37 |
+
- `transformer_layers.py`: Generic transformer components (MLP, attention, drop path)
|
| 38 |
+
- `roi_tic.py`: `ModifiedTIC` class extending base TIC with ROI-aware forward pass
|
| 39 |
+
- `utils.py`: `compress_image()`, `compute_padding()` for image processing
|
| 40 |
+
- `visualization.py`: `highlight_roi()`, `create_comparison_grid()` for results
|
| 41 |
+
- Handles checkpoint loading with compressai version compatibility fixes
|
| 42 |
+
|
| 43 |
+
**Detection Module** (`detection/`):
|
| 44 |
+
- Abstract base class `BaseDetector` defines common interface
|
| 45 |
+
- Factory pattern: `create_detector('yolo', device='cuda')`
|
| 46 |
+
- Implementations:
|
| 47 |
+
- `YOLODetector` - Ultralytics YOLO (closed-vocabulary COCO weights)
|
| 48 |
+
- Torchvision: Faster R-CNN, RetinaNet, SSD, FCOS
|
| 49 |
+
- Transformers: DETR, Deformable DETR
|
| 50 |
+
- `EfficientDetDetector` - optional via `effdet`
|
| 51 |
+
- `YOLOWorldDetector` - open-vocabulary detection (Ultralytics YOLO-World; requires prompts)
|
| 52 |
+
- `GroundingDINODetector` - open-vocabulary detection (Transformers; requires prompts)
|
| 53 |
+
- CLI: `roi_detection_eval.py` evaluates detection retention before vs after ROI compression
|
| 54 |
+
|
| 55 |
+
**TIC Model** (`vae/tic_model.py`):
|
| 56 |
+
- Transformer-based VAE with encoder (`g_a`), decoder (`g_s`), and hyperprior (`h_a`, `h_s`)
|
| 57 |
+
- Uses RSTB (Residual Swin Transformer Blocks) for feature extraction
|
| 58 |
+
- Channels: N=192, M=192 (expansion layer)
|
| 59 |
+
- Critical: Images must be padded to multiples of 256 (use `compute_padding()`)
|
| 60 |
+
|
| 61 |
+
**ModifiedTIC** (`vae/roi_tic.py`):
|
| 62 |
+
- Extends base TIC with ROI-aware forward pass
|
| 63 |
+
- Takes mask + sigma parameter to create quality factors
|
| 64 |
+
- Applies `similarity_loss` tensor: 1.0 for ROI pixels, sigma for background
|
| 65 |
+
- Integrates mask through `simi_net` and `sub_impor_net` branches
|
| 66 |
+
|
| 67 |
+
## Critical Conventions
|
| 68 |
+
|
| 69 |
+
### Model Cache Locations
|
| 70 |
+
- By default, auto-downloaded model artifacts are kept inside `checkpoints/`:
|
| 71 |
+
- Hugging Face cache: `checkpoints/hf/`
|
| 72 |
+
- Torch/torchvision cache: `checkpoints/torch/`
|
| 73 |
+
|
| 74 |
+
### Checkpoint Loading Pattern
|
| 75 |
+
```python
|
| 76 |
+
from vae import load_checkpoint
|
| 77 |
+
|
| 78 |
+
# Automatically handles compressai version mismatch
|
| 79 |
+
model = load_checkpoint('checkpoints/tic_lambda_0.0483.pth.tar', N=192, M=192, device='cuda')
|
| 80 |
+
# Note: model.update(force=True) is called automatically
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
Manual loading:
|
| 84 |
+
```python
|
| 85 |
+
# Fix compressai version mismatch - required for all checkpoint loading
|
| 86 |
+
state_dict = checkpoint["state_dict"]
|
| 87 |
+
new_state_dict = {}
|
| 88 |
+
for k, v in state_dict.items():
|
| 89 |
+
if "entropy_bottleneck._matrix" in k:
|
| 90 |
+
new_key = k.replace("entropy_bottleneck._matrix", "entropy_bottleneck.matrices.")
|
| 91 |
+
# ... similar replacements for _bias, _factor
|
| 92 |
+
```
|
| 93 |
+
Always call `model.update(force=True)` after loading checkpoints.
|
| 94 |
+
|
| 95 |
+
### Image Preprocessing
|
| 96 |
+
1. Convert PIL to torch tensor: `x = torch.from_numpy(np.array(img)).float() / 255.0`
|
| 97 |
+
2. Permute to [B, C, H, W]: `x = x.permute(2, 0, 1).unsqueeze(0)`
|
| 98 |
+
3. Pad to 256 multiples using `compute_padding(h, w, min_div=256)`
|
| 99 |
+
4. Apply mask at same resolution as input image
|
| 100 |
+
|
| 101 |
+
### Sigma Parameter
|
| 102 |
+
- Range: 0.01 - 1.0 (lower = more background compression)
|
| 103 |
+
- Default: 0.3
|
| 104 |
+
- ROI pixels always get quality factor 1.0
|
| 105 |
+
- Applied via `torch.where(mask > 0.5, 1.0, sigma)`
|
| 106 |
+
|
| 107 |
+
### Available Checkpoints
|
| 108 |
+
Located in `checkpoints/` directory with different lambda (rate-distortion) values:
|
| 109 |
+
- `tic_lambda_0.0035.pth.tar` - Lowest bitrate (highest compression)
|
| 110 |
+
- `tic_lambda_0.013.pth.tar` - Low bitrate (N=128, M=192)
|
| 111 |
+
- `tic_lambda_0.025.pth.tar` - Medium-low bitrate
|
| 112 |
+
- `tic_lambda_0.0483.pth.tar` - **Default** - Medium bitrate
|
| 113 |
+
- `tic_lambda_0.0932.pth.tar` - High bitrate (better quality)
|
| 114 |
+
- `yolo26x-seg.pt` - YOLO segmentation model
|
| 115 |
+
|
| 116 |
+
## Development Workflows
|
| 117 |
+
|
| 118 |
+
### Using Segmentation Module (New)
|
| 119 |
+
```python
|
| 120 |
+
from segmentation import create_segmenter
|
| 121 |
+
|
| 122 |
+
# Available methods: segformer, yolo, mask2former, maskrcnn, sam3
|
| 123 |
+
# Fake methods: fake_yolo, fake_yolo_botsort, fake_detr, fake_fasterrcnn, etc.
|
| 124 |
+
segmenter = create_segmenter('mask2former', device='cuda', model_type='coco')
|
| 125 |
+
|
| 126 |
+
# Segment image
|
| 127 |
+
mask = segmenter(image, target_classes=['car', 'person'])
|
| 128 |
+
|
| 129 |
+
# Fast segmentation with detection + tracking (non-pixel-perfect)
|
| 130 |
+
fake_seg = create_segmenter('fake_yolo', device='cuda')
|
| 131 |
+
mask = fake_seg(image, target_classes=['person']) # Uses ByteTrack tracking
|
| 132 |
+
# Much faster: ~60-100 fps vs 10-30 fps for pixel-perfect segmentation
|
| 133 |
+
|
| 134 |
+
# Add new segmentation method
|
| 135 |
+
from segmentation import register_segmenter, BaseSegmenter
|
| 136 |
+
|
| 137 |
+
class MySegmenter(BaseSegmenter):
|
| 138 |
+
def load_model(self): ...
|
| 139 |
+
def segment(self, image, target_classes, **kwargs): ...
|
| 140 |
+
def get_available_classes(self): ...
|
| 141 |
+
|
| 142 |
+
register_segmenter('my_method', MySegmenter)
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
### Using Compression Module (New)
|
| 146 |
+
```python
|
| 147 |
+
from vae import load_checkpoint, compress_image
|
| 148 |
+
from PIL import Image
|
| 149 |
+
import numpy as np
|
| 150 |
+
|
| 151 |
+
# Load model
|
| 152 |
+
model = load_checkpoint('checkpoints/tic_lambda_0.0483.pth.tar', device='cuda')
|
| 153 |
+
|
| 154 |
+
# Compress image with mask
|
| 155 |
+
image = Image.open('input.jpg')
|
| 156 |
+
mask = np.zeros((image.height, image.width)) # Your mask here
|
| 157 |
+
|
| 158 |
+
result = compress_image(image, mask, model, sigma=0.3, device='cuda')
|
| 159 |
+
compressed = result['compressed'] # PIL Image
|
| 160 |
+
bpp = result['bpp'] # Bits per pixel
|
| 161 |
+
|
| 162 |
+
# Visualize results
|
| 163 |
+
from vae import create_comparison_grid
|
| 164 |
+
grid = create_comparison_grid(image, compressed, mask, bpp, sigma=0.3, lambda_val=0.0483)
|
| 165 |
+
grid.save('comparison.jpg')
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
### Using Detection Module (New)
|
| 169 |
+
```python
|
| 170 |
+
from detection import create_detector
|
| 171 |
+
|
| 172 |
+
# Closed-vocabulary
|
| 173 |
+
det = create_detector('yolo', device='cuda', model_path='checkpoints/yolo26x.pt')
|
| 174 |
+
dets = det(image, conf_threshold=0.25)
|
| 175 |
+
|
| 176 |
+
# Open-vocabulary (must pass prompts/classes)
|
| 177 |
+
det_ov = create_detector('yolo_world', device='cuda')
|
| 178 |
+
dets_ov = det_ov(image, conf_threshold=0.25, classes='person,car')
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
### Detection Eval (CLI)
|
| 182 |
+
```bash
|
| 183 |
+
# Compare before vs after (already-compressed)
|
| 184 |
+
python roi_detection_eval.py \
|
| 185 |
+
--before images/car/0016cf15fa4d4e16.jpg \
|
| 186 |
+
--after results/compressed.jpg \
|
| 187 |
+
--detectors yolo detr \
|
| 188 |
+
--viz-dir results/det_viz
|
| 189 |
+
|
| 190 |
+
# Open-vocabulary eval (YOLO-World requires prompts)
|
| 191 |
+
python roi_detection_eval.py \
|
| 192 |
+
--before images/person/kodim04.png \
|
| 193 |
+
--checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
|
| 194 |
+
--sigma 0.3 \
|
| 195 |
+
--seg-method yolo --seg-classes person \
|
| 196 |
+
--detectors yolo_world \
|
| 197 |
+
--open-vocab-classes "person,car" \
|
| 198 |
+
--viz-dir results/det_viz
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
### Running Compression (CLI)
|
| 202 |
+
```bash
|
| 203 |
+
# Basic compression with segmentation
|
| 204 |
+
python roi_compressor.py \
|
| 205 |
+
--input images/car/0016cf15fa4d4e16.jpg \
|
| 206 |
+
--output results/compressed.jpg \
|
| 207 |
+
--checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
|
| 208 |
+
--sigma 0.3 \
|
| 209 |
+
--seg-classes car \
|
| 210 |
+
--seg-method yolo
|
| 211 |
+
|
| 212 |
+
# Fast compression with detection-based fake segmentation (~3x faster)
|
| 213 |
+
python roi_compressor.py \
|
| 214 |
+
--input images/car/0016cf15fa4d4e16.jpg \
|
| 215 |
+
--output results/compressed.jpg \
|
| 216 |
+
--checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
|
| 217 |
+
--sigma 0.3 \
|
| 218 |
+
--seg-classes car \
|
| 219 |
+
--seg-method fake_yolo
|
| 220 |
+
|
| 221 |
+
# With comparison grid (original, compressed, ROI highlighted)
|
| 222 |
+
python roi_compressor.py ... --highlight
|
| 223 |
+
```
|
| 224 |
+
|
| 225 |
+
### Standalone Segmentation (CLI)
|
| 226 |
+
```bash
|
| 227 |
+
# Using Mask2Former with COCO panoptic
|
| 228 |
+
python roi_segmenter.py \
|
| 229 |
+
--input images/car/0016cf15fa4d4e16.jpg \
|
| 230 |
+
--output results/mask.png \
|
| 231 |
+
--method mask2former \
|
| 232 |
+
--classes car building person \
|
| 233 |
+
--visualize
|
| 234 |
+
|
| 235 |
+
# Fast segmentation with detection + ByteTrack tracking
|
| 236 |
+
python roi_segmenter.py \
|
| 237 |
+
--input data/videos/Person_doing_handstand.mp4 \
|
| 238 |
+
--output results/masks.mp4 \
|
| 239 |
+
--method fake_yolo \
|
| 240 |
+
--classes person \
|
| 241 |
+
--resize-height 480 \
|
| 242 |
+
--smooth-patience 10 \
|
| 243 |
+
--visualize
|
| 244 |
+
|
| 245 |
+
# Other fake methods (detection + tracking)
|
| 246 |
+
# fake_yolo_botsort (YOLO + BoTSORT)
|
| 247 |
+
# fake_detr (DETR + SimpleTracker)
|
| 248 |
+
# fake_fasterrcnn, fake_retinanet, fake_fcos, etc.
|
| 249 |
+
```
|
| 250 |
+
|
| 251 |
+
### Adding New Segmentation Models
|
| 252 |
+
1. Create new file in `segmentation/` (e.g., `sam.py`)
|
| 253 |
+
2. Extend `BaseSegmenter` and implement abstract methods:
|
| 254 |
+
- `load_model()`: Load model weights
|
| 255 |
+
- `segment()`: Generate mask from image
|
| 256 |
+
- `get_available_classes()`: Return supported classes/capabilities
|
| 257 |
+
3. Register in `segmentation/__init__.py` or use `register_segmenter()`
|
| 258 |
+
4. Use via `create_segmenter('your_method', ...)`
|
| 259 |
+
|
| 260 |
+
### Testing Examples
|
| 261 |
+
- `roi_segmenter.py`: CLI tool for standalone segmentation
|
| 262 |
+
- `roi_compressor.py`: CLI tool for ROI-based image compression
|
| 263 |
+
- `roi_segmenter.py`: CLI tool for standalone segmentation
|
| 264 |
+
- `roi_compressor.py`: CLI tool for ROI-based image compression
|
| 265 |
+
|
| 266 |
+
- `segmentation/`: Modular segmentation with abstract base class
|
| 267 |
+
- `base.py`: `BaseSegmenter` abstract class
|
| 268 |
+
- `segformer.py`: Cityscapes semantic segmentation
|
| 269 |
+
- `yolo.py`: COCO instance segmentation
|
| 270 |
+
- `factory.py`: Factory pattern for creating segmenters
|
| 271 |
+
- `utils.py`: Visualization and I/O utilities
|
| 272 |
+
- `vae/`: Modular compression with ROI support
|
| 273 |
+
- `tic_model.py`: Base `TIC` class (Transformer-based VAE)
|
| 274 |
+
- `RSTB.py`: Residual Swin Transformer Blocks
|
| 275 |
+
- `transformer_layers.py`: Generic transformer components
|
| 276 |
+
- `roi_tic.py`: `ModifiedTIC` class and checkpoint loading
|
| 277 |
+
- `utils.py`: `compress_image()`, `compute_padding()`
|
| 278 |
+
- `visualization.py`: `highlight_roi()`, `create_comparison_grid()`
|
| 279 |
+
- `roi_segmenter.py`: CLI tool for standalone segmentation
|
| 280 |
+
- `roi_compressor.py`: CLI tool for ROI-based compression
|
| 281 |
+
- `vae_compress.py`: Legacy ROI compression script (updated to use modules)
|
| 282 |
+
- `*.bak`: Backup files from pre-modularization (tic_model, RSTB, etc.)
|
| 283 |
+
|
| 284 |
+
## Dependencies
|
| 285 |
+
- PyTorch + torchvision for model
|
| 286 |
+
- compressai for entropy models (version sensitive - see checkpoint loading)
|
| 287 |
+
- transformers for SegFormer + DETR/Deformable DETR + Grounding DINO
|
| 288 |
+
- ultralytics for YOLO + YOLO-World
|
| 289 |
+
- effdet (optional) for EfficientDet detector
|
| 290 |
+
- timm for model layers
|
| 291 |
+
|
| 292 |
+
## Common Pitfalls
|
| 293 |
+
1. **Padding**: Forgetting to pad images to 256 multiples causes dimension mismatches
|
| 294 |
+
2. **Checkpoint keys**: Old checkpoints use `_matrix/_bias/_factor` naming that must be converted
|
| 295 |
+
3. **Mask resolution**: Mask must match input image size; it's automatically downsampled in forward pass
|
| 296 |
+
4. **Mask downsampling**: In ModifiedTIC, mask is downsampled to 1/2 resolution before simi_net (which further downsamples 8x to match 16x16 latent)
|
| 297 |
+
5. **Device mismatch**: Ensure mask, sigma tensor, and model are on same device
|
| 298 |
+
6. **Model update**: Must call `model.update(force=True)` after loading for entropy models
|
| 299 |
+
|
| 300 |
+
## Project Structure
|
| 301 |
+
|
| 302 |
+
- `.github/copilot-instructions.md`: This file - comprehensive development guide
|
| 303 |
+
- `examples.sh`: Example commands for running compression and segmentation
|
| 304 |
+
- `README.md`: Project overview and quick start guide
|
| 305 |
+
- `requirements.txt`: Python dependencies
|
| 306 |
+
|
| 307 |
+
**CLI Tools:**
|
| 308 |
+
- `roi_segmenter.py`: CLI tool for standalone segmentation
|
| 309 |
+
- `roi_compressor.py`: CLI tool for ROI-based image compression
|
| 310 |
+
- `app.py`: Gradio demo with Image and Video tabs
|
| 311 |
+
|
| 312 |
+
**Core Modules:**
|
| 313 |
+
- `segmentation/`: Modular segmentation with abstract base class
|
| 314 |
+
- `base.py`: `BaseSegmenter` abstract class
|
| 315 |
+
- `segformer.py`: Cityscapes semantic segmentation (19 classes)
|
| 316 |
+
- `yolo.py`: COCO instance segmentation (80 classes)
|
| 317 |
+
- `mask2former.py`: Swin-based panoptic/semantic (COCO: 133, ADE20K: 150 classes)
|
| 318 |
+
- `maskrcnn.py`: ResNet50-FPN instance segmentation (COCO: 80 classes)
|
| 319 |
+
- `sam3.py`: Prompt-based segmentation
|
| 320 |
+
- `factory.py`: Factory pattern for creating segmenters
|
| 321 |
+
- `utils.py`: Visualization and I/O utilities
|
| 322 |
+
- `vae/`: Modular compression with ROI support
|
| 323 |
+
- `tic_model.py`: Base `TIC` class (Transformer-based VAE)
|
| 324 |
+
- `RSTB.py`: Residual Swin Transformer Blocks
|
| 325 |
+
- `transformer_layers.py`: Generic transformer components
|
| 326 |
+
- `roi_tic.py`: `ModifiedTIC` class and checkpoint loading
|
| 327 |
+
- `utils.py`: `compress_image()`, `compute_padding()`
|
| 328 |
+
- `visualization.py`: `highlight_roi()`, `create_comparison_grid()`
|
| 329 |
+
- `video/`: Video compression with streaming support
|
| 330 |
+
- `video_processor.py`: `VideoProcessor` class for video compression
|
| 331 |
+
- `motion_analyzer.py`: `MotionAnalyzer` for scene complexity estimation
|
| 332 |
+
- `chunk_compressor.py`: `ChunkCompressor` and `BandwidthController`
|
| 333 |
+
- `detection/`: Object detection and tracking
|
| 334 |
+
- `tracker.py`: `SimpleTracker` IoU-based multi-object tracker
|
| 335 |
+
- `utils.py`: `draw_detections()`, `draw_tracks()`
|
| 336 |
+
|
| 337 |
+
## Video Processing
|
| 338 |
+
|
| 339 |
+
### Video Module Usage
|
| 340 |
+
```python
|
| 341 |
+
from video import VideoProcessor, CompressionSettings
|
| 342 |
+
|
| 343 |
+
# Create processor
|
| 344 |
+
processor = VideoProcessor(device='cuda')
|
| 345 |
+
processor.load_models(
|
| 346 |
+
quality_level=4,
|
| 347 |
+
segmentation_method='sam3',
|
| 348 |
+
detection_method='yolo',
|
| 349 |
+
enable_tracking=True,
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
# Static mode (fixed settings)
|
| 353 |
+
settings = CompressionSettings(
|
| 354 |
+
mode='static',
|
| 355 |
+
quality_level=4,
|
| 356 |
+
sigma=0.3,
|
| 357 |
+
output_fps=15.0,
|
| 358 |
+
target_classes=['person', 'car'],
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
for chunk in processor.process_static('input.mp4', settings):
|
| 362 |
+
# Stream chunks in real-time
|
| 363 |
+
print(f"Chunk {chunk.chunk_index}: {len(chunk.frames)} frames at {chunk.fps} FPS")
|
| 364 |
+
|
| 365 |
+
# Dynamic mode (bandwidth-adaptive)
|
| 366 |
+
settings = CompressionSettings(
|
| 367 |
+
mode='dynamic',
|
| 368 |
+
target_bandwidth_kbps=500,
|
| 369 |
+
min_fps=5,
|
| 370 |
+
max_fps=30,
|
| 371 |
+
chunk_duration_sec=1.0,
|
| 372 |
+
target_classes=['person', 'car'],
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
for chunk in processor.process_dynamic('input.mp4', settings):
|
| 376 |
+
# Adaptive FPS and quality per chunk based on motion
|
| 377 |
+
print(f"Chunk {chunk.chunk_index}: fps={chunk.fps:.1f}, quality={chunk.quality_level}")
|
| 378 |
+
```
|
| 379 |
+
|
| 380 |
+
### Motion-Adaptive Compression
|
| 381 |
+
The dynamic mode analyzes each chunk for:
|
| 382 |
+
- **Motion magnitude**: Mean pixel change between frames
|
| 383 |
+
- **Motion coverage**: Fraction of pixels with significant motion
|
| 384 |
+
- **Scene complexity**: Edge density and texture variance
|
| 385 |
+
- **Scene changes**: Large global differences
|
| 386 |
+
|
| 387 |
+
High-motion scenes get:
|
| 388 |
+
- More frames (higher FPS)
|
| 389 |
+
- Higher spatial compression (lower quality/sigma) to stay within bandwidth
|
| 390 |
+
|
| 391 |
+
Low-motion scenes get:
|
| 392 |
+
- Fewer frames (lower FPS)
|
| 393 |
+
- Better spatial quality (higher quality/sigma)
|
| 394 |
+
|
| 395 |
+
### Object Tracking
|
| 396 |
+
```python
|
| 397 |
+
from detection import SimpleTracker, draw_tracks
|
| 398 |
+
|
| 399 |
+
tracker = SimpleTracker(iou_threshold=0.3, max_age=30)
|
| 400 |
+
|
| 401 |
+
for frame_detections in frame_by_frame_detections:
|
| 402 |
+
tracks = tracker.update(frame_detections)
|
| 403 |
+
# tracks contains track_id, label, bbox, history
|
| 404 |
+
|
| 405 |
+
# Draw tracks with trails
|
| 406 |
+
img = draw_tracks(frame, tracks, show_id=True, show_trail=True)
|
| 407 |
+
```
|
| 408 |
+
|
| 409 |
+
## Coding Guidelines
|
| 410 |
+
- Don't create unnecessary files—focus on core functionality.
|
| 411 |
+
- Ensure all scripts have clear argument parsing and help messages.
|
| 412 |
+
- Maintain consistent coding style and comments for clarity.
|
| 413 |
+
- Validate inputs (image paths, checkpoint paths, segmentation classes).
|
| 414 |
+
- Include error handling for common issues (file not found, dimension mismatches).
|
| 415 |
+
- Document all functions and classes with docstrings.
|
| 416 |
+
- Write modular code to facilitate testing and future extensions.
|
| 417 |
+
- Use ipynb files for prototyping but keep main logic in .py files.
|
.gitignore
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --- Python ---
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*.pyd
|
| 5 |
+
*.so
|
| 6 |
+
*.egg-info/
|
| 7 |
+
dist/
|
| 8 |
+
build/
|
| 9 |
+
.eggs/
|
| 10 |
+
.pytest_cache/
|
| 11 |
+
.mypy_cache/
|
| 12 |
+
.ruff_cache/
|
| 13 |
+
.pytype/
|
| 14 |
+
.coverage
|
| 15 |
+
coverage.xml
|
| 16 |
+
htmlcov/
|
| 17 |
+
|
| 18 |
+
# --- Virtual environments ---
|
| 19 |
+
venv/
|
| 20 |
+
.venv/
|
| 21 |
+
env/
|
| 22 |
+
ENV/
|
| 23 |
+
|
| 24 |
+
# --- Jupyter ---
|
| 25 |
+
.ipynb_checkpoints/
|
| 26 |
+
|
| 27 |
+
# --- OS / editor ---
|
| 28 |
+
.DS_Store
|
| 29 |
+
Thumbs.db
|
| 30 |
+
.vscode/
|
| 31 |
+
.idea/
|
| 32 |
+
|
| 33 |
+
# --- Secrets / local config ---
|
| 34 |
+
.env
|
| 35 |
+
.env.*
|
| 36 |
+
*.key
|
| 37 |
+
*.pem
|
| 38 |
+
|
| 39 |
+
# --- Logs / temp ---
|
| 40 |
+
*.log
|
| 41 |
+
logs/
|
| 42 |
+
tmp/
|
| 43 |
+
.cache/
|
| 44 |
+
|
| 45 |
+
# --- Gradio / HF Spaces artifacts ---
|
| 46 |
+
flagged/
|
| 47 |
+
gradio_cached_examples/
|
| 48 |
+
.gradio/
|
| 49 |
+
|
| 50 |
+
# --- Data / media ---
|
| 51 |
+
data/
|
| 52 |
+
|
| 53 |
+
# --- Model + framework caches (keep your curated checkpoints, ignore auto-download caches) ---
|
| 54 |
+
checkpoints/hf/
|
| 55 |
+
checkpoints/torch/
|
| 56 |
+
checkpoints/yolo*.pt
|
| 57 |
+
|
| 58 |
+
# --- Common ML experiment outputs ---
|
| 59 |
+
runs/
|
| 60 |
+
wandb/
|
| 61 |
+
outputs/
|
| 62 |
+
results/
|
| 63 |
+
|
| 64 |
+
# --- Large artifacts (uncomment if you don't want binaries tracked) ---
|
| 65 |
+
# *.pth
|
| 66 |
+
# *.pt
|
| 67 |
+
# *.pth.tar
|
| 68 |
+
# *.onnx
|
| 69 |
+
# *.ckpt
|
API.md
ADDED
|
@@ -0,0 +1,1029 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# API Documentation
|
| 2 |
+
|
| 3 |
+
This document describes the Gradio API endpoints exposed by the ROI-VAE image and video compression application. The API allows programmatic access to segmentation, compression, detection, and full pipeline processing for both images and videos.
|
| 4 |
+
|
| 5 |
+
**Live Demo:** https://biaslab2025-contextual-communication-demo.hf.space
|
| 6 |
+
|
| 7 |
+
## Table of Contents
|
| 8 |
+
|
| 9 |
+
- [Quick Start](#quick-start)
|
| 10 |
+
- [Important Notes](#important-notes)
|
| 11 |
+
- [Image API Endpoints](#image-api-endpoints)
|
| 12 |
+
- [/segment](#1-segment---generate-roi-mask)
|
| 13 |
+
- [/compress](#2-compress---compress-image)
|
| 14 |
+
- [/detect](#3-detect---object-detection)
|
| 15 |
+
- [/detect_overlay](#31-detect_overlay---detection-with-visualization)
|
| 16 |
+
- [/process](#4-process---full-image-pipeline)
|
| 17 |
+
- [Video API Endpoints](#video-api-endpoints)
|
| 18 |
+
- [/segment_video](#1-segment_video---segment-video)
|
| 19 |
+
- [/compress_video](#2-compress_video---compress-video)
|
| 20 |
+
- [/detect_video](#3-detect_video---video-detection)
|
| 21 |
+
- [/process_video](#4-process_video---full-video-pipeline)
|
| 22 |
+
- [Streaming Video API Endpoints](#streaming-video-api-endpoints)
|
| 23 |
+
- [/stream_process_video](#1-stream_process_video---full-streaming-pipeline)
|
| 24 |
+
- [/stream_compress_video](#2-stream_compress_video---simplified-streaming-compression)
|
| 25 |
+
- [Class Reference](#class-reference)
|
| 26 |
+
- [Error Handling](#error-handling)
|
| 27 |
+
- [GPU Quota Handling](#handling-gpu-quota-on-hf-spaces)
|
| 28 |
+
- [cURL Examples](#using-with-curl)
|
| 29 |
+
- [Example Scripts](#example-scripts)
|
| 30 |
+
|
| 31 |
+
---
|
| 32 |
+
|
| 33 |
+
## Quick Start
|
| 34 |
+
|
| 35 |
+
### Installation
|
| 36 |
+
|
| 37 |
+
```bash
|
| 38 |
+
pip install gradio_client
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
### Image Processing
|
| 42 |
+
|
| 43 |
+
```python
|
| 44 |
+
from gradio_client import Client, handle_file
|
| 45 |
+
|
| 46 |
+
# Connect to the API
|
| 47 |
+
client = Client("https://biaslab2025-contextual-communication-demo.hf.space")
|
| 48 |
+
# Or local: client = Client("http://localhost:7860")
|
| 49 |
+
|
| 50 |
+
# Full pipeline: segment → compress → detect
|
| 51 |
+
compressed, mask, bpp, ratio, coverage, detections_json = client.predict(
|
| 52 |
+
handle_file("path/to/image.jpg"),
|
| 53 |
+
"car, person", # segmentation prompt
|
| 54 |
+
"sam3", # segmentation method
|
| 55 |
+
4, # quality level (1-5)
|
| 56 |
+
0.3, # sigma (background compression)
|
| 57 |
+
True, # run detection
|
| 58 |
+
"yolo", # detection method
|
| 59 |
+
"", # detection classes (empty for closed-vocab)
|
| 60 |
+
api_name="/process"
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
print(f"Compression: {bpp:.4f} bpp ({ratio:.2f}x)")
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
### Video Processing
|
| 67 |
+
|
| 68 |
+
```python
|
| 69 |
+
from gradio_client import Client, handle_file
|
| 70 |
+
import json
|
| 71 |
+
|
| 72 |
+
client = Client("http://localhost:7860")
|
| 73 |
+
|
| 74 |
+
# Full pipeline with static settings
|
| 75 |
+
output_video, stats_json = client.predict(
|
| 76 |
+
handle_file("path/to/video.mp4"),
|
| 77 |
+
"person, car", # segmentation classes
|
| 78 |
+
"sam3", # segmentation method
|
| 79 |
+
"static", # mode: "static" or "dynamic"
|
| 80 |
+
4, # quality level (1-5)
|
| 81 |
+
0.3, # sigma
|
| 82 |
+
15.0, # output FPS
|
| 83 |
+
500, # bandwidth (dynamic mode)
|
| 84 |
+
5, # min_fps (dynamic mode)
|
| 85 |
+
30, # max_fps (dynamic mode)
|
| 86 |
+
False, # run detection
|
| 87 |
+
"yolo", # detection method
|
| 88 |
+
None, # mask_file_path (optional)
|
| 89 |
+
api_name="/process_video"
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
stats = json.loads(stats_json)
|
| 93 |
+
print(f"Compressed video: {output_video}")
|
| 94 |
+
print(f"Total frames: {stats['total_frames']}")
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
---
|
| 98 |
+
|
| 99 |
+
## Important Notes
|
| 100 |
+
|
| 101 |
+
### File Handling
|
| 102 |
+
|
| 103 |
+
Always wrap file paths with `handle_file()` when using `gradio_client`:
|
| 104 |
+
|
| 105 |
+
```python
|
| 106 |
+
from gradio_client import handle_file
|
| 107 |
+
|
| 108 |
+
# ✅ Correct
|
| 109 |
+
client.predict(handle_file("image.jpg"), ...)
|
| 110 |
+
|
| 111 |
+
# ❌ Incorrect - will fail with validation error
|
| 112 |
+
client.predict("image.jpg", ...)
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
### Detection Output Format
|
| 116 |
+
|
| 117 |
+
All detection endpoints return JSON strings with this structure:
|
| 118 |
+
|
| 119 |
+
```python
|
| 120 |
+
import json
|
| 121 |
+
|
| 122 |
+
detections = json.loads(detections_json)
|
| 123 |
+
# Each detection has:
|
| 124 |
+
# - label: str (class name)
|
| 125 |
+
# - score: float (confidence 0-1)
|
| 126 |
+
# - bbox_xyxy: list[float] (bounding box [x1, y1, x2, y2])
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
### Open-Vocabulary Detectors
|
| 130 |
+
|
| 131 |
+
The following detectors require a `classes` parameter:
|
| 132 |
+
- `yolo_world` - YOLO-World
|
| 133 |
+
- `grounding_dino` - Grounding DINO
|
| 134 |
+
|
| 135 |
+
Closed-vocabulary detectors (`yolo`, `detr`, `faster_rcnn`, etc.) use pretrained COCO classes and ignore the `classes` parameter.
|
| 136 |
+
|
| 137 |
+
---
|
| 138 |
+
|
| 139 |
+
## Image API Endpoints
|
| 140 |
+
|
| 141 |
+
### 1. `/segment` - Generate ROI Mask
|
| 142 |
+
|
| 143 |
+
Segments an image to create a Region of Interest (ROI) mask.
|
| 144 |
+
|
| 145 |
+
**Parameters:**
|
| 146 |
+
|
| 147 |
+
| Parameter | Type | Default | Description |
|
| 148 |
+
|-----------|------|---------|-------------|
|
| 149 |
+
| `image` | Image | required | Input image file |
|
| 150 |
+
| `prompt` | str | `"object"` | Comma-separated classes or natural language prompt |
|
| 151 |
+
| `method` | str | `"sam3"` | Segmentation method (see [methods](#segmentation-methods)) |
|
| 152 |
+
| `return_overlay` | bool | `False` | If `True`, returns image with ROI highlighted instead of mask |
|
| 153 |
+
|
| 154 |
+
**Returns:**
|
| 155 |
+
|
| 156 |
+
| Output | Type | Description |
|
| 157 |
+
|--------|------|-------------|
|
| 158 |
+
| `result_image` | Image | Grayscale mask OR image with ROI overlay (if `return_overlay=True`) |
|
| 159 |
+
| `roi_coverage` | float | Fraction of image covered by ROI (0.0-1.0) |
|
| 160 |
+
| `classes_used` | str | JSON list of classes/prompts used |
|
| 161 |
+
|
| 162 |
+
**Example:**
|
| 163 |
+
|
| 164 |
+
```python
|
| 165 |
+
# Get binary mask (default)
|
| 166 |
+
mask, coverage, classes = client.predict(
|
| 167 |
+
handle_file("car_scene.jpg"),
|
| 168 |
+
"car, road",
|
| 169 |
+
"sam3",
|
| 170 |
+
False, # return_overlay
|
| 171 |
+
api_name="/segment"
|
| 172 |
+
)
|
| 173 |
+
print(f"ROI covers {coverage*100:.2f}% of image")
|
| 174 |
+
|
| 175 |
+
# Get image with ROI highlighted
|
| 176 |
+
highlighted, coverage, classes = client.predict(
|
| 177 |
+
handle_file("car_scene.jpg"),
|
| 178 |
+
"car, road",
|
| 179 |
+
"sam3",
|
| 180 |
+
True, # return_overlay=True
|
| 181 |
+
api_name="/segment"
|
| 182 |
+
)
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
---
|
| 186 |
+
|
| 187 |
+
### 2. `/compress` - Compress Image
|
| 188 |
+
|
| 189 |
+
Compresses an image using TIC VAE, optionally with an ROI mask for variable quality.
|
| 190 |
+
|
| 191 |
+
**Parameters:**
|
| 192 |
+
|
| 193 |
+
| Parameter | Type | Default | Description |
|
| 194 |
+
|-----------|------|---------|-------------|
|
| 195 |
+
| `image` | Image | required | Input image file |
|
| 196 |
+
| `mask_image` | Image | `None` | ROI mask (white=ROI, black=background) |
|
| 197 |
+
| `quality` | int | `4` | Quality level 1-5 |
|
| 198 |
+
| `sigma` | float | `0.3` | Background preservation (0.01-1.0) |
|
| 199 |
+
|
| 200 |
+
**Quality Levels:**
|
| 201 |
+
|
| 202 |
+
| Level | Lambda | Description |
|
| 203 |
+
|-------|--------|-------------|
|
| 204 |
+
| 1 | 0.0035 | Smallest file |
|
| 205 |
+
| 2 | 0.013 | Smaller file |
|
| 206 |
+
| 3 | 0.025 | Balanced |
|
| 207 |
+
| 4 | 0.0483 | Higher quality (default) |
|
| 208 |
+
| 5 | 0.0932 | Best quality |
|
| 209 |
+
|
| 210 |
+
**Returns:**
|
| 211 |
+
|
| 212 |
+
| Output | Type | Description |
|
| 213 |
+
|--------|------|-------------|
|
| 214 |
+
| `compressed_image` | Image | Compressed output image |
|
| 215 |
+
| `bpp` | float | Bits per pixel |
|
| 216 |
+
| `compression_ratio` | float | Compression ratio (24/bpp) |
|
| 217 |
+
|
| 218 |
+
**Example:**
|
| 219 |
+
|
| 220 |
+
```python
|
| 221 |
+
# Compress without mask (uniform quality)
|
| 222 |
+
compressed, bpp, ratio = client.predict(
|
| 223 |
+
handle_file("image.jpg"),
|
| 224 |
+
None, # no mask
|
| 225 |
+
4, # quality
|
| 226 |
+
0.3, # sigma (ignored without mask)
|
| 227 |
+
api_name="/compress"
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Compress with ROI mask
|
| 231 |
+
mask, _, _ = client.predict(handle_file("image.jpg"), "person", "yolo", False, api_name="/segment")
|
| 232 |
+
|
| 233 |
+
compressed, bpp, ratio = client.predict(
|
| 234 |
+
handle_file("image.jpg"),
|
| 235 |
+
handle_file(mask),
|
| 236 |
+
4,
|
| 237 |
+
0.2, # aggressive background compression
|
| 238 |
+
api_name="/compress"
|
| 239 |
+
)
|
| 240 |
+
```
|
| 241 |
+
|
| 242 |
+
---
|
| 243 |
+
|
| 244 |
+
### 3. `/detect` - Object Detection
|
| 245 |
+
|
| 246 |
+
Runs object detection on an image and returns detection results as JSON.
|
| 247 |
+
|
| 248 |
+
**Parameters:**
|
| 249 |
+
|
| 250 |
+
| Parameter | Type | Default | Description |
|
| 251 |
+
|-----------|------|---------|-------------|
|
| 252 |
+
| `image` | Image | required | Input image file |
|
| 253 |
+
| `method` | str | `"yolo"` | Detection method (see [methods](#detection-methods)) |
|
| 254 |
+
| `classes` | str | `""` | Comma-separated classes (required for open-vocab detectors) |
|
| 255 |
+
| `confidence` | float | `0.25` | Confidence threshold (0.0-1.0) |
|
| 256 |
+
|
| 257 |
+
**Returns:**
|
| 258 |
+
|
| 259 |
+
| Output | Type | Description |
|
| 260 |
+
|--------|------|-------------|
|
| 261 |
+
| `detections_json` | str | JSON string of detection results |
|
| 262 |
+
|
| 263 |
+
**Example - Closed-Vocabulary:**
|
| 264 |
+
|
| 265 |
+
```python
|
| 266 |
+
import json
|
| 267 |
+
|
| 268 |
+
# YOLO detection (COCO classes)
|
| 269 |
+
dets_json = client.predict(
|
| 270 |
+
handle_file("street_scene.jpg"),
|
| 271 |
+
"yolo",
|
| 272 |
+
"", # no classes needed
|
| 273 |
+
0.25,
|
| 274 |
+
api_name="/detect"
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
detections = json.loads(dets_json)
|
| 278 |
+
for det in detections:
|
| 279 |
+
print(f"{det['label']}: {det['score']:.2f}")
|
| 280 |
+
```
|
| 281 |
+
|
| 282 |
+
**Example - Open-Vocabulary:**
|
| 283 |
+
|
| 284 |
+
```python
|
| 285 |
+
# YOLO-World with custom classes
|
| 286 |
+
dets_json = client.predict(
|
| 287 |
+
handle_file("image.jpg"),
|
| 288 |
+
"yolo_world",
|
| 289 |
+
"hat, backpack, umbrella", # custom classes required
|
| 290 |
+
0.25,
|
| 291 |
+
api_name="/detect"
|
| 292 |
+
)
|
| 293 |
+
```
|
| 294 |
+
|
| 295 |
+
---
|
| 296 |
+
|
| 297 |
+
### 3.1. `/detect_overlay` - Detection with Visualization
|
| 298 |
+
|
| 299 |
+
Runs object detection and returns the image with bounding boxes drawn.
|
| 300 |
+
|
| 301 |
+
**Parameters:**
|
| 302 |
+
|
| 303 |
+
| Parameter | Type | Default | Description |
|
| 304 |
+
|-----------|------|---------|-------------|
|
| 305 |
+
| `image` | Image | required | Input image file |
|
| 306 |
+
| `method` | str | `"yolo"` | Detection method (see [methods](#detection-methods)) |
|
| 307 |
+
| `classes` | str | `""` | Comma-separated classes (required for open-vocab detectors) |
|
| 308 |
+
| `confidence` | float | `0.25` | Confidence threshold (0.0-1.0) |
|
| 309 |
+
|
| 310 |
+
**Returns:**
|
| 311 |
+
|
| 312 |
+
| Output | Type | Description |
|
| 313 |
+
|--------|------|-------------|
|
| 314 |
+
| `result_image` | Image | Image with detection bounding boxes |
|
| 315 |
+
| `detections_json` | str | JSON string of detection results |
|
| 316 |
+
|
| 317 |
+
**Example:**
|
| 318 |
+
|
| 319 |
+
```python
|
| 320 |
+
import json
|
| 321 |
+
|
| 322 |
+
# Get image with detection boxes
|
| 323 |
+
result_img, dets_json = client.predict(
|
| 324 |
+
handle_file("street_scene.jpg"),
|
| 325 |
+
"yolo",
|
| 326 |
+
"",
|
| 327 |
+
0.25,
|
| 328 |
+
api_name="/detect_overlay"
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
# result_img is a file path to the image with boxes drawn
|
| 332 |
+
print(f"Image with boxes: {result_img}")
|
| 333 |
+
detections = json.loads(dets_json)
|
| 334 |
+
```
|
| 335 |
+
|
| 336 |
+
---
|
| 337 |
+
|
| 338 |
+
### 4. `/process` - Full Image Pipeline
|
| 339 |
+
|
| 340 |
+
Runs the complete pipeline: segmentation → compression → optional detection.
|
| 341 |
+
|
| 342 |
+
**Parameters:**
|
| 343 |
+
|
| 344 |
+
| Parameter | Type | Default | Description |
|
| 345 |
+
|-----------|------|---------|-------------|
|
| 346 |
+
| `image` | Image | required | Input image file |
|
| 347 |
+
| `prompt` | str | `"object"` | Segmentation prompt/classes |
|
| 348 |
+
| `segmentation_method` | str | `"sam3"` | ROI segmentation method |
|
| 349 |
+
| `quality` | int | `4` | Compression quality (1-5) |
|
| 350 |
+
| `sigma` | float | `0.3` | Background preservation (0.01-1.0) |
|
| 351 |
+
| `run_detection` | bool | `False` | Whether to run detection on output |
|
| 352 |
+
| `detection_method` | str | `"yolo"` | Detector to use |
|
| 353 |
+
| `detection_classes` | str | `""` | Classes for open-vocab detectors |
|
| 354 |
+
|
| 355 |
+
**Returns:**
|
| 356 |
+
|
| 357 |
+
| Output | Type | Description |
|
| 358 |
+
|--------|------|-------------|
|
| 359 |
+
| `compressed_image` | Image | Compressed output image |
|
| 360 |
+
| `mask_image` | Image | Generated ROI mask |
|
| 361 |
+
| `bpp` | float | Bits per pixel |
|
| 362 |
+
| `compression_ratio` | float | Compression ratio |
|
| 363 |
+
| `roi_coverage` | float | ROI coverage percentage (0-1) |
|
| 364 |
+
| `detections_json` | str | JSON detections (empty list if `run_detection=False`) |
|
| 365 |
+
|
| 366 |
+
**Example:**
|
| 367 |
+
|
| 368 |
+
```python
|
| 369 |
+
import json
|
| 370 |
+
|
| 371 |
+
compressed, mask, bpp, ratio, coverage, dets_json = client.predict(
|
| 372 |
+
handle_file("street.jpg"),
|
| 373 |
+
"car, person, road",
|
| 374 |
+
"sam3",
|
| 375 |
+
4,
|
| 376 |
+
0.3,
|
| 377 |
+
True, # run detection
|
| 378 |
+
"yolo",
|
| 379 |
+
"",
|
| 380 |
+
api_name="/process"
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
print(f"ROI Coverage: {coverage*100:.2f}%")
|
| 384 |
+
print(f"Compression: {bpp:.4f} bpp ({ratio:.2f}x)")
|
| 385 |
+
print(f"Detections: {len(json.loads(dets_json))}")
|
| 386 |
+
```
|
| 387 |
+
|
| 388 |
+
---
|
| 389 |
+
|
| 390 |
+
## Video API Endpoints
|
| 391 |
+
|
| 392 |
+
### 1. `/segment_video` - Segment Video
|
| 393 |
+
|
| 394 |
+
Segments a video to find ROI regions, returning either a mask file or overlay video.
|
| 395 |
+
|
| 396 |
+
**Parameters:**
|
| 397 |
+
|
| 398 |
+
| Parameter | Type | Default | Description |
|
| 399 |
+
|-----------|------|---------|-------------|
|
| 400 |
+
| `video_path` | Video | required | Input video file |
|
| 401 |
+
| `prompt` | str | `"object"` | Comma-separated classes or natural language prompt |
|
| 402 |
+
| `method` | str | `"sam3"` | Segmentation method |
|
| 403 |
+
| `return_overlay` | bool | `False` | If `True`, returns video with ROI highlighted |
|
| 404 |
+
| `output_fps` | float | `15.0` | Output framerate (max 30) |
|
| 405 |
+
|
| 406 |
+
**Returns:**
|
| 407 |
+
|
| 408 |
+
| Output | Type | Description |
|
| 409 |
+
|--------|------|-------------|
|
| 410 |
+
| `result_path` | File/Video | Mask file (NPZ) OR video with ROI overlay |
|
| 411 |
+
| `stats_json` | str | JSON with frame count, coverage, and classes |
|
| 412 |
+
|
| 413 |
+
**Example:**
|
| 414 |
+
|
| 415 |
+
```python
|
| 416 |
+
import json
|
| 417 |
+
|
| 418 |
+
# Get mask file for reuse in compression
|
| 419 |
+
mask_file, stats_json = client.predict(
|
| 420 |
+
handle_file("video.mp4"),
|
| 421 |
+
"person, car",
|
| 422 |
+
"sam3",
|
| 423 |
+
False, # return masks file
|
| 424 |
+
15.0, # fps
|
| 425 |
+
api_name="/segment_video"
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
stats = json.loads(stats_json)
|
| 429 |
+
print(f"Processed {stats['total_frames']} frames")
|
| 430 |
+
print(f"Avg ROI coverage: {stats['avg_roi_coverage']*100:.2f}%")
|
| 431 |
+
|
| 432 |
+
# Get video with ROI overlay for visualization
|
| 433 |
+
overlay_video, _ = client.predict(
|
| 434 |
+
handle_file("video.mp4"),
|
| 435 |
+
"person, car",
|
| 436 |
+
"sam3",
|
| 437 |
+
True, # return overlay video
|
| 438 |
+
15.0,
|
| 439 |
+
api_name="/segment_video"
|
| 440 |
+
)
|
| 441 |
+
```
|
| 442 |
+
|
| 443 |
+
---
|
| 444 |
+
|
| 445 |
+
### 2. `/compress_video` - Compress Video
|
| 446 |
+
|
| 447 |
+
Compresses a video with optional ROI mask preservation.
|
| 448 |
+
|
| 449 |
+
**Parameters:**
|
| 450 |
+
|
| 451 |
+
| Parameter | Type | Default | Description |
|
| 452 |
+
|-----------|------|---------|-------------|
|
| 453 |
+
| `video_path` | Video | required | Input video file |
|
| 454 |
+
| `mask_file_path` | str | `None` | Path to pre-computed masks (from `/segment_video`) |
|
| 455 |
+
| `quality` | int | `4` | Quality level (1-5) |
|
| 456 |
+
| `sigma` | float | `0.3` | Background preservation (0.01-1.0) |
|
| 457 |
+
| `output_fps` | float | `15.0` | Target output framerate |
|
| 458 |
+
|
| 459 |
+
**Returns:**
|
| 460 |
+
|
| 461 |
+
| Output | Type | Description |
|
| 462 |
+
|--------|------|-------------|
|
| 463 |
+
| `compressed_video` | Video | Compressed output video |
|
| 464 |
+
| `stats_json` | str | JSON with compression statistics |
|
| 465 |
+
|
| 466 |
+
**Example:**
|
| 467 |
+
|
| 468 |
+
```python
|
| 469 |
+
import json
|
| 470 |
+
|
| 471 |
+
# First, segment to get masks
|
| 472 |
+
mask_file, _ = client.predict(
|
| 473 |
+
handle_file("video.mp4"), "person", "sam3", False, 15.0,
|
| 474 |
+
api_name="/segment_video"
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
# Then compress with cached masks (3-5x faster!)
|
| 478 |
+
compressed, stats_json = client.predict(
|
| 479 |
+
handle_file("video.mp4"),
|
| 480 |
+
mask_file, # reuse masks
|
| 481 |
+
4, # quality
|
| 482 |
+
0.3, # sigma
|
| 483 |
+
15.0, # fps
|
| 484 |
+
api_name="/compress_video"
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
stats = json.loads(stats_json)
|
| 488 |
+
print(f"Compression ratio: {stats['compression_ratio']}x")
|
| 489 |
+
print(f"Total size: {stats['total_size_kb']} KB")
|
| 490 |
+
```
|
| 491 |
+
|
| 492 |
+
---
|
| 493 |
+
|
| 494 |
+
### 3. `/detect_video` - Video Detection
|
| 495 |
+
|
| 496 |
+
Runs object detection on each frame of a video.
|
| 497 |
+
|
| 498 |
+
**Parameters:**
|
| 499 |
+
|
| 500 |
+
| Parameter | Type | Default | Description |
|
| 501 |
+
|-----------|------|---------|-------------|
|
| 502 |
+
| `video_path` | Video | required | Input video file |
|
| 503 |
+
| `method` | str | `"yolo"` | Detection method |
|
| 504 |
+
| `classes` | str | `""` | Comma-separated classes (required for open-vocab) |
|
| 505 |
+
| `confidence` | float | `0.25` | Confidence threshold (0.0-1.0) |
|
| 506 |
+
| `return_overlay` | bool | `False` | If `True`, returns video with detection boxes |
|
| 507 |
+
| `output_fps` | float | `15.0` | Output framerate (max 30) |
|
| 508 |
+
|
| 509 |
+
**Returns:**
|
| 510 |
+
|
| 511 |
+
| Output | Type | Description |
|
| 512 |
+
|--------|------|-------------|
|
| 513 |
+
| `result_video` | Video | Video with detection boxes (if `return_overlay=True`), None otherwise |
|
| 514 |
+
| `detections_json` | str | JSON with per-frame detections |
|
| 515 |
+
|
| 516 |
+
**Example:**
|
| 517 |
+
|
| 518 |
+
```python
|
| 519 |
+
import json
|
| 520 |
+
|
| 521 |
+
# Get per-frame detections JSON
|
| 522 |
+
_, dets_json = client.predict(
|
| 523 |
+
handle_file("video.mp4"),
|
| 524 |
+
"yolo",
|
| 525 |
+
"",
|
| 526 |
+
0.25,
|
| 527 |
+
False, # return JSON only
|
| 528 |
+
15.0,
|
| 529 |
+
api_name="/detect_video"
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
data = json.loads(dets_json)
|
| 533 |
+
print(f"Total detections: {data['total_detections']}")
|
| 534 |
+
print(f"Avg per frame: {data['avg_detections_per_frame']}")
|
| 535 |
+
|
| 536 |
+
# Get video with detection overlays
|
| 537 |
+
det_video, _ = client.predict(
|
| 538 |
+
handle_file("video.mp4"),
|
| 539 |
+
"yolo",
|
| 540 |
+
"",
|
| 541 |
+
0.25,
|
| 542 |
+
True, # return overlay video
|
| 543 |
+
15.0,
|
| 544 |
+
api_name="/detect_video"
|
| 545 |
+
)
|
| 546 |
+
```
|
| 547 |
+
|
| 548 |
+
---
|
| 549 |
+
|
| 550 |
+
### 4. `/process_video` - Full Video Pipeline
|
| 551 |
+
|
| 552 |
+
Processes a video with ROI-based compression (segment → compress), with optional detection.
|
| 553 |
+
|
| 554 |
+
**Parameters:**
|
| 555 |
+
|
| 556 |
+
| Parameter | Type | Default | Description |
|
| 557 |
+
|-----------|------|---------|-------------|
|
| 558 |
+
| `video_path` | Video | required | Input video file |
|
| 559 |
+
| `prompt` | str | `"object"` | Segmentation prompt/classes |
|
| 560 |
+
| `segmentation_method` | str | `"sam3"` | ROI segmentation method |
|
| 561 |
+
| `mode` | str | `"static"` | `"static"` or `"dynamic"` |
|
| 562 |
+
| `quality` | int | `4` | Quality level 1-5 (static mode) |
|
| 563 |
+
| `sigma` | float | `0.3` | Background preservation (static mode) |
|
| 564 |
+
| `output_fps` | float | `15.0` | Target framerate (static mode) |
|
| 565 |
+
| `bandwidth_kbps` | float | `500.0` | Target bandwidth (dynamic mode) |
|
| 566 |
+
| `min_fps` | float | `5.0` | Minimum framerate (dynamic mode) |
|
| 567 |
+
| `max_fps` | float | `30.0` | Maximum framerate (dynamic mode) |
|
| 568 |
+
| `aggressiveness` | float | `0.5` | Bandwidth savings strategy (dynamic mode): `0.0` = use full bandwidth (high FPS always), `0.5` = moderate savings, `1.0` = maximum savings (aggressive FPS reduction for low motion) |
|
| 569 |
+
| `run_detection` | bool | `False` | Whether to run detection/tracking |
|
| 570 |
+
| `detection_method` | str | `"yolo"` | Detector to use |
|
| 571 |
+
| `mask_file_path` | str | `None` | Path to pre-computed masks (skips segmentation) |
|
| 572 |
+
|
| 573 |
+
**Returns:**
|
| 574 |
+
|
| 575 |
+
| Output | Type | Description |
|
| 576 |
+
|--------|------|-------------|
|
| 577 |
+
| `output_video` | Video | Compressed video |
|
| 578 |
+
| `stats_json` | str | JSON with detailed statistics |
|
| 579 |
+
|
| 580 |
+
**Example - Static Mode:**
|
| 581 |
+
|
| 582 |
+
```python
|
| 583 |
+
import json
|
| 584 |
+
|
| 585 |
+
output, stats_json = client.predict(
|
| 586 |
+
handle_file("video.mp4"),
|
| 587 |
+
"person, car",
|
| 588 |
+
"sam3",
|
| 589 |
+
"static",
|
| 590 |
+
4, 0.3, 15.0, # static: quality, sigma, fps
|
| 591 |
+
500, 5, 30, # dynamic: bandwidth, min_fps, max_fps (ignored)
|
| 592 |
+
False, "yolo", None,
|
| 593 |
+
api_name="/process_video"
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
stats = json.loads(stats_json)
|
| 597 |
+
print(f"Processed {stats['total_frames']} frames")
|
| 598 |
+
```
|
| 599 |
+
|
| 600 |
+
**Example - Dynamic Mode:**
|
| 601 |
+
|
| 602 |
+
```python
|
| 603 |
+
output, stats_json = client.predict(
|
| 604 |
+
handle_file("video.mp4"),
|
| 605 |
+
"person",
|
| 606 |
+
"yolo",
|
| 607 |
+
"dynamic",
|
| 608 |
+
4, 0.3, 15.0, # static settings (ignored)
|
| 609 |
+
750, # target bandwidth 750 kbps
|
| 610 |
+
8, # min FPS
|
| 611 |
+
30, # max FPS
|
| 612 |
+
True, "yolo", None,
|
| 613 |
+
api_name="/process_video"
|
| 614 |
+
)
|
| 615 |
+
```
|
| 616 |
+
|
| 617 |
+
---
|
| 618 |
+
|
| 619 |
+
## Class Reference
|
| 620 |
+
|
| 621 |
+
### Segmentation Methods
|
| 622 |
+
|
| 623 |
+
**Pixel-Perfect Segmentation:**
|
| 624 |
+
|
| 625 |
+
| Method | Description | Classes |
|
| 626 |
+
|--------|-------------|---------|
|
| 627 |
+
| `sam3` | Prompt-based (natural language) | Any text prompt |
|
| 628 |
+
### Segmentation Methods
|
| 629 |
+
|
| 630 |
+
| Method | Description | Classes |
|
| 631 |
+
|--------|-------------|---------|
|
| 632 |
+
| `sam3` | Prompt-based (natural language) | Any text prompt |
|
| 633 |
+
| `yolo` | YOLO instance segmentation | 80 COCO classes |
|
| 634 |
+
| `segformer` | Cityscapes semantic segmentation | 19 classes |
|
| 635 |
+
| `mask2former` | Swin-based panoptic/semantic | 133 COCO / 150 ADE20K |
|
| 636 |
+
| `maskrcnn` | ResNet50-FPN instance segmentation | 80 COCO classes |
|
| 637 |
+
| `fake_yolo` | Fast bbox-based (YOLO + ByteTrack) | 80 COCO classes |
|
| 638 |
+
| `fake_yolo_botsort` | Fast bbox-based (YOLO + BoTSORT) | 80 COCO classes |
|
| 639 |
+
| `fake_detr` | Fast bbox-based (DETR + ByteTrack) | 80 COCO classes |
|
| 640 |
+
| `fake_fasterrcnn` | Fast bbox-based (Faster R-CNN + ByteTrack) | 80 COCO classes |
|
| 641 |
+
| `fake_retinanet` | Fast bbox-based (RetinaNet + ByteTrack) | 80 COCO classes |
|
| 642 |
+
| `fake_fcos` | Fast bbox-based (FCOS + ByteTrack) | 80 COCO classes |
|
| 643 |
+
| `fake_deformable_detr` | Fast bbox-based (Deformable DETR + ByteTrack) | 80 COCO classes |
|
| 644 |
+
| `fake_grounding_dino` | Fast bbox-based (Grounding DINO + ByteTrack) | Requires prompt |
|
| 645 |
+
|
| 646 |
+
**Note:** `fake_*` methods create rectangular masks from detection bounding boxes with object tracking. Faster than pixel-perfect segmentation, suitable for video when precise boundaries aren't critical.
|
| 647 |
+
|
| 648 |
+
### Detection Methods
|
| 649 |
+
|
| 650 |
+
**Closed-Vocabulary (COCO pretrained):**
|
| 651 |
+
|
| 652 |
+
| Method | Description |
|
| 653 |
+
|--------|-------------|
|
| 654 |
+
| `yolo` | Ultralytics YOLO |
|
| 655 |
+
| `detr` | Facebook DETR |
|
| 656 |
+
| `faster_rcnn` | Faster R-CNN |
|
| 657 |
+
| `retinanet` | RetinaNet |
|
| 658 |
+
| `fcos` | FCOS |
|
| 659 |
+
| `ssd` | SSD300 |
|
| 660 |
+
| `deformable_detr` | Deformable DETR |
|
| 661 |
+
|
| 662 |
+
**Open-Vocabulary (requires `classes` parameter):**
|
| 663 |
+
|
| 664 |
+
| Method | Description |
|
| 665 |
+
|--------|-------------|
|
| 666 |
+
| `yolo_world` | YOLO-World |
|
| 667 |
+
| `grounding_dino` | Grounding DINO |
|
| 668 |
+
|
| 669 |
+
### COCO Classes (80)
|
| 670 |
+
|
| 671 |
+
```
|
| 672 |
+
person, bicycle, car, motorcycle, airplane, bus, train, truck, boat,
|
| 673 |
+
traffic light, fire hydrant, stop sign, parking meter, bench, bird, cat,
|
| 674 |
+
dog, horse, sheep, cow, elephant, bear, zebra, giraffe, backpack, umbrella,
|
| 675 |
+
handbag, tie, suitcase, frisbee, skis, snowboard, sports ball, kite,
|
| 676 |
+
baseball bat, baseball glove, skateboard, surfboard, tennis racket, bottle,
|
| 677 |
+
wine glass, cup, fork, knife, spoon, bowl, banana, apple, sandwich, orange,
|
| 678 |
+
broccoli, carrot, hot dog, pizza, donut, cake, chair, couch, potted plant,
|
| 679 |
+
bed, dining table, toilet, tv, laptop, mouse, remote, keyboard, cell phone,
|
| 680 |
+
microwave, oven, toaster, sink, refrigerator, book, clock, vase, scissors,
|
| 681 |
+
teddy bear, hair drier, toothbrush
|
| 682 |
+
```
|
| 683 |
+
|
| 684 |
+
### Cityscapes Classes (19)
|
| 685 |
+
|
| 686 |
+
```
|
| 687 |
+
road, sidewalk, building, wall, fence, pole, traffic light, traffic sign,
|
| 688 |
+
vegetation, terrain, sky, person, rider, car, truck, bus, train, motorcycle,
|
| 689 |
+
bicycle
|
| 690 |
+
```
|
| 691 |
+
|
| 692 |
+
---
|
| 693 |
+
|
| 694 |
+
## Error Handling
|
| 695 |
+
|
| 696 |
+
```python
|
| 697 |
+
try:
|
| 698 |
+
result = client.predict(
|
| 699 |
+
handle_file("image.jpg"),
|
| 700 |
+
...,
|
| 701 |
+
api_name="/endpoint"
|
| 702 |
+
)
|
| 703 |
+
except Exception as e:
|
| 704 |
+
print(f"API Error: {e}")
|
| 705 |
+
```
|
| 706 |
+
|
| 707 |
+
**Common Errors:**
|
| 708 |
+
|
| 709 |
+
| Error | Cause | Solution |
|
| 710 |
+
|-------|-------|----------|
|
| 711 |
+
| Validation error for ImageData | Missing `handle_file()` | Wrap file paths with `handle_file()` |
|
| 712 |
+
| File does not exist | Invalid path | Check file path is correct |
|
| 713 |
+
| Empty detection classes | Open-vocab detector without classes | Provide classes for `yolo_world`, `grounding_dino` |
|
| 714 |
+
| GPU quota exceeded | HF Spaces limit | Wait and retry (see below) |
|
| 715 |
+
|
| 716 |
+
---
|
| 717 |
+
|
| 718 |
+
## Handling GPU Quota on HF Spaces
|
| 719 |
+
|
| 720 |
+
When using Hugging Face Spaces with ZeroGPU, you may encounter quota limits:
|
| 721 |
+
|
| 722 |
+
```
|
| 723 |
+
You have exceeded your GPU quota (60s requested vs. 0s left). Try again in 0:05:30
|
| 724 |
+
```
|
| 725 |
+
|
| 726 |
+
### Automatic Retry with Backoff
|
| 727 |
+
|
| 728 |
+
```python
|
| 729 |
+
import time
|
| 730 |
+
import re
|
| 731 |
+
|
| 732 |
+
def extract_wait_time(error_msg):
|
| 733 |
+
"""Extract wait time from GPU quota error message."""
|
| 734 |
+
match = re.search(r'Try again in (\d+):(\d+)(?::(\d+))?', error_msg)
|
| 735 |
+
if match:
|
| 736 |
+
if match.group(3): # HH:MM:SS
|
| 737 |
+
return int(match.group(1)) * 3600 + int(match.group(2)) * 60 + int(match.group(3))
|
| 738 |
+
else: # MM:SS
|
| 739 |
+
return int(match.group(1)) * 60 + int(match.group(2))
|
| 740 |
+
return 60
|
| 741 |
+
|
| 742 |
+
def call_with_retry(client, *args, api_name, max_retries=5):
|
| 743 |
+
"""Call API with exponential backoff retry."""
|
| 744 |
+
delay = 10
|
| 745 |
+
|
| 746 |
+
for attempt in range(max_retries):
|
| 747 |
+
try:
|
| 748 |
+
return client.predict(*args, api_name=api_name)
|
| 749 |
+
except Exception as e:
|
| 750 |
+
error_msg = str(e)
|
| 751 |
+
if "exceeded your GPU quota" in error_msg:
|
| 752 |
+
wait_time = extract_wait_time(error_msg)
|
| 753 |
+
actual_delay = max(delay, wait_time + 5)
|
| 754 |
+
print(f"⏳ GPU quota exhausted. Waiting {actual_delay}s... (attempt {attempt + 1})")
|
| 755 |
+
time.sleep(actual_delay)
|
| 756 |
+
delay *= 2
|
| 757 |
+
else:
|
| 758 |
+
raise
|
| 759 |
+
raise Exception("Max retries reached")
|
| 760 |
+
|
| 761 |
+
# Usage
|
| 762 |
+
result = call_with_retry(
|
| 763 |
+
client,
|
| 764 |
+
handle_file("image.jpg"),
|
| 765 |
+
"car", "sam3", False, 4, 0.3, False, "yolo", "",
|
| 766 |
+
api_name="/process"
|
| 767 |
+
)
|
| 768 |
+
```
|
| 769 |
+
|
| 770 |
+
---
|
| 771 |
+
|
| 772 |
+
## Using with cURL
|
| 773 |
+
|
| 774 |
+
### Upload File First
|
| 775 |
+
|
| 776 |
+
```bash
|
| 777 |
+
# Upload image
|
| 778 |
+
FILE_URL=$(curl -s -X POST http://localhost:7860/upload \
|
| 779 |
+
-F "files=@image.jpg" | \
|
| 780 |
+
python3 -c "import sys, json; print(json.load(sys.stdin)[0])")
|
| 781 |
+
```
|
| 782 |
+
|
| 783 |
+
### Call Endpoints
|
| 784 |
+
|
| 785 |
+
```bash
|
| 786 |
+
# Segment
|
| 787 |
+
curl -X POST http://localhost:7860/api/segment \
|
| 788 |
+
-H "Content-Type: application/json" \
|
| 789 |
+
-d "{\"data\": [\"$FILE_URL\", \"car, person\", \"sam3\", false]}"
|
| 790 |
+
|
| 791 |
+
# Compress (no mask)
|
| 792 |
+
curl -X POST http://localhost:7860/api/compress \
|
| 793 |
+
-H "Content-Type: application/json" \
|
| 794 |
+
-d "{\"data\": [\"$FILE_URL\", null, 4, 0.3]}"
|
| 795 |
+
|
| 796 |
+
# Detect
|
| 797 |
+
curl -X POST http://localhost:7860/api/detect \
|
| 798 |
+
-H "Content-Type: application/json" \
|
| 799 |
+
-d "{\"data\": [\"$FILE_URL\", \"yolo\", \"\", 0.25, false]}"
|
| 800 |
+
|
| 801 |
+
# Full pipeline
|
| 802 |
+
curl -X POST http://localhost:7860/api/process \
|
| 803 |
+
-H "Content-Type: application/json" \
|
| 804 |
+
-d "{\"data\": [\"$FILE_URL\", \"car, person\", \"sam3\", 4, 0.3, true, \"yolo\", \"\"]}"
|
| 805 |
+
```
|
| 806 |
+
|
| 807 |
+
---
|
| 808 |
+
|
| 809 |
+
## Performance Guide
|
| 810 |
+
|
| 811 |
+
### Choosing Segmentation Methods
|
| 812 |
+
|
| 813 |
+
**Use Pixel-Perfect Segmentation when:**
|
| 814 |
+
- You need precise object boundaries
|
| 815 |
+
- Working with single images or small videos
|
| 816 |
+
- Quality is more important than speed
|
| 817 |
+
- Computing time/power is not constrained
|
| 818 |
+
|
| 819 |
+
**Use Fast Segmentation (fake_*) when:**
|
| 820 |
+
- Processing large videos or real-time streams
|
| 821 |
+
- Speed is critical (2-3x faster)
|
| 822 |
+
- Rectangular masks are acceptable
|
| 823 |
+
- Need temporal consistency (tracking maintains object IDs)
|
| 824 |
+
|
| 825 |
+
### Performance Benchmarks
|
| 826 |
+
|
| 827 |
+
**Video Processing (480p, 30 frames):**
|
| 828 |
+
|
| 829 |
+
| Method | Speed | Use Case |
|
| 830 |
+
|--------|-------|----------|
|
| 831 |
+
| `fake_yolo` | ~70 fps | Real-time video, fastest |
|
| 832 |
+
| `fake_yolo_botsort` | ~65 fps | Real-time with robust tracking |
|
| 833 |
+
| `fake_detr` | ~40 fps | Good speed + accuracy balance |
|
| 834 |
+
| `fake_fasterrcnn` | ~30 fps | Accurate detection |
|
| 835 |
+
| `yolo` (pixel-perfect) | ~30 fps | Instance segmentation |
|
| 836 |
+
| `sam3` | ~15 fps | Prompt-based, highest flexibility |
|
| 837 |
+
| `mask2former` | ~20 fps | Panoptic segmentation |
|
| 838 |
+
|
| 839 |
+
**Detection Performance (with batch support):**
|
| 840 |
+
|
| 841 |
+
| Detector | Single-Frame | Batch (30 frames) | Speedup |
|
| 842 |
+
|----------|--------------|-------------------|---------|
|
| 843 |
+
| YOLO26x | ~40 fps | ~70 fps | 1.75x |
|
| 844 |
+
| DETR | ~15 fps | ~40 fps | 2.67x |
|
| 845 |
+
| Faster R-CNN | ~12 fps | ~30 fps | 2.50x |
|
| 846 |
+
|
| 847 |
+
### Example: Fast Video Processing
|
| 848 |
+
|
| 849 |
+
```python
|
| 850 |
+
from gradio_client import Client, handle_file
|
| 851 |
+
import json
|
| 852 |
+
import time
|
| 853 |
+
|
| 854 |
+
client = Client("http://localhost:7860")
|
| 855 |
+
|
| 856 |
+
# Method 1: Fast fake segmentation (recommended for video)
|
| 857 |
+
start = time.time()
|
| 858 |
+
output1, stats1 = client.predict(
|
| 859 |
+
handle_file("long_video.mp4"),
|
| 860 |
+
"person, car",
|
| 861 |
+
"fake_yolo", # Fast detection + tracking
|
| 862 |
+
"static",
|
| 863 |
+
4,
|
| 864 |
+
0.3,
|
| 865 |
+
15.0,
|
| 866 |
+
500, 5, 30, False, "yolo", None,
|
| 867 |
+
api_name="/process_video"
|
| 868 |
+
)
|
| 869 |
+
fast_time = time.time() - start
|
| 870 |
+
|
| 871 |
+
# Method 2: Pixel-perfect segmentation
|
| 872 |
+
start = time.time()
|
| 873 |
+
output2, stats2 = client.predict(
|
| 874 |
+
handle_file("long_video.mp4"),
|
| 875 |
+
"person, car",
|
| 876 |
+
"yolo", # Pixel-perfect YOLO26x-seg
|
| 877 |
+
"static",
|
| 878 |
+
4,
|
| 879 |
+
0.3,
|
| 880 |
+
15.0,
|
| 881 |
+
500, 5, 30, False, "yolo", None,
|
| 882 |
+
api_name="/process_video"
|
| 883 |
+
)
|
| 884 |
+
perfect_time = time.time() - start
|
| 885 |
+
|
| 886 |
+
stats1_data = json.loads(stats1)
|
| 887 |
+
stats2_data = json.loads(stats2)
|
| 888 |
+
|
| 889 |
+
print(f"Fast segmentation: {fast_time:.2f}s")
|
| 890 |
+
print(f"Pixel-perfect: {perfect_time:.2f}s")
|
| 891 |
+
print(f"Speedup: {perfect_time/fast_time:.2f}x faster")
|
| 892 |
+
print(f"Compression ratio (fast): {stats1_data['compression_ratio']:.2f}x")
|
| 893 |
+
print(f"Compression ratio (perfect): {stats2_data['compression_ratio']:.2f}x")
|
| 894 |
+
```
|
| 895 |
+
|
| 896 |
+
### Example: Tracker Comparison
|
| 897 |
+
|
| 898 |
+
```python
|
| 899 |
+
# Test different trackers with same detector
|
| 900 |
+
trackers = {
|
| 901 |
+
"ByteTrack (default)": "fake_yolo",
|
| 902 |
+
"BoTSORT": "fake_yolo_botsort",
|
| 903 |
+
}
|
| 904 |
+
|
| 905 |
+
for name, method in trackers.items():
|
| 906 |
+
output, stats = client.predict(
|
| 907 |
+
handle_file("test_video.mp4"),
|
| 908 |
+
"person",
|
| 909 |
+
method,
|
| 910 |
+
"static",
|
| 911 |
+
4, 0.3, 15.0,
|
| 912 |
+
500, 5, 30, False, "yolo", None,
|
| 913 |
+
api_name="/process_video"
|
| 914 |
+
)
|
| 915 |
+
|
| 916 |
+
stats_data = json.loads(stats)
|
| 917 |
+
print(f"{name}: {stats_data['avg_roi_coverage']:.2f}% avg coverage")
|
| 918 |
+
```
|
| 919 |
+
|
| 920 |
+
---
|
| 921 |
+
|
| 922 |
+
## Example Scripts
|
| 923 |
+
|
| 924 |
+
### Batch Image Processing
|
| 925 |
+
|
| 926 |
+
```python
|
| 927 |
+
from gradio_client import Client, handle_file
|
| 928 |
+
from pathlib import Path
|
| 929 |
+
|
| 930 |
+
client = Client("http://localhost:7860")
|
| 931 |
+
output_dir = Path("compressed_output")
|
| 932 |
+
output_dir.mkdir(exist_ok=True)
|
| 933 |
+
|
| 934 |
+
for img_path in Path("images").glob("*.jpg"):
|
| 935 |
+
print(f"Processing {img_path.name}...")
|
| 936 |
+
|
| 937 |
+
compressed, mask, bpp, ratio, coverage, _ = client.predict(
|
| 938 |
+
handle_file(str(img_path)),
|
| 939 |
+
"car, person",
|
| 940 |
+
"sam3",
|
| 941 |
+
4, 0.3,
|
| 942 |
+
False, "", "",
|
| 943 |
+
api_name="/process"
|
| 944 |
+
)
|
| 945 |
+
|
| 946 |
+
# Save compressed image
|
| 947 |
+
output_path = output_dir / f"compressed_{img_path.name}"
|
| 948 |
+
with open(output_path, "wb") as f:
|
| 949 |
+
f.write(open(compressed, "rb").read())
|
| 950 |
+
|
| 951 |
+
print(f" BPP: {bpp:.4f}, Ratio: {ratio:.2f}x, ROI: {coverage*100:.2f}%")
|
| 952 |
+
```
|
| 953 |
+
|
| 954 |
+
### Video Processing with Mask Caching
|
| 955 |
+
|
| 956 |
+
```python
|
| 957 |
+
from gradio_client import Client, handle_file
|
| 958 |
+
import json
|
| 959 |
+
|
| 960 |
+
client = Client("http://localhost:7860")
|
| 961 |
+
video_path = "input_video.mp4"
|
| 962 |
+
|
| 963 |
+
# Step 1: Segment video (one-time cost)
|
| 964 |
+
mask_file, seg_stats = client.predict(
|
| 965 |
+
handle_file(video_path),
|
| 966 |
+
"person, car",
|
| 967 |
+
"sam3",
|
| 968 |
+
False, # return mask file
|
| 969 |
+
15.0,
|
| 970 |
+
api_name="/segment_video"
|
| 971 |
+
)
|
| 972 |
+
print(f"Segmented video, masks saved to: {mask_file}")
|
| 973 |
+
|
| 974 |
+
# Step 2: Compress with different settings, reusing masks
|
| 975 |
+
for quality in [3, 4, 5]:
|
| 976 |
+
compressed, comp_stats = client.predict(
|
| 977 |
+
handle_file(video_path),
|
| 978 |
+
mask_file, # reuse cached masks
|
| 979 |
+
quality,
|
| 980 |
+
0.3,
|
| 981 |
+
15.0,
|
| 982 |
+
api_name="/compress_video"
|
| 983 |
+
)
|
| 984 |
+
stats = json.loads(comp_stats)
|
| 985 |
+
print(f"Quality {quality}: {stats['compression_ratio']}x compression")
|
| 986 |
+
```
|
| 987 |
+
|
| 988 |
+
### Detection Comparison (Original vs Compressed)
|
| 989 |
+
|
| 990 |
+
```python
|
| 991 |
+
from gradio_client import Client, handle_file
|
| 992 |
+
import json
|
| 993 |
+
|
| 994 |
+
client = Client("http://localhost:7860")
|
| 995 |
+
image = "street_scene.jpg"
|
| 996 |
+
|
| 997 |
+
# Detect on original
|
| 998 |
+
_, dets_orig = client.predict(
|
| 999 |
+
handle_file(image), "yolo", "", 0.25, False,
|
| 1000 |
+
api_name="/detect"
|
| 1001 |
+
)
|
| 1002 |
+
orig_count = len(json.loads(dets_orig))
|
| 1003 |
+
print(f"Original: {orig_count} detections")
|
| 1004 |
+
|
| 1005 |
+
# Compress and detect
|
| 1006 |
+
compressed, _, bpp, ratio, _, dets_comp = client.predict(
|
| 1007 |
+
handle_file(image),
|
| 1008 |
+
"car, person, road",
|
| 1009 |
+
"sam3",
|
| 1010 |
+
4, 0.3,
|
| 1011 |
+
True, "yolo", "",
|
| 1012 |
+
api_name="/process"
|
| 1013 |
+
)
|
| 1014 |
+
comp_count = len(json.loads(dets_comp))
|
| 1015 |
+
|
| 1016 |
+
retention = comp_count / orig_count * 100 if orig_count else 0
|
| 1017 |
+
print(f"Compressed ({ratio:.2f}x): {comp_count} detections")
|
| 1018 |
+
print(f"Detection retention: {retention:.1f}%")
|
| 1019 |
+
```
|
| 1020 |
+
|
| 1021 |
+
---
|
| 1022 |
+
|
| 1023 |
+
## Additional Resources
|
| 1024 |
+
|
| 1025 |
+
- **Web UI**: Visit `http://localhost:7860` for interactive interface
|
| 1026 |
+
- **GitHub**: See repository for source code and examples
|
| 1027 |
+
- **Model Checkpoints**: Available in `checkpoints/` directory
|
| 1028 |
+
- **Test Images**: Sample images in `data/images/` directory
|
| 1029 |
+
|
README.md
CHANGED
|
@@ -1,13 +1,513 @@
|
|
| 1 |
---
|
| 2 |
-
title: Contextual Communication Demo
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 6.
|
| 8 |
-
python_version: '3.12'
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Contextual Communication Demo
|
| 3 |
+
emoji: "📡"
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: "6.2.0"
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# Contextual Communication Demo
|
| 13 |
+
|
| 14 |
+
An interactive demo for **contextual communication in bandwidth-degraded environments** (e.g., ISR collection from drones). The core idea is **context-aware compression**: transmit an extremely compact latent representation while ensuring the **decoded output remains useful for downstream decision-making** (e.g., object detection).
|
| 15 |
+
|
| 16 |
+
This repository implements **contextual spatial compression** for EO/IR-style imagery using an ROI-aware learned image compression model (TIC-style VAE) guided by segmentation masks.
|
| 17 |
+
|
| 18 |
+
## Features
|
| 19 |
+
|
| 20 |
+
- **Contextual (ROI) compression**: preserves fidelity in mission-relevant regions while aggressively compressing non-relevant background.
|
| 21 |
+
- **Mission-driven context extraction**: map a mission prompt to ROI masks via multiple segmentation strategies:
|
| 22 |
+
- Class-based segmentation (SegFormer / YOLO / Mask2Former / Mask R-CNN)
|
| 23 |
+
- Prompt/referring segmentation (SAM3)
|
| 24 |
+
- Optional object detection overlays to visualize task retention on the decoded image
|
| 25 |
+
- **Two operator knobs** for bandwidth adaptation:
|
| 26 |
+
- **Background preservation** ($\sigma$, 0.01–1.0): lower = more background degradation
|
| 27 |
+
- **Transmission quality** (checkpoint/lambda selection): higher = larger payload / better reconstruction
|
| 28 |
+
- **CLI tools** for segmentation, ROI compression, and before/after detection retention.
|
| 29 |
+
|
| 30 |
+
## Setup
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
pip install -r requirements.txt
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
Checkpoints are expected under `checkpoints/` (e.g., `checkpoints/tic_lambda_0.0483.pth.tar`).
|
| 37 |
+
|
| 38 |
+
By default, model weights/caches downloaded by detection/segmentation backends are also stored under `checkpoints/`:
|
| 39 |
+
|
| 40 |
+
- Hugging Face models under `checkpoints/hf/`
|
| 41 |
+
- Torch/torchvision weights under `checkpoints/torch/`
|
| 42 |
+
|
| 43 |
+
## Usage
|
| 44 |
+
|
| 45 |
+
### Interactive Demo (Hugging Face Spaces / Local)
|
| 46 |
+
|
| 47 |
+
This repo includes a Gradio app intended for Hugging Face Spaces (`app_file: app.py`). To run locally:
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
python app.py
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
In the UI:
|
| 54 |
+
|
| 55 |
+
- Enter a **Mission** and choose a **Context Extraction Method (ROI)**.
|
| 56 |
+
- Tune the two knobs to match bandwidth constraints:
|
| 57 |
+
- **Transmission quality** (checkpoint selection)
|
| 58 |
+
- **Background preservation** ($\sigma$)
|
| 59 |
+
- Optionally enable **object detection overlays**.
|
| 60 |
+
|
| 61 |
+
Note: the app includes a **Video** tab placeholder (inactive).
|
| 62 |
+
|
| 63 |
+
### CLI: Contextual Spatial Compression (Images)
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
python roi_compressor.py \
|
| 67 |
+
--input data/images/car/0016cf15fa4d4e16.jpg \
|
| 68 |
+
--output results/compressed.jpg \
|
| 69 |
+
--checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
|
| 70 |
+
--sigma 0.3 \
|
| 71 |
+
--seg-method yolo \
|
| 72 |
+
--seg-classes car \
|
| 73 |
+
--highlight
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
Key arguments:
|
| 77 |
+
|
| 78 |
+
- `--sigma`: background quality (lower = more compression)
|
| 79 |
+
- `--seg-method`: `segformer`, `yolo`, `mask2former`, `maskrcnn`
|
| 80 |
+
- `--load-mask`: bypass segmentation using a precomputed mask
|
| 81 |
+
|
| 82 |
+
### CLI: Segmentation Only
|
| 83 |
+
|
| 84 |
+
```bash
|
| 85 |
+
python roi_segmenter.py \
|
| 86 |
+
--input data/images/car/0016cf15fa4d4e16.jpg \
|
| 87 |
+
--output results/mask.png \
|
| 88 |
+
--method segformer \
|
| 89 |
+
--classes car \
|
| 90 |
+
--visualize
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
Prompt-based segmentation (SAM3):
|
| 94 |
+
|
| 95 |
+
```bash
|
| 96 |
+
python roi_segmenter.py \
|
| 97 |
+
--input data/images/car/0016cf15fa4d4e16.jpg \
|
| 98 |
+
--output results/mask.png \
|
| 99 |
+
--method sam3 \
|
| 100 |
+
--prompt "a car" \
|
| 101 |
+
--visualize
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
### CLI: Detection Retention (Before vs After)
|
| 105 |
+
|
| 106 |
+
Compare original vs already-compressed:
|
| 107 |
+
|
| 108 |
+
```bash
|
| 109 |
+
python roi_detection_eval.py \
|
| 110 |
+
--before data/images/car/0016cf15fa4d4e16.jpg \
|
| 111 |
+
--after results/compressed.jpg \
|
| 112 |
+
--detectors yolo fasterrcnn detr \
|
| 113 |
+
--viz-dir results/det_viz
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
Or generate the "after" image via ROI compression and then evaluate:
|
| 117 |
+
|
| 118 |
+
```bash
|
| 119 |
+
python roi_detection_eval.py \
|
| 120 |
+
--before data/images/car/0016cf15fa4d4e16.jpg \
|
| 121 |
+
--checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
|
| 122 |
+
--sigma 0.3 \
|
| 123 |
+
--seg-method yolo --seg-classes car \
|
| 124 |
+
--detectors yolo fasterrcnn \
|
| 125 |
+
--save-after results/after.jpg \
|
| 126 |
+
--viz-dir results/det_viz
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
Open-vocabulary example (YOLO-World):
|
| 130 |
+
|
| 131 |
+
```bash
|
| 132 |
+
python roi_detection_eval.py \
|
| 133 |
+
--before data/images/person/kodim04.png \
|
| 134 |
+
--checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
|
| 135 |
+
--sigma 0.3 \
|
| 136 |
+
--seg-method yolo --seg-classes person \
|
| 137 |
+
--detectors yolo_world \
|
| 138 |
+
--open-vocab-classes "person,car" \
|
| 139 |
+
--viz-dir results/det_viz
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
## Project Structure
|
| 143 |
+
|
| 144 |
+
```
|
| 145 |
+
.
|
| 146 |
+
├── app.py # Gradio demo (Hugging Face Spaces)
|
| 147 |
+
├── model_cache.py # Cache routing to `checkpoints/`
|
| 148 |
+
├── roi_compressor.py # CLI: contextual (ROI) image compression
|
| 149 |
+
├── roi_segmenter.py # CLI: ROI mask generation
|
| 150 |
+
├── roi_detection_eval.py # CLI: before/after detection retention
|
| 151 |
+
├── segmentation/ # Segmenters + factory
|
| 152 |
+
├── detection/ # Detectors + factory
|
| 153 |
+
├── vae/ # ROI-aware TIC model + compression utils
|
| 154 |
+
├── checkpoints/ # Compression checkpoints + model caches
|
| 155 |
+
├── data/images/ # Sample images
|
| 156 |
+
├── examples.sh
|
| 157 |
+
└── _segmentation_comparison.ipynb
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
## Modular API
|
| 161 |
+
|
| 162 |
+
Segmentation:
|
| 163 |
+
|
| 164 |
+
```python
|
| 165 |
+
from segmentation import create_segmenter
|
| 166 |
+
|
| 167 |
+
segmenter = create_segmenter("yolo", device="cuda", conf_threshold=0.3)
|
| 168 |
+
mask = segmenter(image, target_classes=["car", "person"])
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
Compression:
|
| 172 |
+
|
| 173 |
+
```python
|
| 174 |
+
from vae import load_checkpoint, compress_image
|
| 175 |
+
|
| 176 |
+
model = load_checkpoint("checkpoints/tic_lambda_0.0483.pth.tar", device="cuda")
|
| 177 |
+
out = compress_image(image, mask, model, sigma=0.3, device="cuda")
|
| 178 |
+
compressed = out["compressed"]
|
| 179 |
+
bpp = out["bpp"]
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
## Notes
|
| 183 |
+
|
| 184 |
+
- OpenCV is included via `opencv-python-headless` (recommended for server/Spaces environments).
|
| 185 |
+
- Some backends download weights on first use; caches are routed under `checkpoints/`.
|
| 186 |
+
- Output directories like `results/` are created at runtime by the CLIs.
|
| 187 |
+
---
|
| 188 |
+
title: Contextual Communication Demo
|
| 189 |
+
emoji: "📡"
|
| 190 |
+
colorFrom: blue
|
| 191 |
+
colorTo: purple
|
| 192 |
+
sdk: gradio
|
| 193 |
+
sdk_version: "6.2.0"
|
| 194 |
+
app_file: app.py
|
| 195 |
+
pinned: false
|
| 196 |
+
---
|
| 197 |
+
|
| 198 |
+
# Contextual Communication Demo
|
| 199 |
+
|
| 200 |
+
An interactive demo for **contextual communication in bandwidth-degraded environments** (e.g., ISR collection from drones). The core idea is **context-aware compression**: transmit an extremely compact latent representation while ensuring the **decoded output remains useful for downstream decision-making** (e.g., object detection).
|
| 201 |
+
|
| 202 |
+
This repository implements **contextual spatial compression** for EO/IR-style imagery using an ROI-aware learned image compression model (TIC-style VAE) guided by segmentation masks.
|
| 203 |
+
|
| 204 |
+
## Features
|
| 205 |
+
|
| 206 |
+
- **Contextual (ROI) compression**: Preserves fidelity in mission-relevant regions while aggressively compressing non-relevant background.
|
| 207 |
+
- **Mission-driven context extraction**: A mission prompt can be mapped to ROI masks via multiple segmentation strategies:
|
| 208 |
+
- **Class-based segmentation** (e.g., SegFormer / YOLO / Mask2Former / Mask R-CNN)
|
| 209 |
+
- **Prompt/referring segmentation** (SAM3)
|
| 210 |
+
- Optional **object detection overlays** to evaluate task retention on decoded outputs
|
| 211 |
+
- **Two operator knobs** for bandwidth adaptation:
|
| 212 |
+
- **Background preservation** (`sigma`, 0.01–1.0): lower = more background degradation
|
| 213 |
+
- **Overall quality level** (checkpoint/lambda selection): higher = larger file / better reconstruction
|
| 214 |
+
- **Visualization**: Compare input vs decoded output and optionally highlight context regions.
|
| 215 |
+
- **CLI tools**: Scripts for segmentation, ROI compression, and before/after detection eval.
|
| 216 |
+
|
| 217 |
+
## Setup
|
| 218 |
+
|
| 219 |
+
1. **Install Dependencies**:
|
| 220 |
+
```bash
|
| 221 |
+
pip install -r requirements.txt
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
2. **Model Checkpoints**:
|
| 225 |
+
Checkpoints are located in `checkpoints/` directory. Main checkpoint: `checkpoints/tic_lambda_0.0483.pth.tar`
|
| 226 |
+
|
| 227 |
+
By default, model weights/caches downloaded by detection/segmentation backends are also stored under `checkpoints/`
|
| 228 |
+
(Hugging Face models under `checkpoints/hf/`, torchvision weights under `checkpoints/torch/`).
|
| 229 |
+
|
| 230 |
+
## Usage
|
| 231 |
+
|
| 232 |
+
### Interactive Demo (Hugging Face Spaces / Local)
|
| 233 |
+
|
| 234 |
+
This repo includes a Gradio app intended for Hugging Face Spaces (`app_file: app.py`). To run locally:
|
| 235 |
+
|
| 236 |
+
```bash
|
| 237 |
+
python app.py
|
| 238 |
+
```
|
| 239 |
+
|
| 240 |
+
In the UI:
|
| 241 |
+
|
| 242 |
+
- Enter a **Mission** and choose a **Context Extraction Method (ROI)**.
|
| 243 |
+
- Tune the two knobs to match bandwidth constraints:
|
| 244 |
+
- **Transmission quality** (checkpoint selection)
|
| 245 |
+
- **Background preservation** ($\sigma$)
|
| 246 |
+
- Optionally enable **object detection overlays** to visualize task retention on the decoded image.
|
| 247 |
+
|
| 248 |
+
Note: the app includes a **Video** tab placeholder (inactive).
|
| 249 |
+
|
| 250 |
+
### Contextual Spatial Compression (Images)
|
| 251 |
+
|
| 252 |
+
Run the compression script with an input image:
|
| 253 |
+
|
| 254 |
+
```bash
|
| 255 |
+
python roi_compressor.py \
|
| 256 |
+
--input data/images/car/0016cf15fa4d4e16.jpg \
|
| 257 |
+
--output results/compressed.jpg \
|
| 258 |
+
--checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
|
| 259 |
+
--sigma 0.3 \
|
| 260 |
+
--seg-classes car \
|
| 261 |
+
--highlight
|
| 262 |
+
```
|
| 263 |
+
|
| 264 |
+
**Arguments:**
|
| 265 |
+
- `--input`: Path to input image.
|
| 266 |
+
- `--output`: Path to save compressed image.
|
| 267 |
+
- `--checkpoint`: Path to model checkpoint.
|
| 268 |
+
- `--sigma`: Background quality factor (lower = more compression). Default: 0.3.
|
| 269 |
+
- `--lambda`: Rate-distortion tradeoff parameter (default: 0.0483).
|
| 270 |
+
- `--seg-method`: Segmentation method (`segformer`, `yolo`, `mask2former`, `maskrcnn`). Default: `segformer`.
|
| 271 |
+
- `--seg-classes`: List of classes to treat as ROI (e.g., `car`, `person`).
|
| 272 |
+
- `--highlight`: Save a comparison grid with ROI highlighted.
|
| 273 |
+
|
| 274 |
+
Tip: you can bypass segmentation by providing `--load-mask`.
|
| 275 |
+
|
| 276 |
+
### Segmentation Only
|
| 277 |
+
|
| 278 |
+
Generate segmentation masks without compression:
|
| 279 |
+
|
| 280 |
+
```bash
|
| 281 |
+
python roi_segmenter.py \
|
| 282 |
+
--input data/images/car/0016cf15fa4d4e16.jpg \
|
| 283 |
+
--output results/mask.png \
|
| 284 |
+
--method segformer \
|
| 285 |
+
--classes car \
|
| 286 |
+
--visualize
|
| 287 |
+
```
|
| 288 |
+
|
| 289 |
+
Prompt-based segmentation (SAM3):
|
| 290 |
+
|
| 291 |
+
```bash
|
| 292 |
+
python roi_segmenter.py \
|
| 293 |
+
--input data/images/car/0016cf15fa4d4e16.jpg \
|
| 294 |
+
--output results/mask.png \
|
| 295 |
+
--method sam3 \
|
| 296 |
+
--prompt "a car" \
|
| 297 |
+
--visualize
|
| 298 |
+
```
|
| 299 |
+
|
| 300 |
+
## Project Structure
|
| 301 |
+
|
| 302 |
+
```
|
| 303 |
+
.
|
| 304 |
+
├── app.py # Gradio demo (Hugging Face Spaces)
|
| 305 |
+
├── README.md
|
| 306 |
+
├── requirements.txt
|
| 307 |
+
├── model_cache.py # Cache routing to `checkpoints/`
|
| 308 |
+
├── examples.sh # Example CLI commands
|
| 309 |
+
├── _segmentation_comparison.ipynb
|
| 310 |
+
├── roi_compressor.py # CLI: contextual (ROI) image compression
|
| 311 |
+
├── roi_segmenter.py # CLI: ROI mask generation
|
| 312 |
+
├── roi_detection_eval.py # CLI: before/after detection retention
|
| 313 |
+
├── checkpoints/ # Compression checkpoints + model caches
|
| 314 |
+
├── data/images/ # Sample images
|
| 315 |
+
├── segmentation/ # Segmenters + factory
|
| 316 |
+
├── detection/ # Detectors + factory
|
| 317 |
+
└── vae/ # ROI-aware TIC model + compression utils
|
| 318 |
+
```
|
| 319 |
+
|
| 320 |
+
## Modular API
|
| 321 |
+
|
| 322 |
+
### Using Segmentation Module
|
| 323 |
+
|
| 324 |
+
```python
|
| 325 |
+
from segmentation import create_segmenter
|
| 326 |
+
|
| 327 |
+
# Create a segmenter
|
| 328 |
+
segmenter = create_segmenter('yolo', device='cuda', conf_threshold=0.3)
|
| 329 |
+
|
| 330 |
+
# Segment image
|
| 331 |
+
mask = segmenter(image, target_classes=['car', 'person'])
|
| 332 |
+
```
|
| 333 |
+
|
| 334 |
+
### Using Compression Module
|
| 335 |
+
|
| 336 |
+
```python
|
| 337 |
+
from vae import load_checkpoint, compress_image
|
| 338 |
+
from PIL import Image
|
| 339 |
+
|
| 340 |
+
# Load model
|
| 341 |
+
model = load_checkpoint('checkpoints/tic_lambda_0.0483.pth.tar', device='cuda')
|
| 342 |
+
|
| 343 |
+
# Compress with ROI mask
|
| 344 |
+
result = compress_image(image, mask, model, sigma=0.3, device='cuda')
|
| 345 |
+
compressed_img = result['compressed']
|
| 346 |
+
bpp = result['bpp']
|
| 347 |
+
```
|
| 348 |
+
## Object Detection (New)
|
| 349 |
+
|
| 350 |
+
An extendable object detection module is available in `detection/` with multiple implemented backends:
|
| 351 |
+
|
| 352 |
+
- YOLO (Ultralytics)
|
| 353 |
+
- YOLO-World (Ultralytics, open-vocabulary)
|
| 354 |
+
- Faster R-CNN (torchvision)
|
| 355 |
+
- RetinaNet (torchvision)
|
| 356 |
+
- SSD (torchvision)
|
| 357 |
+
- FCOS (torchvision)
|
| 358 |
+
- DETR (transformers)
|
| 359 |
+
- Deformable DETR (transformers, if supported by your installed version)
|
| 360 |
+
- EfficientDet (optional, requires `effdet`)
|
| 361 |
+
- Grounding DINO (transformers, open-vocabulary)
|
| 362 |
+
|
| 363 |
+
Open-vocabulary detectors (YOLO-World / Grounding DINO) require text prompts/classes at runtime.
|
| 364 |
+
|
| 365 |
+
### Evaluate Detection Before/After ROI Compression
|
| 366 |
+
|
| 367 |
+
Compare an original image vs an already-compressed image:
|
| 368 |
+
|
| 369 |
+
```bash
|
| 370 |
+
python roi_detection_eval.py \
|
| 371 |
+
--before data/images/car/0016cf15fa4d4e16.jpg \
|
| 372 |
+
--after results/compressed.jpg \
|
| 373 |
+
--detectors yolo fasterrcnn detr \
|
| 374 |
+
--viz-dir results/det_viz
|
| 375 |
+
```
|
| 376 |
+
|
| 377 |
+
Or generate the "after" image via ROI compression and then evaluate:
|
| 378 |
+
|
| 379 |
+
```bash
|
| 380 |
+
python roi_detection_eval.py \
|
| 381 |
+
--before data/images/car/0016cf15fa4d4e16.jpg \
|
| 382 |
+
--checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
|
| 383 |
+
--sigma 0.3 \
|
| 384 |
+
--seg-method yolo --seg-classes car \
|
| 385 |
+
--detectors yolo fasterrcnn \
|
| 386 |
+
--save-after results/after.jpg \
|
| 387 |
+
--viz-dir results/det_viz
|
| 388 |
+
|
| 389 |
+
```
|
| 390 |
+
|
| 391 |
+
Open-vocabulary example (YOLO-World):
|
| 392 |
+
|
| 393 |
+
```bash
|
| 394 |
+
python roi_detection_eval.py \
|
| 395 |
+
--before data/images/person/kodim04.png \
|
| 396 |
+
--checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
|
| 397 |
+
--sigma 0.3 \
|
| 398 |
+
--seg-method yolo --seg-classes person \
|
| 399 |
+
--detectors yolo_world \
|
| 400 |
+
--open-vocab-classes "person,car" \
|
| 401 |
+
--viz-dir results/det_viz
|
| 402 |
+
```
|
| 403 |
+
|
| 404 |
+
Open-vocabulary example (Grounding DINO):
|
| 405 |
+
|
| 406 |
+
```bash
|
| 407 |
+
python roi_detection_eval.py \
|
| 408 |
+
--before data/images/car/0016cf15fa4d4e16.jpg \
|
| 409 |
+
--checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
|
| 410 |
+
--sigma 0.3 \
|
| 411 |
+
--seg-method yolo --seg-classes car \
|
| 412 |
+
--detectors grounding_dino \
|
| 413 |
+
--open-vocab-classes "car,person" \
|
| 414 |
+
--viz-dir results/det_viz
|
| 415 |
+
```
|
| 416 |
+
|
| 417 |
+
## Programmatic API
|
| 418 |
+
|
| 419 |
+
The application exposes a Gradio API for programmatic access to all features:
|
| 420 |
+
|
| 421 |
+
### Image API
|
| 422 |
+
- `/segment` - Segment image → mask or overlay
|
| 423 |
+
- `/compress` - Compress image with optional ROI mask
|
| 424 |
+
- `/detect` - Run object detection → JSON or overlay
|
| 425 |
+
- `/process` - Full pipeline: segment → compress → detect
|
| 426 |
+
|
| 427 |
+
### Video API (Buffered)
|
| 428 |
+
- `/segment_video` - Segment video → mask file or overlay video
|
| 429 |
+
- `/compress_video` - Compress video with optional cached masks
|
| 430 |
+
- `/detect_video` - Run detection on video → JSON or overlay video
|
| 431 |
+
- `/process_video` - Full pipeline with static/dynamic modes
|
| 432 |
+
|
| 433 |
+
### Video API (Streaming - NEW!)
|
| 434 |
+
- `/stream_process_video` - Stream compressed chunks progressively (HLS-style)
|
| 435 |
+
- `/stream_compress_video` - Stream chunks with pre-computed masks
|
| 436 |
+
|
| 437 |
+
**Key difference**: Streaming endpoints yield chunks as they're produced (low latency, ~1 second for first chunk) instead of buffering the entire video. Perfect for real-time streaming applications.
|
| 438 |
+
|
| 439 |
+
See [API.md](API.md) for complete documentation with examples.
|
| 440 |
+
See [STREAMING_API.md](STREAMING_API.md) for streaming API guide and comparison.
|
| 441 |
+
|
| 442 |
+
### Quick Example
|
| 443 |
+
|
| 444 |
+
```python
|
| 445 |
+
from gradio_client import Client, handle_file
|
| 446 |
+
|
| 447 |
+
client = Client("http://localhost:7860")
|
| 448 |
+
|
| 449 |
+
# Image: segment → compress → detect
|
| 450 |
+
compressed, mask, bpp, ratio, coverage, detections = client.predict(
|
| 451 |
+
handle_file("image.jpg"),
|
| 452 |
+
"car, person", # mission prompt
|
| 453 |
+
"sam3", # ROI method
|
| 454 |
+
4, # quality level (1-5)
|
| 455 |
+
0.3, # sigma (background preservation)
|
| 456 |
+
True, # run detection
|
| 457 |
+
"yolo", # detection method
|
| 458 |
+
"", # detection classes
|
| 459 |
+
api_name="/process"
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
# Video: streaming compression (chunk-by-chunk)
|
| 463 |
+
chunk_stream = client.submit(
|
| 464 |
+
handle_file("video.mp4"),
|
| 465 |
+
"person, car",
|
| 466 |
+
"sam3", "static",
|
| 467 |
+
4, 0.3, 15.0,
|
| 468 |
+
api_name="/stream_process_video"
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
for chunk_json in chunk_stream:
|
| 472 |
+
chunk = json.loads(chunk_json)
|
| 473 |
+
if chunk.get("status") == "complete":
|
| 474 |
+
break
|
| 475 |
+
print(f"Chunk {chunk['chunk_index']}: {len(chunk['frames'])} frames")
|
| 476 |
+
```
|
| 477 |
+
|
| 478 |
+
### JavaScript/Frontend Integration
|
| 479 |
+
|
| 480 |
+
**Yes, streaming works great with JavaScript!** The `@gradio/client` package fully supports async iterators for streaming:
|
| 481 |
+
|
| 482 |
+
```javascript
|
| 483 |
+
import { Client } from "@gradio/client";
|
| 484 |
+
|
| 485 |
+
const client = await Client.connect("http://localhost:7860");
|
| 486 |
+
const stream = client.submit("/stream_process_video", {
|
| 487 |
+
video_path: videoFile,
|
| 488 |
+
prompt: "person, car",
|
| 489 |
+
segmentation_method: "sam3",
|
| 490 |
+
mode: "static",
|
| 491 |
+
quality: 4,
|
| 492 |
+
sigma: 0.3,
|
| 493 |
+
output_fps: 15.0,
|
| 494 |
+
frame_format: "jpeg",
|
| 495 |
+
frame_quality: 85
|
| 496 |
+
});
|
| 497 |
+
|
| 498 |
+
for await (const msg of stream) {
|
| 499 |
+
const chunk = JSON.parse(msg.data);
|
| 500 |
+
if (chunk.status === "complete") break;
|
| 501 |
+
|
| 502 |
+
// Display frames immediately
|
| 503 |
+
displayFrame(`data:image/jpeg;base64,${chunk.frames[0]}`);
|
| 504 |
+
}
|
| 505 |
+
```
|
| 506 |
+
|
| 507 |
+
**Complete examples available:**
|
| 508 |
+
- [examples/streaming_demo.html](examples/streaming_demo.html) - Standalone HTML demo
|
| 509 |
+
- [examples/streaming_client.ts](examples/streaming_client.ts) - React/TypeScript/Vanilla JS examples
|
| 510 |
+
|
| 511 |
+
See [STREAMING_API.md](STREAMING_API.md) for detailed streaming guide.```
|
| 512 |
+
|
| 513 |
+
````
|
_segmentation_comparison.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
app.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
checkpoints/tic_lambda_0.0035.pth.tar
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:03f2627ba49ddc117c2235868e179e48cd7fbb343b0ab637f6bac575f447b44f
|
| 3 |
+
size 93931400
|
checkpoints/tic_lambda_0.013.pth.tar
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:821df1b78110d8bb6d01809a55b64ea1da8f17ed5e893d33ed258ee180054385
|
| 3 |
+
size 93922614
|
checkpoints/tic_lambda_0.025.pth.tar
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:867f34e9949cc02800d1c8eff48c4f67ff7436d0e8c8208aec572fd78d83affb
|
| 3 |
+
size 168141319
|
checkpoints/tic_lambda_0.0483.pth.tar
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9686371eed833dc2a6fdd7c07cf3695290c482f53b75db5be662892a45a71430
|
| 3 |
+
size 168229460
|
checkpoints/tic_lambda_0.0932.pth.tar
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0b7ddfa2f5f1b5082d087135a806f11ae65e8f8da5f724c5e510db264d3816c0
|
| 3 |
+
size 168141383
|
detection/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Object detection module.
|
| 2 |
+
|
| 3 |
+
Provides an extendable interface + factory similar to `segmentation/`.
|
| 4 |
+
|
| 5 |
+
Detectors:
|
| 6 |
+
- yolo (Ultralytics)
|
| 7 |
+
- yolo_world (Ultralytics YOLO-World; open-vocabulary)
|
| 8 |
+
- fasterrcnn, retinanet, ssd, fcos (torchvision)
|
| 9 |
+
- efficientdet (optional: effdet)
|
| 10 |
+
- detr, deformable_detr (transformers)
|
| 11 |
+
- grounding_dino (transformers; open-vocabulary)
|
| 12 |
+
|
| 13 |
+
Tracking:
|
| 14 |
+
- SimpleTracker (IoU-based multi-object tracking)
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from .base import BaseDetector, Detection
|
| 18 |
+
from .factory import create_detector, register_detector, get_available_detectors
|
| 19 |
+
from .tracker import SimpleTracker, Track, draw_tracks
|
| 20 |
+
|
| 21 |
+
__all__ = [
|
| 22 |
+
"BaseDetector",
|
| 23 |
+
"Detection",
|
| 24 |
+
"create_detector",
|
| 25 |
+
"register_detector",
|
| 26 |
+
"get_available_detectors",
|
| 27 |
+
"SimpleTracker",
|
| 28 |
+
"Track",
|
| 29 |
+
"draw_tracks",
|
| 30 |
+
]
|
detection/base.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Abstract base class for object detection models."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from abc import ABC, abstractmethod
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any, Dict, List, Optional, Sequence, Union
|
| 8 |
+
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass(frozen=True)
|
| 13 |
+
class Detection:
|
| 14 |
+
"""Single detection result."""
|
| 15 |
+
|
| 16 |
+
label: str
|
| 17 |
+
score: float
|
| 18 |
+
# [x1, y1, x2, y2] in pixel coordinates
|
| 19 |
+
bbox_xyxy: List[float]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class BaseDetector(ABC):
|
| 23 |
+
"""Common interface for all object detectors."""
|
| 24 |
+
|
| 25 |
+
def __init__(self, device: str = "cuda", **kwargs: Any):
|
| 26 |
+
self.device = device
|
| 27 |
+
self.model = None
|
| 28 |
+
self._is_loaded = False
|
| 29 |
+
|
| 30 |
+
@abstractmethod
|
| 31 |
+
def load_model(self) -> None:
|
| 32 |
+
"""Load weights and prepare for inference."""
|
| 33 |
+
|
| 34 |
+
@abstractmethod
|
| 35 |
+
def detect(
|
| 36 |
+
self,
|
| 37 |
+
image: Image.Image,
|
| 38 |
+
conf_threshold: float = 0.25,
|
| 39 |
+
**kwargs: Any,
|
| 40 |
+
) -> List[Detection]:
|
| 41 |
+
"""Run detection and return a list of detections."""
|
| 42 |
+
|
| 43 |
+
@abstractmethod
|
| 44 |
+
def get_available_classes(self) -> Union[List[str], Dict[str, int], None]:
|
| 45 |
+
"""Return the class list (or mapping) supported by this detector.
|
| 46 |
+
|
| 47 |
+
For open-vocabulary / prompt-based detectors, return None.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def ensure_loaded(self) -> None:
|
| 51 |
+
if not self._is_loaded:
|
| 52 |
+
self.load_model()
|
| 53 |
+
self._is_loaded = True
|
| 54 |
+
|
| 55 |
+
def __call__(
|
| 56 |
+
self,
|
| 57 |
+
image: Image.Image,
|
| 58 |
+
conf_threshold: float = 0.25,
|
| 59 |
+
**kwargs: Any,
|
| 60 |
+
) -> List[Detection]:
|
| 61 |
+
self.ensure_loaded()
|
| 62 |
+
return self.detect(image=image, conf_threshold=conf_threshold, **kwargs)
|
| 63 |
+
|
| 64 |
+
def _classes_to_list(self) -> List[str]:
|
| 65 |
+
avail = self.get_available_classes()
|
| 66 |
+
if avail is None:
|
| 67 |
+
return []
|
| 68 |
+
if isinstance(avail, dict):
|
| 69 |
+
return list(avail.keys())
|
| 70 |
+
if isinstance(avail, (list, tuple, set)):
|
| 71 |
+
return list(avail)
|
| 72 |
+
try:
|
| 73 |
+
return list(avail) # type: ignore[arg-type]
|
| 74 |
+
except Exception:
|
| 75 |
+
return []
|
| 76 |
+
|
| 77 |
+
def supports_label(self, label: str) -> bool:
|
| 78 |
+
"""Best-effort check whether the detector has a given label."""
|
| 79 |
+
|
| 80 |
+
classes = [c.lower() for c in self._classes_to_list()]
|
| 81 |
+
if not classes:
|
| 82 |
+
return True
|
| 83 |
+
return label.lower() in classes
|
detection/bytetrack.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ByteTrack: Multi-object tracker using high/low confidence detection matching.
|
| 2 |
+
|
| 3 |
+
Based on: https://github.com/ifzhang/ByteTrack
|
| 4 |
+
Simple, fast, and strong multi-object tracker without ReID features.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from typing import Dict, List, Optional, Tuple
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def linear_assignment(cost_matrix: np.ndarray) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
|
| 16 |
+
"""Solve linear assignment problem using LAP (Hungarian algorithm).
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
cost_matrix: Cost matrix (N x M)
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
matches: Array of (row, col) pairs
|
| 23 |
+
unmatched: Tuple of (unmatched_rows, unmatched_cols)
|
| 24 |
+
"""
|
| 25 |
+
try:
|
| 26 |
+
import lap
|
| 27 |
+
_, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=1e6)
|
| 28 |
+
matches = np.array([[idx, x[idx]] for idx in range(len(x)) if x[idx] >= 0])
|
| 29 |
+
unmatched_a = np.where(x < 0)[0]
|
| 30 |
+
unmatched_b = np.where(y < 0)[0]
|
| 31 |
+
return matches, (unmatched_a, unmatched_b)
|
| 32 |
+
except ImportError:
|
| 33 |
+
# Fallback to greedy matching
|
| 34 |
+
return _greedy_assignment(cost_matrix)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _greedy_assignment(cost_matrix: np.ndarray) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
|
| 38 |
+
"""Greedy assignment fallback."""
|
| 39 |
+
matches = []
|
| 40 |
+
matched_rows = set()
|
| 41 |
+
matched_cols = set()
|
| 42 |
+
|
| 43 |
+
n_rows, n_cols = cost_matrix.shape
|
| 44 |
+
pairs = []
|
| 45 |
+
for i in range(n_rows):
|
| 46 |
+
for j in range(n_cols):
|
| 47 |
+
pairs.append((cost_matrix[i, j], i, j))
|
| 48 |
+
|
| 49 |
+
pairs.sort(key=lambda x: x[0])
|
| 50 |
+
|
| 51 |
+
for cost, i, j in pairs:
|
| 52 |
+
if i not in matched_rows and j not in matched_cols:
|
| 53 |
+
matches.append([i, j])
|
| 54 |
+
matched_rows.add(i)
|
| 55 |
+
matched_cols.add(j)
|
| 56 |
+
|
| 57 |
+
unmatched_rows = np.array([i for i in range(n_rows) if i not in matched_rows])
|
| 58 |
+
unmatched_cols = np.array([j for j in range(n_cols) if j not in matched_cols])
|
| 59 |
+
|
| 60 |
+
return np.array(matches) if matches else np.empty((0, 2)), (unmatched_rows, unmatched_cols)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass
|
| 64 |
+
class STrack:
|
| 65 |
+
"""Single track for ByteTrack."""
|
| 66 |
+
|
| 67 |
+
track_id: int
|
| 68 |
+
label: str
|
| 69 |
+
tlbr: np.ndarray # [x1, y1, x2, y2]
|
| 70 |
+
score: float
|
| 71 |
+
|
| 72 |
+
# State
|
| 73 |
+
state: str = "tracked" # tracked, lost, removed
|
| 74 |
+
frame_id: int = 0
|
| 75 |
+
tracklet_len: int = 0
|
| 76 |
+
|
| 77 |
+
# History
|
| 78 |
+
history: List[Tuple[int, np.ndarray, float]] = field(default_factory=list)
|
| 79 |
+
|
| 80 |
+
@property
|
| 81 |
+
def tlwh(self) -> np.ndarray:
|
| 82 |
+
"""Top-left-width-height format."""
|
| 83 |
+
x1, y1, x2, y2 = self.tlbr
|
| 84 |
+
return np.array([x1, y1, x2 - x1, y2 - y1])
|
| 85 |
+
|
| 86 |
+
def activate(self, frame_id: int):
|
| 87 |
+
"""Activate new track."""
|
| 88 |
+
self.track_id = self.next_id()
|
| 89 |
+
self.tracklet_len = 0
|
| 90 |
+
self.state = "tracked"
|
| 91 |
+
self.frame_id = frame_id
|
| 92 |
+
self.history = [(frame_id, self.tlbr.copy(), self.score)]
|
| 93 |
+
|
| 94 |
+
def re_activate(self, new_track: 'STrack', frame_id: int):
|
| 95 |
+
"""Reactivate lost track."""
|
| 96 |
+
self.tlbr = new_track.tlbr
|
| 97 |
+
self.score = new_track.score
|
| 98 |
+
self.tracklet_len = 0
|
| 99 |
+
self.state = "tracked"
|
| 100 |
+
self.frame_id = frame_id
|
| 101 |
+
self.history.append((frame_id, self.tlbr.copy(), self.score))
|
| 102 |
+
|
| 103 |
+
def update(self, new_track: 'STrack', frame_id: int):
|
| 104 |
+
"""Update with new detection."""
|
| 105 |
+
self.tlbr = new_track.tlbr
|
| 106 |
+
self.score = new_track.score
|
| 107 |
+
self.tracklet_len += 1
|
| 108 |
+
self.state = "tracked"
|
| 109 |
+
self.frame_id = frame_id
|
| 110 |
+
self.history.append((frame_id, self.tlbr.copy(), self.score))
|
| 111 |
+
|
| 112 |
+
def mark_lost(self):
|
| 113 |
+
"""Mark as lost."""
|
| 114 |
+
self.state = "lost"
|
| 115 |
+
|
| 116 |
+
def mark_removed(self):
|
| 117 |
+
"""Mark as removed."""
|
| 118 |
+
self.state = "removed"
|
| 119 |
+
|
| 120 |
+
_count = 0
|
| 121 |
+
|
| 122 |
+
@staticmethod
|
| 123 |
+
def next_id() -> int:
|
| 124 |
+
STrack._count += 1
|
| 125 |
+
return STrack._count
|
| 126 |
+
|
| 127 |
+
@staticmethod
|
| 128 |
+
def reset_id():
|
| 129 |
+
STrack._count = 0
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class ByteTracker:
|
| 133 |
+
"""ByteTrack: Multi-object tracker using high/low confidence matching.
|
| 134 |
+
|
| 135 |
+
Key features:
|
| 136 |
+
- Two-stage matching with high/low confidence detections
|
| 137 |
+
- Hungarian algorithm for optimal assignment
|
| 138 |
+
- No ReID features needed (fast and simple)
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
def __init__(
|
| 142 |
+
self,
|
| 143 |
+
track_thresh: float = 0.5,
|
| 144 |
+
match_thresh: float = 0.8,
|
| 145 |
+
track_buffer: int = 30,
|
| 146 |
+
frame_rate: int = 30,
|
| 147 |
+
):
|
| 148 |
+
"""
|
| 149 |
+
Args:
|
| 150 |
+
track_thresh: High confidence threshold for first matching
|
| 151 |
+
match_thresh: IoU threshold for matching
|
| 152 |
+
track_buffer: Frames to keep lost tracks
|
| 153 |
+
frame_rate: Video frame rate
|
| 154 |
+
"""
|
| 155 |
+
self.track_thresh = track_thresh
|
| 156 |
+
self.match_thresh = match_thresh
|
| 157 |
+
self.track_buffer = track_buffer
|
| 158 |
+
self.frame_rate = frame_rate
|
| 159 |
+
|
| 160 |
+
self.tracked_stracks: List[STrack] = []
|
| 161 |
+
self.lost_stracks: List[STrack] = []
|
| 162 |
+
self.removed_stracks: List[STrack] = []
|
| 163 |
+
|
| 164 |
+
self.frame_id = 0
|
| 165 |
+
self.max_time_lost = int(frame_rate / 30.0 * track_buffer)
|
| 166 |
+
|
| 167 |
+
def reset(self):
|
| 168 |
+
"""Reset tracker state."""
|
| 169 |
+
self.tracked_stracks = []
|
| 170 |
+
self.lost_stracks = []
|
| 171 |
+
self.removed_stracks = []
|
| 172 |
+
self.frame_id = 0
|
| 173 |
+
STrack.reset_id()
|
| 174 |
+
|
| 175 |
+
def update(self, detections: List[Dict]) -> List[Dict]:
|
| 176 |
+
"""Update with new detections.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
detections: List of dicts with label, score, bbox_xyxy
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
List of track dicts with track_id, label, bbox_xyxy, score
|
| 183 |
+
"""
|
| 184 |
+
self.frame_id += 1
|
| 185 |
+
activated_stracks = []
|
| 186 |
+
refind_stracks = []
|
| 187 |
+
lost_stracks = []
|
| 188 |
+
removed_stracks = []
|
| 189 |
+
|
| 190 |
+
# Separate high and low confidence detections
|
| 191 |
+
remain_high_inds = [i for i, d in enumerate(detections) if d["score"] >= self.track_thresh]
|
| 192 |
+
remain_low_inds = [i for i, d in enumerate(detections) if d["score"] < self.track_thresh]
|
| 193 |
+
|
| 194 |
+
dets_high = [detections[i] for i in remain_high_inds]
|
| 195 |
+
dets_low = [detections[i] for i in remain_low_inds]
|
| 196 |
+
|
| 197 |
+
# Convert to STrack format
|
| 198 |
+
detections_high = [self._det_to_strack(d) for d in dets_high]
|
| 199 |
+
detections_low = [self._det_to_strack(d) for d in dets_low]
|
| 200 |
+
|
| 201 |
+
# ---- Step 1: Match high-confidence detections with tracked tracks ----
|
| 202 |
+
strack_pool = self.tracked_stracks
|
| 203 |
+
|
| 204 |
+
# Compute IoU distance
|
| 205 |
+
dists = self._iou_distance(strack_pool, detections_high)
|
| 206 |
+
|
| 207 |
+
# Hungarian matching
|
| 208 |
+
matches, u_track, u_detection = self._matching(dists, strack_pool, detections_high)
|
| 209 |
+
|
| 210 |
+
# Update matched tracks
|
| 211 |
+
for itracked, idet in matches:
|
| 212 |
+
track = strack_pool[itracked]
|
| 213 |
+
det = detections_high[idet]
|
| 214 |
+
track.update(det, self.frame_id)
|
| 215 |
+
activated_stracks.append(track)
|
| 216 |
+
|
| 217 |
+
# ---- Step 2: Match unmatched tracks with low-confidence detections ----
|
| 218 |
+
detections_low = [detections_low[i] for i in u_detection if i < len(detections_low)]
|
| 219 |
+
r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == "tracked"]
|
| 220 |
+
|
| 221 |
+
dists = self._iou_distance(r_tracked_stracks, detections_low)
|
| 222 |
+
matches, u_track, u_detection_low = self._matching(dists, r_tracked_stracks, detections_low)
|
| 223 |
+
|
| 224 |
+
for itracked, idet in matches:
|
| 225 |
+
track = r_tracked_stracks[itracked]
|
| 226 |
+
det = detections_low[idet]
|
| 227 |
+
track.update(det, self.frame_id)
|
| 228 |
+
activated_stracks.append(track)
|
| 229 |
+
|
| 230 |
+
# Unmatched tracks become lost
|
| 231 |
+
for it in u_track:
|
| 232 |
+
track = r_tracked_stracks[it]
|
| 233 |
+
track.mark_lost()
|
| 234 |
+
lost_stracks.append(track)
|
| 235 |
+
|
| 236 |
+
# ---- Step 3: Match unmatched high-confidence detections with lost tracks ----
|
| 237 |
+
detections_high_second = [detections_high[i] for i in u_detection]
|
| 238 |
+
dists = self._iou_distance(self.lost_stracks, detections_high_second)
|
| 239 |
+
matches, u_lost, u_detection = self._matching(dists, self.lost_stracks, detections_high_second)
|
| 240 |
+
|
| 241 |
+
for ilost, idet in matches:
|
| 242 |
+
track = self.lost_stracks[ilost]
|
| 243 |
+
det = detections_high_second[idet]
|
| 244 |
+
track.re_activate(det, self.frame_id)
|
| 245 |
+
refind_stracks.append(track)
|
| 246 |
+
|
| 247 |
+
# ---- Step 4: Initialize new tracks ----
|
| 248 |
+
for inew in u_detection:
|
| 249 |
+
track = detections_high_second[inew]
|
| 250 |
+
if track.score >= self.track_thresh:
|
| 251 |
+
track.activate(self.frame_id)
|
| 252 |
+
activated_stracks.append(track)
|
| 253 |
+
|
| 254 |
+
# ---- Step 5: Remove long-lost tracks ----
|
| 255 |
+
for track in self.lost_stracks:
|
| 256 |
+
if self.frame_id - track.frame_id > self.max_time_lost:
|
| 257 |
+
track.mark_removed()
|
| 258 |
+
removed_stracks.append(track)
|
| 259 |
+
|
| 260 |
+
# Update state lists
|
| 261 |
+
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == "tracked"]
|
| 262 |
+
self.tracked_stracks = activated_stracks + refind_stracks
|
| 263 |
+
self.lost_stracks = [t for t in self.lost_stracks if t.state == "lost"]
|
| 264 |
+
self.lost_stracks.extend(lost_stracks)
|
| 265 |
+
self.lost_stracks = [t for t in self.lost_stracks if t not in removed_stracks]
|
| 266 |
+
self.removed_stracks.extend(removed_stracks)
|
| 267 |
+
|
| 268 |
+
# Convert to output format
|
| 269 |
+
return self._stracks_to_output(self.tracked_stracks)
|
| 270 |
+
|
| 271 |
+
def _det_to_strack(self, det: Dict) -> STrack:
|
| 272 |
+
"""Convert detection dict to STrack."""
|
| 273 |
+
return STrack(
|
| 274 |
+
track_id=-1,
|
| 275 |
+
label=det["label"],
|
| 276 |
+
tlbr=np.array(det["bbox_xyxy"]),
|
| 277 |
+
score=det["score"],
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
def _iou_distance(self, atracks: List[STrack], btracks: List[STrack]) -> np.ndarray:
|
| 281 |
+
"""Compute IoU distance matrix."""
|
| 282 |
+
if not atracks or not btracks:
|
| 283 |
+
return np.zeros((len(atracks), len(btracks)))
|
| 284 |
+
|
| 285 |
+
atlbrs = np.array([track.tlbr for track in atracks])
|
| 286 |
+
btlbrs = np.array([track.tlbr for track in btracks])
|
| 287 |
+
|
| 288 |
+
ious = self._batch_iou(atlbrs, btlbrs)
|
| 289 |
+
cost_matrix = 1 - ious
|
| 290 |
+
|
| 291 |
+
return cost_matrix
|
| 292 |
+
|
| 293 |
+
def _batch_iou(self, boxes_a: np.ndarray, boxes_b: np.ndarray) -> np.ndarray:
|
| 294 |
+
"""Batch IoU computation."""
|
| 295 |
+
area_a = (boxes_a[:, 2] - boxes_a[:, 0]) * (boxes_a[:, 3] - boxes_a[:, 1])
|
| 296 |
+
area_b = (boxes_b[:, 2] - boxes_b[:, 0]) * (boxes_b[:, 3] - boxes_b[:, 1])
|
| 297 |
+
|
| 298 |
+
iw = np.minimum(boxes_a[:, None, 2], boxes_b[:, 2]) - np.maximum(boxes_a[:, None, 0], boxes_b[:, 0])
|
| 299 |
+
ih = np.minimum(boxes_a[:, None, 3], boxes_b[:, 3]) - np.maximum(boxes_a[:, None, 1], boxes_b[:, 1])
|
| 300 |
+
|
| 301 |
+
iw = np.maximum(iw, 0)
|
| 302 |
+
ih = np.maximum(ih, 0)
|
| 303 |
+
|
| 304 |
+
inter = iw * ih
|
| 305 |
+
union = area_a[:, None] + area_b - inter
|
| 306 |
+
|
| 307 |
+
ious = inter / np.maximum(union, 1e-6)
|
| 308 |
+
return ious
|
| 309 |
+
|
| 310 |
+
def _matching(
|
| 311 |
+
self,
|
| 312 |
+
dists: np.ndarray,
|
| 313 |
+
atracks: List[STrack],
|
| 314 |
+
btracks: List[STrack],
|
| 315 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 316 |
+
"""Match tracks using Hungarian algorithm."""
|
| 317 |
+
if not atracks or not btracks:
|
| 318 |
+
return np.empty((0, 2), dtype=int), np.arange(len(atracks)), np.arange(len(btracks))
|
| 319 |
+
|
| 320 |
+
# Filter by threshold
|
| 321 |
+
dists[dists > 1 - self.match_thresh] = 1e6
|
| 322 |
+
|
| 323 |
+
matches, (u_track, u_detection) = linear_assignment(dists)
|
| 324 |
+
|
| 325 |
+
return matches, u_track, u_detection
|
| 326 |
+
|
| 327 |
+
def _stracks_to_output(self, stracks: List[STrack]) -> List[Dict]:
|
| 328 |
+
"""Convert STracks to output dict format."""
|
| 329 |
+
result = []
|
| 330 |
+
for track in stracks:
|
| 331 |
+
result.append({
|
| 332 |
+
"track_id": track.track_id,
|
| 333 |
+
"label": track.label,
|
| 334 |
+
"bbox_xyxy": track.tlbr.tolist(),
|
| 335 |
+
"score": track.score,
|
| 336 |
+
"frame_id": track.frame_id,
|
| 337 |
+
"tracklet_len": track.tracklet_len,
|
| 338 |
+
})
|
| 339 |
+
return result
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
class BoTSORT(ByteTracker):
|
| 343 |
+
"""BoTSORT: ByteTrack with camera motion compensation and ReID.
|
| 344 |
+
|
| 345 |
+
For simplicity, this is a lightweight version without ReID features.
|
| 346 |
+
Adds Kalman filter for state prediction over ByteTrack.
|
| 347 |
+
"""
|
| 348 |
+
|
| 349 |
+
def __init__(
|
| 350 |
+
self,
|
| 351 |
+
track_thresh: float = 0.5,
|
| 352 |
+
match_thresh: float = 0.8,
|
| 353 |
+
track_buffer: int = 30,
|
| 354 |
+
frame_rate: int = 30,
|
| 355 |
+
):
|
| 356 |
+
super().__init__(track_thresh, match_thresh, track_buffer, frame_rate)
|
| 357 |
+
# BoTSORT would add Kalman filter here, but for detection-based tracking
|
| 358 |
+
# we keep it simple and inherit ByteTrack behavior
|
detection/detr.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DETR and Deformable DETR via Hugging Face Transformers."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Dict, List, Optional, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
from .base import BaseDetector, Detection
|
| 11 |
+
from model_cache import hf_cache_dir, ensure_default_checkpoint_dirs
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DETRDetector(BaseDetector):
|
| 15 |
+
def __init__(self, device: str = "cuda", model_name: str = "facebook/detr-resnet-50", **kwargs):
|
| 16 |
+
super().__init__(device=device, **kwargs)
|
| 17 |
+
self.model_name = model_name
|
| 18 |
+
self._device = torch.device(device if torch.cuda.is_available() or device == "cpu" else "cpu")
|
| 19 |
+
self.processor = None
|
| 20 |
+
self._id2label: Dict[int, str] = {}
|
| 21 |
+
|
| 22 |
+
def load_model(self) -> None:
|
| 23 |
+
from transformers import DetrForObjectDetection, DetrImageProcessor
|
| 24 |
+
|
| 25 |
+
ensure_default_checkpoint_dirs()
|
| 26 |
+
cache_dir = str(hf_cache_dir())
|
| 27 |
+
self.processor = DetrImageProcessor.from_pretrained(self.model_name, cache_dir=cache_dir)
|
| 28 |
+
self.model = DetrForObjectDetection.from_pretrained(self.model_name, cache_dir=cache_dir).to(self._device).eval()
|
| 29 |
+
self._id2label = dict(getattr(self.model.config, "id2label", {}) or {})
|
| 30 |
+
|
| 31 |
+
def get_available_classes(self) -> Union[List[str], Dict[str, int], None]:
|
| 32 |
+
if not self._id2label:
|
| 33 |
+
return None
|
| 34 |
+
return {name: int(i) for i, name in self._id2label.items()}
|
| 35 |
+
|
| 36 |
+
def detect(self, image: Image.Image, conf_threshold: float = 0.25, **kwargs) -> List[Detection]:
|
| 37 |
+
assert self.model is not None and self.processor is not None
|
| 38 |
+
|
| 39 |
+
img = image.convert("RGB")
|
| 40 |
+
inputs = self.processor(images=img, return_tensors="pt").to(self._device)
|
| 41 |
+
|
| 42 |
+
with torch.no_grad():
|
| 43 |
+
outputs = self.model(**inputs)
|
| 44 |
+
|
| 45 |
+
target_sizes = torch.tensor([img.size[::-1]], device=self._device) # (h, w)
|
| 46 |
+
results = self.processor.post_process_object_detection(
|
| 47 |
+
outputs,
|
| 48 |
+
threshold=float(conf_threshold),
|
| 49 |
+
target_sizes=target_sizes,
|
| 50 |
+
)[0]
|
| 51 |
+
|
| 52 |
+
dets: List[Detection] = []
|
| 53 |
+
for score, label_id, box in zip(results["scores"], results["labels"], results["boxes"]):
|
| 54 |
+
lid = int(label_id)
|
| 55 |
+
label = self._id2label.get(lid, str(lid))
|
| 56 |
+
b = [float(v) for v in box.tolist()]
|
| 57 |
+
dets.append(Detection(label=label, score=float(score), bbox_xyxy=b))
|
| 58 |
+
return dets
|
| 59 |
+
|
| 60 |
+
def detect_batch(self, images: List[Image.Image], conf_threshold: float = 0.25, **kwargs) -> List[List[Detection]]:
|
| 61 |
+
"""Batch detection for video processing - faster than frame-by-frame.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
images: List of PIL Images
|
| 65 |
+
conf_threshold: Confidence threshold
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
List of detection lists, one per image
|
| 69 |
+
"""
|
| 70 |
+
assert self.model is not None and self.processor is not None
|
| 71 |
+
|
| 72 |
+
if not images:
|
| 73 |
+
return []
|
| 74 |
+
|
| 75 |
+
# Convert all images to RGB
|
| 76 |
+
imgs = [img.convert("RGB") for img in images]
|
| 77 |
+
|
| 78 |
+
# Batch process images
|
| 79 |
+
inputs = self.processor(images=imgs, return_tensors="pt").to(self._device)
|
| 80 |
+
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
outputs = self.model(**inputs)
|
| 83 |
+
|
| 84 |
+
# Target sizes for each image (h, w)
|
| 85 |
+
target_sizes = torch.tensor([img.size[::-1] for img in imgs], device=self._device)
|
| 86 |
+
results = self.processor.post_process_object_detection(
|
| 87 |
+
outputs,
|
| 88 |
+
threshold=float(conf_threshold),
|
| 89 |
+
target_sizes=target_sizes,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Parse results for each image
|
| 93 |
+
all_detections = []
|
| 94 |
+
for result in results:
|
| 95 |
+
frame_dets = []
|
| 96 |
+
for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]):
|
| 97 |
+
lid = int(label_id)
|
| 98 |
+
label = self._id2label.get(lid, str(lid))
|
| 99 |
+
b = [float(v) for v in box.tolist()]
|
| 100 |
+
frame_dets.append(Detection(label=label, score=float(score), bbox_xyxy=b))
|
| 101 |
+
all_detections.append(frame_dets)
|
| 102 |
+
|
| 103 |
+
return all_detections
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class DeformableDETRDetector(BaseDetector):
|
| 107 |
+
"""Deformable DETR wrapper.
|
| 108 |
+
|
| 109 |
+
This relies on transformers' Deformable DETR implementations and checkpoints.
|
| 110 |
+
If your transformers build doesn't include Deformable DETR, this will raise.
|
| 111 |
+
|
| 112 |
+
Default checkpoint is `SenseTime/deformable-detr`.
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
def __init__(self, device: str = "cuda", model_name: str = "SenseTime/deformable-detr", **kwargs):
|
| 116 |
+
super().__init__(device=device, **kwargs)
|
| 117 |
+
self.model_name = model_name
|
| 118 |
+
self._device = torch.device(device if torch.cuda.is_available() or device == "cpu" else "cpu")
|
| 119 |
+
self.processor = None
|
| 120 |
+
self._id2label: Dict[int, str] = {}
|
| 121 |
+
|
| 122 |
+
def load_model(self) -> None:
|
| 123 |
+
try:
|
| 124 |
+
from transformers import AutoImageProcessor, AutoModelForObjectDetection
|
| 125 |
+
except Exception as e:
|
| 126 |
+
raise ImportError("transformers is required for Deformable DETR") from e
|
| 127 |
+
|
| 128 |
+
ensure_default_checkpoint_dirs()
|
| 129 |
+
cache_dir = str(hf_cache_dir())
|
| 130 |
+
self.processor = AutoImageProcessor.from_pretrained(self.model_name, cache_dir=cache_dir)
|
| 131 |
+
self.model = AutoModelForObjectDetection.from_pretrained(self.model_name, cache_dir=cache_dir).to(self._device).eval()
|
| 132 |
+
self._id2label = dict(getattr(self.model.config, "id2label", {}) or {})
|
| 133 |
+
|
| 134 |
+
def get_available_classes(self) -> Union[List[str], Dict[str, int], None]:
|
| 135 |
+
if not self._id2label:
|
| 136 |
+
return None
|
| 137 |
+
return {name: int(i) for i, name in self._id2label.items()}
|
| 138 |
+
|
| 139 |
+
def detect(self, image: Image.Image, conf_threshold: float = 0.25, **kwargs) -> List[Detection]:
|
| 140 |
+
assert self.model is not None and self.processor is not None
|
| 141 |
+
|
| 142 |
+
img = image.convert("RGB")
|
| 143 |
+
inputs = self.processor(images=img, return_tensors="pt").to(self._device)
|
| 144 |
+
|
| 145 |
+
with torch.no_grad():
|
| 146 |
+
outputs = self.model(**inputs)
|
| 147 |
+
|
| 148 |
+
# Prefer standard post_process if provided
|
| 149 |
+
if hasattr(self.processor, "post_process_object_detection"):
|
| 150 |
+
target_sizes = torch.tensor([img.size[::-1]], device=self._device)
|
| 151 |
+
results = self.processor.post_process_object_detection(
|
| 152 |
+
outputs,
|
| 153 |
+
threshold=float(conf_threshold),
|
| 154 |
+
target_sizes=target_sizes,
|
| 155 |
+
)[0]
|
| 156 |
+
|
| 157 |
+
dets: List[Detection] = []
|
| 158 |
+
for score, label_id, box in zip(results["scores"], results["labels"], results["boxes"]):
|
| 159 |
+
lid = int(label_id)
|
| 160 |
+
label = self._id2label.get(lid, str(lid))
|
| 161 |
+
b = [float(v) for v in box.tolist()]
|
| 162 |
+
dets.append(Detection(label=label, score=float(score), bbox_xyxy=b))
|
| 163 |
+
return dets
|
| 164 |
+
|
| 165 |
+
# Fallback: no post_process available
|
| 166 |
+
raise RuntimeError("This Deformable DETR processor does not support post_process_object_detection")
|
| 167 |
+
|
| 168 |
+
def detect_batch(self, images: List[Image.Image], conf_threshold: float = 0.25, **kwargs) -> List[List[Detection]]:
|
| 169 |
+
"""Batch detection for video processing - faster than frame-by-frame.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
images: List of PIL Images
|
| 173 |
+
conf_threshold: Confidence threshold
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
List of detection lists, one per image
|
| 177 |
+
"""
|
| 178 |
+
assert self.model is not None and self.processor is not None
|
| 179 |
+
|
| 180 |
+
if not images:
|
| 181 |
+
return []
|
| 182 |
+
|
| 183 |
+
# Convert all images to RGB
|
| 184 |
+
imgs = [img.convert("RGB") for img in images]
|
| 185 |
+
|
| 186 |
+
# Batch process images
|
| 187 |
+
inputs = self.processor(images=imgs, return_tensors="pt").to(self._device)
|
| 188 |
+
|
| 189 |
+
with torch.no_grad():
|
| 190 |
+
outputs = self.model(**inputs)
|
| 191 |
+
|
| 192 |
+
# Post-process requires method support
|
| 193 |
+
if not hasattr(self.processor, "post_process_object_detection"):
|
| 194 |
+
raise RuntimeError("This Deformable DETR processor does not support post_process_object_detection")
|
| 195 |
+
|
| 196 |
+
# Target sizes for each image (h, w)
|
| 197 |
+
target_sizes = torch.tensor([img.size[::-1] for img in imgs], device=self._device)
|
| 198 |
+
results = self.processor.post_process_object_detection(
|
| 199 |
+
outputs,
|
| 200 |
+
threshold=float(conf_threshold),
|
| 201 |
+
target_sizes=target_sizes,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Parse results for each image
|
| 205 |
+
all_detections = []
|
| 206 |
+
for result in results:
|
| 207 |
+
frame_dets = []
|
| 208 |
+
for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]):
|
| 209 |
+
lid = int(label_id)
|
| 210 |
+
label = self._id2label.get(lid, str(lid))
|
| 211 |
+
b = [float(v) for v in box.tolist()]
|
| 212 |
+
frame_dets.append(Detection(label=label, score=float(score), bbox_xyxy=b))
|
| 213 |
+
all_detections.append(frame_dets)
|
| 214 |
+
|
| 215 |
+
return all_detections
|
detection/factory.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Factory + registry for object detectors."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Dict, List, Type
|
| 6 |
+
|
| 7 |
+
from .base import BaseDetector
|
| 8 |
+
from .yolo import YOLODetector
|
| 9 |
+
from .torchvision_detectors import (
|
| 10 |
+
FasterRCNNDetector,
|
| 11 |
+
RetinaNetDetector,
|
| 12 |
+
SSDDetector,
|
| 13 |
+
EfficientDetDetector,
|
| 14 |
+
FCOSDetector,
|
| 15 |
+
)
|
| 16 |
+
from .detr import DETRDetector, DeformableDETRDetector
|
| 17 |
+
from .grounding_dino import GroundingDINODetector
|
| 18 |
+
from .yolo_world import YOLOWorldDetector
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
DETECTOR_REGISTRY: Dict[str, Type[BaseDetector]] = {
|
| 22 |
+
"yolo": YOLODetector,
|
| 23 |
+
"yolo_world": YOLOWorldDetector,
|
| 24 |
+
"fasterrcnn": FasterRCNNDetector,
|
| 25 |
+
"retinanet": RetinaNetDetector,
|
| 26 |
+
"ssd": SSDDetector,
|
| 27 |
+
"efficientdet": EfficientDetDetector,
|
| 28 |
+
"fcos": FCOSDetector,
|
| 29 |
+
"detr": DETRDetector,
|
| 30 |
+
"deformable_detr": DeformableDETRDetector,
|
| 31 |
+
"grounding_dino": GroundingDINODetector,
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def register_detector(name: str, detector_class: Type[BaseDetector]) -> None:
|
| 36 |
+
if not issubclass(detector_class, BaseDetector):
|
| 37 |
+
raise ValueError(f"{detector_class} must extend BaseDetector")
|
| 38 |
+
DETECTOR_REGISTRY[name.lower()] = detector_class
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def create_detector(method: str, device: str = "cuda", **kwargs) -> BaseDetector:
|
| 42 |
+
method_lower = method.lower()
|
| 43 |
+
if method_lower not in DETECTOR_REGISTRY:
|
| 44 |
+
available = ", ".join(sorted(DETECTOR_REGISTRY.keys()))
|
| 45 |
+
raise ValueError(f"Unknown detector: '{method}'. Available: {available}")
|
| 46 |
+
return DETECTOR_REGISTRY[method_lower](device=device, **kwargs)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_available_detectors() -> List[str]:
|
| 50 |
+
return sorted(DETECTOR_REGISTRY.keys())
|
detection/grounding_dino.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Grounding DINO open-vocabulary object detection via Hugging Face Transformers."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, List, Optional, Sequence, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
from .base import BaseDetector, Detection
|
| 11 |
+
from model_cache import hf_cache_dir, ensure_default_checkpoint_dirs
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _normalize_prompts(prompts: Optional[Union[str, Sequence[str]]]) -> Optional[List[str]]:
|
| 15 |
+
if prompts is None:
|
| 16 |
+
return None
|
| 17 |
+
if isinstance(prompts, str):
|
| 18 |
+
# allow comma-separated convenience
|
| 19 |
+
parts = [p.strip() for p in prompts.split(",")]
|
| 20 |
+
parts = [p for p in parts if p]
|
| 21 |
+
return parts or None
|
| 22 |
+
out = [str(p).strip() for p in prompts]
|
| 23 |
+
out = [p for p in out if p]
|
| 24 |
+
return out or None
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class GroundingDINODetector(BaseDetector):
|
| 28 |
+
"""Grounding DINO wrapper.
|
| 29 |
+
|
| 30 |
+
Usage:
|
| 31 |
+
detector = create_detector('grounding_dino', device='cuda')
|
| 32 |
+
dets = detector(image, prompts=['person', 'car'])
|
| 33 |
+
|
| 34 |
+
Notes:
|
| 35 |
+
- This is an open-vocabulary model; `get_available_classes()` returns None.
|
| 36 |
+
- Pass prompts via `prompts=` or `classes=`; you may also pass raw `text=`.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
device: str = "cuda",
|
| 42 |
+
model_name: str = "IDEA-Research/grounding-dino-base",
|
| 43 |
+
**kwargs: Any,
|
| 44 |
+
):
|
| 45 |
+
super().__init__(device=device, **kwargs)
|
| 46 |
+
self.model_name = model_name
|
| 47 |
+
self._device = torch.device(device if (torch.cuda.is_available() or device == "cpu") else "cpu")
|
| 48 |
+
self.processor = None
|
| 49 |
+
|
| 50 |
+
def load_model(self) -> None:
|
| 51 |
+
try:
|
| 52 |
+
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
| 53 |
+
except Exception as e:
|
| 54 |
+
raise ImportError(
|
| 55 |
+
"transformers>=4.36 is required for Grounding DINO (AutoProcessor + AutoModelForZeroShotObjectDetection)"
|
| 56 |
+
) from e
|
| 57 |
+
|
| 58 |
+
ensure_default_checkpoint_dirs()
|
| 59 |
+
cache_dir = str(hf_cache_dir())
|
| 60 |
+
self.processor = AutoProcessor.from_pretrained(self.model_name, cache_dir=cache_dir)
|
| 61 |
+
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(self.model_name, cache_dir=cache_dir).to(self._device).eval()
|
| 62 |
+
|
| 63 |
+
def get_available_classes(self):
|
| 64 |
+
return None
|
| 65 |
+
|
| 66 |
+
def detect(self, image: Image.Image, conf_threshold: float = 0.25, **kwargs: Any) -> List[Detection]:
|
| 67 |
+
assert self.model is not None and self.processor is not None
|
| 68 |
+
|
| 69 |
+
img = image.convert("RGB")
|
| 70 |
+
prompts = _normalize_prompts(kwargs.get("prompts") or kwargs.get("classes") or kwargs.get("labels"))
|
| 71 |
+
text: Optional[str] = kwargs.get("text")
|
| 72 |
+
|
| 73 |
+
if text is None:
|
| 74 |
+
if not prompts:
|
| 75 |
+
raise ValueError("GroundingDINO requires `prompts`/`classes` (or raw `text`) for open-vocabulary detection")
|
| 76 |
+
# GroundingDINO expects a single string; period-delimited works well
|
| 77 |
+
text = " . ".join(prompts)
|
| 78 |
+
if not text.endswith("."):
|
| 79 |
+
text = text + " ."
|
| 80 |
+
|
| 81 |
+
threshold = float(kwargs.get("threshold", conf_threshold))
|
| 82 |
+
|
| 83 |
+
inputs = self.processor(images=img, text=text, return_tensors="pt")
|
| 84 |
+
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
| 85 |
+
|
| 86 |
+
with torch.no_grad():
|
| 87 |
+
outputs = self.model(**inputs)
|
| 88 |
+
|
| 89 |
+
target_sizes = torch.tensor([img.size[::-1]], device=self._device) # (h, w)
|
| 90 |
+
|
| 91 |
+
# Prefer the processor's helper if available.
|
| 92 |
+
if hasattr(self.processor, "post_process_grounded_object_detection"):
|
| 93 |
+
# Move input_ids to CPU for post-processing (some versions have device mismatch bugs)
|
| 94 |
+
input_ids_for_postprocess = inputs["input_ids"].cpu()
|
| 95 |
+
results = self.processor.post_process_grounded_object_detection(
|
| 96 |
+
outputs,
|
| 97 |
+
input_ids_for_postprocess,
|
| 98 |
+
threshold=threshold,
|
| 99 |
+
target_sizes=target_sizes.cpu(),
|
| 100 |
+
)[0]
|
| 101 |
+
else:
|
| 102 |
+
raise RuntimeError("This processor does not support post_process_grounded_object_detection")
|
| 103 |
+
|
| 104 |
+
boxes = results.get("boxes")
|
| 105 |
+
scores = results.get("scores")
|
| 106 |
+
text_labels = results.get("text_labels") or results.get("labels")
|
| 107 |
+
|
| 108 |
+
if boxes is None or scores is None:
|
| 109 |
+
return []
|
| 110 |
+
|
| 111 |
+
dets: List[Detection] = []
|
| 112 |
+
for i in range(int(scores.shape[0])):
|
| 113 |
+
score = float(scores[i])
|
| 114 |
+
box = boxes[i]
|
| 115 |
+
b = [float(v) for v in box.tolist()]
|
| 116 |
+
|
| 117 |
+
label: str
|
| 118 |
+
if isinstance(text_labels, (list, tuple)) and i < len(text_labels):
|
| 119 |
+
label = str(text_labels[i])
|
| 120 |
+
elif torch.is_tensor(text_labels) and i < int(text_labels.shape[0]):
|
| 121 |
+
lid = int(text_labels[i])
|
| 122 |
+
if prompts and 0 <= lid < len(prompts):
|
| 123 |
+
label = prompts[lid]
|
| 124 |
+
else:
|
| 125 |
+
label = str(lid)
|
| 126 |
+
else:
|
| 127 |
+
label = "object"
|
| 128 |
+
|
| 129 |
+
dets.append(Detection(label=label, score=score, bbox_xyxy=b))
|
| 130 |
+
|
| 131 |
+
return dets
|
| 132 |
+
|
| 133 |
+
def detect_batch(self, images: List[Image.Image], conf_threshold: float = 0.25, **kwargs: Any) -> List[List[Detection]]:
|
| 134 |
+
"""Batch detection for video processing.
|
| 135 |
+
|
| 136 |
+
Note: Grounding DINO requires the same text prompt for all images in batch.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
images: List of PIL Images
|
| 140 |
+
conf_threshold: Confidence threshold
|
| 141 |
+
**kwargs: Should include 'prompts'/'classes' or 'text'
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
List of detection lists, one per image
|
| 145 |
+
"""
|
| 146 |
+
assert self.model is not None and self.processor is not None
|
| 147 |
+
|
| 148 |
+
if not images:
|
| 149 |
+
return []
|
| 150 |
+
|
| 151 |
+
imgs = [img.convert("RGB") for img in images]
|
| 152 |
+
prompts = _normalize_prompts(kwargs.get("prompts") or kwargs.get("classes") or kwargs.get("labels"))
|
| 153 |
+
text: Optional[str] = kwargs.get("text")
|
| 154 |
+
|
| 155 |
+
if text is None:
|
| 156 |
+
if not prompts:
|
| 157 |
+
raise ValueError("GroundingDINO requires `prompts`/`classes` (or raw `text`) for open-vocabulary detection")
|
| 158 |
+
text = " . ".join(prompts)
|
| 159 |
+
if not text.endswith("."):
|
| 160 |
+
text = text + " ."
|
| 161 |
+
|
| 162 |
+
threshold = float(kwargs.get("threshold", conf_threshold))
|
| 163 |
+
|
| 164 |
+
# Batch process images with same text prompt
|
| 165 |
+
inputs = self.processor(images=imgs, text=[text] * len(imgs), return_tensors="pt")
|
| 166 |
+
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
| 167 |
+
|
| 168 |
+
with torch.no_grad():
|
| 169 |
+
outputs = self.model(**inputs)
|
| 170 |
+
|
| 171 |
+
target_sizes = torch.tensor([img.size[::-1] for img in imgs], device=self._device)
|
| 172 |
+
|
| 173 |
+
if not hasattr(self.processor, "post_process_grounded_object_detection"):
|
| 174 |
+
raise RuntimeError("This processor does not support post_process_grounded_object_detection")
|
| 175 |
+
|
| 176 |
+
# Post-process
|
| 177 |
+
input_ids_for_postprocess = inputs["input_ids"].cpu()
|
| 178 |
+
results = self.processor.post_process_grounded_object_detection(
|
| 179 |
+
outputs,
|
| 180 |
+
input_ids_for_postprocess,
|
| 181 |
+
threshold=threshold,
|
| 182 |
+
target_sizes=target_sizes.cpu(),
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# Parse results for each image
|
| 186 |
+
all_detections = []
|
| 187 |
+
for result in results:
|
| 188 |
+
boxes = result.get("boxes")
|
| 189 |
+
scores = result.get("scores")
|
| 190 |
+
text_labels = result.get("text_labels") or result.get("labels")
|
| 191 |
+
|
| 192 |
+
frame_dets = []
|
| 193 |
+
if boxes is not None and scores is not None:
|
| 194 |
+
for i in range(int(scores.shape[0])):
|
| 195 |
+
score = float(scores[i])
|
| 196 |
+
box = boxes[i]
|
| 197 |
+
b = [float(v) for v in box.tolist()]
|
| 198 |
+
|
| 199 |
+
label: str
|
| 200 |
+
if isinstance(text_labels, (list, tuple)) and i < len(text_labels):
|
| 201 |
+
label = str(text_labels[i])
|
| 202 |
+
elif torch.is_tensor(text_labels) and i < int(text_labels.shape[0]):
|
| 203 |
+
lid = int(text_labels[i])
|
| 204 |
+
if prompts and 0 <= lid < len(prompts):
|
| 205 |
+
label = prompts[lid]
|
| 206 |
+
else:
|
| 207 |
+
label = str(lid)
|
| 208 |
+
else:
|
| 209 |
+
label = "object"
|
| 210 |
+
|
| 211 |
+
frame_dets.append(Detection(label=label, score=score, bbox_xyxy=b))
|
| 212 |
+
|
| 213 |
+
all_detections.append(frame_dets)
|
| 214 |
+
|
| 215 |
+
return all_detections
|
detection/torchvision_detectors.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Torchvision-based object detectors: Faster R-CNN, RetinaNet, SSD, FCOS.
|
| 2 |
+
|
| 3 |
+
Also includes an EfficientDet wrapper (optional dependency: effdet).
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from model_cache import ensure_default_checkpoint_dirs
|
| 9 |
+
|
| 10 |
+
from typing import Dict, List, Optional, Union
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from torchvision.transforms.functional import to_tensor
|
| 16 |
+
|
| 17 |
+
from .base import BaseDetector, Detection
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Ensure torchvision/torch hub downloads land under `checkpoints/` by default.
|
| 21 |
+
ensure_default_checkpoint_dirs()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class _TorchvisionCOCODetector(BaseDetector):
|
| 25 |
+
_categories: Optional[List[str]] = None
|
| 26 |
+
|
| 27 |
+
def __init__(self, device: str = "cuda", **kwargs):
|
| 28 |
+
super().__init__(device=device, **kwargs)
|
| 29 |
+
self._device = torch.device(device if torch.cuda.is_available() or device == "cpu" else "cpu")
|
| 30 |
+
|
| 31 |
+
def get_available_classes(self) -> Union[List[str], Dict[str, int], None]:
|
| 32 |
+
if not self._categories:
|
| 33 |
+
return None
|
| 34 |
+
return {name: int(i) for i, name in enumerate(self._categories)}
|
| 35 |
+
|
| 36 |
+
def detect(self, image: Image.Image, conf_threshold: float = 0.25, **kwargs) -> List[Detection]:
|
| 37 |
+
assert self.model is not None
|
| 38 |
+
|
| 39 |
+
img = image.convert("RGB")
|
| 40 |
+
x = to_tensor(img).to(self._device)
|
| 41 |
+
|
| 42 |
+
with torch.no_grad():
|
| 43 |
+
out = self.model([x])[0]
|
| 44 |
+
|
| 45 |
+
boxes = out.get("boxes")
|
| 46 |
+
labels = out.get("labels")
|
| 47 |
+
scores = out.get("scores")
|
| 48 |
+
|
| 49 |
+
if boxes is None or labels is None or scores is None:
|
| 50 |
+
return []
|
| 51 |
+
|
| 52 |
+
boxes = boxes.detach().cpu().numpy()
|
| 53 |
+
labels = labels.detach().cpu().numpy().astype(int)
|
| 54 |
+
scores = scores.detach().cpu().numpy()
|
| 55 |
+
|
| 56 |
+
dets: List[Detection] = []
|
| 57 |
+
for b, l, s in zip(boxes, labels, scores):
|
| 58 |
+
if float(s) < float(conf_threshold):
|
| 59 |
+
continue
|
| 60 |
+
label_name = str(int(l))
|
| 61 |
+
if self._categories and 0 <= int(l) < len(self._categories):
|
| 62 |
+
label_name = self._categories[int(l)]
|
| 63 |
+
dets.append(
|
| 64 |
+
Detection(
|
| 65 |
+
label=label_name,
|
| 66 |
+
score=float(s),
|
| 67 |
+
bbox_xyxy=[float(v) for v in b.tolist()],
|
| 68 |
+
)
|
| 69 |
+
)
|
| 70 |
+
return dets
|
| 71 |
+
|
| 72 |
+
def detect_batch(self, images: List[Image.Image], conf_threshold: float = 0.25, **kwargs) -> List[List[Detection]]:
|
| 73 |
+
"""Batch detection for video processing - much faster than frame-by-frame.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
images: List of PIL Images
|
| 77 |
+
conf_threshold: Confidence threshold
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
List of detection lists, one per image
|
| 81 |
+
"""
|
| 82 |
+
assert self.model is not None
|
| 83 |
+
|
| 84 |
+
if not images:
|
| 85 |
+
return []
|
| 86 |
+
|
| 87 |
+
# Convert all images to tensors
|
| 88 |
+
imgs = [img.convert("RGB") for img in images]
|
| 89 |
+
tensors = [to_tensor(img).to(self._device) for img in imgs]
|
| 90 |
+
|
| 91 |
+
# Batch inference - torchvision models accept list of tensors
|
| 92 |
+
with torch.no_grad():
|
| 93 |
+
outputs = self.model(tensors)
|
| 94 |
+
|
| 95 |
+
# Parse results for each image
|
| 96 |
+
all_detections = []
|
| 97 |
+
for out in outputs:
|
| 98 |
+
boxes = out.get("boxes")
|
| 99 |
+
labels = out.get("labels")
|
| 100 |
+
scores = out.get("scores")
|
| 101 |
+
|
| 102 |
+
frame_dets = []
|
| 103 |
+
if boxes is not None and labels is not None and scores is not None:
|
| 104 |
+
boxes = boxes.detach().cpu().numpy()
|
| 105 |
+
labels = labels.detach().cpu().numpy().astype(int)
|
| 106 |
+
scores = scores.detach().cpu().numpy()
|
| 107 |
+
|
| 108 |
+
for b, l, s in zip(boxes, labels, scores):
|
| 109 |
+
if float(s) < float(conf_threshold):
|
| 110 |
+
continue
|
| 111 |
+
label_name = str(int(l))
|
| 112 |
+
if self._categories and 0 <= int(l) < len(self._categories):
|
| 113 |
+
label_name = self._categories[int(l)]
|
| 114 |
+
frame_dets.append(
|
| 115 |
+
Detection(
|
| 116 |
+
label=label_name,
|
| 117 |
+
score=float(s),
|
| 118 |
+
bbox_xyxy=[float(v) for v in b.tolist()],
|
| 119 |
+
)
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
all_detections.append(frame_dets)
|
| 123 |
+
|
| 124 |
+
return all_detections
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class FasterRCNNDetector(_TorchvisionCOCODetector):
|
| 128 |
+
def load_model(self) -> None:
|
| 129 |
+
from torchvision.models.detection import fasterrcnn_resnet50_fpn
|
| 130 |
+
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
|
| 131 |
+
|
| 132 |
+
weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
|
| 133 |
+
self._categories = list(weights.meta.get("categories", []))
|
| 134 |
+
self.model = fasterrcnn_resnet50_fpn(weights=weights).to(self._device).eval()
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class RetinaNetDetector(_TorchvisionCOCODetector):
|
| 138 |
+
def load_model(self) -> None:
|
| 139 |
+
from torchvision.models.detection import retinanet_resnet50_fpn
|
| 140 |
+
from torchvision.models.detection import RetinaNet_ResNet50_FPN_Weights
|
| 141 |
+
|
| 142 |
+
weights = RetinaNet_ResNet50_FPN_Weights.DEFAULT
|
| 143 |
+
self._categories = list(weights.meta.get("categories", []))
|
| 144 |
+
self.model = retinanet_resnet50_fpn(weights=weights).to(self._device).eval()
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class SSDDetector(_TorchvisionCOCODetector):
|
| 148 |
+
def load_model(self) -> None:
|
| 149 |
+
from torchvision.models.detection import ssd300_vgg16
|
| 150 |
+
from torchvision.models.detection import SSD300_VGG16_Weights
|
| 151 |
+
|
| 152 |
+
weights = SSD300_VGG16_Weights.DEFAULT
|
| 153 |
+
self._categories = list(weights.meta.get("categories", []))
|
| 154 |
+
self.model = ssd300_vgg16(weights=weights).to(self._device).eval()
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class FCOSDetector(_TorchvisionCOCODetector):
|
| 158 |
+
def load_model(self) -> None:
|
| 159 |
+
from torchvision.models.detection import fcos_resnet50_fpn
|
| 160 |
+
from torchvision.models.detection import FCOS_ResNet50_FPN_Weights
|
| 161 |
+
|
| 162 |
+
weights = FCOS_ResNet50_FPN_Weights.DEFAULT
|
| 163 |
+
self._categories = list(weights.meta.get("categories", []))
|
| 164 |
+
self.model = fcos_resnet50_fpn(weights=weights).to(self._device).eval()
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class EfficientDetDetector(BaseDetector):
|
| 168 |
+
"""EfficientDet via `effdet` (optional dependency).
|
| 169 |
+
|
| 170 |
+
This is implemented so the module is extendable, but requires:
|
| 171 |
+
- `pip install effdet`
|
| 172 |
+
|
| 173 |
+
We default to `tf_efficientdet_d0`.
|
| 174 |
+
"""
|
| 175 |
+
|
| 176 |
+
def __init__(self, device: str = "cuda", model_name: str = "tf_efficientdet_d0", **kwargs):
|
| 177 |
+
super().__init__(device=device, **kwargs)
|
| 178 |
+
self.model_name = model_name
|
| 179 |
+
self._device = torch.device(device if torch.cuda.is_available() or device == "cpu" else "cpu")
|
| 180 |
+
self._categories: Optional[List[str]] = None
|
| 181 |
+
self._input_size_hw: Optional[tuple[int, int]] = None
|
| 182 |
+
|
| 183 |
+
def load_model(self) -> None:
|
| 184 |
+
try:
|
| 185 |
+
from effdet import create_model
|
| 186 |
+
from effdet.config import get_efficientdet_config
|
| 187 |
+
except Exception as e:
|
| 188 |
+
raise ImportError(
|
| 189 |
+
"EfficientDet requires the optional package 'effdet'. "
|
| 190 |
+
"Install it with: pip install effdet"
|
| 191 |
+
) from e
|
| 192 |
+
|
| 193 |
+
self.model = create_model(
|
| 194 |
+
self.model_name,
|
| 195 |
+
pretrained=True,
|
| 196 |
+
bench_task="predict",
|
| 197 |
+
).to(self._device).eval()
|
| 198 |
+
|
| 199 |
+
# effdet prediction bench expects a fixed input resolution.
|
| 200 |
+
cfg = get_efficientdet_config(self.model_name)
|
| 201 |
+
# cfg.image_size is [H, W]
|
| 202 |
+
self._input_size_hw = (int(cfg.image_size[0]), int(cfg.image_size[1]))
|
| 203 |
+
|
| 204 |
+
# effdet uses COCO labels by default for pretrained models
|
| 205 |
+
# Keep class list unknown here to avoid hard-coding.
|
| 206 |
+
self._categories = None
|
| 207 |
+
|
| 208 |
+
def get_available_classes(self) -> Union[List[str], Dict[str, int], None]:
|
| 209 |
+
return self._categories
|
| 210 |
+
|
| 211 |
+
def detect(self, image: Image.Image, conf_threshold: float = 0.25, **kwargs) -> List[Detection]:
|
| 212 |
+
assert self.model is not None
|
| 213 |
+
|
| 214 |
+
img = image.convert("RGB")
|
| 215 |
+
orig_w, orig_h = img.size
|
| 216 |
+
|
| 217 |
+
# Resize to the model's configured input size. This avoids internal shape mismatch
|
| 218 |
+
# issues in effdet when given arbitrary resolutions.
|
| 219 |
+
target_h, target_w = self._input_size_hw or (512, 512)
|
| 220 |
+
resized = img.resize((target_w, target_h))
|
| 221 |
+
x0 = to_tensor(resized)
|
| 222 |
+
|
| 223 |
+
in_h, in_w = target_h, target_w
|
| 224 |
+
x = x0.unsqueeze(0).to(self._device)
|
| 225 |
+
|
| 226 |
+
with torch.no_grad():
|
| 227 |
+
pred = self.model(x)
|
| 228 |
+
|
| 229 |
+
# pred is a dict-like with 'detections': [B, max_det, 6]
|
| 230 |
+
det = pred[0].detach().cpu().numpy()
|
| 231 |
+
out: List[Detection] = []
|
| 232 |
+
sx = orig_w / float(in_w) if in_w > 0 else 1.0
|
| 233 |
+
sy = orig_h / float(in_h) if in_h > 0 else 1.0
|
| 234 |
+
for row in det:
|
| 235 |
+
x1, y1, x2, y2, score, cls = row.tolist()
|
| 236 |
+
if float(score) < float(conf_threshold):
|
| 237 |
+
continue
|
| 238 |
+
out.append(
|
| 239 |
+
Detection(
|
| 240 |
+
label=str(int(cls)),
|
| 241 |
+
score=float(score),
|
| 242 |
+
bbox_xyxy=[float(x1) * sx, float(y1) * sy, float(x2) * sx, float(y2) * sy],
|
| 243 |
+
)
|
| 244 |
+
)
|
| 245 |
+
return out
|
| 246 |
+
|
| 247 |
+
def detect_batch(self, images: List[Image.Image], conf_threshold: float = 0.25, **kwargs) -> List[List[Detection]]:
|
| 248 |
+
"""Batch detection for video processing.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
images: List of PIL Images
|
| 252 |
+
conf_threshold: Confidence threshold
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
List of detection lists, one per image
|
| 256 |
+
"""
|
| 257 |
+
assert self.model is not None
|
| 258 |
+
|
| 259 |
+
if not images:
|
| 260 |
+
return []
|
| 261 |
+
|
| 262 |
+
# Convert and resize all images
|
| 263 |
+
target_h, target_w = self._input_size_hw or (512, 512)
|
| 264 |
+
tensors = []
|
| 265 |
+
scales = []
|
| 266 |
+
|
| 267 |
+
for img in images:
|
| 268 |
+
img_rgb = img.convert("RGB")
|
| 269 |
+
orig_w, orig_h = img_rgb.size
|
| 270 |
+
resized = img_rgb.resize((target_w, target_h))
|
| 271 |
+
tensors.append(to_tensor(resized))
|
| 272 |
+
|
| 273 |
+
sx = orig_w / float(target_w) if target_w > 0 else 1.0
|
| 274 |
+
sy = orig_h / float(target_h) if target_h > 0 else 1.0
|
| 275 |
+
scales.append((sx, sy))
|
| 276 |
+
|
| 277 |
+
# Batch inference
|
| 278 |
+
batch = torch.stack(tensors).to(self._device)
|
| 279 |
+
with torch.no_grad():
|
| 280 |
+
preds = self.model(batch)
|
| 281 |
+
|
| 282 |
+
# Parse results for each image
|
| 283 |
+
all_detections = []
|
| 284 |
+
for pred, (sx, sy) in zip(preds, scales):
|
| 285 |
+
det = pred.detach().cpu().numpy()
|
| 286 |
+
frame_dets = []
|
| 287 |
+
for row in det:
|
| 288 |
+
x1, y1, x2, y2, score, cls = row.tolist()
|
| 289 |
+
if float(score) < float(conf_threshold):
|
| 290 |
+
continue
|
| 291 |
+
frame_dets.append(
|
| 292 |
+
Detection(
|
| 293 |
+
label=str(int(cls)),
|
| 294 |
+
score=float(score),
|
| 295 |
+
bbox_xyxy=[float(x1) * sx, float(y1) * sy, float(x2) * sx, float(y2) * sy],
|
| 296 |
+
)
|
| 297 |
+
)
|
| 298 |
+
all_detections.append(frame_dets)
|
| 299 |
+
|
| 300 |
+
return all_detections
|
detection/tracker.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Simple IoU-based object tracker for video processing.
|
| 2 |
+
|
| 3 |
+
Provides basic multi-object tracking using detection-to-track association
|
| 4 |
+
based on IoU overlap and optional appearance features.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from typing import Dict, List, Optional, Sequence, Tuple
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class Track:
|
| 17 |
+
"""A tracked object across frames."""
|
| 18 |
+
|
| 19 |
+
track_id: int
|
| 20 |
+
label: str
|
| 21 |
+
|
| 22 |
+
# History of bounding boxes [(frame_idx, bbox_xyxy, score)]
|
| 23 |
+
history: List[Tuple[int, List[float], float]] = field(default_factory=list)
|
| 24 |
+
|
| 25 |
+
# Current state
|
| 26 |
+
last_bbox: List[float] = field(default_factory=list)
|
| 27 |
+
last_score: float = 0.0
|
| 28 |
+
last_frame: int = -1
|
| 29 |
+
|
| 30 |
+
# Track status
|
| 31 |
+
age: int = 0 # Frames since first detection
|
| 32 |
+
hits: int = 0 # Total detection matches
|
| 33 |
+
time_since_update: int = 0 # Frames since last match
|
| 34 |
+
|
| 35 |
+
# Color for visualization (RGB)
|
| 36 |
+
color: Tuple[int, int, int] = (0, 255, 0)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class SimpleTracker:
|
| 40 |
+
"""IoU-based multi-object tracker.
|
| 41 |
+
|
| 42 |
+
Associates detections to existing tracks using IoU overlap.
|
| 43 |
+
Creates new tracks for unmatched detections.
|
| 44 |
+
Removes tracks that haven't been updated for too long.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
iou_threshold: float = 0.3,
|
| 50 |
+
max_age: int = 30,
|
| 51 |
+
min_hits: int = 3,
|
| 52 |
+
label_match: bool = True,
|
| 53 |
+
):
|
| 54 |
+
"""
|
| 55 |
+
Args:
|
| 56 |
+
iou_threshold: Minimum IoU for detection-track association
|
| 57 |
+
max_age: Maximum frames a track survives without update
|
| 58 |
+
min_hits: Minimum detections before track is confirmed
|
| 59 |
+
label_match: Require label match for association
|
| 60 |
+
"""
|
| 61 |
+
self.iou_threshold = iou_threshold
|
| 62 |
+
self.max_age = max_age
|
| 63 |
+
self.min_hits = min_hits
|
| 64 |
+
self.label_match = label_match
|
| 65 |
+
|
| 66 |
+
self._tracks: Dict[int, Track] = {}
|
| 67 |
+
self._next_id: int = 1
|
| 68 |
+
self._frame_idx: int = 0
|
| 69 |
+
|
| 70 |
+
# Color palette for tracks
|
| 71 |
+
self._colors = [
|
| 72 |
+
(255, 0, 0), (0, 255, 0), (0, 0, 255),
|
| 73 |
+
(255, 255, 0), (255, 0, 255), (0, 255, 255),
|
| 74 |
+
(255, 128, 0), (128, 0, 255), (0, 255, 128),
|
| 75 |
+
(255, 0, 128), (128, 255, 0), (0, 128, 255),
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
def reset(self):
|
| 79 |
+
"""Reset tracker state."""
|
| 80 |
+
self._tracks = {}
|
| 81 |
+
self._next_id = 1
|
| 82 |
+
self._frame_idx = 0
|
| 83 |
+
|
| 84 |
+
def update(
|
| 85 |
+
self,
|
| 86 |
+
detections: List[Dict],
|
| 87 |
+
frame_idx: Optional[int] = None,
|
| 88 |
+
) -> List[Dict]:
|
| 89 |
+
"""Update tracks with new detections.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
detections: List of detection dicts with label, score, bbox_xyxy
|
| 93 |
+
frame_idx: Optional frame index (auto-incremented if None)
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
List of track dicts with track_id, label, bbox_xyxy, score
|
| 97 |
+
"""
|
| 98 |
+
if frame_idx is None:
|
| 99 |
+
self._frame_idx += 1
|
| 100 |
+
else:
|
| 101 |
+
self._frame_idx = frame_idx
|
| 102 |
+
|
| 103 |
+
# Increment age for all tracks
|
| 104 |
+
for track in self._tracks.values():
|
| 105 |
+
track.time_since_update += 1
|
| 106 |
+
|
| 107 |
+
if not detections:
|
| 108 |
+
# No detections - just age tracks
|
| 109 |
+
self._remove_old_tracks()
|
| 110 |
+
return self._get_active_tracks()
|
| 111 |
+
|
| 112 |
+
# Build cost matrix (negative IoU for Hungarian matching)
|
| 113 |
+
track_ids = list(self._tracks.keys())
|
| 114 |
+
det_indices = list(range(len(detections)))
|
| 115 |
+
|
| 116 |
+
if track_ids and det_indices:
|
| 117 |
+
# Compute IoU matrix
|
| 118 |
+
iou_matrix = self._compute_iou_matrix(
|
| 119 |
+
[self._tracks[tid].last_bbox for tid in track_ids],
|
| 120 |
+
[d["bbox_xyxy"] for d in detections],
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Greedy matching (could use Hungarian for optimal)
|
| 124 |
+
matches, unmatched_tracks, unmatched_dets = self._greedy_match(
|
| 125 |
+
iou_matrix,
|
| 126 |
+
track_ids,
|
| 127 |
+
det_indices,
|
| 128 |
+
detections,
|
| 129 |
+
)
|
| 130 |
+
else:
|
| 131 |
+
matches = []
|
| 132 |
+
unmatched_tracks = track_ids
|
| 133 |
+
unmatched_dets = det_indices
|
| 134 |
+
|
| 135 |
+
# Update matched tracks
|
| 136 |
+
for track_id, det_idx in matches:
|
| 137 |
+
det = detections[det_idx]
|
| 138 |
+
track = self._tracks[track_id]
|
| 139 |
+
track.last_bbox = det["bbox_xyxy"]
|
| 140 |
+
track.last_score = det["score"]
|
| 141 |
+
track.last_frame = self._frame_idx
|
| 142 |
+
track.hits += 1
|
| 143 |
+
track.time_since_update = 0
|
| 144 |
+
track.history.append((self._frame_idx, det["bbox_xyxy"], det["score"]))
|
| 145 |
+
|
| 146 |
+
# Create new tracks for unmatched detections
|
| 147 |
+
for det_idx in unmatched_dets:
|
| 148 |
+
det = detections[det_idx]
|
| 149 |
+
color = self._colors[self._next_id % len(self._colors)]
|
| 150 |
+
|
| 151 |
+
track = Track(
|
| 152 |
+
track_id=self._next_id,
|
| 153 |
+
label=det["label"],
|
| 154 |
+
last_bbox=det["bbox_xyxy"],
|
| 155 |
+
last_score=det["score"],
|
| 156 |
+
last_frame=self._frame_idx,
|
| 157 |
+
age=1,
|
| 158 |
+
hits=1,
|
| 159 |
+
time_since_update=0,
|
| 160 |
+
color=color,
|
| 161 |
+
)
|
| 162 |
+
track.history.append((self._frame_idx, det["bbox_xyxy"], det["score"]))
|
| 163 |
+
|
| 164 |
+
self._tracks[self._next_id] = track
|
| 165 |
+
self._next_id += 1
|
| 166 |
+
|
| 167 |
+
# Remove old tracks
|
| 168 |
+
self._remove_old_tracks()
|
| 169 |
+
|
| 170 |
+
return self._get_active_tracks()
|
| 171 |
+
|
| 172 |
+
def update_batch(
|
| 173 |
+
self,
|
| 174 |
+
detections_list: List[List[Dict]],
|
| 175 |
+
) -> List[Dict]:
|
| 176 |
+
"""Update tracks with a batch of frames.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
detections_list: List of detection lists, one per frame
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
List of all track dicts with full history
|
| 183 |
+
"""
|
| 184 |
+
for dets in detections_list:
|
| 185 |
+
self.update(dets)
|
| 186 |
+
|
| 187 |
+
return self._get_all_tracks_with_history()
|
| 188 |
+
|
| 189 |
+
def _compute_iou_matrix(
|
| 190 |
+
self,
|
| 191 |
+
boxes_a: List[List[float]],
|
| 192 |
+
boxes_b: List[List[float]],
|
| 193 |
+
) -> np.ndarray:
|
| 194 |
+
"""Compute IoU matrix between two sets of boxes."""
|
| 195 |
+
n_a = len(boxes_a)
|
| 196 |
+
n_b = len(boxes_b)
|
| 197 |
+
|
| 198 |
+
if n_a == 0 or n_b == 0:
|
| 199 |
+
return np.zeros((n_a, n_b))
|
| 200 |
+
|
| 201 |
+
iou_matrix = np.zeros((n_a, n_b))
|
| 202 |
+
|
| 203 |
+
for i, box_a in enumerate(boxes_a):
|
| 204 |
+
for j, box_b in enumerate(boxes_b):
|
| 205 |
+
iou_matrix[i, j] = self._box_iou(box_a, box_b)
|
| 206 |
+
|
| 207 |
+
return iou_matrix
|
| 208 |
+
|
| 209 |
+
def _box_iou(self, a: List[float], b: List[float]) -> float:
|
| 210 |
+
"""Compute IoU between two boxes."""
|
| 211 |
+
ax1, ay1, ax2, ay2 = a
|
| 212 |
+
bx1, by1, bx2, by2 = b
|
| 213 |
+
|
| 214 |
+
inter_x1 = max(ax1, bx1)
|
| 215 |
+
inter_y1 = max(ay1, by1)
|
| 216 |
+
inter_x2 = min(ax2, bx2)
|
| 217 |
+
inter_y2 = min(ay2, by2)
|
| 218 |
+
|
| 219 |
+
inter_w = max(0.0, inter_x2 - inter_x1)
|
| 220 |
+
inter_h = max(0.0, inter_y2 - inter_y1)
|
| 221 |
+
inter = inter_w * inter_h
|
| 222 |
+
|
| 223 |
+
area_a = max(0.0, ax2 - ax1) * max(0.0, ay2 - ay1)
|
| 224 |
+
area_b = max(0.0, bx2 - bx1) * max(0.0, by2 - by1)
|
| 225 |
+
union = area_a + area_b - inter
|
| 226 |
+
|
| 227 |
+
return float(inter / union) if union > 0 else 0.0
|
| 228 |
+
|
| 229 |
+
def _greedy_match(
|
| 230 |
+
self,
|
| 231 |
+
iou_matrix: np.ndarray,
|
| 232 |
+
track_ids: List[int],
|
| 233 |
+
det_indices: List[int],
|
| 234 |
+
detections: List[Dict],
|
| 235 |
+
) -> Tuple[List[Tuple[int, int]], List[int], List[int]]:
|
| 236 |
+
"""Greedy matching based on IoU."""
|
| 237 |
+
matches: List[Tuple[int, int]] = []
|
| 238 |
+
matched_tracks = set()
|
| 239 |
+
matched_dets = set()
|
| 240 |
+
|
| 241 |
+
# Sort by IoU (descending)
|
| 242 |
+
n_tracks, n_dets = iou_matrix.shape
|
| 243 |
+
pairs = []
|
| 244 |
+
for i in range(n_tracks):
|
| 245 |
+
for j in range(n_dets):
|
| 246 |
+
if iou_matrix[i, j] >= self.iou_threshold:
|
| 247 |
+
pairs.append((iou_matrix[i, j], i, j))
|
| 248 |
+
|
| 249 |
+
pairs.sort(reverse=True, key=lambda x: x[0])
|
| 250 |
+
|
| 251 |
+
for iou_val, track_idx, det_idx in pairs:
|
| 252 |
+
if track_idx in matched_tracks or det_idx in matched_dets:
|
| 253 |
+
continue
|
| 254 |
+
|
| 255 |
+
track_id = track_ids[track_idx]
|
| 256 |
+
det = detections[det_idx]
|
| 257 |
+
|
| 258 |
+
# Check label match if required
|
| 259 |
+
if self.label_match:
|
| 260 |
+
if self._tracks[track_id].label != det["label"]:
|
| 261 |
+
continue
|
| 262 |
+
|
| 263 |
+
matches.append((track_id, det_idx))
|
| 264 |
+
matched_tracks.add(track_idx)
|
| 265 |
+
matched_dets.add(det_idx)
|
| 266 |
+
|
| 267 |
+
unmatched_tracks = [track_ids[i] for i in range(n_tracks) if i not in matched_tracks]
|
| 268 |
+
unmatched_dets = [j for j in range(n_dets) if j not in matched_dets]
|
| 269 |
+
|
| 270 |
+
return matches, unmatched_tracks, unmatched_dets
|
| 271 |
+
|
| 272 |
+
def _remove_old_tracks(self):
|
| 273 |
+
"""Remove tracks that haven't been updated recently."""
|
| 274 |
+
to_remove = []
|
| 275 |
+
for track_id, track in self._tracks.items():
|
| 276 |
+
if track.time_since_update > self.max_age:
|
| 277 |
+
to_remove.append(track_id)
|
| 278 |
+
|
| 279 |
+
for track_id in to_remove:
|
| 280 |
+
del self._tracks[track_id]
|
| 281 |
+
|
| 282 |
+
def _get_active_tracks(self) -> List[Dict]:
|
| 283 |
+
"""Get currently active (confirmed) tracks."""
|
| 284 |
+
result = []
|
| 285 |
+
for track in self._tracks.values():
|
| 286 |
+
# Only return confirmed tracks
|
| 287 |
+
if track.hits >= self.min_hits:
|
| 288 |
+
result.append({
|
| 289 |
+
"track_id": track.track_id,
|
| 290 |
+
"label": track.label,
|
| 291 |
+
"bbox_xyxy": track.last_bbox,
|
| 292 |
+
"score": track.last_score,
|
| 293 |
+
"age": track.age,
|
| 294 |
+
"color": track.color,
|
| 295 |
+
})
|
| 296 |
+
return result
|
| 297 |
+
|
| 298 |
+
def _get_all_tracks_with_history(self) -> List[Dict]:
|
| 299 |
+
"""Get all tracks with full history."""
|
| 300 |
+
result = []
|
| 301 |
+
for track in self._tracks.values():
|
| 302 |
+
if track.hits >= self.min_hits:
|
| 303 |
+
result.append({
|
| 304 |
+
"track_id": track.track_id,
|
| 305 |
+
"label": track.label,
|
| 306 |
+
"bbox_xyxy": track.last_bbox,
|
| 307 |
+
"score": track.last_score,
|
| 308 |
+
"age": track.age,
|
| 309 |
+
"hits": track.hits,
|
| 310 |
+
"color": track.color,
|
| 311 |
+
"history": [
|
| 312 |
+
{"frame": h[0], "bbox_xyxy": h[1], "score": h[2]}
|
| 313 |
+
for h in track.history
|
| 314 |
+
],
|
| 315 |
+
})
|
| 316 |
+
return result
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def draw_tracks(
|
| 320 |
+
image,
|
| 321 |
+
tracks: List[Dict],
|
| 322 |
+
show_id: bool = True,
|
| 323 |
+
show_trail: bool = False,
|
| 324 |
+
trail_length: int = 10,
|
| 325 |
+
):
|
| 326 |
+
"""Draw tracks on image.
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
image: PIL Image
|
| 330 |
+
tracks: List of track dicts
|
| 331 |
+
show_id: Show track ID
|
| 332 |
+
show_trail: Show movement trail
|
| 333 |
+
trail_length: Number of trail points
|
| 334 |
+
|
| 335 |
+
Returns:
|
| 336 |
+
Image with tracks drawn (PIL Image)
|
| 337 |
+
"""
|
| 338 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 339 |
+
|
| 340 |
+
img = image.copy()
|
| 341 |
+
draw = ImageDraw.Draw(img)
|
| 342 |
+
|
| 343 |
+
try:
|
| 344 |
+
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14)
|
| 345 |
+
except Exception:
|
| 346 |
+
font = ImageFont.load_default()
|
| 347 |
+
|
| 348 |
+
for track in tracks:
|
| 349 |
+
bbox = track["bbox_xyxy"]
|
| 350 |
+
color = track.get("color", (0, 255, 0))
|
| 351 |
+
track_id = track.get("track_id", 0)
|
| 352 |
+
label = track.get("label", "")
|
| 353 |
+
score = track.get("score", 0.0)
|
| 354 |
+
|
| 355 |
+
# Draw bounding box
|
| 356 |
+
x1, y1, x2, y2 = [int(c) for c in bbox]
|
| 357 |
+
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
|
| 358 |
+
|
| 359 |
+
# Draw label with track ID
|
| 360 |
+
if show_id:
|
| 361 |
+
text = f"#{track_id} {label}"
|
| 362 |
+
else:
|
| 363 |
+
text = f"{label} {score:.2f}"
|
| 364 |
+
|
| 365 |
+
# Text background
|
| 366 |
+
text_bbox = draw.textbbox((x1, y1 - 20), text, font=font)
|
| 367 |
+
draw.rectangle(text_bbox, fill=color)
|
| 368 |
+
draw.text((x1, y1 - 20), text, fill=(255, 255, 255), font=font)
|
| 369 |
+
|
| 370 |
+
# Draw trail if history available
|
| 371 |
+
if show_trail and "history" in track:
|
| 372 |
+
history = track["history"][-trail_length:]
|
| 373 |
+
if len(history) > 1:
|
| 374 |
+
centers = []
|
| 375 |
+
for h in history:
|
| 376 |
+
hbbox = h["bbox_xyxy"]
|
| 377 |
+
cx = (hbbox[0] + hbbox[2]) / 2
|
| 378 |
+
cy = (hbbox[1] + hbbox[3]) / 2
|
| 379 |
+
centers.append((int(cx), int(cy)))
|
| 380 |
+
|
| 381 |
+
for i in range(len(centers) - 1):
|
| 382 |
+
# Fade trail
|
| 383 |
+
alpha = (i + 1) / len(centers)
|
| 384 |
+
trail_color = tuple(int(c * alpha) for c in color)
|
| 385 |
+
draw.line([centers[i], centers[i + 1]], fill=trail_color, width=2)
|
| 386 |
+
|
| 387 |
+
return img
|
detection/utils.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Detection utilities: IoU, matching, drawing, metrics."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import asdict
|
| 6 |
+
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 10 |
+
|
| 11 |
+
from .base import Detection
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def box_iou_xyxy(a: Sequence[float], b: Sequence[float]) -> float:
|
| 15 |
+
ax1, ay1, ax2, ay2 = a
|
| 16 |
+
bx1, by1, bx2, by2 = b
|
| 17 |
+
|
| 18 |
+
inter_x1 = max(ax1, bx1)
|
| 19 |
+
inter_y1 = max(ay1, by1)
|
| 20 |
+
inter_x2 = min(ax2, bx2)
|
| 21 |
+
inter_y2 = min(ay2, by2)
|
| 22 |
+
|
| 23 |
+
inter_w = max(0.0, inter_x2 - inter_x1)
|
| 24 |
+
inter_h = max(0.0, inter_y2 - inter_y1)
|
| 25 |
+
inter = inter_w * inter_h
|
| 26 |
+
|
| 27 |
+
area_a = max(0.0, ax2 - ax1) * max(0.0, ay2 - ay1)
|
| 28 |
+
area_b = max(0.0, bx2 - bx1) * max(0.0, by2 - by1)
|
| 29 |
+
union = area_a + area_b - inter
|
| 30 |
+
|
| 31 |
+
return float(inter / union) if union > 0 else 0.0
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def greedy_match_detections(
|
| 35 |
+
before: List[Detection],
|
| 36 |
+
after: List[Detection],
|
| 37 |
+
iou_threshold: float = 0.5,
|
| 38 |
+
require_same_label: bool = True,
|
| 39 |
+
) -> List[Tuple[int, int, float]]:
|
| 40 |
+
"""Greedy match: for each `before` detection (sorted by score), pick best IoU `after`.
|
| 41 |
+
|
| 42 |
+
Returns list of (before_idx, after_idx, iou).
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
before_order = sorted(range(len(before)), key=lambda i: before[i].score, reverse=True)
|
| 46 |
+
used_after = set()
|
| 47 |
+
matches: List[Tuple[int, int, float]] = []
|
| 48 |
+
|
| 49 |
+
for bi in before_order:
|
| 50 |
+
best = None
|
| 51 |
+
best_iou = 0.0
|
| 52 |
+
for ai in range(len(after)):
|
| 53 |
+
if ai in used_after:
|
| 54 |
+
continue
|
| 55 |
+
if require_same_label and before[bi].label != after[ai].label:
|
| 56 |
+
continue
|
| 57 |
+
iou = box_iou_xyxy(before[bi].bbox_xyxy, after[ai].bbox_xyxy)
|
| 58 |
+
if iou >= iou_threshold and iou > best_iou:
|
| 59 |
+
best = ai
|
| 60 |
+
best_iou = iou
|
| 61 |
+
if best is not None:
|
| 62 |
+
used_after.add(best)
|
| 63 |
+
matches.append((bi, best, float(best_iou)))
|
| 64 |
+
|
| 65 |
+
return matches
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def summarize_before_after(
|
| 69 |
+
before: List[Detection],
|
| 70 |
+
after: List[Detection],
|
| 71 |
+
iou_threshold: float = 0.5,
|
| 72 |
+
) -> Dict:
|
| 73 |
+
matches = greedy_match_detections(before, after, iou_threshold=iou_threshold, require_same_label=True)
|
| 74 |
+
|
| 75 |
+
matched_before = {bi for (bi, _, _) in matches}
|
| 76 |
+
matched_after = {ai for (_, ai, _) in matches}
|
| 77 |
+
|
| 78 |
+
retention = (len(matches) / len(before)) if before else 1.0
|
| 79 |
+
emergence = (len(after) - len(matches))
|
| 80 |
+
|
| 81 |
+
ious = [iou for (_, _, iou) in matches]
|
| 82 |
+
score_before = [before[bi].score for (bi, _, _) in matches]
|
| 83 |
+
score_after = [after[ai].score for (_, ai, _) in matches]
|
| 84 |
+
|
| 85 |
+
return {
|
| 86 |
+
"num_before": len(before),
|
| 87 |
+
"num_after": len(after),
|
| 88 |
+
"matched": len(matches),
|
| 89 |
+
"retention": float(retention),
|
| 90 |
+
"new_after": int(emergence),
|
| 91 |
+
"mean_iou_matched": float(np.mean(ious)) if ious else None,
|
| 92 |
+
"mean_score_before_matched": float(np.mean(score_before)) if score_before else None,
|
| 93 |
+
"mean_score_after_matched": float(np.mean(score_after)) if score_after else None,
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def detections_to_dict(dets: List[Detection]) -> List[Dict]:
|
| 98 |
+
return [asdict(d) for d in dets]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def draw_detections(
|
| 102 |
+
image: Image.Image,
|
| 103 |
+
detections: List[Detection],
|
| 104 |
+
max_dets: Optional[int] = 50,
|
| 105 |
+
color: Tuple[int, int, int] = (0, 255, 0),
|
| 106 |
+
) -> Image.Image:
|
| 107 |
+
img = image.copy()
|
| 108 |
+
draw = ImageDraw.Draw(img)
|
| 109 |
+
|
| 110 |
+
try:
|
| 111 |
+
font = ImageFont.load_default()
|
| 112 |
+
except Exception:
|
| 113 |
+
font = None
|
| 114 |
+
|
| 115 |
+
dets = detections[: max_dets or len(detections)]
|
| 116 |
+
for d in dets:
|
| 117 |
+
x1, y1, x2, y2 = d.bbox_xyxy
|
| 118 |
+
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
|
| 119 |
+
label = f"{d.label} {d.score:.2f}"
|
| 120 |
+
if font is not None:
|
| 121 |
+
draw.text((x1 + 2, y1 + 2), label, fill=color, font=font)
|
| 122 |
+
else:
|
| 123 |
+
draw.text((x1 + 2, y1 + 2), label, fill=color)
|
| 124 |
+
return img
|
detection/yolo.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""YOLO detector via Ultralytics."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Dict, List, Optional, Union
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
from .base import BaseDetector, Detection
|
| 11 |
+
from model_cache import default_checkpoint_path, ensure_default_checkpoint_dirs
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class YOLODetector(BaseDetector):
|
| 15 |
+
def __init__(self, device: str = "cuda", model_path: str = default_checkpoint_path("yolo26x.pt"), **kwargs):
|
| 16 |
+
super().__init__(device=device, **kwargs)
|
| 17 |
+
self.model_path = model_path
|
| 18 |
+
self._names: Dict[int, str] = {}
|
| 19 |
+
|
| 20 |
+
def load_model(self) -> None:
|
| 21 |
+
ensure_default_checkpoint_dirs()
|
| 22 |
+
from ultralytics import YOLO
|
| 23 |
+
|
| 24 |
+
self.model = YOLO(self.model_path)
|
| 25 |
+
# ultralytics uses 'cuda:0' style
|
| 26 |
+
self._device_arg = 0 if self.device.startswith("cuda") else "cpu"
|
| 27 |
+
self._names = dict(getattr(self.model, "names", {}) or {})
|
| 28 |
+
|
| 29 |
+
def get_available_classes(self) -> Union[List[str], Dict[str, int], None]:
|
| 30 |
+
if not self._names:
|
| 31 |
+
return None
|
| 32 |
+
# map name->id
|
| 33 |
+
return {name: int(idx) for idx, name in self._names.items()}
|
| 34 |
+
|
| 35 |
+
def detect(self, image: Image.Image, conf_threshold: float = 0.25, **kwargs) -> List[Detection]:
|
| 36 |
+
assert self.model is not None
|
| 37 |
+
|
| 38 |
+
img = np.asarray(image.convert("RGB"))
|
| 39 |
+
res = self.model.predict(source=img, conf=float(conf_threshold), device=self._device_arg, verbose=False)
|
| 40 |
+
if not res:
|
| 41 |
+
return []
|
| 42 |
+
|
| 43 |
+
r0 = res[0]
|
| 44 |
+
if getattr(r0, "boxes", None) is None:
|
| 45 |
+
return []
|
| 46 |
+
|
| 47 |
+
boxes = r0.boxes
|
| 48 |
+
xyxy = boxes.xyxy.detach().cpu().numpy()
|
| 49 |
+
conf = boxes.conf.detach().cpu().numpy()
|
| 50 |
+
cls = boxes.cls.detach().cpu().numpy().astype(int)
|
| 51 |
+
|
| 52 |
+
out: List[Detection] = []
|
| 53 |
+
for (b, s, c) in zip(xyxy, conf, cls):
|
| 54 |
+
label = self._names.get(int(c), str(int(c)))
|
| 55 |
+
out.append(Detection(label=label, score=float(s), bbox_xyxy=[float(x) for x in b.tolist()]))
|
| 56 |
+
return out
|
| 57 |
+
|
| 58 |
+
def detect_batch(self, images: List[Image.Image], conf_threshold: float = 0.25, **kwargs) -> List[List[Detection]]:
|
| 59 |
+
"""Batch detection for video processing - much faster than frame-by-frame.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
images: List of PIL Images
|
| 63 |
+
conf_threshold: Confidence threshold
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
List of detection lists, one per image
|
| 67 |
+
"""
|
| 68 |
+
assert self.model is not None
|
| 69 |
+
|
| 70 |
+
if not images:
|
| 71 |
+
return []
|
| 72 |
+
|
| 73 |
+
# Convert all images to numpy arrays
|
| 74 |
+
imgs = [np.asarray(img.convert("RGB")) for img in images]
|
| 75 |
+
|
| 76 |
+
# TRUE batch inference - YOLO processes all frames together
|
| 77 |
+
results = self.model.predict(source=imgs, conf=float(conf_threshold), device=self._device_arg, verbose=False)
|
| 78 |
+
|
| 79 |
+
# Parse results for each frame
|
| 80 |
+
all_detections = []
|
| 81 |
+
for result in results:
|
| 82 |
+
if not result or getattr(result, "boxes", None) is None:
|
| 83 |
+
all_detections.append([])
|
| 84 |
+
continue
|
| 85 |
+
|
| 86 |
+
boxes = result.boxes
|
| 87 |
+
xyxy = boxes.xyxy.detach().cpu().numpy()
|
| 88 |
+
conf = boxes.conf.detach().cpu().numpy()
|
| 89 |
+
cls = boxes.cls.detach().cpu().numpy().astype(int)
|
| 90 |
+
|
| 91 |
+
frame_dets = []
|
| 92 |
+
for (b, s, c) in zip(xyxy, conf, cls):
|
| 93 |
+
label = self._names.get(int(c), str(int(c)))
|
| 94 |
+
frame_dets.append(Detection(label=label, score=float(s), bbox_xyxy=[float(x) for x in b.tolist()]))
|
| 95 |
+
|
| 96 |
+
all_detections.append(frame_dets)
|
| 97 |
+
|
| 98 |
+
return all_detections
|
detection/yolo_world.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""YOLO-World open-vocabulary object detection via Ultralytics."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, List, Optional, Sequence, Union
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
from .base import BaseDetector, Detection
|
| 11 |
+
from model_cache import default_checkpoint_path, ensure_default_checkpoint_dirs
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _normalize_classes(classes: Optional[Union[str, Sequence[str]]]) -> Optional[List[str]]:
|
| 15 |
+
if classes is None:
|
| 16 |
+
return None
|
| 17 |
+
if isinstance(classes, str):
|
| 18 |
+
parts = [c.strip() for c in classes.split(",")]
|
| 19 |
+
parts = [c for c in parts if c]
|
| 20 |
+
return parts or None
|
| 21 |
+
out = [str(c).strip() for c in classes]
|
| 22 |
+
out = [c for c in out if c]
|
| 23 |
+
return out or None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class YOLOWorldDetector(BaseDetector):
|
| 27 |
+
"""YOLO-World wrapper.
|
| 28 |
+
|
| 29 |
+
Usage:
|
| 30 |
+
detector = create_detector('yolo_world', device='cuda')
|
| 31 |
+
dets = detector(image, classes=['person', 'car'])
|
| 32 |
+
|
| 33 |
+
Notes:
|
| 34 |
+
- This is open-vocabulary; `get_available_classes()` returns None.
|
| 35 |
+
- Requires the Ultralytics `YOLOWorld` class (available in ultralytics>=8).
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
device: str = "cuda",
|
| 41 |
+
model_path: str = default_checkpoint_path("yolo26s-world.pt"),
|
| 42 |
+
**kwargs: Any,
|
| 43 |
+
):
|
| 44 |
+
super().__init__(device=device, **kwargs)
|
| 45 |
+
self.model_path = model_path
|
| 46 |
+
self._device_arg = "cpu"
|
| 47 |
+
self._names: Dict[int, str] = {}
|
| 48 |
+
|
| 49 |
+
def load_model(self) -> None:
|
| 50 |
+
ensure_default_checkpoint_dirs()
|
| 51 |
+
try:
|
| 52 |
+
from ultralytics import YOLOWorld
|
| 53 |
+
except Exception as e:
|
| 54 |
+
raise ImportError("ultralytics is required for YOLO-World detection") from e
|
| 55 |
+
|
| 56 |
+
self.model = YOLOWorld(self.model_path)
|
| 57 |
+
self._device_arg = 0 if self.device.startswith("cuda") else "cpu"
|
| 58 |
+
|
| 59 |
+
def get_available_classes(self):
|
| 60 |
+
return None
|
| 61 |
+
|
| 62 |
+
def detect(self, image: Image.Image, conf_threshold: float = 0.25, **kwargs: Any) -> List[Detection]:
|
| 63 |
+
assert self.model is not None
|
| 64 |
+
|
| 65 |
+
classes = _normalize_classes(kwargs.get("classes") or kwargs.get("prompts") or kwargs.get("labels"))
|
| 66 |
+
if not classes:
|
| 67 |
+
raise ValueError("YOLOWorld requires `classes`/`prompts` (comma-separated string or list) for open-vocabulary detection")
|
| 68 |
+
|
| 69 |
+
# Tell the model which classes to look for.
|
| 70 |
+
# Note: set_classes triggers CLIP text encoding which can have device mismatch issues.
|
| 71 |
+
# We work around this by temporarily moving to CPU for set_classes, then back to GPU for inference.
|
| 72 |
+
if hasattr(self.model, "set_classes"):
|
| 73 |
+
import torch
|
| 74 |
+
try:
|
| 75 |
+
# Try setting classes directly first
|
| 76 |
+
self.model.set_classes(classes)
|
| 77 |
+
except RuntimeError as e:
|
| 78 |
+
if "device" in str(e).lower():
|
| 79 |
+
# Device mismatch - try moving model to CPU temporarily
|
| 80 |
+
if hasattr(self.model, 'model'):
|
| 81 |
+
original_device = next(self.model.model.parameters()).device
|
| 82 |
+
self.model.model.to('cpu')
|
| 83 |
+
self.model.set_classes(classes)
|
| 84 |
+
self.model.model.to(original_device)
|
| 85 |
+
else:
|
| 86 |
+
raise
|
| 87 |
+
|
| 88 |
+
img = np.asarray(image.convert("RGB"))
|
| 89 |
+
res = self.model.predict(source=img, conf=float(conf_threshold), device=self._device_arg, verbose=False)
|
| 90 |
+
if not res:
|
| 91 |
+
return []
|
| 92 |
+
|
| 93 |
+
r0 = res[0]
|
| 94 |
+
# Handle names attribute - can be dict, list, or other types
|
| 95 |
+
raw_names = getattr(r0, "names", {})
|
| 96 |
+
if isinstance(raw_names, dict):
|
| 97 |
+
names = raw_names
|
| 98 |
+
elif isinstance(raw_names, (list, tuple)):
|
| 99 |
+
names = {i: str(n) for i, n in enumerate(raw_names)}
|
| 100 |
+
else:
|
| 101 |
+
names = {}
|
| 102 |
+
|
| 103 |
+
if getattr(r0, "boxes", None) is None:
|
| 104 |
+
return []
|
| 105 |
+
|
| 106 |
+
boxes = r0.boxes
|
| 107 |
+
xyxy = boxes.xyxy.detach().cpu().numpy()
|
| 108 |
+
conf = boxes.conf.detach().cpu().numpy()
|
| 109 |
+
cls = boxes.cls.detach().cpu().numpy().astype(int)
|
| 110 |
+
|
| 111 |
+
out: List[Detection] = []
|
| 112 |
+
for (b, s, c) in zip(xyxy, conf, cls):
|
| 113 |
+
label = names.get(int(c))
|
| 114 |
+
if not label:
|
| 115 |
+
# fallback: index into the user-provided classes if the mapping is missing
|
| 116 |
+
label = classes[int(c)] if 0 <= int(c) < len(classes) else str(int(c))
|
| 117 |
+
out.append(Detection(label=str(label), score=float(s), bbox_xyxy=[float(x) for x in b.tolist()]))
|
| 118 |
+
return out
|
| 119 |
+
|
| 120 |
+
def detect_batch(self, images: List[Image.Image], conf_threshold: float = 0.25, **kwargs: Any) -> List[List[Detection]]:
|
| 121 |
+
"""Batch detection for video processing - much faster than frame-by-frame.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
images: List of PIL Images
|
| 125 |
+
conf_threshold: Confidence threshold
|
| 126 |
+
**kwargs: Should include 'classes'/'prompts' for open-vocabulary detection
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
List of detection lists, one per image
|
| 130 |
+
"""
|
| 131 |
+
assert self.model is not None
|
| 132 |
+
|
| 133 |
+
if not images:
|
| 134 |
+
return []
|
| 135 |
+
|
| 136 |
+
classes = _normalize_classes(kwargs.get("classes") or kwargs.get("prompts") or kwargs.get("labels"))
|
| 137 |
+
if not classes:
|
| 138 |
+
raise ValueError("YOLOWorld requires `classes`/`prompts` (comma-separated string or list) for open-vocabulary detection")
|
| 139 |
+
|
| 140 |
+
# Set classes for the model
|
| 141 |
+
if hasattr(self.model, "set_classes"):
|
| 142 |
+
import torch
|
| 143 |
+
try:
|
| 144 |
+
self.model.set_classes(classes)
|
| 145 |
+
except RuntimeError as e:
|
| 146 |
+
if "device" in str(e).lower():
|
| 147 |
+
if hasattr(self.model, 'model'):
|
| 148 |
+
original_device = next(self.model.model.parameters()).device
|
| 149 |
+
self.model.model.to('cpu')
|
| 150 |
+
self.model.set_classes(classes)
|
| 151 |
+
self.model.model.to(original_device)
|
| 152 |
+
else:
|
| 153 |
+
raise
|
| 154 |
+
|
| 155 |
+
# Convert all images to numpy arrays
|
| 156 |
+
imgs = [np.asarray(img.convert("RGB")) for img in images]
|
| 157 |
+
|
| 158 |
+
# TRUE batch inference
|
| 159 |
+
results = self.model.predict(source=imgs, conf=float(conf_threshold), device=self._device_arg, verbose=False)
|
| 160 |
+
|
| 161 |
+
# Parse results for each frame
|
| 162 |
+
all_detections = []
|
| 163 |
+
for result in results:
|
| 164 |
+
# Handle names attribute
|
| 165 |
+
raw_names = getattr(result, "names", {})
|
| 166 |
+
if isinstance(raw_names, dict):
|
| 167 |
+
names = raw_names
|
| 168 |
+
elif isinstance(raw_names, (list, tuple)):
|
| 169 |
+
names = {i: str(n) for i, n in enumerate(raw_names)}
|
| 170 |
+
else:
|
| 171 |
+
names = {}
|
| 172 |
+
|
| 173 |
+
frame_dets = []
|
| 174 |
+
if getattr(result, "boxes", None) is not None:
|
| 175 |
+
boxes = result.boxes
|
| 176 |
+
xyxy = boxes.xyxy.detach().cpu().numpy()
|
| 177 |
+
conf = boxes.conf.detach().cpu().numpy()
|
| 178 |
+
cls = boxes.cls.detach().cpu().numpy().astype(int)
|
| 179 |
+
|
| 180 |
+
for (b, s, c) in zip(xyxy, conf, cls):
|
| 181 |
+
label = names.get(int(c))
|
| 182 |
+
if not label:
|
| 183 |
+
label = classes[int(c)] if 0 <= int(c) < len(classes) else str(int(c))
|
| 184 |
+
frame_dets.append(Detection(label=str(label), score=float(s), bbox_xyxy=[float(x) for x in b.tolist()]))
|
| 185 |
+
|
| 186 |
+
all_detections.append(frame_dets)
|
| 187 |
+
|
| 188 |
+
return all_detections
|
examples.sh
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Example commands for ROI-VAE compression
|
| 3 |
+
# Covers: Car, Building, Person, Boat
|
| 4 |
+
|
| 5 |
+
# Make sure you're in the roi-vae directory
|
| 6 |
+
cd "$(dirname "$0")"
|
| 7 |
+
|
| 8 |
+
# Create results directory if it doesn't exist
|
| 9 |
+
mkdir -p results
|
| 10 |
+
|
| 11 |
+
echo "ROI-VAE Compression Examples"
|
| 12 |
+
echo "============================="
|
| 13 |
+
echo
|
| 14 |
+
|
| 15 |
+
# ---------------------------------------------------------
|
| 16 |
+
# 1. CAR
|
| 17 |
+
# ---------------------------------------------------------
|
| 18 |
+
echo "1. CAR (YOLO)"
|
| 19 |
+
echo "-------------"
|
| 20 |
+
|
| 21 |
+
# 1a. Only Image
|
| 22 |
+
echo "a) Compressing (Image Only)..."
|
| 23 |
+
python roi_compressor.py \
|
| 24 |
+
--input data/images/car/0016cf15fa4d4e16.jpg \
|
| 25 |
+
--output results/car_compressed.jpg \
|
| 26 |
+
--checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
|
| 27 |
+
--sigma 0.3 \
|
| 28 |
+
--seg-method yolo \
|
| 29 |
+
--seg-classes car
|
| 30 |
+
|
| 31 |
+
# 1b. With Comparison
|
| 32 |
+
echo "b) Compressing (With Comparison)..."
|
| 33 |
+
python roi_compressor.py \
|
| 34 |
+
--input data/images/car/0016cf15fa4d4e16.jpg \
|
| 35 |
+
--output results/car_comparison.jpg \
|
| 36 |
+
--checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
|
| 37 |
+
--sigma 0.3 \
|
| 38 |
+
--seg-method yolo \
|
| 39 |
+
--seg-classes car \
|
| 40 |
+
--highlight
|
| 41 |
+
|
| 42 |
+
echo
|
| 43 |
+
echo "---------------------------------------------------------"
|
| 44 |
+
echo
|
| 45 |
+
|
| 46 |
+
# ---------------------------------------------------------
|
| 47 |
+
# 2. BUILDING
|
| 48 |
+
# ---------------------------------------------------------
|
| 49 |
+
echo "2. BUILDING (SegFormer)"
|
| 50 |
+
echo "-----------------------"
|
| 51 |
+
|
| 52 |
+
# 2a. Only Image
|
| 53 |
+
echo "a) Compressing (Image Only)..."
|
| 54 |
+
python roi_compressor.py \
|
| 55 |
+
--input data/images/building/000571767ec7a593.jpg \
|
| 56 |
+
--output results/building_compressed.jpg \
|
| 57 |
+
--checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
|
| 58 |
+
--sigma 0.3 \
|
| 59 |
+
--seg-classes building
|
| 60 |
+
|
| 61 |
+
# 2b. With Comparison
|
| 62 |
+
echo "b) Compressing (With Comparison)..."
|
| 63 |
+
python roi_compressor.py \
|
| 64 |
+
--input data/images/building/000571767ec7a593.jpg \
|
| 65 |
+
--output results/building_comparison.jpg \
|
| 66 |
+
--checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
|
| 67 |
+
--sigma 0.3 \
|
| 68 |
+
--seg-classes building \
|
| 69 |
+
--highlight
|
| 70 |
+
|
| 71 |
+
echo
|
| 72 |
+
echo "---------------------------------------------------------"
|
| 73 |
+
echo
|
| 74 |
+
|
| 75 |
+
# ---------------------------------------------------------
|
| 76 |
+
# 3. PERSON
|
| 77 |
+
# ---------------------------------------------------------
|
| 78 |
+
echo "3. PERSON (YOLO)"
|
| 79 |
+
echo "----------------"
|
| 80 |
+
|
| 81 |
+
# 3a. Only Image
|
| 82 |
+
echo "a) Compressing (Image Only)..."
|
| 83 |
+
python roi_compressor.py \
|
| 84 |
+
--input data/images/person/kodim04.png \
|
| 85 |
+
--output results/person_compressed.jpg \
|
| 86 |
+
--checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
|
| 87 |
+
--sigma 0.3 \
|
| 88 |
+
--seg-method yolo \
|
| 89 |
+
--seg-classes person
|
| 90 |
+
|
| 91 |
+
# 3b. With Comparison
|
| 92 |
+
echo "b) Compressing (With Comparison)..."
|
| 93 |
+
python roi_compressor.py \
|
| 94 |
+
--input data/images/person/kodim04.png \
|
| 95 |
+
--output results/person_comparison.jpg \
|
| 96 |
+
--checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
|
| 97 |
+
--sigma 0.3 \
|
| 98 |
+
--seg-method yolo \
|
| 99 |
+
--seg-classes person \
|
| 100 |
+
--highlight
|
| 101 |
+
|
| 102 |
+
echo
|
| 103 |
+
echo "---------------------------------------------------------"
|
| 104 |
+
echo
|
| 105 |
+
|
| 106 |
+
# ---------------------------------------------------------
|
| 107 |
+
# 4. BOAT
|
| 108 |
+
# ---------------------------------------------------------
|
| 109 |
+
echo "4. BOAT (YOLO)"
|
| 110 |
+
echo "--------------"
|
| 111 |
+
|
| 112 |
+
# 4a. Only Image
|
| 113 |
+
echo "a) Compressing (Image Only)..."
|
| 114 |
+
python roi_compressor.py \
|
| 115 |
+
--input data/images/boat/kodim06.png \
|
| 116 |
+
--output results/boat_compressed.jpg \
|
| 117 |
+
--checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
|
| 118 |
+
--sigma 0.3 \
|
| 119 |
+
--seg-method yolo \
|
| 120 |
+
--seg-classes boat
|
| 121 |
+
|
| 122 |
+
# 4b. With Comparison
|
| 123 |
+
echo "b) Compressing (With Comparison)..."
|
| 124 |
+
python roi_compressor.py \
|
| 125 |
+
--input data/images/boat/kodim06.png \
|
| 126 |
+
--output results/boat_comparison.jpg \
|
| 127 |
+
--checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
|
| 128 |
+
--sigma 0.3 \
|
| 129 |
+
--seg-method yolo \
|
| 130 |
+
--seg-classes boat \
|
| 131 |
+
--highlight
|
| 132 |
+
|
| 133 |
+
echo
|
| 134 |
+
echo "---------------------------------------------------------"
|
| 135 |
+
echo
|
| 136 |
+
|
| 137 |
+
# ---------------------------------------------------------
|
| 138 |
+
# 5. PARAMETER COMPARISON
|
| 139 |
+
# ---------------------------------------------------------
|
| 140 |
+
echo "5. PARAMETER COMPARISON"
|
| 141 |
+
echo "-----------------------"
|
| 142 |
+
|
| 143 |
+
# 5a. Sigma Comparison (Background Quality)
|
| 144 |
+
echo "a) Sigma Comparison (Background Quality)..."
|
| 145 |
+
# Low Sigma (High Compression)
|
| 146 |
+
echo " - Low Sigma (0.1)"
|
| 147 |
+
python roi_compressor.py \
|
| 148 |
+
--input data/images/car/0016cf15fa4d4e16.jpg \
|
| 149 |
+
--output results/sigma_low.jpg \
|
| 150 |
+
--checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
|
| 151 |
+
--sigma 0.1 \
|
| 152 |
+
--seg-method yolo \
|
| 153 |
+
--seg-classes car \
|
| 154 |
+
--highlight
|
| 155 |
+
|
| 156 |
+
# High Sigma (Low Compression)
|
| 157 |
+
echo " - High Sigma (0.9)"
|
| 158 |
+
python roi_compressor.py \
|
| 159 |
+
--input data/images/car/0016cf15fa4d4e16.jpg \
|
| 160 |
+
--output results/sigma_high.jpg \
|
| 161 |
+
--checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
|
| 162 |
+
--sigma 0.9 \
|
| 163 |
+
--seg-method yolo \
|
| 164 |
+
--seg-classes car \
|
| 165 |
+
--highlight
|
| 166 |
+
|
| 167 |
+
# 5b. Lambda Comparison (Rate-Distortion)
|
| 168 |
+
echo "b) Lambda Comparison (Rate-Distortion)..."
|
| 169 |
+
# Low Lambda
|
| 170 |
+
echo " - Low Lambda (0.013)"
|
| 171 |
+
python roi_compressor.py \
|
| 172 |
+
--input data/images/car/0016cf15fa4d4e16.jpg \
|
| 173 |
+
--output results/lambda_low.jpg \
|
| 174 |
+
--checkpoint checkpoints/tic_lambda_0.013.pth.tar \
|
| 175 |
+
--sigma 0.3 \
|
| 176 |
+
--lambda 0.013 \
|
| 177 |
+
--N 128 \
|
| 178 |
+
--M 192 \
|
| 179 |
+
--seg-method yolo \
|
| 180 |
+
--seg-classes car \
|
| 181 |
+
--highlight
|
| 182 |
+
|
| 183 |
+
# High Lambda
|
| 184 |
+
echo " - High Lambda (0.0932)"
|
| 185 |
+
python roi_compressor.py \
|
| 186 |
+
--input data/images/car/0016cf15fa4d4e16.jpg \
|
| 187 |
+
--output results/lambda_high.jpg \
|
| 188 |
+
--checkpoint checkpoints/tic_lambda_0.0932.pth.tar \
|
| 189 |
+
--sigma 0.3 \
|
| 190 |
+
--lambda 0.0932 \
|
| 191 |
+
--seg-method yolo \
|
| 192 |
+
--seg-classes car \
|
| 193 |
+
--highlight
|
| 194 |
+
|
| 195 |
+
echo
|
| 196 |
+
echo "---------------------------------------------------------"
|
| 197 |
+
echo
|
| 198 |
+
|
| 199 |
+
# -------------------------------------------------
|
| 200 |
+
# 5. Standalone Segmentation (Mask R-CNN)
|
| 201 |
+
# -------------------------------------------------
|
| 202 |
+
echo "--- Running Standalone Segmentation (Mask R-CNN) ---"
|
| 203 |
+
python roi_segmenter.py \
|
| 204 |
+
--input images/car/0016cf15fa4d4e16.jpg \
|
| 205 |
+
--output results/mask_rcnn.png \
|
| 206 |
+
--method maskrcnn \
|
| 207 |
+
--classes car \
|
| 208 |
+
--visualize
|
| 209 |
+
|
| 210 |
+
# -------------------------------------------------
|
| 211 |
+
# 6. Video Compression (Static Mode)
|
| 212 |
+
# -------------------------------------------------
|
| 213 |
+
echo "--- Running Video Compression (Static Mode) ---"
|
| 214 |
+
python roi_compressor.py \
|
| 215 |
+
--input data/videos/traffic.mp4 \
|
| 216 |
+
--output results/video_static.mp4 \
|
| 217 |
+
--quality-level 4 \
|
| 218 |
+
--sigma 0.4 \
|
| 219 |
+
--seg-method yolo \
|
| 220 |
+
--seg-classes car person \
|
| 221 |
+
--video-mode static \
|
| 222 |
+
--output-fps 10
|
| 223 |
+
|
| 224 |
+
# -------------------------------------------------
|
| 225 |
+
# 7. Video Compression (Dynamic Mode)
|
| 226 |
+
# -------------------------------------------------
|
| 227 |
+
echo "--- Running Video Compression (Dynamic Mode) ---"
|
| 228 |
+
python roi_compressor.py \
|
| 229 |
+
--input data/videos/traffic.mp4 \
|
| 230 |
+
--output results/video_dynamic.mp4 \
|
| 231 |
+
--quality-level 4 \
|
| 232 |
+
--seg-method yolo \
|
| 233 |
+
--seg-classes car person \
|
| 234 |
+
--video-mode dynamic \
|
| 235 |
+
--target-bandwidth-kbps 800 \
|
| 236 |
+
--min-fps 5 \
|
| 237 |
+
--max-fps 20
|
| 238 |
+
|
| 239 |
+
# -------------------------------------------------
|
| 240 |
+
# 8. Video Detection Evaluation (Static Mode)
|
| 241 |
+
# -------------------------------------------------
|
| 242 |
+
echo "--- Running Video Detection Evaluation (Static Mode) ---"
|
| 243 |
+
python roi_detection_eval.py \
|
| 244 |
+
--before data/videos/traffic.mp4 \
|
| 245 |
+
--video-mode static \
|
| 246 |
+
--sigma 0.3 \
|
| 247 |
+
--output-fps 10 \
|
| 248 |
+
--quality-level 4 \
|
| 249 |
+
--seg-method yolo \
|
| 250 |
+
--seg-classes car person \
|
| 251 |
+
--detectors yolo \
|
| 252 |
+
--max-video-frames 30 \
|
| 253 |
+
--video-sample-interval 3 \
|
| 254 |
+
--save-after results/video_eval_compressed.mp4 \
|
| 255 |
+
--out results/video_detection_eval.json \
|
| 256 |
+
--viz-dir results/video_detection_viz
|
| 257 |
+
|
| 258 |
+
# -------------------------------------------------
|
| 259 |
+
# 9. Video Detection Evaluation (Comparing Two Videos)
|
| 260 |
+
# -------------------------------------------------
|
| 261 |
+
echo "--- Running Video Detection Evaluation (Compare Two Videos) ---"
|
| 262 |
+
python roi_detection_eval.py \
|
| 263 |
+
--before data/videos/traffic.mp4 \
|
| 264 |
+
--after results/video_static.mp4 \
|
| 265 |
+
--detectors yolo \
|
| 266 |
+
--max-video-frames 30 \
|
| 267 |
+
--video-sample-interval 3 \
|
| 268 |
+
--out results/video_compare_eval.json
|
| 269 |
+
|
| 270 |
+
echo
|
| 271 |
+
echo "============================="
|
| 272 |
+
echo "All examples complete! Check results/ directory"
|
model_cache.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Centralized defaults for model checkpoint/cache locations.
|
| 2 |
+
|
| 3 |
+
Goal: keep all auto-downloaded model artifacts inside this repo's `checkpoints/`
|
| 4 |
+
directory by default (instead of user-wide cache dirs or repo root).
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
PROJECT_ROOT = Path(__file__).resolve().parent
|
| 14 |
+
CHECKPOINTS_DIR = PROJECT_ROOT / "checkpoints"
|
| 15 |
+
|
| 16 |
+
# Hugging Face will create subfolders like `hub/`, `datasets/`, etc under HF_HOME.
|
| 17 |
+
HF_HOME_DIR = CHECKPOINTS_DIR / "hf"
|
| 18 |
+
|
| 19 |
+
# Torchvision uses torch.hub.load_state_dict_from_url which respects TORCH_HOME.
|
| 20 |
+
TORCH_HOME_DIR = CHECKPOINTS_DIR / "torch"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def ensure_default_checkpoint_dirs() -> None:
|
| 24 |
+
"""Ensure checkpoint dirs exist and set cache-related env vars.
|
| 25 |
+
|
| 26 |
+
This is intentionally a best-effort helper. If the user has explicitly set
|
| 27 |
+
env vars already, we do not override them.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
CHECKPOINTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 31 |
+
HF_HOME_DIR.mkdir(parents=True, exist_ok=True)
|
| 32 |
+
TORCH_HOME_DIR.mkdir(parents=True, exist_ok=True)
|
| 33 |
+
|
| 34 |
+
# Hugging Face
|
| 35 |
+
os.environ.setdefault("HF_HOME", str(HF_HOME_DIR))
|
| 36 |
+
# Compatibility env vars across transformers/huggingface-hub versions.
|
| 37 |
+
os.environ.setdefault("TRANSFORMERS_CACHE", str(HF_HOME_DIR / "hub"))
|
| 38 |
+
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", str(HF_HOME_DIR / "hub"))
|
| 39 |
+
|
| 40 |
+
# Torch / torchvision
|
| 41 |
+
os.environ.setdefault("TORCH_HOME", str(TORCH_HOME_DIR))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def hf_cache_dir() -> Path:
|
| 45 |
+
ensure_default_checkpoint_dirs()
|
| 46 |
+
return HF_HOME_DIR
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def torch_home_dir() -> Path:
|
| 50 |
+
ensure_default_checkpoint_dirs()
|
| 51 |
+
return TORCH_HOME_DIR
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def checkpoints_dir() -> Path:
|
| 55 |
+
ensure_default_checkpoint_dirs()
|
| 56 |
+
return CHECKPOINTS_DIR
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def default_checkpoint_path(filename: str) -> str:
|
| 60 |
+
"""Return an absolute path under `checkpoints/` for a given filename."""
|
| 61 |
+
|
| 62 |
+
return str(checkpoints_dir() / filename)
|
requirements.txt
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core dependencies
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
torchvision>=0.15.0
|
| 4 |
+
numpy>=1.24.0
|
| 5 |
+
pillow>=9.5.0
|
| 6 |
+
matplotlib>=3.7.0
|
| 7 |
+
# Prefer headless OpenCV for servers / Hugging Face Spaces.
|
| 8 |
+
opencv-python-headless>=4.7.0
|
| 9 |
+
|
| 10 |
+
# Model dependencies
|
| 11 |
+
compressai>=1.2.4
|
| 12 |
+
timm>=0.9.0
|
| 13 |
+
|
| 14 |
+
# Segmentation
|
| 15 |
+
transformers>=4.36.0
|
| 16 |
+
ultralytics>=8.0.0
|
| 17 |
+
|
| 18 |
+
# Detection (optional)
|
| 19 |
+
effdet>=0.4.1
|
| 20 |
+
|
| 21 |
+
# Demo app (Hugging Face Spaces)
|
| 22 |
+
gradio>=6.2.0,<7
|
| 23 |
+
openai>=1.0.0
|
| 24 |
+
gradio_client
|
| 25 |
+
scipy
|
| 26 |
+
tqdm
|
roi_compressor.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CLI for ROI-based image/video compression using modular compression framework.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from segmentation import create_segmenter
|
| 15 |
+
from vae import load_checkpoint, compress_image
|
| 16 |
+
from vae.visualization import create_comparison_grid
|
| 17 |
+
from video import VideoProcessor, CompressionSettings
|
| 18 |
+
from video.video_processor import frames_to_video_bytes
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# Command-line interface
|
| 22 |
+
def main():
|
| 23 |
+
parser = argparse.ArgumentParser(description="ROI-based Image/Video Compressor.")
|
| 24 |
+
# I/O
|
| 25 |
+
parser.add_argument("--input", required=True, help="Path to input image or video file.")
|
| 26 |
+
parser.add_argument("--output", required=True, help="Path to save compressed output file.")
|
| 27 |
+
parser.add_argument("--checkpoint", default="checkpoints/tic_lambda_0.0483.pth.tar", help="Path to VAE model checkpoint.")
|
| 28 |
+
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run on.")
|
| 29 |
+
|
| 30 |
+
# General Compression Settings
|
| 31 |
+
parser.add_argument("--quality-level", type=int, default=4, choices=range(7), help="Base quality level (0-6, higher is better). Affects VAE model selection.")
|
| 32 |
+
parser.add_argument("--sigma", type=float, default=0.3, help="Background quality factor (0.01-1.0). Lower means more background compression.")
|
| 33 |
+
|
| 34 |
+
# Segmentation Settings
|
| 35 |
+
parser.add_argument("--seg-method", default="yolo", help="Segmentation method to use.")
|
| 36 |
+
parser.add_argument("--seg-classes", nargs="+", required=True, help="Classes to segment as ROI.")
|
| 37 |
+
parser.add_argument("--seg-model", help="Path to a specific segmentation model checkpoint (optional).")
|
| 38 |
+
|
| 39 |
+
# Video-Specific Settings
|
| 40 |
+
parser.add_argument("--video-mode", choices=['static', 'dynamic'], default='static', help="Video compression mode: 'static' for fixed settings, 'dynamic' for adaptive.")
|
| 41 |
+
parser.add_argument("--target-bandwidth-kbps", type=int, default=1000, help="[Dynamic Mode] Target bandwidth in kbps.")
|
| 42 |
+
parser.add_argument("--output-fps", type=float, default=15.0, help="[Static Mode] Output framerate.")
|
| 43 |
+
parser.add_argument("--min-fps", type=float, default=5.0, help="[Dynamic Mode] Minimum framerate.")
|
| 44 |
+
parser.add_argument("--max-fps", type=float, default=30.0, help="[Dynamic Mode] Maximum framerate.")
|
| 45 |
+
parser.add_argument("--chunk-duration-sec", type=float, default=1.0, help="[Dynamic Mode] Duration of video chunks for analysis.")
|
| 46 |
+
|
| 47 |
+
# Detection/Tracking Settings (for video)
|
| 48 |
+
parser.add_argument("--detection-method", default="yolo", help="Object detection method for video.")
|
| 49 |
+
parser.add_argument("--enable-tracking", action="store_true", help="Enable object tracking in video.")
|
| 50 |
+
|
| 51 |
+
# Visualization
|
| 52 |
+
parser.add_argument("--highlight", action="store_true", help="Create a comparison grid image (for image input only).")
|
| 53 |
+
parser.add_argument("--viz-dir", help="Directory to save visualization artifacts (e.g., masks).")
|
| 54 |
+
|
| 55 |
+
args = parser.parse_args()
|
| 56 |
+
|
| 57 |
+
# --- Input Type Check ---
|
| 58 |
+
input_path = args.input.lower()
|
| 59 |
+
is_video = any(input_path.endswith(ext) for ext in ['.mp4', '.avi', '.mov', '.mkv'])
|
| 60 |
+
|
| 61 |
+
if is_video:
|
| 62 |
+
print("Processing video input...")
|
| 63 |
+
process_video(args)
|
| 64 |
+
else:
|
| 65 |
+
print("Processing image input...")
|
| 66 |
+
process_image(args)
|
| 67 |
+
|
| 68 |
+
def process_image(args):
|
| 69 |
+
"""Compresses a single image."""
|
| 70 |
+
print(f"Loading VAE model from {args.checkpoint}...")
|
| 71 |
+
model = load_checkpoint(args.checkpoint, device=args.device)
|
| 72 |
+
model.eval()
|
| 73 |
+
|
| 74 |
+
print(f"Loading segmenter '{args.seg_method}'...")
|
| 75 |
+
seg_kwargs = dict(device=args.device)
|
| 76 |
+
if args.seg_model:
|
| 77 |
+
seg_kwargs['model_path'] = args.seg_model
|
| 78 |
+
segmenter = create_segmenter(args.seg_method, **seg_kwargs)
|
| 79 |
+
|
| 80 |
+
print(f"Loading image from {args.input}...")
|
| 81 |
+
image = Image.open(args.input).convert("RGB")
|
| 82 |
+
|
| 83 |
+
print(f"Segmenting image for classes: {args.seg_classes}...")
|
| 84 |
+
mask = segmenter(image, target_classes=args.seg_classes)
|
| 85 |
+
|
| 86 |
+
if args.viz_dir:
|
| 87 |
+
if not os.path.exists(args.viz_dir):
|
| 88 |
+
os.makedirs(args.viz_dir)
|
| 89 |
+
mask_path = os.path.join(args.viz_dir, "mask.png")
|
| 90 |
+
Image.fromarray((mask * 255).astype(np.uint8)).save(mask_path)
|
| 91 |
+
print(f"Saved segmentation mask to {mask_path}")
|
| 92 |
+
|
| 93 |
+
print(f"Compressing image with sigma={args.sigma}...")
|
| 94 |
+
result = compress_image(
|
| 95 |
+
image,
|
| 96 |
+
mask,
|
| 97 |
+
model,
|
| 98 |
+
sigma=args.sigma,
|
| 99 |
+
device=args.device
|
| 100 |
+
)
|
| 101 |
+
compressed_img = result['compressed']
|
| 102 |
+
bpp = result['bpp']
|
| 103 |
+
|
| 104 |
+
print(f"Saving compressed image to {args.output} (BPP: {bpp:.4f})")
|
| 105 |
+
compressed_img.save(args.output)
|
| 106 |
+
|
| 107 |
+
if args.highlight:
|
| 108 |
+
print("Creating comparison grid...")
|
| 109 |
+
lambda_val = float(os.path.basename(args.checkpoint).split('_')[-1].replace('.pth.tar', ''))
|
| 110 |
+
grid = create_comparison_grid(image, compressed_img, mask, bpp, args.sigma, lambda_val)
|
| 111 |
+
grid_path = args.output.replace(os.path.splitext(args.output)[1], "_comparison.jpg")
|
| 112 |
+
grid.save(grid_path)
|
| 113 |
+
print(f"Saved comparison grid to {grid_path}")
|
| 114 |
+
|
| 115 |
+
def process_video(args):
|
| 116 |
+
"""Compresses a video using static or dynamic settings."""
|
| 117 |
+
processor = VideoProcessor(device=args.device)
|
| 118 |
+
print("Loading models for video processing...")
|
| 119 |
+
processor.load_models(
|
| 120 |
+
quality_level=args.quality_level,
|
| 121 |
+
segmentation_method=args.seg_method,
|
| 122 |
+
detection_method=args.detection_method,
|
| 123 |
+
enable_tracking=args.enable_tracking,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Simple progress callback
|
| 127 |
+
def progress_callback(current, total, message):
|
| 128 |
+
percent = int(100 * current / max(1, total))
|
| 129 |
+
print(f"[{percent:3d}%] {message}")
|
| 130 |
+
|
| 131 |
+
if args.video_mode == 'static':
|
| 132 |
+
settings = CompressionSettings(
|
| 133 |
+
mode='static',
|
| 134 |
+
quality_level=args.quality_level,
|
| 135 |
+
sigma=args.sigma,
|
| 136 |
+
output_fps=args.output_fps,
|
| 137 |
+
target_classes=args.seg_classes,
|
| 138 |
+
)
|
| 139 |
+
print(f"Starting STATIC video compression with FPS={settings.output_fps}, Sigma={settings.sigma}...")
|
| 140 |
+
print("Using offline batch processing (GPU memory optimized)...")
|
| 141 |
+
chunks = processor.process_static_offline(args.input, settings, progress_callback=progress_callback)
|
| 142 |
+
else: # dynamic
|
| 143 |
+
settings = CompressionSettings(
|
| 144 |
+
mode='dynamic',
|
| 145 |
+
target_bandwidth_kbps=args.target_bandwidth_kbps,
|
| 146 |
+
min_fps=args.min_fps,
|
| 147 |
+
max_fps=args.max_fps,
|
| 148 |
+
chunk_duration_sec=args.chunk_duration_sec,
|
| 149 |
+
target_classes=args.seg_classes,
|
| 150 |
+
quality_level=args.quality_level,
|
| 151 |
+
)
|
| 152 |
+
print(f"Starting DYNAMIC video compression with Target Bandwidth={settings.target_bandwidth_kbps} kbps...")
|
| 153 |
+
print("Using offline batch processing (GPU memory optimized)...")
|
| 154 |
+
chunks = processor.process_dynamic_offline(args.input, settings, progress_callback=progress_callback)
|
| 155 |
+
|
| 156 |
+
if not chunks:
|
| 157 |
+
print("No frames were processed. Exiting.")
|
| 158 |
+
return
|
| 159 |
+
|
| 160 |
+
# Collect all frames from chunks
|
| 161 |
+
all_frames = []
|
| 162 |
+
for chunk in chunks:
|
| 163 |
+
all_frames.extend(chunk.frames)
|
| 164 |
+
|
| 165 |
+
# Determine the final video's FPS
|
| 166 |
+
if args.video_mode == 'static':
|
| 167 |
+
final_fps = args.output_fps
|
| 168 |
+
else:
|
| 169 |
+
# For dynamic, use weighted average FPS from chunks
|
| 170 |
+
total_frames = sum(len(c.frames) for c in chunks)
|
| 171 |
+
total_duration = sum(len(c.frames) / c.fps for c in chunks)
|
| 172 |
+
final_fps = total_frames / total_duration if total_duration > 0 else args.max_fps
|
| 173 |
+
|
| 174 |
+
print(f"\nRe-encoding {len(all_frames)} frames into final video at ~{final_fps:.2f} FPS...")
|
| 175 |
+
video_bytes = frames_to_video_bytes(all_frames, fps=final_fps)
|
| 176 |
+
|
| 177 |
+
print(f"Saving compressed video to {args.output}...")
|
| 178 |
+
with open(args.output, "wb") as f:
|
| 179 |
+
f.write(video_bytes)
|
| 180 |
+
print("Done.")
|
| 181 |
+
|
| 182 |
+
if __name__ == "__main__":
|
| 183 |
+
main()
|
roi_detection_eval.py
ADDED
|
@@ -0,0 +1,639 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CLI: evaluate object detection before vs after ROI compression.
|
| 2 |
+
|
| 3 |
+
Supports both images and videos.
|
| 4 |
+
|
| 5 |
+
Image modes:
|
| 6 |
+
1) Compare two images:
|
| 7 |
+
python roi_detection_eval.py --before img.jpg --after img_compressed.jpg --detectors yolo fasterrcnn
|
| 8 |
+
|
| 9 |
+
2) Create the "after" image via ROI compression, then evaluate:
|
| 10 |
+
python roi_detection_eval.py --before img.jpg \
|
| 11 |
+
--checkpoint checkpoints/tic_lambda_0.0483.pth.tar --sigma 0.3 \
|
| 12 |
+
--seg-method yolo --seg-classes car person \
|
| 13 |
+
--detectors yolo fasterrcnn
|
| 14 |
+
|
| 15 |
+
Video modes:
|
| 16 |
+
3) Compare two videos:
|
| 17 |
+
python roi_detection_eval.py --before video.mp4 --after video_compressed.mp4 --detectors yolo
|
| 18 |
+
|
| 19 |
+
4) Create the "after" video via ROI compression, then evaluate:
|
| 20 |
+
python roi_detection_eval.py --before video.mp4 \
|
| 21 |
+
--video-mode static --sigma 0.3 --output-fps 10 \
|
| 22 |
+
--seg-method yolo --seg-classes car person \
|
| 23 |
+
--detectors yolo
|
| 24 |
+
|
| 25 |
+
Outputs JSON summary + optional visualization images/videos.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
from __future__ import annotations
|
| 29 |
+
|
| 30 |
+
import json
|
| 31 |
+
import os
|
| 32 |
+
import tempfile
|
| 33 |
+
from pathlib import Path
|
| 34 |
+
from typing import Dict, List, Optional, Tuple
|
| 35 |
+
|
| 36 |
+
import numpy as np
|
| 37 |
+
from PIL import Image
|
| 38 |
+
|
| 39 |
+
from detection import create_detector, get_available_detectors
|
| 40 |
+
from detection.utils import (
|
| 41 |
+
detections_to_dict,
|
| 42 |
+
draw_detections,
|
| 43 |
+
summarize_before_after,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.mkv', '.webm'}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _is_video(path: str) -> bool:
|
| 51 |
+
"""Check if file is a video based on extension."""
|
| 52 |
+
return Path(path).suffix.lower() in VIDEO_EXTENSIONS
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _load_image(path: str) -> Image.Image:
|
| 56 |
+
return Image.open(path).convert("RGB")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _normalize_open_vocab_classes(raw: Optional[str]) -> Optional[str]:
|
| 60 |
+
if raw is None:
|
| 61 |
+
return None
|
| 62 |
+
s = str(raw).strip()
|
| 63 |
+
return s or None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _maybe_make_after_image(
|
| 67 |
+
before_img: Image.Image,
|
| 68 |
+
args,
|
| 69 |
+
) -> Image.Image:
|
| 70 |
+
"""Create after image via ROI compression if --after not provided."""
|
| 71 |
+
if args.after:
|
| 72 |
+
return _load_image(args.after)
|
| 73 |
+
|
| 74 |
+
if not args.checkpoint:
|
| 75 |
+
raise SystemExit("Provide either --after or --checkpoint (to generate after via ROI compression).")
|
| 76 |
+
|
| 77 |
+
from vae import load_checkpoint, compress_image
|
| 78 |
+
from segmentation import create_segmenter
|
| 79 |
+
from segmentation.utils import load_mask
|
| 80 |
+
|
| 81 |
+
if args.mask:
|
| 82 |
+
mask = load_mask(args.mask)
|
| 83 |
+
else:
|
| 84 |
+
if not args.seg_method:
|
| 85 |
+
raise SystemExit("When not using --mask, provide --seg-method.")
|
| 86 |
+
|
| 87 |
+
seg_kwargs = {}
|
| 88 |
+
if args.seg_method == "yolo":
|
| 89 |
+
seg_kwargs["conf_threshold"] = args.seg_conf
|
| 90 |
+
if args.seg_method == "mask2former":
|
| 91 |
+
seg_kwargs["model_type"] = args.seg_model_type
|
| 92 |
+
|
| 93 |
+
segmenter = create_segmenter(args.seg_method, device=args.device, **seg_kwargs)
|
| 94 |
+
|
| 95 |
+
if args.seg_method == "sam3":
|
| 96 |
+
if not args.seg_prompt:
|
| 97 |
+
raise SystemExit("For --seg-method sam3, provide --seg-prompt.")
|
| 98 |
+
mask = segmenter(before_img, target_classes=[args.seg_prompt])
|
| 99 |
+
else:
|
| 100 |
+
if not args.seg_classes:
|
| 101 |
+
raise SystemExit("Provide --seg-classes (or use --seg-method sam3 + --seg-prompt).")
|
| 102 |
+
mask = segmenter(before_img, target_classes=args.seg_classes)
|
| 103 |
+
|
| 104 |
+
model = load_checkpoint(args.checkpoint, N=args.N, M=args.M, device=args.device)
|
| 105 |
+
out = compress_image(before_img, mask, model, sigma=float(args.sigma), device=args.device)
|
| 106 |
+
after_img = out["compressed"]
|
| 107 |
+
|
| 108 |
+
if args.save_after:
|
| 109 |
+
Path(args.save_after).parent.mkdir(parents=True, exist_ok=True)
|
| 110 |
+
after_img.save(args.save_after)
|
| 111 |
+
|
| 112 |
+
return after_img
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _extract_video_frames(
|
| 116 |
+
video_path: str,
|
| 117 |
+
max_frames: Optional[int] = None,
|
| 118 |
+
sample_interval: int = 1,
|
| 119 |
+
) -> Tuple[List[Image.Image], float]:
|
| 120 |
+
"""Extract frames from a video file.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
video_path: Path to video file
|
| 124 |
+
max_frames: Max frames to extract (None = all)
|
| 125 |
+
sample_interval: Extract every Nth frame
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
(frames, fps)
|
| 129 |
+
"""
|
| 130 |
+
try:
|
| 131 |
+
import cv2
|
| 132 |
+
except ImportError:
|
| 133 |
+
raise ImportError("opencv-python required for video processing")
|
| 134 |
+
|
| 135 |
+
cap = cv2.VideoCapture(video_path)
|
| 136 |
+
if not cap.isOpened():
|
| 137 |
+
raise ValueError(f"Cannot open video: {video_path}")
|
| 138 |
+
|
| 139 |
+
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
|
| 140 |
+
frames = []
|
| 141 |
+
frame_idx = 0
|
| 142 |
+
|
| 143 |
+
while True:
|
| 144 |
+
ret, frame = cap.read()
|
| 145 |
+
if not ret:
|
| 146 |
+
break
|
| 147 |
+
|
| 148 |
+
if frame_idx % sample_interval == 0:
|
| 149 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 150 |
+
frames.append(Image.fromarray(frame_rgb))
|
| 151 |
+
|
| 152 |
+
frame_idx += 1
|
| 153 |
+
|
| 154 |
+
if max_frames is not None and len(frames) >= max_frames:
|
| 155 |
+
break
|
| 156 |
+
|
| 157 |
+
cap.release()
|
| 158 |
+
return frames, fps
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def _maybe_make_after_video(
|
| 162 |
+
before_path: str,
|
| 163 |
+
args,
|
| 164 |
+
) -> Tuple[str, List[Image.Image], List[Image.Image]]:
|
| 165 |
+
"""Create after video via ROI compression if --after not provided.
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
(after_video_path, before_frames, after_frames)
|
| 169 |
+
"""
|
| 170 |
+
from video import VideoProcessor, CompressionSettings
|
| 171 |
+
from video.video_processor import frames_to_video_bytes
|
| 172 |
+
|
| 173 |
+
# Extract frames from before video
|
| 174 |
+
sample_interval = max(1, int(args.video_sample_interval))
|
| 175 |
+
before_frames, fps = _extract_video_frames(
|
| 176 |
+
before_path,
|
| 177 |
+
max_frames=args.max_video_frames,
|
| 178 |
+
sample_interval=sample_interval,
|
| 179 |
+
)
|
| 180 |
+
print(f"Extracted {len(before_frames)} frames from before video (sampled every {sample_interval} frames)")
|
| 181 |
+
|
| 182 |
+
if args.after:
|
| 183 |
+
# Load after video frames
|
| 184 |
+
after_frames, _ = _extract_video_frames(
|
| 185 |
+
args.after,
|
| 186 |
+
max_frames=args.max_video_frames,
|
| 187 |
+
sample_interval=sample_interval,
|
| 188 |
+
)
|
| 189 |
+
print(f"Extracted {len(after_frames)} frames from after video")
|
| 190 |
+
return args.after, before_frames, after_frames
|
| 191 |
+
|
| 192 |
+
# Create compressed video
|
| 193 |
+
print("Creating compressed video via ROI compression...")
|
| 194 |
+
|
| 195 |
+
processor = VideoProcessor(device=args.device)
|
| 196 |
+
processor.load_models(
|
| 197 |
+
quality_level=args.quality_level,
|
| 198 |
+
segmentation_method=args.seg_method or "yolo",
|
| 199 |
+
detection_method="yolo",
|
| 200 |
+
enable_tracking=False,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
if args.video_mode == "static":
|
| 204 |
+
settings = CompressionSettings(
|
| 205 |
+
mode="static",
|
| 206 |
+
quality_level=args.quality_level,
|
| 207 |
+
sigma=args.sigma,
|
| 208 |
+
output_fps=args.output_fps,
|
| 209 |
+
target_classes=args.seg_classes or [],
|
| 210 |
+
enable_tracking=False,
|
| 211 |
+
)
|
| 212 |
+
chunks = processor.process_static(before_path, settings)
|
| 213 |
+
else: # dynamic
|
| 214 |
+
settings = CompressionSettings(
|
| 215 |
+
mode="dynamic",
|
| 216 |
+
target_bandwidth_kbps=args.target_bandwidth_kbps,
|
| 217 |
+
min_fps=args.min_fps,
|
| 218 |
+
max_fps=args.max_fps,
|
| 219 |
+
chunk_duration_sec=args.chunk_duration_sec,
|
| 220 |
+
target_classes=args.seg_classes or [],
|
| 221 |
+
quality_level=args.quality_level,
|
| 222 |
+
enable_tracking=False,
|
| 223 |
+
)
|
| 224 |
+
chunks = processor.process_dynamic(before_path, settings)
|
| 225 |
+
|
| 226 |
+
# Collect compressed frames
|
| 227 |
+
all_compressed_frames = []
|
| 228 |
+
for chunk in chunks:
|
| 229 |
+
all_compressed_frames.extend(chunk.frames)
|
| 230 |
+
if args.max_video_frames and len(all_compressed_frames) >= args.max_video_frames:
|
| 231 |
+
all_compressed_frames = all_compressed_frames[:args.max_video_frames]
|
| 232 |
+
break
|
| 233 |
+
|
| 234 |
+
print(f"Compressed {len(all_compressed_frames)} frames")
|
| 235 |
+
|
| 236 |
+
# Sample the after frames at the same interval as before frames
|
| 237 |
+
# Note: The compression may change frame count, so we align by time proportion
|
| 238 |
+
after_frames = []
|
| 239 |
+
if len(all_compressed_frames) >= len(before_frames):
|
| 240 |
+
# Sample at regular intervals
|
| 241 |
+
step = len(all_compressed_frames) / len(before_frames)
|
| 242 |
+
for i in range(len(before_frames)):
|
| 243 |
+
idx = min(int(i * step), len(all_compressed_frames) - 1)
|
| 244 |
+
after_frames.append(all_compressed_frames[idx])
|
| 245 |
+
else:
|
| 246 |
+
# Use all available frames
|
| 247 |
+
after_frames = all_compressed_frames
|
| 248 |
+
# Extend before_frames to match
|
| 249 |
+
before_frames = before_frames[:len(after_frames)]
|
| 250 |
+
|
| 251 |
+
# Save compressed video if requested
|
| 252 |
+
after_path = args.save_after
|
| 253 |
+
if after_path:
|
| 254 |
+
video_bytes = frames_to_video_bytes(all_compressed_frames, fps=args.output_fps)
|
| 255 |
+
Path(after_path).parent.mkdir(parents=True, exist_ok=True)
|
| 256 |
+
with open(after_path, "wb") as f:
|
| 257 |
+
f.write(video_bytes)
|
| 258 |
+
print(f"Saved compressed video to {after_path}")
|
| 259 |
+
else:
|
| 260 |
+
# Create temp file
|
| 261 |
+
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
|
| 262 |
+
after_path = tmp.name
|
| 263 |
+
video_bytes = frames_to_video_bytes(all_compressed_frames, fps=args.output_fps)
|
| 264 |
+
with open(after_path, "wb") as f:
|
| 265 |
+
f.write(video_bytes)
|
| 266 |
+
|
| 267 |
+
return after_path, before_frames, after_frames
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def _evaluate_video_detections(
|
| 271 |
+
before_frames: List[Image.Image],
|
| 272 |
+
after_frames: List[Image.Image],
|
| 273 |
+
detector,
|
| 274 |
+
det_kwargs: Dict,
|
| 275 |
+
iou_threshold: float,
|
| 276 |
+
) -> Dict:
|
| 277 |
+
"""Evaluate detection across video frames.
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
Summary dict with per-frame and aggregate stats
|
| 281 |
+
"""
|
| 282 |
+
frame_results = []
|
| 283 |
+
total_before_dets = 0
|
| 284 |
+
total_after_dets = 0
|
| 285 |
+
total_matched = 0
|
| 286 |
+
total_lost = 0
|
| 287 |
+
total_new = 0
|
| 288 |
+
|
| 289 |
+
for i, (before_frame, after_frame) in enumerate(zip(before_frames, after_frames)):
|
| 290 |
+
before_dets = detector(before_frame, **det_kwargs)
|
| 291 |
+
after_dets = detector(after_frame, **det_kwargs)
|
| 292 |
+
|
| 293 |
+
summary = summarize_before_after(before_dets, after_dets, iou_threshold=iou_threshold)
|
| 294 |
+
|
| 295 |
+
# Extract values using correct keys from summarize_before_after
|
| 296 |
+
num_before = summary["num_before"]
|
| 297 |
+
num_after = summary["num_after"]
|
| 298 |
+
matched = summary["matched"]
|
| 299 |
+
lost = num_before - matched
|
| 300 |
+
new_dets = summary["new_after"]
|
| 301 |
+
|
| 302 |
+
frame_results.append({
|
| 303 |
+
"frame_index": i,
|
| 304 |
+
"before_count": num_before,
|
| 305 |
+
"after_count": num_after,
|
| 306 |
+
"matched": matched,
|
| 307 |
+
"lost": lost,
|
| 308 |
+
"new_detections": new_dets,
|
| 309 |
+
})
|
| 310 |
+
|
| 311 |
+
total_before_dets += num_before
|
| 312 |
+
total_after_dets += num_after
|
| 313 |
+
total_matched += matched
|
| 314 |
+
total_lost += lost
|
| 315 |
+
total_new += new_dets
|
| 316 |
+
|
| 317 |
+
retention_rate = total_matched / max(total_before_dets, 1)
|
| 318 |
+
|
| 319 |
+
return {
|
| 320 |
+
"total_frames": len(before_frames),
|
| 321 |
+
"total_before_detections": total_before_dets,
|
| 322 |
+
"total_after_detections": total_after_dets,
|
| 323 |
+
"total_matched": total_matched,
|
| 324 |
+
"total_lost": total_lost,
|
| 325 |
+
"total_new": total_new,
|
| 326 |
+
"retention_rate": retention_rate,
|
| 327 |
+
"avg_before_per_frame": total_before_dets / max(len(before_frames), 1),
|
| 328 |
+
"avg_after_per_frame": total_after_dets / max(len(after_frames), 1),
|
| 329 |
+
"frame_results": frame_results,
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def _create_video_visualization(
|
| 334 |
+
before_frames: List[Image.Image],
|
| 335 |
+
after_frames: List[Image.Image],
|
| 336 |
+
before_dets_list: List[List],
|
| 337 |
+
after_dets_list: List[List],
|
| 338 |
+
output_dir: Path,
|
| 339 |
+
method: str,
|
| 340 |
+
fps: float = 10.0,
|
| 341 |
+
) -> None:
|
| 342 |
+
"""Create visualization videos with detections drawn."""
|
| 343 |
+
from video.video_processor import frames_to_video_bytes
|
| 344 |
+
|
| 345 |
+
# Draw detections on frames
|
| 346 |
+
before_viz = []
|
| 347 |
+
after_viz = []
|
| 348 |
+
|
| 349 |
+
for i, (bf, af) in enumerate(zip(before_frames, after_frames)):
|
| 350 |
+
b_dets = before_dets_list[i] if i < len(before_dets_list) else []
|
| 351 |
+
a_dets = after_dets_list[i] if i < len(after_dets_list) else []
|
| 352 |
+
|
| 353 |
+
before_viz.append(draw_detections(bf, b_dets, color=(0, 255, 0)))
|
| 354 |
+
after_viz.append(draw_detections(af, a_dets, color=(255, 0, 0)))
|
| 355 |
+
|
| 356 |
+
# Save visualization videos
|
| 357 |
+
before_bytes = frames_to_video_bytes(before_viz, fps=fps)
|
| 358 |
+
after_bytes = frames_to_video_bytes(after_viz, fps=fps)
|
| 359 |
+
|
| 360 |
+
with open(output_dir / f"{method}_before.mp4", "wb") as f:
|
| 361 |
+
f.write(before_bytes)
|
| 362 |
+
with open(output_dir / f"{method}_after.mp4", "wb") as f:
|
| 363 |
+
f.write(after_bytes)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def main_image(args) -> None:
|
| 367 |
+
"""Run detection evaluation on images."""
|
| 368 |
+
from model_cache import default_checkpoint_path
|
| 369 |
+
|
| 370 |
+
before_img = _load_image(args.before)
|
| 371 |
+
after_img = _maybe_make_after_image(before_img, args)
|
| 372 |
+
|
| 373 |
+
det_methods = args.detectors
|
| 374 |
+
if len(det_methods) == 1 and det_methods[0].lower() == "all":
|
| 375 |
+
det_methods = get_available_detectors()
|
| 376 |
+
|
| 377 |
+
results: Dict = {
|
| 378 |
+
"type": "image",
|
| 379 |
+
"before": str(Path(args.before)),
|
| 380 |
+
"after": str(Path(args.after)) if args.after else None,
|
| 381 |
+
"det_conf": float(args.det_conf),
|
| 382 |
+
"iou_threshold": float(args.iou),
|
| 383 |
+
"open_vocab_classes": _normalize_open_vocab_classes(args.open_vocab_classes),
|
| 384 |
+
"detectors": {},
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
viz_dir = Path(args.viz_dir) if args.viz_dir else None
|
| 388 |
+
if viz_dir:
|
| 389 |
+
viz_dir.mkdir(parents=True, exist_ok=True)
|
| 390 |
+
|
| 391 |
+
for method in det_methods:
|
| 392 |
+
det_kwargs = _get_detector_kwargs(method, args)
|
| 393 |
+
detector = create_detector(method, device=args.device, **det_kwargs)
|
| 394 |
+
|
| 395 |
+
call_kwargs = {"conf_threshold": args.det_conf}
|
| 396 |
+
if method in {"yolo_world", "grounding_dino"}:
|
| 397 |
+
ov = _normalize_open_vocab_classes(args.open_vocab_classes)
|
| 398 |
+
if not ov:
|
| 399 |
+
raise SystemExit(
|
| 400 |
+
f"Detector '{method}' is open-vocabulary; provide --open-vocab-classes (e.g. 'person,car')."
|
| 401 |
+
)
|
| 402 |
+
call_kwargs["classes"] = ov
|
| 403 |
+
|
| 404 |
+
before_dets = detector(before_img, **call_kwargs)
|
| 405 |
+
after_dets = detector(after_img, **call_kwargs)
|
| 406 |
+
|
| 407 |
+
summary = summarize_before_after(before_dets, after_dets, iou_threshold=args.iou)
|
| 408 |
+
|
| 409 |
+
results["detectors"][method] = {
|
| 410 |
+
"summary": summary,
|
| 411 |
+
"before_detections": detections_to_dict(before_dets),
|
| 412 |
+
"after_detections": detections_to_dict(after_dets),
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
if viz_dir:
|
| 416 |
+
b = draw_detections(before_img, before_dets, color=(0, 255, 0))
|
| 417 |
+
a = draw_detections(after_img, after_dets, color=(255, 0, 0))
|
| 418 |
+
b.save(viz_dir / f"{method}_before.png")
|
| 419 |
+
a.save(viz_dir / f"{method}_after.png")
|
| 420 |
+
|
| 421 |
+
out_path = Path(args.out)
|
| 422 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 423 |
+
out_path.write_text(json.dumps(results, indent=2))
|
| 424 |
+
|
| 425 |
+
print(f"Wrote: {out_path}")
|
| 426 |
+
if viz_dir:
|
| 427 |
+
print(f"Wrote visualizations to: {viz_dir}")
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def main_video(args) -> None:
|
| 431 |
+
"""Run detection evaluation on videos."""
|
| 432 |
+
print(f"Video detection evaluation: {args.before}")
|
| 433 |
+
|
| 434 |
+
after_path, before_frames, after_frames = _maybe_make_after_video(args.before, args)
|
| 435 |
+
|
| 436 |
+
det_methods = args.detectors
|
| 437 |
+
if len(det_methods) == 1 and det_methods[0].lower() == "all":
|
| 438 |
+
det_methods = get_available_detectors()
|
| 439 |
+
|
| 440 |
+
results: Dict = {
|
| 441 |
+
"type": "video",
|
| 442 |
+
"before": str(Path(args.before)),
|
| 443 |
+
"after": str(Path(after_path)),
|
| 444 |
+
"det_conf": float(args.det_conf),
|
| 445 |
+
"iou_threshold": float(args.iou),
|
| 446 |
+
"open_vocab_classes": _normalize_open_vocab_classes(args.open_vocab_classes),
|
| 447 |
+
"video_settings": {
|
| 448 |
+
"mode": args.video_mode,
|
| 449 |
+
"sigma": args.sigma,
|
| 450 |
+
"output_fps": args.output_fps,
|
| 451 |
+
"quality_level": args.quality_level,
|
| 452 |
+
"frames_evaluated": len(before_frames),
|
| 453 |
+
},
|
| 454 |
+
"detectors": {},
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
viz_dir = Path(args.viz_dir) if args.viz_dir else None
|
| 458 |
+
if viz_dir:
|
| 459 |
+
viz_dir.mkdir(parents=True, exist_ok=True)
|
| 460 |
+
|
| 461 |
+
for method in det_methods:
|
| 462 |
+
print(f"Evaluating with detector: {method}")
|
| 463 |
+
det_kwargs = _get_detector_kwargs(method, args)
|
| 464 |
+
detector = create_detector(method, device=args.device, **det_kwargs)
|
| 465 |
+
|
| 466 |
+
call_kwargs = {"conf_threshold": args.det_conf}
|
| 467 |
+
if method in {"yolo_world", "grounding_dino"}:
|
| 468 |
+
ov = _normalize_open_vocab_classes(args.open_vocab_classes)
|
| 469 |
+
if not ov:
|
| 470 |
+
raise SystemExit(
|
| 471 |
+
f"Detector '{method}' is open-vocabulary; provide --open-vocab-classes (e.g. 'person,car')."
|
| 472 |
+
)
|
| 473 |
+
call_kwargs["classes"] = ov
|
| 474 |
+
|
| 475 |
+
# Evaluate across all frames
|
| 476 |
+
video_summary = _evaluate_video_detections(
|
| 477 |
+
before_frames,
|
| 478 |
+
after_frames,
|
| 479 |
+
detector,
|
| 480 |
+
call_kwargs,
|
| 481 |
+
args.iou,
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
results["detectors"][method] = {
|
| 485 |
+
"summary": {
|
| 486 |
+
"total_frames": video_summary["total_frames"],
|
| 487 |
+
"retention_rate": video_summary["retention_rate"],
|
| 488 |
+
"total_before_detections": video_summary["total_before_detections"],
|
| 489 |
+
"total_after_detections": video_summary["total_after_detections"],
|
| 490 |
+
"total_matched": video_summary["total_matched"],
|
| 491 |
+
"total_lost": video_summary["total_lost"],
|
| 492 |
+
"avg_before_per_frame": video_summary["avg_before_per_frame"],
|
| 493 |
+
"avg_after_per_frame": video_summary["avg_after_per_frame"],
|
| 494 |
+
},
|
| 495 |
+
"per_frame_results": video_summary["frame_results"],
|
| 496 |
+
}
|
| 497 |
+
|
| 498 |
+
print(f" Retention rate: {video_summary['retention_rate']:.2%}")
|
| 499 |
+
print(f" Avg detections: {video_summary['avg_before_per_frame']:.1f} before, {video_summary['avg_after_per_frame']:.1f} after")
|
| 500 |
+
|
| 501 |
+
if viz_dir:
|
| 502 |
+
# Create visualization videos
|
| 503 |
+
print(f" Creating visualization videos...")
|
| 504 |
+
before_dets_list = []
|
| 505 |
+
after_dets_list = []
|
| 506 |
+
for bf, af in zip(before_frames, after_frames):
|
| 507 |
+
before_dets_list.append(detector(bf, **call_kwargs))
|
| 508 |
+
after_dets_list.append(detector(af, **call_kwargs))
|
| 509 |
+
|
| 510 |
+
_create_video_visualization(
|
| 511 |
+
before_frames,
|
| 512 |
+
after_frames,
|
| 513 |
+
before_dets_list,
|
| 514 |
+
after_dets_list,
|
| 515 |
+
viz_dir,
|
| 516 |
+
method,
|
| 517 |
+
fps=args.output_fps,
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
out_path = Path(args.out)
|
| 521 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 522 |
+
out_path.write_text(json.dumps(results, indent=2))
|
| 523 |
+
|
| 524 |
+
print(f"\nWrote: {out_path}")
|
| 525 |
+
if viz_dir:
|
| 526 |
+
print(f"Wrote visualizations to: {viz_dir}")
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
def _get_detector_kwargs(method: str, args) -> Dict:
|
| 530 |
+
"""Get detector-specific kwargs."""
|
| 531 |
+
det_kwargs = {}
|
| 532 |
+
if method == "yolo":
|
| 533 |
+
det_kwargs["model_path"] = args.yolo_weights
|
| 534 |
+
if method == "yolo_world":
|
| 535 |
+
det_kwargs["model_path"] = args.yolo_world_weights
|
| 536 |
+
if method == "efficientdet":
|
| 537 |
+
det_kwargs["model_name"] = args.efficientdet_name
|
| 538 |
+
if method == "detr":
|
| 539 |
+
det_kwargs["model_name"] = args.detr_model
|
| 540 |
+
if method == "deformable_detr":
|
| 541 |
+
det_kwargs["model_name"] = args.deformable_detr_model
|
| 542 |
+
if method == "grounding_dino":
|
| 543 |
+
det_kwargs["model_name"] = args.grounding_dino_model
|
| 544 |
+
return det_kwargs
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def main() -> None:
|
| 548 |
+
import argparse
|
| 549 |
+
|
| 550 |
+
from model_cache import default_checkpoint_path
|
| 551 |
+
|
| 552 |
+
parser = argparse.ArgumentParser(
|
| 553 |
+
description="Evaluate object detection before vs after ROI compression (images or videos)"
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
# Input/Output
|
| 557 |
+
parser.add_argument("--before", required=True, help="Path to original (before) image or video")
|
| 558 |
+
parser.add_argument("--after", help="Path to already-compressed (after) image or video")
|
| 559 |
+
parser.add_argument("--out", default="results/detection_eval.json", help="Where to write JSON results")
|
| 560 |
+
parser.add_argument("--viz-dir", default=None, help="If set, write visualization images/videos here")
|
| 561 |
+
parser.add_argument("--save-after", help="Save generated after image/video here")
|
| 562 |
+
|
| 563 |
+
# ROI compression (if --after is not provided)
|
| 564 |
+
parser.add_argument("--checkpoint", help="TIC checkpoint to generate after image (images only)")
|
| 565 |
+
parser.add_argument("--sigma", type=float, default=0.3, help="Background quality for ROI compression")
|
| 566 |
+
parser.add_argument("--mask", help="Optional mask path to use for ROI compression (images only)")
|
| 567 |
+
parser.add_argument("--seg-method", default=None, help="Segmentation method to build mask")
|
| 568 |
+
parser.add_argument("--seg-classes", nargs="+", default=None, help="Segmentation classes")
|
| 569 |
+
parser.add_argument("--seg-prompt", default=None, help="Segmentation prompt (sam3)")
|
| 570 |
+
parser.add_argument("--seg-conf", type=float, default=0.25, help="Segmentation conf for yolo")
|
| 571 |
+
parser.add_argument("--seg-model-type", default="coco", help="mask2former model_type")
|
| 572 |
+
|
| 573 |
+
# Video-specific settings
|
| 574 |
+
parser.add_argument("--video-mode", choices=["static", "dynamic"], default="static",
|
| 575 |
+
help="Video compression mode")
|
| 576 |
+
parser.add_argument("--quality-level", type=int, default=4, help="Quality level (1-5)")
|
| 577 |
+
parser.add_argument("--output-fps", type=float, default=10.0, help="Output FPS for compressed video")
|
| 578 |
+
parser.add_argument("--target-bandwidth-kbps", type=float, default=500.0,
|
| 579 |
+
help="Target bandwidth for dynamic mode")
|
| 580 |
+
parser.add_argument("--min-fps", type=float, default=5.0, help="Min FPS for dynamic mode")
|
| 581 |
+
parser.add_argument("--max-fps", type=float, default=30.0, help="Max FPS for dynamic mode")
|
| 582 |
+
parser.add_argument("--chunk-duration-sec", type=float, default=1.0, help="Chunk duration for dynamic mode")
|
| 583 |
+
parser.add_argument("--max-video-frames", type=int, default=100,
|
| 584 |
+
help="Max frames to evaluate from video (for efficiency)")
|
| 585 |
+
parser.add_argument("--video-sample-interval", type=int, default=5,
|
| 586 |
+
help="Sample every Nth frame from videos")
|
| 587 |
+
|
| 588 |
+
# Detection settings
|
| 589 |
+
parser.add_argument(
|
| 590 |
+
"--detectors",
|
| 591 |
+
nargs="+",
|
| 592 |
+
default=["yolo"],
|
| 593 |
+
help=f"Detectors to run (or 'all'). Available: {', '.join(get_available_detectors())}",
|
| 594 |
+
)
|
| 595 |
+
parser.add_argument("--det-conf", type=float, default=0.25, help="Detection confidence threshold")
|
| 596 |
+
parser.add_argument("--iou", type=float, default=0.5, help="IoU threshold for matching before↔after detections")
|
| 597 |
+
|
| 598 |
+
# Open-vocabulary detection
|
| 599 |
+
parser.add_argument(
|
| 600 |
+
"--open-vocab-classes",
|
| 601 |
+
default=None,
|
| 602 |
+
help="Comma-separated class prompts for open-vocabulary detectors",
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"], help="Device")
|
| 606 |
+
parser.add_argument("--N", type=int, default=192)
|
| 607 |
+
parser.add_argument("--M", type=int, default=192)
|
| 608 |
+
|
| 609 |
+
# Per-detector config
|
| 610 |
+
parser.add_argument(
|
| 611 |
+
"--yolo-weights",
|
| 612 |
+
default=default_checkpoint_path("yolo26x.pt"),
|
| 613 |
+
help="Ultralytics YOLO weights path/name",
|
| 614 |
+
)
|
| 615 |
+
parser.add_argument(
|
| 616 |
+
"--yolo-world-weights",
|
| 617 |
+
default=default_checkpoint_path("yolo26s-world.pt"),
|
| 618 |
+
help="Ultralytics YOLO-World weights path/name",
|
| 619 |
+
)
|
| 620 |
+
parser.add_argument("--efficientdet-name", default="tf_efficientdet_d0", help="EfficientDet model name")
|
| 621 |
+
parser.add_argument("--detr-model", default="facebook/detr-resnet-50", help="DETR model name")
|
| 622 |
+
parser.add_argument("--deformable-detr-model", default="SenseTime/deformable-detr", help="Deformable DETR model")
|
| 623 |
+
parser.add_argument(
|
| 624 |
+
"--grounding-dino-model",
|
| 625 |
+
default="IDEA-Research/grounding-dino-base",
|
| 626 |
+
help="GroundingDINO model name",
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
args = parser.parse_args()
|
| 630 |
+
|
| 631 |
+
# Determine if input is image or video
|
| 632 |
+
if _is_video(args.before):
|
| 633 |
+
main_video(args)
|
| 634 |
+
else:
|
| 635 |
+
main_image(args)
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
if __name__ == "__main__":
|
| 639 |
+
main()
|
roi_segmenter.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CLI for ROI segmentation using modular segmentation framework.
|
| 3 |
+
Supports both image and video input with batched processing.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import sys
|
| 7 |
+
import time
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from segmentation import create_segmenter
|
| 14 |
+
from segmentation.utils import visualize_mask, save_mask, calculate_roi_stats
|
| 15 |
+
from video import estimate_batch_sizes, smooth_masks_sdf
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.mkv', '.webm'}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def is_video(path: str) -> bool:
|
| 22 |
+
"""Check if file is a video based on extension."""
|
| 23 |
+
return Path(path).suffix.lower() in VIDEO_EXTENSIONS
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def extract_video_frames(video_path: str, target_height: int = None) -> tuple[list[Image.Image], float]:
|
| 27 |
+
"""Extract frames from video file.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
video_path: Path to video file
|
| 31 |
+
target_height: Optional height to resize frames to (maintains aspect ratio)
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
(frames, fps)
|
| 35 |
+
"""
|
| 36 |
+
try:
|
| 37 |
+
import cv2
|
| 38 |
+
except ImportError:
|
| 39 |
+
raise ImportError("opencv-python required for video processing. Install with: pip install opencv-python")
|
| 40 |
+
|
| 41 |
+
cap = cv2.VideoCapture(video_path)
|
| 42 |
+
if not cap.isOpened():
|
| 43 |
+
raise ValueError(f"Cannot open video: {video_path}")
|
| 44 |
+
|
| 45 |
+
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
|
| 46 |
+
frames = []
|
| 47 |
+
|
| 48 |
+
while True:
|
| 49 |
+
ret, frame = cap.read()
|
| 50 |
+
if not ret:
|
| 51 |
+
break
|
| 52 |
+
|
| 53 |
+
if target_height:
|
| 54 |
+
h, w = frame.shape[:2]
|
| 55 |
+
scale = target_height / h
|
| 56 |
+
new_w = max(1, int(w * scale))
|
| 57 |
+
frame = cv2.resize(frame, (new_w, target_height))
|
| 58 |
+
|
| 59 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 60 |
+
frames.append(Image.fromarray(frame_rgb))
|
| 61 |
+
|
| 62 |
+
cap.release()
|
| 63 |
+
return frames, fps
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def save_masks_as_video(masks: list[np.ndarray], output_path: str, fps: float = 30.0) -> None:
|
| 67 |
+
"""Save mask sequence as video file."""
|
| 68 |
+
try:
|
| 69 |
+
import cv2
|
| 70 |
+
except ImportError:
|
| 71 |
+
raise ImportError("opencv-python required. Install with: pip install opencv-python")
|
| 72 |
+
|
| 73 |
+
if not masks:
|
| 74 |
+
raise ValueError("No masks to save")
|
| 75 |
+
|
| 76 |
+
h, w = masks[0].shape
|
| 77 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 78 |
+
writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h), isColor=False)
|
| 79 |
+
|
| 80 |
+
for mask in masks:
|
| 81 |
+
mask_uint8 = (mask * 255).astype(np.uint8)
|
| 82 |
+
writer.write(mask_uint8)
|
| 83 |
+
|
| 84 |
+
writer.release()
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def process_image(args):
|
| 88 |
+
"""Process single image input."""
|
| 89 |
+
print(f"Segmenting {args.input}...")
|
| 90 |
+
print(f" Method: {args.method}")
|
| 91 |
+
if args.method == 'sam3':
|
| 92 |
+
print(f" Prompt: {args.prompt or ' '.join(args.classes)}")
|
| 93 |
+
else:
|
| 94 |
+
print(f" Classes: {args.classes}")
|
| 95 |
+
|
| 96 |
+
# Load image
|
| 97 |
+
image = Image.open(args.input).convert('RGB')
|
| 98 |
+
print(f" Image size: {image.size}")
|
| 99 |
+
|
| 100 |
+
# Create segmenter and mask
|
| 101 |
+
seg_kwargs = {}
|
| 102 |
+
if args.method == 'yolo':
|
| 103 |
+
seg_kwargs['conf_threshold'] = args.conf_threshold
|
| 104 |
+
|
| 105 |
+
segmenter = create_segmenter(args.method, device=args.device, **seg_kwargs)
|
| 106 |
+
|
| 107 |
+
targets = args.classes
|
| 108 |
+
if args.method == 'sam3' and args.prompt:
|
| 109 |
+
targets = [args.prompt]
|
| 110 |
+
|
| 111 |
+
mask = segmenter(image, target_classes=targets)
|
| 112 |
+
|
| 113 |
+
# Calculate statistics
|
| 114 |
+
stats = calculate_roi_stats(mask)
|
| 115 |
+
print(f" ROI coverage: {stats['roi_percentage']:.2f}% "
|
| 116 |
+
f"({stats['roi_pixels']}/{stats['total_pixels']} pixels)")
|
| 117 |
+
|
| 118 |
+
# Save mask
|
| 119 |
+
save_mask(mask, args.output)
|
| 120 |
+
print(f" Mask saved: {args.output}")
|
| 121 |
+
|
| 122 |
+
# Save visualization if requested
|
| 123 |
+
if args.visualize:
|
| 124 |
+
viz_path = args.output.replace('.', '_overlay.')
|
| 125 |
+
viz_img = visualize_mask(image, mask, alpha=0.5, color=(255, 0, 0))
|
| 126 |
+
viz_img.save(viz_path)
|
| 127 |
+
print(f" Visualization saved: {viz_path}")
|
| 128 |
+
|
| 129 |
+
print("Done!")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def process_video(args):
|
| 133 |
+
"""Process video input with batched segmentation."""
|
| 134 |
+
print(f"Processing video: {args.input}")
|
| 135 |
+
print(f" Method: {args.method}")
|
| 136 |
+
print(f" Classes: {args.classes}")
|
| 137 |
+
|
| 138 |
+
# Extract frames
|
| 139 |
+
print("Extracting frames...")
|
| 140 |
+
frames, fps = extract_video_frames(args.input, target_height=args.resize_height)
|
| 141 |
+
print(f" Extracted {len(frames)} frames at {fps:.2f} FPS")
|
| 142 |
+
|
| 143 |
+
# Create segmenter
|
| 144 |
+
seg_kwargs = {}
|
| 145 |
+
if args.method == 'yolo':
|
| 146 |
+
seg_kwargs['conf_threshold'] = args.conf_threshold
|
| 147 |
+
|
| 148 |
+
segmenter = create_segmenter(args.method, device=args.device, **seg_kwargs)
|
| 149 |
+
|
| 150 |
+
targets = args.classes
|
| 151 |
+
if args.method == 'sam3' and args.prompt:
|
| 152 |
+
targets = [args.prompt]
|
| 153 |
+
|
| 154 |
+
# Check if batched segmentation is supported
|
| 155 |
+
supports_batch = getattr(segmenter, 'supports_batch', False)
|
| 156 |
+
|
| 157 |
+
print(f" Segmenter supports batching: {supports_batch}")
|
| 158 |
+
|
| 159 |
+
# Estimate optimal batch size if GPU is being used
|
| 160 |
+
if args.device == 'cuda' and supports_batch:
|
| 161 |
+
try:
|
| 162 |
+
# Get first frame dimensions for estimation
|
| 163 |
+
sample_w, sample_h = frames[0].size
|
| 164 |
+
batch_info = estimate_batch_sizes(
|
| 165 |
+
device=args.device,
|
| 166 |
+
seg_method=args.method,
|
| 167 |
+
frame_width=sample_w,
|
| 168 |
+
frame_height=sample_h,
|
| 169 |
+
total_frames=len(frames),
|
| 170 |
+
)
|
| 171 |
+
recommended_batch = batch_info.seg_batch_size
|
| 172 |
+
print(f" GPU Memory Estimation:")
|
| 173 |
+
print(f" Free VRAM: {batch_info.free_vram_bytes / (1024**3):.2f} GB")
|
| 174 |
+
print(f" Recommended batch size: {recommended_batch}")
|
| 175 |
+
if batch_info.notes:
|
| 176 |
+
print(f" {batch_info.notes}")
|
| 177 |
+
except Exception as e:
|
| 178 |
+
print(f" Warning: Could not estimate GPU batch size: {e}")
|
| 179 |
+
recommended_batch = None
|
| 180 |
+
else:
|
| 181 |
+
recommended_batch = None
|
| 182 |
+
|
| 183 |
+
# Segment frames with OOM retry logic
|
| 184 |
+
t0 = time.perf_counter()
|
| 185 |
+
if supports_batch and hasattr(segmenter, 'segment_batch'):
|
| 186 |
+
print(f" Segmenting {len(frames)} frames in batches...")
|
| 187 |
+
|
| 188 |
+
# OOM retry logic
|
| 189 |
+
max_retries = 7
|
| 190 |
+
retry_count = 0
|
| 191 |
+
masks = None
|
| 192 |
+
current_batch_size = recommended_batch # None means use segmenter's default
|
| 193 |
+
|
| 194 |
+
while retry_count <= max_retries:
|
| 195 |
+
try:
|
| 196 |
+
if current_batch_size is not None:
|
| 197 |
+
# Segment in manual batches
|
| 198 |
+
masks = []
|
| 199 |
+
for i in range(0, len(frames), current_batch_size):
|
| 200 |
+
batch = frames[i:i + current_batch_size]
|
| 201 |
+
print(f" Batch {i//current_batch_size + 1}: frames {i}-{i+len(batch)-1} (batch_size={len(batch)})")
|
| 202 |
+
batch_masks = segmenter.segment_batch(batch, target_classes=targets)
|
| 203 |
+
masks.extend(batch_masks)
|
| 204 |
+
else:
|
| 205 |
+
# Let segmenter handle batching internally
|
| 206 |
+
masks = segmenter.segment_batch(frames, target_classes=targets)
|
| 207 |
+
|
| 208 |
+
# Success - break retry loop
|
| 209 |
+
break
|
| 210 |
+
|
| 211 |
+
except torch.cuda.OutOfMemoryError as e:
|
| 212 |
+
retry_count += 1
|
| 213 |
+
if retry_count > max_retries:
|
| 214 |
+
print(f" ERROR: Out of memory after {max_retries} retries. Try reducing --resize-height.")
|
| 215 |
+
raise
|
| 216 |
+
|
| 217 |
+
# Halve the batch size
|
| 218 |
+
if current_batch_size is None:
|
| 219 |
+
# Initial OOM with default batching - start with a reasonable size
|
| 220 |
+
current_batch_size = max(1, len(frames) // 4)
|
| 221 |
+
else:
|
| 222 |
+
current_batch_size = max(1, current_batch_size // 2)
|
| 223 |
+
|
| 224 |
+
print(f" Out of memory! Retry {retry_count}/{max_retries} with batch_size={current_batch_size}")
|
| 225 |
+
torch.cuda.empty_cache()
|
| 226 |
+
masks = None # Reset
|
| 227 |
+
|
| 228 |
+
if masks is None:
|
| 229 |
+
raise RuntimeError("Segmentation failed after all retries")
|
| 230 |
+
else:
|
| 231 |
+
print(f" Segmenting {len(frames)} frames sequentially...")
|
| 232 |
+
masks = []
|
| 233 |
+
for i, frame in enumerate(frames):
|
| 234 |
+
if (i + 1) % 10 == 0 or i == 0:
|
| 235 |
+
print(f" Frame {i+1}/{len(frames)}")
|
| 236 |
+
masks.append(segmenter(frame, target_classes=targets))
|
| 237 |
+
|
| 238 |
+
t1 = time.perf_counter()
|
| 239 |
+
total_time = t1 - t0
|
| 240 |
+
|
| 241 |
+
print(f" Total segmentation time: {total_time:.3f} s")
|
| 242 |
+
print(f" Average per-frame: {total_time/len(frames):.4f} s ({len(frames)/total_time:.2f} fps)")
|
| 243 |
+
|
| 244 |
+
# Apply SDF temporal smoothing to reduce jitter
|
| 245 |
+
if not args.no_smooth and len(masks) > 2:
|
| 246 |
+
print(f" Applying SDF temporal smoothing (alpha={args.smooth_alpha}, patience={args.smooth_patience})...")
|
| 247 |
+
t_smooth_start = time.perf_counter()
|
| 248 |
+
masks = smooth_masks_sdf(
|
| 249 |
+
masks,
|
| 250 |
+
alpha=args.smooth_alpha,
|
| 251 |
+
empty_thresh=10,
|
| 252 |
+
patience=args.smooth_patience,
|
| 253 |
+
)
|
| 254 |
+
t_smooth_end = time.perf_counter()
|
| 255 |
+
print(f" Smoothing time: {t_smooth_end - t_smooth_start:.3f} s")
|
| 256 |
+
|
| 257 |
+
# Calculate aggregate statistics
|
| 258 |
+
total_roi_pixels = sum(m.sum() for m in masks)
|
| 259 |
+
total_pixels = sum(m.size for m in masks)
|
| 260 |
+
avg_coverage = (total_roi_pixels / total_pixels * 100) if total_pixels > 0 else 0.0
|
| 261 |
+
print(f" Average ROI coverage: {avg_coverage:.2f}%")
|
| 262 |
+
|
| 263 |
+
# Save output
|
| 264 |
+
output_path = Path(args.output)
|
| 265 |
+
if args.save_frames:
|
| 266 |
+
# Save individual mask frames
|
| 267 |
+
output_dir = output_path.parent / f"{output_path.stem}_frames"
|
| 268 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 269 |
+
print(f" Saving {len(masks)} mask frames to {output_dir}/")
|
| 270 |
+
for i, mask in enumerate(masks):
|
| 271 |
+
frame_path = output_dir / f"mask_{i:06d}.png"
|
| 272 |
+
save_mask(mask, str(frame_path))
|
| 273 |
+
print(f" Saved {len(masks)} frames")
|
| 274 |
+
else:
|
| 275 |
+
# Save as video
|
| 276 |
+
print(f" Saving mask video to {args.output}")
|
| 277 |
+
save_masks_as_video(masks, str(output_path), fps=fps)
|
| 278 |
+
print(f" Saved mask video")
|
| 279 |
+
|
| 280 |
+
# Save visualization if requested
|
| 281 |
+
if args.visualize:
|
| 282 |
+
viz_dir = output_path.parent / f"{output_path.stem}_viz"
|
| 283 |
+
viz_dir.mkdir(parents=True, exist_ok=True)
|
| 284 |
+
print(f" Creating visualization video...")
|
| 285 |
+
|
| 286 |
+
viz_frames = [visualize_mask(frame, mask, alpha=0.5, color=(255, 0, 0))
|
| 287 |
+
for frame, mask in zip(frames, masks)]
|
| 288 |
+
|
| 289 |
+
# Save as video
|
| 290 |
+
try:
|
| 291 |
+
import cv2
|
| 292 |
+
viz_path = viz_dir / "overlay.mp4"
|
| 293 |
+
h, w = viz_frames[0].size[1], viz_frames[0].size[0]
|
| 294 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 295 |
+
writer = cv2.VideoWriter(str(viz_path), fourcc, fps, (w, h), isColor=True)
|
| 296 |
+
|
| 297 |
+
for vf in viz_frames:
|
| 298 |
+
frame_np = np.array(vf)
|
| 299 |
+
frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR)
|
| 300 |
+
writer.write(frame_bgr)
|
| 301 |
+
|
| 302 |
+
writer.release()
|
| 303 |
+
print(f" Visualization saved: {viz_path}")
|
| 304 |
+
except Exception as e:
|
| 305 |
+
print(f" Warning: Could not save visualization video: {e}")
|
| 306 |
+
|
| 307 |
+
print("Done!")
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
# Command-line interface
|
| 311 |
+
if __name__ == '__main__':
|
| 312 |
+
import argparse
|
| 313 |
+
from segmentation import get_available_methods
|
| 314 |
+
|
| 315 |
+
# Get available segmentation methods dynamically
|
| 316 |
+
available_methods = get_available_methods()
|
| 317 |
+
|
| 318 |
+
parser = argparse.ArgumentParser(description='Segment objects in images or videos')
|
| 319 |
+
parser.add_argument('--input', required=True, help='Input image or video path')
|
| 320 |
+
parser.add_argument('--output', required=True, help='Output mask path (image/video) or directory')
|
| 321 |
+
parser.add_argument('--method', default='yolo',
|
| 322 |
+
choices=available_methods,
|
| 323 |
+
help='Segmentation method (including fake_* methods for detection+tracking)')
|
| 324 |
+
parser.add_argument('--classes', nargs='+', default=['car'],
|
| 325 |
+
help='Target classes to segment')
|
| 326 |
+
parser.add_argument('--prompt', type=str, default=None,
|
| 327 |
+
help='Natural language prompt (use with --method sam3)')
|
| 328 |
+
parser.add_argument('--conf-threshold', type=float, default=0.25,
|
| 329 |
+
help='Confidence threshold')
|
| 330 |
+
parser.add_argument('--visualize', action='store_true',
|
| 331 |
+
help='Save visualization with overlay')
|
| 332 |
+
parser.add_argument('--device', default='cuda', choices=['cuda', 'cpu'],
|
| 333 |
+
help='Device to run on')
|
| 334 |
+
|
| 335 |
+
# Video-specific options
|
| 336 |
+
parser.add_argument('--resize-height', type=int, default=None,
|
| 337 |
+
help='Resize video frames to this height (maintains aspect ratio)')
|
| 338 |
+
parser.add_argument('--save-frames', action='store_true',
|
| 339 |
+
help='For videos: save masks as individual frames instead of video')
|
| 340 |
+
|
| 341 |
+
# Temporal smoothing options
|
| 342 |
+
parser.add_argument('--no-smooth', action='store_true',
|
| 343 |
+
help='Disable temporal smoothing (may cause jitter)')
|
| 344 |
+
parser.add_argument('--smooth-alpha', type=float, default=0.5,
|
| 345 |
+
help='SDF smoothing factor (0.1=slow/viscous, 0.9=fast/reactive)')
|
| 346 |
+
parser.add_argument('--smooth-patience', type=int, default=5,
|
| 347 |
+
help='Frames to tolerate dropouts before decay (0=immediate, 5=conservative, 15=aggressive)')
|
| 348 |
+
|
| 349 |
+
args = parser.parse_args()
|
| 350 |
+
|
| 351 |
+
# Determine if input is image or video
|
| 352 |
+
if is_video(args.input):
|
| 353 |
+
process_video(args)
|
| 354 |
+
else:
|
| 355 |
+
process_image(args)
|
segmentation/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ROI Segmentation Module
|
| 3 |
+
|
| 4 |
+
Provides abstract base class and concrete implementations for various
|
| 5 |
+
segmentation models (YOLO, SegFormer, Mask2Former, Mask R-CNN, SAM3, etc.) used in ROI-based compression.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .base import BaseSegmenter
|
| 9 |
+
from .segformer import SegFormerSegmenter
|
| 10 |
+
from .yolo import YOLOSegmenter
|
| 11 |
+
from .mask2former import Mask2FormerSegmenter
|
| 12 |
+
from .maskrcnn import MaskRCNNSegmenter
|
| 13 |
+
from .sam3 import SAM3Segmenter
|
| 14 |
+
from .fake import FakeSegmenter
|
| 15 |
+
from .factory import create_segmenter, get_available_methods, register_segmenter
|
| 16 |
+
from .utils import visualize_mask, save_mask, load_mask, calculate_roi_stats
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
'BaseSegmenter',
|
| 20 |
+
'SegFormerSegmenter',
|
| 21 |
+
'YOLOSegmenter',
|
| 22 |
+
'Mask2FormerSegmenter',
|
| 23 |
+
'MaskRCNNSegmenter',
|
| 24 |
+
'SAM3Segmenter',
|
| 25 |
+
'FakeSegmenter',
|
| 26 |
+
'create_segmenter',
|
| 27 |
+
'get_available_methods',
|
| 28 |
+
'register_segmenter',
|
| 29 |
+
'visualize_mask',
|
| 30 |
+
'save_mask',
|
| 31 |
+
'load_mask',
|
| 32 |
+
'calculate_roi_stats',
|
| 33 |
+
]
|
segmentation/base.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Abstract base class for segmentation models.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from abc import ABC, abstractmethod
|
| 6 |
+
from typing import List, Optional, Union, Dict, Any
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BaseSegmenter(ABC):
|
| 12 |
+
"""
|
| 13 |
+
Abstract base class for all segmentation models.
|
| 14 |
+
|
| 15 |
+
This class defines the common interface that all segmentation models
|
| 16 |
+
must implement. Subclasses can handle different types of inputs:
|
| 17 |
+
- Class-based segmentation (YOLO, SegFormer): List of class names
|
| 18 |
+
- Natural language segmentation (SAM, CLIP-based): Text prompts
|
| 19 |
+
- Point/box-based segmentation (SAM): Coordinates
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, device: str = 'cuda', **kwargs):
|
| 23 |
+
"""
|
| 24 |
+
Initialize the segmenter.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
device: Device to run inference on ('cuda' or 'cpu')
|
| 28 |
+
**kwargs: Model-specific parameters
|
| 29 |
+
"""
|
| 30 |
+
self.device = device
|
| 31 |
+
self.model = None
|
| 32 |
+
self._is_loaded = False
|
| 33 |
+
|
| 34 |
+
@abstractmethod
|
| 35 |
+
def load_model(self):
|
| 36 |
+
"""
|
| 37 |
+
Load the segmentation model.
|
| 38 |
+
|
| 39 |
+
This method should load the model weights and prepare the model
|
| 40 |
+
for inference. Called automatically before first use.
|
| 41 |
+
"""
|
| 42 |
+
pass
|
| 43 |
+
|
| 44 |
+
@abstractmethod
|
| 45 |
+
def segment(
|
| 46 |
+
self,
|
| 47 |
+
image: Image.Image,
|
| 48 |
+
target_classes: Optional[List[str]] = None,
|
| 49 |
+
**kwargs
|
| 50 |
+
) -> np.ndarray:
|
| 51 |
+
"""
|
| 52 |
+
Create binary segmentation mask for ROI.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
image: PIL Image to segment
|
| 56 |
+
target_classes: List of target classes or text prompts
|
| 57 |
+
**kwargs: Model-specific parameters (e.g., confidence threshold)
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Binary mask as numpy array (H, W) with values 0 or 1
|
| 61 |
+
- 1: Region of Interest (ROI)
|
| 62 |
+
- 0: Background
|
| 63 |
+
"""
|
| 64 |
+
pass
|
| 65 |
+
|
| 66 |
+
@abstractmethod
|
| 67 |
+
def get_available_classes(self) -> Union[List[str], Dict[str, int]]:
|
| 68 |
+
"""
|
| 69 |
+
Get list or mapping of classes this model can segment.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
List of class names or dict mapping class names to IDs
|
| 73 |
+
"""
|
| 74 |
+
pass
|
| 75 |
+
|
| 76 |
+
def validate_classes(self, target_classes: Optional[List[str]]) -> List[str]:
|
| 77 |
+
"""
|
| 78 |
+
Validate and filter target classes against available classes.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
target_classes: List of requested class names
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
List of valid class names
|
| 85 |
+
"""
|
| 86 |
+
if target_classes is None:
|
| 87 |
+
return self.get_default_classes()
|
| 88 |
+
|
| 89 |
+
available_classes = self.get_available_classes()
|
| 90 |
+
if isinstance(available_classes, dict):
|
| 91 |
+
available_classes = list(available_classes.keys())
|
| 92 |
+
|
| 93 |
+
valid_classes = []
|
| 94 |
+
for cls in target_classes:
|
| 95 |
+
cls_lower = cls.lower()
|
| 96 |
+
if cls_lower in [c.lower() for c in available_classes]:
|
| 97 |
+
valid_classes.append(cls)
|
| 98 |
+
else:
|
| 99 |
+
print(f"Warning: '{cls}' not in {self.__class__.__name__} classes.")
|
| 100 |
+
|
| 101 |
+
if not valid_classes:
|
| 102 |
+
print(f"Warning: No valid classes found. Using defaults.")
|
| 103 |
+
valid_classes = self.get_default_classes()
|
| 104 |
+
|
| 105 |
+
return valid_classes
|
| 106 |
+
|
| 107 |
+
def segment_batch(
|
| 108 |
+
self,
|
| 109 |
+
images: List[Image.Image],
|
| 110 |
+
target_classes: Optional[List[str]] = None,
|
| 111 |
+
**kwargs
|
| 112 |
+
) -> List[np.ndarray]:
|
| 113 |
+
"""
|
| 114 |
+
Segment a batch of images.
|
| 115 |
+
|
| 116 |
+
Default implementation processes images sequentially. Subclasses
|
| 117 |
+
should override this with a true batched forward pass when the
|
| 118 |
+
underlying model supports it.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
images: List of PIL Images (ideally same resolution)
|
| 122 |
+
target_classes: List of target classes or text prompts
|
| 123 |
+
**kwargs: Model-specific parameters
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
List of binary masks (H, W) – one per input image
|
| 127 |
+
"""
|
| 128 |
+
self.ensure_loaded()
|
| 129 |
+
return [self.segment(img, target_classes, **kwargs) for img in images]
|
| 130 |
+
|
| 131 |
+
# Whether the model supports true GPU-batched inference.
|
| 132 |
+
# Subclasses should set this to True if segment_batch uses a
|
| 133 |
+
# single forward pass rather than a loop.
|
| 134 |
+
supports_batch: bool = False
|
| 135 |
+
|
| 136 |
+
def get_default_classes(self) -> List[str]:
|
| 137 |
+
"""
|
| 138 |
+
Get default classes to segment if none specified.
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
List of default class names
|
| 142 |
+
"""
|
| 143 |
+
return ['car'] # Default fallback
|
| 144 |
+
|
| 145 |
+
def ensure_loaded(self):
|
| 146 |
+
"""Ensure model is loaded before use."""
|
| 147 |
+
if not self._is_loaded:
|
| 148 |
+
self.load_model()
|
| 149 |
+
self._is_loaded = True
|
| 150 |
+
|
| 151 |
+
def __call__(
|
| 152 |
+
self,
|
| 153 |
+
image: Image.Image,
|
| 154 |
+
target_classes: Optional[List[str]] = None,
|
| 155 |
+
**kwargs
|
| 156 |
+
) -> np.ndarray:
|
| 157 |
+
"""
|
| 158 |
+
Convenience method to call segment().
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
image: PIL Image to segment
|
| 162 |
+
target_classes: List of target classes or text prompts
|
| 163 |
+
**kwargs: Model-specific parameters
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
Binary mask as numpy array (H, W)
|
| 167 |
+
"""
|
| 168 |
+
self.ensure_loaded()
|
| 169 |
+
return self.segment(image, target_classes, **kwargs)
|
segmentation/factory.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Factory for creating segmentation models.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Dict, Type, Optional, List, Callable, Union
|
| 6 |
+
from .base import BaseSegmenter
|
| 7 |
+
from .segformer import SegFormerSegmenter
|
| 8 |
+
from .yolo import YOLOSegmenter
|
| 9 |
+
from .mask2former import Mask2FormerSegmenter
|
| 10 |
+
from .maskrcnn import MaskRCNNSegmenter
|
| 11 |
+
from .sam3 import SAM3Segmenter
|
| 12 |
+
from .fake import FakeSegmenter
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# Registry of available segmentation methods
|
| 16 |
+
# Can be either a class or a factory function
|
| 17 |
+
SEGMENTER_REGISTRY: Dict[str, Union[Type[BaseSegmenter], Callable]] = {
|
| 18 |
+
'segformer': SegFormerSegmenter,
|
| 19 |
+
'yolo': YOLOSegmenter,
|
| 20 |
+
'mask2former': Mask2FormerSegmenter,
|
| 21 |
+
'maskrcnn': MaskRCNNSegmenter,
|
| 22 |
+
'sam3': SAM3Segmenter,
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _register_fake_methods():
|
| 27 |
+
"""Register fake segmentation methods (detection + tracking → bbox masks)."""
|
| 28 |
+
fake_configs = [
|
| 29 |
+
('fake_yolo', 'yolo', 'bytetrack'),
|
| 30 |
+
('fake_yolo_botsort', 'yolo', 'botsort'),
|
| 31 |
+
('fake_detr', 'detr', 'bytetrack'),
|
| 32 |
+
('fake_deformable_detr', 'deformable_detr', 'bytetrack'),
|
| 33 |
+
('fake_fasterrcnn', 'fasterrcnn', 'bytetrack'),
|
| 34 |
+
('fake_retinanet', 'retinanet', 'bytetrack'),
|
| 35 |
+
('fake_fcos', 'fcos', 'bytetrack'),
|
| 36 |
+
('fake_grounding_dino', 'grounding_dino', 'bytetrack'),
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
for name, detector, tracker in fake_configs:
|
| 40 |
+
# Create a factory function for each config
|
| 41 |
+
# Use default arguments to capture values properly in closure
|
| 42 |
+
def make_factory(det=detector, default_tracker=tracker):
|
| 43 |
+
def factory(**kwargs):
|
| 44 |
+
# Allow overriding tracker_type, otherwise use default
|
| 45 |
+
if 'tracker_type' not in kwargs:
|
| 46 |
+
kwargs['tracker_type'] = default_tracker
|
| 47 |
+
return FakeSegmenter(detector_name=det, **kwargs)
|
| 48 |
+
return factory
|
| 49 |
+
|
| 50 |
+
SEGMENTER_REGISTRY[name] = make_factory()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# Register fake methods on import
|
| 54 |
+
_register_fake_methods()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def register_segmenter(name: str, segmenter_class: Type[BaseSegmenter]):
|
| 58 |
+
"""
|
| 59 |
+
Register a new segmentation method.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
name: Method name (e.g., 'sam', 'drone_detector')
|
| 63 |
+
segmenter_class: Segmenter class that extends BaseSegmenter
|
| 64 |
+
"""
|
| 65 |
+
if not issubclass(segmenter_class, BaseSegmenter):
|
| 66 |
+
raise ValueError(f"{segmenter_class} must extend BaseSegmenter")
|
| 67 |
+
SEGMENTER_REGISTRY[name.lower()] = segmenter_class
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def create_segmenter(
|
| 71 |
+
method: str,
|
| 72 |
+
device: str = 'cuda',
|
| 73 |
+
**kwargs
|
| 74 |
+
) -> BaseSegmenter:
|
| 75 |
+
"""
|
| 76 |
+
Factory function to create a segmentation model.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
method: Segmentation method name ('segformer', 'yolo', 'fake_yolo', etc.)
|
| 80 |
+
device: Device to run on ('cuda' or 'cpu')
|
| 81 |
+
**kwargs: Method-specific parameters
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
Instance of the requested segmenter
|
| 85 |
+
|
| 86 |
+
Raises:
|
| 87 |
+
ValueError: If method is not recognized
|
| 88 |
+
|
| 89 |
+
Example:
|
| 90 |
+
>>> segmenter = create_segmenter('yolo', device='cuda', conf_threshold=0.3)
|
| 91 |
+
>>> mask = segmenter(image, target_classes=['car', 'person'])
|
| 92 |
+
|
| 93 |
+
>>> # Use detection-based fake segmentation with tracking
|
| 94 |
+
>>> fake_seg = create_segmenter('fake_yolo', device='cuda')
|
| 95 |
+
>>> mask = fake_seg(image, target_classes=['person'])
|
| 96 |
+
"""
|
| 97 |
+
method_lower = method.lower()
|
| 98 |
+
|
| 99 |
+
if method_lower not in SEGMENTER_REGISTRY:
|
| 100 |
+
available = ', '.join(sorted(SEGMENTER_REGISTRY.keys()))
|
| 101 |
+
raise ValueError(
|
| 102 |
+
f"Unknown segmentation method: '{method}'. "
|
| 103 |
+
f"Available methods: {available}"
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
factory = SEGMENTER_REGISTRY[method_lower]
|
| 107 |
+
|
| 108 |
+
# Handle both class constructors and factory functions
|
| 109 |
+
if callable(factory) and not isinstance(factory, type):
|
| 110 |
+
# It's a factory function (for fake segmenters)
|
| 111 |
+
return factory(device=device, **kwargs)
|
| 112 |
+
else:
|
| 113 |
+
# It's a class constructor
|
| 114 |
+
return factory(device=device, **kwargs)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def get_available_methods() -> List[str]:
|
| 118 |
+
"""
|
| 119 |
+
Get list of available segmentation methods.
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
List of method names
|
| 123 |
+
"""
|
| 124 |
+
return list(SEGMENTER_REGISTRY.keys())
|
segmentation/fake.py
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Detection-based 'fake' segmentation using bounding boxes + object tracking.
|
| 2 |
+
|
| 3 |
+
Creates rectangular masks from detection bounding boxes and maintains object
|
| 4 |
+
identity across frames using tracking (ByteTrack, BoTSORT, or SimpleTracker).
|
| 5 |
+
|
| 6 |
+
Now supports batch detection for efficiency: detects all frames in batches,
|
| 7 |
+
then runs tracking sequentially on the results.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from typing import List, Optional, Dict, Any
|
| 15 |
+
|
| 16 |
+
from .base import BaseSegmenter
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class FakeSegmenter(BaseSegmenter):
|
| 20 |
+
"""
|
| 21 |
+
Detection-based segmentation that creates rectangular masks from bboxes.
|
| 22 |
+
|
| 23 |
+
Uses object tracking to maintain consistent masks across video frames:
|
| 24 |
+
- ByteTrack (default for YOLO)
|
| 25 |
+
- BoTSORT (available for YOLO)
|
| 26 |
+
- SimpleTracker (fallback for non-YOLO detectors)
|
| 27 |
+
|
| 28 |
+
This is useful for fast ROI extraction when pixel-perfect masks aren't needed.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
device: str = 'cuda',
|
| 34 |
+
detector_name: str = 'yolo',
|
| 35 |
+
tracker_type: str = 'bytetrack',
|
| 36 |
+
conf_threshold: float = 0.25,
|
| 37 |
+
model_path: Optional[str] = None,
|
| 38 |
+
**kwargs
|
| 39 |
+
):
|
| 40 |
+
"""
|
| 41 |
+
Initialize fake segmenter.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
device: Device to run on ('cuda' or 'cpu')
|
| 45 |
+
detector_name: Detection method ('yolo', 'detr', 'faster_rcnn', etc.)
|
| 46 |
+
tracker_type: Tracker type ('bytetrack', 'botsort', 'simple')
|
| 47 |
+
conf_threshold: Confidence threshold for detections
|
| 48 |
+
model_path: Optional path to detector weights
|
| 49 |
+
**kwargs: Additional parameters
|
| 50 |
+
"""
|
| 51 |
+
super().__init__(device=device, **kwargs)
|
| 52 |
+
self.detector_name = detector_name.lower()
|
| 53 |
+
self.tracker_type = tracker_type.lower()
|
| 54 |
+
self.conf_threshold = conf_threshold
|
| 55 |
+
self.model_path = model_path
|
| 56 |
+
|
| 57 |
+
# Detector (set in load_model())
|
| 58 |
+
self.detector = None
|
| 59 |
+
|
| 60 |
+
# State for tracking
|
| 61 |
+
self._tracker = None
|
| 62 |
+
self._frame_count = 0
|
| 63 |
+
|
| 64 |
+
# Batch support
|
| 65 |
+
self.supports_batch = True
|
| 66 |
+
|
| 67 |
+
def load_model(self):
|
| 68 |
+
"""Load detector model and initialize tracker."""
|
| 69 |
+
from detection import create_detector
|
| 70 |
+
|
| 71 |
+
# Create detector
|
| 72 |
+
if self.model_path:
|
| 73 |
+
self.detector = create_detector(
|
| 74 |
+
self.detector_name,
|
| 75 |
+
device=self.device,
|
| 76 |
+
model_path=self.model_path,
|
| 77 |
+
)
|
| 78 |
+
else:
|
| 79 |
+
self.detector = create_detector(
|
| 80 |
+
self.detector_name,
|
| 81 |
+
device=self.device,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Ensure detector model is loaded
|
| 85 |
+
if not self.detector._is_loaded:
|
| 86 |
+
self.detector.load_model()
|
| 87 |
+
|
| 88 |
+
# Initialize tracker (now all trackers work with any detector!)
|
| 89 |
+
if self.tracker_type == 'bytetrack':
|
| 90 |
+
from detection.bytetrack import ByteTracker
|
| 91 |
+
self._tracker = ByteTracker(
|
| 92 |
+
track_thresh=0.5,
|
| 93 |
+
match_thresh=0.8,
|
| 94 |
+
track_buffer=30,
|
| 95 |
+
frame_rate=30,
|
| 96 |
+
)
|
| 97 |
+
elif self.tracker_type == 'botsort':
|
| 98 |
+
from detection.bytetrack import BoTSORT
|
| 99 |
+
self._tracker = BoTSORT(
|
| 100 |
+
track_thresh=0.5,
|
| 101 |
+
match_thresh=0.8,
|
| 102 |
+
track_buffer=30,
|
| 103 |
+
frame_rate=30,
|
| 104 |
+
)
|
| 105 |
+
else: # simple
|
| 106 |
+
from detection.tracker import SimpleTracker
|
| 107 |
+
self._tracker = SimpleTracker(
|
| 108 |
+
iou_threshold=0.3,
|
| 109 |
+
max_age=30,
|
| 110 |
+
min_hits=1,
|
| 111 |
+
label_match=True,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
print(f"Loaded FakeSegmenter: {self.detector_name} + {self.tracker_type} tracking")
|
| 115 |
+
|
| 116 |
+
def reset_tracking(self):
|
| 117 |
+
"""Reset tracker state (call between videos)."""
|
| 118 |
+
self._frame_count = 0
|
| 119 |
+
if self._tracker is not None:
|
| 120 |
+
self._tracker.reset()
|
| 121 |
+
|
| 122 |
+
def _create_bbox_mask(
|
| 123 |
+
self,
|
| 124 |
+
width: int,
|
| 125 |
+
height: int,
|
| 126 |
+
detections: List[Dict],
|
| 127 |
+
) -> np.ndarray:
|
| 128 |
+
"""Create binary mask from bounding boxes.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
width: Image width
|
| 132 |
+
height: Image height
|
| 133 |
+
detections: List of detections with 'bbox_xyxy' key
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
Binary mask (H, W) with 1.0 where bboxes are
|
| 137 |
+
"""
|
| 138 |
+
mask = np.zeros((height, width), dtype=np.float32)
|
| 139 |
+
|
| 140 |
+
for det in detections:
|
| 141 |
+
bbox = det.get('bbox_xyxy', det.get('bbox'))
|
| 142 |
+
if bbox is None:
|
| 143 |
+
continue
|
| 144 |
+
|
| 145 |
+
x1, y1, x2, y2 = bbox
|
| 146 |
+
x1 = int(max(0, min(x1, width - 1)))
|
| 147 |
+
y1 = int(max(0, min(y1, height - 1)))
|
| 148 |
+
x2 = int(max(0, min(x2, width - 1)))
|
| 149 |
+
y2 = int(max(0, min(y2, height - 1)))
|
| 150 |
+
|
| 151 |
+
if x2 > x1 and y2 > y1:
|
| 152 |
+
mask[y1:y2, x1:x2] = 1.0
|
| 153 |
+
|
| 154 |
+
return mask
|
| 155 |
+
|
| 156 |
+
def segment(
|
| 157 |
+
self,
|
| 158 |
+
image: Image.Image,
|
| 159 |
+
target_classes: Optional[List[str]] = None,
|
| 160 |
+
conf_threshold: Optional[float] = None,
|
| 161 |
+
**kwargs
|
| 162 |
+
) -> np.ndarray:
|
| 163 |
+
"""
|
| 164 |
+
Create segmentation mask from detections.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
image: Input PIL Image
|
| 168 |
+
target_classes: Classes to detect (None = all classes)
|
| 169 |
+
conf_threshold: Override default confidence threshold
|
| 170 |
+
**kwargs: Additional parameters
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
Binary mask (H, W) as float32
|
| 174 |
+
"""
|
| 175 |
+
if conf_threshold is None:
|
| 176 |
+
conf_threshold = self.conf_threshold
|
| 177 |
+
|
| 178 |
+
width, height = image.size
|
| 179 |
+
|
| 180 |
+
# Run detection
|
| 181 |
+
# Pass classes for open-vocabulary detectors (Grounding DINO, YOLO-World)
|
| 182 |
+
detect_kwargs = {"conf_threshold": conf_threshold}
|
| 183 |
+
if target_classes:
|
| 184 |
+
detect_kwargs["classes"] = target_classes
|
| 185 |
+
detections = self.detector.detect(image, **detect_kwargs)
|
| 186 |
+
|
| 187 |
+
# Convert to dict format
|
| 188 |
+
det_dicts = [
|
| 189 |
+
{
|
| 190 |
+
'label': d.label,
|
| 191 |
+
'score': d.score,
|
| 192 |
+
'bbox_xyxy': d.bbox_xyxy,
|
| 193 |
+
}
|
| 194 |
+
for d in detections
|
| 195 |
+
]
|
| 196 |
+
|
| 197 |
+
# Update tracker
|
| 198 |
+
if self._tracker is not None:
|
| 199 |
+
tracks = self._tracker.update(det_dicts)
|
| 200 |
+
# Convert tracks to detection format
|
| 201 |
+
detections = track_dicts_to_detections(tracks)
|
| 202 |
+
else:
|
| 203 |
+
detections = det_dicts
|
| 204 |
+
|
| 205 |
+
# Filter by target classes if specified
|
| 206 |
+
if target_classes:
|
| 207 |
+
target_lower = [tc.lower() for tc in target_classes]
|
| 208 |
+
detections = [
|
| 209 |
+
d for d in detections
|
| 210 |
+
if any(tc in d['label'].lower() for tc in target_lower)
|
| 211 |
+
]
|
| 212 |
+
|
| 213 |
+
# Create mask from bboxes
|
| 214 |
+
mask = self._create_bbox_mask(width, height, detections)
|
| 215 |
+
self._frame_count += 1
|
| 216 |
+
|
| 217 |
+
return mask
|
| 218 |
+
|
| 219 |
+
def _detect_with_yolo_tracking(
|
| 220 |
+
self,
|
| 221 |
+
image: Image.Image,
|
| 222 |
+
conf_threshold: float,
|
| 223 |
+
) -> List[Dict]:
|
| 224 |
+
"""Run YOLO detection with built-in tracking.
|
| 225 |
+
|
| 226 |
+
Returns:
|
| 227 |
+
List of detections with track IDs
|
| 228 |
+
"""
|
| 229 |
+
img = np.asarray(image.convert('RGB'))
|
| 230 |
+
|
| 231 |
+
# Determine device argument
|
| 232 |
+
device_arg = 0 if self.detector.device.startswith('cuda') else 'cpu'
|
| 233 |
+
|
| 234 |
+
# Use YOLO's .track() method instead of .predict()
|
| 235 |
+
results = self.detector.model.track(
|
| 236 |
+
source=img,
|
| 237 |
+
conf=conf_threshold,
|
| 238 |
+
device=device_arg,
|
| 239 |
+
verbose=False,
|
| 240 |
+
tracker=f'{self.tracker_type}.yaml', # bytetrack.yaml or botsort.yaml
|
| 241 |
+
persist=True, # Persist tracks between frames
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
if not results:
|
| 245 |
+
return []
|
| 246 |
+
|
| 247 |
+
r0 = results[0]
|
| 248 |
+
if not hasattr(r0, 'boxes') or r0.boxes is None:
|
| 249 |
+
return []
|
| 250 |
+
|
| 251 |
+
boxes = r0.boxes
|
| 252 |
+
xyxy = boxes.xyxy.detach().cpu().numpy()
|
| 253 |
+
conf = boxes.conf.detach().cpu().numpy()
|
| 254 |
+
cls = boxes.cls.detach().cpu().numpy().astype(int)
|
| 255 |
+
|
| 256 |
+
# Get track IDs if available
|
| 257 |
+
track_ids = None
|
| 258 |
+
if hasattr(boxes, 'id') and boxes.id is not None:
|
| 259 |
+
track_ids = boxes.id.detach().cpu().numpy().astype(int)
|
| 260 |
+
|
| 261 |
+
detections = []
|
| 262 |
+
for i, (bbox, score, class_id) in enumerate(zip(xyxy, conf, cls)):
|
| 263 |
+
label = self.detector._names.get(int(class_id), str(int(class_id)))
|
| 264 |
+
det = {
|
| 265 |
+
'label': label,
|
| 266 |
+
'score': float(score),
|
| 267 |
+
'bbox_xyxy': [float(x) for x in bbox.tolist()],
|
| 268 |
+
}
|
| 269 |
+
if track_ids is not None:
|
| 270 |
+
det['track_id'] = int(track_ids[i])
|
| 271 |
+
detections.append(det)
|
| 272 |
+
|
| 273 |
+
return detections
|
| 274 |
+
|
| 275 |
+
def segment_batch(
|
| 276 |
+
self,
|
| 277 |
+
images: List[Image.Image],
|
| 278 |
+
target_classes: Optional[List[str]] = None,
|
| 279 |
+
conf_threshold: Optional[float] = None,
|
| 280 |
+
**kwargs
|
| 281 |
+
) -> List[np.ndarray]:
|
| 282 |
+
"""
|
| 283 |
+
Batch segmentation for video processing using offline batch detection + sequential tracking.
|
| 284 |
+
|
| 285 |
+
This is much more efficient:
|
| 286 |
+
1. Batch detect all frames at once (or in batches if memory limited)
|
| 287 |
+
2. Run tracker sequentially on detection results
|
| 288 |
+
3. Create masks from tracked detections
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
images: List of PIL Images
|
| 292 |
+
target_classes: Classes to detect
|
| 293 |
+
conf_threshold: Confidence threshold
|
| 294 |
+
**kwargs: Additional parameters
|
| 295 |
+
|
| 296 |
+
Returns:
|
| 297 |
+
List of binary masks
|
| 298 |
+
"""
|
| 299 |
+
# Ensure model is loaded
|
| 300 |
+
self.ensure_loaded()
|
| 301 |
+
|
| 302 |
+
if conf_threshold is None:
|
| 303 |
+
conf_threshold = self.conf_threshold
|
| 304 |
+
|
| 305 |
+
if not images:
|
| 306 |
+
return []
|
| 307 |
+
|
| 308 |
+
# Step 1: Batch detect all frames (TRUE batch inference for speed)
|
| 309 |
+
all_detections = []
|
| 310 |
+
# Prepare detection kwargs (classes for open-vocabulary detectors)
|
| 311 |
+
detect_kwargs = {"conf_threshold": conf_threshold}
|
| 312 |
+
if target_classes:
|
| 313 |
+
detect_kwargs["classes"] = target_classes
|
| 314 |
+
|
| 315 |
+
if hasattr(self.detector, 'detect_batch'):
|
| 316 |
+
# Use batch detection for efficiency (GPU parallelization)
|
| 317 |
+
batch_dets = self.detector.detect_batch(images, **detect_kwargs)
|
| 318 |
+
for dets in batch_dets:
|
| 319 |
+
det_dicts = [
|
| 320 |
+
{
|
| 321 |
+
'label': d.label,
|
| 322 |
+
'score': d.score,
|
| 323 |
+
'bbox_xyxy': d.bbox_xyxy,
|
| 324 |
+
}
|
| 325 |
+
for d in dets
|
| 326 |
+
]
|
| 327 |
+
all_detections.append(det_dicts)
|
| 328 |
+
else:
|
| 329 |
+
# Fallback to frame-by-frame for detectors without batch support
|
| 330 |
+
for image in images:
|
| 331 |
+
dets = self.detector.detect(image, **detect_kwargs)
|
| 332 |
+
det_dicts = [
|
| 333 |
+
{
|
| 334 |
+
'label': d.label,
|
| 335 |
+
'score': d.score,
|
| 336 |
+
'bbox_xyxy': d.bbox_xyxy,
|
| 337 |
+
}
|
| 338 |
+
for d in dets
|
| 339 |
+
]
|
| 340 |
+
all_detections.append(det_dicts)
|
| 341 |
+
|
| 342 |
+
# Step 2: Run tracker sequentially on all detections
|
| 343 |
+
tracked_detections = []
|
| 344 |
+
if self._tracker is not None:
|
| 345 |
+
for frame_dets in all_detections:
|
| 346 |
+
tracks = self._tracker.update(frame_dets)
|
| 347 |
+
# Convert tracks to detection format
|
| 348 |
+
frame_tracked = track_dicts_to_detections(tracks)
|
| 349 |
+
tracked_detections.append(frame_tracked)
|
| 350 |
+
else:
|
| 351 |
+
tracked_detections = all_detections
|
| 352 |
+
|
| 353 |
+
# Step 3: Filter by target classes and create masks
|
| 354 |
+
masks = []
|
| 355 |
+
for i, (image, detections) in enumerate(zip(images, tracked_detections)):
|
| 356 |
+
width, height = image.size
|
| 357 |
+
|
| 358 |
+
# Filter by target classes if specified
|
| 359 |
+
if target_classes:
|
| 360 |
+
target_lower = [tc.lower() for tc in target_classes]
|
| 361 |
+
detections = [
|
| 362 |
+
d for d in detections
|
| 363 |
+
if any(tc in d['label'].lower() for tc in target_lower)
|
| 364 |
+
]
|
| 365 |
+
|
| 366 |
+
# Create mask from bboxes
|
| 367 |
+
mask = self._create_bbox_mask(width, height, detections)
|
| 368 |
+
masks.append(mask)
|
| 369 |
+
self._frame_count += 1
|
| 370 |
+
|
| 371 |
+
return masks
|
| 372 |
+
|
| 373 |
+
def get_available_classes(self) -> List[str]:
|
| 374 |
+
"""Get list of classes the detector can detect."""
|
| 375 |
+
classes = self.detector.get_available_classes()
|
| 376 |
+
|
| 377 |
+
if isinstance(classes, dict):
|
| 378 |
+
return sorted(classes.keys())
|
| 379 |
+
elif isinstance(classes, list):
|
| 380 |
+
return classes
|
| 381 |
+
else:
|
| 382 |
+
return []
|
| 383 |
+
|
| 384 |
+
def get_default_classes(self) -> List[str]:
|
| 385 |
+
"""Get default classes for common use cases."""
|
| 386 |
+
# Common COCO classes
|
| 387 |
+
return ['person', 'car', 'truck', 'bus', 'bicycle', 'motorcycle']
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def track_dicts_to_detections(tracks: List[Dict]) -> List[Dict]:
|
| 391 |
+
"""Convert tracker output to detection format.
|
| 392 |
+
|
| 393 |
+
Args:
|
| 394 |
+
tracks: List of track dicts from tracker
|
| 395 |
+
|
| 396 |
+
Returns:
|
| 397 |
+
List of detection dicts
|
| 398 |
+
"""
|
| 399 |
+
detections = []
|
| 400 |
+
for track in tracks:
|
| 401 |
+
det = {
|
| 402 |
+
'label': track.get('label', ''),
|
| 403 |
+
'score': track.get('score', track.get('last_score', 0.0)),
|
| 404 |
+
'bbox_xyxy': track.get('bbox_xyxy', track.get('last_bbox', [])),
|
| 405 |
+
}
|
| 406 |
+
# Add track_id if available
|
| 407 |
+
if 'track_id' in track:
|
| 408 |
+
det['track_id'] = track['track_id']
|
| 409 |
+
|
| 410 |
+
detections.append(det)
|
| 411 |
+
|
| 412 |
+
return detections
|
segmentation/mask2former.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Mask2Former segmentation using Swin Transformer backbone.
|
| 3 |
+
Supports both COCO and ADE20K pre-trained models.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
import cv2
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from typing import List, Optional, Dict
|
| 11 |
+
from .base import BaseSegmenter
|
| 12 |
+
from model_cache import hf_cache_dir, ensure_default_checkpoint_dirs
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Mask2FormerSegmenter(BaseSegmenter):
|
| 16 |
+
"""
|
| 17 |
+
Mask2Former segmentation with Swin Transformer backbone.
|
| 18 |
+
|
| 19 |
+
Supports both COCO (133 classes) and ADE20K (150 classes) datasets.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
# COCO panoptic categories (133 classes including stuff)
|
| 23 |
+
COCO_CLASSES = {
|
| 24 |
+
'person': 0, 'bicycle': 1, 'car': 2, 'motorcycle': 3, 'airplane': 4,
|
| 25 |
+
'bus': 5, 'train': 6, 'truck': 7, 'boat': 8, 'traffic light': 9,
|
| 26 |
+
'fire hydrant': 10, 'stop sign': 11, 'parking meter': 12, 'bench': 13,
|
| 27 |
+
'bird': 14, 'cat': 15, 'dog': 16, 'horse': 17, 'sheep': 18, 'cow': 19,
|
| 28 |
+
'elephant': 20, 'bear': 21, 'zebra': 22, 'giraffe': 23, 'backpack': 24,
|
| 29 |
+
'umbrella': 25, 'handbag': 26, 'tie': 27, 'suitcase': 28, 'frisbee': 29,
|
| 30 |
+
'skis': 30, 'snowboard': 31, 'sports ball': 32, 'kite': 33, 'baseball bat': 34,
|
| 31 |
+
'baseball glove': 35, 'skateboard': 36, 'surfboard': 37, 'tennis racket': 38,
|
| 32 |
+
'bottle': 39, 'wine glass': 40, 'cup': 41, 'fork': 42, 'knife': 43,
|
| 33 |
+
'spoon': 44, 'bowl': 45, 'banana': 46, 'apple': 47, 'sandwich': 48,
|
| 34 |
+
'orange': 49, 'broccoli': 50, 'carrot': 51, 'hot dog': 52, 'pizza': 53,
|
| 35 |
+
'donut': 54, 'cake': 55, 'chair': 56, 'couch': 57, 'potted plant': 58,
|
| 36 |
+
'bed': 59, 'dining table': 60, 'toilet': 61, 'tv': 62, 'laptop': 63,
|
| 37 |
+
'mouse': 64, 'remote': 65, 'keyboard': 66, 'cell phone': 67, 'microwave': 68,
|
| 38 |
+
'oven': 69, 'toaster': 70, 'sink': 71, 'refrigerator': 72, 'book': 73,
|
| 39 |
+
'clock': 74, 'vase': 75, 'scissors': 76, 'teddy bear': 77, 'hair drier': 78,
|
| 40 |
+
'toothbrush': 79, 'banner': 80, 'blanket': 81, 'bridge': 82, 'cardboard': 83,
|
| 41 |
+
'counter': 84, 'curtain': 85, 'door-stuff': 86, 'floor-wood': 87,
|
| 42 |
+
'flower': 88, 'fruit': 89, 'gravel': 90, 'house': 91, 'light': 92,
|
| 43 |
+
'mirror-stuff': 93, 'net': 94, 'pillow': 95, 'platform': 96, 'playingfield': 97,
|
| 44 |
+
'railroad': 98, 'river': 99, 'road': 100, 'roof': 101, 'sand': 102,
|
| 45 |
+
'sea': 103, 'shelf': 104, 'snow': 105, 'stairs': 106, 'tent': 107,
|
| 46 |
+
'towel': 108, 'wall-brick': 109, 'wall-stone': 110, 'wall-tile': 111,
|
| 47 |
+
'wall-wood': 112, 'water': 113, 'window-blind': 114, 'window': 115,
|
| 48 |
+
'tree': 116, 'fence': 117, 'ceiling': 118, 'sky': 119, 'cabinet': 120,
|
| 49 |
+
'table': 121, 'floor': 122, 'pavement': 123, 'mountain': 124, 'grass': 125,
|
| 50 |
+
'dirt': 126, 'paper': 127, 'food': 128, 'building': 129, 'rock': 130,
|
| 51 |
+
'wall': 131, 'rug': 132
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
# Common ADE20K classes (subset of 150)
|
| 55 |
+
ADE20K_CLASSES = {
|
| 56 |
+
'wall': 0, 'building': 1, 'sky': 2, 'floor': 3, 'tree': 4,
|
| 57 |
+
'ceiling': 5, 'road': 6, 'bed': 7, 'windowpane': 8, 'grass': 9,
|
| 58 |
+
'cabinet': 10, 'sidewalk': 11, 'person': 12, 'earth': 13, 'door': 14,
|
| 59 |
+
'table': 15, 'mountain': 16, 'plant': 17, 'curtain': 18, 'chair': 19,
|
| 60 |
+
'car': 20, 'water': 21, 'painting': 22, 'sofa': 23, 'shelf': 24,
|
| 61 |
+
'house': 25, 'sea': 26, 'mirror': 27, 'rug': 28, 'field': 29,
|
| 62 |
+
'armchair': 30, 'seat': 31, 'fence': 32, 'desk': 33, 'rock': 34,
|
| 63 |
+
'wardrobe': 35, 'lamp': 36, 'bathtub': 37, 'railing': 38, 'cushion': 39,
|
| 64 |
+
'base': 40, 'box': 41, 'column': 42, 'signboard': 43, 'chest of drawers': 44,
|
| 65 |
+
'counter': 45, 'sand': 46, 'sink': 47, 'skyscraper': 48, 'fireplace': 49,
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
device: str = 'cuda',
|
| 71 |
+
conf_threshold: float = 0.5,
|
| 72 |
+
model_type: str = 'coco', # 'coco' or 'ade20k'
|
| 73 |
+
**kwargs
|
| 74 |
+
):
|
| 75 |
+
"""
|
| 76 |
+
Initialize Mask2Former segmenter.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
device: Device to run model on
|
| 80 |
+
conf_threshold: Confidence threshold for predictions
|
| 81 |
+
model_type: Which pre-trained model to use ('coco' or 'ade20k')
|
| 82 |
+
**kwargs: Additional arguments
|
| 83 |
+
"""
|
| 84 |
+
super().__init__(device, **kwargs)
|
| 85 |
+
self.conf_threshold = conf_threshold
|
| 86 |
+
self.model_type = model_type.lower()
|
| 87 |
+
|
| 88 |
+
if self.model_type not in ['coco', 'ade20k']:
|
| 89 |
+
raise ValueError(f"model_type must be 'coco' or 'ade20k', got {self.model_type}")
|
| 90 |
+
|
| 91 |
+
self.class_map = self.COCO_CLASSES if self.model_type == 'coco' else self.ADE20K_CLASSES
|
| 92 |
+
self.model = None
|
| 93 |
+
self.processor = None
|
| 94 |
+
|
| 95 |
+
def load_model(self):
|
| 96 |
+
"""Load Mask2Former model and processor from HuggingFace."""
|
| 97 |
+
try:
|
| 98 |
+
from transformers import Mask2FormerForUniversalSegmentation, AutoImageProcessor
|
| 99 |
+
except ImportError:
|
| 100 |
+
raise ImportError(
|
| 101 |
+
"Mask2Former requires transformers. Install with: pip install transformers"
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
if self.model_type == 'coco':
|
| 105 |
+
model_name = "facebook/mask2former-swin-large-coco-panoptic"
|
| 106 |
+
else: # ade20k
|
| 107 |
+
model_name = "facebook/mask2former-swin-large-ade-semantic"
|
| 108 |
+
|
| 109 |
+
print(f"Loading Mask2Former ({self.model_type}) model...")
|
| 110 |
+
ensure_default_checkpoint_dirs()
|
| 111 |
+
cache_dir = str(hf_cache_dir())
|
| 112 |
+
self.processor = AutoImageProcessor.from_pretrained(model_name, cache_dir=cache_dir)
|
| 113 |
+
self.model = Mask2FormerForUniversalSegmentation.from_pretrained(model_name, cache_dir=cache_dir)
|
| 114 |
+
self.model = self.model.to(self.device)
|
| 115 |
+
self.model.eval()
|
| 116 |
+
print(f"✓ Mask2Former loaded: {model_name}")
|
| 117 |
+
|
| 118 |
+
def segment(
|
| 119 |
+
self,
|
| 120 |
+
image: Image.Image,
|
| 121 |
+
target_classes: Optional[List[str]] = None,
|
| 122 |
+
**kwargs
|
| 123 |
+
) -> np.ndarray:
|
| 124 |
+
"""
|
| 125 |
+
Segment image using Mask2Former.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
image: PIL Image
|
| 129 |
+
target_classes: List of class names to segment (None for all)
|
| 130 |
+
**kwargs: Additional arguments (unused)
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
Binary mask as numpy array [H, W] with 1 for ROI, 0 for background
|
| 134 |
+
"""
|
| 135 |
+
if self.model is None:
|
| 136 |
+
self.load_model()
|
| 137 |
+
|
| 138 |
+
# Prepare image
|
| 139 |
+
inputs = self.processor(images=image, return_tensors="pt")
|
| 140 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 141 |
+
|
| 142 |
+
# Run inference
|
| 143 |
+
with torch.no_grad():
|
| 144 |
+
outputs = self.model(**inputs)
|
| 145 |
+
|
| 146 |
+
# Post-process
|
| 147 |
+
if self.model_type == 'coco':
|
| 148 |
+
# Panoptic segmentation
|
| 149 |
+
result = self.processor.post_process_panoptic_segmentation(
|
| 150 |
+
outputs,
|
| 151 |
+
target_sizes=[image.size[::-1]]
|
| 152 |
+
)[0]
|
| 153 |
+
segmentation = result['segmentation'].cpu().numpy()
|
| 154 |
+
segments_info = result['segments_info']
|
| 155 |
+
else:
|
| 156 |
+
# Semantic segmentation
|
| 157 |
+
result = self.processor.post_process_semantic_segmentation(
|
| 158 |
+
outputs,
|
| 159 |
+
target_sizes=[image.size[::-1]]
|
| 160 |
+
)[0]
|
| 161 |
+
segmentation = result.cpu().numpy()
|
| 162 |
+
segments_info = None
|
| 163 |
+
|
| 164 |
+
# Create binary mask
|
| 165 |
+
mask = np.zeros(segmentation.shape, dtype=np.uint8)
|
| 166 |
+
|
| 167 |
+
if target_classes is None:
|
| 168 |
+
# Return all detected objects
|
| 169 |
+
mask = (segmentation > 0).astype(np.uint8)
|
| 170 |
+
else:
|
| 171 |
+
# Filter by target classes
|
| 172 |
+
target_ids = []
|
| 173 |
+
for class_name in target_classes:
|
| 174 |
+
class_lower = class_name.lower()
|
| 175 |
+
if class_lower in self.class_map:
|
| 176 |
+
target_ids.append(self.class_map[class_lower])
|
| 177 |
+
else:
|
| 178 |
+
# Try fuzzy matching
|
| 179 |
+
for key, val in self.class_map.items():
|
| 180 |
+
if class_lower in key or key in class_lower:
|
| 181 |
+
target_ids.append(val)
|
| 182 |
+
break
|
| 183 |
+
|
| 184 |
+
if self.model_type == 'coco' and segments_info:
|
| 185 |
+
# Use segment info for panoptic
|
| 186 |
+
for segment in segments_info:
|
| 187 |
+
if segment['label_id'] in target_ids:
|
| 188 |
+
mask[segmentation == segment['id']] = 1
|
| 189 |
+
else:
|
| 190 |
+
# Use class IDs directly for semantic
|
| 191 |
+
for target_id in target_ids:
|
| 192 |
+
mask[segmentation == target_id] = 1
|
| 193 |
+
|
| 194 |
+
return mask
|
| 195 |
+
|
| 196 |
+
def get_available_classes(self) -> List[str]:
|
| 197 |
+
"""
|
| 198 |
+
Get list of available class names.
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
List of class names supported by the model
|
| 202 |
+
"""
|
| 203 |
+
return sorted(self.class_map.keys())
|
| 204 |
+
|
| 205 |
+
# Mask2Former can be batched through the HF processor for semantic mode.
|
| 206 |
+
# Panoptic post-processing is per-image, but the forward pass is batched.
|
| 207 |
+
supports_batch: bool = True
|
| 208 |
+
|
| 209 |
+
def segment_batch(
|
| 210 |
+
self,
|
| 211 |
+
images: List[Image.Image],
|
| 212 |
+
target_classes: Optional[List[str]] = None,
|
| 213 |
+
**kwargs,
|
| 214 |
+
) -> List[np.ndarray]:
|
| 215 |
+
"""Segment a batch of images via Mask2Former.
|
| 216 |
+
|
| 217 |
+
The HuggingFace processor accepts a list of images. The forward
|
| 218 |
+
pass runs on the full batch; post-processing is per-image.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
images: List of PIL Images
|
| 222 |
+
target_classes: Class names to include in ROI
|
| 223 |
+
**kwargs: unused
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
List of binary masks (H, W) float32
|
| 227 |
+
"""
|
| 228 |
+
if not images:
|
| 229 |
+
return []
|
| 230 |
+
|
| 231 |
+
if self.model is None:
|
| 232 |
+
self.load_model()
|
| 233 |
+
|
| 234 |
+
# Resolve target class IDs
|
| 235 |
+
target_ids = []
|
| 236 |
+
if target_classes:
|
| 237 |
+
for cn in target_classes:
|
| 238 |
+
cl = cn.lower()
|
| 239 |
+
if cl in self.class_map:
|
| 240 |
+
target_ids.append(self.class_map[cl])
|
| 241 |
+
else:
|
| 242 |
+
for key, val in self.class_map.items():
|
| 243 |
+
if cl in key or key in cl:
|
| 244 |
+
target_ids.append(val)
|
| 245 |
+
break
|
| 246 |
+
|
| 247 |
+
# Batch preprocess
|
| 248 |
+
inputs = self.processor(images=images, return_tensors="pt", padding=True)
|
| 249 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 250 |
+
|
| 251 |
+
with torch.no_grad():
|
| 252 |
+
outputs = self.model(**inputs)
|
| 253 |
+
|
| 254 |
+
# Collect all target sizes for batch post-processing
|
| 255 |
+
target_sizes = [img.size[::-1] for img in images] # [(H1, W1), (H2, W2), ...]
|
| 256 |
+
|
| 257 |
+
masks: List[np.ndarray] = []
|
| 258 |
+
|
| 259 |
+
if self.model_type == 'coco':
|
| 260 |
+
# Post-process entire batch at once for panoptic segmentation
|
| 261 |
+
results = self.processor.post_process_panoptic_segmentation(
|
| 262 |
+
outputs,
|
| 263 |
+
target_sizes=target_sizes,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
for i, (result, img) in enumerate(zip(results, images)):
|
| 267 |
+
segmentation = result['segmentation'].cpu().numpy()
|
| 268 |
+
segments_info = result['segments_info']
|
| 269 |
+
|
| 270 |
+
mask = np.zeros(segmentation.shape, dtype=np.float32)
|
| 271 |
+
if target_classes is None:
|
| 272 |
+
mask = (segmentation > 0).astype(np.float32)
|
| 273 |
+
else:
|
| 274 |
+
for seg in segments_info:
|
| 275 |
+
if seg['label_id'] in target_ids:
|
| 276 |
+
mask[segmentation == seg['id']] = 1.0
|
| 277 |
+
|
| 278 |
+
masks.append(mask)
|
| 279 |
+
else:
|
| 280 |
+
# Post-process entire batch at once for semantic segmentation
|
| 281 |
+
results = self.processor.post_process_semantic_segmentation(
|
| 282 |
+
outputs,
|
| 283 |
+
target_sizes=target_sizes,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
for i, (result, img) in enumerate(zip(results, images)):
|
| 287 |
+
segmentation = result.cpu().numpy()
|
| 288 |
+
mask = np.zeros(segmentation.shape, dtype=np.float32)
|
| 289 |
+
if target_classes is None:
|
| 290 |
+
mask = (segmentation > 0).astype(np.float32)
|
| 291 |
+
else:
|
| 292 |
+
for tid in target_ids:
|
| 293 |
+
mask[segmentation == tid] = 1.0
|
| 294 |
+
|
| 295 |
+
masks.append(mask)
|
| 296 |
+
|
| 297 |
+
return masks
|
| 298 |
+
|
| 299 |
+
def get_class_info(self) -> Dict[str, int]:
|
| 300 |
+
"""
|
| 301 |
+
Get detailed class information.
|
| 302 |
+
|
| 303 |
+
Returns:
|
| 304 |
+
Dictionary mapping class names to IDs
|
| 305 |
+
"""
|
| 306 |
+
return {
|
| 307 |
+
'model_type': self.model_type,
|
| 308 |
+
'num_classes': len(self.class_map),
|
| 309 |
+
'classes': self.class_map.copy()
|
| 310 |
+
}
|
segmentation/maskrcnn.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Mask R-CNN segmentation using torchvision pre-trained models.
|
| 3 |
+
Supports COCO-trained instance segmentation.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from model_cache import ensure_default_checkpoint_dirs
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
import cv2
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from typing import List, Optional, Dict
|
| 13 |
+
from .base import BaseSegmenter
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# Ensure torchvision/torch hub downloads land under `checkpoints/` by default.
|
| 17 |
+
ensure_default_checkpoint_dirs()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MaskRCNNSegmenter(BaseSegmenter):
|
| 21 |
+
"""
|
| 22 |
+
Mask R-CNN instance segmentation from torchvision.
|
| 23 |
+
|
| 24 |
+
Uses pre-trained ResNet50-FPN backbone on COCO dataset (80 classes).
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
# COCO class names (80 classes)
|
| 28 |
+
COCO_CLASSES = [
|
| 29 |
+
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
| 30 |
+
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
|
| 31 |
+
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
| 32 |
+
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
|
| 33 |
+
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
| 34 |
+
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
|
| 35 |
+
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
|
| 36 |
+
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
| 37 |
+
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
|
| 38 |
+
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
|
| 39 |
+
'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
|
| 40 |
+
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
device: str = 'cuda',
|
| 46 |
+
conf_threshold: float = 0.5,
|
| 47 |
+
backbone: str = 'resnet50', # 'resnet50' or 'mobilenet'
|
| 48 |
+
**kwargs
|
| 49 |
+
):
|
| 50 |
+
"""
|
| 51 |
+
Initialize Mask R-CNN segmenter.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
device: Device to run model on
|
| 55 |
+
conf_threshold: Confidence threshold for detections
|
| 56 |
+
backbone: Model backbone ('resnet50' or 'mobilenet')
|
| 57 |
+
**kwargs: Additional arguments
|
| 58 |
+
"""
|
| 59 |
+
super().__init__(device, **kwargs)
|
| 60 |
+
self.conf_threshold = conf_threshold
|
| 61 |
+
self.backbone = backbone.lower()
|
| 62 |
+
|
| 63 |
+
if self.backbone not in ['resnet50', 'mobilenet']:
|
| 64 |
+
raise ValueError(f"backbone must be 'resnet50' or 'mobilenet', got {self.backbone}")
|
| 65 |
+
|
| 66 |
+
self.model = None
|
| 67 |
+
|
| 68 |
+
def load_model(self):
|
| 69 |
+
"""Load Mask R-CNN model from torchvision."""
|
| 70 |
+
try:
|
| 71 |
+
import torchvision
|
| 72 |
+
from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights
|
| 73 |
+
from torchvision.models.detection import maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights
|
| 74 |
+
except ImportError:
|
| 75 |
+
raise ImportError(
|
| 76 |
+
"Mask R-CNN requires torchvision. Install with: pip install torchvision"
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
print(f"Loading Mask R-CNN ({self.backbone}) model...")
|
| 80 |
+
|
| 81 |
+
if self.backbone == 'resnet50':
|
| 82 |
+
# Use newer V2 weights for better performance
|
| 83 |
+
self.model = maskrcnn_resnet50_fpn_v2(weights=MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT)
|
| 84 |
+
else:
|
| 85 |
+
# MobileNet version (lighter but less accurate)
|
| 86 |
+
from torchvision.models.detection import maskrcnn_mobilenet_v3_large_fpn
|
| 87 |
+
from torchvision.models.detection import MaskRCNN_MobileNet_V3_Large_FPN_Weights
|
| 88 |
+
self.model = maskrcnn_mobilenet_v3_large_fpn(weights=MaskRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT)
|
| 89 |
+
|
| 90 |
+
self.model = self.model.to(self.device)
|
| 91 |
+
self.model.eval()
|
| 92 |
+
print(f"✓ Mask R-CNN loaded: {self.backbone}")
|
| 93 |
+
|
| 94 |
+
def segment(
|
| 95 |
+
self,
|
| 96 |
+
image: Image.Image,
|
| 97 |
+
target_classes: Optional[List[str]] = None,
|
| 98 |
+
**kwargs
|
| 99 |
+
) -> np.ndarray:
|
| 100 |
+
"""
|
| 101 |
+
Segment image using Mask R-CNN.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
image: PIL Image
|
| 105 |
+
target_classes: List of class names to segment (None for all)
|
| 106 |
+
**kwargs: Additional arguments (can override conf_threshold)
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
Binary mask as numpy array [H, W] with 1 for ROI, 0 for background
|
| 110 |
+
"""
|
| 111 |
+
if self.model is None:
|
| 112 |
+
self.load_model()
|
| 113 |
+
|
| 114 |
+
# Get confidence threshold from kwargs or use default
|
| 115 |
+
conf_threshold = kwargs.get('conf_threshold', self.conf_threshold)
|
| 116 |
+
|
| 117 |
+
# Prepare image
|
| 118 |
+
img_array = np.array(image)
|
| 119 |
+
if len(img_array.shape) == 2: # Grayscale
|
| 120 |
+
img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
|
| 121 |
+
elif img_array.shape[2] == 4: # RGBA
|
| 122 |
+
img_array = cv2.cvtColor(img_array, cv2.COLOR_RGBA2RGB)
|
| 123 |
+
|
| 124 |
+
# Convert to tensor [3, H, W] and normalize
|
| 125 |
+
img_tensor = torch.from_numpy(img_array).float().permute(2, 0, 1) / 255.0
|
| 126 |
+
img_tensor = img_tensor.to(self.device)
|
| 127 |
+
|
| 128 |
+
# Run inference
|
| 129 |
+
with torch.no_grad():
|
| 130 |
+
predictions = self.model([img_tensor])[0]
|
| 131 |
+
|
| 132 |
+
# Get predictions
|
| 133 |
+
boxes = predictions['boxes'].cpu().numpy()
|
| 134 |
+
labels = predictions['labels'].cpu().numpy()
|
| 135 |
+
scores = predictions['scores'].cpu().numpy()
|
| 136 |
+
masks = predictions['masks'].cpu().numpy()
|
| 137 |
+
|
| 138 |
+
# Filter by confidence
|
| 139 |
+
keep_indices = scores >= conf_threshold
|
| 140 |
+
boxes = boxes[keep_indices]
|
| 141 |
+
labels = labels[keep_indices]
|
| 142 |
+
scores = scores[keep_indices]
|
| 143 |
+
masks = masks[keep_indices]
|
| 144 |
+
|
| 145 |
+
# Create combined mask
|
| 146 |
+
h, w = img_array.shape[:2]
|
| 147 |
+
combined_mask = np.zeros((h, w), dtype=np.uint8)
|
| 148 |
+
|
| 149 |
+
if target_classes is None:
|
| 150 |
+
# Combine all high-confidence masks
|
| 151 |
+
for mask in masks:
|
| 152 |
+
binary_mask = (mask[0] > 0.5).astype(np.uint8)
|
| 153 |
+
combined_mask = np.maximum(combined_mask, binary_mask)
|
| 154 |
+
else:
|
| 155 |
+
# Filter by target classes
|
| 156 |
+
target_indices = []
|
| 157 |
+
for class_name in target_classes:
|
| 158 |
+
class_lower = class_name.lower()
|
| 159 |
+
for idx, coco_class in enumerate(self.COCO_CLASSES):
|
| 160 |
+
if coco_class.lower() == class_lower or class_lower in coco_class.lower():
|
| 161 |
+
target_indices.append(idx)
|
| 162 |
+
|
| 163 |
+
# Combine masks for target classes
|
| 164 |
+
for i, label in enumerate(labels):
|
| 165 |
+
if label in target_indices:
|
| 166 |
+
binary_mask = (masks[i][0] > 0.5).astype(np.uint8)
|
| 167 |
+
combined_mask = np.maximum(combined_mask, binary_mask)
|
| 168 |
+
|
| 169 |
+
return combined_mask
|
| 170 |
+
|
| 171 |
+
def get_available_classes(self) -> List[str]:
|
| 172 |
+
"""
|
| 173 |
+
Get list of available class names.
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
List of COCO class names (excluding 'N/A' and '__background__')
|
| 177 |
+
"""
|
| 178 |
+
return [cls for cls in self.COCO_CLASSES if cls not in ['N/A', '__background__']]
|
| 179 |
+
|
| 180 |
+
# torchvision Mask R-CNN natively accepts a list of tensors.
|
| 181 |
+
supports_batch: bool = True
|
| 182 |
+
|
| 183 |
+
def segment_batch(
|
| 184 |
+
self,
|
| 185 |
+
images: List[Image.Image],
|
| 186 |
+
target_classes: Optional[List[str]] = None,
|
| 187 |
+
**kwargs,
|
| 188 |
+
) -> List[np.ndarray]:
|
| 189 |
+
"""Segment a batch of images via Mask R-CNN.
|
| 190 |
+
|
| 191 |
+
torchvision's Mask R-CNN forward accepts ``List[Tensor]`` so we
|
| 192 |
+
can pass all images in one call. Post-processing is per-image.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
images: List of PIL Images
|
| 196 |
+
target_classes: COCO class names to include in mask
|
| 197 |
+
**kwargs: May include ``conf_threshold``
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
List of binary masks (H, W) float32
|
| 201 |
+
"""
|
| 202 |
+
if not images:
|
| 203 |
+
return []
|
| 204 |
+
|
| 205 |
+
if self.model is None:
|
| 206 |
+
self.load_model()
|
| 207 |
+
|
| 208 |
+
conf_threshold = kwargs.get('conf_threshold', self.conf_threshold)
|
| 209 |
+
|
| 210 |
+
# Resolve target class indices
|
| 211 |
+
target_indices: Optional[List[int]] = None
|
| 212 |
+
if target_classes is not None:
|
| 213 |
+
target_indices = []
|
| 214 |
+
for cn in target_classes:
|
| 215 |
+
cl = cn.lower()
|
| 216 |
+
for idx, cc in enumerate(self.COCO_CLASSES):
|
| 217 |
+
if cc.lower() == cl or cl in cc.lower():
|
| 218 |
+
target_indices.append(idx)
|
| 219 |
+
|
| 220 |
+
# Build list of tensors (varying sizes are ok for torchvision)
|
| 221 |
+
tensors = []
|
| 222 |
+
for img in images:
|
| 223 |
+
arr = np.array(img)
|
| 224 |
+
if len(arr.shape) == 2:
|
| 225 |
+
arr = cv2.cvtColor(arr, cv2.COLOR_GRAY2RGB)
|
| 226 |
+
elif arr.shape[2] == 4:
|
| 227 |
+
arr = cv2.cvtColor(arr, cv2.COLOR_RGBA2RGB)
|
| 228 |
+
t = torch.from_numpy(arr).float().permute(2, 0, 1) / 255.0
|
| 229 |
+
tensors.append(t.to(self.device))
|
| 230 |
+
|
| 231 |
+
with torch.no_grad():
|
| 232 |
+
predictions_list = self.model(tensors)
|
| 233 |
+
|
| 234 |
+
masks_out: List[np.ndarray] = []
|
| 235 |
+
for i, preds in enumerate(predictions_list):
|
| 236 |
+
h, w = images[i].height, images[i].width
|
| 237 |
+
combined = np.zeros((h, w), dtype=np.float32)
|
| 238 |
+
|
| 239 |
+
labels = preds['labels'].cpu().numpy()
|
| 240 |
+
scores = preds['scores'].cpu().numpy()
|
| 241 |
+
pred_masks = preds['masks'].cpu().numpy()
|
| 242 |
+
|
| 243 |
+
keep = scores >= conf_threshold
|
| 244 |
+
labels = labels[keep]
|
| 245 |
+
pred_masks = pred_masks[keep]
|
| 246 |
+
|
| 247 |
+
for j, lbl in enumerate(labels):
|
| 248 |
+
if target_indices is not None and int(lbl) not in target_indices:
|
| 249 |
+
continue
|
| 250 |
+
binary = (pred_masks[j][0] > 0.5).astype(np.float32)
|
| 251 |
+
combined = np.maximum(combined, binary)
|
| 252 |
+
|
| 253 |
+
masks_out.append(combined)
|
| 254 |
+
|
| 255 |
+
return masks_out
|
| 256 |
+
|
| 257 |
+
def get_detection_info(
|
| 258 |
+
self,
|
| 259 |
+
image: Image.Image,
|
| 260 |
+
conf_threshold: Optional[float] = None
|
| 261 |
+
) -> List[Dict]:
|
| 262 |
+
"""
|
| 263 |
+
Get detailed detection information for all objects.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
image: PIL Image
|
| 267 |
+
conf_threshold: Confidence threshold (uses default if None)
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
List of dictionaries with detection info (class, score, bbox, mask)
|
| 271 |
+
"""
|
| 272 |
+
if self.model is None:
|
| 273 |
+
self.load_model()
|
| 274 |
+
|
| 275 |
+
if conf_threshold is None:
|
| 276 |
+
conf_threshold = self.conf_threshold
|
| 277 |
+
|
| 278 |
+
# Prepare image
|
| 279 |
+
img_array = np.array(image)
|
| 280 |
+
if len(img_array.shape) == 2:
|
| 281 |
+
img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
|
| 282 |
+
elif img_array.shape[2] == 4:
|
| 283 |
+
img_array = cv2.cvtColor(img_array, cv2.COLOR_RGBA2RGB)
|
| 284 |
+
|
| 285 |
+
img_tensor = torch.from_numpy(img_array).float().permute(2, 0, 1) / 255.0
|
| 286 |
+
img_tensor = img_tensor.to(self.device)
|
| 287 |
+
|
| 288 |
+
# Run inference
|
| 289 |
+
with torch.no_grad():
|
| 290 |
+
predictions = self.model([img_tensor])[0]
|
| 291 |
+
|
| 292 |
+
# Get predictions
|
| 293 |
+
boxes = predictions['boxes'].cpu().numpy()
|
| 294 |
+
labels = predictions['labels'].cpu().numpy()
|
| 295 |
+
scores = predictions['scores'].cpu().numpy()
|
| 296 |
+
masks = predictions['masks'].cpu().numpy()
|
| 297 |
+
|
| 298 |
+
# Filter and format results
|
| 299 |
+
detections = []
|
| 300 |
+
for i in range(len(scores)):
|
| 301 |
+
if scores[i] >= conf_threshold:
|
| 302 |
+
detections.append({
|
| 303 |
+
'class': self.COCO_CLASSES[labels[i]],
|
| 304 |
+
'class_id': int(labels[i]),
|
| 305 |
+
'score': float(scores[i]),
|
| 306 |
+
'bbox': boxes[i].tolist(),
|
| 307 |
+
'mask': (masks[i][0] > 0.5).astype(np.uint8)
|
| 308 |
+
})
|
| 309 |
+
|
| 310 |
+
return detections
|
segmentation/sam3.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SAM3-style promptable segmentation.
|
| 2 |
+
|
| 3 |
+
This integrates a prompt-driven segmentation method into the existing
|
| 4 |
+
class-based segmentation interface. Instead of relying on a fixed class
|
| 5 |
+
vocabulary, it accepts natural-language prompts (e.g., "a red car", "the person").
|
| 6 |
+
|
| 7 |
+
Implementation approach (lightweight, no custom training):
|
| 8 |
+
- Use a text-conditioned detector (OWL-ViT) to propose bounding boxes from text.
|
| 9 |
+
- Use SAM (Segment Anything) to convert boxes into masks.
|
| 10 |
+
|
| 11 |
+
Notes:
|
| 12 |
+
- This is not "SAM 3" in the sense of an official model release; it is a
|
| 13 |
+
prompt-to-mask pipeline exposed as a single segmenter named "sam3".
|
| 14 |
+
- If required dependencies/models are missing, this segmenter raises a clear
|
| 15 |
+
error message.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from typing import List, Optional, Union
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import torch
|
| 25 |
+
from PIL import Image
|
| 26 |
+
|
| 27 |
+
from .base import BaseSegmenter
|
| 28 |
+
from model_cache import hf_cache_dir, ensure_default_checkpoint_dirs
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class _SAM3Config:
|
| 33 |
+
detector_model: str = "google/owlvit-base-patch32"
|
| 34 |
+
sam_model: str = "facebook/sam-vit-base"
|
| 35 |
+
box_threshold: float = 0.02
|
| 36 |
+
max_boxes: int = 5
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class SAM3Segmenter(BaseSegmenter):
|
| 40 |
+
"""Prompt-driven segmentation via (text detector → SAM).
|
| 41 |
+
|
| 42 |
+
Use `target_classes` to pass natural language prompts:
|
| 43 |
+
- `['car']`, `['a car']`, `['the person']`, etc.
|
| 44 |
+
|
| 45 |
+
Returns a binary mask (H, W) with 1 for predicted ROI.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
device: str = "cuda",
|
| 51 |
+
detector_model: str = _SAM3Config.detector_model,
|
| 52 |
+
sam_model: str = _SAM3Config.sam_model,
|
| 53 |
+
box_threshold: float = _SAM3Config.box_threshold,
|
| 54 |
+
max_boxes: int = _SAM3Config.max_boxes,
|
| 55 |
+
**kwargs,
|
| 56 |
+
):
|
| 57 |
+
super().__init__(device=device, **kwargs)
|
| 58 |
+
self.detector_model_name = detector_model
|
| 59 |
+
self.sam_model_name = sam_model
|
| 60 |
+
self.box_threshold = float(box_threshold)
|
| 61 |
+
self.max_boxes = int(max_boxes)
|
| 62 |
+
|
| 63 |
+
self._detector = None
|
| 64 |
+
self._sam_model = None
|
| 65 |
+
self._sam_processor = None
|
| 66 |
+
|
| 67 |
+
def load_model(self):
|
| 68 |
+
try:
|
| 69 |
+
from transformers import pipeline, SamModel, SamProcessor
|
| 70 |
+
except Exception as e: # pragma: no cover
|
| 71 |
+
raise ImportError(
|
| 72 |
+
"SAM3Segmenter requires `transformers` with SAM support. "
|
| 73 |
+
"Try: pip install -U transformers"
|
| 74 |
+
) from e
|
| 75 |
+
|
| 76 |
+
# Make sure any HF downloads (including pipeline internals) land under `checkpoints/`.
|
| 77 |
+
ensure_default_checkpoint_dirs()
|
| 78 |
+
|
| 79 |
+
# Configure device for HF pipeline
|
| 80 |
+
if self.device.startswith("cuda") and torch.cuda.is_available():
|
| 81 |
+
pipeline_device = 0
|
| 82 |
+
else:
|
| 83 |
+
pipeline_device = -1
|
| 84 |
+
|
| 85 |
+
self._detector = pipeline(
|
| 86 |
+
task="zero-shot-object-detection",
|
| 87 |
+
model=self.detector_model_name,
|
| 88 |
+
device=pipeline_device,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
cache_dir = str(hf_cache_dir())
|
| 92 |
+
|
| 93 |
+
self._sam_processor = SamProcessor.from_pretrained(self.sam_model_name, cache_dir=cache_dir)
|
| 94 |
+
self._sam_model = SamModel.from_pretrained(self.sam_model_name, cache_dir=cache_dir)
|
| 95 |
+
self._sam_model = self._sam_model.to(self.device)
|
| 96 |
+
self._sam_model.eval()
|
| 97 |
+
|
| 98 |
+
# Keep BaseSegmenter.model set for consistency
|
| 99 |
+
self.model = self._sam_model
|
| 100 |
+
|
| 101 |
+
def segment(
|
| 102 |
+
self,
|
| 103 |
+
image: Image.Image,
|
| 104 |
+
target_classes: Optional[List[str]] = None,
|
| 105 |
+
**kwargs,
|
| 106 |
+
) -> np.ndarray:
|
| 107 |
+
self.ensure_loaded()
|
| 108 |
+
|
| 109 |
+
prompts: List[str]
|
| 110 |
+
if target_classes is None or len(target_classes) == 0:
|
| 111 |
+
prompts = ["object"]
|
| 112 |
+
else:
|
| 113 |
+
# Treat provided "classes" as free-form text prompts.
|
| 114 |
+
prompts = [str(p).strip() for p in target_classes if str(p).strip()]
|
| 115 |
+
if not prompts:
|
| 116 |
+
prompts = ["object"]
|
| 117 |
+
|
| 118 |
+
box_threshold = float(kwargs.get("box_threshold", self.box_threshold))
|
| 119 |
+
max_boxes = int(kwargs.get("max_boxes", self.max_boxes))
|
| 120 |
+
|
| 121 |
+
detections = self._detector(image, candidate_labels=prompts)
|
| 122 |
+
|
| 123 |
+
# HF pipeline may return dict (single) or list
|
| 124 |
+
if isinstance(detections, dict):
|
| 125 |
+
detections = [detections]
|
| 126 |
+
|
| 127 |
+
boxes: List[List[float]] = []
|
| 128 |
+
for det in detections:
|
| 129 |
+
score = float(det.get("score", 0.0))
|
| 130 |
+
if score < box_threshold:
|
| 131 |
+
continue
|
| 132 |
+
b = det.get("box") or {}
|
| 133 |
+
xmin = float(b.get("xmin", 0.0))
|
| 134 |
+
ymin = float(b.get("ymin", 0.0))
|
| 135 |
+
xmax = float(b.get("xmax", 0.0))
|
| 136 |
+
ymax = float(b.get("ymax", 0.0))
|
| 137 |
+
# Sanity clamp
|
| 138 |
+
xmin, ymin = max(0.0, xmin), max(0.0, ymin)
|
| 139 |
+
xmax, ymax = max(xmin + 1.0, xmax), max(ymin + 1.0, ymax)
|
| 140 |
+
boxes.append([xmin, ymin, xmax, ymax])
|
| 141 |
+
|
| 142 |
+
if not boxes:
|
| 143 |
+
return np.zeros((image.height, image.width), dtype=np.float32)
|
| 144 |
+
|
| 145 |
+
boxes = boxes[:max_boxes]
|
| 146 |
+
|
| 147 |
+
# SAM expects a batch; provide one image with N boxes
|
| 148 |
+
inputs = self._sam_processor(
|
| 149 |
+
image,
|
| 150 |
+
input_boxes=[boxes],
|
| 151 |
+
return_tensors="pt",
|
| 152 |
+
)
|
| 153 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 154 |
+
|
| 155 |
+
with torch.no_grad():
|
| 156 |
+
outputs = self._sam_model(**inputs)
|
| 157 |
+
|
| 158 |
+
# Post-process masks back to original image size
|
| 159 |
+
# Returns list (batch) of tensors: [num_boxes, H, W]
|
| 160 |
+
post = self._sam_processor.image_processor.post_process_masks(
|
| 161 |
+
outputs.pred_masks.detach().cpu(),
|
| 162 |
+
inputs["original_sizes"].detach().cpu(),
|
| 163 |
+
inputs["reshaped_input_sizes"].detach().cpu(),
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
masks0 = post[0]
|
| 167 |
+
if isinstance(masks0, (list, tuple)):
|
| 168 |
+
# Defensive: some versions may nest
|
| 169 |
+
masks0 = torch.stack([m.squeeze(0) if m.ndim == 3 else m for m in masks0], dim=0)
|
| 170 |
+
|
| 171 |
+
# masks0: [num_boxes, H, W] or [num_boxes, 1, H, W]
|
| 172 |
+
if masks0.ndim == 4:
|
| 173 |
+
masks0 = masks0[:, 0]
|
| 174 |
+
|
| 175 |
+
combined = (masks0 > 0.5).any(dim=0).to(torch.float32)
|
| 176 |
+
return combined.numpy()
|
| 177 |
+
|
| 178 |
+
def get_available_classes(self) -> Union[List[str], dict]:
|
| 179 |
+
# Prompt-based model: not a fixed class list.
|
| 180 |
+
return []
|
| 181 |
+
|
| 182 |
+
def get_default_classes(self) -> List[str]:
|
| 183 |
+
return ["object"]
|
| 184 |
+
|
| 185 |
+
# SAM3 (OWL-ViT detector → SAM masks) is inherently sequential;
|
| 186 |
+
# the two-stage pipeline does not support batched inference.
|
| 187 |
+
supports_batch: bool = False
|
segmentation/segformer.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SegFormer-based segmentation implementation.
|
| 3 |
+
Uses Cityscapes-trained model for semantic segmentation.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
import cv2
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from typing import List, Optional, Dict
|
| 11 |
+
from .base import BaseSegmenter
|
| 12 |
+
from model_cache import hf_cache_dir, ensure_default_checkpoint_dirs
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SegFormerSegmenter(BaseSegmenter):
|
| 16 |
+
"""
|
| 17 |
+
SegFormer segmentation using Cityscapes classes.
|
| 18 |
+
|
| 19 |
+
Supports semantic segmentation of 19 urban scene classes including
|
| 20 |
+
vehicles, pedestrians, buildings, and road infrastructure.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
# Cityscapes class mapping
|
| 24 |
+
CITYSCAPES_CLASSES = {
|
| 25 |
+
'road': 0, 'sidewalk': 1, 'building': 2, 'wall': 3, 'fence': 4,
|
| 26 |
+
'pole': 5, 'traffic light': 6, 'traffic sign': 7, 'vegetation': 8,
|
| 27 |
+
'terrain': 9, 'sky': 10, 'person': 11, 'rider': 12, 'car': 13,
|
| 28 |
+
'truck': 14, 'bus': 15, 'train': 16, 'motorcycle': 17, 'bicycle': 18
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
device: str = 'cuda',
|
| 34 |
+
model_name: str = "nvidia/segformer-b4-finetuned-cityscapes-1024-1024",
|
| 35 |
+
**kwargs
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
Initialize SegFormer segmenter.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
device: Device to run on ('cuda' or 'cpu')
|
| 42 |
+
model_name: HuggingFace model identifier
|
| 43 |
+
**kwargs: Additional parameters
|
| 44 |
+
"""
|
| 45 |
+
super().__init__(device=device, **kwargs)
|
| 46 |
+
self.model_name = model_name
|
| 47 |
+
self.processor = None
|
| 48 |
+
|
| 49 |
+
def load_model(self):
|
| 50 |
+
"""Load SegFormer model from HuggingFace."""
|
| 51 |
+
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
|
| 52 |
+
|
| 53 |
+
print(f"Loading SegFormer model: {self.model_name}")
|
| 54 |
+
ensure_default_checkpoint_dirs()
|
| 55 |
+
cache_dir = str(hf_cache_dir())
|
| 56 |
+
self.processor = SegformerImageProcessor.from_pretrained(self.model_name, cache_dir=cache_dir)
|
| 57 |
+
self.model = SegformerForSemanticSegmentation.from_pretrained(
|
| 58 |
+
self.model_name,
|
| 59 |
+
cache_dir=cache_dir,
|
| 60 |
+
).to(self.device)
|
| 61 |
+
self.model.eval()
|
| 62 |
+
print("SegFormer model loaded successfully")
|
| 63 |
+
|
| 64 |
+
def segment(
|
| 65 |
+
self,
|
| 66 |
+
image: Image.Image,
|
| 67 |
+
target_classes: Optional[List[str]] = None,
|
| 68 |
+
**kwargs
|
| 69 |
+
) -> np.ndarray:
|
| 70 |
+
"""
|
| 71 |
+
Create segmentation mask using SegFormer.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
image: PIL Image
|
| 75 |
+
target_classes: List of class names (e.g., ['car', 'building', 'person'])
|
| 76 |
+
**kwargs: Additional parameters (unused)
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
Binary mask (H, W) with 1 for target classes, 0 for background
|
| 80 |
+
"""
|
| 81 |
+
# Validate classes
|
| 82 |
+
if target_classes is None:
|
| 83 |
+
target_classes = self.get_default_classes()
|
| 84 |
+
|
| 85 |
+
# Get target class IDs
|
| 86 |
+
target_ids = []
|
| 87 |
+
for cls in target_classes:
|
| 88 |
+
cls_lower = cls.lower()
|
| 89 |
+
if cls_lower in self.CITYSCAPES_CLASSES:
|
| 90 |
+
target_ids.append(self.CITYSCAPES_CLASSES[cls_lower])
|
| 91 |
+
else:
|
| 92 |
+
print(f"Warning: '{cls}' not in Cityscapes classes. "
|
| 93 |
+
f"Available: {list(self.CITYSCAPES_CLASSES.keys())}")
|
| 94 |
+
|
| 95 |
+
if not target_ids:
|
| 96 |
+
print(f"Warning: No valid classes found. Using 'car' as default.")
|
| 97 |
+
target_ids = [self.CITYSCAPES_CLASSES['car']]
|
| 98 |
+
|
| 99 |
+
# Process image
|
| 100 |
+
orig_size = image.size
|
| 101 |
+
inputs = self.processor(images=image, return_tensors="pt")
|
| 102 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 103 |
+
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
outputs = self.model(**inputs)
|
| 106 |
+
logits = outputs.logits
|
| 107 |
+
|
| 108 |
+
# Get segmentation map
|
| 109 |
+
seg_map = torch.argmax(logits, dim=1)[0].cpu().numpy()
|
| 110 |
+
|
| 111 |
+
# Resize to original size
|
| 112 |
+
seg_map_resized = cv2.resize(
|
| 113 |
+
seg_map.astype(np.uint8),
|
| 114 |
+
orig_size,
|
| 115 |
+
interpolation=cv2.INTER_NEAREST
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Create binary mask for target classes
|
| 119 |
+
mask = np.zeros_like(seg_map_resized, dtype=np.float32)
|
| 120 |
+
for class_id in target_ids:
|
| 121 |
+
mask[seg_map_resized == class_id] = 1.0
|
| 122 |
+
|
| 123 |
+
return mask
|
| 124 |
+
|
| 125 |
+
def get_available_classes(self) -> Dict[str, int]:
|
| 126 |
+
"""Get Cityscapes class mapping."""
|
| 127 |
+
return self.CITYSCAPES_CLASSES
|
| 128 |
+
|
| 129 |
+
def get_default_classes(self) -> List[str]:
|
| 130 |
+
"""Default to car segmentation."""
|
| 131 |
+
return ['car']
|
| 132 |
+
|
| 133 |
+
# SegFormer supports batched inference via the HF processor.
|
| 134 |
+
supports_batch: bool = True
|
| 135 |
+
|
| 136 |
+
def segment_batch(
|
| 137 |
+
self,
|
| 138 |
+
images: List[Image.Image],
|
| 139 |
+
target_classes: Optional[List[str]] = None,
|
| 140 |
+
**kwargs,
|
| 141 |
+
) -> List[np.ndarray]:
|
| 142 |
+
"""Segment a batch of images in a single forward pass.
|
| 143 |
+
|
| 144 |
+
The HuggingFace SegFormer preprocessor natively accepts a list of
|
| 145 |
+
PIL images and returns a batched tensor.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
images: List of PIL Images (should be same resolution for padding)
|
| 149 |
+
target_classes: Cityscapes class names to include in ROI mask
|
| 150 |
+
**kwargs: unused
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
List of binary masks (H, W) float32
|
| 154 |
+
"""
|
| 155 |
+
if not images:
|
| 156 |
+
return []
|
| 157 |
+
|
| 158 |
+
self.ensure_loaded()
|
| 159 |
+
|
| 160 |
+
if target_classes is None:
|
| 161 |
+
target_classes = self.get_default_classes()
|
| 162 |
+
|
| 163 |
+
target_ids = []
|
| 164 |
+
for cls in target_classes:
|
| 165 |
+
cls_lower = cls.lower()
|
| 166 |
+
if cls_lower in self.CITYSCAPES_CLASSES:
|
| 167 |
+
target_ids.append(self.CITYSCAPES_CLASSES[cls_lower])
|
| 168 |
+
|
| 169 |
+
if not target_ids:
|
| 170 |
+
target_ids = [self.CITYSCAPES_CLASSES['car']]
|
| 171 |
+
|
| 172 |
+
# Batch preprocess
|
| 173 |
+
inputs = self.processor(images=images, return_tensors="pt", padding=True)
|
| 174 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 175 |
+
|
| 176 |
+
with torch.no_grad():
|
| 177 |
+
outputs = self.model(**inputs)
|
| 178 |
+
logits = outputs.logits # (B, num_classes, h', w')
|
| 179 |
+
|
| 180 |
+
masks: List[np.ndarray] = []
|
| 181 |
+
for i, img in enumerate(images):
|
| 182 |
+
seg_map = torch.argmax(logits[i], dim=0).cpu().numpy()
|
| 183 |
+
seg_resized = cv2.resize(
|
| 184 |
+
seg_map.astype(np.uint8),
|
| 185 |
+
img.size,
|
| 186 |
+
interpolation=cv2.INTER_NEAREST,
|
| 187 |
+
)
|
| 188 |
+
mask = np.zeros_like(seg_resized, dtype=np.float32)
|
| 189 |
+
for cid in target_ids:
|
| 190 |
+
mask[seg_resized == cid] = 1.0
|
| 191 |
+
masks.append(mask)
|
| 192 |
+
|
| 193 |
+
return masks
|
segmentation/utils.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions for segmentation visualization and I/O.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import cv2
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def visualize_mask(
|
| 12 |
+
image: Image.Image,
|
| 13 |
+
mask: np.ndarray,
|
| 14 |
+
alpha: float = 0.5,
|
| 15 |
+
color: Tuple[int, int, int] = (255, 0, 0)
|
| 16 |
+
) -> Image.Image:
|
| 17 |
+
"""
|
| 18 |
+
Overlay segmentation mask on image.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
image: PIL Image
|
| 22 |
+
mask: Binary mask (H, W)
|
| 23 |
+
alpha: Transparency (0-1)
|
| 24 |
+
color: RGB color tuple for mask
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
Image with mask overlay
|
| 28 |
+
"""
|
| 29 |
+
# Convert to numpy
|
| 30 |
+
img_array = np.array(image)
|
| 31 |
+
|
| 32 |
+
# Create colored mask
|
| 33 |
+
colored_mask = np.zeros_like(img_array)
|
| 34 |
+
colored_mask[mask > 0.5] = color
|
| 35 |
+
|
| 36 |
+
# Blend
|
| 37 |
+
result = cv2.addWeighted(img_array, 1.0, colored_mask, alpha, 0)
|
| 38 |
+
|
| 39 |
+
return Image.fromarray(result)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def save_mask(mask: np.ndarray, output_path: str):
|
| 43 |
+
"""
|
| 44 |
+
Save mask as image (white for ROI, black for background).
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
mask: Binary mask array (H, W)
|
| 48 |
+
output_path: Path to save mask image
|
| 49 |
+
"""
|
| 50 |
+
mask_img = (mask * 255).astype(np.uint8)
|
| 51 |
+
Image.fromarray(mask_img).save(output_path)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def load_mask(mask_path: str) -> np.ndarray:
|
| 55 |
+
"""
|
| 56 |
+
Load mask from image file.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
mask_path: Path to mask image
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
Binary mask as numpy array (H, W) with values 0 or 1
|
| 63 |
+
"""
|
| 64 |
+
mask_img = Image.open(mask_path).convert('L')
|
| 65 |
+
mask = np.array(mask_img).astype(np.float32) / 255.0
|
| 66 |
+
return mask
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def calculate_roi_stats(mask: np.ndarray) -> dict:
|
| 70 |
+
"""
|
| 71 |
+
Calculate statistics about ROI coverage.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
mask: Binary mask (H, W)
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
Dictionary with statistics:
|
| 78 |
+
- roi_pixels: Number of ROI pixels
|
| 79 |
+
- total_pixels: Total number of pixels
|
| 80 |
+
- roi_percentage: Percentage of image covered by ROI
|
| 81 |
+
"""
|
| 82 |
+
roi_pixels = int(np.sum(mask > 0.5))
|
| 83 |
+
total_pixels = int(mask.size)
|
| 84 |
+
roi_percentage = (roi_pixels / total_pixels) * 100
|
| 85 |
+
|
| 86 |
+
return {
|
| 87 |
+
'roi_pixels': roi_pixels,
|
| 88 |
+
'total_pixels': total_pixels,
|
| 89 |
+
'roi_percentage': roi_percentage
|
| 90 |
+
}
|
segmentation/yolo.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
YOLO-based instance segmentation implementation.
|
| 3 |
+
Uses YOLO26 (or YOLOv8 fallback) for COCO object detection and segmentation.
|
| 4 |
+
Supports true batch inference for video processing.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import cv2
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from typing import List, Optional
|
| 11 |
+
from .base import BaseSegmenter
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class YOLOSegmenter(BaseSegmenter):
|
| 15 |
+
"""
|
| 16 |
+
YOLO instance segmentation using COCO classes.
|
| 17 |
+
|
| 18 |
+
Defaults to YOLO26x-seg for best accuracy.
|
| 19 |
+
Supports batch inference via ``segment_batch()`` for video pipelines.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
device: str = 'cuda',
|
| 25 |
+
model_path: str = 'checkpoints/yolo26x-seg.pt',
|
| 26 |
+
conf_threshold: float = 0.25,
|
| 27 |
+
**kwargs
|
| 28 |
+
):
|
| 29 |
+
"""
|
| 30 |
+
Initialize YOLO segmenter.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
device: Device to run on ('cuda' or 'cpu')
|
| 34 |
+
model_path: Path to YOLO weights file (yolo26x-seg.pt default)
|
| 35 |
+
conf_threshold: Confidence threshold for detections
|
| 36 |
+
**kwargs: Additional parameters
|
| 37 |
+
"""
|
| 38 |
+
super().__init__(device=device, **kwargs)
|
| 39 |
+
self.model_path = model_path
|
| 40 |
+
self.conf_threshold = conf_threshold
|
| 41 |
+
|
| 42 |
+
def load_model(self):
|
| 43 |
+
"""Load YOLO model."""
|
| 44 |
+
from ultralytics import YOLO
|
| 45 |
+
|
| 46 |
+
print(f"Loading YOLO model: {self.model_path}")
|
| 47 |
+
self.model = YOLO(self.model_path)
|
| 48 |
+
print("YOLO model loaded successfully")
|
| 49 |
+
|
| 50 |
+
def _extract_mask(
|
| 51 |
+
self,
|
| 52 |
+
result,
|
| 53 |
+
width: int,
|
| 54 |
+
height: int,
|
| 55 |
+
target_classes: List[str],
|
| 56 |
+
) -> np.ndarray:
|
| 57 |
+
"""Extract combined binary mask from a single YOLO result object.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
result: Single ultralytics Results object
|
| 61 |
+
width: Target mask width
|
| 62 |
+
height: Target mask height
|
| 63 |
+
target_classes: Classes to include
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Binary mask (H, W) float32
|
| 67 |
+
"""
|
| 68 |
+
mask = np.zeros((height, width), dtype=np.float32)
|
| 69 |
+
|
| 70 |
+
if result.masks is None:
|
| 71 |
+
return mask
|
| 72 |
+
|
| 73 |
+
masks_data = result.masks.data
|
| 74 |
+
boxes = result.boxes
|
| 75 |
+
|
| 76 |
+
for idx, box in enumerate(boxes):
|
| 77 |
+
class_id = int(box.cls[0])
|
| 78 |
+
class_name = self.model.names[class_id].lower()
|
| 79 |
+
|
| 80 |
+
if any(target.lower() in class_name for target in target_classes):
|
| 81 |
+
instance_mask = masks_data[idx].cpu().numpy()
|
| 82 |
+
instance_mask_resized = cv2.resize(
|
| 83 |
+
instance_mask,
|
| 84 |
+
(width, height),
|
| 85 |
+
interpolation=cv2.INTER_LINEAR,
|
| 86 |
+
)
|
| 87 |
+
mask = np.maximum(mask, instance_mask_resized)
|
| 88 |
+
|
| 89 |
+
return (mask > 0.5).astype(np.float32)
|
| 90 |
+
|
| 91 |
+
def segment(
|
| 92 |
+
self,
|
| 93 |
+
image: Image.Image,
|
| 94 |
+
target_classes: Optional[List[str]] = None,
|
| 95 |
+
conf_threshold: Optional[float] = None,
|
| 96 |
+
**kwargs
|
| 97 |
+
) -> np.ndarray:
|
| 98 |
+
"""
|
| 99 |
+
Create segmentation mask using YOLO.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
image: PIL Image
|
| 103 |
+
target_classes: List of class names (e.g., ['car', 'person'])
|
| 104 |
+
conf_threshold: Override default confidence threshold
|
| 105 |
+
**kwargs: Additional parameters
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
Binary mask (H, W) with 1 for target instances, 0 for background
|
| 109 |
+
"""
|
| 110 |
+
threshold = conf_threshold if conf_threshold is not None else self.conf_threshold
|
| 111 |
+
if target_classes is None:
|
| 112 |
+
target_classes = self.get_default_classes()
|
| 113 |
+
|
| 114 |
+
results = self.model(image, verbose=False, conf=threshold, device=self.device)
|
| 115 |
+
|
| 116 |
+
if not results:
|
| 117 |
+
return np.zeros((image.height, image.width), dtype=np.float32)
|
| 118 |
+
|
| 119 |
+
return self._extract_mask(results[0], image.width, image.height, target_classes)
|
| 120 |
+
|
| 121 |
+
def segment_batch(
|
| 122 |
+
self,
|
| 123 |
+
images: List[Image.Image],
|
| 124 |
+
target_classes: Optional[List[str]] = None,
|
| 125 |
+
conf_threshold: Optional[float] = None,
|
| 126 |
+
**kwargs,
|
| 127 |
+
) -> List[np.ndarray]:
|
| 128 |
+
"""Segment a batch of images in a single YOLO forward pass.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
images: List of PIL Images (should be the same resolution)
|
| 132 |
+
target_classes: Class names to include in ROI mask
|
| 133 |
+
conf_threshold: Override default confidence threshold
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
List of binary masks (H, W), one per input image
|
| 137 |
+
"""
|
| 138 |
+
if not images:
|
| 139 |
+
return []
|
| 140 |
+
|
| 141 |
+
self.ensure_loaded()
|
| 142 |
+
|
| 143 |
+
threshold = conf_threshold if conf_threshold is not None else self.conf_threshold
|
| 144 |
+
if target_classes is None:
|
| 145 |
+
target_classes = self.get_default_classes()
|
| 146 |
+
|
| 147 |
+
# Ultralytics accepts a list of PIL images for batch inference
|
| 148 |
+
results = self.model(images, verbose=False, conf=threshold, device=self.device)
|
| 149 |
+
|
| 150 |
+
masks = []
|
| 151 |
+
for i, result in enumerate(results):
|
| 152 |
+
img = images[i]
|
| 153 |
+
masks.append(self._extract_mask(result, img.width, img.height, target_classes))
|
| 154 |
+
|
| 155 |
+
return masks
|
| 156 |
+
|
| 157 |
+
def get_available_classes(self) -> List[str]:
|
| 158 |
+
"""
|
| 159 |
+
Get COCO class names.
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
List of COCO class names (80 classes)
|
| 163 |
+
"""
|
| 164 |
+
if self.model is None:
|
| 165 |
+
# Return common COCO classes if model not loaded
|
| 166 |
+
return [
|
| 167 |
+
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
|
| 168 |
+
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
|
| 169 |
+
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
|
| 170 |
+
'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
|
| 171 |
+
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
|
| 172 |
+
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
| 173 |
+
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
|
| 174 |
+
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
|
| 175 |
+
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
| 176 |
+
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
|
| 177 |
+
'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
|
| 178 |
+
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
|
| 179 |
+
'scissors', 'teddy bear', 'hair drier', 'toothbrush'
|
| 180 |
+
]
|
| 181 |
+
return list(self.model.names.values())
|
| 182 |
+
|
| 183 |
+
def get_default_classes(self) -> List[str]:
|
| 184 |
+
"""Default to car segmentation."""
|
| 185 |
+
return ['car']
|
| 186 |
+
|
| 187 |
+
# YOLO natively supports batched inference via ultralytics.
|
| 188 |
+
supports_batch: bool = True
|
vae/RSTB.py
ADDED
|
@@ -0,0 +1,813 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021-2022, InterDigital Communications, Inc
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# Redistribution and use in source and binary forms, with or without
|
| 5 |
+
# modification, are permitted (subject to the limitations in the disclaimer
|
| 6 |
+
# below) provided that the following conditions are met:
|
| 7 |
+
|
| 8 |
+
# * Redistributions of source code must retain the above copyright notice,
|
| 9 |
+
# this list of conditions and the following disclaimer.
|
| 10 |
+
# * Redistributions in binary form must reproduce the above copyright notice,
|
| 11 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 12 |
+
# and/or other materials provided with the distribution.
|
| 13 |
+
# * Neither the name of InterDigital Communications, Inc nor the names of its
|
| 14 |
+
# contributors may be used to endorse or promote products derived from this
|
| 15 |
+
# software without specific prior written permission.
|
| 16 |
+
|
| 17 |
+
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
|
| 18 |
+
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
|
| 19 |
+
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
|
| 20 |
+
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
|
| 21 |
+
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
|
| 22 |
+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
| 23 |
+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
| 24 |
+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
| 25 |
+
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
|
| 26 |
+
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
|
| 27 |
+
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
|
| 28 |
+
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
|
| 30 |
+
from typing import Any
|
| 31 |
+
|
| 32 |
+
import torch
|
| 33 |
+
import torch.nn as nn
|
| 34 |
+
import torch.nn.functional as F
|
| 35 |
+
from torch import Tensor
|
| 36 |
+
from torch.autograd import Function
|
| 37 |
+
import torch.utils.checkpoint as checkpoint
|
| 38 |
+
|
| 39 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 40 |
+
|
| 41 |
+
from compressai.layers import GDN
|
| 42 |
+
|
| 43 |
+
__all__ = [
|
| 44 |
+
"AttentionBlock",
|
| 45 |
+
"MaskedConv2d",
|
| 46 |
+
"ResidualBlock",
|
| 47 |
+
"ResidualBlockUpsample",
|
| 48 |
+
"ResidualBlockWithStride",
|
| 49 |
+
"conv3x3",
|
| 50 |
+
"subpel_conv3x3",
|
| 51 |
+
"QReLU",
|
| 52 |
+
"RSTB",
|
| 53 |
+
"CausalAttentionModule",
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class MaskedConv2d(nn.Conv2d):
|
| 58 |
+
r"""Masked 2D convolution implementation, mask future "unseen" pixels.
|
| 59 |
+
Useful for building auto-regressive network components.
|
| 60 |
+
|
| 61 |
+
Introduced in `"Conditional Image Generation with PixelCNN Decoders"
|
| 62 |
+
<https://arxiv.org/abs/1606.05328>`_.
|
| 63 |
+
|
| 64 |
+
Inherits the same arguments as a `nn.Conv2d`. Use `mask_type='A'` for the
|
| 65 |
+
first layer (which also masks the "current pixel"), `mask_type='B'` for the
|
| 66 |
+
following layers.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(self, *args: Any, mask_type: str = "A", **kwargs: Any):
|
| 70 |
+
super().__init__(*args, **kwargs)
|
| 71 |
+
|
| 72 |
+
if mask_type not in ("A", "B"):
|
| 73 |
+
raise ValueError(f'Invalid "mask_type" value "{mask_type}"')
|
| 74 |
+
|
| 75 |
+
self.register_buffer("mask", torch.ones_like(self.weight.data))
|
| 76 |
+
_, _, h, w = self.mask.size()
|
| 77 |
+
self.mask[:, :, h // 2, w // 2 + (mask_type == "B"):] = 0
|
| 78 |
+
self.mask[:, :, h // 2 + 1:] = 0
|
| 79 |
+
|
| 80 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 81 |
+
# TODO(begaintj): weight assigment is not supported by torchscript
|
| 82 |
+
self.weight.data *= self.mask
|
| 83 |
+
return super().forward(x)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def conv3x3(in_ch: int, out_ch: int, stride: int = 1) -> nn.Module:
|
| 87 |
+
"""3x3 convolution with padding."""
|
| 88 |
+
return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def subpel_conv3x3(in_ch: int, out_ch: int, r: int = 1) -> nn.Sequential:
|
| 92 |
+
"""3x3 sub-pixel convolution for up-sampling."""
|
| 93 |
+
return nn.Sequential(
|
| 94 |
+
nn.Conv2d(in_ch, out_ch * r ** 2, kernel_size=3, padding=1), nn.PixelShuffle(r)
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def conv1x1(in_ch: int, out_ch: int, stride: int = 1) -> nn.Module:
|
| 99 |
+
"""1x1 convolution."""
|
| 100 |
+
return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class ResidualBlockWithStride(nn.Module):
|
| 104 |
+
"""Residual block with a stride on the first convolution.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
in_ch (int): number of input channels
|
| 108 |
+
out_ch (int): number of output channels
|
| 109 |
+
stride (int): stride value (default: 2)
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
def __init__(self, in_ch: int, out_ch: int, stride: int = 2):
|
| 113 |
+
super().__init__()
|
| 114 |
+
self.conv1 = conv3x3(in_ch, out_ch, stride=stride)
|
| 115 |
+
self.leaky_relu = nn.LeakyReLU(inplace=True)
|
| 116 |
+
self.conv2 = conv3x3(out_ch, out_ch)
|
| 117 |
+
self.gdn = GDN(out_ch)
|
| 118 |
+
if stride != 1 or in_ch != out_ch:
|
| 119 |
+
self.skip = conv1x1(in_ch, out_ch, stride=stride)
|
| 120 |
+
else:
|
| 121 |
+
self.skip = None
|
| 122 |
+
|
| 123 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 124 |
+
identity = x
|
| 125 |
+
out = self.conv1(x)
|
| 126 |
+
out = self.leaky_relu(out)
|
| 127 |
+
out = self.conv2(out)
|
| 128 |
+
out = self.gdn(out)
|
| 129 |
+
|
| 130 |
+
if self.skip is not None:
|
| 131 |
+
identity = self.skip(x)
|
| 132 |
+
|
| 133 |
+
out += identity
|
| 134 |
+
return out
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class ResidualBlockUpsample(nn.Module):
|
| 138 |
+
"""Residual block with sub-pixel upsampling on the last convolution.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
in_ch (int): number of input channels
|
| 142 |
+
out_ch (int): number of output channels
|
| 143 |
+
upsample (int): upsampling factor (default: 2)
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
def __init__(self, in_ch: int, out_ch: int, upsample: int = 2):
|
| 147 |
+
super().__init__()
|
| 148 |
+
self.subpel_conv = subpel_conv3x3(in_ch, out_ch, upsample)
|
| 149 |
+
self.leaky_relu = nn.LeakyReLU(inplace=True)
|
| 150 |
+
self.conv = conv3x3(out_ch, out_ch)
|
| 151 |
+
self.igdn = GDN(out_ch, inverse=True)
|
| 152 |
+
self.upsample = subpel_conv3x3(in_ch, out_ch, upsample)
|
| 153 |
+
|
| 154 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 155 |
+
identity = x
|
| 156 |
+
out = self.subpel_conv(x)
|
| 157 |
+
out = self.leaky_relu(out)
|
| 158 |
+
out = self.conv(out)
|
| 159 |
+
out = self.igdn(out)
|
| 160 |
+
identity = self.upsample(x)
|
| 161 |
+
out += identity
|
| 162 |
+
return out
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class ResidualBlock(nn.Module):
|
| 166 |
+
"""Simple residual block with two 3x3 convolutions.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
in_ch (int): number of input channels
|
| 170 |
+
out_ch (int): number of output channels
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
def __init__(self, in_ch: int, out_ch: int):
|
| 174 |
+
super().__init__()
|
| 175 |
+
self.conv1 = conv3x3(in_ch, out_ch)
|
| 176 |
+
self.leaky_relu = nn.LeakyReLU(inplace=True)
|
| 177 |
+
self.conv2 = conv3x3(out_ch, out_ch)
|
| 178 |
+
if in_ch != out_ch:
|
| 179 |
+
self.skip = conv1x1(in_ch, out_ch)
|
| 180 |
+
else:
|
| 181 |
+
self.skip = None
|
| 182 |
+
|
| 183 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 184 |
+
identity = x
|
| 185 |
+
|
| 186 |
+
out = self.conv1(x)
|
| 187 |
+
out = self.leaky_relu(out)
|
| 188 |
+
out = self.conv2(out)
|
| 189 |
+
out = self.leaky_relu(out)
|
| 190 |
+
|
| 191 |
+
if self.skip is not None:
|
| 192 |
+
identity = self.skip(x)
|
| 193 |
+
|
| 194 |
+
out = out + identity
|
| 195 |
+
return out
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class AttentionBlock(nn.Module):
|
| 199 |
+
"""Self attention block.
|
| 200 |
+
|
| 201 |
+
Simplified variant from `"Learned Image Compression with
|
| 202 |
+
Discretized Gaussian Mixture Likelihoods and Attention Modules"
|
| 203 |
+
<https://arxiv.org/abs/2001.01568>`_, by Zhengxue Cheng, Heming Sun, Masaru
|
| 204 |
+
Takeuchi, Jiro Katto.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
N (int): Number of channels)
|
| 208 |
+
"""
|
| 209 |
+
|
| 210 |
+
def __init__(self, N: int):
|
| 211 |
+
super().__init__()
|
| 212 |
+
|
| 213 |
+
class ResidualUnit(nn.Module):
|
| 214 |
+
"""Simple residual unit."""
|
| 215 |
+
|
| 216 |
+
def __init__(self):
|
| 217 |
+
super().__init__()
|
| 218 |
+
self.conv = nn.Sequential(
|
| 219 |
+
conv1x1(N, N // 2),
|
| 220 |
+
nn.ReLU(inplace=True),
|
| 221 |
+
conv3x3(N // 2, N // 2),
|
| 222 |
+
nn.ReLU(inplace=True),
|
| 223 |
+
conv1x1(N // 2, N),
|
| 224 |
+
)
|
| 225 |
+
self.relu = nn.ReLU(inplace=True)
|
| 226 |
+
|
| 227 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 228 |
+
identity = x
|
| 229 |
+
out = self.conv(x)
|
| 230 |
+
out += identity
|
| 231 |
+
out = self.relu(out)
|
| 232 |
+
return out
|
| 233 |
+
|
| 234 |
+
self.conv_a = nn.Sequential(ResidualUnit(), ResidualUnit(), ResidualUnit())
|
| 235 |
+
|
| 236 |
+
self.conv_b = nn.Sequential(
|
| 237 |
+
ResidualUnit(),
|
| 238 |
+
ResidualUnit(),
|
| 239 |
+
ResidualUnit(),
|
| 240 |
+
conv1x1(N, N),
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 244 |
+
identity = x
|
| 245 |
+
a = self.conv_a(x)
|
| 246 |
+
b = self.conv_b(x)
|
| 247 |
+
out = a * torch.sigmoid(b)
|
| 248 |
+
out += identity
|
| 249 |
+
return out
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class QReLU(Function):
|
| 253 |
+
"""QReLU
|
| 254 |
+
|
| 255 |
+
Clamping input with given bit-depth range.
|
| 256 |
+
Suppose that input data presents integer through an integer network
|
| 257 |
+
otherwise any precision of input will simply clamp without rounding
|
| 258 |
+
operation.
|
| 259 |
+
|
| 260 |
+
Pre-computed scale with gamma function is used for backward computation.
|
| 261 |
+
|
| 262 |
+
More details can be found in
|
| 263 |
+
`"Integer networks for data compression with latent-variable models"
|
| 264 |
+
<https://openreview.net/pdf?id=S1zz2i0cY7>`_,
|
| 265 |
+
by Johannes Ballé, Nick Johnston and David Minnen, ICLR in 2019
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
input: a tensor data
|
| 269 |
+
bit_depth: source bit-depth (used for clamping)
|
| 270 |
+
beta: a parameter for modeling the gradient during backward computation
|
| 271 |
+
"""
|
| 272 |
+
|
| 273 |
+
@staticmethod
|
| 274 |
+
def forward(ctx, input, bit_depth, beta):
|
| 275 |
+
# TODO(choih): allow to use adaptive scale instead of
|
| 276 |
+
# pre-computed scale with gamma function
|
| 277 |
+
ctx.alpha = 0.9943258522851727
|
| 278 |
+
ctx.beta = beta
|
| 279 |
+
ctx.max_value = 2 ** bit_depth - 1
|
| 280 |
+
ctx.save_for_backward(input)
|
| 281 |
+
|
| 282 |
+
return input.clamp(min=0, max=ctx.max_value)
|
| 283 |
+
|
| 284 |
+
@staticmethod
|
| 285 |
+
def backward(ctx, grad_output):
|
| 286 |
+
grad_input = None
|
| 287 |
+
(input,) = ctx.saved_tensors
|
| 288 |
+
|
| 289 |
+
grad_input = grad_output.clone()
|
| 290 |
+
grad_sub = (
|
| 291 |
+
torch.exp(
|
| 292 |
+
(-ctx.alpha ** ctx.beta)
|
| 293 |
+
* torch.abs(2.0 * input / ctx.max_value - 1) ** ctx.beta
|
| 294 |
+
)
|
| 295 |
+
* grad_output.clone()
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
grad_input[input < 0] = grad_sub[input < 0]
|
| 299 |
+
grad_input[input > ctx.max_value] = grad_sub[input > ctx.max_value]
|
| 300 |
+
|
| 301 |
+
return grad_input, None, None
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class PatchEmbed(nn.Module):
|
| 305 |
+
def __init__(self):
|
| 306 |
+
super().__init__()
|
| 307 |
+
|
| 308 |
+
def forward(self, x):
|
| 309 |
+
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
|
| 310 |
+
return x
|
| 311 |
+
|
| 312 |
+
def flops(self):
|
| 313 |
+
flops = 0
|
| 314 |
+
return flops
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
class PatchUnEmbed(nn.Module):
|
| 318 |
+
def __init__(self):
|
| 319 |
+
super().__init__()
|
| 320 |
+
|
| 321 |
+
def forward(self, x, x_size):
|
| 322 |
+
B, HW, C = x.shape
|
| 323 |
+
x = x.transpose(1, 2).view(B, -1, x_size[0], x_size[1])
|
| 324 |
+
return x
|
| 325 |
+
|
| 326 |
+
def flops(self):
|
| 327 |
+
flops = 0
|
| 328 |
+
return flops
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
class Mlp(nn.Module):
|
| 332 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 333 |
+
super().__init__()
|
| 334 |
+
out_features = out_features or in_features
|
| 335 |
+
hidden_features = hidden_features or in_features
|
| 336 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 337 |
+
self.act = act_layer()
|
| 338 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 339 |
+
self.drop = nn.Dropout(drop)
|
| 340 |
+
|
| 341 |
+
def forward(self, x):
|
| 342 |
+
x = self.fc1(x)
|
| 343 |
+
x = self.act(x)
|
| 344 |
+
x = self.drop(x)
|
| 345 |
+
x = self.fc2(x)
|
| 346 |
+
x = self.drop(x)
|
| 347 |
+
return x
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def window_partition(x, window_size):
|
| 351 |
+
"""
|
| 352 |
+
Args:
|
| 353 |
+
x: (B, H, W, C)
|
| 354 |
+
window_size (int): window size
|
| 355 |
+
Returns:
|
| 356 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 357 |
+
"""
|
| 358 |
+
B, H, W, C = x.shape
|
| 359 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
| 360 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
| 361 |
+
return windows
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def window_reverse(windows, window_size, H, W):
|
| 365 |
+
"""
|
| 366 |
+
Args:
|
| 367 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 368 |
+
window_size (int): Window size
|
| 369 |
+
H (int): Height of image
|
| 370 |
+
W (int): Width of image
|
| 371 |
+
Returns:
|
| 372 |
+
x: (B, H, W, C)
|
| 373 |
+
"""
|
| 374 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
| 375 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
| 376 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
| 377 |
+
return x
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class WindowAttention(nn.Module):
|
| 381 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
| 382 |
+
It supports both of shifted and non-shifted window.
|
| 383 |
+
Args:
|
| 384 |
+
dim (int): Number of input channels.
|
| 385 |
+
window_size (tuple[int]): The height and width of the window.
|
| 386 |
+
num_heads (int): Number of attention heads.
|
| 387 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 388 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
| 389 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
| 390 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
| 391 |
+
"""
|
| 392 |
+
|
| 393 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 394 |
+
|
| 395 |
+
super().__init__()
|
| 396 |
+
self.dim = dim
|
| 397 |
+
self.window_size = window_size # Wh, Ww
|
| 398 |
+
self.num_heads = num_heads
|
| 399 |
+
head_dim = dim // num_heads
|
| 400 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 401 |
+
|
| 402 |
+
# define a parameter table of relative position bias
|
| 403 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 404 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
| 405 |
+
|
| 406 |
+
# get pair-wise relative position index for each token inside the window
|
| 407 |
+
coords_h = torch.arange(self.window_size[0])
|
| 408 |
+
coords_w = torch.arange(self.window_size[1])
|
| 409 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
| 410 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 411 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
| 412 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 413 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
| 414 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
| 415 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
| 416 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 417 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
| 418 |
+
|
| 419 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 420 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 421 |
+
self.proj = nn.Linear(dim, dim)
|
| 422 |
+
|
| 423 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 424 |
+
|
| 425 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
| 426 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 427 |
+
|
| 428 |
+
def forward(self, x, mask=None):
|
| 429 |
+
"""
|
| 430 |
+
Args:
|
| 431 |
+
x: input features with shape of (num_windows*B, N, C)
|
| 432 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
| 433 |
+
"""
|
| 434 |
+
B_, N, C = x.shape
|
| 435 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 436 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
| 437 |
+
|
| 438 |
+
q = q * self.scale
|
| 439 |
+
attn = (q @ k.transpose(-2, -1))
|
| 440 |
+
|
| 441 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
| 442 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
| 443 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
| 444 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
| 445 |
+
|
| 446 |
+
if mask is not None:
|
| 447 |
+
nW = mask.shape[0]
|
| 448 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
| 449 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
| 450 |
+
attn = self.softmax(attn)
|
| 451 |
+
else:
|
| 452 |
+
attn = self.softmax(attn)
|
| 453 |
+
|
| 454 |
+
attn = self.attn_drop(attn)
|
| 455 |
+
|
| 456 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
| 457 |
+
x = self.proj(x)
|
| 458 |
+
x = self.proj_drop(x)
|
| 459 |
+
return x
|
| 460 |
+
|
| 461 |
+
def extra_repr(self) -> str:
|
| 462 |
+
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
| 463 |
+
|
| 464 |
+
def flops(self, N):
|
| 465 |
+
# calculate flops for 1 window with token length of N
|
| 466 |
+
flops = 0
|
| 467 |
+
# qkv = self.qkv(x)
|
| 468 |
+
flops += N * self.dim * 3 * self.dim
|
| 469 |
+
# attn = (q @ k.transpose(-2, -1))
|
| 470 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
| 471 |
+
# x = (attn @ v)
|
| 472 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
| 473 |
+
# x = self.proj(x)
|
| 474 |
+
flops += N * self.dim * self.dim
|
| 475 |
+
return flops
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
class SwinTransformerBlock(nn.Module):
|
| 479 |
+
r""" Swin Transformer Block.
|
| 480 |
+
Args:
|
| 481 |
+
dim (int): Number of input channels.
|
| 482 |
+
input_resolution (tuple[int]): Input resulotion.
|
| 483 |
+
num_heads (int): Number of attention heads.
|
| 484 |
+
window_size (int): Window size.
|
| 485 |
+
shift_size (int): Shift size for SW-MSA.
|
| 486 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 487 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 488 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 489 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 490 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 491 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
| 492 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
| 493 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 494 |
+
"""
|
| 495 |
+
|
| 496 |
+
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
| 497 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
| 498 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
| 499 |
+
super().__init__()
|
| 500 |
+
self.dim = dim
|
| 501 |
+
self.input_resolution = input_resolution
|
| 502 |
+
self.num_heads = num_heads
|
| 503 |
+
self.window_size = window_size
|
| 504 |
+
self.shift_size = shift_size
|
| 505 |
+
self.mlp_ratio = mlp_ratio
|
| 506 |
+
if min(self.input_resolution) <= self.window_size:
|
| 507 |
+
# if window size is larger than input resolution, we don't partition windows
|
| 508 |
+
self.shift_size = 0
|
| 509 |
+
self.window_size = min(self.input_resolution)
|
| 510 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
| 511 |
+
|
| 512 |
+
self.norm1 = norm_layer(dim)
|
| 513 |
+
self.attn = WindowAttention(
|
| 514 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
| 515 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 516 |
+
|
| 517 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 518 |
+
self.norm2 = norm_layer(dim)
|
| 519 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 520 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 521 |
+
|
| 522 |
+
if self.shift_size > 0:
|
| 523 |
+
attn_mask = self.calculate_mask(self.input_resolution)
|
| 524 |
+
else:
|
| 525 |
+
attn_mask = None
|
| 526 |
+
|
| 527 |
+
self.register_buffer("attn_mask", attn_mask)
|
| 528 |
+
|
| 529 |
+
def calculate_mask(self, x_size):
|
| 530 |
+
# calculate attention mask for SW-MSA
|
| 531 |
+
H, W = x_size
|
| 532 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
| 533 |
+
h_slices = (slice(0, -self.window_size),
|
| 534 |
+
slice(-self.window_size, -self.shift_size),
|
| 535 |
+
slice(-self.shift_size, None))
|
| 536 |
+
w_slices = (slice(0, -self.window_size),
|
| 537 |
+
slice(-self.window_size, -self.shift_size),
|
| 538 |
+
slice(-self.shift_size, None))
|
| 539 |
+
cnt = 0
|
| 540 |
+
for h in h_slices:
|
| 541 |
+
for w in w_slices:
|
| 542 |
+
img_mask[:, h, w, :] = cnt
|
| 543 |
+
cnt += 1
|
| 544 |
+
|
| 545 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
| 546 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
| 547 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 548 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
| 549 |
+
|
| 550 |
+
return attn_mask
|
| 551 |
+
|
| 552 |
+
def forward(self, x, x_size):
|
| 553 |
+
H, W = x_size
|
| 554 |
+
B, L, C = x.shape
|
| 555 |
+
# assert L == H * W, "input feature has wrong size"
|
| 556 |
+
|
| 557 |
+
shortcut = x
|
| 558 |
+
x = self.norm1(x)
|
| 559 |
+
x = x.view(B, H, W, C)
|
| 560 |
+
|
| 561 |
+
# cyclic shift
|
| 562 |
+
if self.shift_size > 0:
|
| 563 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
| 564 |
+
else:
|
| 565 |
+
shifted_x = x
|
| 566 |
+
|
| 567 |
+
# partition windows
|
| 568 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
| 569 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
| 570 |
+
|
| 571 |
+
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
| 572 |
+
if self.input_resolution == x_size:
|
| 573 |
+
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
| 574 |
+
else:
|
| 575 |
+
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
| 576 |
+
|
| 577 |
+
# merge windows
|
| 578 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
| 579 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
| 580 |
+
|
| 581 |
+
# reverse cyclic shift
|
| 582 |
+
if self.shift_size > 0:
|
| 583 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
| 584 |
+
else:
|
| 585 |
+
x = shifted_x
|
| 586 |
+
x = x.view(B, H * W, C)
|
| 587 |
+
|
| 588 |
+
# FFN
|
| 589 |
+
x = shortcut + self.drop_path(x)
|
| 590 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 591 |
+
|
| 592 |
+
return x
|
| 593 |
+
|
| 594 |
+
def extra_repr(self) -> str:
|
| 595 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
| 596 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
| 597 |
+
|
| 598 |
+
def flops(self):
|
| 599 |
+
flops = 0
|
| 600 |
+
H, W = self.input_resolution
|
| 601 |
+
# norm1
|
| 602 |
+
flops += self.dim * H * W
|
| 603 |
+
# W-MSA/SW-MSA
|
| 604 |
+
nW = H * W / self.window_size / self.window_size
|
| 605 |
+
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
| 606 |
+
# mlp
|
| 607 |
+
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
| 608 |
+
# norm2
|
| 609 |
+
flops += self.dim * H * W
|
| 610 |
+
return flops
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
class BasicLayer(nn.Module):
|
| 614 |
+
""" A basic Swin Transformer layer for one stage.
|
| 615 |
+
Args:
|
| 616 |
+
dim (int): Number of input channels.
|
| 617 |
+
input_resolution (tuple[int]): Input resolution.
|
| 618 |
+
depth (int): Number of blocks.
|
| 619 |
+
num_heads (int): Number of attention heads.
|
| 620 |
+
window_size (int): Local window size.
|
| 621 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 622 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 623 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 624 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 625 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 626 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
| 627 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 628 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 629 |
+
"""
|
| 630 |
+
|
| 631 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
| 632 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
| 633 |
+
drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False):
|
| 634 |
+
|
| 635 |
+
super().__init__()
|
| 636 |
+
self.dim = dim
|
| 637 |
+
self.input_resolution = input_resolution
|
| 638 |
+
self.depth = depth
|
| 639 |
+
self.use_checkpoint = use_checkpoint
|
| 640 |
+
|
| 641 |
+
# build blocks
|
| 642 |
+
self.blocks = nn.ModuleList([
|
| 643 |
+
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
| 644 |
+
num_heads=num_heads, window_size=window_size,
|
| 645 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
| 646 |
+
mlp_ratio=mlp_ratio,
|
| 647 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 648 |
+
drop=drop, attn_drop=attn_drop,
|
| 649 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
| 650 |
+
norm_layer=norm_layer)
|
| 651 |
+
for i in range(depth)])
|
| 652 |
+
|
| 653 |
+
def forward(self, x, x_size):
|
| 654 |
+
for blk in self.blocks:
|
| 655 |
+
if self.use_checkpoint:
|
| 656 |
+
x = checkpoint.checkpoint(blk, x)
|
| 657 |
+
else:
|
| 658 |
+
x = blk(x, x_size)
|
| 659 |
+
return x
|
| 660 |
+
|
| 661 |
+
def extra_repr(self) -> str:
|
| 662 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
| 663 |
+
|
| 664 |
+
def flops(self):
|
| 665 |
+
flops = 0
|
| 666 |
+
for blk in self.blocks:
|
| 667 |
+
flops += blk.flops()
|
| 668 |
+
return flops
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
class RSTB(nn.Module):
|
| 672 |
+
"""Residual Swin Transformer Block (RSTB).
|
| 673 |
+
Args:
|
| 674 |
+
dim (int): Number of input channels.
|
| 675 |
+
input_resolution (tuple[int]): Input resolution.
|
| 676 |
+
depth (int): Number of blocks.
|
| 677 |
+
num_heads (int): Number of attention heads.
|
| 678 |
+
window_size (int): Local window size.
|
| 679 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 680 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 681 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 682 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 683 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 684 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
| 685 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 686 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 687 |
+
"""
|
| 688 |
+
|
| 689 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
| 690 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
| 691 |
+
drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False):
|
| 692 |
+
super(RSTB, self).__init__()
|
| 693 |
+
|
| 694 |
+
self.dim = dim
|
| 695 |
+
self.input_resolution = input_resolution
|
| 696 |
+
|
| 697 |
+
self.residual_group = BasicLayer(dim=dim,
|
| 698 |
+
input_resolution=input_resolution,
|
| 699 |
+
depth=depth,
|
| 700 |
+
num_heads=num_heads,
|
| 701 |
+
window_size=window_size,
|
| 702 |
+
mlp_ratio=mlp_ratio,
|
| 703 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 704 |
+
drop=drop, attn_drop=attn_drop,
|
| 705 |
+
drop_path=drop_path,
|
| 706 |
+
norm_layer=norm_layer,
|
| 707 |
+
use_checkpoint=use_checkpoint
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
self.patch_embed = PatchEmbed()
|
| 711 |
+
self.patch_unembed = PatchUnEmbed()
|
| 712 |
+
|
| 713 |
+
def forward(self, x, x_size):
|
| 714 |
+
return self.patch_unembed(self.residual_group(self.patch_embed(x), x_size), x_size) + x
|
| 715 |
+
|
| 716 |
+
def flops(self):
|
| 717 |
+
flops = 0
|
| 718 |
+
flops += self.residual_group.flops()
|
| 719 |
+
flops += self.patch_embed.flops()
|
| 720 |
+
flops += self.patch_unembed.flops()
|
| 721 |
+
|
| 722 |
+
return flops
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
class CausalAttentionModule(nn.Module):
|
| 726 |
+
r""" Causal multi-head self attention module.
|
| 727 |
+
|
| 728 |
+
Args:
|
| 729 |
+
dim (int): Number of input channels.
|
| 730 |
+
window_size (tuple[int]): The height and width of the window.
|
| 731 |
+
num_heads (int): Number of attention heads.
|
| 732 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 733 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
| 734 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
| 735 |
+
"""
|
| 736 |
+
|
| 737 |
+
def __init__(self, dim, out_dim, block_len=5, num_heads=16, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
| 738 |
+
attn_drop=0.):
|
| 739 |
+
super().__init__()
|
| 740 |
+
assert dim % num_heads == 0
|
| 741 |
+
self.dim = dim
|
| 742 |
+
self.num_heads = num_heads
|
| 743 |
+
head_dim = dim // num_heads
|
| 744 |
+
self.block_size = block_len * block_len
|
| 745 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 746 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 747 |
+
|
| 748 |
+
self.norm1 = nn.LayerNorm(dim)
|
| 749 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 750 |
+
self.mask = torch.Tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]).view(1,
|
| 751 |
+
self.block_size,
|
| 752 |
+
1)
|
| 753 |
+
|
| 754 |
+
# define a parameter table of relative position bias
|
| 755 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 756 |
+
torch.zeros((2 * block_len - 1) * (2 * block_len - 1), num_heads)) # 2*P-1 * 2*P-1, num_heads
|
| 757 |
+
|
| 758 |
+
# get pair-wise relative position index for each token inside the window
|
| 759 |
+
coords_h = torch.arange(block_len)
|
| 760 |
+
coords_w = torch.arange(block_len)
|
| 761 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, P, P
|
| 762 |
+
coords_flatten = torch.flatten(coords, 1) # 2, P*P
|
| 763 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, PP, PP
|
| 764 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # PP, PP, 2
|
| 765 |
+
relative_coords[:, :, 0] += block_len - 1 # shift to start from 0
|
| 766 |
+
relative_coords[:, :, 1] += block_len - 1
|
| 767 |
+
relative_coords[:, :, 0] *= 2 * block_len - 1
|
| 768 |
+
relative_position_index = relative_coords.sum(-1) # PP, PP
|
| 769 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
| 770 |
+
|
| 771 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 772 |
+
|
| 773 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 774 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 775 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=nn.GELU, drop=attn_drop)
|
| 776 |
+
self.proj = nn.Linear(dim, out_dim)
|
| 777 |
+
|
| 778 |
+
def forward(self, x):
|
| 779 |
+
B, C, H, W = x.shape
|
| 780 |
+
x_unfold = F.unfold(x, kernel_size=(5, 5), padding=2) # B, CPP, HW
|
| 781 |
+
x_unfold = x_unfold.reshape(B, C, self.block_size, H * W).permute(0, 3, 2, 1).contiguous().view(-1,
|
| 782 |
+
self.block_size,
|
| 783 |
+
C) # BHW, PP, C
|
| 784 |
+
|
| 785 |
+
x_masked = x_unfold * self.mask.to(x_unfold.device)
|
| 786 |
+
out = self.norm1(x_masked)
|
| 787 |
+
qkv = self.qkv(out).reshape(B * H * W, self.block_size, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3,
|
| 788 |
+
1,
|
| 789 |
+
4) # 3, BHW, num_heads, PP, C
|
| 790 |
+
q, k, v = qkv[0], qkv[1], qkv[
|
| 791 |
+
2] # make torchscript happy (cannot use tensor as tuple) # BHW, num_heads, PP, C//num_heads
|
| 792 |
+
q = q * self.scale
|
| 793 |
+
attn = (q @ k.transpose(-2, -1)) # BHW, num_heads, PP, PP
|
| 794 |
+
|
| 795 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
| 796 |
+
self.block_size, self.block_size, -1) # PP, PP, num_heads
|
| 797 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # num_heads, PP, PP
|
| 798 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
| 799 |
+
|
| 800 |
+
attn = self.softmax(attn)
|
| 801 |
+
attn = self.attn_drop(attn)
|
| 802 |
+
out = (attn @ v).transpose(1, 2).reshape(B * H * W, self.block_size,
|
| 803 |
+
C) # [BHW, num_heads, PP, PP] [BHW, num_heads, PP, C//num_heads]
|
| 804 |
+
out += x_masked
|
| 805 |
+
out_sumed = torch.sum(out, dim=1).reshape(B, H * W, C)
|
| 806 |
+
out = self.norm2(out_sumed)
|
| 807 |
+
out = self.mlp(out)
|
| 808 |
+
out += out_sumed
|
| 809 |
+
|
| 810 |
+
out = self.proj(out)
|
| 811 |
+
out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2) # B, C_out, H, W
|
| 812 |
+
|
| 813 |
+
return out
|
vae/__init__.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Local ROI-VAE compression package.
|
| 2 |
+
|
| 3 |
+
This package was previously named `compression`, but was renamed to `vae` to
|
| 4 |
+
avoid colliding with optional third-party imports (e.g. via Gradio/fsspec).
|
| 5 |
+
|
| 6 |
+
It intentionally uses lazy attribute loading to keep import-time overhead low.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from importlib import import_module
|
| 12 |
+
from typing import TYPE_CHECKING
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"TIC",
|
| 17 |
+
"ModifiedTIC",
|
| 18 |
+
"load_checkpoint",
|
| 19 |
+
"compute_padding",
|
| 20 |
+
"compress_image",
|
| 21 |
+
"highlight_roi",
|
| 22 |
+
"create_comparison_grid",
|
| 23 |
+
"RSTB",
|
| 24 |
+
"CausalAttentionModule",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
_EXPORTS: dict[str, tuple[str, str]] = {
|
| 29 |
+
"TIC": (".tic_model", "TIC"),
|
| 30 |
+
"ModifiedTIC": (".roi_tic", "ModifiedTIC"),
|
| 31 |
+
"load_checkpoint": (".roi_tic", "load_checkpoint"),
|
| 32 |
+
"compute_padding": (".utils", "compute_padding"),
|
| 33 |
+
"compress_image": (".utils", "compress_image"),
|
| 34 |
+
"highlight_roi": (".visualization", "highlight_roi"),
|
| 35 |
+
"create_comparison_grid": (".visualization", "create_comparison_grid"),
|
| 36 |
+
"RSTB": (".RSTB", "RSTB"),
|
| 37 |
+
"CausalAttentionModule": (".RSTB", "CausalAttentionModule"),
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def __getattr__(name: str):
|
| 42 |
+
if name not in _EXPORTS:
|
| 43 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
| 44 |
+
|
| 45 |
+
module_name, attr_name = _EXPORTS[name]
|
| 46 |
+
module = import_module(module_name, package=__name__)
|
| 47 |
+
value = getattr(module, attr_name)
|
| 48 |
+
globals()[name] = value
|
| 49 |
+
return value
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
if TYPE_CHECKING:
|
| 53 |
+
from .RSTB import CausalAttentionModule, RSTB
|
| 54 |
+
from .roi_tic import ModifiedTIC, load_checkpoint
|
| 55 |
+
from .tic_model import TIC
|
| 56 |
+
from .utils import compress_image, compute_padding
|
| 57 |
+
from .visualization import create_comparison_grid, highlight_roi
|
vae/roi_tic.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ROI-aware TIC model for region-based compression.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from .tic_model import TIC
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ModifiedTIC(TIC):
|
| 11 |
+
"""Modified TIC that uses pre-computed binary mask with sigma parameter for ROI-based compression"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, N=192, M=192):
|
| 14 |
+
super().__init__(N=N, M=M)
|
| 15 |
+
# Cache for pre-allocated constant tensors (device -> tensor)
|
| 16 |
+
self._ones_cache = {}
|
| 17 |
+
self._sigma_cache = {}
|
| 18 |
+
|
| 19 |
+
def forward(self, x, mask, sigma=0.3):
|
| 20 |
+
"""
|
| 21 |
+
Forward pass with ROI mask and quality factor.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
x: input image tensor [B, C, H, W]
|
| 25 |
+
mask: ROI mask (1 for ROI, 0 for background) [B, 1, H, W]
|
| 26 |
+
sigma: quality factor for background (0.01-1.0, lower = more compression)
|
| 27 |
+
Can be a scalar (same for all frames) or tensor [B] for per-frame values
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
dict with compression outputs (y_hat, y, similarity, x_hat, likelihoods)
|
| 31 |
+
"""
|
| 32 |
+
x_size = (x.shape[2], x.shape[3])
|
| 33 |
+
batch_size = x.shape[0]
|
| 34 |
+
device = mask.device
|
| 35 |
+
|
| 36 |
+
# Get or create cached ones tensor for this device
|
| 37 |
+
if device not in self._ones_cache:
|
| 38 |
+
self._ones_cache[device] = torch.tensor(1.0, device=device)
|
| 39 |
+
ones_tensor = self._ones_cache[device]
|
| 40 |
+
|
| 41 |
+
# Convert sigma to tensor [B, 1, 1, 1] for broadcasting
|
| 42 |
+
if isinstance(sigma, (int, float)):
|
| 43 |
+
# Scalar sigma - use cache
|
| 44 |
+
cache_key = (device, sigma)
|
| 45 |
+
if cache_key not in self._sigma_cache:
|
| 46 |
+
self._sigma_cache[cache_key] = torch.tensor(sigma, device=device)
|
| 47 |
+
sigma_tensor = self._sigma_cache[cache_key]
|
| 48 |
+
else:
|
| 49 |
+
# Batched sigma - convert to [B, 1, 1, 1]
|
| 50 |
+
if sigma.dim() == 1:
|
| 51 |
+
sigma_tensor = sigma.view(batch_size, 1, 1, 1)
|
| 52 |
+
else:
|
| 53 |
+
sigma_tensor = sigma
|
| 54 |
+
|
| 55 |
+
# Convert binary mask to quality factors (broadcasting handles per-frame sigma)
|
| 56 |
+
similarity_loss = torch.where(mask > 0.5, ones_tensor, sigma_tensor)
|
| 57 |
+
similarity_imp = torch.where(mask > 0.5, ones_tensor, sigma_tensor)
|
| 58 |
+
|
| 59 |
+
# Downsample mask to 1/2 resolution for simi_net
|
| 60 |
+
# simi_net has 3 stride-2 convolutions (8x downsampling), so input at 1/2 gives output at 1/16
|
| 61 |
+
# which matches y_codec_a6 dimensions (after g_a's 3x downsampling + g_a6's 1x downsampling = 16x)
|
| 62 |
+
# Use nearest-neighbor for binary masks (faster, no quality loss for binary data)
|
| 63 |
+
similarity_down = F.interpolate(similarity_imp, scale_factor=0.5, mode='nearest')
|
| 64 |
+
|
| 65 |
+
similarity_up = F.interpolate(similarity_loss, scale_factor=2, mode='nearest')
|
| 66 |
+
similarity_up_repeated = similarity_up.repeat(1, 3, 1, 1)
|
| 67 |
+
|
| 68 |
+
# simi_net downsamples by 8x: 128x128 -> 16x16 to match y_codec_a6
|
| 69 |
+
similarities_channel = self.simi_net(similarity_down)
|
| 70 |
+
similarities_sigmoid = torch.sigmoid(similarities_channel)
|
| 71 |
+
|
| 72 |
+
y_codec = self.g_a(x, x_size)
|
| 73 |
+
y_codec_a6 = self.g_a6(y_codec)
|
| 74 |
+
|
| 75 |
+
y_import = self.sub_impor_net(y_codec)
|
| 76 |
+
y_tanh = self.tanh(y_import)
|
| 77 |
+
y_soft = self.softsign(y_tanh)
|
| 78 |
+
|
| 79 |
+
y_imp = y_soft + similarities_sigmoid
|
| 80 |
+
y = y_codec_a6 * y_imp
|
| 81 |
+
|
| 82 |
+
z = self.h_a(y, x_size)
|
| 83 |
+
z_hat, z_likelihoods = self.entropy_bottleneck(z)
|
| 84 |
+
params = self.h_s(z_hat, x_size)
|
| 85 |
+
|
| 86 |
+
y_hat = self.gaussian_conditional.quantize(
|
| 87 |
+
y, "noise" if self.training else "dequantize"
|
| 88 |
+
)
|
| 89 |
+
ctx_params = self.context_prediction(y_hat)
|
| 90 |
+
gaussian_params = self.entropy_parameters(
|
| 91 |
+
torch.cat((params, ctx_params), dim=1)
|
| 92 |
+
)
|
| 93 |
+
scales_hat, means_hat = gaussian_params.chunk(2, 1)
|
| 94 |
+
_, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat)
|
| 95 |
+
x_hat = self.g_s(y_hat, x_size)
|
| 96 |
+
|
| 97 |
+
return {
|
| 98 |
+
"y_hat": y_hat,
|
| 99 |
+
"y": y,
|
| 100 |
+
"similarity": similarity_up_repeated,
|
| 101 |
+
"x_hat": x_hat,
|
| 102 |
+
"likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def load_checkpoint(checkpoint_path: str, N: int = 192, M: int = 192, device: str = 'cuda') -> ModifiedTIC:
|
| 107 |
+
"""
|
| 108 |
+
Load TIC model from checkpoint.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
checkpoint_path: Path to .pth.tar checkpoint
|
| 112 |
+
N: Number of channels (default 192)
|
| 113 |
+
M: Number of channels in expansion layers (default 192)
|
| 114 |
+
device: 'cuda' or 'cpu'
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
Loaded ModifiedTIC model in eval mode
|
| 118 |
+
"""
|
| 119 |
+
model = ModifiedTIC(N=N, M=M).to(device)
|
| 120 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 121 |
+
|
| 122 |
+
# Fix for compressai version mismatch (handle both old→new and new→old)
|
| 123 |
+
state_dict = checkpoint["state_dict"]
|
| 124 |
+
|
| 125 |
+
# Don't convert - just use as-is
|
| 126 |
+
# The checkpoint is already in the old format that the model expects
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
model.load_state_dict(state_dict)
|
| 130 |
+
model.eval()
|
| 131 |
+
model.update(force=True)
|
| 132 |
+
return model
|
vae/tic_model.py
ADDED
|
@@ -0,0 +1,989 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from compressai.entropy_models import EntropyBottleneck, GaussianConditional
|
| 9 |
+
from .RSTB import RSTB, CausalAttentionModule
|
| 10 |
+
from compressai.ans import BufferedRansEncoder, RansDecoder
|
| 11 |
+
from timm.models.layers import trunc_normal_
|
| 12 |
+
from compressai.models.utils import conv, deconv, update_registered_buffers
|
| 13 |
+
from compressai.layers import AttentionBlock
|
| 14 |
+
from PIL import Image
|
| 15 |
+
import numpy as np
|
| 16 |
+
import matplotlib
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
from matplotlib.colors import ListedColormap
|
| 19 |
+
|
| 20 |
+
# from lseg.lseg_net import LSegNet
|
| 21 |
+
# import cv2
|
| 22 |
+
# import random
|
| 23 |
+
import itertools
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 27 |
+
|
| 28 |
+
# From Balle's tensorflow compression examples
|
| 29 |
+
SCALES_MIN = 0.11
|
| 30 |
+
SCALES_MAX = 256
|
| 31 |
+
SCALES_LEVELS = 64
|
| 32 |
+
|
| 33 |
+
device = "cuda"
|
| 34 |
+
|
| 35 |
+
def get_scale_table(min=SCALES_MIN, max=SCALES_MAX, levels=SCALES_LEVELS):
|
| 36 |
+
return torch.exp(torch.linspace(math.log(min), math.log(max), levels))
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Binarizer(torch.autograd.Function):
|
| 40 |
+
"""
|
| 41 |
+
An elementwise function that bins values
|
| 42 |
+
to 0 or 1 depending on a threshold of
|
| 43 |
+
0.5
|
| 44 |
+
|
| 45 |
+
Input: a tensor with values in range(0,1)
|
| 46 |
+
|
| 47 |
+
Returns: a tensor with binary values: 0 or 1
|
| 48 |
+
based on a threshold of 0.5
|
| 49 |
+
|
| 50 |
+
Equation(1) in paper
|
| 51 |
+
"""
|
| 52 |
+
@staticmethod
|
| 53 |
+
def forward(ctx, i):
|
| 54 |
+
result = torch.where(i > 0.9, torch.tensor(1.0), torch.tensor(0.2))
|
| 55 |
+
return result
|
| 56 |
+
|
| 57 |
+
@staticmethod
|
| 58 |
+
def backward(ctx, grad_output):
|
| 59 |
+
return grad_output
|
| 60 |
+
|
| 61 |
+
def bin_values(x):
|
| 62 |
+
return Binarizer.apply(x)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class TIC(nn.Module):
|
| 68 |
+
"""Neural image compression framework from
|
| 69 |
+
Lu Ming and Guo, Peiyao and Shi, Huiqing and Cao, Chuntong and Ma, Zhan:
|
| 70 |
+
`"Transformer-based Image Compression" <https://arxiv.org/abs/2111.06707>`, (DCC 2022).
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
N (int): Number of channels
|
| 74 |
+
M (int): Number of channels in the expansion layers (last layer of the
|
| 75 |
+
encoder and last layer of the hyperprior decoder)
|
| 76 |
+
input_resolution (int): Just used for window partition decision
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(self, N=192, M=192):
|
| 80 |
+
super().__init__()
|
| 81 |
+
|
| 82 |
+
depths = [1, 2, 3, 1, 1]
|
| 83 |
+
num_heads = [4, 8, 16, 16, 16]
|
| 84 |
+
window_size = 8
|
| 85 |
+
mlp_ratio = 4.
|
| 86 |
+
qkv_bias = True
|
| 87 |
+
qk_scale = None
|
| 88 |
+
drop_rate = 0.
|
| 89 |
+
attn_drop_rate = 0.
|
| 90 |
+
drop_path_rate = 0.2
|
| 91 |
+
norm_layer = nn.LayerNorm
|
| 92 |
+
use_checkpoint = False
|
| 93 |
+
|
| 94 |
+
# stochastic depth
|
| 95 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
| 96 |
+
self.align_corners = True
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
self.g_a0 = conv(3, N, kernel_size=5, stride=2)
|
| 100 |
+
self.g_a1 = RSTB(dim=N,
|
| 101 |
+
input_resolution=(128, 128),
|
| 102 |
+
depth=depths[0],
|
| 103 |
+
num_heads=num_heads[0],
|
| 104 |
+
window_size=window_size,
|
| 105 |
+
mlp_ratio=mlp_ratio,
|
| 106 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 107 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
| 108 |
+
drop_path=dpr[sum(depths[:0]):sum(depths[:1])],
|
| 109 |
+
norm_layer=norm_layer,
|
| 110 |
+
use_checkpoint=use_checkpoint,
|
| 111 |
+
)
|
| 112 |
+
self.g_a2 = conv(N, N, kernel_size=3, stride=2)
|
| 113 |
+
self.g_a3 = RSTB(dim=N,
|
| 114 |
+
input_resolution=(64, 64),
|
| 115 |
+
depth=depths[1],
|
| 116 |
+
num_heads=num_heads[1],
|
| 117 |
+
window_size=window_size,
|
| 118 |
+
mlp_ratio=mlp_ratio,
|
| 119 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 120 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
| 121 |
+
drop_path=dpr[sum(depths[:1]):sum(depths[:2])],
|
| 122 |
+
norm_layer=norm_layer,
|
| 123 |
+
use_checkpoint=use_checkpoint,
|
| 124 |
+
)
|
| 125 |
+
self.g_a4 = conv(N, N, kernel_size=3, stride=2)
|
| 126 |
+
self.g_a5 = RSTB(dim=N,
|
| 127 |
+
input_resolution=(32, 32),
|
| 128 |
+
depth=depths[2],
|
| 129 |
+
num_heads=num_heads[2],
|
| 130 |
+
window_size=window_size,
|
| 131 |
+
mlp_ratio=mlp_ratio,
|
| 132 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 133 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
| 134 |
+
drop_path=dpr[sum(depths[:2]):sum(depths[:3])],
|
| 135 |
+
norm_layer=norm_layer,
|
| 136 |
+
use_checkpoint=use_checkpoint,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
self.g_a6 = conv(N, M, kernel_size=3, stride=2)
|
| 140 |
+
|
| 141 |
+
self.h_a0 = conv(M, N, kernel_size=3, stride=1)
|
| 142 |
+
self.h_a1 = RSTB(dim=N,
|
| 143 |
+
input_resolution=(16, 16),
|
| 144 |
+
depth=depths[3],
|
| 145 |
+
num_heads=num_heads[3],
|
| 146 |
+
window_size=window_size // 2,
|
| 147 |
+
mlp_ratio=mlp_ratio,
|
| 148 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 149 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
| 150 |
+
drop_path=dpr[sum(depths[:3]):sum(depths[:4])],
|
| 151 |
+
norm_layer=norm_layer,
|
| 152 |
+
use_checkpoint=use_checkpoint,
|
| 153 |
+
)
|
| 154 |
+
self.h_a2 = conv(N, N, kernel_size=3, stride=2)
|
| 155 |
+
self.h_a3 = RSTB(dim=N,
|
| 156 |
+
input_resolution=(8, 8),
|
| 157 |
+
depth=depths[4],
|
| 158 |
+
num_heads=num_heads[4],
|
| 159 |
+
window_size=window_size // 2,
|
| 160 |
+
mlp_ratio=mlp_ratio,
|
| 161 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 162 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
| 163 |
+
drop_path=dpr[sum(depths[:4]):sum(depths[:5])],
|
| 164 |
+
norm_layer=norm_layer,
|
| 165 |
+
use_checkpoint=use_checkpoint,
|
| 166 |
+
)
|
| 167 |
+
self.h_a4 = conv(N, N, kernel_size=3, stride=2)
|
| 168 |
+
|
| 169 |
+
depths = depths[::-1]
|
| 170 |
+
num_heads = num_heads[::-1]
|
| 171 |
+
self.h_s0 = deconv(N, N, kernel_size=3, stride=2)
|
| 172 |
+
self.h_s1 = RSTB(dim=N,
|
| 173 |
+
input_resolution=(8, 8),
|
| 174 |
+
depth=depths[0],
|
| 175 |
+
num_heads=num_heads[0],
|
| 176 |
+
window_size=window_size // 2,
|
| 177 |
+
mlp_ratio=mlp_ratio,
|
| 178 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 179 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
| 180 |
+
drop_path=dpr[sum(depths[:0]):sum(depths[:1])],
|
| 181 |
+
norm_layer=norm_layer,
|
| 182 |
+
use_checkpoint=use_checkpoint,
|
| 183 |
+
)
|
| 184 |
+
self.h_s2 = deconv(N, N, kernel_size=3, stride=2)
|
| 185 |
+
self.h_s3 = RSTB(dim=N,
|
| 186 |
+
input_resolution=(16, 16),
|
| 187 |
+
depth=depths[1],
|
| 188 |
+
num_heads=num_heads[1],
|
| 189 |
+
window_size=window_size // 2,
|
| 190 |
+
mlp_ratio=mlp_ratio,
|
| 191 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 192 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
| 193 |
+
drop_path=dpr[sum(depths[:1]):sum(depths[:2])],
|
| 194 |
+
norm_layer=norm_layer,
|
| 195 |
+
use_checkpoint=use_checkpoint,
|
| 196 |
+
)
|
| 197 |
+
self.h_s4 = conv(N, M * 2, kernel_size=3, stride=1)
|
| 198 |
+
|
| 199 |
+
self.g_s0 = deconv(M, N, kernel_size=3, stride=2)
|
| 200 |
+
self.g_s1 = RSTB(dim=N,
|
| 201 |
+
input_resolution=(32, 32),
|
| 202 |
+
depth=depths[2],
|
| 203 |
+
num_heads=num_heads[2],
|
| 204 |
+
window_size=window_size,
|
| 205 |
+
mlp_ratio=mlp_ratio,
|
| 206 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 207 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
| 208 |
+
drop_path=dpr[sum(depths[:2]):sum(depths[:3])],
|
| 209 |
+
norm_layer=norm_layer,
|
| 210 |
+
use_checkpoint=use_checkpoint,
|
| 211 |
+
)
|
| 212 |
+
self.g_s2 = deconv(N, N, kernel_size=3, stride=2)
|
| 213 |
+
self.g_s3 = RSTB(dim=N,
|
| 214 |
+
input_resolution=(64, 64),
|
| 215 |
+
depth=depths[3],
|
| 216 |
+
num_heads=num_heads[3],
|
| 217 |
+
window_size=window_size,
|
| 218 |
+
mlp_ratio=mlp_ratio,
|
| 219 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 220 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
| 221 |
+
drop_path=dpr[sum(depths[:3]):sum(depths[:4])],
|
| 222 |
+
norm_layer=norm_layer,
|
| 223 |
+
use_checkpoint=use_checkpoint,
|
| 224 |
+
)
|
| 225 |
+
self.g_s4 = deconv(N, N, kernel_size=3, stride=2)
|
| 226 |
+
self.g_s5 = RSTB(dim=N,
|
| 227 |
+
input_resolution=(128, 128),
|
| 228 |
+
depth=depths[4],
|
| 229 |
+
num_heads=num_heads[4],
|
| 230 |
+
window_size=window_size,
|
| 231 |
+
mlp_ratio=mlp_ratio,
|
| 232 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 233 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
| 234 |
+
drop_path=dpr[sum(depths[:4]):sum(depths[:5])],
|
| 235 |
+
norm_layer=norm_layer,
|
| 236 |
+
use_checkpoint=use_checkpoint,
|
| 237 |
+
)
|
| 238 |
+
self.g_s6 = deconv(N, 3, kernel_size=5, stride=2)
|
| 239 |
+
|
| 240 |
+
self.entropy_bottleneck = EntropyBottleneck(N)
|
| 241 |
+
self.gaussian_conditional = GaussianConditional(None)
|
| 242 |
+
self.context_prediction = CausalAttentionModule(M, M * 2)
|
| 243 |
+
# self.attetionmap = AttentionBlock(M)
|
| 244 |
+
|
| 245 |
+
self.entropy_parameters = nn.Sequential(
|
| 246 |
+
nn.Conv2d(M * 12 // 3, M * 10 // 3, 1),
|
| 247 |
+
nn.GELU(),
|
| 248 |
+
nn.Conv2d(M * 10 // 3, M * 8 // 3, 1),
|
| 249 |
+
nn.GELU(),
|
| 250 |
+
nn.Conv2d(M * 8 // 3, M * 6 // 3, 1),
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
self.sub_net_leaky = nn.Sequential(
|
| 254 |
+
conv(N,N,kernel_size=3,stride=2),
|
| 255 |
+
nn.LeakyReLU()
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
self.sub_net0 = nn.Sequential(
|
| 259 |
+
conv(N,64,kernel_size=1,stride=1),
|
| 260 |
+
nn.ReLU()
|
| 261 |
+
)
|
| 262 |
+
self.sub_net1 = nn.Sequential(
|
| 263 |
+
conv(64,64,kernel_size=3,stride=1),
|
| 264 |
+
nn.ReLU()
|
| 265 |
+
)
|
| 266 |
+
self.sub_net2 = conv(64,N,kernel_size=1,stride=1)
|
| 267 |
+
|
| 268 |
+
self.sub_net_channel = conv(N,M,kernel_size=1,stride=1)
|
| 269 |
+
|
| 270 |
+
self.simi_net = nn.Sequential(
|
| 271 |
+
conv(1,64,kernel_size=3,stride=2),
|
| 272 |
+
nn.ReLU(),
|
| 273 |
+
conv(64,128,kernel_size=3,stride=2),
|
| 274 |
+
nn.ReLU(),
|
| 275 |
+
conv(128, M, kernel_size=3, stride=2),
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
# self.net_lseg = LSegNet(
|
| 279 |
+
# backbone="clip_vitl16_384",
|
| 280 |
+
# features=256,
|
| 281 |
+
# crop_size=256,
|
| 282 |
+
# arch_option=0,
|
| 283 |
+
# block_depth=0,
|
| 284 |
+
# activation="lrelu",
|
| 285 |
+
# )
|
| 286 |
+
|
| 287 |
+
self.cosine_similarity = torch.nn.CosineSimilarity(dim=1)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
# self.con3_3 = conv(192,192,kernel_size=3,stride=1)
|
| 293 |
+
|
| 294 |
+
self.tanh = nn.Tanh()
|
| 295 |
+
self.softsign = nn.Softsign()
|
| 296 |
+
self.relu = nn.ReLU()
|
| 297 |
+
self.sigmoid = nn.Sigmoid()
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
self.apply(self._init_weights)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def g_a(self, x, x_size=None):
|
| 305 |
+
if x_size is None:
|
| 306 |
+
x_size = x.shape[2:4]
|
| 307 |
+
x = self.g_a0(x)
|
| 308 |
+
x = self.g_a1(x, (x_size[0] // 2, x_size[1] // 2))
|
| 309 |
+
x = self.g_a2(x)
|
| 310 |
+
x = self.g_a3(x, (x_size[0] // 4, x_size[1] // 4))
|
| 311 |
+
x = self.g_a4(x)
|
| 312 |
+
x = self.g_a5(x, (x_size[0] // 8, x_size[1] // 8))
|
| 313 |
+
# x = self.g_a6(x)
|
| 314 |
+
return x
|
| 315 |
+
|
| 316 |
+
def g_s(self, x, x_size=None):
|
| 317 |
+
if x_size is None:
|
| 318 |
+
x_size = (x.shape[2] * 16, x.shape[3] * 16)
|
| 319 |
+
x = self.g_s0(x)
|
| 320 |
+
x = self.g_s1(x, (x_size[0] // 8, x_size[1] // 8))
|
| 321 |
+
x = self.g_s2(x)
|
| 322 |
+
x = self.g_s3(x, (x_size[0] // 4, x_size[1] // 4))
|
| 323 |
+
x = self.g_s4(x)
|
| 324 |
+
x = self.g_s5(x, (x_size[0] // 2, x_size[1] // 2))
|
| 325 |
+
x = self.g_s6(x)
|
| 326 |
+
return x
|
| 327 |
+
|
| 328 |
+
def h_a(self, x, x_size=None):
|
| 329 |
+
if x_size is None:
|
| 330 |
+
x_size = (x.shape[2] * 16, x.shape[3] * 16)
|
| 331 |
+
x = self.h_a0(x)
|
| 332 |
+
x = self.h_a1(x, (x_size[0] // 16, x_size[1] // 16))
|
| 333 |
+
x = self.h_a2(x)
|
| 334 |
+
x = self.h_a3(x, (x_size[0] // 32, x_size[1] // 32))
|
| 335 |
+
x = self.h_a4(x)
|
| 336 |
+
return x
|
| 337 |
+
|
| 338 |
+
def h_s(self, x, x_size=None):
|
| 339 |
+
if x_size is None:
|
| 340 |
+
x_size = (x.shape[2] * 64, x.shape[3] * 64)
|
| 341 |
+
x = self.h_s0(x)
|
| 342 |
+
x = self.h_s1(x, (x_size[0] // 32, x_size[1] // 32))
|
| 343 |
+
x = self.h_s2(x)
|
| 344 |
+
x = self.h_s3(x, (x_size[0] // 16, x_size[1] // 16))
|
| 345 |
+
x = self.h_s4(x)
|
| 346 |
+
return x
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def sub_impor_net(self,x): # important map
|
| 350 |
+
x1 = self.sub_net_leaky(x)
|
| 351 |
+
|
| 352 |
+
x2 = self.sub_net0(x1)
|
| 353 |
+
x2 = self.sub_net1(x2)
|
| 354 |
+
x2 = self.sub_net2(x2)
|
| 355 |
+
|
| 356 |
+
x2 = x1 + x2
|
| 357 |
+
x3 = self.sub_net0(x2)
|
| 358 |
+
x3 = self.sub_net1(x3)
|
| 359 |
+
x3 = self.sub_net2(x3)
|
| 360 |
+
|
| 361 |
+
x3 = x2 + x3
|
| 362 |
+
x4 = self.sub_net0(x3)
|
| 363 |
+
x4 = self.sub_net1(x4)
|
| 364 |
+
x4 = self.sub_net2(x4)
|
| 365 |
+
|
| 366 |
+
x_out = x4 + x3
|
| 367 |
+
x_out = self.sub_net_channel(x_out)
|
| 368 |
+
|
| 369 |
+
return x_out
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def aux_loss(self):
|
| 374 |
+
"""Return the aggregated loss over the auxiliary entropy bottleneck
|
| 375 |
+
module(s).
|
| 376 |
+
"""
|
| 377 |
+
aux_loss = sum(
|
| 378 |
+
m.loss() for m in self.modules() if isinstance(m, EntropyBottleneck)
|
| 379 |
+
)
|
| 380 |
+
return aux_loss
|
| 381 |
+
|
| 382 |
+
def _init_weights(self, m):
|
| 383 |
+
if isinstance(m, nn.Linear):
|
| 384 |
+
trunc_normal_(m.weight, std=.02)
|
| 385 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 386 |
+
nn.init.constant_(m.bias, 0)
|
| 387 |
+
elif isinstance(m, nn.LayerNorm):
|
| 388 |
+
nn.init.constant_(m.bias, 0)
|
| 389 |
+
nn.init.constant_(m.weight, 1.0)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
@torch.jit.ignore
|
| 393 |
+
def no_weight_decay_keywords(self):
|
| 394 |
+
return {'relative_position_bias_table'}
|
| 395 |
+
|
| 396 |
+
def forward(self, x, similarity):
|
| 397 |
+
|
| 398 |
+
x_size = (x.shape[2], x.shape[3])
|
| 399 |
+
|
| 400 |
+
h, w = x.size(2), x.size(3)
|
| 401 |
+
|
| 402 |
+
similarity_loss = torch.where(similarity > 0.85, torch.tensor(1.0), torch.tensor(0.01))
|
| 403 |
+
similarity_imp = torch.where(similarity > 0.85, torch.tensor(1.0), torch.tensor(0.01))
|
| 404 |
+
|
| 405 |
+
similarity_up = F.interpolate(similarity_loss, scale_factor=2, mode='bilinear')
|
| 406 |
+
similarity_up_repeated = similarity_up.repeat(1, 3, 1, 1)
|
| 407 |
+
|
| 408 |
+
similarities_channel = self.simi_net(similarity_imp)
|
| 409 |
+
similarities_sigmoid = torch.sigmoid(similarities_channel)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
y_codec = self.g_a(x, x_size) # y
|
| 413 |
+
y_codec_a6 = self.g_a6(y_codec)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
y_import = self.sub_impor_net(y_codec)
|
| 417 |
+
y_tanh = self.tanh(y_import)
|
| 418 |
+
y_soft = self.softsign(y_tanh)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
y_imp = y_soft + similarities_sigmoid
|
| 423 |
+
y = y_codec_a6 * y_imp
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
z = self.h_a(y, x_size)
|
| 427 |
+
z_hat, z_likelihoods = self.entropy_bottleneck(z)
|
| 428 |
+
params = self.h_s(z_hat, x_size)
|
| 429 |
+
|
| 430 |
+
y_hat = self.gaussian_conditional.quantize(
|
| 431 |
+
y, "noise" if self.training else "dequantize"
|
| 432 |
+
)
|
| 433 |
+
ctx_params = self.context_prediction(y_hat)
|
| 434 |
+
gaussian_params = self.entropy_parameters(
|
| 435 |
+
torch.cat((params, ctx_params), dim=1)
|
| 436 |
+
)
|
| 437 |
+
scales_hat, means_hat = gaussian_params.chunk(2, 1)
|
| 438 |
+
_, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat)
|
| 439 |
+
x_hat = self.g_s(y_hat, x_size)
|
| 440 |
+
|
| 441 |
+
return {
|
| 442 |
+
"y_hat": y_hat,
|
| 443 |
+
"y": y,
|
| 444 |
+
"similarity":similarity_up_repeated,
|
| 445 |
+
"x_hat": x_hat,
|
| 446 |
+
"likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def update(self, scale_table=None, force=False):
|
| 451 |
+
"""Updates the entropy bottleneck(s) CDF values.
|
| 452 |
+
|
| 453 |
+
Needs to be called once after training to be able to later perform the
|
| 454 |
+
evaluation with an actual entropy coder.
|
| 455 |
+
|
| 456 |
+
Args:
|
| 457 |
+
scale_table (bool): (default: None)
|
| 458 |
+
force (bool): overwrite previous values (default: False)
|
| 459 |
+
|
| 460 |
+
Returns:
|
| 461 |
+
updated (bool): True if one of the EntropyBottlenecks was updated.
|
| 462 |
+
|
| 463 |
+
"""
|
| 464 |
+
if scale_table is None:
|
| 465 |
+
scale_table = get_scale_table()
|
| 466 |
+
self.gaussian_conditional.update_scale_table(scale_table, force=force)
|
| 467 |
+
|
| 468 |
+
updated = False
|
| 469 |
+
for m in self.children():
|
| 470 |
+
if not isinstance(m, EntropyBottleneck):
|
| 471 |
+
continue
|
| 472 |
+
rv = m.update(force=force)
|
| 473 |
+
updated |= rv
|
| 474 |
+
return updated
|
| 475 |
+
|
| 476 |
+
def load_state_dict(self, state_dict, strict=True):
|
| 477 |
+
# Dynamically update the entropy bottleneck buffers related to the CDFs
|
| 478 |
+
update_registered_buffers(
|
| 479 |
+
self.entropy_bottleneck,
|
| 480 |
+
"entropy_bottleneck",
|
| 481 |
+
["_quantized_cdf", "_offset", "_cdf_length"],
|
| 482 |
+
state_dict,
|
| 483 |
+
)
|
| 484 |
+
update_registered_buffers(
|
| 485 |
+
self.gaussian_conditional,
|
| 486 |
+
"gaussian_conditional",
|
| 487 |
+
["_quantized_cdf", "_offset", "_cdf_length", "scale_table"],
|
| 488 |
+
state_dict,
|
| 489 |
+
)
|
| 490 |
+
super().load_state_dict(state_dict, strict=strict)
|
| 491 |
+
|
| 492 |
+
@classmethod
|
| 493 |
+
def from_state_dict(cls, state_dict):
|
| 494 |
+
"""Return a new model instance from `state_dict`."""
|
| 495 |
+
N = state_dict["g_a0.weight"].size(0)
|
| 496 |
+
M = state_dict["g_a6.weight"].size(0)
|
| 497 |
+
net = cls(N, M)
|
| 498 |
+
net.load_state_dict(state_dict)
|
| 499 |
+
return net
|
| 500 |
+
|
| 501 |
+
# def compress(self, x,similarity):
|
| 502 |
+
def compress(self, x):
|
| 503 |
+
x = x.cuda()
|
| 504 |
+
# similarity = similarity.to(device)
|
| 505 |
+
x_size = (x.shape[2], x.shape[3])
|
| 506 |
+
|
| 507 |
+
# start_1 = time.time()
|
| 508 |
+
#
|
| 509 |
+
# img_feat = self.net_lseg.forward(x)
|
| 510 |
+
# img_feat_norm = torch.nn.functional.normalize(img_feat, dim=1)
|
| 511 |
+
# #
|
| 512 |
+
# prompt = clip.tokenize(similarity).cuda()
|
| 513 |
+
# text_feat = self.net_lseg.clip_pretrained.encode_text(prompt) # 1, 512
|
| 514 |
+
# text_feat_norm = torch.nn.functional.normalize(text_feat, dim=1)
|
| 515 |
+
# #
|
| 516 |
+
# similarity = self.cosine_similarity(
|
| 517 |
+
# img_feat_norm, text_feat_norm.unsqueeze(-1).unsqueeze(-1)
|
| 518 |
+
# )
|
| 519 |
+
# similarity = similarity.unsqueeze(0)
|
| 520 |
+
#
|
| 521 |
+
# torch.cuda.synchronize()
|
| 522 |
+
#
|
| 523 |
+
# inf_time = time.time() - start_1
|
| 524 |
+
#
|
| 525 |
+
# print(inf_time)
|
| 526 |
+
|
| 527 |
+
# #####在这里
|
| 528 |
+
|
| 529 |
+
start = time.time()
|
| 530 |
+
# similarity_down_1 = torch.where(similarity > 0.9, torch.tensor(1.0), torch.tensor(1.0))
|
| 531 |
+
# similarities_repeated = self.simi_net(similarity_down_1)
|
| 532 |
+
# similarities_repeated = torch.sigmoid(similarities_repeated)
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
y_codec = self.g_a(x, x_size) # y
|
| 539 |
+
|
| 540 |
+
# y_import = self.sub_impor_net(y_codec)
|
| 541 |
+
# y_tanh = self.tanh(y_import)
|
| 542 |
+
#
|
| 543 |
+
# y_soft = self.softsign(y_tanh)
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
y_codec_a6 = self.g_a6(y_codec)
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
# y_imp = y_soft + similarities_repeated # 相似度* important map
|
| 550 |
+
# y = y_codec_a6 * y_imp
|
| 551 |
+
y = y_codec_a6
|
| 552 |
+
# y = y_imp * y_codec_a6
|
| 553 |
+
# y = self.sub_net_channel(y)
|
| 554 |
+
|
| 555 |
+
# y = y_codec_a6 * similarities_repeated
|
| 556 |
+
|
| 557 |
+
z = self.h_a(y)
|
| 558 |
+
|
| 559 |
+
z_strings = self.entropy_bottleneck.compress(z)
|
| 560 |
+
z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:])
|
| 561 |
+
|
| 562 |
+
params = self.h_s(z_hat)
|
| 563 |
+
|
| 564 |
+
s = 4 # scaling factor between z and y
|
| 565 |
+
kernel_size = 5 # context prediction kernel size
|
| 566 |
+
padding = (kernel_size - 1) // 2
|
| 567 |
+
|
| 568 |
+
y_height = z_hat.size(2) * s
|
| 569 |
+
y_width = z_hat.size(3) * s
|
| 570 |
+
|
| 571 |
+
y_hat = F.pad(y, (padding, padding, padding, padding))
|
| 572 |
+
|
| 573 |
+
# pylint: disable=protected-access
|
| 574 |
+
cdf = self.gaussian_conditional._quantized_cdf.tolist()
|
| 575 |
+
cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist()
|
| 576 |
+
offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist()
|
| 577 |
+
# pylint: enable=protected-access
|
| 578 |
+
# print(cdf, cdf_lengths, offsets)
|
| 579 |
+
y_strings = []
|
| 580 |
+
for i in range(y.size(0)):
|
| 581 |
+
encoder = BufferedRansEncoder()
|
| 582 |
+
# Warning, this is slow...
|
| 583 |
+
# TODO: profile the calls to the bindings...
|
| 584 |
+
symbols_list = []
|
| 585 |
+
indexes_list = []
|
| 586 |
+
y_q_ = torch.zeros_like(y)
|
| 587 |
+
indexes_ = torch.zeros_like(y)
|
| 588 |
+
for h in range(y_height):
|
| 589 |
+
for w in range(y_width):
|
| 590 |
+
y_crop = y_hat[
|
| 591 |
+
i: i + 1, :, h: h + kernel_size, w: w + kernel_size
|
| 592 |
+
]
|
| 593 |
+
ctx_p = self.context_prediction(y_crop)
|
| 594 |
+
# 1x1 conv for the entropy parameters prediction network, so
|
| 595 |
+
# we only keep the elements in the "center"
|
| 596 |
+
p = params[i: i + 1, :, h: h + 1, w: w + 1]
|
| 597 |
+
gaussian_params = self.entropy_parameters(
|
| 598 |
+
torch.cat((p, ctx_p[i: i + 1, :, 2: 3, 2: 3]), dim=1)
|
| 599 |
+
)
|
| 600 |
+
scales_hat, means_hat = gaussian_params.chunk(2, 1)
|
| 601 |
+
|
| 602 |
+
indexes = self.gaussian_conditional.build_indexes(scales_hat)
|
| 603 |
+
y_q = torch.round(y_crop - means_hat)
|
| 604 |
+
y_hat[i, :, h + padding, w + padding] = (y_q + means_hat)[
|
| 605 |
+
i, :, padding, padding
|
| 606 |
+
]
|
| 607 |
+
y_q_[i,:, h, w] = y_q[i, :, padding, padding]
|
| 608 |
+
indexes_[i,:, h, w] = indexes[i, :,0,0]
|
| 609 |
+
|
| 610 |
+
flag = np.array(np.zeros(y_q_.shape[1]))
|
| 611 |
+
for idx in range(y_q_.shape[1]):
|
| 612 |
+
if torch.sum(torch.abs(y_q_[:, idx, :, :])) > 0: # 全部大于0就设置标志位是1
|
| 613 |
+
flag[idx] = 1
|
| 614 |
+
y_q_ = y_q_[:,np.nonzero(flag),...].squeeze()
|
| 615 |
+
indexes_ = indexes_[:,np.nonzero(flag),...].squeeze()
|
| 616 |
+
for h in range(y_height):
|
| 617 |
+
for w in range(y_width):
|
| 618 |
+
# encoder.encode_with_indexes(
|
| 619 |
+
# y_q_[:,np.nonzero(flag),h,w].squeeze().int().tolist(),
|
| 620 |
+
# indexes_[:,np.nonzero(flag),h,w].squeeze().int().tolist(), cdf, cdf_lengths, offsets
|
| 621 |
+
# )
|
| 622 |
+
symbols_list.extend(y_q_[:,h,w].int().tolist())
|
| 623 |
+
indexes_list.extend(indexes_[:,h,w].squeeze().int().tolist())
|
| 624 |
+
encoder.encode_with_indexes(
|
| 625 |
+
symbols_list, indexes_list, cdf, cdf_lengths, offsets
|
| 626 |
+
)
|
| 627 |
+
string = encoder.flush()
|
| 628 |
+
y_strings.append(string)
|
| 629 |
+
print(flag.sum())
|
| 630 |
+
|
| 631 |
+
torch.cuda.synchronize() # 确保 model2 真正跑完
|
| 632 |
+
t2 = time.time() - start
|
| 633 |
+
# print(t2)
|
| 634 |
+
|
| 635 |
+
return {"strings": [y_strings, z_strings], "shape": z.size()[-2:],"flag":flag}
|
| 636 |
+
|
| 637 |
+
# return {"test":similarity}
|
| 638 |
+
|
| 639 |
+
def compress_1(self, x,similarity):
|
| 640 |
+
# def compress_1(self, x):
|
| 641 |
+
x = x.cuda()
|
| 642 |
+
x_size = (x.shape[2], x.shape[3])
|
| 643 |
+
|
| 644 |
+
similarity = similarity.cuda()
|
| 645 |
+
# # #
|
| 646 |
+
similarity_down_1 = torch.where(similarity == 0, torch.tensor(1e-4), torch.tensor(1.0))
|
| 647 |
+
#
|
| 648 |
+
#
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
# similarity_down_1 = torch.where(similarity > 0.9, torch.tensor(1.0), torch.tensor(1e-4))
|
| 652 |
+
#
|
| 653 |
+
similarity_down_1 = F.interpolate(similarity_down_1, scale_factor=0.5, mode='bilinear')
|
| 654 |
+
similarities_repeated = self.simi_net(similarity_down_1)
|
| 655 |
+
similarities_repeated = torch.sigmoid(similarities_repeated)
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
y_codec = self.g_a(x, x_size) # y
|
| 659 |
+
|
| 660 |
+
y_import = self.sub_impor_net(y_codec)
|
| 661 |
+
y_tanh = self.tanh(y_import)
|
| 662 |
+
y_soft = self.softsign(y_tanh) # important2
|
| 663 |
+
# y_soft = self.sigmoid(y_soft)
|
| 664 |
+
|
| 665 |
+
y_codec_a6 = self.g_a6(y_codec)
|
| 666 |
+
# y_codec_a6 = self.attetionmap(y_codec_a6)
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
y_imp = similarities_repeated + y_soft # 相似度* important map
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
y = y_codec_a6 * y_imp
|
| 673 |
+
#
|
| 674 |
+
# y= y_codec_a6 * y_tanh
|
| 675 |
+
|
| 676 |
+
# cmap = ListedColormap(['yellow'])
|
| 677 |
+
#
|
| 678 |
+
# similarity_image = torch.where(similarity > 0.9, torch.tensor(1.0), torch.tensor(0.1))
|
| 679 |
+
# similarity_image = F.interpolate(similarity_image, scale_factor=2, mode='bilinear')
|
| 680 |
+
# abs = torch.abs(similarity_image)
|
| 681 |
+
# mean = torch.mean(abs, axis=1, keepdims=True)
|
| 682 |
+
# viz = mean.detach().cpu().numpy()
|
| 683 |
+
# viz = viz[0]
|
| 684 |
+
# viz = viz.squeeze()
|
| 685 |
+
# plt.imshow(viz)
|
| 686 |
+
# # # 保存图像
|
| 687 |
+
# plt.imsave('/mnt/disk10T/xfx/CLIP/bird.png', viz)
|
| 688 |
+
|
| 689 |
+
z = self.h_a(y)
|
| 690 |
+
|
| 691 |
+
z_strings = self.entropy_bottleneck.compress(z)
|
| 692 |
+
z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:])
|
| 693 |
+
|
| 694 |
+
params = self.h_s(z_hat)
|
| 695 |
+
|
| 696 |
+
s = 4 # scaling factor between z and y
|
| 697 |
+
kernel_size = 5 # context prediction kernel size
|
| 698 |
+
padding = (kernel_size - 1) // 2
|
| 699 |
+
|
| 700 |
+
y_height = z_hat.size(2) * s
|
| 701 |
+
y_width = z_hat.size(3) * s
|
| 702 |
+
|
| 703 |
+
y_hat = F.pad(y, (padding, padding, padding, padding))
|
| 704 |
+
|
| 705 |
+
# pylint: disable=protected-access
|
| 706 |
+
cdf = self.gaussian_conditional._quantized_cdf.tolist()
|
| 707 |
+
cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist()
|
| 708 |
+
offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist()
|
| 709 |
+
# pylint: enable=protected-access
|
| 710 |
+
# print(cdf, cdf_lengths, offsets)
|
| 711 |
+
|
| 712 |
+
y_strings = []
|
| 713 |
+
for i in range(y.size(0)):
|
| 714 |
+
encoder = BufferedRansEncoder()
|
| 715 |
+
# Warning, this is slow...
|
| 716 |
+
# TODO: profile the calls to the bindings...
|
| 717 |
+
symbols_list = []
|
| 718 |
+
indexes_list = []
|
| 719 |
+
for h in range(y_height):
|
| 720 |
+
for w in range(y_width):
|
| 721 |
+
y_crop = y_hat[
|
| 722 |
+
i: i + 1, :, h: h + kernel_size, w: w + kernel_size
|
| 723 |
+
]
|
| 724 |
+
ctx_p = self.context_prediction(y_crop)
|
| 725 |
+
# 1x1 conv for the entropy parameters prediction network, so
|
| 726 |
+
# we only keep the elements in the "center"
|
| 727 |
+
p = params[i: i + 1, :, h: h + 1, w: w + 1]
|
| 728 |
+
gaussian_params = self.entropy_parameters(
|
| 729 |
+
torch.cat((p, ctx_p[i: i + 1, :, 2: 3, 2: 3]), dim=1)
|
| 730 |
+
)
|
| 731 |
+
scales_hat, means_hat = gaussian_params.chunk(2, 1)
|
| 732 |
+
|
| 733 |
+
indexes = self.gaussian_conditional.build_indexes(scales_hat)
|
| 734 |
+
y_q = torch.round(y_crop - means_hat)
|
| 735 |
+
y_hat[i, :, h + padding, w + padding] = (y_q + means_hat)[
|
| 736 |
+
i, :, padding, padding
|
| 737 |
+
]
|
| 738 |
+
|
| 739 |
+
symbols_list.extend(y_q[i, :, padding, padding].int().tolist())
|
| 740 |
+
indexes_list.extend(indexes[i, :].squeeze().int().tolist())
|
| 741 |
+
|
| 742 |
+
encoder.encode_with_indexes(
|
| 743 |
+
symbols_list, indexes_list, cdf, cdf_lengths, offsets
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
string = encoder.flush()
|
| 747 |
+
y_strings.append(string)
|
| 748 |
+
|
| 749 |
+
return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]}
|
| 750 |
+
|
| 751 |
+
def compress_2(self, x,similarity):
|
| 752 |
+
# def compress_1(self, x):
|
| 753 |
+
x = x.cuda()
|
| 754 |
+
x_size = (x.shape[2], x.shape[3])
|
| 755 |
+
|
| 756 |
+
#####在这里
|
| 757 |
+
similarity_down_1 = torch.where(similarity > 0.9, torch.tensor(1.0), torch.tensor(0.1))
|
| 758 |
+
similarities_repeated = self.simi_net(similarity_down_1)
|
| 759 |
+
similarities_repeated = torch.sigmoid(similarities_repeated)
|
| 760 |
+
|
| 761 |
+
y_codec = self.g_a(x, x_size) # y
|
| 762 |
+
|
| 763 |
+
y_import = self.sub_impor_net(y_codec)
|
| 764 |
+
y_tanh = self.tanh(y_import)
|
| 765 |
+
|
| 766 |
+
y_codec_a6 = self.g_a6(y_codec)
|
| 767 |
+
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
y_imp = similarities_repeated + y_tanh # 相似度* important map
|
| 771 |
+
|
| 772 |
+
y = y_codec_a6 * y_imp
|
| 773 |
+
|
| 774 |
+
z = self.h_a(y)
|
| 775 |
+
|
| 776 |
+
z_strings = self.entropy_bottleneck.compress(z)
|
| 777 |
+
z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:])
|
| 778 |
+
|
| 779 |
+
params = self.h_s(z_hat)
|
| 780 |
+
|
| 781 |
+
s = 4 # scaling factor between z and y
|
| 782 |
+
kernel_size = 5 # context prediction kernel size
|
| 783 |
+
padding = (kernel_size - 1) // 2
|
| 784 |
+
|
| 785 |
+
y_height = z_hat.size(2) * s
|
| 786 |
+
y_width = z_hat.size(3) * s
|
| 787 |
+
|
| 788 |
+
y_hat = F.pad(y, (padding, padding, padding, padding))
|
| 789 |
+
|
| 790 |
+
# pylint: disable=protected-access
|
| 791 |
+
cdf = self.gaussian_conditional._quantized_cdf.tolist()
|
| 792 |
+
cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist()
|
| 793 |
+
offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist()
|
| 794 |
+
# pylint: enable=protected-access
|
| 795 |
+
# print(cdf, cdf_lengths, offsets)
|
| 796 |
+
|
| 797 |
+
y_strings = []
|
| 798 |
+
for i in range(y.size(0)):
|
| 799 |
+
encoder = BufferedRansEncoder()
|
| 800 |
+
# Warning, this is slow...
|
| 801 |
+
# TODO: profile the calls to the bindings...
|
| 802 |
+
symbols_list = []
|
| 803 |
+
indexes_list = []
|
| 804 |
+
for h in range(y_height):
|
| 805 |
+
for w in range(y_width):
|
| 806 |
+
y_crop = y_hat[
|
| 807 |
+
i: i + 1, :, h: h + kernel_size, w: w + kernel_size
|
| 808 |
+
]
|
| 809 |
+
ctx_p = self.context_prediction(y_crop)
|
| 810 |
+
# 1x1 conv for the entropy parameters prediction network, so
|
| 811 |
+
# we only keep the elements in the "center"
|
| 812 |
+
p = params[i: i + 1, :, h: h + 1, w: w + 1]
|
| 813 |
+
gaussian_params = self.entropy_parameters(
|
| 814 |
+
torch.cat((p, ctx_p[i: i + 1, :, 2: 3, 2: 3]), dim=1)
|
| 815 |
+
)
|
| 816 |
+
scales_hat, means_hat = gaussian_params.chunk(2, 1)
|
| 817 |
+
|
| 818 |
+
indexes = self.gaussian_conditional.build_indexes(scales_hat)
|
| 819 |
+
y_q = torch.round(y_crop - means_hat)
|
| 820 |
+
y_hat[i, :, h + padding, w + padding] = (y_q + means_hat)[
|
| 821 |
+
i, :, padding, padding
|
| 822 |
+
]
|
| 823 |
+
|
| 824 |
+
symbols_list.extend(y_q[i, :, padding, padding].int().tolist())
|
| 825 |
+
indexes_list.extend(indexes[i, :].squeeze().int().tolist())
|
| 826 |
+
|
| 827 |
+
encoder.encode_with_indexes(
|
| 828 |
+
symbols_list, indexes_list, cdf, cdf_lengths, offsets
|
| 829 |
+
)
|
| 830 |
+
|
| 831 |
+
string = encoder.flush()
|
| 832 |
+
y_strings.append(string)
|
| 833 |
+
|
| 834 |
+
return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]}
|
| 835 |
+
|
| 836 |
+
def decompress(self, strings, shape, flag):
|
| 837 |
+
# def decompress(self, strings, shape):
|
| 838 |
+
flag = np.nonzero(flag)
|
| 839 |
+
assert isinstance(strings, list) and len(strings) == 2
|
| 840 |
+
# FIXME: we don't respect the default entropy coder and directly call the
|
| 841 |
+
# range ANS decoder
|
| 842 |
+
|
| 843 |
+
z_hat = self.entropy_bottleneck.decompress(strings[1], shape)
|
| 844 |
+
params = self.h_s(z_hat)
|
| 845 |
+
|
| 846 |
+
s = 4 # scaling factor between z and y
|
| 847 |
+
kernel_size = 5 # context prediction kernel size
|
| 848 |
+
padding = (kernel_size - 1) // 2
|
| 849 |
+
|
| 850 |
+
y_height = z_hat.size(2) * s
|
| 851 |
+
y_width = z_hat.size(3) * s
|
| 852 |
+
|
| 853 |
+
# initialize y_hat to zeros, and pad it so we can directly work with
|
| 854 |
+
# sub-tensors of size (N, C, kernel size, kernel_size)
|
| 855 |
+
y_hat = torch.zeros(
|
| 856 |
+
(z_hat.size(0), 192, y_height + 2 * padding, y_width + 2 * padding),
|
| 857 |
+
device=z_hat.device,
|
| 858 |
+
)
|
| 859 |
+
decoder = RansDecoder()
|
| 860 |
+
|
| 861 |
+
# pylint: disable=protected-access
|
| 862 |
+
cdf = self.gaussian_conditional._quantized_cdf.tolist()
|
| 863 |
+
cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist()
|
| 864 |
+
offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist()
|
| 865 |
+
|
| 866 |
+
# Warning: this is slow due to the auto-regressive nature of the
|
| 867 |
+
# decoding... See more recent publication where they use an
|
| 868 |
+
# auto-regressive module on chunks of channels for faster decoding...
|
| 869 |
+
for i, y_string in enumerate(strings[0]):
|
| 870 |
+
decoder.set_stream(y_string)
|
| 871 |
+
|
| 872 |
+
for h in range(y_height):
|
| 873 |
+
for w in range(y_width):
|
| 874 |
+
# only perform the 5x5 convolution on a cropped tensor
|
| 875 |
+
# centered in (h, w)
|
| 876 |
+
y_crop = y_hat[
|
| 877 |
+
i: i + 1, :, h: h + kernel_size, w: w + kernel_size
|
| 878 |
+
]
|
| 879 |
+
ctx_p = self.context_prediction(y_crop)
|
| 880 |
+
# 1x1 conv for the entropy parameters prediction network, so
|
| 881 |
+
# we only keep the elements in the "center"
|
| 882 |
+
p = params[i: i + 1, :, h: h + 1, w: w + 1]
|
| 883 |
+
gaussian_params = self.entropy_parameters(
|
| 884 |
+
torch.cat((p, ctx_p[i: i + 1, :, 2: 3, 2: 3]), dim=1)
|
| 885 |
+
)
|
| 886 |
+
scales_hat, means_hat = gaussian_params.chunk(2, 1)
|
| 887 |
+
|
| 888 |
+
indexes = self.gaussian_conditional.build_indexes(scales_hat)
|
| 889 |
+
rv = decoder.decode_stream(
|
| 890 |
+
indexes[i, flag].squeeze().int().tolist(),
|
| 891 |
+
# indexes[i, :].squeeze().int().tolist(),
|
| 892 |
+
cdf,
|
| 893 |
+
cdf_lengths,
|
| 894 |
+
offsets,
|
| 895 |
+
)
|
| 896 |
+
# rv = torch.Tensor(rv).reshape(1, -1, 1, 1)
|
| 897 |
+
rv = torch.Tensor(rv).reshape(1, -1, 1, 1)
|
| 898 |
+
tmp = torch.zeros((1, 192, 1, 1))
|
| 899 |
+
tmp[:, flag, ...] = rv
|
| 900 |
+
rv = self.gaussian_conditional._dequantize(tmp, means_hat)
|
| 901 |
+
# rv = self.gaussian_conditional._dequantize(rv, means_hat)
|
| 902 |
+
|
| 903 |
+
y_hat[
|
| 904 |
+
i,
|
| 905 |
+
:,
|
| 906 |
+
h + padding: h + padding + 1,
|
| 907 |
+
w + padding: w + padding + 1,
|
| 908 |
+
] = rv
|
| 909 |
+
|
| 910 |
+
|
| 911 |
+
y_hat = y_hat[:, :, padding:-padding, padding:-padding]
|
| 912 |
+
# pylint: enable=protected-access
|
| 913 |
+
|
| 914 |
+
x_hat = self.g_s(y_hat).clamp_(0, 1)
|
| 915 |
+
return {"x_hat": x_hat,}
|
| 916 |
+
|
| 917 |
+
def decompress_1(self, strings, shape):
|
| 918 |
+
assert isinstance(strings, list) and len(strings) == 2
|
| 919 |
+
# FIXME: we don't respect the default entropy coder and directly call the
|
| 920 |
+
# range ANS decoder
|
| 921 |
+
|
| 922 |
+
z_hat = self.entropy_bottleneck.decompress(strings[1], shape)
|
| 923 |
+
params = self.h_s(z_hat)
|
| 924 |
+
|
| 925 |
+
s = 4 # scaling factor between z and y
|
| 926 |
+
kernel_size = 5 # context prediction kernel size
|
| 927 |
+
padding = (kernel_size - 1) // 2
|
| 928 |
+
|
| 929 |
+
y_height = z_hat.size(2) * s
|
| 930 |
+
y_width = z_hat.size(3) * s
|
| 931 |
+
|
| 932 |
+
# initialize y_hat to zeros, and pad it so we can directly work with
|
| 933 |
+
# sub-tensors of size (N, C, kernel size, kernel_size)
|
| 934 |
+
y_hat = torch.zeros(
|
| 935 |
+
(z_hat.size(0), 192, y_height + 2 * padding, y_width + 2 * padding),
|
| 936 |
+
device=z_hat.device,
|
| 937 |
+
)
|
| 938 |
+
decoder = RansDecoder()
|
| 939 |
+
|
| 940 |
+
# pylint: disable=protected-access
|
| 941 |
+
cdf = self.gaussian_conditional._quantized_cdf.tolist()
|
| 942 |
+
cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist()
|
| 943 |
+
offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist()
|
| 944 |
+
|
| 945 |
+
# Warning: this is slow due to the auto-regressive nature of the
|
| 946 |
+
# decoding... See more recent publication where they use an
|
| 947 |
+
# auto-regressive module on chunks of channels for faster decoding...
|
| 948 |
+
for i, y_string in enumerate(strings[0]):
|
| 949 |
+
decoder.set_stream(y_string)
|
| 950 |
+
|
| 951 |
+
for h in range(y_height):
|
| 952 |
+
for w in range(y_width):
|
| 953 |
+
# only perform the 5x5 convolution on a cropped tensor
|
| 954 |
+
# centered in (h, w)
|
| 955 |
+
y_crop = y_hat[
|
| 956 |
+
i: i + 1, :, h: h + kernel_size, w: w + kernel_size
|
| 957 |
+
]
|
| 958 |
+
ctx_p = self.context_prediction(y_crop)
|
| 959 |
+
# 1x1 conv for the entropy parameters prediction network, so
|
| 960 |
+
# we only keep the elements in the "center"
|
| 961 |
+
p = params[i: i + 1, :, h: h + 1, w: w + 1]
|
| 962 |
+
gaussian_params = self.entropy_parameters(
|
| 963 |
+
torch.cat((p, ctx_p[i: i + 1, :, 2: 3, 2: 3]), dim=1)
|
| 964 |
+
)
|
| 965 |
+
scales_hat, means_hat = gaussian_params.chunk(2, 1)
|
| 966 |
+
|
| 967 |
+
indexes = self.gaussian_conditional.build_indexes(scales_hat)
|
| 968 |
+
|
| 969 |
+
rv = decoder.decode_stream(
|
| 970 |
+
indexes[i, :].squeeze().int().tolist(),
|
| 971 |
+
cdf,
|
| 972 |
+
cdf_lengths,
|
| 973 |
+
offsets,
|
| 974 |
+
)
|
| 975 |
+
rv = torch.Tensor(rv).reshape(1, -1, 1, 1)
|
| 976 |
+
|
| 977 |
+
rv = self.gaussian_conditional._dequantize(rv, means_hat)
|
| 978 |
+
|
| 979 |
+
y_hat[
|
| 980 |
+
i,
|
| 981 |
+
:,
|
| 982 |
+
h + padding: h + padding + 1,
|
| 983 |
+
w + padding: w + padding + 1,
|
| 984 |
+
] = rv
|
| 985 |
+
y_hat = y_hat[:, :, padding:-padding, padding:-padding]
|
| 986 |
+
# pylint: enable=protected-access
|
| 987 |
+
|
| 988 |
+
x_hat = self.g_s(y_hat).clamp_(0, 1)
|
| 989 |
+
return {"x_hat": x_hat}
|
vae/transformer_layers.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def create_look_ahead_mask(size):
|
| 7 |
+
"""Creates a lookahead mask for autoregressive masking."""
|
| 8 |
+
mask = np.triu(np.ones((size, size), np.float32), 1)
|
| 9 |
+
return torch.Tensor(mask)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class StochasticDepth(nn.Module):
|
| 13 |
+
"""Creates a stochastic depth layer."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, stochastic_depth_drop_rate):
|
| 16 |
+
"""Initializes a stochastic depth layer.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
stochastic_depth_drop_rate: A `float` of drop rate.
|
| 20 |
+
name: Name of the layer.
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
A output `tf.Tensor` of which should have the same shape as input.
|
| 24 |
+
"""
|
| 25 |
+
super().__init__()
|
| 26 |
+
self._drop_rate = stochastic_depth_drop_rate
|
| 27 |
+
|
| 28 |
+
def forward(self, inputs):
|
| 29 |
+
if not self.training or self._drop_rate == 0.:
|
| 30 |
+
return inputs
|
| 31 |
+
keep_prob = 1.0 - self._drop_rate
|
| 32 |
+
batch_size = inputs.shape[0]
|
| 33 |
+
random_tensor = keep_prob
|
| 34 |
+
random_tensor += torch.rand_like(
|
| 35 |
+
[batch_size] + [1] * (inputs.shape.rank - 1), dtype=inputs.dtype)
|
| 36 |
+
binary_tensor = torch.floor(random_tensor)
|
| 37 |
+
output = torch.div(inputs, keep_prob) * binary_tensor
|
| 38 |
+
return output
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class MLP(nn.Module):
|
| 42 |
+
"""MLP head for transformer."""
|
| 43 |
+
|
| 44 |
+
def __init__(self, n_channel,expansion_rate, act, dropout_rate):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self._expansion_rate = expansion_rate
|
| 47 |
+
self._act = act
|
| 48 |
+
self._dropout_rate = dropout_rate
|
| 49 |
+
self._fc1 = nn.Linear(
|
| 50 |
+
n_channel,
|
| 51 |
+
self._expansion_rate * n_channel)
|
| 52 |
+
self.act1 = self._act()
|
| 53 |
+
self._fc2 = nn.Linear(
|
| 54 |
+
self._expansion_rate * n_channel,
|
| 55 |
+
n_channel)
|
| 56 |
+
self.act2 = self._act()
|
| 57 |
+
self._drop = nn.Dropout(self._dropout_rate)
|
| 58 |
+
|
| 59 |
+
def forward(self, features):
|
| 60 |
+
"""Forward pass."""
|
| 61 |
+
features = self.act1(self._fc1(features))
|
| 62 |
+
features = self._drop(features)
|
| 63 |
+
features = self.act2(self._fc2(features))
|
| 64 |
+
features = self._drop(features)
|
| 65 |
+
return features
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class TransformerBlock(nn.Module):
|
| 69 |
+
"""Transformer block that is similar to the Swin encoder block.
|
| 70 |
+
|
| 71 |
+
However, an important difference is that we _do not_ shift the windows
|
| 72 |
+
for the second Attention layer. Instead, we _feed the encoder outputs_
|
| 73 |
+
as Keys and Values. This allows for autoregressive applications.
|
| 74 |
+
|
| 75 |
+
If `style == "encoder"`, no autoregression is happening.
|
| 76 |
+
|
| 77 |
+
Also, this class operates on windowed tensor, see `call` docstring.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
*,
|
| 83 |
+
d_model,
|
| 84 |
+
seq_len,
|
| 85 |
+
num_head = 4,
|
| 86 |
+
mlp_expansion = 4,
|
| 87 |
+
mlp_act = nn.GELU,
|
| 88 |
+
drop_out_rate = 0.1,
|
| 89 |
+
drop_path_rate = 0.1,
|
| 90 |
+
style = "decoder",
|
| 91 |
+
):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self._style = style
|
| 94 |
+
if style == "decoder":
|
| 95 |
+
# Register as a buffer so moving the module moves the mask too.
|
| 96 |
+
self.register_buffer(
|
| 97 |
+
"look_ahead_mask",
|
| 98 |
+
create_look_ahead_mask(seq_len),
|
| 99 |
+
persistent=False,
|
| 100 |
+
)
|
| 101 |
+
elif style == "encoder":
|
| 102 |
+
self.look_ahead_mask = None
|
| 103 |
+
else:
|
| 104 |
+
raise ValueError(f"Invalid style: {style}")
|
| 105 |
+
|
| 106 |
+
# self._norm1a = nn.LayerNorm(
|
| 107 |
+
# axis=-1, epsilon=1e-5, name="mhsa_normalization1")
|
| 108 |
+
self._norm1a = nn.LayerNorm(d_model)
|
| 109 |
+
# self._norm1b = tf.keras.layers.LayerNormalization(
|
| 110 |
+
# axis=-1, epsilon=1e-5, name="ffn_normalization1")
|
| 111 |
+
self._norm1b = nn.LayerNorm(d_model,eps=1e-5)
|
| 112 |
+
|
| 113 |
+
# self._norm2a = tf.keras.layers.LayerNormalization(
|
| 114 |
+
# axis=-1, epsilon=1e-5, name="mhsa_normalization2")
|
| 115 |
+
self._norm2a = nn.LayerNorm(d_model, eps=1e-5)
|
| 116 |
+
# self._norm2b = tf.keras.layers.LayerNormalization(
|
| 117 |
+
# axis=-1, epsilon=1e-5, name="ffn_normalization2")
|
| 118 |
+
self._norm2b = nn.LayerNorm(d_model, eps=1e-5)
|
| 119 |
+
self._attn1 = nn.MultiheadAttention(
|
| 120 |
+
d_model,
|
| 121 |
+
num_head,
|
| 122 |
+
dropout=drop_out_rate
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
self._attn2 = nn.MultiheadAttention(
|
| 126 |
+
d_model,
|
| 127 |
+
num_head,
|
| 128 |
+
dropout=drop_out_rate
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
self._mlp1 = MLP(
|
| 132 |
+
d_model,
|
| 133 |
+
expansion_rate=mlp_expansion,
|
| 134 |
+
act=mlp_act,
|
| 135 |
+
dropout_rate=drop_out_rate)
|
| 136 |
+
self._mlp2 = MLP(
|
| 137 |
+
d_model,
|
| 138 |
+
expansion_rate=mlp_expansion,
|
| 139 |
+
act=mlp_act,
|
| 140 |
+
dropout_rate=drop_out_rate)
|
| 141 |
+
|
| 142 |
+
# No weights, so we share for both blocks.
|
| 143 |
+
self._drop_path = StochasticDepth(drop_path_rate)
|
| 144 |
+
|
| 145 |
+
def forward(self, features, enc_output):
|
| 146 |
+
if enc_output is None:
|
| 147 |
+
if self._style == "decoder":
|
| 148 |
+
raise ValueError("Need `enc_output` when running decoder.")
|
| 149 |
+
else:
|
| 150 |
+
assert enc_output.shape[0] == features.shape[0] and enc_output.shape[2] == features.shape[2]
|
| 151 |
+
|
| 152 |
+
# First Block ---
|
| 153 |
+
shortcut = features
|
| 154 |
+
features = self._norm1a(features)
|
| 155 |
+
# Masked self-attention.
|
| 156 |
+
features = features.permute(1, 0, 2) # NLD -> LND
|
| 157 |
+
features, _ = self._attn1(
|
| 158 |
+
value=features,
|
| 159 |
+
key=features,
|
| 160 |
+
query=features,
|
| 161 |
+
attn_mask=self.look_ahead_mask)
|
| 162 |
+
features = features.permute(1, 0, 2) # LND -> NLD
|
| 163 |
+
|
| 164 |
+
assert features.shape == shortcut.shape
|
| 165 |
+
features = shortcut + self._drop_path(features)
|
| 166 |
+
|
| 167 |
+
features = features + self._drop_path(
|
| 168 |
+
self._mlp1(self._norm1b(features)))
|
| 169 |
+
|
| 170 |
+
# Second Block ---
|
| 171 |
+
shortcut = features
|
| 172 |
+
features = self._norm2a(features)
|
| 173 |
+
# Unmasked "lookup" into enc_output, no need for mask.
|
| 174 |
+
|
| 175 |
+
features = features.permute(1, 0, 2) # NLD -> LND
|
| 176 |
+
if enc_output is not None:
|
| 177 |
+
enc_output = enc_output.permute(1, 0, 2) # NLD -> LND
|
| 178 |
+
features, _ = self._attn2( # pytype: disable=wrong-arg-types # dynamic-method-lookup
|
| 179 |
+
value=enc_output if enc_output is not None else features,
|
| 180 |
+
key=enc_output if enc_output is not None else features,
|
| 181 |
+
query=features,
|
| 182 |
+
attn_mask=None)
|
| 183 |
+
features = features.permute(1, 0, 2) # LND -> NLD
|
| 184 |
+
|
| 185 |
+
features = shortcut + self._drop_path(features)
|
| 186 |
+
output = features + self._drop_path(
|
| 187 |
+
self._mlp2(self._norm2b(features)))
|
| 188 |
+
|
| 189 |
+
return output
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class Transformer(nn.Module):
|
| 193 |
+
"""A stack of transformer blocks, useable for encoding or decoding."""
|
| 194 |
+
|
| 195 |
+
def __init__(
|
| 196 |
+
self,
|
| 197 |
+
is_decoder,
|
| 198 |
+
num_layers = 4,
|
| 199 |
+
d_model = 192,
|
| 200 |
+
seq_len = 16,
|
| 201 |
+
num_head = 4,
|
| 202 |
+
mlp_expansion = 4,
|
| 203 |
+
drop_out = 0.1
|
| 204 |
+
):
|
| 205 |
+
super().__init__()
|
| 206 |
+
self.is_decoder = is_decoder
|
| 207 |
+
|
| 208 |
+
# IMPORTANT: use ModuleList so parameters/buffers are registered and moved
|
| 209 |
+
# correctly with `.to(device)`.
|
| 210 |
+
self.layers = nn.ModuleList(
|
| 211 |
+
[
|
| 212 |
+
TransformerBlock(
|
| 213 |
+
d_model=d_model,
|
| 214 |
+
seq_len=seq_len,
|
| 215 |
+
num_head=num_head,
|
| 216 |
+
mlp_expansion=mlp_expansion,
|
| 217 |
+
drop_out_rate=drop_out,
|
| 218 |
+
drop_path_rate=drop_out,
|
| 219 |
+
style="decoder" if is_decoder else "encoder",
|
| 220 |
+
)
|
| 221 |
+
for _ in range(num_layers)
|
| 222 |
+
]
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
def forward(
|
| 226 |
+
self, latent, enc_output
|
| 227 |
+
):
|
| 228 |
+
"""Forward pass.
|
| 229 |
+
|
| 230 |
+
For decoder, this predicts distribution of `latent` given `enc_output`.
|
| 231 |
+
|
| 232 |
+
We assume that `latent` has already been embedded in a d_model-dimensional
|
| 233 |
+
space.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
latent: (B', seq_len, C) latent.
|
| 237 |
+
enc_output: (B', seq_len_enc, C) result of concatenated encode output.
|
| 238 |
+
training: Whether we are training.
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
Decoder output of shape (B', seq_len, C).
|
| 242 |
+
"""
|
| 243 |
+
assert len(latent.shape) == 3, latent.shape
|
| 244 |
+
if enc_output is not None:
|
| 245 |
+
assert latent.shape[-1] == enc_output.shape[-1], (latent.shape,
|
| 246 |
+
enc_output.shape)
|
| 247 |
+
for layer in self.layers:
|
| 248 |
+
latent = layer(features=latent, enc_output=enc_output)
|
| 249 |
+
return latent
|
| 250 |
+
|
vae/utils.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions for image compression.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from typing import Tuple, Dict
|
| 10 |
+
from .roi_tic import ModifiedTIC
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def compute_padding(in_h: int, in_w: int, min_div: int = 256) -> Tuple[Tuple[int, int, int, int], Tuple[int, int, int, int]]:
|
| 14 |
+
"""
|
| 15 |
+
Compute padding to make dimensions divisible by min_div.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
in_h: input height
|
| 19 |
+
in_w: input width
|
| 20 |
+
min_div: minimum divisor (default 256 for TIC)
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
pad: (left, right, top, bottom) padding
|
| 24 |
+
unpad: negative padding for cropping back
|
| 25 |
+
"""
|
| 26 |
+
out_h = (in_h + min_div - 1) // min_div * min_div
|
| 27 |
+
out_w = (in_w + min_div - 1) // min_div * min_div
|
| 28 |
+
|
| 29 |
+
left = (out_w - in_w) // 2
|
| 30 |
+
right = out_w - in_w - left
|
| 31 |
+
top = (out_h - in_h) // 2
|
| 32 |
+
bottom = out_h - in_h - top
|
| 33 |
+
|
| 34 |
+
pad = (left, right, top, bottom)
|
| 35 |
+
unpad = (-left, -right, -top, -bottom)
|
| 36 |
+
|
| 37 |
+
return pad, unpad
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def compress_image(
|
| 41 |
+
image: Image.Image,
|
| 42 |
+
mask: np.ndarray,
|
| 43 |
+
model: ModifiedTIC,
|
| 44 |
+
sigma: float = 0.3,
|
| 45 |
+
device: str = 'cuda'
|
| 46 |
+
) -> Dict:
|
| 47 |
+
"""
|
| 48 |
+
Compress image with ROI-based quality control.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
image: PIL Image (RGB)
|
| 52 |
+
mask: Binary mask (H, W) with 1 for ROI, 0 for background
|
| 53 |
+
model: Loaded ModifiedTIC model
|
| 54 |
+
sigma: Background quality (0.01-1.0, lower = more compression)
|
| 55 |
+
device: 'cuda' or 'cpu'
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
dict with:
|
| 59 |
+
- compressed: PIL Image of compressed result
|
| 60 |
+
- bpp: Bits per pixel
|
| 61 |
+
- original_size: Original image dimensions
|
| 62 |
+
- mask_used: The mask that was used
|
| 63 |
+
"""
|
| 64 |
+
# Convert image to tensor
|
| 65 |
+
img_array = np.array(image).astype(np.float32) / 255.0
|
| 66 |
+
img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).to(device)
|
| 67 |
+
|
| 68 |
+
# Pad image
|
| 69 |
+
_, _, h, w = img_tensor.shape
|
| 70 |
+
pad, unpad = compute_padding(h, w, min_div=256)
|
| 71 |
+
img_padded = F.pad(img_tensor, pad, mode='constant', value=0)
|
| 72 |
+
|
| 73 |
+
# Prepare mask
|
| 74 |
+
mask_tensor = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).to(device)
|
| 75 |
+
mask_padded = F.pad(mask_tensor, pad, mode='constant', value=0)
|
| 76 |
+
|
| 77 |
+
# Compress
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
# NOTE: `ModifiedTIC.forward()` handles mask downsampling internally.
|
| 80 |
+
out = model(img_padded, mask_padded, sigma=sigma)
|
| 81 |
+
|
| 82 |
+
# Unpad result
|
| 83 |
+
x_hat = F.pad(out['x_hat'], unpad)
|
| 84 |
+
|
| 85 |
+
# Convert back to image
|
| 86 |
+
x_hat_np = x_hat.squeeze(0).permute(1, 2, 0).cpu().numpy()
|
| 87 |
+
x_hat_np = np.clip(x_hat_np * 255, 0, 255).astype(np.uint8)
|
| 88 |
+
compressed_img = Image.fromarray(x_hat_np)
|
| 89 |
+
|
| 90 |
+
# Calculate BPP
|
| 91 |
+
num_pixels = h * w
|
| 92 |
+
likelihoods = out['likelihoods']
|
| 93 |
+
bpp_y = torch.log(likelihoods['y']).sum() / (-np.log(2) * num_pixels)
|
| 94 |
+
bpp_z = torch.log(likelihoods['z']).sum() / (-np.log(2) * num_pixels)
|
| 95 |
+
bpp = (bpp_y + bpp_z).item()
|
| 96 |
+
|
| 97 |
+
return {
|
| 98 |
+
'compressed': compressed_img,
|
| 99 |
+
'bpp': bpp,
|
| 100 |
+
'original_size': (w, h),
|
| 101 |
+
'mask_used': mask
|
| 102 |
+
}
|
vae/visualization.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Visualization utilities for compression results.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import cv2
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def highlight_roi(
|
| 13 |
+
image: Image.Image,
|
| 14 |
+
mask: np.ndarray,
|
| 15 |
+
alpha: float = 0.3,
|
| 16 |
+
color: Tuple[int, int, int] = (0, 255, 0)
|
| 17 |
+
) -> Image.Image:
|
| 18 |
+
"""
|
| 19 |
+
Highlight ROI regions in image with colored overlay.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
image: PIL Image
|
| 23 |
+
mask: Binary mask (H, W)
|
| 24 |
+
alpha: Overlay transparency (0-1)
|
| 25 |
+
color: RGB color tuple for ROI highlight
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
Image with ROI highlighted
|
| 29 |
+
"""
|
| 30 |
+
img_array = np.array(image)
|
| 31 |
+
|
| 32 |
+
# Create colored overlay
|
| 33 |
+
overlay = img_array.copy()
|
| 34 |
+
overlay[mask > 0.5] = color
|
| 35 |
+
|
| 36 |
+
# Blend
|
| 37 |
+
result = cv2.addWeighted(img_array, 1 - alpha, overlay, alpha, 0)
|
| 38 |
+
|
| 39 |
+
return Image.fromarray(result)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def create_comparison_grid(
|
| 43 |
+
original: Image.Image,
|
| 44 |
+
compressed: Image.Image,
|
| 45 |
+
mask: np.ndarray,
|
| 46 |
+
bpp: float,
|
| 47 |
+
sigma: float,
|
| 48 |
+
lambda_val: float,
|
| 49 |
+
highlight: bool = True
|
| 50 |
+
) -> Image.Image:
|
| 51 |
+
"""
|
| 52 |
+
Create side-by-side comparison of original and compressed images.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
original: Original PIL Image
|
| 56 |
+
compressed: Compressed PIL Image
|
| 57 |
+
mask: Binary mask used
|
| 58 |
+
bpp: Bits per pixel
|
| 59 |
+
sigma: Sigma value used
|
| 60 |
+
lambda_val: Lambda value used
|
| 61 |
+
highlight: Whether to show ROI overlay
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
Combined comparison image
|
| 65 |
+
"""
|
| 66 |
+
fig, axes = plt.subplots(1, 3 if highlight else 2, figsize=(15 if highlight else 10, 5))
|
| 67 |
+
|
| 68 |
+
# Original
|
| 69 |
+
axes[0].imshow(original)
|
| 70 |
+
axes[0].set_title('Original', fontsize=14)
|
| 71 |
+
axes[0].axis('off')
|
| 72 |
+
|
| 73 |
+
# Compressed
|
| 74 |
+
axes[1].imshow(compressed)
|
| 75 |
+
axes[1].set_title(f'Compressed (σ={sigma:.2f}, λ={lambda_val}, BPP={bpp:.3f})', fontsize=14)
|
| 76 |
+
axes[1].axis('off')
|
| 77 |
+
|
| 78 |
+
# ROI overlay
|
| 79 |
+
if highlight:
|
| 80 |
+
highlighted = highlight_roi(original, mask, alpha=0.4, color=(0, 255, 0))
|
| 81 |
+
axes[2].imshow(highlighted)
|
| 82 |
+
axes[2].set_title('ROI Mask (green)', fontsize=14)
|
| 83 |
+
axes[2].axis('off')
|
| 84 |
+
|
| 85 |
+
plt.tight_layout()
|
| 86 |
+
|
| 87 |
+
# Convert to PIL Image
|
| 88 |
+
fig.canvas.draw()
|
| 89 |
+
img_array = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
|
| 90 |
+
img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (4,))
|
| 91 |
+
img_array = img_array[:, :, :3] # Remove alpha channel
|
| 92 |
+
plt.close(fig)
|
| 93 |
+
|
| 94 |
+
return Image.fromarray(img_array)
|
video/__init__.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Video processing module for ROI-based video compression.
|
| 2 |
+
|
| 3 |
+
Provides:
|
| 4 |
+
- VideoProcessor: Frame extraction, motion analysis, adaptive compression
|
| 5 |
+
- MotionAnalyzer: Optical flow and scene complexity estimation
|
| 6 |
+
- ChunkCompressor: Chunk-by-chunk compression with bandwidth targeting
|
| 7 |
+
- Temporal smoothing utilities for mask stabilization
|
| 8 |
+
- Mask caching for video segmentation reuse
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from .video_processor import (
|
| 12 |
+
VideoProcessor,
|
| 13 |
+
VideoFrame,
|
| 14 |
+
CompressedChunk,
|
| 15 |
+
CompressionSettings,
|
| 16 |
+
ChunkPlan,
|
| 17 |
+
)
|
| 18 |
+
from .motion_analyzer import MotionAnalyzer
|
| 19 |
+
from .chunk_compressor import (
|
| 20 |
+
ChunkCompressor,
|
| 21 |
+
BandwidthController,
|
| 22 |
+
smooth_masks_temporal,
|
| 23 |
+
smooth_masks_temporal_fast,
|
| 24 |
+
smooth_masks_sdf,
|
| 25 |
+
)
|
| 26 |
+
from .mask_cache import (
|
| 27 |
+
save_video_masks,
|
| 28 |
+
load_video_masks,
|
| 29 |
+
get_mask_info,
|
| 30 |
+
)
|
| 31 |
+
from .gpu_memory import (
|
| 32 |
+
estimate_batch_sizes,
|
| 33 |
+
BatchSizeEstimate,
|
| 34 |
+
)
|
| 35 |
+
from .sdf_smoother import SDFSmoother
|
| 36 |
+
|
| 37 |
+
__all__ = [
|
| 38 |
+
"VideoProcessor",
|
| 39 |
+
"VideoFrame",
|
| 40 |
+
"CompressedChunk",
|
| 41 |
+
"CompressionSettings",
|
| 42 |
+
"ChunkPlan",
|
| 43 |
+
"MotionAnalyzer",
|
| 44 |
+
"ChunkCompressor",
|
| 45 |
+
"BandwidthController",
|
| 46 |
+
"smooth_masks_temporal",
|
| 47 |
+
"smooth_masks_temporal_fast",
|
| 48 |
+
"smooth_masks_sdf",
|
| 49 |
+
"save_video_masks",
|
| 50 |
+
"load_video_masks",
|
| 51 |
+
"get_mask_info",
|
| 52 |
+
"estimate_batch_sizes",
|
| 53 |
+
"BatchSizeEstimate",
|
| 54 |
+
"SDFSmoother",
|
| 55 |
+
]
|
video/chunk_compressor.py
ADDED
|
@@ -0,0 +1,877 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Chunk-based video compression with bandwidth targeting.
|
| 2 |
+
|
| 3 |
+
Implements dynamic compression that balances framerate and spatial quality
|
| 4 |
+
to meet bandwidth constraints while prioritizing motion-heavy scenes.
|
| 5 |
+
|
| 6 |
+
Includes temporal smoothing for segmentation masks to reduce flickering.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
from typing import List, Optional, Tuple, Generator, Dict, Any
|
| 13 |
+
import math
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
from PIL import Image
|
| 17 |
+
from scipy import ndimage
|
| 18 |
+
|
| 19 |
+
from .motion_analyzer import MotionAnalyzer, MotionMetrics
|
| 20 |
+
from .sdf_smoother import SDFSmoother
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def smooth_masks_temporal(
|
| 24 |
+
masks: List[np.ndarray],
|
| 25 |
+
window_size: int = 5,
|
| 26 |
+
threshold_appear: float = 0.4,
|
| 27 |
+
threshold_disappear: float = 0.2,
|
| 28 |
+
spatial_smooth: bool = True,
|
| 29 |
+
spatial_kernel_size: int = 5,
|
| 30 |
+
) -> List[np.ndarray]:
|
| 31 |
+
"""Apply temporal smoothing to a sequence of segmentation masks.
|
| 32 |
+
|
| 33 |
+
Reduces flickering by:
|
| 34 |
+
1. Applying temporal median/mean filtering across frames
|
| 35 |
+
2. Using hysteresis thresholding (different thresholds for appearing vs disappearing)
|
| 36 |
+
3. Optional spatial smoothing to clean up edges
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
masks: List of binary/float masks (H, W), values in [0, 1]
|
| 40 |
+
window_size: Number of frames to consider for temporal filtering (odd number)
|
| 41 |
+
threshold_appear: Confidence threshold for a pixel to become ROI (higher = stricter)
|
| 42 |
+
threshold_disappear: Confidence threshold for a pixel to stop being ROI (lower = stickier)
|
| 43 |
+
spatial_smooth: Whether to apply spatial Gaussian smoothing
|
| 44 |
+
spatial_kernel_size: Size of spatial smoothing kernel
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
List of smoothed masks
|
| 48 |
+
"""
|
| 49 |
+
if not masks or len(masks) < 2:
|
| 50 |
+
return masks
|
| 51 |
+
|
| 52 |
+
# Convert to numpy array for efficient processing: (T, H, W)
|
| 53 |
+
h, w = masks[0].shape
|
| 54 |
+
mask_stack = np.stack([m.astype(np.float32) for m in masks], axis=0)
|
| 55 |
+
num_frames = mask_stack.shape[0]
|
| 56 |
+
|
| 57 |
+
# Pad temporally for filtering
|
| 58 |
+
half_window = window_size // 2
|
| 59 |
+
padded = np.pad(mask_stack, ((half_window, half_window), (0, 0), (0, 0)), mode='edge')
|
| 60 |
+
|
| 61 |
+
# Apply temporal filtering using a sliding window
|
| 62 |
+
smoothed = np.zeros_like(mask_stack)
|
| 63 |
+
for t in range(num_frames):
|
| 64 |
+
# Get window of frames
|
| 65 |
+
window = padded[t:t + window_size] # (window_size, H, W)
|
| 66 |
+
# Use weighted mean - center frame gets more weight
|
| 67 |
+
weights = np.array([1, 2, 3, 2, 1][:window_size])
|
| 68 |
+
weights = weights / weights.sum()
|
| 69 |
+
weighted_mean = np.average(window, axis=0, weights=weights)
|
| 70 |
+
smoothed[t] = weighted_mean
|
| 71 |
+
|
| 72 |
+
# Apply hysteresis thresholding
|
| 73 |
+
# A pixel becomes ROI if confidence > threshold_appear
|
| 74 |
+
# A pixel stays ROI until confidence < threshold_disappear
|
| 75 |
+
result = np.zeros_like(smoothed)
|
| 76 |
+
|
| 77 |
+
# Initialize with first frame using appear threshold
|
| 78 |
+
result[0] = (smoothed[0] > threshold_appear).astype(np.float32)
|
| 79 |
+
|
| 80 |
+
for t in range(1, num_frames):
|
| 81 |
+
# Pixels that were ROI in previous frame
|
| 82 |
+
was_roi = result[t - 1] > 0.5
|
| 83 |
+
|
| 84 |
+
# New ROI: high confidence (above appear threshold)
|
| 85 |
+
new_roi = smoothed[t] > threshold_appear
|
| 86 |
+
|
| 87 |
+
# Continuing ROI: was ROI and still above disappear threshold
|
| 88 |
+
continuing_roi = was_roi & (smoothed[t] > threshold_disappear)
|
| 89 |
+
|
| 90 |
+
# Combine: either new or continuing
|
| 91 |
+
result[t] = (new_roi | continuing_roi).astype(np.float32)
|
| 92 |
+
|
| 93 |
+
# Optional: apply spatial smoothing to clean up edges
|
| 94 |
+
if spatial_smooth:
|
| 95 |
+
for t in range(num_frames):
|
| 96 |
+
# Gaussian blur then threshold
|
| 97 |
+
blurred = ndimage.gaussian_filter(result[t], sigma=spatial_kernel_size / 4)
|
| 98 |
+
result[t] = (blurred > 0.5).astype(np.float32)
|
| 99 |
+
|
| 100 |
+
# Convert back to list
|
| 101 |
+
return [result[t] for t in range(num_frames)]
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def smooth_masks_temporal_fast(
|
| 105 |
+
masks: List[np.ndarray],
|
| 106 |
+
alpha: float = 0.3,
|
| 107 |
+
threshold: float = 0.5,
|
| 108 |
+
) -> List[np.ndarray]:
|
| 109 |
+
"""Fast temporal smoothing using exponential moving average.
|
| 110 |
+
|
| 111 |
+
Simpler and faster than full temporal smoothing, good for real-time.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
masks: List of binary/float masks (H, W)
|
| 115 |
+
alpha: Smoothing factor (0-1). Lower = more smoothing, more lag.
|
| 116 |
+
threshold: Threshold for final binary mask
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
List of smoothed masks
|
| 120 |
+
"""
|
| 121 |
+
if not masks or len(masks) < 2:
|
| 122 |
+
return masks
|
| 123 |
+
|
| 124 |
+
result = []
|
| 125 |
+
ema = masks[0].astype(np.float32).copy()
|
| 126 |
+
result.append((ema > threshold).astype(np.float32))
|
| 127 |
+
|
| 128 |
+
for i in range(1, len(masks)):
|
| 129 |
+
current = masks[i].astype(np.float32)
|
| 130 |
+
# Exponential moving average
|
| 131 |
+
ema = alpha * current + (1 - alpha) * ema
|
| 132 |
+
result.append((ema > threshold).astype(np.float32))
|
| 133 |
+
|
| 134 |
+
return result
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def smooth_masks_sdf(
|
| 138 |
+
masks: List[np.ndarray],
|
| 139 |
+
alpha: float = 0.5,
|
| 140 |
+
empty_thresh: int = 10,
|
| 141 |
+
patience: int = 5,
|
| 142 |
+
) -> List[np.ndarray]:
|
| 143 |
+
"""Smooth masks using Signed Distance Field temporal filtering.
|
| 144 |
+
|
| 145 |
+
Uses SDF representation for fluid, jitter-free transitions while
|
| 146 |
+
preserving sharp boundaries. More sophisticated than simple EMA.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
masks: List of binary/float masks (H, W)
|
| 150 |
+
alpha: Smoothing factor (0.1 = slow/viscous, 0.9 = fast/reactive)
|
| 151 |
+
empty_thresh: Min pixel count to consider mask "valid"
|
| 152 |
+
patience: Frames to tolerate empty masks before decay
|
| 153 |
+
(0 = immediate, 5 = conservative, 15 = aggressive)
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
List of smoothed masks
|
| 157 |
+
"""
|
| 158 |
+
if not masks or len(masks) < 2:
|
| 159 |
+
return masks
|
| 160 |
+
|
| 161 |
+
smoother = SDFSmoother(alpha=alpha, empty_thresh=empty_thresh, patience=patience)
|
| 162 |
+
result = []
|
| 163 |
+
|
| 164 |
+
for mask in masks:
|
| 165 |
+
smoothed = smoother.update(mask.astype(np.float32))
|
| 166 |
+
result.append(smoothed)
|
| 167 |
+
|
| 168 |
+
return result
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
@dataclass
|
| 172 |
+
class CompressionResult:
|
| 173 |
+
"""Result of compressing a single frame."""
|
| 174 |
+
|
| 175 |
+
compressed_image: Image.Image
|
| 176 |
+
bpp: float
|
| 177 |
+
original_size: Tuple[int, int]
|
| 178 |
+
roi_coverage: float
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
@dataclass
|
| 182 |
+
class ChunkResult:
|
| 183 |
+
"""Result of compressing a video chunk."""
|
| 184 |
+
|
| 185 |
+
# Compressed frames for this chunk
|
| 186 |
+
frames: List[Image.Image]
|
| 187 |
+
|
| 188 |
+
# Frame indices from original video that were kept
|
| 189 |
+
frame_indices: List[int]
|
| 190 |
+
|
| 191 |
+
# Effective framerate for this chunk
|
| 192 |
+
effective_fps: float
|
| 193 |
+
|
| 194 |
+
# Quality level used (1-5)
|
| 195 |
+
quality_level: int
|
| 196 |
+
|
| 197 |
+
# Sigma (background preservation) used
|
| 198 |
+
sigma: float
|
| 199 |
+
|
| 200 |
+
# Average bits per pixel
|
| 201 |
+
avg_bpp: float
|
| 202 |
+
|
| 203 |
+
# Estimated chunk size in bytes
|
| 204 |
+
estimated_bytes: int
|
| 205 |
+
|
| 206 |
+
# Motion metrics for this chunk
|
| 207 |
+
motion_metrics: Optional[MotionMetrics] = None
|
| 208 |
+
|
| 209 |
+
# Chunk index
|
| 210 |
+
chunk_index: int = 0
|
| 211 |
+
|
| 212 |
+
# Total number of original frames in chunk
|
| 213 |
+
original_frame_count: int = 0
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class BandwidthController:
|
| 217 |
+
"""Controls compression parameters to meet bandwidth targets.
|
| 218 |
+
|
| 219 |
+
Dynamically adjusts:
|
| 220 |
+
- Frame sampling rate (effective FPS)
|
| 221 |
+
- Spatial compression quality
|
| 222 |
+
- Background preservation (sigma)
|
| 223 |
+
|
| 224 |
+
Based on:
|
| 225 |
+
- Target bandwidth constraint
|
| 226 |
+
- Motion complexity of current chunk
|
| 227 |
+
- Smooth transitions between settings
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
# Quality level presets (lambda, expected_bpp_range)
|
| 231 |
+
QUALITY_PRESETS = [
|
| 232 |
+
(1, 0.05, 0.15), # Lowest quality: ~0.05-0.15 bpp
|
| 233 |
+
(2, 0.10, 0.25), # Low quality: ~0.10-0.25 bpp
|
| 234 |
+
(3, 0.15, 0.40), # Medium quality: ~0.15-0.40 bpp
|
| 235 |
+
(4, 0.25, 0.60), # High quality: ~0.25-0.60 bpp
|
| 236 |
+
(5, 0.40, 1.00), # Best quality: ~0.40-1.00 bpp
|
| 237 |
+
]
|
| 238 |
+
|
| 239 |
+
def __init__(
|
| 240 |
+
self,
|
| 241 |
+
target_bandwidth_kbps: float = 500.0,
|
| 242 |
+
base_fps: float = 30.0,
|
| 243 |
+
min_fps: float = 5.0,
|
| 244 |
+
max_fps: float = 60.0,
|
| 245 |
+
chunk_duration_sec: float = 1.0,
|
| 246 |
+
smoothing_factor: float = 0.3,
|
| 247 |
+
aggressiveness: float = 0.5,
|
| 248 |
+
):
|
| 249 |
+
"""
|
| 250 |
+
Args:
|
| 251 |
+
target_bandwidth_kbps: Target bandwidth in kilobits per second
|
| 252 |
+
base_fps: Original video framerate
|
| 253 |
+
min_fps: Minimum allowed effective framerate
|
| 254 |
+
max_fps: Maximum allowed effective framerate
|
| 255 |
+
chunk_duration_sec: Duration of each chunk in seconds
|
| 256 |
+
smoothing_factor: How much to smooth parameter transitions (0-1)
|
| 257 |
+
aggressiveness: Bandwidth savings strategy (0.0=use full bandwidth, 1.0=maximum savings)
|
| 258 |
+
"""
|
| 259 |
+
self.target_bandwidth_kbps = target_bandwidth_kbps
|
| 260 |
+
self.base_fps = base_fps
|
| 261 |
+
self.min_fps = min_fps
|
| 262 |
+
self.max_fps = max_fps
|
| 263 |
+
self.chunk_duration_sec = chunk_duration_sec
|
| 264 |
+
self.smoothing_factor = smoothing_factor
|
| 265 |
+
self.aggressiveness = max(0.0, min(1.0, aggressiveness))
|
| 266 |
+
|
| 267 |
+
# State for smooth transitions
|
| 268 |
+
self._prev_fps: Optional[float] = None
|
| 269 |
+
self._prev_quality: Optional[int] = None
|
| 270 |
+
self._prev_sigma: Optional[float] = None
|
| 271 |
+
|
| 272 |
+
def reset(self):
|
| 273 |
+
"""Reset state for new video."""
|
| 274 |
+
self._prev_fps = None
|
| 275 |
+
self._prev_quality = None
|
| 276 |
+
self._prev_sigma = None
|
| 277 |
+
|
| 278 |
+
def compute_settings(
|
| 279 |
+
self,
|
| 280 |
+
frame_width: int,
|
| 281 |
+
frame_height: int,
|
| 282 |
+
motion_metrics: MotionMetrics,
|
| 283 |
+
roi_coverage: float = 0.3,
|
| 284 |
+
) -> Tuple[float, int, float]:
|
| 285 |
+
"""Compute compression settings for a chunk.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
frame_width: Frame width in pixels
|
| 289 |
+
frame_height: Frame height in pixels
|
| 290 |
+
motion_metrics: Motion analysis results
|
| 291 |
+
roi_coverage: Fraction of frame covered by ROI
|
| 292 |
+
|
| 293 |
+
Returns:
|
| 294 |
+
Tuple of (effective_fps, quality_level, sigma)
|
| 295 |
+
"""
|
| 296 |
+
num_pixels = frame_width * frame_height
|
| 297 |
+
|
| 298 |
+
# Target bits per chunk
|
| 299 |
+
target_bits_per_chunk = self.target_bandwidth_kbps * 1000 * self.chunk_duration_sec
|
| 300 |
+
|
| 301 |
+
# Framerate adjustment based on motion
|
| 302 |
+
# High motion -> more frames, lower quality per frame
|
| 303 |
+
# Low motion -> fewer frames, higher quality per frame
|
| 304 |
+
fr_factor = motion_metrics.framerate_factor
|
| 305 |
+
|
| 306 |
+
# For dynamic mode with high aggressiveness, use a higher baseline
|
| 307 |
+
# to ensure motion-based variation actually produces meaningful FPS range
|
| 308 |
+
if self.aggressiveness > 0.5:
|
| 309 |
+
# Use average of min/max as effective base for motion scaling
|
| 310 |
+
effective_base_fps = (self.min_fps + self.max_fps) / 2
|
| 311 |
+
else:
|
| 312 |
+
# Use video's native FPS
|
| 313 |
+
effective_base_fps = self.base_fps
|
| 314 |
+
|
| 315 |
+
# Base framerate adjusted by motion
|
| 316 |
+
motion_fps = effective_base_fps * fr_factor
|
| 317 |
+
motion_fps = np.clip(motion_fps, self.min_fps, self.max_fps)
|
| 318 |
+
|
| 319 |
+
# Estimate frames in chunk at this FPS
|
| 320 |
+
frames_in_chunk = int(motion_fps * self.chunk_duration_sec)
|
| 321 |
+
frames_in_chunk = max(1, frames_in_chunk)
|
| 322 |
+
|
| 323 |
+
# Target bits per frame
|
| 324 |
+
target_bits_per_frame = target_bits_per_chunk / frames_in_chunk
|
| 325 |
+
target_bpp = target_bits_per_frame / num_pixels
|
| 326 |
+
|
| 327 |
+
# Find quality level that matches target BPP
|
| 328 |
+
quality_level = self._find_quality_for_bpp(target_bpp, roi_coverage)
|
| 329 |
+
|
| 330 |
+
# Compute sigma based on motion and quality
|
| 331 |
+
# Higher motion -> lower sigma (more background compression to save bits for motion)
|
| 332 |
+
# Lower quality -> lower sigma (aggressive background compression)
|
| 333 |
+
base_sigma = 0.3
|
| 334 |
+
motion_adjustment = (1.0 - motion_metrics.motion_magnitude) * 0.4
|
| 335 |
+
quality_adjustment = (quality_level - 3) * 0.1
|
| 336 |
+
sigma = np.clip(base_sigma + motion_adjustment + quality_adjustment, 0.05, 0.8)
|
| 337 |
+
|
| 338 |
+
# Final FPS adjustment based on actual quality/bpp achieved
|
| 339 |
+
expected_bpp = self._estimate_bpp(quality_level, sigma, roi_coverage)
|
| 340 |
+
expected_bits_per_frame = expected_bpp * num_pixels
|
| 341 |
+
|
| 342 |
+
current_bandwidth = (motion_fps * expected_bits_per_frame) / 1000 # kbps
|
| 343 |
+
bandwidth_ratio = current_bandwidth / self.target_bandwidth_kbps if self.target_bandwidth_kbps > 0 else 1.0
|
| 344 |
+
|
| 345 |
+
# At high aggressiveness, prioritize motion over bandwidth constraints
|
| 346 |
+
# At low aggressiveness, prioritize bandwidth target
|
| 347 |
+
if self.aggressiveness > 0.7:
|
| 348 |
+
# Very aggressive: trust motion analysis, ignore bandwidth (may exceed target)
|
| 349 |
+
# Only enforce absolute max_fps constraint, not bandwidth
|
| 350 |
+
final_fps = motion_fps
|
| 351 |
+
elif self.aggressiveness > 0.5:
|
| 352 |
+
# Moderately aggressive: allow bandwidth excursions for high motion
|
| 353 |
+
# Only reduce FPS if significantly over budget (>1.5x)
|
| 354 |
+
if bandwidth_ratio > 1.5:
|
| 355 |
+
reduction_factor = 1.5 / bandwidth_ratio
|
| 356 |
+
final_fps = motion_fps * reduction_factor
|
| 357 |
+
else:
|
| 358 |
+
final_fps = motion_fps
|
| 359 |
+
else:
|
| 360 |
+
# Conservative mode: enforce bandwidth target strictly
|
| 361 |
+
override_threshold = 1.2 + self.aggressiveness * 0.5 # 1.2 at agg=0, 1.35 at agg=0.5
|
| 362 |
+
|
| 363 |
+
if bandwidth_ratio > override_threshold:
|
| 364 |
+
# Bandwidth too high, reduce FPS
|
| 365 |
+
fps_reduction = (bandwidth_ratio - 1.0) * 0.5
|
| 366 |
+
final_fps = max(self.min_fps, motion_fps / (1 + fps_reduction))
|
| 367 |
+
else:
|
| 368 |
+
# Keep motion-based FPS
|
| 369 |
+
final_fps = motion_fps
|
| 370 |
+
|
| 371 |
+
final_fps = np.clip(final_fps, self.min_fps, self.max_fps)
|
| 372 |
+
|
| 373 |
+
# Smooth transitions (disable smoothing at high aggressiveness for dramatic variation)
|
| 374 |
+
# At agg>0.7, no FPS smoothing - follow motion exactly
|
| 375 |
+
# At agg≤0.7, apply smoothing
|
| 376 |
+
if self.aggressiveness > 0.7:
|
| 377 |
+
# No FPS smoothing - allow dramatic jumps
|
| 378 |
+
smoothing_fps = 0.0
|
| 379 |
+
smoothing_other = 0.1 # Minimal smoothing for quality/sigma
|
| 380 |
+
else:
|
| 381 |
+
smoothing_fps = 0.3 - self.aggressiveness * 0.3 # 0.3 at agg=0, 0.09 at agg=0.7
|
| 382 |
+
smoothing_other = smoothing_fps
|
| 383 |
+
|
| 384 |
+
quality_change_limit = 1 + int(self.aggressiveness * 2) # 1 at agg=0, 3 at agg=1.0
|
| 385 |
+
|
| 386 |
+
if self._prev_fps is not None and smoothing_fps > 0:
|
| 387 |
+
final_fps = self._smooth(final_fps, self._prev_fps, smoothing_fps)
|
| 388 |
+
if self._prev_quality is not None:
|
| 389 |
+
# Quality can change by quality_change_limit steps at a time
|
| 390 |
+
quality_level = int(np.clip(
|
| 391 |
+
quality_level,
|
| 392 |
+
self._prev_quality - quality_change_limit,
|
| 393 |
+
self._prev_quality + quality_change_limit
|
| 394 |
+
))
|
| 395 |
+
if self._prev_sigma is not None and smoothing_other > 0:
|
| 396 |
+
sigma = self._smooth(sigma, self._prev_sigma, smoothing_other)
|
| 397 |
+
|
| 398 |
+
# Update state
|
| 399 |
+
self._prev_fps = final_fps
|
| 400 |
+
self._prev_quality = quality_level
|
| 401 |
+
self._prev_sigma = sigma
|
| 402 |
+
|
| 403 |
+
return float(final_fps), int(quality_level), float(sigma)
|
| 404 |
+
|
| 405 |
+
def _find_quality_for_bpp(self, target_bpp: float, roi_coverage: float) -> int:
|
| 406 |
+
"""Find quality level that approximately matches target BPP."""
|
| 407 |
+
# ROI coverage affects effective BPP (more ROI = higher BPP for same quality)
|
| 408 |
+
adjusted_target = target_bpp / (0.5 + roi_coverage * 0.5)
|
| 409 |
+
|
| 410 |
+
for level, min_bpp, max_bpp in self.QUALITY_PRESETS:
|
| 411 |
+
mid_bpp = (min_bpp + max_bpp) / 2
|
| 412 |
+
if adjusted_target <= mid_bpp:
|
| 413 |
+
return level
|
| 414 |
+
|
| 415 |
+
return 5 # Max quality if target is very high
|
| 416 |
+
|
| 417 |
+
def _estimate_bpp(self, quality_level: int, sigma: float, roi_coverage: float) -> float:
|
| 418 |
+
"""Estimate BPP for given settings."""
|
| 419 |
+
_, min_bpp, max_bpp = self.QUALITY_PRESETS[quality_level - 1]
|
| 420 |
+
base_bpp = (min_bpp + max_bpp) / 2
|
| 421 |
+
|
| 422 |
+
# Sigma affects background compression
|
| 423 |
+
# Lower sigma = more compression = lower BPP
|
| 424 |
+
sigma_factor = 0.5 + sigma * 0.5
|
| 425 |
+
|
| 426 |
+
# ROI coverage affects total BPP
|
| 427 |
+
# More ROI = more high-quality pixels = higher BPP
|
| 428 |
+
roi_factor = 0.7 + roi_coverage * 0.3
|
| 429 |
+
|
| 430 |
+
return base_bpp * sigma_factor * roi_factor
|
| 431 |
+
|
| 432 |
+
def _smooth(self, new_value: float, prev_value: float, factor: Optional[float] = None) -> float:
|
| 433 |
+
"""Apply exponential smoothing.
|
| 434 |
+
|
| 435 |
+
Args:
|
| 436 |
+
new_value: Target value
|
| 437 |
+
prev_value: Previous value
|
| 438 |
+
factor: Smoothing factor (uses self.smoothing_factor if None)
|
| 439 |
+
"""
|
| 440 |
+
if factor is None:
|
| 441 |
+
factor = self.smoothing_factor
|
| 442 |
+
return prev_value + factor * (new_value - prev_value)
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
class ChunkCompressor:
|
| 446 |
+
"""Compresses video chunks with adaptive settings.
|
| 447 |
+
|
| 448 |
+
Each chunk can have different:
|
| 449 |
+
- Frame sampling rate
|
| 450 |
+
- Compression quality
|
| 451 |
+
- Background preservation
|
| 452 |
+
|
| 453 |
+
Based on motion analysis and bandwidth constraints.
|
| 454 |
+
"""
|
| 455 |
+
|
| 456 |
+
def __init__(
|
| 457 |
+
self,
|
| 458 |
+
compression_model,
|
| 459 |
+
segmenter=None,
|
| 460 |
+
target_classes: Optional[List[str]] = None,
|
| 461 |
+
device: str = "cuda",
|
| 462 |
+
):
|
| 463 |
+
"""
|
| 464 |
+
Args:
|
| 465 |
+
compression_model: Loaded TIC compression model
|
| 466 |
+
segmenter: Optional segmenter for ROI extraction
|
| 467 |
+
target_classes: Classes to segment as ROI
|
| 468 |
+
device: Compute device
|
| 469 |
+
"""
|
| 470 |
+
self.compression_model = compression_model
|
| 471 |
+
self.segmenter = segmenter
|
| 472 |
+
self.target_classes = target_classes or []
|
| 473 |
+
self.device = device
|
| 474 |
+
|
| 475 |
+
self.motion_analyzer = MotionAnalyzer()
|
| 476 |
+
|
| 477 |
+
def reset(self):
|
| 478 |
+
"""Reset state for new video."""
|
| 479 |
+
self.motion_analyzer.reset()
|
| 480 |
+
|
| 481 |
+
def compress_frame(
|
| 482 |
+
self,
|
| 483 |
+
frame: Image.Image,
|
| 484 |
+
mask: Optional[np.ndarray],
|
| 485 |
+
sigma: float,
|
| 486 |
+
) -> CompressionResult:
|
| 487 |
+
"""Compress a single frame.
|
| 488 |
+
|
| 489 |
+
Args:
|
| 490 |
+
frame: PIL Image frame
|
| 491 |
+
mask: ROI mask (or None for no ROI)
|
| 492 |
+
sigma: Background preservation factor
|
| 493 |
+
|
| 494 |
+
Returns:
|
| 495 |
+
CompressionResult with compressed image and stats
|
| 496 |
+
"""
|
| 497 |
+
import vae
|
| 498 |
+
|
| 499 |
+
if mask is None:
|
| 500 |
+
mask = np.zeros((frame.height, frame.width), dtype=np.float32)
|
| 501 |
+
|
| 502 |
+
result = vae.compress_image(
|
| 503 |
+
image=frame,
|
| 504 |
+
mask=mask,
|
| 505 |
+
model=self.compression_model,
|
| 506 |
+
sigma=sigma,
|
| 507 |
+
device=self.device,
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
roi_coverage = float(mask.mean()) if mask is not None else 0.0
|
| 511 |
+
|
| 512 |
+
return CompressionResult(
|
| 513 |
+
compressed_image=result["compressed"],
|
| 514 |
+
bpp=result["bpp"],
|
| 515 |
+
original_size=(frame.width, frame.height),
|
| 516 |
+
roi_coverage=roi_coverage,
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
def compress_chunk(
|
| 520 |
+
self,
|
| 521 |
+
frames: List[Image.Image],
|
| 522 |
+
chunk_index: int,
|
| 523 |
+
effective_fps: float,
|
| 524 |
+
base_fps: float,
|
| 525 |
+
quality_level: int,
|
| 526 |
+
sigma: float,
|
| 527 |
+
roi_masks: Optional[List[np.ndarray]] = None,
|
| 528 |
+
) -> ChunkResult:
|
| 529 |
+
"""Compress a chunk of frames with given settings.
|
| 530 |
+
|
| 531 |
+
Args:
|
| 532 |
+
frames: List of PIL Image frames
|
| 533 |
+
chunk_index: Index of this chunk
|
| 534 |
+
effective_fps: Target effective framerate
|
| 535 |
+
base_fps: Original video framerate
|
| 536 |
+
quality_level: Compression quality 1-5
|
| 537 |
+
sigma: Background preservation factor
|
| 538 |
+
roi_masks: Optional list of ROI masks
|
| 539 |
+
|
| 540 |
+
Returns:
|
| 541 |
+
ChunkResult with compressed frames and stats
|
| 542 |
+
"""
|
| 543 |
+
if not frames:
|
| 544 |
+
return ChunkResult(
|
| 545 |
+
frames=[],
|
| 546 |
+
frame_indices=[],
|
| 547 |
+
effective_fps=0.0,
|
| 548 |
+
quality_level=quality_level,
|
| 549 |
+
sigma=sigma,
|
| 550 |
+
avg_bpp=0.0,
|
| 551 |
+
estimated_bytes=0,
|
| 552 |
+
chunk_index=chunk_index,
|
| 553 |
+
original_frame_count=0,
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
# Compute frame sampling to achieve target FPS
|
| 557 |
+
frame_step = max(1, int(base_fps / effective_fps))
|
| 558 |
+
sampled_indices = list(range(0, len(frames), frame_step))
|
| 559 |
+
|
| 560 |
+
# Ensure at least one frame
|
| 561 |
+
if not sampled_indices:
|
| 562 |
+
sampled_indices = [0]
|
| 563 |
+
|
| 564 |
+
# Compress sampled frames
|
| 565 |
+
compressed_frames: List[Image.Image] = []
|
| 566 |
+
bpps: List[float] = []
|
| 567 |
+
|
| 568 |
+
for idx in sampled_indices:
|
| 569 |
+
frame = frames[idx]
|
| 570 |
+
mask = roi_masks[idx] if roi_masks and idx < len(roi_masks) else None
|
| 571 |
+
|
| 572 |
+
result = self.compress_frame(frame, mask, sigma)
|
| 573 |
+
compressed_frames.append(result.compressed_image)
|
| 574 |
+
bpps.append(result.bpp)
|
| 575 |
+
|
| 576 |
+
# Compute stats
|
| 577 |
+
avg_bpp = float(np.mean(bpps)) if bpps else 0.0
|
| 578 |
+
frame_pixels = frames[0].width * frames[0].height
|
| 579 |
+
total_bits = avg_bpp * frame_pixels * len(compressed_frames)
|
| 580 |
+
estimated_bytes = int(total_bits / 8)
|
| 581 |
+
|
| 582 |
+
actual_fps = len(sampled_indices) / (len(frames) / base_fps) if len(frames) > 0 else 0
|
| 583 |
+
|
| 584 |
+
return ChunkResult(
|
| 585 |
+
frames=compressed_frames,
|
| 586 |
+
frame_indices=sampled_indices,
|
| 587 |
+
effective_fps=float(actual_fps),
|
| 588 |
+
quality_level=quality_level,
|
| 589 |
+
sigma=sigma,
|
| 590 |
+
avg_bpp=avg_bpp,
|
| 591 |
+
estimated_bytes=estimated_bytes,
|
| 592 |
+
chunk_index=chunk_index,
|
| 593 |
+
original_frame_count=len(frames),
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
def segment_frames(
|
| 597 |
+
self,
|
| 598 |
+
frames: List[Image.Image],
|
| 599 |
+
temporal_smoothing: bool = True,
|
| 600 |
+
smoothing_alpha: float = 0.3,
|
| 601 |
+
) -> List[np.ndarray]:
|
| 602 |
+
"""Segment multiple frames to get ROI masks (sequential).
|
| 603 |
+
|
| 604 |
+
Args:
|
| 605 |
+
frames: List of PIL Images
|
| 606 |
+
temporal_smoothing: Whether to apply temporal smoothing
|
| 607 |
+
smoothing_alpha: Alpha for fast EMA smoothing (0-1, lower=smoother)
|
| 608 |
+
|
| 609 |
+
Returns:
|
| 610 |
+
List of ROI masks (temporally smoothed if enabled)
|
| 611 |
+
"""
|
| 612 |
+
if self.segmenter is None:
|
| 613 |
+
return [np.zeros((f.height, f.width), dtype=np.float32) for f in frames]
|
| 614 |
+
|
| 615 |
+
masks = []
|
| 616 |
+
for frame in frames:
|
| 617 |
+
mask = self.segmenter(frame, target_classes=self.target_classes)
|
| 618 |
+
masks.append(mask.astype(np.float32))
|
| 619 |
+
|
| 620 |
+
# Apply temporal smoothing to reduce flickering
|
| 621 |
+
if temporal_smoothing and len(masks) > 2:
|
| 622 |
+
masks = smooth_masks_sdf(
|
| 623 |
+
masks,
|
| 624 |
+
alpha=smoothing_alpha,
|
| 625 |
+
empty_thresh=10,
|
| 626 |
+
patience=5,
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
return masks
|
| 630 |
+
|
| 631 |
+
def segment_frames_batch(
|
| 632 |
+
self,
|
| 633 |
+
frames: List[Image.Image],
|
| 634 |
+
batch_size: int = 4,
|
| 635 |
+
temporal_smoothing: bool = True,
|
| 636 |
+
smoothing_window: int = 5,
|
| 637 |
+
smoothing_alpha: float = 0.3,
|
| 638 |
+
) -> List[np.ndarray]:
|
| 639 |
+
"""Segment multiple frames with batch processing and temporal smoothing.
|
| 640 |
+
|
| 641 |
+
Processes frames in batches for better GPU utilization.
|
| 642 |
+
Falls back to sequential processing if batch not supported.
|
| 643 |
+
|
| 644 |
+
Optionally applies temporal smoothing to reduce mask flickering.
|
| 645 |
+
|
| 646 |
+
Args:
|
| 647 |
+
frames: List of PIL Images
|
| 648 |
+
batch_size: Number of frames to process at once
|
| 649 |
+
temporal_smoothing: Whether to apply temporal smoothing
|
| 650 |
+
smoothing_window: Window size for temporal smoothing (odd number)
|
| 651 |
+
smoothing_alpha: Alpha for fast EMA smoothing (0-1, lower=smoother)
|
| 652 |
+
|
| 653 |
+
Returns:
|
| 654 |
+
List of ROI masks (temporally smoothed if enabled)
|
| 655 |
+
"""
|
| 656 |
+
if self.segmenter is None:
|
| 657 |
+
return [np.zeros((f.height, f.width), dtype=np.float32) for f in frames]
|
| 658 |
+
|
| 659 |
+
# Check if segmenter supports batch processing
|
| 660 |
+
import torch
|
| 661 |
+
max_retries = 7
|
| 662 |
+
bs = batch_size
|
| 663 |
+
masks = None
|
| 664 |
+
for attempt in range(max_retries + 1):
|
| 665 |
+
try:
|
| 666 |
+
masks = []
|
| 667 |
+
if hasattr(self.segmenter, 'segment_batch') and getattr(self.segmenter, 'supports_batch', False):
|
| 668 |
+
for i in range(0, len(frames), bs):
|
| 669 |
+
batch = frames[i:i + bs]
|
| 670 |
+
batch_masks = self.segmenter.segment_batch(batch, target_classes=self.target_classes)
|
| 671 |
+
masks.extend([m.astype(np.float32) for m in batch_masks])
|
| 672 |
+
else:
|
| 673 |
+
for frame in frames:
|
| 674 |
+
mask = self.segmenter(frame, target_classes=self.target_classes)
|
| 675 |
+
masks.append(mask.astype(np.float32))
|
| 676 |
+
break # success
|
| 677 |
+
except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
|
| 678 |
+
if 'out of memory' in str(e).lower() and attempt < max_retries:
|
| 679 |
+
bs = max(1, bs // 2)
|
| 680 |
+
# Aggressive memory cleanup
|
| 681 |
+
masks = None
|
| 682 |
+
import gc
|
| 683 |
+
gc.collect()
|
| 684 |
+
torch.cuda.empty_cache()
|
| 685 |
+
torch.cuda.synchronize()
|
| 686 |
+
print(f"segment_frames_batch: OOM, retrying batch_size={bs} (attempt {attempt+1}/{max_retries})")
|
| 687 |
+
continue
|
| 688 |
+
raise
|
| 689 |
+
if masks is None:
|
| 690 |
+
raise RuntimeError("Segmentation failed after OOM retries")
|
| 691 |
+
|
| 692 |
+
# Apply temporal smoothing to reduce flickering
|
| 693 |
+
if temporal_smoothing and len(masks) > 2:
|
| 694 |
+
# Use SDF smoothing for jitter-free, fluid transitions
|
| 695 |
+
masks = smooth_masks_sdf(
|
| 696 |
+
masks,
|
| 697 |
+
alpha=smoothing_alpha,
|
| 698 |
+
empty_thresh=10,
|
| 699 |
+
patience=5,
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
return masks
|
| 703 |
+
|
| 704 |
+
def compress_frames_batch(
|
| 705 |
+
self,
|
| 706 |
+
frames: List[Image.Image],
|
| 707 |
+
masks: List[np.ndarray],
|
| 708 |
+
sigma: "float | List[float]",
|
| 709 |
+
batch_size: int = 4,
|
| 710 |
+
) -> List[CompressionResult]:
|
| 711 |
+
"""Compress multiple frames with true batch processing.
|
| 712 |
+
|
| 713 |
+
Processes frames in batches for better GPU utilization.
|
| 714 |
+
Batches padding operations and CPU transfers for efficiency.
|
| 715 |
+
|
| 716 |
+
Args:
|
| 717 |
+
frames: List of PIL Images
|
| 718 |
+
masks: List of ROI masks
|
| 719 |
+
sigma: Background preservation factor – a single float (uniform)
|
| 720 |
+
or a per-frame list of floats.
|
| 721 |
+
batch_size: Number of frames to process at once
|
| 722 |
+
|
| 723 |
+
Returns:
|
| 724 |
+
List of CompressionResult
|
| 725 |
+
"""
|
| 726 |
+
import torch
|
| 727 |
+
import torch.nn.functional as F
|
| 728 |
+
import vae
|
| 729 |
+
from vae.utils import compute_padding
|
| 730 |
+
|
| 731 |
+
if not frames:
|
| 732 |
+
return []
|
| 733 |
+
|
| 734 |
+
# Normalise sigma to a per-frame list
|
| 735 |
+
if isinstance(sigma, (int, float)):
|
| 736 |
+
sigmas = [float(sigma)] * len(frames)
|
| 737 |
+
else:
|
| 738 |
+
sigmas = [float(s) for s in sigma]
|
| 739 |
+
if len(sigmas) != len(frames):
|
| 740 |
+
raise ValueError(f"sigma list length {len(sigmas)} != frame count {len(frames)}")
|
| 741 |
+
|
| 742 |
+
results = []
|
| 743 |
+
max_retries = 7
|
| 744 |
+
bs = batch_size
|
| 745 |
+
|
| 746 |
+
# Process in batches with OOM retry
|
| 747 |
+
pos = 0
|
| 748 |
+
while pos < len(frames):
|
| 749 |
+
batch_end = min(pos + bs, len(frames))
|
| 750 |
+
batch_frames = frames[pos:batch_end]
|
| 751 |
+
batch_masks = masks[pos:batch_end]
|
| 752 |
+
|
| 753 |
+
try:
|
| 754 |
+
batch_results = self._compress_batch_inner(
|
| 755 |
+
batch_frames, batch_masks, sigmas, pos, pad_cache=None,
|
| 756 |
+
)
|
| 757 |
+
results.extend(batch_results)
|
| 758 |
+
pos = batch_end
|
| 759 |
+
except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
|
| 760 |
+
if 'out of memory' in str(e).lower() and bs > 1:
|
| 761 |
+
bs = max(1, bs // 2)
|
| 762 |
+
torch.cuda.empty_cache()
|
| 763 |
+
print(f"compress_frames_batch: OOM, retrying with batch_size={bs}")
|
| 764 |
+
continue
|
| 765 |
+
raise
|
| 766 |
+
|
| 767 |
+
# Clear GPU memory after all batches
|
| 768 |
+
if torch.cuda.is_available():
|
| 769 |
+
torch.cuda.empty_cache()
|
| 770 |
+
|
| 771 |
+
return results
|
| 772 |
+
|
| 773 |
+
def _compress_batch_inner(
|
| 774 |
+
self,
|
| 775 |
+
batch_frames: List[Image.Image],
|
| 776 |
+
batch_masks: List[np.ndarray],
|
| 777 |
+
sigmas: List[float],
|
| 778 |
+
global_offset: int,
|
| 779 |
+
pad_cache: Optional[Any] = None,
|
| 780 |
+
) -> List[CompressionResult]:
|
| 781 |
+
"""Compress a single batch of frames (inner helper).
|
| 782 |
+
|
| 783 |
+
Args:
|
| 784 |
+
batch_frames: Frames in this batch.
|
| 785 |
+
batch_masks: Masks in this batch.
|
| 786 |
+
sigmas: Per-frame sigma values (full list, indexed by global_offset + i).
|
| 787 |
+
global_offset: Start index into the full sigmas list.
|
| 788 |
+
pad_cache: Unused, reserved for future padding reuse.
|
| 789 |
+
|
| 790 |
+
Returns:
|
| 791 |
+
List of CompressionResult for this batch.
|
| 792 |
+
"""
|
| 793 |
+
import torch
|
| 794 |
+
import torch.nn.functional as F
|
| 795 |
+
from vae.utils import compute_padding
|
| 796 |
+
|
| 797 |
+
batch_results = []
|
| 798 |
+
batch_tensors = []
|
| 799 |
+
batch_mask_tensors = []
|
| 800 |
+
original_sizes = []
|
| 801 |
+
|
| 802 |
+
for frame, mask in zip(batch_frames, batch_masks):
|
| 803 |
+
img_array = np.array(frame).astype(np.float32) / 255.0
|
| 804 |
+
img_tensor = torch.from_numpy(img_array).permute(2, 0, 1)
|
| 805 |
+
|
| 806 |
+
if mask is None:
|
| 807 |
+
mask = np.zeros((frame.height, frame.width), dtype=np.float32)
|
| 808 |
+
mask_tensor = torch.from_numpy(mask).unsqueeze(0)
|
| 809 |
+
|
| 810 |
+
batch_tensors.append(img_tensor)
|
| 811 |
+
batch_mask_tensors.append(mask_tensor)
|
| 812 |
+
original_sizes.append((frame.width, frame.height))
|
| 813 |
+
|
| 814 |
+
_, h, w = batch_tensors[0].shape
|
| 815 |
+
pad, unpad = compute_padding(h, w, min_div=256)
|
| 816 |
+
|
| 817 |
+
padded_batch = torch.stack([
|
| 818 |
+
F.pad(img_t, pad, mode='constant', value=0) for img_t in batch_tensors
|
| 819 |
+
]).to(self.device)
|
| 820 |
+
|
| 821 |
+
padded_masks = torch.stack([
|
| 822 |
+
F.pad(mask_t, pad, mode='constant', value=0) for mask_t in batch_mask_tensors
|
| 823 |
+
]).to(self.device)
|
| 824 |
+
|
| 825 |
+
with torch.no_grad():
|
| 826 |
+
# Extract per-frame sigma values for this batch
|
| 827 |
+
batch_sigmas = [sigmas[global_offset + i] for i in range(len(batch_frames))]
|
| 828 |
+
|
| 829 |
+
# Call model once with entire batch (TRUE BATCHING)
|
| 830 |
+
if all(s == batch_sigmas[0] for s in batch_sigmas):
|
| 831 |
+
# All frames have same sigma - use scalar
|
| 832 |
+
out = self.compression_model(
|
| 833 |
+
padded_batch,
|
| 834 |
+
padded_masks,
|
| 835 |
+
sigma=batch_sigmas[0],
|
| 836 |
+
)
|
| 837 |
+
else:
|
| 838 |
+
# Different sigma values - pass as tensor
|
| 839 |
+
sigma_tensor = torch.tensor(batch_sigmas, device=self.device, dtype=torch.float32)
|
| 840 |
+
out = self.compression_model(
|
| 841 |
+
padded_batch,
|
| 842 |
+
padded_masks,
|
| 843 |
+
sigma=sigma_tensor,
|
| 844 |
+
)
|
| 845 |
+
|
| 846 |
+
# Unpad all frames at once
|
| 847 |
+
x_hat_padded = out['x_hat']
|
| 848 |
+
x_hat_batch = F.pad(x_hat_padded, unpad)
|
| 849 |
+
|
| 850 |
+
# Move to CPU and convert to numpy once
|
| 851 |
+
x_hat_cpu = x_hat_batch.cpu().numpy()
|
| 852 |
+
|
| 853 |
+
# Extract per-frame results
|
| 854 |
+
for i in range(len(batch_frames)):
|
| 855 |
+
x_hat_np = x_hat_cpu[i].transpose(1, 2, 0)
|
| 856 |
+
x_hat_np = np.clip(x_hat_np * 255, 0, 255).astype(np.uint8)
|
| 857 |
+
compressed_img = Image.fromarray(x_hat_np)
|
| 858 |
+
|
| 859 |
+
# Calculate BPP for this frame
|
| 860 |
+
num_pixels = h * w
|
| 861 |
+
likelihoods = out['likelihoods']
|
| 862 |
+
bpp_y = torch.log(likelihoods['y'][i:i+1]).sum() / (-np.log(2) * num_pixels)
|
| 863 |
+
bpp_z = torch.log(likelihoods['z'][i:i+1]).sum() / (-np.log(2) * num_pixels)
|
| 864 |
+
bpp = (bpp_y + bpp_z).item()
|
| 865 |
+
|
| 866 |
+
mask_for_coverage = batch_masks[i] if i < len(batch_masks) else None
|
| 867 |
+
roi_coverage = float(mask_for_coverage.mean()) if mask_for_coverage is not None else 0.0
|
| 868 |
+
|
| 869 |
+
batch_results.append(CompressionResult(
|
| 870 |
+
compressed_image=compressed_img,
|
| 871 |
+
bpp=bpp,
|
| 872 |
+
original_size=original_sizes[i],
|
| 873 |
+
roi_coverage=roi_coverage,
|
| 874 |
+
))
|
| 875 |
+
|
| 876 |
+
del padded_batch, padded_masks, x_hat_batch, x_hat_cpu
|
| 877 |
+
return batch_results
|
video/gpu_memory.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GPU memory estimation and automatic batch-size selection.
|
| 2 |
+
|
| 3 |
+
Queries free VRAM and uses per-model heuristics to pick the largest safe
|
| 4 |
+
batch size for both segmentation and compression stages. Falls back to
|
| 5 |
+
batch=1 on CPU or when CUDA info is unavailable.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from typing import Optional, Tuple
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# ---------------------------------------------------------------------------
|
| 16 |
+
# Per-frame memory cost heuristics (bytes, float32, conservative)
|
| 17 |
+
# These are empirical estimates measured on a mix of 480p frames.
|
| 18 |
+
# ---------------------------------------------------------------------------
|
| 19 |
+
|
| 20 |
+
# Segmentation models: (model_key, approx_param_bytes, per_frame_activation_bytes_480p)
|
| 21 |
+
_SEG_MEMORY_PER_FRAME: dict[str, float] = {
|
| 22 |
+
# YOLO-X seg – large model (71M params) with conv layers
|
| 23 |
+
"yolo": 180 * 1024**2, # ~180 MB per frame at 480p (X model)
|
| 24 |
+
# SegFormer – transformer encoder, moderate
|
| 25 |
+
"segformer": 200 * 1024**2, # ~200 MB per frame
|
| 26 |
+
# Mask2Former – Swin-Large + mask decoder, heavier
|
| 27 |
+
"mask2former": 350 * 1024**2, # ~350 MB per frame
|
| 28 |
+
# Mask R-CNN – ResNet50-FPN, moderate
|
| 29 |
+
"maskrcnn": 250 * 1024**2, # ~250 MB per frame
|
| 30 |
+
# SAM3 (OWL-ViT + SAM) – not truly batchable, treat as sequential
|
| 31 |
+
"sam3": 500 * 1024**2, # ~500 MB (single image pipeline)
|
| 32 |
+
# Fake segmentation (detection + tracking → bbox masks) - much lighter
|
| 33 |
+
"fake_yolo": 120 * 1024**2, # ~120 MB per frame (YOLO detection)
|
| 34 |
+
"fake_yolo_botsort": 120 * 1024**2, # Same as fake_yolo
|
| 35 |
+
"fake_detr": 150 * 1024**2, # ~150 MB per frame (DETR transformer)
|
| 36 |
+
"fake_deformable_detr": 170 * 1024**2, # ~170 MB per frame
|
| 37 |
+
"fake_fasterrcnn": 140 * 1024**2, # ~140 MB per frame (ResNet50 backbone)
|
| 38 |
+
"fake_retinanet": 140 * 1024**2, # ~140 MB per frame
|
| 39 |
+
"fake_fcos": 130 * 1024**2, # ~130 MB per frame
|
| 40 |
+
"fake_grounding_dino": 800 * 1024**2, # ~800 MB per frame (VERY large: BERT text + Swin-T vision + cross-attention)
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
_SEG_MODEL_OVERHEAD: dict[str, float] = {
|
| 44 |
+
"yolo": 450 * 1024**2, # YOLO-X model overhead (~71M params)
|
| 45 |
+
"segformer": 400 * 1024**2,
|
| 46 |
+
"mask2former": 800 * 1024**2,
|
| 47 |
+
"maskrcnn": 300 * 1024**2,
|
| 48 |
+
"sam3": 600 * 1024**2,
|
| 49 |
+
# Fake segmentation overhead (detector weights + tracker state)
|
| 50 |
+
"fake_yolo": 350 * 1024**2, # YOLO-X detector weights
|
| 51 |
+
"fake_yolo_botsort": 350 * 1024**2,
|
| 52 |
+
"fake_detr": 200 * 1024**2, # DETR weights
|
| 53 |
+
"fake_deformable_detr": 250 * 1024**2,
|
| 54 |
+
"fake_fasterrcnn": 160 * 1024**2, # Faster R-CNN weights
|
| 55 |
+
"fake_retinanet": 160 * 1024**2,
|
| 56 |
+
"fake_fcos": 150 * 1024**2,
|
| 57 |
+
"fake_grounding_dino": 1200 * 1024**2, # Grounding DINO weights (VERY large: ~700M params)
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
# TIC transformer compression: very heavy activations (attention is O(N²))
|
| 61 |
+
# Measured empirically on 854×480 frames:
|
| 62 |
+
# - Model params + entropy structures ≈ 1–3 GB
|
| 63 |
+
# - Per-frame activations ≈ 2–3 GB (float32)
|
| 64 |
+
_TIC_MODEL_OVERHEAD: float = 2.0 * 1024**3 # ~2 GB for params + entropy buffers
|
| 65 |
+
_TIC_PER_FRAME_480P: float = 2.5 * 1024**3 # ~2.5 GB per frame
|
| 66 |
+
|
| 67 |
+
# Reference resolution for the estimates above
|
| 68 |
+
_REF_HEIGHT = 480
|
| 69 |
+
_REF_WIDTH = 854
|
| 70 |
+
_REF_PIXELS = _REF_HEIGHT * _REF_WIDTH
|
| 71 |
+
|
| 72 |
+
# Safety margin – leave at least this fraction of free memory unused
|
| 73 |
+
_SAFETY_MARGIN = 0.15 # 15 %
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@dataclass
|
| 77 |
+
class BatchSizeEstimate:
|
| 78 |
+
"""Result of automatic batch-size estimation."""
|
| 79 |
+
|
| 80 |
+
seg_batch_size: int
|
| 81 |
+
compress_batch_size: int
|
| 82 |
+
free_vram_bytes: int
|
| 83 |
+
device: str
|
| 84 |
+
notes: str = ""
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _get_free_vram(device: str = "cuda") -> Optional[int]:
|
| 88 |
+
"""Return free VRAM in bytes, or None if unavailable."""
|
| 89 |
+
try:
|
| 90 |
+
import torch
|
| 91 |
+
|
| 92 |
+
if not torch.cuda.is_available() or device == "cpu":
|
| 93 |
+
return None
|
| 94 |
+
dev_idx = 0
|
| 95 |
+
if ":" in device:
|
| 96 |
+
dev_idx = int(device.split(":")[1])
|
| 97 |
+
free, _total = torch.cuda.mem_get_info(dev_idx)
|
| 98 |
+
return int(free)
|
| 99 |
+
except Exception:
|
| 100 |
+
return None
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _scale_memory(base_bytes: float, frame_h: int, frame_w: int) -> float:
|
| 104 |
+
"""Scale a 480p memory estimate to an arbitrary resolution."""
|
| 105 |
+
pixels = frame_h * frame_w
|
| 106 |
+
# Attention is roughly O(pixels) for feature-map memory and O(tokens²)
|
| 107 |
+
# for attention matrices. Use a conservative linear-ish scaling.
|
| 108 |
+
ratio = pixels / _REF_PIXELS
|
| 109 |
+
# Slightly super-linear to account for quadratic attention term
|
| 110 |
+
return base_bytes * (ratio ** 1.2)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def estimate_batch_sizes(
|
| 114 |
+
frame_height: int = 480,
|
| 115 |
+
frame_width: int = 854,
|
| 116 |
+
seg_method: str = "yolo",
|
| 117 |
+
device: str = "cuda",
|
| 118 |
+
total_frames: int = 300,
|
| 119 |
+
) -> BatchSizeEstimate:
|
| 120 |
+
"""Estimate optimal batch sizes for segmentation and compression.
|
| 121 |
+
|
| 122 |
+
The function queries free VRAM and computes how many frames can be
|
| 123 |
+
processed in a single batch for each stage, independently.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
frame_height: Height of (pre-processed) frames.
|
| 127 |
+
frame_width: Width of (pre-processed) frames.
|
| 128 |
+
seg_method: Segmentation method key (yolo, segformer, mask2former, etc.)
|
| 129 |
+
device: Torch device string.
|
| 130 |
+
total_frames: Total number of frames in the video.
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
BatchSizeEstimate with recommended batch sizes.
|
| 134 |
+
"""
|
| 135 |
+
free = _get_free_vram(device)
|
| 136 |
+
notes_parts: list[str] = []
|
| 137 |
+
|
| 138 |
+
if free is None:
|
| 139 |
+
# CPU fallback
|
| 140 |
+
return BatchSizeEstimate(
|
| 141 |
+
seg_batch_size=min(16, total_frames),
|
| 142 |
+
compress_batch_size=1,
|
| 143 |
+
free_vram_bytes=0,
|
| 144 |
+
device=device,
|
| 145 |
+
notes="CPU mode – using sequential compression, modest seg batches.",
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
usable = int(free * (1.0 - _SAFETY_MARGIN))
|
| 149 |
+
|
| 150 |
+
# --- Segmentation batch size ---
|
| 151 |
+
seg_key = seg_method.lower()
|
| 152 |
+
per_frame_seg = _scale_memory(
|
| 153 |
+
_SEG_MEMORY_PER_FRAME.get(seg_key, 200 * 1024**2),
|
| 154 |
+
frame_height,
|
| 155 |
+
frame_width,
|
| 156 |
+
)
|
| 157 |
+
model_overhead_seg = _SEG_MODEL_OVERHEAD.get(seg_key, 400 * 1024**2)
|
| 158 |
+
available_for_seg_frames = usable - model_overhead_seg
|
| 159 |
+
if available_for_seg_frames < per_frame_seg:
|
| 160 |
+
seg_batch = 1
|
| 161 |
+
else:
|
| 162 |
+
seg_batch = int(available_for_seg_frames / per_frame_seg)
|
| 163 |
+
|
| 164 |
+
# SAM3 is not truly batchable – cap to 1
|
| 165 |
+
if seg_key == "sam3":
|
| 166 |
+
seg_batch = 1
|
| 167 |
+
notes_parts.append("SAM3 is sequential (OWL-ViT+SAM pipeline).")
|
| 168 |
+
|
| 169 |
+
seg_batch = max(1, min(seg_batch, total_frames))
|
| 170 |
+
|
| 171 |
+
# --- Compression batch size ---
|
| 172 |
+
per_frame_compress = _scale_memory(_TIC_PER_FRAME_480P, frame_height, frame_width)
|
| 173 |
+
available_for_compress = usable - _TIC_MODEL_OVERHEAD
|
| 174 |
+
if available_for_compress < per_frame_compress:
|
| 175 |
+
compress_batch = 1
|
| 176 |
+
else:
|
| 177 |
+
compress_batch = int(available_for_compress / per_frame_compress)
|
| 178 |
+
|
| 179 |
+
compress_batch = max(1, min(compress_batch, total_frames))
|
| 180 |
+
|
| 181 |
+
notes_parts.append(
|
| 182 |
+
f"Free VRAM: {free / 1024**3:.1f} GB, "
|
| 183 |
+
f"usable: {usable / 1024**3:.1f} GB. "
|
| 184 |
+
f"Seg per-frame est: {per_frame_seg / 1024**2:.0f} MB, "
|
| 185 |
+
f"compress per-frame est: {per_frame_compress / 1024**2:.0f} MB."
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
return BatchSizeEstimate(
|
| 189 |
+
seg_batch_size=seg_batch,
|
| 190 |
+
compress_batch_size=compress_batch,
|
| 191 |
+
free_vram_bytes=free,
|
| 192 |
+
device=device,
|
| 193 |
+
notes=" ".join(notes_parts),
|
| 194 |
+
)
|
video/mask_cache.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utilities for saving and loading video segmentation masks."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import tempfile
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import List, Optional
|
| 7 |
+
import pickle
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def save_video_masks(masks: List[np.ndarray], output_path: Optional[str] = None) -> str:
|
| 11 |
+
"""Save video segmentation masks to a file.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
masks: List of mask arrays (H, W) for each frame
|
| 15 |
+
output_path: Optional output path. If None, creates temp file.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
Path to saved mask file
|
| 19 |
+
"""
|
| 20 |
+
if output_path is None:
|
| 21 |
+
# Create temporary file
|
| 22 |
+
fd, output_path = tempfile.mkstemp(suffix='.masks.npz', prefix='video_masks_')
|
| 23 |
+
import os
|
| 24 |
+
os.close(fd)
|
| 25 |
+
|
| 26 |
+
# Stack masks and save
|
| 27 |
+
mask_array = np.stack(masks, axis=0) # (T, H, W)
|
| 28 |
+
np.savez_compressed(output_path, masks=mask_array)
|
| 29 |
+
|
| 30 |
+
return output_path
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def load_video_masks(mask_path: str) -> List[np.ndarray]:
|
| 34 |
+
"""Load video segmentation masks from a file.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
mask_path: Path to saved mask file
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
List of mask arrays (H, W) for each frame
|
| 41 |
+
"""
|
| 42 |
+
data = np.load(mask_path)
|
| 43 |
+
mask_array = data['masks'] # (T, H, W)
|
| 44 |
+
|
| 45 |
+
# Convert back to list
|
| 46 |
+
masks = [mask_array[i] for i in range(mask_array.shape[0])]
|
| 47 |
+
|
| 48 |
+
return masks
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_mask_info(mask_path: str) -> dict:
|
| 52 |
+
"""Get information about saved masks without loading them.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
mask_path: Path to saved mask file
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
Dictionary with mask metadata
|
| 59 |
+
"""
|
| 60 |
+
data = np.load(mask_path)
|
| 61 |
+
mask_array = data['masks']
|
| 62 |
+
|
| 63 |
+
return {
|
| 64 |
+
'num_frames': mask_array.shape[0],
|
| 65 |
+
'height': mask_array.shape[1],
|
| 66 |
+
'width': mask_array.shape[2],
|
| 67 |
+
'dtype': str(mask_array.dtype),
|
| 68 |
+
'size_mb': mask_array.nbytes / (1024 * 1024),
|
| 69 |
+
}
|