Spaces:
Runtime error
Runtime error
| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| from .lpips import LPIPS | |
| # Normalize image tensors | |
| def normalize_tensor(images, norm_type): | |
| assert norm_type in ["imagenet", "naive"] | |
| # Two possible normalization conventions | |
| if norm_type == "imagenet": | |
| mean = [0.485, 0.456, 0.406] | |
| std = [0.229, 0.224, 0.225] | |
| normalize = transforms.Normalize(mean, std) | |
| elif norm_type == "naive": | |
| mean = [0.5, 0.5, 0.5] | |
| std = [0.5, 0.5, 0.5] | |
| normalize = transforms.Normalize(mean, std) | |
| else: | |
| assert False | |
| return torch.stack([normalize(image) for image in images]) | |
| def to_tensor(images, norm_type="naive"): | |
| assert isinstance(images, list) and all( | |
| [isinstance(image, Image.Image) for image in images] | |
| ) | |
| images = torch.stack([transforms.ToTensor()(image) for image in images]) | |
| if norm_type is not None: | |
| images = normalize_tensor(images, norm_type) | |
| return images | |
| def load_perceptual_models(metric_name, mode, device=torch.device("cuda")): | |
| assert metric_name in ["lpips"] | |
| if metric_name == "lpips": | |
| assert mode in ["vgg", "alex"] | |
| perceptual_model = LPIPS(net=mode).to(device) | |
| else: | |
| assert False | |
| return perceptual_model | |
| # Compute metric between two images | |
| def compute_metric(image1, image2, perceptual_model, device=torch.device("cuda")): | |
| assert isinstance(image1, Image.Image) and isinstance(image2, Image.Image) | |
| image1_tensor = to_tensor([image1]).to(device) | |
| image2_tensor = to_tensor([image2]).to(device) | |
| return perceptual_model(image1_tensor, image2_tensor).cpu().item() | |
| # Compute LPIPS distance between two images | |
| def compute_lpips(image1, image2, mode="alex", device=torch.device("cuda")): | |
| perceptual_model = load_perceptual_models("lpips", mode, device) | |
| return compute_metric(image1, image2, perceptual_model, device) | |
| # Compute metrics between pairs of images | |
| def compute_perceptual_metric_repeated( | |
| images1, | |
| images2, | |
| metric_name, | |
| mode, | |
| model, | |
| device, | |
| ): | |
| # Accept list of PIL images | |
| assert isinstance(images1, list) and isinstance(images1[0], Image.Image) | |
| assert isinstance(images2, list) and isinstance(images2[0], Image.Image) | |
| assert len(images1) == len(images2) | |
| if model is None: | |
| model = load_perceptual_models(metric_name, mode).to(device) | |
| return ( | |
| model(to_tensor(images1).to(device), to_tensor(images2).to(device)) | |
| .detach() | |
| .cpu() | |
| .numpy() | |
| .flatten() | |
| .tolist() | |
| ) | |
| # Compute LPIPS distance between pairs of images | |
| def compute_lpips_repeated( | |
| images1, | |
| images2, | |
| mode="alex", | |
| model=None, | |
| device=torch.device("cuda"), | |
| ): | |
| return compute_perceptual_metric_repeated( | |
| images1, images2, "lpips", mode, model, device | |
| ) | |