uyen1109 commited on
Commit
d4b1511
·
verified ·
1 Parent(s): 672f23f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -139
app.py CHANGED
@@ -12,7 +12,7 @@ import json
12
  from huggingface_hub import hf_hub_download
13
 
14
  # ==========================================
15
- # 1. ĐỊNH NGHĨA KIẾN TRÚC MODEL (GraphSAGE)
16
  # ==========================================
17
  class SAGE(nn.Module):
18
  def __init__(self, in_dim, h=128, out_dim=2, p_drop=0.3):
@@ -25,251 +25,271 @@ class SAGE(nn.Module):
25
  self.drop = nn.Dropout(p_drop)
26
 
27
  def forward(self, x, edge_index):
28
- # Layer 1
29
  x = self.conv1(x, edge_index)
30
  x = self.bn1(x)
31
  x = F.relu(x)
32
  x = self.drop(x)
33
- # Layer 2
34
  x = self.conv2(x, edge_index)
35
  x = self.bn2(x)
36
  x = F.relu(x)
37
  x = self.drop(x)
38
- # Output
39
  return self.head(x)
40
 
41
  # ==========================================
42
- # 2. QUẢN LÝ DỮ LIỆU & TẢI MODEL
43
  # ==========================================
44
  REPO_ID = "uyen1109/eth-fraud-gnn-uyenuyen-v3"
45
  TOKEN = os.getenv("HF_TOKEN")
46
 
47
- # Biến toàn cục lưu trữ dữ liệu
 
 
 
 
 
 
48
  GLOBAL_DATA = {
49
  "model": None,
50
  "df_scores": pd.DataFrame(),
51
  "df_edges": pd.DataFrame(),
52
- "feature_cols": [],
53
- "input_dim": 0
54
  }
55
 
56
  def smart_load_file(filename):
57
- """Tìm file root hoặc thư mục hf_export"""
58
- possible_paths = [f"hf_export/{filename}", filename]
59
- for p in possible_paths:
60
  try:
 
61
  return hf_hub_download(repo_id=REPO_ID, filename=p, token=TOKEN)
62
  except:
63
- continue
 
 
 
64
  return None
65
 
66
  def load_resources():
67
- print("⏳ Loading resources...")
 
68
 
69
- # 1. Load Scores (Lookup Data)
70
  try:
71
  path = smart_load_file("scores/node_scores_with_labels.csv") or smart_load_file("node_scores_with_labels.csv")
72
  if path:
73
  df = pd.read_csv(path)
74
- if "address" in df.columns:
75
- df["address"] = df["address"].astype(str).str.lower().str.strip()
76
- df.set_index("address", inplace=True)
 
 
77
  GLOBAL_DATA["df_scores"] = df
78
- print(f"✅ Loaded {len(df)} scores.")
 
 
 
 
79
  except Exception as e:
80
- print(f"⚠️ Score load error: {e}")
81
 
82
- # 2. Load Edges (Graph Visualization)
83
  try:
84
  path = smart_load_file("graph/edges_all.csv") or smart_load_file("edges_all.csv")
85
  if path:
86
- # Chỉ load cột cần thiết để tiết kiệm RAM
87
- GLOBAL_DATA["df_edges"] = pd.read_csv(path, usecols=["src", "dst", "edge_type"])
88
- GLOBAL_DATA["df_edges"]["src"] = GLOBAL_DATA["df_edges"]["src"].astype(str).str.lower().str.strip()
89
- GLOBAL_DATA["df_edges"]["dst"] = GLOBAL_DATA["df_edges"]["dst"].astype(str).str.lower().str.strip()
90
- print(f"✅ Loaded edges.")
91
- except Exception as e:
92
- print(f"⚠️ Edge load error: {e}")
93
 
94
- # 3. Load Model Weights & Detect Dimension
95
  try:
96
  model_path = smart_load_file("pytorch_model.bin")
97
  if model_path:
98
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))
99
- # Tự động phát hiện input dimension từ trọng số
100
  detected_dim = state_dict['conv1.lin_l.weight'].shape[1]
101
- GLOBAL_DATA["input_dim"] = detected_dim
102
 
103
- # Khởi tạo model
104
  model = SAGE(in_dim=detected_dim, h=128, out_dim=2, p_drop=0.3)
105
  model.load_state_dict(state_dict)
106
  model.eval()
