Spaces:
Sleeping
Sleeping
| import warnings | |
| warnings.filterwarnings("ignore") | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.preprocessing import StandardScaler, LabelEncoder | |
| from sklearn.impute import SimpleImputer | |
| from sklearn.neighbors import KNeighborsClassifier | |
| from sklearn.tree import DecisionTreeClassifier | |
| from sklearn.ensemble import RandomForestClassifier | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.svm import SVC | |
| from sklearn.metrics import ( | |
| accuracy_score, | |
| precision_score, | |
| recall_score, | |
| f1_score, | |
| classification_report, | |
| confusion_matrix, | |
| ConfusionMatrixDisplay, | |
| roc_curve, | |
| auc | |
| ) | |
| # ========================= | |
| # 基本工具函式 | |
| # ========================= | |
| def load_data(file_obj): | |
| if file_obj is None: | |
| raise ValueError("請先上傳 CSV 或 Excel 檔案。") | |
| file_path = file_obj.name | |
| lower_name = file_path.lower() | |
| if lower_name.endswith(".csv"): | |
| return pd.read_csv(file_path) | |
| if lower_name.endswith(".xlsx") or lower_name.endswith(".xls"): | |
| return pd.read_excel(file_path) | |
| raise ValueError("目前只支援 .csv、.xlsx、.xls 檔案。") | |
| def build_model( | |
| model_name, | |
| knn_k, | |
| dt_criterion, | |
| dt_max_depth, | |
| rf_estimators, | |
| rf_max_depth, | |
| lr_c, | |
| svm_kernel, | |
| svm_c | |
| ): | |
| if model_name == "KNN": | |
| return KNeighborsClassifier(n_neighbors=int(knn_k)) | |
| if model_name == "Decision Tree": | |
| max_depth = None if int(dt_max_depth) == 0 else int(dt_max_depth) | |
| return DecisionTreeClassifier( | |
| criterion=dt_criterion, | |
| max_depth=max_depth, | |
| random_state=42 | |
| ) | |
| if model_name == "Random Forest": | |
| max_depth = None if int(rf_max_depth) == 0 else int(rf_max_depth) | |
| return RandomForestClassifier( | |
| n_estimators=int(rf_estimators), | |
| max_depth=max_depth, | |
| random_state=42 | |
| ) | |
| if model_name == "Logistic Regression": | |
| return LogisticRegression( | |
| C=float(lr_c), | |
| max_iter=2000, | |
| random_state=42 | |
| ) | |
| if model_name == "SVM": | |
| return SVC( | |
| kernel=svm_kernel, | |
| C=float(svm_c), | |
| probability=True, | |
| random_state=42 | |
| ) | |
| raise ValueError("不支援的模型。") | |
| def preprocess_features(df, target_column): | |
| df = df.copy().dropna(how="all") | |
| y = df[target_column] | |
| X = df.drop(columns=[target_column]) | |
| numeric_cols = X.select_dtypes(include=[np.number]).columns.tolist() | |
| categorical_cols = X.select_dtypes(exclude=[np.number]).columns.tolist() | |
| if numeric_cols: | |
| num_imputer = SimpleImputer(strategy="median") | |
| X[numeric_cols] = num_imputer.fit_transform(X[numeric_cols]) | |
| if categorical_cols: | |
| cat_imputer = SimpleImputer(strategy="most_frequent") | |
| X[categorical_cols] = cat_imputer.fit_transform(X[categorical_cols]) | |
| X = pd.get_dummies(X, columns=categorical_cols, drop_first=True) | |
| return X, y | |
| def prepare_target(df, target_column, use_count_as_target): | |
| df = df.copy() | |
| if use_count_as_target: | |
| if "count" not in df.columns: | |
| raise ValueError("你勾選了 count 二元分類,但資料中沒有 count 欄位。") | |
| median_value = df["count"].median() | |
| df["label"] = (df["count"] > median_value).astype(int) | |
| target_column = "label" | |
| if target_column is None or target_column not in df.columns: | |
| raise ValueError("請選擇正確的目標欄位。") | |
| return df, target_column | |
| def encode_target(y): | |
| if y.dtype == "object": | |
| encoder = LabelEncoder() | |
| y = encoder.fit_transform(y) | |
| return y | |
| # ========================= | |
| # 視覺化函式 | |
| # ========================= | |
| def plot_target_distribution(y_series, title="Label Distribution"): | |
| fig, ax = plt.subplots(figsize=(6, 4)) | |
| counts = pd.Series(y_series).value_counts().sort_index() | |
| ax.bar(counts.index.astype(str), counts.values) | |
| ax.set_title(title) | |
| ax.set_xlabel("Class") | |
| ax.set_ylabel("Count") | |
| plt.tight_layout() | |
| return fig | |
| def plot_confusion(y_true, y_pred): | |
| fig, ax = plt.subplots(figsize=(5, 4)) | |
| disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix(y_true, y_pred)) | |
| disp.plot(ax=ax) | |
| ax.set_title("Confusion Matrix") | |
| plt.tight_layout() | |
| return fig | |
| def plot_roc_curve(y_true, y_prob): | |
| fpr, tpr, _ = roc_curve(y_true, y_prob) | |
| roc_auc = auc(fpr, tpr) | |
| fig, ax = plt.subplots(figsize=(6, 4)) | |
| ax.plot(fpr, tpr, label=f"AUC = {roc_auc:.4f}") | |
| ax.plot([0, 1], [0, 1], linestyle="--") | |
| ax.set_title("ROC Curve") | |
| ax.set_xlabel("False Positive Rate") | |
| ax.set_ylabel("True Positive Rate") | |
| ax.legend(loc="lower right") | |
| plt.tight_layout() | |
| return fig, roc_auc | |
| def plot_model_comparison(result_df): | |
| fig, ax = plt.subplots(figsize=(8, 4)) | |
| ax.bar(result_df["Model"], result_df["Accuracy"]) | |
| ax.set_title("Model Accuracy Comparison") | |
| ax.set_xlabel("Model") | |
| ax.set_ylabel("Accuracy") | |
| ax.set_ylim(0, 1) | |
| plt.xticks(rotation=15) | |
| plt.tight_layout() | |
| return fig | |
| # ========================= | |
| # 資料分析 | |
| # ========================= | |
| def analyze_file(file_obj): | |
| try: | |
| df = load_data(file_obj) | |
| preview_df = df.head(10) | |
| info_df = pd.DataFrame({ | |
| "欄位名稱": df.columns, | |
| "資料型態": [str(dtype) for dtype in df.dtypes] | |
| }) | |
| missing_df = pd.DataFrame({ | |
| "欄位名稱": df.columns, | |
| "缺失值數量": df.isnull().sum().values, | |
| "缺失比例(%)": (df.isnull().mean().values * 100).round(2) | |
| }) | |
| summary = [] | |
| summary.append(f"資料筆數:{df.shape[0]}") | |
| summary.append(f"資料欄數:{df.shape[1]}") | |
| summary.append(f"數值欄位數:{len(df.select_dtypes(include=[np.number]).columns)}") | |
| summary.append(f"類別欄位數:{len(df.select_dtypes(exclude=[np.number]).columns)}") | |
| summary.append(f"總缺失值數:{int(df.isnull().sum().sum())}") | |
| columns = list(df.columns) | |
| if len(columns) > 0: | |
| default_target = "count" if "count" in columns else columns[-1] | |
| else: | |
| default_target = None | |
| has_count_message = "有偵測到 count 欄位,可直接轉成二元分類。" if "count" in df.columns else "未偵測到 count 欄位。" | |
| empty_fig = plt.figure() | |
| plt.close(empty_fig) | |
| return ( | |
| preview_df, | |
| info_df, | |
| missing_df, | |
| "\n".join(summary) + f"\n{has_count_message}", | |
| gr.update(choices=columns, value=default_target), | |
| ) | |
| except Exception as e: | |
| empty_df = pd.DataFrame() | |
| return ( | |
| empty_df, | |
| empty_df, | |
| empty_df, | |
| f"資料分析失敗:{e}", | |
| gr.update(choices=[], value=None), | |
| ) | |
| def target_distribution(file_obj, target_column, use_count_as_target): | |
| try: | |
| df = load_data(file_obj) | |
| df, target_column = prepare_target(df, target_column, use_count_as_target) | |
| fig = plot_target_distribution(df[target_column], title=f"{target_column} Distribution") | |
| return fig | |
| except Exception as e: | |
| fig, ax = plt.subplots(figsize=(6, 3)) | |
| ax.text(0.5, 0.5, f"無法產生分布圖:\n{e}", ha="center", va="center") | |
| ax.axis("off") | |
| plt.tight_layout() | |
| return fig | |
| # ========================= | |
| # 單一模型訓練 | |
| # ========================= | |
| def train_single_model( | |
| file_obj, | |
| target_column, | |
| use_count_as_target, | |
| test_size, | |
| use_scaling, | |
| model_name, | |
| knn_k, | |
| dt_criterion, | |
| dt_max_depth, | |
| rf_estimators, | |
| rf_max_depth, | |
| lr_c, | |
| svm_kernel, | |
| svm_c | |
| ): | |
| try: | |
| df = load_data(file_obj) | |
| df, target_column = prepare_target(df, target_column, use_count_as_target) | |
| X, y = preprocess_features(df, target_column) | |
| y = encode_target(y) | |
| unique_classes = np.unique(y) | |
| if len(unique_classes) != 2: | |
| raise ValueError("目前版本只支援二元分類,因為需要輸出 ROC/AUC。") | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X, y, | |
| test_size=float(test_size), | |
| random_state=42, | |
| stratify=y | |
| ) | |
| if use_scaling: | |
| scaler = StandardScaler() | |
| X_train = scaler.fit_transform(X_train) | |
| X_test = scaler.transform(X_test) | |
| else: | |
| X_train = X_train.values | |
| X_test = X_test.values | |
| model = build_model( | |
| model_name=model_name, | |
| knn_k=knn_k, | |
| dt_criterion=dt_criterion, | |
| dt_max_depth=dt_max_depth, | |
| rf_estimators=rf_estimators, | |
| rf_max_depth=rf_max_depth, | |
| lr_c=lr_c, | |
| svm_kernel=svm_kernel, | |
| svm_c=svm_c | |
| ) | |
| model.fit(X_train, y_train) | |
| y_pred = model.predict(X_test) | |
| y_prob = None | |
| if hasattr(model, "predict_proba"): | |
| y_prob = model.predict_proba(X_test)[:, 1] | |
| acc = accuracy_score(y_test, y_pred) | |
| pre = precision_score(y_test, y_pred, zero_division=0) | |
| rec = recall_score(y_test, y_pred, zero_division=0) | |
| f1 = f1_score(y_test, y_pred, zero_division=0) | |
| auc_text = "無法計算" | |
| roc_fig = None | |
| if y_prob is not None: | |
| roc_fig, roc_auc = plot_roc_curve(y_test, y_prob) | |
| auc_text = f"{roc_auc:.4f}" | |
| result_text = ( | |
| f"模型名稱:{model_name}\n" | |
| f"Accuracy:{acc:.4f}\n" | |
| f"Precision:{pre:.4f}\n" | |
| f"Recall:{rec:.4f}\n" | |
| f"F1-score:{f1:.4f}\n" | |
| f"AUC:{auc_text}" | |
| ) | |
| report_df = pd.DataFrame(classification_report(y_test, y_pred, output_dict=True)).transpose() | |
| cm_fig = plot_confusion(y_test, y_pred) | |
| return result_text, report_df.round(4), cm_fig, roc_fig | |
| except Exception as e: | |
| empty_df = pd.DataFrame() | |
| fig, ax = plt.subplots(figsize=(6, 3)) | |
| ax.text(0.5, 0.5, f"錯誤:{e}", ha="center", va="center") | |
| ax.axis("off") | |
| plt.tight_layout() | |
| return f"模型訓練失敗:{e}", empty_df, fig, None | |
| # ========================= | |
| # 多模型比較 | |
| # ========================= | |
| def compare_models( | |
| file_obj, | |
| target_column, | |
| use_count_as_target, | |
| test_size, | |
| use_scaling | |
| ): | |
| try: | |
| df = load_data(file_obj) | |
| df, target_column = prepare_target(df, target_column, use_count_as_target) | |
| X, y = preprocess_features(df, target_column) | |
| y = encode_target(y) | |
| unique_classes = np.unique(y) | |
| if len(unique_classes) != 2: | |
| raise ValueError("目前版本只支援二元分類比較。") | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X, y, | |
| test_size=float(test_size), | |
| random_state=42, | |
| stratify=y | |
| ) | |
| if use_scaling: | |
| scaler = StandardScaler() | |
| X_train_scaled = scaler.fit_transform(X_train) | |
| X_test_scaled = scaler.transform(X_test) | |
| else: | |
| X_train_scaled = X_train.values | |
| X_test_scaled = X_test.values | |
| models = [ | |
| ("KNN", KNeighborsClassifier(n_neighbors=5)), | |
| ("Decision Tree", DecisionTreeClassifier(random_state=42)), | |
| ("Random Forest", RandomForestClassifier(n_estimators=100, random_state=42)), | |
| ("Logistic Regression", LogisticRegression(max_iter=2000, random_state=42)), | |
| ("SVM", SVC(kernel="rbf", probability=True, random_state=42)), | |
| ] | |
| rows = [] | |
| for name, model in models: | |
| model.fit(X_train_scaled, y_train) | |
| y_pred = model.predict(X_test_scaled) | |
| acc = accuracy_score(y_test, y_pred) | |
| pre = precision_score(y_test, y_pred, zero_division=0) | |
| rec = recall_score(y_test, y_pred, zero_division=0) | |
| f1 = f1_score(y_test, y_pred, zero_division=0) | |
| auc_score = np.nan | |
| if hasattr(model, "predict_proba"): | |
| y_prob = model.predict_proba(X_test_scaled)[:, 1] | |
| auc_score = auc(*roc_curve(y_test, y_prob)[:2]) | |
| rows.append({ | |
| "Model": name, | |
| "Accuracy": round(acc, 4), | |
| "Precision": round(pre, 4), | |
| "Recall": round(rec, 4), | |
| "F1-score": round(f1, 4), | |
| "AUC": None if pd.isna(auc_score) else round(auc_score, 4) | |
| }) | |
| result_df = pd.DataFrame(rows).sort_values(by="Accuracy", ascending=False).reset_index(drop=True) | |
| compare_fig = plot_model_comparison(result_df) | |
| best_model = result_df.iloc[0] | |
| summary = ( | |
| f"最佳模型:{best_model['Model']}\n" | |
| f"Accuracy:{best_model['Accuracy']}\n" | |
| f"Precision:{best_model['Precision']}\n" | |
| f"Recall:{best_model['Recall']}\n" | |
| f"F1-score:{best_model['F1-score']}\n" | |
| f"AUC:{best_model['AUC']}" | |
| ) | |
| return summary, result_df, compare_fig | |
| except Exception as e: | |
| empty_df = pd.DataFrame() | |
| fig, ax = plt.subplots(figsize=(6, 3)) | |
| ax.text(0.5, 0.5, f"錯誤:{e}", ha="center", va="center") | |
| ax.axis("off") | |
| plt.tight_layout() | |
| return f"模型比較失敗:{e}", empty_df, fig | |
| # ========================= | |
| # UI | |
| # ========================= | |
| custom_css = """ | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| } | |
| """ | |
| with gr.Blocks(title="機器學習模型訓練工具", css=custom_css) as demo: | |
| gr.Markdown(""" | |
| # 機器學習模型訓練 | |
| - 資料上傳與預覽 | |
| - 欄位型態與缺失值分析 | |
| - `count` 欄位轉二元分類 | |
| - KNN / Decision Tree / Random Forest / Logistic Regression / SVM | |
| - Accuracy / Precision / Recall / F1-score / AUC | |
| - Confusion Matrix / ROC Curve | |
| - 多模型比較 | |
| """) | |
| with gr.Tab("1. 資料分析"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| file_input = gr.File( | |
| label="上傳 CSV 或 Excel 檔案", | |
| file_types=[".csv", ".xlsx", ".xls"] | |
| ) | |
| analyze_btn = gr.Button("分析資料", variant="primary") | |
| target_dropdown = gr.Dropdown(label="目標欄位", choices=[], value=None) | |
| use_count_checkbox = gr.Checkbox( | |
| label="若資料有 count 欄位,將其依中位數轉成二元分類", | |
| value=True | |
| ) | |
| dist_btn = gr.Button("顯示類別分布") | |
| with gr.Column(scale=2): | |
| summary_output = gr.Textbox(label="資料摘要", lines=8) | |
| preview_output = gr.Dataframe(label="資料預覽") | |
| info_output = gr.Dataframe(label="欄位型態") | |
| missing_output = gr.Dataframe(label="缺失值統計") | |
| dist_plot = gr.Plot(label="類別分布圖") | |
| with gr.Tab("2. 單一模型訓練"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| test_size_slider = gr.Slider( | |
| label="測試集比例", | |
| minimum=0.1, | |
| maximum=0.5, | |
| step=0.1, | |
| value=0.2 | |
| ) | |
| use_scaling_checkbox = gr.Checkbox( | |
| label="使用 StandardScaler", | |
| value=True | |
| ) | |
| model_dropdown = gr.Dropdown( | |
| label="選擇模型", | |
| choices=[ | |
| "KNN", | |
| "Decision Tree", | |
| "Random Forest", | |
| "Logistic Regression", | |
| "SVM" | |
| ], | |
| value="KNN" | |
| ) | |
| gr.Markdown("## 模型參數") | |
| knn_k = gr.Slider(label="KNN:k 值", minimum=1, maximum=15, value=5, step=1) | |
| dt_criterion = gr.Dropdown( | |
| label="Decision Tree:criterion", | |
| choices=["gini", "entropy"], | |
| value="gini" | |
| ) | |
| dt_max_depth = gr.Slider( | |
| label="Decision Tree:max_depth(0 代表不限)", | |
| minimum=0, maximum=20, value=5, step=1 | |
| ) | |
| rf_estimators = gr.Slider( | |
| label="Random Forest:n_estimators", | |
| minimum=10, maximum=300, value=100, step=10 | |
| ) | |
| rf_max_depth = gr.Slider( | |
| label="Random Forest:max_depth(0 代表不限)", | |
| minimum=0, maximum=20, value=5, step=1 | |
| ) | |
| lr_c = gr.Slider( | |
| label="Logistic Regression:C", | |
| minimum=0.01, maximum=10.0, value=1.0, step=0.01 | |
| ) | |
| svm_kernel = gr.Dropdown( | |
| label="SVM:kernel", | |
| choices=["linear", "rbf"], | |
| value="rbf" | |
| ) | |
| svm_c = gr.Slider( | |
| label="SVM:C", | |
| minimum=0.01, maximum=10.0, value=1.0, step=0.01 | |
| ) | |
| train_btn = gr.Button("開始訓練單一模型", variant="primary") | |
| with gr.Column(scale=2): | |
| single_result_output = gr.Textbox(label="模型結果", lines=8) | |
| report_output = gr.Dataframe(label="Classification Report") | |
| cm_output = gr.Plot(label="Confusion Matrix") | |
| roc_output = gr.Plot(label="ROC Curve") | |
| with gr.Tab("3. 多模型比較"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| compare_btn = gr.Button("比較所有模型", variant="primary") | |
| with gr.Column(scale=2): | |
| compare_summary = gr.Textbox(label="最佳模型摘要", lines=8) | |
| compare_table = gr.Dataframe(label="模型比較表") | |
| compare_plot = gr.Plot(label="模型 Accuracy 比較圖") | |
| analyze_btn.click( | |
| fn=analyze_file, | |
| inputs=[file_input], | |
| outputs=[ | |
| preview_output, | |
| info_output, | |
| missing_output, | |
| summary_output, | |
| target_dropdown | |
| ] | |
| ) | |
| dist_btn.click( | |
| fn=target_distribution, | |
| inputs=[file_input, target_dropdown, use_count_checkbox], | |
| outputs=[dist_plot] | |
| ) | |
| train_btn.click( | |
| fn=train_single_model, | |
| inputs=[ | |
| file_input, | |
| target_dropdown, | |
| use_count_checkbox, | |
| test_size_slider, | |
| use_scaling_checkbox, | |
| model_dropdown, | |
| knn_k, | |
| dt_criterion, | |
| dt_max_depth, | |
| rf_estimators, | |
| rf_max_depth, | |
| lr_c, | |
| svm_kernel, | |
| svm_c | |
| ], | |
| outputs=[ | |
| single_result_output, | |
| report_output, | |
| cm_output, | |
| roc_output | |
| ] | |
| ) | |
| compare_btn.click( | |
| fn=compare_models, | |
| inputs=[ | |
| file_input, | |
| target_dropdown, | |
| use_count_checkbox, | |
| test_size_slider, | |
| use_scaling_checkbox | |
| ], | |
| outputs=[ | |
| compare_summary, | |
| compare_table, | |
| compare_plot | |
| ] | |
| ) | |
| demo.launch() |