"""Encoder-specific downstream head subclass with multi-position pooling. Extends parent's `DownstreamHead` with one new pool strategy: "pre_last_tx_mean" — pool the entire tx_(T-1) stripe (last num_features positions BEFORE tx_T), then mean across positions. Why this exists: The encoder non-compress mode emits 15 pseudo-tokens per transaction. The parent's `pre_last_tx` strategy reads a SINGLE position `hidden[:, -(num_features+1), :]` — for our T=960, num_features=15 layout that's position 944 = the last feature of tx 62 (`customer_tenure` per schema ordering). That single position has: - Full causal context through tx 62 (good — attention has propagated everything relevant). - Direct semantics around `customer_tenure` (its value table and type-embedding offset). For predicting tx 63's mcc / merchant / amount, the `customer_tenure`-anchored representation is suboptimal. The first nocompress run showed exactly this: fraud ROC-AUC 0.96 (great), but auxiliary head top-1 noticeably below the compress run's holistic-tx summary. The information is in the stripe; we're just reading it through the wrong feature. `pre_last_tx_mean` pools across all 15 positions of tx 62 — a holistic tx-62 summary, much closer to what the compress run's single per-tx pseudo-token represented. When `pre_last_tx` collapses to the same thing: For `num_features=1` (compress mode), `pre_last_tx` reads `hidden[:, -2, :]` and `pre_last_tx_mean` reads `mean(hidden[:, -2:-1, :])` = same single position. So compress mode can use either strategy without behavior change; we still default to `pre_last_tx` there for clarity. """ from __future__ import annotations import torch import torch.nn as nn from src.model.task_heads import DownstreamHead, HeadConfig, TiedEmbeddingHead class EncoderDownstreamHead(DownstreamHead): """DownstreamHead with `pre_last_tx_mean` pool strategy added. Forward + extract_targets + compute_loss are all inherited unchanged — only `pool()` is overridden to handle the new strategy. """ def __init__( self, config: HeadConfig, hidden_dim: int, num_features: int, ) -> None: super().__init__(config, hidden_dim, num_features) def pool(self, hidden_states: torch.Tensor) -> torch.Tensor: """(B, S, D) → (B, D) via head-specific pooling. Adds `pre_last_tx_mean` to parent's strategies. Falls through to parent for `last_tx_mean` and `pre_last_tx`. """ nf = self.num_features if self.config.pool_strategy == "pre_last_tx_mean": # tx_(T-1) stripe = positions [-2*nf, -nf). For T=960 nf=15: # positions 930..944 inclusive (the 15 features of tx 62). # For T=64 nf=1: position 62 only (single-element mean). stripe = hidden_states[:, -(2 * nf):-nf, :] return stripe.mean(dim=1) return super().pool(hidden_states) class EncoderTiedEmbeddingHead(TiedEmbeddingHead): """TiedEmbeddingHead with `pre_last_tx_mean` pool strategy added. Why this exists for the encoder: The encoder's per-feature value tables (e.g. merchant_id at `vocab_size=10003, dim=d_lfm=1024`) are exactly the right projection matrix for a high-cardinality classifier head. Instead of learning a fresh `Linear(128, 10003)` from a 128-dim bottleneck — which gave only 7.78% top-1 — we share weights with the encoder's merchant_id table. This recovers the pattern parent observed: a tied embedding head on merchant_id gave 20.8% top-1 vs 13.8% for the non-tied head. What this changes mechanically: - The classifier matrix `weight` is the encoder's value table (shape `(vocab_size, d_lfm)`). - Forward pass: pool hidden → adapter MLP (d_lfm → d_lfm → d_lfm) → matmul through the value table → logits over that feature's vocab. - Gradients flow back into the value table from BOTH the input- embedding side (during forward of every transaction) AND this head (during loss), keeping the table consistent. Only differs from parent's TiedEmbeddingHead in the supported pool strategies — adds `pre_last_tx_mean` for the encoder's 15-token-per-tx layout. Constraints inherited from parent: - `target_type` must start with "feature:N" (the value-table dimension this head is tied to). - `output_dim` is implicitly the vocab_size of that feature; the head config's `output_dim` is unused. """ def __init__( self, config: HeadConfig, hidden_dim: int, num_features: int, value_tables: nn.ModuleList, ) -> None: super().__init__(config, hidden_dim, num_features, value_tables) def pool(self, hidden_states: torch.Tensor) -> torch.Tensor: nf = self.num_features if self.config.pool_strategy == "pre_last_tx_mean": stripe = hidden_states[:, -(2 * nf):-nf, :] return stripe.mean(dim=1) return super().pool(hidden_states)