| """Inference plumbing for the Fraud Pattern Co-Pilot demo. |
| |
| Mirror of copilot_inference_collections.py adapted to the two-head |
| Fraud output: |
| - probability head returns (B, 9) flat logits = 5 stage + 4 type |
| - we split into per-stage and per-type softmax distributions |
| - the reasoning template grounds in (stage, type, cross-position signals) |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Iterator |
|
|
| import numpy as np |
| import torch |
| import yaml |
|
|
| from src.data.schema import SchemaConfig, load_schema |
|
|
| from encoder.src.data.mixed_modality import MixedModalityBatch, tokenize_texts |
| from encoder.src.data.synthetic_fraud_pattern import ( |
| NUM_STAGES, |
| NUM_TYPES, |
| STAGE_DORMANT, |
| STAGE_EXFILTRATION, |
| STAGE_MONETIZATION, |
| STAGE_NAMES, |
| STAGE_PRE_ATTACK, |
| STAGE_PROBING, |
| TYPE_ACCOUNT_TAKEOVER, |
| TYPE_DECLINED_LEGIT, |
| TYPE_NAMES, |
| TYPE_SCAM_REDIRECTED, |
| TYPE_VICTIM_FRAUD, |
| signal_novel_device, |
| signal_post_attack_density, |
| signal_probe_density, |
| signal_recent_authorize_density, |
| signal_signature_clean, |
| ) |
| from encoder.src.model.transaction_fm_multisurface import ( |
| TransactionMultiSurfaceModel, |
| build_transaction_multisurface, |
| ) |
|
|
|
|
| DEMO_SEED = 42 |
|
|
|
|
| |
| |
| |
| STAGE_COLORS: dict[int, str] = { |
| STAGE_PRE_ATTACK: "#eab308", |
| STAGE_PROBING: "#f59e0b", |
| STAGE_MONETIZATION: "#ef4444", |
| STAGE_EXFILTRATION: "#b91c1c", |
| STAGE_DORMANT: "#22c55e", |
| } |
|
|
| |
| TYPE_COLORS: dict[int, str] = { |
| TYPE_VICTIM_FRAUD: "#7c3aed", |
| TYPE_ACCOUNT_TAKEOVER: "#dc2626", |
| TYPE_SCAM_REDIRECTED: "#a855f7", |
| TYPE_DECLINED_LEGIT: "#22c55e", |
| } |
|
|
|
|
| @dataclass |
| class FraudPatternCastMember: |
| pattern: str |
| display_name: str |
| customer_idx: int |
| flagged_idx: int |
| stage_label: int |
| stage_label_name: str |
| type_label: int |
| type_label_name: str |
| description: str |
| context_text: str |
|
|
|
|
| @dataclass |
| class FraudPatternResult: |
| """One inference call result. |
| |
| Attributes: |
| stage_probs: (NUM_STAGES,) softmax over stage logits. |
| type_probs: (NUM_TYPES,) softmax over type logits. |
| predicted_stage: argmax stage index. |
| predicted_type: argmax type index. |
| attribution_probs: (64,) per-position contribution. |
| top_k_positions: (k,) sorted top contributors. |
| flagged_idx: the upstream-flagged position (passed through). |
| """ |
|
|
| stage_probs: np.ndarray |
| type_probs: np.ndarray |
| predicted_stage: int |
| predicted_type: int |
| attribution_probs: np.ndarray |
| top_k_positions: np.ndarray |
| flagged_idx: int |
|
|
|
|
| class FraudPatternCopilotModel: |
| """Encapsulates the Fraud Pattern multi-surface model + cast + tokenizer.""" |
|
|
| def __init__( |
| self, |
| model: TransactionMultiSurfaceModel, |
| schema: SchemaConfig, |
| histories: np.ndarray, |
| cast: list[FraudPatternCastMember], |
| device: torch.device, |
| ) -> None: |
| self.model = model |
| self.schema = schema |
| self.histories = histories |
| self.cast = cast |
| self.device = device |
| self.tokenizer = model.backbone.tokenizer |
| self.model.eval() |
|
|
| @classmethod |
| def from_paths( |
| cls, |
| checkpoint_path: Path, |
| model_config_path: Path, |
| schema_path: Path, |
| histories_path: Path, |
| cast_path: Path, |
| device: torch.device = torch.device("cpu"), |
| ) -> "FraudPatternCopilotModel": |
| schema = load_schema(schema_path) |
| histories = np.load(histories_path, mmap_mode="r") |
| cast = _load_cast(cast_path) |
| with model_config_path.open() as f: |
| mcfg = yaml.safe_load(f) |
| dtype = torch.float32 if device.type == "cpu" else torch.bfloat16 |
| model = build_transaction_multisurface( |
| schema=schema, |
| model_path=mcfg["backbone"]["hf_path"], |
| encoder_cfg=mcfg.get("encoder"), |
| projector_cfg=mcfg.get("projector"), |
| head_cfg=mcfg.get("heads"), |
| lora_cfg=mcfg["backbone"].get("lora"), |
| dtype=dtype, |
| device_map=None if device.type == "cpu" else "auto", |
| ) |
| ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) |
| if not ckpt.get("model_state_dict_slim"): |
| raise ValueError(f"Expected a slim checkpoint at {checkpoint_path}.") |
| state = { |
| k: v.to(dtype) if v.is_floating_point() else v |
| for k, v in ckpt["model_state_dict"].items() |
| } |
| missing, unexpected = model.load_state_dict(state, strict=False) |
| if unexpected: |
| raise RuntimeError(f"Unexpected state_dict keys: {unexpected[:5]} ...") |
| model.to(device) |
| return cls(model=model, schema=schema, histories=histories, |
| cast=cast, device=device) |
|
|
| @torch.inference_mode() |
| def predict( |
| self, |
| member: FraudPatternCastMember, |
| top_k: int = 5, |
| ) -> FraudPatternResult: |
| torch.manual_seed(DEMO_SEED) |
| batch = self._build_batch(member) |
| out = self.model.predict(batch) |
| |
| flat = out["probability_logits"][0].float().cpu() |
| stage_logits = flat[:NUM_STAGES] |
| type_logits = flat[NUM_STAGES:] |
| stage_probs = torch.softmax(stage_logits, dim=-1).numpy() |
| type_probs = torch.softmax(type_logits, dim=-1).numpy() |
| attr_logits = out["attribution_logits"][0].float().cpu() |
| attr_probs = torch.sigmoid(attr_logits).numpy() |
| top_k_positions = torch.topk(attr_logits, k=top_k, dim=-1).indices.numpy() |
|
|
| return FraudPatternResult( |
| stage_probs=stage_probs, |
| type_probs=type_probs, |
| predicted_stage=int(np.argmax(stage_probs)), |
| predicted_type=int(np.argmax(type_probs)), |
| attribution_probs=attr_probs, |
| top_k_positions=top_k_positions, |
| flagged_idx=member.flagged_idx, |
| ) |
|
|
| def build_reasoning_text( |
| self, |
| member: FraudPatternCastMember, |
| result: FraudPatternResult, |
| ) -> str: |
| history = np.asarray(self.histories[member.customer_idx]) |
| return _render_reasoning(history, member.flagged_idx, result) |
|
|
| def stream_reasoning( |
| self, |
| member: FraudPatternCastMember, |
| result: FraudPatternResult, |
| chunk_chars: int = 6, |
| ) -> Iterator[str]: |
| text = self.build_reasoning_text(member, result) |
| for i in range(chunk_chars, len(text) + chunk_chars, chunk_chars): |
| yield text[: min(i, len(text))] |
|
|
| def _build_batch(self, member: FraudPatternCastMember) -> MixedModalityBatch: |
| history = np.asarray(self.histories[member.customer_idx]).copy() |
| feature_ids = torch.from_numpy(history).long().unsqueeze(0).to(self.device) |
| input_ids, attn_mask, lengths = tokenize_texts( |
| self.tokenizer, [member.context_text], max_length=256, |
| ) |
| input_ids = input_ids.to(self.device) |
| attn_mask = attn_mask.to(self.device) |
| lengths = lengths.to(self.device) |
| flagged_idx = torch.tensor( |
| [member.flagged_idx], dtype=torch.long, device=self.device, |
| ) |
| return MixedModalityBatch( |
| feature_ids=feature_ids, |
| text_input_ids=input_ids, |
| text_attention_mask=attn_mask, |
| text_lengths=lengths, |
| head_target="probability", |
| disputed_idx=flagged_idx, |
| ) |
|
|
|
|
| def _load_cast(cast_path: Path) -> list[FraudPatternCastMember]: |
| payload = json.loads(cast_path.read_text()) |
| return [ |
| FraudPatternCastMember( |
| pattern=m["pattern"], |
| display_name=m["display_name"], |
| customer_idx=int(m["customer_idx"]), |
| flagged_idx=int(m["flagged_idx"]), |
| stage_label=int(m["stage_label"]), |
| stage_label_name=m["stage_label_name"], |
| type_label=int(m["type_label"]), |
| type_label_name=m["type_label_name"], |
| description=m["description"], |
| context_text=m["context_text"], |
| ) |
| for m in payload["cast"] |
| ] |
|
|
|
|
| def _render_reasoning( |
| history: np.ndarray, |
| flagged_idx: int, |
| result: FraudPatternResult, |
| ) -> str: |
| probe_count = signal_probe_density(history, flagged_idx) |
| post_count = signal_post_attack_density(history, flagged_idx) |
| novel_device = signal_novel_device(history, flagged_idx) |
| sig_clean = signal_signature_clean(history, flagged_idx) |
| recent_auth = signal_recent_authorize_density(history, flagged_idx) |
|
|
| stage = result.predicted_stage |
| ptype = result.predicted_type |
| stage_p = float(result.stage_probs[stage]) |
| type_p = float(result.type_probs[ptype]) |
|
|
| parts: list[str] = [] |
| parts.append( |
| f"Verdict: stage={STAGE_NAMES[stage]} (P={stage_p:.2f}), " |
| f"type={TYPE_NAMES[ptype]} (P={type_p:.2f})." |
| ) |
| parts.append( |
| f"Cross-position signals — probe-cluster: {probe_count} small-CNP " |
| f"preceding tx{flagged_idx}, post-attack density: {post_count} large " |
| f"unfamiliar charges around the flag, novel-device: {novel_device}, " |
| f"signature-clean: {sig_clean}, recent-authorize density: {recent_auth}." |
| ) |
|
|
| if stage == STAGE_PROBING: |
| parts.append( |
| "Probing pattern detected. The attacker is testing whether the card " |
| "works via small charges before escalating. Recommend immediate " |
| "containment + step-up auth on subsequent transactions." |
| ) |
| elif stage == STAGE_MONETIZATION: |
| parts.append( |
| "Monetization pattern: probes succeeded and the attacker is " |
| "converting access into value. The flagged transaction is the " |
| "first big charge after the probe phase. Recommend full block + " |
| "customer outreach within the hour." |
| ) |
| elif stage == STAGE_EXFILTRATION: |
| parts.append( |
| "Exfiltration pattern: multiple large unfamiliar charges around " |
| "the flag indicate a mature attack. Freeze card + notify customer " |
| "immediately. Damage may already be material." |
| ) |
| elif stage == STAGE_DORMANT: |
| parts.append( |
| "Dormant pattern: the flagged tx sits inside the customer's " |
| "normal signature. Upstream detector is most likely false-positive. " |
| "Release with low-priority follow-up." |
| ) |
| else: |
| parts.append( |
| "Pre-attack stage: a single anomalous transaction with no chain " |
| "evidence yet. Recommend step-up auth and observe next 24 hours." |
| ) |
|
|
| if ptype == TYPE_ACCOUNT_TAKEOVER: |
| parts.append( |
| "Type=account_takeover: the device fingerprint at the flag is one " |
| "the customer has not used elsewhere. Treat as credential/device " |
| "compromise — disable existing sessions and reset auth." |
| ) |
| elif ptype == TYPE_VICTIM_FRAUD: |
| parts.append( |
| "Type=victim_fraud: device matches the customer's own, but the " |
| "transaction pattern is anomalous. Likely the customer was tricked " |
| "into authorizing — engage with care, the customer's perception " |
| "may be that this is legitimate." |
| ) |
| elif ptype == TYPE_SCAM_REDIRECTED: |
| parts.append( |
| "Type=scam_redirected: recent history shows many customer-authorized " |
| "CNP payments to unfamiliar merchants. Consistent with romance / " |
| "impostor scam. Engage customer-care team, not fraud team." |
| ) |
| else: |
| parts.append( |
| "Type=declined_legit: the flagged tx looks like a normal customer " |
| "purchase. The upstream score is plausibly a rules false-positive." |
| ) |
|
|
| top_pos = ", ".join(str(int(p)) for p in result.top_k_positions[:5]) |
| parts.append( |
| f"Top contributing transactions: positions {top_pos}." |
| ) |
|
|
| return " ".join(parts) |
|
|