# Modified from https://github.com/Zheng-Chong/CatVTON/blob/edited/eval.py import os import torch from cleanfid import fid as FID from PIL import Image from torch.utils.data import Dataset from torchmetrics.image import StructuralSimilarityIndexMeasure from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from torchvision import transforms from tqdm import tqdm from DISTS_pytorch import DISTS from module.utils import scan_files_in_dir from prettytable import PrettyTable class EvalDataset(Dataset): def __init__(self, gt_folder, pred_folder, height=1024): self.gt_folder = gt_folder self.pred_folder = pred_folder self.height = height self.data = self.prepare_data() self.to_tensor = transforms.ToTensor() def extract_id_from_filename(self, filename): # find first number in filename start_i = None for i, c in enumerate(filename): if c.isdigit(): start_i = i break if start_i is None: assert False, f"Cannot find number in filename {filename}" return filename[start_i:start_i+8] def prepare_data(self): gt_files = scan_files_in_dir(self.gt_folder, postfix={'.jpg', '.png'}) gt_dict = {self.extract_id_from_filename(file.name): file for file in gt_files} pred_files = scan_files_in_dir(self.pred_folder, postfix={'.jpg', '.png'}) tuples = [] for pred_file in pred_files: pred_id = self.extract_id_from_filename(pred_file.name) if pred_id not in gt_dict: print(f"Cannot find gt file for {pred_file}") else: tuples.append((gt_dict[pred_id].path, pred_file.path)) return tuples def resize(self, img): w, h = img.size new_w = int(w * self.height / h) return img.resize((new_w, self.height), Image.LANCZOS) def __len__(self): return len(self.data) def __getitem__(self, idx): gt_path, pred_path = self.data[idx] gt, pred = self.resize(Image.open(gt_path)), self.resize(Image.open(pred_path)) if gt.height != self.height: gt = self.resize(gt) if pred.height != self.height: pred = self.resize(pred) gt = self.to_tensor(gt) pred = self.to_tensor(pred) return gt, pred def copy_resize_gt(gt_folder, height): new_folder = f"{gt_folder}_{height}" if not os.path.exists(new_folder): os.makedirs(new_folder, exist_ok=True) for file in tqdm(os.listdir(gt_folder)): if os.path.exists(os.path.join(new_folder, file)): continue img = Image.open(os.path.join(gt_folder, file)) w, h = img.size new_w = int(w * height / h) img = img.resize((new_w, height), Image.LANCZOS) img.save(os.path.join(new_folder, file)) return new_folder @torch.no_grad() def ssim(dataloader): ssim_score = 0 ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to("cuda") for gt, pred in tqdm(dataloader, desc="Calculating SSIM"): batch_size = gt.size(0) gt, pred = gt.to("cuda"), pred.to("cuda") ssim_score += ssim(pred, gt) * batch_size return ssim_score / len(dataloader.dataset) @torch.no_grad() def lpips(dataloader): lpips_score = LearnedPerceptualImagePatchSimilarity(net_type='squeeze').to("cuda") score = 0 for gt, pred in tqdm(dataloader, desc="Calculating LPIPS"): batch_size = gt.size(0) pred = pred.to("cuda") gt = gt.to("cuda") # LPIPS needs the images to be in the [-1, 1] range. gt = (gt * 2) - 1 pred = (pred * 2) - 1 score += lpips_score(gt, pred) * batch_size return score / len(dataloader.dataset) @torch.no_grad() def dists(dataloader): D = DISTS().to("cuda") score = 0 for gt, pred in tqdm(dataloader, desc="Calculating DISTS"): batch_size = gt.size(0) pred = pred.to("cuda") gt = gt.to("cuda") # DISTS expects images in [0, 1] range, which matches ToTensor output dists_value = D(pred, gt) # DISTS returns a tensor - sum over batch dimension if it's a vector if dists_value.dim() > 0: # If it's a vector (batch_size,), sum all elements score += dists_value.sum().item() else: # If it's a scalar, multiply by batch_size score += dists_value.item() * batch_size return score / len(dataloader.dataset) def eval(args): # Check gt_folder has images with target height, resize if not pred_sample = os.listdir(args.pred_folder)[0] gt_sample = os.listdir(args.gt_folder)[0] img = Image.open(os.path.join(args.pred_folder, pred_sample)) gt_img = Image.open(os.path.join(args.gt_folder, gt_sample)) if img.height != gt_img.height: title = "--"*30 + "Resizing GT Images to height {img.height}" + "--"*30 print(title) args.gt_folder = copy_resize_gt(args.gt_folder, img.height) print("-"*len(title)) # Form dataset dataset = EvalDataset(args.gt_folder, args.pred_folder, img.height) dataloader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, drop_last=False ) # Calculate Metrics header = [] row = [] results_dict = {} # FID and KID are disabled for now header = ["FID", "KID"] fid_ = FID.compute_fid(args.gt_folder, args.pred_folder) kid_ = FID.compute_kid(args.gt_folder, args.pred_folder) * 1000 row = [fid_, kid_] if args.paired: header += ["SSIM", "LPIPS", "DISTS"] ssim_ = ssim(dataloader) lpips_ = lpips(dataloader) dists_ = dists(dataloader) row += [ssim_, lpips_, dists_] # Print Results print("GT Folder : ", args.gt_folder) print("Pred Folder: ", args.pred_folder) if header and row: # Only create table if we have both header and row table = PrettyTable() table.field_names = header table.add_row(row) print(table) else: print("No metrics to display (FID/KID disabled, paired metrics only available with --paired flag)") # Save results to CSV if specified if args.results_csv and results_dict: import csv file_exists = os.path.exists(args.results_csv) with open(args.results_csv, 'a', newline='') as f: writer = csv.writer(f) if not file_exists: # Write header writer.writerow(["Method", "Dataset", "Setting"] + list(results_dict.keys())) # Extract method name from pred_folder path method = os.path.basename(os.path.dirname(os.path.dirname(args.pred_folder))) dataset = os.path.basename(os.path.dirname(args.pred_folder)) setting = os.path.basename(args.pred_folder) writer.writerow([method, dataset, setting] + [results_dict.get(h, "") for h in results_dict.keys()]) return results_dict if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--gt_folder", type=str, required=True) parser.add_argument("--pred_folder", type=str, required=True) parser.add_argument("--paired", action="store_true") parser.add_argument("--batch_size", type=int, default=16) parser.add_argument("--num_workers", type=int, default=4) parser.add_argument("--results_csv", type=str, default=None, help="CSV file to save results") args = parser.parse_args() eval(args)