waltertaya commited on
Commit
a29c713
Β·
verified Β·
1 Parent(s): 1e4e13b

Added files

Browse files
Files changed (3) hide show
  1. README.md +69 -14
  2. export_graph_state.py +63 -0
  3. model.py +46 -0
README.md CHANGED
@@ -1,14 +1,69 @@
1
- ---
2
- title: Aml App
3
- emoji: 🐠
4
- colorFrom: purple
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 6.9.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: AML Detect app
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AML Detection β€” Gradio App
2
+
3
+ GraphSAGE-based Anti-Money Laundering transaction classifier.
4
+
5
+ ## File Structure
6
+
7
+ ```
8
+ aml_app/
9
+ β”œβ”€β”€ app.py # Gradio UI
10
+ β”œβ”€β”€ inference.py # Feature engineering + inference engine
11
+ β”œβ”€β”€ model.py # EdgeGNN architecture (matches training notebook exactly)
12
+ β”œβ”€β”€ requirements.txt
13
+ β”œβ”€β”€ export_graph_state.py # Run in notebook to export artifacts
14
+ └── artifacts/ # ← place your trained files here
15
+ β”œβ”€β”€ best_model.pt # trained model weights
16
+ β”œβ”€β”€ config.json # training config (threshold, hidden_dim, etc.)
17
+ └── graph_state.pt # optional: historical graph for richer predictions
18
+ ```
19
+
20
+ ## Quick Start
21
+
22
+ ### 1. Install dependencies
23
+ ```bash
24
+ pip install -r requirements.txt
25
+ # PyTorch Geometric also needs:
26
+ pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.1.0+cpu.html
27
+ ```
28
+
29
+ ### 2. Export artifacts from your notebook
30
+ After training, run `export_graph_state.py` inside your Kaggle/Lightning AI notebook:
31
+ ```python
32
+ exec(open('export_graph_state.py').read())
33
+ ```
34
+ Then download the `artifacts/` folder and place it next to `app.py`.
35
+
36
+ ### 3. Run
37
+ ```bash
38
+ python app.py
39
+ # β†’ http://localhost:7860
40
+ ```
41
+
42
+ ## Running without trained weights (Demo Mode)
43
+
44
+ The app runs in **demo mode** if `artifacts/best_model.pt` is missing β€”
45
+ it uses random weights so the interface is fully functional but predictions
46
+ are meaningless. Good for testing the UI.
47
+
48
+ ## Inference Modes
49
+
50
+ | Mode | When | Quality |
51
+ |---|---|---|
52
+ | `graph-lookup` | Account exists in `graph_state.pt` | Best β€” uses full neighbourhood |
53
+ | `cold-start` | New account / no `graph_state.pt` | Reduced β€” single edge only |
54
+
55
+ ## Environment Variables
56
+
57
+ | Variable | Default | Description |
58
+ |---|---|---|
59
+ | `MODEL_PATH` | `artifacts/best_model.pt` | Path to model weights |
60
+ | `CONFIG_PATH` | `artifacts/config.json` | Path to config JSON |
61
+ | `GRAPH_PATH` | `artifacts/graph_state.pt` | Path to graph state |
62
+ | `PORT` | `7860` | Server port |
63
+
64
+ ## Deploying to Hugging Face Spaces
65
+
66
+ 1. Create a new Space (Gradio SDK)
67
+ 2. Upload all files in this directory
68
+ 3. Upload your `artifacts/` folder
69
+ 4. The Space will auto-install requirements and launch `app.py`
export_graph_state.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ export_graph_state.py
3
+ ─────────────────────
4
+ Run this inside your Kaggle / Lightning AI notebook AFTER Cell 7
5
+ to export everything the Gradio app needs for full graph-lookup inference.
6
+
7
+ Usage (paste into a new notebook cell after c07-graph):
8
+ %run export_graph_state.py
9
+
10
+ Or inline:
11
+ exec(open('export_graph_state.py').read())
12
+ """
13
+
14
+ import os, torch
15
+ from pathlib import Path
16
+
17
+ EXPORT_DIR = Path(os.environ.get("EXPORT_DIR", "/kaggle/working/artifacts/aml-gnn"))
18
+ EXPORT_DIR.mkdir(parents=True, exist_ok=True)
19
+
20
+ # ── 1. Best model weights ─────────────────────────────────────────────────────
21
+ # (already saved by c09-train if you ran it β€” just copy)
22
+ model_src = EXPORT_DIR / "best_model.pt"
23
+ if not model_src.exists() and best_state is not None:
24
+ torch.save(best_state, model_src)
25
+ print(f"Saved model weights β†’ {model_src}")
26
+ else:
27
+ print(f"Model weights already at {model_src}")
28
+
29
+ # ── 2. Graph state (for graph-lookup inference) ───────────────────────────────
30
+ graph_path = EXPORT_DIR / "graph_state.pt"
31
+ torch.save({
32
+ "data": data, # PyG Data object (node_x, edge_index, edge_attr)
33
+ "account_to_id": account_to_id, # dict: account_str β†’ node_id int
34
+ "edge_columns": list(edge_feat_df.columns) if 'edge_feat_df' in dir() else [],
35
+ }, graph_path)
36
+ print(f"Saved graph state β†’ {graph_path}")
37
+ print(f" Nodes : {data.x.shape[0]:,}")
38
+ print(f" Edges : {data.edge_index.shape[1]:,}")
39
+
40
+ # ── 3. Config ─────────────────────────────────────────────────────────────────
41
+ import json
42
+ from datetime import datetime, timezone
43
+ config_path = EXPORT_DIR / "config.json"
44
+ cfg = {
45
+ "model_class" : "EdgeGNN",
46
+ "in_dim" : int(data.x.shape[1]),
47
+ "edge_dim" : int(data.edge_attr.shape[1]),
48
+ "hidden_dim" : int(HIDDEN_DIM),
49
+ "dropout" : float(DROPOUT),
50
+ "best_threshold" : float(best_thr),
51
+ "use_focal_loss" : bool(USE_FOCAL_LOSS),
52
+ "focal_gamma" : float(FOCAL_GAMMA),
53
+ "sample_frac" : float(SAMPLE_FRAC),
54
+ "timestamp_utc" : datetime.now(timezone.utc).isoformat(),
55
+ }
56
+ config_path.write_text(json.dumps(cfg, indent=2))
57
+ print(f"Saved config β†’ {config_path}")
58
+
59
+ print("\nβœ… Export complete. Copy the artifacts/ folder to your Gradio app directory.")
60
+ print(f" {EXPORT_DIR}/")
61
+ print(f" β”œβ”€β”€ best_model.pt")
62
+ print(f" β”œβ”€β”€ graph_state.pt")
63
+ print(f" └── config.json")
model.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ model.py β€” EdgeGNN definition, exactly matching the training notebook.
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch_geometric.nn import SAGEConv
8
+
9
+
10
+ class EdgeGNN(nn.Module):
11
+ """2-layer GraphSAGE node encoder + MLP edge classifier."""
12
+
13
+ def __init__(self, in_dim: int, edge_dim: int, hidden_dim: int = 64, dropout: float = 0.2):
14
+ super().__init__()
15
+ self.conv1 = SAGEConv(in_dim, hidden_dim)
16
+ self.conv2 = SAGEConv(hidden_dim, hidden_dim)
17
+ self.dropout = dropout
18
+ self.edge_mlp = nn.Sequential(
19
+ nn.Linear(hidden_dim * 2 + edge_dim, hidden_dim),
20
+ nn.ReLU(),
21
+ nn.Dropout(dropout),
22
+ nn.Linear(hidden_dim, 1),
23
+ )
24
+
25
+ def encode_nodes(self, x, edge_index):
26
+ h = F.relu(self.conv1(x, edge_index))
27
+ h = F.dropout(h, p=self.dropout, training=self.training)
28
+ return self.conv2(h, edge_index)
29
+
30
+ def edge_logits(self, node_emb, edge_index, edge_attr, local_idx=None):
31
+ if local_idx is not None:
32
+ s = edge_index[0, local_idx]
33
+ d = edge_index[1, local_idx]
34
+ ea = edge_attr[local_idx]
35
+ else:
36
+ s = edge_index[0]
37
+ d = edge_index[1]
38
+ ea = edge_attr
39
+ return self.edge_mlp(
40
+ torch.cat([node_emb[s], node_emb[d], ea], dim=1)
41
+ ).squeeze(-1)
42
+
43
+ def forward(self, x, edge_index, edge_attr):
44
+ return self.edge_logits(
45
+ self.encode_nodes(x, edge_index), edge_index, edge_attr
46
+ )