uyen1109 commited on
Commit
466a439
·
verified ·
1 Parent(s): d6f6917

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -148
app.py CHANGED
@@ -1,177 +1,164 @@
1
  import gradio as gr
 
 
 
 
 
2
  import pandas as pd
3
- import networkx as nx
4
- import matplotlib.pyplot as plt
5
  import os
6
  from huggingface_hub import hf_hub_download
7
 
8
- # --- 1. SETUP & DATA LOADING ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  REPO_ID = "uyen1109/eth-fraud-gnn-uyenuyen-v3"
10
  TOKEN = os.getenv("HF_TOKEN")
11
 
12
- print("Loading data from Hugging Face Hub...")
13
 
14
- # Global variables
15
- df_scores = pd.DataFrame()
16
- df_edges = pd.DataFrame()
17
-
18
- # 1.1 Tải file điểm số (Scores)
19
  try:
20
- scores_path = hf_hub_download(repo_id=REPO_ID, filename="scores/node_scores_with_labels.csv", repo_type="model", token=TOKEN)
21
- df_scores = pd.read_csv(scores_path)
22
- if "address" in df_scores.columns:
23
- df_scores["address"] = df_scores["address"].astype(str).str.lower().str.strip()
24
- df_scores.set_index("address", inplace=True)
25
- print(f"✅ Loaded {len(df_scores)} node scores.")
26
  except Exception as e:
27
- print(f"⚠️ Error loading scores: {e}")
28
-
29
- # 1.2 Tải file cạnh (Edges)
 
 
 
 
 
 
30
  try:
31
- edges_path = hf_hub_download(repo_id=REPO_ID, filename="graph/edges_all.csv", repo_type="model", token=TOKEN)
32
- # Chỉ tải các cột cần thiết để tiết kiệm RAM
33
- df_edges = pd.read_csv(edges_path, usecols=["src", "dst", "edge_type"])
34
- df_edges["src"] = df_edges["src"].astype(str).str.lower().str.strip()
35
- df_edges["dst"] = df_edges["dst"].astype(str).str.lower().str.strip()
36
- print(f"✅ Loaded {len(df_edges)} edges.")
 
 
 
 
 
37
  except Exception as e:
38
- print(f"⚠️ Error loading edges: {e}")
 
39
 
40
- # --- 2. HELPER FUNCTIONS ---
41
 
42
- def get_node_info(address):
43
- """Lấy thông tin điểm số và nhãn"""
44
- address = address.lower().strip()
45
 
46
- # Case 1: Có trong bảng điểm (Model đã chấm điểm)
47
- if address in df_scores.index:
48
- row = df_scores.loc[address]
49
- score = float(row.get("prob_criminal", row.get("susp", 0.0)))
50
-
51
- label_map = {0: "Benign (0)", 1: "Criminal (1)"}
52
- label_val = row.get("label", float('nan'))
53
- label_str = label_map.get(label_val, "Unknown")
54
 
55
- return score, label_str, "SCORED"
56
-
57
- # Case 2: Không điểm, kiểm tra xem có trong giao dịch không
58
- # (Lưu ý: Kiểm tra này hơi chậm nếu df lớn, nhưng chấp nhận được cho demo)
59
- is_in_edges = ((df_edges["src"] == address) | (df_edges["dst"] == address)).any()
60
-
61
- if is_in_edges:
62
- return None, "Unknown", "UNSCORED_BUT_FOUND"
63
 
64
- return None, "Unknown", "NOT_FOUND"
65
-
66
- def draw_ego_graph(address):
67
- """Vẽ đồ thị 1-hop"""
68
- address = address.lower().strip()
69
-
70
- # Lọc các giao dịch liên quan
71
- subset = df_edges[(df_edges["src"] == address) | (df_edges["dst"] == address)].head(30)
72
-
73
- if subset.empty:
74
- fig, ax = plt.subplots(figsize=(6, 6))
75
- ax.text(0.5, 0.5, "No transactions found", ha='center')
76
- ax.axis('off')
77
- return fig
78
-
79
- # Tạo đồ thị
80
- G = nx.from_pandas_edgelist(subset, source="src", target="dst", edge_attr="edge_type", create_using=nx.DiGraph())
81
-
82
- pos = nx.spring_layout(G, seed=42, k=0.8)
83
-
84
- plt.figure(figsize=(8, 8))
85
-
86
- # Tô màu node
87
- node_colors = []
88
- node_sizes = []
89
- for node in G.nodes():
90
- if node == address:
91
- node_colors.append("#FF4500") # Target: OrangeRed
92
- node_sizes.append(400)
93
- else:
94
- node_colors.append("#1E90FF") # Neighbor: DodgerBlue
95
- node_sizes.append(150)
96
 
