lfm2-transaction-encoder / encoder /src /demo /copilot_inference_fraud_pattern.py
cdotsanghvi's picture
initial transaction co-pilot deployment
b3112c7
Raw
History Blame Contribute Delete
12.5 kB
"""Inference plumbing for the Fraud Pattern Co-Pilot demo.
Mirror of copilot_inference_collections.py adapted to the two-head
Fraud output:
- probability head returns (B, 9) flat logits = 5 stage + 4 type
- we split into per-stage and per-type softmax distributions
- the reasoning template grounds in (stage, type, cross-position signals)
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Iterator
import numpy as np
import torch
import yaml
from src.data.schema import SchemaConfig, load_schema
from encoder.src.data.mixed_modality import MixedModalityBatch, tokenize_texts
from encoder.src.data.synthetic_fraud_pattern import (
NUM_STAGES,
NUM_TYPES,
STAGE_DORMANT,
STAGE_EXFILTRATION,
STAGE_MONETIZATION,
STAGE_NAMES,
STAGE_PRE_ATTACK,
STAGE_PROBING,
TYPE_ACCOUNT_TAKEOVER,
TYPE_DECLINED_LEGIT,
TYPE_NAMES,
TYPE_SCAM_REDIRECTED,
TYPE_VICTIM_FRAUD,
signal_novel_device,
signal_post_attack_density,
signal_probe_density,
signal_recent_authorize_density,
signal_signature_clean,
)
from encoder.src.model.transaction_fm_multisurface import (
TransactionMultiSurfaceModel,
build_transaction_multisurface,
)
DEMO_SEED = 42
# Per-stage display color. Severity scale: DORMANT (green, OK) →
# PRE_ATTACK (yellow) → PROBING (orange) → MONETIZATION (red) →
# EXFILTRATION (dark red).
STAGE_COLORS: dict[int, str] = {
STAGE_PRE_ATTACK: "#eab308", # yellow
STAGE_PROBING: "#f59e0b", # amber
STAGE_MONETIZATION: "#ef4444", # red
STAGE_EXFILTRATION: "#b91c1c", # dark red
STAGE_DORMANT: "#22c55e", # green
}
# Per-type display color. Distinguish by attack-type semantics.
TYPE_COLORS: dict[int, str] = {
TYPE_VICTIM_FRAUD: "#7c3aed", # purple — customer manipulated
TYPE_ACCOUNT_TAKEOVER: "#dc2626", # red — compromise
TYPE_SCAM_REDIRECTED: "#a855f7", # violet — scam
TYPE_DECLINED_LEGIT: "#22c55e", # green — false positive
}
@dataclass
class FraudPatternCastMember:
pattern: str
display_name: str
customer_idx: int
flagged_idx: int
stage_label: int
stage_label_name: str
type_label: int
type_label_name: str
description: str
context_text: str
@dataclass
class FraudPatternResult:
"""One inference call result.
Attributes:
stage_probs: (NUM_STAGES,) softmax over stage logits.
type_probs: (NUM_TYPES,) softmax over type logits.
predicted_stage: argmax stage index.
predicted_type: argmax type index.
attribution_probs: (64,) per-position contribution.
top_k_positions: (k,) sorted top contributors.
flagged_idx: the upstream-flagged position (passed through).
"""
stage_probs: np.ndarray
type_probs: np.ndarray
predicted_stage: int
predicted_type: int
attribution_probs: np.ndarray
top_k_positions: np.ndarray
flagged_idx: int
class FraudPatternCopilotModel:
"""Encapsulates the Fraud Pattern multi-surface model + cast + tokenizer."""
def __init__(
self,
model: TransactionMultiSurfaceModel,
schema: SchemaConfig,
histories: np.ndarray,
cast: list[FraudPatternCastMember],
device: torch.device,
) -> None:
self.model = model
self.schema = schema
self.histories = histories
self.cast = cast
self.device = device
self.tokenizer = model.backbone.tokenizer
self.model.eval()
@classmethod
def from_paths(
cls,
checkpoint_path: Path,
model_config_path: Path,
schema_path: Path,
histories_path: Path,
cast_path: Path,
device: torch.device = torch.device("cpu"),
) -> "FraudPatternCopilotModel":
schema = load_schema(schema_path)
histories = np.load(histories_path, mmap_mode="r")
cast = _load_cast(cast_path)
with model_config_path.open() as f:
mcfg = yaml.safe_load(f)
dtype = torch.float32 if device.type == "cpu" else torch.bfloat16
model = build_transaction_multisurface(
schema=schema,
model_path=mcfg["backbone"]["hf_path"],
encoder_cfg=mcfg.get("encoder"),
projector_cfg=mcfg.get("projector"),
head_cfg=mcfg.get("heads"),
lora_cfg=mcfg["backbone"].get("lora"),
dtype=dtype,
device_map=None if device.type == "cpu" else "auto",
)
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
if not ckpt.get("model_state_dict_slim"):
raise ValueError(f"Expected a slim checkpoint at {checkpoint_path}.")
state = {
k: v.to(dtype) if v.is_floating_point() else v
for k, v in ckpt["model_state_dict"].items()
}
missing, unexpected = model.load_state_dict(state, strict=False)
if unexpected:
raise RuntimeError(f"Unexpected state_dict keys: {unexpected[:5]} ...")
model.to(device)
return cls(model=model, schema=schema, histories=histories,
cast=cast, device=device)
@torch.inference_mode()
def predict(
self,
member: FraudPatternCastMember,
top_k: int = 5,
) -> FraudPatternResult:
torch.manual_seed(DEMO_SEED)
batch = self._build_batch(member)
out = self.model.predict(batch)
# prob_logits: (1, 9). First 5 = stage logits, last 4 = type logits.
flat = out["probability_logits"][0].float().cpu()
stage_logits = flat[:NUM_STAGES]
type_logits = flat[NUM_STAGES:]
stage_probs = torch.softmax(stage_logits, dim=-1).numpy()
type_probs = torch.softmax(type_logits, dim=-1).numpy()
attr_logits = out["attribution_logits"][0].float().cpu()
attr_probs = torch.sigmoid(attr_logits).numpy()
top_k_positions = torch.topk(attr_logits, k=top_k, dim=-1).indices.numpy()
return FraudPatternResult(
stage_probs=stage_probs,
type_probs=type_probs,
predicted_stage=int(np.argmax(stage_probs)),
predicted_type=int(np.argmax(type_probs)),
attribution_probs=attr_probs,
top_k_positions=top_k_positions,
flagged_idx=member.flagged_idx,
)
def build_reasoning_text(
self,
member: FraudPatternCastMember,
result: FraudPatternResult,
) -> str:
history = np.asarray(self.histories[member.customer_idx])
return _render_reasoning(history, member.flagged_idx, result)
def stream_reasoning(
self,
member: FraudPatternCastMember,
result: FraudPatternResult,
chunk_chars: int = 6,
) -> Iterator[str]:
text = self.build_reasoning_text(member, result)
for i in range(chunk_chars, len(text) + chunk_chars, chunk_chars):
yield text[: min(i, len(text))]
def _build_batch(self, member: FraudPatternCastMember) -> MixedModalityBatch:
history = np.asarray(self.histories[member.customer_idx]).copy()
feature_ids = torch.from_numpy(history).long().unsqueeze(0).to(self.device)
input_ids, attn_mask, lengths = tokenize_texts(
self.tokenizer, [member.context_text], max_length=256,
)
input_ids = input_ids.to(self.device)
attn_mask = attn_mask.to(self.device)
lengths = lengths.to(self.device)
flagged_idx = torch.tensor(
[member.flagged_idx], dtype=torch.long, device=self.device,
)
return MixedModalityBatch(
feature_ids=feature_ids,
text_input_ids=input_ids,
text_attention_mask=attn_mask,
text_lengths=lengths,
head_target="probability",
disputed_idx=flagged_idx,
)
def _load_cast(cast_path: Path) -> list[FraudPatternCastMember]:
payload = json.loads(cast_path.read_text())
return [
FraudPatternCastMember(
pattern=m["pattern"],
display_name=m["display_name"],
customer_idx=int(m["customer_idx"]),
flagged_idx=int(m["flagged_idx"]),
stage_label=int(m["stage_label"]),
stage_label_name=m["stage_label_name"],
type_label=int(m["type_label"]),
type_label_name=m["type_label_name"],
description=m["description"],
context_text=m["context_text"],
)
for m in payload["cast"]
]
def _render_reasoning(
history: np.ndarray,
flagged_idx: int,
result: FraudPatternResult,
) -> str:
probe_count = signal_probe_density(history, flagged_idx)
post_count = signal_post_attack_density(history, flagged_idx)
novel_device = signal_novel_device(history, flagged_idx)
sig_clean = signal_signature_clean(history, flagged_idx)
recent_auth = signal_recent_authorize_density(history, flagged_idx)
stage = result.predicted_stage
ptype = result.predicted_type
stage_p = float(result.stage_probs[stage])
type_p = float(result.type_probs[ptype])
parts: list[str] = []
parts.append(
f"Verdict: stage={STAGE_NAMES[stage]} (P={stage_p:.2f}), "
f"type={TYPE_NAMES[ptype]} (P={type_p:.2f})."
)
parts.append(
f"Cross-position signals — probe-cluster: {probe_count} small-CNP "
f"preceding tx{flagged_idx}, post-attack density: {post_count} large "
f"unfamiliar charges around the flag, novel-device: {novel_device}, "
f"signature-clean: {sig_clean}, recent-authorize density: {recent_auth}."
)
if stage == STAGE_PROBING:
parts.append(
"Probing pattern detected. The attacker is testing whether the card "
"works via small charges before escalating. Recommend immediate "
"containment + step-up auth on subsequent transactions."
)
elif stage == STAGE_MONETIZATION:
parts.append(
"Monetization pattern: probes succeeded and the attacker is "
"converting access into value. The flagged transaction is the "
"first big charge after the probe phase. Recommend full block + "
"customer outreach within the hour."
)
elif stage == STAGE_EXFILTRATION:
parts.append(
"Exfiltration pattern: multiple large unfamiliar charges around "
"the flag indicate a mature attack. Freeze card + notify customer "
"immediately. Damage may already be material."
)
elif stage == STAGE_DORMANT:
parts.append(
"Dormant pattern: the flagged tx sits inside the customer's "
"normal signature. Upstream detector is most likely false-positive. "
"Release with low-priority follow-up."
)
else:
parts.append(
"Pre-attack stage: a single anomalous transaction with no chain "
"evidence yet. Recommend step-up auth and observe next 24 hours."
)
if ptype == TYPE_ACCOUNT_TAKEOVER:
parts.append(
"Type=account_takeover: the device fingerprint at the flag is one "
"the customer has not used elsewhere. Treat as credential/device "
"compromise — disable existing sessions and reset auth."
)
elif ptype == TYPE_VICTIM_FRAUD:
parts.append(
"Type=victim_fraud: device matches the customer's own, but the "
"transaction pattern is anomalous. Likely the customer was tricked "
"into authorizing — engage with care, the customer's perception "
"may be that this is legitimate."
)
elif ptype == TYPE_SCAM_REDIRECTED:
parts.append(
"Type=scam_redirected: recent history shows many customer-authorized "
"CNP payments to unfamiliar merchants. Consistent with romance / "
"impostor scam. Engage customer-care team, not fraud team."
)
else:
parts.append(
"Type=declined_legit: the flagged tx looks like a normal customer "
"purchase. The upstream score is plausibly a rules false-positive."
)
top_pos = ", ".join(str(int(p)) for p in result.top_k_positions[:5])
parts.append(
f"Top contributing transactions: positions {top_pos}."
)
return " ".join(parts)