lfm2-transaction-encoder / encoder /src /model /heads /fraud_pattern_head.py
cdotsanghvi's picture
initial transaction co-pilot deployment
b3112c7
Raw
History Blame Contribute Delete
8.01 kB
"""Two-head categorical output for the Fraud Pattern surface.
Where dispute has a single 3-class output and collections has K=4
treatments × 3 bands, Fraud has TWO independent categoricals at
DIFFERENT cardinalities:
stage_logits (B, NUM_STAGES=5) pre_attack / probing / monetization /
exfiltration / dormant
type_logits (B, NUM_TYPES=4) victim_fraud / account_takeover /
scam_redirected / declined_legit
Both heads share the mean-pooled representation; they are independent
classifications, not a joint distribution over the 20 cells. This
mirrors how a fraud analyst thinks: "what stage is the attack at?" is
orthogonal to "what kind of attack is it?"
To fit the model's existing prob_head interface (`forward → tensor`,
`compute_loss(logits, labels)`), the head emits a FLAT (B, 9) tensor —
the concatenation of stage_logits and type_logits — and the caller
splits via the registered offsets. Targets are (B, 2) int64 with
column 0 = stage and column 1 = type.
This pattern generalizes: any number of independent categorical
classifiers can be packed into one flat logits tensor + a target
matrix, as long as compute_loss knows the layout. Generalizing is
overkill for two heads though; we keep the structure surface-specific.
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class FraudPatternHeadConfig:
"""Two-head categorical config.
Attributes:
name: identifier for logging.
num_stages: 5 for the Fraud surface.
num_types: 4 for the Fraud surface.
hidden_dim: backbone hidden size (1024 for LFM2.5-350M).
mlp_hidden: intermediate hidden of the per-head MLP. Each head
owns its own MLP so the representations don't entangle.
dropout: dropout between MLP layers.
num_tx_positions: leading transaction positions to mean-pool over.
stage_class_weights: optional CE weights per stage band; up-weight
rare classes (PROBING / MONETIZATION / EXFILTRATION).
type_class_weights: optional CE weights per type band.
"""
name: str
num_stages: int = 5
num_types: int = 4
hidden_dim: int = 1024
mlp_hidden: int = 256
dropout: float = 0.1
num_tx_positions: int = 64
stage_class_weights: list[float] | None = None
type_class_weights: list[float] | None = None
class FraudPatternHead(nn.Module):
"""Two MLPs (stage + type) over a shared mean-pooled representation.
Parameter count at defaults (D=1024, mlp_hidden=256): ~265K + 265K =
~530K. Same order of magnitude as the dispute probability head, well
below the LoRA budget.
"""
def __init__(self, config: FraudPatternHeadConfig) -> None:
super().__init__()
self.config = config
self.stage_mlp = nn.Sequential(
nn.Linear(config.hidden_dim, config.mlp_hidden),
nn.ReLU(),
nn.Dropout(config.dropout),
nn.Linear(config.mlp_hidden, config.num_stages),
)
self.type_mlp = nn.Sequential(
nn.Linear(config.hidden_dim, config.mlp_hidden),
nn.ReLU(),
nn.Dropout(config.dropout),
nn.Linear(config.mlp_hidden, config.num_types),
)
if config.stage_class_weights is not None:
if len(config.stage_class_weights) != config.num_stages:
raise ValueError(
f"stage_class_weights must have length {config.num_stages}, "
f"got {len(config.stage_class_weights)}",
)
self.register_buffer(
"stage_weights",
torch.tensor(config.stage_class_weights, dtype=torch.float32),
)
else:
self.register_buffer("stage_weights", None, persistent=False)
if config.type_class_weights is not None:
if len(config.type_class_weights) != config.num_types:
raise ValueError(
f"type_class_weights must have length {config.num_types}, "
f"got {len(config.type_class_weights)}",
)
self.register_buffer(
"type_weights",
torch.tensor(config.type_class_weights, dtype=torch.float32),
)
else:
self.register_buffer("type_weights", None, persistent=False)
def pool(self, hidden_states: torch.Tensor) -> torch.Tensor:
head_dtype = next(self.stage_mlp.parameters()).dtype
tx_slice = hidden_states[:, : self.config.num_tx_positions, :].to(head_dtype)
return tx_slice.mean(dim=1)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Returns (B, num_stages + num_types) flat logits.
The model class slices via `stage_logits` and `type_logits`
properties on the result tensor. Compute_loss splits internally.
"""
pooled = self.pool(hidden_states)
stage_logits = self.stage_mlp(pooled) # (B, num_stages)
type_logits = self.type_mlp(pooled) # (B, num_types)
return torch.cat([stage_logits, type_logits], dim=-1) # (B, S+T)
def split_logits(
self, logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Slice the flat (B, S+T) tensor back into (stage, type)."""
ns = self.config.num_stages
return logits[..., :ns], logits[..., ns:]
def compute_loss(
self,
logits: torch.Tensor,
targets: torch.Tensor,
) -> torch.Tensor:
"""Sum of CE on stage + CE on type, mean-reduced.
Args:
logits: (B, num_stages + num_types).
targets: (B, 2) int64 — col 0 stage idx, col 1 type idx.
Returns:
Scalar loss = stage_CE + type_CE. Equal weight by default;
per-class weights up-weight rare bands via the registered
buffers.
"""
stage_logits, type_logits = self.split_logits(logits)
stage_targets = targets[:, 0].long()
type_targets = targets[:, 1].long()
stage_loss = F.cross_entropy(
stage_logits, stage_targets,
weight=self.stage_weights, reduction="mean",
)
type_loss = F.cross_entropy(
type_logits, type_targets,
weight=self.type_weights, reduction="mean",
)
return stage_loss + type_loss
@torch.no_grad()
def score(self, logits: torch.Tensor) -> torch.Tensor:
"""For the headline UI: the max probability across the
non-DORMANT stages. Higher = more confident the customer is
under active attack. For demo logging only; the actual UI
consumes split logits directly.
"""
stage_logits, _ = self.split_logits(logits)
stage_probs = F.softmax(stage_logits, dim=-1) # (B, num_stages)
# Last class index is DORMANT (=4); the active-attack score is
# 1 - P(DORMANT).
return 1.0 - stage_probs[..., -1]
@torch.no_grad()
def predict_band(self, logits: torch.Tensor) -> torch.Tensor:
"""For uniformity with ProbabilityHead.predict_band. Returns
(B, 2) int64 — predicted stage + type."""
stage_logits, type_logits = self.split_logits(logits)
return torch.stack([
stage_logits.argmax(dim=-1),
type_logits.argmax(dim=-1),
], dim=-1)
@torch.no_grad()
def stage_probabilities(self, logits: torch.Tensor) -> torch.Tensor:
stage_logits, _ = self.split_logits(logits)
return F.softmax(stage_logits, dim=-1)
@torch.no_grad()
def type_probabilities(self, logits: torch.Tensor) -> torch.Tensor:
_, type_logits = self.split_logits(logits)
return F.softmax(type_logits, dim=-1)
def num_parameters(self) -> int:
return sum(p.numel() for p in self.parameters())