|
|
""" |
|
|
Utility functions for image compression. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from typing import Tuple, Dict |
|
|
from .roi_tic import ModifiedTIC |
|
|
|
|
|
|
|
|
def compute_padding(in_h: int, in_w: int, min_div: int = 256) -> Tuple[Tuple[int, int, int, int], Tuple[int, int, int, int]]: |
|
|
""" |
|
|
Compute padding to make dimensions divisible by min_div. |
|
|
|
|
|
Args: |
|
|
in_h: input height |
|
|
in_w: input width |
|
|
min_div: minimum divisor (default 256 for TIC) |
|
|
|
|
|
Returns: |
|
|
pad: (left, right, top, bottom) padding |
|
|
unpad: negative padding for cropping back |
|
|
""" |
|
|
out_h = (in_h + min_div - 1) // min_div * min_div |
|
|
out_w = (in_w + min_div - 1) // min_div * min_div |
|
|
|
|
|
left = (out_w - in_w) // 2 |
|
|
right = out_w - in_w - left |
|
|
top = (out_h - in_h) // 2 |
|
|
bottom = out_h - in_h - top |
|
|
|
|
|
pad = (left, right, top, bottom) |
|
|
unpad = (-left, -right, -top, -bottom) |
|
|
|
|
|
return pad, unpad |
|
|
|
|
|
|
|
|
def compress_image( |
|
|
image: Image.Image, |
|
|
mask: np.ndarray, |
|
|
model: ModifiedTIC, |
|
|
sigma: float = 0.3, |
|
|
device: str = 'cuda' |
|
|
) -> Dict: |
|
|
""" |
|
|
Compress image with ROI-based quality control. |
|
|
|
|
|
Args: |
|
|
image: PIL Image (RGB) |
|
|
mask: Binary mask (H, W) with 1 for ROI, 0 for background |
|
|
model: Loaded ModifiedTIC model |
|
|
sigma: Background quality (0.01-1.0, lower = more compression) |
|
|
device: 'cuda' or 'cpu' |
|
|
|
|
|
Returns: |
|
|
dict with: |
|
|
- compressed: PIL Image of compressed result |
|
|
- bpp: Bits per pixel |
|
|
- original_size: Original image dimensions |
|
|
- mask_used: The mask that was used |
|
|
""" |
|
|
|
|
|
img_array = np.array(image).astype(np.float32) / 255.0 |
|
|
img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
_, _, h, w = img_tensor.shape |
|
|
pad, unpad = compute_padding(h, w, min_div=256) |
|
|
img_padded = F.pad(img_tensor, pad, mode='constant', value=0) |
|
|
|
|
|
|
|
|
mask_tensor = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).to(device) |
|
|
mask_padded = F.pad(mask_tensor, pad, mode='constant', value=0) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
out = model(img_padded, mask_padded, sigma=sigma) |
|
|
|
|
|
|
|
|
x_hat = F.pad(out['x_hat'], unpad) |
|
|
|
|
|
|
|
|
x_hat_np = x_hat.squeeze(0).permute(1, 2, 0).cpu().numpy() |
|
|
x_hat_np = np.clip(x_hat_np * 255, 0, 255).astype(np.uint8) |
|
|
compressed_img = Image.fromarray(x_hat_np) |
|
|
|
|
|
|
|
|
num_pixels = h * w |
|
|
likelihoods = out['likelihoods'] |
|
|
bpp_y = torch.log(likelihoods['y']).sum() / (-np.log(2) * num_pixels) |
|
|
bpp_z = torch.log(likelihoods['z']).sum() / (-np.log(2) * num_pixels) |
|
|
bpp = (bpp_y + bpp_z).item() |
|
|
|
|
|
return { |
|
|
'compressed': compressed_img, |
|
|
'bpp': bpp, |
|
|
'original_size': (w, h), |
|
|
'mask_used': mask |
|
|
} |
|
|
|