97
- nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=node_sizes, alpha=0.9)
98
- nx.draw_networkx_edges(G, pos, alpha=0.4, arrowstyle='->', arrowsize=15, edge_color="gray")
99
-
100
- # Label ngắn gọn
101
- labels = {n: (n[:5] + ".." if n != address else "TARGET") for n in G.nodes()}
102
- nx.draw_networkx_labels(G, pos, labels=labels, font_size=9, font_color="black")
103
-
104
- plt.title(f"Ego Graph: {address[:10]}...\n(Showing top {len(subset)} txs)")
105
- plt.axis('off')
106
-
107
- return plt.gcf()
108
-
109
- # --- 3. CORE LOGIC ---
110
-
111
- def analyze_wallet(address):
112
- if not address:
113
- return "Please enter an address.", "N/A", None
114
-
115
- address = address.strip()
116
- score, label_str, status = get_node_info(address)
117
-
118
- # Xử lý kết quả hiển thị
119
- if status == "NOT_FOUND":
120
- return f"❌ Address {address} not found in any transaction data.", "Unknown", None
121
-
122
- plot = draw_ego_graph(address)
123
-
124
- if status == "UNSCORED_BUT_FOUND":
125
- return (
126
- f"⚠️ **Not Scored via GNN**\n\n"
127
- f"This address exists in the transaction list (`edges_all.csv`) but was filtered out during the GNN training graph construction (likely an isolated node or missing features).\n"
128
- f"Therefore, the model did not assign a risk score.",
129
- "Not Scored",
130
- plot
131
  )
132
 
133
- # Nếu có điểm (SCORED)
134
- risk_level = "LOW 🟢"
135
- if score > 0.8: risk_level = "CRITICAL 🔴"
136
- elif score > 0.5: risk_level = "HIGH 🟠"
137
- elif score > 0.2: risk_level = "MEDIUM 🟡"
138
-
139
- result_text = (
140
- f"### 🎯 Risk Score: {score:.4f}\n"
141
- f"**Label:** {label_str}\n"
142
- f"**Status:** Analyzed by GraphSAGE\n"
143
- )
144
-
145
- return result_text, risk_level, plot
146
 
147
- # --- 4. UI ---
148
 
149
- with gr.Blocks(title="ETH Fraud Inspector") as demo:
150
- gr.Markdown("# 🕵️‍♀️ Ethereum Fraud Inspector (GraphSAGE v3)")
151
- gr.Markdown("Investigate Ethereum wallets using Graph Neural Networks. Even if a wallet wasn't scored by the model, we will visualize its transaction history.")
 
 
 
 
 
 
152
 
 
153
  with gr.Row():
154
- with gr.Column(scale=1):
155
- inp_addr = gr.Textbox(label="Ethereum Address", placeholder="0x...", lines=1)
156
- btn = gr.Button("🔍 Analyze", variant="primary")
157
-
158
- gr.Markdown("### 💡 Try these addresses:")
159
- # Lấy mẫu 1 ví có điểm (Criminal) và 1 ví chỉ có trong edges
160
- examples = []
161
- if not df_scores.empty:
162
- # Lấy 1 ví criminal
163
- crim_example = df_scores[df_scores['label'] == 1].index[0] if 1 in df_scores['label'].values else df_scores.index[0]
164
- examples.append(crim_example)
 
 
 
 
 
165
 
166
- gr.Examples(examples=examples, inputs=inp_addr)
167
-
168
- with gr.Column(scale=2):
169
- with gr.Row():
170
- lbl_risk = gr.Label(label="Risk Level")
171
- out_text = gr.Markdown(label="Analysis Report")
172
- out_plot = gr.Plot(label="Transaction Graph")
173
-
174
- btn.click(fn=analyze_wallet, inputs=inp_addr, outputs=[out_text, lbl_risk, out_plot])
 
