import argparse import os import sys sys.path.append(os.getcwd()) import warnings warnings.filterwarnings("ignore") import time from contextlib import nullcontext from omegaconf import OmegaConf from torch import autocast from tqdm import tqdm import numpy as np import torch from einops import rearrange from lightning.pytorch import seed_everything from vidtok.data.vidtok import VidTokValDataset from vidtok.modules.lpips import LPIPS from vidtok.modules.util import (compute_psnr, compute_ssim, instantiate_from_config, print0) def load_model_from_config(config, ckpt, ignore_keys=[], verbose=False): config = OmegaConf.load(config) config.model.params.ckpt_path = ckpt config.model.params.ignore_keys = ignore_keys config.model.params.verbose = verbose model = instantiate_from_config(config.model) return model class MultiVideoDataset(VidTokValDataset): def __init__( self, data_dir, meta_path=None, input_height=256, input_width=256, sample_fps=30, chunk_size=16, is_causal=True, read_long_video=False ): super().__init__( data_dir=data_dir, meta_path=meta_path, video_params={ "input_height": input_height, "input_width": input_width, "sample_num_frames": chunk_size + 1 if is_causal else chunk_size, "sample_fps": sample_fps, }, pre_load_frames=True, last_frames_handle="repeat", read_long_video=read_long_video, chunk_size=chunk_size, is_causal=is_causal, ) def __getitem__(self, idx): frames = super().__getitem__(idx)["jpg"] return frames def main(): parser = argparse.ArgumentParser() parser.add_argument( "--seed", type=int, default=42, help="the seed (for reproducible sampling)", ) parser.add_argument( "--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="full" ) parser.add_argument( "--config", type=str, default="configs/vidtok_kl_causal_488_4chn.yaml", help="path to config which constructs model", ) parser.add_argument( "--ckpt", type=str, default="checkpoints/vidtok_kl_causal_488_4chn.ckpt", help="path to checkpoint of model", ) parser.add_argument( "--data_dir", type=str, default="./", help="root folder", ) parser.add_argument( "--meta_path", type=str, default=None, help="path to the .csv meta file", ) parser.add_argument( "--input_height", type=int, default=256, help="height of the input video", ) parser.add_argument( "--input_width", type=int, default=256, help="width of the input video", ) parser.add_argument( "--sample_fps", type=int, default=30, help="sample fps", ) parser.add_argument( "--chunk_size", type=int, default=16, help="the size of a chunk - we split a long video into several chunks", ) parser.add_argument( "--read_long_video", action='store_true' ) args = parser.parse_args() seed_everything(args.seed) print0(f"[bold red]\[scripts.inference_evaluate][/bold red] Evaluating model {args.ckpt}") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") precision_scope = autocast if args.precision == "autocast" else nullcontext model = load_model_from_config(args.config, args.ckpt) model.to(device).eval() assert args.chunk_size % model.encoder.time_downsample_factor == 0 if args.read_long_video: assert hasattr(model, 'use_tiling'), "Tiling inference is needed to conduct long video reconstruction." print(f"Using tiling inference to save memory usage...") model.enable_tiling() model.t_chunk_enc = args.chunk_size model.t_chunk_dec = model.t_chunk_enc // model.encoder.time_downsample_factor if args.input_width > 256: model.enable_tiling() dataset = MultiVideoDataset( data_dir=args.data_dir, meta_path=args.meta_path, input_height=args.input_height, input_width=args.input_width, sample_fps=args.sample_fps, chunk_size=args.chunk_size, is_causal=model.is_causal, read_long_video=args.read_long_video ) dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False) perceptual_loss = LPIPS().eval() perceptual_loss = perceptual_loss.to(device) psnrs, ssims, lpipss = [], [], [] with torch.no_grad(), precision_scope("cuda"): tic = time.time() for i, input in tqdm(enumerate(dataloader)): input = input.to(device) _, output, reg_log = model(input) output = output.clamp(-1, 1) input, output = map(lambda x: (x + 1) / 2, (input, output)) if input.dim() == 5: input = rearrange(input, "b c t h w -> (b t) c h w") assert output.dim() == 5 output = rearrange(output, "b c t h w -> (b t) c h w") for inp, out in zip(torch.split(input, 16), torch.split(output, 16)): psnrs += [compute_psnr(inp, out).item()] * inp.shape[0] ssims += [compute_ssim(inp, out).item()] * inp.shape[0] lpipss += [perceptual_loss(inp * 2 - 1, out * 2 - 1).mean().item()] * inp.shape[0] toc = time.time() print0( f"[bold red]\[scripts.inference_evaluate][/bold red] PSNR: {np.mean(psnrs):.4f}, SSIM: {np.mean(ssims):.4f}, LPIPS: {np.mean(lpipss):.4f}" ) print0(f"[bold red]\[scripts.inference_evaluate][/bold red] Time taken: {toc - tic:.2f}s") if __name__ == "__main__": main()