Spaces:
Sleeping
Sleeping
| """Compound (Octuple-style) GPT for bach-gpt. | |
| Each input position carries N_AXES feature ids that are embedded in | |
| parallel and summed before the transformer. Output is N_AXES separate | |
| classification heads with their own softmaxes; the training loss is the | |
| (optionally weighted) sum of per-axis cross-entropies. | |
| Usage | |
| ----- | |
| from compound import AXIS_SIZES, AXIS_NAMES, STEP_PAD | |
| from compound_model import CompoundGPT, CompoundGPTConfig, compound_loss | |
| cfg = CompoundGPTConfig() # axis_sizes default to compound.AXIS_SIZES | |
| model = CompoundGPT(cfg) | |
| # idx: (B, T, N_AXES) long; targets: same shape, shifted by one step | |
| logits = model(idx) # list of (B, T, axis_size_a) tensors | |
| loss = compound_loss(logits, targets) | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| from typing import List, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from compound import AXIS_NAMES, AXIS_SIZES, N_AXES, STEP_PAD | |
| from model import GPT, GPTConfig, TransformerBlock | |
| class CompoundGPTConfig: | |
| axis_sizes: Tuple[int, ...] = field(default_factory=lambda: tuple(AXIS_SIZES)) | |
| block_size: int = 1024 | |
| d_model: int = 512 | |
| n_layers: int = 6 | |
| n_heads: int = 8 | |
| d_ff: int = 2048 | |
| dropout: float = 0.1 | |
| # Optional per-axis loss weighting at training time. None = uniform. | |
| axis_loss_weights: Optional[Tuple[float, ...]] = None | |
| def default_compound_config() -> CompoundGPTConfig: | |
| return CompoundGPTConfig() | |
| class CompoundGPT(nn.Module): | |
| """Decoder-only transformer over compound (multi-axis) inputs.""" | |
| def __init__(self, config: CompoundGPTConfig) -> None: | |
| super().__init__() | |
| self.config = config | |
| self.n_axes = len(config.axis_sizes) | |
| if self.n_axes != N_AXES: | |
| raise ValueError( | |
| f"axis_sizes has {self.n_axes} entries, expected {N_AXES}" | |
| ) | |
| # Per-axis input embeddings, summed across axes. | |
| self.input_embeds = nn.ModuleList( | |
| nn.Embedding(s, config.d_model) for s in config.axis_sizes | |
| ) | |
| self.wpe = nn.Embedding(config.block_size, config.d_model) | |
| self.drop = nn.Dropout(config.dropout) | |
| # Reuse the regular transformer blocks. vocab_size in this fake | |
| # GPTConfig is unused by TransformerBlock; we only need the | |
| # attention/MLP shape parameters. | |
| block_cfg = GPTConfig( | |
| vocab_size=1, | |
| block_size=config.block_size, | |
| d_model=config.d_model, | |
| n_layers=config.n_layers, | |
| n_heads=config.n_heads, | |
| d_ff=config.d_ff, | |
| dropout=config.dropout, | |
| ) | |
| self.blocks = nn.ModuleList( | |
| TransformerBlock(block_cfg) for _ in range(config.n_layers) | |
| ) | |
| self.ln_f = nn.LayerNorm(config.d_model) | |
| # Per-axis output heads. No weight tying — each axis has its own | |
| # classifier over a small vocabulary. | |
| self.heads = nn.ModuleList( | |
| nn.Linear(config.d_model, s, bias=False) for s in config.axis_sizes | |
| ) | |
| self.apply(GPT._init_weights) | |
| def forward( | |
| self, | |
| idx: Optional[torch.Tensor] = None, | |
| inputs_embeds: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| return_attn: bool = False, | |
| return_hidden: bool = False, | |
| use_cache: bool = False, | |
| past_key_values: Optional[Tuple[torch.Tensor, ...]] = None, | |
| ) -> List[torch.Tensor] | Tuple[List[torch.Tensor], List[torch.Tensor]] | torch.Tensor: | |
| """Forward pass for compound GPT. | |
| - return_hidden=False, return_attn=False: list of N_AXES logits. | |
| - return_hidden=True: post-LN hidden states (B, T, d_model). Used by | |
| the contrastive MIDI encoder for pooling. No classification heads run. | |
| - return_attn=True: (logits, attn_list). | |
| Input modes: | |
| - ``idx``: (B, T, N_AXES) of long feature ids. | |
| - ``inputs_embeds``: (B, T, d_model) precomputed embeddings. | |
| """ | |
| if use_cache or past_key_values is not None: | |
| # Kept for API compatibility with GPT-style call sites. | |
| raise NotImplementedError( | |
| "CompoundGPT does not currently support KV caching." | |
| ) | |
| if (idx is None) == (inputs_embeds is None): | |
| raise ValueError("Provide exactly one of idx or inputs_embeds.") | |
| if inputs_embeds is None: | |
| assert idx is not None | |
| if idx.dim() != 3 or idx.shape[-1] != self.n_axes: | |
| raise ValueError( | |
| f"Expected idx of shape (B, T, {self.n_axes}); got {tuple(idx.shape)}" | |
| ) | |
| B, T, _ = idx.shape | |
| else: | |
| if inputs_embeds.dim() != 3 or inputs_embeds.size(-1) != self.config.d_model: | |
| raise ValueError( | |
| "Expected inputs_embeds shape " | |
| f"(B, T, {self.config.d_model}); got {tuple(inputs_embeds.shape)}" | |
| ) | |
| B, T, _ = inputs_embeds.shape | |
| if T > self.config.block_size: | |
| raise ValueError( | |
| f"Sequence length {T} exceeds block_size {self.config.block_size}" | |
| ) | |
| if inputs_embeds is None: | |
| assert idx is not None | |
| x = self.input_embeds[0](idx[..., 0]) | |
| for a in range(1, self.n_axes): | |
| x = x + self.input_embeds[a](idx[..., a]) | |
| device = idx.device | |
| else: | |
| x = inputs_embeds | |
| device = inputs_embeds.device | |
| if position_ids is None: | |
| pos = torch.arange(T, device=device, dtype=torch.long).unsqueeze(0) | |
| else: | |
| if position_ids.dim() == 1: | |
| pos = position_ids.unsqueeze(0) | |
| elif position_ids.dim() == 2: | |
| pos = position_ids | |
| else: | |
| raise ValueError( | |
| f"position_ids must be shape (T,) or (B, T); got {tuple(position_ids.shape)}" | |
| ) | |
| if pos.size(1) != T: | |
| raise ValueError( | |
| f"position_ids length {pos.size(1)} must equal sequence length {T}" | |
| ) | |
| if pos.size(0) not in (1, B): | |
| raise ValueError( | |
| f"position_ids batch dim {pos.size(0)} must be 1 or {B}" | |
| ) | |
| pos = pos.to(device=device, dtype=torch.long) | |
| x = self.drop(x + self.wpe(pos)) | |
| attn_list: List[torch.Tensor] = [] | |
| for block in self.blocks: | |
| x, aw, _ = block(x, return_attn=return_attn) | |
| if aw is not None: | |
| attn_list.append(aw) | |
| x = self.ln_f(x) | |
| if return_hidden: | |
| return x | |
| logits_per_axis = [head(x) for head in self.heads] | |
| if return_attn: | |
| return logits_per_axis, attn_list | |
| return logits_per_axis | |
| def count_parameters(self) -> int: | |
| return sum(p.numel() for p in self.parameters()) | |
| def compound_loss( | |
| logits_per_axis: List[torch.Tensor], | |
| targets: torch.Tensor, | |
| axis_weights: Optional[Tuple[float, ...]] = None, | |
| pad_step_value: int = STEP_PAD, | |
| ignore_pad_steps: bool = True, | |
| ) -> torch.Tensor: | |
| """Sum of per-axis cross-entropies. ``targets`` shape (B, T, N_AXES). | |
| If ``ignore_pad_steps`` is True, positions whose step-axis target is | |
| ``pad_step_value`` contribute zero loss (standard padding mask). | |
| """ | |
| if targets.dim() != 3: | |
| raise ValueError( | |
| f"targets must be (B, T, N_AXES); got {tuple(targets.shape)}" | |
| ) | |
| n_axes = len(logits_per_axis) | |
| if axis_weights is None: | |
| axis_weights = tuple([1.0] * n_axes) | |
| if len(axis_weights) != n_axes: | |
| raise ValueError( | |
| f"axis_weights length {len(axis_weights)} != {n_axes}" | |
| ) | |
| step_targets = targets[..., 0] # (B, T) | |
| valid = step_targets != pad_step_value if ignore_pad_steps else None | |
| total = torch.zeros((), device=targets.device, dtype=torch.float32) | |
| for a in range(n_axes): | |
| logits = logits_per_axis[a] # (B, T, A_a) | |
| tgt = targets[..., a] # (B, T) | |
| flat_logits = logits.reshape(-1, logits.size(-1)) | |
| flat_tgt = tgt.reshape(-1) | |
| if valid is not None: | |
| flat_mask = valid.reshape(-1) | |
| if flat_mask.sum() == 0: | |
| continue | |
| loss_a = F.cross_entropy( | |
| flat_logits[flat_mask], flat_tgt[flat_mask], reduction="mean" | |
| ) | |
| else: | |
| loss_a = F.cross_entropy(flat_logits, flat_tgt, reduction="mean") | |
| total = total + axis_weights[a] * loss_a | |
| return total | |
| # --- Smoke test -------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| cfg = default_compound_config() | |
| cfg.block_size = 64 | |
| cfg.n_layers = 2 | |
| cfg.d_model = 128 | |
| cfg.d_ff = 512 | |
| model = CompoundGPT(cfg) | |
| print(f"Compound axes ({N_AXES}): {AXIS_NAMES}") | |
| print(f"Axis sizes: {AXIS_SIZES}") | |
| print( | |
| f"Parameters: {model.count_parameters():,} " | |
| f"(~{model.count_parameters()/1e6:.2f}M)" | |
| ) | |
| B, T = 2, 32 | |
| idx = torch.stack([ | |
| torch.randint(0, AXIS_SIZES[a], (B, T)) for a in range(N_AXES) | |
| ], dim=-1).long() | |
| logits = model(idx) | |
| assert len(logits) == N_AXES | |
| for a, l in enumerate(logits): | |
| assert l.shape == (B, T, AXIS_SIZES[a]), (a, l.shape, AXIS_SIZES[a]) | |
| loss = compound_loss(logits, idx) | |
| print(f"Forward + per-axis loss OK. loss={loss.item():.4f}") | |