Wan2GP / plugins /wan2gp-flashvsr /src /pipelines /flashvsr_tiny_long.py
Egnalkram's picture
Upload folder using huggingface_hub
4689c2b verified
# Copyright 2025 Wan2GP Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import types
import os
import time
from typing import Optional, Tuple, Literal
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange
from PIL import Image
from tqdm import tqdm
import types
import imageio
# import pyfiglet
from ..models import ModelManager
from ..models.utils import clean_vram, FrameStreamBuffer, TensorAsBuffer, tensor_to_imageio_frame
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline
# -----------------------------
# 基础工具:ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet)
# -----------------------------
def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)'
N, C = feat.shape[:2]
var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps
std = var.sqrt().view(N, C, 1, 1)
mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return mean, std
def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor:
assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配"
size = content_feat.size()
style_mean, style_std = _calc_mean_std(style_feat)
content_mean, content_std = _calc_mean_std(content_feat)
normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized * style_std.expand(size) + style_mean.expand(size)
# -----------------------------
# 小波式模糊与分解/重构(ColorCorrector 用)
# -----------------------------
def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor:
vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125 ],
[0.0625, 0.125, 0.0625],
]
return torch.tensor(vals, dtype=dtype, device=device)
def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor:
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
N, C, H, W = x.shape
base = _make_gaussian3x3_kernel(x.dtype, x.device)
weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1)
pad = radius
x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate')
out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C)
return out
def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
high = torch.zeros_like(x)
low = x
for i in range(levels):
radius = 2 ** i
blurred = _wavelet_blur(low, radius)
high = high + (low - blurred)
low = blurred
return high, low
def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor:
c_high, _ = _wavelet_decompose(content, levels=levels)
_, s_low = _wavelet_decompose(style, levels=levels)
return c_high + s_low
# -----------------------------
# 无状态颜色矫正模块(视频友好,默认 wavelet)
# -----------------------------
class TorchColorCorrectorWavelet(nn.Module):
def __init__(self, levels: int = 5):
super().__init__()
self.levels = levels
@staticmethod
def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
assert x.dim() == 5, '输入必须是 (B, C, f, H, W)'
B, C, f, H, W = x.shape
y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W)
return y, B, f
@staticmethod
def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor:
BF, C, H, W = y.shape
assert BF == B * f
return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4)
def forward(
self,
hq_image: torch.Tensor, # (B, C, f, H, W)
lq_image: torch.Tensor, # (B, C, f, H, W)
clip_range: Tuple[float, float] = (-1.0, 1.0),
method: Literal['wavelet', 'adain'] = 'wavelet',
chunk_size: Optional[int] = None,
) -> torch.Tensor:
assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致"
assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)"
B, C, f, H, W = hq_image.shape
if chunk_size is None or chunk_size >= f:
hq4, B, f = self._flatten_time(hq_image)
lq4, _, _ = self._flatten_time(lq_image)
if method == 'wavelet':
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
elif method == 'adain':
out4 = _adain(hq4, lq4)
else:
raise ValueError(f"未知 method: {method}")
out4 = torch.clamp(out4, *clip_range)
out = self._unflatten_time(out4, B, f)
return out
outs = []
for start in range(0, f, chunk_size):
end = min(start + chunk_size, f)
hq_chunk = hq_image[:, :, start:end]
lq_chunk = lq_image[:, :, start:end]
hq4, B_, f_ = self._flatten_time(hq_chunk)
lq4, _, _ = self._flatten_time(lq_chunk)
if method == 'wavelet':
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
elif method == 'adain':
out4 = _adain(hq4, lq4)
else:
raise ValueError(f"未知 method: {method}")
out4 = torch.clamp(out4, *clip_range)
out_chunk = self._unflatten_time(out4, B_, f_)
outs.append(out_chunk)
out = torch.cat(outs, dim=2)
return out
# -----------------------------
# 简化版 Pipeline(仅 dit + vae)
# -----------------------------
class FlashVSRTinyLongPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
self.dit: WanModel = None
self.vae: WanVideoVAE = None
self.model_names = ['dit', 'vae']
self.height_division_factor = 16
self.width_division_factor = 16
self.use_unified_sequence_parallel = False
self.prompt_emb_posi = None
self.ColorCorrector = TorchColorCorrectorWavelet(levels=5)
print(r"""
███████╗██╗ █████╗ ███████╗██╗ ██╗██╗ ██╗███████╗█████╗
██╔════╝██║ ██╔══██╗██╔════╝██║ ██║██║ ██║██╔════╝██╔══██╗ ██╗
█████╗ ██║ ███████║███████╗███████║╚██╗ ██╔╝███████╗███████║ ██████╗
██╔══╝ ██║ ██╔══██║╚════██║██╔══██║ ╚████╔╝ ╚════██║██╔═██║ ██╔═╝
██║ ███████╗██║ ██║███████║██║ ██║ ╚██╔╝ ███████║██║ ██║ ╚═╝
╚═╝ ╚══════╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═╝ ╚═╝ ╚══════╝╚═╝ ╚═╝
""")
def enable_vram_management(self, num_persistent_param_in_dit=None):
# 仅管理 dit / vae
dtype = next(iter(self.dit.parameters())).dtype
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
enable_vram_management(
self.dit,
module_map={
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
RMSNorm: AutoWrappedModule,
},
module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config=dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
self.enable_cpu_offload()
def fetch_models(self, model_manager: ModelManager):
self.dit = model_manager.fetch_model("wan_video_dit")
self.vae = model_manager.fetch_model("wan_video_vae")
@staticmethod
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
if device is None: device = model_manager.device
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
pipe = FlashVSRTinyLongPipeline(device=device, torch_dtype=torch_dtype)
pipe.fetch_models(model_manager)
# 可选:统一序列并行入口(此处默认关闭)
pipe.use_unified_sequence_parallel = False
return pipe
def denoising_model(self):
return self.dit
# -------------------------
# 新增:显式 KV 预初始化函数
# -------------------------
def init_cross_kv(
self,
context_tensor: Optional[torch.Tensor] = None,
prompt_path = None
):
self.load_models_to_device(["dit"])
"""
使用固定 prompt 生成文本 context,并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。
必须在 __call__ 前显式调用一次。
"""
#prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth"
if self.dit is None:
raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit")
if context_tensor is None:
if prompt_path is None:
raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一")
ctx = torch.load(prompt_path, map_location=self.device)
else:
ctx = context_tensor
ctx = ctx.to(dtype=self.torch_dtype, device=self.device)
if self.prompt_emb_posi is None:
self.prompt_emb_posi = {}
self.prompt_emb_posi['context'] = ctx
self.prompt_emb_posi['stats'] = "load"
if hasattr(self.dit, "reinit_cross_kv"):
self.dit.reinit_cross_kv(ctx)
else:
raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。")
self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype)
self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep))
self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim))
# Scheduler
self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0)
self.load_models_to_device([])
def prepare_unified_sequence_parallel(self):
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
def prepare_extra_input(self, latents=None):
return {}
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def _decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return frames
def decode_video(self, latents, cond=None, **kwargs):
frames = self.TCDecoder.decode_video(
latents.transpose(1, 2), # TCDecoder 需要 (B, F, C, H, W)
parallel=False,
show_progress_bar=False,
cond=cond
).transpose(1, 2).mul_(2).sub_(1) # 转回 (B, C, F, H, W) 格式,范围 -1 to 1
return frames
def offload_model(self, keep_vae=False):
self.dit.clear_cross_kv()
self.prompt_emb_posi['stats'] = "offload"
self.load_models_to_device([])
if hasattr(self.dit, "LQ_proj_in"):
self.dit.LQ_proj_in.to('cpu')
if not keep_vae:
self.TCDecoder.to('cpu')
@torch.no_grad()
def __call__(
self,
prompt=None,
negative_prompt="",
denoising_strength=1.0,
seed=None,
rand_device="gpu",
height=480,
width=832,
num_frames=81,
cfg_scale=5.0,
num_inference_steps=50,
sigma_shift=5.0,
tiled=True,
tile_size=(60, 104),
tile_stride=(30, 52),
tea_cache_l1_thresh=None,
tea_cache_model_id="Wan2.1-T2V-1.3B",
progress_bar_cmd=tqdm,
progress_bar_st=None,
LQ_video=None,
buffer_size: int = 40,
is_full_block=False,
if_buffer=False,
topk_ratio=2.0,
kv_ratio=3.0,
local_range = 9,
color_fix = True,
unload_dit = False,
force_offload = False,
fps=30,
quality=6,
output_path=None,
**kwargs,
):
# 只接受 cfg=1.0(与原代码一致)
assert cfg_scale == 1.0, "cfg_scale must be 1.0"
# 要求:必须先 init_cross_kv()
if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi:
raise RuntimeError(
"Cross-Attention KV not initialized. Please call __call__ only after:\n"
" pipe.init_cross_kv()\n"
"Or provide a custom context:\n"
" pipe.init_cross_kv(context_tensor=your_context_tensor)"
)
if isinstance(LQ_video, types.GeneratorType):
lq_buffer = FrameStreamBuffer(LQ_video, buffer_size=buffer_size, device=self.device, dtype=self.torch_dtype)
is_stream_input = True
elif isinstance(LQ_video, torch.Tensor):
lq_buffer = TensorAsBuffer(LQ_video)
is_stream_input = False
else:
raise TypeError(f"LQ_video must be torch.Tensor or Generator,But get {type(LQ_video)}")
# 尺寸修正
height, width = self.check_resize_height_width(height, width)
if num_frames % 4 != 1:
num_frames = (num_frames + 2) // 4 * 4 + 1
# print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
# Tiler 参数
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# 初始化噪声
if if_buffer:
noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
else:
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
# noise = noise.to(dtype=self.torch_dtype, device=self.device)
latents = noise
# Support both file-streaming mode (output_path provided) and in-memory mode (output_path=None)
use_file_writer = output_path is not None
writer = imageio.get_writer(output_path, fps=fps, quality=quality) if use_file_writer else None
collected_frames = [] if not use_file_writer else None
process_total_num = (num_frames - 1) // 8 - 2
is_stream = True
if self.prompt_emb_posi['stats'] == "offload":
self.init_cross_kv(context_tensor=self.prompt_emb_posi['context'])
self.load_models_to_device(["dit"])
self.dit.LQ_proj_in.to(self.device)
self.TCDecoder.to(self.device)
# 清理可能存在的 LQ_proj_in cache
if hasattr(self.dit, "LQ_proj_in"):
self.dit.LQ_proj_in.clear_cache()
frames_total = []
LQ_pre_idx = 0
LQ_cur_idx = 0
self.TCDecoder.clean_mem()
try:
with torch.no_grad():
for cur_process_idx in progress_bar_cmd(range(process_total_num)):
if cur_process_idx == 0:
pre_cache_k = [None] * len(self.dit.blocks)
pre_cache_v = [None] * len(self.dit.blocks)
LQ_latents_list = []
inner_loop_num = 7
for inner_idx in range(inner_loop_num):
start, end = max(0, inner_idx * 4 - 3), (inner_idx + 1) * 4 - 3
lq_chunk = lq_buffer.get_chunk(start, end)
cur = self.denoising_model().LQ_proj_in.stream_forward(lq_chunk)
if cur is None:
continue
LQ_latents_list.append(cur)
# 合并 LQ latents
LQ_latents = [torch.cat([l[i] for l in LQ_latents_list], dim=1) for i in range(len(LQ_latents_list[0]))]
LQ_cur_idx = (inner_loop_num - 1) * 4 - 3
cur_latents = latents[:, :, :6, :, :]
else:
LQ_latents_list = []
inner_loop_num = 2
for inner_idx in range(inner_loop_num):
start = cur_process_idx * 8 + 17 + inner_idx * 4
end = cur_process_idx * 8 + 21 + inner_idx * 4
lq_chunk = lq_buffer.get_chunk(start, end)
cur = self.denoising_model().LQ_proj_in.stream_forward(lq_chunk)
if cur is None:
continue
LQ_latents_list.append(cur)
LQ_latents = [torch.cat([l[i] for l in LQ_latents_list], dim=1) for i in range(len(LQ_latents_list[0]))]
LQ_cur_idx = cur_process_idx * 8 + 21 + (inner_loop_num - 2) * 4
cur_latents = latents[:, :, 4 + cur_process_idx * 2 : 6 + cur_process_idx * 2, :, :]
# 推理(无 motion_controller / vace)
noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
self.dit,
x=cur_latents,
timestep=self.timestep,
context=None,
tea_cache=None,
use_unified_sequence_parallel=False,
LQ_latents=LQ_latents,
is_full_block=is_full_block,
is_stream=is_stream,
pre_cache_k=pre_cache_k,
pre_cache_v=pre_cache_v,
topk_ratio=topk_ratio,
kv_ratio=kv_ratio,
cur_process_idx=cur_process_idx,
t_mod=self.t_mod,
t=self.t,
local_range = local_range,
)
# 更新 latent
cur_latents = cur_latents - noise_pred_posi
# Decode
cur_LQ_frame = lq_buffer.get_chunk(LQ_pre_idx, LQ_cur_idx).to(self.device)
cur_frames = self.decode_video(cur_latents, cond=cur_LQ_frame)
# 颜色校正(wavelet)
try:
if color_fix:
cur_frames = self.ColorCorrector(
cur_frames.to(device=self.device),
cur_LQ_frame,
clip_range=(-1, 1),
chunk_size=None,
method='adain'
)
except:
pass
num_frames_in_chunk = cur_frames.shape[2]
for i in range(num_frames_in_chunk):
single_frame_tensor = cur_frames[0, :, i, :, :]
if use_file_writer:
imageio_frame = tensor_to_imageio_frame(single_frame_tensor)
writer.append_data(imageio_frame)
else:
# Collect frame for in-memory return (same format as tensor_to_imageio_frame)
# Convert from [-1, 1] to [0, 1] range, format (C, H, W) -> (H, W, C)
frame_01 = (single_frame_tensor.clamp(-1, 1) + 1) / 2
collected_frames.append(frame_01.permute(1, 2, 0).cpu())
LQ_pre_idx = LQ_cur_idx
if unload_dit:
del noise_pred_posi, cur_frames, cur_latents, cur_LQ_frame
clean_vram()
# NOTE: Do NOT clear ANY state between iterations within the same pipeline call!
# The streaming design expects continuous state across iterations:
# - LQ_proj_in: clip_idx and caches must persist for correct shape matching
# - TCDecoder: mem must persist to avoid 3-frame trim penalty on each iteration
#
# State is cleared only at the START of a new pipeline call (already done above).
# Clearing inside the loop causes:
# - LQ_proj_in.clear_cache(): Warmup consumes iteration → shape mismatch
# - TCDecoder.clean_mem(): 3-frame trim per iteration → major frame loss
#
# Without clearing TCDecoder.mem inside the loop:
# - Iteration 0: trim 3 frames → 21 output frames
# - Iterations 1-49: no trim → 8 frames each → 392 frames
# - Total: ~413 frames (close to 417 input)
if force_offload:
self.offload_model()
except Exception as e:
print(f"Error: {e}")
return False
finally:
if use_file_writer and writer is not None:
writer.close()
# Return based on mode
if use_file_writer:
return True
else:
# Return stacked tensor in (B, C, F, H, W) format for compatibility with tensor2video
if collected_frames:
# Stack frames: list of (H, W, C) -> (F, H, W, C) -> (1, C, F, H, W)
stacked = torch.stack(collected_frames, dim=0) # (F, H, W, C)
# Convert to (1, C, F, H, W) format expected by tensor2video
output_tensor = stacked.permute(3, 0, 1, 2).unsqueeze(0) # (1, C, F, H, W)
# Scale from [0, 1] to [-1, 1] for tensor2video compatibility
output_tensor = output_tensor * 2 - 1
return output_tensor
else:
return None
# -----------------------------
# TeaCache(保留原逻辑;此处默认不启用)
# -----------------------------
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
self.num_inference_steps = num_inference_steps
self.step = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.rel_l1_thresh = rel_l1_thresh
self.previous_residual = None
self.previous_hidden_states = None
self.coefficients_dict = {
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
"Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
}
if model_id not in self.coefficients_dict:
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
self.coefficients = self.coefficients_dict[model_id]
def check(self, dit: WanModel, x, t_mod):
modulated_inp = t_mod.clone()
if self.step == 0 or self.step == self.num_inference_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = self.coefficients
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh)
if should_calc:
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.step = (self.step + 1) % self.num_inference_steps
if should_calc:
self.previous_hidden_states = x.clone()
return not should_calc
def store(self, hidden_states):
self.previous_residual = hidden_states - self.previous_hidden_states
self.previous_hidden_states = None
def update(self, hidden_states):
hidden_states = hidden_states + self.previous_residual
return hidden_states
# -----------------------------
# 简化版模型前向封装(无 vace / 无 motion_controller)
# -----------------------------
def model_fn_wan_video(
dit: WanModel,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
tea_cache: Optional[TeaCache] = None,
use_unified_sequence_parallel: bool = False,
LQ_latents: Optional[torch.Tensor] = None,
is_full_block: bool = False,
is_stream: bool = False,
pre_cache_k: Optional[list[torch.Tensor]] = None,
pre_cache_v: Optional[list[torch.Tensor]] = None,
topk_ratio: float = 2.0,
kv_ratio: float = 3.0,
cur_process_idx: int = 0,
t_mod : torch.Tensor = None,
t : torch.Tensor = None,
local_range: int = 9,
**kwargs,
):
# patchify
x, (f, h, w) = dit.patchify(x)
win = (2, 8, 8)
seqlen = f // win[0]
local_num = seqlen
window_size = win[0] * h * w // 128
square_num = window_size * window_size
topk = int(square_num * topk_ratio) - 1
kv_len = int(kv_ratio)
# RoPE 位置(分段)
if cur_process_idx == 0:
freqs = torch.cat([
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
else:
freqs = torch.cat([
dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
# TeaCache(默认不启用)
tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False
# 统一序列并行(此处默认关闭)
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
if dist.is_initialized() and dist.get_world_size() > 1:
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
# Block 堆叠
if tea_cache_update:
x = tea_cache.update(x)
else:
for block_id, block in enumerate(dit.blocks):
if LQ_latents is not None and block_id < len(LQ_latents):
x = x + LQ_latents[block_id]
x, last_pre_cache_k, last_pre_cache_v = block(
x, context, t_mod, freqs, f, h, w,
local_num, topk,
block_id=block_id,
kv_len=kv_len,
is_full_block=is_full_block,
is_stream=is_stream,
pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None,
pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None,
local_range = local_range,
)
if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k
if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v
x = dit.head(x, t)
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import get_sp_group
if dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)
x = dit.unpatchify(x, (f, h, w))
return x, pre_cache_k, pre_cache_v