uyen1109 commited on
Commit
ddd9671
·
verified ·
1 Parent(s): c8829d2

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +12 -11
  2. app.py +195 -0
  3. requirements.txt +18 -0
README.md CHANGED
@@ -1,12 +1,13 @@
1
- ---
2
- title: Elliptic
3
- emoji: 📈
4
- colorFrom: green
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.47.2
8
- app_file: app.py
9
- pinned: false
10
- ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
1
+ # Elliptic Fraud Demo (GraphSAGE)
 
 
 
 
 
 
 
 
 
2
 
3
+ - Model repo: `uyen1109/elliptic` (auto pick latest `checkpoints/<TS>/*.pt`)
4
+ - Inference: CPU-only, PyG 2.6.1
5
+
6
+ ## Use
7
+ - Tab **txId**: cần 3 CSV trong `data/` hoặc set `DATASET_REPO` tới HF dataset repo có `elliptic/elliptic_txs_*.csv`.
8
+ - Tab **feature vector**: dán vector `165` chiều (hoặc set `EXPECTED_IN` nếu khác).
9
+
10
+ ## Env vars
11
+ - `MODEL_REPO` (default `uyen1109/elliptic`)
12
+ - `DATASET_REPO` (optional)
13
+ - `EXPECTED_IN` (optional, e.g. `165`)
app.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json
2
+ import gradio as gr
3
+ import torch
4
+ import torch.nn as nn
5
+ import pandas as pd
6
+ import numpy as np
7
+ from huggingface_hub import HfApi, hf_hub_download
8
+ from torch_geometric.nn import SAGEConv
9
+ from torch_geometric.data import Data
10
+ from torch_geometric.utils import k_hop_subgraph
11
+
12
+ # ====== CONFIG ======
13
+ MODEL_REPO = os.getenv("MODEL_REPO", "uyen1109/elliptic") # repo chứa checkpoint
14
+ DATASET_REPO = os.getenv("DATASET_REPO", "") # (tùy chọn) repo dataset Elliptic để auto tải 3 CSV
15
+ EXPECTED_IN = os.getenv("EXPECTED_IN") # (tùy chọn) ép số chiều feature, vd "165"
16
+ TOPK_DEMO = int(os.getenv("TOPK_DEMO", "5000")) # số node labeled hiển thị sample
17
+
18
+ # ====== MODEL DEF ======
19
+ class GraphSAGEClassifier(nn.Module):
20
+ def __init__(self, in_channels, hidden=128, num_layers=3, dropout=0.3, out_channels=1):
21
+ super().__init__()
22
+ self.dropout = nn.Dropout(dropout)
23
+ self.activation = nn.ReLU()
24
+ self.layers = nn.ModuleList([
25
+ SAGEConv(in_channels, hidden),
26
+ SAGEConv(hidden, hidden),
27
+ SAGEConv(hidden, out_channels),
28
+ ])
29
+ def forward(self, x, edge_index):
30
+ h = x
31
+ for i, conv in enumerate(self.layers):
32
+ h = conv(h, edge_index)
33
+ if i < len(self.layers) - 1:
34
+ h = self.activation(h)
35
+ h = self.dropout(h)
36
+ return h.view(-1)
37
+
38
+ # ====== LOAD CHECKPOINT (auto-pick latest .pt) ======
39
+ def load_checkpoint(repo_id):
40
+ api = HfApi()
41
+ files = api.list_repo_files(repo_id, repo_type="model")
42
+ pt_files = [f for f in files if f.startswith("checkpoints/") and f.endswith(".pt")]
43
+ if not pt_files:
44
+ raise RuntimeError("Không tìm thấy file .pt trong repo model.")
45
+ remote = sorted(pt_files)[-1]
46
+ path = hf_hub_download(repo_id=repo_id, filename=remote, repo_type="model")
47
+ ckpt = torch.load(path, map_location="cpu")
48
+ sd = ckpt.get("state_dict", ckpt)
49
+ # remap key cho SAGEConv: linear_self/linear_neigh -> lin_l/lin_r
50
+ remapped = {}
51
+ for k, v in sd.items():
52
+ k2 = k.replace("linear_self", "lin_l").replace("linear_neigh", "lin_r")
53
+ if k2.startswith("model."): k2 = k2[6:]
54
+ if k2.startswith("module."): k2 = k2[7:]
55
+ remapped[k2] = v
56
+ meta = ckpt.get("meta", {})
57
+ return remapped, meta, remote
58
+
59
+ STATE_DICT, META, REMOTE_PATH = load_checkpoint(MODEL_REPO)
60
+
61
+ # ====== DATA LOADING (CSV) ======
62
+ def ensure_local_dataset():
63
+ """
64
+ Ưu tiên: nếu trong Space repo có /data/*.csv thì dùng luôn.
65
+ Nếu không có và DATASET_REPO được set -> tải 3 CSV từ đó.
66
+ """
67
+ os.makedirs("data", exist_ok=True)
68
+ need = {
69
+ "elliptic_txs_features.csv": False,
70
+ "elliptic_txs_classes.csv": False,
71
+ "elliptic_txs_edgelist.csv": False
72
+ }
73
+ for f in need:
74
+ need[f] = os.path.exists(os.path.join("data", f))
75
+ if all(need.values()):
76
+ return True
77
+
78
+ if DATASET_REPO:
79
+ for f in need:
80
+ hf_hub_download(repo_id=DATASET_REPO, filename=f"elliptic/{f}", repo_type="dataset", local_dir="data", local_dir_use_symlinks=False)
81
+ return True
82
+ return False # chưa có CSV
83
+
84
+ def load_graph_from_csv():
85
+ feats = pd.read_csv("data/elliptic_txs_features.csv", header=None)
86
+ tx_ids = feats.iloc[:, 0].astype(str).values
87
+ X = torch.tensor(feats.iloc[:, 1:].astype(np.float32).values, dtype=torch.float)
88
+ id2idx = {tid: i for i, tid in enumerate(tx_ids)}
89
+
90
+ classes = pd.read_csv("data/elliptic_txs_classes.csv")
91
+ classes["txId"] = classes["txId"].astype(str)
92
+ y = torch.full((X.shape[0],), -1, dtype=torch.long)
93
+ for _, row in classes.iterrows():
94
+ tid = row["txId"]
95
+ c = row["class"]
96
+ if tid in id2idx:
97
+ y[id2idx[tid]] = 1 if (isinstance(c, str) and c.strip() == "2") or (not isinstance(c, str) and int(c) == 2) else (0 if (isinstance(c, str) and c.strip() == "1") or (not isinstance(c, str) and int(c) == 1) else -1)
98
+
99
+ edges = pd.read_csv("data/elliptic_txs_edgelist.csv")
100
+ src, dst = [], []
101
+ for _, row in edges.iterrows():
102
+ a, b = str(row["txId1"]), str(row["txId2"])
103
+ if a in id2idx and b in id2idx:
104
+ src.append(id2idx[a]); dst.append(id2idx[b])
105
+ edge_index = torch.tensor([src, dst], dtype=torch.long)
106
+ data = Data(x=X, edge_index=edge_index, y=y)
107
+ return data, id2idx, tx_ids
108
+
109
+ HAVE_CSV = ensure_local_dataset()
110
+ DATA_OBJ, ID2IDX, TXIDS = (None, None, None)
111
+ if HAVE_CSV:
112
+ DATA_OBJ, ID2IDX, TXIDS = load_graph_from_csv()
113
+
114
+ # ====== BUILD MODEL INSTANCE ======
115
+ def build_model(in_channels: int):
116
+ hidden = int(META.get("hidden_dim", 128))
117
+ out_ch = int(META.get("out_channels", 1))
118
+ dropout = float(META.get("dropout", 0.3))
119
+ model = GraphSAGEClassifier(in_channels=in_channels, hidden=hidden, dropout=dropout, out_channels=out_ch)
120
+ # strict=False để an toàn nếu version PyG khác đôi chút
121
+ model.load_state_dict(STATE_DICT, strict=False)
122
+ model.eval()
123
+ return model, out_ch
124
+
125
+ # ====== INFERENCE HELPERS ======
126
+ @torch.no_grad()
127
+ def predict_from_features(vec: np.ndarray):
128
+ in_channels = int(EXPECTED_IN) if EXPECTED_IN else len(vec)
129
+ x = torch.tensor(vec, dtype=torch.float).view(1, in_channels)
130
+ # dummy 2-hop self-edge to make SAGEConv happy
131
+ edge_index = torch.tensor([[0],[0]], dtype=torch.long).repeat(1,1)
132
+ model, out_ch = build_model(in_channels)
133
+ logits = model(x, edge_index)
134
+ if out_ch == 1:
135
+ prob = torch.sigmoid(logits)[0].item()
136
+ return prob
137
+ else:
138
+ prob_illicit = torch.softmax(logits, dim=-1)[0, 1].item()
139
+ return prob_illicit
140
+
141
+ @torch.no_grad()
142
+ def predict_from_txid(txid: str, hops: int = 2):
143
+ if not HAVE_CSV or DATA_OBJ is None:
144
+ raise RuntimeError("Chưa có dataset CSV. Hãy upload 3 file vào thư mục /data của Space hoặc đặt DATASET_REPO.")
145
+ if txid not in ID2IDX:
146
+ raise RuntimeError(f"txId {txid} không có trong dataset.")
147
+ center = ID2IDX[txid]
148
+ # k-hop subgraph
149
+ subset, sub_edge_index, mapping, _ = k_hop_subgraph(center, hops, DATA_OBJ.edge_index, relabel_nodes=True)
150
+ sub_x = DATA_OBJ.x[subset]
151
+ model, out_ch = build_model(sub_x.shape[1])
152
+ logits = model(sub_x, sub_edge_index)
153
+ center_logit = logits[mapping]
154
+ if out_ch == 1:
155
+ prob = torch.sigmoid(center_logit).item()
156
+ else:
157
+ prob = torch.softmax(center_logit.view(1, -1), dim=-1)[0, 1].item()
158
+ return float(prob), int(sub_x.shape[0]), int(sub_edge_index.shape[1])
159
+
160
+ # ====== GRADIO UI ======
161
+ def ui_predict_txid(txid, hops):
162
+ try:
163
+ prob, n_nodes, n_edges = predict_from_txid(txid.strip(), int(hops))
164
+ return f"txId: {txid}\nHops: {hops}\nSubgraph nodes: {n_nodes}, edges: {n_edges}\nIllicit probability: {prob:.4f}"
165
+ except Exception as e:
166
+ return f"Error: {e}"
167
+
168
+ def ui_predict_vector(feat_str):
169
+ try:
170
+ parts = [p for p in feat_str.replace("\n"," ").split(",") if p.strip()!=""]
171
+ vec = np.array([float(x) for x in parts], dtype=np.float32)
172
+ prob = predict_from_features(vec)
173
+ return f"Vector len: {len(vec)}\nIllicit probability: {prob:.4f}"
174
+ except Exception as e:
175
+ return f"Error: {e}"
176
+
177
+ with gr.Blocks(title="Elliptic Fraud Demo (GraphSAGE)") as demo:
178
+ gr.Markdown(f"### Elliptic Fraud Demo • Model: `{MODEL_REPO}`\nCheckpoint: `{REMOTE_PATH}`")
179
+ with gr.Tab("Predict by txId (needs dataset)"):
180
+ gr.Markdown("Cần 3 CSV Elliptic trong thư mục **/data** của Space, hoặc set env `DATASET_REPO` để auto tải.")
181
+ txid_in = gr.Textbox(label="txId")
182
+ hops_in = gr.Slider(1, 3, value=2, step=1, label="K-hop subgraph")
183
+ out_tx = gr.Textbox(label="Result")
184
+ btn_tx = gr.Button("Predict")
185
+ btn_tx.click(fn=ui_predict_txid, inputs=[txid_in, hops_in], outputs=out_tx)
186
+
187
+ with gr.Tab("Predict by feature vector"):
188
+ gr.Markdown("Dán vector đặc trưng (comma-separated). Nếu khác số chiều khi train, set env `EXPECTED_IN`.")
189
+ feat_in = gr.Textbox(label="feature1, feature2, ...")
190
+ out_vec = gr.Textbox(label="Result")
191
+ btn_vec = gr.Button("Predict")
192
+ btn_vec.click(fn=ui_predict_vector, inputs=[feat_in], outputs=out_vec)
193
+
194
+ if __name__ == "__main__":
195
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cpu
2
+ torch==2.4.1+cpu
3
+ torchvision==0.19.1+cpu
4
+ torchaudio==2.4.1+cpu
5
+
6
+ # PyG (khớp torch 2.4.1 + cpu)
7
+ -f https://data.pyg.org/whl/torch-2.4.1+cpu.html
8
+ pyg_lib
9
+ torch_scatter
10
+ torch_sparse
11
+ torch_cluster
12
+ torch_spline_conv
13
+ torch-geometric==2.6.1
14
+
15
+ gradio>=4.44
16
+ huggingface_hub>=0.24.6
17
+ pandas
18
+ numpy