| import torch |
| from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor |
| from torch.nn import Module |
| import torch.nn.functional as F |
| import random |
| from beartype import beartype |
| from beartype.typing import Tuple, Optional, List, Union |
|
|
| from einops.layers.torch import Rearrange |
| from einops import rearrange, repeat, reduce, pack, unpack |
|
|
| from modules.audio2motion.cfm.utils import exists, identity, default, divisible_by, is_odd, coin_flip, pack_one, unpack_one |
| from modules.audio2motion.cfm.utils import prob_mask_like, reduce_masks_with_and, interpolate_1d, curtail_or_pad, mask_from_start_end_indices, mask_from_frac_lengths |
| from modules.audio2motion.cfm.module import ConvPositionEmbed, LearnedSinusoidalPosEmb, Transformer |
| from torch.cuda.amp import autocast |
|
|
| class InContextTransformerAudio2Motion(Module): |
| def __init__( |
| self, |
| *, |
| dim_in = 64, |
| dim_audio_in = 1024, |
| dim = 1024, |
| depth = 24, |
| dim_head = 64, |
| heads = 16, |
| ff_mult = 4, |
| ff_dropout = 0., |
| time_hidden_dim = None, |
| conv_pos_embed_kernel_size = 31, |
| conv_pos_embed_groups = None, |
| attn_dropout = 0, |
| attn_flash = False, |
| attn_qk_norm = True, |
| use_gateloop_layers = False, |
| num_register_tokens = 16, |
| frac_lengths_mask: Tuple[float, float] = (0.7, 1.), |
| ): |
| super().__init__() |
| dim_in = default(dim_in, dim) |
|
|
| time_hidden_dim = default(time_hidden_dim, dim * 4) |
|
|
| self.proj_in = nn.Identity() |
| self.sinu_pos_emb = nn.Sequential( |
| LearnedSinusoidalPosEmb(dim), |
| nn.Linear(dim, time_hidden_dim), |
| nn.SiLU() |
| ) |
|
|
| self.dim_audio_in = dim_audio_in |
| if self.dim_audio_in != dim_in: |
| self.to_cond_emb = nn.Linear(self.dim_audio_in, dim_in) |
| else: |
| self.to_cond_emb = nn.Identity() |
|
|
| |
| self.frac_lengths_mask = frac_lengths_mask |
|
|
| self.to_embed = nn.Linear(dim_in * 2 + dim_in, dim) |
|
|
| self.null_cond = nn.Parameter(torch.zeros(dim_in)) |
|
|
| self.conv_embed = ConvPositionEmbed( |
| dim = dim, |
| kernel_size = conv_pos_embed_kernel_size, |
| groups = conv_pos_embed_groups |
| ) |
|
|
| self.transformer = Transformer( |
| dim = dim, |
| depth = depth, |
| dim_head = dim_head, |
| heads = heads, |
| ff_mult = ff_mult, |
| ff_dropout = ff_dropout, |
| attn_dropout= attn_dropout, |
| attn_flash = attn_flash, |
| attn_qk_norm = attn_qk_norm, |
| num_register_tokens = num_register_tokens, |
| adaptive_rmsnorm = True, |
| adaptive_rmsnorm_cond_dim_in = time_hidden_dim, |
| use_gateloop_layers = use_gateloop_layers |
| ) |
|
|
| dim_out = dim_in |
| self.to_pred = nn.Linear(dim, dim_out, bias = False) |
|
|
| @property |
| def device(self): |
| return next(self.parameters()).device |
|
|
| @torch.inference_mode() |
| def forward_with_cond_scale( |
| self, |
| *args, |
| cond_scale = 1., |
| **kwargs |
| ): |
| |
| logits = self.forward(*args, cond_drop_prob = 0., **kwargs) |
|
|
| if cond_scale == 1.: |
| return logits |
|
|
| null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) |
| return null_logits + (logits - null_logits) * cond_scale |
|
|
| def forward( |
| self, |
| x, |
| *, |
| times, |
| cond_audio, |
| self_attn_mask = None, |
| cond_drop_prob = 0.1, |
| target = None, |
| cond = None, |
| cond_mask = None, |
| ret=None |
| ): |
| if ret is None: |
| ret = {} |
| |
| |
| x = self.proj_in(x) |
|
|
| if exists(cond): |
| cond = self.proj_in(cond) |
|
|
| cond = default(cond, x) |
|
|
| |
| batch, seq_len, cond_dim = cond.shape |
| assert cond_dim == x.shape[-1] |
|
|
| |
|
|
| if times.ndim == 0: |
| times = repeat(times, '-> b', b = cond.shape[0]) |
|
|
| if times.ndim == 1 and times.shape[0] == 1: |
| times = repeat(times, '1 -> b', b = cond.shape[0]) |
|
|
| |
| if self.training: |
| |
| if not exists(cond_mask): |
| if coin_flip(): |
| frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask) |
| |
| cond_mask = mask_from_frac_lengths(seq_len, frac_lengths) |
| else: |
| |
| p_drop_prob_ = self.frac_lengths_mask[0] + random.random()*(self.frac_lengths_mask[1]-self.frac_lengths_mask[0]) |
| cond_mask = prob_mask_like((batch, seq_len), p_drop_prob_, self.device) |
| |
| else: |
| if not exists(cond_mask): |
| |
| |
| cond_mask = torch.ones((batch, seq_len), device = cond.device, dtype = torch.bool) |
| cond_mask_with_pad_dim = rearrange(cond_mask, '... -> ... 1') |
|
|
| |
| x = x * cond_mask_with_pad_dim |
| cond = cond * ~cond_mask_with_pad_dim |
|
|
| |
| |
| if cond_drop_prob > 0.: |
| cond_drop_mask = prob_mask_like(cond.shape[:1], cond_drop_prob, self.device) |
|
|
| |
| cond = torch.where( |
| rearrange(cond_drop_mask, '... -> ... 1 1'), |
| self.null_cond, |
| cond |
| ) |
|
|
| |
| cond_audio_emb = self.to_cond_emb(cond_audio) |
| cond_audio_emb_length = cond_audio_emb.shape[-2] |
| if cond_audio_emb_length != seq_len: |
| cond_audio_emb = rearrange(cond_audio_emb, 'b n d -> b d n') |
| cond_audio_emb = interpolate_1d(cond_audio_emb, seq_len) |
| cond_audio_emb = rearrange(cond_audio_emb, 'b d n -> b n d') |
| if exists(self_attn_mask): |
| self_attn_mask = interpolate_1d(self_attn_mask, seq_len) |
|
|
| |
| |
| to_concat = [*filter(exists, (x, cond_audio_emb, cond))] |
| embed = torch.cat(to_concat, dim = -1) |
|
|
| x = self.to_embed(embed) |
|
|
| x = self.conv_embed(x) + x |
|
|
| time_emb = self.sinu_pos_emb(times) |
|
|
| |
|
|
| x = self.transformer( |
| x, |
| mask = self_attn_mask, |
| adaptive_rmsnorm_cond = time_emb |
| ) |
|
|
| x = self.to_pred(x) |
| |
| ret['pred'] = x |
|
|
| if not exists(target): |
| |
| return x |
| else: |
| |
| loss_mask = reduce_masks_with_and(cond_mask, self_attn_mask) |
| if not exists(loss_mask): |
| return F.mse_loss(x, target) |
| ret['loss_mask'] = loss_mask |
| loss = F.mse_loss(x, target, reduction = 'none') |
|
|
| loss = reduce(loss, 'b n d -> b n', 'mean') |
| loss = loss.masked_fill(~loss_mask, 0.) |
|
|
| |
|
|
| num = reduce(loss, 'b n -> b', 'sum') |
| den = loss_mask.sum(dim = -1).clamp(min = 1e-5) |
| loss = num / den |
| loss = loss.mean() |
| ret['mse'] = loss |
| return loss |
|
|
|
|
| if __name__ == '__main__': |
| |
| model = InContextTransformerAudio2Motion() |
|
|
| |
| input_tensor = torch.randn(2, 125, 64) |
| time_tensor = torch.rand(2) |
| audio_tensor = torch.rand(2, 125, 1024) |
|
|
| |
| output = model.forward_with_cond_scale(input_tensor, times=time_tensor, cond_audio=audio_tensor, cond=input_tensor) |
|
|
| |
| print(output.shape) |
|
|
|
|
|
|