uyen1109 commited on
Commit
4051c4e
·
verified ·
1 Parent(s): d4b1511

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -151
app.py CHANGED
@@ -36,38 +36,42 @@ class SAGE(nn.Module):
36
  return self.head(x)
37
 
38
  # ==========================================
39
- # 2. QUẢN LÝ RESOURCE (DATA & MODEL)
40
  # ==========================================
41
  REPO_ID = "uyen1109/eth-fraud-gnn-uyenuyen-v3"
42
  TOKEN = os.getenv("HF_TOKEN")
43
 
44
- # Danh sách feature mặc định (Fallback) để UI không bị trống nếu lỗi load file
45
- DEFAULT_FEATURES = [
46
- 'out_deg', 'in_deg', 'eth_out_sum', 'eth_in_sum',
47
- 'unique_dst_cnt', 'unique_src_cnt', 'first_seen_ts', 'last_seen_ts',
48
- 'pr', 'clust_coef', 'betw', 'feat_11', 'feat_12', 'feat_13', 'feat_14'
49
- ]
50
-
51
  GLOBAL_DATA = {
52
  "model": None,
53
  "df_scores": pd.DataFrame(),
54
  "df_edges": pd.DataFrame(),
55
- "feature_cols": DEFAULT_FEATURES, # Luôn có giá trị mặc định
56
  "status": "Initializing..."
57
  }
58
 
59
  def smart_load_file(filename):
60
- """Thử tải file từ repo, ưu tiên hf_export"""
61
- paths = [f"hf_export/{filename}", filename]
 
 
 
 
 
 
62
  for p in paths:
63
  try:
64
- # Thử tải với token trước, nếu lỗi thử không token (public repo)
65
  return hf_hub_download(repo_id=REPO_ID, filename=p, token=TOKEN)
66
- except:
 
67
  try:
 
68
  return hf_hub_download(repo_id=REPO_ID, filename=p, token=None)
69
- except:
 
70
  continue
 
 
71
  return None
72
 
73
  def load_resources():
@@ -75,72 +79,83 @@ def load_resources():
75
  print("⏳ Starting Resource Loading...")
76
 
77
  # 1. Load Scores
78
- try:
79
- path = smart_load_file("scores/node_scores_with_labels.csv") or smart_load_file("node_scores_with_labels.csv")
80
- if path:
81
  df = pd.read_csv(path)
82
- # Chuẩn hóa cột địa chỉ: tìm cột chứa chữ 'address' hoặc 'id'
83
- addr_col = next((c for c in df.columns if 'addr' in c.lower() or 'id' in c.lower()), df.columns[0])
 
 
 
 
 
84
  df[addr_col] = df[addr_col].astype(str).str.lower().str.strip()
85
- # Đặt index là địa chỉ để tra cứu nhanh
86
  df.set_index(addr_col, inplace=True)
87
  GLOBAL_DATA["df_scores"] = df
88
- msg = f"✅ Loaded Scores: {len(df)} rows (Index col: {addr_col})"
89
- print(msg)
90
- logs.append(msg)
91
- else:
92
- logs.append("⚠️ Scores CSV not found.")
93
- except Exception as e:
94
- logs.append(f"❌ Error loading scores: {str(e)}")
95
 
96
  # 2. Load Edges
97
- try:
98
- path = smart_load_file("graph/edges_all.csv") or smart_load_file("edges_all.csv")
99
- if path:
100
- df = pd.read_csv(path, usecols=["src", "dst", "edge_type"])
101
- df["src"] = df["src"].astype(str).str.lower().str.strip()
102
- df["dst"] = df["dst"].astype(str).str.lower().str.strip()
103
- GLOBAL_DATA["df_edges"] = df
104
  print("✅ Loaded Edges.")
105
- except:
106
- print("⚠️ Edges CSV not found (Graph viz will be disabled).")
 
 
107
 
