from __future__ import annotations from dataclasses import dataclass from typing import Tuple import numpy as np import torch import torch.nn.functional as F from PIL import Image, ImageOps @dataclass(frozen=True) class PreparedImage: full_pil: Image.Image model_pil: Image.Image full_tensor: torch.Tensor model_tensor: torch.Tensor def normalize_pil(image: Image.Image) -> Image.Image: image = ImageOps.exif_transpose(image) return image.convert("RGB") def resize_max_side(image: Image.Image, max_side: int) -> Image.Image: max_side = int(max(256, min(max_side, 2048))) width, height = image.size longest = max(width, height) if longest <= max_side: return image.copy() scale = max_side / float(longest) new_size = (max(1, round(width * scale)), max(1, round(height * scale))) return image.resize(new_size, Image.Resampling.LANCZOS) def pil_to_tensor(image: Image.Image) -> torch.Tensor: arr = np.asarray(image, dtype=np.float32) / 255.0 tensor = torch.from_numpy(arr).permute(2, 0, 1).contiguous() return tensor.unsqueeze(0) def tensor_to_pil(tensor: torch.Tensor) -> Image.Image: tensor = tensor.detach().float().cpu().clamp(0, 1) if tensor.ndim == 4: tensor = tensor[0] arr = (tensor.permute(1, 2, 0).numpy() * 255.0).round().astype(np.uint8) return Image.fromarray(arr, mode="RGB") def prepare_image(image: Image.Image, max_side: int, model_size: int = 512) -> PreparedImage: full_pil = resize_max_side(normalize_pil(image), max_side) model_pil = full_pil.resize((model_size, model_size), Image.Resampling.BICUBIC) return PreparedImage( full_pil=full_pil, model_pil=model_pil, full_tensor=pil_to_tensor(full_pil), model_tensor=pil_to_tensor(model_pil), ) def match_tensor_size(tensor: torch.Tensor, target_hw: Tuple[int, int]) -> torch.Tensor: if tuple(tensor.shape[-2:]) == tuple(target_hw): return tensor return F.interpolate(tensor, size=target_hw, mode="bilinear", align_corners=False) def blend_strength(input_tensor: torch.Tensor, output_tensor: torch.Tensor, strength: float) -> torch.Tensor: strength = float(max(0.0, min(2.0, strength))) output_tensor = match_tensor_size(output_tensor, input_tensor.shape[-2:]) return (input_tensor + strength * (output_tensor - input_tensor)).clamp(0, 1)