multimodalart's picture
multimodalart HF Staff
Upload folder using huggingface_hub
5c93746 verified
Raw
History Blame Contribute Delete
43.3 kB
"""
Refactored multi-rank inference pipeline with communication abstractions.
This is a refactored version of inference_pipe_multi.py that uses the new
communication abstraction layers for better code organization and maintainability.
"""
from models.wan.causal_stream_inference import CausalStreamInferencePipeline
from models.util import set_seed
from diffusers.utils import export_to_video
from models.data import TextDataset
import argparse
from dataclasses import dataclass
import torch
import torch.distributed as dist
import os
import time
import numpy as np
import logging
try:
from streamv2v.inference import compute_noise_scale_and_step
from streamv2v.communication import (
DistributedCommunicator,
ModelDataTransfer,
BufferManager,
KVCacheManager,
CommunicationConfig,
init_distributed,
setup_logging,
compute_balanced_split
)
from streamv2v.inference_common import (
load_generator_state_dict,
load_mp4_as_tensor,
merge_cli_config,
)
except ModuleNotFoundError:
from inference import compute_noise_scale_and_step
from communication import (
DistributedCommunicator,
ModelDataTransfer,
BufferManager,
KVCacheManager,
CommunicationConfig,
init_distributed,
setup_logging,
compute_balanced_split
)
from inference_common import (
load_generator_state_dict,
load_mp4_as_tensor,
merge_cli_config,
)
LOGGER = logging.getLogger(__name__)
def compute_default_block_distribution(total_blocks: int, world_size: int) -> list[list[int]]:
"""Split transformer blocks into contiguous ranges for each rank."""
if world_size == 2:
midpoint = total_blocks // 2
return [[0, midpoint], [midpoint, total_blocks]]
base = total_blocks // world_size
rem = total_blocks % world_size
start = 0
block_ranges = []
for rank in range(world_size):
size = base + (1 if rank < rem else 0)
end = start + size if rank < world_size - 1 else total_blocks
block_ranges.append([start, end])
start = end
return block_ranges
@dataclass
class MultiGPUDemoInputSession:
prompt: str
noise_scale: float
init_noise_scale: float
chunk_size: int
current_start: int
current_end: int
last_image: torch.Tensor
chunk_idx: int = 0
input_batch: int = 0
current_step: int = 0
noisy_latents: torch.Tensor | None = None
class InferencePipelineManager:
"""
Manages the inference pipeline with communication abstractions.
This class encapsulates the main inference logic and uses the communication
abstractions for distributed operations.
"""
def __init__(self, config, device: torch.device, rank: int, world_size: int):
"""
Initialize the inference pipeline manager.
Args:
config: Configuration object
device: GPU device
rank: Current rank
world_size: Total number of ranks
"""
self.config = config
self.device = device
self.rank = rank
self.world_size = world_size
self.com_stream = torch.cuda.Stream()
self.control_stream = torch.cuda.Stream()
# Setup logging
self.logger = setup_logging(rank)
# Initialize communication components
comm_config = CommunicationConfig(
max_outstanding=config.get('max_outstanding', 1),
buffer_pool_size=config.get('buffer_pool_size', 10),
enable_buffer_reuse=config.get('enable_buffer_reuse', True)
)
self.communicator = DistributedCommunicator(rank, world_size, device, comm_config)
self.buffer_manager = BufferManager(device, comm_config)
# Initialize pipeline
self.pipeline = CausalStreamInferencePipeline(config, device=str(device))
self.pipeline.to(device=str(device), dtype=torch.bfloat16)
# Initialize KV cache manager
self.kv_cache_manager = KVCacheManager(self.pipeline, device)
# Initialize model data transfer
self.data_transfer = ModelDataTransfer(
self.communicator,
self.buffer_manager,
self.kv_cache_manager,
comm_config
)
# Performance tracking
self.t_dit = 100.0
self.t_total = 100.0
self.processed = 0
self.schedule_step = (self.world_size + len(config.denoising_step_list)) * 2
self.processed_offset = 3
self.base_chunk_size = 4
self.t_refresh = 50
self.profile = bool(config.get('profile', False))
self.encode_fps_list: list[float] = []
self.decode_fps_list: list[float] = []
self.logger.info(f"Initialized InferencePipelineManager for rank {rank}")
def load_model(self, checkpoint_folder: str):
"""Load the model from checkpoint."""
ckpt_path, state_dict = load_generator_state_dict(checkpoint_folder)
try:
self.pipeline.generator.load_state_dict(state_dict, strict=True)
except RuntimeError as exc:
self.logger.warning(f"Strict load_state_dict failed: {exc}; retrying with strict=False")
self.pipeline.generator.load_state_dict(state_dict, strict=False)
self.logger.info(f"Model loaded successfully from {ckpt_path}")
def prepare_pipeline(self, text_prompts: list, noise: torch.Tensor,
block_mode: str, current_start: int, current_end: int, block_num: torch.Tensor):
"""Prepare the pipeline for inference."""
denoised_pred = self.pipeline.prepare(
text_prompts=text_prompts,
device=self.device,
dtype=torch.bfloat16,
noise=noise,
block_mode=block_mode,
current_start=current_start,
current_end=current_end,
block_num=block_num
)
# Broadcast the prepared result from rank 0
self.data_transfer.broadcast_tensor(denoised_pred, src=0)
return denoised_pred
def _wait_for_outstanding(self, outstanding: list) -> None:
"""Keep the number of queued async sends bounded."""
while len(outstanding) >= self.config.get('max_outstanding', 1):
oldest = outstanding.pop(0)
for work in oldest:
work.wait()
def _drain_outstanding(self, outstanding: list) -> None:
"""Wait for all queued async sends to complete."""
while outstanding:
oldest = outstanding.pop(0)
for work in oldest:
work.wait()
def _maybe_schedule_blocks(self, schedule_block: bool, threshold: int, block_num: torch.Tensor, total_blocks: int) -> bool:
"""Run one-time block rebalancing when the warmup threshold is reached."""
if schedule_block and self.processed >= threshold:
self._handle_block_scheduling(block_num, total_blocks)
return False
return schedule_block
def _receive_latent_data(self, previous_latent_data, num_steps: int):
"""Release the previous payload and receive the next one from the upstream rank."""
with torch.cuda.stream(self.com_stream):
if previous_latent_data is not None:
self.data_transfer.release_latent_data(previous_latent_data)
latent_data = self.data_transfer.receive_latent_data_async(num_steps)
torch.cuda.current_stream().wait_stream(self.com_stream)
return latent_data
def _run_worker_stage(self, role: str, latent_data, block_num: torch.Tensor):
"""Execute the local DiT blocks for a middle or output rank."""
return self.pipeline.inference(
noise=latent_data.original_latents,
current_start=latent_data.current_start,
current_end=latent_data.current_end,
current_step=latent_data.current_step,
block_mode=role,
block_num=block_num,
patched_x_shape=latent_data.patched_x_shape,
block_x=latent_data.latents,
)
def _send_worker_result(self, role: str, outstanding: list, latent_data, denoised_pred: torch.Tensor) -> None:
"""Forward the payload that should continue around the pipeline ring."""
if role == 'output':
latents = latent_data.latents
original_latents = denoised_pred
else:
latents = denoised_pred
original_latents = latent_data.original_latents
with torch.cuda.stream(self.com_stream):
work_objects = self.data_transfer.send_latent_data_async(
chunk_idx=latent_data.chunk_idx,
latents=latents,
original_latents=original_latents,
patched_x_shape=latent_data.patched_x_shape,
current_start=latent_data.current_start,
current_end=latent_data.current_end,
current_step=latent_data.current_step
)
outstanding.append(work_objects)
def _decode_prediction(self, denoised_pred: torch.Tensor) -> np.ndarray:
"""Decode the newest latent prediction into pixel-space frames."""
video = self._timed_stream_decode(denoised_pred[[-1]])
video = (video * 0.5 + 0.5).clamp(0, 1)
video = video[0].permute(0, 2, 3, 1).contiguous()
return video.cpu().float().numpy()
def _rank_loop_complete(self, num_chunks: int, num_steps: int) -> bool:
"""Return whether a non-output rank has processed all required chunks."""
return (
self.processed + self.processed_offset
>= num_chunks + num_steps * self.world_size + self.world_size - self.rank - 1
)
def _safe_mean(self, values: list) -> float:
if not values:
return 0.0
return float(np.mean(np.array(values)))
def _record_stage_fps(self, values: list[float], num_frames: int, elapsed: float) -> None:
if self.profile and elapsed > 0 and num_frames > 0:
values.append(num_frames / elapsed)
def _timing_enabled(self, schedule_block: bool = False) -> bool:
"""Only force GPU synchronization when profiling or schedule calibration needs it."""
return self.profile or schedule_block
def _sync_for_timing(self, schedule_block: bool = False) -> None:
if self._timing_enabled(schedule_block):
torch.cuda.synchronize()
def _timed_stream_encode(self, images: torch.Tensor) -> torch.Tensor:
self._sync_for_timing()
start_time = time.time()
latents = self.pipeline.vae.stream_encode(images)
self._sync_for_timing()
self._record_stage_fps(self.encode_fps_list, int(images.shape[2]), time.time() - start_time)
return latents
def _timed_stream_decode(self, denoised_pred: torch.Tensor) -> torch.Tensor:
self._sync_for_timing()
start_time = time.time()
video = self.pipeline.vae.stream_decode_to_pixel(denoised_pred)
self._sync_for_timing()
self._record_stage_fps(self.decode_fps_list, int(video.shape[1]), time.time() - start_time)
return video
def reset_stream_state(self, reset_encode: bool = False, reset_decode: bool = False) -> None:
"""Reset cached inference state before starting a new prompt/session."""
self.pipeline.kv_cache1 = None
self.pipeline.crossattn_cache = None
self.pipeline.block_x = None
self.pipeline.hidden_states = None
self.processed = 0
if reset_encode:
self.pipeline.vae.model.first_encode = True
if reset_decode:
self.pipeline.vae.model.first_decode = True
def _broadcast_initial_noise(self, noisy_latents: torch.Tensor) -> None:
latents_shape = torch.tensor(noisy_latents.shape, dtype=torch.int64, device=self.device)
self.communicator.broadcast_tensor(latents_shape, src=0)
self.communicator.broadcast_tensor(noisy_latents, src=0)
def _receive_initial_noise(self) -> torch.Tensor:
latents_shape = torch.zeros(5, dtype=torch.int64, device=self.device)
self.communicator.broadcast_tensor(latents_shape, src=0)
noisy_latents = torch.zeros(tuple(latents_shape.tolist()), dtype=torch.bfloat16, device=self.device)
self.communicator.broadcast_tensor(noisy_latents, src=0)
return noisy_latents
def get_demo_chunk_size(self) -> int:
"""Return the demo stream chunk size in frames."""
return self.base_chunk_size * self.pipeline.num_frame_per_block
def get_demo_first_batch_num_frames(self) -> int:
"""Return the number of frames required to initialize a demo stream."""
return 1 + self.get_demo_chunk_size()
def prepare_demo_input_session(self, images: torch.Tensor, prompt: str, block_num: torch.Tensor, noise_scale: float) -> None:
"""Initialize rank 0 for demo streaming and broadcast the first noisy latents."""
self.reset_stream_state(reset_encode=True)
torch.cuda.empty_cache()
latents = self._timed_stream_encode(images)
latents = latents.transpose(2, 1).contiguous().to(dtype=torch.bfloat16)
noise = torch.randn_like(latents)
noisy_latents = noise * noise_scale + latents * (1 - noise_scale)
self._broadcast_initial_noise(noisy_latents)
self.prepare_pipeline(
text_prompts=[prompt],
noise=noisy_latents,
block_mode='input',
current_start=0,
current_end=self.pipeline.frame_seq_length * 2,
block_num=block_num,
)
torch.cuda.empty_cache()
dist.barrier()
def start_demo_input_stream_session(
self,
prompt: str,
images: torch.Tensor,
block_num: torch.Tensor,
noise_scale: float,
) -> MultiGPUDemoInputSession:
"""Initialize rank 0 and return the demo stream session state."""
chunk_size = self.get_demo_chunk_size()
self.prepare_demo_input_session(images, prompt, block_num, noise_scale)
current_start = self.pipeline.frame_seq_length * (1 + chunk_size // self.base_chunk_size)
current_end = current_start + (chunk_size // self.base_chunk_size) * self.pipeline.frame_seq_length
return MultiGPUDemoInputSession(
prompt=prompt,
noise_scale=noise_scale,
init_noise_scale=noise_scale,
chunk_size=chunk_size,
current_start=current_start,
current_end=current_end,
last_image=images[:, :, [-1]],
)
def prepare_demo_worker_session(self, prompt: str, block_mode: str, block_num: torch.Tensor, decode_initial: bool = False):
"""Initialize a non-input rank for demo streaming from the broadcast first chunk."""
self.reset_stream_state(reset_decode=(block_mode == 'output'))
torch.cuda.empty_cache()
noisy_latents = self._receive_initial_noise()
denoised_pred = self.prepare_pipeline(
text_prompts=[prompt],
noise=noisy_latents,
block_mode=block_mode,
current_start=0,
current_end=self.pipeline.frame_seq_length * 2,
block_num=block_num,
)
torch.cuda.empty_cache()
dist.barrier()
if decode_initial:
return self._decode_prediction(denoised_pred)
return None
def maybe_refresh_demo_input_window(self, session: MultiGPUDemoInputSession) -> None:
"""Wrap the KV-cache window once the streaming refresh threshold is reached."""
if session.current_start // self.pipeline.frame_seq_length >= self.t_refresh:
session.current_start = self.pipeline.kv_cache_length - self.pipeline.frame_seq_length
session.current_end = session.current_start + (session.chunk_size // self.base_chunk_size) * self.pipeline.frame_seq_length
def prepare_demo_input_batch(self, session: MultiGPUDemoInputSession, images: torch.Tensor) -> None:
"""Encode one demo chunk and update the session with the current denoising step."""
num_frames = images.shape[2]
session.input_batch = num_frames // session.chunk_size
session.noise_scale, session.current_step = compute_noise_scale_and_step(
input_video_original=torch.cat([session.last_image, images], dim=2),
end_idx=num_frames + 1,
chunk_size=num_frames,
noise_scale=float(session.noise_scale),
init_noise_scale=float(session.init_noise_scale),
)
latents = self._timed_stream_encode(images)
latents = latents.transpose(2, 1).contiguous().to(dtype=torch.bfloat16)
noise = torch.randn_like(latents)
session.noisy_latents = noise * session.noise_scale + latents * (1 - session.noise_scale)
def run_demo_input_step(
self,
session: MultiGPUDemoInputSession,
block_num: torch.Tensor,
previous_latent_data=None,
):
"""Run one rank-0 demo step from the current session batch."""
if session.noisy_latents is None or session.input_batch <= 0:
raise RuntimeError("demo input batch was not prepared before run_demo_input_step")
denoised_pred, patched_x_shape = self.run_input_stage(
noisy_latents=session.noisy_latents[:, -session.input_batch].unsqueeze(1),
current_start=session.current_start,
current_end=session.current_end,
current_step=session.current_step,
block_num=block_num,
previous_latent_data=previous_latent_data,
)
session.input_batch -= 1
return denoised_pred, patched_x_shape
def advance_demo_input_stream_session(self, session: MultiGPUDemoInputSession, images: torch.Tensor) -> None:
"""Advance the demo stream session after a chunk has been queued downstream."""
session.last_image = images[:, :, [-1]]
session.chunk_idx += 1
session.current_start = session.current_end
session.current_end += (session.chunk_size // self.base_chunk_size) * self.pipeline.frame_seq_length
def send_demo_input_prompt_update(
self,
prompt: str,
device: torch.device,
num_steps: int,
chunk_idx: int,
denoised_pred: torch.Tensor,
patched_x_shape: torch.Tensor,
current_step: int,
) -> None:
"""Signal a prompt restart from rank 0 and drain in-flight returns from downstream ranks."""
with torch.cuda.stream(self.com_stream):
self.data_transfer.send_latent_data_async(
chunk_idx=-1,
latents=denoised_pred.new_zeros([1] * denoised_pred.ndim),
original_latents=self.pipeline.hidden_states.new_zeros([1] * self.pipeline.hidden_states.ndim),
patched_x_shape=patched_x_shape,
current_start=self.pipeline.kv_cache_starts,
current_end=self.pipeline.kv_cache_ends,
current_step=int(current_step),
)
self.data_transfer.send_prompt_async(prompt, device)
for _ in range(min(chunk_idx, self.world_size - 1)):
pending_data = self.data_transfer.receive_latent_data_async(num_steps)
self.data_transfer.release_latent_data(pending_data)
def send_demo_middle_prompt_update(
self,
prompt: str,
device: torch.device,
denoised_pred: torch.Tensor | None,
latent_data,
) -> None:
"""Forward a prompt restart from a middle rank to the next rank."""
sentinel_source = denoised_pred if denoised_pred is not None else latent_data.latents
with torch.cuda.stream(self.com_stream):
self.data_transfer.send_latent_data_async(
chunk_idx=-1,
latents=sentinel_source.new_zeros([1] * sentinel_source.ndim),
original_latents=latent_data.original_latents,
patched_x_shape=latent_data.patched_x_shape,
current_start=latent_data.current_start,
current_end=latent_data.current_end,
current_step=int(latent_data.current_step),
)
self.data_transfer.send_prompt_async(prompt, device)
def run_input_stage(self, noisy_latents: torch.Tensor, current_start: int, current_end: int, current_step: int, block_num: torch.Tensor, previous_latent_data=None):
"""Run the rank-0 stage for one streaming chunk."""
if previous_latent_data is not None and self.processed >= self.world_size:
self.pipeline.hidden_states.copy_(previous_latent_data.original_latents)
self.pipeline.kv_cache_starts.copy_(previous_latent_data.current_start)
self.pipeline.kv_cache_ends.copy_(previous_latent_data.current_end)
return self.pipeline.inference(
noise=noisy_latents,
current_start=current_start,
current_end=current_end,
current_step=current_step,
block_mode='input',
block_num=block_num,
)
def run_rank_0_loop(self, input_video_original: torch.Tensor, prompts: list,
num_chunks: int, num_steps: int, chunk_size: int,
block_num: torch.Tensor, noise_scale: float,
schedule_block: bool, total_blocks: int):
"""
Run the main loop for rank 0 (encoder + async send).
This method encapsulates the rank 0 logic using the communication abstractions.
"""
self.logger.info("Starting rank 0 inference loop")
# Initialize variables
start_idx = 0
end_idx = 1 + chunk_size
current_start = 0
current_end = self.pipeline.frame_seq_length * (1+chunk_size//self.base_chunk_size)
init_noise_scale = noise_scale
outstanding = []
latent_data = None
self._sync_for_timing(schedule_block)
start_time = time.time()
while True:
# Process new chunk if available
start_idx = end_idx
end_idx = end_idx + chunk_size
current_start = current_end
current_end = current_end + (chunk_size // self.base_chunk_size) * self.pipeline.frame_seq_length
if schedule_block:
self._sync_for_timing(schedule_block)
start_vae = time.time()
if end_idx <= input_video_original.shape[2]:
inp = input_video_original[:, :, start_idx:end_idx]
noise_scale, current_step = compute_noise_scale_and_step(
input_video_original, end_idx, chunk_size, noise_scale, init_noise_scale
)
latents = self._timed_stream_encode(inp)
latents = latents.transpose(2, 1).contiguous().to(dtype=torch.bfloat16)
noise = torch.randn_like(latents)
noisy_latents = noise * noise_scale + latents * (1 - noise_scale)
# if current_start//self.pipeline.frame_seq_length >= self.t_refresh:
# current_start = self.pipeline.kv_cache_length - self.pipeline.frame_seq_length
# current_end = current_start + (chunk_size // self.base_chunk_size) * self.pipeline.frame_seq_length
# Measure DiT time if scheduling is enabled
if schedule_block:
self._sync_for_timing(schedule_block)
start_dit = time.time()
t_vae = start_dit - start_vae
# Run inference
denoised_pred, patched_x_shape = self.pipeline.inference(
noise=noisy_latents,
current_start=current_start,
current_end=current_end,
current_step=current_step,
block_mode='input',
block_num=block_num[self.rank],
)
# Update DiT timing
if schedule_block:
self._sync_for_timing(schedule_block)
temp = time.time() - start_dit
if temp < self.t_dit:
self.t_dit = temp
self.processed += 1
with torch.cuda.stream(self.com_stream):
if self.processed >= self.world_size:
if latent_data is not None:
self.data_transfer.release_latent_data(latent_data)
# Receive data from previous rank
latent_data = self.data_transfer.receive_latent_data_async(num_steps)
torch.cuda.current_stream().wait_stream(self.com_stream)
# Wait for outstanding operations
self._wait_for_outstanding(outstanding)
# Send data to next rank
with torch.cuda.stream(self.com_stream):
work_objects = self.data_transfer.send_latent_data_async(
chunk_idx=start_idx,
latents=denoised_pred,
original_latents=self.pipeline.hidden_states,
patched_x_shape=patched_x_shape,
current_start=self.pipeline.kv_cache_starts,
current_end=self.pipeline.kv_cache_ends,
current_step=current_step
)
outstanding.append(work_objects)
# Handle block scheduling
if schedule_block and self.processed >= self.schedule_step:
self._handle_block_scheduling(block_num, total_blocks)
schedule_block = False
# Update timing and check completion
if self._timing_enabled(schedule_block):
self._sync_for_timing(schedule_block)
end_time = time.time()
t = end_time - start_time
self.logger.info(f"Encode {self.processed}, time: {t:.4f} s, fps: {inp.shape[2]/t:.4f}")
if schedule_block:
t_total = self.t_dit + t_vae
if t_total < self.t_total:
self.t_total = t_total
start_time = end_time
if self.processed >= self.world_size:
self.pipeline.hidden_states.copy_(latent_data.original_latents)
self.pipeline.kv_cache_starts.copy_(latent_data.current_start)
self.pipeline.kv_cache_ends.copy_(latent_data.current_end)
if self.processed + self.processed_offset >= num_chunks + num_steps * self.world_size + self.world_size - self.rank - 1:
break
if latent_data is not None:
self.data_transfer.release_latent_data(latent_data)
self._drain_outstanding(outstanding)
self.logger.info(f"VAE Encode Average FPS: {self._safe_mean(self.encode_fps_list):.4f}")
self.logger.info("Rank 0 inference loop completed")
def run_final_rank_loop(self, num_chunks: int, num_steps: int, chunk_size: int,
block_num: torch.Tensor, output_folder: str, fps: int,
schedule_block: bool, total_blocks: int, results: dict):
"""Run the worker loop for the output rank."""
self.run_worker_rank_loop(
role='output',
num_chunks=num_chunks,
num_steps=num_steps,
chunk_size=chunk_size,
block_num=block_num,
schedule_block=schedule_block,
total_blocks=total_blocks,
output_folder=output_folder,
fps=fps,
results=results,
)
def run_middle_rank_loop(self, num_chunks: int, num_steps: int, chunk_size: int,
block_num: torch.Tensor, schedule_block: bool, total_blocks: int):
"""Run the worker loop for a middle rank."""
self.run_worker_rank_loop(
role='middle',
num_chunks=num_chunks,
num_steps=num_steps,
chunk_size=chunk_size,
block_num=block_num,
schedule_block=schedule_block,
total_blocks=total_blocks,
)
def run_worker_rank_loop(
self,
role: str,
num_chunks: int,
num_steps: int,
chunk_size: int,
block_num: torch.Tensor,
schedule_block: bool,
total_blocks: int,
output_folder: str = None,
fps: int = None,
results: dict = None,
):
"""Run the shared receive -> infer -> forward loop for middle and output ranks."""
if role not in {'middle', 'output'}:
raise ValueError(f"Unsupported worker role: {role}")
self.logger.info(f"Starting {role} rank inference loop")
if role == 'output':
if output_folder is None or fps is None or results is None:
raise ValueError("output rank requires output_folder, fps, and results")
os.makedirs(output_folder, exist_ok=True)
save_results = 1
outstanding = []
fps_list = []
latent_data = None
self._sync_for_timing(schedule_block)
start_time = time.time()
while True:
latent_data = self._receive_latent_data(latent_data, num_steps)
schedule_block = self._maybe_schedule_blocks(
schedule_block,
self.schedule_step - self.rank,
block_num,
total_blocks,
)
if schedule_block:
self._sync_for_timing(schedule_block)
start_dit = time.time()
denoised_pred, _ = self._run_worker_stage(role, latent_data, block_num[self.rank])
if schedule_block:
self._sync_for_timing(schedule_block)
temp = time.time() - start_dit
if temp < self.t_dit:
self.t_dit = temp
self.processed += 1
self._wait_for_outstanding(outstanding)
self._send_worker_result(role, outstanding, latent_data, denoised_pred)
if role == 'output':
if self.processed >= num_steps * self.world_size - 1:
if schedule_block:
self._sync_for_timing(schedule_block)
start_vae = time.time()
video = self._timed_stream_decode(denoised_pred[[-1]])
video = (video * 0.5 + 0.5).clamp(0, 1)
video = video[0].permute(0, 2, 3, 1).contiguous()
results[save_results] = video.cpu().float().numpy()
if self._timing_enabled(schedule_block):
self._sync_for_timing(schedule_block)
end_time = time.time()
elapsed = end_time - start_time
fps_test = video.shape[0] / elapsed
if self.processed > self.schedule_step:
fps_list.append(fps_test)
self.logger.info(f"Decode {self.processed}, time: {elapsed:.4f} s, FPS: {fps_test:.4f}")
if schedule_block:
t_vae = end_time - start_vae
t_total = t_vae + self.t_dit
if t_total < self.t_total:
self.t_total = t_total
start_time = end_time
save_results += 1
if save_results >= num_chunks:
break
else:
if self._timing_enabled(schedule_block):
self._sync_for_timing(schedule_block)
end_time = time.time()
elapsed = end_time - start_time
fps_test = chunk_size / elapsed
if self.processed > self.schedule_step:
fps_list.append(fps_test)
if schedule_block:
t_total = self.t_dit
if t_total < self.t_total:
self.t_total = t_total
self.logger.info(f"Middle {self.processed}, time: {elapsed:.4f} s, fps: {fps_test:.4f}")
start_time = end_time
if self._rank_loop_complete(num_chunks, num_steps):
break
if latent_data is not None:
self.data_transfer.release_latent_data(latent_data)
self._drain_outstanding(outstanding)
if role == 'output':
video_list = [results[i] for i in range(num_chunks)]
video = np.concatenate(video_list, axis=0)
fps_avg = self._safe_mean(fps_list)
self.logger.info(f"Video shape: {video.shape}, Average FPS: {fps_avg:.4f}")
self.logger.info(f"VAE Decode Average FPS: {self._safe_mean(self.decode_fps_list):.4f}")
output_path = os.path.join(output_folder, f"output_{0:03d}.mp4")
export_to_video(video, output_path, fps=fps)
self.logger.info(f"Video saved to: {output_path} (Press Ctrl+C to force exit)")
return
self.logger.info(f"DiT Average FPS: {self._safe_mean(fps_list):.4f}")
self.logger.info(f"Rank {self.rank} inference loop completed")
def _handle_block_scheduling(self, block_num: torch.Tensor, total_blocks: int):
"""Handle block scheduling and rebalancing."""
self.logger.info(f"Scheduling block in {self.processed}")
# Gather timing information from all ranks
t_total_tensor = torch.tensor(self.t_total, dtype=torch.float32, device=self.device)
t_dit_tensor = torch.tensor(self.t_dit, dtype=torch.float32, device=self.device)
gather_blocks = [torch.zeros_like(t_dit_tensor, dtype=torch.float32, device=self.device)
for _ in range(self.world_size)]
dist.all_gather(gather_blocks, t_dit_tensor)
t_dit_list = [t_dit_i.item() for t_dit_i in gather_blocks]
dist.all_gather(gather_blocks, t_total_tensor)
t_list = [t_i.item() for t_i in gather_blocks]
# Compute new block distribution
new_block_num = torch.tensor(
compute_balanced_split(total_blocks, t_list, t_dit_list, block_num.tolist()),
dtype=torch.int64, device=self.device
)
self.logger.info(f"New block distribution: {new_block_num[self.rank].tolist()}")
# Broadcast new block distribution
dist.broadcast(new_block_num, src=self.world_size - 1)
# Rebalance KV cache
self.data_transfer.rebalance_kv_cache(block_num, new_block_num, total_blocks)
# Update block_num
block_num.copy_(new_block_num)
start_block, end_block = block_num[self.rank][0].item(), block_num[self.rank][1].item()
blocks_to_keep = list(range(start_block, end_block))
for i in range(self.pipeline.num_transformer_blocks):
if i not in blocks_to_keep:
self.pipeline.kv_cache1[i]['k'] = self.pipeline.kv_cache1[i]['k'].cpu()
self.pipeline.kv_cache1[i]['v'] = self.pipeline.kv_cache1[i]['v'].cpu()
self.logger.info("Block scheduling completed")
def cleanup(self):
"""Clean up resources."""
self.data_transfer.cleanup()
self.logger.info("InferencePipelineManager cleanup completed")
def main():
"""Main function for the refactored inference pipeline."""
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str)
parser.add_argument("--checkpoint_folder", type=str)
parser.add_argument("--output_folder", type=str)
parser.add_argument("--prompt_file_path", type=str)
parser.add_argument("--video_path", type=str)
parser.add_argument("--noise_scale", type=float, default=0.8)
parser.add_argument("--height", type=int, default=480)
parser.add_argument("--width", type=int, default=832)
parser.add_argument("--fps", type=int, default=30)
parser.add_argument("--max_outstanding", type=int, default=1, help="max number of outstanding sends/recv to keep")
parser.add_argument("--dit_fsdp", action="store_true", default=False)
parser.add_argument("--t5_fsdp", action="store_true", default=False)
parser.add_argument("--ulysses_size", type=int, default=1)
parser.add_argument("--ring_size", type=int, default=1)
parser.add_argument("--step", type=int, default=2)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument("--schedule_block", action="store_true", default=False)
parser.add_argument("--profile", action="store_true", default=False, help="Enable synchronized throughput logging")
parser.add_argument("--t2v", action="store_true", default=False)
parser.add_argument("--model_type", type=str, default="T2V-1.3B", help="Model type (e.g., T2V-1.3B)")
parser.add_argument("--use_taehv", action="store_true", default=False, help="Use the lightweight TAEHV VAE for encode/decode")
parser.add_argument("--use_tensorrt", "--use_taehv_tensorrt", dest="use_tensorrt", action="store_true", default=False, help="Enable available TensorRT acceleration paths")
parser.add_argument("--fast", action="store_true", default=False, help="Enable the fast path: --use_taehv --use_tensorrt")
args = parser.parse_args()
torch.set_grad_enabled(False)
init_distributed()
rank = dist.get_rank()
world_size = dist.get_world_size()
local_rank = int(os.environ.get("LOCAL_RANK", rank))
assert world_size >= 2, "world_size must be at least 2"
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")
# Load configuration
config = merge_cli_config(args.config_path, args)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
LOGGER.info("Denoising Step List: %s", list(config.denoising_step_list))
set_seed(args.seed)
# Load input video
input_video_original = load_mp4_as_tensor(args.video_path, resize_hw=(args.height, args.width)).unsqueeze(0)
if input_video_original.dtype != torch.bfloat16:
input_video_original = input_video_original.to(dtype=torch.bfloat16).to(device)
LOGGER.info("Input video tensor shape: %s", tuple(input_video_original.shape))
b, c, t, h, w = input_video_original.shape
# Calculate number of chunks
chunk_size = 4 * config.num_frame_per_block
if rank == 0:
num_chunks = (t - 1) // chunk_size
else:
num_chunks = 0
num_chunks_tensor = torch.tensor([num_chunks], dtype=torch.int64, device=device)
dist.broadcast(num_chunks_tensor, src=0)
num_chunks = int(num_chunks_tensor.item())
# Initialize pipeline manager
pipeline_manager = InferencePipelineManager(config, device, rank, world_size)
pipeline_manager.load_model(args.checkpoint_folder)
# Load prompts
dataset = TextDataset(args.prompt_file_path)
prompts = [dataset[0]]
num_steps = len(pipeline_manager.pipeline.denoising_step_list)
# Determine block mode and setup block distribution
if rank == 0:
block_mode = 'input'
elif rank == world_size - 1:
block_mode = 'output'
else:
block_mode = 'middle'
# Setup block distribution
total_blocks = pipeline_manager.pipeline.num_transformer_blocks
total_block_num = compute_default_block_distribution(total_blocks, world_size)
block_num = torch.tensor(total_block_num, dtype=torch.int64, device=device)
# Prepare pipeline
start_idx = 0
end_idx = 5
current_start = 0
current_end = pipeline_manager.pipeline.frame_seq_length * 2
inp = input_video_original[:, :, start_idx:end_idx]
# Only rank 0 performs VAE encoding operation
if rank == 0:
latents = pipeline_manager._timed_stream_encode(inp)
latents = latents.transpose(2, 1).contiguous().to(dtype=torch.bfloat16)
noise = torch.randn_like(latents)
noisy_latents = noise * args.noise_scale + latents * (1 - args.noise_scale)
# First broadcast the shape information
latents_shape = torch.tensor(latents.shape, dtype=torch.int64, device=device)
pipeline_manager.communicator.broadcast_tensor(latents_shape, src=0)
# Then broadcast noisy_latents
pipeline_manager.communicator.broadcast_tensor(noisy_latents, src=0)
else:
# Other ranks receive shape info first
latents_shape = torch.zeros(5, dtype=torch.int64, device=device)
pipeline_manager.communicator.broadcast_tensor(latents_shape, src=0)
# Create tensor with same shape for receiving broadcast data
noisy_latents = torch.zeros(tuple(latents_shape.tolist()), dtype=torch.bfloat16, device=device)
# Receive the broadcasted noisy_latents
pipeline_manager.communicator.broadcast_tensor(noisy_latents, src=0)
denoised_pred = pipeline_manager.prepare_pipeline(
text_prompts=prompts,
noise=noisy_latents,
block_mode=block_mode,
current_start=current_start,
current_end=current_end,
block_num=block_num[rank],
)
# Clear unused GPU memory
torch.cuda.empty_cache()
# Save initial result for final rank
if rank == world_size - 1:
results = {}
video = pipeline_manager._timed_stream_decode(denoised_pred)
video = (video * 0.5 + 0.5).clamp(0, 1)
video = video[0].permute(0, 2, 3, 1).contiguous()
results[0] = video.cpu().float().numpy()
dist.barrier()
pipeline_manager.logger.info(f"Prepared, Block num: {block_num[rank].tolist()}")
used_mem = torch.cuda.memory_allocated(device) / 1024 / 1024 / 1024
total_mem = torch.cuda.get_device_properties(device).total_memory / 1024 / 1024 / 1024
pipeline_manager.logger.info(f"Current GPU memory usage: {used_mem:.2f} GB / {total_mem:.2f} GB")
# Run appropriate loop based on rank
try:
if rank == 0:
pipeline_manager.run_rank_0_loop(
input_video_original, prompts, num_chunks, num_steps, chunk_size,
block_num, args.noise_scale, args.schedule_block, total_blocks
)
elif rank == world_size - 1:
pipeline_manager.run_final_rank_loop(
num_chunks, num_steps, chunk_size, block_num, args.output_folder,
args.fps, args.schedule_block, total_blocks, results
)
else:
pipeline_manager.run_middle_rank_loop(
num_chunks, num_steps, chunk_size, block_num, args.schedule_block, total_blocks
)
finally:
# Cleanup
pipeline_manager.cleanup()
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
main()