| | from typing import Callable |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| |
|
| | class ModulateDiT(nn.Module): |
| | """Modulation layer for DiT.""" |
| | def __init__( |
| | self, |
| | hidden_size: int, |
| | factor: int, |
| | act_layer: Callable, |
| | dtype=None, |
| | device=None, |
| | ): |
| | factory_kwargs = {"dtype": dtype, "device": device} |
| | super().__init__() |
| | self.act = act_layer() |
| | self.linear = nn.Linear( |
| | hidden_size, factor * hidden_size, bias=True, **factory_kwargs |
| | ) |
| | |
| | nn.init.zeros_(self.linear.weight) |
| | nn.init.zeros_(self.linear.bias) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | return self.linear(self.act(x)) |
| |
|
| |
|
| | def modulate(x, shift=None, scale=None): |
| | """modulate by shift and scale |
| | |
| | Args: |
| | x (torch.Tensor): input tensor. |
| | shift (torch.Tensor, optional): shift tensor. Defaults to None. |
| | scale (torch.Tensor, optional): scale tensor. Defaults to None. |
| | |
| | Returns: |
| | torch.Tensor: the output tensor after modulate. |
| | """ |
| | if scale is None and shift is None: |
| | return x |
| | elif shift is None: |
| | return x * (1 + scale.unsqueeze(1)) |
| | elif scale is None: |
| | return x + shift.unsqueeze(1) |
| | else: |
| | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
| |
|
| |
|
| | def apply_gate(x, gate=None, tanh=False): |
| | """AI is creating summary for apply_gate |
| | |
| | Args: |
| | x (torch.Tensor): input tensor. |
| | gate (torch.Tensor, optional): gate tensor. Defaults to None. |
| | tanh (bool, optional): whether to use tanh function. Defaults to False. |
| | |
| | Returns: |
| | torch.Tensor: the output tensor after apply gate. |
| | """ |
| | if gate is None: |
| | return x |
| | if tanh: |
| | return x * gate.unsqueeze(1).tanh() |
| | else: |
| | return x * gate.unsqueeze(1) |
| |
|
| |
|
| | def ckpt_wrapper(module): |
| | def ckpt_forward(*inputs): |
| | outputs = module(*inputs) |
| | return outputs |
| |
|
| | return ckpt_forward |
| |
|