|
|
import torch |
|
|
from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor |
|
|
from torch.nn import Module |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from beartype import beartype |
|
|
from beartype.typing import Tuple, Optional, List, Union |
|
|
|
|
|
from einops.layers.torch import Rearrange |
|
|
from einops import rearrange, repeat, reduce, pack, unpack |
|
|
|
|
|
|
|
|
|
|
|
from modules.audio2motion.cfm.utils import * |
|
|
from modules.audio2motion.cfm.attend import Attend |
|
|
|
|
|
import math |
|
|
from functools import partial |
|
|
from torch.cuda.amp import autocast |
|
|
|
|
|
|
|
|
|
|
|
class LearnedSinusoidalPosEmb(Module): |
|
|
""" used by @crowsonkb """ |
|
|
|
|
|
def __init__(self, dim): |
|
|
super().__init__() |
|
|
assert divisible_by(dim, 2) |
|
|
half_dim = dim // 2 |
|
|
self.weights = nn.Parameter(torch.randn(half_dim)) |
|
|
|
|
|
def forward(self, x): |
|
|
x = rearrange(x, 'b -> b 1') |
|
|
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi |
|
|
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) |
|
|
return fouriered |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RotaryEmbedding(Module): |
|
|
def __init__(self, dim, theta = 50000): |
|
|
super().__init__() |
|
|
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) |
|
|
self.register_buffer("inv_freq", inv_freq) |
|
|
|
|
|
@property |
|
|
def device(self): |
|
|
return self.inv_freq.device |
|
|
|
|
|
@autocast(enabled = False) |
|
|
@beartype |
|
|
def forward(self, t: Union[int, Tensor]): |
|
|
if not torch.is_tensor(t): |
|
|
t = torch.arange(t, device = self.device) |
|
|
|
|
|
t = t.type_as(self.inv_freq) |
|
|
freqs = torch.einsum('i , j -> i j', t, self.inv_freq) |
|
|
freqs = torch.cat((freqs, freqs), dim = -1) |
|
|
return freqs |
|
|
|
|
|
def rotate_half(x): |
|
|
x1, x2 = x.chunk(2, dim = -1) |
|
|
return torch.cat((-x2, x1), dim = -1) |
|
|
|
|
|
@autocast(enabled = False) |
|
|
def apply_rotary_pos_emb(pos, t): |
|
|
return t * pos.cos() + rotate_half(t) * pos.sin() |
|
|
|
|
|
|
|
|
|
|
|
class ConvPositionEmbed(Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim, |
|
|
*, |
|
|
kernel_size, |
|
|
groups = None |
|
|
): |
|
|
super().__init__() |
|
|
assert is_odd(kernel_size) |
|
|
groups = default(groups, dim) |
|
|
|
|
|
self.dw_conv1d = nn.Sequential( |
|
|
nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2), |
|
|
nn.GELU() |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = rearrange(x, 'b n c -> b c n') |
|
|
x = self.dw_conv1d(x) |
|
|
return rearrange(x, 'b c n -> b n c') |
|
|
|
|
|
|
|
|
|
|
|
class RMSNorm(Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim |
|
|
): |
|
|
super().__init__() |
|
|
self.scale = dim ** 0.5 |
|
|
self.gamma = nn.Parameter(torch.ones(dim)) |
|
|
|
|
|
def forward(self, x): |
|
|
return F.normalize(x, dim = -1) * self.scale * self.gamma |
|
|
|
|
|
class AdaptiveRMSNorm(Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim, |
|
|
cond_dim = None |
|
|
): |
|
|
super().__init__() |
|
|
cond_dim = default(cond_dim, dim) |
|
|
self.scale = dim ** 0.5 |
|
|
|
|
|
self.to_gamma = nn.Linear(cond_dim, dim) |
|
|
self.to_beta = nn.Linear(cond_dim, dim) |
|
|
|
|
|
|
|
|
|
|
|
nn.init.zeros_(self.to_gamma.weight) |
|
|
nn.init.ones_(self.to_gamma.bias) |
|
|
|
|
|
nn.init.zeros_(self.to_beta.weight) |
|
|
nn.init.zeros_(self.to_beta.bias) |
|
|
|
|
|
def forward(self, x, *, cond): |
|
|
normed = F.normalize(x, dim = -1) * self.scale |
|
|
|
|
|
gamma, beta = self.to_gamma(cond), self.to_beta(cond) |
|
|
gamma, beta = map(lambda t: rearrange(t, 'b d -> b 1 d'), (gamma, beta)) |
|
|
|
|
|
return normed * gamma + beta |
|
|
|
|
|
|
|
|
|
|
|
class MultiheadRMSNorm(Module): |
|
|
def __init__(self, dim, heads): |
|
|
super().__init__() |
|
|
self.scale = dim ** 0.5 |
|
|
self.gamma = nn.Parameter(torch.ones(heads, 1, dim)) |
|
|
|
|
|
def forward(self, x): |
|
|
return F.normalize(x, dim = -1) * self.gamma * self.scale |
|
|
|
|
|
class Attention(Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim, |
|
|
dim_head = 64, |
|
|
heads = 8, |
|
|
dropout = 0, |
|
|
flash = False, |
|
|
qk_norm = False, |
|
|
qk_norm_scale = 10 |
|
|
): |
|
|
super().__init__() |
|
|
self.heads = heads |
|
|
dim_inner = dim_head * heads |
|
|
|
|
|
scale = qk_norm_scale if qk_norm else None |
|
|
|
|
|
self.attend = Attend(dropout, flash = flash, scale = scale) |
|
|
|
|
|
self.qk_norm = qk_norm |
|
|
|
|
|
if qk_norm: |
|
|
self.q_norm = MultiheadRMSNorm(dim_head, heads = heads) |
|
|
self.k_norm = MultiheadRMSNorm(dim_head, heads = heads) |
|
|
|
|
|
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False) |
|
|
self.to_out = nn.Linear(dim_inner, dim, bias = False) |
|
|
|
|
|
def forward(self, x, mask = None, rotary_emb = None): |
|
|
h = self.heads |
|
|
|
|
|
q, k, v = self.to_qkv(x).chunk(3, dim = -1) |
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) |
|
|
|
|
|
if self.qk_norm: |
|
|
q = self.q_norm(q) |
|
|
k = self.k_norm(k) |
|
|
|
|
|
if exists(rotary_emb): |
|
|
q, k = map(lambda t: apply_rotary_pos_emb(rotary_emb, t), (q, k)) |
|
|
|
|
|
out = self.attend(q, k, v, mask = mask) |
|
|
|
|
|
out = rearrange(out, 'b h n d -> b n (h d)') |
|
|
return self.to_out(out) |
|
|
|
|
|
|
|
|
|
|
|
class GEGLU(Module): |
|
|
def forward(self, x): |
|
|
x, gate = x.chunk(2, dim = -1) |
|
|
return F.gelu(gate) * x |
|
|
|
|
|
def FeedForward(dim, mult = 4, dropout = 0.): |
|
|
dim_inner = int(dim * mult * 2 / 3) |
|
|
return nn.Sequential( |
|
|
nn.Linear(dim, dim_inner * 2), |
|
|
GEGLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(dim_inner, dim) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class Transformer(Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim, |
|
|
*, |
|
|
depth, |
|
|
dim_head = 64, |
|
|
heads = 8, |
|
|
ff_mult = 4, |
|
|
attn_dropout = 0., |
|
|
ff_dropout = 0., |
|
|
num_register_tokens = 0., |
|
|
attn_flash = False, |
|
|
adaptive_rmsnorm = False, |
|
|
adaptive_rmsnorm_cond_dim_in = None, |
|
|
use_unet_skip_connection = False, |
|
|
skip_connect_scale = None, |
|
|
attn_qk_norm = False, |
|
|
use_gateloop_layers = False |
|
|
): |
|
|
super().__init__() |
|
|
assert divisible_by(depth, 2) |
|
|
self.layers = nn.ModuleList([]) |
|
|
|
|
|
self.rotary_emb = RotaryEmbedding(dim = dim_head) |
|
|
|
|
|
self.num_register_tokens = num_register_tokens |
|
|
self.has_register_tokens = num_register_tokens > 0 |
|
|
|
|
|
if self.has_register_tokens: |
|
|
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim)) |
|
|
|
|
|
if adaptive_rmsnorm: |
|
|
rmsnorm_klass = partial(AdaptiveRMSNorm, cond_dim = adaptive_rmsnorm_cond_dim_in) |
|
|
else: |
|
|
rmsnorm_klass = RMSNorm |
|
|
|
|
|
self.skip_connect_scale = default(skip_connect_scale, 2 ** -0.5) |
|
|
|
|
|
for ind in range(depth): |
|
|
layer = ind + 1 |
|
|
has_skip = use_unet_skip_connection and layer > (depth // 2) |
|
|
|
|
|
self.layers.append(nn.ModuleList([ |
|
|
nn.Linear(dim * 2, dim) if has_skip else None, |
|
|
|
|
|
None, |
|
|
rmsnorm_klass(dim = dim), |
|
|
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = attn_flash, qk_norm = attn_qk_norm), |
|
|
rmsnorm_klass(dim = dim), |
|
|
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) |
|
|
])) |
|
|
|
|
|
self.final_norm = RMSNorm(dim) |
|
|
|
|
|
@property |
|
|
def device(self): |
|
|
return next(self.parameters()).device |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x, |
|
|
mask = None, |
|
|
adaptive_rmsnorm_cond = None |
|
|
): |
|
|
batch, seq_len, *_ = x.shape |
|
|
|
|
|
|
|
|
|
|
|
if self.has_register_tokens: |
|
|
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch) |
|
|
|
|
|
x, ps = pack([register_tokens, x], 'b * d') |
|
|
|
|
|
if exists(mask): |
|
|
mask = F.pad(mask, (self.num_register_tokens, 0), value = True) |
|
|
|
|
|
|
|
|
|
|
|
skip_connects = [] |
|
|
|
|
|
|
|
|
|
|
|
positions = seq_len |
|
|
|
|
|
if self.has_register_tokens: |
|
|
main_positions = torch.arange(seq_len, device = self.device, dtype = torch.long) |
|
|
register_positions = torch.full((self.num_register_tokens,), -10000, device = self.device, dtype = torch.long) |
|
|
positions = torch.cat((register_positions, main_positions)) |
|
|
|
|
|
rotary_emb = self.rotary_emb(positions) |
|
|
|
|
|
|
|
|
|
|
|
rmsnorm_kwargs = dict() |
|
|
if exists(adaptive_rmsnorm_cond): |
|
|
rmsnorm_kwargs = dict(cond = adaptive_rmsnorm_cond) |
|
|
|
|
|
|
|
|
|
|
|
for skip_combiner, maybe_gateloop, attn_prenorm, attn, ff_prenorm, ff in self.layers: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not exists(skip_combiner): |
|
|
skip_connects.append(x) |
|
|
else: |
|
|
skip_connect = skip_connects.pop() * self.skip_connect_scale |
|
|
x = torch.cat((x, skip_connect), dim = -1) |
|
|
x = skip_combiner(x) |
|
|
|
|
|
if exists(maybe_gateloop): |
|
|
x = maybe_gateloop(x) + x |
|
|
|
|
|
attn_input = attn_prenorm(x, **rmsnorm_kwargs) |
|
|
x = attn(attn_input, mask = mask, rotary_emb = rotary_emb) + x |
|
|
|
|
|
ff_input = ff_prenorm(x, **rmsnorm_kwargs) |
|
|
x = ff(ff_input) + x |
|
|
|
|
|
|
|
|
|
|
|
if self.has_register_tokens: |
|
|
_, x = unpack(x, ps, 'b * d') |
|
|
|
|
|
return self.final_norm(x) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
transformer = Transformer(dim=512, depth=6, dim_head=64, heads=8, ff_mult=4) |
|
|
|
|
|
|
|
|
input_tensor = torch.randn(1, 10, 512) |
|
|
|
|
|
|
|
|
output = transformer(input_tensor) |
|
|
|
|
|
|
|
|
print(output.shape) |
|
|
|
|
|
|