108
- # 3. Load Model
109
- try:
110
- model_path = smart_load_file("pytorch_model.bin")
111
- if model_path:
112
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))
113
- # Tự động phát hiện input dim
114
  detected_dim = state_dict['conv1.lin_l.weight'].shape[1]
115
 
116
  model = SAGE(in_dim=detected_dim, h=128, out_dim=2, p_drop=0.3)
117
  model.load_state_dict(state_dict)
118
  model.eval()
119
  GLOBAL_DATA["model"] = model
 
120
 
121
- # Cập nhật danh sách feature cột nếu có
122
  cols_path = smart_load_file("feature_columns.json")
123
  if cols_path:
124
  with open(cols_path, 'r') as f:
125
  cols = json.load(f)
126
- # Điều chỉnh cho khớp detected_dim
127
- if len(cols) >= detected_dim:
128
- GLOBAL_DATA["feature_cols"] = cols[:detected_dim]
129
- else:
130
- GLOBAL_DATA["feature_cols"] = cols + [f"F_{i}" for i in range(len(cols), detected_dim)]
 
 
131
  else:
132
- # Nếu không có file json, tạo dummy name cho đủ số lượng
133
  GLOBAL_DATA["feature_cols"] = [f"Feature_{i}" for i in range(detected_dim)]
 
134
 
135
- logs.append(f"✅ Model Loaded (Input Dim: {detected_dim})")
136
- else:
137
- logs.append("❌ pytorch_model.bin not found.")
138
- except Exception as e:
139
- logs.append(f"❌ Model Load Error: {str(e)}")
 
 
 
 
 
140
 
141
  GLOBAL_DATA["status"] = "\n".join(logs)
 
142
 
143
- # Chạy load ngay lập tức
144
  load_resources()
145
 
146
  # ==========================================
@@ -151,145 +166,97 @@ def draw_graph(address):
151
  df = GLOBAL_DATA["df_edges"]
152
  if df.empty: return None
153
 
154
- # Tìm giao dịch liên quan (cả in và out)
155
- subset = df[(df["src"] == address) | (df["dst"] == address)].head(30)
156
  if subset.empty: return None
157
 
158
  G = nx.from_pandas_edgelist(subset, "src", "dst", edge_attr="edge_type", create_using=nx.DiGraph())
 
 
159
 
160
- plt.figure(figsize=(8, 8))
161
- pos = nx.spring_layout(G, k=0.8, seed=42)
162
-
163
- # Màu sắc: Target màu đỏ, Neighbor màu xanh
164
  node_colors = ["#FF4500" if n == address else "#1E90FF" for n in G.nodes()]
165
- node_sizes = [400 if n == address else 150 for n in G.nodes()]
166
-
167
- nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=node_sizes, alpha=0.9)
168
  nx.draw_networkx_edges(G, pos, alpha=0.3, arrowstyle='->')
169
- # Label ngắn gọn
170
- nx.draw_networkx_labels(G, pos, labels={n: n[:4] + ".." for n in G.nodes()}, font_size=8)
171
 
172
  plt.title(f"Ego Graph: {address[:6]}...")
173
  plt.axis('off')
174
  return plt.gcf()
175
 
176
  def lookup_handler(address):
177
- # Chuẩn hóa input cực mạnh để khớp với index CSV
178
- raw_addr = str(address).strip().lower()
179
 
 
180
  df = GLOBAL_DATA["df_scores"]
181
- found_row = None
182
 
