Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files- app.py +394 -0
- requirements.txt +7 -0
app.py
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import tempfile
|
| 4 |
+
from typing import List, Dict, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from huggingface_hub import hf_hub_download
|
| 13 |
+
from sklearn.metrics import average_precision_score, precision_recall_fscore_support, roc_auc_score
|
| 14 |
+
from torch_geometric.data import Data
|
| 15 |
+
from torch_geometric.nn import SAGEConv
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "walter-taya/aml-gnn-ibm-baseline-medium")
|
| 19 |
+
MODEL_REPO_TYPE = "model"
|
| 20 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 21 |
+
|
| 22 |
+
TARGET_COL_CANDIDATES = ["is_laundering", "is laundering", "is_sar", "label", "target", "y"]
|
| 23 |
+
SRC_COL_CANDIDATES = ["from_account", "originator", "sender", "nameorig", "account", "src", "source"]
|
| 24 |
+
DST_COL_CANDIDATES = ["to_account", "beneficiary", "receiver", "namedest", "account.1", "dst", "target_account"]
|
| 25 |
+
AMOUNT_COL_CANDIDATES = ["amount", "amount_paid", "payment_amount", "transaction_amount", "amt"]
|
| 26 |
+
TIME_COL_CANDIDATES = ["timestamp", "step", "time", "date", "tran_date", "transaction_date"]
|
| 27 |
+
CURRENCY_COL_CANDIDATES = ["currency", "payment_currency", "ccy", "cur"]
|
| 28 |
+
PAYMENT_TYPE_COL_CANDIDATES = ["payment_type", "type", "transaction_type", "channel", "payment_method"]
|
| 29 |
+
CHANNEL_COL_CANDIDATES = ["channel", "delivery_channel", "device_channel", "network"]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _normalize_name(col: str) -> str:
|
| 33 |
+
return str(col).strip().lower().replace(" ", "_")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def find_column(df: pd.DataFrame, candidates: List[str], required: bool = True) -> Optional[str]:
|
| 37 |
+
norm_map = {_normalize_name(c): c for c in df.columns}
|
| 38 |
+
for cand in candidates:
|
| 39 |
+
cand_norm = _normalize_name(cand)
|
| 40 |
+
if cand_norm in norm_map:
|
| 41 |
+
return norm_map[cand_norm]
|
| 42 |
+
if required:
|
| 43 |
+
raise KeyError(f"None of the candidate columns found: {candidates}")
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def to_binary_label(series: pd.Series) -> pd.Series:
|
| 48 |
+
if series.dtype == bool:
|
| 49 |
+
return series.astype(int)
|
| 50 |
+
if np.issubdtype(series.dtype, np.number):
|
| 51 |
+
return (series > 0).astype(int)
|
| 52 |
+
s = series.astype(str).str.strip().str.lower()
|
| 53 |
+
positives = {"1", "true", "yes", "sar", "laundering", "suspicious", "fraud"}
|
| 54 |
+
return s.isin(positives).astype(int)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def safe_amount(x):
|
| 58 |
+
if pd.isna(x):
|
| 59 |
+
return 0.0
|
| 60 |
+
try:
|
| 61 |
+
return float(str(x).replace(",", ""))
|
| 62 |
+
except Exception:
|
| 63 |
+
return 0.0
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def add_time_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 67 |
+
if "time" not in df.columns:
|
| 68 |
+
df["hour"] = 0
|
| 69 |
+
df["dayofweek"] = 0
|
| 70 |
+
return df
|
| 71 |
+
|
| 72 |
+
t = df["time"]
|
| 73 |
+
if np.issubdtype(t.dtype, np.number):
|
| 74 |
+
hour = (t.astype(float) % 24).fillna(0)
|
| 75 |
+
day = ((t.astype(float) // 24) % 7).fillna(0)
|
| 76 |
+
df["hour"] = hour.astype(int)
|
| 77 |
+
df["dayofweek"] = day.astype(int)
|
| 78 |
+
return df
|
| 79 |
+
|
| 80 |
+
dt = pd.to_datetime(t, errors="coerce")
|
| 81 |
+
df["hour"] = dt.dt.hour.fillna(0).astype(int)
|
| 82 |
+
df["dayofweek"] = dt.dt.dayofweek.fillna(0).astype(int)
|
| 83 |
+
return df
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class EdgeGNN(nn.Module):
|
| 87 |
+
def __init__(self, in_dim: int, edge_dim: int, hidden_dim: int, dropout: float = 0.2):
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.conv1 = SAGEConv(in_dim, hidden_dim)
|
| 90 |
+
self.conv2 = SAGEConv(hidden_dim, hidden_dim)
|
| 91 |
+
self.dropout = dropout
|
| 92 |
+
self.edge_mlp = nn.Sequential(
|
| 93 |
+
nn.Linear(hidden_dim * 2 + edge_dim, hidden_dim),
|
| 94 |
+
nn.ReLU(),
|
| 95 |
+
nn.Dropout(dropout),
|
| 96 |
+
nn.Linear(hidden_dim, 1),
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def encode_nodes(self, x, edge_index):
|
| 100 |
+
h = self.conv1(x, edge_index)
|
| 101 |
+
h = F.relu(h)
|
| 102 |
+
h = F.dropout(h, p=self.dropout, training=self.training)
|
| 103 |
+
h = self.conv2(h, edge_index)
|
| 104 |
+
return h
|
| 105 |
+
|
| 106 |
+
def edge_logits(self, node_emb, edge_index, edge_attr):
|
| 107 |
+
src = edge_index[0]
|
| 108 |
+
dst = edge_index[1]
|
| 109 |
+
edge_feat = torch.cat([node_emb[src], node_emb[dst], edge_attr], dim=1)
|
| 110 |
+
return self.edge_mlp(edge_feat).squeeze(-1)
|
| 111 |
+
|
| 112 |
+
def forward(self, x, edge_index, edge_attr):
|
| 113 |
+
node_emb = self.encode_nodes(x, edge_index)
|
| 114 |
+
return self.edge_logits(node_emb, edge_index, edge_attr)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def compute_metrics(y_true: np.ndarray, y_prob: np.ndarray, threshold: float) -> Dict[str, float]:
|
| 118 |
+
y_pred = (y_prob >= threshold).astype(int)
|
| 119 |
+
p, r, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0)
|
| 120 |
+
metrics = {"precision": float(p), "recall": float(r), "f1": float(f1)}
|
| 121 |
+
if len(np.unique(y_true)) > 1:
|
| 122 |
+
metrics["roc_auc"] = float(roc_auc_score(y_true, y_prob))
|
| 123 |
+
metrics["pr_auc"] = float(average_precision_score(y_true, y_prob))
|
| 124 |
+
return metrics
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def prepare_graph(df: pd.DataFrame) -> Tuple[pd.DataFrame, Data, torch.Tensor, Dict[str, str]]:
|
| 128 |
+
src_col = find_column(df, SRC_COL_CANDIDATES)
|
| 129 |
+
dst_col = find_column(df, DST_COL_CANDIDATES)
|
| 130 |
+
amount_col = find_column(df, AMOUNT_COL_CANDIDATES)
|
| 131 |
+
|
| 132 |
+
label_col = find_column(df, TARGET_COL_CANDIDATES, required=False)
|
| 133 |
+
time_col = find_column(df, TIME_COL_CANDIDATES, required=False)
|
| 134 |
+
currency_col = find_column(df, CURRENCY_COL_CANDIDATES, required=False)
|
| 135 |
+
payment_type_col = find_column(df, PAYMENT_TYPE_COL_CANDIDATES, required=False)
|
| 136 |
+
channel_col = find_column(df, CHANNEL_COL_CANDIDATES, required=False)
|
| 137 |
+
|
| 138 |
+
selected = [src_col, dst_col, amount_col]
|
| 139 |
+
rename_to = ["src", "dst", "amount"]
|
| 140 |
+
|
| 141 |
+
if label_col:
|
| 142 |
+
selected.append(label_col)
|
| 143 |
+
rename_to.append("label")
|
| 144 |
+
if time_col:
|
| 145 |
+
selected.append(time_col)
|
| 146 |
+
rename_to.append("time")
|
| 147 |
+
if currency_col:
|
| 148 |
+
selected.append(currency_col)
|
| 149 |
+
rename_to.append("currency")
|
| 150 |
+
if payment_type_col:
|
| 151 |
+
selected.append(payment_type_col)
|
| 152 |
+
rename_to.append("payment_type")
|
| 153 |
+
if channel_col:
|
| 154 |
+
selected.append(channel_col)
|
| 155 |
+
rename_to.append("channel")
|
| 156 |
+
|
| 157 |
+
work_df = df[selected].copy()
|
| 158 |
+
work_df.columns = rename_to
|
| 159 |
+
work_df["amount"] = work_df["amount"].apply(safe_amount).astype(float)
|
| 160 |
+
work_df = work_df.dropna(subset=["src", "dst"]).reset_index(drop=True)
|
| 161 |
+
work_df["src"] = work_df["src"].astype(str)
|
| 162 |
+
work_df["dst"] = work_df["dst"].astype(str)
|
| 163 |
+
|
| 164 |
+
if "label" in work_df.columns:
|
| 165 |
+
work_df["label"] = to_binary_label(work_df["label"]).astype(int)
|
| 166 |
+
|
| 167 |
+
for col in ["currency", "payment_type", "channel"]:
|
| 168 |
+
if col in work_df.columns:
|
| 169 |
+
work_df[col] = work_df[col].astype(str).fillna("UNK")
|
| 170 |
+
|
| 171 |
+
work_df = add_time_features(work_df)
|
| 172 |
+
|
| 173 |
+
all_accounts = pd.Index(work_df["src"]).append(pd.Index(work_df["dst"])).unique()
|
| 174 |
+
account_to_id = {acc: i for i, acc in enumerate(all_accounts)}
|
| 175 |
+
work_df["src_id"] = work_df["src"].map(account_to_id)
|
| 176 |
+
work_df["dst_id"] = work_df["dst"].map(account_to_id)
|
| 177 |
+
|
| 178 |
+
edge_index = torch.tensor(work_df[["src_id", "dst_id"]].to_numpy().T, dtype=torch.long)
|
| 179 |
+
|
| 180 |
+
edge_cont_cols = ["amount", "hour", "dayofweek"]
|
| 181 |
+
edge_cont = work_df[edge_cont_cols].copy()
|
| 182 |
+
edge_cont["amount"] = np.log1p(edge_cont["amount"].clip(lower=0.0))
|
| 183 |
+
edge_cont["hour"] = edge_cont["hour"] / 23.0
|
| 184 |
+
edge_cont["dayofweek"] = edge_cont["dayofweek"] / 6.0
|
| 185 |
+
|
| 186 |
+
cat_cols = [c for c in ["currency", "payment_type", "channel"] if c in work_df.columns]
|
| 187 |
+
if cat_cols:
|
| 188 |
+
edge_cat = pd.get_dummies(work_df[cat_cols], prefix=cat_cols, dummy_na=True)
|
| 189 |
+
edge_feat_df = pd.concat([edge_cont, edge_cat], axis=1)
|
| 190 |
+
else:
|
| 191 |
+
edge_feat_df = edge_cont
|
| 192 |
+
|
| 193 |
+
edge_attr = torch.tensor(edge_feat_df.to_numpy(), dtype=torch.float32)
|
| 194 |
+
|
| 195 |
+
num_nodes = len(all_accounts)
|
| 196 |
+
node_df = pd.DataFrame(index=np.arange(num_nodes))
|
| 197 |
+
out_count = work_df.groupby("src_id").size().reindex(node_df.index, fill_value=0)
|
| 198 |
+
in_count = work_df.groupby("dst_id").size().reindex(node_df.index, fill_value=0)
|
| 199 |
+
out_amt_sum = work_df.groupby("src_id")["amount"].sum().reindex(node_df.index, fill_value=0.0)
|
| 200 |
+
in_amt_sum = work_df.groupby("dst_id")["amount"].sum().reindex(node_df.index, fill_value=0.0)
|
| 201 |
+
out_amt_mean = work_df.groupby("src_id")["amount"].mean().reindex(node_df.index, fill_value=0.0)
|
| 202 |
+
in_amt_mean = work_df.groupby("dst_id")["amount"].mean().reindex(node_df.index, fill_value=0.0)
|
| 203 |
+
out_hour_mean = work_df.groupby("src_id")["hour"].mean().reindex(node_df.index, fill_value=0.0)
|
| 204 |
+
in_hour_mean = work_df.groupby("dst_id")["hour"].mean().reindex(node_df.index, fill_value=0.0)
|
| 205 |
+
out_night_ratio = (
|
| 206 |
+
work_df.assign(night=work_df["hour"].isin([0, 1, 2, 3, 4, 5]).astype(int))
|
| 207 |
+
.groupby("src_id")["night"]
|
| 208 |
+
.mean()
|
| 209 |
+
.reindex(node_df.index, fill_value=0.0)
|
| 210 |
+
)
|
| 211 |
+
in_night_ratio = (
|
| 212 |
+
work_df.assign(night=work_df["hour"].isin([0, 1, 2, 3, 4, 5]).astype(int))
|
| 213 |
+
.groupby("dst_id")["night"]
|
| 214 |
+
.mean()
|
| 215 |
+
.reindex(node_df.index, fill_value=0.0)
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
node_df["out_count"] = out_count
|
| 219 |
+
node_df["in_count"] = in_count
|
| 220 |
+
node_df["out_amt_sum"] = out_amt_sum
|
| 221 |
+
node_df["in_amt_sum"] = in_amt_sum
|
| 222 |
+
node_df["out_amt_mean"] = out_amt_mean
|
| 223 |
+
node_df["in_amt_mean"] = in_amt_mean
|
| 224 |
+
node_df["out_hour_mean"] = out_hour_mean
|
| 225 |
+
node_df["in_hour_mean"] = in_hour_mean
|
| 226 |
+
node_df["out_night_ratio"] = out_night_ratio
|
| 227 |
+
node_df["in_night_ratio"] = in_night_ratio
|
| 228 |
+
|
| 229 |
+
for col in ["out_count", "in_count", "out_amt_sum", "in_amt_sum", "out_amt_mean", "in_amt_mean"]:
|
| 230 |
+
node_df[col] = np.log1p(node_df[col].clip(lower=0.0))
|
| 231 |
+
|
| 232 |
+
node_x = torch.tensor(node_df.to_numpy(), dtype=torch.float32)
|
| 233 |
+
data = Data(x=node_x, edge_index=edge_index, edge_attr=edge_attr)
|
| 234 |
+
|
| 235 |
+
mapping = {
|
| 236 |
+
"src": src_col,
|
| 237 |
+
"dst": dst_col,
|
| 238 |
+
"amount": amount_col,
|
| 239 |
+
"label": label_col or "(not found)",
|
| 240 |
+
"time": time_col or "(not found)",
|
| 241 |
+
}
|
| 242 |
+
return work_df, data, edge_attr, mapping
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def load_model_and_config(repo_id: str):
|
| 246 |
+
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", repo_type=MODEL_REPO_TYPE)
|
| 247 |
+
weights_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin", repo_type=MODEL_REPO_TYPE)
|
| 248 |
+
|
| 249 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
| 250 |
+
config = json.load(f)
|
| 251 |
+
|
| 252 |
+
model = EdgeGNN(
|
| 253 |
+
in_dim=int(config["in_dim"]),
|
| 254 |
+
edge_dim=int(config["edge_dim"]),
|
| 255 |
+
hidden_dim=int(config["hidden_dim"]),
|
| 256 |
+
dropout=float(config.get("dropout", 0.2)),
|
| 257 |
+
).to(DEVICE)
|
| 258 |
+
state_dict = torch.load(weights_path, map_location=DEVICE)
|
| 259 |
+
model.load_state_dict(state_dict)
|
| 260 |
+
model.eval()
|
| 261 |
+
|
| 262 |
+
default_threshold = float(config.get("best_threshold", 0.5))
|
| 263 |
+
return model, config, default_threshold
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
MODEL = None
|
| 267 |
+
MODEL_CONFIG = None
|
| 268 |
+
DEFAULT_THRESHOLD = 0.5
|
| 269 |
+
MODEL_LOAD_ERROR = None
|
| 270 |
+
|
| 271 |
+
try:
|
| 272 |
+
MODEL, MODEL_CONFIG, DEFAULT_THRESHOLD = load_model_and_config(MODEL_REPO_ID)
|
| 273 |
+
except Exception as ex:
|
| 274 |
+
MODEL_LOAD_ERROR = str(ex)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def score_transactions(file_obj, threshold: float, top_k: int):
|
| 278 |
+
if MODEL_LOAD_ERROR:
|
| 279 |
+
return (
|
| 280 |
+
f"❌ Model load failed from `{MODEL_REPO_ID}`: {MODEL_LOAD_ERROR}",
|
| 281 |
+
pd.DataFrame(),
|
| 282 |
+
pd.DataFrame(),
|
| 283 |
+
None,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
if file_obj is None:
|
| 287 |
+
return "Please upload a CSV file.", pd.DataFrame(), pd.DataFrame(), None
|
| 288 |
+
|
| 289 |
+
try:
|
| 290 |
+
raw_df = pd.read_csv(file_obj.name)
|
| 291 |
+
work_df, data, _, mapping = prepare_graph(raw_df)
|
| 292 |
+
|
| 293 |
+
with torch.no_grad():
|
| 294 |
+
logits = MODEL(
|
| 295 |
+
data.x.to(DEVICE),
|
| 296 |
+
data.edge_index.to(DEVICE),
|
| 297 |
+
data.edge_attr.to(DEVICE),
|
| 298 |
+
)
|
| 299 |
+
probs = torch.sigmoid(logits).detach().cpu().numpy()
|
| 300 |
+
|
| 301 |
+
result_df = work_df.copy()
|
| 302 |
+
result_df["pred_prob"] = probs
|
| 303 |
+
result_df["pred_label"] = (result_df["pred_prob"] >= threshold).astype(int)
|
| 304 |
+
result_df = result_df.sort_values("pred_prob", ascending=False).reset_index(drop=True)
|
| 305 |
+
|
| 306 |
+
account_alerts = pd.concat(
|
| 307 |
+
[
|
| 308 |
+
result_df[["src", "pred_prob", "pred_label"]].rename(columns={"src": "account"}),
|
| 309 |
+
result_df[["dst", "pred_prob", "pred_label"]].rename(columns={"dst": "account"}),
|
| 310 |
+
],
|
| 311 |
+
axis=0,
|
| 312 |
+
ignore_index=True,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
account_risk = (
|
| 316 |
+
account_alerts.groupby("account")
|
| 317 |
+
.agg(
|
| 318 |
+
max_txn_risk=("pred_prob", "max"),
|
| 319 |
+
mean_txn_risk=("pred_prob", "mean"),
|
| 320 |
+
txn_count=("pred_prob", "size"),
|
| 321 |
+
pred_account_alert=("pred_label", "max"),
|
| 322 |
+
)
|
| 323 |
+
.reset_index()
|
| 324 |
+
.sort_values("max_txn_risk", ascending=False)
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
metrics_block = ""
|
| 328 |
+
if "label" in result_df.columns:
|
| 329 |
+
y_true = result_df["label"].to_numpy().astype(int)
|
| 330 |
+
if len(np.unique(y_true)) > 1:
|
| 331 |
+
metrics = compute_metrics(y_true, result_df["pred_prob"].to_numpy(), threshold=threshold)
|
| 332 |
+
metrics_block = "\n".join([f"- **{k}**: {v:.4f}" for k, v in metrics.items()])
|
| 333 |
+
else:
|
| 334 |
+
metrics_block = "- Ground-truth label has only one class; metrics skipped."
|
| 335 |
+
else:
|
| 336 |
+
metrics_block = "- Ground-truth label column not found; showing inference-only outputs."
|
| 337 |
+
|
| 338 |
+
summary = (
|
| 339 |
+
f"✅ Scored **{len(result_df):,}** transactions from `{os.path.basename(file_obj.name)}`\n\n"
|
| 340 |
+
f"**Model repo**: `{MODEL_REPO_ID}` \n"
|
| 341 |
+
f"**Threshold**: `{threshold:.3f}` \n"
|
| 342 |
+
f"**Detected schema**: `{mapping}`\n\n"
|
| 343 |
+
f"**Metrics (if label available)**\n{metrics_block}"
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
top_txn = result_df.head(max(1, int(top_k)))
|
| 347 |
+
top_accounts = account_risk.head(max(1, int(top_k)))
|
| 348 |
+
|
| 349 |
+
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as tmp:
|
| 350 |
+
result_df.to_csv(tmp.name, index=False)
|
| 351 |
+
out_path = tmp.name
|
| 352 |
+
|
| 353 |
+
return summary, top_txn, top_accounts, out_path
|
| 354 |
+
|
| 355 |
+
except Exception as ex:
|
| 356 |
+
return f"❌ Inference failed: {ex}", pd.DataFrame(), pd.DataFrame(), None
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
with gr.Blocks(title="AML GNN Inference Space") as demo:
|
| 360 |
+
gr.Markdown(
|
| 361 |
+
"""
|
| 362 |
+
# AML Transaction Risk Scoring (GNN)
|
| 363 |
+
Upload a transaction CSV and score suspicious transactions using the model from Hugging Face Hub.
|
| 364 |
+
|
| 365 |
+
**Expected columns (flexible names supported):** source account, destination account, amount.
|
| 366 |
+
Optional: label, timestamp/time, currency, payment type, channel.
|
| 367 |
+
"""
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
with gr.Row():
|
| 371 |
+
file_input = gr.File(label="Upload transaction CSV", file_types=[".csv"])
|
| 372 |
+
threshold = gr.Slider(0.05, 0.95, value=DEFAULT_THRESHOLD, step=0.01, label="Decision threshold")
|
| 373 |
+
top_k = gr.Slider(5, 100, value=20, step=1, label="Top rows to display")
|
| 374 |
+
|
| 375 |
+
run_btn = gr.Button("Score Transactions", variant="primary")
|
| 376 |
+
summary = gr.Markdown(label="Run summary")
|
| 377 |
+
|
| 378 |
+
with gr.Tab("Top suspicious transactions"):
|
| 379 |
+
top_txn_df = gr.Dataframe(label="Top scored transactions")
|
| 380 |
+
|
| 381 |
+
with gr.Tab("Top risky accounts"):
|
| 382 |
+
top_acc_df = gr.Dataframe(label="Top account risks")
|
| 383 |
+
|
| 384 |
+
download_file = gr.File(label="Download full scored CSV")
|
| 385 |
+
|
| 386 |
+
run_btn.click(
|
| 387 |
+
fn=score_transactions,
|
| 388 |
+
inputs=[file_input, threshold, top_k],
|
| 389 |
+
outputs=[summary, top_txn_df, top_acc_df, download_file],
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
if __name__ == "__main__":
|
| 394 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=5.0.0
|
| 2 |
+
pandas>=2.0.0
|
| 3 |
+
numpy>=1.24.0
|
| 4 |
+
scikit-learn>=1.3.0
|
| 5 |
+
torch>=2.2.0
|
| 6 |
+
torch-geometric>=2.5.0
|
| 7 |
+
huggingface_hub>=0.24.0
|