| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import argparse |
| import importlib |
| from contextlib import contextmanager |
| from typing import List, NamedTuple, Optional, Tuple |
|
|
| import einops |
| import imageio |
| import numpy as np |
| import torch |
| import torchvision.transforms.functional as transforms_F |
|
|
| from .model_t2w import DiffusionT2WModel |
| from .model_v2w import DiffusionV2WModel |
| from .config_helper import get_config_module, override |
| from .utils_io import load_from_fileobj |
| from .misc import misc |
|
|
| TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) |
| if TORCH_VERSION >= (1, 11): |
| from torch.ao import quantization |
| from torch.ao.quantization import FakeQuantizeBase, ObserverBase |
| elif ( |
| TORCH_VERSION >= (1, 8) |
| and hasattr(torch.quantization, "FakeQuantizeBase") |
| and hasattr(torch.quantization, "ObserverBase") |
| ): |
| from torch import quantization |
| from torch.quantization import FakeQuantizeBase, ObserverBase |
|
|
| DEFAULT_AUGMENT_SIGMA = 0.001 |
|
|
|
|
| def add_common_arguments(parser): |
| """Add common command line arguments for text2world and video2world generation. |
| |
| Args: |
| parser (ArgumentParser): Argument parser to add arguments to |
| |
| The arguments include: |
| - checkpoint_dir: Base directory containing model weights |
| - tokenizer_dir: Directory containing tokenizer weights |
| - video_save_name: Output video filename for single video generation |
| - video_save_folder: Output directory for batch video generation |
| - prompt: Text prompt for single video generation |
| - batch_input_path: Path to JSONL file with input prompts for batch video generation |
| - negative_prompt: Text prompt describing undesired attributes |
| - num_steps: Number of diffusion sampling steps |
| - guidance: Classifier-free guidance scale |
| - num_video_frames: Number of frames to generate |
| - height/width: Output video dimensions |
| - fps: Output video frame rate |
| - seed: Random seed for reproducibility |
| - Various model offloading flags |
| """ |
| parser.add_argument( |
| "--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints" |
| ) |
| parser.add_argument( |
| "--tokenizer_dir", |
| type=str, |
| default="Cosmos-1.0-Tokenizer-CV8x8x8", |
| help="Tokenizer weights directory relative to checkpoint_dir", |
| ) |
| parser.add_argument( |
| "--video_save_name", |
| type=str, |
| default="output", |
| help="Output filename for generating a single video", |
| ) |
| parser.add_argument( |
| "--video_save_folder", |
| type=str, |
| default="outputs/", |
| help="Output folder for generating a batch of videos", |
| ) |
| parser.add_argument( |
| "--prompt", |
| type=str, |
| help="Text prompt for generating a single video", |
| ) |
| parser.add_argument( |
| "--batch_input_path", |
| type=str, |
| help="Path to a JSONL file of input prompts for generating a batch of videos", |
| ) |
| parser.add_argument( |
| "--negative_prompt", |
| type=str, |
| default="The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " |
| "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " |
| "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, " |
| "jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special " |
| "effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and " |
| "flickering. Overall, the video is of poor quality.", |
| help="Negative prompt for the video", |
| ) |
| parser.add_argument("--num_steps", type=int, default=35, help="Number of diffusion sampling steps") |
| parser.add_argument("--guidance", type=float, default=7, help="Guidance scale value") |
| parser.add_argument("--num_video_frames", type=int, default=121, help="Number of video frames to sample") |
| parser.add_argument("--height", type=int, default=704, help="Height of video to sample") |
| parser.add_argument("--width", type=int, default=1280, help="Width of video to sample") |
| parser.add_argument("--fps", type=int, default=24, help="FPS of the sampled video") |
| parser.add_argument("--seed", type=int, default=1, help="Random seed") |
| parser.add_argument( |
| "--disable_prompt_upsampler", |
| action="store_true", |
| help="Disable prompt upsampling", |
| ) |
| parser.add_argument( |
| "--offload_diffusion_transformer", |
| action="store_true", |
| help="Offload DiT after inference", |
| ) |
| parser.add_argument( |
| "--offload_tokenizer", |
| action="store_true", |
| help="Offload tokenizer after inference", |
| ) |
| parser.add_argument( |
| "--offload_text_encoder_model", |
| action="store_true", |
| help="Offload text encoder model after inference", |
| ) |
| parser.add_argument( |
| "--offload_prompt_upsampler", |
| action="store_true", |
| help="Offload prompt upsampler after inference", |
| ) |
| parser.add_argument( |
| "--offload_guardrail_models", |
| action="store_true", |
| help="Offload guardrail models after inference", |
| ) |
|
|
|
|
| def validate_args(args: argparse.Namespace, inference_type: str) -> None: |
| """Validate command line arguments for text2world and video2world generation.""" |
| assert inference_type in [ |
| "text2world", |
| "video2world", |
| ], "Invalid inference_type, must be 'text2world' or 'video2world'" |
|
|
| |
| if inference_type == "text2world" or (inference_type == "video2world" and args.disable_prompt_upsampler): |
| assert args.prompt or args.batch_input_path, "--prompt or --batch_input_path must be provided." |
| if inference_type == "video2world" and not args.batch_input_path: |
| assert ( |
| args.input_image_or_video_path |
| ), "--input_image_or_video_path must be provided for single video generation." |
|
|
|
|
| class _IncompatibleKeys( |
| NamedTuple( |
| "IncompatibleKeys", |
| [ |
| ("missing_keys", List[str]), |
| ("unexpected_keys", List[str]), |
| ("incorrect_shapes", List[Tuple[str, Tuple[int], Tuple[int]]]), |
| ], |
| ) |
| ): |
| pass |
|
|
|
|
| def non_strict_load_model(model: torch.nn.Module, checkpoint_state_dict: dict) -> _IncompatibleKeys: |
| """Load a model checkpoint with non-strict matching, handling shape mismatches. |
| |
| Args: |
| model (torch.nn.Module): Model to load weights into |
| checkpoint_state_dict (dict): State dict from checkpoint |
| |
| Returns: |
| _IncompatibleKeys: Named tuple containing: |
| - missing_keys: Keys present in model but missing from checkpoint |
| - unexpected_keys: Keys present in checkpoint but not in model |
| - incorrect_shapes: Keys with mismatched tensor shapes |
| |
| The function handles special cases like: |
| - Uninitialized parameters |
| - Quantization observers |
| - TransformerEngine FP8 states |
| """ |
| |
| model_state_dict = model.state_dict() |
| incorrect_shapes = [] |
| for k in list(checkpoint_state_dict.keys()): |
| if k in model_state_dict: |
| if "_extra_state" in k: |
| log.debug(f"Skipping key {k} introduced by TransformerEngine for FP8 in the checkpoint.") |
| continue |
| model_param = model_state_dict[k] |
| |
| if TORCH_VERSION >= (1, 8) and isinstance(model_param, torch.nn.parameter.UninitializedParameter): |
| continue |
| if not isinstance(model_param, torch.Tensor): |
| raise ValueError( |
| f"Find non-tensor parameter {k} in the model. type: {type(model_param)} {type(checkpoint_state_dict[k])}, please check if this key is safe to skip or not." |
| ) |
|
|
| shape_model = tuple(model_param.shape) |
| shape_checkpoint = tuple(checkpoint_state_dict[k].shape) |
| if shape_model != shape_checkpoint: |
| has_observer_base_classes = ( |
| TORCH_VERSION >= (1, 8) |
| and hasattr(quantization, "ObserverBase") |
| and hasattr(quantization, "FakeQuantizeBase") |
| ) |
| if has_observer_base_classes: |
| |
| |
| def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module: |
| |
| key_parts = key.split(".")[:-1] |
| cur_module = model |
| for key_part in key_parts: |
| cur_module = getattr(cur_module, key_part) |
| return cur_module |
|
|
| cls_to_skip = ( |
| ObserverBase, |
| FakeQuantizeBase, |
| ) |
| target_module = _get_module_for_key(model, k) |
| if isinstance(target_module, cls_to_skip): |
| |
| |
| |
| continue |
|
|
| incorrect_shapes.append((k, shape_checkpoint, shape_model)) |
| checkpoint_state_dict.pop(k) |
| incompatible = model.load_state_dict(checkpoint_state_dict, strict=False) |
| |
| missing_keys = [k for k in incompatible.missing_keys if "_extra_state" not in k] |
| unexpected_keys = [k for k in incompatible.unexpected_keys if "_extra_state" not in k] |
| return _IncompatibleKeys( |
| missing_keys=missing_keys, |
| unexpected_keys=unexpected_keys, |
| incorrect_shapes=incorrect_shapes, |
| ) |
|
|
|
|
| @contextmanager |
| def skip_init_linear(): |
| |
| orig_reset_parameters = torch.nn.Linear.reset_parameters |
| torch.nn.Linear.reset_parameters = lambda x: x |
| xavier_uniform_ = torch.nn.init.xavier_uniform_ |
| torch.nn.init.xavier_uniform_ = lambda x: x |
| yield |
| torch.nn.Linear.reset_parameters = orig_reset_parameters |
| torch.nn.init.xavier_uniform_ = xavier_uniform_ |
|
|
|
|
| def load_model_by_config( |
| config_job_name, |
| config_file="projects/cosmos_video/config/config.py", |
| model_class=DiffusionT2WModel, |
| ): |
| config_module = get_config_module(config_file) |
| config = importlib.import_module(config_module).make_config() |
|
|
| config = override(config, ["--", f"experiment={config_job_name}"]) |
|
|
| |
| config.validate() |
| |
| config.freeze() |
|
|
| |
| with skip_init_linear(): |
| model = model_class(config.model) |
| return model |
|
|
|
|
| def load_network_model(model: DiffusionT2WModel, ckpt_path: str): |
| with skip_init_linear(): |
| model.set_up_model() |
| net_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) |
| log.debug(non_strict_load_model(model.model, net_state_dict)) |
| model.cuda() |
|
|
|
|
| def load_tokenizer_model(model: DiffusionT2WModel, tokenizer_dir: str): |
| with skip_init_linear(): |
| model.set_up_tokenizer(tokenizer_dir) |
| model.cuda() |
|
|
|
|
| def prepare_data_batch( |
| height: int, |
| width: int, |
| num_frames: int, |
| fps: int, |
| prompt_embedding: torch.Tensor, |
| negative_prompt_embedding: Optional[torch.Tensor] = None, |
| ): |
| """Prepare input batch tensors for video generation. |
| |
| Args: |
| height (int): Height of video frames |
| width (int): Width of video frames |
| num_frames (int): Number of frames to generate |
| fps (int): Frames per second |
| prompt_embedding (torch.Tensor): Encoded text prompt embeddings |
| negative_prompt_embedding (torch.Tensor, optional): Encoded negative prompt embeddings |
| |
| Returns: |
| dict: Batch dictionary containing: |
| - video: Zero tensor of target video shape |
| - t5_text_mask: Attention mask for text embeddings |
| - image_size: Target frame dimensions |
| - fps: Target frame rate |
| - num_frames: Number of frames |
| - padding_mask: Frame padding mask |
| - t5_text_embeddings: Prompt embeddings |
| - neg_t5_text_embeddings: Negative prompt embeddings (if provided) |
| - neg_t5_text_mask: Mask for negative embeddings (if provided) |
| """ |
| |
| data_batch = { |
| "video": torch.zeros((1, 3, num_frames, height, width), dtype=torch.uint8).cuda(), |
| "t5_text_mask": torch.ones(1, 512, dtype=torch.bfloat16).cuda(), |
| "image_size": torch.tensor([[height, width, height, width]] * 1, dtype=torch.bfloat16).cuda(), |
| "fps": torch.tensor([fps] * 1, dtype=torch.bfloat16).cuda(), |
| "num_frames": torch.tensor([num_frames] * 1, dtype=torch.bfloat16).cuda(), |
| "padding_mask": torch.zeros((1, 1, height, width), dtype=torch.bfloat16).cuda(), |
| } |
|
|
| |
|
|
| t5_embed = prompt_embedding.to(dtype=torch.bfloat16).cuda() |
| data_batch["t5_text_embeddings"] = t5_embed |
|
|
| if negative_prompt_embedding is not None: |
| neg_t5_embed = negative_prompt_embedding.to(dtype=torch.bfloat16).cuda() |
| data_batch["neg_t5_text_embeddings"] = neg_t5_embed |
| data_batch["neg_t5_text_mask"] = torch.ones(1, 512, dtype=torch.bfloat16).cuda() |
|
|
| return data_batch |
|
|
|
|
| def get_video_batch(model, prompt_embedding, negative_prompt_embedding, height, width, fps, num_video_frames): |
| """Prepare complete input batch for video generation including latent dimensions. |
| |
| Args: |
| model: Diffusion model instance |
| prompt_embedding (torch.Tensor): Text prompt embeddings |
| negative_prompt_embedding (torch.Tensor): Negative prompt embeddings |
| height (int): Output video height |
| width (int): Output video width |
| fps (int): Output video frame rate |
| num_video_frames (int): Number of frames to generate |
| |
| Returns: |
| tuple: |
| - data_batch (dict): Complete model input batch |
| - state_shape (list): Shape of latent state [C,T,H,W] accounting for VAE compression |
| """ |
| raw_video_batch = prepare_data_batch( |
| height=height, |
| width=width, |
| num_frames=num_video_frames, |
| fps=fps, |
| prompt_embedding=prompt_embedding, |
| negative_prompt_embedding=negative_prompt_embedding, |
| ) |
| state_shape = [ |
| model.tokenizer.channel, |
| model.tokenizer.get_latent_num_frames(num_video_frames), |
| height // model.tokenizer.spatial_compression_factor, |
| width // model.tokenizer.spatial_compression_factor, |
| ] |
| return raw_video_batch, state_shape |
|
|
|
|
| def generate_world_from_text( |
| model: DiffusionT2WModel, |
| state_shape: list[int], |
| is_negative_prompt: bool, |
| data_batch: dict, |
| guidance: float, |
| num_steps: int, |
| seed: int, |
| ): |
| """Generate video from text prompt using diffusion model. |
| |
| Args: |
| model (DiffusionT2WModel): Text-to-video diffusion model |
| state_shape (list[int]): Latent state dimensions [C,T,H,W] |
| is_negative_prompt (bool): Whether negative prompt is provided |
| data_batch (dict): Model input batch with embeddings |
| guidance (float): Classifier-free guidance scale |
| num_steps (int): Number of diffusion sampling steps |
| seed (int): Random seed for reproducibility |
| |
| Returns: |
| np.ndarray: Generated video frames [T,H,W,C], range [0,255] |
| |
| The function: |
| 1. Initializes random latent with maximum noise |
| 2. Performs guided diffusion sampling |
| 3. Decodes latents to pixel space |
| """ |
| x_sigma_max = ( |
| misc.arch_invariant_rand( |
| (1,) + tuple(state_shape), |
| torch.float32, |
| model.tensor_kwargs["device"], |
| seed, |
| ) |
| * model.sde.sigma_max |
| ) |
|
|
| |
| sample = model.generate_samples_from_batch( |
| data_batch, |
| guidance=guidance, |
| state_shape=state_shape, |
| num_steps=num_steps, |
| is_negative_prompt=is_negative_prompt, |
| seed=seed, |
| x_sigma_max=x_sigma_max, |
| ) |
|
|
| return sample |
|
|
|
|
| def generate_world_from_video( |
| model: DiffusionV2WModel, |
| state_shape: list[int], |
| is_negative_prompt: bool, |
| data_batch: dict, |
| guidance: float, |
| num_steps: int, |
| seed: int, |
| condition_latent: torch.Tensor, |
| num_input_frames: int, |
| ) -> Tuple[np.array, list, list]: |
| """Generate video using a conditioning video/image input. |
| |
| Args: |
| model (DiffusionV2WModel): The diffusion model instance |
| state_shape (list[int]): Shape of the latent state [C,T,H,W] |
| is_negative_prompt (bool): Whether negative prompt is provided |
| data_batch (dict): Batch containing model inputs including text embeddings |
| guidance (float): Classifier-free guidance scale for sampling |
| num_steps (int): Number of diffusion sampling steps |
| seed (int): Random seed for generation |
| condition_latent (torch.Tensor): Latent tensor from conditioning video/image file |
| num_input_frames (int): Number of input frames |
| |
| Returns: |
| np.array: Generated video frames in shape [T,H,W,C], range [0,255] |
| """ |
| assert not model.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i, "not supported" |
| augment_sigma = DEFAULT_AUGMENT_SIGMA |
|
|
| if condition_latent.shape[2] < state_shape[1]: |
| |
| b, c, t, h, w = condition_latent.shape |
| condition_latent = torch.cat( |
| [ |
| condition_latent, |
| condition_latent.new_zeros(b, c, state_shape[1] - t, h, w), |
| ], |
| dim=2, |
| ).contiguous() |
| num_of_latent_condition = compute_num_latent_frames(model, num_input_frames) |
|
|
| x_sigma_max = ( |
| misc.arch_invariant_rand( |
| (1,) + tuple(state_shape), |
| torch.float32, |
| model.tensor_kwargs["device"], |
| seed, |
| ) |
| * model.sde.sigma_max |
| ) |
|
|
| sample = model.generate_samples_from_batch( |
| data_batch, |
| guidance=guidance, |
| state_shape=state_shape, |
| num_steps=num_steps, |
| is_negative_prompt=is_negative_prompt, |
| seed=seed, |
| condition_latent=condition_latent, |
| num_condition_t=num_of_latent_condition, |
| condition_video_augment_sigma_in_inference=augment_sigma, |
| x_sigma_max=x_sigma_max, |
| ) |
| return sample |
|
|
|
|
| def read_video_or_image_into_frames_BCTHW( |
| input_path: str, |
| input_path_format: str = "mp4", |
| H: int = None, |
| W: int = None, |
| normalize: bool = True, |
| max_frames: int = -1, |
| also_return_fps: bool = False, |
| ) -> torch.Tensor: |
| """Read video or image file and convert to tensor format. |
| |
| Args: |
| input_path (str): Path to input video/image file |
| input_path_format (str): Format of input file (default: "mp4") |
| H (int, optional): Height to resize frames to |
| W (int, optional): Width to resize frames to |
| normalize (bool): Whether to normalize pixel values to [-1,1] (default: True) |
| max_frames (int): Maximum number of frames to read (-1 for all frames) |
| also_return_fps (bool): Whether to return fps along with frames |
| |
| Returns: |
| torch.Tensor | tuple: Video tensor in shape [B,C,T,H,W], optionally with fps if requested |
| """ |
| log.debug(f"Reading video from {input_path}") |
|
|
| loaded_data = load_from_fileobj(input_path, format=input_path_format) |
| frames, meta_data = loaded_data |
| if input_path.endswith(".png") or input_path.endswith(".jpg") or input_path.endswith(".jpeg"): |
| frames = np.array(frames[0]) |
| if frames.shape[-1] > 3: |
| |
| rgb_channels = frames[..., :3] |
| alpha_channel = frames[..., 3] / 255.0 |
|
|
| |
| white_bg = np.ones_like(rgb_channels) * 255 |
|
|
| |
| frames = (rgb_channels * alpha_channel[..., None] + white_bg * (1 - alpha_channel[..., None])).astype( |
| np.uint8 |
| ) |
| frames = [frames] |
| fps = 0 |
| else: |
| fps = int(meta_data.get("fps")) |
| if max_frames != -1: |
| frames = frames[:max_frames] |
| input_tensor = np.stack(frames, axis=0) |
| input_tensor = einops.rearrange(input_tensor, "t h w c -> t c h w") |
| if normalize: |
| input_tensor = input_tensor / 128.0 - 1.0 |
| input_tensor = torch.from_numpy(input_tensor).bfloat16() |
| log.debug(f"Raw data shape: {input_tensor.shape}") |
| if H is not None and W is not None: |
| input_tensor = transforms_F.resize( |
| input_tensor, |
| size=(H, W), |
| interpolation=transforms_F.InterpolationMode.BICUBIC, |
| antialias=True, |
| ) |
| input_tensor = einops.rearrange(input_tensor, "(b t) c h w -> b c t h w", b=1) |
| if normalize: |
| input_tensor = input_tensor.to("cuda") |
| log.debug(f"Load shape {input_tensor.shape} value {input_tensor.min()}, {input_tensor.max()}") |
| if also_return_fps: |
| return input_tensor, fps |
| return input_tensor |
|
|
|
|
| def compute_num_latent_frames(model: DiffusionV2WModel, num_input_frames: int, downsample_factor=8) -> int: |
| """This function computes the number of latent frames given the number of input frames. |
| Args: |
| model (DiffusionV2WModel): video generation model |
| num_input_frames (int): number of input frames |
| downsample_factor (int): downsample factor for temporal reduce |
| Returns: |
| int: number of latent frames |
| """ |
| num_latent_frames = ( |
| num_input_frames |
| // model.tokenizer.video_vae.pixel_chunk_duration |
| * model.tokenizer.video_vae.latent_chunk_duration |
| ) |
| if num_input_frames % model.tokenizer.video_vae.latent_chunk_duration == 1: |
| num_latent_frames += 1 |
| elif num_input_frames % model.tokenizer.video_vae.latent_chunk_duration > 1: |
| assert ( |
| num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1 |
| ) % downsample_factor == 0, f"num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1 must be divisible by {downsample_factor}" |
| num_latent_frames += ( |
| 1 + (num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1) // downsample_factor |
| ) |
|
|
| return num_latent_frames |
|
|
|
|
| def create_condition_latent_from_input_frames( |
| model: DiffusionV2WModel, |
| input_frames: torch.Tensor, |
| num_frames_condition: int = 25, |
| ): |
| """Create condition latent for video generation from input frames. |
| |
| Takes the last num_frames_condition frames from input as conditioning. |
| |
| Args: |
| model (DiffusionV2WModel): Video generation model |
| input_frames (torch.Tensor): Input video tensor [B,C,T,H,W], range [-1,1] |
| num_frames_condition (int): Number of frames to use for conditioning |
| |
| Returns: |
| tuple: (condition_latent, encode_input_frames) where: |
| - condition_latent (torch.Tensor): Encoded latent condition [B,C,T,H,W] |
| - encode_input_frames (torch.Tensor): Padded input frames used for encoding |
| """ |
| B, C, T, H, W = input_frames.shape |
| num_frames_encode = ( |
| model.tokenizer.pixel_chunk_duration |
| ) |
| log.debug( |
| f"num_frames_encode not set, set it based on pixel chunk duration and model state shape: {num_frames_encode}" |
| ) |
|
|
| log.debug( |
| f"Create condition latent from input frames {input_frames.shape}, value {input_frames.min()}, {input_frames.max()}, dtype {input_frames.dtype}" |
| ) |
|
|
| assert ( |
| input_frames.shape[2] >= num_frames_condition |
| ), f"input_frames not enough for condition, require at least {num_frames_condition}, get {input_frames.shape[2]}, {input_frames.shape}" |
| assert ( |
| num_frames_encode >= num_frames_condition |
| ), f"num_frames_encode should be larger than num_frames_condition, get {num_frames_encode}, {num_frames_condition}" |
|
|
| |
| condition_frames = input_frames[:, :, -num_frames_condition:] |
| padding_frames = condition_frames.new_zeros(B, C, num_frames_encode - num_frames_condition, H, W) |
| encode_input_frames = torch.cat([condition_frames, padding_frames], dim=2) |
|
|
| log.debug( |
| f"create latent with input shape {encode_input_frames.shape} including padding {num_frames_encode - num_frames_condition} at the end" |
| ) |
| latent = model.encode(encode_input_frames) |
| return latent, encode_input_frames |
|
|
|
|
| def get_condition_latent( |
| model: DiffusionV2WModel, |
| input_image_or_video_path: str, |
| num_input_frames: int = 1, |
| state_shape: list[int] = None, |
| ): |
| """Get condition latent from input image/video file. |
| |
| Args: |
| model (DiffusionV2WModel): Video generation model |
| input_image_or_video_path (str): Path to conditioning image/video |
| num_input_frames (int): Number of input frames for video2world prediction |
| |
| Returns: |
| tuple: (condition_latent, input_frames) where: |
| - condition_latent (torch.Tensor): Encoded latent condition [B,C,T,H,W] |
| - input_frames (torch.Tensor): Input frames tensor [B,C,T,H,W] |
| """ |
| if state_shape is None: |
| state_shape = model.state_shape |
| assert num_input_frames > 0, "num_input_frames must be greater than 0" |
|
|
| H, W = ( |
| state_shape[-2] * model.tokenizer.spatial_compression_factor, |
| state_shape[-1] * model.tokenizer.spatial_compression_factor, |
| ) |
|
|
| input_path_format = input_image_or_video_path.split(".")[-1] |
| input_frames = read_video_or_image_into_frames_BCTHW( |
| input_image_or_video_path, |
| input_path_format=input_path_format, |
| H=H, |
| W=W, |
| ) |
|
|
| condition_latent, _ = create_condition_latent_from_input_frames(model, input_frames, num_input_frames) |
| condition_latent = condition_latent.to(torch.bfloat16) |
|
|
| return condition_latent |
|
|
|
|
| def check_input_frames(input_path: str, required_frames: int) -> bool: |
| """Check if input video/image has sufficient frames. |
| |
| Args: |
| input_path: Path to input video or image |
| required_frames: Number of required frames |
| |
| Returns: |
| np.ndarray of frames if valid, None if invalid |
| """ |
| if input_path.endswith((".jpg", ".jpeg", ".png")): |
| if required_frames > 1: |
| log.error(f"Input ({input_path}) is an image but {required_frames} frames are required") |
| return False |
| return True |
| |
| try: |
| vid = imageio.get_reader(input_path, "ffmpeg") |
| frame_count = vid.count_frames() |
|
|
| if frame_count < required_frames: |
| log.error(f"Input video has {frame_count} frames but {required_frames} frames are required") |
| return False |
| else: |
| return True |
| except Exception as e: |
| log.error(f"Error reading video file {input_path}: {e}") |
| return False |
|
|