183
- # Thử các trường hợp khớp
184
- if raw_addr in df.index:
185
- found_row = df.loc[raw_addr]
186
- elif raw_addr.startswith("0x") and raw_addr[2:] in df.index: # Thử bỏ 0x
187
- found_row = df.loc[raw_addr[2:]]
188
- elif ("0x" + raw_addr) in df.index: # Thử thêm 0x
189
- found_row = df.loc["0x" + raw_addr]
190
-
191
- if found_row is not None:
192
- # Lấy điểm số
193
- try:
194
- # Xử lý trường hợp duplicate index hoặc series
195
- if isinstance(found_row, pd.DataFrame):
196
- found_row = found_row.iloc[0]
197
-
198
- score = float(found_row.get("prob_criminal", found_row.get("susp", 0.0)))
199
- label = int(found_row.get("label", -1))
200
-
201
- risk = "CRITICAL 🔴" if score > 0.8 else ("HIGH 🟠" if score > 0.5 else "LOW 🟢")
202
- label_text = "Unknown"
203
- if label == 1: label_text = "Criminal (True Label)"
204
- elif label == 0: label_text = "Benign (True Label)"
205
-
206
- info = (
207
- f"### ✅ Address Found\n"
208
- f"- **Risk Score:** {score:.4f}\n"
209
- f"- **Risk Level:** {risk}\n"
210
- f"- **Dataset Label:** {label_text}"
211
- )
212
- return info, draw_graph(raw_addr)
213
- except Exception as e:
214
- return f"Error parsing row: {e}", None
215
-
216
- # Nếu không tìm thấy
217
  return (
218
- f"### ❌ Not Found in Database\n"
219
- f"Address `{raw_addr}` does not exist in `node_scores_with_labels.csv`.\n"
220
- f"Please verify the address or use the **Inductive Prediction** tab.",
221
  None
222
  )
223
 
224
  def predict_handler(*features):
225
- model = GLOBAL_DATA["model"]
226
- if model is None:
227
- return f"❌ Model failed to load properly.\n\nLogs:\n{GLOBAL_DATA['status']}"
228
 
229
  try:
230
  x = torch.tensor([[float(f) for f in features]], dtype=torch.float)
231
  edge_index = torch.tensor([[], []], dtype=torch.long)
232
-
233
  with torch.no_grad():
234
- logits = model(x, edge_index)
235
- prob = torch.softmax(logits, dim=1)[0][1].item()
236
-
237
- verdict = "CRIMINAL 🔴" if prob > 0.5 else "BENIGN 🟢"
238
- return (
239
- f"### 🧠 Prediction Result\n"
240
- f"- **Fraud Probability:** {prob*100:.2f}%\n"
241
- f"- **Verdict:** {verdict}"
242
- )
243
  except Exception as e:
244
- return f"Prediction Error: {str(e)}"
245
 
246
  # ==========================================
247
  # 4. UI SETUP
248
  # ==========================================
249
  with gr.Blocks(title="ETH Fraud GNN") as demo:
250
- gr.Markdown("# 🕵️‍♀️ Ethereum Fraud GNN (Hybrid V3)")
251
 
252
- # Hiển thị trạng thái load hệ thống (ẩn đi nếu muốn gọn)
253
- with gr.Accordion("System Status / Logs", open=False):
254
- gr.Markdown(GLOBAL_DATA["status"])
255
 
256
  with gr.Tabs():
257
- # TAB 1: LOOKUP
258
- with gr.TabItem("🔍 Lookup Address"):
259
  with gr.Row():
260
- inp_addr = gr.Textbox(label="Enter Address", placeholder="0x...")
261
- btn_search = gr.Button("Search", variant="primary")
262
-
263
  with gr.Row():
264
- out_info = gr.Markdown()
265
- out_plot = gr.Plot()
266
-
267
- btn_search.click(lookup_handler, inputs=inp_addr, outputs=[out_info, out_plot])
268
 
269
- # TAB 2: INDUCTIVE
270
- with gr.TabItem("🧠 Inductive Prediction"):
271
- gr.Markdown("### Predict New Address")
272
- gr.Markdown("Enter extracted features manually:")
273
 
274
- # TẠO INPUT ĐỘNG: model load được hay không, UI vẫn sẽ render dựa trên GLOBAL_DATA["feature_cols"]
275
- # Điều này fix lỗi giao diện trống trơn.
276
- feat_inputs = []
277
  cols = GLOBAL_DATA["feature_cols"]
