TunaDance / model /model.py
NikhilMarisetty's picture
Upload folder using huggingface_hub
eb71a72 verified
from typing import Any, Callable, List, Optional, Union
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torch import Tensor
from torch.nn import functional as F
from model.rotary_embedding_torch import RotaryEmbedding
from model.utils import PositionalEncoding, SinusoidalPosEmb, prob_mask_like
class DenseFiLM(nn.Module):
"""Feature-wise linear modulation (FiLM) generator."""
def __init__(self, embed_channels):
super().__init__()
self.embed_channels = embed_channels
self.block = nn.Sequential(
nn.Mish(), nn.Linear(embed_channels, embed_channels * 2)
)
def forward(self, position):
pos_encoding = self.block(position)
pos_encoding = rearrange(pos_encoding, "b c -> b 1 c")
scale_shift = pos_encoding.chunk(2, dim=-1)
return scale_shift
def featurewise_affine(x, scale_shift):
scale, shift = scale_shift
return (scale + 1) * x + shift
class TransformerEncoderLayer(nn.Module):
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
layer_norm_eps: float = 1e-5,
batch_first: bool = False,
norm_first: bool = True,
device=None,
dtype=None,
rotary=None,
) -> None:
super().__init__()
self.self_attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=batch_first
)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm_first = norm_first
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = activation
self.rotary = rotary
self.use_rotary = rotary is not None
def forward(
self,
src: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
x = src
if self.norm_first:
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
x = x + self._ff_block(self.norm2(x))
else:
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
x = self.norm2(x + self._ff_block(x))
return x
# self-attention block
def _sa_block(
self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]
) -> Tensor:
qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
x = self.self_attn(
qk,
qk,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)[0]
return self.dropout1(x)
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)
class FiLMTransformerDecoderLayer(nn.Module):
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward=2048,
dropout=0.1,
activation=F.relu,
layer_norm_eps=1e-5,
batch_first=False,
norm_first=True,
device=None,
dtype=None,
rotary=None,
):
super().__init__()
self.self_attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=batch_first
)
self.multihead_attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=batch_first
)
# Feedforward
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm_first = norm_first
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = activation
self.film1 = DenseFiLM(d_model)
self.film2 = DenseFiLM(d_model)
self.film3 = DenseFiLM(d_model)
self.rotary = rotary
self.use_rotary = rotary is not None
# x, cond, t
def forward(
self,
tgt,
memory,
t,
tgt_mask=None,
memory_mask=None,
tgt_key_padding_mask=None,
memory_key_padding_mask=None,
):
x = tgt
if self.norm_first:
# self-attention -> film -> residual
x_1 = self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask)
x = x + featurewise_affine(x_1, self.film1(t))
# cross-attention -> film -> residual
x_2 = self._mha_block(
self.norm2(x), memory, memory_mask, memory_key_padding_mask
)
x = x + featurewise_affine(x_2, self.film2(t))
# feedforward -> film -> residual
x_3 = self._ff_block(self.norm3(x))
x = x + featurewise_affine(x_3, self.film3(t))
else:
x = self.norm1(
x
+ featurewise_affine(
self._sa_block(x, tgt_mask, tgt_key_padding_mask), self.film1(t)
)
)
x = self.norm2(
x
+ featurewise_affine(
self._mha_block(x, memory, memory_mask, memory_key_padding_mask),
self.film2(t),
)
)
x = self.norm3(x + featurewise_affine(self._ff_block(x), self.film3(t)))
return x
# self-attention block
# qkv
def _sa_block(self, x, attn_mask, key_padding_mask):
qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
x = self.self_attn(
qk,
qk,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)[0]
return self.dropout1(x)
# multihead attention block
# qkv
def _mha_block(self, x, mem, attn_mask, key_padding_mask):
q = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
k = self.rotary.rotate_queries_or_keys(mem) if self.use_rotary else mem
x = self.multihead_attn(
q,
k,
mem,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)[0]
return self.dropout2(x)
# feed forward block
def _ff_block(self, x):
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout3(x)
class DecoderLayerStack(nn.Module):
def __init__(self, stack):
super().__init__()
self.stack = stack
def forward(self, x, cond, t):
for layer in self.stack:
x = layer(x, cond, t)
return x
class SeqModel(nn.Module):
def __init__(self,
nfeats: int,
seq_len: int = 150, # 5 seconds, 30 fps
latent_dim: int = 256,
ff_size: int = 1024,
num_layers: int = 4,
num_heads: int = 4,
dropout: float = 0.1,
cond_feature_dim: int = 35,
activation: Callable[[Tensor], Tensor] = F.gelu,
use_rotary=True,
**kwargs
) -> None:
super().__init__()
self.network = nn.ModuleDict()
self.network['body_net'] = DanceDecoder(
nfeats=4+3+22*6,
seq_len=seq_len,
latent_dim=latent_dim,
ff_size=ff_size,
num_layers=num_layers,
num_heads=num_heads,
dropout=dropout,
cond_feature_dim=cond_feature_dim,
activation=activation
)
self.network['hand_net'] = DanceDecoder(
nfeats=30*6,
seq_len=seq_len,
latent_dim=latent_dim,
ff_size=ff_size,
num_layers=num_layers,
num_heads=num_heads,
dropout=dropout,
cond_feature_dim=35+139, # debug !
activation=activation
)
def forward(self, x: Tensor, cond_embed: Tensor, times: Tensor, cond_drop_prob: float = 0.0):
x_body_start = x[:,:,:4+135]
x_hand_start = x[:,:,4+135:]
body_output = self.network['body_net'](x_body_start, cond_embed, times, cond_drop_prob)
cond_embed = torch.cat([body_output, cond_embed], dim = -1)
hand_output = self.network['hand_net'](x_hand_start, cond_embed, times, cond_drop_prob)
output = torch.cat([body_output, hand_output], dim=-1)
return output
def guided_forward(self, x, cond_embed, times, guidance_weight):
unc = self.forward(x, cond_embed, times, cond_drop_prob=1)
conditioned = self.forward(x, cond_embed, times, cond_drop_prob=0)
return unc + (conditioned - unc) * guidance_weight
class DanceDecoder(nn.Module):
def __init__(
self,
nfeats: int,
seq_len: int = 150, # 5 seconds, 30 fps
latent_dim: int = 256,
ff_size: int = 1024,
num_layers: int = 4,
num_heads: int = 4,
dropout: float = 0.1,
cond_feature_dim: int = 35,
activation: Callable[[Tensor], Tensor] = F.gelu,
use_rotary=True,
**kwargs
) -> None:
super().__init__()
output_feats = nfeats
# positional embeddings
self.rotary = None
self.abs_pos_encoding = nn.Identity()
# if rotary, replace absolute embedding with a rotary embedding instance (absolute becomes an identity)
if use_rotary:
self.rotary = RotaryEmbedding(dim=latent_dim)
else:
self.abs_pos_encoding = PositionalEncoding(
latent_dim, dropout, batch_first=True
)
# time embedding processing
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(latent_dim), # learned?
nn.Linear(latent_dim, latent_dim * 4),
nn.Mish(),
)
self.to_time_cond = nn.Sequential(nn.Linear(latent_dim * 4, latent_dim),)
self.to_time_tokens = nn.Sequential(
nn.Linear(latent_dim * 4, latent_dim * 2), # 2 time tokens
Rearrange("b (r d) -> b r d", r=2),
)
# null embeddings for guidance dropout
self.null_cond_embed = nn.Parameter(torch.randn(1, seq_len, latent_dim))
self.null_cond_hidden = nn.Parameter(torch.randn(1, latent_dim))
self.norm_cond = nn.LayerNorm(latent_dim)
# input projection
self.input_projection = nn.Linear(nfeats, latent_dim)
self.cond_encoder = nn.Sequential()
for _ in range(2):
self.cond_encoder.append(
TransformerEncoderLayer(
d_model=latent_dim,
nhead=num_heads,
dim_feedforward=ff_size,
dropout=dropout,
activation=activation,
batch_first=True,
rotary=self.rotary,
)
)
# conditional projection
self.cond_projection = nn.Linear(cond_feature_dim, latent_dim) # debug cond_feature_dim
self.non_attn_cond_projection = nn.Sequential(
nn.LayerNorm(latent_dim),
nn.Linear(latent_dim, latent_dim),
nn.SiLU(),
nn.Linear(latent_dim, latent_dim),
)
# decoder
decoderstack = nn.ModuleList([])
for _ in range(num_layers):
decoderstack.append(
FiLMTransformerDecoderLayer(
latent_dim,
num_heads,
dim_feedforward=ff_size,
dropout=dropout,
activation=activation,
batch_first=True,
rotary=self.rotary,
)
)
self.seqTransDecoder = DecoderLayerStack(decoderstack)
self.final_layer = nn.Linear(latent_dim, output_feats)
def guided_forward(self, x, cond_embed, times, guidance_weight):
unc = self.forward(x, cond_embed, times, cond_drop_prob=1)
conditioned = self.forward(x, cond_embed, times, cond_drop_prob=0)
return unc + (conditioned - unc) * guidance_weight
def forward(
self, x: Tensor, cond_embed: Tensor, times: Tensor, cond_drop_prob: float = 0.0
):
batch_size, device = x.shape[0], x.device
# project to latent space
x = self.input_projection(x)
# add the positional embeddings of the input sequence to provide temporal information
x = self.abs_pos_encoding(x)
# create music conditional embedding with conditional dropout
keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device=device)
keep_mask_embed = rearrange(keep_mask, "b -> b 1 1")
keep_mask_hidden = rearrange(keep_mask, "b -> b 1")
cond_tokens = self.cond_projection(cond_embed)
# encode tokens
cond_tokens = self.abs_pos_encoding(cond_tokens)
cond_tokens = self.cond_encoder(cond_tokens)
null_cond_embed = self.null_cond_embed.to(cond_tokens.dtype)
cond_tokens = torch.where(keep_mask_embed, cond_tokens, null_cond_embed)
mean_pooled_cond_tokens = cond_tokens.mean(dim=-2)
cond_hidden = self.non_attn_cond_projection(mean_pooled_cond_tokens)
# create the diffusion timestep embedding, add the extra music projection
t_hidden = self.time_mlp(times)
# project to attention and FiLM conditioning
t = self.to_time_cond(t_hidden)
t_tokens = self.to_time_tokens(t_hidden)
# FiLM conditioning
null_cond_hidden = self.null_cond_hidden.to(t.dtype)
cond_hidden = torch.where(keep_mask_hidden, cond_hidden, null_cond_hidden)
t += cond_hidden
# cross-attention conditioning
c = torch.cat((cond_tokens, t_tokens), dim=-2)
cond_tokens = self.norm_cond(c)
# Pass through the transformer decoder
# attending to the conditional embedding
output = self.seqTransDecoder(x, cond_tokens, t)
output = self.final_layer(output)
return output