Spaces:
Runtime error
Runtime error
| import argparse | |
| import os | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
| import torch | |
| from omegaconf import OmegaConf | |
| from tqdm import tqdm | |
| from torchvision import transforms | |
| from torchvision.io import write_video | |
| from einops import rearrange | |
| import torch.distributed as dist | |
| from torch.utils.data import DataLoader, SequentialSampler | |
| from torch.utils.data.distributed import DistributedSampler | |
| from pipeline import ( | |
| CausalDiffusionInferencePipeline, | |
| CausalInferencePipeline, | |
| ) | |
| from utils.dataset import TextDataset, TextImagePairDataset | |
| from utils.misc import set_seed | |
| from hydra import initialize, compose | |
| from hydra.core.global_hydra import GlobalHydra | |
| from demo_utils.memory import gpu, get_cuda_free_memory_gb, DynamicSwapInstaller | |
| from pathlib import Path | |
| config_name = "self_forcing_dmd_vsink" | |
| output_chunk_number = 21 | |
| output_latent_frame_number = 21 | |
| # output_latent_frame_number = 81 | |
| seed = 42 | |
| import sys | |
| sys.argv.extend( | |
| [ | |
| "--output_folder", | |
| f"outputs/{output_latent_frame_number}-{config_name}-seed{seed}", | |
| # f"outputs-test/{output_latent_frame_number}-{config_name}-seed{seed}", | |
| "--config_dir", | |
| "configs", | |
| "--config_name", | |
| config_name, | |
| "--num_output_frames", | |
| f"{output_latent_frame_number}", | |
| "--data_path", | |
| "prompts/MovieGenVideoBench_extended.txt", | |
| "--checkpoint_path", | |
| "./checkpoints/self_forcing_dmd.pt", | |
| "--use_ema", | |
| "--seed", | |
| f"{seed}", | |
| ] | |
| ) | |
| print(f"{sys.argv = }") | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config_dir", type=str, help="Directory to the config file") | |
| parser.add_argument("--config_name", type=str, help="Name to the config file") | |
| parser.add_argument("--checkpoint_path", type=str, help="Path to the checkpoint folder") | |
| parser.add_argument("--data_path", type=str, help="Path to the dataset") | |
| parser.add_argument("--extended_prompt_path", type=str, help="Path to the extended prompt") | |
| parser.add_argument("--output_folder", type=str, help="Output folder") | |
| parser.add_argument( | |
| "--num_output_frames", | |
| type=int, | |
| default=21, | |
| help="Number of overlap frames between sliding windows", | |
| ) | |
| parser.add_argument( | |
| "--i2v", | |
| action="store_true", | |
| help="Whether to perform I2V (or T2V by default)", | |
| ) | |
| parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA parameters") | |
| parser.add_argument("--seed", type=int, default=0, help="Random seed") | |
| parser.add_argument( | |
| "--num_samples", | |
| type=int, | |
| default=1, | |
| help="Number of samples to generate per prompt", | |
| ) | |
| args = parser.parse_args() | |
| # Initialize distributed inference | |
| if "LOCAL_RANK" in os.environ: | |
| dist.init_process_group(backend="nccl") | |
| local_rank = int(os.environ["LOCAL_RANK"]) | |
| torch.cuda.set_device(local_rank) | |
| device = torch.device(f"cuda:{local_rank}") | |
| world_size = dist.get_world_size() | |
| set_seed(args.seed + local_rank) | |
| else: | |
| device = torch.device("cuda") | |
| local_rank = 0 | |
| world_size = 1 | |
| set_seed(args.seed) | |
| print(f"Free VRAM {get_cuda_free_memory_gb(gpu)} GB") | |
| low_memory = get_cuda_free_memory_gb(gpu) < 40 | |
| torch.set_grad_enabled(False) | |
| if GlobalHydra.instance().is_initialized(): | |
| GlobalHydra.instance().clear() | |
| with initialize(version_base=None, config_path=args.config_dir): | |
| config = compose(config_name=args.config_name) | |
| print(f"{config = }") | |
| # Initialize pipeline | |
| if hasattr(config, "denoising_step_list"): | |
| # Few-step inference | |
| pipeline = CausalInferencePipeline(config, device=device) | |
| else: | |
| # Multi-step diffusion inference | |
| pipeline = CausalDiffusionInferencePipeline(config, device=device) | |
| if args.checkpoint_path: | |
| state_dict = torch.load(args.checkpoint_path, map_location="cpu") | |
| pipeline.generator.load_state_dict( | |
| state_dict["generator" if not args.use_ema else "generator_ema"] | |
| ) | |
| pipeline = pipeline.to(dtype=torch.bfloat16) | |
| if low_memory: | |
| DynamicSwapInstaller.install_model(pipeline.text_encoder, device=gpu) | |
| else: | |
| pipeline.text_encoder.to(device=gpu) | |
| pipeline.generator.to(device=gpu) | |
| pipeline.vae.to(device=gpu) | |
| # Create dataset | |
| if args.i2v: | |
| assert not dist.is_initialized(), "I2V does not support distributed inference yet" | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize((480, 832)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]), | |
| ] | |
| ) | |
| dataset = TextImagePairDataset(args.data_path, transform=transform) | |
| else: | |
| dataset = TextDataset( | |
| prompt_path=args.data_path, | |
| extended_prompt_path=args.extended_prompt_path, | |
| ) | |
| num_prompts = len(dataset) | |
| print(f"Number of prompts: {num_prompts}") | |
| if dist.is_initialized(): | |
| sampler = DistributedSampler(dataset, shuffle=False, drop_last=True) | |
| else: | |
| sampler = SequentialSampler(dataset) | |
| dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False) | |
| # Create output directory (only on main process to avoid race conditions) | |
| if local_rank == 0: | |
| os.makedirs(args.output_folder, exist_ok=True) | |
| if dist.is_initialized(): | |
| dist.barrier() | |
| def encode(self, videos: torch.Tensor) -> torch.Tensor: | |
| device, dtype = videos[0].device, videos[0].dtype | |
| scale = [ | |
| self.mean.to(device=device, dtype=dtype), | |
| 1.0 / self.std.to(device=device, dtype=dtype), | |
| ] | |
| output = [self.model.encode(u.unsqueeze(0), scale).float().squeeze(0) for u in videos] | |
| output = torch.stack(output, dim=0) | |
| return output | |
| for i, batch_data in tqdm(enumerate(dataloader), disable=(local_rank != 0)): | |
| idx = batch_data["idx"].item() | |
| # For DataLoader batch_size=1, the batch_data is already a single item, but in a batch container | |
| # Unpack the batch data for convenience | |
| if isinstance(batch_data, dict): | |
| batch = batch_data | |
| elif isinstance(batch_data, list): | |
| batch = batch_data[0] # First (and only) item in the batch | |
| all_video = [] | |
| num_generated_frames = 0 # Number of generated (latent) frames | |
| set_seed(args.seed) | |
| if args.i2v: | |
| # For image-to-video, batch contains image and caption | |
| prompt = batch["prompts"][0] # Get caption from batch | |
| prompts = [prompt] * args.num_samples | |
| # Process the image | |
| image = ( | |
| batch["image"] | |
| .squeeze(0) | |
| .unsqueeze(0) | |
| .unsqueeze(2) | |
| .to(device=device, dtype=torch.bfloat16) | |
| ) | |
| # Encode the input image as the first latent | |
| initial_latent = pipeline.vae.encode_to_latent(image).to( | |
| device=device, dtype=torch.bfloat16 | |
| ) | |
| initial_latent = initial_latent.repeat(args.num_samples, 1, 1, 1, 1) | |
| sampled_noise = torch.randn( | |
| [args.num_samples, args.num_output_frames - 1, 16, 60, 104], | |
| device=device, | |
| dtype=torch.bfloat16, | |
| ) | |
| else: | |
| # For text-to-video, batch is just the text prompt | |
| prompt = batch["prompts"][0] | |
| extended_prompt = batch["extended_prompts"][0] if "extended_prompts" in batch else None | |
| if extended_prompt is not None: | |
| prompts = [extended_prompt] * args.num_samples | |
| else: | |
| prompts = [prompt] * args.num_samples | |
| initial_latent = None | |
| sampled_noise = torch.randn( | |
| [args.num_samples, args.num_output_frames, 16, 60, 104], | |
| device=device, | |
| dtype=torch.bfloat16, | |
| ) | |
| set_seed(args.seed) | |
| # Generate 81 frames | |
| video, latents = pipeline.inference( | |
| noise=sampled_noise, | |
| text_prompts=prompts, | |
| return_latents=True, | |
| initial_latent=initial_latent, | |
| low_memory=low_memory, | |
| ) | |
| current_video = rearrange(video, "b t c h w -> b t h w c").cpu() | |
| all_video.append(current_video) | |
| num_generated_frames += latents.shape[1] | |
| # Final output video | |
| video = 255.0 * torch.cat(all_video, dim=1) | |
| # Clear VAE cache | |
| pipeline.vae.model.clear_cache() | |
| # Save the video if the current prompt is not a dummy prompt | |
| if idx < num_prompts: | |
| model = "regular" if not args.use_ema else "ema" | |
| for seed_idx in range(args.num_samples): | |
| # All processes save their videos | |
| output_path = os.path.join( | |
| args.output_folder, | |
| f"{idx}-{prompt[:50].replace(' ', '_')}-{seed_idx}_{model}.mp4", | |
| ) | |
| write_video(output_path, video[seed_idx], fps=16) | |