KKingzor commited on
Commit
70a5d73
·
verified ·
1 Parent(s): 5914e95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -55
app.py CHANGED
@@ -8,67 +8,65 @@ from sklearn.tree import DecisionTreeClassifier
8
  from sklearn.ensemble import RandomForestClassifier
9
  from sklearn.metrics import accuracy_score
10
 
11
- # Streamlit 標題
12
- st.title('分類模型選擇器')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # 輸入數據集的路徑
15
- file_path = st.text_input('請輸入數據集的路徑 (CSV 格式)')
 
 
 
 
16
 
17
- # 如果路徑不為空,則加載數據
18
- if file_path:
19
- try:
20
- # 加載數據集
21
- def load_data(file_path):
22
- data = pd.read_csv(file_path)
23
- X = data.iloc[:, :-1].values # 特徵變量
24
- y = data.iloc[:, -1].values # 目標變量
25
- return X, y
26
-
27
- X, y = load_data(file_path)
28
 
29
- # 分割數據集
30
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
31
 
32
- # 下拉選單讓使用者選擇分類方法
33
- classifier_name = st.selectbox(
34
- '選擇分類器',
35
- ('SVM', '堆疊法', '軟投票', '硬投票')
36
- )
 
37
 
38
- # 定義每個分類方法的模型
39
- def get_classifier(classifier_name):
40
- if classifier_name == 'SVM':
41
- clf = SVC(kernel='linear', probability=True)
42
- elif classifier_name == '堆疊法':
43
- estimators = [
44
- ('lr', LogisticRegression()),
45
- ('rf', RandomForestClassifier()),
46
- ('dt', DecisionTreeClassifier())
47
- ]
48
- clf = StackingClassifier(estimators=estimators, final_estimator=SVC())
49
- elif classifier_name == '軟投票':
50
- clf1 = LogisticRegression()
51
- clf2 = RandomForestClassifier()
52
- clf3 = SVC(probability=True)
53
- clf = VotingClassifier(estimators=[('lr', clf1), ('rf', clf2), ('svc', clf3)], voting='soft')
54
- else: # 硬投票
55
- clf1 = LogisticRegression()
56
- clf2 = RandomForestClassifier()
57
- clf3 = SVC()
58
- clf = VotingClassifier(estimators=[('lr', clf1), ('rf', clf2), ('svc', clf3)], voting='hard')
59
- return clf
60
 
61
- clf = get_classifier(classifier_name)
 
 
62
 
63
- # 訓練模型
64
- clf.fit(X_train, y_train)
65
 
66
- # 預測測試集
67
- y_pred = clf.predict(X_test)
68
 
69
- # 顯示模型準確率
70
- acc = accuracy_score(y_test, y_pred)
71
- st.write(f'分類器 = {classifier_name}')
72
- st.write(f'準確率 = {acc:.2f}')
73
- except Exception as e:
74
- st.error(f"加載數據集時出錯: {e}")
 
8
  from sklearn.ensemble import RandomForestClassifier
9
  from sklearn.metrics import accuracy_score
10
 
11
+ # 定義每個分類方法模型
12
+ def get_classifier(classifier_name):
13
+ if classifier_name == 'SVM':
14
+ clf = SVC(kernel='linear', probability=True)
15
+ elif classifier_name == '堆疊法':
16
+ estimators = [
17
+ ('lr', LogisticRegression()),
18
+ ('rf', RandomForestClassifier()),
19
+ ('dt', DecisionTreeClassifier())
20
+ ]
21
+ clf = StackingClassifier(estimators=estimators, final_estimator=SVC())
22
+ elif classifier_name == '軟投票':
23
+ clf1 = LogisticRegression()
24
+ clf2 = RandomForestClassifier()
25
+ clf3 = SVC(probability=True)
26
+ clf = VotingClassifier(estimators=[('lr', clf1), ('rf', clf2), ('svc', clf3)], voting='soft')
27
+ else: # 硬投票
28
+ clf1 = LogisticRegression()
29
+ clf2 = RandomForestClassifier()
30
+ clf3 = SVC()
31
+ clf = VotingClassifier(estimators=[('lr', clf1), ('rf', clf2), ('svc', clf3)], voting='hard')
32
+ return clf
33
 
34
+ # 加載數據集
35
+ def load_data(file):
36
+ data = pd.read_csv(file)
37
+ X = data.iloc[:, :-1].values # 特徵變量
38
+ y = data.iloc[:, -1].values # 目標變量
39
+ return X, y
40
 
41
+ # 設定 Streamlit 標題
42
+ st.title("分類模型選擇器")
 
 
 
 
 
 
 
 
 
43
 
44
+ # 上傳 CSV 文件
45
+ uploaded_file = st.file_uploader("上傳 CSV 文件作為數據集", type="csv")
46
 
47
+ if uploaded_file is not None:
48
+ # 加載數據集
49
+ X, y = load_data(uploaded_file)
50
+
51
+ # 分割數據集
52
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
53
 
54
+ # 選擇分類
55
+ classifier_name = st.selectbox(
56
+ '選擇分類器',
57
+ ('SVM', '堆疊法', '軟投票', '硬投票')
58
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ # 獲取並訓練模型
61
+ clf = get_classifier(classifier_name)
62
+ clf.fit(X_train, y_train)
63
 
64
+ # 預測結果
65
+ y_pred = clf.predict(X_test)
66
 
67
+ # 計算準確率
68
+ acc = accuracy_score(y_test, y_pred)
69
 
70
+ # 顯示結果
71
+ st.write(f'分類器 = {classifier_name}')
72
+ st.write(f'準確率 = {acc:.2f}')