Spaces:
Build error
Build error
| """ | |
| MIDI Transformer Model | |
| Conditional autoregressive transformer for MIDI generation. | |
| Supports tempo, instrument, and mood conditioning. | |
| """ | |
| import math | |
| from dataclasses import dataclass, field | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| class MIDITransformerConfig: | |
| """Configuration for MIDI Transformer model.""" | |
| # Model architecture | |
| vocab_size: int = 512 | |
| max_seq_len: int = 2048 | |
| d_model: int = 768 | |
| n_heads: int = 12 | |
| n_layers: int = 12 | |
| d_ff: int = 3072 | |
| dropout: float = 0.1 | |
| # Conditioning | |
| n_tempo_tokens: int = 32 | |
| n_instrument_tokens: int = 17 | |
| n_mood_tokens: int = 16 | |
| # Training | |
| tie_weights: bool = True | |
| use_flash_attention: bool = True | |
| gradient_checkpointing: bool = False | |
| # Initialization | |
| init_std: float = 0.02 | |
| class RotaryPositionalEmbedding(nn.Module): | |
| """Rotary Position Embedding (RoPE) for better position encoding.""" | |
| def __init__(self, dim: int, max_seq_len: int = 2048, base: int = 10000): | |
| super().__init__() | |
| self.dim = dim | |
| self.max_seq_len = max_seq_len | |
| self.base = base | |
| # Precompute frequencies | |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) | |
| self.register_buffer("inv_freq", inv_freq) | |
| # Precompute cos/sin cache | |
| self._update_cache(max_seq_len) | |
| def _update_cache(self, seq_len: int): | |
| t = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) | |
| freqs = torch.einsum("i,j->ij", t, self.inv_freq) | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| self.register_buffer("cos_cached", emb.cos(), persistent=False) | |
| self.register_buffer("sin_cached", emb.sin(), persistent=False) | |
| def forward(self, x: torch.Tensor, seq_len: int) -> tuple[torch.Tensor, torch.Tensor]: | |
| if seq_len > self.cos_cached.shape[0]: | |
| self._update_cache(seq_len) | |
| return self.cos_cached[:seq_len], self.sin_cached[:seq_len] | |
| def rotate_half(x: torch.Tensor) -> torch.Tensor: | |
| """Rotate half the hidden dims.""" | |
| x1, x2 = x.chunk(2, dim=-1) | |
| return torch.cat((-x2, x1), dim=-1) | |
| def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): | |
| """Apply rotary positional embedding to Q and K.""" | |
| q_embed = (q * cos) + (rotate_half(q) * sin) | |
| k_embed = (k * cos) + (rotate_half(k) * sin) | |
| return q_embed, k_embed | |
| class MultiHeadAttention(nn.Module): | |
| """Multi-head self-attention with RoPE and optional Flash Attention.""" | |
| def __init__(self, config: MIDITransformerConfig): | |
| super().__init__() | |
| self.config = config | |
| self.n_heads = config.n_heads | |
| self.head_dim = config.d_model // config.n_heads | |
| self.scale = self.head_dim ** -0.5 | |
| self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False) | |
| self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False) | |
| self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False) | |
| self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False) | |
| self.rope = RotaryPositionalEmbedding(self.head_dim, config.max_seq_len) | |
| self.dropout = nn.Dropout(config.dropout) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| use_cache: bool = False, | |
| past_kv: Optional[tuple] = None, | |
| ) -> tuple[torch.Tensor, Optional[tuple]]: | |
| batch_size, seq_len, _ = x.shape | |
| q = self.q_proj(x) | |
| k = self.k_proj(x) | |
| v = self.v_proj(x) | |
| # Reshape for multi-head attention | |
| q = rearrange(q, "b s (h d) -> b h s d", h=self.n_heads) | |
| k = rearrange(k, "b s (h d) -> b h s d", h=self.n_heads) | |
| v = rearrange(v, "b s (h d) -> b h s d", h=self.n_heads) | |
| # Handle KV cache for generation | |
| if past_kv is not None: | |
| past_k, past_v = past_kv | |
| k = torch.cat([past_k, k], dim=2) | |
| v = torch.cat([past_v, v], dim=2) | |
| if use_cache: | |
| present_kv = (k, v) | |
| else: | |
| present_kv = None | |
| # Apply RoPE | |
| cos, sin = self.rope(q, k.shape[2]) | |
| q_pos = q.shape[2] | |
| k_pos = k.shape[2] | |
| # Adjust for cached positions | |
| cos_q = cos[-q_pos:].unsqueeze(0).unsqueeze(0) | |
| sin_q = sin[-q_pos:].unsqueeze(0).unsqueeze(0) | |
| cos_k = cos[:k_pos].unsqueeze(0).unsqueeze(0) | |
| sin_k = sin[:k_pos].unsqueeze(0).unsqueeze(0) | |
| q = (q * cos_q) + (rotate_half(q) * sin_q) | |
| k = (k * cos_k) + (rotate_half(k) * sin_k) | |
| # Attention | |
| if self.config.use_flash_attention and hasattr(F, 'scaled_dot_product_attention'): | |
| # Use Flash Attention if available | |
| attn_mask = None | |
| if attention_mask is not None: | |
| attn_mask = attention_mask.unsqueeze(1).unsqueeze(2) | |
| attn_mask = attn_mask.expand(batch_size, self.n_heads, q.shape[2], k.shape[2]) | |
| attn_mask = attn_mask.bool() | |
| # Causal mask | |
| out = F.scaled_dot_product_attention( | |
| q, k, v, | |
| attn_mask=attn_mask, | |
| dropout_p=self.dropout.p if self.training else 0.0, | |
| is_causal=True if past_kv is None else False, | |
| ) | |
| else: | |
| # Standard attention | |
| attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale | |
| # Causal mask | |
| if past_kv is None: | |
| causal_mask = torch.triu( | |
| torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool), | |
| diagonal=1 | |
| ) | |
| attn = attn.masked_fill(causal_mask, float("-inf")) | |
| if attention_mask is not None: | |
| attn = attn.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(2), float("-inf")) | |
| attn = F.softmax(attn, dim=-1) | |
| attn = self.dropout(attn) | |
| out = torch.matmul(attn, v) | |
| # Reshape back | |
| out = rearrange(out, "b h s d -> b s (h d)") | |
| out = self.out_proj(out) | |
| return out, present_kv | |
| class FeedForward(nn.Module): | |
| """Feed-forward network with SwiGLU activation.""" | |
| def __init__(self, config: MIDITransformerConfig): | |
| super().__init__() | |
| hidden_dim = int(config.d_ff * 2 / 3) # SwiGLU uses 2/3 of standard FFN dim | |
| self.gate_proj = nn.Linear(config.d_model, hidden_dim, bias=False) | |
| self.up_proj = nn.Linear(config.d_model, hidden_dim, bias=False) | |
| self.down_proj = nn.Linear(hidden_dim, config.d_model, bias=False) | |
| self.dropout = nn.Dropout(config.dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # SwiGLU: down(silu(gate(x)) * up(x)) | |
| gate = F.silu(self.gate_proj(x)) | |
| up = self.up_proj(x) | |
| out = self.down_proj(gate * up) | |
| return self.dropout(out) | |
| class TransformerBlock(nn.Module): | |
| """Single transformer block with pre-norm.""" | |
| def __init__(self, config: MIDITransformerConfig): | |
| super().__init__() | |
| self.config = config | |
| self.attn_norm = nn.RMSNorm(config.d_model) | |
| self.attn = MultiHeadAttention(config) | |
| self.ff_norm = nn.RMSNorm(config.d_model) | |
| self.ff = FeedForward(config) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| use_cache: bool = False, | |
| past_kv: Optional[tuple] = None, | |
| ) -> tuple[torch.Tensor, Optional[tuple]]: | |
| # Pre-norm attention | |
| residual = x | |
| x = self.attn_norm(x) | |
| x, present_kv = self.attn(x, attention_mask, use_cache, past_kv) | |
| x = residual + x | |
| # Pre-norm FFN | |
| residual = x | |
| x = self.ff_norm(x) | |
| x = self.ff(x) | |
| x = residual + x | |
| return x, present_kv | |
| class MIDITransformer(nn.Module): | |
| """ | |
| Conditional MIDI Transformer for music generation. | |
| Architecture: | |
| - Token embeddings with conditioning tokens | |
| - Rotary positional embeddings (RoPE) | |
| - Pre-norm transformer blocks with SwiGLU | |
| - Tied input/output embeddings (optional) | |
| """ | |
| def __init__(self, config: MIDITransformerConfig): | |
| super().__init__() | |
| self.config = config | |
| # Token embedding | |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) | |
| self.embed_dropout = nn.Dropout(config.dropout) | |
| # Transformer blocks | |
| self.layers = nn.ModuleList([ | |
| TransformerBlock(config) for _ in range(config.n_layers) | |
| ]) | |
| # Output | |
| self.norm = nn.RMSNorm(config.d_model) | |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) | |
| # Tie weights | |
| if config.tie_weights: | |
| self.lm_head.weight = self.embed_tokens.weight | |
| # Initialize weights | |
| self.apply(self._init_weights) | |
| def _init_weights(self, module: nn.Module): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) | |
| if module.bias is not None: | |
| torch.nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| labels: Optional[torch.Tensor] = None, | |
| use_cache: bool = False, | |
| past_key_values: Optional[list] = None, | |
| ) -> dict: | |
| """ | |
| Forward pass. | |
| Args: | |
| input_ids: Token IDs (batch, seq_len) | |
| attention_mask: Attention mask (batch, seq_len) | |
| labels: Target token IDs for loss (batch, seq_len) | |
| use_cache: Whether to return KV cache | |
| past_key_values: Cached KV for generation | |
| Returns: | |
| Dict with logits, loss, and optional cache | |
| """ | |
| batch_size, seq_len = input_ids.shape | |
| # Embed tokens | |
| x = self.embed_tokens(input_ids) | |
| x = self.embed_dropout(x) | |
| # Apply transformer blocks | |
| present_key_values = [] if use_cache else None | |
| for i, layer in enumerate(self.layers): | |
| past_kv = past_key_values[i] if past_key_values is not None else None | |
| if self.config.gradient_checkpointing and self.training: | |
| x, present_kv = torch.utils.checkpoint.checkpoint( | |
| layer, x, attention_mask, use_cache, past_kv, | |
| use_reentrant=False | |
| ) | |
| else: | |
| x, present_kv = layer(x, attention_mask, use_cache, past_kv) | |
| if use_cache: | |
| present_key_values.append(present_kv) | |
| # Output projection | |
| x = self.norm(x) | |
| logits = self.lm_head(x) | |
| # Compute loss if labels provided | |
| # NOTE: Dataloaders already offset labels by +1 (labels[i] = next token | |
| # for input_ids[i]), so no extra shift is needed here. | |
| loss = None | |
| if labels is not None: | |
| loss = F.cross_entropy( | |
| logits.view(-1, self.config.vocab_size), | |
| labels.view(-1), | |
| ignore_index=-100, | |
| ) | |
| return { | |
| "logits": logits, | |
| "loss": loss, | |
| "past_key_values": present_key_values, | |
| } | |
| def generate( | |
| self, | |
| input_ids: torch.Tensor, | |
| max_length: int = 512, | |
| temperature: float = 1.0, | |
| top_k: int = 50, | |
| top_p: float = 0.95, | |
| repetition_penalty: float = 1.1, | |
| eos_token_id: int = 2, | |
| pad_token_id: int = 0, | |
| ) -> torch.Tensor: | |
| """ | |
| Generate MIDI tokens autoregressively. | |
| Args: | |
| input_ids: Conditioning tokens (batch, prefix_len) | |
| max_length: Maximum sequence length to generate | |
| temperature: Sampling temperature | |
| top_k: Top-k sampling | |
| top_p: Nucleus sampling threshold | |
| repetition_penalty: Penalty for repeating tokens | |
| eos_token_id: End of sequence token | |
| pad_token_id: Padding token | |
| Returns: | |
| Generated token sequence (batch, seq_len) | |
| """ | |
| self.eval() | |
| batch_size = input_ids.shape[0] | |
| device = input_ids.device | |
| generated = input_ids.clone() | |
| past_key_values = None | |
| for _ in range(max_length - input_ids.shape[1]): | |
| # Forward pass | |
| if past_key_values is None: | |
| curr_input = generated | |
| else: | |
| curr_input = generated[:, -1:] | |
| outputs = self.forward( | |
| curr_input, | |
| use_cache=True, | |
| past_key_values=past_key_values, | |
| ) | |
| logits = outputs["logits"][:, -1, :] | |
| past_key_values = outputs["past_key_values"] | |
| # Apply repetition penalty | |
| if repetition_penalty != 1.0: | |
| for i in range(batch_size): | |
| for token_id in set(generated[i].tolist()): | |
| logits[i, token_id] /= repetition_penalty | |
| # Apply temperature | |
| logits = logits / temperature | |
| # Top-k filtering | |
| if top_k > 0: | |
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | |
| logits[indices_to_remove] = float("-inf") | |
| # Top-p (nucleus) filtering | |
| if top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = False | |
| indices_to_remove = sorted_indices_to_remove.scatter( | |
| 1, sorted_indices, sorted_indices_to_remove | |
| ) | |
| logits[indices_to_remove] = float("-inf") | |
| # Sample | |
| probs = F.softmax(logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| generated = torch.cat([generated, next_token], dim=-1) | |
| # Check for EOS | |
| if (next_token == eos_token_id).all(): | |
| break | |
| return generated | |
| def get_num_params(self, non_embedding: bool = True) -> int: | |
| """Get number of parameters.""" | |
| n_params = sum(p.numel() for p in self.parameters()) | |
| if non_embedding: | |
| n_params -= self.embed_tokens.weight.numel() | |
| return n_params | |
| # Model size configurations | |
| MODEL_CONFIGS = { | |
| # ~45M params - fast prototyping | |
| "tiny": MIDITransformerConfig( | |
| d_model=512, | |
| n_heads=8, | |
| n_layers=8, | |
| d_ff=2048, | |
| ), | |
| # ~85M params - model_s_raw (raw data learning) | |
| "small": MIDITransformerConfig( | |
| d_model=768, | |
| n_heads=12, | |
| n_layers=12, | |
| d_ff=3072, | |
| ), | |
| # ~125M params - balanced | |
| "base": MIDITransformerConfig( | |
| d_model=768, | |
| n_heads=12, | |
| n_layers=12, | |
| d_ff=3072, | |
| ), | |
| # ~193M params - model_m_augmented (augmented data) | |
| "medium": MIDITransformerConfig( | |
| d_model=1024, | |
| n_heads=16, | |
| n_layers=16, | |
| d_ff=4096, | |
| ), | |
| # ~350M params - for 1M+ synthetic files | |
| "large": MIDITransformerConfig( | |
| d_model=1024, | |
| n_heads=16, | |
| n_layers=24, | |
| d_ff=4096, | |
| ), | |
| # ~770M params - for 10M+ files | |
| "xl": MIDITransformerConfig( | |
| d_model=1536, | |
| n_heads=24, | |
| n_layers=32, | |
| d_ff=6144, | |
| ), | |
| } | |
| def create_model( | |
| size: str = "base", | |
| vocab_size: int = 512, | |
| **kwargs | |
| ) -> MIDITransformer: | |
| """Create a model with predefined configuration.""" | |
| config = MODEL_CONFIGS.get(size, MODEL_CONFIGS["base"]) | |
| config.vocab_size = vocab_size | |
| for k, v in kwargs.items(): | |
| if hasattr(config, k): | |
| setattr(config, k, v) | |
| return MIDITransformer(config) | |