import numpy as np import torch from PIL import Image import torchvision from torchvision.transforms.functional import to_pil_image, to_tensor as tv_to_tensor from src.Device import Device def tensor2pil(image: torch.Tensor) -> Image.Image: """Convert tensor [B,H,W,C] or [H,W,C] to PIL image using torchvision.""" if image.dim() == 4: image = image[0] # Take first from batch # HWC -> CHW for torchvision if image.shape[-1] in [1, 3, 4]: image = image.permute(2, 0, 1) return to_pil_image(torch.clamp(image, 0, 1)) def general_tensor_resize(image: torch.Tensor, w: int, h: int) -> torch.Tensor: """Resize tensor using bilinear interpolation. Expects [B,H,W,C].""" image = image.permute(0, 3, 1, 2) image = torch.nn.functional.interpolate(image, size=(h, w), mode="bilinear") return image.permute(0, 2, 3, 1) def pil2tensor(image: Image.Image) -> torch.Tensor: """Convert PIL image to tensor [1,H,W,C] using torchvision.""" return tv_to_tensor(image).unsqueeze(0).permute(0, 2, 3, 1) class TensorBatchBuilder: """Utility for building a batch of tensors by concatenation.""" def __init__(self): self.tensor: torch.Tensor | None = None def concat(self, new_tensor: torch.Tensor) -> None: self.tensor = new_tensor if self.tensor is None else torch.cat([self.tensor, new_tensor], dim=0) LANCZOS = Image.Resampling.LANCZOS def tensor_resize(image: torch.Tensor, w: int, h: int) -> torch.Tensor: """Resize tensor [B,H,W,C] using LANCZOS (3+ channels) or bilinear.""" if image.shape[3] >= 3: scaled = TensorBatchBuilder() for single in image: pil = tensor2pil(single.unsqueeze(0)) scaled.concat(pil2tensor(pil.resize((w, h), resample=LANCZOS))) return scaled.tensor return general_tensor_resize(image, w, h) def tensor_paste( image1: torch.Tensor, image2: torch.Tensor, left_top: tuple[int, int], mask: torch.Tensor, ) -> None: """Paste image2 onto image1 at left_top position using mask.""" x, y = [int(round(c)) for c in left_top] _, h1, w1, _ = image1.shape _, h2, w2, _ = image2.shape w, h = min(w1, x + w2) - x, min(h1, y + h2) - y # Ensure all tensors are on the same device as image1 device = image1.device mask = mask[:, :h, :w, :].to(device) image2 = image2[:, :h, :w, :].to(device) image1[:, y:y+h, x:x+w, :] = (1 - mask) * image1[:, y:y+h, x:x+w, :] + mask * image2 def tensor_convert_rgba(image: torch.Tensor, prefer_copy: bool = True) -> torch.Tensor: """Add alpha channel (ones) to tensor.""" return torch.cat((image, torch.ones((*image.shape[:-1], 1))), axis=-1) def tensor_convert_rgb(image: torch.Tensor, prefer_copy: bool = True) -> torch.Tensor: """Return image unchanged (already RGB).""" return image def tensor_get_size(image: torch.Tensor) -> tuple[int, int]: """Return (width, height) of tensor [B,H,W,C].""" _, h, w, _ = image.shape return (w, h) def tensor_putalpha(image: torch.Tensor, mask: torch.Tensor) -> None: """Set alpha channel from mask.""" image[..., -1] = mask[..., 0] def tensor_gaussian_blur_mask( mask: torch.Tensor | np.ndarray, kernel_size: int, sigma: float = 10.0 ) -> torch.Tensor: """Apply Gaussian blur to mask using torchvision.""" if isinstance(mask, np.ndarray): mask = torch.from_numpy(mask) if mask.ndim == 2: mask = mask[None, ..., None] device = Device.get_torch_device() mask = mask[:, None, ..., 0].to(device) blurred = torchvision.transforms.GaussianBlur(kernel_size=kernel_size*2+1, sigma=sigma)(mask) return blurred[:, 0, ..., None] def to_tensor(image: np.ndarray) -> torch.Tensor: """Convert numpy array to tensor.""" return torch.from_numpy(image)