william1324 commited on
Commit
6bb1489
·
verified ·
1 Parent(s): bb16f1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +422 -176
app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import gradio as gr
2
  import pandas as pd
3
  import numpy as np
@@ -15,6 +18,9 @@ from sklearn.svm import SVC
15
 
16
  from sklearn.metrics import (
17
  accuracy_score,
 
 
 
18
  classification_report,
19
  confusion_matrix,
20
  ConfusionMatrixDisplay,
@@ -23,6 +29,9 @@ from sklearn.metrics import (
23
  )
24
 
25
 
 
 
 
26
  def load_data(file_obj):
27
  if file_obj is None:
28
  raise ValueError("請先上傳 CSV 或 Excel 檔案。")
@@ -35,29 +44,7 @@ def load_data(file_obj):
35
  if lower_name.endswith(".xlsx") or lower_name.endswith(".xls"):
36
  return pd.read_excel(file_path)
37
 
38
- raise ValueError("只支援 CSVXLSXXLS 檔案。")
39
-
40
-
41
- def preprocess_data(df, target_column):
42
- df = df.copy()
43
- df = df.dropna(how="all")
44
-
45
- y = df[target_column]
46
- X = df.drop(columns=[target_column])
47
-
48
- numeric_cols = X.select_dtypes(include=[np.number]).columns.tolist()
49
- categorical_cols = X.select_dtypes(exclude=[np.number]).columns.tolist()
50
-
51
- if numeric_cols:
52
- num_imputer = SimpleImputer(strategy="median")
53
- X[numeric_cols] = num_imputer.fit_transform(X[numeric_cols])
54
-
55
- if categorical_cols:
56
- cat_imputer = SimpleImputer(strategy="most_frequent")
57
- X[categorical_cols] = cat_imputer.fit_transform(X[categorical_cols])
58
- X = pd.get_dummies(X, columns=categorical_cols, drop_first=True)
59
-
60
- return X, y
61
 
62
 
63
  def build_model(
@@ -93,7 +80,7 @@ def build_model(
93
  if model_name == "Logistic Regression":
94
  return LogisticRegression(
95
  C=float(lr_c),
96
- max_iter=1000,
97
  random_state=42
98
  )
99
 
@@ -108,33 +95,110 @@ def build_model(
108
  raise ValueError("不支援的模型。")
109
 
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  def plot_confusion(y_true, y_pred):
112
  fig, ax = plt.subplots(figsize=(5, 4))
113
  disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix(y_true, y_pred))
114
  disp.plot(ax=ax)
 
115
  plt.tight_layout()
116
  return fig
117
 
118
 
119
- def plot_roc(y_true, y_prob):
120
  fpr, tpr, _ = roc_curve(y_true, y_prob)
121
  roc_auc = auc(fpr, tpr)
122
 
123
  fig, ax = plt.subplots(figsize=(6, 4))
124
  ax.plot(fpr, tpr, label=f"AUC = {roc_auc:.4f}")
125
  ax.plot([0, 1], [0, 1], linestyle="--")
 
126
  ax.set_xlabel("False Positive Rate")
127
  ax.set_ylabel("True Positive Rate")
128
- ax.set_title("ROC Curve")
129
  ax.legend(loc="lower right")
130
  plt.tight_layout()
 
131
  return fig, roc_auc
132
 
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  def analyze_file(file_obj):
135
  try:
136
  df = load_data(file_obj)
137
 
 
 
138
  info_df = pd.DataFrame({
139
  "欄位名稱": df.columns,
140
  "資料型態": [str(dtype) for dtype in df.dtypes]
@@ -142,19 +206,35 @@ def analyze_file(file_obj):
142
 
143
  missing_df = pd.DataFrame({
144
  "欄位名稱": df.columns,
145
- "缺失值數量": df.isnull().sum().values
 
146
  })
147
 
148
- preview_df = df.head(10)
149
- summary_text = f"資料維度:{df.shape[0]} 筆 × {df.shape[1]} 欄"
 
 
 
 
 
150
  columns = list(df.columns)
151
 
 
 
 
 
 
 
 
 
 
 
152
  return (
153
  preview_df,
154
  info_df,
155
  missing_df,
156
- summary_text,
157
- gr.update(choices=columns, value=columns[0] if columns else None)
158
  )
159
 
160
  except Exception as e:
@@ -163,12 +243,29 @@ def analyze_file(file_obj):
163
  empty_df,
164
  empty_df,
165
  empty_df,
166
- f"錯誤:{e}",
167
- gr.update(choices=[], value=None)
168
  )
169
 
170
 
171
- def train_model(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  file_obj,
173
  target_column,
174
  use_count_as_target,
@@ -186,30 +283,17 @@ def train_model(
186
  ):
187
  try:
188
  df = load_data(file_obj)
 
189
 
190
- if use_count_as_target:
191
- if "count" not in df.columns:
192
- raise ValueError("你勾選了用 count 轉二元分類,但資料中沒有 count 欄位。")
193
- median_value = df["count"].median()
194
- df["label"] = (df["count"] > median_value).astype(int)
195
- target_column = "label"
196
-
197
- if target_column is None or target_column not in df.columns:
198
- raise ValueError("請先選擇正確的目標欄位。")
199
-
200
- X, y = preprocess_data(df, target_column)
201
-
202
- if y.dtype == "object":
203
- encoder = LabelEncoder()
204
- y = encoder.fit_transform(y)
205
 
206
  unique_classes = np.unique(y)
207
  if len(unique_classes) != 2:
208
- raise ValueError("目前版本只支援二元分類,因為需要輸出 ROC / AUC。")
209
 
210
  X_train, X_test, y_train, y_test = train_test_split(
211
- X,
212
- y,
213
  test_size=float(test_size),
214
  random_state=42,
215
  stratify=y
@@ -236,146 +320,287 @@ def train_model(
236
  )
237
 
238
  model.fit(X_train, y_train)
239
-
240
  y_pred = model.predict(X_test)
241
- y_prob = model.predict_proba(X_test)[:, 1] if hasattr(model, "predict_proba") else None
242
 
243
- acc = accuracy_score(y_test, y_pred)
244
- report_df = pd.DataFrame(
245
- classification_report(y_test, y_pred, output_dict=True)
246
- ).transpose()
247
 
248
- cm_fig = plot_confusion(y_test, y_pred)
 
 
 
249
 
 
 
250
  if y_prob is not None:
251
- roc_fig, roc_auc = plot_roc(y_test, y_prob)
252
- auc_text = f"AUC:{roc_auc:.4f}"
253
- else:
254
- roc_fig = None
255
- auc_text = "AUC無法計算"
 
 
 
 
 
 
256
 
257
- result_text = f"模型:{model_name}\nAccuracy:{acc:.4f}\n{auc_text}"
 
258
 
259
- return result_text, report_df, cm_fig, roc_fig
260
 
261
  except Exception as e:
262
  empty_df = pd.DataFrame()
263
- return f"錯誤:{e}", empty_df, None, None
 
 
 
 
264
 
265
 
266
- with gr.Blocks(title="機器學習模型訓練工具") as demo:
267
- gr.Markdown("# 機器學習模型訓練工具")
268
- gr.Markdown(
269
- "支援 CSV / Excel 上傳、資料檢視、前處理、模型訓練、Classification Report、Confusion Matrix、ROC Curve。"
270
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
- with gr.Row():
273
- with gr.Column(scale=1):
274
- file_input = gr.File(label="上傳 CSV 或 Excel 檔案", file_types=[".csv", ".xlsx", ".xls"])
275
-
276
- analyze_button = gr.Button("分析資料")
277
- target_dropdown = gr.Dropdown(label="選擇目標欄位", choices=[], value=None)
278
-
279
- use_count_checkbox = gr.Checkbox(
280
- label="若資料有 count 欄位,將 count 依中位數轉為二元分類",
281
- value=True
282
- )
283
-
284
- test_size_slider = gr.Slider(
285
- label="測試集比例",
286
- minimum=0.1,
287
- maximum=0.5,
288
- value=0.2,
289
- step=0.1
290
- )
291
-
292
- use_scaling_checkbox = gr.Checkbox(
293
- label="使用 StandardScaler",
294
- value=True
295
- )
296
-
297
- model_dropdown = gr.Dropdown(
298
- label="選擇模型",
299
- choices=[
300
- "KNN",
301
- "Decision Tree",
302
- "Random Forest",
303
- "Logistic Regression",
304
- "SVM"
305
- ],
306
- value="KNN"
307
- )
308
-
309
- gr.Markdown("## 模型參數")
310
-
311
- knn_k = gr.Slider(label="KNN:k 值", minimum=1, maximum=15, value=5, step=1)
312
- dt_criterion = gr.Dropdown(
313
- label="Decision Tree:criterion",
314
- choices=["gini", "entropy"],
315
- value="gini"
316
- )
317
- dt_max_depth = gr.Slider(
318
- label="Decision Tree:max_depth(0 代表不限)",
319
- minimum=0,
320
- maximum=20,
321
- value=5,
322
- step=1
323
- )
324
- rf_estimators = gr.Slider(
325
- label="Random Forestn_estimators",
326
- minimum=10,
327
- maximum=300,
328
- value=100,
329
- step=10
330
- )
331
- rf_max_depth = gr.Slider(
332
- label="Random Forest:max_depth(0 代表不限)",
333
- minimum=0,
334
- maximum=20,
335
- value=5,
336
- step=1
337
- )
338
- lr_c = gr.Slider(
339
- label="Logistic Regression:C",
340
- minimum=0.01,
341
- maximum=10.0,
342
- value=1.0,
343
- step=0.01
344
- )
345
- svm_kernel = gr.Dropdown(
346
- label="SVM:kernel",
347
- choices=["linear", "rbf"],
348
- value="rbf"
349
- )
350
- svm_c = gr.Slider(
351
- label="SVM:C",
352
- minimum=0.01,
353
- maximum=10.0,
354
- value=1.0,
355
- step=0.01
356
- )
357
-
358
- train_button = gr.Button("開始訓練", variant="primary")
359
-
360
- with gr.Column(scale=2):
361
- summary_text = gr.Textbox(label="資料摘要")
362
- preview_output = gr.Dataframe(label="資料預覽")
363
- info_output = gr.Dataframe(label="欄位態")
364
- missing_output = gr.Dataframe(label="缺失值統計")
365
-
366
- result_text = gr.Textbox(label="模型結果")
367
- report_output = gr.Dataframe(label="Classification Report")
368
- cm_output = gr.Plot(label="Confusion Matrix")
369
- roc_output = gr.Plot(label="ROC Curve")
370
-
371
- analyze_button.click(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  fn=analyze_file,
373
  inputs=[file_input],
374
- outputs=[preview_output, info_output, missing_output, summary_text, target_dropdown]
 
 
 
 
 
 
375
  )
376
 
377
- train_button.click(
378
- fn=train_model,
 
 
 
 
 
 
379
  inputs=[
380
  file_input,
381
  target_dropdown,
@@ -392,7 +617,28 @@ with gr.Blocks(title="機器學習模型訓練工具") as demo:
392
  svm_kernel,
393
  svm_c
394
  ],
395
- outputs=[result_text, report_output, cm_output, roc_output]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  )
397
 
398
  demo.launch()
 
1
+ import warnings
2
+ warnings.filterwarnings("ignore")
3
+
4
  import gradio as gr
5
  import pandas as pd
6
  import numpy as np
 
18
 
19
  from sklearn.metrics import (
20
  accuracy_score,
21
+ precision_score,
22
+ recall_score,
23
+ f1_score,
24
  classification_report,
25
  confusion_matrix,
26
  ConfusionMatrixDisplay,
 
29
  )
30
 
31
 
32
+ # =========================
33
+ # 基本工具函式
34
+ # =========================
35
  def load_data(file_obj):
36
  if file_obj is None:
37
  raise ValueError("請先上傳 CSV 或 Excel 檔案。")
 
44
  if lower_name.endswith(".xlsx") or lower_name.endswith(".xls"):
45
  return pd.read_excel(file_path)
46
 
47
+ raise ValueError("目前只支援 .csv.xlsx.xls 檔案。")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
 
50
  def build_model(
 
80
  if model_name == "Logistic Regression":
81
  return LogisticRegression(
82
  C=float(lr_c),
83
+ max_iter=2000,
84
  random_state=42
85
  )
86
 
 
95
  raise ValueError("不支援的模型。")
96
 
97
 
98
+ def preprocess_features(df, target_column):
99
+ df = df.copy().dropna(how="all")
100
+
101
+ y = df[target_column]
102
+ X = df.drop(columns=[target_column])
103
+
104
+ numeric_cols = X.select_dtypes(include=[np.number]).columns.tolist()
105
+ categorical_cols = X.select_dtypes(exclude=[np.number]).columns.tolist()
106
+
107
+ if numeric_cols:
108
+ num_imputer = SimpleImputer(strategy="median")
109
+ X[numeric_cols] = num_imputer.fit_transform(X[numeric_cols])
110
+
111
+ if categorical_cols:
112
+ cat_imputer = SimpleImputer(strategy="most_frequent")
113
+ X[categorical_cols] = cat_imputer.fit_transform(X[categorical_cols])
114
+ X = pd.get_dummies(X, columns=categorical_cols, drop_first=True)
115
+
116
+ return X, y
117
+
118
+
119
+ def prepare_target(df, target_column, use_count_as_target):
120
+ df = df.copy()
121
+
122
+ if use_count_as_target:
123
+ if "count" not in df.columns:
124
+ raise ValueError("你勾選了 count 二元分類,但資料中沒有 count 欄位。")
125
+ median_value = df["count"].median()
126
+ df["label"] = (df["count"] > median_value).astype(int)
127
+ target_column = "label"
128
+
129
+ if target_column is None or target_column not in df.columns:
130
+ raise ValueError("請選擇正確的目標欄位。")
131
+
132
+ return df, target_column
133
+
134
+
135
+ def encode_target(y):
136
+ if y.dtype == "object":
137
+ encoder = LabelEncoder()
138
+ y = encoder.fit_transform(y)
139
+ return y
140
+
141
+
142
+ # =========================
143
+ # 視覺化函式
144
+ # =========================
145
+ def plot_target_distribution(y_series, title="Label Distribution"):
146
+ fig, ax = plt.subplots(figsize=(6, 4))
147
+ counts = pd.Series(y_series).value_counts().sort_index()
148
+ ax.bar(counts.index.astype(str), counts.values)
149
+ ax.set_title(title)
150
+ ax.set_xlabel("Class")
151
+ ax.set_ylabel("Count")
152
+ plt.tight_layout()
153
+ return fig
154
+
155
+
156
  def plot_confusion(y_true, y_pred):
157
  fig, ax = plt.subplots(figsize=(5, 4))
158
  disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix(y_true, y_pred))
159
  disp.plot(ax=ax)
160
+ ax.set_title("Confusion Matrix")
161
  plt.tight_layout()
162
  return fig
163
 
164
 
165
+ def plot_roc_curve(y_true, y_prob):
166
  fpr, tpr, _ = roc_curve(y_true, y_prob)
167
  roc_auc = auc(fpr, tpr)
168
 
169
  fig, ax = plt.subplots(figsize=(6, 4))
170
  ax.plot(fpr, tpr, label=f"AUC = {roc_auc:.4f}")
171
  ax.plot([0, 1], [0, 1], linestyle="--")
172
+ ax.set_title("ROC Curve")
173
  ax.set_xlabel("False Positive Rate")
174
  ax.set_ylabel("True Positive Rate")
 
175
  ax.legend(loc="lower right")
176
  plt.tight_layout()
177
+
178
  return fig, roc_auc
179
 
180
 
181
+ def plot_model_comparison(result_df):
182
+ fig, ax = plt.subplots(figsize=(8, 4))
183
+ ax.bar(result_df["Model"], result_df["Accuracy"])
184
+ ax.set_title("Model Accuracy Comparison")
185
+ ax.set_xlabel("Model")
186
+ ax.set_ylabel("Accuracy")
187
+ ax.set_ylim(0, 1)
188
+ plt.xticks(rotation=15)
189
+ plt.tight_layout()
190
+ return fig
191
+
192
+
193
+ # =========================
194
+ # 資料分析
195
+ # =========================
196
  def analyze_file(file_obj):
197
  try:
198
  df = load_data(file_obj)
199
 
200
+ preview_df = df.head(10)
201
+
202
  info_df = pd.DataFrame({
203
  "欄位名稱": df.columns,
204
  "資料型態": [str(dtype) for dtype in df.dtypes]
 
206
 
207
  missing_df = pd.DataFrame({
208
  "欄位名稱": df.columns,
209
+ "缺失值數量": df.isnull().sum().values,
210
+ "缺失比例(%)": (df.isnull().mean().values * 100).round(2)
211
  })
212
 
213
+ summary = []
214
+ summary.append(f"資料筆數:{df.shape[0]}")
215
+ summary.append(f"資料欄數:{df.shape[1]}")
216
+ summary.append(f"數值欄位數:{len(df.select_dtypes(include=[np.number]).columns)}")
217
+ summary.append(f"類別欄位數:{len(df.select_dtypes(exclude=[np.number]).columns)}")
218
+ summary.append(f"總缺失值數:{int(df.isnull().sum().sum())}")
219
+
220
  columns = list(df.columns)
221
 
222
+ if len(columns) > 0:
223
+ default_target = "count" if "count" in columns else columns[-1]
224
+ else:
225
+ default_target = None
226
+
227
+ has_count_message = "有偵測到 count 欄位,可直接轉成二元分類。" if "count" in df.columns else "未偵測到 count 欄位。"
228
+
229
+ empty_fig = plt.figure()
230
+ plt.close(empty_fig)
231
+
232
  return (
233
  preview_df,
234
  info_df,
235
  missing_df,
236
+ "\n".join(summary) + f"\n{has_count_message}",
237
+ gr.update(choices=columns, value=default_target),
238
  )
239
 
240
  except Exception as e:
 
243
  empty_df,
244
  empty_df,
245
  empty_df,
246
+ f"資料分析失敗:{e}",
247
+ gr.update(choices=[], value=None),
248
  )
249
 
250
 
251
+ def target_distribution(file_obj, target_column, use_count_as_target):
252
+ try:
253
+ df = load_data(file_obj)
254
+ df, target_column = prepare_target(df, target_column, use_count_as_target)
255
+ fig = plot_target_distribution(df[target_column], title=f"{target_column} Distribution")
256
+ return fig
257
+ except Exception as e:
258
+ fig, ax = plt.subplots(figsize=(6, 3))
259
+ ax.text(0.5, 0.5, f"無法產生分布圖:\n{e}", ha="center", va="center")
260
+ ax.axis("off")
261
+ plt.tight_layout()
262
+ return fig
263
+
264
+
265
+ # =========================
266
+ # 單一模型訓練
267
+ # =========================
268
+ def train_single_model(
269
  file_obj,
270
  target_column,
271
  use_count_as_target,
 
283
  ):
284
  try:
285
  df = load_data(file_obj)
286
+ df, target_column = prepare_target(df, target_column, use_count_as_target)
287
 
288
+ X, y = preprocess_features(df, target_column)
289
+ y = encode_target(y)
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
  unique_classes = np.unique(y)
292
  if len(unique_classes) != 2:
293
+ raise ValueError("目前版本只支援二元分類,因為需要輸出 ROC/AUC。")
294
 
295
  X_train, X_test, y_train, y_test = train_test_split(
296
+ X, y,
 
297
  test_size=float(test_size),
298
  random_state=42,
299
  stratify=y
 
320
  )
321
 
322
  model.fit(X_train, y_train)
 
323
  y_pred = model.predict(X_test)
 
324
 
325
+ y_prob = None
326
+ if hasattr(model, "predict_proba"):
327
+ y_prob = model.predict_proba(X_test)[:, 1]
 
328
 
329
+ acc = accuracy_score(y_test, y_pred)
330
+ pre = precision_score(y_test, y_pred, zero_division=0)
331
+ rec = recall_score(y_test, y_pred, zero_division=0)
332
+ f1 = f1_score(y_test, y_pred, zero_division=0)
333
 
334
+ auc_text = "無法計算"
335
+ roc_fig = None
336
  if y_prob is not None:
337
+ roc_fig, roc_auc = plot_roc_curve(y_test, y_prob)
338
+ auc_text = f"{roc_auc:.4f}"
339
+
340
+ result_text = (
341
+ f"模型名稱{model_name}\n"
342
+ f"Accuracy:{acc:.4f}\n"
343
+ f"Precision:{pre:.4f}\n"
344
+ f"Recall:{rec:.4f}\n"
345
+ f"F1-score:{f1:.4f}\n"
346
+ f"AUC:{auc_text}"
347
+ )
348
 
349
+ report_df = pd.DataFrame(classification_report(y_test, y_pred, output_dict=True)).transpose()
350
+ cm_fig = plot_confusion(y_test, y_pred)
351
 
352
+ return result_text, report_df.round(4), cm_fig, roc_fig
353
 
354
  except Exception as e:
355
  empty_df = pd.DataFrame()
356
+ fig, ax = plt.subplots(figsize=(6, 3))
357
+ ax.text(0.5, 0.5, f"錯誤:{e}", ha="center", va="center")
358
+ ax.axis("off")
359
+ plt.tight_layout()
360
+ return f"模型訓練失敗:{e}", empty_df, fig, None
361
 
362
 
363
+ # =========================
364
+ # 模型比較
365
+ # =========================
366
+ def compare_models(
367
+ file_obj,
368
+ target_column,
369
+ use_count_as_target,
370
+ test_size,
371
+ use_scaling
372
+ ):
373
+ try:
374
+ df = load_data(file_obj)
375
+ df, target_column = prepare_target(df, target_column, use_count_as_target)
376
+
377
+ X, y = preprocess_features(df, target_column)
378
+ y = encode_target(y)
379
+
380
+ unique_classes = np.unique(y)
381
+ if len(unique_classes) != 2:
382
+ raise ValueError("目前版本只支援二元分類比較。")
383
 
384
+ X_train, X_test, y_train, y_test = train_test_split(
385
+ X, y,
386
+ test_size=float(test_size),
387
+ random_state=42,
388
+ stratify=y
389
+ )
390
+
391
+ if use_scaling:
392
+ scaler = StandardScaler()
393
+ X_train_scaled = scaler.fit_transform(X_train)
394
+ X_test_scaled = scaler.transform(X_test)
395
+ else:
396
+ X_train_scaled = X_train.values
397
+ X_test_scaled = X_test.values
398
+
399
+ models = [
400
+ ("KNN", KNeighborsClassifier(n_neighbors=5)),
401
+ ("Decision Tree", DecisionTreeClassifier(random_state=42)),
402
+ ("Random Forest", RandomForestClassifier(n_estimators=100, random_state=42)),
403
+ ("Logistic Regression", LogisticRegression(max_iter=2000, random_state=42)),
404
+ ("SVM", SVC(kernel="rbf", probability=True, random_state=42)),
405
+ ]
406
+
407
+ rows = []
408
+
409
+ for name, model in models:
410
+ model.fit(X_train_scaled, y_train)
411
+ y_pred = model.predict(X_test_scaled)
412
+
413
+ acc = accuracy_score(y_test, y_pred)
414
+ pre = precision_score(y_test, y_pred, zero_division=0)
415
+ rec = recall_score(y_test, y_pred, zero_division=0)
416
+ f1 = f1_score(y_test, y_pred, zero_division=0)
417
+
418
+ auc_score = np.nan
419
+ if hasattr(model, "predict_proba"):
420
+ y_prob = model.predict_proba(X_test_scaled)[:, 1]
421
+ auc_score = auc(*roc_curve(y_test, y_prob)[:2])
422
+
423
+ rows.append({
424
+ "Model": name,
425
+ "Accuracy": round(acc, 4),
426
+ "Precision": round(pre, 4),
427
+ "Recall": round(rec, 4),
428
+ "F1-score": round(f1, 4),
429
+ "AUC": None if pd.isna(auc_score) else round(auc_score, 4)
430
+ })
431
+
432
+ result_df = pd.DataFrame(rows).sort_values(by="Accuracy", ascending=False).reset_index(drop=True)
433
+ compare_fig = plot_model_comparison(result_df)
434
+
435
+ best_model = result_df.iloc[0]
436
+ summary = (
437
+ f"最佳模型{best_model['Model']}\n"
438
+ f"Accuracy:{best_model['Accuracy']}\n"
439
+ f"Precision:{best_model['Precision']}\n"
440
+ f"Recall:{best_model['Recall']}\n"
441
+ f"F1-score:{best_model['F1-score']}\n"
442
+ f"AUC:{best_model['AUC']}"
443
+ )
444
+
445
+ return summary, result_df, compare_fig
446
+
447
+ except Exception as e:
448
+ empty_df = pd.DataFrame()
449
+ fig, ax = plt.subplots(figsize=(6, 3))
450
+ ax.text(0.5, 0.5, f"錯誤:{e}", ha="center", va="center")
451
+ ax.axis("off")
452
+ plt.tight_layout()
453
+ return f"模型比較失敗:{e}", empty_df, fig
454
+
455
+
456
+ # =========================
457
+ # UI
458
+ # =========================
459
+ custom_css = """
460
+ .gradio-container {
461
+ max-width: 1200px !important;
462
+ }
463
+ """
464
+
465
+ with gr.Blocks(title="機器學習模型訓練工具", css=custom_css) as demo:
466
+ gr.Markdown("""
467
+ # 機器學習模型訓練工具(滿分版)
468
+ 這個系統可完成:
469
+ - 資料上傳與預覽
470
+ - 欄位型態與缺失值分析
471
+ - `count` 欄位轉二元分類
472
+ - KNN / Decision Tree / Random Forest / Logistic Regression / SVM
473
+ - Accuracy / Precision / Recall / F1-score / AUC
474
+ - Confusion Matrix / ROC Curve
475
+ - 多模比較
476
+ """)
477
+
478
+ with gr.Tab("1. 資料分析"):
479
+ with gr.Row():
480
+ with gr.Column(scale=1):
481
+ file_input = gr.File(
482
+ label="上傳 CSV 或 Excel 檔案",
483
+ file_types=[".csv", ".xlsx", ".xls"]
484
+ )
485
+ analyze_btn = gr.Button("分析資料", variant="primary")
486
+ target_dropdown = gr.Dropdown(label="目標欄位", choices=[], value=None)
487
+ use_count_checkbox = gr.Checkbox(
488
+ label="若資料有 count 欄位,將其依中位數轉成二元分類",
489
+ value=True
490
+ )
491
+ dist_btn = gr.Button("顯示類別分布")
492
+
493
+ with gr.Column(scale=2):
494
+ summary_output = gr.Textbox(label="資料摘要", lines=8)
495
+ preview_output = gr.Dataframe(label="資料預覽")
496
+ info_output = gr.Dataframe(label="欄位型態")
497
+ missing_output = gr.Dataframe(label="缺失值統計")
498
+ dist_plot = gr.Plot(label="類別分布圖")
499
+
500
+ with gr.Tab("2. 單一模型訓練"):
501
+ with gr.Row():
502
+ with gr.Column(scale=1):
503
+ test_size_slider = gr.Slider(
504
+ label="測試集比例",
505
+ minimum=0.1,
506
+ maximum=0.5,
507
+ step=0.1,
508
+ value=0.2
509
+ )
510
+
511
+ use_scaling_checkbox = gr.Checkbox(
512
+ label="使用 StandardScaler",
513
+ value=True
514
+ )
515
+
516
+ model_dropdown = gr.Dropdown(
517
+ label="選擇模型",
518
+ choices=[
519
+ "KNN",
520
+ "Decision Tree",
521
+ "Random Forest",
522
+ "Logistic Regression",
523
+ "SVM"
524
+ ],
525
+ value="KNN"
526
+ )
527
+
528
+ gr.Markdown("## 模型參數")
529
+
530
+ knn_k = gr.Slider(label="KNN:k 值", minimum=1, maximum=15, value=5, step=1)
531
+
532
+ dt_criterion = gr.Dropdown(
533
+ label="Decision Tree:criterion",
534
+ choices=["gini", "entropy"],
535
+ value="gini"
536
+ )
537
+ dt_max_depth = gr.Slider(
538
+ label="Decision Tree:max_depth(0 代表不限)",
539
+ minimum=0, maximum=20, value=5, step=1
540
+ )
541
+
542
+ rf_estimators = gr.Slider(
543
+ label="Random Forest:n_estimators",
544
+ minimum=10, maximum=300, value=100, step=10
545
+ )
546
+ rf_max_depth = gr.Slider(
547
+ label="Random Forest:max_depth(0 代表不限)",
548
+ minimum=0, maximum=20, value=5, step=1
549
+ )
550
+
551
+ lr_c = gr.Slider(
552
+ label="Logistic Regression:C",
553
+ minimum=0.01, maximum=10.0, value=1.0, step=0.01
554
+ )
555
+
556
+ svm_kernel = gr.Dropdown(
557
+ label="SVM:kernel",
558
+ choices=["linear", "rbf"],
559
+ value="rbf"
560
+ )
561
+ svm_c = gr.Slider(
562
+ label="SVM:C",
563
+ minimum=0.01, maximum=10.0, value=1.0, step=0.01
564
+ )
565
+
566
+ train_btn = gr.Button("開始訓練單一模型", variant="primary")
567
+
568
+ with gr.Column(scale=2):
569
+ single_result_output = gr.Textbox(label="模型結果", lines=8)
570
+ report_output = gr.Dataframe(label="Classification Report")
571
+ cm_output = gr.Plot(label="Confusion Matrix")
572
+ roc_output = gr.Plot(label="ROC Curve")
573
+
574
+ with gr.Tab("3. 多模型比較"):
575
+ with gr.Row():
576
+ with gr.Column(scale=1):
577
+ compare_btn = gr.Button("比較所有模型", variant="primary")
578
+
579
+ with gr.Column(scale=2):
580
+ compare_summary = gr.Textbox(label="最佳模型摘要", lines=8)
581
+ compare_table = gr.Dataframe(label="模型比較表")
582
+ compare_plot = gr.Plot(label="模型 Accuracy 比較圖")
583
+
584
+ analyze_btn.click(
585
  fn=analyze_file,
586
  inputs=[file_input],
587
+ outputs=[
588
+ preview_output,
589
+ info_output,
590
+ missing_output,
591
+ summary_output,
592
+ target_dropdown
593
+ ]
594
  )
595
 
596
+ dist_btn.click(
597
+ fn=target_distribution,
598
+ inputs=[file_input, target_dropdown, use_count_checkbox],
599
+ outputs=[dist_plot]
600
+ )
601
+
602
+ train_btn.click(
603
+ fn=train_single_model,
604
  inputs=[
605
  file_input,
606
  target_dropdown,
 
617
  svm_kernel,
618
  svm_c
619
  ],
620
+ outputs=[
621
+ single_result_output,
622
+ report_output,
623
+ cm_output,
624
+ roc_output
625
+ ]
626
+ )
627
+
628
+ compare_btn.click(
629
+ fn=compare_models,
630
+ inputs=[
631
+ file_input,
632
+ target_dropdown,
633
+ use_count_checkbox,
634
+ test_size_slider,
635
+ use_scaling_checkbox
636
+ ],
637
+ outputs=[
638
+ compare_summary,
639
+ compare_table,
640
+ compare_plot
641
+ ]
642
  )
643
 
644
  demo.launch()