cdotsanghvi's picture
add multi-head demo as 4th-6th tabs; restore Why Liquid + Integration
083b138
Raw
History Blame Contribute Delete
10.2 kB
"""Downstream task heads for multi-head fine-tuning.
Four heads sharing the pretrained LFM2Small backbone:
1. fraud -- P(fraud) from last-transaction pool (BCE)
2. next_merchant -- merchant_id of last tx from prefix (CE, self-supervised)
3. amount -- amount bucket of last tx from prefix (CE, self-supervised)
4. mcc -- MCC of last tx from prefix (CE, self-supervised)
Two head implementations:
- DownstreamHead -- fresh 2-layer MLP. Doesn't inherit pretrained
knowledge stored in the tied LM-head weights.
Default for fraud (which has no LM-head analog).
- TiedEmbeddingHead -- pool -> adapter -> matmul through the backbone's
tied embedding table for the target feature.
Recovers the pretrained next-feature signal that
a fresh MLP cannot. Use when the target feature
has a corresponding weight-tied LM head in
pretraining (anything with target_type "feature:N").
Pool strategies:
last_tx_mean -- mean of last transaction's 15 positions. These have seen
the full sequence (causal masking), so the representation
encodes full-sequence context. Used for sequence-level tasks.
pre_last_tx -- hidden state at position S - num_features - 1 (end of tx 62).
Sees tx 0-62 only, appropriate for predicting tx 63's features
without target leakage. Verified safe: both causal attention
and left-padded Conv1d respect this boundary.
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class HeadConfig:
"""Single head configuration, loaded from finetune.yaml."""
name: str
output_dim: int
loss_type: str # "bce" | "ce"
pool_strategy: str # "last_tx_mean" | "pre_last_tx"
target_type: str # "sequence_label" | "feature:<idx>"
weight: float = 1.0
mlp_hidden: int = 128
dropout: float = 0.1
class DownstreamHead(nn.Module):
"""Pool hidden states -> 2-layer MLP -> task prediction.
Param count per head: ~33K (small output) to ~65K (merchant_id, 5003 classes).
"""
def __init__(self, config: HeadConfig, hidden_dim: int, num_features: int) -> None:
super().__init__()
self.config = config
self.num_features = num_features
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, config.mlp_hidden),
nn.ReLU(),
nn.Dropout(config.dropout),
nn.Linear(config.mlp_hidden, config.output_dim),
)
def pool(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""(B, S, D) -> (B, D) via head-specific pooling."""
nf = self.num_features
if self.config.pool_strategy == "last_tx_mean":
return hidden_states[:, -nf:, :].mean(dim=1)
if self.config.pool_strategy == "pre_last_tx":
return hidden_states[:, -(nf + 1), :]
raise ValueError(f"Unknown pool strategy: {self.config.pool_strategy}")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.mlp(self.pool(hidden_states))
def compute_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
if self.config.loss_type == "bce":
return F.binary_cross_entropy_with_logits(
logits.squeeze(-1), targets.float(),
)
return F.cross_entropy(logits, targets)
def extract_targets(
self,
token_ids: torch.Tensor,
sequence_labels: torch.Tensor | None,
aux_targets: dict[str, torch.Tensor] | None = None,
) -> torch.Tensor:
"""Get this head's targets from input data.
Args:
token_ids: (B, T, F).
sequence_labels: (B,) binary fraud labels, or None.
aux_targets: optional dict of auxiliary per-sequence targets
(e.g. "amount_range" from amount_range_labels.npy).
"""
if self.config.target_type == "sequence_label":
assert sequence_labels is not None
return sequence_labels
if self.config.target_type == "amount_range":
assert aux_targets is not None and "amount_range" in aux_targets
return aux_targets["amount_range"]
feat_idx = int(self.config.target_type.split(":")[1])
return token_ids[:, -1, feat_idx]
class TiedEmbeddingHead(nn.Module):
"""Downstream head that projects through the backbone's tied embedding table.
The pretrained LM head for feature F is the value-embedding table for F
(weight-tied). The fresh-MLP DownstreamHead discards that projection by
learning a new one from scratch. This head preserves it: pool the backbone
hidden states, run a small adapter, then matmul through the same embedding
table the backbone reads from at input time.
Why this matters:
- At fine-tune time the adapter learns a small distribution shift, not
a 256-d -> vocab_size projection from scratch.
- Gradients flow back to the embedding table from both the input-embed
side and this head, just as during pretraining.
- When the backbone is frozen, the embedding table is frozen too, so
this head reduces to "small adapter on top of pretrained features."
Constraints:
- Only valid for target_type "feature:N". The tied table is keyed to
a specific feature index.
- Output dim is implicitly the vocab_size of that feature.
"""
def __init__(
self,
config: HeadConfig,
hidden_dim: int,
num_features: int,
value_tables: nn.ModuleList,
) -> None:
super().__init__()
self.config = config
self.num_features = num_features
if not config.target_type.startswith("feature:"):
raise ValueError(
f"TiedEmbeddingHead requires target_type 'feature:N', got "
f"{config.target_type!r}",
)
self.feature_idx = int(config.target_type.split(":")[1])
# value_tables is the same nn.ModuleList in StructuredEmbedding. By
# assigning it as a module attribute we share parameters with the
# backbone — no parameter duplication in state_dict.
self.value_tables = value_tables
# The adapter is the only thing this head learns from scratch. Two
# layers with a nonlinearity gives enough flexibility for a small
# distribution-shift correction without obscuring the tied projection.
self.adapter = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(config.dropout),
nn.Linear(hidden_dim, hidden_dim),
)
def pool(self, hidden_states: torch.Tensor) -> torch.Tensor:
nf = self.num_features
if self.config.pool_strategy == "last_tx_mean":
return hidden_states[:, -nf:, :].mean(dim=1)
if self.config.pool_strategy == "pre_last_tx":
return hidden_states[:, -(nf + 1), :]
raise ValueError(f"Unknown pool strategy: {self.config.pool_strategy}")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
pooled = self.pool(hidden_states) # (B, D)
adapted = self.adapter(pooled) # (B, D)
# Project through the tied embedding table. F.linear computes
# adapted @ weight.T, exactly matching how the pretrained LM head
# produces logits. No bias, matching PerFeatureLMHeads.
weight = self.value_tables[self.feature_idx].weight # (vocab_f, D)
return F.linear(adapted, weight)
def compute_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
# Tied heads only support CE — they predict over a tied vocabulary.
return F.cross_entropy(logits, targets)
def extract_targets(
self,
token_ids: torch.Tensor,
sequence_labels: torch.Tensor | None,
aux_targets: dict[str, torch.Tensor] | None = None,
) -> torch.Tensor:
return token_ids[:, -1, self.feature_idx]
# DownstreamHead and TiedEmbeddingHead share a duck-typed interface.
# Union type for the head dict in MultiHeadModel.
AnyHead = DownstreamHead | TiedEmbeddingHead
class MultiHeadModel(nn.Module):
"""Pretrained backbone + downstream task heads.
Calls backbone.backbone_forward() once, then each head pools and
predicts independently. Backbone may be frozen or slow-learned.
"""
def __init__(
self, backbone: nn.Module, heads: dict[str, AnyHead],
) -> None:
super().__init__()
self.backbone = backbone
self.heads = nn.ModuleDict(heads)
def forward(self, token_ids: torch.Tensor) -> dict[str, torch.Tensor]:
hidden = self.backbone.backbone_forward(token_ids)
return {name: head(hidden) for name, head in self.heads.items()}
def compute_losses(
self,
predictions: dict[str, torch.Tensor],
token_ids: torch.Tensor,
sequence_labels: torch.Tensor | None,
aux_targets: dict[str, torch.Tensor] | None = None,
) -> tuple[torch.Tensor, dict[str, float]]:
"""Weighted sum of per-head losses."""
device = next(iter(predictions.values())).device
total = torch.tensor(0.0, device=device)
per_head: dict[str, float] = {}
for name, head in self.heads.items():
targets = head.extract_targets(
token_ids, sequence_labels, aux_targets,
).to(device)
loss = head.compute_loss(predictions[name], targets)
total = total + head.config.weight * loss
per_head[name] = loss.item()
return total, per_head
def head_param_count(self) -> dict[str, int]:
"""Parameter count per head (excludes shared backbone)."""
return {
name: sum(p.numel() for p in head.parameters())
for name, head in self.heads.items()
}