Spaces:
Running
Running
| """ | |
| Utilities used by app.py. | |
| This is a Space-local subset of the project's `utils.py` — only the helpers | |
| needed for Stage 2 fusion (image I/O, decoder loading, PSNR). | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from torchvision import transforms | |
| from decoder import Decoder | |
| def load_image(image_path: str, size: int = 224) -> torch.Tensor: | |
| """Load an image as a (1, 3, H, W) tensor in [0, 1].""" | |
| img = Image.open(image_path).convert("RGB") | |
| transform = transforms.Compose([ | |
| transforms.Resize((size, size)), | |
| transforms.ToTensor(), | |
| ]) | |
| return transform(img).unsqueeze(0) | |
| def load_decoder(path: str, embed_dim: int = 512, device: torch.device = None) -> Decoder: | |
| """Load AnyAttack Decoder weights with state dict key remapping.""" | |
| decoder = Decoder(embed_dim=embed_dim).to(device).eval() | |
| ckpt = torch.load(path, map_location="cpu", weights_only=False) | |
| state = ckpt.get("decoder_state_dict", ckpt) | |
| remapped = {} | |
| for k, v in state.items(): | |
| k = k.removeprefix("module.") | |
| k = k.replace("upsample_blocks.", "blocks.") | |
| k = k.replace("final_conv.", "head.") | |
| remapped[k] = v | |
| decoder.load_state_dict(remapped) | |
| return decoder | |
| def compute_psnr(img1: torch.Tensor, img2: torch.Tensor) -> float: | |
| """Compute PSNR between two image tensors in [0, 1].""" | |
| mse = torch.mean((img1 - img2) ** 2).item() | |
| if mse == 0: | |
| return float("inf") | |
| return -10 * torch.log10(torch.tensor(mse)).item() | |