Coda / src /contrastive_model.py
Prajanya Gupta
initial deploy
6b7b403
"""Contrastive MIDI-text model architecture.
Components:
1) Frozen MIDI GPT encoder + masked mean pooling
2) Frozen sentence-transformers MiniLM text encoder
3) Two trainable projection heads to a shared embedding space
4) Learnable log-temperature scalar
"""
from __future__ import annotations
import math
from typing import Any, Dict, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from model import GPT
# CompoundGPT lives in compound_model; only imported for the compound
# variant below to avoid a circular dependency when only GPT is used.
def symmetric_infonce_loss(
midi_embeds: torch.Tensor,
text_embeds: torch.Tensor,
temperature: torch.Tensor,
) -> Dict[str, torch.Tensor]:
"""Compute symmetric InfoNCE loss for paired MIDI/text batches.
Args:
midi_embeds: L2-normalized MIDI embeddings, shape (N, D).
text_embeds: L2-normalized text embeddings, shape (N, D).
temperature: Scalar temperature tensor.
Returns:
Dictionary containing:
- logits_midi_to_text: (N, N)
- logits_text_to_midi: (N, N)
- loss_midi_to_text: scalar
- loss_text_to_midi: scalar
- loss: symmetric average of both directions
- acc_midi_to_text: top-1 retrieval accuracy for rows
- acc_text_to_midi: top-1 retrieval accuracy for cols
"""
if midi_embeds.ndim != 2 or text_embeds.ndim != 2:
raise ValueError("midi_embeds and text_embeds must both be rank-2 tensors.")
if midi_embeds.shape != text_embeds.shape:
raise ValueError("midi_embeds and text_embeds must have identical shape.")
logits = midi_embeds @ text_embeds.t() / temperature
labels = torch.arange(logits.size(0), device=logits.device)
loss_m2t = F.cross_entropy(logits, labels)
loss_t2m = F.cross_entropy(logits.t(), labels)
loss = 0.5 * (loss_m2t + loss_t2m)
with torch.no_grad():
pred_m2t = torch.argmax(logits, dim=1)
pred_t2m = torch.argmax(logits.t(), dim=1)
acc_m2t = (pred_m2t == labels).float().mean()
acc_t2m = (pred_t2m == labels).float().mean()
return {
"logits_midi_to_text": logits,
"logits_text_to_midi": logits.t(),
"loss_midi_to_text": loss_m2t,
"loss_text_to_midi": loss_t2m,
"loss": loss,
"acc_midi_to_text": acc_m2t,
"acc_text_to_midi": acc_t2m,
}
class ProjectionHead(nn.Module):
"""Linear -> GELU -> LayerNorm -> Linear projection MLP."""
def __init__(
self, input_dim: int, hidden_dim: int = 512, out_dim: int = 256
):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, out_dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class MidiTextContrastiveModel(nn.Module):
"""MIDI-text contrastive architecture with frozen base encoders."""
def __init__(
self,
midi_gpt: GPT,
text_model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
embed_dim: int = 256,
init_temperature: float = 0.07,
min_temperature: float = 0.01,
max_temperature: float = 1.0,
device: Optional[torch.device] = None,
) -> None:
super().__init__()
self.midi_encoder = midi_gpt
self.text_encoder = SentenceTransformer(text_model_name)
self.embed_dim = embed_dim
self.min_temperature = min_temperature
self.max_temperature = max_temperature
self._last_hidden: Optional[torch.Tensor] = None
self.midi_projection = ProjectionHead(
input_dim=midi_gpt.config.d_model,
hidden_dim=512,
out_dim=embed_dim,
)
self.text_projection = ProjectionHead(
input_dim=384,
hidden_dim=512,
out_dim=embed_dim,
)
# Learnable log-temperature (CLIP-style init at log(0.07)).
self.log_temperature = nn.Parameter(
torch.tensor(math.log(init_temperature))
)
if device is not None:
self.to(device)
self.freeze_midi_encoder()
self.freeze_text_encoder()
def freeze_midi_encoder(self) -> None:
for p in self.midi_encoder.parameters():
p.requires_grad = False
self.midi_encoder.eval()
def freeze_text_encoder(self) -> None:
for p in self.text_encoder.parameters():
p.requires_grad = False
self.text_encoder.eval()
def unfreeze_text_encoder(self) -> None:
for p in self.text_encoder.parameters():
p.requires_grad = True
self.text_encoder.train()
def _capture_last_hidden_hook(self, _module, _inputs, output) -> None:
x = output[0] if isinstance(output, tuple) else output
self._last_hidden = x
def _extract_midi_last_hidden(self, input_ids: torch.Tensor) -> torch.Tensor:
# Preferred route: try native hidden-state support if available.
try:
out = self.midi_encoder(input_ids, output_hidden_states=True)
if isinstance(out, dict) and "hidden_states" in out:
hs = out["hidden_states"]
if isinstance(hs, (list, tuple)) and hs:
return hs[-1]
if hasattr(out, "hidden_states") and out.hidden_states:
return out.hidden_states[-1]
except TypeError:
# Expected for this repo's custom GPT; use hook fallback below.
pass
# Hook fallback uses instance state and is not concurrency-safe across
# overlapping forward passes on the same model instance.
self._last_hidden = None
hook = self.midi_encoder.blocks[-1].register_forward_hook(
self._capture_last_hidden_hook
)
try:
_ = self.midi_encoder(input_ids)
finally:
hook.remove()
if self._last_hidden is None:
raise RuntimeError("Failed to capture MIDI last hidden states.")
return self.midi_encoder.ln_f(self._last_hidden)
@staticmethod
def _masked_mean_pool(
hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
# hidden_states: (B, T, D), attention_mask: (B, T)
mask = attention_mask.unsqueeze(-1).to(hidden_states.dtype)
summed = (hidden_states * mask).sum(dim=1)
denom = mask.sum(dim=1).clamp_min(1.0)
return summed / denom
def encode_midi(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
with torch.no_grad():
hidden = self._extract_midi_last_hidden(input_ids)
pooled = self._masked_mean_pool(hidden, attention_mask)
return pooled
def encode_text(
self, captions: List[str], device: torch.device
) -> torch.Tensor:
text_trainable = any(p.requires_grad for p in self.text_encoder.parameters())
if text_trainable:
features = self.text_encoder.tokenize(captions)
features = {
k: v.to(device) if hasattr(v, "to") else v
for k, v in features.items()
}
out = self.text_encoder(features)
emb = out.get("sentence_embedding")
if emb is None:
emb = out.get("sentence_embeddings")
if emb is None:
emb = next(iter(out.values()))
return emb
with torch.no_grad():
emb = self.text_encoder.encode(
captions,
convert_to_tensor=True,
device=str(device),
normalize_embeddings=False,
)
# SentenceTransformer.encode may use inference_mode internally; clone so
# trainable projection heads can participate in autograd.
return emb.clone()
def get_temperature(self) -> torch.Tensor:
return torch.exp(self.log_temperature).clamp(
self.min_temperature, self.max_temperature
)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
captions: List[str],
) -> Dict[str, torch.Tensor]:
device = input_ids.device
midi_features = self.encode_midi(input_ids, attention_mask)
text_features = self.encode_text(captions, device=device)
midi_proj = self.midi_projection(midi_features)
text_proj = self.text_projection(text_features)
midi_embeds = F.normalize(midi_proj, p=2, dim=-1)
text_embeds = F.normalize(text_proj, p=2, dim=-1)
temperature = self.get_temperature()
loss_out = symmetric_infonce_loss(
midi_embeds=midi_embeds,
text_embeds=text_embeds,
temperature=temperature,
)
return {
"midi_embeds": midi_embeds,
"text_embeds": text_embeds,
"temperature": temperature,
**loss_out,
}
def trainable_parameters(self):
# Stage A: only projection heads + temperature are trainable.
return list(self.midi_projection.parameters()) + list(
self.text_projection.parameters()
) + [self.log_temperature]
def make_optimizer_param_groups(
self,
proj_lr: float,
weight_decay: float = 0.0,
temperature_lr_scale: float = 0.1,
) -> List[Dict[str, Any]]:
return [
{
"params": self.midi_projection.parameters(),
"lr": proj_lr,
"weight_decay": weight_decay,
},
{
"params": self.text_projection.parameters(),
"lr": proj_lr,
"weight_decay": weight_decay,
},
{
"params": [self.log_temperature],
"lr": proj_lr * temperature_lr_scale,
"weight_decay": 0.0,
},
]
# --- Compound (Octuple) variant ----------------------------------------------
class CompoundMidiTextContrastiveModel(nn.Module):
"""CLIP-style contrastive model with a CompoundGPT MIDI encoder.
This is the "Option C" hybrid: the contrastive encoder uses compound
(per-axis) MIDI inputs to learn structured musical representations,
while the autoregressive decoder side of the project remains the 1D +
BPE GPT (loaded separately at prefix-tuning time).
Inputs to ``forward`` differ from ``MidiTextContrastiveModel``:
- ``compound_input`` of shape (B, T, N_AXES) instead of (input_ids, attention_mask)
- The pad mask is derived from the step-type axis (== STEP_PAD).
"""
def __init__(
self,
midi_compound_gpt: "CompoundGPT",
text_model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
embed_dim: int = 256,
init_temperature: float = 0.07,
min_temperature: float = 0.01,
max_temperature: float = 1.0,
device: Optional[torch.device] = None,
) -> None:
super().__init__()
from compound import STEP_PAD # local import to avoid cycles
self._step_pad = STEP_PAD
self.midi_encoder = midi_compound_gpt
self.text_encoder = SentenceTransformer(text_model_name)
self.embed_dim = embed_dim
self.min_temperature = min_temperature
self.max_temperature = max_temperature
self.midi_projection = ProjectionHead(
input_dim=midi_compound_gpt.config.d_model,
hidden_dim=512,
out_dim=embed_dim,
)
self.text_projection = ProjectionHead(
input_dim=384,
hidden_dim=512,
out_dim=embed_dim,
)
self.log_temperature = nn.Parameter(
torch.tensor(math.log(init_temperature))
)
if device is not None:
self.to(device)
self.freeze_midi_encoder()
self.freeze_text_encoder()
def freeze_midi_encoder(self) -> None:
for p in self.midi_encoder.parameters():
p.requires_grad = False
self.midi_encoder.eval()
def freeze_text_encoder(self) -> None:
for p in self.text_encoder.parameters():
p.requires_grad = False
self.text_encoder.eval()
def unfreeze_text_encoder(self) -> None:
for p in self.text_encoder.parameters():
p.requires_grad = True
self.text_encoder.train()
def encode_midi(self, compound_input: torch.Tensor) -> torch.Tensor:
"""compound_input: (B, T, N_AXES) long. Returns pooled (B, d_model).
Pad steps (step-axis == STEP_PAD) are excluded from the mean pool."""
with torch.no_grad():
hidden = self.midi_encoder(compound_input, return_hidden=True)
mask = (compound_input[..., 0] != self._step_pad).to(hidden.dtype)
mask = mask.unsqueeze(-1)
summed = (hidden * mask).sum(dim=1)
denom = mask.sum(dim=1).clamp_min(1.0)
pooled = summed / denom
return pooled
def encode_text(
self, captions: List[str], device: torch.device
) -> torch.Tensor:
text_trainable = any(p.requires_grad for p in self.text_encoder.parameters())
if text_trainable:
features = self.text_encoder.tokenize(captions)
features = {
k: v.to(device) if hasattr(v, "to") else v
for k, v in features.items()
}
out = self.text_encoder(features)
emb = out.get("sentence_embedding")
if emb is None:
emb = out.get("sentence_embeddings")
if emb is None:
emb = next(iter(out.values()))
return emb
with torch.no_grad():
emb = self.text_encoder.encode(
captions,
convert_to_tensor=True,
device=str(device),
normalize_embeddings=False,
)
# SentenceTransformer.encode may use inference_mode internally; clone so
# trainable projection heads can participate in autograd.
return emb.clone()
def get_temperature(self) -> torch.Tensor:
return torch.exp(self.log_temperature).clamp(
self.min_temperature, self.max_temperature
)
def forward(
self,
compound_input: torch.Tensor,
captions: List[str],
) -> Dict[str, torch.Tensor]:
device = compound_input.device
midi_features = self.encode_midi(compound_input)
text_features = self.encode_text(captions, device=device)
midi_proj = self.midi_projection(midi_features)
text_proj = self.text_projection(text_features)
midi_embeds = F.normalize(midi_proj, p=2, dim=-1)
text_embeds = F.normalize(text_proj, p=2, dim=-1)
temperature = self.get_temperature()
loss_out = symmetric_infonce_loss(
midi_embeds=midi_embeds,
text_embeds=text_embeds,
temperature=temperature,
)
return {
"midi_embeds": midi_embeds,
"text_embeds": text_embeds,
"temperature": temperature,
**loss_out,
}
def trainable_parameters(self):
return (
list(self.midi_projection.parameters())
+ list(self.text_projection.parameters())
+ [self.log_temperature]
)
def make_optimizer_param_groups(
self,
proj_lr: float,
weight_decay: float = 0.0,
temperature_lr_scale: float = 0.1,
) -> List[Dict[str, Any]]:
return [
{
"params": self.midi_projection.parameters(),
"lr": proj_lr,
"weight_decay": weight_decay,
},
{
"params": self.text_projection.parameters(),
"lr": proj_lr,
"weight_decay": weight_decay,
},
{
"params": [self.log_temperature],
"lr": proj_lr * temperature_lr_scale,
"weight_decay": 0.0,
},
]