ROOM / modelw /model.py
solo363614's picture
Upload folder using huggingface_hub
aed1d05 verified
"""
MIDI Transformer Model
Conditional autoregressive transformer for MIDI generation.
Supports tempo, instrument, and mood conditioning.
"""
import math
from dataclasses import dataclass, field
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
@dataclass
class MIDITransformerConfig:
"""Configuration for MIDI Transformer model."""
# Model architecture
vocab_size: int = 512
max_seq_len: int = 2048
d_model: int = 768
n_heads: int = 12
n_layers: int = 12
d_ff: int = 3072
dropout: float = 0.1
# Conditioning
n_tempo_tokens: int = 32
n_instrument_tokens: int = 17
n_mood_tokens: int = 16
# Training
tie_weights: bool = True
use_flash_attention: bool = True
gradient_checkpointing: bool = False
# Initialization
init_std: float = 0.02
class RotaryPositionalEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE) for better position encoding."""
def __init__(self, dim: int, max_seq_len: int = 2048, base: int = 10000):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
self.base = base
# Precompute frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
# Precompute cos/sin cache
self._update_cache(max_seq_len)
def _update_cache(self, seq_len: int):
t = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
def forward(self, x: torch.Tensor, seq_len: int) -> tuple[torch.Tensor, torch.Tensor]:
if seq_len > self.cos_cached.shape[0]:
self._update_cache(seq_len)
return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotate half the hidden dims."""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
"""Apply rotary positional embedding to Q and K."""
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class MultiHeadAttention(nn.Module):
"""Multi-head self-attention with RoPE and optional Flash Attention."""
def __init__(self, config: MIDITransformerConfig):
super().__init__()
self.config = config
self.n_heads = config.n_heads
self.head_dim = config.d_model // config.n_heads
self.scale = self.head_dim ** -0.5
self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.rope = RotaryPositionalEmbedding(self.head_dim, config.max_seq_len)
self.dropout = nn.Dropout(config.dropout)
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
past_kv: Optional[tuple] = None,
) -> tuple[torch.Tensor, Optional[tuple]]:
batch_size, seq_len, _ = x.shape
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
# Reshape for multi-head attention
q = rearrange(q, "b s (h d) -> b h s d", h=self.n_heads)
k = rearrange(k, "b s (h d) -> b h s d", h=self.n_heads)
v = rearrange(v, "b s (h d) -> b h s d", h=self.n_heads)
# Handle KV cache for generation
if past_kv is not None:
past_k, past_v = past_kv
k = torch.cat([past_k, k], dim=2)
v = torch.cat([past_v, v], dim=2)
if use_cache:
present_kv = (k, v)
else:
present_kv = None
# Apply RoPE
cos, sin = self.rope(q, k.shape[2])
q_pos = q.shape[2]
k_pos = k.shape[2]
# Adjust for cached positions
cos_q = cos[-q_pos:].unsqueeze(0).unsqueeze(0)
sin_q = sin[-q_pos:].unsqueeze(0).unsqueeze(0)
cos_k = cos[:k_pos].unsqueeze(0).unsqueeze(0)
sin_k = sin[:k_pos].unsqueeze(0).unsqueeze(0)
q = (q * cos_q) + (rotate_half(q) * sin_q)
k = (k * cos_k) + (rotate_half(k) * sin_k)
# Attention
if self.config.use_flash_attention and hasattr(F, 'scaled_dot_product_attention'):
# Use Flash Attention if available
attn_mask = None
if attention_mask is not None:
attn_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attn_mask = attn_mask.expand(batch_size, self.n_heads, q.shape[2], k.shape[2])
attn_mask = attn_mask.bool()
# Causal mask
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.dropout.p if self.training else 0.0,
is_causal=True if past_kv is None else False,
)
else:
# Standard attention
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
# Causal mask
if past_kv is None:
causal_mask = torch.triu(
torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool),
diagonal=1
)
attn = attn.masked_fill(causal_mask, float("-inf"))
if attention_mask is not None:
attn = attn.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(2), float("-inf"))
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
# Reshape back
out = rearrange(out, "b h s d -> b s (h d)")
out = self.out_proj(out)
return out, present_kv
class FeedForward(nn.Module):
"""Feed-forward network with SwiGLU activation."""
def __init__(self, config: MIDITransformerConfig):
super().__init__()
hidden_dim = int(config.d_ff * 2 / 3) # SwiGLU uses 2/3 of standard FFN dim
self.gate_proj = nn.Linear(config.d_model, hidden_dim, bias=False)
self.up_proj = nn.Linear(config.d_model, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# SwiGLU: down(silu(gate(x)) * up(x))
gate = F.silu(self.gate_proj(x))
up = self.up_proj(x)
out = self.down_proj(gate * up)
return self.dropout(out)
class TransformerBlock(nn.Module):
"""Single transformer block with pre-norm."""
def __init__(self, config: MIDITransformerConfig):
super().__init__()
self.config = config
self.attn_norm = nn.RMSNorm(config.d_model)
self.attn = MultiHeadAttention(config)
self.ff_norm = nn.RMSNorm(config.d_model)
self.ff = FeedForward(config)
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
past_kv: Optional[tuple] = None,
) -> tuple[torch.Tensor, Optional[tuple]]:
# Pre-norm attention
residual = x
x = self.attn_norm(x)
x, present_kv = self.attn(x, attention_mask, use_cache, past_kv)
x = residual + x
# Pre-norm FFN
residual = x
x = self.ff_norm(x)
x = self.ff(x)
x = residual + x
return x, present_kv
class MIDITransformer(nn.Module):
"""
Conditional MIDI Transformer for music generation.
Architecture:
- Token embeddings with conditioning tokens
- Rotary positional embeddings (RoPE)
- Pre-norm transformer blocks with SwiGLU
- Tied input/output embeddings (optional)
"""
def __init__(self, config: MIDITransformerConfig):
super().__init__()
self.config = config
# Token embedding
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
self.embed_dropout = nn.Dropout(config.dropout)
# Transformer blocks
self.layers = nn.ModuleList([
TransformerBlock(config) for _ in range(config.n_layers)
])
# Output
self.norm = nn.RMSNorm(config.d_model)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
# Tie weights
if config.tie_weights:
self.lm_head.weight = self.embed_tokens.weight
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module: nn.Module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: bool = False,
past_key_values: Optional[list] = None,
) -> dict:
"""
Forward pass.
Args:
input_ids: Token IDs (batch, seq_len)
attention_mask: Attention mask (batch, seq_len)
labels: Target token IDs for loss (batch, seq_len)
use_cache: Whether to return KV cache
past_key_values: Cached KV for generation
Returns:
Dict with logits, loss, and optional cache
"""
batch_size, seq_len = input_ids.shape
# Embed tokens
x = self.embed_tokens(input_ids)
x = self.embed_dropout(x)
# Apply transformer blocks
present_key_values = [] if use_cache else None
for i, layer in enumerate(self.layers):
past_kv = past_key_values[i] if past_key_values is not None else None
if self.config.gradient_checkpointing and self.training:
x, present_kv = torch.utils.checkpoint.checkpoint(
layer, x, attention_mask, use_cache, past_kv,
use_reentrant=False
)
else:
x, present_kv = layer(x, attention_mask, use_cache, past_kv)
if use_cache:
present_key_values.append(present_kv)
# Output projection
x = self.norm(x)
logits = self.lm_head(x)
# Compute loss if labels provided
# NOTE: Dataloaders already offset labels by +1 (labels[i] = next token
# for input_ids[i]), so no extra shift is needed here.
loss = None
if labels is not None:
loss = F.cross_entropy(
logits.view(-1, self.config.vocab_size),
labels.view(-1),
ignore_index=-100,
)
return {
"logits": logits,
"loss": loss,
"past_key_values": present_key_values,
}
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_length: int = 512,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 0.95,
repetition_penalty: float = 1.1,
eos_token_id: int = 2,
pad_token_id: int = 0,
) -> torch.Tensor:
"""
Generate MIDI tokens autoregressively.
Args:
input_ids: Conditioning tokens (batch, prefix_len)
max_length: Maximum sequence length to generate
temperature: Sampling temperature
top_k: Top-k sampling
top_p: Nucleus sampling threshold
repetition_penalty: Penalty for repeating tokens
eos_token_id: End of sequence token
pad_token_id: Padding token
Returns:
Generated token sequence (batch, seq_len)
"""
self.eval()
batch_size = input_ids.shape[0]
device = input_ids.device
generated = input_ids.clone()
past_key_values = None
for _ in range(max_length - input_ids.shape[1]):
# Forward pass
if past_key_values is None:
curr_input = generated
else:
curr_input = generated[:, -1:]
outputs = self.forward(
curr_input,
use_cache=True,
past_key_values=past_key_values,
)
logits = outputs["logits"][:, -1, :]
past_key_values = outputs["past_key_values"]
# Apply repetition penalty
if repetition_penalty != 1.0:
for i in range(batch_size):
for token_id in set(generated[i].tolist()):
logits[i, token_id] /= repetition_penalty
# Apply temperature
logits = logits / temperature
# Top-k filtering
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = float("-inf")
# Top-p (nucleus) filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
logits[indices_to_remove] = float("-inf")
# Sample
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated = torch.cat([generated, next_token], dim=-1)
# Check for EOS
if (next_token == eos_token_id).all():
break
return generated
def get_num_params(self, non_embedding: bool = True) -> int:
"""Get number of parameters."""
n_params = sum(p.numel() for p in self.parameters())
if non_embedding:
n_params -= self.embed_tokens.weight.numel()
return n_params
# Model size configurations
MODEL_CONFIGS = {
# ~45M params - fast prototyping
"tiny": MIDITransformerConfig(
d_model=512,
n_heads=8,
n_layers=8,
d_ff=2048,
),
# ~85M params - model_s_raw (raw data learning)
"small": MIDITransformerConfig(
d_model=768,
n_heads=12,
n_layers=12,
d_ff=3072,
),
# ~125M params - balanced
"base": MIDITransformerConfig(
d_model=768,
n_heads=12,
n_layers=12,
d_ff=3072,
),
# ~193M params - model_m_augmented (augmented data)
"medium": MIDITransformerConfig(
d_model=1024,
n_heads=16,
n_layers=16,
d_ff=4096,
),
# ~350M params - for 1M+ synthetic files
"large": MIDITransformerConfig(
d_model=1024,
n_heads=16,
n_layers=24,
d_ff=4096,
),
# ~770M params - for 10M+ files
"xl": MIDITransformerConfig(
d_model=1536,
n_heads=24,
n_layers=32,
d_ff=6144,
),
}
def create_model(
size: str = "base",
vocab_size: int = 512,
**kwargs
) -> MIDITransformer:
"""Create a model with predefined configuration."""
config = MODEL_CONFIGS.get(size, MODEL_CONFIGS["base"])
config.vocab_size = vocab_size
for k, v in kwargs.items():
if hasattr(config, k):
setattr(config, k, v)
return MIDITransformer(config)