| | |
| |
|
| | import gc |
| | import os |
| | from typing import List, Optional, Union |
| |
|
| | import numpy as np |
| | import torch |
| | from diffusers import FlowMatchEulerDiscreteScheduler |
| | from diffusers.utils import export_to_video |
| | from diffusers.utils.torch_utils import randn_tensor |
| | from diffusers.video_processor import VideoProcessor |
| | from einops import rearrange |
| | from tqdm import tqdm |
| |
|
| | import wandb |
| | from fastvideo.distill.solver import PCMFMScheduler |
| | from fastvideo.models.mochi_hf.pipeline_mochi import ( |
| | linear_quadratic_schedule, retrieve_timesteps) |
| | from fastvideo.utils.communications import all_gather |
| | from fastvideo.utils.load import load_vae |
| | from fastvideo.utils.parallel_states import (get_sequence_parallel_state, |
| | nccl_info) |
| |
|
| |
|
| | def prepare_latents( |
| | batch_size, |
| | num_channels_latents, |
| | height, |
| | width, |
| | num_frames, |
| | dtype, |
| | device, |
| | generator, |
| | vae_spatial_scale_factor, |
| | vae_temporal_scale_factor, |
| | ): |
| | height = height // vae_spatial_scale_factor |
| | width = width // vae_spatial_scale_factor |
| | num_frames = (num_frames - 1) // vae_temporal_scale_factor + 1 |
| |
|
| | shape = (batch_size, num_channels_latents, num_frames, height, width) |
| |
|
| | latents = randn_tensor(shape, |
| | generator=generator, |
| | device=device, |
| | dtype=dtype) |
| | return latents |
| |
|
| |
|
| | def sample_validation_video( |
| | model_type, |
| | transformer, |
| | vae, |
| | scheduler, |
| | scheduler_type="euler", |
| | height: Optional[int] = None, |
| | width: Optional[int] = None, |
| | num_frames: int = 16, |
| | num_inference_steps: int = 28, |
| | timesteps: List[int] = None, |
| | guidance_scale: float = 4.5, |
| | num_videos_per_prompt: Optional[int] = 1, |
| | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| | prompt_embeds: Optional[torch.Tensor] = None, |
| | prompt_attention_mask: Optional[torch.Tensor] = None, |
| | negative_prompt_embeds: Optional[torch.Tensor] = None, |
| | negative_prompt_attention_mask: Optional[torch.Tensor] = None, |
| | output_type: Optional[str] = "pil", |
| | vae_spatial_scale_factor=8, |
| | vae_temporal_scale_factor=6, |
| | num_channels_latents=12, |
| | ): |
| | device = vae.device |
| |
|
| | batch_size = prompt_embeds.shape[0] |
| |
|
| | do_classifier_free_guidance = guidance_scale > 1.0 |
| | if do_classifier_free_guidance: |
| | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], |
| | dim=0) |
| | prompt_attention_mask = torch.cat( |
| | [negative_prompt_attention_mask, prompt_attention_mask], dim=0) |
| |
|
| | |
| | |
| | latents = prepare_latents( |
| | batch_size * num_videos_per_prompt, |
| | num_channels_latents, |
| | height, |
| | width, |
| | num_frames, |
| | prompt_embeds.dtype, |
| | device, |
| | generator, |
| | vae_spatial_scale_factor, |
| | vae_temporal_scale_factor, |
| | ) |
| | world_size, rank = nccl_info.sp_size, nccl_info.rank_within_group |
| | if get_sequence_parallel_state(): |
| | latents = rearrange(latents, |
| | "b t (n s) h w -> b t n s h w", |
| | n=world_size).contiguous() |
| | latents = latents[:, :, rank, :, :, :] |
| |
|
| | |
| | |
| | threshold_noise = 0.025 |
| | sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise) |
| | sigmas = np.array(sigmas) |
| | if scheduler_type == "euler" and model_type == "mochi": |
| | timesteps, num_inference_steps = retrieve_timesteps( |
| | scheduler, |
| | num_inference_steps, |
| | device, |
| | timesteps, |
| | sigmas, |
| | ) |
| | else: |
| | timesteps, num_inference_steps = retrieve_timesteps( |
| | scheduler, |
| | num_inference_steps, |
| | device, |
| | ) |
| | num_warmup_steps = max( |
| | len(timesteps) - num_inference_steps * scheduler.order, 0) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | with tqdm( |
| | total=num_inference_steps, |
| | disable=nccl_info.rank_within_group != 0, |
| | desc="Validation sampling...", |
| | ) as progress_bar: |
| | for i, t in enumerate(timesteps): |
| | latent_model_input = (torch.cat([latents] * 2) |
| | if do_classifier_free_guidance else latents) |
| | |
| | timestep = t.expand(latent_model_input.shape[0]) |
| | with torch.autocast("cuda", dtype=torch.bfloat16): |
| | noise_pred = transformer( |
| | hidden_states=latent_model_input, |
| | encoder_hidden_states=prompt_embeds, |
| | timestep=timestep, |
| | encoder_attention_mask=prompt_attention_mask, |
| | return_dict=False, |
| | )[0] |
| |
|
| | |
| | noise_pred = noise_pred.to(torch.float32) |
| | if do_classifier_free_guidance: |
| | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| | noise_pred = noise_pred_uncond + guidance_scale * ( |
| | noise_pred_text - noise_pred_uncond) |
| |
|
| | |
| | latents_dtype = latents.dtype |
| | latents = scheduler.step(noise_pred, |
| | t, |
| | latents.to(torch.float32), |
| | return_dict=False)[0] |
| | latents = latents.to(latents_dtype) |
| |
|
| | if latents.dtype != latents_dtype: |
| | if torch.backends.mps.is_available(): |
| | |
| | latents = latents.to(latents_dtype) |
| |
|
| | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and |
| | (i + 1) % scheduler.order == 0): |
| | progress_bar.update() |
| |
|
| | if get_sequence_parallel_state(): |
| | latents = all_gather(latents, dim=2) |
| |
|
| | if output_type == "latent": |
| | video = latents |
| | else: |
| | |
| | |
| | has_latents_mean = (hasattr(vae.config, "latents_mean") |
| | and vae.config.latents_mean is not None) |
| | has_latents_std = (hasattr(vae.config, "latents_std") |
| | and vae.config.latents_std is not None) |
| | if has_latents_mean and has_latents_std: |
| | latents_mean = (torch.tensor(vae.config.latents_mean).view( |
| | 1, 12, 1, 1, 1).to(latents.device, latents.dtype)) |
| | latents_std = (torch.tensor(vae.config.latents_std).view( |
| | 1, 12, 1, 1, 1).to(latents.device, latents.dtype)) |
| | latents = latents * latents_std / vae.config.scaling_factor + latents_mean |
| | else: |
| | latents = latents / vae.config.scaling_factor |
| | with torch.autocast("cuda", dtype=vae.dtype): |
| | video = vae.decode(latents, return_dict=False)[0] |
| | video_processor = VideoProcessor( |
| | vae_scale_factor=vae_spatial_scale_factor) |
| | video = video_processor.postprocess_video(video, |
| | output_type=output_type) |
| |
|
| | return (video, ) |
| |
|
| |
|
| | @torch.no_grad() |
| | @torch.autocast("cuda", dtype=torch.bfloat16) |
| | def log_validation( |
| | args, |
| | transformer, |
| | device, |
| | weight_dtype, |
| | global_step, |
| | scheduler_type="euler", |
| | shift=1.0, |
| | num_euler_timesteps=100, |
| | linear_quadratic_threshold=0.025, |
| | linear_range=0.5, |
| | ema=False, |
| | ): |
| | |
| | print("Running validation....\n") |
| | if args.model_type == "mochi": |
| | vae_spatial_scale_factor = 8 |
| | vae_temporal_scale_factor = 6 |
| | num_channels_latents = 12 |
| | elif args.model_type == "hunyuan" or "hunyuan_hf": |
| | vae_spatial_scale_factor = 8 |
| | vae_temporal_scale_factor = 4 |
| | num_channels_latents = 16 |
| | else: |
| | raise ValueError(f"Model type {args.model_type} not supported") |
| | vae, autocast_type, fps = load_vae(args.model_type, |
| | args.pretrained_model_name_or_path) |
| | vae.enable_tiling() |
| | if scheduler_type == "euler": |
| | scheduler = FlowMatchEulerDiscreteScheduler(shift=shift) |
| | else: |
| | linear_quadraic = True if scheduler_type == "pcm_linear_quadratic" else False |
| | scheduler = PCMFMScheduler( |
| | 1000, |
| | shift, |
| | num_euler_timesteps, |
| | linear_quadraic, |
| | linear_quadratic_threshold, |
| | linear_range, |
| | ) |
| | |
| |
|
| | validation_guidance_scale_ls = args.validation_guidance_scale.split(",") |
| | validation_guidance_scale_ls = [ |
| | float(scale) for scale in validation_guidance_scale_ls |
| | ] |
| | for validation_sampling_step in args.validation_sampling_steps.split(","): |
| | validation_sampling_step = int(validation_sampling_step) |
| | for validation_guidance_scale in validation_guidance_scale_ls: |
| | videos = [] |
| | |
| | |
| | embe_dir = os.path.join(args.validation_prompt_dir, "prompt_embed") |
| | mask_dir = os.path.join(args.validation_prompt_dir, |
| | "prompt_attention_mask") |
| | embeds = sorted([f for f in os.listdir(embe_dir)]) |
| | masks = sorted([f for f in os.listdir(mask_dir)]) |
| | num_embeds = len(embeds) |
| | validation_prompt_ids = list(range(num_embeds)) |
| | num_sp_groups = int(os.getenv("WORLD_SIZE", |
| | "1")) // nccl_info.sp_size |
| | |
| | if num_embeds % num_sp_groups != 0: |
| | validation_prompt_ids += [0] * (num_sp_groups - |
| | num_embeds % num_sp_groups) |
| | num_embeds_per_group = len(validation_prompt_ids) // num_sp_groups |
| | local_prompt_ids = validation_prompt_ids[nccl_info.group_id * |
| | num_embeds_per_group: |
| | (nccl_info.group_id + 1) * |
| | num_embeds_per_group] |
| |
|
| | for i in local_prompt_ids: |
| | prompt_embed_path = os.path.join(embe_dir, f"{embeds[i]}") |
| | prompt_mask_path = os.path.join(mask_dir, f"{masks[i]}") |
| | prompt_embeds = (torch.load( |
| | prompt_embed_path, map_location="cpu", |
| | weights_only=True).to(device).unsqueeze(0)) |
| | prompt_attention_mask = (torch.load( |
| | prompt_mask_path, map_location="cpu", |
| | weights_only=True).to(device).unsqueeze(0)) |
| | negative_prompt_embeds = torch.zeros( |
| | 256, 4096).to(device).unsqueeze(0) |
| | negative_prompt_attention_mask = ( |
| | torch.zeros(256).bool().to(device).unsqueeze(0)) |
| | generator = torch.Generator(device="cpu").manual_seed(12345) |
| | video = sample_validation_video( |
| | args.model_type, |
| | transformer, |
| | vae, |
| | scheduler, |
| | scheduler_type=scheduler_type, |
| | num_frames=args.num_frames, |
| | height=args.num_height, |
| | width=args.num_width, |
| | num_inference_steps=validation_sampling_step, |
| | guidance_scale=validation_guidance_scale, |
| | generator=generator, |
| | prompt_embeds=prompt_embeds, |
| | prompt_attention_mask=prompt_attention_mask, |
| | negative_prompt_embeds=negative_prompt_embeds, |
| | negative_prompt_attention_mask= |
| | negative_prompt_attention_mask, |
| | vae_spatial_scale_factor=vae_spatial_scale_factor, |
| | vae_temporal_scale_factor=vae_temporal_scale_factor, |
| | num_channels_latents=num_channels_latents, |
| | )[0] |
| | if nccl_info.rank_within_group == 0: |
| | videos.append(video[0]) |
| | |
| |
|
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | |
| | torch.distributed.barrier() |
| | all_videos = [ |
| | None for i in range(int(os.getenv("WORLD_SIZE", "1"))) |
| | ] |
| | torch.distributed.all_gather_object(all_videos, videos) |
| | if nccl_info.global_rank == 0: |
| | |
| | videos = [video for videos in all_videos for video in videos] |
| | videos = videos[:num_embeds] |
| | |
| | video_filenames = [] |
| | for i, video in enumerate(videos): |
| | filename = os.path.join( |
| | args.output_dir, |
| | f"validation_step_{global_step}_sample_{validation_sampling_step}_guidance_{validation_guidance_scale}_video_{i}.mp4", |
| | ) |
| | export_to_video(video, filename, fps=fps) |
| | video_filenames.append(filename) |
| |
|
| | logs = { |
| | f"{'ema_' if ema else ''}validation_sample_{validation_sampling_step}_guidance_{validation_guidance_scale}": |
| | [ |
| | wandb.Video(filename) |
| | for i, filename in enumerate(video_filenames) |
| | ] |
| | } |
| | wandb.log(logs, step=global_step) |
| |
|