Spaces:
Running on Zero
Running on Zero
| """ | |
| Refactored multi-rank inference pipeline with communication abstractions. | |
| This is a refactored version of inference_pipe_multi.py that uses the new | |
| communication abstraction layers for better code organization and maintainability. | |
| """ | |
| from models.wan.causal_stream_inference import CausalStreamInferencePipeline | |
| from models.util import set_seed | |
| from diffusers.utils import export_to_video | |
| from models.data import TextDataset | |
| import argparse | |
| from dataclasses import dataclass | |
| import torch | |
| import torch.distributed as dist | |
| import os | |
| import time | |
| import numpy as np | |
| import logging | |
| try: | |
| from streamv2v.inference import compute_noise_scale_and_step | |
| from streamv2v.communication import ( | |
| DistributedCommunicator, | |
| ModelDataTransfer, | |
| BufferManager, | |
| KVCacheManager, | |
| CommunicationConfig, | |
| init_distributed, | |
| setup_logging, | |
| compute_balanced_split | |
| ) | |
| from streamv2v.inference_common import ( | |
| load_generator_state_dict, | |
| load_mp4_as_tensor, | |
| merge_cli_config, | |
| ) | |
| except ModuleNotFoundError: | |
| from inference import compute_noise_scale_and_step | |
| from communication import ( | |
| DistributedCommunicator, | |
| ModelDataTransfer, | |
| BufferManager, | |
| KVCacheManager, | |
| CommunicationConfig, | |
| init_distributed, | |
| setup_logging, | |
| compute_balanced_split | |
| ) | |
| from inference_common import ( | |
| load_generator_state_dict, | |
| load_mp4_as_tensor, | |
| merge_cli_config, | |
| ) | |
| LOGGER = logging.getLogger(__name__) | |
| def compute_default_block_distribution(total_blocks: int, world_size: int) -> list[list[int]]: | |
| """Split transformer blocks into contiguous ranges for each rank.""" | |
| if world_size == 2: | |
| midpoint = total_blocks // 2 | |
| return [[0, midpoint], [midpoint, total_blocks]] | |
| base = total_blocks // world_size | |
| rem = total_blocks % world_size | |
| start = 0 | |
| block_ranges = [] | |
| for rank in range(world_size): | |
| size = base + (1 if rank < rem else 0) | |
| end = start + size if rank < world_size - 1 else total_blocks | |
| block_ranges.append([start, end]) | |
| start = end | |
| return block_ranges | |
| class MultiGPUDemoInputSession: | |
| prompt: str | |
| noise_scale: float | |
| init_noise_scale: float | |
| chunk_size: int | |
| current_start: int | |
| current_end: int | |
| last_image: torch.Tensor | |
| chunk_idx: int = 0 | |
| input_batch: int = 0 | |
| current_step: int = 0 | |
| noisy_latents: torch.Tensor | None = None | |
| class InferencePipelineManager: | |
| """ | |
| Manages the inference pipeline with communication abstractions. | |
| This class encapsulates the main inference logic and uses the communication | |
| abstractions for distributed operations. | |
| """ | |
| def __init__(self, config, device: torch.device, rank: int, world_size: int): | |
| """ | |
| Initialize the inference pipeline manager. | |
| Args: | |
| config: Configuration object | |
| device: GPU device | |
| rank: Current rank | |
| world_size: Total number of ranks | |
| """ | |
| self.config = config | |
| self.device = device | |
| self.rank = rank | |
| self.world_size = world_size | |
| self.com_stream = torch.cuda.Stream() | |
| self.control_stream = torch.cuda.Stream() | |
| # Setup logging | |
| self.logger = setup_logging(rank) | |
| # Initialize communication components | |
| comm_config = CommunicationConfig( | |
| max_outstanding=config.get('max_outstanding', 1), | |
| buffer_pool_size=config.get('buffer_pool_size', 10), | |
| enable_buffer_reuse=config.get('enable_buffer_reuse', True) | |
| ) | |
| self.communicator = DistributedCommunicator(rank, world_size, device, comm_config) | |
| self.buffer_manager = BufferManager(device, comm_config) | |
| # Initialize pipeline | |
| self.pipeline = CausalStreamInferencePipeline(config, device=str(device)) | |
| self.pipeline.to(device=str(device), dtype=torch.bfloat16) | |
| # Initialize KV cache manager | |
| self.kv_cache_manager = KVCacheManager(self.pipeline, device) | |
| # Initialize model data transfer | |
| self.data_transfer = ModelDataTransfer( | |
| self.communicator, | |
| self.buffer_manager, | |
| self.kv_cache_manager, | |
| comm_config | |
| ) | |
| # Performance tracking | |
| self.t_dit = 100.0 | |
| self.t_total = 100.0 | |
| self.processed = 0 | |
| self.schedule_step = (self.world_size + len(config.denoising_step_list)) * 2 | |
| self.processed_offset = 3 | |
| self.base_chunk_size = 4 | |
| self.t_refresh = 50 | |
| self.profile = bool(config.get('profile', False)) | |
| self.encode_fps_list: list[float] = [] | |
| self.decode_fps_list: list[float] = [] | |
| self.logger.info(f"Initialized InferencePipelineManager for rank {rank}") | |
| def load_model(self, checkpoint_folder: str): | |
| """Load the model from checkpoint.""" | |
| ckpt_path, state_dict = load_generator_state_dict(checkpoint_folder) | |
| try: | |
| self.pipeline.generator.load_state_dict(state_dict, strict=True) | |
| except RuntimeError as exc: | |
| self.logger.warning(f"Strict load_state_dict failed: {exc}; retrying with strict=False") | |
| self.pipeline.generator.load_state_dict(state_dict, strict=False) | |
| self.logger.info(f"Model loaded successfully from {ckpt_path}") | |
| def prepare_pipeline(self, text_prompts: list, noise: torch.Tensor, | |
| block_mode: str, current_start: int, current_end: int, block_num: torch.Tensor): | |
| """Prepare the pipeline for inference.""" | |
| denoised_pred = self.pipeline.prepare( | |
| text_prompts=text_prompts, | |
| device=self.device, | |
| dtype=torch.bfloat16, | |
| noise=noise, | |
| block_mode=block_mode, | |
| current_start=current_start, | |
| current_end=current_end, | |
| block_num=block_num | |
| ) | |
| # Broadcast the prepared result from rank 0 | |
| self.data_transfer.broadcast_tensor(denoised_pred, src=0) | |
| return denoised_pred | |
| def _wait_for_outstanding(self, outstanding: list) -> None: | |
| """Keep the number of queued async sends bounded.""" | |
| while len(outstanding) >= self.config.get('max_outstanding', 1): | |
| oldest = outstanding.pop(0) | |
| for work in oldest: | |
| work.wait() | |
| def _drain_outstanding(self, outstanding: list) -> None: | |
| """Wait for all queued async sends to complete.""" | |
| while outstanding: | |
| oldest = outstanding.pop(0) | |
| for work in oldest: | |
| work.wait() | |
| def _maybe_schedule_blocks(self, schedule_block: bool, threshold: int, block_num: torch.Tensor, total_blocks: int) -> bool: | |
| """Run one-time block rebalancing when the warmup threshold is reached.""" | |
| if schedule_block and self.processed >= threshold: | |
| self._handle_block_scheduling(block_num, total_blocks) | |
| return False | |
| return schedule_block | |
| def _receive_latent_data(self, previous_latent_data, num_steps: int): | |
| """Release the previous payload and receive the next one from the upstream rank.""" | |
| with torch.cuda.stream(self.com_stream): | |
| if previous_latent_data is not None: | |
| self.data_transfer.release_latent_data(previous_latent_data) | |
| latent_data = self.data_transfer.receive_latent_data_async(num_steps) | |
| torch.cuda.current_stream().wait_stream(self.com_stream) | |
| return latent_data | |
| def _run_worker_stage(self, role: str, latent_data, block_num: torch.Tensor): | |
| """Execute the local DiT blocks for a middle or output rank.""" | |
| return self.pipeline.inference( | |
| noise=latent_data.original_latents, | |
| current_start=latent_data.current_start, | |
| current_end=latent_data.current_end, | |
| current_step=latent_data.current_step, | |
| block_mode=role, | |
| block_num=block_num, | |
| patched_x_shape=latent_data.patched_x_shape, | |
| block_x=latent_data.latents, | |
| ) | |
| def _send_worker_result(self, role: str, outstanding: list, latent_data, denoised_pred: torch.Tensor) -> None: | |
| """Forward the payload that should continue around the pipeline ring.""" | |
| if role == 'output': | |
| latents = latent_data.latents | |
| original_latents = denoised_pred | |
| else: | |
| latents = denoised_pred | |
| original_latents = latent_data.original_latents | |
| with torch.cuda.stream(self.com_stream): | |
| work_objects = self.data_transfer.send_latent_data_async( | |
| chunk_idx=latent_data.chunk_idx, | |
| latents=latents, | |
| original_latents=original_latents, | |
| patched_x_shape=latent_data.patched_x_shape, | |
| current_start=latent_data.current_start, | |
| current_end=latent_data.current_end, | |
| current_step=latent_data.current_step | |
| ) | |
| outstanding.append(work_objects) | |
| def _decode_prediction(self, denoised_pred: torch.Tensor) -> np.ndarray: | |
| """Decode the newest latent prediction into pixel-space frames.""" | |
| video = self._timed_stream_decode(denoised_pred[[-1]]) | |
| video = (video * 0.5 + 0.5).clamp(0, 1) | |
| video = video[0].permute(0, 2, 3, 1).contiguous() | |
| return video.cpu().float().numpy() | |
| def _rank_loop_complete(self, num_chunks: int, num_steps: int) -> bool: | |
| """Return whether a non-output rank has processed all required chunks.""" | |
| return ( | |
| self.processed + self.processed_offset | |
| >= num_chunks + num_steps * self.world_size + self.world_size - self.rank - 1 | |
| ) | |
| def _safe_mean(self, values: list) -> float: | |
| if not values: | |
| return 0.0 | |
| return float(np.mean(np.array(values))) | |
| def _record_stage_fps(self, values: list[float], num_frames: int, elapsed: float) -> None: | |
| if self.profile and elapsed > 0 and num_frames > 0: | |
| values.append(num_frames / elapsed) | |
| def _timing_enabled(self, schedule_block: bool = False) -> bool: | |
| """Only force GPU synchronization when profiling or schedule calibration needs it.""" | |
| return self.profile or schedule_block | |
| def _sync_for_timing(self, schedule_block: bool = False) -> None: | |
| if self._timing_enabled(schedule_block): | |
| torch.cuda.synchronize() | |
| def _timed_stream_encode(self, images: torch.Tensor) -> torch.Tensor: | |
| self._sync_for_timing() | |
| start_time = time.time() | |
| latents = self.pipeline.vae.stream_encode(images) | |
| self._sync_for_timing() | |
| self._record_stage_fps(self.encode_fps_list, int(images.shape[2]), time.time() - start_time) | |
| return latents | |
| def _timed_stream_decode(self, denoised_pred: torch.Tensor) -> torch.Tensor: | |
| self._sync_for_timing() | |
| start_time = time.time() | |
| video = self.pipeline.vae.stream_decode_to_pixel(denoised_pred) | |
| self._sync_for_timing() | |
| self._record_stage_fps(self.decode_fps_list, int(video.shape[1]), time.time() - start_time) | |
| return video | |
| def reset_stream_state(self, reset_encode: bool = False, reset_decode: bool = False) -> None: | |
| """Reset cached inference state before starting a new prompt/session.""" | |
| self.pipeline.kv_cache1 = None | |
| self.pipeline.crossattn_cache = None | |
| self.pipeline.block_x = None | |
| self.pipeline.hidden_states = None | |
| self.processed = 0 | |
| if reset_encode: | |
| self.pipeline.vae.model.first_encode = True | |
| if reset_decode: | |
| self.pipeline.vae.model.first_decode = True | |
| def _broadcast_initial_noise(self, noisy_latents: torch.Tensor) -> None: | |
| latents_shape = torch.tensor(noisy_latents.shape, dtype=torch.int64, device=self.device) | |
| self.communicator.broadcast_tensor(latents_shape, src=0) | |
| self.communicator.broadcast_tensor(noisy_latents, src=0) | |
| def _receive_initial_noise(self) -> torch.Tensor: | |
| latents_shape = torch.zeros(5, dtype=torch.int64, device=self.device) | |
| self.communicator.broadcast_tensor(latents_shape, src=0) | |
| noisy_latents = torch.zeros(tuple(latents_shape.tolist()), dtype=torch.bfloat16, device=self.device) | |
| self.communicator.broadcast_tensor(noisy_latents, src=0) | |
| return noisy_latents | |
| def get_demo_chunk_size(self) -> int: | |
| """Return the demo stream chunk size in frames.""" | |
| return self.base_chunk_size * self.pipeline.num_frame_per_block | |
| def get_demo_first_batch_num_frames(self) -> int: | |
| """Return the number of frames required to initialize a demo stream.""" | |
| return 1 + self.get_demo_chunk_size() | |
| def prepare_demo_input_session(self, images: torch.Tensor, prompt: str, block_num: torch.Tensor, noise_scale: float) -> None: | |
| """Initialize rank 0 for demo streaming and broadcast the first noisy latents.""" | |
| self.reset_stream_state(reset_encode=True) | |
| torch.cuda.empty_cache() | |
| latents = self._timed_stream_encode(images) | |
| latents = latents.transpose(2, 1).contiguous().to(dtype=torch.bfloat16) | |
| noise = torch.randn_like(latents) | |
| noisy_latents = noise * noise_scale + latents * (1 - noise_scale) | |
| self._broadcast_initial_noise(noisy_latents) | |
| self.prepare_pipeline( | |
| text_prompts=[prompt], | |
| noise=noisy_latents, | |
| block_mode='input', | |
| current_start=0, | |
| current_end=self.pipeline.frame_seq_length * 2, | |
| block_num=block_num, | |
| ) | |
| torch.cuda.empty_cache() | |
| dist.barrier() | |
| def start_demo_input_stream_session( | |
| self, | |
| prompt: str, | |
| images: torch.Tensor, | |
| block_num: torch.Tensor, | |
| noise_scale: float, | |
| ) -> MultiGPUDemoInputSession: | |
| """Initialize rank 0 and return the demo stream session state.""" | |
| chunk_size = self.get_demo_chunk_size() | |
| self.prepare_demo_input_session(images, prompt, block_num, noise_scale) | |
| current_start = self.pipeline.frame_seq_length * (1 + chunk_size // self.base_chunk_size) | |
| current_end = current_start + (chunk_size // self.base_chunk_size) * self.pipeline.frame_seq_length | |
| return MultiGPUDemoInputSession( | |
| prompt=prompt, | |
| noise_scale=noise_scale, | |
| init_noise_scale=noise_scale, | |
| chunk_size=chunk_size, | |
| current_start=current_start, | |
| current_end=current_end, | |
| last_image=images[:, :, [-1]], | |
| ) | |
| def prepare_demo_worker_session(self, prompt: str, block_mode: str, block_num: torch.Tensor, decode_initial: bool = False): | |
| """Initialize a non-input rank for demo streaming from the broadcast first chunk.""" | |
| self.reset_stream_state(reset_decode=(block_mode == 'output')) | |
| torch.cuda.empty_cache() | |
| noisy_latents = self._receive_initial_noise() | |
| denoised_pred = self.prepare_pipeline( | |
| text_prompts=[prompt], | |
| noise=noisy_latents, | |
| block_mode=block_mode, | |
| current_start=0, | |
| current_end=self.pipeline.frame_seq_length * 2, | |
| block_num=block_num, | |
| ) | |
| torch.cuda.empty_cache() | |
| dist.barrier() | |
| if decode_initial: | |
| return self._decode_prediction(denoised_pred) | |
| return None | |
| def maybe_refresh_demo_input_window(self, session: MultiGPUDemoInputSession) -> None: | |
| """Wrap the KV-cache window once the streaming refresh threshold is reached.""" | |
| if session.current_start // self.pipeline.frame_seq_length >= self.t_refresh: | |
| session.current_start = self.pipeline.kv_cache_length - self.pipeline.frame_seq_length | |
| session.current_end = session.current_start + (session.chunk_size // self.base_chunk_size) * self.pipeline.frame_seq_length | |
| def prepare_demo_input_batch(self, session: MultiGPUDemoInputSession, images: torch.Tensor) -> None: | |
| """Encode one demo chunk and update the session with the current denoising step.""" | |
| num_frames = images.shape[2] | |
| session.input_batch = num_frames // session.chunk_size | |
| session.noise_scale, session.current_step = compute_noise_scale_and_step( | |
| input_video_original=torch.cat([session.last_image, images], dim=2), | |
| end_idx=num_frames + 1, | |
| chunk_size=num_frames, | |
| noise_scale=float(session.noise_scale), | |
| init_noise_scale=float(session.init_noise_scale), | |
| ) | |
| latents = self._timed_stream_encode(images) | |
| latents = latents.transpose(2, 1).contiguous().to(dtype=torch.bfloat16) | |
| noise = torch.randn_like(latents) | |
| session.noisy_latents = noise * session.noise_scale + latents * (1 - session.noise_scale) | |
| def run_demo_input_step( | |
| self, | |
| session: MultiGPUDemoInputSession, | |
| block_num: torch.Tensor, | |
| previous_latent_data=None, | |
| ): | |
| """Run one rank-0 demo step from the current session batch.""" | |
| if session.noisy_latents is None or session.input_batch <= 0: | |
| raise RuntimeError("demo input batch was not prepared before run_demo_input_step") | |
| denoised_pred, patched_x_shape = self.run_input_stage( | |
| noisy_latents=session.noisy_latents[:, -session.input_batch].unsqueeze(1), | |
| current_start=session.current_start, | |
| current_end=session.current_end, | |
| current_step=session.current_step, | |
| block_num=block_num, | |
| previous_latent_data=previous_latent_data, | |
| ) | |
| session.input_batch -= 1 | |
| return denoised_pred, patched_x_shape | |
| def advance_demo_input_stream_session(self, session: MultiGPUDemoInputSession, images: torch.Tensor) -> None: | |
| """Advance the demo stream session after a chunk has been queued downstream.""" | |
| session.last_image = images[:, :, [-1]] | |
| session.chunk_idx += 1 | |
| session.current_start = session.current_end | |
| session.current_end += (session.chunk_size // self.base_chunk_size) * self.pipeline.frame_seq_length | |
| def send_demo_input_prompt_update( | |
| self, | |
| prompt: str, | |
| device: torch.device, | |
| num_steps: int, | |
| chunk_idx: int, | |
| denoised_pred: torch.Tensor, | |
| patched_x_shape: torch.Tensor, | |
| current_step: int, | |
| ) -> None: | |
| """Signal a prompt restart from rank 0 and drain in-flight returns from downstream ranks.""" | |
| with torch.cuda.stream(self.com_stream): | |
| self.data_transfer.send_latent_data_async( | |
| chunk_idx=-1, | |
| latents=denoised_pred.new_zeros([1] * denoised_pred.ndim), | |
| original_latents=self.pipeline.hidden_states.new_zeros([1] * self.pipeline.hidden_states.ndim), | |
| patched_x_shape=patched_x_shape, | |
| current_start=self.pipeline.kv_cache_starts, | |
| current_end=self.pipeline.kv_cache_ends, | |
| current_step=int(current_step), | |
| ) | |
| self.data_transfer.send_prompt_async(prompt, device) | |
| for _ in range(min(chunk_idx, self.world_size - 1)): | |
| pending_data = self.data_transfer.receive_latent_data_async(num_steps) | |
| self.data_transfer.release_latent_data(pending_data) | |
| def send_demo_middle_prompt_update( | |
| self, | |
| prompt: str, | |
| device: torch.device, | |
| denoised_pred: torch.Tensor | None, | |
| latent_data, | |
| ) -> None: | |
| """Forward a prompt restart from a middle rank to the next rank.""" | |
| sentinel_source = denoised_pred if denoised_pred is not None else latent_data.latents | |
| with torch.cuda.stream(self.com_stream): | |
| self.data_transfer.send_latent_data_async( | |
| chunk_idx=-1, | |
| latents=sentinel_source.new_zeros([1] * sentinel_source.ndim), | |
| original_latents=latent_data.original_latents, | |
| patched_x_shape=latent_data.patched_x_shape, | |
| current_start=latent_data.current_start, | |
| current_end=latent_data.current_end, | |
| current_step=int(latent_data.current_step), | |
| ) | |
| self.data_transfer.send_prompt_async(prompt, device) | |
| def run_input_stage(self, noisy_latents: torch.Tensor, current_start: int, current_end: int, current_step: int, block_num: torch.Tensor, previous_latent_data=None): | |
| """Run the rank-0 stage for one streaming chunk.""" | |
| if previous_latent_data is not None and self.processed >= self.world_size: | |
| self.pipeline.hidden_states.copy_(previous_latent_data.original_latents) | |
| self.pipeline.kv_cache_starts.copy_(previous_latent_data.current_start) | |
| self.pipeline.kv_cache_ends.copy_(previous_latent_data.current_end) | |
| return self.pipeline.inference( | |
| noise=noisy_latents, | |
| current_start=current_start, | |
| current_end=current_end, | |
| current_step=current_step, | |
| block_mode='input', | |
| block_num=block_num, | |
| ) | |
| def run_rank_0_loop(self, input_video_original: torch.Tensor, prompts: list, | |
| num_chunks: int, num_steps: int, chunk_size: int, | |
| block_num: torch.Tensor, noise_scale: float, | |
| schedule_block: bool, total_blocks: int): | |
| """ | |
| Run the main loop for rank 0 (encoder + async send). | |
| This method encapsulates the rank 0 logic using the communication abstractions. | |
| """ | |
| self.logger.info("Starting rank 0 inference loop") | |
| # Initialize variables | |
| start_idx = 0 | |
| end_idx = 1 + chunk_size | |
| current_start = 0 | |
| current_end = self.pipeline.frame_seq_length * (1+chunk_size//self.base_chunk_size) | |
| init_noise_scale = noise_scale | |
| outstanding = [] | |
| latent_data = None | |
| self._sync_for_timing(schedule_block) | |
| start_time = time.time() | |
| while True: | |
| # Process new chunk if available | |
| start_idx = end_idx | |
| end_idx = end_idx + chunk_size | |
| current_start = current_end | |
| current_end = current_end + (chunk_size // self.base_chunk_size) * self.pipeline.frame_seq_length | |
| if schedule_block: | |
| self._sync_for_timing(schedule_block) | |
| start_vae = time.time() | |
| if end_idx <= input_video_original.shape[2]: | |
| inp = input_video_original[:, :, start_idx:end_idx] | |
| noise_scale, current_step = compute_noise_scale_and_step( | |
| input_video_original, end_idx, chunk_size, noise_scale, init_noise_scale | |
| ) | |
| latents = self._timed_stream_encode(inp) | |
| latents = latents.transpose(2, 1).contiguous().to(dtype=torch.bfloat16) | |
| noise = torch.randn_like(latents) | |
| noisy_latents = noise * noise_scale + latents * (1 - noise_scale) | |
| # if current_start//self.pipeline.frame_seq_length >= self.t_refresh: | |
| # current_start = self.pipeline.kv_cache_length - self.pipeline.frame_seq_length | |
| # current_end = current_start + (chunk_size // self.base_chunk_size) * self.pipeline.frame_seq_length | |
| # Measure DiT time if scheduling is enabled | |
| if schedule_block: | |
| self._sync_for_timing(schedule_block) | |
| start_dit = time.time() | |
| t_vae = start_dit - start_vae | |
| # Run inference | |
| denoised_pred, patched_x_shape = self.pipeline.inference( | |
| noise=noisy_latents, | |
| current_start=current_start, | |
| current_end=current_end, | |
| current_step=current_step, | |
| block_mode='input', | |
| block_num=block_num[self.rank], | |
| ) | |
| # Update DiT timing | |
| if schedule_block: | |
| self._sync_for_timing(schedule_block) | |
| temp = time.time() - start_dit | |
| if temp < self.t_dit: | |
| self.t_dit = temp | |
| self.processed += 1 | |
| with torch.cuda.stream(self.com_stream): | |
| if self.processed >= self.world_size: | |
| if latent_data is not None: | |
| self.data_transfer.release_latent_data(latent_data) | |
| # Receive data from previous rank | |
| latent_data = self.data_transfer.receive_latent_data_async(num_steps) | |
| torch.cuda.current_stream().wait_stream(self.com_stream) | |
| # Wait for outstanding operations | |
| self._wait_for_outstanding(outstanding) | |
| # Send data to next rank | |
| with torch.cuda.stream(self.com_stream): | |
| work_objects = self.data_transfer.send_latent_data_async( | |
| chunk_idx=start_idx, | |
| latents=denoised_pred, | |
| original_latents=self.pipeline.hidden_states, | |
| patched_x_shape=patched_x_shape, | |
| current_start=self.pipeline.kv_cache_starts, | |
| current_end=self.pipeline.kv_cache_ends, | |
| current_step=current_step | |
| ) | |
| outstanding.append(work_objects) | |
| # Handle block scheduling | |
| if schedule_block and self.processed >= self.schedule_step: | |
| self._handle_block_scheduling(block_num, total_blocks) | |
| schedule_block = False | |
| # Update timing and check completion | |
| if self._timing_enabled(schedule_block): | |
| self._sync_for_timing(schedule_block) | |
| end_time = time.time() | |
| t = end_time - start_time | |
| self.logger.info(f"Encode {self.processed}, time: {t:.4f} s, fps: {inp.shape[2]/t:.4f}") | |
| if schedule_block: | |
| t_total = self.t_dit + t_vae | |
| if t_total < self.t_total: | |
| self.t_total = t_total | |
| start_time = end_time | |
| if self.processed >= self.world_size: | |
| self.pipeline.hidden_states.copy_(latent_data.original_latents) | |
| self.pipeline.kv_cache_starts.copy_(latent_data.current_start) | |
| self.pipeline.kv_cache_ends.copy_(latent_data.current_end) | |
| if self.processed + self.processed_offset >= num_chunks + num_steps * self.world_size + self.world_size - self.rank - 1: | |
| break | |
| if latent_data is not None: | |
| self.data_transfer.release_latent_data(latent_data) | |
| self._drain_outstanding(outstanding) | |
| self.logger.info(f"VAE Encode Average FPS: {self._safe_mean(self.encode_fps_list):.4f}") | |
| self.logger.info("Rank 0 inference loop completed") | |
| def run_final_rank_loop(self, num_chunks: int, num_steps: int, chunk_size: int, | |
| block_num: torch.Tensor, output_folder: str, fps: int, | |
| schedule_block: bool, total_blocks: int, results: dict): | |
| """Run the worker loop for the output rank.""" | |
| self.run_worker_rank_loop( | |
| role='output', | |
| num_chunks=num_chunks, | |
| num_steps=num_steps, | |
| chunk_size=chunk_size, | |
| block_num=block_num, | |
| schedule_block=schedule_block, | |
| total_blocks=total_blocks, | |
| output_folder=output_folder, | |
| fps=fps, | |
| results=results, | |
| ) | |
| def run_middle_rank_loop(self, num_chunks: int, num_steps: int, chunk_size: int, | |
| block_num: torch.Tensor, schedule_block: bool, total_blocks: int): | |
| """Run the worker loop for a middle rank.""" | |
| self.run_worker_rank_loop( | |
| role='middle', | |
| num_chunks=num_chunks, | |
| num_steps=num_steps, | |
| chunk_size=chunk_size, | |
| block_num=block_num, | |
| schedule_block=schedule_block, | |
| total_blocks=total_blocks, | |
| ) | |
| def run_worker_rank_loop( | |
| self, | |
| role: str, | |
| num_chunks: int, | |
| num_steps: int, | |
| chunk_size: int, | |
| block_num: torch.Tensor, | |
| schedule_block: bool, | |
| total_blocks: int, | |
| output_folder: str = None, | |
| fps: int = None, | |
| results: dict = None, | |
| ): | |
| """Run the shared receive -> infer -> forward loop for middle and output ranks.""" | |
| if role not in {'middle', 'output'}: | |
| raise ValueError(f"Unsupported worker role: {role}") | |
| self.logger.info(f"Starting {role} rank inference loop") | |
| if role == 'output': | |
| if output_folder is None or fps is None or results is None: | |
| raise ValueError("output rank requires output_folder, fps, and results") | |
| os.makedirs(output_folder, exist_ok=True) | |
| save_results = 1 | |
| outstanding = [] | |
| fps_list = [] | |
| latent_data = None | |
| self._sync_for_timing(schedule_block) | |
| start_time = time.time() | |
| while True: | |
| latent_data = self._receive_latent_data(latent_data, num_steps) | |
| schedule_block = self._maybe_schedule_blocks( | |
| schedule_block, | |
| self.schedule_step - self.rank, | |
| block_num, | |
| total_blocks, | |
| ) | |
| if schedule_block: | |
| self._sync_for_timing(schedule_block) | |
| start_dit = time.time() | |
| denoised_pred, _ = self._run_worker_stage(role, latent_data, block_num[self.rank]) | |
| if schedule_block: | |
| self._sync_for_timing(schedule_block) | |
| temp = time.time() - start_dit | |
| if temp < self.t_dit: | |
| self.t_dit = temp | |
| self.processed += 1 | |
| self._wait_for_outstanding(outstanding) | |
| self._send_worker_result(role, outstanding, latent_data, denoised_pred) | |
| if role == 'output': | |
| if self.processed >= num_steps * self.world_size - 1: | |
| if schedule_block: | |
| self._sync_for_timing(schedule_block) | |
| start_vae = time.time() | |
| video = self._timed_stream_decode(denoised_pred[[-1]]) | |
| video = (video * 0.5 + 0.5).clamp(0, 1) | |
| video = video[0].permute(0, 2, 3, 1).contiguous() | |
| results[save_results] = video.cpu().float().numpy() | |
| if self._timing_enabled(schedule_block): | |
| self._sync_for_timing(schedule_block) | |
| end_time = time.time() | |
| elapsed = end_time - start_time | |
| fps_test = video.shape[0] / elapsed | |
| if self.processed > self.schedule_step: | |
| fps_list.append(fps_test) | |
| self.logger.info(f"Decode {self.processed}, time: {elapsed:.4f} s, FPS: {fps_test:.4f}") | |
| if schedule_block: | |
| t_vae = end_time - start_vae | |
| t_total = t_vae + self.t_dit | |
| if t_total < self.t_total: | |
| self.t_total = t_total | |
| start_time = end_time | |
| save_results += 1 | |
| if save_results >= num_chunks: | |
| break | |
| else: | |
| if self._timing_enabled(schedule_block): | |
| self._sync_for_timing(schedule_block) | |
| end_time = time.time() | |
| elapsed = end_time - start_time | |
| fps_test = chunk_size / elapsed | |
| if self.processed > self.schedule_step: | |
| fps_list.append(fps_test) | |
| if schedule_block: | |
| t_total = self.t_dit | |
| if t_total < self.t_total: | |
| self.t_total = t_total | |
| self.logger.info(f"Middle {self.processed}, time: {elapsed:.4f} s, fps: {fps_test:.4f}") | |
| start_time = end_time | |
| if self._rank_loop_complete(num_chunks, num_steps): | |
| break | |
| if latent_data is not None: | |
| self.data_transfer.release_latent_data(latent_data) | |
| self._drain_outstanding(outstanding) | |
| if role == 'output': | |
| video_list = [results[i] for i in range(num_chunks)] | |
| video = np.concatenate(video_list, axis=0) | |
| fps_avg = self._safe_mean(fps_list) | |
| self.logger.info(f"Video shape: {video.shape}, Average FPS: {fps_avg:.4f}") | |
| self.logger.info(f"VAE Decode Average FPS: {self._safe_mean(self.decode_fps_list):.4f}") | |
| output_path = os.path.join(output_folder, f"output_{0:03d}.mp4") | |
| export_to_video(video, output_path, fps=fps) | |
| self.logger.info(f"Video saved to: {output_path} (Press Ctrl+C to force exit)") | |
| return | |
| self.logger.info(f"DiT Average FPS: {self._safe_mean(fps_list):.4f}") | |
| self.logger.info(f"Rank {self.rank} inference loop completed") | |
| def _handle_block_scheduling(self, block_num: torch.Tensor, total_blocks: int): | |
| """Handle block scheduling and rebalancing.""" | |
| self.logger.info(f"Scheduling block in {self.processed}") | |
| # Gather timing information from all ranks | |
| t_total_tensor = torch.tensor(self.t_total, dtype=torch.float32, device=self.device) | |
| t_dit_tensor = torch.tensor(self.t_dit, dtype=torch.float32, device=self.device) | |
| gather_blocks = [torch.zeros_like(t_dit_tensor, dtype=torch.float32, device=self.device) | |
| for _ in range(self.world_size)] | |
| dist.all_gather(gather_blocks, t_dit_tensor) | |
| t_dit_list = [t_dit_i.item() for t_dit_i in gather_blocks] | |
| dist.all_gather(gather_blocks, t_total_tensor) | |
| t_list = [t_i.item() for t_i in gather_blocks] | |
| # Compute new block distribution | |
| new_block_num = torch.tensor( | |
| compute_balanced_split(total_blocks, t_list, t_dit_list, block_num.tolist()), | |
| dtype=torch.int64, device=self.device | |
| ) | |
| self.logger.info(f"New block distribution: {new_block_num[self.rank].tolist()}") | |
| # Broadcast new block distribution | |
| dist.broadcast(new_block_num, src=self.world_size - 1) | |
| # Rebalance KV cache | |
| self.data_transfer.rebalance_kv_cache(block_num, new_block_num, total_blocks) | |
| # Update block_num | |
| block_num.copy_(new_block_num) | |
| start_block, end_block = block_num[self.rank][0].item(), block_num[self.rank][1].item() | |
| blocks_to_keep = list(range(start_block, end_block)) | |
| for i in range(self.pipeline.num_transformer_blocks): | |
| if i not in blocks_to_keep: | |
| self.pipeline.kv_cache1[i]['k'] = self.pipeline.kv_cache1[i]['k'].cpu() | |
| self.pipeline.kv_cache1[i]['v'] = self.pipeline.kv_cache1[i]['v'].cpu() | |
| self.logger.info("Block scheduling completed") | |
| def cleanup(self): | |
| """Clean up resources.""" | |
| self.data_transfer.cleanup() | |
| self.logger.info("InferencePipelineManager cleanup completed") | |
| def main(): | |
| """Main function for the refactored inference pipeline.""" | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config_path", type=str) | |
| parser.add_argument("--checkpoint_folder", type=str) | |
| parser.add_argument("--output_folder", type=str) | |
| parser.add_argument("--prompt_file_path", type=str) | |
| parser.add_argument("--video_path", type=str) | |
| parser.add_argument("--noise_scale", type=float, default=0.8) | |
| parser.add_argument("--height", type=int, default=480) | |
| parser.add_argument("--width", type=int, default=832) | |
| parser.add_argument("--fps", type=int, default=30) | |
| parser.add_argument("--max_outstanding", type=int, default=1, help="max number of outstanding sends/recv to keep") | |
| parser.add_argument("--dit_fsdp", action="store_true", default=False) | |
| parser.add_argument("--t5_fsdp", action="store_true", default=False) | |
| parser.add_argument("--ulysses_size", type=int, default=1) | |
| parser.add_argument("--ring_size", type=int, default=1) | |
| parser.add_argument("--step", type=int, default=2) | |
| parser.add_argument("--seed", type=int, default=0, help="Random seed") | |
| parser.add_argument("--schedule_block", action="store_true", default=False) | |
| parser.add_argument("--profile", action="store_true", default=False, help="Enable synchronized throughput logging") | |
| parser.add_argument("--t2v", action="store_true", default=False) | |
| parser.add_argument("--model_type", type=str, default="T2V-1.3B", help="Model type (e.g., T2V-1.3B)") | |
| parser.add_argument("--use_taehv", action="store_true", default=False, help="Use the lightweight TAEHV VAE for encode/decode") | |
| parser.add_argument("--use_tensorrt", "--use_taehv_tensorrt", dest="use_tensorrt", action="store_true", default=False, help="Enable available TensorRT acceleration paths") | |
| parser.add_argument("--fast", action="store_true", default=False, help="Enable the fast path: --use_taehv --use_tensorrt") | |
| args = parser.parse_args() | |
| torch.set_grad_enabled(False) | |
| init_distributed() | |
| rank = dist.get_rank() | |
| world_size = dist.get_world_size() | |
| local_rank = int(os.environ.get("LOCAL_RANK", rank)) | |
| assert world_size >= 2, "world_size must be at least 2" | |
| torch.cuda.set_device(local_rank) | |
| device = torch.device(f"cuda:{local_rank}") | |
| # Load configuration | |
| config = merge_cli_config(args.config_path, args) | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| LOGGER.info("Denoising Step List: %s", list(config.denoising_step_list)) | |
| set_seed(args.seed) | |
| # Load input video | |
| input_video_original = load_mp4_as_tensor(args.video_path, resize_hw=(args.height, args.width)).unsqueeze(0) | |
| if input_video_original.dtype != torch.bfloat16: | |
| input_video_original = input_video_original.to(dtype=torch.bfloat16).to(device) | |
| LOGGER.info("Input video tensor shape: %s", tuple(input_video_original.shape)) | |
| b, c, t, h, w = input_video_original.shape | |
| # Calculate number of chunks | |
| chunk_size = 4 * config.num_frame_per_block | |
| if rank == 0: | |
| num_chunks = (t - 1) // chunk_size | |
| else: | |
| num_chunks = 0 | |
| num_chunks_tensor = torch.tensor([num_chunks], dtype=torch.int64, device=device) | |
| dist.broadcast(num_chunks_tensor, src=0) | |
| num_chunks = int(num_chunks_tensor.item()) | |
| # Initialize pipeline manager | |
| pipeline_manager = InferencePipelineManager(config, device, rank, world_size) | |
| pipeline_manager.load_model(args.checkpoint_folder) | |
| # Load prompts | |
| dataset = TextDataset(args.prompt_file_path) | |
| prompts = [dataset[0]] | |
| num_steps = len(pipeline_manager.pipeline.denoising_step_list) | |
| # Determine block mode and setup block distribution | |
| if rank == 0: | |
| block_mode = 'input' | |
| elif rank == world_size - 1: | |
| block_mode = 'output' | |
| else: | |
| block_mode = 'middle' | |
| # Setup block distribution | |
| total_blocks = pipeline_manager.pipeline.num_transformer_blocks | |
| total_block_num = compute_default_block_distribution(total_blocks, world_size) | |
| block_num = torch.tensor(total_block_num, dtype=torch.int64, device=device) | |
| # Prepare pipeline | |
| start_idx = 0 | |
| end_idx = 5 | |
| current_start = 0 | |
| current_end = pipeline_manager.pipeline.frame_seq_length * 2 | |
| inp = input_video_original[:, :, start_idx:end_idx] | |
| # Only rank 0 performs VAE encoding operation | |
| if rank == 0: | |
| latents = pipeline_manager._timed_stream_encode(inp) | |
| latents = latents.transpose(2, 1).contiguous().to(dtype=torch.bfloat16) | |
| noise = torch.randn_like(latents) | |
| noisy_latents = noise * args.noise_scale + latents * (1 - args.noise_scale) | |
| # First broadcast the shape information | |
| latents_shape = torch.tensor(latents.shape, dtype=torch.int64, device=device) | |
| pipeline_manager.communicator.broadcast_tensor(latents_shape, src=0) | |
| # Then broadcast noisy_latents | |
| pipeline_manager.communicator.broadcast_tensor(noisy_latents, src=0) | |
| else: | |
| # Other ranks receive shape info first | |
| latents_shape = torch.zeros(5, dtype=torch.int64, device=device) | |
| pipeline_manager.communicator.broadcast_tensor(latents_shape, src=0) | |
| # Create tensor with same shape for receiving broadcast data | |
| noisy_latents = torch.zeros(tuple(latents_shape.tolist()), dtype=torch.bfloat16, device=device) | |
| # Receive the broadcasted noisy_latents | |
| pipeline_manager.communicator.broadcast_tensor(noisy_latents, src=0) | |
| denoised_pred = pipeline_manager.prepare_pipeline( | |
| text_prompts=prompts, | |
| noise=noisy_latents, | |
| block_mode=block_mode, | |
| current_start=current_start, | |
| current_end=current_end, | |
| block_num=block_num[rank], | |
| ) | |
| # Clear unused GPU memory | |
| torch.cuda.empty_cache() | |
| # Save initial result for final rank | |
| if rank == world_size - 1: | |
| results = {} | |
| video = pipeline_manager._timed_stream_decode(denoised_pred) | |
| video = (video * 0.5 + 0.5).clamp(0, 1) | |
| video = video[0].permute(0, 2, 3, 1).contiguous() | |
| results[0] = video.cpu().float().numpy() | |
| dist.barrier() | |
| pipeline_manager.logger.info(f"Prepared, Block num: {block_num[rank].tolist()}") | |
| used_mem = torch.cuda.memory_allocated(device) / 1024 / 1024 / 1024 | |
| total_mem = torch.cuda.get_device_properties(device).total_memory / 1024 / 1024 / 1024 | |
| pipeline_manager.logger.info(f"Current GPU memory usage: {used_mem:.2f} GB / {total_mem:.2f} GB") | |
| # Run appropriate loop based on rank | |
| try: | |
| if rank == 0: | |
| pipeline_manager.run_rank_0_loop( | |
| input_video_original, prompts, num_chunks, num_steps, chunk_size, | |
| block_num, args.noise_scale, args.schedule_block, total_blocks | |
| ) | |
| elif rank == world_size - 1: | |
| pipeline_manager.run_final_rank_loop( | |
| num_chunks, num_steps, chunk_size, block_num, args.output_folder, | |
| args.fps, args.schedule_block, total_blocks, results | |
| ) | |
| else: | |
| pipeline_manager.run_middle_rank_loop( | |
| num_chunks, num_steps, chunk_size, block_num, args.schedule_block, total_blocks | |
| ) | |
| finally: | |
| # Cleanup | |
| pipeline_manager.cleanup() | |
| dist.barrier() | |
| dist.destroy_process_group() | |
| if __name__ == "__main__": | |
| main() | |