DVD / diffsynth /pipelines /wan_video_new_determine.py
haodongli's picture
init-1
4b35c4e
import glob
import os
import time
import types
import warnings
from dataclasses import dataclass
from typing import List, Optional, Union
import numpy as np
import torch
from einops import rearrange, reduce, repeat
# from modelscope import snapshot_download
from huggingface_hub import snapshot_download
from PIL import Image
from tqdm import tqdm
from typing_extensions import Literal
from ..models import ModelManager, load_state_dict
from ..models.wan_video_dit import RMSNorm, WanModel, sinusoidal_embedding_1d
from ..models.wan_video_image_encoder import WanImageEncoder
from ..models.wan_video_motion_controller import WanMotionControllerModel
# from ..model.
from ..models.wan_video_text_encoder import (T5LayerNorm, T5RelativeEmbedding,
WanTextEncoder)
from ..models.wan_video_vace import VaceWanModel
from ..models.wan_video_vae import (CausalConv3d, RMS_norm, Upsample,
WanVideoVAE)
from ..schedulers.flow_match import FlowMatchScheduler
# from ..prompters import WanPrompter
from ..vram_management import (AutoWrappedLinear, AutoWrappedModule,
WanAutoCastLayerNorm, enable_vram_management)
class BasePipeline(torch.nn.Module):
def __init__(
self,
device="cuda",
torch_dtype=torch.float16,
height_division_factor=64,
width_division_factor=64,
time_division_factor=None,
time_division_remainder=None,
):
super().__init__()
# The device and torch_dtype is used for the storage of intermediate variables, not models.
self.device = device
self.torch_dtype = torch_dtype
# The following parameters are used for shape check.
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
self.time_division_factor = time_division_factor
self.time_division_remainder = time_division_remainder
self.vram_management_enabled = False
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
*args, **kwargs
)
if device is not None:
self.device = device
if dtype is not None:
self.torch_dtype = dtype
super().to(*args, **kwargs)
return self
def check_resize_height_width(self, height, width, num_frames=None):
# Shape check
# print(
# f"height, width, time division factor: {self.height_division_factor}, {self.width_division_factor}, {self.time_division_factor}, time division remainder: {self.time_division_remainder}"
# )
assert (
height % self.height_division_factor == 0
), f"height {height} is not divisible by {self.height_division_factor}."
assert (
width % self.width_division_factor == 0
), f"width {width} is not divisible by {self.width_division_factor}."
assert (num_frames is not None) and (
(num_frames + self.time_division_factor) % self.time_division_factor
== self.time_division_remainder
), f"num_frames {num_frames} is not divisible by {self.time_division_factor} with remainder {self.time_division_remainder}."
return height, width, num_frames
def preprocess_image(
self,
image,
torch_dtype=None,
device=None,
pattern="B C H W",
min_value=-1,
max_value=1,
):
# Transform a PIL.Image to torch.Tensor
# print(f"Image size: {image.size}, dtype: {image.mode}")
# assert isinstance(image, torch.Tensor), "Image must be a torch.Tensor."
# C H W
if isinstance(image, torch.Tensor):
# C H W
# print(f"Image shape {image.shape}")
assert (len(image.shape) == 3 and image.shape[0] == 3) or (
len(image.shape) == 4 and image.shape[1] == 3
), "Image tensor must be in 3 H W or B 3 H W format."
image = image.to(
dtype=torch_dtype or self.torch_dtype, device=device or self.device
)
image = image * ((max_value - min_value)) + min_value
if len(image.shape) == 3:
image = image.unsqueeze(0) # Add batch dimension
else:
image = torch.Tensor(np.array(image, dtype=np.float32))
image = image.to(
dtype=torch_dtype or self.torch_dtype, device=device or self.device
)
image = image * ((max_value - min_value) / 255) + min_value
image = repeat(
image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {})
)
return image
def preprocess_video(
self,
video,
torch_dtype=None,
device=None,
pattern="B C T H W",
min_value=-1,
max_value=1,
):
video = [
self.preprocess_image(
image,
torch_dtype=torch_dtype,
device=device,
min_value=min_value,
max_value=max_value,
)
for image in video
]
video = torch.stack(video, dim=pattern.index("T") // 2)
return video
def vae_output_to_image(
self, vae_output, pattern="B C H W", min_value=-1, max_value=1
):
# Transform a torch.Tensor to PIL.Image
if pattern != "H W C":
vae_output = reduce(
vae_output, f"{pattern} -> H W C", reduction="mean")
# image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(
# 0, 255
# )
image = (vae_output - min_value) * (255.0 / (max_value - min_value))
image = image.clamp(0.0, 255.0)
image = image.to(device="cpu", dtype=torch.float32)
image = image.numpy()
# image = Image.fromarray(image.numpy())
return image
def vae_output_to_video(
self, vae_output, pattern="B C T H W", min_value=-1, max_value=1
):
# Transform a torch.Tensor to list of PIL.Image
# if pattern != "T H W C":
# vae_output = reduce(
# vae_output, f"{pattern} -> T H W C", reduction="mean")
if vae_output.ndim == 5: # B C T H W
assert (
vae_output.shape[1] == 3
), f"vae_output shape {vae_output.shape} is not valid. Expected 5D tensor with 3 channels on the second dimension."
vae_output = vae_output.permute(0, 2, 3, 4, 1)
# print(f"vae_output shape after permute: {vae_output.shape}")
video = vae_output.to(device="cpu", dtype=torch.float32).numpy()
video = (video + 1.0) / 2.0
# print(f"Video range before clip: {video.min()} to {video.max()}")
video = video.clip(0.0, 1.0)
# for _video in vae_output:
# video.append(
# [
# self.vae_output_to_image(
# image,
# pattern="H W C",
# min_value=min_value,
# max_value=max_value,
# )
# for image in _video
# ]
# )
# else:
# raise ValueError(
# f"Invalid vae_output shape {vae_output.shape}. Expected 5D tensor."
# )
return video
def load_models_to_device(self, model_names=[]):
if self.vram_management_enabled:
# offload models
for name, model in self.named_children():
if name not in model_names:
if (
hasattr(model, "vram_management_enabled")
and model.vram_management_enabled
):
for module in model.modules():
if hasattr(module, "offload"):
module.offload()
else:
model.cpu()
torch.cuda.empty_cache()
# onload models
for name, model in self.named_children():
if name in model_names:
if (
hasattr(model, "vram_management_enabled")
and model.vram_management_enabled
):
for module in model.modules():
if hasattr(module, "onload"):
module.onload()
else:
model.to(self.device)
def generate_noise(
self,
shape,
seed=None,
rand_device="cpu",
rand_torch_dtype=torch.float32,
device=None,
torch_dtype=None,
):
# Initialize Gaussian noise
generator = (
None if seed is None else torch.Generator(
rand_device).manual_seed(seed)
)
# TODO multi-res noise
noise = torch.randn(
shape, generator=generator, device=rand_device, dtype=rand_torch_dtype
)
noise = noise.to(
dtype=torch_dtype or self.torch_dtype, device=device or self.device
)
return noise
def enable_cpu_offload(self):
warnings.warn(
"`enable_cpu_offload` will be deprecated. Please use `enable_vram_management`."
)
self.vram_management_enabled = True
def get_vram(self):
return torch.cuda.mem_get_info(self.device)[1] / (1024**3)
def freeze_except(self, model_names):
for name, model in self.named_children():
if name in model_names:
print(f"Unfreezing model {name}.")
print(
f"Model parameters: {sum(p.numel() for p in model.parameters())}")
model.train()
model.requires_grad_(True)
else:
print(f"Freezing model {name}.")
print(
f"Model parameters: {sum(p.numel() for p in model.parameters())}")
model.eval()
model.requires_grad_(False)
@dataclass
class ModelConfig:
path: Union[str, list[str]] = None
model_id: str = None
origin_file_pattern: Union[str, list[str]] = None
download_resource: str = "ModelScope"
offload_device: Optional[Union[str, torch.device]] = None
offload_dtype: Optional[torch.dtype] = None
def download_if_necessary(
self, local_model_path="./models", skip_download=False, use_usp=False
):
if self.path is None:
# Check model_id and origin_file_pattern
if self.model_id is None:
raise ValueError(
f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`."""
)
# Skip if not in rank 0
if use_usp:
import torch.distributed as dist
skip_download = dist.get_rank() != 0
# Check whether the origin path is a folder
if self.origin_file_pattern is None or self.origin_file_pattern == "":
self.origin_file_pattern = ""
allow_file_pattern = None
is_folder = True
elif isinstance(
self.origin_file_pattern, str
) and self.origin_file_pattern.endswith("/"):
allow_file_pattern = self.origin_file_pattern + "*"
is_folder = True
else:
allow_file_pattern = self.origin_file_pattern
is_folder = False
# Download
if not skip_download:
downloaded_files = glob.glob(
self.origin_file_pattern,
root_dir=os.path.join(local_model_path, self.model_id),
)
# snapshot_download(
# self.model_id,
# local_dir=os.path.join(local_model_path, self.model_id),
# allow_file_pattern=allow_file_pattern,
# ignore_file_pattern=downloaded_files,
# local_files_only=False,
# )
snapshot_download(
self.model_id,
repo_type="model", # 如果是dataset要改成"dataset"
local_dir=os.path.join(local_model_path, self.model_id),
allow_patterns=allow_file_pattern,
ignore_patterns=downloaded_files, # 注意这里是 patterns
# ignore_file_pattern=downloaded_files,
# local_files_only=False,
local_files_only=False,
resume_download=True, # 支持断点续传
)
# Let rank 1, 2, ... wait for rank 0
if use_usp:
import torch.distributed as dist
dist.barrier(device_ids=[dist.get_rank()])
# Return downloaded files
if is_folder:
self.path = os.path.join(
local_model_path, self.model_id, self.origin_file_pattern
)
else:
self.path = glob.glob(
os.path.join(
local_model_path, self.model_id, self.origin_file_pattern
)
)
if isinstance(self.path, list) and len(self.path) == 1:
self.path = self.path[0]
class WanVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None):
super().__init__(
device=device,
torch_dtype=torch_dtype,
height_division_factor=16,
width_division_factor=16,
time_division_factor=4,
time_division_remainder=1,
)
self.scheduler = FlowMatchScheduler(
shift=5, sigma_min=0.0, extra_one_step=True)
# self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
self.text_encoder: WanTextEncoder = None
self.image_encoder: WanImageEncoder = None
# self.pose_encoder: CameraPoseEncoder = None
self.dit: WanModel = None
self.vae: WanVideoVAE = None
self.motion_controller: WanMotionControllerModel = None
self.vace: VaceWanModel = None
self.in_iteration_models = ("dit", "motion_controller", "vace")
self.unit_runner = PipelineUnitRunner()
self.units = [
WanVideoUnit_ShapeChecker(), # check if the shape if ok
# WanVideoUnit_NoiseInitializer(),
WanVideoUnit_InputVideoEmbedder(),
WanVideoUnit_PromptEmbedder(),
WanVideoUnit_ImageEmbedder(),
# WanVideoUnit_FunReference(),
# WanVideoUnit_CameraPoseEmbedder(),
# WanVideoUnit_SpeedControl(),
# WanVideoUnit_VACE(),
WanVideoUnit_UnifiedSequenceParallel(),
# WanVideoUnit_TeaCache(),
# WanVideoUnit_CfgMerger(),
]
self.model_fn = model_fn_wan_video
def training_predict(self, **inputs):
timestep_id = torch.tensor([0])
# print(f"timestep_id: {timestep_id}")
timestep = self.scheduler.timesteps[timestep_id].to(
dtype=self.torch_dtype, device=self.device
)
# print(f"Selected timestep {timestep}")
inputs["latents"] = inputs['rgb_latents']
training_target = self.scheduler.training_target(
inputs["depth_latents"], inputs["rgb_latents"], timestep
)
noise_pred = self.model_fn(**inputs, timestep=timestep)
return {
'rgb_gt': inputs['rgb_latents'],
"depth_gt": training_target,
"pred": noise_pred,
"weight": self.scheduler.training_weight(timestep),
}
def enable_vram_management(
self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5
):
self.vram_management_enabled = True
if num_persistent_param_in_dit is not None:
vram_limit = None
else:
if vram_limit is None:
vram_limit = self.get_vram()
vram_limit = vram_limit - vram_buffer
if self.text_encoder is not None:
dtype = next(iter(self.text_encoder.parameters())).dtype
enable_vram_management(
self.text_encoder,
module_map={
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Embedding: AutoWrappedModule,
T5RelativeEmbedding: AutoWrappedModule,
T5LayerNorm: AutoWrappedModule,
},
module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.dit is not None:
dtype = next(iter(self.dit.parameters())).dtype
device = "cpu" if vram_limit is not None else self.device
enable_vram_management(
self.dit,
module_map={
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.LayerNorm: WanAutoCastLayerNorm,
RMSNorm: AutoWrappedModule,
torch.nn.Conv2d: AutoWrappedModule,
},
module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.vae is not None:
dtype = next(iter(self.vae.parameters())).dtype
enable_vram_management(
self.vae,
module_map={
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv2d: AutoWrappedModule,
RMS_norm: AutoWrappedModule,
CausalConv3d: AutoWrappedModule,
Upsample: AutoWrappedModule,
torch.nn.SiLU: AutoWrappedModule,
torch.nn.Dropout: AutoWrappedModule,
},
module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
if self.image_encoder is not None:
dtype = next(iter(self.image_encoder.parameters())).dtype
enable_vram_management(
self.image_encoder,
module_map={
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv2d: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
},
module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=dtype,
computation_device=self.device,
),
)
if self.motion_controller is not None:
dtype = next(iter(self.motion_controller.parameters())).dtype
enable_vram_management(
self.motion_controller,
module_map={
torch.nn.Linear: AutoWrappedLinear,
},
module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=dtype,
computation_device=self.device,
),
)
if self.vace is not None:
device = "cpu" if vram_limit is not None else self.device
enable_vram_management(
self.vace,
module_map={
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
RMSNorm: AutoWrappedModule,
},
module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
def initialize_usp(self):
import torch.distributed as dist
from xfuser.core.distributed import (init_distributed_environment,
initialize_model_parallel)
dist.init_process_group(backend="nccl", init_method="env://")
init_distributed_environment(
rank=dist.get_rank(), world_size=dist.get_world_size()
)
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=1,
ulysses_degree=dist.get_world_size(),
)
torch.cuda.set_device(dist.get_rank())
def enable_usp(self):
from xfuser.core.distributed import get_sequence_parallel_world_size
from ..distributed.xdit_context_parallel import (usp_attn_forward,
usp_dit_forward)
for block in self.dit.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn
)
self.dit.forward = types.MethodType(usp_dit_forward, self.dit)
self.sp_size = get_sequence_parallel_world_size()
self.use_unified_sequence_parallel = True
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(
model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"
),
local_model_path: str = "./models",
skip_download: bool = False,
redirect_common_files: bool = True,
use_usp=False,
):
# Redirect model path
if redirect_common_files:
redirect_dict = {
"models_t5_umt5-xxl-enc-bf16.pth": "Wan-AI/Wan2.1-T2V-1.3B",
"Wan2.1_VAE.pth": "Wan-AI/Wan2.1-T2V-1.3B",
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": "Wan-AI/Wan2.1-I2V-14B-480P",
}
for model_config in model_configs:
if (
model_config.origin_file_pattern is None
or model_config.model_id is None
):
continue
if (
model_config.origin_file_pattern in redirect_dict
and model_config.model_id
!= redirect_dict[model_config.origin_file_pattern]
):
print(
f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection."
)
model_config.model_id = redirect_dict[
model_config.origin_file_pattern
]
# Initialize pipeline
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
if use_usp:
pipe.initialize_usp()
# Download and load models
model_manager = ModelManager()
for model_config in model_configs:
model_config.download_if_necessary(
local_model_path, skip_download=skip_download, use_usp=use_usp
)
model_manager.load_model(
model_config.path,
device=model_config.offload_device or device,
torch_dtype=model_config.offload_dtype or torch_dtype,
)
# Load models
# pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
pipe.dit = model_manager.fetch_model("wan_video_dit")
pipe.vae = model_manager.fetch_model("wan_video_vae")
pipe.image_encoder = model_manager.fetch_model(
"wan_video_image_encoder")
pipe.motion_controller = model_manager.fetch_model(
"wan_video_motion_controller"
)
pipe.vace = model_manager.fetch_model("wan_video_vace")
# Initialize tokenizer
tokenizer_config.download_if_necessary(
local_model_path, skip_download=skip_download
)
# pipe.prompter.fetch_models(pipe.text_encoder)
# pipe.prompter.fetch_tokenizer(tokenizer_config.path)
# Unified Sequence Parallel
if use_usp:
pipe.enable_usp()
return pipe
# @torch.no_grad()
@torch.inference_mode()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: Optional[str] = "",
# Image-to-video
input_image: Optional[Image.Image] = None,
# First-last-frame-to-video
end_image: Optional[Image.Image] = None,
# Video-to-video
input_video: Optional[list[Image.Image]] = None,
denoising_strength: Optional[float] = 1.0,
# ControlNet
reference_image: Optional[Image.Image] = None,
extra_images: Optional[List[Image.Image]] = None,
extra_image_frame_index: Optional[List[int]] = None,
# VACE
vace_video: Optional[list[Image.Image]] = None,
vace_video_mask: Optional[Image.Image] = None,
vace_reference_image: Optional[Image.Image] = None,
vace_scale: Optional[float] = 1.0,
# Randomness
seed: Optional[int] = None,
rand_device: Optional[str] = "cpu",
# Shape
mode: Optional[str] = "regression",
batch_size: Optional[int] = 1,
height: Optional[int] = 480,
width: Optional[int] = 720,
frame_mask: Optional[torch.Tensor] = None,
num_frames=41,
# Classifier-free guidance
cfg_scale: Optional[float] = 1,
cfg_merge: Optional[bool] = False,
# Scheduler
num_inference_steps: Optional[int] = 1,
sigma_shift: Optional[float] = 5.0,
denoise_step=1,
# Speed control
motion_bucket_id: Optional[int] = None,
# VAE tiling
tiled: Optional[bool] = False,
tile_size: Optional[tuple[int, int]] = (30, 52),
tile_stride: Optional[tuple[int, int]] = (15, 26),
# Sliding window
sliding_window_size: Optional[int] = None,
sliding_window_stride: Optional[int] = None,
# Teacache
tea_cache_l1_thresh: Optional[float] = None,
tea_cache_model_id: Optional[str] = "",
# progress_bar
progress_bar_cmd=tqdm,
):
self.scheduler.set_timesteps(
num_inference_steps=num_inference_steps,
denoising_strength=denoising_strength,
shift=sigma_shift,
denoise_step=denoise_step,
)
# Inputs
inputs_posi = {
"prompt": prompt,
"prompt_num": batch_size,
"tea_cache_l1_thresh": tea_cache_l1_thresh,
"tea_cache_model_id": tea_cache_model_id,
"num_inference_steps": num_inference_steps,
}
inputs_nega = {
"negative_prompt": negative_prompt,
"prompt_num": batch_size,
"tea_cache_l1_thresh": tea_cache_l1_thresh,
"tea_cache_model_id": tea_cache_model_id,
"num_inference_steps": num_inference_steps,
}
inputs_shared = {
"batch_size": batch_size,
"input_image": input_image,
"end_image": end_image,
"input_video": input_video,
"denoising_strength": denoising_strength,
"reference_image": reference_image,
"vace_video": vace_video,
"vace_video_mask": vace_video_mask,
"vace_reference_image": vace_reference_image,
"vace_scale": vace_scale,
"seed": seed,
"rand_device": rand_device,
'mode': mode,
"height": height,
"width": width,
"frame_mask": frame_mask,
"num_frames": num_frames,
"cfg_scale": cfg_scale,
"cfg_merge": cfg_merge,
"sigma_shift": sigma_shift,
"motion_bucket_id": motion_bucket_id,
"tiled": tiled,
"tile_size": tile_size,
"tile_stride": tile_stride,
"sliding_window_size": sliding_window_size,
"sliding_window_stride": sliding_window_stride,
"extra_images": extra_images,
"extra_image_frame_index": extra_image_frame_index,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
unit, self, inputs_shared, inputs_posi, inputs_nega
)
models = {name: getattr(self, name)
for name in self.in_iteration_models}
for timestep in self.scheduler.timesteps:
timestep = timestep.unsqueeze(0).to(
dtype=self.torch_dtype, device=self.device
)
# torch.cuda.synchronize()
# start_time = time.time()
noise_pred_posi = self.model_fn(
**models, **inputs_shared, **inputs_posi, timestep=timestep
)
# torch.cuda.synchronize()
# end_time = time.time()
# print(f"Model forward time: {end_time - start_time}")
noise_pred = noise_pred_posi
inputs_shared["latents"] = self.scheduler.step(
model_output=noise_pred,
sample=inputs_shared["latents"],
)
rgb, depth = None, None
if isinstance(inputs_shared['latents'], tuple):
rgb, depth = inputs_shared['latents']
else:
depth = inputs_shared['latents']
# VACE (TODO: remove it)
if vace_reference_image is not None:
inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:]
# torch.cuda.synchronize()
# start_time = time.time()
depth_video = self.vae.decode(
depth,
device=self.device,
tiled=tiled,
tile_size=tile_size,
tile_stride=tile_stride,
)
# torch.cuda.synchronize()
# end_time = time.time()
# print(f"VAE decoding time: {end_time - start_time}")
depth_video = self.vae_output_to_video(depth_video)
rgb_video = None
if rgb is not None:
rgb_video = self.vae.decode(
depth,
device=self.device,
tiled=tiled,
tile_size=tile_size,
tile_stride=tile_stride,
)
rgb_video = self.vae_output_to_video(rgb_video)
return {
'depth': depth_video,
'rgb': rgb_video
}
class PipelineUnit:
def __init__(
self,
seperate_cfg: bool = False,
take_over: bool = False,
input_params: tuple[str] = None,
input_params_posi: dict[str, str] = None,
input_params_nega: dict[str, str] = None,
onload_model_names: tuple[str] = None,
):
self.seperate_cfg = seperate_cfg
self.take_over = take_over
self.input_params = input_params
self.input_params_posi = input_params_posi
self.input_params_nega = input_params_nega
self.onload_model_names = onload_model_names
def process(
self, pipe: WanVideoPipeline, inputs: dict, positive=True, **kwargs
) -> dict:
raise NotImplementedError("`process` is not implemented.")
class PipelineUnitRunner:
def __init__(self):
pass
def __call__(
self,
unit: PipelineUnit,
pipe: WanVideoPipeline,
inputs_shared: dict,
inputs_posi: dict,
inputs_nega: dict,
) -> tuple[dict, dict]:
if unit.take_over:
# Let the pipeline unit take over this function.
inputs_shared, inputs_posi, inputs_nega = unit.process(
pipe,
inputs_shared=inputs_shared,
inputs_posi=inputs_posi,
inputs_nega=inputs_nega,
)
elif unit.seperate_cfg:
# Positive side
processor_inputs = {
name: inputs_posi.get(name_)
for name, name_ in unit.input_params_posi.items()
}
if unit.input_params is not None:
for name in unit.input_params:
processor_inputs[name] = inputs_shared.get(name)
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_posi.update(processor_outputs)
# Negative side
if inputs_shared["cfg_scale"] != 1:
processor_inputs = {
name: inputs_nega.get(name_)
for name, name_ in unit.input_params_nega.items()
}
if unit.input_params is not None:
for name in unit.input_params:
processor_inputs[name] = inputs_shared.get(name)
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_nega.update(processor_outputs)
else:
inputs_nega.update(processor_outputs)
else:
processor_inputs = {
name: inputs_shared.get(name) for name in unit.input_params
}
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_shared.update(processor_outputs)
return inputs_shared, inputs_posi, inputs_nega
class WanVideoUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(input_params=("height", "width", "num_frames"))
def process(self, pipe: WanVideoPipeline, height, width, num_frames):
# print(
# f"Init WanVideoPipeline with height={height}, width={width}, num_frames={num_frames}."
# )
height, width, num_frames = pipe.check_resize_height_width(
height, width, num_frames
)
# print(
# f"Resized WanVideoPipeline to height={height}, width={width}, num_frames={num_frames}."
# )
return {"height": height, "width": width, "num_frames": num_frames}
class WanVideoUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=(
"batch_size",
"height",
"width",
"num_frames",
"seed",
"rand_device",
"vace_reference_image",
)
)
def process(
self,
pipe: WanVideoPipeline,
batch_size,
height,
width,
num_frames,
seed,
rand_device,
vace_reference_image,
):
# print(f"num frames {num_frames}")
length = (num_frames - 1) // 4 + 1
if vace_reference_image is not None:
length += 1
# TODO
noise = pipe.generate_noise(
(batch_size, 16, length, height // 8, width // 8),
seed=seed,
rand_device=rand_device,
)
# print(f"Noise shape {noise.shape} ")
return {"noise": noise, "latents": noise}
class WanVideoUnit_InputVideoEmbedder(PipelineUnit): # For training only
def __init__(self):
super().__init__(
input_params=(
'mode',
'seed',
'rand_device',
"batch_size",
"height",
"width",
"num_frames",
"input_video",
"input_disp",
"noise",
"tiled",
"tile_size",
"tile_stride",
"vace_reference_image",
),
onload_model_names=("vae",),
)
def process(
self,
pipe,
mode,
seed,
rand_device,
batch_size,
height,
width,
num_frames,
input_video,
input_disp,
noise,
tiled,
tile_size,
tile_stride,
vace_reference_image,
):
assert mode in ['generation',
'regression'], f"mode {mode} is not supported"
length = (num_frames - 1) // 4 + 1
# inference part
if not pipe.scheduler.training:
if mode == 'generation':
# only need noise
noise = pipe.generate_noise(
(batch_size, 16, length, height // 8, width // 8),
seed=seed,
rand_device=rand_device,
)
return {'latents': noise}
else:
# only need rgb latent
video_list = []
for _input_video in input_video:
_preprocessed_video = pipe.preprocess_video(_input_video)
video_list.append(_preprocessed_video)
videos_tensor = torch.cat(video_list, dim=0)
# print(f"videos_tensor shape: {videos_tensor.shape}")
input_rgb_latents = pipe.vae.encode(
videos_tensor,
device=pipe.device,
tiled=tiled,
tile_size=tile_size,
tile_stride=tile_stride,
).to(dtype=pipe.torch_dtype, device=pipe.device)
return {"latents": input_rgb_latents}
disp_list = []
for _input_disp in input_disp:
_preprocessed_disp = pipe.preprocess_video(_input_disp)
disp_list.append(_preprocessed_disp)
disp_tensor = torch.cat(disp_list, dim=0)
input_disp_latents = pipe.vae.encode(
disp_tensor,
device=pipe.device,
tiled=tiled,
tile_size=tile_size,
tile_stride=tile_stride,
).to(dtype=pipe.torch_dtype, device=pipe.device)
# Training
if mode == 'generation':
# need noise + depth
noise = pipe.generate_noise(
(batch_size, 16, length, height // 8, width // 8),
seed=seed,
rand_device=rand_device,
)
return {'rgb_latents': noise, 'depth_latents': input_disp_latents}
else:
# need rgb + depth
video_list = []
for _input_video in input_video:
_preprocessed_video = pipe.preprocess_video(_input_video)
video_list.append(_preprocessed_video)
videos_tensor = torch.cat(video_list, dim=0)
input_rgb_latents = pipe.vae.encode(
videos_tensor,
device=pipe.device,
tiled=tiled,
tile_size=tile_size,
tile_stride=tile_stride,
).to(dtype=pipe.torch_dtype, device=pipe.device)
# del videos_tensor
return {
"rgb_latents": input_rgb_latents,
"depth_latents": input_disp_latents,
}
class WanVideoUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={
"prompt": "prompt",
"positive": "positive",
"prompt_num": "prompt_num",
},
input_params_nega={
"prompt": "negative_prompt",
"positive": "positive",
"prompt_num": "prompt_num",
},
onload_model_names=("text_encoder",),
)
def process(self, pipe: WanVideoPipeline, prompt, positive, prompt_num) -> dict:
# pipe.load_models_to_device(self.onload_model_names)
prompt_emb = []
# print(f"Encoding prompt: {prompt}")
# if isinstance(prompt, str):
# prompt = [prompt] * prompt_num
# prompt_emb = None
# for _prompt in prompt:
# _prompt_emb = pipe.prompter.encode_prompt(
# _prompt, positive=positive, device=pipe.device
# )
# prompt_emb = _prompt_emb
# break
# prompt_emb = prompt_emb.repeat(prompt_num,1,1)
# # prompt_emb = torch.cat(prompt_emb, dim=0)
# prompt_emb = prompt_emb.to(dtype=pipe.torch_dtype, device=pipe.device)
# print(f"Prompt embedding shape: {prompt_emb.shape}")
zero_pad = torch.zeros([prompt_num, 512, 4096])
zero_pad = zero_pad.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"context": zero_pad}
# return {"context": prompt_emb}
class WanVideoUnit_ImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=(
"input_image",
"end_image",
"num_frames",
"height",
"width",
"tiled",
"tile_size",
"tile_stride",
"extra_images",
"extra_image_frame_index",
),
onload_model_names=("image_encoder", "vae"),
)
def process(
self,
pipe: WanVideoPipeline,
input_image,
end_image,
num_frames,
height,
width,
tiled,
tile_size,
tile_stride,
extra_images,
extra_image_frame_index,
):
# print(f"input image shape{input_image.shape} ")
if not pipe.dit.has_image_input:
return {}
if input_image is None:
return {}
# pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image).to(pipe.device) # B C H W
batch_size = image.shape[0]
clip_context = pipe.image_encoder.encode_image([image])
msk = torch.ones(
batch_size, num_frames, height // 8, width // 8, device=pipe.device
)
# print(
# f"tiled, tile size, tile stride: {tiled}, {tile_size}, {tile_stride}")
# Assmue that one must have a input image
vae_input = torch.concat(
[
image.unsqueeze(2), # B C 1 H W
torch.zeros(batch_size, 3, num_frames - 1, height, width).to(
image.device
),
],
dim=2,
) # B C F H W
vae_input = vae_input.permute(0, 2, 1, 3, 4).contiguous() # B F C H W
if (
extra_images is not None
and extra_image_frame_index is not None
):
# print(f"extra images shape {extra_images.shape}")
for _videoid, _video in enumerate(extra_images):
# _video F C H W
for idx, image in enumerate(_video):
if idx == 0:
continue
image = pipe.preprocess_image(
image).to(pipe.device) # 1 C H W
vae_input[_videoid, idx] = image.squeeze(0)
mask = extra_image_frame_index[:, :, None, None].to(
pipe.device) # B F 1 1
mask = mask.expand(
batch_size, mask.shape[1], height // 8, width // 8
) # B F H W
msk = msk * mask
else:
msk[:, 1:] = 0
msk = torch.concat(
[torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1
)
msk = msk.view(
batch_size, msk.shape[1] // 4, 4, height // 8, width // 8
) # B F C(4) H W
msk = msk.transpose(1, 2)
vae_input = vae_input.permute(0, 2, 1, 3, 4).contiguous() # B C F H W
y = pipe.vae.encode(
vae_input.to(dtype=pipe.torch_dtype, device=pipe.device),
device=pipe.device,
tiled=tiled,
tile_size=tile_size,
tile_stride=tile_stride,
)
# print(f"y shape after VAE encode: {y.shape}")
# print(f"after VAE encode, y shape: {y.shape}")
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
# print()
y = torch.concat([msk, y], dim=1) # B 16+4 F H W
# print(f"after concat, y shape: {y.shape}")
# y = y.unsqueeze(0)
clip_context = clip_context.to(
dtype=pipe.torch_dtype, device=pipe.device)
# print(f"clip context shape: {clip_context.shape}")
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"clip_feature": clip_context, "y": y}
class WanVideoUnit_VACE(PipelineUnit):
def __init__(self):
super().__init__(
input_params=(
"vace_video",
"vace_video_mask",
"vace_reference_image",
"vace_scale",
"height",
"width",
"num_frames",
"tiled",
"tile_size",
"tile_stride",
),
onload_model_names=("vae",),
)
def process(
self,
pipe: WanVideoPipeline,
vace_video,
vace_video_mask,
vace_reference_image,
vace_scale,
height,
width,
num_frames,
tiled,
tile_size,
tile_stride,
):
if (
vace_video is not None
or vace_video_mask is not None
or vace_reference_image is not None
):
# pipe.load_models_to_device(["vae"])
if vace_video is None:
vace_video = torch.zeros(
(1, 3, num_frames, height, width),
dtype=pipe.torch_dtype,
device=pipe.device,
)
else:
vace_video = pipe.preprocess_video(vace_video)
if vace_video_mask is None:
vace_video_mask = torch.ones_like(vace_video)
else:
vace_video_mask = pipe.preprocess_video(
vace_video_mask, min_value=0, max_value=1
)
inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask
reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask)
inactive = pipe.vae.encode(
inactive,
device=pipe.device,
tiled=tiled,
tile_size=tile_size,
tile_stride=tile_stride,
).to(dtype=pipe.torch_dtype, device=pipe.device)
reactive = pipe.vae.encode(
reactive,
device=pipe.device,
tiled=tiled,
tile_size=tile_size,
tile_stride=tile_stride,
).to(dtype=pipe.torch_dtype, device=pipe.device)
vace_video_latents = torch.concat((inactive, reactive), dim=1)
vace_mask_latents = rearrange(
vace_video_mask[0, 0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8
)
vace_mask_latents = torch.nn.functional.interpolate(
vace_mask_latents,
size=(
(vace_mask_latents.shape[2] + 3) // 4,
vace_mask_latents.shape[3],
vace_mask_latents.shape[4],
),
mode="nearest-exact",
)
if vace_reference_image is None:
pass
else:
vace_reference_image = pipe.preprocess_video(
[vace_reference_image])
vace_reference_latents = pipe.vae.encode(
vace_reference_image,
device=pipe.device,
tiled=tiled,
tile_size=tile_size,
tile_stride=tile_stride,
).to(dtype=pipe.torch_dtype, device=pipe.device)
vace_reference_latents = torch.concat(
(vace_reference_latents, torch.zeros_like(
vace_reference_latents)),
dim=1,
)
vace_video_latents = torch.concat(
(vace_reference_latents, vace_video_latents), dim=2
)
vace_mask_latents = torch.concat(
(torch.zeros_like(
vace_mask_latents[:, :, :1]), vace_mask_latents),
dim=2,
)
vace_context = torch.concat(
(vace_video_latents, vace_mask_latents), dim=1)
return {"vace_context": vace_context, "vace_scale": vace_scale}
else:
# print(f"No VACE video, mask or reference image provided, skipping VACE.")
return {"vace_context": None, "vace_scale": vace_scale}
class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit):
def __init__(self):
super().__init__(input_params=())
def process(self, pipe: WanVideoPipeline):
if hasattr(pipe, "use_unified_sequence_parallel"):
if pipe.use_unified_sequence_parallel:
return {"use_unified_sequence_parallel": True}
return {}
class WanVideoUnit_CfgMerger(PipelineUnit):
def __init__(self):
super().__init__(take_over=True)
self.concat_tensor_names = ["context",
"clip_feature", "y", "reference_latents"]
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
if not inputs_shared["cfg_merge"]:
# print(f"Skipping CFG merge, cfg_merge is set to False.")
return inputs_shared, inputs_posi, inputs_nega
for name in self.concat_tensor_names:
tensor_posi = inputs_posi.get(name)
tensor_nega = inputs_nega.get(name)
tensor_shared = inputs_shared.get(name)
if tensor_posi is not None and tensor_nega is not None:
inputs_shared[name] = torch.concat(
(tensor_posi, tensor_nega), dim=0)
elif tensor_shared is not None:
inputs_shared[name] = torch.concat(
(tensor_shared, tensor_shared), dim=0
)
inputs_posi.clear()
inputs_nega.clear()
return inputs_shared, inputs_posi, inputs_nega
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
self.num_inference_steps = num_inference_steps
self.step = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.rel_l1_thresh = rel_l1_thresh
self.previous_residual = None
self.previous_hidden_states = None
self.coefficients_dict = {
"Wan2.1-T2V-1.3B": [
-5.21862437e04,
9.23041404e03,
-5.28275948e02,
1.36987616e01,
-4.99875664e-02,
],
"Wan2.1-T2V-14B": [
-3.03318725e05,
4.90537029e04,
-2.65530556e03,
5.87365115e01,
-3.15583525e-01,
],
"Wan2.1-I2V-14B-480P": [
2.57151496e05,
-3.54229917e04,
1.40286849e03,
-1.35890334e01,
1.32517977e-01,
],
"Wan2.1-I2V-14B-720P": [
8.10705460e03,
2.13393892e03,
-3.72934672e02,
1.66203073e01,
-4.17769401e-02,
],
}
if model_id not in self.coefficients_dict:
supported_model_ids = ", ".join(
[i for i in self.coefficients_dict])
raise ValueError(
f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids})."
)
self.coefficients = self.coefficients_dict[model_id]
def check(self, dit: WanModel, x, t_mod):
modulated_inp = t_mod.clone()
if self.step == 0 or self.step == self.num_inference_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = self.coefficients
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(
(
(modulated_inp - self.previous_modulated_input).abs().mean()
/ self.previous_modulated_input.abs().mean()
)
.cpu()
.item()
)
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.step += 1
if self.step == self.num_inference_steps:
self.step = 0
if should_calc:
self.previous_hidden_states = x.clone()
return not should_calc
def store(self, hidden_states):
self.previous_residual = hidden_states - self.previous_hidden_states
self.previous_hidden_states = None
def update(self, hidden_states):
hidden_states = hidden_states + self.previous_residual
return hidden_states
class TemporalTiler_BCTHW:
def __init__(self):
pass
def build_1d_mask(self, length, left_bound, right_bound, border_width):
x = torch.ones((length,))
if not left_bound:
x[:border_width] = (torch.arange(border_width) + 1) / border_width
if not right_bound:
x[-border_width:] = torch.flip(
(torch.arange(border_width) + 1) / border_width, dims=(0,)
)
return x
def build_mask(self, data, is_bound, border_width):
_, _, T, _, _ = data.shape
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
mask = repeat(t, "T -> 1 1 T 1 1")
return mask
def run(
self,
model_fn,
sliding_window_size,
sliding_window_stride,
computation_device,
computation_dtype,
model_kwargs,
tensor_names,
batch_size=None,
):
tensor_names = [
tensor_name
for tensor_name in tensor_names
if model_kwargs.get(tensor_name) is not None
]
tensor_dict = {
tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names
}
B, C, T, H, W = tensor_dict[tensor_names[0]].shape
if batch_size is not None:
B *= batch_size
data_device, data_dtype = (
tensor_dict[tensor_names[0]].device,
tensor_dict[tensor_names[0]].dtype,
)
value = torch.zeros(
(B, C, T, H, W), device=data_device, dtype=data_dtype)
weight = torch.zeros(
(1, 1, T, 1, 1), device=data_device, dtype=data_dtype)
for t in range(0, T, sliding_window_stride):
if (
t - sliding_window_stride >= 0
and t - sliding_window_stride + sliding_window_size >= T
):
continue
t_ = min(t + sliding_window_size, T)
model_kwargs.update(
{
tensor_name: tensor_dict[tensor_name][:, :, t:t_:, :].to(
device=computation_device, dtype=computation_dtype
)
for tensor_name in tensor_names
}
)
model_output = model_fn(**model_kwargs).to(
device=data_device, dtype=data_dtype
)
mask = self.build_mask(
model_output,
is_bound=(t == 0, t_ == T),
border_width=(sliding_window_size - sliding_window_stride,),
).to(device=data_device, dtype=data_dtype)
value[:, :, t:t_, :, :] += model_output * mask
weight[:, :, t:t_, :, :] += mask
value /= weight
model_kwargs.update(tensor_dict)
return value
def model_fn_wan_video(
dit: WanModel,
motion_controller: WanMotionControllerModel = None,
vace: VaceWanModel = None,
latents: torch.Tensor = None,
timestep: torch.Tensor = None,
context: torch.Tensor = None,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
reference_latents=None,
vace_context=None,
vace_scale=1.0,
tea_cache: TeaCache = None,
use_unified_sequence_parallel: bool = False,
motion_bucket_id: Optional[torch.Tensor] = None,
sliding_window_size: Optional[int] = None,
sliding_window_stride: Optional[int] = None,
cfg_merge: bool = False,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**kwargs,
):
if sliding_window_size is not None and sliding_window_stride is not None:
model_kwargs = dict(
dit=dit,
motion_controller=motion_controller,
vace=vace,
latents=latents,
timestep=timestep,
context=context,
clip_feature=clip_feature,
y=y,
reference_latents=reference_latents,
vace_context=vace_context,
vace_scale=vace_scale,
tea_cache=tea_cache,
use_unified_sequence_parallel=use_unified_sequence_parallel,
motion_bucket_id=motion_bucket_id,
)
return TemporalTiler_BCTHW().run(
model_fn_wan_video,
sliding_window_size,
sliding_window_stride,
latents.device,
latents.dtype,
model_kwargs=model_kwargs,
tensor_names=["latents", "y"],
batch_size=2 if cfg_merge else 1,
)
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
# x = latents
# print(f"Receving x with shape{x.shape}")
# print(f"timesteps {timestep}", end=" ")
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
# print(f"t_mod shape: {t_mod.shape}")
# print(f"first ten element{t_mod[0][:10]}")
if motion_bucket_id is not None and motion_controller is not None:
t_mod = t_mod + \
motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
context = dit.text_embedding(context)
# c_b, c_c, c_f, c_h, c_w = x.shape
# Merged cfg
if latents.shape[0] != context.shape[0]:
latents = torch.concat([latents] * context.shape[0], dim=0)
# print(f"Merging x to shape {x.shape}")
if timestep.shape[0] != context.shape[0]:
timestep = torch.concat([timestep] * context.shape[0], dim=0)
# import pdb
# pdb.set_trace()
if dit.has_image_input:
latents = torch.cat([latents, y], dim=1) # (b, c_x + c_y, f, h, w)
clip_embdding = dit.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)
latents, (f, h, w) = dit.patchify(latents, None)
_shortcut = latents
freqs = (
torch.cat(
[
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
)
.reshape(f * h * w, 1, -1)
.to(latents.device)
)
if tea_cache is not None:
tea_cache_update = tea_cache.check(dit, latents, t_mod)
else:
tea_cache_update = False
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
latents = torch.chunk(latents, get_sequence_parallel_world_size(), dim=1)[
get_sequence_parallel_rank()
]
if tea_cache_update:
latents = tea_cache.update(latents)
else:
def create_custom_forward(module):
def custom_forward(*inputs, **kwargs):
return module(*inputs, **kwargs)
return custom_forward
for idx, block in enumerate(dit.blocks):
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
latents = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
latents,
context,
t_mod,
freqs,
use_reentrant=False,
)
elif use_gradient_checkpointing:
latents = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
latents,
context,
t_mod,
freqs,
use_reentrant=False,
)
else:
latents = block(latents, context, t_mod, freqs)
if vace_context is not None and idx in vace.vace_layers_mapping:
current_vace_hint = vace_hints[vace.vace_layers_mapping[idx]]
if (
use_unified_sequence_parallel
and dist.is_initialized()
and dist.get_world_size() > 1
):
current_vace_hint = torch.chunk(
current_vace_hint, get_sequence_parallel_world_size(), dim=1
)[get_sequence_parallel_rank()]
latents = latents + current_vace_hint * vace_scale
if tea_cache is not None:
tea_cache.store(latents)
latents = dit.head(latents, t)
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
latents = get_sp_group().all_gather(latents, dim=1)
# Remove reference latents
if reference_latents is not None:
latents = latents[:, reference_latents.shape[1]:]
f -= 1
latents = dit.unpatchify(latents, (f, h, w))
return latents