waltertaya commited on
Commit
3b50aa4
·
verified ·
1 Parent(s): 327f9ba

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +394 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import tempfile
4
+ from typing import List, Dict, Optional, Tuple
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from huggingface_hub import hf_hub_download
13
+ from sklearn.metrics import average_precision_score, precision_recall_fscore_support, roc_auc_score
14
+ from torch_geometric.data import Data
15
+ from torch_geometric.nn import SAGEConv
16
+
17
+
18
+ MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "walter-taya/aml-gnn-ibm-baseline-medium")
19
+ MODEL_REPO_TYPE = "model"
20
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+
22
+ TARGET_COL_CANDIDATES = ["is_laundering", "is laundering", "is_sar", "label", "target", "y"]
23
+ SRC_COL_CANDIDATES = ["from_account", "originator", "sender", "nameorig", "account", "src", "source"]
24
+ DST_COL_CANDIDATES = ["to_account", "beneficiary", "receiver", "namedest", "account.1", "dst", "target_account"]
25
+ AMOUNT_COL_CANDIDATES = ["amount", "amount_paid", "payment_amount", "transaction_amount", "amt"]
26
+ TIME_COL_CANDIDATES = ["timestamp", "step", "time", "date", "tran_date", "transaction_date"]
27
+ CURRENCY_COL_CANDIDATES = ["currency", "payment_currency", "ccy", "cur"]
28
+ PAYMENT_TYPE_COL_CANDIDATES = ["payment_type", "type", "transaction_type", "channel", "payment_method"]
29
+ CHANNEL_COL_CANDIDATES = ["channel", "delivery_channel", "device_channel", "network"]
30
+
31
+
32
+ def _normalize_name(col: str) -> str:
33
+ return str(col).strip().lower().replace(" ", "_")
34
+
35
+
36
+ def find_column(df: pd.DataFrame, candidates: List[str], required: bool = True) -> Optional[str]:
37
+ norm_map = {_normalize_name(c): c for c in df.columns}
38
+ for cand in candidates:
39
+ cand_norm = _normalize_name(cand)
40
+ if cand_norm in norm_map:
41
+ return norm_map[cand_norm]
42
+ if required:
43
+ raise KeyError(f"None of the candidate columns found: {candidates}")
44
+ return None
45
+
46
+
47
+ def to_binary_label(series: pd.Series) -> pd.Series:
48
+ if series.dtype == bool:
49
+ return series.astype(int)
50
+ if np.issubdtype(series.dtype, np.number):
51
+ return (series > 0).astype(int)
52
+ s = series.astype(str).str.strip().str.lower()
53
+ positives = {"1", "true", "yes", "sar", "laundering", "suspicious", "fraud"}
54
+ return s.isin(positives).astype(int)
55
+
56
+
57
+ def safe_amount(x):
58
+ if pd.isna(x):
59
+ return 0.0
60
+ try:
61
+ return float(str(x).replace(",", ""))
62
+ except Exception:
63
+ return 0.0
64
+
65
+
66
+ def add_time_features(df: pd.DataFrame) -> pd.DataFrame:
67
+ if "time" not in df.columns:
68
+ df["hour"] = 0
69
+ df["dayofweek"] = 0
70
+ return df
71
+
72
+ t = df["time"]
73
+ if np.issubdtype(t.dtype, np.number):
74
+ hour = (t.astype(float) % 24).fillna(0)
75
+ day = ((t.astype(float) // 24) % 7).fillna(0)
76
+ df["hour"] = hour.astype(int)
77
+ df["dayofweek"] = day.astype(int)
78
+ return df
79
+
80
+ dt = pd.to_datetime(t, errors="coerce")
81
+ df["hour"] = dt.dt.hour.fillna(0).astype(int)
82
+ df["dayofweek"] = dt.dt.dayofweek.fillna(0).astype(int)
83
+ return df
84
+
85
+
86
+ class EdgeGNN(nn.Module):
87
+ def __init__(self, in_dim: int, edge_dim: int, hidden_dim: int, dropout: float = 0.2):
88
+ super().__init__()
89
+ self.conv1 = SAGEConv(in_dim, hidden_dim)
90
+ self.conv2 = SAGEConv(hidden_dim, hidden_dim)
91
+ self.dropout = dropout
92
+ self.edge_mlp = nn.Sequential(
93
+ nn.Linear(hidden_dim * 2 + edge_dim, hidden_dim),
94
+ nn.ReLU(),
95
+ nn.Dropout(dropout),
96
+ nn.Linear(hidden_dim, 1),
97
+ )
98
+
99
+ def encode_nodes(self, x, edge_index):
100
+ h = self.conv1(x, edge_index)
101
+ h = F.relu(h)
102
+ h = F.dropout(h, p=self.dropout, training=self.training)
103
+ h = self.conv2(h, edge_index)
104
+ return h
105
+
106
+ def edge_logits(self, node_emb, edge_index, edge_attr):
107
+ src = edge_index[0]
108
+ dst = edge_index[1]
109
+ edge_feat = torch.cat([node_emb[src], node_emb[dst], edge_attr], dim=1)
110
+ return self.edge_mlp(edge_feat).squeeze(-1)
111
+
112
+ def forward(self, x, edge_index, edge_attr):
113
+ node_emb = self.encode_nodes(x, edge_index)
114
+ return self.edge_logits(node_emb, edge_index, edge_attr)
115
+
116
+
117
+ def compute_metrics(y_true: np.ndarray, y_prob: np.ndarray, threshold: float) -> Dict[str, float]:
118
+ y_pred = (y_prob >= threshold).astype(int)
119
+ p, r, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0)
120
+ metrics = {"precision": float(p), "recall": float(r), "f1": float(f1)}
121
+ if len(np.unique(y_true)) > 1:
122
+ metrics["roc_auc"] = float(roc_auc_score(y_true, y_prob))
123
+ metrics["pr_auc"] = float(average_precision_score(y_true, y_prob))
124
+ return metrics
125
+
126
+
127
+ def prepare_graph(df: pd.DataFrame) -> Tuple[pd.DataFrame, Data, torch.Tensor, Dict[str, str]]:
128
+ src_col = find_column(df, SRC_COL_CANDIDATES)
129
+ dst_col = find_column(df, DST_COL_CANDIDATES)
130
+ amount_col = find_column(df, AMOUNT_COL_CANDIDATES)
131
+
132
+ label_col = find_column(df, TARGET_COL_CANDIDATES, required=False)
133
+ time_col = find_column(df, TIME_COL_CANDIDATES, required=False)
134
+ currency_col = find_column(df, CURRENCY_COL_CANDIDATES, required=False)
135
+ payment_type_col = find_column(df, PAYMENT_TYPE_COL_CANDIDATES, required=False)
136
+ channel_col = find_column(df, CHANNEL_COL_CANDIDATES, required=False)
137
+
138
+ selected = [src_col, dst_col, amount_col]
139
+ rename_to = ["src", "dst", "amount"]
140
+
141
+ if label_col:
142
+ selected.append(label_col)
143
+ rename_to.append("label")
144
+ if time_col:
145
+ selected.append(time_col)
146
+ rename_to.append("time")
147
+ if currency_col:
148
+ selected.append(currency_col)
149
+ rename_to.append("currency")
150
+ if payment_type_col:
151
+ selected.append(payment_type_col)
152
+ rename_to.append("payment_type")
153
+ if channel_col:
154
+ selected.append(channel_col)
155
+ rename_to.append("channel")
156
+
157
+ work_df = df[selected].copy()
158
+ work_df.columns = rename_to
159
+ work_df["amount"] = work_df["amount"].apply(safe_amount).astype(float)
160
+ work_df = work_df.dropna(subset=["src", "dst"]).reset_index(drop=True)
161
+ work_df["src"] = work_df["src"].astype(str)
162
+ work_df["dst"] = work_df["dst"].astype(str)
163
+
164
+ if "label" in work_df.columns:
165
+ work_df["label"] = to_binary_label(work_df["label"]).astype(int)
166
+
167
+ for col in ["currency", "payment_type", "channel"]:
168
+ if col in work_df.columns:
169
+ work_df[col] = work_df[col].astype(str).fillna("UNK")
170
+
171
+ work_df = add_time_features(work_df)
172
+
173
+ all_accounts = pd.Index(work_df["src"]).append(pd.Index(work_df["dst"])).unique()
174
+ account_to_id = {acc: i for i, acc in enumerate(all_accounts)}
175
+ work_df["src_id"] = work_df["src"].map(account_to_id)
176
+ work_df["dst_id"] = work_df["dst"].map(account_to_id)
177
+
178
+ edge_index = torch.tensor(work_df[["src_id", "dst_id"]].to_numpy().T, dtype=torch.long)
179
+
180
+ edge_cont_cols = ["amount", "hour", "dayofweek"]
181
+ edge_cont = work_df[edge_cont_cols].copy()
182
+ edge_cont["amount"] = np.log1p(edge_cont["amount"].clip(lower=0.0))
183
+ edge_cont["hour"] = edge_cont["hour"] / 23.0
184
+ edge_cont["dayofweek"] = edge_cont["dayofweek"] / 6.0
185
+
186
+ cat_cols = [c for c in ["currency", "payment_type", "channel"] if c in work_df.columns]
187
+ if cat_cols:
188
+ edge_cat = pd.get_dummies(work_df[cat_cols], prefix=cat_cols, dummy_na=True)
189
+ edge_feat_df = pd.concat([edge_cont, edge_cat], axis=1)
190
+ else:
191
+ edge_feat_df = edge_cont
192
+
193
+ edge_attr = torch.tensor(edge_feat_df.to_numpy(), dtype=torch.float32)
194
+
195
+ num_nodes = len(all_accounts)
196
+ node_df = pd.DataFrame(index=np.arange(num_nodes))
197
+ out_count = work_df.groupby("src_id").size().reindex(node_df.index, fill_value=0)
198
+ in_count = work_df.groupby("dst_id").size().reindex(node_df.index, fill_value=0)
199
+ out_amt_sum = work_df.groupby("src_id")["amount"].sum().reindex(node_df.index, fill_value=0.0)
200
+ in_amt_sum = work_df.groupby("dst_id")["amount"].sum().reindex(node_df.index, fill_value=0.0)
201
+ out_amt_mean = work_df.groupby("src_id")["amount"].mean().reindex(node_df.index, fill_value=0.0)
202
+ in_amt_mean = work_df.groupby("dst_id")["amount"].mean().reindex(node_df.index, fill_value=0.0)
203
+ out_hour_mean = work_df.groupby("src_id")["hour"].mean().reindex(node_df.index, fill_value=0.0)
204
+ in_hour_mean = work_df.groupby("dst_id")["hour"].mean().reindex(node_df.index, fill_value=0.0)
205
+ out_night_ratio = (
206
+ work_df.assign(night=work_df["hour"].isin([0, 1, 2, 3, 4, 5]).astype(int))
207
+ .groupby("src_id")["night"]
208
+ .mean()
209
+ .reindex(node_df.index, fill_value=0.0)
210
+ )
211
+ in_night_ratio = (
212
+ work_df.assign(night=work_df["hour"].isin([0, 1, 2, 3, 4, 5]).astype(int))
213
+ .groupby("dst_id")["night"]
214
+ .mean()
215
+ .reindex(node_df.index, fill_value=0.0)
216
+ )
217
+
218
+ node_df["out_count"] = out_count
219
+ node_df["in_count"] = in_count
220
+ node_df["out_amt_sum"] = out_amt_sum
221
+ node_df["in_amt_sum"] = in_amt_sum
222
+ node_df["out_amt_mean"] = out_amt_mean
223
+ node_df["in_amt_mean"] = in_amt_mean
224
+ node_df["out_hour_mean"] = out_hour_mean
225
+ node_df["in_hour_mean"] = in_hour_mean
226
+ node_df["out_night_ratio"] = out_night_ratio
227
+ node_df["in_night_ratio"] = in_night_ratio
228
+
229
+ for col in ["out_count", "in_count", "out_amt_sum", "in_amt_sum", "out_amt_mean", "in_amt_mean"]:
230
+ node_df[col] = np.log1p(node_df[col].clip(lower=0.0))
231
+
232
+ node_x = torch.tensor(node_df.to_numpy(), dtype=torch.float32)
233
+ data = Data(x=node_x, edge_index=edge_index, edge_attr=edge_attr)
234
+
235
+ mapping = {
236
+ "src": src_col,
237
+ "dst": dst_col,
238
+ "amount": amount_col,
239
+ "label": label_col or "(not found)",
240
+ "time": time_col or "(not found)",
241
+ }
242
+ return work_df, data, edge_attr, mapping
243
+
244
+
245
+ def load_model_and_config(repo_id: str):
246
+ config_path = hf_hub_download(repo_id=repo_id, filename="config.json", repo_type=MODEL_REPO_TYPE)
247
+ weights_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin", repo_type=MODEL_REPO_TYPE)
248
+
249
+ with open(config_path, "r", encoding="utf-8") as f:
250
+ config = json.load(f)
251
+
252
+ model = EdgeGNN(
253
+ in_dim=int(config["in_dim"]),
254
+ edge_dim=int(config["edge_dim"]),
255
+ hidden_dim=int(config["hidden_dim"]),
256
+ dropout=float(config.get("dropout", 0.2)),
257
+ ).to(DEVICE)
258
+ state_dict = torch.load(weights_path, map_location=DEVICE)
259
+ model.load_state_dict(state_dict)
260
+ model.eval()
261
+
262
+ default_threshold = float(config.get("best_threshold", 0.5))
263
+ return model, config, default_threshold
264
+
265
+
266
+ MODEL = None
267
+ MODEL_CONFIG = None
268
+ DEFAULT_THRESHOLD = 0.5
269
+ MODEL_LOAD_ERROR = None
270
+
271
+ try:
272
+ MODEL, MODEL_CONFIG, DEFAULT_THRESHOLD = load_model_and_config(MODEL_REPO_ID)
273
+ except Exception as ex:
274
+ MODEL_LOAD_ERROR = str(ex)
275
+
276
+
277
+ def score_transactions(file_obj, threshold: float, top_k: int):
278
+ if MODEL_LOAD_ERROR:
279
+ return (
280
+ f"❌ Model load failed from `{MODEL_REPO_ID}`: {MODEL_LOAD_ERROR}",
281
+ pd.DataFrame(),
282
+ pd.DataFrame(),
283
+ None,
284
+ )
285
+
286
+ if file_obj is None:
287
+ return "Please upload a CSV file.", pd.DataFrame(), pd.DataFrame(), None
288
+
289
+ try:
290
+ raw_df = pd.read_csv(file_obj.name)
291
+ work_df, data, _, mapping = prepare_graph(raw_df)
292
+
293
+ with torch.no_grad():
294
+ logits = MODEL(
295
+ data.x.to(DEVICE),
296
+ data.edge_index.to(DEVICE),
297
+ data.edge_attr.to(DEVICE),
298
+ )
299
+ probs = torch.sigmoid(logits).detach().cpu().numpy()
300
+
301
+ result_df = work_df.copy()
302
+ result_df["pred_prob"] = probs
303
+ result_df["pred_label"] = (result_df["pred_prob"] >= threshold).astype(int)
304
+ result_df = result_df.sort_values("pred_prob", ascending=False).reset_index(drop=True)
305
+
306
+ account_alerts = pd.concat(
307
+ [
308
+ result_df[["src", "pred_prob", "pred_label"]].rename(columns={"src": "account"}),
309
+ result_df[["dst", "pred_prob", "pred_label"]].rename(columns={"dst": "account"}),
310
+ ],
311
+ axis=0,
312
+ ignore_index=True,
313
+ )
314
+
315
+ account_risk = (
316
+ account_alerts.groupby("account")
317
+ .agg(
318
+ max_txn_risk=("pred_prob", "max"),
319
+ mean_txn_risk=("pred_prob", "mean"),
320
+ txn_count=("pred_prob", "size"),
321
+ pred_account_alert=("pred_label", "max"),
322
+ )
323
+ .reset_index()
324
+ .sort_values("max_txn_risk", ascending=False)
325
+ )
326
+
327
+ metrics_block = ""
328
+ if "label" in result_df.columns:
329
+ y_true = result_df["label"].to_numpy().astype(int)
330
+ if len(np.unique(y_true)) > 1:
331
+ metrics = compute_metrics(y_true, result_df["pred_prob"].to_numpy(), threshold=threshold)
332
+ metrics_block = "\n".join([f"- **{k}**: {v:.4f}" for k, v in metrics.items()])
333
+ else:
334
+ metrics_block = "- Ground-truth label has only one class; metrics skipped."
335
+ else:
336
+ metrics_block = "- Ground-truth label column not found; showing inference-only outputs."
337
+
338
+ summary = (
339
+ f"✅ Scored **{len(result_df):,}** transactions from `{os.path.basename(file_obj.name)}`\n\n"
340
+ f"**Model repo**: `{MODEL_REPO_ID}` \n"
341
+ f"**Threshold**: `{threshold:.3f}` \n"
342
+ f"**Detected schema**: `{mapping}`\n\n"
343
+ f"**Metrics (if label available)**\n{metrics_block}"
344
+ )
345
+
346
+ top_txn = result_df.head(max(1, int(top_k)))
347
+ top_accounts = account_risk.head(max(1, int(top_k)))
348
+
349
+ with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as tmp:
350
+ result_df.to_csv(tmp.name, index=False)
351
+ out_path = tmp.name
352
+
353
+ return summary, top_txn, top_accounts, out_path
354
+
355
+ except Exception as ex:
356
+ return f"❌ Inference failed: {ex}", pd.DataFrame(), pd.DataFrame(), None
357
+
358
+
359
+ with gr.Blocks(title="AML GNN Inference Space") as demo:
360
+ gr.Markdown(
361
+ """
362
+ # AML Transaction Risk Scoring (GNN)
363
+ Upload a transaction CSV and score suspicious transactions using the model from Hugging Face Hub.
364
+
365
+ **Expected columns (flexible names supported):** source account, destination account, amount.
366
+ Optional: label, timestamp/time, currency, payment type, channel.
367
+ """
368
+ )
369
+
370
+ with gr.Row():
371
+ file_input = gr.File(label="Upload transaction CSV", file_types=[".csv"])
372
+ threshold = gr.Slider(0.05, 0.95, value=DEFAULT_THRESHOLD, step=0.01, label="Decision threshold")
373
+ top_k = gr.Slider(5, 100, value=20, step=1, label="Top rows to display")
374
+
375
+ run_btn = gr.Button("Score Transactions", variant="primary")
376
+ summary = gr.Markdown(label="Run summary")
377
+
378
+ with gr.Tab("Top suspicious transactions"):
379
+ top_txn_df = gr.Dataframe(label="Top scored transactions")
380
+
381
+ with gr.Tab("Top risky accounts"):
382
+ top_acc_df = gr.Dataframe(label="Top account risks")
383
+
384
+ download_file = gr.File(label="Download full scored CSV")
385
+
386
+ run_btn.click(
387
+ fn=score_transactions,
388
+ inputs=[file_input, threshold, top_k],
389
+ outputs=[summary, top_txn_df, top_acc_df, download_file],
390
+ )
391
+
392
+
393
+ if __name__ == "__main__":
394
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio>=5.0.0
2
+ pandas>=2.0.0
3
+ numpy>=1.24.0
4
+ scikit-learn>=1.3.0
5
+ torch>=2.2.0
6
+ torch-geometric>=2.5.0
7
+ huggingface_hub>=0.24.0