elliptic / app.py
uyen1109's picture
Update app.py
c003220 verified
import os, json
import gradio as gr
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from huggingface_hub import HfApi, hf_hub_download
from torch_geometric.nn import SAGEConv
from torch_geometric.data import Data
from torch_geometric.utils import k_hop_subgraph
# ====== CONFIG ======
MODEL_REPO = os.getenv("MODEL_REPO", "uyen1109/elliptic") # repo chứa checkpoint
DATASET_REPO = os.getenv("DATASET_REPO", "") # (tùy chọn) repo dataset Elliptic để auto tải 3 CSV
EXPECTED_IN = os.getenv("EXPECTED_IN") # (tùy chọn) ép số chiều feature, vd "165"
TOPK_DEMO = int(os.getenv("TOPK_DEMO", "5000")) # số node labeled hiển thị sample
# ====== MODEL DEF ======
class GraphSAGEClassifier(nn.Module):
def __init__(self, in_channels, hidden=128, num_layers=3, dropout=0.3, out_channels=1):
super().__init__()
self.dropout = nn.Dropout(dropout)
self.activation = nn.ReLU()
self.layers = nn.ModuleList([
SAGEConv(in_channels, hidden),
SAGEConv(hidden, hidden),
SAGEConv(hidden, out_channels),
])
def forward(self, x, edge_index):
h = x
for i, conv in enumerate(self.layers):
h = conv(h, edge_index)
if i < len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)
return h.view(-1)
# ====== LOAD CHECKPOINT (auto-pick latest .pt) ======
def load_checkpoint(repo_id):
api = HfApi()
files = api.list_repo_files(repo_id, repo_type="model")
pt_files = [f for f in files if f.startswith("checkpoints/") and f.endswith(".pt")]
if not pt_files:
raise RuntimeError("Không tìm thấy file .pt trong repo model.")
remote = sorted(pt_files)[-1]
path = hf_hub_download(repo_id=repo_id, filename=remote, repo_type="model")
ckpt = torch.load(path, map_location="cpu")
sd = ckpt.get("state_dict", ckpt)
# remap key cho SAGEConv: linear_self/linear_neigh -> lin_l/lin_r
remapped = {}
for k, v in sd.items():
k2 = k.replace("linear_self", "lin_l").replace("linear_neigh", "lin_r")
if k2.startswith("model."): k2 = k2[6:]
if k2.startswith("module."): k2 = k2[7:]
remapped[k2] = v
meta = ckpt.get("meta", {})
return remapped, meta, remote
STATE_DICT, META, REMOTE_PATH = load_checkpoint(MODEL_REPO)
# ====== DATA LOADING (CSV) ======
def ensure_local_dataset():
"""
Ưu tiên: nếu trong Space repo có /data/*.csv thì dùng luôn.
Nếu không có và DATASET_REPO được set -> tải 3 CSV từ đó.
"""
os.makedirs("data", exist_ok=True)
need = {
"elliptic_txs_features.csv": False,
"elliptic_txs_classes.csv": False,
"elliptic_txs_edgelist.csv": False
}
for f in need:
need[f] = os.path.exists(os.path.join("data", f))
if all(need.values()):
return True
if DATASET_REPO:
for f in need:
hf_hub_download(repo_id=DATASET_REPO, filename=f"elliptic/{f}", repo_type="dataset", local_dir="data", local_dir_use_symlinks=False)
return True
return False # chưa có CSV
def load_graph_from_csv():
feats = pd.read_csv("data/elliptic_txs_features.csv", header=None)
tx_ids = feats.iloc[:, 0].astype(str).values
# Lấy toàn bộ sau txId: (time_step + features…)
arr = feats.iloc[:, 1:].astype(np.float32).values # thường là 166 với Elliptic
# Suy số chiều mong đợi từ checkpoint (hoặc ENV EXPECTED_IN)
exp_in = os.getenv("EXPECTED_IN")
if exp_in:
expected_in = int(exp_in)
else:
expected_in = None
for k, v in STATE_DICT.items():
if k.endswith("lin_l.weight") or k.endswith("linear_self.weight"):
expected_in = int(v.shape[1])
break
# Nếu lệch đúng 1 cột, coi cột đầu là time_step và loại bỏ
if expected_in is not None:
if arr.shape[1] == expected_in + 1:
arr = arr[:, 1:] # drop time_step
elif arr.shape[1] != expected_in:
# Fallback hiếm gặp: cắt/pad để khớp
if arr.shape[1] > expected_in:
arr = arr[:, :expected_in]
else:
arr = np.pad(arr, ((0, 0), (0, expected_in - arr.shape[1])), "constant")
X = torch.tensor(arr, dtype=torch.float)
id2idx = {tid: i for i, tid in enumerate(tx_ids)}
classes = pd.read_csv("data/elliptic_txs_classes.csv")
classes["txId"] = classes["txId"].astype(str)
y = torch.full((X.shape[0],), -1, dtype=torch.long)
for _, row in classes.iterrows():
tid = row["txId"]
c = row["class"]
if tid in id2idx:
if (isinstance(c, str) and c.strip() == "2") or (not isinstance(c, str) and int(c) == 2):
y[id2idx[tid]] = 1
elif (isinstance(c, str) and c.strip() == "1") or (not isinstance(c, str) and int(c) == 1):
y[id2idx[tid]] = 0
else:
y[id2idx[tid]] = -1
edges = pd.read_csv("data/elliptic_txs_edgelist.csv")
src, dst = [], []
for _, row in edges.iterrows():
a, b = str(row["txId1"]), str(row["txId2"])
if a in id2idx and b in id2idx:
src.append(id2idx[a]); dst.append(id2idx[b])
edge_index = torch.tensor([src, dst], dtype=torch.long)
data = Data(x=X, edge_index=edge_index, y=y)
return data, id2idx, tx_ids
HAVE_CSV = ensure_local_dataset()
DATA_OBJ, ID2IDX, TXIDS = (None, None, None)
if HAVE_CSV:
DATA_OBJ, ID2IDX, TXIDS = load_graph_from_csv()
# ====== BUILD MODEL INSTANCE ======
def build_model(in_channels: int):
hidden = int(META.get("hidden_dim", 128))
out_ch = int(META.get("out_channels", 1))
dropout = float(META.get("dropout", 0.3))
model = GraphSAGEClassifier(in_channels=in_channels, hidden=hidden, dropout=dropout, out_channels=out_ch)
# strict=False để an toàn nếu version PyG khác đôi chút
model.load_state_dict(STATE_DICT, strict=False)
model.eval()
return model, out_ch
# ====== INFERENCE HELPERS ======
@torch.no_grad()
def predict_from_features(vec: np.ndarray):
in_channels = int(EXPECTED_IN) if EXPECTED_IN else len(vec)
x = torch.tensor(vec, dtype=torch.float).view(1, in_channels)
# dummy 2-hop self-edge to make SAGEConv happy
edge_index = torch.tensor([[0],[0]], dtype=torch.long).repeat(1,1)
model, out_ch = build_model(in_channels)
logits = model(x, edge_index)
if out_ch == 1:
prob = torch.sigmoid(logits)[0].item()
return prob
else:
prob_illicit = torch.softmax(logits, dim=-1)[0, 1].item()
return prob_illicit
@torch.no_grad()
def predict_from_txid(txid: str, hops: int = 2):
if not HAVE_CSV or DATA_OBJ is None:
raise RuntimeError("Chưa có dataset CSV. Hãy upload 3 file vào thư mục /data của Space hoặc đặt DATASET_REPO.")
if txid not in ID2IDX:
raise RuntimeError(f"txId {txid} không có trong dataset.")
center = ID2IDX[txid]
# k-hop subgraph
subset, sub_edge_index, mapping, _ = k_hop_subgraph(center, hops, DATA_OBJ.edge_index, relabel_nodes=True)
sub_x = DATA_OBJ.x[subset]
model, out_ch = build_model(sub_x.shape[1])
logits = model(sub_x, sub_edge_index)
center_logit = logits[mapping]
if out_ch == 1:
prob = torch.sigmoid(center_logit).item()
else:
prob = torch.softmax(center_logit.view(1, -1), dim=-1)[0, 1].item()
return float(prob), int(sub_x.shape[0]), int(sub_edge_index.shape[1])
# ====== GRADIO UI ======
def ui_predict_txid(txid, hops):
try:
prob, n_nodes, n_edges = predict_from_txid(txid.strip(), int(hops))
return f"txId: {txid}\nHops: {hops}\nSubgraph nodes: {n_nodes}, edges: {n_edges}\nIllicit probability: {prob:.4f}"
except Exception as e:
return f"Error: {e}"
def ui_predict_vector(feat_str):
try:
parts = [p for p in feat_str.replace("\n"," ").split(",") if p.strip()!=""]
vec = np.array([float(x) for x in parts], dtype=np.float32)
prob = predict_from_features(vec)
return f"Vector len: {len(vec)}\nIllicit probability: {prob:.4f}"
except Exception as e:
return f"Error: {e}"
with gr.Blocks(title="Elliptic Fraud Demo (GraphSAGE)") as demo:
gr.Markdown(f"### Elliptic Fraud Demo • Model: `{MODEL_REPO}`\nCheckpoint: `{REMOTE_PATH}`")
with gr.Tab("Predict by txId (needs dataset)"):
gr.Markdown("Cần 3 CSV Elliptic trong thư mục **/data** của Space, hoặc set env `DATASET_REPO` để auto tải.")
txid_in = gr.Textbox(label="txId")
hops_in = gr.Slider(1, 3, value=2, step=1, label="K-hop subgraph")
out_tx = gr.Textbox(label="Result")
btn_tx = gr.Button("Predict")
btn_tx.click(fn=ui_predict_txid, inputs=[txid_in, hops_in], outputs=out_tx)
with gr.Tab("Predict by feature vector"):
gr.Markdown("Dán vector đặc trưng (comma-separated). Nếu khác số chiều khi train, set env `EXPECTED_IN`.")
feat_in = gr.Textbox(label="feature1, feature2, ...")
out_vec = gr.Textbox(label="Result")
btn_vec = gr.Button("Predict")
btn_vec.click(fn=ui_predict_vector, inputs=[feat_in], outputs=out_vec)
if __name__ == "__main__":
demo.launch()