| from typing import Callable |
|
|
| import torch |
| import torch.nn as nn |
| import math |
|
|
| class ModulateDiT(nn.Module): |
| """Modulation layer for DiT.""" |
| def __init__( |
| self, |
| hidden_size: int, |
| factor: int, |
| act_layer: Callable, |
| dtype=None, |
| device=None, |
| ): |
| factory_kwargs = {"dtype": dtype, "device": device} |
| super().__init__() |
| self.act = act_layer() |
| self.linear = nn.Linear( |
| hidden_size, factor * hidden_size, bias=True, **factory_kwargs |
| ) |
| |
| nn.init.zeros_(self.linear.weight) |
| nn.init.zeros_(self.linear.bias) |
|
|
| def forward(self, x: torch.Tensor, condition_type=None, token_replace_vec=None) -> torch.Tensor: |
| x_out = self.linear(self.act(x)) |
|
|
| if condition_type == "token_replace": |
| x_token_replace_out = self.linear(self.act(token_replace_vec)) |
| return x_out, x_token_replace_out |
| else: |
| return x_out |
|
|
| def modulate(x, shift=None, scale=None): |
| """modulate by shift and scale |
| |
| Args: |
| x (torch.Tensor): input tensor. |
| shift (torch.Tensor, optional): shift tensor. Defaults to None. |
| scale (torch.Tensor, optional): scale tensor. Defaults to None. |
| |
| Returns: |
| torch.Tensor: the output tensor after modulate. |
| """ |
| if scale is None and shift is None: |
| return x |
| elif shift is None: |
| return x * (1 + scale.unsqueeze(1)) |
| elif scale is None: |
| return x + shift.unsqueeze(1) |
| else: |
| return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
|
|
| def modulate_(x, shift=None, scale=None): |
|
|
| if scale is None and shift is None: |
| return x |
| elif shift is None: |
| scale = scale + 1 |
| scale = scale.unsqueeze(1) |
| return x.mul_(scale) |
| elif scale is None: |
| return x + shift.unsqueeze(1) |
| else: |
| scale = scale + 1 |
| scale = scale.unsqueeze(1) |
| |
| torch.addcmul(shift.unsqueeze(1), x, scale, out =x ) |
| return x |
| |
| def modulate(x, shift=None, scale=None, condition_type=None, |
| tr_shift=None, tr_scale=None, |
| frist_frame_token_num=None): |
| if condition_type == "token_replace": |
| x_zero = x[:, :frist_frame_token_num] * (1 + tr_scale.unsqueeze(1)) + tr_shift.unsqueeze(1) |
| x_orig = x[:, frist_frame_token_num:] * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
| x = torch.concat((x_zero, x_orig), dim=1) |
| return x |
| else: |
| if scale is None and shift is None: |
| return x |
| elif shift is None: |
| return x * (1 + scale.unsqueeze(1)) |
| elif scale is None: |
| return x + shift.unsqueeze(1) |
| else: |
| return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
| |
| def apply_gate(x, gate=None, tanh=False, condition_type=None, tr_gate=None, frist_frame_token_num=None): |
| """AI is creating summary for apply_gate |
| |
| Args: |
| x (torch.Tensor): input tensor. |
| gate (torch.Tensor, optional): gate tensor. Defaults to None. |
| tanh (bool, optional): whether to use tanh function. Defaults to False. |
| |
| Returns: |
| torch.Tensor: the output tensor after apply gate. |
| """ |
| if condition_type == "token_replace": |
| if gate is None: |
| return x |
| if tanh: |
| x_zero = x[:, :frist_frame_token_num] * tr_gate.unsqueeze(1).tanh() |
| x_orig = x[:, frist_frame_token_num:] * gate.unsqueeze(1).tanh() |
| x = torch.concat((x_zero, x_orig), dim=1) |
| return x |
| else: |
| x_zero = x[:, :frist_frame_token_num] * tr_gate.unsqueeze(1) |
| x_orig = x[:, frist_frame_token_num:] * gate.unsqueeze(1) |
| x = torch.concat((x_zero, x_orig), dim=1) |
| return x |
| else: |
| if gate is None: |
| return x |
| if tanh: |
| return x * gate.unsqueeze(1).tanh() |
| else: |
| return x * gate.unsqueeze(1) |
| |
| def apply_gate_and_accumulate_(accumulator, x, gate=None, tanh=False): |
| if gate is None: |
| return accumulator |
| if tanh: |
| return accumulator.addcmul_(x, gate.unsqueeze(1).tanh()) |
| else: |
| return accumulator.addcmul_(x, gate.unsqueeze(1)) |
| |
| def ckpt_wrapper(module): |
| def ckpt_forward(*inputs): |
| outputs = module(*inputs) |
| return outputs |
|
|
| return ckpt_forward |
|
|