lfm2-transaction-encoder / encoder /src /model /transaction_encoder.py
cdotsanghvi's picture
initial transaction co-pilot deployment
b3112c7
Raw
History Blame Contribute Delete
24.3 kB
"""Per-transaction encoder.
For each transaction independently, looks up per-feature embeddings, concats
them, and projects through a small MLP to a single `d_encoder`-dim vector.
The encoder produces one continuous embedding per transaction — analogous to
how LFM2.5-Audio's FastConformer produces one continuous frame per audio
window. Cross-transaction reasoning is delegated entirely to the LFM2.5
backbone.
Shape contract:
(B, T_tx, F) int64 → (B, T_tx, d_encoder) float
Design choices and their reasoning:
- **Concat (not sum) over features.** Concat preserves feature identity in
early layers; the MLP can collapse to a sum if it wants but cannot recover
what sum already destroyed. ADR default; research confirmed.
- **Per-feature embedding dim `d_feat=32`.** Concat width 15*32 = 480 keeps
the first Linear modestly sized. Smaller `d_feat` (e.g. 16) makes the
high-vocab features (merchant_id with 10003 values) under-parameterized;
larger d_feat balloons the merchant_id table and dominates parameter count.
32 is the elbow.
- **2-layer MLP with SiLU.** Same activation as LFM2's SwiGLU MLPs in the
backbone — keeps the encoder feeling like a continuation of the LFM
architecture rather than a foreign module. (The downstream projector uses
GELU to match LFM2-VL's `Lfm2VlMultiModalProjector` exactly.)
- **Embedding tables not initialized from any pretrained source.** The
parent's structured-feature backbone has per-feature value tables, but
those were trained on a different objective and at a different `d_hidden`.
Reusing them would be a science experiment we haven't designed; starting
from a fresh init keeps the comparison clean.
"""
from __future__ import annotations
import torch
import torch.nn as nn
from src.data.schema import SchemaConfig
class TransactionEncoder(nn.Module):
"""Per-feature embeddings → concat → MLP per-transaction.
Args:
schema: parent's SchemaConfig (defines `num_features` and per-feature
`vocab_size`).
d_feat: per-feature embedding dimension. Default 32 balances
small-vocab and large-vocab features.
d_encoder: output dimension. Default 256 matches the parent
backbone's `d_hidden` so feature-level expressivity is
comparable across the head-to-head.
mlp_hidden: intermediate hidden size of the 2-layer encoder MLP.
Forward:
feature_ids: (B, T_tx, F) int64
returns: (B, T_tx, d_encoder) float (dtype matches embedding tables)
"""
def __init__(
self,
schema: SchemaConfig,
d_feat: int = 32,
d_encoder: int = 256,
mlp_hidden: int = 384,
enable_collections_markers: bool = False,
enable_fraud_markers: bool = False,
) -> None:
super().__init__()
self.num_features = schema.num_features
self.d_feat = d_feat
self.d_encoder = d_encoder
self.enable_collections_markers = enable_collections_markers
self.enable_fraud_markers = enable_fraud_markers
# One embedding table per feature, sized to that feature's full vocab
# (includes MASK/OOV/NULL reserved tokens). We use padding_idx=None
# so all tokens get learned embeddings — MASK is a real token in this
# data (transaction generator never emits MASK as a real value, but
# tokenization-time NULL handling can; we let the table learn it).
self.feature_embeddings = nn.ModuleList(
[nn.Embedding(f.vocab_size, d_feat) for f in schema.features],
)
concat_dim = self.num_features * d_feat
# MLP applied independently per transaction (nn.Linear broadcasts
# over leading dims). No layer norm here — LayerNorm comes in the
# downstream ProjectionAdapter to match LFM2-VL's projector shape.
self.mlp = nn.Sequential(
nn.Linear(concat_dim, mlp_hidden),
nn.SiLU(),
nn.Linear(mlp_hidden, d_encoder),
)
# Learnable "disputed transaction" marker, added to the encoded
# pseudo-token at the disputed position. Required for the
# dispute-legitimacy task — the label, attribution, and LM
# reasoning are all conditional on which transaction is being
# disputed, and that signal needs to be visible to the model.
# Initialized to zeros so day-1 behavior matches a model without
# the marker; training learns to populate it.
self.disputed_marker = nn.Parameter(torch.zeros(d_encoder))
# Cross-position signals made local at the disputed position.
# The label rule depends on (a) how many times the disputed
# merchant appears in the customer's history (subscription
# pattern) and (b) whether the disputed country appears
# elsewhere in history (exotic-country signal). Both are
# cross-position properties that a 350M LFM2.5 with LoRA
# struggles to learn on 4-5K examples (data-distribution-
# doctrine §3 pathology 7: "no amount of rebalancing can teach
# a model to see signals that aren't easily readable from the
# input"). We make them readable by computing them in the
# encoder forward and scaling learnable bias vectors by the
# resulting scalar scores at the disputed position.
self.subscription_marker = nn.Parameter(torch.zeros(d_encoder))
self.exotic_country_marker = nn.Parameter(torch.zeros(d_encoder))
# Feature column indices for the cross-position computations.
# These match the schema in data/schema.yaml:
# merchant_id: index 5
# country: index 10
# We keep them as buffers (not parameters) so they ship with
# the checkpoint as integers.
self.register_buffer(
"_merchant_feature_idx", torch.tensor(5, dtype=torch.long),
)
self.register_buffer(
"_country_feature_idx", torch.tensor(10, dtype=torch.long),
)
# Normalizer for subscription_score: divide raw merchant_match
# count (after subtracting the disputed position itself) by 10.
# A count of 10+ same-merchant transactions saturates the score
# at 1.0; lower counts produce linear partial scores. The
# rationale: the synthesizer's subscription threshold is >=5
# occurrences, so 4-5 of those plus the disputed produces a
# score around 0.4-0.5, deeply-subscribed customers (10+) max
# out at 1.0.
self._subscription_norm = 10.0
# ----- Collections-specific markers (opt-in via flag) -----
#
# Same lesson 2 / Ottoguard pattern as the dispute markers: the
# collections label rule depends on cross-position signals
# (recent velocity, subscription burden, merchant diversity,
# large-amount count, spending volatility) that a 350M backbone
# can't reliably learn from ~5K examples. Each marker is a
# learnable d_encoder bias scaled by the computed signal at
# the "position of interest" (context_idx, = 63 in collections v1).
#
# Markers are constructed unconditionally so the parameter
# count is stable across surfaces; the forward path only adds
# them when enable_collections_markers=True. For dispute, the
# flag is False and the markers stay at their zero init —
# behaviorally identical to the pre-collections encoder.
self.velocity_marker = nn.Parameter(torch.zeros(d_encoder))
self.sub_burden_marker = nn.Parameter(torch.zeros(d_encoder))
self.merchant_diversity_marker = nn.Parameter(torch.zeros(d_encoder))
self.large_amount_marker = nn.Parameter(torch.zeros(d_encoder))
self.volatility_marker = nn.Parameter(torch.zeros(d_encoder))
# Feature column indices for the collections signals (must
# match data/schema.yaml).
self.register_buffer(
"_days_since_last_idx", torch.tensor(2, dtype=torch.long),
)
self.register_buffer(
"_is_recurring_idx", torch.tensor(3, dtype=torch.long),
)
self.register_buffer(
"_amount_idx", torch.tensor(8, dtype=torch.long),
)
# Reserved-offset for is_recurring=True (token 4 = recurring).
# Amount-threshold for "large" (token >= 153 ≈ $150+).
self._is_recurring_true_token = 4
self._amount_large_threshold = 153
# Velocity normalizer: mean days_since_last token in the last 16
# positions; p90 of population is ~11 (token-space). We divide
# by 14 so very-active customers (token ~3) → ~0.2 and
# genuinely-dormant customers (token ~14+) → ~1.0.
self._velocity_norm = 14.0
# Sub-burden normalizer: count of recurring positions in 64-tx
# history; p90 ~18. We divide by 20.
self._sub_burden_norm = 20.0
# Merchant diversity normalizer: unique merchants in history;
# p90 ~19, max 64. We divide by 20.
self._diversity_norm = 20.0
# Large-amount normalizer: count of amount-token >= threshold;
# p90 ~23. We divide by 25.
self._large_amount_norm = 25.0
# Volatility normalizer: std of amount tokens; p90 ~70. We
# divide by 80.
self._volatility_norm = 80.0
# Recent window for velocity computation.
self._recent_window = 16
# ----- Fraud-pattern markers (opt-in via flag) -----
# The fraud-pattern label rule depends on cross-position
# properties at the flagged transaction:
# probe_density: small-CNP cluster preceding flagged
# post_attack_density: large-unfamiliar cluster around flagged
# novel_device: flagged device_hash appears nowhere else
# signature_clean: country/CVV/AVS/merchant all match customer's normal
# recent_authorize_density: CNP-to-unfamiliar in last 16 (scam pattern)
# Each gets a learnable d_encoder bias scaled at the flagged
# position. Zero-initialized so day-1 behavior is identical to
# a marker-less encoder; training learns to populate.
self.probe_cluster_marker = nn.Parameter(torch.zeros(d_encoder))
self.post_attack_marker = nn.Parameter(torch.zeros(d_encoder))
self.novel_device_marker = nn.Parameter(torch.zeros(d_encoder))
self.signature_clean_marker = nn.Parameter(torch.zeros(d_encoder))
self.recent_authorize_marker = nn.Parameter(torch.zeros(d_encoder))
# Feature column indices for fraud signals (match data/schema.yaml).
self.register_buffer(
"_entry_mode_idx", torch.tensor(7, dtype=torch.long),
)
self.register_buffer(
"_avs_idx", torch.tensor(11, dtype=torch.long),
)
self.register_buffer(
"_cvv_idx", torch.tensor(12, dtype=torch.long),
)
self.register_buffer(
"_device_hash_idx", torch.tensor(13, dtype=torch.long),
)
# Probe-density window: PROBE_WINDOW = 6 tx before flagged.
self._probe_window = 6
# Post-attack window: 6 tx from flagged forward.
self._post_attack_window = 6
# Token thresholds (mirror synthetic_fraud_pattern.py constants).
self._entry_cnp_token = 4 # ENTRY_CNP = RESERVED_OFFSET + 1
self._amount_small_token = 11 # AMOUNT_SMALL_THRESH
self._amount_large_token = 153 # AMOUNT_LARGE_THRESH
self._cmc_unfamiliar_thresh = 4 # CMC_UNFAMILIAR + 1
self._cmc_familiar_thresh = 8 # CMC_FAMILIAR
self._cvv_match_token = 3
self._avs_match_token = 3
# Normalizers for score → [0, 1].
self._probe_norm = 6.0
self._post_attack_norm = 6.0
self._recent_authorize_norm = 16.0
self._recent_authorize_window = 16
def forward(
self,
feature_ids: torch.Tensor,
disputed_idx: torch.Tensor | None = None,
) -> torch.Tensor:
"""Encode each transaction into a pseudo-token.
Args:
feature_ids: (B, T_tx, F) int64 feature ids per transaction.
disputed_idx: optional (B,) int64 — position of the disputed
transaction per batch element. When provided, the
learnable disputed marker is added to the encoded output
at that position. None for tasks that have no notion of
a disputed transaction.
Returns:
(B, T_tx, d_encoder) float — one pseudo-token per transaction.
"""
# feature_ids: (B, T_tx, F) int64
# We embed each feature column with its own table, then concat.
# An alternative is to use a single nn.Embedding with a globally
# offset vocab — that's a single kernel but loses per-feature
# vocab boundaries and complicates init scale. Per-feature tables
# is the LFM-pattern-aligned choice (structured-feature backbone
# uses the same pattern).
per_feat = [
self.feature_embeddings[f](feature_ids[:, :, f])
for f in range(self.num_features)
]
# Each entry: (B, T_tx, d_feat)
x = torch.cat(per_feat, dim=-1)
# x: (B, T_tx, F * d_feat)
encoded = self.mlp(x) # (B, T_tx, d_encoder)
# Add the disputed marker at the disputed position, plus two
# cross-position bias markers scaled by computed signals.
if disputed_idx is not None:
B = encoded.shape[0]
batch_idx = torch.arange(B, device=encoded.device)
# --- subscription_score: count of disputed-merchant matches
# in the customer's history, normalized to ~[0, 1]. The
# disputed position itself is subtracted so the score
# reflects PRIOR usage, not the disputed transaction's own
# merchant appearing once.
m_idx = int(self._merchant_feature_idx.item())
disp_merchants = feature_ids[batch_idx, disputed_idx, m_idx] # (B,)
merchant_matches = (
feature_ids[:, :, m_idx] == disp_merchants.unsqueeze(1)
).to(encoded.dtype) # (B, T_tx)
sub_score = (merchant_matches.sum(dim=1) - 1.0) / self._subscription_norm
sub_score = sub_score.clamp(min=0.0, max=1.0) # (B,)
# --- exotic_country_score: 1.0 if the disputed country
# appears nowhere else in the customer's history, else 0.0.
c_idx = int(self._country_feature_idx.item())
disp_countries = feature_ids[batch_idx, disputed_idx, c_idx]
country_matches = (
feature_ids[:, :, c_idx] == disp_countries.unsqueeze(1)
).to(encoded.dtype) # (B, T_tx)
country_appearances = country_matches.sum(dim=1) # (B,)
exotic_score = (country_appearances <= 1.0).to(encoded.dtype) # (B,)
# Apply all three biases at the disputed position. Each bias
# is a learnable d_encoder vector; for the scaled markers
# the score scales the bias's magnitude per batch element.
base = encoded[batch_idx, disputed_idx] # (B, d_encoder)
base = base + self.disputed_marker
base = base + self.subscription_marker * sub_score.unsqueeze(-1)
base = base + self.exotic_country_marker * exotic_score.unsqueeze(-1)
# ----- Collections-specific markers (opt-in) -----
if self.enable_collections_markers:
# recent_velocity: mean days_since_last token over the
# last RECENT_WINDOW positions. Higher = more dormant.
dsl_idx = int(self._days_since_last_idx.item())
recent_dsl = feature_ids[
:, -self._recent_window:, dsl_idx,
].to(encoded.dtype)
velocity_raw = recent_dsl.mean(dim=1) # (B,)
velocity_score = (velocity_raw / self._velocity_norm).clamp(
min=0.0, max=1.0,
)
# subscription_burden: count of is_recurring=1 positions
ir_idx = int(self._is_recurring_idx.item())
sub_burden_raw = (
feature_ids[:, :, ir_idx] == self._is_recurring_true_token
).to(encoded.dtype).sum(dim=1) # (B,)
sub_burden_score = (
sub_burden_raw / self._sub_burden_norm
).clamp(min=0.0, max=1.0)
# merchant_diversity: unique merchants per customer. We
# approximate with a histogram via one-hot bincount. For
# batches this is most simply computed by sorting +
# counting transitions; here we use a softer proxy
# (entropy across positions is expensive). Use the count
# of positions whose merchant_id differs from the prior
# position — a fast diversity proxy. Equals
# (unique-1) for monotonically-changing sequences and
# 0 for all-same merchant.
merchants = feature_ids[:, :, m_idx] # (B, T_tx)
changes = (merchants[:, 1:] != merchants[:, :-1]).to(encoded.dtype)
diversity_raw = changes.sum(dim=1) # (B,)
diversity_score = (
diversity_raw / self._diversity_norm
).clamp(min=0.0, max=1.0)
# large_amount_count: count of amount tokens >= threshold
amt_idx = int(self._amount_idx.item())
large_amt_raw = (
feature_ids[:, :, amt_idx] >= self._amount_large_threshold
).to(encoded.dtype).sum(dim=1)
large_amt_score = (
large_amt_raw / self._large_amount_norm
).clamp(min=0.0, max=1.0)
# spending_volatility: std of amount tokens. Computed
# in encoded.dtype to keep bf16/fp32 dtype consistency.
amounts = feature_ids[:, :, amt_idx].to(encoded.dtype)
vol_raw = amounts.std(dim=1, unbiased=False) # (B,)
volatility_score = (
vol_raw / self._volatility_norm
).clamp(min=0.0, max=1.0)
# Apply at the position of interest.
base = base + self.velocity_marker * velocity_score.unsqueeze(-1)
base = base + self.sub_burden_marker * sub_burden_score.unsqueeze(-1)
base = base + self.merchant_diversity_marker * diversity_score.unsqueeze(-1)
base = base + self.large_amount_marker * large_amt_score.unsqueeze(-1)
base = base + self.volatility_marker * volatility_score.unsqueeze(-1)
# ----- Fraud-pattern markers (opt-in) -----
if self.enable_fraud_markers:
em_idx = int(self._entry_mode_idx.item())
amt_idx = int(self._amount_idx.item())
dev_idx = int(self._device_hash_idx.item())
cvv_idx = int(self._cvv_idx.item())
avs_idx = int(self._avs_idx.item())
cmc_idx = int(self._merchant_feature_idx.item()) # merchant_id
cmc_count_idx = 6 # FEATURE_CUSTOMER_MERCHANT_COUNT
country_idx = int(self._country_feature_idx.item())
B = encoded.shape[0]
# Generate a position grid for window masking.
# positions: (T_tx,) → broadcast for windows around disputed_idx.
T_tx = feature_ids.shape[1]
positions = torch.arange(T_tx, device=encoded.device)
# (B, T_tx) — True for positions in [disputed-PROBE_WINDOW, disputed)
pre_mask = (
(positions.unsqueeze(0) >= (disputed_idx - self._probe_window).unsqueeze(1))
& (positions.unsqueeze(0) < disputed_idx.unsqueeze(1))
)
# (B, T_tx) — True for positions in [disputed, disputed+POST_WINDOW)
post_mask = (
(positions.unsqueeze(0) >= disputed_idx.unsqueeze(1))
& (positions.unsqueeze(0) < (disputed_idx + self._post_attack_window).unsqueeze(1))
)
# (B, T_tx) — True for positions in [disputed-RECENT_AUTHORIZE_WINDOW, disputed+1)
recent_mask = (
(positions.unsqueeze(0) >= (disputed_idx - self._recent_authorize_window).unsqueeze(1))
& (positions.unsqueeze(0) <= disputed_idx.unsqueeze(1))
)
# --- probe_density: small-CNP count in pre_mask
is_cnp = (feature_ids[:, :, em_idx] == self._entry_cnp_token).to(encoded.dtype)
is_small = (feature_ids[:, :, amt_idx] <= self._amount_small_token).to(encoded.dtype)
probe_raw = (is_cnp * is_small * pre_mask.to(encoded.dtype)).sum(dim=1)
probe_score = (probe_raw / self._probe_norm).clamp(min=0.0, max=1.0)
# --- post_attack_density: large-unfamiliar count in post_mask
is_large = (feature_ids[:, :, amt_idx] >= self._amount_large_token).to(encoded.dtype)
is_unfamiliar = (
feature_ids[:, :, cmc_count_idx] <= self._cmc_unfamiliar_thresh
).to(encoded.dtype)
post_raw = (is_large * is_unfamiliar * post_mask.to(encoded.dtype)).sum(dim=1)
post_score = (post_raw / self._post_attack_norm).clamp(min=0.0, max=1.0)
# --- novel_device: device at flagged appears nowhere else
flagged_device = feature_ids[batch_idx, disputed_idx, dev_idx] # (B,)
device_matches = (
feature_ids[:, :, dev_idx] == flagged_device.unsqueeze(1)
).to(encoded.dtype).sum(dim=1) # (B,)
novel_device_score = (device_matches <= 1.0).to(encoded.dtype)
# --- signature_clean: country = mode + CVV match + AVS match + familiar merchant
countries = feature_ids[:, :, country_idx] # (B, T_tx)
# Mode country per row via torch.mode (returns values + indices).
mode_country = countries.mode(dim=1).values # (B,)
flagged_country = feature_ids[batch_idx, disputed_idx, country_idx]
country_match = (flagged_country == mode_country).to(encoded.dtype)
cvv_match = (
feature_ids[batch_idx, disputed_idx, cvv_idx] == self._cvv_match_token
).to(encoded.dtype)
avs_match = (
feature_ids[batch_idx, disputed_idx, avs_idx] == self._avs_match_token
).to(encoded.dtype)
merchant_familiar = (
feature_ids[batch_idx, disputed_idx, cmc_count_idx]
>= self._cmc_familiar_thresh
).to(encoded.dtype)
sig_clean_score = country_match * cvv_match * avs_match * merchant_familiar
# --- recent_authorize_density: CNP × unfamiliar in recent_mask
recent_auth_raw = (
is_cnp * is_unfamiliar * recent_mask.to(encoded.dtype)
).sum(dim=1)
recent_auth_score = (
recent_auth_raw / self._recent_authorize_norm
).clamp(min=0.0, max=1.0)
base = base + self.probe_cluster_marker * probe_score.unsqueeze(-1)
base = base + self.post_attack_marker * post_score.unsqueeze(-1)
base = base + self.novel_device_marker * novel_device_score.unsqueeze(-1)
base = base + self.signature_clean_marker * sig_clean_score.unsqueeze(-1)
base = base + self.recent_authorize_marker * recent_auth_score.unsqueeze(-1)
encoded[batch_idx, disputed_idx] = base
return encoded
# → (B, T_tx, d_encoder)
def num_embedding_params(self) -> int:
"""Total params across feature embedding tables (sanity check)."""
return sum(e.weight.numel() for e in self.feature_embeddings)