"""Inference plumbing for the Fraud Pattern Co-Pilot demo. Mirror of copilot_inference_collections.py adapted to the two-head Fraud output: - probability head returns (B, 9) flat logits = 5 stage + 4 type - we split into per-stage and per-type softmax distributions - the reasoning template grounds in (stage, type, cross-position signals) """ from __future__ import annotations import json from dataclasses import dataclass from pathlib import Path from typing import Iterator import numpy as np import torch import yaml from src.data.schema import SchemaConfig, load_schema from encoder.src.data.mixed_modality import MixedModalityBatch, tokenize_texts from encoder.src.data.synthetic_fraud_pattern import ( NUM_STAGES, NUM_TYPES, STAGE_DORMANT, STAGE_EXFILTRATION, STAGE_MONETIZATION, STAGE_NAMES, STAGE_PRE_ATTACK, STAGE_PROBING, TYPE_ACCOUNT_TAKEOVER, TYPE_DECLINED_LEGIT, TYPE_NAMES, TYPE_SCAM_REDIRECTED, TYPE_VICTIM_FRAUD, signal_novel_device, signal_post_attack_density, signal_probe_density, signal_recent_authorize_density, signal_signature_clean, ) from encoder.src.model.transaction_fm_multisurface import ( TransactionMultiSurfaceModel, build_transaction_multisurface, ) DEMO_SEED = 42 # Per-stage display color. Severity scale: DORMANT (green, OK) → # PRE_ATTACK (yellow) → PROBING (orange) → MONETIZATION (red) → # EXFILTRATION (dark red). STAGE_COLORS: dict[int, str] = { STAGE_PRE_ATTACK: "#eab308", # yellow STAGE_PROBING: "#f59e0b", # amber STAGE_MONETIZATION: "#ef4444", # red STAGE_EXFILTRATION: "#b91c1c", # dark red STAGE_DORMANT: "#22c55e", # green } # Per-type display color. Distinguish by attack-type semantics. TYPE_COLORS: dict[int, str] = { TYPE_VICTIM_FRAUD: "#7c3aed", # purple — customer manipulated TYPE_ACCOUNT_TAKEOVER: "#dc2626", # red — compromise TYPE_SCAM_REDIRECTED: "#a855f7", # violet — scam TYPE_DECLINED_LEGIT: "#22c55e", # green — false positive } @dataclass class FraudPatternCastMember: pattern: str display_name: str customer_idx: int flagged_idx: int stage_label: int stage_label_name: str type_label: int type_label_name: str description: str context_text: str @dataclass class FraudPatternResult: """One inference call result. Attributes: stage_probs: (NUM_STAGES,) softmax over stage logits. type_probs: (NUM_TYPES,) softmax over type logits. predicted_stage: argmax stage index. predicted_type: argmax type index. attribution_probs: (64,) per-position contribution. top_k_positions: (k,) sorted top contributors. flagged_idx: the upstream-flagged position (passed through). """ stage_probs: np.ndarray type_probs: np.ndarray predicted_stage: int predicted_type: int attribution_probs: np.ndarray top_k_positions: np.ndarray flagged_idx: int class FraudPatternCopilotModel: """Encapsulates the Fraud Pattern multi-surface model + cast + tokenizer.""" def __init__( self, model: TransactionMultiSurfaceModel, schema: SchemaConfig, histories: np.ndarray, cast: list[FraudPatternCastMember], device: torch.device, ) -> None: self.model = model self.schema = schema self.histories = histories self.cast = cast self.device = device self.tokenizer = model.backbone.tokenizer self.model.eval() @classmethod def from_paths( cls, checkpoint_path: Path, model_config_path: Path, schema_path: Path, histories_path: Path, cast_path: Path, device: torch.device = torch.device("cpu"), ) -> "FraudPatternCopilotModel": schema = load_schema(schema_path) histories = np.load(histories_path, mmap_mode="r") cast = _load_cast(cast_path) with model_config_path.open() as f: mcfg = yaml.safe_load(f) dtype = torch.float32 if device.type == "cpu" else torch.bfloat16 model = build_transaction_multisurface( schema=schema, model_path=mcfg["backbone"]["hf_path"], encoder_cfg=mcfg.get("encoder"), projector_cfg=mcfg.get("projector"), head_cfg=mcfg.get("heads"), lora_cfg=mcfg["backbone"].get("lora"), dtype=dtype, device_map=None if device.type == "cpu" else "auto", ) ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) if not ckpt.get("model_state_dict_slim"): raise ValueError(f"Expected a slim checkpoint at {checkpoint_path}.") state = { k: v.to(dtype) if v.is_floating_point() else v for k, v in ckpt["model_state_dict"].items() } missing, unexpected = model.load_state_dict(state, strict=False) if unexpected: raise RuntimeError(f"Unexpected state_dict keys: {unexpected[:5]} ...") model.to(device) return cls(model=model, schema=schema, histories=histories, cast=cast, device=device) @torch.inference_mode() def predict( self, member: FraudPatternCastMember, top_k: int = 5, ) -> FraudPatternResult: torch.manual_seed(DEMO_SEED) batch = self._build_batch(member) out = self.model.predict(batch) # prob_logits: (1, 9). First 5 = stage logits, last 4 = type logits. flat = out["probability_logits"][0].float().cpu() stage_logits = flat[:NUM_STAGES] type_logits = flat[NUM_STAGES:] stage_probs = torch.softmax(stage_logits, dim=-1).numpy() type_probs = torch.softmax(type_logits, dim=-1).numpy() attr_logits = out["attribution_logits"][0].float().cpu() attr_probs = torch.sigmoid(attr_logits).numpy() top_k_positions = torch.topk(attr_logits, k=top_k, dim=-1).indices.numpy() return FraudPatternResult( stage_probs=stage_probs, type_probs=type_probs, predicted_stage=int(np.argmax(stage_probs)), predicted_type=int(np.argmax(type_probs)), attribution_probs=attr_probs, top_k_positions=top_k_positions, flagged_idx=member.flagged_idx, ) def build_reasoning_text( self, member: FraudPatternCastMember, result: FraudPatternResult, ) -> str: history = np.asarray(self.histories[member.customer_idx]) return _render_reasoning(history, member.flagged_idx, result) def stream_reasoning( self, member: FraudPatternCastMember, result: FraudPatternResult, chunk_chars: int = 6, ) -> Iterator[str]: text = self.build_reasoning_text(member, result) for i in range(chunk_chars, len(text) + chunk_chars, chunk_chars): yield text[: min(i, len(text))] def _build_batch(self, member: FraudPatternCastMember) -> MixedModalityBatch: history = np.asarray(self.histories[member.customer_idx]).copy() feature_ids = torch.from_numpy(history).long().unsqueeze(0).to(self.device) input_ids, attn_mask, lengths = tokenize_texts( self.tokenizer, [member.context_text], max_length=256, ) input_ids = input_ids.to(self.device) attn_mask = attn_mask.to(self.device) lengths = lengths.to(self.device) flagged_idx = torch.tensor( [member.flagged_idx], dtype=torch.long, device=self.device, ) return MixedModalityBatch( feature_ids=feature_ids, text_input_ids=input_ids, text_attention_mask=attn_mask, text_lengths=lengths, head_target="probability", disputed_idx=flagged_idx, ) def _load_cast(cast_path: Path) -> list[FraudPatternCastMember]: payload = json.loads(cast_path.read_text()) return [ FraudPatternCastMember( pattern=m["pattern"], display_name=m["display_name"], customer_idx=int(m["customer_idx"]), flagged_idx=int(m["flagged_idx"]), stage_label=int(m["stage_label"]), stage_label_name=m["stage_label_name"], type_label=int(m["type_label"]), type_label_name=m["type_label_name"], description=m["description"], context_text=m["context_text"], ) for m in payload["cast"] ] def _render_reasoning( history: np.ndarray, flagged_idx: int, result: FraudPatternResult, ) -> str: probe_count = signal_probe_density(history, flagged_idx) post_count = signal_post_attack_density(history, flagged_idx) novel_device = signal_novel_device(history, flagged_idx) sig_clean = signal_signature_clean(history, flagged_idx) recent_auth = signal_recent_authorize_density(history, flagged_idx) stage = result.predicted_stage ptype = result.predicted_type stage_p = float(result.stage_probs[stage]) type_p = float(result.type_probs[ptype]) parts: list[str] = [] parts.append( f"Verdict: stage={STAGE_NAMES[stage]} (P={stage_p:.2f}), " f"type={TYPE_NAMES[ptype]} (P={type_p:.2f})." ) parts.append( f"Cross-position signals — probe-cluster: {probe_count} small-CNP " f"preceding tx{flagged_idx}, post-attack density: {post_count} large " f"unfamiliar charges around the flag, novel-device: {novel_device}, " f"signature-clean: {sig_clean}, recent-authorize density: {recent_auth}." ) if stage == STAGE_PROBING: parts.append( "Probing pattern detected. The attacker is testing whether the card " "works via small charges before escalating. Recommend immediate " "containment + step-up auth on subsequent transactions." ) elif stage == STAGE_MONETIZATION: parts.append( "Monetization pattern: probes succeeded and the attacker is " "converting access into value. The flagged transaction is the " "first big charge after the probe phase. Recommend full block + " "customer outreach within the hour." ) elif stage == STAGE_EXFILTRATION: parts.append( "Exfiltration pattern: multiple large unfamiliar charges around " "the flag indicate a mature attack. Freeze card + notify customer " "immediately. Damage may already be material." ) elif stage == STAGE_DORMANT: parts.append( "Dormant pattern: the flagged tx sits inside the customer's " "normal signature. Upstream detector is most likely false-positive. " "Release with low-priority follow-up." ) else: parts.append( "Pre-attack stage: a single anomalous transaction with no chain " "evidence yet. Recommend step-up auth and observe next 24 hours." ) if ptype == TYPE_ACCOUNT_TAKEOVER: parts.append( "Type=account_takeover: the device fingerprint at the flag is one " "the customer has not used elsewhere. Treat as credential/device " "compromise — disable existing sessions and reset auth." ) elif ptype == TYPE_VICTIM_FRAUD: parts.append( "Type=victim_fraud: device matches the customer's own, but the " "transaction pattern is anomalous. Likely the customer was tricked " "into authorizing — engage with care, the customer's perception " "may be that this is legitimate." ) elif ptype == TYPE_SCAM_REDIRECTED: parts.append( "Type=scam_redirected: recent history shows many customer-authorized " "CNP payments to unfamiliar merchants. Consistent with romance / " "impostor scam. Engage customer-care team, not fraud team." ) else: parts.append( "Type=declined_legit: the flagged tx looks like a normal customer " "purchase. The upstream score is plausibly a rules false-positive." ) top_pos = ", ".join(str(int(p)) for p in result.top_k_positions[:5]) parts.append( f"Top contributing transactions: positions {top_pos}." ) return " ".join(parts)