Coda / src /model.py
Prajanya Gupta
initial deploy
6b7b403
"""Minimal decoder-only GPT for MIDI token LM (interpretability-friendly).
Architecture (Pre-LN, GPT-2 style):
tok_emb + pos_emb
→ repeat: x += attn(LN(x)); x += mlp(LN(x))
→ LN → logits
Causal self-attention is implemented explicitly so attention weights can be
returned for probing (``return_attn=True``).
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from tokenizer import (
ID2TOKEN,
N_POS_BINS,
PITCH_MAX,
PITCH_MIN,
VOCAB_SIZE,
)
SCALE_DEGREE_NONE = 12 # sentinel for "no key context / not a pitch"
@dataclass
class GPTConfig:
vocab_size: int = VOCAB_SIZE
block_size: int = 1024
d_model: int = 512
n_layers: int = 6
n_heads: int = 8
d_ff: int = 2048
dropout: float = 0.1
# Compound-embedding axes (additive on top of token embedding).
use_pitch_class_embed: bool = True # adds 13 (pc 0..11 + sentinel)
use_octave_embed: bool = True # adds 9 (oct 0..7 + sentinel)
use_interval_embed: bool = True # adds 27 (-13..13 + sentinel) for melodic interval
use_beat_cyclic_embed: bool = True # adds N_POS_BINS+1 for beat-within-bar
use_scale_degree_embed: bool = True # adds 13 (chromatic 0..11 + sentinel) relative to current key
def default_gpt_config() -> GPTConfig:
"""Recommended starter config (~10M params with weight tying)."""
return GPTConfig()
# --- Static per-token feature lookups -----------------------------------------
# Sentinel index for "no pitch class / octave applies."
PC_NONE = 12
OCT_NONE = 8
def _build_token_pitch_feature_tables() -> Tuple[torch.Tensor, torch.Tensor]:
"""Per-token-id buffers giving (pitch_class, octave) for pitch tokens
and sentinels otherwise.
"""
pc = torch.full((VOCAB_SIZE,), PC_NONE, dtype=torch.long)
oct_ = torch.full((VOCAB_SIZE,), OCT_NONE, dtype=torch.long)
for tid, name in ID2TOKEN.items():
if (
name.startswith("P")
and not name.startswith("POS")
and name[1:].isdigit()
):
midi = int(name[1:])
if PITCH_MIN <= midi <= PITCH_MAX:
pc[tid] = midi % 12
oct_[tid] = max(0, min(7, midi // 12 - 1))
return pc, oct_
def _is_pitch_token_mask() -> torch.Tensor:
"""Boolean mask of length VOCAB_SIZE: True for pitch tokens."""
mask = torch.zeros(VOCAB_SIZE, dtype=torch.bool)
for tid, name in ID2TOKEN.items():
if (
name.startswith("P")
and not name.startswith("POS")
and name[1:].isdigit()
):
mask[tid] = True
return mask
def _midi_for_pitch_token(tid: int) -> int:
"""MIDI number for a pitch token id, or -1 if not a pitch token."""
name = ID2TOKEN.get(tid, "")
if name.startswith("P") and not name.startswith("POS") and name[1:].isdigit():
return int(name[1:])
return -1
def _build_pitch_to_midi() -> torch.Tensor:
"""Per-token-id MIDI value for pitch tokens, -1 elsewhere."""
arr = torch.full((VOCAB_SIZE,), -1, dtype=torch.long)
for tid in range(VOCAB_SIZE):
arr[tid] = _midi_for_pitch_token(tid)
return arr
def _build_pos_token_value() -> torch.Tensor:
"""For each token id, the POS bin value if it's a POS token, else -1."""
arr = torch.full((VOCAB_SIZE,), -1, dtype=torch.long)
for tid, name in ID2TOKEN.items():
if name.startswith("POS") and name[3:].isdigit():
arr[tid] = int(name[3:])
return arr
def _build_key_token_root() -> torch.Tensor:
"""For each token id, the root pitch class for KEY tokens (0..11),
else -1. KEY_0..11 are major keys C..B; KEY_12..23 are minor keys C..B.
"""
arr = torch.full((VOCAB_SIZE,), -1, dtype=torch.long)
for tid, name in ID2TOKEN.items():
if name.startswith("KEY_") and name[4:].isdigit():
arr[tid] = int(name[4:]) % 12
return arr
# Interval embedding: clipped to [-13..13] with 27 = sentinel (no interval).
INTERVAL_RANGE = 13
INTERVAL_NONE = 2 * INTERVAL_RANGE + 1 # = 27
class CausalSelfAttention(nn.Module):
"""Multi-head causal self-attention with optional weight return."""
def __init__(
self,
d_model: int,
n_heads: int,
block_size: int,
dropout: float,
) -> None:
super().__init__()
if d_model % n_heads != 0:
raise ValueError("d_model must be divisible by n_heads")
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.scale = self.head_dim**-0.5
self.qkv = nn.Linear(d_model, 3 * d_model)
self.proj = nn.Linear(d_model, d_model)
self.attn_drop = nn.Dropout(dropout)
self.resid_drop = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
return_attn: bool = False,
use_cache: bool = False,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[Tuple[torch.Tensor, torch.Tensor]],
]:
B, Tq, C = x.shape
qkv = self.qkv(x)
qkv = qkv.view(B, Tq, 3, self.n_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
if past_kv is not None:
past_k, past_v = past_kv
k = torch.cat([past_k, k], dim=2)
v = torch.cat([past_v, v], dim=2)
att = (q @ k.transpose(-2, -1)) * self.scale
Tk = k.size(2)
past_len = Tk - Tq
key_pos = torch.arange(Tk, device=x.device).unsqueeze(0)
query_pos = (
torch.arange(Tq, device=x.device).unsqueeze(1) + past_len
)
causal = key_pos <= query_pos
att = att.masked_fill(~causal.unsqueeze(0).unsqueeze(0), float("-inf"))
att_weights = F.softmax(att, dim=-1)
att_weights = self.attn_drop(att_weights)
out = att_weights @ v
out = out.transpose(1, 2).contiguous().view(B, Tq, C)
out = self.resid_drop(self.proj(out))
present = (k, v) if use_cache else None
if return_attn:
return out, att_weights, present
return out, None, present
class TransformerBlock(nn.Module):
def __init__(self, config: GPTConfig) -> None:
super().__init__()
self.ln1 = nn.LayerNorm(config.d_model)
self.attn = CausalSelfAttention(
d_model=config.d_model,
n_heads=config.n_heads,
block_size=config.block_size,
dropout=config.dropout,
)
self.ln2 = nn.LayerNorm(config.d_model)
self.mlp = nn.Sequential(
nn.Linear(config.d_model, config.d_ff),
nn.GELU(),
nn.Linear(config.d_ff, config.d_model),
nn.Dropout(config.dropout),
)
def forward(
self,
x: torch.Tensor,
return_attn: bool = False,
use_cache: bool = False,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[Tuple[torch.Tensor, torch.Tensor]],
]:
h, attn_w, present = self.attn(
self.ln1(x),
return_attn=return_attn,
use_cache=use_cache,
past_kv=past_kv,
)
x = x + h
x = x + self.mlp(self.ln2(x))
return x, attn_w, present
class GPT(nn.Module):
"""Decoder-only transformer LM with optional attention outputs."""
def __init__(self, config: GPTConfig) -> None:
super().__init__()
self.config = config
self.wte = nn.Embedding(config.vocab_size, config.d_model)
self.wpe = nn.Embedding(config.block_size, config.d_model)
self.drop = nn.Dropout(config.dropout)
self.blocks = nn.ModuleList(
TransformerBlock(config) for _ in range(config.n_layers)
)
self.ln_f = nn.LayerNorm(config.d_model)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.lm_head.weight = self.wte.weight
# --- Compound (2D) embedding axes ------------------------------------
# Static per-token-id feature lookups (computed from tokenizer vocab).
pc_tab, oct_tab = _build_token_pitch_feature_tables()
self.register_buffer("tok_to_pc", pc_tab, persistent=False)
self.register_buffer("tok_to_octave", oct_tab, persistent=False)
self.register_buffer(
"tok_to_midi", _build_pitch_to_midi(), persistent=False
)
self.register_buffer(
"tok_to_pos_value", _build_pos_token_value(), persistent=False
)
self.register_buffer(
"tok_to_key_root", _build_key_token_root(), persistent=False
)
self.register_buffer(
"tok_is_pitch", _is_pitch_token_mask(), persistent=False
)
if config.use_pitch_class_embed:
self.embed_pc = nn.Embedding(PC_NONE + 1, config.d_model)
if config.use_octave_embed:
self.embed_octave = nn.Embedding(OCT_NONE + 1, config.d_model)
if config.use_interval_embed:
self.embed_interval = nn.Embedding(INTERVAL_NONE + 1, config.d_model)
if config.use_beat_cyclic_embed:
# N_POS_BINS bins + sentinel for "no bar context".
self.embed_beat = nn.Embedding(N_POS_BINS + 1, config.d_model)
if config.use_scale_degree_embed:
self.embed_scale_degree = nn.Embedding(SCALE_DEGREE_NONE + 1, config.d_model)
self.apply(self._init_weights)
@staticmethod
def _init_weights(module: nn.Module) -> None:
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(
self,
idx: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
return_attn: bool = False,
use_cache: bool = False,
past_key_values: Optional[
List[Tuple[torch.Tensor, torch.Tensor]]
] = None,
) -> Union[
torch.Tensor,
Tuple[torch.Tensor, List[torch.Tensor]],
Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]],
Tuple[
torch.Tensor,
List[torch.Tensor],
List[Tuple[torch.Tensor, torch.Tensor]],
],
]:
"""Compute logits for token inputs.
Provide exactly one of:
- ``idx``: token ids of shape (B, T), or
- ``inputs_embeds``: precomputed embeddings of shape (B, T, d_model)
If ``return_attn`` is True, also returns a list of attention weight
tensors, one per layer, each shaped (B, n_heads, T, T) after softmax.
"""
if (idx is None) == (inputs_embeds is None):
raise ValueError("Provide exactly one of idx or inputs_embeds.")
if inputs_embeds is not None:
B, T, C = inputs_embeds.shape
if C != self.config.d_model:
raise ValueError(
"inputs_embeds last dim "
f"{C} != d_model {self.config.d_model}"
)
else:
assert idx is not None
B, T = idx.shape
if T > self.config.block_size:
raise ValueError(
f"Sequence length {T} exceeds block_size {self.config.block_size}"
)
if position_ids is None:
if idx is not None:
device = idx.device
else:
assert inputs_embeds is not None
device = inputs_embeds.device
pos = torch.arange(0, T, device=device, dtype=torch.long)
pos_e = self.wpe(pos).unsqueeze(0)
else:
if position_ids.shape[-1] != T:
raise ValueError(
"position_ids length must match sequence length."
)
pos_e = self.wpe(position_ids)
if pos_e.dim() == 2:
pos_e = pos_e.unsqueeze(0)
if pos_e.shape[0] == 1 and B > 1:
pos_e = pos_e.expand(B, -1, -1)
tok = self.wte(idx) if idx is not None else inputs_embeds
x = tok + pos_e
if idx is not None:
x = x + self._compound_embeds(idx)
x = self.drop(x)
attn_list: List[torch.Tensor] = []
present_key_values: List[Tuple[torch.Tensor, torch.Tensor]] = []
for block in self.blocks:
block_idx = len(present_key_values)
past_kv = None
if past_key_values is not None and block_idx < len(past_key_values):
past_kv = past_key_values[block_idx]
x, aw, present = block(
x,
return_attn=return_attn,
use_cache=use_cache,
past_kv=past_kv,
)
if aw is not None:
attn_list.append(aw)
if present is not None:
present_key_values.append(present)
x = self.ln_f(x)
logits = self.lm_head(x)
if return_attn and use_cache:
return logits, attn_list, present_key_values
if return_attn:
return logits, attn_list
if use_cache:
return logits, present_key_values
return logits
def _compound_embeds(self, idx: torch.Tensor) -> torch.Tensor:
"""Sum of all enabled compound (2D) embedding axes for a batch of
token ids. Returns a tensor of shape (B, T, d_model) — zero if all
axes are disabled."""
B, T = idx.shape
out = torch.zeros(B, T, self.config.d_model, device=idx.device, dtype=self.wte.weight.dtype)
if self.config.use_pitch_class_embed:
out = out + self.embed_pc(self.tok_to_pc[idx])
if self.config.use_octave_embed:
out = out + self.embed_octave(self.tok_to_octave[idx])
if self.config.use_interval_embed:
out = out + self.embed_interval(self._compute_interval_ids(idx))
if self.config.use_beat_cyclic_embed:
out = out + self.embed_beat(self._compute_beat_ids(idx))
if self.config.use_scale_degree_embed:
out = out + self.embed_scale_degree(self._compute_scale_degree_ids(idx))
return out
def _compute_scale_degree_ids(self, idx: torch.Tensor) -> torch.Tensor:
"""For each pitch position, the chromatic scale degree
(pitch_class - current_key_root) mod 12 — where current key is the
most recent KEY token seen. Non-pitch positions and positions
before the first KEY token get the sentinel.
"""
B, T = idx.shape
pc = self.tok_to_pc[idx] # (B, T) PC_NONE if not pitch
key_root = self.tok_to_key_root[idx] # (B, T) -1 if not KEY
arange = torch.arange(T, device=idx.device).expand(B, T)
cand = torch.where(key_root >= 0, arange, torch.full_like(arange, -1))
last_key_idx = cand.cummax(dim=1).values
safe_idx = last_key_idx.clamp(min=0)
cur_root = torch.gather(key_root, 1, safe_idx)
# Compute (pc - root) mod 12 for pitch positions with a known key.
is_pitch = pc != PC_NONE
sd = (pc - cur_root) % 12
valid = is_pitch & (last_key_idx >= 0)
return torch.where(
valid, sd, torch.full_like(sd, SCALE_DEGREE_NONE)
)
def _compute_interval_ids(self, idx: torch.Tensor) -> torch.Tensor:
"""For each pitch-token position, the clipped melodic interval to the
previous pitch token in the same row. Non-pitch positions and the
first pitch get the sentinel INTERVAL_NONE. Vectorized via cummax.
"""
B, T = idx.shape
midi = self.tok_to_midi[idx] # (B, T) -1 if not pitch
is_pitch = midi >= 0 # (B, T)
arange = torch.arange(T, device=idx.device).expand(B, T)
cand = torch.where(is_pitch, arange, torch.full_like(arange, -1))
# Shift right by 1: previous-pitch-up-to-t-1
shifted = torch.cat(
[torch.full_like(cand[:, :1], -1), cand[:, :-1]], dim=1
)
last_idx = shifted.cummax(dim=1).values # (B, T)
safe_idx = last_idx.clamp(min=0)
prev_midi = torch.gather(midi, 1, safe_idx)
delta = (midi - prev_midi).clamp(-INTERVAL_RANGE, INTERVAL_RANGE) + INTERVAL_RANGE
valid = is_pitch & (last_idx >= 0)
return torch.where(valid, delta, torch.full_like(delta, INTERVAL_NONE))
def _compute_beat_ids(self, idx: torch.Tensor) -> torch.Tensor:
"""For each position, the most recent POS<n> bin value seen so far,
or N_POS_BINS (sentinel) if no POS token has been emitted yet.
Vectorized via cummax over POS positions."""
B, T = idx.shape
pos_val = self.tok_to_pos_value[idx] # (B, T) -1 if not POS
arange = torch.arange(T, device=idx.device).expand(B, T)
cand = torch.where(pos_val >= 0, arange, torch.full_like(arange, -1))
last_idx = cand.cummax(dim=1).values # (B, T)
safe_idx = last_idx.clamp(min=0)
gathered = torch.gather(pos_val, 1, safe_idx)
return torch.where(
last_idx >= 0, gathered, torch.full_like(gathered, N_POS_BINS)
)
@torch.no_grad()
def count_parameters(self) -> int:
return sum(p.numel() for p in self.parameters())
if __name__ == "__main__":
cfg = default_gpt_config()
model = GPT(cfg)
n_params = model.count_parameters()
print(f"Config: {cfg}")
print(f"Parameter count: {n_params:,} (~{n_params / 1e6:.2f}M)")
x = torch.randint(0, cfg.vocab_size, (2, min(64, cfg.block_size)))
logits = model(x)
assert logits.shape == (2, x.shape[1], cfg.vocab_size)
logits2, attn = model(x, return_attn=True)
assert len(attn) == cfg.n_layers
assert attn[0].shape == (2, cfg.n_heads, x.shape[1], x.shape[1])
print("Forward + return_attn smoke test OK.")