| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | import importlib |
| | from contextlib import contextmanager |
| | from typing import List, NamedTuple, Optional, Tuple |
| |
|
| | from Cosmos.utils import misc |
| | import einops |
| | import imageio |
| | import numpy as np |
| | import torch |
| | import torchvision.transforms.functional as transforms_F |
| |
|
| | from Cosmos.model_t2w import DiffusionT2WModel |
| | from Cosmos.model_v2w import DiffusionV2WModel |
| | from Cosmos.utils import log |
| | from Cosmos.utils.config_helper import get_config_module, override |
| | from Cosmos.utils.io import load_from_fileobj |
| |
|
| | TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) |
| | if TORCH_VERSION >= (1, 11): |
| | from torch.ao import quantization |
| | from torch.ao.quantization import FakeQuantizeBase, ObserverBase |
| | elif ( |
| | TORCH_VERSION >= (1, 8) |
| | and hasattr(torch.quantization, "FakeQuantizeBase") |
| | and hasattr(torch.quantization, "ObserverBase") |
| | ): |
| | from torch import quantization |
| | from torch.quantization import FakeQuantizeBase, ObserverBase |
| |
|
| | DEFAULT_AUGMENT_SIGMA = 0.001 |
| |
|
| |
|
| | def add_common_arguments(parser): |
| | """Add common command line arguments for text2world and video2world generation. |
| | |
| | Args: |
| | parser (ArgumentParser): Argument parser to add arguments to |
| | |
| | The arguments include: |
| | - checkpoint_dir: Base directory containing model weights |
| | - tokenizer_dir: Directory containing tokenizer weights |
| | - video_save_name: Output video filename for single video generation |
| | - video_save_folder: Output directory for batch video generation |
| | - prompt: Text prompt for single video generation |
| | - batch_input_path: Path to JSONL file with input prompts for batch video generation |
| | - negative_prompt: Text prompt describing undesired attributes |
| | - num_steps: Number of diffusion sampling steps |
| | - guidance: Classifier-free guidance scale |
| | - num_video_frames: Number of frames to generate |
| | - height/width: Output video dimensions |
| | - fps: Output video frame rate |
| | - seed: Random seed for reproducibility |
| | - Various model offloading flags |
| | """ |
| | parser.add_argument( |
| | "--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints" |
| | ) |
| | parser.add_argument( |
| | "--tokenizer_dir", |
| | type=str, |
| | default="Cosmos-1.0-Tokenizer-CV8x8x8", |
| | help="Tokenizer weights directory relative to checkpoint_dir", |
| | ) |
| | parser.add_argument( |
| | "--video_save_name", |
| | type=str, |
| | default="output", |
| | help="Output filename for generating a single video", |
| | ) |
| | parser.add_argument( |
| | "--video_save_folder", |
| | type=str, |
| | default="outputs/", |
| | help="Output folder for generating a batch of videos", |
| | ) |
| | parser.add_argument( |
| | "--prompt", |
| | type=str, |
| | help="Text prompt for generating a single video", |
| | ) |
| | parser.add_argument( |
| | "--batch_input_path", |
| | type=str, |
| | help="Path to a JSONL file of input prompts for generating a batch of videos", |
| | ) |
| | parser.add_argument( |
| | "--negative_prompt", |
| | type=str, |
| | default="The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " |
| | "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " |
| | "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, " |
| | "jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special " |
| | "effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and " |
| | "flickering. Overall, the video is of poor quality.", |
| | help="Negative prompt for the video", |
| | ) |
| | parser.add_argument("--num_steps", type=int, default=35, help="Number of diffusion sampling steps") |
| | parser.add_argument("--guidance", type=float, default=7, help="Guidance scale value") |
| | parser.add_argument("--num_video_frames", type=int, default=121, help="Number of video frames to sample") |
| | parser.add_argument("--height", type=int, default=704, help="Height of video to sample") |
| | parser.add_argument("--width", type=int, default=1280, help="Width of video to sample") |
| | parser.add_argument("--fps", type=int, default=24, help="FPS of the sampled video") |
| | parser.add_argument("--seed", type=int, default=1, help="Random seed") |
| | parser.add_argument( |
| | "--disable_prompt_upsampler", |
| | action="store_true", |
| | help="Disable prompt upsampling", |
| | ) |
| | parser.add_argument( |
| | "--offload_diffusion_transformer", |
| | action="store_true", |
| | help="Offload DiT after inference", |
| | ) |
| | parser.add_argument( |
| | "--offload_tokenizer", |
| | action="store_true", |
| | help="Offload tokenizer after inference", |
| | ) |
| | parser.add_argument( |
| | "--offload_text_encoder_model", |
| | action="store_true", |
| | help="Offload text encoder model after inference", |
| | ) |
| | parser.add_argument( |
| | "--offload_prompt_upsampler", |
| | action="store_true", |
| | help="Offload prompt upsampler after inference", |
| | ) |
| | parser.add_argument( |
| | "--offload_guardrail_models", |
| | action="store_true", |
| | help="Offload guardrail models after inference", |
| | ) |
| |
|
| |
|
| | def validate_args(args: argparse.Namespace, inference_type: str) -> None: |
| | """Validate command line arguments for text2world and video2world generation.""" |
| | assert inference_type in [ |
| | "text2world", |
| | "video2world", |
| | ], "Invalid inference_type, must be 'text2world' or 'video2world'" |
| |
|
| | |
| | if inference_type == "text2world" or (inference_type == "video2world" and args.disable_prompt_upsampler): |
| | assert args.prompt or args.batch_input_path, "--prompt or --batch_input_path must be provided." |
| | if inference_type == "video2world" and not args.batch_input_path: |
| | assert ( |
| | args.input_image_or_video_path |
| | ), "--input_image_or_video_path must be provided for single video generation." |
| |
|
| |
|
| | class _IncompatibleKeys( |
| | NamedTuple( |
| | "IncompatibleKeys", |
| | [ |
| | ("missing_keys", List[str]), |
| | ("unexpected_keys", List[str]), |
| | ("incorrect_shapes", List[Tuple[str, Tuple[int], Tuple[int]]]), |
| | ], |
| | ) |
| | ): |
| | pass |
| |
|
| |
|
| | def non_strict_load_model(model: torch.nn.Module, checkpoint_state_dict: dict) -> _IncompatibleKeys: |
| | """Load a model checkpoint with non-strict matching, handling shape mismatches. |
| | |
| | Args: |
| | model (torch.nn.Module): Model to load weights into |
| | checkpoint_state_dict (dict): State dict from checkpoint |
| | |
| | Returns: |
| | _IncompatibleKeys: Named tuple containing: |
| | - missing_keys: Keys present in model but missing from checkpoint |
| | - unexpected_keys: Keys present in checkpoint but not in model |
| | - incorrect_shapes: Keys with mismatched tensor shapes |
| | |
| | The function handles special cases like: |
| | - Uninitialized parameters |
| | - Quantization observers |
| | - TransformerEngine FP8 states |
| | """ |
| | |
| | model_state_dict = model.state_dict() |
| | incorrect_shapes = [] |
| | for k in list(checkpoint_state_dict.keys()): |
| | if k in model_state_dict: |
| | if "_extra_state" in k: |
| | log.debug(f"Skipping key {k} introduced by TransformerEngine for FP8 in the checkpoint.") |
| | continue |
| | model_param = model_state_dict[k] |
| | |
| | if TORCH_VERSION >= (1, 8) and isinstance(model_param, torch.nn.parameter.UninitializedParameter): |
| | continue |
| | if not isinstance(model_param, torch.Tensor): |
| | raise ValueError( |
| | f"Find non-tensor parameter {k} in the model. type: {type(model_param)} {type(checkpoint_state_dict[k])}, please check if this key is safe to skip or not." |
| | ) |
| |
|
| | shape_model = tuple(model_param.shape) |
| | shape_checkpoint = tuple(checkpoint_state_dict[k].shape) |
| | if shape_model != shape_checkpoint: |
| | has_observer_base_classes = ( |
| | TORCH_VERSION >= (1, 8) |
| | and hasattr(quantization, "ObserverBase") |
| | and hasattr(quantization, "FakeQuantizeBase") |
| | ) |
| | if has_observer_base_classes: |
| | |
| | |
| | def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module: |
| | |
| | key_parts = key.split(".")[:-1] |
| | cur_module = model |
| | for key_part in key_parts: |
| | cur_module = getattr(cur_module, key_part) |
| | return cur_module |
| |
|
| | cls_to_skip = ( |
| | ObserverBase, |
| | FakeQuantizeBase, |
| | ) |
| | target_module = _get_module_for_key(model, k) |
| | if isinstance(target_module, cls_to_skip): |
| | |
| | |
| | |
| | continue |
| |
|
| | incorrect_shapes.append((k, shape_checkpoint, shape_model)) |
| | checkpoint_state_dict.pop(k) |
| | incompatible = model.load_state_dict(checkpoint_state_dict, strict=False) |
| | |
| | missing_keys = [k for k in incompatible.missing_keys if "_extra_state" not in k] |
| | unexpected_keys = [k for k in incompatible.unexpected_keys if "_extra_state" not in k] |
| | return _IncompatibleKeys( |
| | missing_keys=missing_keys, |
| | unexpected_keys=unexpected_keys, |
| | incorrect_shapes=incorrect_shapes, |
| | ) |
| |
|
| |
|
| | @contextmanager |
| | def skip_init_linear(): |
| | |
| | orig_reset_parameters = torch.nn.Linear.reset_parameters |
| | torch.nn.Linear.reset_parameters = lambda x: x |
| | xavier_uniform_ = torch.nn.init.xavier_uniform_ |
| | torch.nn.init.xavier_uniform_ = lambda x: x |
| | yield |
| | torch.nn.Linear.reset_parameters = orig_reset_parameters |
| | torch.nn.init.xavier_uniform_ = xavier_uniform_ |
| |
|
| |
|
| | def load_model_by_config( |
| | config_job_name, |
| | config_file="projects/cosmos_video/config/config.py", |
| | model_class=DiffusionT2WModel, |
| | ): |
| | config_module = get_config_module(config_file) |
| | config = importlib.import_module(config_module).make_config() |
| |
|
| | config = override(config, ["--", f"experiment={config_job_name}"]) |
| |
|
| | |
| | config.validate() |
| | |
| | config.freeze() |
| |
|
| | |
| | with skip_init_linear(): |
| | model = model_class(config.model) |
| | return model |
| |
|
| |
|
| | def load_network_model(model: DiffusionT2WModel, ckpt_path: str): |
| | with skip_init_linear(): |
| | model.set_up_model() |
| | net_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) |
| | log.debug(non_strict_load_model(model.model, net_state_dict)) |
| | model.cuda() |
| |
|
| |
|
| | def load_tokenizer_model(model: DiffusionT2WModel, tokenizer_dir: str): |
| | with skip_init_linear(): |
| | model.set_up_tokenizer(tokenizer_dir) |
| | model.cuda() |
| |
|
| |
|
| | def prepare_data_batch( |
| | height: int, |
| | width: int, |
| | num_frames: int, |
| | fps: int, |
| | prompt_embedding: torch.Tensor, |
| | negative_prompt_embedding: Optional[torch.Tensor] = None, |
| | ): |
| | """Prepare input batch tensors for video generation. |
| | |
| | Args: |
| | height (int): Height of video frames |
| | width (int): Width of video frames |
| | num_frames (int): Number of frames to generate |
| | fps (int): Frames per second |
| | prompt_embedding (torch.Tensor): Encoded text prompt embeddings |
| | negative_prompt_embedding (torch.Tensor, optional): Encoded negative prompt embeddings |
| | |
| | Returns: |
| | dict: Batch dictionary containing: |
| | - video: Zero tensor of target video shape |
| | - t5_text_mask: Attention mask for text embeddings |
| | - image_size: Target frame dimensions |
| | - fps: Target frame rate |
| | - num_frames: Number of frames |
| | - padding_mask: Frame padding mask |
| | - t5_text_embeddings: Prompt embeddings |
| | - neg_t5_text_embeddings: Negative prompt embeddings (if provided) |
| | - neg_t5_text_mask: Mask for negative embeddings (if provided) |
| | """ |
| | |
| | data_batch = { |
| | "video": torch.zeros((1, 3, num_frames, height, width), dtype=torch.uint8).cuda(), |
| | "t5_text_mask": torch.ones(1, 512, dtype=torch.bfloat16).cuda(), |
| | "image_size": torch.tensor([[height, width, height, width]] * 1, dtype=torch.bfloat16).cuda(), |
| | "fps": torch.tensor([fps] * 1, dtype=torch.bfloat16).cuda(), |
| | "num_frames": torch.tensor([num_frames] * 1, dtype=torch.bfloat16).cuda(), |
| | "padding_mask": torch.zeros((1, 1, height, width), dtype=torch.bfloat16).cuda(), |
| | } |
| |
|
| | |
| |
|
| | t5_embed = prompt_embedding.to(dtype=torch.bfloat16).cuda() |
| | data_batch["t5_text_embeddings"] = t5_embed |
| |
|
| | if negative_prompt_embedding is not None: |
| | neg_t5_embed = negative_prompt_embedding.to(dtype=torch.bfloat16).cuda() |
| | data_batch["neg_t5_text_embeddings"] = neg_t5_embed |
| | data_batch["neg_t5_text_mask"] = torch.ones(1, 512, dtype=torch.bfloat16).cuda() |
| |
|
| | return data_batch |
| |
|
| |
|
| | def get_video_batch(model, prompt_embedding, negative_prompt_embedding, height, width, fps, num_video_frames): |
| | """Prepare complete input batch for video generation including latent dimensions. |
| | |
| | Args: |
| | model: Diffusion model instance |
| | prompt_embedding (torch.Tensor): Text prompt embeddings |
| | negative_prompt_embedding (torch.Tensor): Negative prompt embeddings |
| | height (int): Output video height |
| | width (int): Output video width |
| | fps (int): Output video frame rate |
| | num_video_frames (int): Number of frames to generate |
| | |
| | Returns: |
| | tuple: |
| | - data_batch (dict): Complete model input batch |
| | - state_shape (list): Shape of latent state [C,T,H,W] accounting for VAE compression |
| | """ |
| | raw_video_batch = prepare_data_batch( |
| | height=height, |
| | width=width, |
| | num_frames=num_video_frames, |
| | fps=fps, |
| | prompt_embedding=prompt_embedding, |
| | negative_prompt_embedding=negative_prompt_embedding, |
| | ) |
| | state_shape = [ |
| | model.tokenizer.channel, |
| | model.tokenizer.get_latent_num_frames(num_video_frames), |
| | height // model.tokenizer.spatial_compression_factor, |
| | width // model.tokenizer.spatial_compression_factor, |
| | ] |
| | return raw_video_batch, state_shape |
| |
|
| |
|
| | def generate_world_from_text( |
| | model: DiffusionT2WModel, |
| | state_shape: list[int], |
| | is_negative_prompt: bool, |
| | data_batch: dict, |
| | guidance: float, |
| | num_steps: int, |
| | seed: int, |
| | ): |
| | """Generate video from text prompt using diffusion model. |
| | |
| | Args: |
| | model (DiffusionT2WModel): Text-to-video diffusion model |
| | state_shape (list[int]): Latent state dimensions [C,T,H,W] |
| | is_negative_prompt (bool): Whether negative prompt is provided |
| | data_batch (dict): Model input batch with embeddings |
| | guidance (float): Classifier-free guidance scale |
| | num_steps (int): Number of diffusion sampling steps |
| | seed (int): Random seed for reproducibility |
| | |
| | Returns: |
| | np.ndarray: Generated video frames [T,H,W,C], range [0,255] |
| | |
| | The function: |
| | 1. Initializes random latent with maximum noise |
| | 2. Performs guided diffusion sampling |
| | 3. Decodes latents to pixel space |
| | """ |
| | x_sigma_max = ( |
| | misc.arch_invariant_rand( |
| | (1,) + tuple(state_shape), |
| | torch.float32, |
| | model.tensor_kwargs["device"], |
| | seed, |
| | ) |
| | * model.sde.sigma_max |
| | ) |
| |
|
| | |
| | sample = model.generate_samples_from_batch( |
| | data_batch, |
| | guidance=guidance, |
| | state_shape=state_shape, |
| | num_steps=num_steps, |
| | is_negative_prompt=is_negative_prompt, |
| | seed=seed, |
| | x_sigma_max=x_sigma_max, |
| | ) |
| |
|
| | return sample |
| |
|
| |
|
| | def generate_world_from_video( |
| | model: DiffusionV2WModel, |
| | state_shape: list[int], |
| | is_negative_prompt: bool, |
| | data_batch: dict, |
| | guidance: float, |
| | num_steps: int, |
| | seed: int, |
| | condition_latent: torch.Tensor, |
| | num_input_frames: int, |
| | ) -> Tuple[np.array, list, list]: |
| | """Generate video using a conditioning video/image input. |
| | |
| | Args: |
| | model (DiffusionV2WModel): The diffusion model instance |
| | state_shape (list[int]): Shape of the latent state [C,T,H,W] |
| | is_negative_prompt (bool): Whether negative prompt is provided |
| | data_batch (dict): Batch containing model inputs including text embeddings |
| | guidance (float): Classifier-free guidance scale for sampling |
| | num_steps (int): Number of diffusion sampling steps |
| | seed (int): Random seed for generation |
| | condition_latent (torch.Tensor): Latent tensor from conditioning video/image file |
| | num_input_frames (int): Number of input frames |
| | |
| | Returns: |
| | np.array: Generated video frames in shape [T,H,W,C], range [0,255] |
| | """ |
| | assert not model.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i, "not supported" |
| | augment_sigma = DEFAULT_AUGMENT_SIGMA |
| |
|
| | if condition_latent.shape[2] < state_shape[1]: |
| | |
| | b, c, t, h, w = condition_latent.shape |
| | condition_latent = torch.cat( |
| | [ |
| | condition_latent, |
| | condition_latent.new_zeros(b, c, state_shape[1] - t, h, w), |
| | ], |
| | dim=2, |
| | ).contiguous() |
| | num_of_latent_condition = compute_num_latent_frames(model, num_input_frames) |
| |
|
| | x_sigma_max = ( |
| | misc.arch_invariant_rand( |
| | (1,) + tuple(state_shape), |
| | torch.float32, |
| | model.tensor_kwargs["device"], |
| | seed, |
| | ) |
| | * model.sde.sigma_max |
| | ) |
| |
|
| | sample = model.generate_samples_from_batch( |
| | data_batch, |
| | guidance=guidance, |
| | state_shape=state_shape, |
| | num_steps=num_steps, |
| | is_negative_prompt=is_negative_prompt, |
| | seed=seed, |
| | condition_latent=condition_latent, |
| | num_condition_t=num_of_latent_condition, |
| | condition_video_augment_sigma_in_inference=augment_sigma, |
| | x_sigma_max=x_sigma_max, |
| | ) |
| | return sample |
| |
|
| |
|
| | def read_video_or_image_into_frames_BCTHW( |
| | input_path: str, |
| | input_path_format: str = "mp4", |
| | H: int = None, |
| | W: int = None, |
| | normalize: bool = True, |
| | max_frames: int = -1, |
| | also_return_fps: bool = False, |
| | ) -> torch.Tensor: |
| | """Read video or image file and convert to tensor format. |
| | |
| | Args: |
| | input_path (str): Path to input video/image file |
| | input_path_format (str): Format of input file (default: "mp4") |
| | H (int, optional): Height to resize frames to |
| | W (int, optional): Width to resize frames to |
| | normalize (bool): Whether to normalize pixel values to [-1,1] (default: True) |
| | max_frames (int): Maximum number of frames to read (-1 for all frames) |
| | also_return_fps (bool): Whether to return fps along with frames |
| | |
| | Returns: |
| | torch.Tensor | tuple: Video tensor in shape [B,C,T,H,W], optionally with fps if requested |
| | """ |
| | log.debug(f"Reading video from {input_path}") |
| |
|
| | loaded_data = load_from_fileobj(input_path, format=input_path_format) |
| | frames, meta_data = loaded_data |
| | if input_path.endswith(".png") or input_path.endswith(".jpg") or input_path.endswith(".jpeg"): |
| | frames = np.array(frames[0]) |
| | if frames.shape[-1] > 3: |
| | |
| | rgb_channels = frames[..., :3] |
| | alpha_channel = frames[..., 3] / 255.0 |
| |
|
| | |
| | white_bg = np.ones_like(rgb_channels) * 255 |
| |
|
| | |
| | frames = (rgb_channels * alpha_channel[..., None] + white_bg * (1 - alpha_channel[..., None])).astype( |
| | np.uint8 |
| | ) |
| | frames = [frames] |
| | fps = 0 |
| | else: |
| | fps = int(meta_data.get("fps")) |
| | if max_frames != -1: |
| | frames = frames[:max_frames] |
| | input_tensor = np.stack(frames, axis=0) |
| | input_tensor = einops.rearrange(input_tensor, "t h w c -> t c h w") |
| | if normalize: |
| | input_tensor = input_tensor / 128.0 - 1.0 |
| | input_tensor = torch.from_numpy(input_tensor).bfloat16() |
| | log.debug(f"Raw data shape: {input_tensor.shape}") |
| | if H is not None and W is not None: |
| | input_tensor = transforms_F.resize( |
| | input_tensor, |
| | size=(H, W), |
| | interpolation=transforms_F.InterpolationMode.BICUBIC, |
| | antialias=True, |
| | ) |
| | input_tensor = einops.rearrange(input_tensor, "(b t) c h w -> b c t h w", b=1) |
| | if normalize: |
| | input_tensor = input_tensor.to("cuda") |
| | log.debug(f"Load shape {input_tensor.shape} value {input_tensor.min()}, {input_tensor.max()}") |
| | if also_return_fps: |
| | return input_tensor, fps |
| | return input_tensor |
| |
|
| |
|
| | def compute_num_latent_frames(model: DiffusionV2WModel, num_input_frames: int, downsample_factor=8) -> int: |
| | """This function computes the number of latent frames given the number of input frames. |
| | Args: |
| | model (DiffusionV2WModel): video generation model |
| | num_input_frames (int): number of input frames |
| | downsample_factor (int): downsample factor for temporal reduce |
| | Returns: |
| | int: number of latent frames |
| | """ |
| | num_latent_frames = ( |
| | num_input_frames |
| | // model.tokenizer.video_vae.pixel_chunk_duration |
| | * model.tokenizer.video_vae.latent_chunk_duration |
| | ) |
| | if num_input_frames % model.tokenizer.video_vae.latent_chunk_duration == 1: |
| | num_latent_frames += 1 |
| | elif num_input_frames % model.tokenizer.video_vae.latent_chunk_duration > 1: |
| | assert ( |
| | num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1 |
| | ) % downsample_factor == 0, f"num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1 must be divisible by {downsample_factor}" |
| | num_latent_frames += ( |
| | 1 + (num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1) // downsample_factor |
| | ) |
| |
|
| | return num_latent_frames |
| |
|
| |
|
| | def create_condition_latent_from_input_frames( |
| | model: DiffusionV2WModel, |
| | input_frames: torch.Tensor, |
| | num_frames_condition: int = 25, |
| | ): |
| | """Create condition latent for video generation from input frames. |
| | |
| | Takes the last num_frames_condition frames from input as conditioning. |
| | |
| | Args: |
| | model (DiffusionV2WModel): Video generation model |
| | input_frames (torch.Tensor): Input video tensor [B,C,T,H,W], range [-1,1] |
| | num_frames_condition (int): Number of frames to use for conditioning |
| | |
| | Returns: |
| | tuple: (condition_latent, encode_input_frames) where: |
| | - condition_latent (torch.Tensor): Encoded latent condition [B,C,T,H,W] |
| | - encode_input_frames (torch.Tensor): Padded input frames used for encoding |
| | """ |
| | B, C, T, H, W = input_frames.shape |
| | num_frames_encode = ( |
| | model.tokenizer.pixel_chunk_duration |
| | ) |
| | log.debug( |
| | f"num_frames_encode not set, set it based on pixel chunk duration and model state shape: {num_frames_encode}" |
| | ) |
| |
|
| | log.debug( |
| | f"Create condition latent from input frames {input_frames.shape}, value {input_frames.min()}, {input_frames.max()}, dtype {input_frames.dtype}" |
| | ) |
| |
|
| | assert ( |
| | input_frames.shape[2] >= num_frames_condition |
| | ), f"input_frames not enough for condition, require at least {num_frames_condition}, get {input_frames.shape[2]}, {input_frames.shape}" |
| | assert ( |
| | num_frames_encode >= num_frames_condition |
| | ), f"num_frames_encode should be larger than num_frames_condition, get {num_frames_encode}, {num_frames_condition}" |
| |
|
| | |
| | condition_frames = input_frames[:, :, -num_frames_condition:] |
| | padding_frames = condition_frames.new_zeros(B, C, num_frames_encode - num_frames_condition, H, W) |
| | encode_input_frames = torch.cat([condition_frames, padding_frames], dim=2) |
| |
|
| | log.debug( |
| | f"create latent with input shape {encode_input_frames.shape} including padding {num_frames_encode - num_frames_condition} at the end" |
| | ) |
| | latent = model.encode(encode_input_frames) |
| | return latent, encode_input_frames |
| |
|
| |
|
| | def get_condition_latent( |
| | model: DiffusionV2WModel, |
| | input_image_or_video_path: str, |
| | num_input_frames: int = 1, |
| | state_shape: list[int] = None, |
| | ): |
| | """Get condition latent from input image/video file. |
| | |
| | Args: |
| | model (DiffusionV2WModel): Video generation model |
| | input_image_or_video_path (str): Path to conditioning image/video |
| | num_input_frames (int): Number of input frames for video2world prediction |
| | |
| | Returns: |
| | tuple: (condition_latent, input_frames) where: |
| | - condition_latent (torch.Tensor): Encoded latent condition [B,C,T,H,W] |
| | - input_frames (torch.Tensor): Input frames tensor [B,C,T,H,W] |
| | """ |
| | if state_shape is None: |
| | state_shape = model.state_shape |
| | assert num_input_frames > 0, "num_input_frames must be greater than 0" |
| |
|
| | H, W = ( |
| | state_shape[-2] * model.tokenizer.spatial_compression_factor, |
| | state_shape[-1] * model.tokenizer.spatial_compression_factor, |
| | ) |
| |
|
| | input_path_format = input_image_or_video_path.split(".")[-1] |
| | input_frames = read_video_or_image_into_frames_BCTHW( |
| | input_image_or_video_path, |
| | input_path_format=input_path_format, |
| | H=H, |
| | W=W, |
| | ) |
| |
|
| | condition_latent, _ = create_condition_latent_from_input_frames(model, input_frames, num_input_frames) |
| | condition_latent = condition_latent.to(torch.bfloat16) |
| |
|
| | return condition_latent |
| |
|
| |
|
| | def check_input_frames(input_path: str, required_frames: int) -> bool: |
| | """Check if input video/image has sufficient frames. |
| | |
| | Args: |
| | input_path: Path to input video or image |
| | required_frames: Number of required frames |
| | |
| | Returns: |
| | np.ndarray of frames if valid, None if invalid |
| | """ |
| | if input_path.endswith((".jpg", ".jpeg", ".png")): |
| | if required_frames > 1: |
| | log.error(f"Input ({input_path}) is an image but {required_frames} frames are required") |
| | return False |
| | return True |
| | |
| | try: |
| | vid = imageio.get_reader(input_path, "ffmpeg") |
| | frame_count = vid.count_frames() |
| |
|
| | if frame_count < required_frames: |
| | log.error(f"Input video has {frame_count} frames but {required_frames} frames are required") |
| | return False |
| | else: |
| | return True |
| | except Exception as e: |
| | log.error(f"Error reading video file {input_path}: {e}") |
| | return False |
| |
|