william1324 commited on
Commit
2975e51
·
verified ·
1 Parent(s): 0731034
Files changed (1) hide show
  1. app.py +273 -0
app.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+
6
+ from sklearn.model_selection import train_test_split
7
+ from sklearn.preprocessing import StandardScaler, LabelEncoder
8
+ from sklearn.impute import SimpleImputer
9
+
10
+ from sklearn.neighbors import KNeighborsClassifier
11
+ from sklearn.tree import DecisionTreeClassifier
12
+ from sklearn.ensemble import RandomForestClassifier
13
+ from sklearn.linear_model import LogisticRegression
14
+ from sklearn.svm import SVC
15
+
16
+ from sklearn.metrics import (
17
+ accuracy_score,
18
+ classification_report,
19
+ confusion_matrix,
20
+ ConfusionMatrixDisplay,
21
+ roc_curve,
22
+ auc
23
+ )
24
+
25
+
26
+ st.set_page_config(page_title="機器學習模型訓練工具", layout="wide")
27
+
28
+ st.title("機器學習模型訓練工具開發")
29
+ st.write("支援資料上傳、前處理、模型訓練、模型評估與視覺化。")
30
+
31
+
32
+ def load_data(uploaded_file):
33
+ file_name = uploaded_file.name.lower()
34
+ if file_name.endswith(".csv"):
35
+ df = pd.read_csv(uploaded_file)
36
+ elif file_name.endswith(".xlsx") or file_name.endswith(".xls"):
37
+ df = pd.read_excel(uploaded_file)
38
+ else:
39
+ return None
40
+ return df
41
+
42
+
43
+ def preprocess_data(df, target_column):
44
+ df = df.copy()
45
+ df = df.dropna(how="all")
46
+
47
+ y = df[target_column]
48
+ X = df.drop(columns=[target_column])
49
+
50
+ numeric_cols = X.select_dtypes(include=[np.number]).columns.tolist()
51
+ categorical_cols = X.select_dtypes(exclude=[np.number]).columns.tolist()
52
+
53
+ if len(numeric_cols) > 0:
54
+ num_imputer = SimpleImputer(strategy="median")
55
+ X[numeric_cols] = num_imputer.fit_transform(X[numeric_cols])
56
+
57
+ if len(categorical_cols) > 0:
58
+ cat_imputer = SimpleImputer(strategy="most_frequent")
59
+ X[categorical_cols] = cat_imputer.fit_transform(X[categorical_cols])
60
+
61
+ if len(categorical_cols) > 0:
62
+ X = pd.get_dummies(X, columns=categorical_cols, drop_first=True)
63
+
64
+ return X, y
65
+
66
+
67
+ def build_model(model_name, params):
68
+ if model_name == "KNN":
69
+ return KNeighborsClassifier(n_neighbors=params["n_neighbors"])
70
+
71
+ if model_name == "Decision Tree":
72
+ return DecisionTreeClassifier(
73
+ criterion=params["criterion"],
74
+ max_depth=params["max_depth"],
75
+ random_state=42
76
+ )
77
+
78
+ if model_name == "Random Forest":
79
+ return RandomForestClassifier(
80
+ n_estimators=params["n_estimators"],
81
+ max_depth=params["max_depth"],
82
+ random_state=42
83
+ )
84
+
85
+ if model_name == "Logistic Regression":
86
+ return LogisticRegression(
87
+ C=params["C"],
88
+ max_iter=1000,
89
+ random_state=42
90
+ )
91
+
92
+ if model_name == "SVM":
93
+ return SVC(
94
+ kernel=params["kernel"],
95
+ C=params["C"],
96
+ probability=True,
97
+ random_state=42
98
+ )
99
+
100
+ return None
101
+
102
+
103
+ def plot_confusion_matrix(y_true, y_pred):
104
+ fig, ax = plt.subplots(figsize=(5, 4))
105
+ disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix(y_true, y_pred))
106
+ disp.plot(ax=ax)
107
+ st.pyplot(fig)
108
+
109
+
110
+ def plot_roc_curve(y_true, y_prob):
111
+ fpr, tpr, _ = roc_curve(y_true, y_prob)
112
+ roc_auc = auc(fpr, tpr)
113
+
114
+ fig, ax = plt.subplots(figsize=(6, 4))
115
+ ax.plot(fpr, tpr, label=f"AUC = {roc_auc:.4f}")
116
+ ax.plot([0, 1], [0, 1], linestyle="--")
117
+ ax.set_xlabel("False Positive Rate")
118
+ ax.set_ylabel("True Positive Rate")
119
+ ax.set_title("ROC Curve")
120
+ ax.legend(loc="lower right")
121
+ st.pyplot(fig)
122
+
123
+ return roc_auc
124
+
125
+
126
+ st.sidebar.header("操作區")
127
+ uploaded_file = st.sidebar.file_uploader("請上傳 CSV 或 Excel 檔", type=["csv", "xlsx", "xls"])
128
+
129
+ if uploaded_file is not None:
130
+ df = load_data(uploaded_file)
131
+
132
+ if df is None:
133
+ st.error("檔案格式不支援。")
134
+ st.stop()
135
+
136
+ st.subheader("原始資料預覽")
137
+ st.dataframe(df.head())
138
+
139
+ col1, col2 = st.columns(2)
140
+
141
+ with col1:
142
+ st.subheader("資料基本資訊")
143
+ st.write(f"資料維度:{df.shape[0]} 筆 × {df.shape[1]} 欄")
144
+ st.write("欄位型態:")
145
+ st.dataframe(pd.DataFrame(df.dtypes, columns=["dtype"]))
146
+
147
+ with col2:
148
+ st.subheader("缺失值統計")
149
+ st.dataframe(pd.DataFrame(df.isnull().sum(), columns=["missing_count"]))
150
+
151
+ st.subheader("欄位選擇")
152
+ all_columns = df.columns.tolist()
153
+
154
+ if "count" in all_columns:
155
+ st.info("偵測到 count 欄位,可依作業需求轉為二元分類標籤。")
156
+ use_count_as_target = st.checkbox(
157
+ "將 count 轉為二元分類標籤(大於中位數=1,否則=0)",
158
+ value=True
159
+ )
160
+
161
+ if use_count_as_target:
162
+ median_value = df["count"].median()
163
+ df["label"] = (df["count"] > median_value).astype(int)
164
+ target_column = "label"
165
+ st.write(f"`count` 中位數 = {median_value}")
166
+ st.write("已建立新目標欄位:`label`")
167
+ else:
168
+ target_column = st.selectbox("請選擇目標欄位", all_columns)
169
+ else:
170
+ target_column = st.selectbox("請選擇目標欄位", all_columns)
171
+
172
+ st.subheader("目標欄位分布")
173
+ st.write(df[target_column].value_counts())
174
+
175
+ test_size = st.sidebar.slider("測試集比例 (Test Size)", 0.1, 0.5, 0.2, 0.1)
176
+ use_scaling = st.sidebar.checkbox("使用 StandardScaler", value=True)
177
+
178
+ model_name = st.sidebar.selectbox(
179
+ "選擇模型",
180
+ ["KNN", "Decision Tree", "Random Forest", "Logistic Regression", "SVM"]
181
+ )
182
+
183
+ params = {}
184
+
185
+ if model_name == "KNN":
186
+ params["n_neighbors"] = st.sidebar.slider("k 值", 1, 15, 5)
187
+
188
+ elif model_name == "Decision Tree":
189
+ params["criterion"] = st.sidebar.selectbox("criterion", ["gini", "entropy"])
190
+ max_depth_input = st.sidebar.number_input("max_depth(0 代表不限)", min_value=0, value=5, step=1)
191
+ params["max_depth"] = None if max_depth_input == 0 else int(max_depth_input)
192
+
193
+ elif model_name == "Random Forest":
194
+ params["n_estimators"] = st.sidebar.slider("n_estimators", 10, 300, 100, 10)
195
+ max_depth_input = st.sidebar.number_input("max_depth(0 代表不限)", min_value=0, value=5, step=1)
196
+ params["max_depth"] = None if max_depth_input == 0 else int(max_depth_input)
197
+
198
+ elif model_name == "Logistic Regression":
199
+ params["C"] = st.sidebar.slider("C", 0.01, 10.0, 1.0, 0.01)
200
+
201
+ elif model_name == "SVM":
202
+ params["kernel"] = st.sidebar.selectbox("kernel", ["linear", "rbf"])
203
+ params["C"] = st.sidebar.slider("C", 0.01, 10.0, 1.0, 0.01)
204
+
205
+ run_button = st.sidebar.button("開始訓練模型")
206
+
207
+ if run_button:
208
+ try:
209
+ X, y = preprocess_data(df, target_column)
210
+
211
+ if y.dtype == "object":
212
+ le = LabelEncoder()
213
+ y = le.fit_transform(y)
214
+
215
+ unique_classes = np.unique(y)
216
+ if len(unique_classes) != 2:
217
+ st.error("目前程式設計為二元分類評估(ROC/AUC)。請選擇二元分類目標欄位。")
218
+ st.stop()
219
+
220
+ X_train, X_test, y_train, y_test = train_test_split(
221
+ X, y,
222
+ test_size=test_size,
223
+ random_state=42,
224
+ stratify=y
225
+ )
226
+
227
+ if use_scaling:
228
+ scaler = StandardScaler()
229
+ X_train = scaler.fit_transform(X_train)
230
+ X_test = scaler.transform(X_test)
231
+ else:
232
+ X_train = X_train.values
233
+ X_test = X_test.values
234
+
235
+ model = build_model(model_name, params)
236
+ model.fit(X_train, y_train)
237
+
238
+ y_pred = model.predict(X_test)
239
+ y_prob = model.predict_proba(X_test)[:, 1] if hasattr(model, "predict_proba") else None
240
+
241
+ st.success("模型訓練完成")
242
+
243
+ col3, col4 = st.columns(2)
244
+
245
+ with col3:
246
+ st.subheader("Accuracy")
247
+ acc = accuracy_score(y_test, y_pred)
248
+ st.write(f"{acc:.4f}")
249
+
250
+ with col4:
251
+ if y_prob is not None:
252
+ fpr, tpr, _ = roc_curve(y_test, y_prob)
253
+ roc_auc = auc(fpr, tpr)
254
+ st.subheader("AUC")
255
+ st.write(f"{roc_auc:.4f}")
256
+
257
+ st.subheader("Classification Report")
258
+ report = classification_report(y_test, y_pred, output_dict=True)
259
+ report_df = pd.DataFrame(report).transpose()
260
+ st.dataframe(report_df)
261
+
262
+ st.subheader("Confusion Matrix")
263
+ plot_confusion_matrix(y_test, y_pred)
264
+
265
+ if y_prob is not None:
266
+ st.subheader("ROC Curve")
267
+ plot_roc_curve(y_test, y_prob)
268
+
269
+ except Exception as e:
270
+ st.error(f"執行時發生錯誤:{e}")
271
+
272
+ else:
273
+ st.info("請先在左側上傳資料檔案。")