Spaces:
Running
Running
File size: 1,525 Bytes
e1887f1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 | """
Utilities used by app.py.
This is a Space-local subset of the project's `utils.py` — only the helpers
needed for Stage 2 fusion (image I/O, decoder loading, PSNR).
"""
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from decoder import Decoder
def load_image(image_path: str, size: int = 224) -> torch.Tensor:
"""Load an image as a (1, 3, H, W) tensor in [0, 1]."""
img = Image.open(image_path).convert("RGB")
transform = transforms.Compose([
transforms.Resize((size, size)),
transforms.ToTensor(),
])
return transform(img).unsqueeze(0)
def load_decoder(path: str, embed_dim: int = 512, device: torch.device = None) -> Decoder:
"""Load AnyAttack Decoder weights with state dict key remapping."""
decoder = Decoder(embed_dim=embed_dim).to(device).eval()
ckpt = torch.load(path, map_location="cpu", weights_only=False)
state = ckpt.get("decoder_state_dict", ckpt)
remapped = {}
for k, v in state.items():
k = k.removeprefix("module.")
k = k.replace("upsample_blocks.", "blocks.")
k = k.replace("final_conv.", "head.")
remapped[k] = v
decoder.load_state_dict(remapped)
return decoder
def compute_psnr(img1: torch.Tensor, img2: torch.Tensor) -> float:
"""Compute PSNR between two image tensors in [0, 1]."""
mse = torch.mean((img1 - img2) ** 2).item()
if mse == 0:
return float("inf")
return -10 * torch.log10(torch.tensor(mse)).item()
|