278
-
279
- # Chia layout thành 3 cột
280
  with gr.Row():
281
- col1, col2, col3 = gr.Column(), gr.Column(), gr.Column()
282
-
283
- # Phân phối input vào 3 cột
284
  for i, c in enumerate(cols):
285
- target_col = col1 if i % 3 == 0 else (col2 if i % 3 == 1 else col3)
286
- with target_col:
287
- feat_inputs.append(gr.Number(label=c, value=0.0))
288
-
289
- btn_predict = gr.Button("Run Inference", variant="primary")
290
- out_pred = gr.Markdown()
291
 
292
- btn_predict.click(predict_handler, inputs=feat_inputs, outputs=out_pred)
 
 
293
 
294
  if __name__ == "__main__":
295
  demo.launch()
 
36
  return self.head(x)
37
 
38
  # ==========================================
39
+ # 2. QUẢN LÝ RESOURCE
40
  # ==========================================
41
  REPO_ID = "uyen1109/eth-fraud-gnn-uyenuyen-v3"
42
  TOKEN = os.getenv("HF_TOKEN")
43
 
 
 
 
 
 
 
 
44
  GLOBAL_DATA = {
45
  "model": None,
46
  "df_scores": pd.DataFrame(),
47
  "df_edges": pd.DataFrame(),
48
+ "feature_cols": [],
49
  "status": "Initializing..."
50
  }
51
 
52
  def smart_load_file(filename):
53
+ """
54
+ Ưu tiên tìm ở root (theo hình ảnh user cung cấp).
55
+ Thử có token -> không token.
56
+ """
57
+ # Đảo ngược thứ tự: Tìm ở root trước vì hình ảnh cho thấy file ở root
58
+ paths = [filename, f"hf_export/{filename}"]
59
+
60
+ errs = []
61
  for p in paths:
62
  try:
63
+ # Cách 1: Dùng Token (cho Private Repo hoặc LFS)
64
  return hf_hub_download(repo_id=REPO_ID, filename=p, token=TOKEN)
65
+ except Exception as e1:
66
+ errs.append(f"Token fail {p}: {e1}")
67
  try:
68
+ # Cách 2: Không dùng Token (cho Public Repo)
69
  return hf_hub_download(repo_id=REPO_ID, filename=p, token=None)
70
+ except Exception as e2:
71
+ errs.append(f"No-Token fail {p}: {e2}")
72
  continue
73
+
74
+ print(f"⚠️ Failed to load {filename}. Details: {errs}")
75
  return None
76
 
77
  def load_resources():
 
79
  print("⏳ Starting Resource Loading...")
80
 
81
  # 1. Load Scores
82
+ path = smart_load_file("node_scores_with_labels.csv")
83
+ if path:
84
+ try:
85
  df = pd.read_csv(path)
86
+ # Tìm cột địa chỉ linh hoạt
87
+ cols_lower = [c.lower() for c in df.columns]
88
+ if "address" in cols_lower:
89
+ addr_col = df.columns[cols_lower.index("address")]
90
+ else:
91
+ addr_col = df.columns[0]
92
+
93
  df[addr_col] = df[addr_col].astype(str).str.lower().str.strip()
 
94
  df.set_index(addr_col, inplace=True)
95
  GLOBAL_DATA["df_scores"] = df
96
+ logs.append(f"✅ Loaded Scores: {len(df)} rows.")
97
+ except Exception as e:
98
+ logs.append(f"❌ Error parsing scores csv: {e}")
99
+ else:
100
+ logs.append(" 'node_scores_with_labels.csv' download failed.")
 
 
101
 
102
  # 2. Load Edges