175
 
176
  if __name__ == "__main__":
177
  demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch_geometric.nn import SAGEConv, BatchNorm
6
+ import json
7
  import pandas as pd
8
+ import numpy as np
 
9
  import os
10
  from huggingface_hub import hf_hub_download
11
 
12
+ # --- 1. ĐỊNH NGHĨA MODEL ARCHITECTURE ---
13
+ # Phải khớp chính xác với kiến trúc đã dùng để train trong notebook (Cell 16, trang 25)
14
+ class SAGE(nn.Module):
15
+ def __init__(self, in_dim, h=128, out_dim=2, p_drop=0.3):
16
+ super().__init__()
17
+ self.conv1 = SAGEConv(in_dim, h, bias=True)
18
+ self.bn1 = BatchNorm(h)
19
+ self.conv2 = SAGEConv(h, h, bias=True)
20
+ self.bn2 = BatchNorm(h)
21
+ self.head = nn.Linear(h, out_dim)
22
+ self.drop = nn.Dropout(p_drop)
23
+
24
+ def forward(self, x, edge_index):
25
+ # Layer 1
26
+ x = self.conv1(x, edge_index)
27
+ x = self.bn1(x)
28
+ x = F.relu(x)
29
+ x = self.drop(x)
30
+
31
+ # Layer 2
32
+ x = self.conv2(x, edge_index)
33
+ x = self.bn2(x)
34
+ x = F.relu(x)
35
+ x = self.drop(x)
36
+
37
+ # Output
38
+ return self.head(x)
39
+
40
+ # --- 2. SETUP & LOAD MODEL ---
41
  REPO_ID = "uyen1109/eth-fraud-gnn-uyenuyen-v3"
42
  TOKEN = os.getenv("HF_TOKEN")
43
 
44
+ print(" Downloading model artifacts...")
45
 
46
+ # 2.1 Tải danh sách Features (để biết thứ tự nhập liệu)
 
 
 
 
47
  try:
48
+ cols_path = hf_hub_download(repo_id=REPO_ID, filename="hf_export/feature_columns.json", token=TOKEN)
49
+ with open(cols_path, 'r') as f:
50
+ FEATURE_COLS = json.load(f)
51
+ print(f" Loaded {len(FEATURE_COLS)} feature columns.")
 
 
52
  except Exception as e:
53
+ print(f"⚠️ Could not load feature_columns.json. Using default fallback list. Error: {e}")
54
+ # Fallback danh sách feature dựa trên notebook (Cell 8, 11, 12)
55
+ FEATURE_COLS = [
56
+ 'out_deg', 'in_deg', 'eth_out_sum', 'eth_in_sum',
57
+ 'unique_dst_cnt', 'unique_src_cnt', 'first_seen_ts', 'last_seen_ts',
58
+ 'pr', 'clust_coef', 'betw'
59
+ ]
60
+
61
+ # 2.2 Tải trọng số Model (pytorch_model.bin)
62
  try:
63
+ model_path = hf_hub_download(repo_id=REPO_ID, filename="hf_export/pytorch_model.bin", token=TOKEN)
64
+
65
+ # Khởi tạo model
66
+ # in_dim phải bằng số lượng feature
67
+ model = SAGE(in_dim=len(FEATURE_COLS), h=128, out_dim=2, p_drop=0.3)
68
+
69
+ # Load weights (map_location='cpu' để chạy trên không gian không có GPU)
70
+ state_dict = torch.load(model_path, map_location=torch.device('cpu'))
71
+ model.load_state_dict(state_dict)
72
+ model.eval() # Chuyển sang chế độ Inference (tắt Dropout, v.v.)
73
+ print("✅ Model loaded successfully!")
74
  except Exception as e:
75
+ print(f" Critical Error loading model: {e}")
76
+ model = None
77
 
78
+ # --- 3. INFERENCE FUNCTION ---
79
 
80
+ def predict_custom_node(*features):
81
+ if model is None:
82
+ return "Model not loaded correctly.", "Error"
83
 
