""" Utilities used by app.py. This is a Space-local subset of the project's `utils.py` — only the helpers needed for Stage 2 fusion (image I/O, decoder loading, PSNR). """ import torch import torch.nn.functional as F from PIL import Image from torchvision import transforms from decoder import Decoder def load_image(image_path: str, size: int = 224) -> torch.Tensor: """Load an image as a (1, 3, H, W) tensor in [0, 1].""" img = Image.open(image_path).convert("RGB") transform = transforms.Compose([ transforms.Resize((size, size)), transforms.ToTensor(), ]) return transform(img).unsqueeze(0) def load_decoder(path: str, embed_dim: int = 512, device: torch.device = None) -> Decoder: """Load AnyAttack Decoder weights with state dict key remapping.""" decoder = Decoder(embed_dim=embed_dim).to(device).eval() ckpt = torch.load(path, map_location="cpu", weights_only=False) state = ckpt.get("decoder_state_dict", ckpt) remapped = {} for k, v in state.items(): k = k.removeprefix("module.") k = k.replace("upsample_blocks.", "blocks.") k = k.replace("final_conv.", "head.") remapped[k] = v decoder.load_state_dict(remapped) return decoder def compute_psnr(img1: torch.Tensor, img2: torch.Tensor) -> float: """Compute PSNR between two image tensors in [0, 1].""" mse = torch.mean((img1 - img2) ** 2).item() if mse == 0: return float("inf") return -10 * torch.log10(torch.tensor(mse)).item()