"""Per-transaction encoder. For each transaction independently, looks up per-feature embeddings, concats them, and projects through a small MLP to a single `d_encoder`-dim vector. The encoder produces one continuous embedding per transaction — analogous to how LFM2.5-Audio's FastConformer produces one continuous frame per audio window. Cross-transaction reasoning is delegated entirely to the LFM2.5 backbone. Shape contract: (B, T_tx, F) int64 → (B, T_tx, d_encoder) float Design choices and their reasoning: - **Concat (not sum) over features.** Concat preserves feature identity in early layers; the MLP can collapse to a sum if it wants but cannot recover what sum already destroyed. ADR default; research confirmed. - **Per-feature embedding dim `d_feat=32`.** Concat width 15*32 = 480 keeps the first Linear modestly sized. Smaller `d_feat` (e.g. 16) makes the high-vocab features (merchant_id with 10003 values) under-parameterized; larger d_feat balloons the merchant_id table and dominates parameter count. 32 is the elbow. - **2-layer MLP with SiLU.** Same activation as LFM2's SwiGLU MLPs in the backbone — keeps the encoder feeling like a continuation of the LFM architecture rather than a foreign module. (The downstream projector uses GELU to match LFM2-VL's `Lfm2VlMultiModalProjector` exactly.) - **Embedding tables not initialized from any pretrained source.** The parent's structured-feature backbone has per-feature value tables, but those were trained on a different objective and at a different `d_hidden`. Reusing them would be a science experiment we haven't designed; starting from a fresh init keeps the comparison clean. """ from __future__ import annotations import torch import torch.nn as nn from src.data.schema import SchemaConfig class TransactionEncoder(nn.Module): """Per-feature embeddings → concat → MLP per-transaction. Args: schema: parent's SchemaConfig (defines `num_features` and per-feature `vocab_size`). d_feat: per-feature embedding dimension. Default 32 balances small-vocab and large-vocab features. d_encoder: output dimension. Default 256 matches the parent backbone's `d_hidden` so feature-level expressivity is comparable across the head-to-head. mlp_hidden: intermediate hidden size of the 2-layer encoder MLP. Forward: feature_ids: (B, T_tx, F) int64 returns: (B, T_tx, d_encoder) float (dtype matches embedding tables) """ def __init__( self, schema: SchemaConfig, d_feat: int = 32, d_encoder: int = 256, mlp_hidden: int = 384, enable_collections_markers: bool = False, enable_fraud_markers: bool = False, ) -> None: super().__init__() self.num_features = schema.num_features self.d_feat = d_feat self.d_encoder = d_encoder self.enable_collections_markers = enable_collections_markers self.enable_fraud_markers = enable_fraud_markers # One embedding table per feature, sized to that feature's full vocab # (includes MASK/OOV/NULL reserved tokens). We use padding_idx=None # so all tokens get learned embeddings — MASK is a real token in this # data (transaction generator never emits MASK as a real value, but # tokenization-time NULL handling can; we let the table learn it). self.feature_embeddings = nn.ModuleList( [nn.Embedding(f.vocab_size, d_feat) for f in schema.features], ) concat_dim = self.num_features * d_feat # MLP applied independently per transaction (nn.Linear broadcasts # over leading dims). No layer norm here — LayerNorm comes in the # downstream ProjectionAdapter to match LFM2-VL's projector shape. self.mlp = nn.Sequential( nn.Linear(concat_dim, mlp_hidden), nn.SiLU(), nn.Linear(mlp_hidden, d_encoder), ) # Learnable "disputed transaction" marker, added to the encoded # pseudo-token at the disputed position. Required for the # dispute-legitimacy task — the label, attribution, and LM # reasoning are all conditional on which transaction is being # disputed, and that signal needs to be visible to the model. # Initialized to zeros so day-1 behavior matches a model without # the marker; training learns to populate it. self.disputed_marker = nn.Parameter(torch.zeros(d_encoder)) # Cross-position signals made local at the disputed position. # The label rule depends on (a) how many times the disputed # merchant appears in the customer's history (subscription # pattern) and (b) whether the disputed country appears # elsewhere in history (exotic-country signal). Both are # cross-position properties that a 350M LFM2.5 with LoRA # struggles to learn on 4-5K examples (data-distribution- # doctrine §3 pathology 7: "no amount of rebalancing can teach # a model to see signals that aren't easily readable from the # input"). We make them readable by computing them in the # encoder forward and scaling learnable bias vectors by the # resulting scalar scores at the disputed position. self.subscription_marker = nn.Parameter(torch.zeros(d_encoder)) self.exotic_country_marker = nn.Parameter(torch.zeros(d_encoder)) # Feature column indices for the cross-position computations. # These match the schema in data/schema.yaml: # merchant_id: index 5 # country: index 10 # We keep them as buffers (not parameters) so they ship with # the checkpoint as integers. self.register_buffer( "_merchant_feature_idx", torch.tensor(5, dtype=torch.long), ) self.register_buffer( "_country_feature_idx", torch.tensor(10, dtype=torch.long), ) # Normalizer for subscription_score: divide raw merchant_match # count (after subtracting the disputed position itself) by 10. # A count of 10+ same-merchant transactions saturates the score # at 1.0; lower counts produce linear partial scores. The # rationale: the synthesizer's subscription threshold is >=5 # occurrences, so 4-5 of those plus the disputed produces a # score around 0.4-0.5, deeply-subscribed customers (10+) max # out at 1.0. self._subscription_norm = 10.0 # ----- Collections-specific markers (opt-in via flag) ----- # # Same lesson 2 / Ottoguard pattern as the dispute markers: the # collections label rule depends on cross-position signals # (recent velocity, subscription burden, merchant diversity, # large-amount count, spending volatility) that a 350M backbone # can't reliably learn from ~5K examples. Each marker is a # learnable d_encoder bias scaled by the computed signal at # the "position of interest" (context_idx, = 63 in collections v1). # # Markers are constructed unconditionally so the parameter # count is stable across surfaces; the forward path only adds # them when enable_collections_markers=True. For dispute, the # flag is False and the markers stay at their zero init — # behaviorally identical to the pre-collections encoder. self.velocity_marker = nn.Parameter(torch.zeros(d_encoder)) self.sub_burden_marker = nn.Parameter(torch.zeros(d_encoder)) self.merchant_diversity_marker = nn.Parameter(torch.zeros(d_encoder)) self.large_amount_marker = nn.Parameter(torch.zeros(d_encoder)) self.volatility_marker = nn.Parameter(torch.zeros(d_encoder)) # Feature column indices for the collections signals (must # match data/schema.yaml). self.register_buffer( "_days_since_last_idx", torch.tensor(2, dtype=torch.long), ) self.register_buffer( "_is_recurring_idx", torch.tensor(3, dtype=torch.long), ) self.register_buffer( "_amount_idx", torch.tensor(8, dtype=torch.long), ) # Reserved-offset for is_recurring=True (token 4 = recurring). # Amount-threshold for "large" (token >= 153 ≈ $150+). self._is_recurring_true_token = 4 self._amount_large_threshold = 153 # Velocity normalizer: mean days_since_last token in the last 16 # positions; p90 of population is ~11 (token-space). We divide # by 14 so very-active customers (token ~3) → ~0.2 and # genuinely-dormant customers (token ~14+) → ~1.0. self._velocity_norm = 14.0 # Sub-burden normalizer: count of recurring positions in 64-tx # history; p90 ~18. We divide by 20. self._sub_burden_norm = 20.0 # Merchant diversity normalizer: unique merchants in history; # p90 ~19, max 64. We divide by 20. self._diversity_norm = 20.0 # Large-amount normalizer: count of amount-token >= threshold; # p90 ~23. We divide by 25. self._large_amount_norm = 25.0 # Volatility normalizer: std of amount tokens; p90 ~70. We # divide by 80. self._volatility_norm = 80.0 # Recent window for velocity computation. self._recent_window = 16 # ----- Fraud-pattern markers (opt-in via flag) ----- # The fraud-pattern label rule depends on cross-position # properties at the flagged transaction: # probe_density: small-CNP cluster preceding flagged # post_attack_density: large-unfamiliar cluster around flagged # novel_device: flagged device_hash appears nowhere else # signature_clean: country/CVV/AVS/merchant all match customer's normal # recent_authorize_density: CNP-to-unfamiliar in last 16 (scam pattern) # Each gets a learnable d_encoder bias scaled at the flagged # position. Zero-initialized so day-1 behavior is identical to # a marker-less encoder; training learns to populate. self.probe_cluster_marker = nn.Parameter(torch.zeros(d_encoder)) self.post_attack_marker = nn.Parameter(torch.zeros(d_encoder)) self.novel_device_marker = nn.Parameter(torch.zeros(d_encoder)) self.signature_clean_marker = nn.Parameter(torch.zeros(d_encoder)) self.recent_authorize_marker = nn.Parameter(torch.zeros(d_encoder)) # Feature column indices for fraud signals (match data/schema.yaml). self.register_buffer( "_entry_mode_idx", torch.tensor(7, dtype=torch.long), ) self.register_buffer( "_avs_idx", torch.tensor(11, dtype=torch.long), ) self.register_buffer( "_cvv_idx", torch.tensor(12, dtype=torch.long), ) self.register_buffer( "_device_hash_idx", torch.tensor(13, dtype=torch.long), ) # Probe-density window: PROBE_WINDOW = 6 tx before flagged. self._probe_window = 6 # Post-attack window: 6 tx from flagged forward. self._post_attack_window = 6 # Token thresholds (mirror synthetic_fraud_pattern.py constants). self._entry_cnp_token = 4 # ENTRY_CNP = RESERVED_OFFSET + 1 self._amount_small_token = 11 # AMOUNT_SMALL_THRESH self._amount_large_token = 153 # AMOUNT_LARGE_THRESH self._cmc_unfamiliar_thresh = 4 # CMC_UNFAMILIAR + 1 self._cmc_familiar_thresh = 8 # CMC_FAMILIAR self._cvv_match_token = 3 self._avs_match_token = 3 # Normalizers for score → [0, 1]. self._probe_norm = 6.0 self._post_attack_norm = 6.0 self._recent_authorize_norm = 16.0 self._recent_authorize_window = 16 def forward( self, feature_ids: torch.Tensor, disputed_idx: torch.Tensor | None = None, ) -> torch.Tensor: """Encode each transaction into a pseudo-token. Args: feature_ids: (B, T_tx, F) int64 feature ids per transaction. disputed_idx: optional (B,) int64 — position of the disputed transaction per batch element. When provided, the learnable disputed marker is added to the encoded output at that position. None for tasks that have no notion of a disputed transaction. Returns: (B, T_tx, d_encoder) float — one pseudo-token per transaction. """ # feature_ids: (B, T_tx, F) int64 # We embed each feature column with its own table, then concat. # An alternative is to use a single nn.Embedding with a globally # offset vocab — that's a single kernel but loses per-feature # vocab boundaries and complicates init scale. Per-feature tables # is the LFM-pattern-aligned choice (structured-feature backbone # uses the same pattern). per_feat = [ self.feature_embeddings[f](feature_ids[:, :, f]) for f in range(self.num_features) ] # Each entry: (B, T_tx, d_feat) x = torch.cat(per_feat, dim=-1) # x: (B, T_tx, F * d_feat) encoded = self.mlp(x) # (B, T_tx, d_encoder) # Add the disputed marker at the disputed position, plus two # cross-position bias markers scaled by computed signals. if disputed_idx is not None: B = encoded.shape[0] batch_idx = torch.arange(B, device=encoded.device) # --- subscription_score: count of disputed-merchant matches # in the customer's history, normalized to ~[0, 1]. The # disputed position itself is subtracted so the score # reflects PRIOR usage, not the disputed transaction's own # merchant appearing once. m_idx = int(self._merchant_feature_idx.item()) disp_merchants = feature_ids[batch_idx, disputed_idx, m_idx] # (B,) merchant_matches = ( feature_ids[:, :, m_idx] == disp_merchants.unsqueeze(1) ).to(encoded.dtype) # (B, T_tx) sub_score = (merchant_matches.sum(dim=1) - 1.0) / self._subscription_norm sub_score = sub_score.clamp(min=0.0, max=1.0) # (B,) # --- exotic_country_score: 1.0 if the disputed country # appears nowhere else in the customer's history, else 0.0. c_idx = int(self._country_feature_idx.item()) disp_countries = feature_ids[batch_idx, disputed_idx, c_idx] country_matches = ( feature_ids[:, :, c_idx] == disp_countries.unsqueeze(1) ).to(encoded.dtype) # (B, T_tx) country_appearances = country_matches.sum(dim=1) # (B,) exotic_score = (country_appearances <= 1.0).to(encoded.dtype) # (B,) # Apply all three biases at the disputed position. Each bias # is a learnable d_encoder vector; for the scaled markers # the score scales the bias's magnitude per batch element. base = encoded[batch_idx, disputed_idx] # (B, d_encoder) base = base + self.disputed_marker base = base + self.subscription_marker * sub_score.unsqueeze(-1) base = base + self.exotic_country_marker * exotic_score.unsqueeze(-1) # ----- Collections-specific markers (opt-in) ----- if self.enable_collections_markers: # recent_velocity: mean days_since_last token over the # last RECENT_WINDOW positions. Higher = more dormant. dsl_idx = int(self._days_since_last_idx.item()) recent_dsl = feature_ids[ :, -self._recent_window:, dsl_idx, ].to(encoded.dtype) velocity_raw = recent_dsl.mean(dim=1) # (B,) velocity_score = (velocity_raw / self._velocity_norm).clamp( min=0.0, max=1.0, ) # subscription_burden: count of is_recurring=1 positions ir_idx = int(self._is_recurring_idx.item()) sub_burden_raw = ( feature_ids[:, :, ir_idx] == self._is_recurring_true_token ).to(encoded.dtype).sum(dim=1) # (B,) sub_burden_score = ( sub_burden_raw / self._sub_burden_norm ).clamp(min=0.0, max=1.0) # merchant_diversity: unique merchants per customer. We # approximate with a histogram via one-hot bincount. For # batches this is most simply computed by sorting + # counting transitions; here we use a softer proxy # (entropy across positions is expensive). Use the count # of positions whose merchant_id differs from the prior # position — a fast diversity proxy. Equals # (unique-1) for monotonically-changing sequences and # 0 for all-same merchant. merchants = feature_ids[:, :, m_idx] # (B, T_tx) changes = (merchants[:, 1:] != merchants[:, :-1]).to(encoded.dtype) diversity_raw = changes.sum(dim=1) # (B,) diversity_score = ( diversity_raw / self._diversity_norm ).clamp(min=0.0, max=1.0) # large_amount_count: count of amount tokens >= threshold amt_idx = int(self._amount_idx.item()) large_amt_raw = ( feature_ids[:, :, amt_idx] >= self._amount_large_threshold ).to(encoded.dtype).sum(dim=1) large_amt_score = ( large_amt_raw / self._large_amount_norm ).clamp(min=0.0, max=1.0) # spending_volatility: std of amount tokens. Computed # in encoded.dtype to keep bf16/fp32 dtype consistency. amounts = feature_ids[:, :, amt_idx].to(encoded.dtype) vol_raw = amounts.std(dim=1, unbiased=False) # (B,) volatility_score = ( vol_raw / self._volatility_norm ).clamp(min=0.0, max=1.0) # Apply at the position of interest. base = base + self.velocity_marker * velocity_score.unsqueeze(-1) base = base + self.sub_burden_marker * sub_burden_score.unsqueeze(-1) base = base + self.merchant_diversity_marker * diversity_score.unsqueeze(-1) base = base + self.large_amount_marker * large_amt_score.unsqueeze(-1) base = base + self.volatility_marker * volatility_score.unsqueeze(-1) # ----- Fraud-pattern markers (opt-in) ----- if self.enable_fraud_markers: em_idx = int(self._entry_mode_idx.item()) amt_idx = int(self._amount_idx.item()) dev_idx = int(self._device_hash_idx.item()) cvv_idx = int(self._cvv_idx.item()) avs_idx = int(self._avs_idx.item()) cmc_idx = int(self._merchant_feature_idx.item()) # merchant_id cmc_count_idx = 6 # FEATURE_CUSTOMER_MERCHANT_COUNT country_idx = int(self._country_feature_idx.item()) B = encoded.shape[0] # Generate a position grid for window masking. # positions: (T_tx,) → broadcast for windows around disputed_idx. T_tx = feature_ids.shape[1] positions = torch.arange(T_tx, device=encoded.device) # (B, T_tx) — True for positions in [disputed-PROBE_WINDOW, disputed) pre_mask = ( (positions.unsqueeze(0) >= (disputed_idx - self._probe_window).unsqueeze(1)) & (positions.unsqueeze(0) < disputed_idx.unsqueeze(1)) ) # (B, T_tx) — True for positions in [disputed, disputed+POST_WINDOW) post_mask = ( (positions.unsqueeze(0) >= disputed_idx.unsqueeze(1)) & (positions.unsqueeze(0) < (disputed_idx + self._post_attack_window).unsqueeze(1)) ) # (B, T_tx) — True for positions in [disputed-RECENT_AUTHORIZE_WINDOW, disputed+1) recent_mask = ( (positions.unsqueeze(0) >= (disputed_idx - self._recent_authorize_window).unsqueeze(1)) & (positions.unsqueeze(0) <= disputed_idx.unsqueeze(1)) ) # --- probe_density: small-CNP count in pre_mask is_cnp = (feature_ids[:, :, em_idx] == self._entry_cnp_token).to(encoded.dtype) is_small = (feature_ids[:, :, amt_idx] <= self._amount_small_token).to(encoded.dtype) probe_raw = (is_cnp * is_small * pre_mask.to(encoded.dtype)).sum(dim=1) probe_score = (probe_raw / self._probe_norm).clamp(min=0.0, max=1.0) # --- post_attack_density: large-unfamiliar count in post_mask is_large = (feature_ids[:, :, amt_idx] >= self._amount_large_token).to(encoded.dtype) is_unfamiliar = ( feature_ids[:, :, cmc_count_idx] <= self._cmc_unfamiliar_thresh ).to(encoded.dtype) post_raw = (is_large * is_unfamiliar * post_mask.to(encoded.dtype)).sum(dim=1) post_score = (post_raw / self._post_attack_norm).clamp(min=0.0, max=1.0) # --- novel_device: device at flagged appears nowhere else flagged_device = feature_ids[batch_idx, disputed_idx, dev_idx] # (B,) device_matches = ( feature_ids[:, :, dev_idx] == flagged_device.unsqueeze(1) ).to(encoded.dtype).sum(dim=1) # (B,) novel_device_score = (device_matches <= 1.0).to(encoded.dtype) # --- signature_clean: country = mode + CVV match + AVS match + familiar merchant countries = feature_ids[:, :, country_idx] # (B, T_tx) # Mode country per row via torch.mode (returns values + indices). mode_country = countries.mode(dim=1).values # (B,) flagged_country = feature_ids[batch_idx, disputed_idx, country_idx] country_match = (flagged_country == mode_country).to(encoded.dtype) cvv_match = ( feature_ids[batch_idx, disputed_idx, cvv_idx] == self._cvv_match_token ).to(encoded.dtype) avs_match = ( feature_ids[batch_idx, disputed_idx, avs_idx] == self._avs_match_token ).to(encoded.dtype) merchant_familiar = ( feature_ids[batch_idx, disputed_idx, cmc_count_idx] >= self._cmc_familiar_thresh ).to(encoded.dtype) sig_clean_score = country_match * cvv_match * avs_match * merchant_familiar # --- recent_authorize_density: CNP × unfamiliar in recent_mask recent_auth_raw = ( is_cnp * is_unfamiliar * recent_mask.to(encoded.dtype) ).sum(dim=1) recent_auth_score = ( recent_auth_raw / self._recent_authorize_norm ).clamp(min=0.0, max=1.0) base = base + self.probe_cluster_marker * probe_score.unsqueeze(-1) base = base + self.post_attack_marker * post_score.unsqueeze(-1) base = base + self.novel_device_marker * novel_device_score.unsqueeze(-1) base = base + self.signature_clean_marker * sig_clean_score.unsqueeze(-1) base = base + self.recent_authorize_marker * recent_auth_score.unsqueeze(-1) encoded[batch_idx, disputed_idx] = base return encoded # → (B, T_tx, d_encoder) def num_embedding_params(self) -> int: """Total params across feature embedding tables (sanity check).""" return sum(e.weight.numel() for e in self.feature_embeddings)