107
  GLOBAL_DATA["model"] = model
108
- print(f"✅ Model loaded (Input Dim: {detected_dim})")
109
 
110
- # Load tên cột features (nếu có)
111
  cols_path = smart_load_file("feature_columns.json")
112
  if cols_path:
113
  with open(cols_path, 'r') as f:
114
  cols = json.load(f)
115
- # Cắt hoặc thêm cho khớp với detected_dim
116
- if len(cols) > detected_dim:
117
  GLOBAL_DATA["feature_cols"] = cols[:detected_dim]
118
- elif len(cols) < detected_dim:
119
- GLOBAL_DATA["feature_cols"] = cols + [f"Feat_{i}" for i in range(detected_dim - len(cols))]
120
  else:
121
- GLOBAL_DATA["feature_cols"] = cols
122
  else:
 
123
  GLOBAL_DATA["feature_cols"] = [f"Feature_{i}" for i in range(detected_dim)]
 
 
 
 
124
  except Exception as e:
125
- print(f"❌ Critical Model Error: {e}")
 
 
126
 
127
- # Gọi hàm load khi khởi động
128
  load_resources()
129
 
130
  # ==========================================
131
- # 3. LOGIC XỬ LÝ (HANDLERS)
132
  # ==========================================
133
 
134
- # --- Xử lý Tab 1: Tra cứu ---
135
- def draw_ego_graph(address):
136
- """Vẽ đồ thị mạng lưới giao dịch cục bộ"""
137
  df = GLOBAL_DATA["df_edges"]
138
  if df.empty: return None
139
 
140
- # Lấy 20 giao dịch gần nhất
141
- subset = df[(df["src"] == address) | (df["dst"] == address)].head(20)
142
  if subset.empty: return None
143
 
144
  G = nx.from_pandas_edgelist(subset, "src", "dst", edge_attr="edge_type", create_using=nx.DiGraph())
145
- pos = nx.spring_layout(G, seed=42, k=0.9)
146
 
147
  plt.figure(figsize=(8, 8))
148
- colors = ["#FF4500" if n == address else "#1E90FF" for n in G.nodes()]
149
- sizes = [400 if n == address else 100 for n in G.nodes()]
150
 
151
- nx.draw_networkx_nodes(G, pos, node_color=colors, node_size=sizes, alpha=0.9)
152
- nx.draw_networkx_edges(G, pos, alpha=0.4, arrowstyle='->')
153
- nx.draw_networkx_labels(G, pos, labels={n: n[:4] for n in G.nodes()}, font_size=8)
154
 
155
- plt.title(f"Transaction Graph: {address[:6]}...")
 
 
 
 
 
156
  plt.axis('off')
157
  return plt.gcf()
158
 
159
- def lookup_address(address):
160
- address = address.lower().strip()
 
 
161
  df = GLOBAL_DATA["df_scores"]
 
162
 
163
- # 1. Tìm trong Database
164
- if not df.empty and address in df.index:
165
- row = df.loc[address]
166
- score = float(row.get("prob_criminal", row.get("susp", 0.0)))
167
- label_val = row.get("label", -1)
168
-
169
- status = "CRIMINAL 🔴" if score > 0.5 else "BENIGN 🟢"
170
- if label_val == 1: status += " (Verified Criminal)"
171
- elif label_val == 0: status += " (Verified Benign)"
172
 
173
- plot = draw_ego_graph(address)
174
- return (
175
- f"### ✅ Found in Database\n**Score:** {score:.4f}\n**Status:** {status}",
176
- plot,
177
- gr.update(visible=False) # Ẩn thông báo chuyển tab
178
- )
179
-
180
- # 2. Nếu không tìm thấy
181
- msg = (
182
- f"### ❌ Address Not Found\n"
183
- f"Địa chỉ `{address}` chưa được model chấm điểm trước đó.\n"
184
- f"Vui lòng chuyển sang tab **'Inductive Prediction'** để nhập thông số và dự đoán."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  )
186
- return msg, None, gr.update(visible=True)
187
 
188
- # --- Xử lý Tab 2: Dự đoán ---
189
- def predict_manual(*features):
190
  model = GLOBAL_DATA["model"]
191
  if model is None:
192
- return "❌ Model not loaded."
193
 
194
  try:
