|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import types
|
| import os
|
| import time
|
| from typing import Optional, Tuple, Literal
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import numpy as np
|
| from einops import rearrange
|
| from PIL import Image
|
| from tqdm import tqdm
|
| import types
|
| import imageio
|
|
|
|
|
| from ..models import ModelManager
|
| from ..models.utils import clean_vram, FrameStreamBuffer, TensorAsBuffer, tensor_to_imageio_frame
|
| from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
| from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
| from ..schedulers.flow_match import FlowMatchScheduler
|
| from .base import BasePipeline
|
|
|
|
|
|
|
|
|
| def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
|
| assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)'
|
| N, C = feat.shape[:2]
|
| var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps
|
| std = var.sqrt().view(N, C, 1, 1)
|
| mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
| return mean, std
|
|
|
|
|
| def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor:
|
| assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配"
|
| size = content_feat.size()
|
| style_mean, style_std = _calc_mean_std(style_feat)
|
| content_mean, content_std = _calc_mean_std(content_feat)
|
| normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
| return normalized * style_std.expand(size) + style_mean.expand(size)
|
|
|
|
|
|
|
|
|
|
|
| def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor:
|
| vals = [
|
| [0.0625, 0.125, 0.0625],
|
| [0.125, 0.25, 0.125 ],
|
| [0.0625, 0.125, 0.0625],
|
| ]
|
| return torch.tensor(vals, dtype=dtype, device=device)
|
|
|
|
|
| def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor:
|
| assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
|
| N, C, H, W = x.shape
|
| base = _make_gaussian3x3_kernel(x.dtype, x.device)
|
| weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1)
|
| pad = radius
|
| x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate')
|
| out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C)
|
| return out
|
|
|
|
|
| def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
|
| assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
|
| high = torch.zeros_like(x)
|
| low = x
|
| for i in range(levels):
|
| radius = 2 ** i
|
| blurred = _wavelet_blur(low, radius)
|
| high = high + (low - blurred)
|
| low = blurred
|
| return high, low
|
|
|
|
|
| def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor:
|
| c_high, _ = _wavelet_decompose(content, levels=levels)
|
| _, s_low = _wavelet_decompose(style, levels=levels)
|
| return c_high + s_low
|
|
|
|
|
|
|
|
|
|
|
| class TorchColorCorrectorWavelet(nn.Module):
|
| def __init__(self, levels: int = 5):
|
| super().__init__()
|
| self.levels = levels
|
|
|
| @staticmethod
|
| def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
|
| assert x.dim() == 5, '输入必须是 (B, C, f, H, W)'
|
| B, C, f, H, W = x.shape
|
| y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W)
|
| return y, B, f
|
|
|
| @staticmethod
|
| def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor:
|
| BF, C, H, W = y.shape
|
| assert BF == B * f
|
| return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4)
|
|
|
| def forward(
|
| self,
|
| hq_image: torch.Tensor,
|
| lq_image: torch.Tensor,
|
| clip_range: Tuple[float, float] = (-1.0, 1.0),
|
| method: Literal['wavelet', 'adain'] = 'wavelet',
|
| chunk_size: Optional[int] = None,
|
| ) -> torch.Tensor:
|
| assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致"
|
| assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)"
|
|
|
| B, C, f, H, W = hq_image.shape
|
| if chunk_size is None or chunk_size >= f:
|
| hq4, B, f = self._flatten_time(hq_image)
|
| lq4, _, _ = self._flatten_time(lq_image)
|
| if method == 'wavelet':
|
| out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
|
| elif method == 'adain':
|
| out4 = _adain(hq4, lq4)
|
| else:
|
| raise ValueError(f"未知 method: {method}")
|
| out4 = torch.clamp(out4, *clip_range)
|
| out = self._unflatten_time(out4, B, f)
|
| return out
|
|
|
| outs = []
|
| for start in range(0, f, chunk_size):
|
| end = min(start + chunk_size, f)
|
| hq_chunk = hq_image[:, :, start:end]
|
| lq_chunk = lq_image[:, :, start:end]
|
| hq4, B_, f_ = self._flatten_time(hq_chunk)
|
| lq4, _, _ = self._flatten_time(lq_chunk)
|
| if method == 'wavelet':
|
| out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
|
| elif method == 'adain':
|
| out4 = _adain(hq4, lq4)
|
| else:
|
| raise ValueError(f"未知 method: {method}")
|
| out4 = torch.clamp(out4, *clip_range)
|
| out_chunk = self._unflatten_time(out4, B_, f_)
|
| outs.append(out_chunk)
|
| out = torch.cat(outs, dim=2)
|
| return out
|
|
|
|
|
|
|
|
|
|
|
| class FlashVSRTinyLongPipeline(BasePipeline):
|
|
|
| def __init__(self, device="cuda", torch_dtype=torch.float16):
|
| super().__init__(device=device, torch_dtype=torch_dtype)
|
| self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
|
| self.dit: WanModel = None
|
| self.vae: WanVideoVAE = None
|
| self.model_names = ['dit', 'vae']
|
| self.height_division_factor = 16
|
| self.width_division_factor = 16
|
| self.use_unified_sequence_parallel = False
|
| self.prompt_emb_posi = None
|
| self.ColorCorrector = TorchColorCorrectorWavelet(levels=5)
|
|
|
| print(r"""
|
| ███████╗██╗ █████╗ ███████╗██╗ ██╗██╗ ██╗███████╗█████╗
|
| ██╔════╝██║ ██╔══██╗██╔════╝██║ ██║██║ ██║██╔════╝██╔══██╗ ██╗
|
| █████╗ ██║ ███████║███████╗███████║╚██╗ ██╔╝███████╗███████║ ██████╗
|
| ██╔══╝ ██║ ██╔══██║╚════██║██╔══██║ ╚████╔╝ ╚════██║██╔═██║ ██╔═╝
|
| ██║ ███████╗██║ ██║███████║██║ ██║ ╚██╔╝ ███████║██║ ██║ ╚═╝
|
| ╚═╝ ╚══════╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═╝ ╚═╝ ╚══════╝╚═╝ ╚═╝
|
| """)
|
|
|
| def enable_vram_management(self, num_persistent_param_in_dit=None):
|
|
|
| dtype = next(iter(self.dit.parameters())).dtype
|
| from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
| enable_vram_management(
|
| self.dit,
|
| module_map={
|
| torch.nn.Linear: AutoWrappedLinear,
|
| torch.nn.Conv3d: AutoWrappedModule,
|
| torch.nn.LayerNorm: AutoWrappedModule,
|
| RMSNorm: AutoWrappedModule,
|
| },
|
| module_config=dict(
|
| offload_dtype=dtype,
|
| offload_device="cpu",
|
| onload_dtype=dtype,
|
| onload_device=self.device,
|
| computation_dtype=self.torch_dtype,
|
| computation_device=self.device,
|
| ),
|
| max_num_param=num_persistent_param_in_dit,
|
| overflow_module_config=dict(
|
| offload_dtype=dtype,
|
| offload_device="cpu",
|
| onload_dtype=dtype,
|
| onload_device="cpu",
|
| computation_dtype=self.torch_dtype,
|
| computation_device=self.device,
|
| ),
|
| )
|
| self.enable_cpu_offload()
|
|
|
| def fetch_models(self, model_manager: ModelManager):
|
| self.dit = model_manager.fetch_model("wan_video_dit")
|
| self.vae = model_manager.fetch_model("wan_video_vae")
|
|
|
|
|
| @staticmethod
|
| def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
|
| if device is None: device = model_manager.device
|
| if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
| pipe = FlashVSRTinyLongPipeline(device=device, torch_dtype=torch_dtype)
|
| pipe.fetch_models(model_manager)
|
|
|
| pipe.use_unified_sequence_parallel = False
|
| return pipe
|
|
|
| def denoising_model(self):
|
| return self.dit
|
|
|
|
|
|
|
|
|
| def init_cross_kv(
|
| self,
|
| context_tensor: Optional[torch.Tensor] = None,
|
| prompt_path = None
|
| ):
|
| self.load_models_to_device(["dit"])
|
| """
|
| 使用固定 prompt 生成文本 context,并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。
|
| 必须在 __call__ 前显式调用一次。
|
| """
|
|
|
|
|
| if self.dit is None:
|
| raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit")
|
|
|
| if context_tensor is None:
|
| if prompt_path is None:
|
| raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一")
|
| ctx = torch.load(prompt_path, map_location=self.device)
|
| else:
|
| ctx = context_tensor
|
|
|
| ctx = ctx.to(dtype=self.torch_dtype, device=self.device)
|
|
|
| if self.prompt_emb_posi is None:
|
| self.prompt_emb_posi = {}
|
| self.prompt_emb_posi['context'] = ctx
|
| self.prompt_emb_posi['stats'] = "load"
|
|
|
| if hasattr(self.dit, "reinit_cross_kv"):
|
| self.dit.reinit_cross_kv(ctx)
|
| else:
|
| raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。")
|
| self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype)
|
| self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep))
|
| self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim))
|
|
|
| self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0)
|
| self.load_models_to_device([])
|
|
|
| def prepare_unified_sequence_parallel(self):
|
| return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
|
|
|
| def prepare_extra_input(self, latents=None):
|
| return {}
|
|
|
| def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
| return latents
|
|
|
| def _decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
| return frames
|
|
|
| def decode_video(self, latents, cond=None, **kwargs):
|
| frames = self.TCDecoder.decode_video(
|
| latents.transpose(1, 2),
|
| parallel=False,
|
| show_progress_bar=False,
|
| cond=cond
|
| ).transpose(1, 2).mul_(2).sub_(1)
|
|
|
| return frames
|
|
|
| def offload_model(self, keep_vae=False):
|
| self.dit.clear_cross_kv()
|
| self.prompt_emb_posi['stats'] = "offload"
|
| self.load_models_to_device([])
|
| if hasattr(self.dit, "LQ_proj_in"):
|
| self.dit.LQ_proj_in.to('cpu')
|
| if not keep_vae:
|
| self.TCDecoder.to('cpu')
|
|
|
| @torch.no_grad()
|
| def __call__(
|
| self,
|
| prompt=None,
|
| negative_prompt="",
|
| denoising_strength=1.0,
|
| seed=None,
|
| rand_device="gpu",
|
| height=480,
|
| width=832,
|
| num_frames=81,
|
| cfg_scale=5.0,
|
| num_inference_steps=50,
|
| sigma_shift=5.0,
|
| tiled=True,
|
| tile_size=(60, 104),
|
| tile_stride=(30, 52),
|
| tea_cache_l1_thresh=None,
|
| tea_cache_model_id="Wan2.1-T2V-1.3B",
|
| progress_bar_cmd=tqdm,
|
| progress_bar_st=None,
|
| LQ_video=None,
|
| buffer_size: int = 40,
|
| is_full_block=False,
|
| if_buffer=False,
|
| topk_ratio=2.0,
|
| kv_ratio=3.0,
|
| local_range = 9,
|
| color_fix = True,
|
| unload_dit = False,
|
| force_offload = False,
|
| fps=30,
|
| quality=6,
|
| output_path=None,
|
| **kwargs,
|
| ):
|
|
|
| assert cfg_scale == 1.0, "cfg_scale must be 1.0"
|
|
|
|
|
| if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi:
|
| raise RuntimeError(
|
| "Cross-Attention KV not initialized. Please call __call__ only after:\n"
|
| " pipe.init_cross_kv()\n"
|
| "Or provide a custom context:\n"
|
| " pipe.init_cross_kv(context_tensor=your_context_tensor)"
|
| )
|
|
|
| if isinstance(LQ_video, types.GeneratorType):
|
| lq_buffer = FrameStreamBuffer(LQ_video, buffer_size=buffer_size, device=self.device, dtype=self.torch_dtype)
|
| is_stream_input = True
|
| elif isinstance(LQ_video, torch.Tensor):
|
| lq_buffer = TensorAsBuffer(LQ_video)
|
| is_stream_input = False
|
| else:
|
| raise TypeError(f"LQ_video must be torch.Tensor or Generator,But get {type(LQ_video)}")
|
|
|
|
|
| height, width = self.check_resize_height_width(height, width)
|
| if num_frames % 4 != 1:
|
| num_frames = (num_frames + 2) // 4 * 4 + 1
|
|
|
|
|
|
|
| tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
|
|
|
|
| if if_buffer:
|
| noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
| else:
|
| noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
|
|
| latents = noise
|
|
|
|
|
| use_file_writer = output_path is not None
|
| writer = imageio.get_writer(output_path, fps=fps, quality=quality) if use_file_writer else None
|
| collected_frames = [] if not use_file_writer else None
|
|
|
| process_total_num = (num_frames - 1) // 8 - 2
|
| is_stream = True
|
|
|
| if self.prompt_emb_posi['stats'] == "offload":
|
| self.init_cross_kv(context_tensor=self.prompt_emb_posi['context'])
|
| self.load_models_to_device(["dit"])
|
| self.dit.LQ_proj_in.to(self.device)
|
| self.TCDecoder.to(self.device)
|
|
|
|
|
| if hasattr(self.dit, "LQ_proj_in"):
|
| self.dit.LQ_proj_in.clear_cache()
|
|
|
| frames_total = []
|
| LQ_pre_idx = 0
|
| LQ_cur_idx = 0
|
| self.TCDecoder.clean_mem()
|
| try:
|
| with torch.no_grad():
|
| for cur_process_idx in progress_bar_cmd(range(process_total_num)):
|
| if cur_process_idx == 0:
|
| pre_cache_k = [None] * len(self.dit.blocks)
|
| pre_cache_v = [None] * len(self.dit.blocks)
|
| LQ_latents_list = []
|
| inner_loop_num = 7
|
| for inner_idx in range(inner_loop_num):
|
| start, end = max(0, inner_idx * 4 - 3), (inner_idx + 1) * 4 - 3
|
| lq_chunk = lq_buffer.get_chunk(start, end)
|
|
|
| cur = self.denoising_model().LQ_proj_in.stream_forward(lq_chunk)
|
| if cur is None:
|
| continue
|
| LQ_latents_list.append(cur)
|
|
|
|
|
| LQ_latents = [torch.cat([l[i] for l in LQ_latents_list], dim=1) for i in range(len(LQ_latents_list[0]))]
|
| LQ_cur_idx = (inner_loop_num - 1) * 4 - 3
|
| cur_latents = latents[:, :, :6, :, :]
|
| else:
|
| LQ_latents_list = []
|
| inner_loop_num = 2
|
| for inner_idx in range(inner_loop_num):
|
| start = cur_process_idx * 8 + 17 + inner_idx * 4
|
| end = cur_process_idx * 8 + 21 + inner_idx * 4
|
| lq_chunk = lq_buffer.get_chunk(start, end)
|
|
|
| cur = self.denoising_model().LQ_proj_in.stream_forward(lq_chunk)
|
| if cur is None:
|
| continue
|
| LQ_latents_list.append(cur)
|
|
|
| LQ_latents = [torch.cat([l[i] for l in LQ_latents_list], dim=1) for i in range(len(LQ_latents_list[0]))]
|
| LQ_cur_idx = cur_process_idx * 8 + 21 + (inner_loop_num - 2) * 4
|
| cur_latents = latents[:, :, 4 + cur_process_idx * 2 : 6 + cur_process_idx * 2, :, :]
|
|
|
|
|
| noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
|
| self.dit,
|
| x=cur_latents,
|
| timestep=self.timestep,
|
| context=None,
|
| tea_cache=None,
|
| use_unified_sequence_parallel=False,
|
| LQ_latents=LQ_latents,
|
| is_full_block=is_full_block,
|
| is_stream=is_stream,
|
| pre_cache_k=pre_cache_k,
|
| pre_cache_v=pre_cache_v,
|
| topk_ratio=topk_ratio,
|
| kv_ratio=kv_ratio,
|
| cur_process_idx=cur_process_idx,
|
| t_mod=self.t_mod,
|
| t=self.t,
|
| local_range = local_range,
|
| )
|
|
|
|
|
| cur_latents = cur_latents - noise_pred_posi
|
|
|
|
|
| cur_LQ_frame = lq_buffer.get_chunk(LQ_pre_idx, LQ_cur_idx).to(self.device)
|
| cur_frames = self.decode_video(cur_latents, cond=cur_LQ_frame)
|
|
|
|
|
| try:
|
| if color_fix:
|
| cur_frames = self.ColorCorrector(
|
| cur_frames.to(device=self.device),
|
| cur_LQ_frame,
|
| clip_range=(-1, 1),
|
| chunk_size=None,
|
| method='adain'
|
| )
|
| except:
|
| pass
|
|
|
| num_frames_in_chunk = cur_frames.shape[2]
|
| for i in range(num_frames_in_chunk):
|
| single_frame_tensor = cur_frames[0, :, i, :, :]
|
| if use_file_writer:
|
| imageio_frame = tensor_to_imageio_frame(single_frame_tensor)
|
| writer.append_data(imageio_frame)
|
| else:
|
|
|
|
|
| frame_01 = (single_frame_tensor.clamp(-1, 1) + 1) / 2
|
| collected_frames.append(frame_01.permute(1, 2, 0).cpu())
|
|
|
| LQ_pre_idx = LQ_cur_idx
|
|
|
| if unload_dit:
|
| del noise_pred_posi, cur_frames, cur_latents, cur_LQ_frame
|
| clean_vram()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| if force_offload:
|
| self.offload_model()
|
|
|
| except Exception as e:
|
| print(f"Error: {e}")
|
| return False
|
|
|
| finally:
|
| if use_file_writer and writer is not None:
|
| writer.close()
|
|
|
|
|
| if use_file_writer:
|
| return True
|
| else:
|
|
|
| if collected_frames:
|
|
|
| stacked = torch.stack(collected_frames, dim=0)
|
|
|
| output_tensor = stacked.permute(3, 0, 1, 2).unsqueeze(0)
|
|
|
| output_tensor = output_tensor * 2 - 1
|
| return output_tensor
|
| else:
|
| return None
|
|
|
|
|
|
|
|
|
|
|
| class TeaCache:
|
| def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
| self.num_inference_steps = num_inference_steps
|
| self.step = 0
|
| self.accumulated_rel_l1_distance = 0
|
| self.previous_modulated_input = None
|
| self.rel_l1_thresh = rel_l1_thresh
|
| self.previous_residual = None
|
| self.previous_hidden_states = None
|
|
|
| self.coefficients_dict = {
|
| "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
|
| "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
|
| "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
|
| "Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
|
| }
|
| if model_id not in self.coefficients_dict:
|
| supported_model_ids = ", ".join([i for i in self.coefficients_dict])
|
| raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
|
| self.coefficients = self.coefficients_dict[model_id]
|
|
|
| def check(self, dit: WanModel, x, t_mod):
|
| modulated_inp = t_mod.clone()
|
| if self.step == 0 or self.step == self.num_inference_steps - 1:
|
| should_calc = True
|
| self.accumulated_rel_l1_distance = 0
|
| else:
|
| coefficients = self.coefficients
|
| rescale_func = np.poly1d(coefficients)
|
| self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
| should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh)
|
| if should_calc:
|
| self.accumulated_rel_l1_distance = 0
|
| self.previous_modulated_input = modulated_inp
|
| self.step = (self.step + 1) % self.num_inference_steps
|
| if should_calc:
|
| self.previous_hidden_states = x.clone()
|
| return not should_calc
|
|
|
| def store(self, hidden_states):
|
| self.previous_residual = hidden_states - self.previous_hidden_states
|
| self.previous_hidden_states = None
|
|
|
| def update(self, hidden_states):
|
| hidden_states = hidden_states + self.previous_residual
|
| return hidden_states
|
|
|
|
|
|
|
|
|
|
|
| def model_fn_wan_video(
|
| dit: WanModel,
|
| x: torch.Tensor,
|
| timestep: torch.Tensor,
|
| context: torch.Tensor,
|
| tea_cache: Optional[TeaCache] = None,
|
| use_unified_sequence_parallel: bool = False,
|
| LQ_latents: Optional[torch.Tensor] = None,
|
| is_full_block: bool = False,
|
| is_stream: bool = False,
|
| pre_cache_k: Optional[list[torch.Tensor]] = None,
|
| pre_cache_v: Optional[list[torch.Tensor]] = None,
|
| topk_ratio: float = 2.0,
|
| kv_ratio: float = 3.0,
|
| cur_process_idx: int = 0,
|
| t_mod : torch.Tensor = None,
|
| t : torch.Tensor = None,
|
| local_range: int = 9,
|
| **kwargs,
|
| ):
|
|
|
| x, (f, h, w) = dit.patchify(x)
|
|
|
| win = (2, 8, 8)
|
| seqlen = f // win[0]
|
| local_num = seqlen
|
| window_size = win[0] * h * w // 128
|
| square_num = window_size * window_size
|
| topk = int(square_num * topk_ratio) - 1
|
| kv_len = int(kv_ratio)
|
|
|
|
|
| if cur_process_idx == 0:
|
| freqs = torch.cat([
|
| dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
| else:
|
| freqs = torch.cat([
|
| dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
|
|
|
|
| tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False
|
|
|
|
|
| if use_unified_sequence_parallel:
|
| import torch.distributed as dist
|
| from xfuser.core.distributed import (get_sequence_parallel_rank,
|
| get_sequence_parallel_world_size,
|
| get_sp_group)
|
| if dist.is_initialized() and dist.get_world_size() > 1:
|
| x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
|
|
|
|
| if tea_cache_update:
|
| x = tea_cache.update(x)
|
| else:
|
| for block_id, block in enumerate(dit.blocks):
|
| if LQ_latents is not None and block_id < len(LQ_latents):
|
| x = x + LQ_latents[block_id]
|
| x, last_pre_cache_k, last_pre_cache_v = block(
|
| x, context, t_mod, freqs, f, h, w,
|
| local_num, topk,
|
| block_id=block_id,
|
| kv_len=kv_len,
|
| is_full_block=is_full_block,
|
| is_stream=is_stream,
|
| pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None,
|
| pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None,
|
| local_range = local_range,
|
| )
|
| if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k
|
| if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v
|
|
|
| x = dit.head(x, t)
|
| if use_unified_sequence_parallel:
|
| import torch.distributed as dist
|
| from xfuser.core.distributed import get_sp_group
|
| if dist.is_initialized() and dist.get_world_size() > 1:
|
| x = get_sp_group().all_gather(x, dim=1)
|
| x = dit.unpatchify(x, (f, h, w))
|
| return x, pre_cache_k, pre_cache_v |