Coda / src /compound_model.py
Prajanya Gupta
initial deploy
6b7b403
"""Compound (Octuple-style) GPT for bach-gpt.
Each input position carries N_AXES feature ids that are embedded in
parallel and summed before the transformer. Output is N_AXES separate
classification heads with their own softmaxes; the training loss is the
(optionally weighted) sum of per-axis cross-entropies.
Usage
-----
from compound import AXIS_SIZES, AXIS_NAMES, STEP_PAD
from compound_model import CompoundGPT, CompoundGPTConfig, compound_loss
cfg = CompoundGPTConfig() # axis_sizes default to compound.AXIS_SIZES
model = CompoundGPT(cfg)
# idx: (B, T, N_AXES) long; targets: same shape, shifted by one step
logits = model(idx) # list of (B, T, axis_size_a) tensors
loss = compound_loss(logits, targets)
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from compound import AXIS_NAMES, AXIS_SIZES, N_AXES, STEP_PAD
from model import GPT, GPTConfig, TransformerBlock
@dataclass
class CompoundGPTConfig:
axis_sizes: Tuple[int, ...] = field(default_factory=lambda: tuple(AXIS_SIZES))
block_size: int = 1024
d_model: int = 512
n_layers: int = 6
n_heads: int = 8
d_ff: int = 2048
dropout: float = 0.1
# Optional per-axis loss weighting at training time. None = uniform.
axis_loss_weights: Optional[Tuple[float, ...]] = None
def default_compound_config() -> CompoundGPTConfig:
return CompoundGPTConfig()
class CompoundGPT(nn.Module):
"""Decoder-only transformer over compound (multi-axis) inputs."""
def __init__(self, config: CompoundGPTConfig) -> None:
super().__init__()
self.config = config
self.n_axes = len(config.axis_sizes)
if self.n_axes != N_AXES:
raise ValueError(
f"axis_sizes has {self.n_axes} entries, expected {N_AXES}"
)
# Per-axis input embeddings, summed across axes.
self.input_embeds = nn.ModuleList(
nn.Embedding(s, config.d_model) for s in config.axis_sizes
)
self.wpe = nn.Embedding(config.block_size, config.d_model)
self.drop = nn.Dropout(config.dropout)
# Reuse the regular transformer blocks. vocab_size in this fake
# GPTConfig is unused by TransformerBlock; we only need the
# attention/MLP shape parameters.
block_cfg = GPTConfig(
vocab_size=1,
block_size=config.block_size,
d_model=config.d_model,
n_layers=config.n_layers,
n_heads=config.n_heads,
d_ff=config.d_ff,
dropout=config.dropout,
)
self.blocks = nn.ModuleList(
TransformerBlock(block_cfg) for _ in range(config.n_layers)
)
self.ln_f = nn.LayerNorm(config.d_model)
# Per-axis output heads. No weight tying — each axis has its own
# classifier over a small vocabulary.
self.heads = nn.ModuleList(
nn.Linear(config.d_model, s, bias=False) for s in config.axis_sizes
)
self.apply(GPT._init_weights)
def forward(
self,
idx: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
return_attn: bool = False,
return_hidden: bool = False,
use_cache: bool = False,
past_key_values: Optional[Tuple[torch.Tensor, ...]] = None,
) -> List[torch.Tensor] | Tuple[List[torch.Tensor], List[torch.Tensor]] | torch.Tensor:
"""Forward pass for compound GPT.
- return_hidden=False, return_attn=False: list of N_AXES logits.
- return_hidden=True: post-LN hidden states (B, T, d_model). Used by
the contrastive MIDI encoder for pooling. No classification heads run.
- return_attn=True: (logits, attn_list).
Input modes:
- ``idx``: (B, T, N_AXES) of long feature ids.
- ``inputs_embeds``: (B, T, d_model) precomputed embeddings.
"""
if use_cache or past_key_values is not None:
# Kept for API compatibility with GPT-style call sites.
raise NotImplementedError(
"CompoundGPT does not currently support KV caching."
)
if (idx is None) == (inputs_embeds is None):
raise ValueError("Provide exactly one of idx or inputs_embeds.")
if inputs_embeds is None:
assert idx is not None
if idx.dim() != 3 or idx.shape[-1] != self.n_axes:
raise ValueError(
f"Expected idx of shape (B, T, {self.n_axes}); got {tuple(idx.shape)}"
)
B, T, _ = idx.shape
else:
if inputs_embeds.dim() != 3 or inputs_embeds.size(-1) != self.config.d_model:
raise ValueError(
"Expected inputs_embeds shape "
f"(B, T, {self.config.d_model}); got {tuple(inputs_embeds.shape)}"
)
B, T, _ = inputs_embeds.shape
if T > self.config.block_size:
raise ValueError(
f"Sequence length {T} exceeds block_size {self.config.block_size}"
)
if inputs_embeds is None:
assert idx is not None
x = self.input_embeds[0](idx[..., 0])
for a in range(1, self.n_axes):
x = x + self.input_embeds[a](idx[..., a])
device = idx.device
else:
x = inputs_embeds
device = inputs_embeds.device
if position_ids is None:
pos = torch.arange(T, device=device, dtype=torch.long).unsqueeze(0)
else:
if position_ids.dim() == 1:
pos = position_ids.unsqueeze(0)
elif position_ids.dim() == 2:
pos = position_ids
else:
raise ValueError(
f"position_ids must be shape (T,) or (B, T); got {tuple(position_ids.shape)}"
)
if pos.size(1) != T:
raise ValueError(
f"position_ids length {pos.size(1)} must equal sequence length {T}"
)
if pos.size(0) not in (1, B):
raise ValueError(
f"position_ids batch dim {pos.size(0)} must be 1 or {B}"
)
pos = pos.to(device=device, dtype=torch.long)
x = self.drop(x + self.wpe(pos))
attn_list: List[torch.Tensor] = []
for block in self.blocks:
x, aw, _ = block(x, return_attn=return_attn)
if aw is not None:
attn_list.append(aw)
x = self.ln_f(x)
if return_hidden:
return x
logits_per_axis = [head(x) for head in self.heads]
if return_attn:
return logits_per_axis, attn_list
return logits_per_axis
@torch.no_grad()
def count_parameters(self) -> int:
return sum(p.numel() for p in self.parameters())
def compound_loss(
logits_per_axis: List[torch.Tensor],
targets: torch.Tensor,
axis_weights: Optional[Tuple[float, ...]] = None,
pad_step_value: int = STEP_PAD,
ignore_pad_steps: bool = True,
) -> torch.Tensor:
"""Sum of per-axis cross-entropies. ``targets`` shape (B, T, N_AXES).
If ``ignore_pad_steps`` is True, positions whose step-axis target is
``pad_step_value`` contribute zero loss (standard padding mask).
"""
if targets.dim() != 3:
raise ValueError(
f"targets must be (B, T, N_AXES); got {tuple(targets.shape)}"
)
n_axes = len(logits_per_axis)
if axis_weights is None:
axis_weights = tuple([1.0] * n_axes)
if len(axis_weights) != n_axes:
raise ValueError(
f"axis_weights length {len(axis_weights)} != {n_axes}"
)
step_targets = targets[..., 0] # (B, T)
valid = step_targets != pad_step_value if ignore_pad_steps else None
total = torch.zeros((), device=targets.device, dtype=torch.float32)
for a in range(n_axes):
logits = logits_per_axis[a] # (B, T, A_a)
tgt = targets[..., a] # (B, T)
flat_logits = logits.reshape(-1, logits.size(-1))
flat_tgt = tgt.reshape(-1)
if valid is not None:
flat_mask = valid.reshape(-1)
if flat_mask.sum() == 0:
continue
loss_a = F.cross_entropy(
flat_logits[flat_mask], flat_tgt[flat_mask], reduction="mean"
)
else:
loss_a = F.cross_entropy(flat_logits, flat_tgt, reduction="mean")
total = total + axis_weights[a] * loss_a
return total
# --- Smoke test --------------------------------------------------------------
if __name__ == "__main__":
cfg = default_compound_config()
cfg.block_size = 64
cfg.n_layers = 2
cfg.d_model = 128
cfg.d_ff = 512
model = CompoundGPT(cfg)
print(f"Compound axes ({N_AXES}): {AXIS_NAMES}")
print(f"Axis sizes: {AXIS_SIZES}")
print(
f"Parameters: {model.count_parameters():,} "
f"(~{model.count_parameters()/1e6:.2f}M)"
)
B, T = 2, 32
idx = torch.stack([
torch.randint(0, AXIS_SIZES[a], (B, T)) for a in range(N_AXES)
], dim=-1).long()
logits = model(idx)
assert len(logits) == N_AXES
for a, l in enumerate(logits):
assert l.shape == (B, T, AXIS_SIZES[a]), (a, l.shape, AXIS_SIZES[a])
loss = compound_loss(logits, idx)
print(f"Forward + per-axis loss OK. loss={loss.item():.4f}")