| from PIL import Image |
| import os |
| import numpy as np |
| from torchvision.transforms import functional as F |
| import torch |
| from torchmetrics.image.fid import FrechetInceptionDistance |
|
|
|
|
| |
| generated_dataset_path = "output/tryon_results" |
| original_dataset_path = "data/VITON-HD/test/image" |
|
|
| |
| image_paths = sorted([os.path.join(generated_dataset_path, x) for x in os.listdir(generated_dataset_path)]) |
| generated_images = [np.array(Image.open(path).convert("RGB")) for path in image_paths] |
|
|
| |
| original_images = [] |
| for gen_path in image_paths: |
| |
| base_name = os.path.basename(gen_path) |
| original_id = base_name.replace("tryon_", "") |
| |
| |
| original_path = os.path.join(original_dataset_path, original_id) |
| original_images.append(np.array(Image.open(original_path).convert("RGB"))) |
| |
|
|
|
|
| def preprocess_image(image): |
| image = torch.tensor(image).unsqueeze(0) |
| image = image.permute(0, 3, 1, 2) / 255.0 |
| return F.center_crop(image, (768, 1024)) |
|
|
| real_images = torch.cat([preprocess_image(image) for image in original_images]) |
| fake_images = torch.cat([preprocess_image(image) for image in generated_images]) |
| print(real_images.shape, fake_images.shape) |
|
|
| fid = FrechetInceptionDistance(normalize=True) |
| fid.update(real_images, real=True) |
| fid.update(fake_images, real=False) |
|
|
| print(f"FID: {float(fid.compute())}") |