| """Encoder-specific downstream head subclass with multi-position pooling. |
| |
| Extends parent's `DownstreamHead` with one new pool strategy: |
| |
| "pre_last_tx_mean" — pool the entire tx_(T-1) stripe (last num_features |
| positions BEFORE tx_T), then mean across positions. |
| |
| Why this exists: |
| |
| The encoder non-compress mode emits 15 pseudo-tokens per transaction. |
| The parent's `pre_last_tx` strategy reads a SINGLE position |
| `hidden[:, -(num_features+1), :]` — for our T=960, num_features=15 |
| layout that's position 944 = the last feature of tx 62 |
| (`customer_tenure` per schema ordering). |
| |
| That single position has: |
| - Full causal context through tx 62 (good — attention has propagated |
| everything relevant). |
| - Direct semantics around `customer_tenure` (its value table and |
| type-embedding offset). |
| |
| For predicting tx 63's mcc / merchant / amount, the |
| `customer_tenure`-anchored representation is suboptimal. The first |
| nocompress run showed exactly this: fraud ROC-AUC 0.96 (great), but |
| auxiliary head top-1 noticeably below the compress run's holistic-tx |
| summary. The information is in the stripe; we're just reading it |
| through the wrong feature. |
| |
| `pre_last_tx_mean` pools across all 15 positions of tx 62 — a |
| holistic tx-62 summary, much closer to what the compress run's |
| single per-tx pseudo-token represented. |
| |
| When `pre_last_tx` collapses to the same thing: |
| |
| For `num_features=1` (compress mode), `pre_last_tx` reads |
| `hidden[:, -2, :]` and `pre_last_tx_mean` reads |
| `mean(hidden[:, -2:-1, :])` = same single position. So compress mode |
| can use either strategy without behavior change; we still default to |
| `pre_last_tx` there for clarity. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from src.model.task_heads import DownstreamHead, HeadConfig, TiedEmbeddingHead |
|
|
|
|
| class EncoderDownstreamHead(DownstreamHead): |
| """DownstreamHead with `pre_last_tx_mean` pool strategy added. |
| |
| Forward + extract_targets + compute_loss are all inherited unchanged — |
| only `pool()` is overridden to handle the new strategy. |
| """ |
|
|
| def __init__( |
| self, |
| config: HeadConfig, |
| hidden_dim: int, |
| num_features: int, |
| ) -> None: |
| super().__init__(config, hidden_dim, num_features) |
|
|
| def pool(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| """(B, S, D) → (B, D) via head-specific pooling. |
| |
| Adds `pre_last_tx_mean` to parent's strategies. Falls through to |
| parent for `last_tx_mean` and `pre_last_tx`. |
| """ |
| nf = self.num_features |
| if self.config.pool_strategy == "pre_last_tx_mean": |
| |
| |
| |
| stripe = hidden_states[:, -(2 * nf):-nf, :] |
| return stripe.mean(dim=1) |
| return super().pool(hidden_states) |
|
|
|
|
| class EncoderTiedEmbeddingHead(TiedEmbeddingHead): |
| """TiedEmbeddingHead with `pre_last_tx_mean` pool strategy added. |
| |
| Why this exists for the encoder: |
| |
| The encoder's per-feature value tables (e.g. merchant_id at |
| `vocab_size=10003, dim=d_lfm=1024`) are exactly the right |
| projection matrix for a high-cardinality classifier head. Instead |
| of learning a fresh `Linear(128, 10003)` from a 128-dim |
| bottleneck — which gave only 7.78% top-1 — we share weights with |
| the encoder's merchant_id table. This recovers the pattern parent |
| observed: a tied embedding head on merchant_id gave 20.8% top-1 |
| vs 13.8% for the non-tied head. |
| |
| What this changes mechanically: |
| |
| - The classifier matrix `weight` is the encoder's value table |
| (shape `(vocab_size, d_lfm)`). |
| - Forward pass: pool hidden → adapter MLP (d_lfm → d_lfm → d_lfm) |
| → matmul through the value table → logits over that feature's |
| vocab. |
| - Gradients flow back into the value table from BOTH the input- |
| embedding side (during forward of every transaction) AND this |
| head (during loss), keeping the table consistent. |
| |
| Only differs from parent's TiedEmbeddingHead in the supported pool |
| strategies — adds `pre_last_tx_mean` for the encoder's 15-token-per-tx |
| layout. |
| |
| Constraints inherited from parent: |
| - `target_type` must start with "feature:N" (the value-table |
| dimension this head is tied to). |
| - `output_dim` is implicitly the vocab_size of that feature; the |
| head config's `output_dim` is unused. |
| """ |
|
|
| def __init__( |
| self, |
| config: HeadConfig, |
| hidden_dim: int, |
| num_features: int, |
| value_tables: nn.ModuleList, |
| ) -> None: |
| super().__init__(config, hidden_dim, num_features, value_tables) |
|
|
| def pool(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| nf = self.num_features |
| if self.config.pool_strategy == "pre_last_tx_mean": |
| stripe = hidden_states[:, -(2 * nf):-nf, :] |
| return stripe.mean(dim=1) |
| return super().pool(hidden_states) |
|
|