""" Utility functions for image compression. """ import torch import torch.nn.functional as F import numpy as np from PIL import Image from typing import Tuple, Dict from .roi_tic import ModifiedTIC def compute_padding(in_h: int, in_w: int, min_div: int = 256) -> Tuple[Tuple[int, int, int, int], Tuple[int, int, int, int]]: """ Compute padding to make dimensions divisible by min_div. Args: in_h: input height in_w: input width min_div: minimum divisor (default 256 for TIC) Returns: pad: (left, right, top, bottom) padding unpad: negative padding for cropping back """ out_h = (in_h + min_div - 1) // min_div * min_div out_w = (in_w + min_div - 1) // min_div * min_div left = (out_w - in_w) // 2 right = out_w - in_w - left top = (out_h - in_h) // 2 bottom = out_h - in_h - top pad = (left, right, top, bottom) unpad = (-left, -right, -top, -bottom) return pad, unpad def compress_image( image: Image.Image, mask: np.ndarray, model: ModifiedTIC, sigma: float = 0.3, device: str = 'cuda' ) -> Dict: """ Compress image with ROI-based quality control. Args: image: PIL Image (RGB) mask: Binary mask (H, W) with 1 for ROI, 0 for background model: Loaded ModifiedTIC model sigma: Background quality (0.01-1.0, lower = more compression) device: 'cuda' or 'cpu' Returns: dict with: - compressed: PIL Image of compressed result - bpp: Bits per pixel - original_size: Original image dimensions - mask_used: The mask that was used """ # Convert image to tensor img_array = np.array(image).astype(np.float32) / 255.0 img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).to(device) # Pad image _, _, h, w = img_tensor.shape pad, unpad = compute_padding(h, w, min_div=256) img_padded = F.pad(img_tensor, pad, mode='constant', value=0) # Prepare mask mask_tensor = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).to(device) mask_padded = F.pad(mask_tensor, pad, mode='constant', value=0) # Compress with torch.no_grad(): # NOTE: `ModifiedTIC.forward()` handles mask downsampling internally. out = model(img_padded, mask_padded, sigma=sigma) # Unpad result x_hat = F.pad(out['x_hat'], unpad) # Convert back to image x_hat_np = x_hat.squeeze(0).permute(1, 2, 0).cpu().numpy() x_hat_np = np.clip(x_hat_np * 255, 0, 255).astype(np.uint8) compressed_img = Image.fromarray(x_hat_np) # Calculate BPP num_pixels = h * w likelihoods = out['likelihoods'] bpp_y = torch.log(likelihoods['y']).sum() / (-np.log(2) * num_pixels) bpp_z = torch.log(likelihoods['z']).sum() / (-np.log(2) * num_pixels) bpp = (bpp_y + bpp_z).item() return { 'compressed': compressed_img, 'bpp': bpp, 'original_size': (w, h), 'mask_used': mask }