"""Two-head categorical output for the Fraud Pattern surface. Where dispute has a single 3-class output and collections has K=4 treatments × 3 bands, Fraud has TWO independent categoricals at DIFFERENT cardinalities: stage_logits (B, NUM_STAGES=5) pre_attack / probing / monetization / exfiltration / dormant type_logits (B, NUM_TYPES=4) victim_fraud / account_takeover / scam_redirected / declined_legit Both heads share the mean-pooled representation; they are independent classifications, not a joint distribution over the 20 cells. This mirrors how a fraud analyst thinks: "what stage is the attack at?" is orthogonal to "what kind of attack is it?" To fit the model's existing prob_head interface (`forward → tensor`, `compute_loss(logits, labels)`), the head emits a FLAT (B, 9) tensor — the concatenation of stage_logits and type_logits — and the caller splits via the registered offsets. Targets are (B, 2) int64 with column 0 = stage and column 1 = type. This pattern generalizes: any number of independent categorical classifiers can be packed into one flat logits tensor + a target matrix, as long as compute_loss knows the layout. Generalizing is overkill for two heads though; we keep the structure surface-specific. """ from __future__ import annotations from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F @dataclass class FraudPatternHeadConfig: """Two-head categorical config. Attributes: name: identifier for logging. num_stages: 5 for the Fraud surface. num_types: 4 for the Fraud surface. hidden_dim: backbone hidden size (1024 for LFM2.5-350M). mlp_hidden: intermediate hidden of the per-head MLP. Each head owns its own MLP so the representations don't entangle. dropout: dropout between MLP layers. num_tx_positions: leading transaction positions to mean-pool over. stage_class_weights: optional CE weights per stage band; up-weight rare classes (PROBING / MONETIZATION / EXFILTRATION). type_class_weights: optional CE weights per type band. """ name: str num_stages: int = 5 num_types: int = 4 hidden_dim: int = 1024 mlp_hidden: int = 256 dropout: float = 0.1 num_tx_positions: int = 64 stage_class_weights: list[float] | None = None type_class_weights: list[float] | None = None class FraudPatternHead(nn.Module): """Two MLPs (stage + type) over a shared mean-pooled representation. Parameter count at defaults (D=1024, mlp_hidden=256): ~265K + 265K = ~530K. Same order of magnitude as the dispute probability head, well below the LoRA budget. """ def __init__(self, config: FraudPatternHeadConfig) -> None: super().__init__() self.config = config self.stage_mlp = nn.Sequential( nn.Linear(config.hidden_dim, config.mlp_hidden), nn.ReLU(), nn.Dropout(config.dropout), nn.Linear(config.mlp_hidden, config.num_stages), ) self.type_mlp = nn.Sequential( nn.Linear(config.hidden_dim, config.mlp_hidden), nn.ReLU(), nn.Dropout(config.dropout), nn.Linear(config.mlp_hidden, config.num_types), ) if config.stage_class_weights is not None: if len(config.stage_class_weights) != config.num_stages: raise ValueError( f"stage_class_weights must have length {config.num_stages}, " f"got {len(config.stage_class_weights)}", ) self.register_buffer( "stage_weights", torch.tensor(config.stage_class_weights, dtype=torch.float32), ) else: self.register_buffer("stage_weights", None, persistent=False) if config.type_class_weights is not None: if len(config.type_class_weights) != config.num_types: raise ValueError( f"type_class_weights must have length {config.num_types}, " f"got {len(config.type_class_weights)}", ) self.register_buffer( "type_weights", torch.tensor(config.type_class_weights, dtype=torch.float32), ) else: self.register_buffer("type_weights", None, persistent=False) def pool(self, hidden_states: torch.Tensor) -> torch.Tensor: head_dtype = next(self.stage_mlp.parameters()).dtype tx_slice = hidden_states[:, : self.config.num_tx_positions, :].to(head_dtype) return tx_slice.mean(dim=1) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """Returns (B, num_stages + num_types) flat logits. The model class slices via `stage_logits` and `type_logits` properties on the result tensor. Compute_loss splits internally. """ pooled = self.pool(hidden_states) stage_logits = self.stage_mlp(pooled) # (B, num_stages) type_logits = self.type_mlp(pooled) # (B, num_types) return torch.cat([stage_logits, type_logits], dim=-1) # (B, S+T) def split_logits( self, logits: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Slice the flat (B, S+T) tensor back into (stage, type).""" ns = self.config.num_stages return logits[..., :ns], logits[..., ns:] def compute_loss( self, logits: torch.Tensor, targets: torch.Tensor, ) -> torch.Tensor: """Sum of CE on stage + CE on type, mean-reduced. Args: logits: (B, num_stages + num_types). targets: (B, 2) int64 — col 0 stage idx, col 1 type idx. Returns: Scalar loss = stage_CE + type_CE. Equal weight by default; per-class weights up-weight rare bands via the registered buffers. """ stage_logits, type_logits = self.split_logits(logits) stage_targets = targets[:, 0].long() type_targets = targets[:, 1].long() stage_loss = F.cross_entropy( stage_logits, stage_targets, weight=self.stage_weights, reduction="mean", ) type_loss = F.cross_entropy( type_logits, type_targets, weight=self.type_weights, reduction="mean", ) return stage_loss + type_loss @torch.no_grad() def score(self, logits: torch.Tensor) -> torch.Tensor: """For the headline UI: the max probability across the non-DORMANT stages. Higher = more confident the customer is under active attack. For demo logging only; the actual UI consumes split logits directly. """ stage_logits, _ = self.split_logits(logits) stage_probs = F.softmax(stage_logits, dim=-1) # (B, num_stages) # Last class index is DORMANT (=4); the active-attack score is # 1 - P(DORMANT). return 1.0 - stage_probs[..., -1] @torch.no_grad() def predict_band(self, logits: torch.Tensor) -> torch.Tensor: """For uniformity with ProbabilityHead.predict_band. Returns (B, 2) int64 — predicted stage + type.""" stage_logits, type_logits = self.split_logits(logits) return torch.stack([ stage_logits.argmax(dim=-1), type_logits.argmax(dim=-1), ], dim=-1) @torch.no_grad() def stage_probabilities(self, logits: torch.Tensor) -> torch.Tensor: stage_logits, _ = self.split_logits(logits) return F.softmax(stage_logits, dim=-1) @torch.no_grad() def type_probabilities(self, logits: torch.Tensor) -> torch.Tensor: _, type_logits = self.split_logits(logits) return F.softmax(type_logits, dim=-1) def num_parameters(self) -> int: return sum(p.numel() for p in self.parameters())