DemoGraph / app.py
uyen1109's picture
Upload app.py
4051c4e verified
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, BatchNorm
import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import os
import json
from huggingface_hub import hf_hub_download
# ==========================================
# 1. ĐỊNH NGHĨA MODEL
# ==========================================
class SAGE(nn.Module):
def __init__(self, in_dim, h=128, out_dim=2, p_drop=0.3):
super().__init__()
self.conv1 = SAGEConv(in_dim, h, bias=True)
self.bn1 = BatchNorm(h)
self.conv2 = SAGEConv(h, h, bias=True)
self.bn2 = BatchNorm(h)
self.head = nn.Linear(h, out_dim)
self.drop = nn.Dropout(p_drop)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = self.bn1(x)
x = F.relu(x)
x = self.drop(x)
x = self.conv2(x, edge_index)
x = self.bn2(x)
x = F.relu(x)
x = self.drop(x)
return self.head(x)
# ==========================================
# 2. QUẢN LÝ RESOURCE
# ==========================================
REPO_ID = "uyen1109/eth-fraud-gnn-uyenuyen-v3"
TOKEN = os.getenv("HF_TOKEN")
GLOBAL_DATA = {
"model": None,
"df_scores": pd.DataFrame(),
"df_edges": pd.DataFrame(),
"feature_cols": [],
"status": "Initializing..."
}
def smart_load_file(filename):
"""
Ưu tiên tìm ở root (theo hình ảnh user cung cấp).
Thử có token -> không token.
"""
# Đảo ngược thứ tự: Tìm ở root trước vì hình ảnh cho thấy file ở root
paths = [filename, f"hf_export/{filename}"]
errs = []
for p in paths:
try:
# Cách 1: Dùng Token (cho Private Repo hoặc LFS)
return hf_hub_download(repo_id=REPO_ID, filename=p, token=TOKEN)
except Exception as e1:
errs.append(f"Token fail {p}: {e1}")
try:
# Cách 2: Không dùng Token (cho Public Repo)
return hf_hub_download(repo_id=REPO_ID, filename=p, token=None)
except Exception as e2:
errs.append(f"No-Token fail {p}: {e2}")
continue
print(f"⚠️ Failed to load {filename}. Details: {errs}")
return None
def load_resources():
logs = []
print("⏳ Starting Resource Loading...")
# 1. Load Scores
path = smart_load_file("node_scores_with_labels.csv")
if path:
try:
df = pd.read_csv(path)
# Tìm cột địa chỉ linh hoạt
cols_lower = [c.lower() for c in df.columns]
if "address" in cols_lower:
addr_col = df.columns[cols_lower.index("address")]
else:
addr_col = df.columns[0]
df[addr_col] = df[addr_col].astype(str).str.lower().str.strip()
df.set_index(addr_col, inplace=True)
GLOBAL_DATA["df_scores"] = df
logs.append(f"✅ Loaded Scores: {len(df)} rows.")
except Exception as e:
logs.append(f"❌ Error parsing scores csv: {e}")
else:
logs.append("❌ 'node_scores_with_labels.csv' download failed.")
# 2. Load Edges
path = smart_load_file("edges_all.csv")
if path:
try:
GLOBAL_DATA["df_edges"] = pd.read_csv(path, usecols=["src", "dst", "edge_type"])
# Chuẩn hóa nhẹ để vẽ hình
GLOBAL_DATA["df_edges"]["src"] = GLOBAL_DATA["df_edges"]["src"].astype(str).str.lower().str.strip()
GLOBAL_DATA["df_edges"]["dst"] = GLOBAL_DATA["df_edges"]["dst"].astype(str).str.lower().str.strip()
print("✅ Loaded Edges.")
except Exception as e:
print(f"⚠️ Edge parsing error: {e}")
else:
print("⚠️ 'edges_all.csv' download failed.")
# 3. Load Model & Features
model_path = smart_load_file("pytorch_model.bin")
if model_path:
try:
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
detected_dim = state_dict['conv1.lin_l.weight'].shape[1]
model = SAGE(in_dim=detected_dim, h=128, out_dim=2, p_drop=0.3)
model.load_state_dict(state_dict)
model.eval()
GLOBAL_DATA["model"] = model
logs.append(f"✅ Model Loaded (Input Dim: {detected_dim})")
# Load Feature Columns
cols_path = smart_load_file("feature_columns.json")
if cols_path:
with open(cols_path, 'r') as f:
cols = json.load(f)
# Khớp số lượng feature
if len(cols) == detected_dim:
GLOBAL_DATA["feature_cols"] = cols
elif len(cols) > detected_dim:
GLOBAL_DATA["feature_cols"] = cols[:detected_dim]
else:
GLOBAL_DATA["feature_cols"] = cols + [f"Feat_{i}" for i in range(len(cols), detected_dim)]
else:
GLOBAL_DATA["feature_cols"] = [f"Feature_{i}" for i in range(detected_dim)]
logs.append("⚠️ Using Dummy Feature Names (json missing)")
except Exception as e:
logs.append(f"❌ Model Init Error: {e}")
else:
logs.append("❌ 'pytorch_model.bin' NOT FOUND. Please upload it to Repo Root.")
# Fallback feature list để UI không bị lỗi (dựa trên log của bạn)
GLOBAL_DATA["feature_cols"] = [
'out_deg', 'in_deg', 'eth_out_sum', 'eth_in_sum',
'unique_dst_cnt', 'unique_src_cnt', 'first_seen_ts', 'last_seen_ts',
'pr', 'clust_coef', 'betw', 'feat_11', 'feat_12', 'feat_13', 'feat_14'
]
GLOBAL_DATA["status"] = "\n".join(logs)
print(GLOBAL_DATA["status"])
load_resources()
# ==========================================
# 3. LOGIC XỬ LÝ
# ==========================================
def draw_graph(address):
df = GLOBAL_DATA["df_edges"]
if df.empty: return None
subset = df[(df["src"] == address) | (df["dst"] == address)].head(20)
if subset.empty: return None
G = nx.from_pandas_edgelist(subset, "src", "dst", edge_attr="edge_type", create_using=nx.DiGraph())
plt.figure(figsize=(6, 6))
pos = nx.spring_layout(G, k=0.9, seed=42)
node_colors = ["#FF4500" if n == address else "#1E90FF" for n in G.nodes()]
nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=200, alpha=0.9)
nx.draw_networkx_edges(G, pos, alpha=0.3, arrowstyle='->')
nx.draw_networkx_labels(G, pos, labels={n: n[:4] for n in G.nodes()}, font_size=8)
plt.title(f"Ego Graph: {address[:6]}...")
plt.axis('off')
return plt.gcf()
def lookup_handler(address):
if not address: return "Please enter an address.", None
raw_addr = str(address).strip().lower()
df = GLOBAL_DATA["df_scores"]
# Logic tìm kiếm mạnh mẽ hơn
found = None
if not df.empty:
if raw_addr in df.index:
found = df.loc[raw_addr]
elif raw_addr.replace("0x", "") in df.index:
found = df.loc[raw_addr.replace("0x", "")]
if found is not None:
if isinstance(found, pd.DataFrame): found = found.iloc[0]
score = float(found.get("prob_criminal", found.get("susp", 0.0)))
return (
f"### ✅ Found\n**Score:** {score:.4f}\n**Status:** {'CRITICAL 🔴' if score > 0.5 else 'BENIGN 🟢'}",
draw_graph(raw_addr)
)
return (
f"### ❌ Not Found\nAddress `{raw_addr}` not in database.\nStatus Logs:\n{GLOBAL_DATA['status']}",
None
)
def predict_handler(*features):
if GLOBAL_DATA["model"] is None:
return f"❌ Model Error: pytorch_model.bin missing.\nPlease check 'System Status' below."
try:
x = torch.tensor([[float(f) for f in features]], dtype=torch.float)
edge_index = torch.tensor([[], []], dtype=torch.long)
with torch.no_grad():
prob = torch.softmax(GLOBAL_DATA["model"](x, edge_index), dim=1)[0][1].item()
return f"### Result\n**Fraud Probability:** {prob*100:.2f}%"
except Exception as e:
return f"Error: {e}"
# ==========================================
# 4. UI SETUP
# ==========================================
with gr.Blocks(title="ETH Fraud GNN") as demo:
gr.Markdown("# 🕵️‍♀️ Ethereum Fraud Inspector")
with gr.Accordion("System Status (Click to Debug)", open=False):
gr.Markdown(lambda: GLOBAL_DATA["status"]) # Dynamic update
with gr.Tabs():
with gr.TabItem("🔍 Lookup"):
with gr.Row():
inp = gr.Textbox(label="Address")
btn = gr.Button("Search", variant="primary")
with gr.Row():
out_txt = gr.Markdown()
out_plt = gr.Plot()
btn.click(lookup_handler, inputs=inp, outputs=[out_txt, out_plt])
with gr.TabItem("🧠 Predict"):
gr.Markdown("### Inductive Prediction (Simulated)")
# Render input dựa trên feature cols đã load
cols = GLOBAL_DATA["feature_cols"]
inputs = []
with gr.Row():
# Chia cột tự động
c1, c2 = gr.Column(), gr.Column()
for i, c in enumerate(cols):
with (c1 if i % 2 == 0 else c2):
inputs.append(gr.Number(label=c, value=0.0))
btn2 = gr.Button("Predict", variant="primary")
out2 = gr.Markdown()
btn2.click(predict_handler, inputs=inputs, outputs=out2)
if __name__ == "__main__":
demo.launch()