| """ |
| Evaluate tokenizer performance by computing reconstruction metrics. |
| |
| Metrics include: |
| - rFID (Reconstruction FID) |
| - PSNR (Peak Signal-to-Noise Ratio) |
| - LPIPS (Learned Perceptual Image Patch Similarity) |
| - SSIM (Structural Similarity Index) |
| |
| by Jingfeng Yao |
| from HUST-VL |
| """ |
|
|
| import os |
| import torch, yaml |
| import numpy as np |
| from tqdm import tqdm |
| from PIL import Image |
| import torch.distributed as dist |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from omegaconf import OmegaConf |
| from torch.utils.data import DataLoader, DistributedSampler |
| from tools.calculate_fid import calculate_fid_given_paths |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
| from torchmetrics import StructuralSimilarityIndexMeasure |
| from models.lpips import LPIPS |
| from torchvision.datasets import ImageFolder |
| from torchvision import transforms |
| from diffusers.models import AutoencoderKL |
| from tokenizer.sdvae import Diffusers_AutoencoderKL |
| from tokenizer import models_mae |
|
|
| def load_config(config_path): |
| with open(config_path, "r") as file: |
| config = yaml.safe_load(file) |
| return config |
|
|
| def print_with_prefix(content, prefix='Tokenizer Evaluation', rank=0): |
| if rank == 0: |
| print(f"\033[34m[{prefix}]\033[0m {content}") |
|
|
| def save_image(image, filename): |
| Image.fromarray(image).save(filename) |
|
|
| def evaluate_tokenizer(args, config_path, model_type, data_path, output_path): |
| |
| dist.init_process_group(backend='nccl') |
| local_rank = torch.distributed.get_rank() |
| torch.cuda.set_device(local_rank) |
| device = torch.device(f'cuda:{local_rank}') |
| train_config = load_config(config_path) |
| model_type = train_config['vae']['model_name'].split("_")[0] |
|
|
| if local_rank == 0: |
| print_with_prefix(f"Loading model... {model_type.upper()} {args.epsilon}") |
| |
| if train_config['vae']['model_name'].split("_")[0] == 'vmae': |
| chkpt = train_config['vae']['weight_path'] |
| arch = 'mae_for_ldmae_f8d16_prev' |
| model = getattr(models_mae, arch)(ldmae_mode=True, no_cls=True, kl_loss_weight=True, smooth_output=True, img_size=train_config['data']['image_size']) |
| checkpoint = torch.load(chkpt, map_location='cpu') |
| model = model.to(device).eval() |
| msg = model.load_state_dict(checkpoint['model'], strict=False) |
| elif train_config['vae']['model_name'].split("_")[0] in ['ae','dae','vae','sdv3']: |
| model = Diffusers_AutoencoderKL( |
| img_size=train_config['data']['image_size'], |
| sample_size=128, |
| in_channels=3, |
| out_channels=3, |
| layers_per_block=2, |
| latent_channels=16, |
| norm_num_groups=32, |
| act_fn="silu", |
| block_out_channels=(128, 256, 512, 512), |
| force_upcast=False, |
| use_quant_conv=False, |
| use_post_quant_conv=False, |
| down_block_types=( |
| "DownEncoderBlock2D", |
| "DownEncoderBlock2D", |
| "DownEncoderBlock2D", |
| "DownEncoderBlock2D", |
| ), |
| up_block_types=( |
| "UpDecoderBlock2D", |
| "UpDecoderBlock2D", |
| "UpDecoderBlock2D", |
| "UpDecoderBlock2D", |
| ), |
| ).to(device).eval() |
| chkpt_dir = train_config['vae']['weight_path'] |
| checkpoint = torch.load(chkpt_dir, map_location='cpu') |
| msg = model.load_state_dict(checkpoint['model'], strict=False) |
| else: |
| raise |
| print(msg) |
| |
| transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Resize(256), |
| transforms.CenterCrop(256), |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
| ]) |
|
|
| |
| dataset = ImageFolder(root=data_path, transform=transform) |
| distributed_sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=local_rank) |
| val_dataloader = DataLoader( |
| dataset, |
| batch_size=8, |
| shuffle=False, |
| num_workers=4, |
| sampler=distributed_sampler |
| ) |
| |
| if 'sample' in train_config['data']: |
| train_config['data']['data_path'] += '_sample' |
| latent_stats_cache_file = os.path.join(train_config['data']['data_path'], 'latents_stats.pt') |
| latent_stats = torch.load(latent_stats_cache_file) |
| latent_mean, latent_std = latent_stats['mean'], latent_stats['std'] |
| |
| latent_mean = latent_mean.clone().detach().to(device) |
| latent_std = latent_std.clone().detach().to(device) |
| |
|
|
| |
| folder_name = f"{model_type}_{args.epsilon}" |
| |
| save_dir = os.path.join(output_path, folder_name, 'decoded_images') |
| ref_path = os.path.join(output_path, 'ref_images') |
| os.makedirs(save_dir, exist_ok=True) |
| os.makedirs(ref_path, exist_ok=True) |
|
|
| if local_rank == 0: |
| print_with_prefix(f"Output dir: {save_dir}") |
| print_with_prefix(f"Reference dir: {ref_path}") |
|
|
| |
| ref_png_files = [f for f in os.listdir(ref_path) if f.endswith('.png')] |
| if len(ref_png_files) < 50000: |
| total_samples = 0 |
| for batch in val_dataloader: |
| images = batch[0].to(device) |
| for j in range(images.size(0)): |
| img = torch.clamp(127.5 * images[j] + 128.0, 0, 255).cpu().permute(1, 2, 0).numpy().astype(np.uint8) |
| Image.fromarray(img).save(os.path.join(ref_path, f"ref_image_rank_{local_rank}_{total_samples}.png")) |
| total_samples += 1 |
| if total_samples % 100 == 0 and local_rank == 0: |
| print_with_prefix(f"Rank {local_rank}, Saved {total_samples} reference images") |
| dist.barrier() |
|
|
| |
| lpips_values = [] |
| ssim_values = [] |
| lpips = LPIPS().to(device).eval() |
| ssim_metric = StructuralSimilarityIndexMeasure(data_range=(-1.0, 1.0)).to(device) |
|
|
| |
| if local_rank == 0: |
| print_with_prefix("Generating reconstructions...") |
| all_indices = 0 |
| if len(os.listdir(save_dir)) < 50000: |
| for batch in val_dataloader: |
| images = batch[0].to(device) |
| latents = encode_images(model, images) |
| epsilon = args.epsilon * torch.randn_like(latents) |
| latents = latents + epsilon * latent_std |
| |
| with torch.no_grad(): |
| decoded_images_tensor = model.decode(latents).sample |
| decoded_images = torch.clamp(127.5 * decoded_images_tensor + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() |
| |
| |
| lpips_values.append(lpips(decoded_images_tensor, images).mean()) |
| ssim_values.append(ssim_metric(decoded_images_tensor, images)) |
| |
| |
| for i, img in enumerate(decoded_images): |
| save_image(img, os.path.join(save_dir, f"decoded_image_rank_{local_rank}_{all_indices + i}.png")) |
| if (all_indices + i) % 100 == 0 and local_rank == 0: |
| print_with_prefix(f"Rank {local_rank}, Processed {all_indices + i} images") |
| all_indices += len(decoded_images) |
| dist.barrier() |
|
|
| |
| lpips_values = torch.tensor(lpips_values).to(device) |
| ssim_values = torch.tensor(ssim_values).to(device) |
| dist.all_reduce(lpips_values, op=dist.ReduceOp.AVG) |
| dist.all_reduce(ssim_values, op=dist.ReduceOp.AVG) |
| |
| avg_lpips = lpips_values.mean().item() |
| avg_ssim = ssim_values.mean().item() |
|
|
| if local_rank == 0: |
| |
| print_with_prefix("Computing rFID...") |
| fid = calculate_fid_given_paths([ref_path, save_dir], batch_size=50, dims=2048, device=device, num_workers=16) |
|
|
| |
| print_with_prefix("Computing PSNR...") |
| psnr_values = calculate_psnr_between_folders(ref_path, save_dir) |
| avg_psnr = sum(psnr_values) / len(psnr_values) |
|
|
| |
| print_with_prefix(f"Final Metrics:") |
| print_with_prefix(f"rFID: {fid:.3f}") |
| print_with_prefix(f"PSNR: {avg_psnr:.3f}") |
| print_with_prefix(f"LPIPS: {avg_lpips:.3f}") |
| print_with_prefix(f"SSIM: {avg_ssim:.3f}") |
| dist.barrier() |
| dist.destroy_process_group() |
|
|
| def encode_images(model, images): |
| with torch.no_grad(): |
| posterior = model.encode(images).latent_dist |
| return posterior.mode().to(torch.float32) |
|
|
| def decode_to_images(model, z): |
| with torch.no_grad(): |
| images = model.decode(z) |
| images = torch.clamp(127.5 * images + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() |
| return images |
|
|
| def calculate_psnr(original, processed): |
| mse = torch.mean((original - processed) ** 2) |
| return 20 * torch.log10(255.0 / torch.sqrt(mse)).item() |
|
|
| def load_image(image_path): |
| image = Image.open(image_path).convert('RGB') |
| return torch.tensor(np.array(image).transpose(2, 0, 1), dtype=torch.float32) |
|
|
| def calculate_psnr_for_pair(original_path, processed_path): |
| return calculate_psnr(load_image(original_path), load_image(processed_path)) |
|
|
| def calculate_psnr_between_folders(original_folder, processed_folder): |
| original_files = sorted(os.listdir(original_folder)) |
| processed_files = sorted(os.listdir(processed_folder)) |
|
|
| if len(original_files) != len(processed_files): |
| print("Warning: Mismatched number of images in folders") |
| return [] |
|
|
| with ThreadPoolExecutor() as executor: |
| futures = [ |
| executor.submit(calculate_psnr_for_pair, |
| os.path.join(original_folder, orig), |
| os.path.join(processed_folder, proc)) |
| for orig, proc in zip(original_files, processed_files) |
| ] |
| return [future.result() for future in as_completed(futures)] |
|
|
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--config_path', type=str, default='tokenizer/configs/vavae_f16d32.yaml') |
| parser.add_argument('--model_type', type=str, default='vavae') |
| parser.add_argument('--data_path', type=str, default='/data/dataset/imagenet/1K_dataset/val') |
| parser.add_argument('--output_path', type=str, default='./rfid') |
| parser.add_argument('--seed', type=int, default=42) |
| parser.add_argument('--epsilon', type=float, default=0, help="Noise pertubation ratio for latent robustness experiment.") |
| args = parser.parse_args() |
| evaluate_tokenizer(args, config_path=args.config_path, model_type=args.model_type, data_path=args.data_path, output_path=args.output_path) |