| |
| |
|
|
| """Video generation inference script. |
| |
| Supports: |
| - Text-to-video (T2V): Wan 2.1/2.2 |
| - Image-to-video (I2V): Wan I2V |
| - Video-to-video (V2V): VACE Wan, Self-Forcing |
| - Video2World: Cosmos Predict2 |
| |
| Examples: |
| |
| # T2V: eval teacher only (Wan) |
| PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \\ |
| scripts/inference/video_model_inference.py --do_student_sampling False \\ |
| --config fastgen/configs/experiments/WanT2V/config_dmd2.py \\ |
| - trainer.seed=1 trainer.ddp=True model.guidance_scale=5.0 log_config.name=wan_t2v_inference |
| |
| # I2V: image-to-video (Wan I2V) |
| PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \\ |
| scripts/inference/video_model_inference.py --do_student_sampling False \\ |
| --input_image_file scripts/inference/prompts/source_image_paths.txt \\ |
| --config fastgen/configs/experiments/WanI2V/config_dmd2_14b.py \\ |
| - trainer.seed=1 trainer.ddp=True model.guidance_scale=5.0 log_config.name=wan_i2v_inference |
| |
| # V2V: video-to-video with VACE |
| PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \\ |
| scripts/inference/video_model_inference.py --do_student_sampling False \\ |
| --source_video_file scripts/inference/prompts/source_video_paths.txt \\ |
| --config fastgen/configs/experiments/WanV2V/config_sft_latent.py \\ |
| - trainer.seed=1 trainer.ddp=True model.guidance_scale=5.0 log_config.name=vace_wan_inference |
| |
| # Video2World: Cosmos Predict2 |
| PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \\ |
| scripts/inference/video_model_inference.py --do_student_sampling False \\ |
| --input_image_file scripts/inference/prompts/source_image_paths.txt --num_conditioning_frames 1 \\ |
| --config fastgen/configs/experiments/CosmosPredict2/config_sft.py \\ |
| - trainer.seed=1 trainer.ddp=True model.guidance_scale=5.0 model.net.is_video2world=True \\ |
| log_config.name=cosmos_v2w_inference |
| |
| # Eval with skip-layer guidance (SLG) |
| PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \\ |
| scripts/inference/video_model_inference.py --do_student_sampling False \\ |
| --config fastgen/configs/experiments/WanT2V/config_dmd2.py \\ |
| - trainer.seed=1 trainer.ddp=True model.guidance_scale=6.0 model.skip_layers=[10] \\ |
| log_config.name=wan_slg_inference |
| |
| # Eval student and teacher together |
| PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \\ |
| scripts/inference/video_model_inference.py --ckpt_path /path/to/checkpoint.pth \\ |
| --do_student_sampling True --do_teacher_sampling True \\ |
| --config fastgen/configs/experiments/WanT2V/config_dmd2.py \\ |
| - trainer.seed=1 trainer.ddp=True log_config.name=wan_student_teacher_inference |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import gc |
| import time |
| from pathlib import Path |
| from typing import TYPE_CHECKING, Optional, Sequence |
|
|
| import imageio.v3 as iio |
| import numpy as np |
| import torch |
| from tqdm.auto import tqdm |
|
|
| from fastgen.configs.config import BaseConfig |
| from fastgen.networks.WanI2V import WanI2V |
| from fastgen.networks.cosmos_predict2 import CosmosPredict2 |
| import fastgen.utils.logging_utils as logger |
| from fastgen.utils.distributed import clean_up, is_rank0, world_size |
| from fastgen.utils import basic_utils |
| from fastgen.utils.scripts import parse_args, setup |
| from fastgen.third_party.wan_prompt_expand.prompt_expand import QwenPromptExpander |
| from fastgen.datasets.wds_dataloaders import transform_video |
| from scripts.inference.inference_utils import ( |
| expand_path, |
| load_prompts, |
| init_model, |
| init_checkpointer, |
| load_checkpoint, |
| cleanup_unused_modules, |
| setup_inference_modules, |
| add_common_args, |
| ) |
|
|
| if TYPE_CHECKING: |
| from fastgen.methods import FastGenModel |
|
|
|
|
| def load_video_frames(video_path: str, num_frames: int, height: int, width: int) -> Optional[torch.Tensor]: |
| """ |
| Load video, align spatial preprocessing with dataset pipeline via transform_video, |
| and return tensor shaped [1, C, T, H, W] in [-1, 1]. |
| """ |
| try: |
| frames_np = iio.imread(video_path, plugin="pyav") |
| except Exception as e: |
| logger.error(f"Failed to read video file: {video_path} with error {e}") |
| return None |
|
|
| if frames_np is None or len(frames_np) == 0: |
| logger.error(f"No frames decoded from video file: {video_path}") |
| return None |
|
|
| |
| T = len(frames_np) |
| if T < num_frames: |
| pad_count = num_frames - T |
| last = frames_np[-1:] |
| frames_np = np.concatenate([frames_np, np.repeat(last, pad_count, axis=0)], axis=0) |
| else: |
| |
| start = max(0, (T - num_frames) // 2) |
| frames_np = frames_np[start : start + num_frames] |
|
|
| |
| frames_t = torch.from_numpy(frames_np) |
| out = transform_video(frames_t, sequence_length=num_frames, img_size=(width, height)) |
| frames_tensor = out["real"] |
| return frames_tensor.unsqueeze(0) |
|
|
|
|
| def load_conditioning_image( |
| image_path: str, height: int, width: int, num_latent_frames: int = 1 |
| ) -> Optional[torch.Tensor]: |
| """ |
| Load an image as conditioning frames for image-to-video generation. |
| |
| The image is replicated to create the pixel frames needed by temporal VAE |
| (which has 4x temporal compression). For N latent frames, we need (N-1)*4+1 pixel frames. |
| |
| Args: |
| image_path: Path to the input image. |
| height: Target height in pixels. |
| width: Target width in pixels. |
| num_latent_frames: Number of latent frames to generate from the image (default 1). |
| |
| Returns: |
| Tensor of shape [1, C, T, H, W] in [-1, 1] range, where T is the number of |
| pixel frames needed for the requested latent frames. Returns None on failure. |
| """ |
| try: |
| |
| img_np = iio.imread(image_path) |
| except Exception as e: |
| logger.error(f"Failed to read image file: {image_path} with error {e}") |
| return None |
|
|
| if img_np is None: |
| logger.error(f"Failed to load image: {image_path}") |
| return None |
|
|
| |
| if img_np.ndim == 2: |
| img_np = np.stack([img_np] * 3, axis=-1) |
| elif img_np.shape[-1] == 4: |
| img_np = img_np[..., :3] |
|
|
| |
| |
| num_pixel_frames = (num_latent_frames - 1) * 4 + 1 if num_latent_frames > 1 else 1 |
|
|
| |
| |
| frames_np = np.stack([img_np] * num_pixel_frames, axis=0) |
|
|
| |
| frames_t = torch.from_numpy(frames_np) |
| out = transform_video(frames_t, sequence_length=num_pixel_frames, img_size=(width, height)) |
| frames_tensor = out["real"] |
| return frames_tensor.unsqueeze(0) |
|
|
|
|
| def prepare_wani2v_condition( |
| conditioning_frames: torch.Tensor, |
| conditioning_latents: torch.Tensor, |
| condition: torch.Tensor, |
| neg_condition: Optional[torch.Tensor], |
| model: FastGenModel, |
| vae: torch.nn.Module, |
| t_latent: int, |
| use_concat_mask: bool, |
| ) -> tuple: |
| """Prepare condition dicts for WanI2V models. |
| |
| Args: |
| conditioning_frames: Raw pixel frames [B, C, T, H, W] in [-1, 1] |
| conditioning_latents: VAE-encoded latents of conditioning frames |
| condition: Text embeddings for positive prompt |
| neg_condition: Text embeddings for negative prompt |
| model: The model instance (for precision and device info) |
| vae: VAE model for encoding |
| t_latent: Total number of latent frames |
| use_concat_mask: Whether model uses concat mask (Wan 2.1) or frame replacement (Wan 2.2) |
| |
| Returns: |
| Tuple of (condition_dict, neg_condition_dict, i2v_tag) |
| """ |
| if use_concat_mask: |
| |
| |
| num_pixel_frames = (t_latent - 1) * 4 + 1 |
| B, C_pixel, _, H_pixel, W_pixel = conditioning_frames.shape |
| pixel_cond = torch.zeros( |
| B, |
| C_pixel, |
| num_pixel_frames, |
| H_pixel, |
| W_pixel, |
| device=conditioning_frames.device, |
| dtype=conditioning_frames.dtype, |
| ) |
| pixel_cond[:, :, 0] = conditioning_frames[:, :, 0] |
| |
| with basic_utils.inference_mode(vae, precision_amp=model.precision_amp_infer, device_type=model.device.type): |
| first_frame_cond = vae.encode(pixel_cond) |
| logger.info(f"Wan 2.1 I2V: created first_frame_cond via VAE, shape {first_frame_cond.shape}") |
| else: |
| |
| first_frame_cond = torch.zeros( |
| 1, |
| conditioning_latents.shape[1], |
| t_latent, |
| conditioning_latents.shape[3], |
| conditioning_latents.shape[4], |
| device=conditioning_latents.device, |
| dtype=conditioning_latents.dtype, |
| ) |
| first_frame_cond[:, :, : conditioning_latents.shape[2]] = conditioning_latents |
|
|
| condition_dict = {"text_embeds": condition, "first_frame_cond": first_frame_cond} |
| neg_condition_dict = {"text_embeds": neg_condition, "first_frame_cond": first_frame_cond} |
|
|
| |
| if hasattr(model.net, "image_encoder"): |
| with basic_utils.inference_mode( |
| model.net.image_encoder, precision_amp=model.precision_amp_infer, device_type=model.device.type |
| ): |
| |
| first_pixel_frame = conditioning_frames[:, :, 0:1] |
| img_embeds = model.net.image_encoder.encode(first_pixel_frame[:, :, 0]) |
| |
| img_embeds = img_embeds.to(device=model.device, dtype=model.precision) |
| condition_dict["encoder_hidden_states_image"] = img_embeds |
| neg_condition_dict["encoder_hidden_states_image"] = img_embeds |
|
|
| return condition_dict, neg_condition_dict, "_i2v" |
|
|
|
|
| def prepare_cosmos_v2w_condition( |
| conditioning_latents: torch.Tensor, |
| condition: torch.Tensor, |
| neg_condition: Optional[torch.Tensor], |
| latent_shape: Sequence[int], |
| num_conditioning_frames: int, |
| ) -> tuple: |
| """Prepare condition dicts for CosmosPredict2 video2world mode. |
| |
| Args: |
| conditioning_latents: VAE-encoded latents of conditioning frames |
| condition: Text embeddings for positive prompt |
| neg_condition: Text embeddings for negative prompt |
| latent_shape: Shape of latent tensor [C, T, H, W] |
| num_conditioning_frames: Number of frames to condition on |
| |
| Returns: |
| Tuple of (condition_dict, neg_condition_dict, i2v_tag) |
| """ |
| t_latent, h_latent, w_latent = latent_shape[1], latent_shape[2], latent_shape[3] |
|
|
| |
| condition_mask = torch.zeros( |
| 1, 1, t_latent, h_latent, w_latent, device=conditioning_latents.device, dtype=conditioning_latents.dtype |
| ) |
| condition_mask[:, :, :num_conditioning_frames] = 1.0 |
|
|
| |
| condition_dict = { |
| "text_embeds": condition, |
| "conditioning_latents": conditioning_latents, |
| "condition_mask": condition_mask, |
| } |
| neg_condition_dict = { |
| "text_embeds": neg_condition, |
| "conditioning_latents": conditioning_latents, |
| "condition_mask": condition_mask, |
| } |
|
|
| return condition_dict, neg_condition_dict, f"_v2w{num_conditioning_frames}" |
|
|
|
|
| def prepare_vacewan_condition( |
| source_video_path: str, |
| depth_latent_path: Optional[str], |
| model: FastGenModel, |
| latent_shape: Sequence[int], |
| condition: torch.Tensor, |
| neg_condition: Optional[torch.Tensor], |
| ctx: dict, |
| ) -> tuple: |
| """Prepare condition dicts for VACE Wan models (depth-to-video). |
| |
| Args: |
| source_video_path: Path to the source video for conditioning |
| depth_latent_path: Optional path to precomputed depth latents |
| model: The model instance |
| latent_shape: Shape of latent tensor [C, T, H, W] |
| condition: Text embeddings for positive prompt |
| neg_condition: Text embeddings for negative prompt |
| ctx: Device/dtype context dict |
| |
| Returns: |
| Tuple of (condition_dict, neg_condition_dict) |
| """ |
| t_latent = latent_shape[1] |
| target_frames = (t_latent - 1) * 4 + 1 |
| |
| vae_spatial_factor = 16 if latent_shape[0] == 48 else 8 |
| height = latent_shape[2] * vae_spatial_factor |
| width = latent_shape[3] * vae_spatial_factor |
|
|
| video = load_video_frames(source_video_path, num_frames=target_frames, height=height, width=width) |
| video = video.to(**ctx) |
|
|
| if depth_latent_path is None: |
| depth_latent = model.net.prepare_vid_conditioning(video=video, condition_latents=None) |
| else: |
| depth_latent = torch.load(depth_latent_path) |
| depth_latent = depth_latent[:, :t_latent] |
| depth_latent = depth_latent.unsqueeze(0) |
| depth_latent = depth_latent.to(**ctx) |
| depth_latent = model.net.prepare_vid_conditioning(video=video, condition_latents=depth_latent) |
|
|
| condition_dict = {"text_embeds": condition, "vid_context": depth_latent} |
| neg_condition_dict = {"text_embeds": neg_condition, "vid_context": depth_latent} |
|
|
| return condition_dict, neg_condition_dict |
|
|
|
|
| def prepare_i2v_condition( |
| input_image_path: str, |
| model: FastGenModel, |
| vae: Optional[torch.nn.Module], |
| latent_shape: Sequence[int], |
| condition: torch.Tensor, |
| neg_condition: Optional[torch.Tensor], |
| num_conditioning_frames: int, |
| ctx: dict, |
| ) -> tuple: |
| """Load and prepare I2V/video2world conditioning from an input image. |
| |
| Args: |
| input_image_path: Path to the input image |
| model: The model instance |
| vae: VAE model for encoding |
| latent_shape: Shape of latent tensor [C, T, H, W] |
| condition: Text embeddings for positive prompt |
| neg_condition: Text embeddings for negative prompt |
| num_conditioning_frames: Number of frames to condition on |
| ctx: Device/dtype context dict |
| |
| Returns: |
| Tuple of (condition, neg_condition, i2v_tag) where condition/neg_condition |
| may be updated dicts for I2V mode, or unchanged if loading fails. |
| """ |
| i2v_tag = "" |
|
|
| if not input_image_path or not Path(input_image_path).exists(): |
| if input_image_path: |
| logger.warning(f"Conditioning image not found: {input_image_path}") |
| return condition, neg_condition, i2v_tag |
|
|
| |
| vae_spatial_factor = 16 if latent_shape[0] == 48 else 8 |
| height = latent_shape[2] * vae_spatial_factor |
| width = latent_shape[3] * vae_spatial_factor |
|
|
| |
| conditioning_frames = load_conditioning_image( |
| input_image_path, |
| height=height, |
| width=width, |
| num_latent_frames=num_conditioning_frames, |
| ) |
|
|
| if conditioning_frames is None or vae is None: |
| logger.warning(f"Failed to encode conditioning image: {input_image_path}") |
| return condition, neg_condition, i2v_tag |
|
|
| conditioning_frames = conditioning_frames.to(**ctx) |
| with basic_utils.inference_mode(vae, precision_amp=model.precision_amp_infer, device_type=model.device.type): |
| conditioning_latents = vae.encode(conditioning_frames) |
| logger.info(f"I2V: encoded image to latents shape {conditioning_latents.shape}") |
|
|
| if getattr(model.net, "is_i2v", False): |
| assert isinstance(model.net, WanI2V), f"Expected WanI2V model but got {type(model.net).__name__}" |
| |
| use_concat_mask = getattr(model.net, "concat_mask", False) |
| return prepare_wani2v_condition( |
| conditioning_frames=conditioning_frames, |
| conditioning_latents=conditioning_latents, |
| condition=condition, |
| neg_condition=neg_condition, |
| model=model, |
| vae=vae, |
| t_latent=latent_shape[1], |
| use_concat_mask=use_concat_mask, |
| ) |
| elif getattr(model.net, "is_video2world", False): |
| assert isinstance( |
| model.net, CosmosPredict2 |
| ), f"Expected CosmosPredict2 model but got {type(model.net).__name__}" |
| |
| return prepare_cosmos_v2w_condition( |
| conditioning_latents=conditioning_latents, |
| condition=condition, |
| neg_condition=neg_condition, |
| latent_shape=latent_shape, |
| num_conditioning_frames=num_conditioning_frames, |
| ) |
| else: |
| raise NotImplementedError(f"I2V mode not implemented for {type(model.net).__name__}") |
|
|
|
|
| def expand_prompts_with_qwen( |
| prompts: list[str], |
| model_name: str, |
| device: torch.device, |
| seed: int, |
| ) -> list[str]: |
| """Expand prompts using Qwen model on rank 0 and broadcast to all ranks. |
| |
| Args: |
| prompts: List of prompts to expand |
| model_name: Qwen model name |
| device: Device to run on |
| seed: Random seed for prompt expansion |
| |
| Returns: |
| List of expanded prompts |
| """ |
| logger.info("Expanding prompts on rank 0 ...") |
| if is_rank0(): |
| prompt_expander = QwenPromptExpander( |
| model_name=model_name, |
| is_vl=False, |
| device=device, |
| ) |
| for prompt_idx in tqdm(range(len(prompts))): |
| logger.debug(f"Expanding prompt {prompts[prompt_idx]} with seed {seed}") |
| basic_utils.set_random_seed(seed) |
| prompt_output = prompt_expander(prompts[prompt_idx], tar_lang="en", seed=seed) |
| logger.info(f"Expanded prompt: {prompt_output.prompt}") |
| prompts[prompt_idx] = prompt_output.prompt |
| |
| del prompt_expander |
| gc.collect() |
| torch.cuda.empty_cache() |
| else: |
| prompts = [None] * len(prompts) |
|
|
| if world_size() > 1: |
| torch.distributed.broadcast_object_list(prompts, src=0) |
|
|
| return prompts |
|
|
|
|
| def main(args, config: BaseConfig): |
| |
| pos_prompt_set = load_prompts(args.prompt_file, relative_to="cwd") |
|
|
| |
| if args.prompt_expand_model: |
| pos_prompt_set = expand_prompts_with_qwen( |
| pos_prompt_set, |
| args.prompt_expand_model, |
| torch.device(config.model.device), |
| args.prompt_expand_model_seed, |
| ) |
|
|
| |
| depth_latent_paths = None |
| if args.depth_latent_file is not None: |
| depth_latent_path = expand_path(args.depth_latent_file, relative_to="cwd") |
| if depth_latent_path.is_file(): |
| with depth_latent_path.open("r") as f: |
| depth_latent_paths = [line.strip() for line in f.readlines()] |
| else: |
| raise FileNotFoundError(f"depth_latent_file: {depth_latent_path} not found!") |
|
|
| |
| source_video_paths = None |
| if args.source_video_file is not None: |
| source_video_path = expand_path(args.source_video_file, relative_to="cwd") |
| if source_video_path.is_file(): |
| with source_video_path.open("r") as f: |
| source_video_paths = [line.strip() for line in f.readlines()] |
| else: |
| raise FileNotFoundError(f"source_video_path: {source_video_path} not found!") |
|
|
| |
| input_image_paths = None |
| if args.input_image_file is not None: |
| input_image_file_path = expand_path(args.input_image_file, relative_to="cwd") |
| if input_image_file_path.is_file(): |
| with input_image_file_path.open("r") as f: |
| input_image_paths = [line.strip() for line in f.readlines() if line.strip()] |
|
|
| |
| num_prompts = len(pos_prompt_set) |
| num_images = len(input_image_paths) |
| if num_images < num_prompts: |
| last_image = input_image_paths[-1] if input_image_paths else "" |
| input_image_paths.extend([last_image] * (num_prompts - num_images)) |
| logger.info(f"I2V: extended {num_images} images to {num_prompts} by repeating last image") |
| elif num_images > num_prompts: |
| input_image_paths = input_image_paths[:num_prompts] |
| logger.info(f"I2V: truncated {num_images} images to {num_prompts} prompts") |
|
|
| logger.info(f"I2V mode: {len(input_image_paths)} input images for {num_prompts} prompts") |
| else: |
| raise FileNotFoundError(f"input_image_file_path: {input_image_file_path} not found!") |
|
|
| |
| seed = basic_utils.set_random_seed(config.trainer.seed, by_rank=True) |
|
|
| |
| model = init_model(config) |
| checkpointer = init_checkpointer(config) |
|
|
| |
| ckpt_iter, save_dir = load_checkpoint(checkpointer, model, args.ckpt_path, config) |
|
|
| if ckpt_iter is None and args.do_student_sampling: |
| logger.warning(f"Performing {model.config.student_sample_steps}-step generation on the non-distilled model") |
|
|
| if args.video_save_dir: |
| save_dir = args.video_save_dir |
| logger.info(f"video_save_dir: {save_dir}") |
|
|
| save_dir = Path(save_dir) |
| prompt_name = Path(args.prompt_file).stem |
| if args.prompt_expand_model: |
| prompt_name += f"_{args.prompt_expand_model}" |
| save_dir = save_dir / prompt_name |
|
|
| save_video_kwargs = {"precision_amp": model.precision_amp_infer, "save_as_gif": args.save_as_gif, "fps": args.fps} |
| if args.save_high_quality: |
| save_video_kwargs = { |
| "quality": 18, |
| "preset": "medium", |
| "fps": args.fps, |
| } |
| save_dir = save_dir.parent / (save_dir.name + "_hq") |
|
|
| |
| cleanup_unused_modules(model, args.do_teacher_sampling) |
|
|
| |
| teacher, student, vae = setup_inference_modules( |
| model, config, args.do_teacher_sampling, args.do_student_sampling, model.precision |
| ) |
| ctx = {"dtype": model.precision, "device": model.device} |
|
|
| |
| has_teacher_sampling = teacher is not None and hasattr(teacher, "sample") |
| has_student_sampling = student is not None and hasattr(model, "generator_fn") |
| assert ( |
| has_teacher_sampling or has_student_sampling |
| ), "At least one of teacher or student (with generator_fn) must be provided for sampling" |
|
|
| |
| neg_condition = None |
| if args.neg_prompt_file is not None: |
| neg_condition = load_prompts(args.neg_prompt_file, relative_to="cwd") |
| if len(neg_condition) > 0: |
| neg_condition = neg_condition[:1] |
| logger.warning(f"Found {len(neg_condition)} negative prompts, only using the first one.") |
| logger.debug(f"Loaded negative prompt: {neg_condition[0]}") |
| if hasattr(model.net, "text_encoder"): |
| with basic_utils.inference_mode( |
| model.net.text_encoder, precision_amp=model.precision_amp_enc, device_type=model.device.type |
| ): |
| neg_condition = basic_utils.to(model.net.text_encoder.encode(neg_condition), **ctx) |
|
|
| slg_tag = "" if config.model.skip_layers is None else f"_slg{'_'.join([str(x) for x in config.model.skip_layers])}" |
|
|
| |
| noise = torch.randn( |
| [1, *config.model.input_shape], |
| **ctx, |
| ) |
|
|
| for i, prompt in enumerate(pos_prompt_set): |
| logger.info(f"[{i+1}/{len(pos_prompt_set)}] Generating: {prompt[:80]}...") |
|
|
| |
| condition = [prompt] |
| if hasattr(model.net, "text_encoder"): |
| with basic_utils.inference_mode( |
| model.net.text_encoder, precision_amp=model.precision_amp_enc, device_type=model.device.type |
| ): |
| condition = basic_utils.to(model.net.text_encoder.encode(condition), **ctx) |
|
|
| |
| is_net_v2v = hasattr(model.net, "prepare_vid_conditioning") |
| if source_video_paths is not None and i < len(source_video_paths) and is_net_v2v: |
| depth_latent_path = depth_latent_paths[i] if depth_latent_paths is not None else None |
| condition, neg_condition_sample = prepare_vacewan_condition( |
| source_video_path=source_video_paths[i], |
| depth_latent_path=depth_latent_path, |
| model=model, |
| latent_shape=config.model.input_shape, |
| condition=condition, |
| neg_condition=neg_condition, |
| ctx=ctx, |
| ) |
| else: |
| neg_condition_sample = neg_condition |
|
|
| |
| |
| i2v_tag = "" |
| is_net_i2v = getattr(model.net, "is_i2v", False) or getattr(model.net, "is_video2world", False) |
| if input_image_paths is not None and i < len(input_image_paths) and is_net_i2v: |
| condition, neg_condition_sample, i2v_tag = prepare_i2v_condition( |
| input_image_path=input_image_paths[i], |
| model=model, |
| vae=vae, |
| latent_shape=config.model.input_shape, |
| condition=condition, |
| neg_condition=neg_condition, |
| num_conditioning_frames=args.num_conditioning_frames, |
| ctx=ctx, |
| ) |
|
|
| |
| if has_student_sampling: |
| use_extrapolation = args.num_segments != 1 or args.overlap_frames != 0 |
| start_time = time.time() |
|
|
| |
| student_kwargs = { |
| "condition": condition, |
| "neg_condition": neg_condition_sample, |
| "student_sample_steps": model.config.student_sample_steps, |
| "student_sample_type": model.config.student_sample_type, |
| "t_list": model.config.sample_t_cfg.t_list, |
| "precision_amp": model.precision_amp_infer, |
| } |
|
|
| if use_extrapolation: |
| if not hasattr(model, "generator_fn_extrapolation"): |
| raise RuntimeError("Extrapolation is only supported for causal autoregressive networks") |
| if not hasattr(model.net, "vae"): |
| raise RuntimeError("VAE is required for extrapolation but was not initialized") |
| student_kwargs["num_segments"] = args.num_segments |
| student_kwargs["overlap_frames"] = args.overlap_frames |
| video_student = model.generator_fn_extrapolation(student, noise, **student_kwargs) |
| else: |
| video_student = model.generator_fn(student, noise, **student_kwargs) |
|
|
| sampling_time = time.time() - start_time |
| logger.info(f"Student sampling time: {sampling_time:.2f}s") |
| save_path = save_dir / f"student_step{model.config.student_sample_steps}{i2v_tag}_{i:04d}_seed{seed}.mp4" |
| basic_utils.save_media(video_student, str(save_path), vae=vae, **save_video_kwargs) |
|
|
| |
| if has_teacher_sampling: |
| start_time = time.time() |
| teacher_kwargs = { |
| "condition": condition, |
| "neg_condition": neg_condition_sample, |
| "num_steps": args.num_steps, |
| "second_order": False, |
| "precision_amp": model.precision_amp_infer, |
| "fps": torch.full((noise.shape[0],), float(args.fps), device=noise.device), |
| } |
| if config.model.skip_layers is not None: |
| teacher_kwargs["skip_layers"] = config.model.skip_layers |
|
|
| video_teacher = model.sample(teacher, noise, **teacher_kwargs) |
| sampling_time = time.time() - start_time |
| logger.info(f"Teacher sampling time: {sampling_time:.2f}s") |
| save_path = ( |
| save_dir |
| / f"teacher_cfg{config.model.guidance_scale}_steps{args.num_steps}{slg_tag}{i2v_tag}_{i:04d}_seed{seed}.mp4" |
| ) |
| basic_utils.save_media(video_teacher, str(save_path), vae=vae, **save_video_kwargs) |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Video model inference") |
|
|
| |
| add_common_args(parser) |
|
|
| |
| parser.add_argument( |
| "--save_as_gif", |
| default=True, |
| type=basic_utils.str2bool, |
| help="Whether to save videos as GIF (True) or MP4 (False)", |
| ) |
| parser.add_argument( |
| "--fps", |
| default=16, |
| type=int, |
| help="Frames per second for saved video and model temporal encoding (default: 16, matches Wan base_fps)", |
| ) |
| parser.add_argument( |
| "--save_high_quality", |
| default=False, |
| type=basic_utils.str2bool, |
| help="Whether to save videos in high-quality (codec: libx265 vs libx264)", |
| ) |
| parser.add_argument( |
| "--prompt_file", |
| default="scripts/inference/prompts/validation_aug_qwen_2_5_14b_seed42.txt", |
| type=str, |
| help="File containing prompts (one per line). Relative paths are resolved from script directory.", |
| ) |
| parser.add_argument( |
| "--neg_prompt_file", |
| default="scripts/inference/prompts/negative_prompt.txt", |
| type=str, |
| help="The file containing the negative prompt to use for CFG.", |
| ) |
| parser.add_argument( |
| "--prompt_expand_model", |
| type=str, |
| help="If specified, perform prompt expansion using the specified Qwen model.", |
| choices=["QwenVL2.5_3B", "QwenVL2.5_7B", "Qwen2.5_3B", "Qwen2.5_7B", "Qwen2.5_14B"], |
| ) |
| parser.add_argument( |
| "--prompt_expand_model_seed", |
| type=int, |
| help="Seed for prompt expansion.", |
| default=0, |
| ) |
| parser.add_argument( |
| "--depth_latent_file", |
| default=None, |
| type=str, |
| help="The file containing the depth latent paths to use for sampling.", |
| ) |
| parser.add_argument( |
| "--source_video_file", |
| default="scripts/inference/prompts/source_video_paths.txt", |
| type=str, |
| help="The file containing the source video paths to use for sampling.", |
| ) |
| parser.add_argument( |
| "--num_segments", |
| type=int, |
| default=1, |
| help="Number of autoregressive segments to generate when using extrapolation (default: 1)", |
| ) |
| parser.add_argument( |
| "--overlap_frames", |
| type=int, |
| default=0, |
| help="Number of latent frames to overlap between segments when extrapolating (default: 0)", |
| ) |
| parser.add_argument( |
| "--video_save_dir", |
| type=str, |
| help="Path to the video save directory.", |
| default=None, |
| ) |
| parser.add_argument( |
| "--num_steps", |
| default=50, |
| type=int, |
| help="Number of sampling steps for teacher (default: 50)", |
| ) |
| |
| parser.add_argument( |
| "--input_image_file", |
| type=str, |
| default="scripts/inference/prompts/source_image_paths.txt", |
| help="File containing paths to input images (one per line) for I2V mode (or video2world mode in cosmos). " |
| "Images are aligned with prompts; if fewer images than prompts, the last image is repeated.", |
| ) |
| parser.add_argument( |
| "--num_conditioning_frames", |
| type=int, |
| default=1, |
| help="Number of latent frames to condition on for I2V mode (default: 1).", |
| ) |
| parser.add_argument( |
| "--conditional_frame_timestep", |
| type=float, |
| default=0.0, |
| help="Timestep value for conditioning frames in I2V mode. " |
| "Use 0.0 (default) to indicate clean conditioning frames. " |
| "Use -1.0 to disable timestep modification. " |
| "Use small positive value (e.g., 0.1) for noisy conditioning.", |
| ) |
|
|
| args = parse_args(parser) |
| config = setup(args, evaluation=True) |
| main(args, config) |
|
|
| clean_up() |
|
|
| |
|
|