Raheeb Hassan
Add code + LFS attributes
398659b
"""
Utility functions for image compression.
"""
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
from typing import Tuple, Dict
from .roi_tic import ModifiedTIC
def compute_padding(in_h: int, in_w: int, min_div: int = 256) -> Tuple[Tuple[int, int, int, int], Tuple[int, int, int, int]]:
"""
Compute padding to make dimensions divisible by min_div.
Args:
in_h: input height
in_w: input width
min_div: minimum divisor (default 256 for TIC)
Returns:
pad: (left, right, top, bottom) padding
unpad: negative padding for cropping back
"""
out_h = (in_h + min_div - 1) // min_div * min_div
out_w = (in_w + min_div - 1) // min_div * min_div
left = (out_w - in_w) // 2
right = out_w - in_w - left
top = (out_h - in_h) // 2
bottom = out_h - in_h - top
pad = (left, right, top, bottom)
unpad = (-left, -right, -top, -bottom)
return pad, unpad
def compress_image(
image: Image.Image,
mask: np.ndarray,
model: ModifiedTIC,
sigma: float = 0.3,
device: str = 'cuda'
) -> Dict:
"""
Compress image with ROI-based quality control.
Args:
image: PIL Image (RGB)
mask: Binary mask (H, W) with 1 for ROI, 0 for background
model: Loaded ModifiedTIC model
sigma: Background quality (0.01-1.0, lower = more compression)
device: 'cuda' or 'cpu'
Returns:
dict with:
- compressed: PIL Image of compressed result
- bpp: Bits per pixel
- original_size: Original image dimensions
- mask_used: The mask that was used
"""
# Convert image to tensor
img_array = np.array(image).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).to(device)
# Pad image
_, _, h, w = img_tensor.shape
pad, unpad = compute_padding(h, w, min_div=256)
img_padded = F.pad(img_tensor, pad, mode='constant', value=0)
# Prepare mask
mask_tensor = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).to(device)
mask_padded = F.pad(mask_tensor, pad, mode='constant', value=0)
# Compress
with torch.no_grad():
# NOTE: `ModifiedTIC.forward()` handles mask downsampling internally.
out = model(img_padded, mask_padded, sigma=sigma)
# Unpad result
x_hat = F.pad(out['x_hat'], unpad)
# Convert back to image
x_hat_np = x_hat.squeeze(0).permute(1, 2, 0).cpu().numpy()
x_hat_np = np.clip(x_hat_np * 255, 0, 255).astype(np.uint8)
compressed_img = Image.fromarray(x_hat_np)
# Calculate BPP
num_pixels = h * w
likelihoods = out['likelihoods']
bpp_y = torch.log(likelihoods['y']).sum() / (-np.log(2) * num_pixels)
bpp_z = torch.log(likelihoods['z']).sum() / (-np.log(2) * num_pixels)
bpp = (bpp_y + bpp_z).item()
return {
'compressed': compressed_img,
'bpp': bpp,
'original_size': (w, h),
'mask_used': mask
}