195
- # Chuyển input thành tensor
196
- feat_vals = [float(f) for f in features]
197
- x = torch.tensor([feat_vals], dtype=torch.float)
198
- # Edge index giả (node cô lập)
199
  edge_index = torch.tensor([[], []], dtype=torch.long)
200
 
201
  with torch.no_grad():
202
  logits = model(x, edge_index)
203
- probs = torch.softmax(logits, dim=1)
204
- prob_crim = probs[0][1].item()
205
 
206
- label = "CRIMINAL 🔴" if prob_crim > 0.5 else "BENIGN 🟢"
207
  return (
208
- f"### 🧠 Inductive Prediction Result\n"
209
- f"- **Fraud Probability:** {prob_crim*100:.2f}%\n"
210
- f"- **Model Verdict:** {label}\n"
211
- f"*(Dự đoán dựa trên {len(feat_vals)} features đầu vào)*"
212
  )
213
  except Exception as e:
214
- return f"⚠️ Error: {str(e)}"
215
 
216
  # ==========================================
217
- # 4. GIAO DIỆN GRADIO (UI)
218
  # ==========================================
219
- with gr.Blocks(title="ETH Fraud GNN Hybrid") as demo:
220
- gr.Markdown("# 🕵️‍♀️ Ethereum Fraud Detection System (Hybrid)")
221
- gr.Markdown("Kết hợp tra cứu dữ liệu lịch sử và khả năng dự đoán (Inductive) trên dữ liệu mới.")
222
 
 
 
 
 
223
  with gr.Tabs():
224
- # === TAB 1: TRA CỨU ===
225
  with gr.TabItem("🔍 Lookup Address"):
226
  with gr.Row():
227
- with gr.Column():
228
- inp_addr = gr.Textbox(label="Ethereum Address", placeholder="0x...")
229
- btn_lookup = gr.Button("Search", variant="primary")
230
- with gr.Column():
231
- out_lookup_text = gr.Markdown()
232
- out_plot = gr.Plot(label="Graph")
233
- # Thông báo hướng dẫn chuyển tab (ẩn mặc định)
234
- notice_box = gr.Markdown("👉 **Tip:** Use the 'Inductive Prediction' tab to predict unknown addresses.", visible=False)
235
-
236
- btn_lookup.click(
237
- lookup_address,
238
- inputs=inp_addr,
239
- outputs=[out_lookup_text, out_plot, notice_box]
240
- )
241
 
242
- # === TAB 2: DỰ ĐOÁN THỦ CÔNG ===
243
  with gr.TabItem("🧠 Inductive Prediction"):
244
- gr.Markdown("### Predict New/Unknown Address")
245
- gr.Markdown("Nhập các chỉ số feature của ví để model dự đoán rủi ro.")
246
 
