william1324 commited on
Commit
37ea271
·
verified ·
1 Parent(s): 1087b36

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +398 -0
app.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+
6
+ from sklearn.model_selection import train_test_split
7
+ from sklearn.preprocessing import StandardScaler, LabelEncoder
8
+ from sklearn.impute import SimpleImputer
9
+
10
+ from sklearn.neighbors import KNeighborsClassifier
11
+ from sklearn.tree import DecisionTreeClassifier
12
+ from sklearn.ensemble import RandomForestClassifier
13
+ from sklearn.linear_model import LogisticRegression
14
+ from sklearn.svm import SVC
15
+
16
+ from sklearn.metrics import (
17
+ accuracy_score,
18
+ classification_report,
19
+ confusion_matrix,
20
+ ConfusionMatrixDisplay,
21
+ roc_curve,
22
+ auc
23
+ )
24
+
25
+
26
+ def load_data(file_obj):
27
+ if file_obj is None:
28
+ raise ValueError("請先上傳 CSV 或 Excel 檔案。")
29
+
30
+ file_path = file_obj.name
31
+ lower_name = file_path.lower()
32
+
33
+ if lower_name.endswith(".csv"):
34
+ return pd.read_csv(file_path)
35
+ if lower_name.endswith(".xlsx") or lower_name.endswith(".xls"):
36
+ return pd.read_excel(file_path)
37
+
38
+ raise ValueError("只支援 CSV、XLSX、XLS 檔案。")
39
+
40
+
41
+ def preprocess_data(df, target_column):
42
+ df = df.copy()
43
+ df = df.dropna(how="all")
44
+
45
+ y = df[target_column]
46
+ X = df.drop(columns=[target_column])
47
+
48
+ numeric_cols = X.select_dtypes(include=[np.number]).columns.tolist()
49
+ categorical_cols = X.select_dtypes(exclude=[np.number]).columns.tolist()
50
+
51
+ if numeric_cols:
52
+ num_imputer = SimpleImputer(strategy="median")
53
+ X[numeric_cols] = num_imputer.fit_transform(X[numeric_cols])
54
+
55
+ if categorical_cols:
56
+ cat_imputer = SimpleImputer(strategy="most_frequent")
57
+ X[categorical_cols] = cat_imputer.fit_transform(X[categorical_cols])
58
+ X = pd.get_dummies(X, columns=categorical_cols, drop_first=True)
59
+
60
+ return X, y
61
+
62
+
63
+ def build_model(
64
+ model_name,
65
+ knn_k,
66
+ dt_criterion,
67
+ dt_max_depth,
68
+ rf_estimators,
69
+ rf_max_depth,
70
+ lr_c,
71
+ svm_kernel,
72
+ svm_c
73
+ ):
74
+ if model_name == "KNN":
75
+ return KNeighborsClassifier(n_neighbors=int(knn_k))
76
+
77
+ if model_name == "Decision Tree":
78
+ max_depth = None if int(dt_max_depth) == 0 else int(dt_max_depth)
79
+ return DecisionTreeClassifier(
80
+ criterion=dt_criterion,
81
+ max_depth=max_depth,
82
+ random_state=42
83
+ )
84
+
85
+ if model_name == "Random Forest":
86
+ max_depth = None if int(rf_max_depth) == 0 else int(rf_max_depth)
87
+ return RandomForestClassifier(
88
+ n_estimators=int(rf_estimators),
89
+ max_depth=max_depth,
90
+ random_state=42
91
+ )
92
+
93
+ if model_name == "Logistic Regression":
94
+ return LogisticRegression(
95
+ C=float(lr_c),
96
+ max_iter=1000,
97
+ random_state=42
98
+ )
99
+
100
+ if model_name == "SVM":
101
+ return SVC(
102
+ kernel=svm_kernel,
103
+ C=float(svm_c),
104
+ probability=True,
105
+ random_state=42
106
+ )
107
+
108
+ raise ValueError("不支援的模型。")
109
+
110
+
111
+ def plot_confusion(y_true, y_pred):
112
+ fig, ax = plt.subplots(figsize=(5, 4))
113
+ disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix(y_true, y_pred))
114
+ disp.plot(ax=ax)
115
+ plt.tight_layout()
116
+ return fig
117
+
118
+
119
+ def plot_roc(y_true, y_prob):
120
+ fpr, tpr, _ = roc_curve(y_true, y_prob)
121
+ roc_auc = auc(fpr, tpr)
122
+
123
+ fig, ax = plt.subplots(figsize=(6, 4))
124
+ ax.plot(fpr, tpr, label=f"AUC = {roc_auc:.4f}")
125
+ ax.plot([0, 1], [0, 1], linestyle="--")
126
+ ax.set_xlabel("False Positive Rate")
127
+ ax.set_ylabel("True Positive Rate")
128
+ ax.set_title("ROC Curve")
129
+ ax.legend(loc="lower right")
130
+ plt.tight_layout()
131
+ return fig, roc_auc
132
+
133
+
134
+ def analyze_file(file_obj):
135
+ try:
136
+ df = load_data(file_obj)
137
+
138
+ info_df = pd.DataFrame({
139
+ "欄位名稱": df.columns,
140
+ "資料型態": [str(dtype) for dtype in df.dtypes]
141
+ })
142
+
143
+ missing_df = pd.DataFrame({
144
+ "欄位名稱": df.columns,
145
+ "缺失值數量": df.isnull().sum().values
146
+ })
147
+
148
+ preview_df = df.head(10)
149
+ summary_text = f"資料維度:{df.shape[0]} 筆 × {df.shape[1]} 欄"
150
+ columns = list(df.columns)
151
+
152
+ return (
153
+ preview_df,
154
+ info_df,
155
+ missing_df,
156
+ summary_text,
157
+ gr.update(choices=columns, value=columns[0] if columns else None)
158
+ )
159
+
160
+ except Exception as e:
161
+ empty_df = pd.DataFrame()
162
+ return (
163
+ empty_df,
164
+ empty_df,
165
+ empty_df,
166
+ f"錯誤:{e}",
167
+ gr.update(choices=[], value=None)
168
+ )
169
+
170
+
171
+ def train_model(
172
+ file_obj,
173
+ target_column,
174
+ use_count_as_target,
175
+ test_size,
176
+ use_scaling,
177
+ model_name,
178
+ knn_k,
179
+ dt_criterion,
180
+ dt_max_depth,
181
+ rf_estimators,
182
+ rf_max_depth,
183
+ lr_c,
184
+ svm_kernel,
185
+ svm_c
186
+ ):
187
+ try:
188
+ df = load_data(file_obj)
189
+
190
+ if use_count_as_target:
191
+ if "count" not in df.columns:
192
+ raise ValueError("你勾選了用 count 轉二元分類,但資料中沒有 count 欄位。")
193
+ median_value = df["count"].median()
194
+ df["label"] = (df["count"] > median_value).astype(int)
195
+ target_column = "label"
196
+
197
+ if target_column is None or target_column not in df.columns:
198
+ raise ValueError("請先選擇正確的目標欄位。")
199
+
200
+ X, y = preprocess_data(df, target_column)
201
+
202
+ if y.dtype == "object":
203
+ encoder = LabelEncoder()
204
+ y = encoder.fit_transform(y)
205
+
206
+ unique_classes = np.unique(y)
207
+ if len(unique_classes) != 2:
208
+ raise ValueError("目前此版本只支援二元分類,因為需要輸出 ROC / AUC。")
209
+
210
+ X_train, X_test, y_train, y_test = train_test_split(
211
+ X,
212
+ y,
213
+ test_size=float(test_size),
214
+ random_state=42,
215
+ stratify=y
216
+ )
217
+
218
+ if use_scaling:
219
+ scaler = StandardScaler()
220
+ X_train = scaler.fit_transform(X_train)
221
+ X_test = scaler.transform(X_test)
222
+ else:
223
+ X_train = X_train.values
224
+ X_test = X_test.values
225
+
226
+ model = build_model(
227
+ model_name=model_name,
228
+ knn_k=knn_k,
229
+ dt_criterion=dt_criterion,
230
+ dt_max_depth=dt_max_depth,
231
+ rf_estimators=rf_estimators,
232
+ rf_max_depth=rf_max_depth,
233
+ lr_c=lr_c,
234
+ svm_kernel=svm_kernel,
235
+ svm_c=svm_c
236
+ )
237
+
238
+ model.fit(X_train, y_train)
239
+
240
+ y_pred = model.predict(X_test)
241
+ y_prob = model.predict_proba(X_test)[:, 1] if hasattr(model, "predict_proba") else None
242
+
243
+ acc = accuracy_score(y_test, y_pred)
244
+ report_df = pd.DataFrame(
245
+ classification_report(y_test, y_pred, output_dict=True)
246
+ ).transpose()
247
+
248
+ cm_fig = plot_confusion(y_test, y_pred)
249
+
250
+ if y_prob is not None:
251
+ roc_fig, roc_auc = plot_roc(y_test, y_prob)
252
+ auc_text = f"AUC:{roc_auc:.4f}"
253
+ else:
254
+ roc_fig = None
255
+ auc_text = "AUC:無法計算"
256
+
257
+ result_text = f"模型:{model_name}\nAccuracy:{acc:.4f}\n{auc_text}"
258
+
259
+ return result_text, report_df, cm_fig, roc_fig
260
+
261
+ except Exception as e:
262
+ empty_df = pd.DataFrame()
263
+ return f"錯誤:{e}", empty_df, None, None
264
+
265
+
266
+ with gr.Blocks(title="機器學習模型訓練工具") as demo:
267
+ gr.Markdown("# 機器學習模型訓練工具")
268
+ gr.Markdown(
269
+ "支援 CSV / Excel 上傳、資料檢視、前處理、模型訓練、Classification Report、Confusion Matrix、ROC Curve。"
270
+ )
271
+
272
+ with gr.Row():
273
+ with gr.Column(scale=1):
274
+ file_input = gr.File(label="上傳 CSV 或 Excel 檔案", file_types=[".csv", ".xlsx", ".xls"])
275
+
276
+ analyze_button = gr.Button("分析資料")
277
+ target_dropdown = gr.Dropdown(label="選擇目標欄位", choices=[], value=None)
278
+
279
+ use_count_checkbox = gr.Checkbox(
280
+ label="若資料有 count 欄位,將 count 依中位數轉為二元分類",
281
+ value=True
282
+ )
283
+
284
+ test_size_slider = gr.Slider(
285
+ label="測試集比例",
286
+ minimum=0.1,
287
+ maximum=0.5,
288
+ value=0.2,
289
+ step=0.1
290
+ )
291
+
292
+ use_scaling_checkbox = gr.Checkbox(
293
+ label="使用 StandardScaler",
294
+ value=True
295
+ )
296
+
297
+ model_dropdown = gr.Dropdown(
298
+ label="選擇模型",
299
+ choices=[
300
+ "KNN",
301
+ "Decision Tree",
302
+ "Random Forest",
303
+ "Logistic Regression",
304
+ "SVM"
305
+ ],
306
+ value="KNN"
307
+ )
308
+
309
+ gr.Markdown("## 模型參數")
310
+
311
+ knn_k = gr.Slider(label="KNN:k 值", minimum=1, maximum=15, value=5, step=1)
312
+ dt_criterion = gr.Dropdown(
313
+ label="Decision Tree:criterion",
314
+ choices=["gini", "entropy"],
315
+ value="gini"
316
+ )
317
+ dt_max_depth = gr.Slider(
318
+ label="Decision Tree:max_depth(0 代表不限)",
319
+ minimum=0,
320
+ maximum=20,
321
+ value=5,
322
+ step=1
323
+ )
324
+ rf_estimators = gr.Slider(
325
+ label="Random Forest:n_estimators",
326
+ minimum=10,
327
+ maximum=300,
328
+ value=100,
329
+ step=10
330
+ )
331
+ rf_max_depth = gr.Slider(
332
+ label="Random Forest:max_depth(0 代表不限)",
333
+ minimum=0,
334
+ maximum=20,
335
+ value=5,
336
+ step=1
337
+ )
338
+ lr_c = gr.Slider(
339
+ label="Logistic Regression:C",
340
+ minimum=0.01,
341
+ maximum=10.0,
342
+ value=1.0,
343
+ step=0.01
344
+ )
345
+ svm_kernel = gr.Dropdown(
346
+ label="SVM:kernel",
347
+ choices=["linear", "rbf"],
348
+ value="rbf"
349
+ )
350
+ svm_c = gr.Slider(
351
+ label="SVM:C",
352
+ minimum=0.01,
353
+ maximum=10.0,
354
+ value=1.0,
355
+ step=0.01
356
+ )
357
+
358
+ train_button = gr.Button("開始訓練", variant="primary")
359
+
360
+ with gr.Column(scale=2):
361
+ summary_text = gr.Textbox(label="資料摘要")
362
+ preview_output = gr.Dataframe(label="資料預覽")
363
+ info_output = gr.Dataframe(label="欄位型態")
364
+ missing_output = gr.Dataframe(label="缺失值統計")
365
+
366
+ result_text = gr.Textbox(label="模型結果")
367
+ report_output = gr.Dataframe(label="Classification Report")
368
+ cm_output = gr.Plot(label="Confusion Matrix")
369
+ roc_output = gr.Plot(label="ROC Curve")
370
+
371
+ analyze_button.click(
372
+ fn=analyze_file,
373
+ inputs=[file_input],
374
+ outputs=[preview_output, info_output, missing_output, summary_text, target_dropdown]
375
+ )
376
+
377
+ train_button.click(
378
+ fn=train_model,
379
+ inputs=[
380
+ file_input,
381
+ target_dropdown,
382
+ use_count_checkbox,
383
+ test_size_slider,
384
+ use_scaling_checkbox,
385
+ model_dropdown,
386
+ knn_k,
387
+ dt_criterion,
388
+ dt_max_depth,
389
+ rf_estimators,
390
+ rf_max_depth,
391
+ lr_c,
392
+ svm_kernel,
393
+ svm_c
394
+ ],
395
+ outputs=[result_text, report_output, cm_output, roc_output]
396
+ )
397
+
398
+ demo.launch()