# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 """Video generation inference script. Supports: - Text-to-video (T2V): Wan 2.1/2.2 - Image-to-video (I2V): Wan I2V - Video-to-video (V2V): VACE Wan, Self-Forcing - Video2World: Cosmos Predict2 Examples: # T2V: eval teacher only (Wan) PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \\ scripts/inference/video_model_inference.py --do_student_sampling False \\ --config fastgen/configs/experiments/WanT2V/config_dmd2.py \\ - trainer.seed=1 trainer.ddp=True model.guidance_scale=5.0 log_config.name=wan_t2v_inference # I2V: image-to-video (Wan I2V) PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \\ scripts/inference/video_model_inference.py --do_student_sampling False \\ --input_image_file scripts/inference/prompts/source_image_paths.txt \\ --config fastgen/configs/experiments/WanI2V/config_dmd2_14b.py \\ - trainer.seed=1 trainer.ddp=True model.guidance_scale=5.0 log_config.name=wan_i2v_inference # V2V: video-to-video with VACE PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \\ scripts/inference/video_model_inference.py --do_student_sampling False \\ --source_video_file scripts/inference/prompts/source_video_paths.txt \\ --config fastgen/configs/experiments/WanV2V/config_sft_latent.py \\ - trainer.seed=1 trainer.ddp=True model.guidance_scale=5.0 log_config.name=vace_wan_inference # Video2World: Cosmos Predict2 PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \\ scripts/inference/video_model_inference.py --do_student_sampling False \\ --input_image_file scripts/inference/prompts/source_image_paths.txt --num_conditioning_frames 1 \\ --config fastgen/configs/experiments/CosmosPredict2/config_sft.py \\ - trainer.seed=1 trainer.ddp=True model.guidance_scale=5.0 model.net.is_video2world=True \\ log_config.name=cosmos_v2w_inference # Eval with skip-layer guidance (SLG) PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \\ scripts/inference/video_model_inference.py --do_student_sampling False \\ --config fastgen/configs/experiments/WanT2V/config_dmd2.py \\ - trainer.seed=1 trainer.ddp=True model.guidance_scale=6.0 model.skip_layers=[10] \\ log_config.name=wan_slg_inference # Eval student and teacher together PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \\ scripts/inference/video_model_inference.py --ckpt_path /path/to/checkpoint.pth \\ --do_student_sampling True --do_teacher_sampling True \\ --config fastgen/configs/experiments/WanT2V/config_dmd2.py \\ - trainer.seed=1 trainer.ddp=True log_config.name=wan_student_teacher_inference """ from __future__ import annotations import argparse import gc import time from pathlib import Path from typing import TYPE_CHECKING, Optional, Sequence import imageio.v3 as iio import numpy as np import torch from tqdm.auto import tqdm from fastgen.configs.config import BaseConfig from fastgen.networks.WanI2V import WanI2V from fastgen.networks.cosmos_predict2 import CosmosPredict2 import fastgen.utils.logging_utils as logger from fastgen.utils.distributed import clean_up, is_rank0, world_size from fastgen.utils import basic_utils from fastgen.utils.scripts import parse_args, setup from fastgen.third_party.wan_prompt_expand.prompt_expand import QwenPromptExpander from fastgen.datasets.wds_dataloaders import transform_video from scripts.inference.inference_utils import ( expand_path, load_prompts, init_model, init_checkpointer, load_checkpoint, cleanup_unused_modules, setup_inference_modules, add_common_args, ) if TYPE_CHECKING: from fastgen.methods import FastGenModel def load_video_frames(video_path: str, num_frames: int, height: int, width: int) -> Optional[torch.Tensor]: """ Load video, align spatial preprocessing with dataset pipeline via transform_video, and return tensor shaped [1, C, T, H, W] in [-1, 1]. """ try: frames_np = iio.imread(video_path, plugin="pyav") # [T, H, W, C], uint8 except Exception as e: logger.error(f"Failed to read video file: {video_path} with error {e}") return None if frames_np is None or len(frames_np) == 0: logger.error(f"No frames decoded from video file: {video_path}") return None # Ensure enough frames by padding with the last frame (avoid tiling entire clip) T = len(frames_np) if T < num_frames: pad_count = num_frames - T last = frames_np[-1:] frames_np = np.concatenate([frames_np, np.repeat(last, pad_count, axis=0)], axis=0) else: # Use a centered segment to better match training decode behavior start = max(0, (T - num_frames) // 2) frames_np = frames_np[start : start + num_frames] # Convert to torch and apply the same preprocessing as training frames_t = torch.from_numpy(frames_np) # [T, H, W, C], uint8 out = transform_video(frames_t, sequence_length=num_frames, img_size=(width, height)) frames_tensor = out["real"] # [C, T, H, W], float in [-1, 1] return frames_tensor.unsqueeze(0) def load_conditioning_image( image_path: str, height: int, width: int, num_latent_frames: int = 1 ) -> Optional[torch.Tensor]: """ Load an image as conditioning frames for image-to-video generation. The image is replicated to create the pixel frames needed by temporal VAE (which has 4x temporal compression). For N latent frames, we need (N-1)*4+1 pixel frames. Args: image_path: Path to the input image. height: Target height in pixels. width: Target width in pixels. num_latent_frames: Number of latent frames to generate from the image (default 1). Returns: Tensor of shape [1, C, T, H, W] in [-1, 1] range, where T is the number of pixel frames needed for the requested latent frames. Returns None on failure. """ try: # Load image using imageio img_np = iio.imread(image_path) # [H, W, C], uint8 except Exception as e: logger.error(f"Failed to read image file: {image_path} with error {e}") return None if img_np is None: logger.error(f"Failed to load image: {image_path}") return None # Handle grayscale images if img_np.ndim == 2: img_np = np.stack([img_np] * 3, axis=-1) elif img_np.shape[-1] == 4: # RGBA img_np = img_np[..., :3] # For temporal VAE with 4x compression, we need (T-1)*4+1 pixel frames for T latent frames # For 1 latent frame: 1 pixel frame; For 2 latent frames: 5 pixel frames, etc. num_pixel_frames = (num_latent_frames - 1) * 4 + 1 if num_latent_frames > 1 else 1 # Replicate image to create video-like input # Shape: [T, H, W, C] frames_np = np.stack([img_np] * num_pixel_frames, axis=0) # Convert to torch and apply preprocessing frames_t = torch.from_numpy(frames_np) # [T, H, W, C], uint8 out = transform_video(frames_t, sequence_length=num_pixel_frames, img_size=(width, height)) frames_tensor = out["real"] # [C, T, H, W], float in [-1, 1] return frames_tensor.unsqueeze(0) def prepare_wani2v_condition( conditioning_frames: torch.Tensor, conditioning_latents: torch.Tensor, condition: torch.Tensor, neg_condition: Optional[torch.Tensor], model: FastGenModel, vae: torch.nn.Module, t_latent: int, use_concat_mask: bool, ) -> tuple: """Prepare condition dicts for WanI2V models. Args: conditioning_frames: Raw pixel frames [B, C, T, H, W] in [-1, 1] conditioning_latents: VAE-encoded latents of conditioning frames condition: Text embeddings for positive prompt neg_condition: Text embeddings for negative prompt model: The model instance (for precision and device info) vae: VAE model for encoding t_latent: Total number of latent frames use_concat_mask: Whether model uses concat mask (Wan 2.1) or frame replacement (Wan 2.2) Returns: Tuple of (condition_dict, neg_condition_dict, i2v_tag) """ if use_concat_mask: # Wan 2.1 14B: first_frame_cond must be created in pixel space then encoded # This matches training: [first_frame, zeros, zeros, ...] -> VAE encode num_pixel_frames = (t_latent - 1) * 4 + 1 B, C_pixel, _, H_pixel, W_pixel = conditioning_frames.shape pixel_cond = torch.zeros( B, C_pixel, num_pixel_frames, H_pixel, W_pixel, device=conditioning_frames.device, dtype=conditioning_frames.dtype, ) pixel_cond[:, :, 0] = conditioning_frames[:, :, 0] # First frame only # Encode through VAE (zeros become VAE-encoded zeros, not latent zeros) with basic_utils.inference_mode(vae, precision_amp=model.precision_amp_infer, device_type=model.device.type): first_frame_cond = vae.encode(pixel_cond) logger.info(f"Wan 2.1 I2V: created first_frame_cond via VAE, shape {first_frame_cond.shape}") else: # Wan 2.2 5B: pad with zeros in latent space (simpler, no concat) first_frame_cond = torch.zeros( 1, conditioning_latents.shape[1], t_latent, conditioning_latents.shape[3], conditioning_latents.shape[4], device=conditioning_latents.device, dtype=conditioning_latents.dtype, ) first_frame_cond[:, :, : conditioning_latents.shape[2]] = conditioning_latents condition_dict = {"text_embeds": condition, "first_frame_cond": first_frame_cond} neg_condition_dict = {"text_embeds": neg_condition, "first_frame_cond": first_frame_cond} # Add image encoder embeddings if available (Wan 2.1 14B I2V) if hasattr(model.net, "image_encoder"): with basic_utils.inference_mode( model.net.image_encoder, precision_amp=model.precision_amp_infer, device_type=model.device.type ): # Use first pixel frame for image encoder first_pixel_frame = conditioning_frames[:, :, 0:1] # [B, C, 1, H, W] img_embeds = model.net.image_encoder.encode(first_pixel_frame[:, :, 0]) # Ensure embeddings are on the correct device and dtype img_embeds = img_embeds.to(device=model.device, dtype=model.precision) condition_dict["encoder_hidden_states_image"] = img_embeds neg_condition_dict["encoder_hidden_states_image"] = img_embeds return condition_dict, neg_condition_dict, "_i2v" def prepare_cosmos_v2w_condition( conditioning_latents: torch.Tensor, condition: torch.Tensor, neg_condition: Optional[torch.Tensor], latent_shape: Sequence[int], num_conditioning_frames: int, ) -> tuple: """Prepare condition dicts for CosmosPredict2 video2world mode. Args: conditioning_latents: VAE-encoded latents of conditioning frames condition: Text embeddings for positive prompt neg_condition: Text embeddings for negative prompt latent_shape: Shape of latent tensor [C, T, H, W] num_conditioning_frames: Number of frames to condition on Returns: Tuple of (condition_dict, neg_condition_dict, i2v_tag) """ t_latent, h_latent, w_latent = latent_shape[1], latent_shape[2], latent_shape[3] # Create condition mask: 1 for conditioning frames, 0 for generated condition_mask = torch.zeros( 1, 1, t_latent, h_latent, w_latent, device=conditioning_latents.device, dtype=conditioning_latents.dtype ) condition_mask[:, :, :num_conditioning_frames] = 1.0 # Build condition dict for forward() compatibility condition_dict = { "text_embeds": condition, "conditioning_latents": conditioning_latents, "condition_mask": condition_mask, } neg_condition_dict = { "text_embeds": neg_condition, "conditioning_latents": conditioning_latents, "condition_mask": condition_mask, } return condition_dict, neg_condition_dict, f"_v2w{num_conditioning_frames}" def prepare_vacewan_condition( source_video_path: str, depth_latent_path: Optional[str], model: FastGenModel, latent_shape: Sequence[int], condition: torch.Tensor, neg_condition: Optional[torch.Tensor], ctx: dict, ) -> tuple: """Prepare condition dicts for VACE Wan models (depth-to-video). Args: source_video_path: Path to the source video for conditioning depth_latent_path: Optional path to precomputed depth latents model: The model instance latent_shape: Shape of latent tensor [C, T, H, W] condition: Text embeddings for positive prompt neg_condition: Text embeddings for negative prompt ctx: Device/dtype context dict Returns: Tuple of (condition_dict, neg_condition_dict) """ t_latent = latent_shape[1] target_frames = (t_latent - 1) * 4 + 1 # VAE spatial compression factor: 16 for Wan 2.2 (48ch latents), 8 for others vae_spatial_factor = 16 if latent_shape[0] == 48 else 8 height = latent_shape[2] * vae_spatial_factor width = latent_shape[3] * vae_spatial_factor video = load_video_frames(source_video_path, num_frames=target_frames, height=height, width=width) # [-1, 1] video = video.to(**ctx) if depth_latent_path is None: depth_latent = model.net.prepare_vid_conditioning(video=video, condition_latents=None) else: depth_latent = torch.load(depth_latent_path) depth_latent = depth_latent[:, :t_latent] depth_latent = depth_latent.unsqueeze(0) depth_latent = depth_latent.to(**ctx) depth_latent = model.net.prepare_vid_conditioning(video=video, condition_latents=depth_latent) condition_dict = {"text_embeds": condition, "vid_context": depth_latent} neg_condition_dict = {"text_embeds": neg_condition, "vid_context": depth_latent} return condition_dict, neg_condition_dict def prepare_i2v_condition( input_image_path: str, model: FastGenModel, vae: Optional[torch.nn.Module], latent_shape: Sequence[int], condition: torch.Tensor, neg_condition: Optional[torch.Tensor], num_conditioning_frames: int, ctx: dict, ) -> tuple: """Load and prepare I2V/video2world conditioning from an input image. Args: input_image_path: Path to the input image model: The model instance vae: VAE model for encoding latent_shape: Shape of latent tensor [C, T, H, W] condition: Text embeddings for positive prompt neg_condition: Text embeddings for negative prompt num_conditioning_frames: Number of frames to condition on ctx: Device/dtype context dict Returns: Tuple of (condition, neg_condition, i2v_tag) where condition/neg_condition may be updated dicts for I2V mode, or unchanged if loading fails. """ i2v_tag = "" if not input_image_path or not Path(input_image_path).exists(): if input_image_path: logger.warning(f"Conditioning image not found: {input_image_path}") return condition, neg_condition, i2v_tag # VAE spatial compression factor: 16 for Wan 2.2 (48ch latents), 8 for others vae_spatial_factor = 16 if latent_shape[0] == 48 else 8 height = latent_shape[2] * vae_spatial_factor width = latent_shape[3] * vae_spatial_factor # Load and preprocess input image conditioning_frames = load_conditioning_image( input_image_path, height=height, width=width, num_latent_frames=num_conditioning_frames, ) if conditioning_frames is None or vae is None: logger.warning(f"Failed to encode conditioning image: {input_image_path}") return condition, neg_condition, i2v_tag conditioning_frames = conditioning_frames.to(**ctx) with basic_utils.inference_mode(vae, precision_amp=model.precision_amp_infer, device_type=model.device.type): conditioning_latents = vae.encode(conditioning_frames) logger.info(f"I2V: encoded image to latents shape {conditioning_latents.shape}") if getattr(model.net, "is_i2v", False): assert isinstance(model.net, WanI2V), f"Expected WanI2V model but got {type(model.net).__name__}" # WanI2V model use_concat_mask = getattr(model.net, "concat_mask", False) return prepare_wani2v_condition( conditioning_frames=conditioning_frames, conditioning_latents=conditioning_latents, condition=condition, neg_condition=neg_condition, model=model, vae=vae, t_latent=latent_shape[1], use_concat_mask=use_concat_mask, ) elif getattr(model.net, "is_video2world", False): assert isinstance( model.net, CosmosPredict2 ), f"Expected CosmosPredict2 model but got {type(model.net).__name__}" # CosmosPredict2 video2world return prepare_cosmos_v2w_condition( conditioning_latents=conditioning_latents, condition=condition, neg_condition=neg_condition, latent_shape=latent_shape, num_conditioning_frames=num_conditioning_frames, ) else: raise NotImplementedError(f"I2V mode not implemented for {type(model.net).__name__}") def expand_prompts_with_qwen( prompts: list[str], model_name: str, device: torch.device, seed: int, ) -> list[str]: """Expand prompts using Qwen model on rank 0 and broadcast to all ranks. Args: prompts: List of prompts to expand model_name: Qwen model name device: Device to run on seed: Random seed for prompt expansion Returns: List of expanded prompts """ logger.info("Expanding prompts on rank 0 ...") if is_rank0(): prompt_expander = QwenPromptExpander( model_name=model_name, is_vl=False, device=device, ) for prompt_idx in tqdm(range(len(prompts))): logger.debug(f"Expanding prompt {prompts[prompt_idx]} with seed {seed}") basic_utils.set_random_seed(seed) prompt_output = prompt_expander(prompts[prompt_idx], tar_lang="en", seed=seed) logger.info(f"Expanded prompt: {prompt_output.prompt}") prompts[prompt_idx] = prompt_output.prompt # Free memory del prompt_expander gc.collect() torch.cuda.empty_cache() else: prompts = [None] * len(prompts) if world_size() > 1: torch.distributed.broadcast_object_list(prompts, src=0) return prompts def main(args, config: BaseConfig): # Load prompts pos_prompt_set = load_prompts(args.prompt_file, relative_to="cwd") # Prompt expansion if specified if args.prompt_expand_model: pos_prompt_set = expand_prompts_with_qwen( pos_prompt_set, args.prompt_expand_model, torch.device(config.model.device), args.prompt_expand_model_seed, ) # Load depth latent paths depth_latent_paths = None if args.depth_latent_file is not None: depth_latent_path = expand_path(args.depth_latent_file, relative_to="cwd") if depth_latent_path.is_file(): with depth_latent_path.open("r") as f: depth_latent_paths = [line.strip() for line in f.readlines()] else: raise FileNotFoundError(f"depth_latent_file: {depth_latent_path} not found!") # Load source video paths source_video_paths = None if args.source_video_file is not None: source_video_path = expand_path(args.source_video_file, relative_to="cwd") if source_video_path.is_file(): with source_video_path.open("r") as f: source_video_paths = [line.strip() for line in f.readlines()] else: raise FileNotFoundError(f"source_video_path: {source_video_path} not found!") # Load input images for I2V mode (or video2world mode in cosmos) input_image_paths = None if args.input_image_file is not None: input_image_file_path = expand_path(args.input_image_file, relative_to="cwd") if input_image_file_path.is_file(): with input_image_file_path.open("r") as f: input_image_paths = [line.strip() for line in f.readlines() if line.strip()] # Align with prompts: repeat last image if fewer images than prompts num_prompts = len(pos_prompt_set) num_images = len(input_image_paths) if num_images < num_prompts: last_image = input_image_paths[-1] if input_image_paths else "" input_image_paths.extend([last_image] * (num_prompts - num_images)) logger.info(f"I2V: extended {num_images} images to {num_prompts} by repeating last image") elif num_images > num_prompts: input_image_paths = input_image_paths[:num_prompts] logger.info(f"I2V: truncated {num_images} images to {num_prompts} prompts") logger.info(f"I2V mode: {len(input_image_paths)} input images for {num_prompts} prompts") else: raise FileNotFoundError(f"input_image_file_path: {input_image_file_path} not found!") # Fix sampling seeds seed = basic_utils.set_random_seed(config.trainer.seed, by_rank=True) # Initialize model and checkpointer model = init_model(config) checkpointer = init_checkpointer(config) # Load checkpoint ckpt_iter, save_dir = load_checkpoint(checkpointer, model, args.ckpt_path, config) if ckpt_iter is None and args.do_student_sampling: logger.warning(f"Performing {model.config.student_sample_steps}-step generation on the non-distilled model") if args.video_save_dir: # overwrite the save_dir save_dir = args.video_save_dir logger.info(f"video_save_dir: {save_dir}") save_dir = Path(save_dir) prompt_name = Path(args.prompt_file).stem if args.prompt_expand_model: prompt_name += f"_{args.prompt_expand_model}" save_dir = save_dir / prompt_name save_video_kwargs = {"precision_amp": model.precision_amp_infer, "save_as_gif": args.save_as_gif, "fps": args.fps} if args.save_high_quality: save_video_kwargs = { "quality": 18, "preset": "medium", "fps": args.fps, } save_dir = save_dir.parent / (save_dir.name + "_hq") # Remove unused modules cleanup_unused_modules(model, args.do_teacher_sampling) # Get precision and set up inference modules teacher, student, vae = setup_inference_modules( model, config, args.do_teacher_sampling, args.do_student_sampling, model.precision ) ctx = {"dtype": model.precision, "device": model.device} # Check if we have at least one valid sampling path has_teacher_sampling = teacher is not None and hasattr(teacher, "sample") has_student_sampling = student is not None and hasattr(model, "generator_fn") assert ( has_teacher_sampling or has_student_sampling ), "At least one of teacher or student (with generator_fn) must be provided for sampling" # Load negative condition neg_condition = None if args.neg_prompt_file is not None: neg_condition = load_prompts(args.neg_prompt_file, relative_to="cwd") if len(neg_condition) > 0: neg_condition = neg_condition[:1] logger.warning(f"Found {len(neg_condition)} negative prompts, only using the first one.") logger.debug(f"Loaded negative prompt: {neg_condition[0]}") if hasattr(model.net, "text_encoder"): with basic_utils.inference_mode( model.net.text_encoder, precision_amp=model.precision_amp_enc, device_type=model.device.type ): neg_condition = basic_utils.to(model.net.text_encoder.encode(neg_condition), **ctx) slg_tag = "" if config.model.skip_layers is None else f"_slg{'_'.join([str(x) for x in config.model.skip_layers])}" # Fix noise for all generated samples noise = torch.randn( [1, *config.model.input_shape], **ctx, ) for i, prompt in enumerate(pos_prompt_set): logger.info(f"[{i+1}/{len(pos_prompt_set)}] Generating: {prompt[:80]}...") # Encode prompt condition = [prompt] if hasattr(model.net, "text_encoder"): with basic_utils.inference_mode( model.net.text_encoder, precision_amp=model.precision_amp_enc, device_type=model.device.type ): condition = basic_utils.to(model.net.text_encoder.encode(condition), **ctx) # VACE Wan mode: depth-to-video conditioning is_net_v2v = hasattr(model.net, "prepare_vid_conditioning") if source_video_paths is not None and i < len(source_video_paths) and is_net_v2v: depth_latent_path = depth_latent_paths[i] if depth_latent_paths is not None else None condition, neg_condition_sample = prepare_vacewan_condition( source_video_path=source_video_paths[i], depth_latent_path=depth_latent_path, model=model, latent_shape=config.model.input_shape, condition=condition, neg_condition=neg_condition, ctx=ctx, ) else: neg_condition_sample = neg_condition # Image-to-video / Video2world mode: load and encode conditioning image # Skip if already using VACE video conditioning (model has prepare_vid_conditioning) i2v_tag = "" is_net_i2v = getattr(model.net, "is_i2v", False) or getattr(model.net, "is_video2world", False) if input_image_paths is not None and i < len(input_image_paths) and is_net_i2v: condition, neg_condition_sample, i2v_tag = prepare_i2v_condition( input_image_path=input_image_paths[i], model=model, vae=vae, latent_shape=config.model.input_shape, condition=condition, neg_condition=neg_condition, num_conditioning_frames=args.num_conditioning_frames, ctx=ctx, ) # Student sampling if has_student_sampling: use_extrapolation = args.num_segments != 1 or args.overlap_frames != 0 start_time = time.time() # Build student sampling kwargs student_kwargs = { "condition": condition, "neg_condition": neg_condition_sample, "student_sample_steps": model.config.student_sample_steps, "student_sample_type": model.config.student_sample_type, "t_list": model.config.sample_t_cfg.t_list, "precision_amp": model.precision_amp_infer, } if use_extrapolation: if not hasattr(model, "generator_fn_extrapolation"): raise RuntimeError("Extrapolation is only supported for causal autoregressive networks") if not hasattr(model.net, "vae"): raise RuntimeError("VAE is required for extrapolation but was not initialized") student_kwargs["num_segments"] = args.num_segments student_kwargs["overlap_frames"] = args.overlap_frames video_student = model.generator_fn_extrapolation(student, noise, **student_kwargs) else: video_student = model.generator_fn(student, noise, **student_kwargs) sampling_time = time.time() - start_time logger.info(f"Student sampling time: {sampling_time:.2f}s") save_path = save_dir / f"student_step{model.config.student_sample_steps}{i2v_tag}_{i:04d}_seed{seed}.mp4" basic_utils.save_media(video_student, str(save_path), vae=vae, **save_video_kwargs) # Teacher sampling if has_teacher_sampling: start_time = time.time() teacher_kwargs = { "condition": condition, "neg_condition": neg_condition_sample, "num_steps": args.num_steps, "second_order": False, "precision_amp": model.precision_amp_infer, "fps": torch.full((noise.shape[0],), float(args.fps), device=noise.device), } if config.model.skip_layers is not None: teacher_kwargs["skip_layers"] = config.model.skip_layers video_teacher = model.sample(teacher, noise, **teacher_kwargs) sampling_time = time.time() - start_time logger.info(f"Teacher sampling time: {sampling_time:.2f}s") save_path = ( save_dir / f"teacher_cfg{config.model.guidance_scale}_steps{args.num_steps}{slg_tag}{i2v_tag}_{i:04d}_seed{seed}.mp4" ) basic_utils.save_media(video_teacher, str(save_path), vae=vae, **save_video_kwargs) # ---------------------------------------------------------------------------- if __name__ == "__main__": parser = argparse.ArgumentParser(description="Video model inference") # Add common args add_common_args(parser) # Video-specific args parser.add_argument( "--save_as_gif", default=True, type=basic_utils.str2bool, help="Whether to save videos as GIF (True) or MP4 (False)", ) parser.add_argument( "--fps", default=16, type=int, help="Frames per second for saved video and model temporal encoding (default: 16, matches Wan base_fps)", ) parser.add_argument( "--save_high_quality", default=False, type=basic_utils.str2bool, help="Whether to save videos in high-quality (codec: libx265 vs libx264)", ) parser.add_argument( "--prompt_file", default="scripts/inference/prompts/validation_aug_qwen_2_5_14b_seed42.txt", type=str, help="File containing prompts (one per line). Relative paths are resolved from script directory.", ) parser.add_argument( "--neg_prompt_file", default="scripts/inference/prompts/negative_prompt.txt", type=str, help="The file containing the negative prompt to use for CFG.", ) parser.add_argument( "--prompt_expand_model", type=str, help="If specified, perform prompt expansion using the specified Qwen model.", choices=["QwenVL2.5_3B", "QwenVL2.5_7B", "Qwen2.5_3B", "Qwen2.5_7B", "Qwen2.5_14B"], ) parser.add_argument( "--prompt_expand_model_seed", type=int, help="Seed for prompt expansion.", default=0, ) parser.add_argument( "--depth_latent_file", default=None, type=str, help="The file containing the depth latent paths to use for sampling.", ) parser.add_argument( "--source_video_file", default="scripts/inference/prompts/source_video_paths.txt", type=str, help="The file containing the source video paths to use for sampling.", ) parser.add_argument( "--num_segments", type=int, default=1, help="Number of autoregressive segments to generate when using extrapolation (default: 1)", ) parser.add_argument( "--overlap_frames", type=int, default=0, help="Number of latent frames to overlap between segments when extrapolating (default: 0)", ) parser.add_argument( "--video_save_dir", type=str, help="Path to the video save directory.", default=None, ) parser.add_argument( "--num_steps", default=50, type=int, help="Number of sampling steps for teacher (default: 50)", ) # I2V arguments parser.add_argument( "--input_image_file", type=str, default="scripts/inference/prompts/source_image_paths.txt", help="File containing paths to input images (one per line) for I2V mode (or video2world mode in cosmos). " "Images are aligned with prompts; if fewer images than prompts, the last image is repeated.", ) parser.add_argument( "--num_conditioning_frames", type=int, default=1, help="Number of latent frames to condition on for I2V mode (default: 1).", ) parser.add_argument( "--conditional_frame_timestep", type=float, default=0.0, help="Timestep value for conditioning frames in I2V mode. " "Use 0.0 (default) to indicate clean conditioning frames. " "Use -1.0 to disable timestep modification. " "Use small positive value (e.g., 0.1) for noisy conditioning.", ) args = parse_args(parser) config = setup(args, evaluation=True) main(args, config) clean_up() # ----------------------------------------------------------------------------