247
- input_comps = []
248
- if GLOBAL_DATA["model"]:
249
- # Tạo lưới nhập liệu động
250
- cols = GLOBAL_DATA["feature_cols"]
251
- with gr.Row():
252
- # Chia làm 3 cột cho gọn
253
- with gr.Column():
254
- for c in cols[:len(cols)//3]:
255
- input_comps.append(gr.Number(label=c, value=0.0))
256
- with gr.Column():
257
- for c in cols[len(cols)//3 : 2*len(cols)//3]:
258
- input_comps.append(gr.Number(label=c, value=0.0))
259
- with gr.Column():
260
- for c in cols[2*len(cols)//3:]:
261
- input_comps.append(gr.Number(label=c, value=0.0))
262
-
263
- btn_predict = gr.Button("Run GraphSAGE Inference", variant="primary")
264
- out_pred_text = gr.Markdown()
265
 
266
- btn_predict.click(
267
- predict_manual,
268
- inputs=input_comps,
269
- outputs=out_pred_text
270
- )
271
- else:
272
- gr.Error("Model failed to load. Cannot render inputs.")
 
 
 
273
 
274
  if __name__ == "__main__":
275
  demo.launch()
 
12
  from huggingface_hub import hf_hub_download
13
 
14
  # ==========================================
15
+ # 1. ĐỊNH NGHĨA MODEL
16
  # ==========================================
17
  class SAGE(nn.Module):
18
  def __init__(self, in_dim, h=128, out_dim=2, p_drop=0.3):
 
25
  self.drop = nn.Dropout(p_drop)
26
 
27
  def forward(self, x, edge_index):
 
28
  x = self.conv1(x, edge_index)
29
  x = self.bn1(x)
30
  x = F.relu(x)
31
  x = self.drop(x)
 
32
  x = self.conv2(x, edge_index)
33
  x = self.bn2(x)
34
  x = F.relu(x)
35
  x = self.drop(x)
 
36
  return self.head(x)
37
 
38
  # ==========================================
39
+ # 2. QUẢN LÝ RESOURCE (DATA & MODEL)
40
  # ==========================================
41
  REPO_ID = "uyen1109/eth-fraud-gnn-uyenuyen-v3"
42
  TOKEN = os.getenv("HF_TOKEN")
43
 
44
+ # Danh sách feature mặc định (Fallback) để UI không bị trống nếu lỗi load file
45
+ DEFAULT_FEATURES = [
46
+ 'out_deg', 'in_deg', 'eth_out_sum', 'eth_in_sum',
47
+ 'unique_dst_cnt', 'unique_src_cnt', 'first_seen_ts', 'last_seen_ts',
48
+ 'pr', 'clust_coef', 'betw', 'feat_11', 'feat_12', 'feat_13', 'feat_14'
49
+ ]
50
+
51
  GLOBAL_DATA = {
52
  "model": None,
53
  "df_scores": pd.DataFrame(),
54
  "df_edges": pd.DataFrame(),
55
+ "feature_cols": DEFAULT_FEATURES, # Luôn có giá trị mặc định
56
+ "status": "Initializing..."
57
  }
58
 
59
  def smart_load_file(filename):
60
+ """Thử tải file từ repo, ưu tiên hf_export"""
61
+ paths = [f"hf_export/{filename}", filename]
62
+ for p in paths:
63
  try:
64
+ # Thử tải với token trước, nếu lỗi thử không token (public repo)
65
  return hf_hub_download(repo_id=REPO_ID, filename=p, token=TOKEN)
66
  except:
67
+ try:
68
+ return hf_hub_download(repo_id=REPO_ID, filename=p, token=None)
69
+ except:
70
+ continue
71
  return None
72
 
73
  def load_resources():
74
+ logs = []
75
+ print("⏳ Starting Resource Loading...")
76
 
77
+ # 1. Load Scores
78
  try:
79
  path = smart_load_file("scores/node_scores_with_labels.csv") or smart_load_file("node_scores_with_labels.csv")
80
  if path:
81
  df = pd.read_csv(path)
82
+ # Chuẩn hóa cột địa chỉ: tìm cột chứa chữ 'address' hoặc 'id'
83
+ addr_col = next((c for c in df.columns if 'addr' in c.lower() or 'id' in c.lower()), df.columns[0])
84
+ df[addr_col] = df[addr_col].astype(str).str.lower().str.strip()
85
+ # Đặt index là địa chỉ để tra cứu nhanh
86
+ df.set_index(addr_col, inplace=True)
87
  GLOBAL_DATA["df_scores"] = df
88
+ msg = f"✅ Loaded Scores: {len(df)} rows (Index col: {addr_col})"
89
+ print(msg)
90
+ logs.append(msg)
91
+ else:
92
+ logs.append("⚠️ Scores CSV not found.")
93
  except Exception as e:
94
+ logs.append(f" Error loading scores: {str(e)}")
95
 
96
+ # 2. Load Edges
97
  try:
98
  path = smart_load_file("graph/edges_all.csv") or smart_load_file("edges_all.csv")
99
  if path:
100
+ df = pd.read_csv(path, usecols=["src", "dst", "edge_type"])
101
+ df["src"] = df["src"].astype(str).str.lower().str.strip()
102
+ df["dst"] = df["dst"].astype(str).str.lower().str.strip()
103
+ GLOBAL_DATA["df_edges"] = df
104
+ print("✅ Loaded Edges.")
105
+ except:
106
+ print("⚠️ Edges CSV not found (Graph viz will be disabled).")
107
 
108
+ # 3. Load Model
109
  try:
110
  model_path = smart_load_file("pytorch_model.bin")
111
  if model_path:
112
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))
113
+ # Tự động phát hiện input dim
114
  detected_dim = state_dict['conv1.lin_l.weight'].shape[1]
 
115
 
 
116
  model = SAGE(in_dim=detected_dim, h=128, out_dim=2, p_drop=0.3)
117
  model.load_state_dict(state_dict)
118
  model.eval()
119
  GLOBAL_DATA["model"] = model
 
120
 
121
+ # Cập nhật danh sách feature cột nếu có
122
  cols_path = smart_load_file("feature_columns.json")
123
  if cols_path:
124
  with open(cols_path, 'r') as f:
125
  cols = json.load(f)
126
+ # Điều chỉnh cho khớp detected_dim
127
+ if len(cols) >= detected_dim:
128
  GLOBAL_DATA["feature_cols"] = cols[:detected_dim]
 
 
129
  else:
130
+ GLOBAL_DATA["feature_cols"] = cols + [f"F_{i}" for i in range(len(cols), detected_dim)]
131
  else:
132
+ # Nếu không có file json, tạo dummy name cho đủ số lượng
133
  GLOBAL_DATA["feature_cols"] = [f"Feature_{i}" for i in range(detected_dim)]
134
+
135
+ logs.append(f"✅ Model Loaded (Input Dim: {detected_dim})")
136
+ else:
137
+ logs.append("❌ pytorch_model.bin not found.")
138
  except Exception as e:
139
+ logs.append(f"❌ Model Load Error: {str(e)}")
140
+
141
+ GLOBAL_DATA["status"] = "\n".join(logs)
142
 
143
+ # Chạy load ngay lập tức
144
  load_resources()
145
 
146
  # ==========================================
147
+ # 3. LOGIC XỬ LÝ
148
  # ==========================================
149
 
150
+ def draw_graph(address):
 
 
151
  df = GLOBAL_DATA["df_edges"]
152
  if df.empty: return None
153
 
154
+ # Tìm giao dịch liên quan (cả in và out)
155
+ subset = df[(df["src"] == address) | (df["dst"] == address)].head(30)
156
  if subset.empty: return None
157
 
158
  G = nx.from_pandas_edgelist(subset, "src", "dst", edge_attr="edge_type", create_using=nx.DiGraph())
 
159
 
160
  plt.figure(figsize=(8, 8))
161
+ pos = nx.spring_layout(G, k=0.8, seed=42)
 
162
 
163
+ # Màu sắc: Target màu đỏ, Neighbor màu xanh
164
+ node_colors = ["#FF4500" if n == address else "#1E90FF" for n in G.nodes()]
165
+ node_sizes = [400 if n == address else 150 for n in G.nodes()]
166
 
167
+ nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=node_sizes, alpha=0.9)
168
+ nx.draw_networkx_edges(G, pos, alpha=0.3, arrowstyle='->')
169
+ # Label ngắn gọn
170
+ nx.draw_networkx_labels(G, pos, labels={n: n[:4] + ".." for n in G.nodes()}, font_size=8)
171
+
172
+ plt.title(f"Ego Graph: {address[:6]}...")
173
  plt.axis('off')
