| """Per-transaction encoder. |
| |
| For each transaction independently, looks up per-feature embeddings, concats |
| them, and projects through a small MLP to a single `d_encoder`-dim vector. |
| The encoder produces one continuous embedding per transaction — analogous to |
| how LFM2.5-Audio's FastConformer produces one continuous frame per audio |
| window. Cross-transaction reasoning is delegated entirely to the LFM2.5 |
| backbone. |
| |
| Shape contract: |
| (B, T_tx, F) int64 → (B, T_tx, d_encoder) float |
| |
| Design choices and their reasoning: |
| |
| - **Concat (not sum) over features.** Concat preserves feature identity in |
| early layers; the MLP can collapse to a sum if it wants but cannot recover |
| what sum already destroyed. ADR default; research confirmed. |
| |
| - **Per-feature embedding dim `d_feat=32`.** Concat width 15*32 = 480 keeps |
| the first Linear modestly sized. Smaller `d_feat` (e.g. 16) makes the |
| high-vocab features (merchant_id with 10003 values) under-parameterized; |
| larger d_feat balloons the merchant_id table and dominates parameter count. |
| 32 is the elbow. |
| |
| - **2-layer MLP with SiLU.** Same activation as LFM2's SwiGLU MLPs in the |
| backbone — keeps the encoder feeling like a continuation of the LFM |
| architecture rather than a foreign module. (The downstream projector uses |
| GELU to match LFM2-VL's `Lfm2VlMultiModalProjector` exactly.) |
| |
| - **Embedding tables not initialized from any pretrained source.** The |
| parent's structured-feature backbone has per-feature value tables, but |
| those were trained on a different objective and at a different `d_hidden`. |
| Reusing them would be a science experiment we haven't designed; starting |
| from a fresh init keeps the comparison clean. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from src.data.schema import SchemaConfig |
|
|
|
|
| class TransactionEncoder(nn.Module): |
| """Per-feature embeddings → concat → MLP per-transaction. |
| |
| Args: |
| schema: parent's SchemaConfig (defines `num_features` and per-feature |
| `vocab_size`). |
| d_feat: per-feature embedding dimension. Default 32 balances |
| small-vocab and large-vocab features. |
| d_encoder: output dimension. Default 256 matches the parent |
| backbone's `d_hidden` so feature-level expressivity is |
| comparable across the head-to-head. |
| mlp_hidden: intermediate hidden size of the 2-layer encoder MLP. |
| |
| Forward: |
| feature_ids: (B, T_tx, F) int64 |
| returns: (B, T_tx, d_encoder) float (dtype matches embedding tables) |
| """ |
|
|
| def __init__( |
| self, |
| schema: SchemaConfig, |
| d_feat: int = 32, |
| d_encoder: int = 256, |
| mlp_hidden: int = 384, |
| enable_collections_markers: bool = False, |
| enable_fraud_markers: bool = False, |
| ) -> None: |
| super().__init__() |
| self.num_features = schema.num_features |
| self.d_feat = d_feat |
| self.d_encoder = d_encoder |
| self.enable_collections_markers = enable_collections_markers |
| self.enable_fraud_markers = enable_fraud_markers |
|
|
| |
| |
| |
| |
| |
| self.feature_embeddings = nn.ModuleList( |
| [nn.Embedding(f.vocab_size, d_feat) for f in schema.features], |
| ) |
|
|
| concat_dim = self.num_features * d_feat |
| |
| |
| |
| self.mlp = nn.Sequential( |
| nn.Linear(concat_dim, mlp_hidden), |
| nn.SiLU(), |
| nn.Linear(mlp_hidden, d_encoder), |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| self.disputed_marker = nn.Parameter(torch.zeros(d_encoder)) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self.subscription_marker = nn.Parameter(torch.zeros(d_encoder)) |
| self.exotic_country_marker = nn.Parameter(torch.zeros(d_encoder)) |
| |
| |
| |
| |
| |
| |
| self.register_buffer( |
| "_merchant_feature_idx", torch.tensor(5, dtype=torch.long), |
| ) |
| self.register_buffer( |
| "_country_feature_idx", torch.tensor(10, dtype=torch.long), |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| self._subscription_norm = 10.0 |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self.velocity_marker = nn.Parameter(torch.zeros(d_encoder)) |
| self.sub_burden_marker = nn.Parameter(torch.zeros(d_encoder)) |
| self.merchant_diversity_marker = nn.Parameter(torch.zeros(d_encoder)) |
| self.large_amount_marker = nn.Parameter(torch.zeros(d_encoder)) |
| self.volatility_marker = nn.Parameter(torch.zeros(d_encoder)) |
|
|
| |
| |
| self.register_buffer( |
| "_days_since_last_idx", torch.tensor(2, dtype=torch.long), |
| ) |
| self.register_buffer( |
| "_is_recurring_idx", torch.tensor(3, dtype=torch.long), |
| ) |
| self.register_buffer( |
| "_amount_idx", torch.tensor(8, dtype=torch.long), |
| ) |
| |
| |
| self._is_recurring_true_token = 4 |
| self._amount_large_threshold = 153 |
| |
| |
| |
| |
| self._velocity_norm = 14.0 |
| |
| |
| self._sub_burden_norm = 20.0 |
| |
| |
| self._diversity_norm = 20.0 |
| |
| |
| self._large_amount_norm = 25.0 |
| |
| |
| self._volatility_norm = 80.0 |
| |
| self._recent_window = 16 |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self.probe_cluster_marker = nn.Parameter(torch.zeros(d_encoder)) |
| self.post_attack_marker = nn.Parameter(torch.zeros(d_encoder)) |
| self.novel_device_marker = nn.Parameter(torch.zeros(d_encoder)) |
| self.signature_clean_marker = nn.Parameter(torch.zeros(d_encoder)) |
| self.recent_authorize_marker = nn.Parameter(torch.zeros(d_encoder)) |
|
|
| |
| self.register_buffer( |
| "_entry_mode_idx", torch.tensor(7, dtype=torch.long), |
| ) |
| self.register_buffer( |
| "_avs_idx", torch.tensor(11, dtype=torch.long), |
| ) |
| self.register_buffer( |
| "_cvv_idx", torch.tensor(12, dtype=torch.long), |
| ) |
| self.register_buffer( |
| "_device_hash_idx", torch.tensor(13, dtype=torch.long), |
| ) |
| |
| self._probe_window = 6 |
| |
| self._post_attack_window = 6 |
| |
| self._entry_cnp_token = 4 |
| self._amount_small_token = 11 |
| self._amount_large_token = 153 |
| self._cmc_unfamiliar_thresh = 4 |
| self._cmc_familiar_thresh = 8 |
| self._cvv_match_token = 3 |
| self._avs_match_token = 3 |
| |
| self._probe_norm = 6.0 |
| self._post_attack_norm = 6.0 |
| self._recent_authorize_norm = 16.0 |
| self._recent_authorize_window = 16 |
|
|
| def forward( |
| self, |
| feature_ids: torch.Tensor, |
| disputed_idx: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| """Encode each transaction into a pseudo-token. |
| |
| Args: |
| feature_ids: (B, T_tx, F) int64 feature ids per transaction. |
| disputed_idx: optional (B,) int64 — position of the disputed |
| transaction per batch element. When provided, the |
| learnable disputed marker is added to the encoded output |
| at that position. None for tasks that have no notion of |
| a disputed transaction. |
| |
| Returns: |
| (B, T_tx, d_encoder) float — one pseudo-token per transaction. |
| """ |
| |
| |
| |
| |
| |
| |
| |
| per_feat = [ |
| self.feature_embeddings[f](feature_ids[:, :, f]) |
| for f in range(self.num_features) |
| ] |
| |
| x = torch.cat(per_feat, dim=-1) |
| |
| encoded = self.mlp(x) |
|
|
| |
| |
| if disputed_idx is not None: |
| B = encoded.shape[0] |
| batch_idx = torch.arange(B, device=encoded.device) |
|
|
| |
| |
| |
| |
| |
| m_idx = int(self._merchant_feature_idx.item()) |
| disp_merchants = feature_ids[batch_idx, disputed_idx, m_idx] |
| merchant_matches = ( |
| feature_ids[:, :, m_idx] == disp_merchants.unsqueeze(1) |
| ).to(encoded.dtype) |
| sub_score = (merchant_matches.sum(dim=1) - 1.0) / self._subscription_norm |
| sub_score = sub_score.clamp(min=0.0, max=1.0) |
|
|
| |
| |
| c_idx = int(self._country_feature_idx.item()) |
| disp_countries = feature_ids[batch_idx, disputed_idx, c_idx] |
| country_matches = ( |
| feature_ids[:, :, c_idx] == disp_countries.unsqueeze(1) |
| ).to(encoded.dtype) |
| country_appearances = country_matches.sum(dim=1) |
| exotic_score = (country_appearances <= 1.0).to(encoded.dtype) |
|
|
| |
| |
| |
| base = encoded[batch_idx, disputed_idx] |
| base = base + self.disputed_marker |
| base = base + self.subscription_marker * sub_score.unsqueeze(-1) |
| base = base + self.exotic_country_marker * exotic_score.unsqueeze(-1) |
|
|
| |
| if self.enable_collections_markers: |
| |
| |
| dsl_idx = int(self._days_since_last_idx.item()) |
| recent_dsl = feature_ids[ |
| :, -self._recent_window:, dsl_idx, |
| ].to(encoded.dtype) |
| velocity_raw = recent_dsl.mean(dim=1) |
| velocity_score = (velocity_raw / self._velocity_norm).clamp( |
| min=0.0, max=1.0, |
| ) |
|
|
| |
| ir_idx = int(self._is_recurring_idx.item()) |
| sub_burden_raw = ( |
| feature_ids[:, :, ir_idx] == self._is_recurring_true_token |
| ).to(encoded.dtype).sum(dim=1) |
| sub_burden_score = ( |
| sub_burden_raw / self._sub_burden_norm |
| ).clamp(min=0.0, max=1.0) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| merchants = feature_ids[:, :, m_idx] |
| changes = (merchants[:, 1:] != merchants[:, :-1]).to(encoded.dtype) |
| diversity_raw = changes.sum(dim=1) |
| diversity_score = ( |
| diversity_raw / self._diversity_norm |
| ).clamp(min=0.0, max=1.0) |
|
|
| |
| amt_idx = int(self._amount_idx.item()) |
| large_amt_raw = ( |
| feature_ids[:, :, amt_idx] >= self._amount_large_threshold |
| ).to(encoded.dtype).sum(dim=1) |
| large_amt_score = ( |
| large_amt_raw / self._large_amount_norm |
| ).clamp(min=0.0, max=1.0) |
|
|
| |
| |
| amounts = feature_ids[:, :, amt_idx].to(encoded.dtype) |
| vol_raw = amounts.std(dim=1, unbiased=False) |
| volatility_score = ( |
| vol_raw / self._volatility_norm |
| ).clamp(min=0.0, max=1.0) |
|
|
| |
| base = base + self.velocity_marker * velocity_score.unsqueeze(-1) |
| base = base + self.sub_burden_marker * sub_burden_score.unsqueeze(-1) |
| base = base + self.merchant_diversity_marker * diversity_score.unsqueeze(-1) |
| base = base + self.large_amount_marker * large_amt_score.unsqueeze(-1) |
| base = base + self.volatility_marker * volatility_score.unsqueeze(-1) |
|
|
| |
| if self.enable_fraud_markers: |
| em_idx = int(self._entry_mode_idx.item()) |
| amt_idx = int(self._amount_idx.item()) |
| dev_idx = int(self._device_hash_idx.item()) |
| cvv_idx = int(self._cvv_idx.item()) |
| avs_idx = int(self._avs_idx.item()) |
| cmc_idx = int(self._merchant_feature_idx.item()) |
| cmc_count_idx = 6 |
| country_idx = int(self._country_feature_idx.item()) |
|
|
| B = encoded.shape[0] |
| |
| |
| T_tx = feature_ids.shape[1] |
| positions = torch.arange(T_tx, device=encoded.device) |
| |
| pre_mask = ( |
| (positions.unsqueeze(0) >= (disputed_idx - self._probe_window).unsqueeze(1)) |
| & (positions.unsqueeze(0) < disputed_idx.unsqueeze(1)) |
| ) |
| |
| post_mask = ( |
| (positions.unsqueeze(0) >= disputed_idx.unsqueeze(1)) |
| & (positions.unsqueeze(0) < (disputed_idx + self._post_attack_window).unsqueeze(1)) |
| ) |
| |
| recent_mask = ( |
| (positions.unsqueeze(0) >= (disputed_idx - self._recent_authorize_window).unsqueeze(1)) |
| & (positions.unsqueeze(0) <= disputed_idx.unsqueeze(1)) |
| ) |
|
|
| |
| is_cnp = (feature_ids[:, :, em_idx] == self._entry_cnp_token).to(encoded.dtype) |
| is_small = (feature_ids[:, :, amt_idx] <= self._amount_small_token).to(encoded.dtype) |
| probe_raw = (is_cnp * is_small * pre_mask.to(encoded.dtype)).sum(dim=1) |
| probe_score = (probe_raw / self._probe_norm).clamp(min=0.0, max=1.0) |
|
|
| |
| is_large = (feature_ids[:, :, amt_idx] >= self._amount_large_token).to(encoded.dtype) |
| is_unfamiliar = ( |
| feature_ids[:, :, cmc_count_idx] <= self._cmc_unfamiliar_thresh |
| ).to(encoded.dtype) |
| post_raw = (is_large * is_unfamiliar * post_mask.to(encoded.dtype)).sum(dim=1) |
| post_score = (post_raw / self._post_attack_norm).clamp(min=0.0, max=1.0) |
|
|
| |
| flagged_device = feature_ids[batch_idx, disputed_idx, dev_idx] |
| device_matches = ( |
| feature_ids[:, :, dev_idx] == flagged_device.unsqueeze(1) |
| ).to(encoded.dtype).sum(dim=1) |
| novel_device_score = (device_matches <= 1.0).to(encoded.dtype) |
|
|
| |
| countries = feature_ids[:, :, country_idx] |
| |
| mode_country = countries.mode(dim=1).values |
| flagged_country = feature_ids[batch_idx, disputed_idx, country_idx] |
| country_match = (flagged_country == mode_country).to(encoded.dtype) |
| cvv_match = ( |
| feature_ids[batch_idx, disputed_idx, cvv_idx] == self._cvv_match_token |
| ).to(encoded.dtype) |
| avs_match = ( |
| feature_ids[batch_idx, disputed_idx, avs_idx] == self._avs_match_token |
| ).to(encoded.dtype) |
| merchant_familiar = ( |
| feature_ids[batch_idx, disputed_idx, cmc_count_idx] |
| >= self._cmc_familiar_thresh |
| ).to(encoded.dtype) |
| sig_clean_score = country_match * cvv_match * avs_match * merchant_familiar |
|
|
| |
| recent_auth_raw = ( |
| is_cnp * is_unfamiliar * recent_mask.to(encoded.dtype) |
| ).sum(dim=1) |
| recent_auth_score = ( |
| recent_auth_raw / self._recent_authorize_norm |
| ).clamp(min=0.0, max=1.0) |
|
|
| base = base + self.probe_cluster_marker * probe_score.unsqueeze(-1) |
| base = base + self.post_attack_marker * post_score.unsqueeze(-1) |
| base = base + self.novel_device_marker * novel_device_score.unsqueeze(-1) |
| base = base + self.signature_clean_marker * sig_clean_score.unsqueeze(-1) |
| base = base + self.recent_authorize_marker * recent_auth_score.unsqueeze(-1) |
|
|
| encoded[batch_idx, disputed_idx] = base |
|
|
| return encoded |
| |
|
|
| def num_embedding_params(self) -> int: |
| """Total params across feature embedding tables (sanity check).""" |
| return sum(e.weight.numel() for e in self.feature_embeddings) |
|
|