Upload app.py
Browse files
app.py
CHANGED
|
@@ -1,177 +1,164 @@
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import pandas as pd
|
| 3 |
-
import
|
| 4 |
-
import matplotlib.pyplot as plt
|
| 5 |
import os
|
| 6 |
from huggingface_hub import hf_hub_download
|
| 7 |
|
| 8 |
-
# --- 1.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
REPO_ID = "uyen1109/eth-fraud-gnn-uyenuyen-v3"
|
| 10 |
TOKEN = os.getenv("HF_TOKEN")
|
| 11 |
|
| 12 |
-
print("
|
| 13 |
|
| 14 |
-
#
|
| 15 |
-
df_scores = pd.DataFrame()
|
| 16 |
-
df_edges = pd.DataFrame()
|
| 17 |
-
|
| 18 |
-
# 1.1 Tải file điểm số (Scores)
|
| 19 |
try:
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
df_scores.set_index("address", inplace=True)
|
| 25 |
-
print(f"✅ Loaded {len(df_scores)} node scores.")
|
| 26 |
except Exception as e:
|
| 27 |
-
print(f"⚠️
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
try:
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
except Exception as e:
|
| 38 |
-
print(f"
|
|
|
|
| 39 |
|
| 40 |
-
# ---
|
| 41 |
|
| 42 |
-
def
|
| 43 |
-
|
| 44 |
-
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
label_map = {0: "Benign (0)", 1: "Criminal (1)"}
|
| 52 |
-
label_val = row.get("label", float('nan'))
|
| 53 |
-
label_str = label_map.get(label_val, "Unknown")
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
if is_in_edges:
|
| 62 |
-
return None, "Unknown", "UNSCORED_BUT_FOUND"
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
# Lọc các giao dịch liên quan
|
| 71 |
-
subset = df_edges[(df_edges["src"] == address) | (df_edges["dst"] == address)].head(30)
|
| 72 |
-
|
| 73 |
-
if subset.empty:
|
| 74 |
-
fig, ax = plt.subplots(figsize=(6, 6))
|
| 75 |
-
ax.text(0.5, 0.5, "No transactions found", ha='center')
|
| 76 |
-
ax.axis('off')
|
| 77 |
-
return fig
|
| 78 |
-
|
| 79 |
-
# Tạo đồ thị
|
| 80 |
-
G = nx.from_pandas_edgelist(subset, source="src", target="dst", edge_attr="edge_type", create_using=nx.DiGraph())
|
| 81 |
-
|
| 82 |
-
pos = nx.spring_layout(G, seed=42, k=0.8)
|
| 83 |
-
|
| 84 |
-
plt.figure(figsize=(8, 8))
|
| 85 |
-
|
| 86 |
-
# Tô màu node
|
| 87 |
-
node_colors = []
|
| 88 |
-
node_sizes = []
|
| 89 |
-
for node in G.nodes():
|
| 90 |
-
if node == address:
|
| 91 |
-
node_colors.append("#FF4500") # Target: OrangeRed
|
| 92 |
-
node_sizes.append(400)
|
| 93 |
-
else:
|
| 94 |
-
node_colors.append("#1E90FF") # Neighbor: DodgerBlue
|
| 95 |
-
node_sizes.append(150)
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
# --- 3. CORE LOGIC ---
|
| 110 |
-
|
| 111 |
-
def analyze_wallet(address):
|
| 112 |
-
if not address:
|
| 113 |
-
return "Please enter an address.", "N/A", None
|
| 114 |
-
|
| 115 |
-
address = address.strip()
|
| 116 |
-
score, label_str, status = get_node_info(address)
|
| 117 |
-
|
| 118 |
-
# Xử lý kết quả hiển thị
|
| 119 |
-
if status == "NOT_FOUND":
|
| 120 |
-
return f"❌ Address {address} not found in any transaction data.", "Unknown", None
|
| 121 |
-
|
| 122 |
-
plot = draw_ego_graph(address)
|
| 123 |
-
|
| 124 |
-
if status == "UNSCORED_BUT_FOUND":
|
| 125 |
-
return (
|
| 126 |
-
f"⚠️ **Not Scored via GNN**\n\n"
|
| 127 |
-
f"This address exists in the transaction list (`edges_all.csv`) but was filtered out during the GNN training graph construction (likely an isolated node or missing features).\n"
|
| 128 |
-
f"Therefore, the model did not assign a risk score.",
|
| 129 |
-
"Not Scored",
|
| 130 |
-
plot
|
| 131 |
)
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
elif score > 0.2: risk_level = "MEDIUM 🟡"
|
| 138 |
-
|
| 139 |
-
result_text = (
|
| 140 |
-
f"### 🎯 Risk Score: {score:.4f}\n"
|
| 141 |
-
f"**Label:** {label_str}\n"
|
| 142 |
-
f"**Status:** Analyzed by GraphSAGE\n"
|
| 143 |
-
)
|
| 144 |
-
|
| 145 |
-
return result_text, risk_level, plot
|
| 146 |
|
| 147 |
-
# --- 4. UI ---
|
| 148 |
|
| 149 |
-
with gr.Blocks(title="
|
| 150 |
-
gr.Markdown("#
|
| 151 |
-
gr.Markdown(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
|
|
|
| 153 |
with gr.Row():
|
| 154 |
-
with gr.Column(
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
|
|
|
| 175 |
|
| 176 |
if __name__ == "__main__":
|
| 177 |
demo.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch_geometric.nn import SAGEConv, BatchNorm
|
| 6 |
+
import json
|
| 7 |
import pandas as pd
|
| 8 |
+
import numpy as np
|
|
|
|
| 9 |
import os
|
| 10 |
from huggingface_hub import hf_hub_download
|
| 11 |
|
| 12 |
+
# --- 1. ĐỊNH NGHĨA MODEL ARCHITECTURE ---
|
| 13 |
+
# Phải khớp chính xác với kiến trúc đã dùng để train trong notebook (Cell 16, trang 25)
|
| 14 |
+
class SAGE(nn.Module):
|
| 15 |
+
def __init__(self, in_dim, h=128, out_dim=2, p_drop=0.3):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.conv1 = SAGEConv(in_dim, h, bias=True)
|
| 18 |
+
self.bn1 = BatchNorm(h)
|
| 19 |
+
self.conv2 = SAGEConv(h, h, bias=True)
|
| 20 |
+
self.bn2 = BatchNorm(h)
|
| 21 |
+
self.head = nn.Linear(h, out_dim)
|
| 22 |
+
self.drop = nn.Dropout(p_drop)
|
| 23 |
+
|
| 24 |
+
def forward(self, x, edge_index):
|
| 25 |
+
# Layer 1
|
| 26 |
+
x = self.conv1(x, edge_index)
|
| 27 |
+
x = self.bn1(x)
|
| 28 |
+
x = F.relu(x)
|
| 29 |
+
x = self.drop(x)
|
| 30 |
+
|
| 31 |
+
# Layer 2
|
| 32 |
+
x = self.conv2(x, edge_index)
|
| 33 |
+
x = self.bn2(x)
|
| 34 |
+
x = F.relu(x)
|
| 35 |
+
x = self.drop(x)
|
| 36 |
+
|
| 37 |
+
# Output
|
| 38 |
+
return self.head(x)
|
| 39 |
+
|
| 40 |
+
# --- 2. SETUP & LOAD MODEL ---
|
| 41 |
REPO_ID = "uyen1109/eth-fraud-gnn-uyenuyen-v3"
|
| 42 |
TOKEN = os.getenv("HF_TOKEN")
|
| 43 |
|
| 44 |
+
print("⏳ Downloading model artifacts...")
|
| 45 |
|
| 46 |
+
# 2.1 Tải danh sách Features (để biết thứ tự nhập liệu)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
try:
|
| 48 |
+
cols_path = hf_hub_download(repo_id=REPO_ID, filename="hf_export/feature_columns.json", token=TOKEN)
|
| 49 |
+
with open(cols_path, 'r') as f:
|
| 50 |
+
FEATURE_COLS = json.load(f)
|
| 51 |
+
print(f"✅ Loaded {len(FEATURE_COLS)} feature columns.")
|
|
|
|
|
|
|
| 52 |
except Exception as e:
|
| 53 |
+
print(f"⚠️ Could not load feature_columns.json. Using default fallback list. Error: {e}")
|
| 54 |
+
# Fallback danh sách feature dựa trên notebook (Cell 8, 11, 12)
|
| 55 |
+
FEATURE_COLS = [
|
| 56 |
+
'out_deg', 'in_deg', 'eth_out_sum', 'eth_in_sum',
|
| 57 |
+
'unique_dst_cnt', 'unique_src_cnt', 'first_seen_ts', 'last_seen_ts',
|
| 58 |
+
'pr', 'clust_coef', 'betw'
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
# 2.2 Tải trọng số Model (pytorch_model.bin)
|
| 62 |
try:
|
| 63 |
+
model_path = hf_hub_download(repo_id=REPO_ID, filename="hf_export/pytorch_model.bin", token=TOKEN)
|
| 64 |
+
|
| 65 |
+
# Khởi tạo model
|
| 66 |
+
# in_dim phải bằng số lượng feature
|
| 67 |
+
model = SAGE(in_dim=len(FEATURE_COLS), h=128, out_dim=2, p_drop=0.3)
|
| 68 |
+
|
| 69 |
+
# Load weights (map_location='cpu' để chạy trên không gian không có GPU)
|
| 70 |
+
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
|
| 71 |
+
model.load_state_dict(state_dict)
|
| 72 |
+
model.eval() # Chuyển sang chế độ Inference (tắt Dropout, v.v.)
|
| 73 |
+
print("✅ Model loaded successfully!")
|
| 74 |
except Exception as e:
|
| 75 |
+
print(f"❌ Critical Error loading model: {e}")
|
| 76 |
+
model = None
|
| 77 |
|
| 78 |
+
# --- 3. INFERENCE FUNCTION ---
|
| 79 |
|
| 80 |
+
def predict_custom_node(*features):
|
| 81 |
+
if model is None:
|
| 82 |
+
return "Model not loaded correctly.", "Error"
|
| 83 |
|
| 84 |
+
try:
|
| 85 |
+
# 1. Chuyển list features nhập từ UI thành Tensor
|
| 86 |
+
# features là một tuple các giá trị
|
| 87 |
+
feat_values = [float(f) for f in features]
|
| 88 |
+
x = torch.tensor([feat_values], dtype=torch.float) # Shape: [1, num_features]
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
+
# 2. Tạo cạnh giả (Dummy Edge Index)
|
| 91 |
+
# Vì GraphSAGE cần edge_index để chạy, nhưng với 1 node đơn lẻ (Inductive trên node mới),
|
| 92 |
+
# ta không có thông tin hàng xóm.
|
| 93 |
+
# Ta truyền vào edge_index rỗng. SAGEConv sẽ hoạt động dựa trên feature của chính node đó (Self-loop logic).
|
| 94 |
+
edge_index = torch.tensor([[], []], dtype=torch.long)
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
+
# 3. Forward pass
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
logits = model(x, edge_index)
|
| 99 |
+
probs = torch.softmax(logits, dim=1)
|
| 100 |
+
prob_criminal = probs[0][1].item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
+
# 4. Xử lý kết quả
|
| 103 |
+
label = "CRIMINAL 🔴" if prob_criminal > 0.5 else "BENIGN 🟢"
|
| 104 |
+
score_percent = f"{prob_criminal * 100:.2f}%"
|
| 105 |
+
|
| 106 |
+
explanation = (
|
| 107 |
+
f"### Prediction Result\n"
|
| 108 |
+
f"- **Probability of Fraud:** {score_percent}\n"
|
| 109 |
+
f"- **Verdict:** {label}\n\n"
|
| 110 |
+
f"### Debug Info\n"
|
| 111 |
+
f"- Input Shape: {x.shape}\n"
|
| 112 |
+
f"- Raw Logits: {logits.numpy()}\n"
|
| 113 |
+
f"- Model Architecture: GraphSAGE (2 layers, 128 hidden units)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
)
|
| 115 |
|
| 116 |
+
return explanation, label
|
| 117 |
+
|
| 118 |
+
except Exception as e:
|
| 119 |
+
return f"Error during inference: {str(e)}", "Error"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
+
# --- 4. GRADIO UI ---
|
| 122 |
|
| 123 |
+
with gr.Blocks(title="Inductive Fraud Prediction") as demo:
|
| 124 |
+
gr.Markdown("# 🧠 Inductive GraphSAGE Prediction")
|
| 125 |
+
gr.Markdown(
|
| 126 |
+
"""
|
| 127 |
+
Demo này thể hiện tính **Inductive** của mô hình: Bạn có thể nhập thông số của một ví **hoàn toàn mới** (không có trong tập dữ liệu cũ) và mô hình sẽ dự đoán dựa trên những gì nó đã học được.
|
| 128 |
+
|
| 129 |
+
*Lưu ý: Vì nhập liệu thủ công, ta đang mô phỏng node này như một node cô lập (không có thông tin hàng xóm).*
|
| 130 |
+
"""
|
| 131 |
+
)
|
| 132 |
|
| 133 |
+
inputs = []
|
| 134 |
with gr.Row():
|
| 135 |
+
with gr.Column():
|
| 136 |
+
gr.Markdown("### 1. Nhập Features (Đặc trưng) của Ví")
|
| 137 |
+
# Tự động tạo ô nhập liệu dựa trên danh sách FEATURE_COLS
|
| 138 |
+
for col in FEATURE_COLS:
|
| 139 |
+
# Gợi ý giá trị mặc định để dễ test
|
| 140 |
+
default_val = 0.0
|
| 141 |
+
if "ts" in col: default_val = 1600000000 # Timestamp
|
| 142 |
+
|
| 143 |
+
inp = gr.Number(label=col, value=default_val)
|
| 144 |
+
inputs.append(inp)
|
| 145 |
+
|
| 146 |
+
with gr.Column():
|
| 147 |
+
gr.Markdown("### 2. Kết quả Dự đoán")
|
| 148 |
+
btn_predict = gr.Button("Run Inference", variant="primary")
|
| 149 |
+
lbl_result = gr.Label(label="Prediction")
|
| 150 |
+
out_log = gr.Markdown()
|
| 151 |
|
| 152 |
+
# Nút Clear để reset
|
| 153 |
+
btn_clear = gr.Button("Clear Inputs")
|
| 154 |
+
|
| 155 |
+
# Sự kiện click
|
| 156 |
+
btn_predict.click(fn=predict_custom_node, inputs=inputs, outputs=[out_log, lbl_result])
|
| 157 |
+
|
| 158 |
+
# Reset tất cả về 0
|
| 159 |
+
def clear_fn():
|
| 160 |
+
return [0.0] * len(inputs)
|
| 161 |
+
btn_clear.click(fn=clear_fn, inputs=None, outputs=inputs)
|
| 162 |
|
| 163 |
if __name__ == "__main__":
|
| 164 |
demo.launch()
|