174
  return plt.gcf()
175
 
176
+ def lookup_handler(address):
177
+ # Chuẩn hóa input cực mạnh để khớp với index CSV
178
+ raw_addr = str(address).strip().lower()
179
+
180
  df = GLOBAL_DATA["df_scores"]
181
+ found_row = None
182
 
183
+ # Thử các trường hợp khớp
184
+ if raw_addr in df.index:
185
+ found_row = df.loc[raw_addr]
186
+ elif raw_addr.startswith("0x") and raw_addr[2:] in df.index: # Thử bỏ 0x
187
+ found_row = df.loc[raw_addr[2:]]
188
+ elif ("0x" + raw_addr) in df.index: # Thử thêm 0x
189
+ found_row = df.loc["0x" + raw_addr]
 
 
190
 
191
+ if found_row is not None:
192
+ # Lấy điểm số
193
+ try:
194
+ # Xử lý trường hợp duplicate index hoặc series
195
+ if isinstance(found_row, pd.DataFrame):
196
+ found_row = found_row.iloc[0]
197
+
198
+ score = float(found_row.get("prob_criminal", found_row.get("susp", 0.0)))
199
+ label = int(found_row.get("label", -1))
200
+
201
+ risk = "CRITICAL 🔴" if score > 0.8 else ("HIGH 🟠" if score > 0.5 else "LOW 🟢")
202
+ label_text = "Unknown"
203
+ if label == 1: label_text = "Criminal (True Label)"
204
+ elif label == 0: label_text = "Benign (True Label)"
205
+
206
+ info = (
207
+ f"### ✅ Address Found\n"
208
+ f"- **Risk Score:** {score:.4f}\n"
209
+ f"- **Risk Level:** {risk}\n"
210
+ f"- **Dataset Label:** {label_text}"
211
+ )
212
+ return info, draw_graph(raw_addr)
213
+ except Exception as e:
214
+ return f"Error parsing row: {e}", None
215
+
216
+ # Nếu không tìm thấy
217
+ return (
218
+ f"### ❌ Not Found in Database\n"
219
+ f"Address `{raw_addr}` does not exist in `node_scores_with_labels.csv`.\n"
220
+ f"Please verify the address or use the **Inductive Prediction** tab.",
221
+ None
222
  )
 
