| """ |
| CLI for ROI-based image/video compression using modular compression framework. |
| """ |
|
|
| import argparse |
| import os |
| import sys |
| from pathlib import Path |
|
|
| import numpy as np |
| from PIL import Image |
| import torch |
|
|
| from segmentation import create_segmenter |
| from vae import load_checkpoint, compress_image |
| from vae.visualization import create_comparison_grid |
| from video import VideoProcessor, CompressionSettings |
| from video.video_processor import frames_to_video_bytes |
|
|
|
|
| |
| def main(): |
| parser = argparse.ArgumentParser(description="ROI-based Image/Video Compressor.") |
| |
| parser.add_argument("--input", required=True, help="Path to input image or video file.") |
| parser.add_argument("--output", required=True, help="Path to save compressed output file.") |
| parser.add_argument("--checkpoint", default="checkpoints/tic_lambda_0.0483.pth.tar", help="Path to VAE model checkpoint.") |
| parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run on.") |
|
|
| |
| parser.add_argument("--quality-level", type=int, default=4, choices=range(7), help="Base quality level (0-6, higher is better). Affects VAE model selection.") |
| parser.add_argument("--sigma", type=float, default=0.3, help="Background quality factor (0.01-1.0). Lower means more background compression.") |
| |
| |
| parser.add_argument("--seg-method", default="yolo", help="Segmentation method to use.") |
| parser.add_argument("--seg-classes", nargs="+", required=True, help="Classes to segment as ROI.") |
| parser.add_argument("--seg-model", help="Path to a specific segmentation model checkpoint (optional).") |
|
|
| |
| parser.add_argument("--video-mode", choices=['static', 'dynamic'], default='static', help="Video compression mode: 'static' for fixed settings, 'dynamic' for adaptive.") |
| parser.add_argument("--target-bandwidth-kbps", type=int, default=1000, help="[Dynamic Mode] Target bandwidth in kbps.") |
| parser.add_argument("--output-fps", type=float, default=15.0, help="[Static Mode] Output framerate.") |
| parser.add_argument("--min-fps", type=float, default=5.0, help="[Dynamic Mode] Minimum framerate.") |
| parser.add_argument("--max-fps", type=float, default=30.0, help="[Dynamic Mode] Maximum framerate.") |
| parser.add_argument("--chunk-duration-sec", type=float, default=1.0, help="[Dynamic Mode] Duration of video chunks for analysis.") |
| |
| |
| parser.add_argument("--detection-method", default="yolo", help="Object detection method for video.") |
| parser.add_argument("--enable-tracking", action="store_true", help="Enable object tracking in video.") |
|
|
| |
| parser.add_argument("--highlight", action="store_true", help="Create a comparison grid image (for image input only).") |
| parser.add_argument("--viz-dir", help="Directory to save visualization artifacts (e.g., masks).") |
|
|
| args = parser.parse_args() |
|
|
| |
| input_path = args.input.lower() |
| is_video = any(input_path.endswith(ext) for ext in ['.mp4', '.avi', '.mov', '.mkv']) |
|
|
| if is_video: |
| print("Processing video input...") |
| process_video(args) |
| else: |
| print("Processing image input...") |
| process_image(args) |
|
|
| def process_image(args): |
| """Compresses a single image.""" |
| print(f"Loading VAE model from {args.checkpoint}...") |
| model = load_checkpoint(args.checkpoint, device=args.device) |
| model.eval() |
|
|
| print(f"Loading segmenter '{args.seg_method}'...") |
| segmenter = create_segmenter( |
| args.seg_method, |
| device=args.device, |
| model_path=args.seg_model |
| ) |
|
|
| print(f"Loading image from {args.input}...") |
| image = Image.open(args.input).convert("RGB") |
|
|
| print(f"Segmenting image for classes: {args.seg_classes}...") |
| mask, _ = segmenter(image, target_classes=args.seg_classes) |
|
|
| if args.viz_dir: |
| if not os.path.exists(args.viz_dir): |
| os.makedirs(args.viz_dir) |
| mask_path = os.path.join(args.viz_dir, "mask.png") |
| Image.fromarray((mask * 255).astype(np.uint8)).save(mask_path) |
| print(f"Saved segmentation mask to {mask_path}") |
|
|
| print(f"Compressing image with sigma={args.sigma}...") |
| result = compress_image( |
| image, |
| mask, |
| model, |
| sigma=args.sigma, |
| device=args.device |
| ) |
| compressed_img = result['compressed'] |
| bpp = result['bpp'] |
|
|
| print(f"Saving compressed image to {args.output} (BPP: {bpp:.4f})") |
| compressed_img.save(args.output) |
|
|
| if args.highlight: |
| print("Creating comparison grid...") |
| lambda_val = float(os.path.basename(args.checkpoint).split('_')[-1].replace('.pth.tar', '')) |
| grid = create_comparison_grid(image, compressed_img, mask, bpp, args.sigma, lambda_val) |
| grid_path = args.output.replace(os.path.splitext(args.output)[1], "_comparison.jpg") |
| grid.save(grid_path) |
| print(f"Saved comparison grid to {grid_path}") |
|
|
| def process_video(args): |
| """Compresses a video using static or dynamic settings.""" |
| processor = VideoProcessor(device=args.device) |
| print("Loading models for video processing...") |
| processor.load_models( |
| quality_level=args.quality_level, |
| segmentation_method=args.seg_method, |
| detection_method=args.detection_method, |
| enable_tracking=args.enable_tracking, |
| ) |
|
|
| if args.video_mode == 'static': |
| settings = CompressionSettings( |
| mode='static', |
| quality_level=args.quality_level, |
| sigma=args.sigma, |
| output_fps=args.output_fps, |
| target_classes=args.seg_classes, |
| ) |
| print(f"Starting STATIC video compression with FPS={settings.output_fps}, Sigma={settings.sigma}...") |
| chunks = processor.process_static(args.input, settings) |
| else: |
| settings = CompressionSettings( |
| mode='dynamic', |
| target_bandwidth_kbps=args.target_bandwidth_kbps, |
| min_fps=args.min_fps, |
| max_fps=args.max_fps, |
| chunk_duration_sec=args.chunk_duration_sec, |
| target_classes=args.seg_classes, |
| quality_level=args.quality_level, |
| ) |
| print(f"Starting DYNAMIC video compression with Target Bandwidth={settings.target_bandwidth_kbps} kbps...") |
| chunks = processor.process_dynamic(args.input, settings) |
|
|
| all_frames = [] |
| total_frames = 0 |
| |
| |
| from tqdm import tqdm |
| |
| print("Processing video chunks...") |
| |
| |
| with tqdm(desc="Compressing Chunks") as pbar: |
| for i, chunk in enumerate(chunks): |
| all_frames.extend(chunk.frames) |
| total_frames += len(chunk.frames) |
| pbar.update(1) |
| pbar.set_postfix({ |
| "chunk": i, |
| "frames": len(chunk.frames), |
| "fps": f"{chunk.fps:.1f}", |
| "bpp": f"{chunk.avg_bpp:.3f}" |
| }) |
|
|
| if not all_frames: |
| print("No frames were processed. Exiting.") |
| return |
|
|
| |
| final_fps = args.output_fps if args.video_mode == 'static' else (total_frames / args.chunk_duration_sec / (i+1)) |
|
|
| print(f"\nRe-encoding {len(all_frames)} frames into final video at ~{final_fps:.2f} FPS...") |
| video_bytes = frames_to_video_bytes(all_frames, fps=final_fps) |
|
|
| print(f"Saving compressed video to {args.output}...") |
| with open(args.output, "wb") as f: |
| f.write(video_bytes) |
| print("Done.") |
|
|
| if __name__ == "__main__": |
| main() |
|
|