|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from mmcv.cnn.bricks import DropPath |
|
|
from mmengine.utils import digit_version |
|
|
from mmengine.utils.dl_utils import TORCH_VERSION |
|
|
|
|
|
|
|
|
def rope(x, dim): |
|
|
"""Applies Rotary Position Embedding to input tensor. |
|
|
|
|
|
Args: |
|
|
x (torch.Tensor): Input tensor. |
|
|
dim (int | list[int]): The spatial dimension(s) to apply |
|
|
rotary position embedding. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The tensor after applying rotary position |
|
|
embedding. |
|
|
|
|
|
Reference: |
|
|
`RoFormer: Enhanced Transformer with Rotary |
|
|
Position Embedding <https://arxiv.org/abs/2104.09864>`_ |
|
|
""" |
|
|
shape = x.shape |
|
|
if isinstance(dim, int): |
|
|
dim = [dim] |
|
|
|
|
|
spatial_shape = [shape[i] for i in dim] |
|
|
total_len = 1 |
|
|
for i in spatial_shape: |
|
|
total_len *= i |
|
|
|
|
|
position = torch.reshape( |
|
|
torch.arange(total_len, dtype=torch.int, device=x.device), |
|
|
spatial_shape) |
|
|
|
|
|
for i in range(dim[-1] + 1, len(shape) - 1, 1): |
|
|
position = torch.unsqueeze(position, dim=-1) |
|
|
|
|
|
half_size = shape[-1] // 2 |
|
|
freq_seq = -torch.arange( |
|
|
half_size, dtype=torch.int, device=x.device) / float(half_size) |
|
|
inv_freq = 10000**-freq_seq |
|
|
|
|
|
sinusoid = position[..., None] * inv_freq[None, None, :] |
|
|
|
|
|
sin = torch.sin(sinusoid) |
|
|
cos = torch.cos(sinusoid) |
|
|
x1, x2 = torch.chunk(x, 2, dim=-1) |
|
|
|
|
|
return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) |
|
|
|
|
|
|
|
|
class Scale(nn.Module): |
|
|
"""Scale vector by element multiplications. |
|
|
|
|
|
Args: |
|
|
dim (int): The dimension of the scale vector. |
|
|
init_value (float, optional): The initial value of the scale vector. |
|
|
Defaults to 1.0. |
|
|
trainable (bool, optional): Whether the scale vector is trainable. |
|
|
Defaults to True. |
|
|
""" |
|
|
|
|
|
def __init__(self, dim, init_value=1., trainable=True): |
|
|
super().__init__() |
|
|
self.scale = nn.Parameter( |
|
|
init_value * torch.ones(dim), requires_grad=trainable) |
|
|
|
|
|
def forward(self, x): |
|
|
"""Forward function.""" |
|
|
|
|
|
return x * self.scale |
|
|
|
|
|
|
|
|
class ScaleNorm(nn.Module): |
|
|
"""Scale Norm. |
|
|
|
|
|
Args: |
|
|
dim (int): The dimension of the scale vector. |
|
|
eps (float, optional): The minimum value in clamp. Defaults to 1e-5. |
|
|
|
|
|
Reference: |
|
|
`Transformers without Tears: Improving the Normalization |
|
|
of Self-Attention <https://arxiv.org/abs/1910.05895>`_ |
|
|
""" |
|
|
|
|
|
def __init__(self, dim, eps=1e-5): |
|
|
super().__init__() |
|
|
self.scale = dim**-0.5 |
|
|
self.eps = eps |
|
|
self.g = nn.Parameter(torch.ones(1)) |
|
|
|
|
|
def forward(self, x): |
|
|
"""Forward function. |
|
|
|
|
|
Args: |
|
|
x (torch.Tensor): Input tensor. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The tensor after applying scale norm. |
|
|
""" |
|
|
|
|
|
norm = torch.norm(x, dim=2, keepdim=True) * self.scale |
|
|
return x / norm.clamp(min=self.eps) * self.g |
|
|
|
|
|
|
|
|
class RTMCCBlock(nn.Module): |
|
|
"""Gated Attention Unit (GAU) in RTMBlock. |
|
|
|
|
|
Args: |
|
|
num_token (int): The number of tokens. |
|
|
in_token_dims (int): The input token dimension. |
|
|
out_token_dims (int): The output token dimension. |
|
|
expansion_factor (int, optional): The expansion factor of the |
|
|
intermediate token dimension. Defaults to 2. |
|
|
s (int, optional): The self-attention feature dimension. |
|
|
Defaults to 128. |
|
|
eps (float, optional): The minimum value in clamp. Defaults to 1e-5. |
|
|
dropout_rate (float, optional): The dropout rate. Defaults to 0.0. |
|
|
drop_path (float, optional): The drop path rate. Defaults to 0.0. |
|
|
attn_type (str, optional): Type of attention which should be one of |
|
|
the following options: |
|
|
|
|
|
- 'self-attn': Self-attention. |
|
|
- 'cross-attn': Cross-attention. |
|
|
|
|
|
Defaults to 'self-attn'. |
|
|
act_fn (str, optional): The activation function which should be one |
|
|
of the following options: |
|
|
|
|
|
- 'ReLU': ReLU activation. |
|
|
- 'SiLU': SiLU activation. |
|
|
|
|
|
Defaults to 'SiLU'. |
|
|
bias (bool, optional): Whether to use bias in linear layers. |
|
|
Defaults to False. |
|
|
use_rel_bias (bool, optional): Whether to use relative bias. |
|
|
Defaults to True. |
|
|
pos_enc (bool, optional): Whether to use rotary position |
|
|
embedding. Defaults to False. |
|
|
|
|
|
Reference: |
|
|
`Transformer Quality in Linear Time |
|
|
<https://arxiv.org/abs/2202.10447>`_ |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
num_token, |
|
|
in_token_dims, |
|
|
out_token_dims, |
|
|
expansion_factor=2, |
|
|
s=128, |
|
|
eps=1e-5, |
|
|
dropout_rate=0., |
|
|
drop_path=0., |
|
|
attn_type='self-attn', |
|
|
act_fn='SiLU', |
|
|
bias=False, |
|
|
use_rel_bias=True, |
|
|
pos_enc=False): |
|
|
|
|
|
super(RTMCCBlock, self).__init__() |
|
|
self.s = s |
|
|
self.num_token = num_token |
|
|
self.use_rel_bias = use_rel_bias |
|
|
self.attn_type = attn_type |
|
|
self.pos_enc = pos_enc |
|
|
self.drop_path = DropPath(drop_path) \ |
|
|
if drop_path > 0. else nn.Identity() |
|
|
|
|
|
self.e = int(in_token_dims * expansion_factor) |
|
|
if use_rel_bias: |
|
|
if attn_type == 'self-attn': |
|
|
self.w = nn.Parameter( |
|
|
torch.rand([2 * num_token - 1], dtype=torch.float)) |
|
|
else: |
|
|
self.a = nn.Parameter(torch.rand([1, s], dtype=torch.float)) |
|
|
self.b = nn.Parameter(torch.rand([1, s], dtype=torch.float)) |
|
|
self.o = nn.Linear(self.e, out_token_dims, bias=bias) |
|
|
|
|
|
if attn_type == 'self-attn': |
|
|
self.uv = nn.Linear(in_token_dims, 2 * self.e + self.s, bias=bias) |
|
|
self.gamma = nn.Parameter(torch.rand((2, self.s))) |
|
|
self.beta = nn.Parameter(torch.rand((2, self.s))) |
|
|
else: |
|
|
self.uv = nn.Linear(in_token_dims, self.e + self.s, bias=bias) |
|
|
self.k_fc = nn.Linear(in_token_dims, self.s, bias=bias) |
|
|
self.v_fc = nn.Linear(in_token_dims, self.e, bias=bias) |
|
|
nn.init.xavier_uniform_(self.k_fc.weight) |
|
|
nn.init.xavier_uniform_(self.v_fc.weight) |
|
|
|
|
|
self.ln = ScaleNorm(in_token_dims, eps=eps) |
|
|
|
|
|
nn.init.xavier_uniform_(self.uv.weight) |
|
|
|
|
|
if act_fn == 'SiLU': |
|
|
assert digit_version(TORCH_VERSION) >= digit_version('1.7.0'), \ |
|
|
'SiLU activation requires PyTorch version >= 1.7' |
|
|
|
|
|
self.act_fn = nn.SiLU(True) |
|
|
else: |
|
|
self.act_fn = nn.ReLU(True) |
|
|
|
|
|
if in_token_dims == out_token_dims: |
|
|
self.shortcut = True |
|
|
self.res_scale = Scale(in_token_dims) |
|
|
else: |
|
|
self.shortcut = False |
|
|
|
|
|
self.sqrt_s = math.sqrt(s) |
|
|
|
|
|
self.dropout_rate = dropout_rate |
|
|
|
|
|
if dropout_rate > 0.: |
|
|
self.dropout = nn.Dropout(dropout_rate) |
|
|
|
|
|
def rel_pos_bias(self, seq_len, k_len=None): |
|
|
"""Add relative position bias.""" |
|
|
|
|
|
if self.attn_type == 'self-attn': |
|
|
t = F.pad(self.w[:2 * seq_len - 1], [0, seq_len]).repeat(seq_len) |
|
|
t = t[..., :-seq_len].reshape(-1, seq_len, 3 * seq_len - 2) |
|
|
r = (2 * seq_len - 1) // 2 |
|
|
t = t[..., r:-r] |
|
|
else: |
|
|
a = rope(self.a.repeat(seq_len, 1), dim=0) |
|
|
b = rope(self.b.repeat(k_len, 1), dim=0) |
|
|
t = torch.bmm(a, b.permute(0, 2, 1)) |
|
|
return t |
|
|
|
|
|
def _forward(self, inputs): |
|
|
"""GAU Forward function.""" |
|
|
|
|
|
if self.attn_type == 'self-attn': |
|
|
x = inputs |
|
|
else: |
|
|
x, k, v = inputs |
|
|
|
|
|
x = self.ln(x) |
|
|
|
|
|
|
|
|
uv = self.uv(x) |
|
|
uv = self.act_fn(uv) |
|
|
|
|
|
if self.attn_type == 'self-attn': |
|
|
|
|
|
u, v, base = torch.split(uv, [self.e, self.e, self.s], dim=2) |
|
|
|
|
|
base = base.unsqueeze(2) * self.gamma[None, None, :] + self.beta |
|
|
|
|
|
if self.pos_enc: |
|
|
base = rope(base, dim=1) |
|
|
|
|
|
q, k = torch.unbind(base, dim=2) |
|
|
|
|
|
else: |
|
|
|
|
|
u, q = torch.split(uv, [self.e, self.s], dim=2) |
|
|
|
|
|
k = self.k_fc(k) |
|
|
v = self.v_fc(v) |
|
|
|
|
|
if self.pos_enc: |
|
|
q = rope(q, 1) |
|
|
k = rope(k, 1) |
|
|
|
|
|
|
|
|
|
|
|
qk = torch.bmm(q, k.permute(0, 2, 1)) |
|
|
|
|
|
if self.use_rel_bias: |
|
|
if self.attn_type == 'self-attn': |
|
|
bias = self.rel_pos_bias(q.size(1)) |
|
|
else: |
|
|
bias = self.rel_pos_bias(q.size(1), k.size(1)) |
|
|
qk += bias[:, :q.size(1), :k.size(1)] |
|
|
|
|
|
kernel = torch.square(F.relu(qk / self.sqrt_s)) |
|
|
|
|
|
if self.dropout_rate > 0.: |
|
|
kernel = self.dropout(kernel) |
|
|
|
|
|
x = u * torch.bmm(kernel, v) |
|
|
|
|
|
x = self.o(x) |
|
|
|
|
|
return x |
|
|
|
|
|
def forward(self, x): |
|
|
"""Forward function.""" |
|
|
|
|
|
if self.shortcut: |
|
|
if self.attn_type == 'cross-attn': |
|
|
res_shortcut = x[0] |
|
|
else: |
|
|
res_shortcut = x |
|
|
main_branch = self.drop_path(self._forward(x)) |
|
|
return self.res_scale(res_shortcut) + main_branch |
|
|
else: |
|
|
return self.drop_path(self._forward(x)) |
|
|
|