223
 
224
+ def predict_handler(*features):
 
225
  model = GLOBAL_DATA["model"]
226
  if model is None:
227
+ return f"❌ Model failed to load properly.\n\nLogs:\n{GLOBAL_DATA['status']}"
228
 
229
  try:
230
+ x = torch.tensor([[float(f) for f in features]], dtype=torch.float)
 
 
 
231
  edge_index = torch.tensor([[], []], dtype=torch.long)
232
 
233
  with torch.no_grad():
234
  logits = model(x, edge_index)
235
+ prob = torch.softmax(logits, dim=1)[0][1].item()
 
236
 
237
+ verdict = "CRIMINAL 🔴" if prob > 0.5 else "BENIGN 🟢"
238
  return (
239
+ f"### 🧠 Prediction Result\n"
240
+ f"- **Fraud Probability:** {prob*100:.2f}%\n"
241
+ f"- **Verdict:** {verdict}"
 
242
  )
243
  except Exception as e:
244
+ return f"Prediction Error: {str(e)}"
245
 
246
  # ==========================================
247
+ # 4. UI SETUP
248
  # ==========================================
249
+ with gr.Blocks(title="ETH Fraud GNN") as demo:
250
+ gr.Markdown("# 🕵️‍♀️ Ethereum Fraud GNN (Hybrid V3)")
 
251
 
252
+ # Hiển thị trạng thái load hệ thống (ẩn đi nếu muốn gọn)
253
+ with gr.Accordion("System Status / Logs", open=False):
254
+ gr.Markdown(GLOBAL_DATA["status"])
255
+
256
  with gr.Tabs():
257
+ # TAB 1: LOOKUP
258
  with gr.TabItem("🔍 Lookup Address"):
259
  with gr.Row():
260
+ inp_addr = gr.Textbox(label="Enter Address", placeholder="0x...")
261
+ btn_search = gr.Button("Search", variant="primary")
262
+
263
+ with gr.Row():
264
+ out_info = gr.Markdown()
265
+ out_plot = gr.Plot()
266
+
267
+ btn_search.click(lookup_handler, inputs=inp_addr, outputs=[out_info, out_plot])
 
 
 
 
 
 
268
 
269
+ # TAB 2: INDUCTIVE
270
  with gr.TabItem("🧠 Inductive Prediction"):
271
+ gr.Markdown("### Predict New Address")
272
+ gr.Markdown("Enter extracted features manually:")
273
 
274
+ # TẠO INPUT ĐỘNG: Dù model có load được hay không, UI vẫn sẽ render dựa trên GLOBAL_DATA["feature_cols"]
275
+ # Điều này fix lỗi giao diện trống trơn.
276
+ feat_inputs = []
277
+ cols = GLOBAL_DATA["feature_cols"]
278
+
279
+ # Chia layout thành 3 cột
280
+ with gr.Row():
281
+ col1, col2, col3 = gr.Column(), gr.Column(), gr.Column()
 
 
 
 
 
 
 
 
 
 
282
 
283
+ # Phân phối input vào 3 cột
284
+ for i, c in enumerate(cols):
285
+ target_col = col1 if i % 3 == 0 else (col2 if i % 3 == 1 else col3)
286
+ with target_col:
287
+ feat_inputs.append(gr.Number(label=c, value=0.0))
288
+
289
+ btn_predict = gr.Button("Run Inference", variant="primary")
290
+ out_pred = gr.Markdown()
291
+
292
+ btn_predict.click(predict_handler, inputs=feat_inputs, outputs=out_pred)
293
 
294
  if __name__ == "__main__":
295
  demo.launch()