temporal-twins-code / scripts /train_node_benchmark.py
temporal-twins-anon's picture
Add anonymous Temporal Twins code release
a3682cf verified
"""
UPI-Sim Benchmark Runner
=========================
Node-level temporal fraud risk prediction benchmark.
Runs: 3 difficulties × 5 seeds × (TGN + GNN + baselines + ablations)
Reports: mean ± std for ROC-AUC, PR-AUC, Brier Score
"""
import os
import sys
import pickle
import time
import torch
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score, average_precision_score, brier_score_loss
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import StandardScaler
from src.core.config_loader import load_config
from src.generators.user_generator import generate_users
from src.generators.transaction_generator import generate_transactions
from src.fraud.fraud_engine import FraudEngine
from src.risk.risk_engine import apply_risk_engine
from src.graph.dataset_builder import build_graph_dataset
from src.tgn.train import train_tgn
from src.tgn.memory import Memory
from src.tgn.time_encoding import TimeEncoding
from src.gnn.train import train_gnn
# =========================
# HELPERS
# =========================
def temporal_split(df, train_ratio=0.7):
df = df.sort_values("timestamp")
split_time = df["timestamp"].quantile(train_ratio)
past = df[df["timestamp"] <= split_time]
return past, split_time
def build_node_features(df_past, all_nodes):
# Zero features — all static signal is intentionally removed.
# Only TGN temporal memory can distinguish fraud users.
return np.zeros((len(all_nodes), 2), dtype=np.float32)
def build_node_labels(df, split_time, all_nodes, horizon=0.05):
t_end = df["timestamp"].max()
window_end = split_time + horizon * (t_end - split_time)
future = df[(df["timestamp"] > split_time) & (df["timestamp"] <= window_end)]
fraud = future.groupby("sender_id")["is_fraud"].max()
return np.array([fraud.get(u, 0) for u in all_nodes], dtype=np.float32)
def compute_ece(y_true, y_prob, n_bins=10):
"""Expected Calibration Error."""
bins = np.linspace(0, 1, n_bins + 1)
ece = 0.0
for lo, hi in zip(bins[:-1], bins[1:]):
mask = (y_prob >= lo) & (y_prob < hi)
if mask.sum() == 0:
continue
frac = mask.sum() / len(y_true)
avg_conf = y_prob[mask].mean()
avg_acc = y_true[mask].mean()
ece += frac * abs(avg_conf - avg_acc)
return ece
def evaluate_metrics(y_true, y_prob):
"""Compute ROC-AUC, PR-AUC, Brier, ECE, Expected Cost."""
cost_fn = lambda y, p: (
(y == 1) * (1 - p) * 5 # missed fraud cost
+ (y == 0) * p * 1 # false positive cost
)
expected_cost = cost_fn(y_true, y_prob).mean()
return {
"roc": roc_auc_score(y_true, y_prob),
"pr": average_precision_score(y_true, y_prob),
"brier": brier_score_loss(y_true, y_prob),
"ece": compute_ece(y_true, y_prob),
"cost": expected_cost,
}
# =========================
# TGN NODE CLASSIFIER
# =========================
def train_node_classifier(model, memory, x_node, y_node, num_epochs=100):
device = torch.device("cpu")
x = torch.tensor(x_node, dtype=torch.float32).to(device)
x = (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-6)
y = torch.tensor(y_node, dtype=torch.float32).to(device)
for param in model.parameters():
param.requires_grad = False
for param in model.node_classifier.parameters():
param.requires_grad = True
optimizer = torch.optim.Adam(model.node_classifier.parameters(), lr=1e-3)
pw = torch.clamp((y == 0).sum().float() / (y == 1).sum().float(), max=10.0)
loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=pw)
model.train()
for epoch in range(num_epochs):
node_emb = memory.memory.detach()
combined = torch.cat([node_emb, x], dim=1)
logits = model.node_classifier(combined).squeeze(-1)
loss = loss_fn(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
for param in model.parameters():
param.requires_grad = True
def evaluate_tgn_node(model, memory, x_node, y_node, ablation=None):
device = torch.device("cpu")
x = torch.tensor(x_node, dtype=torch.float32).to(device)
x = (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-6)
y_true = y_node.copy()
model.eval()
with torch.no_grad():
node_emb = memory.memory.clone()
# Ablations
if ablation == "no_memory":
node_emb = torch.zeros_like(node_emb)
if ablation == "no_features":
x = torch.zeros_like(x)
combined = torch.cat([node_emb, x], dim=1)
logits = model.node_classifier(combined).squeeze(-1)
probs = torch.sigmoid(logits).cpu().numpy()
return evaluate_metrics(y_true, probs)
def evaluate_gnn_node(model, graph_data, x_node, y_node):
device = torch.device("cpu")
edge_index = torch.tensor(graph_data["edge_index"], dtype=torch.long).to(device)
edge_attr = torch.tensor(graph_data["edge_attr"], dtype=torch.float32).to(device)
edge_attr = (edge_attr - edge_attr.mean(dim=0)) / (edge_attr.std(dim=0) + 1e-6)
x = torch.tensor(x_node, dtype=torch.float32).to(device)
x = (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-6)
y_true = y_node.copy()
model.eval()
with torch.no_grad():
edge_logits = model(x, edge_index, edge_attr, edge_index[0], edge_index[1])
edge_probs = torch.sigmoid(edge_logits)
node_scores = torch.zeros(x.shape[0], device=device)
node_scores.index_add_(0, edge_index[0], edge_probs)
deg = torch.bincount(edge_index[0], minlength=x.shape[0]).float() + 1e-6
node_scores = node_scores / deg
return evaluate_metrics(y_true, node_scores.cpu().numpy())
# =========================
# BASELINES
# =========================
def run_baselines(x_node, y_node):
scaler = StandardScaler()
X = scaler.fit_transform(x_node)
y = y_node
results = {}
# Logistic Regression
lr = LogisticRegression(max_iter=500, class_weight="balanced")
lr.fit(X, y)
probs_lr = lr.predict_proba(X)[:, 1]
results["LogReg"] = evaluate_metrics(y, probs_lr)
# XGBoost (GradientBoosting)
xgb = GradientBoostingClassifier(n_estimators=100, max_depth=4, random_state=42)
xgb.fit(X, y)
probs_xgb = xgb.predict_proba(X)[:, 1]
results["XGBoost"] = evaluate_metrics(y, probs_xgb)
# MLP
mlp = MLPClassifier(hidden_layer_sizes=(64, 32), max_iter=300, random_state=42)
mlp.fit(X, y)
probs_mlp = mlp.predict_proba(X)[:, 1]
results["MLP"] = evaluate_metrics(y, probs_mlp)
return results
# =========================
# SINGLE DIFFICULTY RUN
# =========================
def run_single(difficulty, config, users, seed=42):
"""Run one seed for one difficulty. Returns dict of all metrics."""
torch.manual_seed(seed)
np.random.seed(seed)
df = generate_transactions(users, config)
df = apply_risk_engine(df, users, config)
engine = FraudEngine(seed=seed, difficulty=difficulty)
df = engine.apply(df)
df = df.sort_values("timestamp").reset_index(drop=True)
graph_data = build_graph_dataset(df, users)
past, split_time = temporal_split(df)
all_nodes = sorted(df["sender_id"].unique())
x_node = build_node_features(past, all_nodes)
y_node = build_node_labels(df, split_time, all_nodes, horizon=0.05)
node_fraud = y_node.mean()
results = {"node_fraud": node_fraud}
# ----- TGN -----
tgn_model, memory, _, _ = train_tgn(graph_data, num_epochs=3)
train_node_classifier(tgn_model, memory, x_node, y_node, num_epochs=100)
results["TGN"] = evaluate_tgn_node(tgn_model, memory, x_node, y_node)
# ----- TGN Ablations -----
results["TGN-no-mem"] = evaluate_tgn_node(tgn_model, memory, x_node, y_node, ablation="no_memory")
results["TGN-no-feat"] = evaluate_tgn_node(tgn_model, memory, x_node, y_node, ablation="no_features")
# ----- GNN -----
gnn_model = train_gnn(graph_data)
results["GNN"] = evaluate_gnn_node(gnn_model, graph_data, x_node, y_node)
# ----- Baselines -----
baseline_results = run_baselines(x_node, y_node)
results.update(baseline_results)
return results
# =========================
# MAIN
# =========================
SEEDS = [42, 43, 44, 45, 46]
DIFFICULTIES = ["easy", "medium", "hard"]
MODELS = ["TGN", "TGN-no-mem", "TGN-no-feat", "GNN", "LogReg", "XGBoost", "MLP"]
METRICS = ["roc", "pr", "brier", "ece", "cost"]
def main():
config = load_config("config/default.yaml")
users = generate_users(config)
# Store all results: {difficulty: {model: {metric: [values]}}}
all_results = {}
for diff in DIFFICULTIES:
all_results[diff] = {m: {k: [] for k in METRICS} for m in MODELS}
fraud_rates = []
for seed in SEEDS:
print(f"\n{'='*50}")
print(f" {diff.upper()} | seed={seed}")
print(f"{'='*50}")
r = run_single(diff, config, users, seed=seed)
fraud_rates.append(r["node_fraud"])
for model in MODELS:
for metric in METRICS:
all_results[diff][model][metric].append(r[model][metric])
avg_fraud = np.mean(fraud_rates)
print(f"\n {diff} avg node fraud: {avg_fraud:.1%}")
# ===========================
# PRINT RESULTS TABLE
# ===========================
print("\n")
print("=" * 100)
print(" UPI-Sim BENCHMARK: Node-Level Fraud Risk Prediction")
print(" Task: predict user fraud in future window | 5 seeds | mean ± std")
print("=" * 100)
for diff in DIFFICULTIES:
fraud_avg = np.mean([all_results[diff][MODELS[0]]["roc"]]) # just for header
print(f"\n--- {diff.upper()} ---")
print(f"{'Model':<14} {'ROC-AUC':>14} {'PR-AUC':>14} {'Brier':>14} {'ECE':>14} {'Cost':>14}")
print("-" * 88)
for model in MODELS:
row = []
for metric in METRICS:
vals = all_results[diff][model][metric]
m, s = np.mean(vals), np.std(vals)
row.append(f"{m:.4f}±{s:.4f}")
print(f"{model:<14} {row[0]:>14} {row[1]:>14} {row[2]:>14} {row[3]:>14} {row[4]:>14}")
# ===========================
# TGN GAP SUMMARY (SCALING LAW)
# ===========================
print(f"\n{'='*65}")
print(f" DIFFICULTY SCALING LAW: TGN Advantage (Δ ROC-AUC)")
print(f"{'='*65}")
print(f"{'Difficulty':<14} | {'Δ(TGN - GNN)':>15} | {'Δ(TGN - XGBoost)':>15}")
print("-" * 52)
for diff in DIFFICULTIES:
tgn_rocs = all_results[diff]["TGN"]["roc"]
gnn_rocs = all_results[diff]["GNN"]["roc"]
xgb_rocs = all_results[diff]["XGBoost"]["roc"]
gaps_gnn = [t - g for t, g in zip(tgn_rocs, gnn_rocs)]
gaps_xgb = [t - x for t, x in zip(tgn_rocs, xgb_rocs)]
gnn_str = f"{np.mean(gaps_gnn):+.4f} ± {np.std(gaps_gnn):.4f}"
xgb_str = f"{np.mean(gaps_xgb):+.4f} ± {np.std(gaps_xgb):.4f}"
print(f"{diff:<14} | {gnn_str:>15} | {xgb_str:>15}")
if __name__ == "__main__":
main()