"""Downstream task heads for multi-head fine-tuning. Four heads sharing the pretrained LFM2Small backbone: 1. fraud -- P(fraud) from last-transaction pool (BCE) 2. next_merchant -- merchant_id of last tx from prefix (CE, self-supervised) 3. amount -- amount bucket of last tx from prefix (CE, self-supervised) 4. mcc -- MCC of last tx from prefix (CE, self-supervised) Two head implementations: - DownstreamHead -- fresh 2-layer MLP. Doesn't inherit pretrained knowledge stored in the tied LM-head weights. Default for fraud (which has no LM-head analog). - TiedEmbeddingHead -- pool -> adapter -> matmul through the backbone's tied embedding table for the target feature. Recovers the pretrained next-feature signal that a fresh MLP cannot. Use when the target feature has a corresponding weight-tied LM head in pretraining (anything with target_type "feature:N"). Pool strategies: last_tx_mean -- mean of last transaction's 15 positions. These have seen the full sequence (causal masking), so the representation encodes full-sequence context. Used for sequence-level tasks. pre_last_tx -- hidden state at position S - num_features - 1 (end of tx 62). Sees tx 0-62 only, appropriate for predicting tx 63's features without target leakage. Verified safe: both causal attention and left-padded Conv1d respect this boundary. """ from __future__ import annotations from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F @dataclass class HeadConfig: """Single head configuration, loaded from finetune.yaml.""" name: str output_dim: int loss_type: str # "bce" | "ce" pool_strategy: str # "last_tx_mean" | "pre_last_tx" target_type: str # "sequence_label" | "feature:" weight: float = 1.0 mlp_hidden: int = 128 dropout: float = 0.1 class DownstreamHead(nn.Module): """Pool hidden states -> 2-layer MLP -> task prediction. Param count per head: ~33K (small output) to ~65K (merchant_id, 5003 classes). """ def __init__(self, config: HeadConfig, hidden_dim: int, num_features: int) -> None: super().__init__() self.config = config self.num_features = num_features self.mlp = nn.Sequential( nn.Linear(hidden_dim, config.mlp_hidden), nn.ReLU(), nn.Dropout(config.dropout), nn.Linear(config.mlp_hidden, config.output_dim), ) def pool(self, hidden_states: torch.Tensor) -> torch.Tensor: """(B, S, D) -> (B, D) via head-specific pooling.""" nf = self.num_features if self.config.pool_strategy == "last_tx_mean": return hidden_states[:, -nf:, :].mean(dim=1) if self.config.pool_strategy == "pre_last_tx": return hidden_states[:, -(nf + 1), :] raise ValueError(f"Unknown pool strategy: {self.config.pool_strategy}") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.mlp(self.pool(hidden_states)) def compute_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: if self.config.loss_type == "bce": return F.binary_cross_entropy_with_logits( logits.squeeze(-1), targets.float(), ) return F.cross_entropy(logits, targets) def extract_targets( self, token_ids: torch.Tensor, sequence_labels: torch.Tensor | None, aux_targets: dict[str, torch.Tensor] | None = None, ) -> torch.Tensor: """Get this head's targets from input data. Args: token_ids: (B, T, F). sequence_labels: (B,) binary fraud labels, or None. aux_targets: optional dict of auxiliary per-sequence targets (e.g. "amount_range" from amount_range_labels.npy). """ if self.config.target_type == "sequence_label": assert sequence_labels is not None return sequence_labels if self.config.target_type == "amount_range": assert aux_targets is not None and "amount_range" in aux_targets return aux_targets["amount_range"] feat_idx = int(self.config.target_type.split(":")[1]) return token_ids[:, -1, feat_idx] class TiedEmbeddingHead(nn.Module): """Downstream head that projects through the backbone's tied embedding table. The pretrained LM head for feature F is the value-embedding table for F (weight-tied). The fresh-MLP DownstreamHead discards that projection by learning a new one from scratch. This head preserves it: pool the backbone hidden states, run a small adapter, then matmul through the same embedding table the backbone reads from at input time. Why this matters: - At fine-tune time the adapter learns a small distribution shift, not a 256-d -> vocab_size projection from scratch. - Gradients flow back to the embedding table from both the input-embed side and this head, just as during pretraining. - When the backbone is frozen, the embedding table is frozen too, so this head reduces to "small adapter on top of pretrained features." Constraints: - Only valid for target_type "feature:N". The tied table is keyed to a specific feature index. - Output dim is implicitly the vocab_size of that feature. """ def __init__( self, config: HeadConfig, hidden_dim: int, num_features: int, value_tables: nn.ModuleList, ) -> None: super().__init__() self.config = config self.num_features = num_features if not config.target_type.startswith("feature:"): raise ValueError( f"TiedEmbeddingHead requires target_type 'feature:N', got " f"{config.target_type!r}", ) self.feature_idx = int(config.target_type.split(":")[1]) # value_tables is the same nn.ModuleList in StructuredEmbedding. By # assigning it as a module attribute we share parameters with the # backbone — no parameter duplication in state_dict. self.value_tables = value_tables # The adapter is the only thing this head learns from scratch. Two # layers with a nonlinearity gives enough flexibility for a small # distribution-shift correction without obscuring the tied projection. self.adapter = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Dropout(config.dropout), nn.Linear(hidden_dim, hidden_dim), ) def pool(self, hidden_states: torch.Tensor) -> torch.Tensor: nf = self.num_features if self.config.pool_strategy == "last_tx_mean": return hidden_states[:, -nf:, :].mean(dim=1) if self.config.pool_strategy == "pre_last_tx": return hidden_states[:, -(nf + 1), :] raise ValueError(f"Unknown pool strategy: {self.config.pool_strategy}") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: pooled = self.pool(hidden_states) # (B, D) adapted = self.adapter(pooled) # (B, D) # Project through the tied embedding table. F.linear computes # adapted @ weight.T, exactly matching how the pretrained LM head # produces logits. No bias, matching PerFeatureLMHeads. weight = self.value_tables[self.feature_idx].weight # (vocab_f, D) return F.linear(adapted, weight) def compute_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: # Tied heads only support CE — they predict over a tied vocabulary. return F.cross_entropy(logits, targets) def extract_targets( self, token_ids: torch.Tensor, sequence_labels: torch.Tensor | None, aux_targets: dict[str, torch.Tensor] | None = None, ) -> torch.Tensor: return token_ids[:, -1, self.feature_idx] # DownstreamHead and TiedEmbeddingHead share a duck-typed interface. # Union type for the head dict in MultiHeadModel. AnyHead = DownstreamHead | TiedEmbeddingHead class MultiHeadModel(nn.Module): """Pretrained backbone + downstream task heads. Calls backbone.backbone_forward() once, then each head pools and predicts independently. Backbone may be frozen or slow-learned. """ def __init__( self, backbone: nn.Module, heads: dict[str, AnyHead], ) -> None: super().__init__() self.backbone = backbone self.heads = nn.ModuleDict(heads) def forward(self, token_ids: torch.Tensor) -> dict[str, torch.Tensor]: hidden = self.backbone.backbone_forward(token_ids) return {name: head(hidden) for name, head in self.heads.items()} def compute_losses( self, predictions: dict[str, torch.Tensor], token_ids: torch.Tensor, sequence_labels: torch.Tensor | None, aux_targets: dict[str, torch.Tensor] | None = None, ) -> tuple[torch.Tensor, dict[str, float]]: """Weighted sum of per-head losses.""" device = next(iter(predictions.values())).device total = torch.tensor(0.0, device=device) per_head: dict[str, float] = {} for name, head in self.heads.items(): targets = head.extract_targets( token_ids, sequence_labels, aux_targets, ).to(device) loss = head.compute_loss(predictions[name], targets) total = total + head.config.weight * loss per_head[name] = loss.item() return total, per_head def head_param_count(self) -> dict[str, int]: """Parameter count per head (excludes shared backbone).""" return { name: sum(p.numel() for p in head.parameters()) for name, head in self.heads.items() }