103
+ path = smart_load_file("edges_all.csv")
104
+ if path:
105
+ try:
106
+ GLOBAL_DATA["df_edges"] = pd.read_csv(path, usecols=["src", "dst", "edge_type"])
107
+ # Chuẩn hóa nhẹ để vẽ hình
108
+ GLOBAL_DATA["df_edges"]["src"] = GLOBAL_DATA["df_edges"]["src"].astype(str).str.lower().str.strip()
109
+ GLOBAL_DATA["df_edges"]["dst"] = GLOBAL_DATA["df_edges"]["dst"].astype(str).str.lower().str.strip()
110
  print("✅ Loaded Edges.")
111
+ except Exception as e:
112
+ print(f"⚠️ Edge parsing error: {e}")
113
+ else:
114
+ print("⚠️ 'edges_all.csv' download failed.")
115
 
116
+ # 3. Load Model & Features
117
+ model_path = smart_load_file("pytorch_model.bin")
118
+ if model_path:
119
+ try:
120
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))
 
121
  detected_dim = state_dict['conv1.lin_l.weight'].shape[1]
122
 
123
  model = SAGE(in_dim=detected_dim, h=128, out_dim=2, p_drop=0.3)
124
  model.load_state_dict(state_dict)
125
  model.eval()
126
  GLOBAL_DATA["model"] = model
127
+ logs.append(f"✅ Model Loaded (Input Dim: {detected_dim})")
128
 
129
+ # Load Feature Columns
130
  cols_path = smart_load_file("feature_columns.json")
131
  if cols_path:
132
  with open(cols_path, 'r') as f:
133
  cols = json.load(f)
134
+ # Khớp số lượng feature
135
+ if len(cols) == detected_dim:
136
+ GLOBAL_DATA["feature_cols"] = cols
137
+ elif len(cols) > detected_dim:
138
+ GLOBAL_DATA["feature_cols"] = cols[:detected_dim]
139
+ else:
140
+ GLOBAL_DATA["feature_cols"] = cols + [f"Feat_{i}" for i in range(len(cols), detected_dim)]
141
  else:
 
142
  GLOBAL_DATA["feature_cols"] = [f"Feature_{i}" for i in range(detected_dim)]
143
+ logs.append("⚠️ Using Dummy Feature Names (json missing)")
144
 
145
+ except Exception as e:
146
+ logs.append(f"❌ Model Init Error: {e}")
147
+ else:
148
+ logs.append("❌ 'pytorch_model.bin' NOT FOUND. Please upload it to Repo Root.")
149
+ # Fallback feature list để UI không bị lỗi (dựa trên log của bạn)
150
+ GLOBAL_DATA["feature_cols"] = [
151
+ 'out_deg', 'in_deg', 'eth_out_sum', 'eth_in_sum',
152
+ 'unique_dst_cnt', 'unique_src_cnt', 'first_seen_ts', 'last_seen_ts',
153
+ 'pr', 'clust_coef', 'betw', 'feat_11', 'feat_12', 'feat_13', 'feat_14'
154
+ ]
155
 
156
  GLOBAL_DATA["status"] = "\n".join(logs)
157
+ print(GLOBAL_DATA["status"])
158
 
 
159
  load_resources()
160
 
161
  # ==========================================
 
166
  df = GLOBAL_DATA["df_edges"]
167
  if df.empty: return None
168
 
169
+ subset = df[(df["src"] == address) | (df["dst"] == address)].head(20)
 
170
  if subset.empty: return None
171
 
172
  G = nx.from_pandas_edgelist(subset, "src", "dst", edge_attr="edge_type", create_using=nx.DiGraph())
173
+ plt.figure(figsize=(6, 6))
174
+ pos = nx.spring_layout(G, k=0.9, seed=42)
175
 
 
 
 
 
176
  node_colors = ["#FF4500" if n == address else "#1E90FF" for n in G.nodes()]
177
+ nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=200, alpha=0.9)
 
 
178
  nx.draw_networkx_edges(G, pos, alpha=0.3, arrowstyle='->')
179
+ nx.draw_networkx_labels(G, pos, labels={n: n[:4] for n in G.nodes()}, font_size=8)
 
180
 
