A newer version of the Gradio SDK is available: 6.13.0
ROI-VAE Image Compression - Copilot Instructions
Project Overview
ROI-based VAE image compression using TIC (Transformer-based Image Compression). The system preserves quality in Regions of Interest (ROI) while aggressively compressing backgrounds using configurable quality factors.
Architecture
Core Pipeline
- Segmentation (
segmentation/module) → 2. Compression (vae/module) → 3. Output- Segmentation creates binary masks (1=ROI, 0=background)
- Compression applies variable quality based on mask using
sigmaparameter
Key Components
Segmentation Module (segmentation/):
- Abstract base class
BaseSegmenterdefines common interface - Implementations:
SegFormerSegmenter- Cityscapes semantic segmentation (19 classes: road, car, building, person, etc.)YOLOSegmenter- COCO instance segmentation (80 classes)Mask2FormerSegmenter- Swin Transformer-based panoptic/semantic segmentation (COCO: 133 classes, ADE20K: 150 classes)MaskRCNNSegmenter- ResNet50-FPN instance segmentation (COCO: 80 classes)SAM3Segmenter- Prompt-based segmentation (natural language prompt → mask via text-conditioned detector + SAM)FakeSegmenter- Detection + tracking → bbox masks (fast, non-pixel-perfect)
- Fake Segmentation (NEW): Detection-based segmentation for speed
- Creates rectangular masks from detection bounding boxes
- Uses object tracking for temporal consistency (ByteTrack, BoTSORT, SimpleTracker)
- Available methods:
fake_yolo(default, ByteTrack),fake_yolo_botsort,fake_detr,fake_fasterrcnn,fake_retinanet,fake_fcos,fake_deformable_detr,fake_grounding_dino - Much faster than pixel-perfect segmentation (~60-100 fps vs 10-30 fps)
- Memory estimates in
gpu_memory.py: 120-200 MB per frame (vs 180-500 MB for full segmentation)
- Factory pattern:
create_segmenter('yolo', device='cuda')orcreate_segmenter('fake_yolo', device='cuda') - Extensible for future models
- Utils:
visualize_mask(),save_mask(),calculate_roi_stats()
Compression Module (vae/):
tic_model.py: BaseTICclass - Transformer-based VAE with encoder, decoder, hyperpriorRSTB.py: Residual Swin Transformer Blocks and attention modulestransformer_layers.py: Generic transformer components (MLP, attention, drop path)roi_tic.py:ModifiedTICclass extending base TIC with ROI-aware forward passutils.py:compress_image(),compute_padding()for image processingvisualization.py:highlight_roi(),create_comparison_grid()for results- Handles checkpoint loading with compressai version compatibility fixes
Detection Module (detection/):
- Abstract base class
BaseDetectordefines common interface - Factory pattern:
create_detector('yolo', device='cuda') - Implementations:
YOLODetector- Ultralytics YOLO (closed-vocabulary COCO weights)- Torchvision: Faster R-CNN, RetinaNet, SSD, FCOS
- Transformers: DETR, Deformable DETR
EfficientDetDetector- optional viaeffdetYOLOWorldDetector- open-vocabulary detection (Ultralytics YOLO-World; requires prompts)GroundingDINODetector- open-vocabulary detection (Transformers; requires prompts)
- CLI:
roi_detection_eval.pyevaluates detection retention before vs after ROI compression
TIC Model (vae/tic_model.py):
- Transformer-based VAE with encoder (
g_a), decoder (g_s), and hyperprior (h_a,h_s) - Uses RSTB (Residual Swin Transformer Blocks) for feature extraction
- Channels: N=192, M=192 (expansion layer)
- Critical: Images must be padded to multiples of 256 (use
compute_padding())
ModifiedTIC (vae/roi_tic.py):
- Extends base TIC with ROI-aware forward pass
- Takes mask + sigma parameter to create quality factors
- Applies
similarity_losstensor: 1.0 for ROI pixels, sigma for background - Integrates mask through
simi_netandsub_impor_netbranches
Critical Conventions
Model Cache Locations
- By default, auto-downloaded model artifacts are kept inside
checkpoints/:- Hugging Face cache:
checkpoints/hf/ - Torch/torchvision cache:
checkpoints/torch/
- Hugging Face cache:
Checkpoint Loading Pattern
from vae import load_checkpoint
# Automatically handles compressai version mismatch
model = load_checkpoint('checkpoints/tic_lambda_0.0483.pth.tar', N=192, M=192, device='cuda')
# Note: model.update(force=True) is called automatically
Manual loading:
# Fix compressai version mismatch - required for all checkpoint loading
state_dict = checkpoint["state_dict"]
new_state_dict = {}
for k, v in state_dict.items():
if "entropy_bottleneck._matrix" in k:
new_key = k.replace("entropy_bottleneck._matrix", "entropy_bottleneck.matrices.")
# ... similar replacements for _bias, _factor
Always call model.update(force=True) after loading checkpoints.
Image Preprocessing
- Convert PIL to torch tensor:
x = torch.from_numpy(np.array(img)).float() / 255.0 - Permute to [B, C, H, W]:
x = x.permute(2, 0, 1).unsqueeze(0) - Pad to 256 multiples using
compute_padding(h, w, min_div=256) - Apply mask at same resolution as input image
Sigma Parameter
- Range: 0.01 - 1.0 (lower = more background compression)
- Default: 0.3
- ROI pixels always get quality factor 1.0
- Applied via
torch.where(mask > 0.5, 1.0, sigma)
Available Checkpoints
Located in checkpoints/ directory with different lambda (rate-distortion) values:
tic_lambda_0.0035.pth.tar- Lowest bitrate (highest compression)tic_lambda_0.013.pth.tar- Low bitrate (N=128, M=192)tic_lambda_0.025.pth.tar- Medium-low bitratetic_lambda_0.0483.pth.tar- Default - Medium bitratetic_lambda_0.0932.pth.tar- High bitrate (better quality)yolo26x-seg.pt- YOLO segmentation model
Development Workflows
Using Segmentation Module (New)
from segmentation import create_segmenter
# Available methods: segformer, yolo, mask2former, maskrcnn, sam3
# Fake methods: fake_yolo, fake_yolo_botsort, fake_detr, fake_fasterrcnn, etc.
segmenter = create_segmenter('mask2former', device='cuda', model_type='coco')
# Segment image
mask = segmenter(image, target_classes=['car', 'person'])
# Fast segmentation with detection + tracking (non-pixel-perfect)
fake_seg = create_segmenter('fake_yolo', device='cuda')
mask = fake_seg(image, target_classes=['person']) # Uses ByteTrack tracking
# Much faster: ~60-100 fps vs 10-30 fps for pixel-perfect segmentation
# Add new segmentation method
from segmentation import register_segmenter, BaseSegmenter
class MySegmenter(BaseSegmenter):
def load_model(self): ...
def segment(self, image, target_classes, **kwargs): ...
def get_available_classes(self): ...
register_segmenter('my_method', MySegmenter)
Using Compression Module (New)
from vae import load_checkpoint, compress_image
from PIL import Image
import numpy as np
# Load model
model = load_checkpoint('checkpoints/tic_lambda_0.0483.pth.tar', device='cuda')
# Compress image with mask
image = Image.open('input.jpg')
mask = np.zeros((image.height, image.width)) # Your mask here
result = compress_image(image, mask, model, sigma=0.3, device='cuda')
compressed = result['compressed'] # PIL Image
bpp = result['bpp'] # Bits per pixel
# Visualize results
from vae import create_comparison_grid
grid = create_comparison_grid(image, compressed, mask, bpp, sigma=0.3, lambda_val=0.0483)
grid.save('comparison.jpg')
Using Detection Module (New)
from detection import create_detector
# Closed-vocabulary
det = create_detector('yolo', device='cuda', model_path='checkpoints/yolo26x.pt')
dets = det(image, conf_threshold=0.25)
# Open-vocabulary (must pass prompts/classes)
det_ov = create_detector('yolo_world', device='cuda')
dets_ov = det_ov(image, conf_threshold=0.25, classes='person,car')
Detection Eval (CLI)
# Compare before vs after (already-compressed)
python roi_detection_eval.py \
--before images/car/0016cf15fa4d4e16.jpg \
--after results/compressed.jpg \
--detectors yolo detr \
--viz-dir results/det_viz
# Open-vocabulary eval (YOLO-World requires prompts)
python roi_detection_eval.py \
--before images/person/kodim04.png \
--checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
--sigma 0.3 \
--seg-method yolo --seg-classes person \
--detectors yolo_world \
--open-vocab-classes "person,car" \
--viz-dir results/det_viz
Running Compression (CLI)
# Basic compression with segmentation
python roi_compressor.py \
--input images/car/0016cf15fa4d4e16.jpg \
--output results/compressed.jpg \
--checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
--sigma 0.3 \
--seg-classes car \
--seg-method yolo
# Fast compression with detection-based fake segmentation (~3x faster)
python roi_compressor.py \
--input images/car/0016cf15fa4d4e16.jpg \
--output results/compressed.jpg \
--checkpoint checkpoints/tic_lambda_0.0483.pth.tar \
--sigma 0.3 \
--seg-classes car \
--seg-method fake_yolo
# With comparison grid (original, compressed, ROI highlighted)
python roi_compressor.py ... --highlight
Standalone Segmentation (CLI)
# Using Mask2Former with COCO panoptic
python roi_segmenter.py \
--input images/car/0016cf15fa4d4e16.jpg \
--output results/mask.png \
--method mask2former \
--classes car building person \
--visualize
# Fast segmentation with detection + ByteTrack tracking
python roi_segmenter.py \
--input data/videos/Person_doing_handstand.mp4 \
--output results/masks.mp4 \
--method fake_yolo \
--classes person \
--resize-height 480 \
--smooth-patience 10 \
--visualize
# Other fake methods (detection + tracking)
# fake_yolo_botsort (YOLO + BoTSORT)
# fake_detr (DETR + SimpleTracker)
# fake_fasterrcnn, fake_retinanet, fake_fcos, etc.
Adding New Segmentation Models
- Create new file in
segmentation/(e.g.,sam.py) - Extend
BaseSegmenterand implement abstract methods:load_model(): Load model weightssegment(): Generate mask from imageget_available_classes(): Return supported classes/capabilities
- Register in
segmentation/__init__.pyor useregister_segmenter() - Use via
create_segmenter('your_method', ...)
Testing Examples
roi_segmenter.py: CLI tool for standalone segmentationroi_compressor.py: CLI tool for ROI-based image compressionroi_segmenter.py: CLI tool for standalone segmentationroi_compressor.py: CLI tool for ROI-based image compressionsegmentation/: Modular segmentation with abstract base classbase.py:BaseSegmenterabstract classsegformer.py: Cityscapes semantic segmentationyolo.py: COCO instance segmentationfactory.py: Factory pattern for creating segmentersutils.py: Visualization and I/O utilities
vae/: Modular compression with ROI supporttic_model.py: BaseTICclass (Transformer-based VAE)RSTB.py: Residual Swin Transformer Blockstransformer_layers.py: Generic transformer componentsroi_tic.py:ModifiedTICclass and checkpoint loadingutils.py:compress_image(),compute_padding()visualization.py:highlight_roi(),create_comparison_grid()
roi_segmenter.py: CLI tool for standalone segmentationroi_compressor.py: CLI tool for ROI-based compressionvae_compress.py: Legacy ROI compression script (updated to use modules)*.bak: Backup files from pre-modularization (tic_model, RSTB, etc.)
Dependencies
- PyTorch + torchvision for model
- compressai for entropy models (version sensitive - see checkpoint loading)
- transformers for SegFormer + DETR/Deformable DETR + Grounding DINO
- ultralytics for YOLO + YOLO-World
- effdet (optional) for EfficientDet detector
- timm for model layers
Common Pitfalls
- Padding: Forgetting to pad images to 256 multiples causes dimension mismatches
- Checkpoint keys: Old checkpoints use
_matrix/_bias/_factornaming that must be converted - Mask resolution: Mask must match input image size; it's automatically downsampled in forward pass
- Mask downsampling: In ModifiedTIC, mask is downsampled to 1/2 resolution before simi_net (which further downsamples 8x to match 16x16 latent)
- Device mismatch: Ensure mask, sigma tensor, and model are on same device
- Model update: Must call
model.update(force=True)after loading for entropy models
Project Structure
.github/copilot-instructions.md: This file - comprehensive development guideexamples.sh: Example commands for running compression and segmentationREADME.md: Project overview and quick start guiderequirements.txt: Python dependencies
CLI Tools:
roi_segmenter.py: CLI tool for standalone segmentationroi_compressor.py: CLI tool for ROI-based image compressionapp.py: Gradio demo with Image and Video tabs
Core Modules:
segmentation/: Modular segmentation with abstract base classbase.py:BaseSegmenterabstract classsegformer.py: Cityscapes semantic segmentation (19 classes)yolo.py: COCO instance segmentation (80 classes)mask2former.py: Swin-based panoptic/semantic (COCO: 133, ADE20K: 150 classes)maskrcnn.py: ResNet50-FPN instance segmentation (COCO: 80 classes)sam3.py: Prompt-based segmentationfactory.py: Factory pattern for creating segmentersutils.py: Visualization and I/O utilities
vae/: Modular compression with ROI supporttic_model.py: BaseTICclass (Transformer-based VAE)RSTB.py: Residual Swin Transformer Blockstransformer_layers.py: Generic transformer componentsroi_tic.py:ModifiedTICclass and checkpoint loadingutils.py:compress_image(),compute_padding()visualization.py:highlight_roi(),create_comparison_grid()
video/: Video compression with streaming supportvideo_processor.py:VideoProcessorclass for video compressionmotion_analyzer.py:MotionAnalyzerfor scene complexity estimationchunk_compressor.py:ChunkCompressorandBandwidthController
detection/: Object detection and trackingtracker.py:SimpleTrackerIoU-based multi-object trackerutils.py:draw_detections(),draw_tracks()
Video Processing
Video Module Usage
from video import VideoProcessor, CompressionSettings
# Create processor
processor = VideoProcessor(device='cuda')
processor.load_models(
quality_level=4,
segmentation_method='sam3',
detection_method='yolo',
enable_tracking=True,
)
# Static mode (fixed settings)
settings = CompressionSettings(
mode='static',
quality_level=4,
sigma=0.3,
output_fps=15.0,
target_classes=['person', 'car'],
)
for chunk in processor.process_static('input.mp4', settings):
# Stream chunks in real-time
print(f"Chunk {chunk.chunk_index}: {len(chunk.frames)} frames at {chunk.fps} FPS")
# Dynamic mode (bandwidth-adaptive)
settings = CompressionSettings(
mode='dynamic',
target_bandwidth_kbps=500,
min_fps=5,
max_fps=30,
chunk_duration_sec=1.0,
target_classes=['person', 'car'],
)
for chunk in processor.process_dynamic('input.mp4', settings):
# Adaptive FPS and quality per chunk based on motion
print(f"Chunk {chunk.chunk_index}: fps={chunk.fps:.1f}, quality={chunk.quality_level}")
Motion-Adaptive Compression
The dynamic mode analyzes each chunk for:
- Motion magnitude: Mean pixel change between frames
- Motion coverage: Fraction of pixels with significant motion
- Scene complexity: Edge density and texture variance
- Scene changes: Large global differences
High-motion scenes get:
- More frames (higher FPS)
- Higher spatial compression (lower quality/sigma) to stay within bandwidth
Low-motion scenes get:
- Fewer frames (lower FPS)
- Better spatial quality (higher quality/sigma)
Object Tracking
from detection import SimpleTracker, draw_tracks
tracker = SimpleTracker(iou_threshold=0.3, max_age=30)
for frame_detections in frame_by_frame_detections:
tracks = tracker.update(frame_detections)
# tracks contains track_id, label, bbox, history
# Draw tracks with trails
img = draw_tracks(frame, tracks, show_id=True, show_trail=True)
Coding Guidelines
- Don't create unnecessary files—focus on core functionality.
- Ensure all scripts have clear argument parsing and help messages.
- Maintain consistent coding style and comments for clarity.
- Validate inputs (image paths, checkpoint paths, segmentation classes).
- Include error handling for common issues (file not found, dimension mismatches).
- Document all functions and classes with docstrings.
- Write modular code to facilitate testing and future extensions.
- Use ipynb files for prototyping but keep main logic in .py files.