Spaces:
Sleeping
Sleeping
Added files
Browse files- README.md +69 -14
- export_graph_state.py +63 -0
- model.py +46 -0
README.md
CHANGED
|
@@ -1,14 +1,69 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AML Detection β Gradio App
|
| 2 |
+
|
| 3 |
+
GraphSAGE-based Anti-Money Laundering transaction classifier.
|
| 4 |
+
|
| 5 |
+
## File Structure
|
| 6 |
+
|
| 7 |
+
```
|
| 8 |
+
aml_app/
|
| 9 |
+
βββ app.py # Gradio UI
|
| 10 |
+
βββ inference.py # Feature engineering + inference engine
|
| 11 |
+
βββ model.py # EdgeGNN architecture (matches training notebook exactly)
|
| 12 |
+
βββ requirements.txt
|
| 13 |
+
βββ export_graph_state.py # Run in notebook to export artifacts
|
| 14 |
+
βββ artifacts/ # β place your trained files here
|
| 15 |
+
βββ best_model.pt # trained model weights
|
| 16 |
+
βββ config.json # training config (threshold, hidden_dim, etc.)
|
| 17 |
+
βββ graph_state.pt # optional: historical graph for richer predictions
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
## Quick Start
|
| 21 |
+
|
| 22 |
+
### 1. Install dependencies
|
| 23 |
+
```bash
|
| 24 |
+
pip install -r requirements.txt
|
| 25 |
+
# PyTorch Geometric also needs:
|
| 26 |
+
pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.1.0+cpu.html
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
### 2. Export artifacts from your notebook
|
| 30 |
+
After training, run `export_graph_state.py` inside your Kaggle/Lightning AI notebook:
|
| 31 |
+
```python
|
| 32 |
+
exec(open('export_graph_state.py').read())
|
| 33 |
+
```
|
| 34 |
+
Then download the `artifacts/` folder and place it next to `app.py`.
|
| 35 |
+
|
| 36 |
+
### 3. Run
|
| 37 |
+
```bash
|
| 38 |
+
python app.py
|
| 39 |
+
# β http://localhost:7860
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
## Running without trained weights (Demo Mode)
|
| 43 |
+
|
| 44 |
+
The app runs in **demo mode** if `artifacts/best_model.pt` is missing β
|
| 45 |
+
it uses random weights so the interface is fully functional but predictions
|
| 46 |
+
are meaningless. Good for testing the UI.
|
| 47 |
+
|
| 48 |
+
## Inference Modes
|
| 49 |
+
|
| 50 |
+
| Mode | When | Quality |
|
| 51 |
+
|---|---|---|
|
| 52 |
+
| `graph-lookup` | Account exists in `graph_state.pt` | Best β uses full neighbourhood |
|
| 53 |
+
| `cold-start` | New account / no `graph_state.pt` | Reduced β single edge only |
|
| 54 |
+
|
| 55 |
+
## Environment Variables
|
| 56 |
+
|
| 57 |
+
| Variable | Default | Description |
|
| 58 |
+
|---|---|---|
|
| 59 |
+
| `MODEL_PATH` | `artifacts/best_model.pt` | Path to model weights |
|
| 60 |
+
| `CONFIG_PATH` | `artifacts/config.json` | Path to config JSON |
|
| 61 |
+
| `GRAPH_PATH` | `artifacts/graph_state.pt` | Path to graph state |
|
| 62 |
+
| `PORT` | `7860` | Server port |
|
| 63 |
+
|
| 64 |
+
## Deploying to Hugging Face Spaces
|
| 65 |
+
|
| 66 |
+
1. Create a new Space (Gradio SDK)
|
| 67 |
+
2. Upload all files in this directory
|
| 68 |
+
3. Upload your `artifacts/` folder
|
| 69 |
+
4. The Space will auto-install requirements and launch `app.py`
|
export_graph_state.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
export_graph_state.py
|
| 3 |
+
βββββββββββββββββββββ
|
| 4 |
+
Run this inside your Kaggle / Lightning AI notebook AFTER Cell 7
|
| 5 |
+
to export everything the Gradio app needs for full graph-lookup inference.
|
| 6 |
+
|
| 7 |
+
Usage (paste into a new notebook cell after c07-graph):
|
| 8 |
+
%run export_graph_state.py
|
| 9 |
+
|
| 10 |
+
Or inline:
|
| 11 |
+
exec(open('export_graph_state.py').read())
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import os, torch
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
EXPORT_DIR = Path(os.environ.get("EXPORT_DIR", "/kaggle/working/artifacts/aml-gnn"))
|
| 18 |
+
EXPORT_DIR.mkdir(parents=True, exist_ok=True)
|
| 19 |
+
|
| 20 |
+
# ββ 1. Best model weights βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 21 |
+
# (already saved by c09-train if you ran it β just copy)
|
| 22 |
+
model_src = EXPORT_DIR / "best_model.pt"
|
| 23 |
+
if not model_src.exists() and best_state is not None:
|
| 24 |
+
torch.save(best_state, model_src)
|
| 25 |
+
print(f"Saved model weights β {model_src}")
|
| 26 |
+
else:
|
| 27 |
+
print(f"Model weights already at {model_src}")
|
| 28 |
+
|
| 29 |
+
# ββ 2. Graph state (for graph-lookup inference) βββββββββββββββββββββββββββββββ
|
| 30 |
+
graph_path = EXPORT_DIR / "graph_state.pt"
|
| 31 |
+
torch.save({
|
| 32 |
+
"data": data, # PyG Data object (node_x, edge_index, edge_attr)
|
| 33 |
+
"account_to_id": account_to_id, # dict: account_str β node_id int
|
| 34 |
+
"edge_columns": list(edge_feat_df.columns) if 'edge_feat_df' in dir() else [],
|
| 35 |
+
}, graph_path)
|
| 36 |
+
print(f"Saved graph state β {graph_path}")
|
| 37 |
+
print(f" Nodes : {data.x.shape[0]:,}")
|
| 38 |
+
print(f" Edges : {data.edge_index.shape[1]:,}")
|
| 39 |
+
|
| 40 |
+
# ββ 3. Config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 41 |
+
import json
|
| 42 |
+
from datetime import datetime, timezone
|
| 43 |
+
config_path = EXPORT_DIR / "config.json"
|
| 44 |
+
cfg = {
|
| 45 |
+
"model_class" : "EdgeGNN",
|
| 46 |
+
"in_dim" : int(data.x.shape[1]),
|
| 47 |
+
"edge_dim" : int(data.edge_attr.shape[1]),
|
| 48 |
+
"hidden_dim" : int(HIDDEN_DIM),
|
| 49 |
+
"dropout" : float(DROPOUT),
|
| 50 |
+
"best_threshold" : float(best_thr),
|
| 51 |
+
"use_focal_loss" : bool(USE_FOCAL_LOSS),
|
| 52 |
+
"focal_gamma" : float(FOCAL_GAMMA),
|
| 53 |
+
"sample_frac" : float(SAMPLE_FRAC),
|
| 54 |
+
"timestamp_utc" : datetime.now(timezone.utc).isoformat(),
|
| 55 |
+
}
|
| 56 |
+
config_path.write_text(json.dumps(cfg, indent=2))
|
| 57 |
+
print(f"Saved config β {config_path}")
|
| 58 |
+
|
| 59 |
+
print("\nβ
Export complete. Copy the artifacts/ folder to your Gradio app directory.")
|
| 60 |
+
print(f" {EXPORT_DIR}/")
|
| 61 |
+
print(f" βββ best_model.pt")
|
| 62 |
+
print(f" βββ graph_state.pt")
|
| 63 |
+
print(f" βββ config.json")
|
model.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
model.py β EdgeGNN definition, exactly matching the training notebook.
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch_geometric.nn import SAGEConv
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class EdgeGNN(nn.Module):
|
| 11 |
+
"""2-layer GraphSAGE node encoder + MLP edge classifier."""
|
| 12 |
+
|
| 13 |
+
def __init__(self, in_dim: int, edge_dim: int, hidden_dim: int = 64, dropout: float = 0.2):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.conv1 = SAGEConv(in_dim, hidden_dim)
|
| 16 |
+
self.conv2 = SAGEConv(hidden_dim, hidden_dim)
|
| 17 |
+
self.dropout = dropout
|
| 18 |
+
self.edge_mlp = nn.Sequential(
|
| 19 |
+
nn.Linear(hidden_dim * 2 + edge_dim, hidden_dim),
|
| 20 |
+
nn.ReLU(),
|
| 21 |
+
nn.Dropout(dropout),
|
| 22 |
+
nn.Linear(hidden_dim, 1),
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
def encode_nodes(self, x, edge_index):
|
| 26 |
+
h = F.relu(self.conv1(x, edge_index))
|
| 27 |
+
h = F.dropout(h, p=self.dropout, training=self.training)
|
| 28 |
+
return self.conv2(h, edge_index)
|
| 29 |
+
|
| 30 |
+
def edge_logits(self, node_emb, edge_index, edge_attr, local_idx=None):
|
| 31 |
+
if local_idx is not None:
|
| 32 |
+
s = edge_index[0, local_idx]
|
| 33 |
+
d = edge_index[1, local_idx]
|
| 34 |
+
ea = edge_attr[local_idx]
|
| 35 |
+
else:
|
| 36 |
+
s = edge_index[0]
|
| 37 |
+
d = edge_index[1]
|
| 38 |
+
ea = edge_attr
|
| 39 |
+
return self.edge_mlp(
|
| 40 |
+
torch.cat([node_emb[s], node_emb[d], ea], dim=1)
|
| 41 |
+
).squeeze(-1)
|
| 42 |
+
|
| 43 |
+
def forward(self, x, edge_index, edge_attr):
|
| 44 |
+
return self.edge_logits(
|
| 45 |
+
self.encode_nodes(x, edge_index), edge_index, edge_attr
|
| 46 |
+
)
|