| """Two-head categorical output for the Fraud Pattern surface. |
| |
| Where dispute has a single 3-class output and collections has K=4 |
| treatments × 3 bands, Fraud has TWO independent categoricals at |
| DIFFERENT cardinalities: |
| |
| stage_logits (B, NUM_STAGES=5) pre_attack / probing / monetization / |
| exfiltration / dormant |
| type_logits (B, NUM_TYPES=4) victim_fraud / account_takeover / |
| scam_redirected / declined_legit |
| |
| Both heads share the mean-pooled representation; they are independent |
| classifications, not a joint distribution over the 20 cells. This |
| mirrors how a fraud analyst thinks: "what stage is the attack at?" is |
| orthogonal to "what kind of attack is it?" |
| |
| To fit the model's existing prob_head interface (`forward → tensor`, |
| `compute_loss(logits, labels)`), the head emits a FLAT (B, 9) tensor — |
| the concatenation of stage_logits and type_logits — and the caller |
| splits via the registered offsets. Targets are (B, 2) int64 with |
| column 0 = stage and column 1 = type. |
| |
| This pattern generalizes: any number of independent categorical |
| classifiers can be packed into one flat logits tensor + a target |
| matrix, as long as compute_loss knows the layout. Generalizing is |
| overkill for two heads though; we keep the structure surface-specific. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| @dataclass |
| class FraudPatternHeadConfig: |
| """Two-head categorical config. |
| |
| Attributes: |
| name: identifier for logging. |
| num_stages: 5 for the Fraud surface. |
| num_types: 4 for the Fraud surface. |
| hidden_dim: backbone hidden size (1024 for LFM2.5-350M). |
| mlp_hidden: intermediate hidden of the per-head MLP. Each head |
| owns its own MLP so the representations don't entangle. |
| dropout: dropout between MLP layers. |
| num_tx_positions: leading transaction positions to mean-pool over. |
| stage_class_weights: optional CE weights per stage band; up-weight |
| rare classes (PROBING / MONETIZATION / EXFILTRATION). |
| type_class_weights: optional CE weights per type band. |
| """ |
|
|
| name: str |
| num_stages: int = 5 |
| num_types: int = 4 |
| hidden_dim: int = 1024 |
| mlp_hidden: int = 256 |
| dropout: float = 0.1 |
| num_tx_positions: int = 64 |
| stage_class_weights: list[float] | None = None |
| type_class_weights: list[float] | None = None |
|
|
|
|
| class FraudPatternHead(nn.Module): |
| """Two MLPs (stage + type) over a shared mean-pooled representation. |
| |
| Parameter count at defaults (D=1024, mlp_hidden=256): ~265K + 265K = |
| ~530K. Same order of magnitude as the dispute probability head, well |
| below the LoRA budget. |
| """ |
|
|
| def __init__(self, config: FraudPatternHeadConfig) -> None: |
| super().__init__() |
| self.config = config |
|
|
| self.stage_mlp = nn.Sequential( |
| nn.Linear(config.hidden_dim, config.mlp_hidden), |
| nn.ReLU(), |
| nn.Dropout(config.dropout), |
| nn.Linear(config.mlp_hidden, config.num_stages), |
| ) |
| self.type_mlp = nn.Sequential( |
| nn.Linear(config.hidden_dim, config.mlp_hidden), |
| nn.ReLU(), |
| nn.Dropout(config.dropout), |
| nn.Linear(config.mlp_hidden, config.num_types), |
| ) |
|
|
| if config.stage_class_weights is not None: |
| if len(config.stage_class_weights) != config.num_stages: |
| raise ValueError( |
| f"stage_class_weights must have length {config.num_stages}, " |
| f"got {len(config.stage_class_weights)}", |
| ) |
| self.register_buffer( |
| "stage_weights", |
| torch.tensor(config.stage_class_weights, dtype=torch.float32), |
| ) |
| else: |
| self.register_buffer("stage_weights", None, persistent=False) |
|
|
| if config.type_class_weights is not None: |
| if len(config.type_class_weights) != config.num_types: |
| raise ValueError( |
| f"type_class_weights must have length {config.num_types}, " |
| f"got {len(config.type_class_weights)}", |
| ) |
| self.register_buffer( |
| "type_weights", |
| torch.tensor(config.type_class_weights, dtype=torch.float32), |
| ) |
| else: |
| self.register_buffer("type_weights", None, persistent=False) |
|
|
| def pool(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| head_dtype = next(self.stage_mlp.parameters()).dtype |
| tx_slice = hidden_states[:, : self.config.num_tx_positions, :].to(head_dtype) |
| return tx_slice.mean(dim=1) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| """Returns (B, num_stages + num_types) flat logits. |
| |
| The model class slices via `stage_logits` and `type_logits` |
| properties on the result tensor. Compute_loss splits internally. |
| """ |
| pooled = self.pool(hidden_states) |
| stage_logits = self.stage_mlp(pooled) |
| type_logits = self.type_mlp(pooled) |
| return torch.cat([stage_logits, type_logits], dim=-1) |
|
|
| def split_logits( |
| self, logits: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Slice the flat (B, S+T) tensor back into (stage, type).""" |
| ns = self.config.num_stages |
| return logits[..., :ns], logits[..., ns:] |
|
|
| def compute_loss( |
| self, |
| logits: torch.Tensor, |
| targets: torch.Tensor, |
| ) -> torch.Tensor: |
| """Sum of CE on stage + CE on type, mean-reduced. |
| |
| Args: |
| logits: (B, num_stages + num_types). |
| targets: (B, 2) int64 — col 0 stage idx, col 1 type idx. |
| |
| Returns: |
| Scalar loss = stage_CE + type_CE. Equal weight by default; |
| per-class weights up-weight rare bands via the registered |
| buffers. |
| """ |
| stage_logits, type_logits = self.split_logits(logits) |
| stage_targets = targets[:, 0].long() |
| type_targets = targets[:, 1].long() |
|
|
| stage_loss = F.cross_entropy( |
| stage_logits, stage_targets, |
| weight=self.stage_weights, reduction="mean", |
| ) |
| type_loss = F.cross_entropy( |
| type_logits, type_targets, |
| weight=self.type_weights, reduction="mean", |
| ) |
| return stage_loss + type_loss |
|
|
| @torch.no_grad() |
| def score(self, logits: torch.Tensor) -> torch.Tensor: |
| """For the headline UI: the max probability across the |
| non-DORMANT stages. Higher = more confident the customer is |
| under active attack. For demo logging only; the actual UI |
| consumes split logits directly. |
| """ |
| stage_logits, _ = self.split_logits(logits) |
| stage_probs = F.softmax(stage_logits, dim=-1) |
| |
| |
| return 1.0 - stage_probs[..., -1] |
|
|
| @torch.no_grad() |
| def predict_band(self, logits: torch.Tensor) -> torch.Tensor: |
| """For uniformity with ProbabilityHead.predict_band. Returns |
| (B, 2) int64 — predicted stage + type.""" |
| stage_logits, type_logits = self.split_logits(logits) |
| return torch.stack([ |
| stage_logits.argmax(dim=-1), |
| type_logits.argmax(dim=-1), |
| ], dim=-1) |
|
|
| @torch.no_grad() |
| def stage_probabilities(self, logits: torch.Tensor) -> torch.Tensor: |
| stage_logits, _ = self.split_logits(logits) |
| return F.softmax(stage_logits, dim=-1) |
|
|
| @torch.no_grad() |
| def type_probabilities(self, logits: torch.Tensor) -> torch.Tensor: |
| _, type_logits = self.split_logits(logits) |
| return F.softmax(type_logits, dim=-1) |
|
|
| def num_parameters(self) -> int: |
| return sum(p.numel() for p in self.parameters()) |
|
|