import torch import torch.nn as nn def load_modulation( modulate_type: str, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None): factory_kwargs = {"dtype": dtype, "device": device} if modulate_type == 'wanx': return ModulateWan(hidden_size, factor, **factory_kwargs) raise ValueError( f"Unknown modulation type: {modulate_type}. Only 'wanx' is supported.") class ModulateWan(nn.Module): """Modulation layer for WanX.""" def __init__( self, hidden_size: int, factor: int, dtype=None, device=None, ): super().__init__() self.factor = factor self.modulate_table = nn.Parameter( torch.zeros(1, factor, hidden_size, dtype=dtype, device=device) / hidden_size**0.5, requires_grad=True ) def forward(self, x: torch.Tensor) -> torch.Tensor: if len(x.shape) != 3: x = x.unsqueeze(1) return [o.squeeze(1) for o in (self.modulate_table + x).chunk(self.factor, dim=1)] def modulate(x, shift=None, scale=None): """modulate by shift and scale Args: x (torch.Tensor): input tensor. shift (torch.Tensor, optional): shift tensor. Defaults to None. scale (torch.Tensor, optional): scale tensor. Defaults to None. Returns: torch.Tensor: the output tensor after modulate. """ if scale is None and shift is None: return x elif shift is None: return x * (1 + scale.unsqueeze(1)) elif scale is None: return x + shift.unsqueeze(1) else: return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) def apply_gate(x, gate=None, tanh=False): """Apply gating to tensor. Args: x (torch.Tensor): input tensor. gate (torch.Tensor, optional): gate tensor. Defaults to None. tanh (bool, optional): whether to use tanh function. Defaults to False. Returns: torch.Tensor: the output tensor after apply gate. """ if gate is None: return x if tanh: return x * gate.unsqueeze(1).tanh() else: return x * gate.unsqueeze(1)