| | import logging |
| | import gc |
| | import torch |
| | import numpy as np |
| | import torch.distributed as dist |
| | from einops import rearrange, repeat |
| | from tqdm import tqdm |
| | from algorithms.common.base_pytorch_algo import BasePytorchAlgo |
| | from transformers import get_scheduler |
| | import zmq |
| | import msgpack |
| | import io |
| | from PIL import Image |
| | import torchvision.transforms as transforms |
| | from utils.video_utils import numpy_to_mp4_bytes |
| |
|
| | from .modules.model import WanModel, WanAttentionBlock |
| | from .modules.t5 import umt5_xxl, T5CrossAttention, T5SelfAttention |
| | from .modules.tokenizers import HuggingfaceTokenizer |
| | from .modules.vae import video_vae_factory |
| | from .utils.fm_solvers import ( |
| | FlowDPMSolverMultistepScheduler, |
| | get_sampling_sigmas, |
| | retrieve_timesteps, |
| | ) |
| | from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler |
| | from utils.distributed_utils import is_rank_zero |
| |
|
| | def print_module_hierarchy(model, indent=0): |
| | for name, module in model.named_children(): |
| | print(" " * indent + f"{name}: {type(module)}") |
| | print_module_hierarchy(module, indent + 2) |
| |
|
| |
|
| | class WanTextToVideo(BasePytorchAlgo): |
| | """ |
| | Main class for WanTextToVideo |
| | """ |
| |
|
| | def __init__(self, cfg): |
| | self.num_train_timesteps = cfg.num_train_timesteps |
| | self.height = cfg.height |
| | self.width = cfg.width |
| | self.n_frames = cfg.n_frames |
| | self.gradient_checkpointing_rate = cfg.gradient_checkpointing_rate |
| | self.sample_solver = cfg.sample_solver |
| | self.sample_steps = cfg.sample_steps |
| | self.sample_shift = cfg.sample_shift |
| | self.lang_guidance = cfg.lang_guidance |
| | self.neg_prompt = cfg.neg_prompt |
| | self.hist_guidance = cfg.hist_guidance |
| | self.sliding_hist = cfg.sliding_hist |
| | self.diffusion_forcing = cfg.diffusion_forcing |
| | self.vae_stride = cfg.vae.stride |
| | self.patch_size = cfg.model.patch_size |
| | self.diffusion_type = cfg.diffusion_type |
| |
|
| | self.lat_h = self.height // self.vae_stride[1] |
| | self.lat_w = self.width // self.vae_stride[2] |
| | self.lat_t = 1 + (self.n_frames - 1) // self.vae_stride[0] |
| | self.lat_c = cfg.vae.z_dim |
| | self.max_area = self.height * self.width |
| | self.max_tokens = ( |
| | self.lat_t |
| | * self.lat_h |
| | * self.lat_w |
| | // (self.patch_size[1] * self.patch_size[2]) |
| | ) |
| |
|
| | self.load_prompt_embed = cfg.load_prompt_embed |
| | self.load_video_latent = cfg.load_video_latent |
| | self.socket = None |
| | if (self.sliding_hist - 1) % self.vae_stride[0] != 0: |
| | raise ValueError( |
| | "sliding_hist - 1 must be a multiple of vae_stride[0] due to temporal " |
| | f"vae. Got {self.sliding_hist} and vae stride {self.vae_stride[0]}" |
| | ) |
| | if self.load_video_latent: |
| | raise NotImplementedError("Loading video latent is not implemented yet") |
| | super().__init__(cfg) |
| |
|
| | @staticmethod |
| | def classes_to_shard(): |
| | classes = {WanAttentionBlock, T5CrossAttention, T5SelfAttention} |
| | return classes |
| |
|
| | @property |
| | def is_inference(self) -> bool: |
| | return self._trainer is None or not self.trainer.training |
| |
|
| | def configure_model(self): |
| | logging.info("Building model...") |
| | |
| | if not self.cfg.load_prompt_embed: |
| | text_encoder = ( |
| | umt5_xxl( |
| | encoder_only=True, |
| | return_tokenizer=False, |
| | dtype=torch.bfloat16 if self.is_inference else self.dtype, |
| | device=torch.device("cpu"), |
| | ) |
| | .eval() |
| | .requires_grad_(False) |
| | ) |
| | if self.cfg.text_encoder.ckpt_path is not None: |
| | text_encoder.load_state_dict( |
| | torch.load( |
| | self.cfg.text_encoder.ckpt_path, |
| | map_location="cpu", |
| | weights_only=True, |
| | |
| | ) |
| | ) |
| | if self.cfg.text_encoder.compile: |
| | text_encoder = torch.compile(text_encoder) |
| | else: |
| | text_encoder = None |
| | self.text_encoder = text_encoder |
| |
|
| | |
| | self.tokenizer = HuggingfaceTokenizer( |
| | name=self.cfg.text_encoder.name, |
| | seq_len=self.cfg.text_encoder.text_len, |
| | clean="whitespace", |
| | ) |
| |
|
| | |
| | self.vae = ( |
| | video_vae_factory( |
| | pretrained_path=self.cfg.vae.ckpt_path, |
| | z_dim=self.cfg.vae.z_dim, |
| | ) |
| | .eval() |
| | .requires_grad_(False) |
| | ).to(self.dtype) |
| | self.register_buffer( |
| | "vae_mean", torch.tensor(self.cfg.vae.mean, dtype=self.dtype) |
| | ) |
| | self.register_buffer( |
| | "vae_inv_std", 1.0 / torch.tensor(self.cfg.vae.std, dtype=self.dtype) |
| | ) |
| | self.vae_scale = [self.vae_mean, self.vae_inv_std] |
| | if self.cfg.vae.compile: |
| | self.vae = torch.compile(self.vae) |
| |
|
| | |
| | if self.cfg.model.tuned_ckpt_path is None: |
| | self.model = WanModel.from_pretrained(self.cfg.model.ckpt_path) |
| | else: |
| | print("Loading model from config") |
| | from accelerate import init_empty_weights, load_checkpoint_and_dispatch |
| | with init_empty_weights(): |
| | self.model = WanModel.from_config( |
| | WanModel._dict_from_json_file(self.cfg.model.ckpt_path + "/config.json") |
| | ) |
| | print("Loading state dict") |
| | self.model = load_checkpoint_and_dispatch( |
| | self.model, |
| | self.cfg.model.tuned_ckpt_path, |
| | device_map="auto", |
| | dtype=torch.bfloat16, |
| | no_split_module_classes=["WanAttentionBlock"], |
| | ) |
| | print("State dict loaded successfully") |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if not self.is_inference: |
| | self.model.to(self.dtype).train() |
| | if self.gradient_checkpointing_rate > 0: |
| | self.model.gradient_checkpointing_enable(p=self.gradient_checkpointing_rate) |
| | if self.cfg.model.compile: |
| | self.model = torch.compile(self.model) |
| |
|
| | self.training_scheduler, self.training_timesteps = self.build_scheduler(True) |
| |
|
| | def configure_optimizers(self): |
| | optimizer = torch.optim.AdamW( |
| | [ |
| | {"params": self.model.parameters(), "lr": self.cfg.lr}, |
| | {"params": self.vae.parameters(), "lr": 0}, |
| | ], |
| | weight_decay=self.cfg.weight_decay, |
| | betas=self.cfg.betas, |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | |
| | lr_scheduler_config = { |
| | "scheduler": get_scheduler( |
| | optimizer=optimizer, |
| | **self.cfg.lr_scheduler, |
| | ), |
| | "interval": "step", |
| | "frequency": 1, |
| | } |
| |
|
| | return { |
| | "optimizer": optimizer, |
| | "lr_scheduler": lr_scheduler_config, |
| | } |
| |
|
| | def _load_tuned_state_dict(self, prefix="model."): |
| | ckpt = torch.load( |
| | self.cfg.model.tuned_ckpt_path, |
| | mmap=True, |
| | map_location="cpu", |
| | weights_only=True, |
| | ) |
| | return ckpt |
| |
|
| | def build_scheduler(self, is_training=True): |
| | |
| | if self.sample_solver == "unipc": |
| | scheduler = FlowUniPCMultistepScheduler( |
| | num_train_timesteps=self.num_train_timesteps, |
| | shift=self.sample_shift, |
| | use_dynamic_shifting=False, |
| | ) |
| | if not is_training: |
| | scheduler.set_timesteps( |
| | self.sample_steps, device=self.device, shift=self.sample_shift |
| | ) |
| | timesteps = scheduler.timesteps |
| | elif self.sample_solver == "dpm++": |
| | scheduler = FlowDPMSolverMultistepScheduler( |
| | num_train_timesteps=self.num_train_timesteps, |
| | shift=self.sample_shift, |
| | use_dynamic_shifting=False, |
| | ) |
| | if not is_training: |
| | sampling_sigmas = get_sampling_sigmas( |
| | self.sample_steps, self.sample_shift |
| | ) |
| | timesteps, _ = retrieve_timesteps( |
| | scheduler, device=self.device, sigmas=sampling_sigmas |
| | ) |
| | else: |
| | raise NotImplementedError("Unsupported solver.") |
| | return scheduler, timesteps |
| |
|
| | def encode_text(self, texts): |
| | ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True) |
| | ids = ids.to(self.device) |
| | mask = mask.to(self.device) |
| | seq_lens = mask.gt(0).sum(dim=1).long() |
| | context = self.text_encoder(ids, mask) |
| | return [u[:v] for u, v in zip(context, seq_lens)] |
| |
|
| | def encode_video(self, videos): |
| | """videos: [B, C, T, H, W]""" |
| | return self.vae.encode(videos, self.vae_scale) |
| |
|
| | def decode_video(self, zs): |
| | return self.vae.decode(zs, self.vae_scale).clamp_(-1, 1) |
| |
|
| | def clone_batch(self, batch): |
| | new_batch = {} |
| | for k, v in batch.items(): |
| | if isinstance(v, torch.Tensor): |
| | new_batch[k] = v.clone() |
| | else: |
| | new_batch[k] = v |
| | return new_batch |
| |
|
| | @torch.no_grad() |
| | def prepare_embeds(self, batch): |
| | videos = batch["videos"] |
| | prompts = batch["prompts"] |
| |
|
| | batch_size, t, _, h, w = videos.shape |
| |
|
| | if t != self.n_frames: |
| | raise ValueError(f"Number of frames in videos must be {self.n_frames}") |
| | if h != self.height or w != self.width: |
| | raise ValueError( |
| | f"Height and width of videos must be {self.height} and {self.width}" |
| | ) |
| |
|
| | if not self.cfg.load_prompt_embed: |
| | prompt_embeds = self.encode_text(prompts) |
| | else: |
| | prompt_embeds = batch["prompt_embeds"].to(self.dtype) |
| | prompt_embed_len = batch["prompt_embed_len"] |
| | prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, prompt_embed_len)] |
| |
|
| | video_lat = self.encode_video(rearrange(videos, "b t c h w -> b c t h w")) |
| | |
| |
|
| | batch["prompt_embeds"] = prompt_embeds |
| | batch["video_lat"] = video_lat |
| | batch["image_embeds"] = None |
| | batch["clip_embeds"] = None |
| |
|
| | return batch |
| |
|
| | def add_training_noise(self, video_lat): |
| | b, _, f = video_lat.shape[:3] |
| | device = video_lat.device |
| | if self.diffusion_type == "discrete": |
| | video_lat = rearrange(video_lat, "b c f h w -> (b f) c h w") |
| | noise = torch.randn_like(video_lat) |
| | timesteps = self.num_train_timesteps |
| | if self.diffusion_forcing.enabled: |
| | match self.diffusion_forcing.mode: |
| | case "independent": |
| | t = np.random.randint(timesteps, size=(b, f)) |
| | if np.random.rand() < self.diffusion_forcing.clean_hist_prob: |
| | t[:, 0] = timesteps - 1 |
| | case "rand_history": |
| | |
| | possible_hist_lengths = [1, 2, 3, 4, 5, 6] |
| | hist_length_probs = [0.5, 0.1, 0.1, 0.1, 0.1, 0.1] |
| | t = np.zeros((b, f), dtype=np.int64) |
| | for i in range(b): |
| | hist_len_idx = np.random.choice( |
| | len(possible_hist_lengths), p=hist_length_probs |
| | ) |
| | hist_len = possible_hist_lengths[hist_len_idx] |
| | history_t = np.random.randint(timesteps) |
| | future_t = np.random.randint(timesteps) |
| | t[i, :hist_len] = history_t |
| | t[i, hist_len:] = future_t |
| | if ( |
| | np.random.rand() |
| | < self.diffusion_forcing.clean_hist_prob |
| | ): |
| | t[i, :hist_len] = timesteps - 1 |
| | t = self.training_timesteps[t.flatten()].reshape(b, f) |
| | t_expanded = t.flatten() |
| | else: |
| | t = np.random.randint(timesteps, size=(b,)) |
| | t_expanded = repeat(t, "b -> (b f)", f=f) |
| | t = self.training_timesteps[t] |
| | t_expanded = self.training_timesteps[t_expanded] |
| |
|
| | noisy_lat = self.training_scheduler.add_noise(video_lat, noise, t_expanded) |
| | noisy_lat = rearrange(noisy_lat, "(b f) c h w -> b c f h w", b=b, f=f) |
| | noise = rearrange(noise, "(b f) c h w -> b c f h w", b=b, f=f) |
| | elif self.diffusion_type == "continuous": |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | dist = torch.distributions.uniform.Uniform(0, 1) |
| | noise = torch.randn_like(video_lat) |
| |
|
| | if self.diffusion_forcing.enabled: |
| | match self.diffusion_forcing.mode: |
| | case "independent": |
| | t = dist.sample((b, f)).to(device) |
| | if np.random.rand() < self.diffusion_forcing.clean_hist_prob: |
| | t[:, 0] = 0.0 |
| | case "rand_history": |
| | |
| | possible_hist_lengths = [1, 2, 3, 4, 5, 6] |
| | hist_length_probs = [0.5, 0.1, 0.1, 0.1, 0.1, 0.1] |
| | t = np.zeros((b, f), dtype=np.float32) |
| | for i in range(b): |
| | hist_len_idx = np.random.choice( |
| | len(possible_hist_lengths), p=hist_length_probs |
| | ) |
| | hist_len = possible_hist_lengths[hist_len_idx] |
| | history_t = np.random.uniform(0, 1) |
| | future_t = np.random.uniform(0, 1) |
| | t[i, :hist_len] = history_t |
| | t[i, hist_len:] = future_t |
| | if ( |
| | np.random.rand() |
| | < self.diffusion_forcing.clean_hist_prob |
| | ): |
| | t[i, :hist_len] = 0 |
| |
|
| | |
| | t = torch.from_numpy(t).to(device) |
| | t = t.float() |
| | |
| | t = t * self.sample_shift / (1 + (self.sample_shift - 1) * t) |
| | t_expanded = ( |
| | t.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) |
| | ) |
| |
|
| | |
| | noisy_lat = video_lat * (1.0 - t_expanded) + noise * t_expanded |
| | t = t * self.num_train_timesteps |
| | |
| | else: |
| | t = dist.sample((b,)).to(device) |
| | t = t * self.sample_shift / (1 + (self.sample_shift - 1) * t) |
| | t_expanded = t.view(-1, 1, 1, 1, 1) |
| |
|
| | noisy_lat = video_lat * (1.0 - t_expanded) + noise * t_expanded |
| | t = t * self.num_train_timesteps |
| | |
| | else: |
| | raise NotImplementedError("Unsupported time step type.") |
| |
|
| | return noisy_lat, noise, t |
| |
|
| | def remove_noise(self, flow_pred, t, video_pred_lat): |
| | b, _, f = video_pred_lat.shape[:3] |
| | video_pred_lat = rearrange(video_pred_lat, "b c f h w -> (b f) c h w") |
| | flow_pred = rearrange(flow_pred, "b c f h w -> (b f) c h w") |
| | if t.ndim == 1: |
| | t = repeat(t, "b -> (b f)", f=f) |
| | elif t.ndim == 2: |
| | t = t.flatten() |
| | video_pred_lat = self.inference_scheduler.step( |
| | flow_pred, |
| | t, |
| | video_pred_lat, |
| | return_dict=False, |
| | )[0] |
| | video_pred_lat = rearrange(video_pred_lat, "(b f) c h w -> b c f h w", b=b) |
| | return video_pred_lat |
| |
|
| | def training_step(self, batch, batch_idx=None): |
| | batch = self.prepare_embeds(batch) |
| | clip_embeds = batch["clip_embeds"] |
| | image_embeds = batch["image_embeds"] |
| | prompt_embeds = batch["prompt_embeds"] |
| | video_lat = batch["video_lat"] |
| |
|
| | noisy_lat, noise, t = self.add_training_noise(video_lat) |
| | flow = noise - video_lat |
| |
|
| | flow_pred = self.model( |
| | noisy_lat, |
| | t=t, |
| | context=prompt_embeds, |
| | clip_fea=clip_embeds, |
| | seq_len=self.max_tokens, |
| | y=image_embeds, |
| | ) |
| | loss = torch.nn.functional.mse_loss(flow_pred, flow) |
| |
|
| | if self.global_step % self.cfg.logging.loss_freq == 0: |
| | self.log("train/loss", loss, sync_dist=True) |
| |
|
| | return loss |
| |
|
| | @torch.no_grad() |
| | def sample_seq(self, batch, hist_len=1, pbar=None): |
| | """ |
| | Main sampling loop. Only first hist_len frames are used for conditioning |
| | batch: dict |
| | batch["videos"]: [B, T, C, H, W] |
| | batch["prompts"]: [B] |
| | """ |
| | if (hist_len - 1) % self.vae_stride[0] != 0: |
| | raise ValueError( |
| | "hist_len - 1 must be a multiple of vae_stride[0] due to temporal vae. " |
| | f"Got {hist_len} and vae stride {self.vae_stride[0]}" |
| | ) |
| | hist_len = (hist_len - 1) // self.vae_stride[0] + 1 |
| |
|
| | self.inference_scheduler, self.inference_timesteps = self.build_scheduler(False) |
| | lang_guidance = self.lang_guidance if self.lang_guidance else 0 |
| | hist_guidance = self.hist_guidance if self.hist_guidance else 0 |
| |
|
| | batch = self.prepare_embeds(batch) |
| | clip_embeds = batch["clip_embeds"] |
| | image_embeds = batch["image_embeds"] |
| | prompt_embeds = batch["prompt_embeds"] |
| | video_lat = batch["video_lat"] |
| |
|
| | batch_size = video_lat.shape[0] |
| |
|
| | video_pred_lat = torch.randn_like(video_lat) |
| | if self.lang_guidance: |
| | neg_prompt_embeds = self.encode_text( |
| | [self.neg_prompt] * len(batch["prompts"]) |
| | ) |
| | if pbar is None: |
| | pbar = tqdm(range(len(self.inference_timesteps)), desc="Sampling") |
| | for t in self.inference_timesteps: |
| | if self.diffusion_forcing.enabled: |
| | video_pred_lat[:, :, :hist_len] = video_lat[:, :, :hist_len] |
| | t_expanded = torch.full((batch_size, self.lat_t), t, device=self.device) |
| | t_expanded[:, :hist_len] = self.inference_timesteps[-1] |
| | else: |
| | t_expanded = torch.full((batch_size,), t, device=self.device) |
| |
|
| | |
| | flow_pred = self.model( |
| | video_pred_lat, |
| | t=t_expanded, |
| | context=prompt_embeds, |
| | seq_len=self.max_tokens, |
| | clip_fea=clip_embeds, |
| | y=image_embeds, |
| | ) |
| |
|
| | if lang_guidance and hist_guidance and self.diffusion_forcing.enabled and lang_guidance == hist_guidance: |
| | |
| | no_hist_video_pred_lat = video_pred_lat.clone() |
| | no_hist_video_pred_lat[:, :, :hist_len] = torch.randn_like( |
| | no_hist_video_pred_lat[:, :, :hist_len] |
| | ) |
| | t_expanded[:, :hist_len] = self.inference_timesteps[0] |
| | no_cond_flow_pred = self.model( |
| | no_hist_video_pred_lat, |
| | t=t_expanded, |
| | context=neg_prompt_embeds, |
| | seq_len=self.max_tokens, |
| | clip_fea=clip_embeds, |
| | y=image_embeds, |
| | ) |
| | flow_pred = flow_pred * (1 + lang_guidance) - lang_guidance * no_cond_flow_pred |
| |
|
| | else: |
| | |
| | if lang_guidance: |
| | no_lang_flow_pred = self.model( |
| | video_pred_lat, |
| | t=t_expanded, |
| | context=neg_prompt_embeds, |
| | seq_len=self.max_tokens, |
| | clip_fea=clip_embeds, |
| | y=image_embeds, |
| | ) |
| | else: |
| | no_lang_flow_pred = torch.zeros_like(flow_pred) |
| |
|
| | |
| | if hist_guidance and self.diffusion_forcing.enabled: |
| | no_hist_video_pred_lat = video_pred_lat.clone() |
| | no_hist_video_pred_lat[:, :, :hist_len] = torch.randn_like( |
| | no_hist_video_pred_lat[:, :, :hist_len] |
| | ) |
| | t_expanded[:, :hist_len] = self.inference_timesteps[0] |
| | no_hist_flow_pred = self.model( |
| | no_hist_video_pred_lat, |
| | t=t_expanded, |
| | context=prompt_embeds, |
| | seq_len=self.max_tokens, |
| | clip_fea=clip_embeds, |
| | y=image_embeds, |
| | ) |
| | else: |
| | no_hist_flow_pred = torch.zeros_like(flow_pred) |
| |
|
| | flow_pred = flow_pred * (1 + lang_guidance + hist_guidance) |
| | flow_pred = ( |
| | flow_pred |
| | - lang_guidance * no_lang_flow_pred |
| | - hist_guidance * no_hist_flow_pred |
| | ) |
| |
|
| | video_pred_lat = self.remove_noise(flow_pred, t, video_pred_lat) |
| | pbar.update(1) |
| |
|
| | video_pred_lat[:, :, :hist_len] = video_lat[:, :, :hist_len] |
| | video_pred = self.decode_video(video_pred_lat) |
| | video_pred = rearrange(video_pred, "b c t h w -> b t c h w") |
| | return video_pred |
| |
|
| | def validation_step(self, batch, batch_idx=None): |
| | video_pred = self.sample_seq(batch) |
| | self.visualize(video_pred, batch) |
| |
|
| | def visualize(self, video_pred, batch): |
| | video_gt = batch["videos"] |
| |
|
| | if self.cfg.logging.video_type == "single": |
| | video_vis = video_pred.cpu() |
| | else: |
| | video_vis = torch.cat([video_pred, video_gt], dim=-1).cpu() |
| | video_vis = video_vis * 0.5 + 0.5 |
| | video_vis = rearrange(self.all_gather(video_vis), "p b ... -> (p b) ...") |
| |
|
| | all_prompts = [None for _ in range(dist.get_world_size())] |
| | dist.all_gather_object(all_prompts, batch["prompts"]) |
| | all_prompts = [item for sublist in all_prompts for item in sublist] |
| |
|
| | if is_rank_zero: |
| | if self.cfg.logging.video_type == "single": |
| | for i in range(min(len(video_vis), 16)): |
| | self.log_video( |
| | f"validation_vis/video_pred_{i}", |
| | video_vis[i], |
| | fps=self.cfg.logging.fps, |
| | caption=all_prompts[i], |
| | ) |
| | else: |
| | self.log_video( |
| | "validation_vis/video_pred", |
| | video_vis[:16], |
| | fps=self.cfg.logging.fps, |
| | step=self.global_step, |
| | ) |
| |
|
| | def maybe_reset_socket(self): |
| | if not self.socket: |
| | ctx = zmq.Context() |
| | socket = ctx.socket(zmq.ROUTER) |
| | socket.setsockopt(zmq.ROUTER_HANDOVER, 1) |
| | socket.bind(f"tcp://*:{self.cfg.serving.port}") |
| | self.socket = socket |
| |
|
| | print(f"Server ready on port {self.cfg.serving.port}...") |
| |
|
| | @torch.no_grad() |
| | def test_step(self, batch, batch_idx): |
| | """ |
| | This function is used to test the model. |
| | It will receive an image and a prompt from remote gradio and generate a video. |
| | The remote client shall run scripts/inference_client.py to send requests to this server. |
| | """ |
| |
|
| | |
| | if is_rank_zero: |
| | self.maybe_reset_socket() |
| |
|
| | print(f"Waiting for request on local rank: {dist.get_rank()}") |
| | if is_rank_zero: |
| | ident, payload = self.socket.recv_multipart() |
| | request = msgpack.unpackb(payload, raw=False) |
| | print(f"Received request with prompt: {request['prompt']}") |
| |
|
| | |
| | image_bytes = request["image"] |
| | prompt = request["prompt"] |
| | data_to_broadcast = [image_bytes, prompt] |
| | else: |
| | data_to_broadcast = [None, None] |
| |
|
| | |
| | dist.broadcast_object_list(data_to_broadcast, src=0) |
| | image_bytes, prompt = data_to_broadcast |
| | transform = transforms.Compose( |
| | [ |
| | transforms.ToTensor(), |
| | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
| | transforms.RandomResizedCrop( |
| | size=(self.height, self.width), |
| | scale=(1.0, 1.0), |
| | ratio=(self.width / self.height, self.width / self.height), |
| | interpolation=transforms.InterpolationMode.BICUBIC, |
| | ), |
| | ] |
| | ) |
| | pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| | image = transform(pil_image) |
| | batch["videos"][:, 0] = image[None] |
| |
|
| | prompt_segments = prompt.split("<sep>") |
| | hist_len = 1 |
| | videos = batch["videos"][:, :hist_len] |
| | for i, prompt in enumerate(prompt_segments): |
| | |
| | print(f"Generating task {i+1} out of {len(prompt_segments)} sub-tasks") |
| | batch["prompts"] = [prompt] * batch["videos"].shape[0] |
| | batch["videos"][:, :hist_len] = videos[:, -hist_len:] |
| | videos = torch.cat([videos, self.sample_seq(batch, hist_len)], dim=1) |
| | videos = torch.clamp(videos, -1, 1) |
| | hist_len = self.sliding_hist |
| | videos = rearrange(self.all_gather(videos), "p b t c h w -> (p b) t h w c") |
| | videos = videos.float().cpu().numpy() |
| |
|
| | |
| | if is_rank_zero: |
| | videos = np.clip(videos * 0.5 + 0.5, 0, 1) |
| | videos = (videos * 255).astype(np.uint8) |
| | |
| | video_bytes_list = [ |
| | numpy_to_mp4_bytes(video, fps=self.cfg.logging.fps) for video in videos |
| | ] |
| |
|
| | |
| | reply = {"videos": video_bytes_list} |
| | self.socket.send_multipart([ident, msgpack.packb(reply)]) |
| | print(f"Sent reply to {ident}") |
| |
|
| | self.log_video( |
| | "test_vis/video_pred", |
| | rearrange(videos, "b t h w c -> b t c h w"), |
| | fps=self.cfg.logging.fps, |
| | caption="<sep>\n".join(prompt_segments), |
| | ) |
| |
|