JoyAI-Image-Edit-Space / src /modules /models /mmdit /dit /modulate_layers.py
stevengrove's picture
Initial commit with Xet-tracked image assets
fcfea15
import torch
import torch.nn as nn
def load_modulation(
modulate_type: str,
hidden_size: int,
factor: int,
act_layer=nn.SiLU,
dtype=None,
device=None):
factory_kwargs = {"dtype": dtype, "device": device}
if modulate_type == 'wanx':
return ModulateWan(hidden_size, factor, **factory_kwargs)
raise ValueError(
f"Unknown modulation type: {modulate_type}. Only 'wanx' is supported.")
class ModulateWan(nn.Module):
"""Modulation layer for WanX."""
def __init__(
self,
hidden_size: int,
factor: int,
dtype=None,
device=None,
):
super().__init__()
self.factor = factor
self.modulate_table = nn.Parameter(
torch.zeros(1, factor, hidden_size,
dtype=dtype, device=device) / hidden_size**0.5,
requires_grad=True
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(x.shape) != 3:
x = x.unsqueeze(1)
return [o.squeeze(1) for o in (self.modulate_table + x).chunk(self.factor, dim=1)]
def modulate(x, shift=None, scale=None):
"""modulate by shift and scale
Args:
x (torch.Tensor): input tensor.
shift (torch.Tensor, optional): shift tensor. Defaults to None.
scale (torch.Tensor, optional): scale tensor. Defaults to None.
Returns:
torch.Tensor: the output tensor after modulate.
"""
if scale is None and shift is None:
return x
elif shift is None:
return x * (1 + scale.unsqueeze(1))
elif scale is None:
return x + shift.unsqueeze(1)
else:
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def apply_gate(x, gate=None, tanh=False):
"""Apply gating to tensor.
Args:
x (torch.Tensor): input tensor.
gate (torch.Tensor, optional): gate tensor. Defaults to None.
tanh (bool, optional): whether to use tanh function. Defaults to False.
Returns:
torch.Tensor: the output tensor after apply gate.
"""
if gate is None:
return x
if tanh:
return x * gate.unsqueeze(1).tanh()
else:
return x * gate.unsqueeze(1)