File size: 3,076 Bytes
398659b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
Utility functions for image compression.
"""

import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
from typing import Tuple, Dict
from .roi_tic import ModifiedTIC


def compute_padding(in_h: int, in_w: int, min_div: int = 256) -> Tuple[Tuple[int, int, int, int], Tuple[int, int, int, int]]:
    """
    Compute padding to make dimensions divisible by min_div.
    
    Args:
        in_h: input height
        in_w: input width
        min_div: minimum divisor (default 256 for TIC)
    
    Returns:
        pad: (left, right, top, bottom) padding
        unpad: negative padding for cropping back
    """
    out_h = (in_h + min_div - 1) // min_div * min_div
    out_w = (in_w + min_div - 1) // min_div * min_div
    
    left = (out_w - in_w) // 2
    right = out_w - in_w - left
    top = (out_h - in_h) // 2
    bottom = out_h - in_h - top
    
    pad = (left, right, top, bottom)
    unpad = (-left, -right, -top, -bottom)
    
    return pad, unpad


def compress_image(
    image: Image.Image,
    mask: np.ndarray,
    model: ModifiedTIC,
    sigma: float = 0.3,
    device: str = 'cuda'
) -> Dict:
    """
    Compress image with ROI-based quality control.
    
    Args:
        image: PIL Image (RGB)
        mask: Binary mask (H, W) with 1 for ROI, 0 for background
        model: Loaded ModifiedTIC model
        sigma: Background quality (0.01-1.0, lower = more compression)
        device: 'cuda' or 'cpu'
    
    Returns:
        dict with:
            - compressed: PIL Image of compressed result
            - bpp: Bits per pixel
            - original_size: Original image dimensions
            - mask_used: The mask that was used
    """
    # Convert image to tensor
    img_array = np.array(image).astype(np.float32) / 255.0
    img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).to(device)
    
    # Pad image
    _, _, h, w = img_tensor.shape
    pad, unpad = compute_padding(h, w, min_div=256)
    img_padded = F.pad(img_tensor, pad, mode='constant', value=0)
    
    # Prepare mask
    mask_tensor = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).to(device)
    mask_padded = F.pad(mask_tensor, pad, mode='constant', value=0)
    
    # Compress
    with torch.no_grad():
        # NOTE: `ModifiedTIC.forward()` handles mask downsampling internally.
        out = model(img_padded, mask_padded, sigma=sigma)
    
    # Unpad result
    x_hat = F.pad(out['x_hat'], unpad)
    
    # Convert back to image
    x_hat_np = x_hat.squeeze(0).permute(1, 2, 0).cpu().numpy()
    x_hat_np = np.clip(x_hat_np * 255, 0, 255).astype(np.uint8)
    compressed_img = Image.fromarray(x_hat_np)
    
    # Calculate BPP
    num_pixels = h * w
    likelihoods = out['likelihoods']
    bpp_y = torch.log(likelihoods['y']).sum() / (-np.log(2) * num_pixels)
    bpp_z = torch.log(likelihoods['z']).sum() / (-np.log(2) * num_pixels)
    bpp = (bpp_y + bpp_z).item()
    
    return {
        'compressed': compressed_img,
        'bpp': bpp,
        'original_size': (w, h),
        'mask_used': mask
    }