| import gradio as gr |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch_geometric.nn import SAGEConv, BatchNorm |
| import pandas as pd |
| import numpy as np |
| import networkx as nx |
| import matplotlib.pyplot as plt |
| import os |
| import json |
| from huggingface_hub import hf_hub_download |
|
|
| |
| |
| |
| class SAGE(nn.Module): |
| def __init__(self, in_dim, h=128, out_dim=2, p_drop=0.3): |
| super().__init__() |
| self.conv1 = SAGEConv(in_dim, h, bias=True) |
| self.bn1 = BatchNorm(h) |
| self.conv2 = SAGEConv(h, h, bias=True) |
| self.bn2 = BatchNorm(h) |
| self.head = nn.Linear(h, out_dim) |
| self.drop = nn.Dropout(p_drop) |
|
|
| def forward(self, x, edge_index): |
| x = self.conv1(x, edge_index) |
| x = self.bn1(x) |
| x = F.relu(x) |
| x = self.drop(x) |
| x = self.conv2(x, edge_index) |
| x = self.bn2(x) |
| x = F.relu(x) |
| x = self.drop(x) |
| return self.head(x) |
|
|
| |
| |
| |
| REPO_ID = "uyen1109/eth-fraud-gnn-uyenuyen-v3" |
| TOKEN = os.getenv("HF_TOKEN") |
|
|
| GLOBAL_DATA = { |
| "model": None, |
| "df_scores": pd.DataFrame(), |
| "df_edges": pd.DataFrame(), |
| "feature_cols": [], |
| "status": "Initializing..." |
| } |
|
|
| def smart_load_file(filename): |
| """ |
| Ưu tiên tìm ở root (theo hình ảnh user cung cấp). |
| Thử có token -> không token. |
| """ |
| |
| paths = [filename, f"hf_export/{filename}"] |
| |
| errs = [] |
| for p in paths: |
| try: |
| |
| return hf_hub_download(repo_id=REPO_ID, filename=p, token=TOKEN) |
| except Exception as e1: |
| errs.append(f"Token fail {p}: {e1}") |
| try: |
| |
| return hf_hub_download(repo_id=REPO_ID, filename=p, token=None) |
| except Exception as e2: |
| errs.append(f"No-Token fail {p}: {e2}") |
| continue |
| |
| print(f"⚠️ Failed to load {filename}. Details: {errs}") |
| return None |
|
|
| def load_resources(): |
| logs = [] |
| print("⏳ Starting Resource Loading...") |
| |
| |
| path = smart_load_file("node_scores_with_labels.csv") |
| if path: |
| try: |
| df = pd.read_csv(path) |
| |
| cols_lower = [c.lower() for c in df.columns] |
| if "address" in cols_lower: |
| addr_col = df.columns[cols_lower.index("address")] |
| else: |
| addr_col = df.columns[0] |
| |
| df[addr_col] = df[addr_col].astype(str).str.lower().str.strip() |
| df.set_index(addr_col, inplace=True) |
| GLOBAL_DATA["df_scores"] = df |
| logs.append(f"✅ Loaded Scores: {len(df)} rows.") |
| except Exception as e: |
| logs.append(f"❌ Error parsing scores csv: {e}") |
| else: |
| logs.append("❌ 'node_scores_with_labels.csv' download failed.") |
|
|
| |
| path = smart_load_file("edges_all.csv") |
| if path: |
| try: |
| GLOBAL_DATA["df_edges"] = pd.read_csv(path, usecols=["src", "dst", "edge_type"]) |
| |
| GLOBAL_DATA["df_edges"]["src"] = GLOBAL_DATA["df_edges"]["src"].astype(str).str.lower().str.strip() |
| GLOBAL_DATA["df_edges"]["dst"] = GLOBAL_DATA["df_edges"]["dst"].astype(str).str.lower().str.strip() |
| print("✅ Loaded Edges.") |
| except Exception as e: |
| print(f"⚠️ Edge parsing error: {e}") |
| else: |
| print("⚠️ 'edges_all.csv' download failed.") |
|
|
| |
| model_path = smart_load_file("pytorch_model.bin") |
| if model_path: |
| try: |
| state_dict = torch.load(model_path, map_location=torch.device('cpu')) |
| detected_dim = state_dict['conv1.lin_l.weight'].shape[1] |
| |
| model = SAGE(in_dim=detected_dim, h=128, out_dim=2, p_drop=0.3) |
| model.load_state_dict(state_dict) |
| model.eval() |
| GLOBAL_DATA["model"] = model |
| logs.append(f"✅ Model Loaded (Input Dim: {detected_dim})") |
| |
| |
| cols_path = smart_load_file("feature_columns.json") |
| if cols_path: |
| with open(cols_path, 'r') as f: |
| cols = json.load(f) |
| |
| if len(cols) == detected_dim: |
| GLOBAL_DATA["feature_cols"] = cols |
| elif len(cols) > detected_dim: |
| GLOBAL_DATA["feature_cols"] = cols[:detected_dim] |
| else: |
| GLOBAL_DATA["feature_cols"] = cols + [f"Feat_{i}" for i in range(len(cols), detected_dim)] |
| else: |
| GLOBAL_DATA["feature_cols"] = [f"Feature_{i}" for i in range(detected_dim)] |
| logs.append("⚠️ Using Dummy Feature Names (json missing)") |
| |
| except Exception as e: |
| logs.append(f"❌ Model Init Error: {e}") |
| else: |
| logs.append("❌ 'pytorch_model.bin' NOT FOUND. Please upload it to Repo Root.") |
| |
| GLOBAL_DATA["feature_cols"] = [ |
| 'out_deg', 'in_deg', 'eth_out_sum', 'eth_in_sum', |
| 'unique_dst_cnt', 'unique_src_cnt', 'first_seen_ts', 'last_seen_ts', |
| 'pr', 'clust_coef', 'betw', 'feat_11', 'feat_12', 'feat_13', 'feat_14' |
| ] |
|
|
| GLOBAL_DATA["status"] = "\n".join(logs) |
| print(GLOBAL_DATA["status"]) |
|
|
| load_resources() |
|
|
| |
| |
| |
|
|
| def draw_graph(address): |
| df = GLOBAL_DATA["df_edges"] |
| if df.empty: return None |
| |
| subset = df[(df["src"] == address) | (df["dst"] == address)].head(20) |
| if subset.empty: return None |
|
|
| G = nx.from_pandas_edgelist(subset, "src", "dst", edge_attr="edge_type", create_using=nx.DiGraph()) |
| plt.figure(figsize=(6, 6)) |
| pos = nx.spring_layout(G, k=0.9, seed=42) |
| |
| node_colors = ["#FF4500" if n == address else "#1E90FF" for n in G.nodes()] |
| nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=200, alpha=0.9) |
| nx.draw_networkx_edges(G, pos, alpha=0.3, arrowstyle='->') |
| nx.draw_networkx_labels(G, pos, labels={n: n[:4] for n in G.nodes()}, font_size=8) |
| |
| plt.title(f"Ego Graph: {address[:6]}...") |
| plt.axis('off') |
| return plt.gcf() |
|
|
| def lookup_handler(address): |
| if not address: return "Please enter an address.", None |
| |
| raw_addr = str(address).strip().lower() |
| df = GLOBAL_DATA["df_scores"] |
| |
| |
| found = None |
| if not df.empty: |
| if raw_addr in df.index: |
| found = df.loc[raw_addr] |
| elif raw_addr.replace("0x", "") in df.index: |
| found = df.loc[raw_addr.replace("0x", "")] |
| |
| if found is not None: |
| if isinstance(found, pd.DataFrame): found = found.iloc[0] |
| score = float(found.get("prob_criminal", found.get("susp", 0.0))) |
| return ( |
| f"### ✅ Found\n**Score:** {score:.4f}\n**Status:** {'CRITICAL 🔴' if score > 0.5 else 'BENIGN 🟢'}", |
| draw_graph(raw_addr) |
| ) |
| |
| return ( |
| f"### ❌ Not Found\nAddress `{raw_addr}` not in database.\nStatus Logs:\n{GLOBAL_DATA['status']}", |
| None |
| ) |
|
|
| def predict_handler(*features): |
| if GLOBAL_DATA["model"] is None: |
| return f"❌ Model Error: pytorch_model.bin missing.\nPlease check 'System Status' below." |
| |
| try: |
| x = torch.tensor([[float(f) for f in features]], dtype=torch.float) |
| edge_index = torch.tensor([[], []], dtype=torch.long) |
| with torch.no_grad(): |
| prob = torch.softmax(GLOBAL_DATA["model"](x, edge_index), dim=1)[0][1].item() |
| return f"### Result\n**Fraud Probability:** {prob*100:.2f}%" |
| except Exception as e: |
| return f"Error: {e}" |
|
|
| |
| |
| |
| with gr.Blocks(title="ETH Fraud GNN") as demo: |
| gr.Markdown("# 🕵️♀️ Ethereum Fraud Inspector") |
| |
| with gr.Accordion("System Status (Click to Debug)", open=False): |
| gr.Markdown(lambda: GLOBAL_DATA["status"]) |
|
|
| with gr.Tabs(): |
| with gr.TabItem("🔍 Lookup"): |
| with gr.Row(): |
| inp = gr.Textbox(label="Address") |
| btn = gr.Button("Search", variant="primary") |
| with gr.Row(): |
| out_txt = gr.Markdown() |
| out_plt = gr.Plot() |
| btn.click(lookup_handler, inputs=inp, outputs=[out_txt, out_plt]) |
|
|
| with gr.TabItem("🧠 Predict"): |
| gr.Markdown("### Inductive Prediction (Simulated)") |
| |
| |
| cols = GLOBAL_DATA["feature_cols"] |
| inputs = [] |
| with gr.Row(): |
| |
| c1, c2 = gr.Column(), gr.Column() |
| for i, c in enumerate(cols): |
| with (c1 if i % 2 == 0 else c2): |
| inputs.append(gr.Number(label=c, value=0.0)) |
| |
| btn2 = gr.Button("Predict", variant="primary") |
| out2 = gr.Markdown() |
| btn2.click(predict_handler, inputs=inputs, outputs=out2) |
|
|
| if __name__ == "__main__": |
| demo.launch() |