| """ |
| model.py β Music-conditioned causal transformer for pose generation. |
| |
| Architecture: |
| - Audio encoder: linear projection of mel features β d_model |
| - Pose embedding: linear projection of previous pose β d_model |
| - Causal decoder: GPT-style transformer (masked self-attention) |
| attends to both past poses and current+past audio via cross-attention |
| - Output head: linear β pose_dim |
| |
| At inference: autoregressive β feed one frame at a time. |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class SinusoidalPE(nn.Module): |
| """Standard fixed sinusoidal positional encoding.""" |
| def __init__(self, d_model: int, max_len: int = 4096): |
| super().__init__() |
| pe = torch.zeros(max_len, d_model) |
| pos = torch.arange(0, max_len).unsqueeze(1).float() |
| div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) |
| pe[:, 0::2] = torch.sin(pos * div) |
| pe[:, 1::2] = torch.cos(pos * div) |
| self.register_buffer("pe", pe.unsqueeze(0)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x + self.pe[:, :x.size(1)] |
|
|
|
|
| class Music2PoseTransformer(nn.Module): |
| """ |
| Args |
| ---- |
| audio_dim : number of audio features per frame (e.g. 82 = 80 mel + 2) |
| pose_dim : flattened pose vector size (e.g. 99 = 33 kpts Γ 3) |
| d_model : transformer hidden size |
| nhead : attention heads |
| num_layers : decoder layers |
| dropout : dropout probability |
| max_seq_len : maximum sequence length during training |
| """ |
| def __init__( |
| self, |
| audio_dim: int = 82, |
| pose_dim: int = 297, |
| d_model: int = 256, |
| nhead: int = 8, |
| num_layers: int = 6, |
| dropout: float = 0.1, |
| max_seq_len: int = 512, |
| ): |
| super().__init__() |
| self.pose_dim = pose_dim |
| self.d_model = d_model |
|
|
| |
| self.audio_proj = nn.Sequential( |
| nn.Linear(audio_dim, d_model), |
| nn.LayerNorm(d_model), |
| ) |
| self.pose_proj = nn.Sequential( |
| nn.Linear(pose_dim, d_model), |
| nn.LayerNorm(d_model), |
| ) |
|
|
| self.pos_enc = SinusoidalPE(d_model, max_len=max_seq_len) |
|
|
| |
| decoder_layer = nn.TransformerDecoderLayer( |
| d_model=d_model, |
| nhead=nhead, |
| dim_feedforward=d_model * 4, |
| dropout=dropout, |
| batch_first=True, |
| norm_first=True, |
| ) |
| self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers) |
|
|
| |
| self.out_head = nn.Linear(d_model, pose_dim) |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| for p in self.parameters(): |
| if p.dim() > 1: |
| nn.init.xavier_uniform_(p) |
|
|
| @staticmethod |
| def _causal_mask(seq_len: int, device: torch.device) -> torch.Tensor: |
| """Upper-triangular mask so position i cannot attend to j > i.""" |
| return torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool() |
|
|
| def forward( |
| self, |
| audio: torch.Tensor, |
| poses: torch.Tensor, |
| ) -> torch.Tensor: |
| T = audio.size(1) |
|
|
| |
| memory = self.pos_enc(self.audio_proj(audio)) |
|
|
| |
| tgt = self.pos_enc(self.pose_proj(poses)) |
|
|
| causal_mask = self._causal_mask(T, audio.device) |
|
|
| out = self.decoder(tgt, memory, tgt_mask=causal_mask) |
| return self.out_head(out) |
|
|
| |
| def step( |
| self, |
| audio_ctx: torch.Tensor, |
| pose_ctx: torch.Tensor, |
| ) -> torch.Tensor: |
| """Return the predicted pose for the NEXT frame given context.""" |
| with torch.no_grad(): |
| pred = self.forward(audio_ctx, pose_ctx) |
| return pred[:, -1, :] |
|
|