""" Single GPU Inference Pipeline - Refactored from inference_pipe.py This file extracts core logic from multi-GPU inference code to implement a complete inference pipeline on a single GPU: 1. VAE encode input video 2. DiT inference (using input mode, processing all 30 blocks) 3. VAE decode output video """ from models.wan.causal_stream_inference import CausalStreamInferencePipeline from models.util import set_seed from diffusers.utils import export_to_video from models.data import TextDataset import argparse from dataclasses import dataclass import torch import os import time import numpy as np import logging from typing import List try: from streamv2v.inference_common import ( load_generator_state_dict, load_mp4_as_tensor, merge_cli_config, ) except ModuleNotFoundError: from inference_common import ( load_generator_state_dict, load_mp4_as_tensor, merge_cli_config, ) LOGGER = logging.getLogger(__name__) @dataclass class SingleGPUStreamSession: prompt: str noise_scale: float init_noise_scale: float chunk_size: int current_start: int current_end: int last_image: torch.Tensor processed: int = 0 def compute_noise_scale_and_step(input_video_original: torch.Tensor, end_idx: int, chunk_size: int, noise_scale: float, init_noise_scale: float): """Compute adaptive noise scale and current step based on video content.""" l2_dist=(input_video_original[:,:,end_idx-chunk_size:end_idx]-input_video_original[:,:,end_idx-chunk_size-1:end_idx-1])**2 l2_dist = (torch.sqrt(l2_dist.mean(dim=(0,1,3,4))).max()/0.2).clamp(0,1) new_noise_scale = (init_noise_scale-0.1*l2_dist.item())*0.9+noise_scale*0.1 current_step = int(1000*new_noise_scale)-100 return new_noise_scale, current_step class SingleGPUInferencePipeline: """ Single GPU Inference Pipeline Manager This class encapsulates the complete inference logic on a single GPU, including encoding, inference, and decoding. """ def __init__(self, config, device: torch.device): """ Initialize the single GPU inference pipeline manager. Args: config: Configuration object device: GPU device """ self.config = config self.device = device # Setup logging self.logger = logging.getLogger("SingleGPUInference") self.logger.setLevel(logging.INFO) # Prevent messages from propagating to the root logger (avoid double prints) self.logger.propagate = False if not self.logger.handlers: handler = logging.StreamHandler() formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') handler.setFormatter(formatter) self.logger.addHandler(handler) # Initialize pipeline self.pipeline = CausalStreamInferencePipeline(config, device=str(device)) self.pipeline.to(device=str(device), dtype=torch.bfloat16) # Performance tracking self.t_dit = 100.0 self.t_total = 100.0 self.processed = 0 self.processed_offset = 3 self.base_chunk_size = 4 self.t_refresh = 50 self.t2v = config.t2v self.profile = bool(config.get("profile", False)) self.encode_fps_list: list[float] = [] self.decode_fps_list: list[float] = [] self.logger.info("Single GPU inference pipeline manager initialized") def load_model(self, checkpoint_folder: str): """Load the model from checkpoint.""" ckpt_path, state_dict = load_generator_state_dict(checkpoint_folder) self.logger.info(f"Loading checkpoint from {ckpt_path}") # Load into the pipeline generator try: self.pipeline.generator.load_state_dict(state_dict, strict=True) except RuntimeError as e: # Try non-strict load as a fallback and report self.logger.warning(f"Strict load_state_dict failed: {e}; retrying with strict=False") self.pipeline.generator.load_state_dict(state_dict, strict=False) def prepare_pipeline(self, text_prompts: list, noise: torch.Tensor, current_start: int, current_end: int): """Prepare the pipeline for inference.""" # Use the original prepare method which now handles distributed environment gracefully denoised_pred = self.pipeline.prepare( text_prompts=text_prompts, device=self.device, dtype=torch.bfloat16, block_mode='input', noise=noise, current_start=current_start, current_end=current_end ) return denoised_pred def _sync_for_timing(self): if self.profile: torch.cuda.synchronize() def _record_stage_fps(self, values: list[float], num_frames: int, elapsed: float) -> None: if self.profile and elapsed > 0 and num_frames > 0: values.append(num_frames / elapsed) def _timed_stream_encode(self, images: torch.Tensor) -> torch.Tensor: self._sync_for_timing() start_time = time.time() latents = self.pipeline.vae.stream_encode(images) self._sync_for_timing() self._record_stage_fps(self.encode_fps_list, int(images.shape[2]), time.time() - start_time) return latents def _timed_stream_decode(self, denoised_pred: torch.Tensor) -> torch.Tensor: self._sync_for_timing() start_time = time.time() video = self.pipeline.vae.stream_decode_to_pixel(denoised_pred) self._sync_for_timing() self._record_stage_fps(self.decode_fps_list, int(video.shape[1]), time.time() - start_time) return video def reset_stream_state(self, reset_vae_flags: bool = True) -> None: """Reset cached model state before starting a new streaming session.""" if reset_vae_flags: self.pipeline.vae.model.first_encode = True self.pipeline.vae.model.first_decode = True self.pipeline.kv_cache1 = None self.pipeline.crossattn_cache = None self.pipeline.block_x = None self.pipeline.hidden_states = None self.processed = 0 def _encode_noisy_latents(self, images: torch.Tensor, noise_scale: float) -> torch.Tensor: latents = self._timed_stream_encode(images) latents = latents.transpose(2, 1).contiguous().to(dtype=torch.bfloat16) noise = torch.randn_like(latents) return noise * noise_scale + latents * (1 - noise_scale) def _decode_video_array(self, denoised_pred: torch.Tensor, last_frame_only: bool = False) -> np.ndarray: if last_frame_only: denoised_pred = denoised_pred[[-1]] video = self._timed_stream_decode(denoised_pred) video = (video * 0.5 + 0.5).clamp(0, 1) video = video[0].permute(0, 2, 3, 1).contiguous() return video.detach().cpu().float().numpy() def start_stream_session(self, prompt: str, images: torch.Tensor, noise_scale: float) -> tuple[SingleGPUStreamSession, np.ndarray]: """Initialize a streaming session and return the first decoded frames.""" self.reset_stream_state(reset_vae_flags=True) chunk_size = self.base_chunk_size * self.pipeline.num_frame_per_block current_start = 0 current_end = self.pipeline.frame_seq_length * (1 + chunk_size // self.base_chunk_size) noisy_latents = self._encode_noisy_latents(images, noise_scale) denoised_pred = self.prepare_pipeline( text_prompts=[prompt], noise=noisy_latents, current_start=current_start, current_end=current_end, ) initial_video = self._decode_video_array(denoised_pred, last_frame_only=False) session = SingleGPUStreamSession( prompt=prompt, noise_scale=noise_scale, init_noise_scale=noise_scale, chunk_size=chunk_size, current_start=current_end, current_end=current_end + (chunk_size // self.base_chunk_size) * self.pipeline.frame_seq_length, last_image=images[:, :, [-1]], processed=0, ) return session, initial_video def run_stream_batch(self, session: SingleGPUStreamSession, images: torch.Tensor, queue_wait_time: float | None = None) -> List[np.ndarray]: """Process one or more chunk-aligned frame groups for an active streaming session.""" num_frames = images.shape[2] input_batch = num_frames // session.chunk_size noise_scale, current_step = compute_noise_scale_and_step( input_video_original=torch.cat([session.last_image, images], dim=2), end_idx=num_frames + 1, chunk_size=num_frames, noise_scale=float(session.noise_scale), init_noise_scale=float(session.init_noise_scale), ) noisy_latents = self._encode_noisy_latents(images, noise_scale) outputs: List[np.ndarray] = [] num_steps = len(self.pipeline.denoising_step_list) for batch_idx in range(input_batch): if session.current_start // self.pipeline.frame_seq_length >= self.t_refresh: session.current_start = self.pipeline.kv_cache_length - self.pipeline.frame_seq_length session.current_end = session.current_start + (session.chunk_size // self.base_chunk_size) * self.pipeline.frame_seq_length denoised_pred = self.pipeline.inference_stream( noise=noisy_latents[:, batch_idx].unsqueeze(1), current_start=session.current_start, current_end=session.current_end, current_step=current_step, ) session.processed += 1 self.processed = session.processed if session.processed >= num_steps: outputs.append(self._decode_video_array(denoised_pred, last_frame_only=True)) session.current_start = session.current_end session.current_end += (session.chunk_size // self.base_chunk_size) * self.pipeline.frame_seq_length session.last_image = images[:, :, [-1]] session.noise_scale = noise_scale return outputs def run_inference( self, input_video_original: torch.Tensor, prompts: list, num_chunks: int, chunk_size: int, noise_scale: float, output_folder: str, fps: int, target_fps:int, num_steps: int, ): """ Run the complete single GPU inference pipeline. This method integrates the complete encoding, inference, and decoding pipeline. """ self.logger.info("Starting single GPU inference pipeline") os.makedirs(output_folder, exist_ok=True) results = {} save_results = 0 fps_list = [] dit_fps_list = [] self.encode_fps_list = [] self.decode_fps_list = [] # Initialize variables start_idx = 0 if self.t2v: end_idx = 1 + chunk_size - 4 else: end_idx = 1 + chunk_size current_start = 0 current_end = self.pipeline.frame_seq_length * (1+(end_idx-1)//4) self._sync_for_timing() start_time = time.time() # Process first chunk (initialization) if not self.t2v: inp = input_video_original[:, :, start_idx:end_idx] # VAE encoding latents = self._timed_stream_encode(inp) latents = latents.transpose(2, 1).contiguous().to(dtype=torch.bfloat16) noise = torch.randn_like(latents) noisy_latents = noise * noise_scale + latents * (1 - noise_scale) else: noisy_latents = torch.randn(1,self.pipeline.num_frame_per_block,16,self.pipeline.height,self.pipeline.width, device=self.device, dtype=torch.bfloat16) # Prepare pipeline denoised_pred = self.prepare_pipeline( text_prompts=prompts, noise=noisy_latents, current_start=current_start, current_end=current_end ) # Save first result - only start decoding after num_steps video = self._timed_stream_decode(denoised_pred) video = (video * 0.5 + 0.5).clamp(0, 1) video = video[0].permute(0, 2, 3, 1).contiguous() results[save_results] = video.cpu().float().numpy() self.logger.info( "Prepared initial chunk: start=%s, end=%s, start_idx=%s, save_results=%s, frames=%s", current_start, current_end, start_idx, save_results, video.shape[0], ) save_results += 1 init_noise_scale = noise_scale # Process remaining chunks while self.processed < num_chunks + num_steps - 1: # Update indices start_idx = end_idx end_idx = end_idx + chunk_size current_start = current_end current_end = current_end + (chunk_size // 4) * self.pipeline.frame_seq_length if not self.t2v and end_idx <= input_video_original.shape[2]: inp = input_video_original[:, :, start_idx:end_idx] noise_scale, current_step = compute_noise_scale_and_step( input_video_original, end_idx, chunk_size, noise_scale, init_noise_scale ) # VAE encoding latents = self._timed_stream_encode(inp) latents = latents.transpose(2, 1).contiguous().to(dtype=torch.bfloat16) noise = torch.randn_like(latents) noisy_latents = noise * noise_scale + latents * (1 - noise_scale) else: noisy_latents = torch.randn(1,self.pipeline.num_frame_per_block,16,self.pipeline.height,self.pipeline.width, device=self.device, dtype=torch.bfloat16) current_step = None # Use default steps self._sync_for_timing() dit_start_time = time.time() # DiT inference - using input mode to process all 30 blocks denoised_pred = self.pipeline.inference_stream( noise=noisy_latents, current_start=current_start, current_end=current_end, current_step=current_step, ) if self.processed > self.processed_offset: self._sync_for_timing() if self.profile: dit_fps_list.append(chunk_size / (time.time() - dit_start_time)) self.processed += 1 # VAE decoding - only start decoding after num_steps if self.processed >= num_steps: if self.t2v and self.processed == num_steps: continue video = self._timed_stream_decode(denoised_pred[[-1]]) video = (video * 0.5 + 0.5).clamp(0, 1) video = video[0].permute(0, 2, 3, 1).contiguous() results[save_results] = video.cpu().float().numpy() save_results += 1 # Update timing if self.profile: self._sync_for_timing() end_time = time.time() t = end_time - start_time fps_test = chunk_size / t fps_list.append(fps_test) self.logger.info(f"Processed {self.processed}, time: {t:.4f} s, FPS: {fps_test:.4f}") else: fps_test = None if self.processed == num_steps + self.processed_offset and target_fps is not None and fps_test is not None and fps_test < target_fps: max_chunk_size = (self.pipeline.num_kv_cache - self.pipeline.num_sink_tokens - 1) * self.base_chunk_size num_chunks=(num_chunks-self.processed-num_steps+1)//(max_chunk_size//chunk_size)+self.processed-num_steps+1 self.pipeline.hidden_states=self.pipeline.hidden_states.repeat(1,max_chunk_size//chunk_size,1,1,1) chunk_size = max_chunk_size self.logger.info(f"Adjust chunk size to {chunk_size}") if self.profile: start_time = end_time # Save final video video_list = [results[i] for i in range(num_chunks)] video = np.concatenate(video_list, axis=0) if self.profile and fps_list: fps_avg = np.mean(np.array(fps_list)) dit_avg = np.mean(np.array(dit_fps_list)) if dit_fps_list else 0.0 encode_avg = np.mean(np.array(self.encode_fps_list)) if self.encode_fps_list else 0.0 decode_avg = np.mean(np.array(self.decode_fps_list)) if self.decode_fps_list else 0.0 self.logger.info(f"VAE Encode Average FPS: {encode_avg:.4f}") self.logger.info(f"DiT Average FPS: {dit_avg:.4f}") self.logger.info(f"VAE Decode Average FPS: {decode_avg:.4f}") self.logger.info(f"Video shape: {video.shape}, Average FPS: {fps_avg:.4f}") else: self.logger.info(f"Video shape: {video.shape}") output_path = os.path.join(output_folder, f"output_{0:03d}.mp4") export_to_video(video, output_path, fps=fps) self.logger.info(f"Video saved to: {output_path}") self.logger.info("Single GPU inference pipeline completed") def main(): """Main function for the single GPU inference pipeline.""" parser = argparse.ArgumentParser() parser.add_argument("--config_path", type=str, required=True, help="Configuration file path") parser.add_argument("--checkpoint_folder", type=str, required=True, help="Checkpoint folder path") parser.add_argument("--output_folder", type=str, required=True, help="Output folder path") parser.add_argument("--prompt_file_path", type=str, required=True, help="Prompt file path") parser.add_argument("--video_path", type=str, required=False, default=None, help="Input video path") parser.add_argument("--noise_scale", type=float, default=0.8, help="Noise scale") parser.add_argument("--height", type=int, default=480, help="Video height") parser.add_argument("--width", type=int, default=832, help="Video width") parser.add_argument("--fps", type=int, default=16, help="Output video fps") parser.add_argument("--step", type=int, default=2, help="Step") parser.add_argument("--seed", type=int, default=0, help="Random seed") parser.add_argument("--gpu_id", type=int, default=None, help="CUDA device index for single-GPU inference") parser.add_argument("--model_type", type=str, default="T2V-1.3B", help="Model type (e.g., T2V-1.3B)") parser.add_argument("--num_frames", type=int, default=81, help="Video length (number of frames)") parser.add_argument("--fixed_noise_scale", action="store_true", default=False) parser.add_argument("--t2v", action="store_true", default=False) parser.add_argument("--target_fps", type=int, required=False, default=None, help="Video length (number of frames)") parser.add_argument("--profile", action="store_true", default=False, help="Enable synchronized throughput logging") parser.add_argument("--use_taehv", action="store_true", default=False, help="Use the lightweight TAEHV VAE for encode/decode") parser.add_argument("--use_tensorrt", "--use_taehv_tensorrt", dest="use_tensorrt", action="store_true", default=False, help="Enable available TensorRT acceleration paths") parser.add_argument("--fast", action="store_true", default=False, help="Enable the fast path: --use_taehv --use_tensorrt") args = parser.parse_args() torch.set_grad_enabled(False) # Auto-detect device if torch.cuda.is_available(): if args.gpu_id is not None: torch.cuda.set_device(args.gpu_id) device = torch.device(f"cuda:{args.gpu_id}") else: device = torch.device("cuda") else: device = torch.device("cpu") # Load configuration config = merge_cli_config(args.config_path, args) set_seed(args.seed) logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") LOGGER.info("Denoising Step List: %s", list(config.denoising_step_list)) # Load input video if not args.t2v: input_video_original = load_mp4_as_tensor(args.video_path, resize_hw=(args.height, args.width)).unsqueeze(0) LOGGER.info("Input video tensor shape: %s", tuple(input_video_original.shape)) b, c, t, h, w = input_video_original.shape if input_video_original.dtype != torch.bfloat16: input_video_original = input_video_original.to(dtype=torch.bfloat16).to(device) else: input_video_original = None t = args.num_frames # Calculate number of chunks chunk_size = 4 * config.num_frame_per_block num_chunks = (t - 1) // chunk_size if args.t2v: num_chunks+=1 # Initialize pipeline manager pipeline_manager = SingleGPUInferencePipeline(config, device) pipeline_manager.load_model(args.checkpoint_folder) # Load prompts dataset = TextDataset(args.prompt_file_path) prompts = [dataset[0]] num_steps = len(pipeline_manager.pipeline.denoising_step_list) # Run inference try: pipeline_manager.run_inference( input_video_original, prompts, num_chunks, chunk_size, args.noise_scale, args.output_folder, args.fps, args.target_fps, num_steps, ) except Exception as e: LOGGER.exception("Error occurred during inference: %s", e) raise if __name__ == "__main__": main()