contextual-communication-demo / roi_compressor.py
raheebhassan's picture
Added video compression
e6997e4
"""
CLI for ROI-based image/video compression using modular compression framework.
"""
import argparse
import os
import sys
from pathlib import Path
import numpy as np
from PIL import Image
import torch
from segmentation import create_segmenter
from vae import load_checkpoint, compress_image
from vae.visualization import create_comparison_grid
from video import VideoProcessor, CompressionSettings
from video.video_processor import frames_to_video_bytes
# Command-line interface
def main():
parser = argparse.ArgumentParser(description="ROI-based Image/Video Compressor.")
# I/O
parser.add_argument("--input", required=True, help="Path to input image or video file.")
parser.add_argument("--output", required=True, help="Path to save compressed output file.")
parser.add_argument("--checkpoint", default="checkpoints/tic_lambda_0.0483.pth.tar", help="Path to VAE model checkpoint.")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run on.")
# General Compression Settings
parser.add_argument("--quality-level", type=int, default=4, choices=range(7), help="Base quality level (0-6, higher is better). Affects VAE model selection.")
parser.add_argument("--sigma", type=float, default=0.3, help="Background quality factor (0.01-1.0). Lower means more background compression.")
# Segmentation Settings
parser.add_argument("--seg-method", default="yolo", help="Segmentation method to use.")
parser.add_argument("--seg-classes", nargs="+", required=True, help="Classes to segment as ROI.")
parser.add_argument("--seg-model", help="Path to a specific segmentation model checkpoint (optional).")
# Video-Specific Settings
parser.add_argument("--video-mode", choices=['static', 'dynamic'], default='static', help="Video compression mode: 'static' for fixed settings, 'dynamic' for adaptive.")
parser.add_argument("--target-bandwidth-kbps", type=int, default=1000, help="[Dynamic Mode] Target bandwidth in kbps.")
parser.add_argument("--output-fps", type=float, default=15.0, help="[Static Mode] Output framerate.")
parser.add_argument("--min-fps", type=float, default=5.0, help="[Dynamic Mode] Minimum framerate.")
parser.add_argument("--max-fps", type=float, default=30.0, help="[Dynamic Mode] Maximum framerate.")
parser.add_argument("--chunk-duration-sec", type=float, default=1.0, help="[Dynamic Mode] Duration of video chunks for analysis.")
# Detection/Tracking Settings (for video)
parser.add_argument("--detection-method", default="yolo", help="Object detection method for video.")
parser.add_argument("--enable-tracking", action="store_true", help="Enable object tracking in video.")
# Visualization
parser.add_argument("--highlight", action="store_true", help="Create a comparison grid image (for image input only).")
parser.add_argument("--viz-dir", help="Directory to save visualization artifacts (e.g., masks).")
args = parser.parse_args()
# --- Input Type Check ---
input_path = args.input.lower()
is_video = any(input_path.endswith(ext) for ext in ['.mp4', '.avi', '.mov', '.mkv'])
if is_video:
print("Processing video input...")
process_video(args)
else:
print("Processing image input...")
process_image(args)
def process_image(args):
"""Compresses a single image."""
print(f"Loading VAE model from {args.checkpoint}...")
model = load_checkpoint(args.checkpoint, device=args.device)
model.eval()
print(f"Loading segmenter '{args.seg_method}'...")
segmenter = create_segmenter(
args.seg_method,
device=args.device,
model_path=args.seg_model
)
print(f"Loading image from {args.input}...")
image = Image.open(args.input).convert("RGB")
print(f"Segmenting image for classes: {args.seg_classes}...")
mask, _ = segmenter(image, target_classes=args.seg_classes)
if args.viz_dir:
if not os.path.exists(args.viz_dir):
os.makedirs(args.viz_dir)
mask_path = os.path.join(args.viz_dir, "mask.png")
Image.fromarray((mask * 255).astype(np.uint8)).save(mask_path)
print(f"Saved segmentation mask to {mask_path}")
print(f"Compressing image with sigma={args.sigma}...")
result = compress_image(
image,
mask,
model,
sigma=args.sigma,
device=args.device
)
compressed_img = result['compressed']
bpp = result['bpp']
print(f"Saving compressed image to {args.output} (BPP: {bpp:.4f})")
compressed_img.save(args.output)
if args.highlight:
print("Creating comparison grid...")
lambda_val = float(os.path.basename(args.checkpoint).split('_')[-1].replace('.pth.tar', ''))
grid = create_comparison_grid(image, compressed_img, mask, bpp, args.sigma, lambda_val)
grid_path = args.output.replace(os.path.splitext(args.output)[1], "_comparison.jpg")
grid.save(grid_path)
print(f"Saved comparison grid to {grid_path}")
def process_video(args):
"""Compresses a video using static or dynamic settings."""
processor = VideoProcessor(device=args.device)
print("Loading models for video processing...")
processor.load_models(
quality_level=args.quality_level,
segmentation_method=args.seg_method,
detection_method=args.detection_method,
enable_tracking=args.enable_tracking,
)
if args.video_mode == 'static':
settings = CompressionSettings(
mode='static',
quality_level=args.quality_level,
sigma=args.sigma,
output_fps=args.output_fps,
target_classes=args.seg_classes,
)
print(f"Starting STATIC video compression with FPS={settings.output_fps}, Sigma={settings.sigma}...")
chunks = processor.process_static(args.input, settings)
else: # dynamic
settings = CompressionSettings(
mode='dynamic',
target_bandwidth_kbps=args.target_bandwidth_kbps,
min_fps=args.min_fps,
max_fps=args.max_fps,
chunk_duration_sec=args.chunk_duration_sec,
target_classes=args.seg_classes,
quality_level=args.quality_level,
)
print(f"Starting DYNAMIC video compression with Target Bandwidth={settings.target_bandwidth_kbps} kbps...")
chunks = processor.process_dynamic(args.input, settings)
all_frames = []
total_frames = 0
# Use tqdm for progress bar
from tqdm import tqdm
print("Processing video chunks...")
# Note: This part can be slow as it processes chunk by chunk.
# A progress bar helps to show it's working.
with tqdm(desc="Compressing Chunks") as pbar:
for i, chunk in enumerate(chunks):
all_frames.extend(chunk.frames)
total_frames += len(chunk.frames)
pbar.update(1)
pbar.set_postfix({
"chunk": i,
"frames": len(chunk.frames),
"fps": f"{chunk.fps:.1f}",
"bpp": f"{chunk.avg_bpp:.3f}"
})
if not all_frames:
print("No frames were processed. Exiting.")
return
# Determine the final video's FPS. For dynamic, we use the average.
final_fps = args.output_fps if args.video_mode == 'static' else (total_frames / args.chunk_duration_sec / (i+1))
print(f"\nRe-encoding {len(all_frames)} frames into final video at ~{final_fps:.2f} FPS...")
video_bytes = frames_to_video_bytes(all_frames, fps=final_fps)
print(f"Saving compressed video to {args.output}...")
with open(args.output, "wb") as f:
f.write(video_bytes)
print("Done.")
if __name__ == "__main__":
main()