Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| def load_modulation( | |
| modulate_type: str, | |
| hidden_size: int, | |
| factor: int, | |
| act_layer=nn.SiLU, | |
| dtype=None, | |
| device=None): | |
| factory_kwargs = {"dtype": dtype, "device": device} | |
| if modulate_type == 'wanx': | |
| return ModulateWan(hidden_size, factor, **factory_kwargs) | |
| raise ValueError( | |
| f"Unknown modulation type: {modulate_type}. Only 'wanx' is supported.") | |
| class ModulateWan(nn.Module): | |
| """Modulation layer for WanX.""" | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| factor: int, | |
| dtype=None, | |
| device=None, | |
| ): | |
| super().__init__() | |
| self.factor = factor | |
| self.modulate_table = nn.Parameter( | |
| torch.zeros(1, factor, hidden_size, | |
| dtype=dtype, device=device) / hidden_size**0.5, | |
| requires_grad=True | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if len(x.shape) != 3: | |
| x = x.unsqueeze(1) | |
| return [o.squeeze(1) for o in (self.modulate_table + x).chunk(self.factor, dim=1)] | |
| def modulate(x, shift=None, scale=None): | |
| """modulate by shift and scale | |
| Args: | |
| x (torch.Tensor): input tensor. | |
| shift (torch.Tensor, optional): shift tensor. Defaults to None. | |
| scale (torch.Tensor, optional): scale tensor. Defaults to None. | |
| Returns: | |
| torch.Tensor: the output tensor after modulate. | |
| """ | |
| if scale is None and shift is None: | |
| return x | |
| elif shift is None: | |
| return x * (1 + scale.unsqueeze(1)) | |
| elif scale is None: | |
| return x + shift.unsqueeze(1) | |
| else: | |
| return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) | |
| def apply_gate(x, gate=None, tanh=False): | |
| """Apply gating to tensor. | |
| Args: | |
| x (torch.Tensor): input tensor. | |
| gate (torch.Tensor, optional): gate tensor. Defaults to None. | |
| tanh (bool, optional): whether to use tanh function. Defaults to False. | |
| Returns: | |
| torch.Tensor: the output tensor after apply gate. | |
| """ | |
| if gate is None: | |
| return x | |
| if tanh: | |
| return x * gate.unsqueeze(1).tanh() | |
| else: | |
| return x * gate.unsqueeze(1) | |