lfm2-transaction-encoder / encoder /src /model /encoder_heads.py
cdotsanghvi's picture
add multi-head demo as 4th-6th tabs; restore Why Liquid + Integration
083b138
Raw
History Blame Contribute Delete
5.27 kB
"""Encoder-specific downstream head subclass with multi-position pooling.
Extends parent's `DownstreamHead` with one new pool strategy:
"pre_last_tx_mean" — pool the entire tx_(T-1) stripe (last num_features
positions BEFORE tx_T), then mean across positions.
Why this exists:
The encoder non-compress mode emits 15 pseudo-tokens per transaction.
The parent's `pre_last_tx` strategy reads a SINGLE position
`hidden[:, -(num_features+1), :]` — for our T=960, num_features=15
layout that's position 944 = the last feature of tx 62
(`customer_tenure` per schema ordering).
That single position has:
- Full causal context through tx 62 (good — attention has propagated
everything relevant).
- Direct semantics around `customer_tenure` (its value table and
type-embedding offset).
For predicting tx 63's mcc / merchant / amount, the
`customer_tenure`-anchored representation is suboptimal. The first
nocompress run showed exactly this: fraud ROC-AUC 0.96 (great), but
auxiliary head top-1 noticeably below the compress run's holistic-tx
summary. The information is in the stripe; we're just reading it
through the wrong feature.
`pre_last_tx_mean` pools across all 15 positions of tx 62 — a
holistic tx-62 summary, much closer to what the compress run's
single per-tx pseudo-token represented.
When `pre_last_tx` collapses to the same thing:
For `num_features=1` (compress mode), `pre_last_tx` reads
`hidden[:, -2, :]` and `pre_last_tx_mean` reads
`mean(hidden[:, -2:-1, :])` = same single position. So compress mode
can use either strategy without behavior change; we still default to
`pre_last_tx` there for clarity.
"""
from __future__ import annotations
import torch
import torch.nn as nn
from src.model.task_heads import DownstreamHead, HeadConfig, TiedEmbeddingHead
class EncoderDownstreamHead(DownstreamHead):
"""DownstreamHead with `pre_last_tx_mean` pool strategy added.
Forward + extract_targets + compute_loss are all inherited unchanged —
only `pool()` is overridden to handle the new strategy.
"""
def __init__(
self,
config: HeadConfig,
hidden_dim: int,
num_features: int,
) -> None:
super().__init__(config, hidden_dim, num_features)
def pool(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""(B, S, D) → (B, D) via head-specific pooling.
Adds `pre_last_tx_mean` to parent's strategies. Falls through to
parent for `last_tx_mean` and `pre_last_tx`.
"""
nf = self.num_features
if self.config.pool_strategy == "pre_last_tx_mean":
# tx_(T-1) stripe = positions [-2*nf, -nf). For T=960 nf=15:
# positions 930..944 inclusive (the 15 features of tx 62).
# For T=64 nf=1: position 62 only (single-element mean).
stripe = hidden_states[:, -(2 * nf):-nf, :]
return stripe.mean(dim=1)
return super().pool(hidden_states)
class EncoderTiedEmbeddingHead(TiedEmbeddingHead):
"""TiedEmbeddingHead with `pre_last_tx_mean` pool strategy added.
Why this exists for the encoder:
The encoder's per-feature value tables (e.g. merchant_id at
`vocab_size=10003, dim=d_lfm=1024`) are exactly the right
projection matrix for a high-cardinality classifier head. Instead
of learning a fresh `Linear(128, 10003)` from a 128-dim
bottleneck — which gave only 7.78% top-1 — we share weights with
the encoder's merchant_id table. This recovers the pattern parent
observed: a tied embedding head on merchant_id gave 20.8% top-1
vs 13.8% for the non-tied head.
What this changes mechanically:
- The classifier matrix `weight` is the encoder's value table
(shape `(vocab_size, d_lfm)`).
- Forward pass: pool hidden → adapter MLP (d_lfm → d_lfm → d_lfm)
→ matmul through the value table → logits over that feature's
vocab.
- Gradients flow back into the value table from BOTH the input-
embedding side (during forward of every transaction) AND this
head (during loss), keeping the table consistent.
Only differs from parent's TiedEmbeddingHead in the supported pool
strategies — adds `pre_last_tx_mean` for the encoder's 15-token-per-tx
layout.
Constraints inherited from parent:
- `target_type` must start with "feature:N" (the value-table
dimension this head is tied to).
- `output_dim` is implicitly the vocab_size of that feature; the
head config's `output_dim` is unused.
"""
def __init__(
self,
config: HeadConfig,
hidden_dim: int,
num_features: int,
value_tables: nn.ModuleList,
) -> None:
super().__init__(config, hidden_dim, num_features, value_tables)
def pool(self, hidden_states: torch.Tensor) -> torch.Tensor:
nf = self.num_features
if self.config.pool_strategy == "pre_last_tx_mean":
stripe = hidden_states[:, -(2 * nf):-nf, :]
return stripe.mean(dim=1)
return super().pool(hidden_states)