181
  plt.title(f"Ego Graph: {address[:6]}...")
182
  plt.axis('off')
183
  return plt.gcf()
184
 
185
  def lookup_handler(address):
186
+ if not address: return "Please enter an address.", None
 
187
 
188
+ raw_addr = str(address).strip().lower()
189
  df = GLOBAL_DATA["df_scores"]
 
190
 
191
+ # Logic tìm kiếm mạnh mẽ hơn
192
+ found = None
193
+ if not df.empty:
194
+ if raw_addr in df.index:
195
+ found = df.loc[raw_addr]
196
+ elif raw_addr.replace("0x", "") in df.index:
197
+ found = df.loc[raw_addr.replace("0x", "")]
198
+
199
+ if found is not None:
200
+ if isinstance(found, pd.DataFrame): found = found.iloc[0]
201
+ score = float(found.get("prob_criminal", found.get("susp", 0.0)))
202
+ return (
203
+ f"### Found\n**Score:** {score:.4f}\n**Status:** {'CRITICAL 🔴' if score > 0.5 else 'BENIGN 🟢'}",
204
+ draw_graph(raw_addr)
205
+ )
206
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  return (
208
+ f"### ❌ Not Found\nAddress `{raw_addr}` not in database.\nStatus Logs:\n{GLOBAL_DATA['status']}",
 
 
209
  None
210
  )
211
 
212
  def predict_handler(*features):
213
+ if GLOBAL_DATA["model"] is None:
214
+ return f"❌ Model Error: pytorch_model.bin missing.\nPlease check 'System Status' below."
 
215
 
216
  try:
217
  x = torch.tensor([[float(f) for f in features]], dtype=torch.float)
218
  edge_index = torch.tensor([[], []], dtype=torch.long)
 
219
  with torch.no_grad():
220
+ prob = torch.softmax(GLOBAL_DATA["model"](x, edge_index), dim=1)[0][1].item()
221
+ return f"### Result\n**Fraud Probability:** {prob*100:.2f}%"
 
 
 
 
 
 
 
222
  except Exception as e:
223
+ return f"Error: {e}"
224
 
225
  # ==========================================
226
  # 4. UI SETUP
227
  # ==========================================
228
  with gr.Blocks(title="ETH Fraud GNN") as demo:
229
+ gr.Markdown("# 🕵️‍♀️ Ethereum Fraud Inspector")
230
 
231
+ with gr.Accordion("System Status (Click to Debug)", open=False):
232
+ gr.Markdown(lambda: GLOBAL_DATA["status"]) # Dynamic update
 
233
 
234
  with gr.Tabs():
235
+ with gr.TabItem("🔍 Lookup"):
 
236
  with gr.Row():
237
+ inp = gr.Textbox(label="Address")
238
+ btn = gr.Button("Search", variant="primary")
 
239
  with gr.Row():
240
+ out_txt = gr.Markdown()
241
+ out_plt = gr.Plot()
242
+ btn.click(lookup_handler, inputs=inp, outputs=[out_txt, out_plt])
 
243
 
244
+ with gr.TabItem("🧠 Predict"):
245
+ gr.Markdown("### Inductive Prediction (Simulated)")
 
 
246
 
247
+ # Render input dựa trên feature cols đã load
 
 
248
  cols = GLOBAL_DATA["feature_cols"]
249
+ inputs = []
 
250
  with gr.Row():
251
+ # Chia cột tự động
252
+ c1, c2 = gr.Column(), gr.Column()
 
253
  for i, c in enumerate(cols):
254
+ with (c1 if i % 2 == 0 else c2):
255
+ inputs.append(gr.Number(label=c, value=0.0))
 
 
 
 
256
 
257
+ btn2 = gr.Button("Predict", variant="primary")
258
+ out2 = gr.Markdown()
259
+ btn2.click(predict_handler, inputs=inputs, outputs=out2)
260
 
261
  if __name__ == "__main__":
262
  demo.launch()