| """Downstream task heads for multi-head fine-tuning. |
| |
| Four heads sharing the pretrained LFM2Small backbone: |
| |
| 1. fraud -- P(fraud) from last-transaction pool (BCE) |
| 2. next_merchant -- merchant_id of last tx from prefix (CE, self-supervised) |
| 3. amount -- amount bucket of last tx from prefix (CE, self-supervised) |
| 4. mcc -- MCC of last tx from prefix (CE, self-supervised) |
| |
| Two head implementations: |
| - DownstreamHead -- fresh 2-layer MLP. Doesn't inherit pretrained |
| knowledge stored in the tied LM-head weights. |
| Default for fraud (which has no LM-head analog). |
| - TiedEmbeddingHead -- pool -> adapter -> matmul through the backbone's |
| tied embedding table for the target feature. |
| Recovers the pretrained next-feature signal that |
| a fresh MLP cannot. Use when the target feature |
| has a corresponding weight-tied LM head in |
| pretraining (anything with target_type "feature:N"). |
| |
| Pool strategies: |
| last_tx_mean -- mean of last transaction's 15 positions. These have seen |
| the full sequence (causal masking), so the representation |
| encodes full-sequence context. Used for sequence-level tasks. |
| pre_last_tx -- hidden state at position S - num_features - 1 (end of tx 62). |
| Sees tx 0-62 only, appropriate for predicting tx 63's features |
| without target leakage. Verified safe: both causal attention |
| and left-padded Conv1d respect this boundary. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| @dataclass |
| class HeadConfig: |
| """Single head configuration, loaded from finetune.yaml.""" |
|
|
| name: str |
| output_dim: int |
| loss_type: str |
| pool_strategy: str |
| target_type: str |
| weight: float = 1.0 |
| mlp_hidden: int = 128 |
| dropout: float = 0.1 |
|
|
|
|
| class DownstreamHead(nn.Module): |
| """Pool hidden states -> 2-layer MLP -> task prediction. |
| |
| Param count per head: ~33K (small output) to ~65K (merchant_id, 5003 classes). |
| """ |
|
|
| def __init__(self, config: HeadConfig, hidden_dim: int, num_features: int) -> None: |
| super().__init__() |
| self.config = config |
| self.num_features = num_features |
|
|
| self.mlp = nn.Sequential( |
| nn.Linear(hidden_dim, config.mlp_hidden), |
| nn.ReLU(), |
| nn.Dropout(config.dropout), |
| nn.Linear(config.mlp_hidden, config.output_dim), |
| ) |
|
|
| def pool(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| """(B, S, D) -> (B, D) via head-specific pooling.""" |
| nf = self.num_features |
| if self.config.pool_strategy == "last_tx_mean": |
| return hidden_states[:, -nf:, :].mean(dim=1) |
| if self.config.pool_strategy == "pre_last_tx": |
| return hidden_states[:, -(nf + 1), :] |
| raise ValueError(f"Unknown pool strategy: {self.config.pool_strategy}") |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| return self.mlp(self.pool(hidden_states)) |
|
|
| def compute_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
| if self.config.loss_type == "bce": |
| return F.binary_cross_entropy_with_logits( |
| logits.squeeze(-1), targets.float(), |
| ) |
| return F.cross_entropy(logits, targets) |
|
|
| def extract_targets( |
| self, |
| token_ids: torch.Tensor, |
| sequence_labels: torch.Tensor | None, |
| aux_targets: dict[str, torch.Tensor] | None = None, |
| ) -> torch.Tensor: |
| """Get this head's targets from input data. |
| |
| Args: |
| token_ids: (B, T, F). |
| sequence_labels: (B,) binary fraud labels, or None. |
| aux_targets: optional dict of auxiliary per-sequence targets |
| (e.g. "amount_range" from amount_range_labels.npy). |
| """ |
| if self.config.target_type == "sequence_label": |
| assert sequence_labels is not None |
| return sequence_labels |
| if self.config.target_type == "amount_range": |
| assert aux_targets is not None and "amount_range" in aux_targets |
| return aux_targets["amount_range"] |
| feat_idx = int(self.config.target_type.split(":")[1]) |
| return token_ids[:, -1, feat_idx] |
|
|
|
|
| class TiedEmbeddingHead(nn.Module): |
| """Downstream head that projects through the backbone's tied embedding table. |
| |
| The pretrained LM head for feature F is the value-embedding table for F |
| (weight-tied). The fresh-MLP DownstreamHead discards that projection by |
| learning a new one from scratch. This head preserves it: pool the backbone |
| hidden states, run a small adapter, then matmul through the same embedding |
| table the backbone reads from at input time. |
| |
| Why this matters: |
| - At fine-tune time the adapter learns a small distribution shift, not |
| a 256-d -> vocab_size projection from scratch. |
| - Gradients flow back to the embedding table from both the input-embed |
| side and this head, just as during pretraining. |
| - When the backbone is frozen, the embedding table is frozen too, so |
| this head reduces to "small adapter on top of pretrained features." |
| |
| Constraints: |
| - Only valid for target_type "feature:N". The tied table is keyed to |
| a specific feature index. |
| - Output dim is implicitly the vocab_size of that feature. |
| """ |
|
|
| def __init__( |
| self, |
| config: HeadConfig, |
| hidden_dim: int, |
| num_features: int, |
| value_tables: nn.ModuleList, |
| ) -> None: |
| super().__init__() |
| self.config = config |
| self.num_features = num_features |
|
|
| if not config.target_type.startswith("feature:"): |
| raise ValueError( |
| f"TiedEmbeddingHead requires target_type 'feature:N', got " |
| f"{config.target_type!r}", |
| ) |
| self.feature_idx = int(config.target_type.split(":")[1]) |
| |
| |
| |
| self.value_tables = value_tables |
|
|
| |
| |
| |
| self.adapter = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(config.dropout), |
| nn.Linear(hidden_dim, hidden_dim), |
| ) |
|
|
| def pool(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| nf = self.num_features |
| if self.config.pool_strategy == "last_tx_mean": |
| return hidden_states[:, -nf:, :].mean(dim=1) |
| if self.config.pool_strategy == "pre_last_tx": |
| return hidden_states[:, -(nf + 1), :] |
| raise ValueError(f"Unknown pool strategy: {self.config.pool_strategy}") |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| pooled = self.pool(hidden_states) |
| adapted = self.adapter(pooled) |
| |
| |
| |
| weight = self.value_tables[self.feature_idx].weight |
| return F.linear(adapted, weight) |
|
|
| def compute_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
| |
| return F.cross_entropy(logits, targets) |
|
|
| def extract_targets( |
| self, |
| token_ids: torch.Tensor, |
| sequence_labels: torch.Tensor | None, |
| aux_targets: dict[str, torch.Tensor] | None = None, |
| ) -> torch.Tensor: |
| return token_ids[:, -1, self.feature_idx] |
|
|
|
|
| |
| |
| AnyHead = DownstreamHead | TiedEmbeddingHead |
|
|
|
|
| class MultiHeadModel(nn.Module): |
| """Pretrained backbone + downstream task heads. |
| |
| Calls backbone.backbone_forward() once, then each head pools and |
| predicts independently. Backbone may be frozen or slow-learned. |
| """ |
|
|
| def __init__( |
| self, backbone: nn.Module, heads: dict[str, AnyHead], |
| ) -> None: |
| super().__init__() |
| self.backbone = backbone |
| self.heads = nn.ModuleDict(heads) |
|
|
| def forward(self, token_ids: torch.Tensor) -> dict[str, torch.Tensor]: |
| hidden = self.backbone.backbone_forward(token_ids) |
| return {name: head(hidden) for name, head in self.heads.items()} |
|
|
| def compute_losses( |
| self, |
| predictions: dict[str, torch.Tensor], |
| token_ids: torch.Tensor, |
| sequence_labels: torch.Tensor | None, |
| aux_targets: dict[str, torch.Tensor] | None = None, |
| ) -> tuple[torch.Tensor, dict[str, float]]: |
| """Weighted sum of per-head losses.""" |
| device = next(iter(predictions.values())).device |
| total = torch.tensor(0.0, device=device) |
| per_head: dict[str, float] = {} |
|
|
| for name, head in self.heads.items(): |
| targets = head.extract_targets( |
| token_ids, sequence_labels, aux_targets, |
| ).to(device) |
| loss = head.compute_loss(predictions[name], targets) |
| total = total + head.config.weight * loss |
| per_head[name] = loss.item() |
|
|
| return total, per_head |
|
|
| def head_param_count(self) -> dict[str, int]: |
| """Parameter count per head (excludes shared backbone).""" |
| return { |
| name: sum(p.numel() for p in head.parameters()) |
| for name, head in self.heads.items() |
| } |
|
|