84
+ try:
85
+ # 1. Chuyển list features nhập từ UI thành Tensor
86
+ # features là một tuple các giá trị
87
+ feat_values = [float(f) for f in features]
88
+ x = torch.tensor([feat_values], dtype=torch.float) # Shape: [1, num_features]
 
 
 
89
 
90
+ # 2. Tạo cạnh giả (Dummy Edge Index)
91
+ # Vì GraphSAGE cần edge_index để chạy, nhưng với 1 node đơn lẻ (Inductive trên node mới),
92
+ # ta khôngthông tin hàng xóm.
93
+ # Ta truyền vào edge_index rỗng. SAGEConv sẽ hoạt động dựa trên feature của chính node đó (Self-loop logic).
94
+ edge_index = torch.tensor([[], []], dtype=torch.long)
 
 
 
95
 
96
+ # 3. Forward pass
97
+ with torch.no_grad():
98
+ logits = model(x, edge_index)
99
+ probs = torch.softmax(logits, dim=1)
100
+ prob_criminal = probs[0][1].item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
+ # 4. Xử kết quả
103
+ label = "CRIMINAL 🔴" if prob_criminal > 0.5 else "BENIGN 🟢"
104
+ score_percent = f"{prob_criminal * 100:.2f}%"
105
+
106
+ explanation = (
107
+ f"### Prediction Result\n"
108
+ f"- **Probability of Fraud:** {score_percent}\n"
109
+ f"- **Verdict:** {label}\n\n"
110
+ f"### Debug Info\n"
111
+ f"- Input Shape: {x.shape}\n"
112
+ f"- Raw Logits: {logits.numpy()}\n"
113
+ f"- Model Architecture: GraphSAGE (2 layers, 128 hidden units)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  )
115
 
116
+ return explanation, label
117
+
118
+ except Exception as e:
119
+ return f"Error during inference: {str(e)}", "Error"
 
 
 
 
 
 
 
 
 
120
 
121
+ # --- 4. GRADIO UI ---
122
 
123
+ with gr.Blocks(title="Inductive Fraud Prediction") as demo:
124
+ gr.Markdown("# 🧠 Inductive GraphSAGE Prediction")
125
+ gr.Markdown(
126
+ """
127
+ Demo này thể hiện tính **Inductive** của mô hình: Bạn có thể nhập thông số của một ví **hoàn toàn mới** (không có trong tập dữ liệu cũ) và mô hình sẽ dự đoán dựa trên những gì nó đã học được.
128
+
129
+ *Lưu ý: Vì nhập liệu thủ công, ta đang mô phỏng node này như một node cô lập (không có thông tin hàng xóm).*
130
+ """
131
+ )
132
 
133
+ inputs = []
134
  with gr.Row():
135
+ with gr.Column():
136
+ gr.Markdown("### 1. Nhập Features (Đặc trưng) của Ví")
137
+ # Tự động tạo ô nhập liệu dựa trên danh sách FEATURE_COLS
138
+ for col in FEATURE_COLS:
139
+ # Gợi ý giá trị mặc định để dễ test
140
+ default_val = 0.0
141
+ if "ts" in col: default_val = 1600000000 # Timestamp
142
+
143
+ inp = gr.Number(label=col, value=default_val)
144
+ inputs.append(inp)
145
+
146
+ with gr.Column():
147
+ gr.Markdown("### 2. Kết quả Dự đoán")
148
+ btn_predict = gr.Button("Run Inference", variant="primary")
149
+ lbl_result = gr.Label(label="Prediction")
150
+ out_log = gr.Markdown()
151
 
152
+ # Nút Clear để reset
153
+ btn_clear = gr.Button("Clear Inputs")
154
+
155
+ # Sự kiện click
156
+ btn_predict.click(fn=predict_custom_node, inputs=inputs, outputs=[out_log, lbl_result])
157
+
158
+ # Reset tất cả về 0
159
+ def clear_fn():
160
+ return [0.0] * len(inputs)
161
+ btn_clear.click(fn=clear_fn, inputs=None, outputs=inputs)
162
 
163
  if __name__ == "__main__":
164
  demo.launch()