import logging import gc import torch import numpy as np import torch.distributed as dist from einops import rearrange, repeat from tqdm import tqdm from algorithms.common.base_pytorch_algo import BasePytorchAlgo from transformers import get_scheduler import zmq import msgpack import io from PIL import Image import torchvision.transforms as transforms from utils.video_utils import numpy_to_mp4_bytes from .modules.model import WanModel, WanAttentionBlock from .modules.t5 import umt5_xxl, T5CrossAttention, T5SelfAttention from .modules.tokenizers import HuggingfaceTokenizer from .modules.vae import video_vae_factory from .utils.fm_solvers import ( FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps, ) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from utils.distributed_utils import is_rank_zero def print_module_hierarchy(model, indent=0): for name, module in model.named_children(): print(" " * indent + f"{name}: {type(module)}") print_module_hierarchy(module, indent + 2) class WanTextToVideo(BasePytorchAlgo): """ Main class for WanTextToVideo """ def __init__(self, cfg): self.num_train_timesteps = cfg.num_train_timesteps self.height = cfg.height self.width = cfg.width self.n_frames = cfg.n_frames self.gradient_checkpointing_rate = cfg.gradient_checkpointing_rate self.sample_solver = cfg.sample_solver self.sample_steps = cfg.sample_steps self.sample_shift = cfg.sample_shift self.lang_guidance = cfg.lang_guidance self.neg_prompt = cfg.neg_prompt self.hist_guidance = cfg.hist_guidance self.sliding_hist = cfg.sliding_hist self.diffusion_forcing = cfg.diffusion_forcing self.vae_stride = cfg.vae.stride self.patch_size = cfg.model.patch_size self.diffusion_type = cfg.diffusion_type # "discrete" # or "continuous" self.lat_h = self.height // self.vae_stride[1] self.lat_w = self.width // self.vae_stride[2] self.lat_t = 1 + (self.n_frames - 1) // self.vae_stride[0] self.lat_c = cfg.vae.z_dim self.max_area = self.height * self.width self.max_tokens = ( self.lat_t * self.lat_h * self.lat_w // (self.patch_size[1] * self.patch_size[2]) ) self.load_prompt_embed = cfg.load_prompt_embed self.load_video_latent = cfg.load_video_latent self.socket = None if (self.sliding_hist - 1) % self.vae_stride[0] != 0: raise ValueError( "sliding_hist - 1 must be a multiple of vae_stride[0] due to temporal " f"vae. Got {self.sliding_hist} and vae stride {self.vae_stride[0]}" ) if self.load_video_latent: raise NotImplementedError("Loading video latent is not implemented yet") super().__init__(cfg) @staticmethod def classes_to_shard(): classes = {WanAttentionBlock, T5CrossAttention, T5SelfAttention} # , return classes @property def is_inference(self) -> bool: return self._trainer is None or not self.trainer.training def configure_model(self): logging.info("Building model...") # Initialize text encoder if not self.cfg.load_prompt_embed: text_encoder = ( umt5_xxl( encoder_only=True, return_tokenizer=False, dtype=torch.bfloat16 if self.is_inference else self.dtype, device=torch.device("cpu"), ) .eval() .requires_grad_(False) ) if self.cfg.text_encoder.ckpt_path is not None: text_encoder.load_state_dict( torch.load( self.cfg.text_encoder.ckpt_path, map_location="cpu", weights_only=True, # mmap=True, ) ) if self.cfg.text_encoder.compile: text_encoder = torch.compile(text_encoder) else: text_encoder = None self.text_encoder = text_encoder # Initialize tokenizer self.tokenizer = HuggingfaceTokenizer( name=self.cfg.text_encoder.name, seq_len=self.cfg.text_encoder.text_len, clean="whitespace", ) # Initialize VAE self.vae = ( video_vae_factory( pretrained_path=self.cfg.vae.ckpt_path, z_dim=self.cfg.vae.z_dim, ) .eval() .requires_grad_(False) ).to(self.dtype) self.register_buffer( "vae_mean", torch.tensor(self.cfg.vae.mean, dtype=self.dtype) ) self.register_buffer( "vae_inv_std", 1.0 / torch.tensor(self.cfg.vae.std, dtype=self.dtype) ) self.vae_scale = [self.vae_mean, self.vae_inv_std] if self.cfg.vae.compile: self.vae = torch.compile(self.vae) # Initialize main diffusion model if self.cfg.model.tuned_ckpt_path is None: self.model = WanModel.from_pretrained(self.cfg.model.ckpt_path) else: print("Loading model from config") from accelerate import init_empty_weights, load_checkpoint_and_dispatch with init_empty_weights(): self.model = WanModel.from_config( WanModel._dict_from_json_file(self.cfg.model.ckpt_path + "/config.json") ) print("Loading state dict") self.model = load_checkpoint_and_dispatch( self.model, self.cfg.model.tuned_ckpt_path, device_map="auto", dtype=torch.bfloat16, no_split_module_classes=["WanAttentionBlock"], ) print("State dict loaded successfully") # self.model = WanModel( # model_type=self.cfg.model.model_type, # patch_size=self.cfg.model.patch_size, # text_len=self.cfg.text_encoder.text_len, # in_dim=self.cfg.model.in_dim, # dim=self.cfg.model.dim, # ffn_dim=self.cfg.model.ffn_dim, # freq_dim=self.cfg.model.freq_dim, # text_dim=self.cfg.text_encoder.text_dim, # out_dim=self.cfg.model.out_dim, # num_heads=self.cfg.model.num_heads, # num_layers=self.cfg.model.num_layers, # window_size=self.cfg.model.window_size, # qk_norm=self.cfg.model.qk_norm, # cross_attn_norm=self.cfg.model.cross_attn_norm, # eps=self.cfg.model.eps, # ) if not self.is_inference: self.model.to(self.dtype).train() if self.gradient_checkpointing_rate > 0: self.model.gradient_checkpointing_enable(p=self.gradient_checkpointing_rate) if self.cfg.model.compile: self.model = torch.compile(self.model) self.training_scheduler, self.training_timesteps = self.build_scheduler(True) def configure_optimizers(self): optimizer = torch.optim.AdamW( [ {"params": self.model.parameters(), "lr": self.cfg.lr}, {"params": self.vae.parameters(), "lr": 0}, ], weight_decay=self.cfg.weight_decay, betas=self.cfg.betas, ) # optimizer = torch.optim.AdamW( # self.model.parameters(), # lr=self.cfg.lr, # weight_decay=self.cfg.weight_decay, # betas=self.cfg.betas, # ) lr_scheduler_config = { "scheduler": get_scheduler( optimizer=optimizer, **self.cfg.lr_scheduler, ), "interval": "step", "frequency": 1, } return { "optimizer": optimizer, "lr_scheduler": lr_scheduler_config, } def _load_tuned_state_dict(self, prefix="model."): ckpt = torch.load( self.cfg.model.tuned_ckpt_path, mmap=True, map_location="cpu", weights_only=True, ) return ckpt def build_scheduler(self, is_training=True): # Solver if self.sample_solver == "unipc": scheduler = FlowUniPCMultistepScheduler( num_train_timesteps=self.num_train_timesteps, shift=self.sample_shift, use_dynamic_shifting=False, ) if not is_training: scheduler.set_timesteps( self.sample_steps, device=self.device, shift=self.sample_shift ) timesteps = scheduler.timesteps elif self.sample_solver == "dpm++": scheduler = FlowDPMSolverMultistepScheduler( num_train_timesteps=self.num_train_timesteps, shift=self.sample_shift, use_dynamic_shifting=False, ) if not is_training: sampling_sigmas = get_sampling_sigmas( self.sample_steps, self.sample_shift ) timesteps, _ = retrieve_timesteps( scheduler, device=self.device, sigmas=sampling_sigmas ) else: raise NotImplementedError("Unsupported solver.") return scheduler, timesteps def encode_text(self, texts): ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True) ids = ids.to(self.device) mask = mask.to(self.device) seq_lens = mask.gt(0).sum(dim=1).long() context = self.text_encoder(ids, mask) return [u[:v] for u, v in zip(context, seq_lens)] def encode_video(self, videos): """videos: [B, C, T, H, W]""" return self.vae.encode(videos, self.vae_scale) def decode_video(self, zs): return self.vae.decode(zs, self.vae_scale).clamp_(-1, 1) def clone_batch(self, batch): new_batch = {} for k, v in batch.items(): if isinstance(v, torch.Tensor): new_batch[k] = v.clone() else: new_batch[k] = v return new_batch @torch.no_grad() def prepare_embeds(self, batch): videos = batch["videos"] prompts = batch["prompts"] batch_size, t, _, h, w = videos.shape if t != self.n_frames: raise ValueError(f"Number of frames in videos must be {self.n_frames}") if h != self.height or w != self.width: raise ValueError( f"Height and width of videos must be {self.height} and {self.width}" ) if not self.cfg.load_prompt_embed: prompt_embeds = self.encode_text(prompts) else: prompt_embeds = batch["prompt_embeds"].to(self.dtype) prompt_embed_len = batch["prompt_embed_len"] prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, prompt_embed_len)] video_lat = self.encode_video(rearrange(videos, "b t c h w -> b c t h w")) # video_lat ~ (b, lat_c, lat_t, lat_h, lat_w batch["prompt_embeds"] = prompt_embeds batch["video_lat"] = video_lat batch["image_embeds"] = None batch["clip_embeds"] = None return batch def add_training_noise(self, video_lat): b, _, f = video_lat.shape[:3] device = video_lat.device if self.diffusion_type == "discrete": video_lat = rearrange(video_lat, "b c f h w -> (b f) c h w") noise = torch.randn_like(video_lat) timesteps = self.num_train_timesteps if self.diffusion_forcing.enabled: match self.diffusion_forcing.mode: case "independent": t = np.random.randint(timesteps, size=(b, f)) if np.random.rand() < self.diffusion_forcing.clean_hist_prob: t[:, 0] = timesteps - 1 case "rand_history": # currently we aim to support two history lengths, 1 and 6 possible_hist_lengths = [1, 2, 3, 4, 5, 6] hist_length_probs = [0.5, 0.1, 0.1, 0.1, 0.1, 0.1] t = np.zeros((b, f), dtype=np.int64) for i in range(b): hist_len_idx = np.random.choice( len(possible_hist_lengths), p=hist_length_probs ) hist_len = possible_hist_lengths[hist_len_idx] history_t = np.random.randint(timesteps) future_t = np.random.randint(timesteps) t[i, :hist_len] = history_t t[i, hist_len:] = future_t if ( np.random.rand() < self.diffusion_forcing.clean_hist_prob ): t[i, :hist_len] = timesteps - 1 t = self.training_timesteps[t.flatten()].reshape(b, f) t_expanded = t.flatten() else: t = np.random.randint(timesteps, size=(b,)) t_expanded = repeat(t, "b -> (b f)", f=f) t = self.training_timesteps[t] t_expanded = self.training_timesteps[t_expanded] noisy_lat = self.training_scheduler.add_noise(video_lat, noise, t_expanded) noisy_lat = rearrange(noisy_lat, "(b f) c h w -> b c f h w", b=b, f=f) noise = rearrange(noise, "(b f) c h w -> b c f h w", b=b, f=f) elif self.diffusion_type == "continuous": # continious time steps. # 1. first sample t ~ U[0, 1] # 2. shift t with equation: t = t * self.sample_shift / (1 + (self.sample_shift - 1) * t) # 3. expand t to [b, 1/f, 1, 1, 1] # 4. compute noisy_lat = video_lat * (1.0 - t_expanded) + noise * t_expanded # 5. scale t to [0, num_train_timesteps] # returns: # t is in [0, num_train_timesteps] of shape [b, f] or [b,], of dtype torch.float32 # video_lat is shape [b, c, f, h, w] # noise is shape [b, c, f, h, w] dist = torch.distributions.uniform.Uniform(0, 1) noise = torch.randn_like(video_lat) # [b, c, f, h, w] if self.diffusion_forcing.enabled: match self.diffusion_forcing.mode: case "independent": t = dist.sample((b, f)).to(device) if np.random.rand() < self.diffusion_forcing.clean_hist_prob: t[:, 0] = 0.0 case "rand_history": # currently we aim to support two history lengths, 1 and 6 possible_hist_lengths = [1, 2, 3, 4, 5, 6] hist_length_probs = [0.5, 0.1, 0.1, 0.1, 0.1, 0.1] t = np.zeros((b, f), dtype=np.float32) for i in range(b): hist_len_idx = np.random.choice( len(possible_hist_lengths), p=hist_length_probs ) hist_len = possible_hist_lengths[hist_len_idx] history_t = np.random.uniform(0, 1) future_t = np.random.uniform(0, 1) t[i, :hist_len] = history_t t[i, hist_len:] = future_t if ( np.random.rand() < self.diffusion_forcing.clean_hist_prob ): t[i, :hist_len] = 0 # cast dtype of t t = torch.from_numpy(t).to(device) t = t.float() # t is [b, f] in range [0, 1] or dtype torch.float32 0 indicates clean. t = t * self.sample_shift / (1 + (self.sample_shift - 1) * t) t_expanded = ( t.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) ) # [b, f] -> [b, 1, f, 1, 1] # [b, c, f, h, w] * [b, 1, f, 1, 1] + [b, c, f, h, w] * [b, 1, f, 1, 1] noisy_lat = video_lat * (1.0 - t_expanded) + noise * t_expanded t = t * self.num_train_timesteps # [b, f] -> [b, f] # now t is in [0, num_train_timesteps] of shape [b, f] else: t = dist.sample((b,)).to(device) t = t * self.sample_shift / (1 + (self.sample_shift - 1) * t) t_expanded = t.view(-1, 1, 1, 1, 1) noisy_lat = video_lat * (1.0 - t_expanded) + noise * t_expanded t = t * self.num_train_timesteps # [b,] # now t is in [0, num_train_timesteps] of shape [b,] else: raise NotImplementedError("Unsupported time step type.") return noisy_lat, noise, t def remove_noise(self, flow_pred, t, video_pred_lat): b, _, f = video_pred_lat.shape[:3] video_pred_lat = rearrange(video_pred_lat, "b c f h w -> (b f) c h w") flow_pred = rearrange(flow_pred, "b c f h w -> (b f) c h w") if t.ndim == 1: t = repeat(t, "b -> (b f)", f=f) elif t.ndim == 2: t = t.flatten() video_pred_lat = self.inference_scheduler.step( flow_pred, t, video_pred_lat, return_dict=False, )[0] video_pred_lat = rearrange(video_pred_lat, "(b f) c h w -> b c f h w", b=b) return video_pred_lat def training_step(self, batch, batch_idx=None): batch = self.prepare_embeds(batch) clip_embeds = batch["clip_embeds"] image_embeds = batch["image_embeds"] prompt_embeds = batch["prompt_embeds"] video_lat = batch["video_lat"] noisy_lat, noise, t = self.add_training_noise(video_lat) flow = noise - video_lat flow_pred = self.model( noisy_lat, t=t, context=prompt_embeds, clip_fea=clip_embeds, seq_len=self.max_tokens, y=image_embeds, ) loss = torch.nn.functional.mse_loss(flow_pred, flow) if self.global_step % self.cfg.logging.loss_freq == 0: self.log("train/loss", loss, sync_dist=True) return loss @torch.no_grad() def sample_seq(self, batch, hist_len=1, pbar=None): """ Main sampling loop. Only first hist_len frames are used for conditioning batch: dict batch["videos"]: [B, T, C, H, W] batch["prompts"]: [B] """ if (hist_len - 1) % self.vae_stride[0] != 0: raise ValueError( "hist_len - 1 must be a multiple of vae_stride[0] due to temporal vae. " f"Got {hist_len} and vae stride {self.vae_stride[0]}" ) hist_len = (hist_len - 1) // self.vae_stride[0] + 1 # length in latent self.inference_scheduler, self.inference_timesteps = self.build_scheduler(False) lang_guidance = self.lang_guidance if self.lang_guidance else 0 hist_guidance = self.hist_guidance if self.hist_guidance else 0 batch = self.prepare_embeds(batch) clip_embeds = batch["clip_embeds"] image_embeds = batch["image_embeds"] prompt_embeds = batch["prompt_embeds"] video_lat = batch["video_lat"] batch_size = video_lat.shape[0] video_pred_lat = torch.randn_like(video_lat) if self.lang_guidance: neg_prompt_embeds = self.encode_text( [self.neg_prompt] * len(batch["prompts"]) ) if pbar is None: pbar = tqdm(range(len(self.inference_timesteps)), desc="Sampling") for t in self.inference_timesteps: if self.diffusion_forcing.enabled: video_pred_lat[:, :, :hist_len] = video_lat[:, :, :hist_len] t_expanded = torch.full((batch_size, self.lat_t), t, device=self.device) t_expanded[:, :hist_len] = self.inference_timesteps[-1] else: t_expanded = torch.full((batch_size,), t, device=self.device) # normal conditional sampling flow_pred = self.model( video_pred_lat, t=t_expanded, context=prompt_embeds, seq_len=self.max_tokens, clip_fea=clip_embeds, y=image_embeds, ) if lang_guidance and hist_guidance and self.diffusion_forcing.enabled and lang_guidance == hist_guidance: # efficient guidance in case language and history guidance have the same strength no_hist_video_pred_lat = video_pred_lat.clone() no_hist_video_pred_lat[:, :, :hist_len] = torch.randn_like( no_hist_video_pred_lat[:, :, :hist_len] ) t_expanded[:, :hist_len] = self.inference_timesteps[0] no_cond_flow_pred = self.model( no_hist_video_pred_lat, t=t_expanded, context=neg_prompt_embeds, seq_len=self.max_tokens, clip_fea=clip_embeds, y=image_embeds, ) flow_pred = flow_pred * (1 + lang_guidance) - lang_guidance * no_cond_flow_pred else: # language unconditional sampling if lang_guidance: no_lang_flow_pred = self.model( video_pred_lat, t=t_expanded, context=neg_prompt_embeds, seq_len=self.max_tokens, clip_fea=clip_embeds, y=image_embeds, ) else: no_lang_flow_pred = torch.zeros_like(flow_pred) # history guidance sampling: if hist_guidance and self.diffusion_forcing.enabled: no_hist_video_pred_lat = video_pred_lat.clone() no_hist_video_pred_lat[:, :, :hist_len] = torch.randn_like( no_hist_video_pred_lat[:, :, :hist_len] ) t_expanded[:, :hist_len] = self.inference_timesteps[0] no_hist_flow_pred = self.model( no_hist_video_pred_lat, t=t_expanded, context=prompt_embeds, seq_len=self.max_tokens, clip_fea=clip_embeds, y=image_embeds, ) else: no_hist_flow_pred = torch.zeros_like(flow_pred) flow_pred = flow_pred * (1 + lang_guidance + hist_guidance) flow_pred = ( flow_pred - lang_guidance * no_lang_flow_pred - hist_guidance * no_hist_flow_pred ) video_pred_lat = self.remove_noise(flow_pred, t, video_pred_lat) pbar.update(1) video_pred_lat[:, :, :hist_len] = video_lat[:, :, :hist_len] video_pred = self.decode_video(video_pred_lat) video_pred = rearrange(video_pred, "b c t h w -> b t c h w") return video_pred def validation_step(self, batch, batch_idx=None): video_pred = self.sample_seq(batch) self.visualize(video_pred, batch) def visualize(self, video_pred, batch): video_gt = batch["videos"] if self.cfg.logging.video_type == "single": video_vis = video_pred.cpu() else: video_vis = torch.cat([video_pred, video_gt], dim=-1).cpu() video_vis = video_vis * 0.5 + 0.5 video_vis = rearrange(self.all_gather(video_vis), "p b ... -> (p b) ...") all_prompts = [None for _ in range(dist.get_world_size())] dist.all_gather_object(all_prompts, batch["prompts"]) all_prompts = [item for sublist in all_prompts for item in sublist] if is_rank_zero: if self.cfg.logging.video_type == "single": for i in range(min(len(video_vis), 16)): self.log_video( f"validation_vis/video_pred_{i}", video_vis[i], fps=self.cfg.logging.fps, caption=all_prompts[i], ) else: self.log_video( "validation_vis/video_pred", video_vis[:16], fps=self.cfg.logging.fps, step=self.global_step, ) def maybe_reset_socket(self): if not self.socket: ctx = zmq.Context() socket = ctx.socket(zmq.ROUTER) socket.setsockopt(zmq.ROUTER_HANDOVER, 1) socket.bind(f"tcp://*:{self.cfg.serving.port}") self.socket = socket print(f"Server ready on port {self.cfg.serving.port}...") @torch.no_grad() def test_step(self, batch, batch_idx): """ This function is used to test the model. It will receive an image and a prompt from remote gradio and generate a video. The remote client shall run scripts/inference_client.py to send requests to this server. """ # Only rank zero sets up the socket if is_rank_zero: self.maybe_reset_socket() print(f"Waiting for request on local rank: {dist.get_rank()}") if is_rank_zero: ident, payload = self.socket.recv_multipart() request = msgpack.unpackb(payload, raw=False) print(f"Received request with prompt: {request['prompt']}") # Prepare data to broadcast image_bytes = request["image"] prompt = request["prompt"] data_to_broadcast = [image_bytes, prompt] else: data_to_broadcast = [None, None] # Broadcast the image and prompt to all ranks dist.broadcast_object_list(data_to_broadcast, src=0) image_bytes, prompt = data_to_broadcast transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), transforms.RandomResizedCrop( size=(self.height, self.width), scale=(1.0, 1.0), ratio=(self.width / self.height, self.width / self.height), interpolation=transforms.InterpolationMode.BICUBIC, ), ] ) pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") image = transform(pil_image) batch["videos"][:, 0] = image[None] prompt_segments = prompt.split("") hist_len = 1 videos = batch["videos"][:, :hist_len] for i, prompt in enumerate(prompt_segments): # extending the video until all prompt segments are used print(f"Generating task {i+1} out of {len(prompt_segments)} sub-tasks") batch["prompts"] = [prompt] * batch["videos"].shape[0] batch["videos"][:, :hist_len] = videos[:, -hist_len:] videos = torch.cat([videos, self.sample_seq(batch, hist_len)], dim=1) videos = torch.clamp(videos, -1, 1) hist_len = self.sliding_hist videos = rearrange(self.all_gather(videos), "p b t c h w -> (p b) t h w c") videos = videos.float().cpu().numpy() # Only rank zero sends the reply if is_rank_zero: videos = np.clip(videos * 0.5 + 0.5, 0, 1) videos = (videos * 255).astype(np.uint8) # Convert videos to mp4 bytes using the utility function video_bytes_list = [ numpy_to_mp4_bytes(video, fps=self.cfg.logging.fps) for video in videos ] # Send the reply reply = {"videos": video_bytes_list} self.socket.send_multipart([ident, msgpack.packb(reply)]) print(f"Sent reply to {ident}") self.log_video( "test_vis/video_pred", rearrange(videos, "b t h w c -> b t c h w"), fps=self.cfg.logging.fps, caption="\n".join(prompt_segments), )