| import pandas as pd |
| import numpy as np |
| import shap |
| import joblib |
| import gradio as gr |
| import matplotlib.pyplot as plt |
| import warnings |
| import os |
| import time |
|
|
| from sklearn.ensemble import RandomForestRegressor |
| from sklearn.gaussian_process import GaussianProcessRegressor |
| from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C |
| from xgboost import XGBRegressor |
| from sklearn.linear_model import LinearRegression |
|
|
| |
| DATA_FILE = "ITRI-PY-1140908-training.csv" |
| FEATURES = ['MW01', 'MW02', 'MW03', 'flux', 'MWA'] |
| TARGET = 'yield' |
|
|
| MODEL_CONFIG = { |
| 'Random Forest (RF)': { |
| 'model_file': "rf_model.joblib", |
| 'explainer_type': shap.TreeExplainer, |
| 'data_file': "rf_shap_data.joblib" |
| }, |
| 'Gaussian Process (GP)': { |
| 'model_file': "gp_model.joblib", |
| 'explainer_type': shap.KernelExplainer, |
| 'data_file': "gp_shap_data.joblib" |
| }, |
| 'XGBoost (XGB)': { |
| 'model_file': "xgb_model.joblib", |
| 'explainer_type': shap.TreeExplainer, |
| 'data_file': "xgb_shap_data.joblib" |
| }, |
| 'Linear Regression (LR)': { |
| 'model_file': "lr_model.joblib", |
| 'explainer_type': shap.LinearExplainer, |
| 'data_file': "lr_shap_data.joblib" |
| } |
| } |
| DEFAULT_MODEL = 'Random Forest (RF)' |
| GLOBAL_DATA = {'X': None, 'Background': None, 'Quartiles': {}} |
| SAMPLE_SIZE_GLOBAL_DECISION = 100 |
|
|
| |
| |
| plt.rcParams['axes.unicode_minus'] = False |
|
|
| |
| def save_plot_to_file(plt_fig, prefix): |
| """Saves a matplotlib figure to a temporary file and returns the path.""" |
| filename = f"/tmp/{prefix}_{time.time()}.png" |
| plt_fig.savefig(filename, bbox_inches='tight') |
| plt.close(plt_fig) |
| return filename |
|
|
| |
| def preprocess_data(df): |
| """Clean up and convert the yield column, and calculate quartiles for UI hints.""" |
| global GLOBAL_DATA |
| |
| if df[TARGET].dtype == 'object': |
| print(f"--- Cleaning up '%' symbols in {TARGET} column and converting to float... ---") |
| df[TARGET] = df[TARGET].astype(str).str.replace('%', '', regex=False).astype(float) |
| print(f"--- {TARGET} column successfully converted to float. ---") |
| |
| for feature in FEATURES: |
| GLOBAL_DATA['Quartiles'][feature] = df[feature].quantile([0.25, 0.5, 0.75]).to_dict() |
| |
| return df |
|
|
| |
| def train_and_save_models(): |
| """Load data, train ALL models, compute their SHAP values, and save.""" |
| global GLOBAL_DATA |
| try: |
| df = pd.read_csv(DATA_FILE) |
| except FileNotFoundError: |
| print(f"Error: Data file {DATA_FILE} not found. Please ensure the file exists.") |
| return |
|
|
| df = preprocess_data(df) |
|
|
| X = df[FEATURES] |
| y = df[TARGET] |
| |
| print("Preparing background data for KernelExplainer...") |
| try: |
| background_data = shap.maskers.KMeans(X, 10).data.values |
| except: |
| background_data = X.sample(min(100, len(X)), random_state=42).values |
| |
| if background_data.ndim == 1: |
| background_data = background_data.reshape(-1, 1) |
|
|
| GLOBAL_DATA['X'] = X |
| GLOBAL_DATA['Background'] = background_data |
|
|
| for name, config in MODEL_CONFIG.items(): |
| print(f"\n--- Processing Model: {name} ---") |
| |
| if name == 'Random Forest (RF)': |
| model = RandomForestRegressor(n_estimators=100, random_state=42) |
| elif name == 'Gaussian Process (GP)': |
| kernel = C(1.0, (1e-3, 1e3)) * RBF(10, (1e-2, 1e2)) |
| model = GaussianProcessRegressor(kernel=kernel, random_state=42, normalize_y=True) |
| elif name == 'XGBoost (XGB)': |
| model = XGBRegressor(n_estimators=100, random_state=42, use_label_encoder=False, eval_metric='rmse') |
| elif name == 'Linear Regression (LR)': |
| model = LinearRegression() |
| else: |
| continue |
| |
| print(f"Training {name}...") |
| try: |
| model.fit(X, y) |
| joblib.dump(model, config['model_file']) |
| except Exception as e: |
| print(f"Error training {name}: {e}") |
| continue |
|
|
| print(f"Calculating SHAP values for {name}...") |
| try: |
| explainer_type = config['explainer_type'] |
| |
| if explainer_type == shap.TreeExplainer: |
| explainer = shap.TreeExplainer(model) |
| shap_values = explainer.shap_values(X) |
| expected_value = explainer.expected_value |
| |
| elif explainer_type == shap.KernelExplainer: |
| explainer = shap.KernelExplainer(model.predict, background_data) |
| shap_values = explainer.shap_values(X.values) |
| expected_value = explainer.expected_value |
| |
| elif explainer_type == shap.LinearExplainer: |
| explainer = shap.LinearExplainer(model, X) |
| shap_values = explainer.shap_values(X) |
| expected_value = explainer.expected_value |
| |
| if isinstance(shap_values, list): |
| if len(shap_values) > 0 and isinstance(shap_values[0], np.ndarray): |
| shap_values = shap_values[0] |
| |
| if not isinstance(shap_values, np.ndarray): |
| shap_values = np.array(shap_values) |
|
|
| joblib.dump({'X': X, 'shap_values': shap_values, 'base_value': expected_value}, config['data_file']) |
| print(f"SHAP data for {name} saved.") |
| |
| except Exception as e: |
| print(f"Error calculating SHAP for {name}: {e}") |
| import traceback |
| traceback.print_exc() |
|
|
| |
| def get_model_explainer_and_data(model_name): |
| """Loads the Model, returns a newly created Explainer, and pre-calculated SHAP data.""" |
| if model_name not in MODEL_CONFIG: |
| return None, None, None, None, None |
| |
| config = MODEL_CONFIG[model_name] |
| |
| try: |
| model = joblib.load(config['model_file']) |
| |
| shap_data = joblib.load(config['data_file']) |
| X = shap_data['X'] |
| shap_values = shap_data['shap_values'] |
| base_value = shap_data['base_value'] |
| |
| explainer_type = config['explainer_type'] |
| |
| if explainer_type == shap.TreeExplainer: |
| explainer = shap.TreeExplainer(model) |
| elif explainer_type == shap.KernelExplainer: |
| if GLOBAL_DATA['Background'] is None: |
| raise Exception("Global background data is missing. Rerun training.") |
| explainer = shap.KernelExplainer(model.predict, GLOBAL_DATA['Background']) |
| elif explainer_type == shap.LinearExplainer: |
| explainer = shap.LinearExplainer(model, X) |
| else: |
| explainer = None |
| |
| return model, explainer, X, shap_values, base_value |
| |
| except Exception as e: |
| print(f"Error loading or creating explainer for {model_name}: {e}") |
| return None, None, None, None, None |
|
|
| |
| def get_local_explanation(model_name, MW01, MW02, MW03, flux, MWA): |
| """Predict and generate SHAP Waterfall Plot for a single input.""" |
| if model_name is None: model_name = DEFAULT_MODEL |
| model, explainer, _, _, _ = get_model_explainer_and_data(model_name) |
| if explainer is None: return f"Error: Model or Explainer for {model_name} failed to load/create.", None |
| input_data = pd.DataFrame([[MW01, MW02, MW03, flux, MWA]], columns=FEATURES) |
| if isinstance(explainer, shap.KernelExplainer): shap_values = explainer.shap_values(input_data.values)[0] |
| else: shap_values = explainer.shap_values(input_data)[0] |
| prediction = model.predict(input_data)[0] |
| base_value = explainer.expected_value |
| plt.figure(figsize=(8, 6)) |
| shap_object = shap.Explanation(values=shap_values, base_values=base_value, data=input_data.iloc[0].values, feature_names=FEATURES) |
| with warnings.catch_warnings(): warnings.simplefilter("ignore"); shap.plots.waterfall(shap_object, show=False) |
| plt.title(f"SHAP Waterfall Plot: {model_name} Explanation") |
| plt.tight_layout() |
| file_path = save_plot_to_file(plt.gcf(), f"{model_name.replace(' ', '_')}_waterfall") |
| result_text = f"{model_name} Prediction for {TARGET}: {prediction:.4f}" |
| return result_text, file_path |
|
|
| |
| def get_summary_plot(model_name): |
| """Generate SHAP Summary Plot for global feature importance.""" |
| if model_name is None: model_name = DEFAULT_MODEL |
| _, _, X, shap_values, _ = get_model_explainer_and_data(model_name) |
| if shap_values is None: return None |
| plt.figure(figsize=(10, 8)) |
| with warnings.catch_warnings(): warnings.simplefilter("ignore"); shap.summary_plot(shap_values, X, show=False) |
| plt.title(f"SHAP Summary Plot ({model_name}): Global Importance") |
| plt.tight_layout() |
| file_path = save_plot_to_file(plt.gcf(), f"{model_name.replace(' ', '_')}_summary") |
| return file_path |
|
|
| |
| def get_dependence_plot(model_name, feature_to_plot, interaction_feature): |
| """Generate SHAP Dependence Plot.""" |
| if model_name is None: model_name = DEFAULT_MODEL |
| _, _, X, shap_values, _ = get_model_explainer_and_data(model_name) |
| if shap_values is None: return None |
| if interaction_feature == 'None': interaction_feature = None |
| plt.figure(figsize=(8, 6)) |
| with warnings.catch_warnings(): warnings.simplefilter("ignore"); shap.dependence_plot(feature_to_plot, shap_values, X, interaction_index=interaction_feature, show=False) |
| plt.title(f"SHAP Dependence Plot ({model_name}): Effect of {feature_to_plot}") |
| plt.tight_layout() |
| file_path = save_plot_to_file(plt.gcf(), f"{model_name.replace(' ', '_')}_dependence") |
| return file_path |
|
|
| |
| def get_decision_plot(model_name, MW01, MW02, MW03, flux, MWA): |
| """Generate SHAP Decision Plot.""" |
| if model_name is None: model_name = DEFAULT_MODEL |
| _, explainer, _, _, _ = get_model_explainer_and_data(model_name) |
| if explainer is None: return None |
| input_data = pd.DataFrame([[MW01, MW02, MW03, flux, MWA]], columns=FEATURES) |
| if isinstance(explainer, shap.KernelExplainer): shap_values = explainer.shap_values(input_data.values) |
| else: shap_values = explainer.shap_values(input_data) |
| plt.figure(figsize=(10, 8)) |
| with warnings.catch_warnings(): warnings.simplefilter("ignore"); shap.decision_plot(explainer.expected_value, shap_values[0], input_data.iloc[0], show=False, feature_names=FEATURES, title=f"SHAP Decision Plot ({model_name}): Prediction Path") |
| plt.tight_layout() |
| file_path = save_plot_to_file(plt.gcf(), f"{model_name.replace(' ', '_')}_decision") |
| return file_path |
|
|
| |
| def get_static_force_plot(model_name, MW01, MW02, MW03, flux, MWA): |
| """Generate SHAP Static Force Plot (Matplotlib).""" |
| if model_name is None: model_name = DEFAULT_MODEL |
| _, explainer, _, _, _ = get_model_explainer_and_data(model_name) |
| if explainer is None: return None |
| input_data = pd.DataFrame([[MW01, MW02, MW03, flux, MWA]], columns=FEATURES) |
| if isinstance(explainer, shap.KernelExplainer): shap_values = explainer.shap_values(input_data.values) |
| else: shap_values = explainer.shap_values(input_data) |
| plt.figure(figsize=(12, 4)) |
| with warnings.catch_warnings(): warnings.simplefilter("ignore"); shap.force_plot(explainer.expected_value, shap_values[0], input_data.iloc[0], show=False, matplotlib=True) |
| plt.title(f"SHAP Force Plot ({model_name}) - Static Version") |
| plt.tight_layout() |
| file_path = save_plot_to_file(plt.gcf(), f"{model_name.replace(' ', '_')}_force") |
| return file_path |
| |
| |
| def get_global_decision_plot(model_name): |
| """Generate a SHAP Decision Plot showing paths for a sample of the entire dataset.""" |
| if model_name is None: model_name = DEFAULT_MODEL |
| model, explainer, X, shap_values_full, base_value = get_model_explainer_and_data(model_name) |
| if explainer is None or X is None: return None |
| n_samples = min(SAMPLE_SIZE_GLOBAL_DECISION, len(X)) |
| sample_indices = X.sample(n_samples, random_state=42).index |
| X_sample = X.loc[sample_indices] |
| try: shap_values_sample = shap_values_full[sample_indices] |
| except: |
| if isinstance(explainer, shap.KernelExplainer): shap_values_sample = explainer.shap_values(X_sample.values) |
| else: shap_values_sample = explainer.shap_values(X_sample) |
| plt.figure(figsize=(10, 10)) |
| with warnings.catch_warnings(): warnings.simplefilter("ignore"); shap.decision_plot(base_value, shap_values_sample, X_sample, show=False, feature_names=FEATURES, title=f"SHAP Global Decision Plot ({model_name}): {n_samples} Samples") |
| plt.tight_layout() |
| file_path = save_plot_to_file(plt.gcf(), f"{model_name.replace(' ', '_')}_global_decision") |
| return file_path |
|
|
| |
| def get_feature_label(feature, value): |
| """Returns a label indicating if the value is Low, Mid, or High based on quartiles (English).""" |
| if feature not in GLOBAL_DATA['Quartiles']: |
| return f"{feature} ({value})" |
| |
| q = GLOBAL_DATA['Quartiles'][feature] |
| q25, q50, q75 = q[0.25], q[0.5], q[0.75] |
| |
| status = "" |
| if value < q25: |
| status = "[Below Q1 (Low)]" |
| elif value <= q50: |
| status = "[Near Median]" |
| elif value <= q75: |
| status = "[Near Q3]" |
| else: |
| status = "[Above Q3 (High)]" |
| |
| return f"{feature} ({value}) {status}" |
|
|
|
|
| def get_local_plots_combined(model_name, MW01, MW02, MW03, flux, MWA): |
| """Calculates prediction, Waterfall, and Decision plots simultaneously.""" |
| if model_name is None: model_name = DEFAULT_MODEL |
| model, explainer, _, _, _ = get_model_explainer_and_data(model_name) |
| if explainer is None: |
| error_msg = f"Error: Model or Explainer for {model_name} failed to load/create." |
| return error_msg, None, None |
| |
| input_values = [MW01, MW02, MW03, flux, MWA] |
| input_data = pd.DataFrame([input_values], columns=FEATURES) |
| |
| if isinstance(explainer, shap.KernelExplainer): |
| shap_values = explainer.shap_values(input_data.values) |
| else: |
| shap_values = explainer.shap_values(input_data) |
| |
| prediction = model.predict(input_data)[0] |
| base_value = explainer.expected_value |
| |
| custom_feature_names = [get_feature_label(FEATURES[i], input_values[i]) for i in range(len(FEATURES))] |
| |
| |
| plt.figure(figsize=(8, 6)) |
| shap_object = shap.Explanation( |
| values=shap_values[0], base_values=base_value, data=input_data.iloc[0].values, feature_names=custom_feature_names |
| ) |
| with warnings.catch_warnings(): |
| warnings.simplefilter("ignore") |
| shap.plots.waterfall(shap_object, show=False) |
| plt.title(f"1. Waterfall Plot: {model_name} Prediction Breakdown") |
| plt.tight_layout() |
| waterfall_path = save_plot_to_file(plt.gcf(), f"{model_name.replace(' ', '_')}_combined_waterfall") |
|
|
| |
| plt.figure(figsize=(8, 6)) |
| with warnings.catch_warnings(): |
| warnings.simplefilter("ignore") |
| shap.decision_plot( |
| base_value, shap_values[0], input_data.iloc[0], show=False, feature_names=custom_feature_names, |
| title=f"2. Decision Plot: {model_name} Prediction Path" |
| ) |
| plt.tight_layout() |
| decision_path = save_plot_to_file(plt.gcf(), f"{model_name.replace(' ', '_')}_combined_decision") |
| |
| result_text = f"Predicted {TARGET}: {prediction:.4f}" |
| |
| return result_text, waterfall_path, decision_path |
|
|
|
|
| |
| train_and_save_models() |
| print("All models trained and SHAP data prepared.") |
|
|
| |
| model_choices = list(MODEL_CONFIG.keys()) |
|
|
| |
| with gr.Blocks() as optimal_block: |
| |
| |
| gr.Markdown( |
| """ |
| ## 最佳預測尋找器(整合分析與優化工作流程) |
| **目的:** 此分頁引導您透過迭代的工作流程,結合「全域」和「局部」SHAP 分析,來尋找最佳的特徵設定,以最大化目標變數(`yield`,良率)。 |
| |
| ### 功能函數 (Functionality): |
| 1. **全域情境 (`get_summary_plot`):** 先執行此步驟,以識別哪些特徵最重要,以及哪些數值(高/低)會將預測值推高(紅點)或推低(藍點)。 |
| 2. **迭代輸入:** 根據全域情境調整輸入參數。圖表中的特徵標籤會動態顯示您輸入值在數據集中的百分位數範圍(例如:`[Below Q1 (Low)]`, `[Near Median]` 等),以協助微調。 |
| 3. **局部驗證 (`get_local_plots_combined`):** 同時生成預測值和局部解釋圖,以驗證您的新設定是否正在正面地推動預測值。 |
| |
| ### 輸出解讀 (Interpretation of Output): |
| * **瀑布圖 (Waterfall Plot):** 精確地顯示每個特徵如何對最終預測值做出貢獻,將其從 `Base Value`(平均預測值)推向 `Predicted Yield`。為了優化,目標是讓重要變數產生**正面**(紅色)的貢獻。 |
| * **決策圖 (Decision Plot):** 將每個特徵的累積影響視覺化。路徑向右移動表示預測值更高。 |
| """ |
| ) |
| gr.Markdown("---") |
| |
| gr.Markdown("### Step 1-2: Select Model and Get Global Context") |
| |
| with gr.Row(): |
| optimal_model_selector = gr.Radio(model_choices, value=DEFAULT_MODEL, label="1. Select Optimal Model (Highest Accuracy)") |
| summary_btn_opt = gr.Button("2. Generate Global Summary Plot (Reference: Find key variable ideal trends)") |
| |
| summary_image_opt = gr.Image(label="Global Feature Importance Context (Summary Plot)", type="filepath", interactive=False, height=250) |
| |
| summary_btn_opt.click(fn=get_summary_plot, inputs=[optimal_model_selector], outputs=summary_image_opt) |
|
|
| gr.Markdown("---") |
| gr.Markdown("### Step 3-4: Iterate Input, Predict, and Locally Validate") |
| gr.Markdown("Based on the ideal trend shown in the Summary Plot (blue/red dot distribution), **fine-tune the variable values**. The plot labels will now hint at your input value's position within the dataset (Low/Mid/High).") |
| |
| with gr.Row(): |
| optimal_input_MW01 = gr.Number(label="MW01", minimum=0, maximum=1000, value=800) |
| optimal_input_MW02 = gr.Number(label="MW02", minimum=0, maximum=1000, value=800) |
| optimal_input_MW03 = gr.Number(label="MW03", minimum=0, maximum=1000, value=300) |
| optimal_input_flux = gr.Slider(label="flux", minimum=0.5, maximum=2, step=0.5, value=2) |
| optimal_input_MWA = gr.Slider(label="MWA", minimum=5, maximum=20, step=5, value=5) |
| |
| optimal_inputs_list = [optimal_model_selector, optimal_input_MW01, optimal_input_MW02, optimal_input_MW03, optimal_input_flux, optimal_input_MWA] |
|
|
| optimize_btn = gr.Button("4. Run Prediction & Generate Local Explanation (Validate/Fine-tune)") |
| |
| result_text_opt = gr.Textbox(label="Final Prediction Result") |
| |
| with gr.Row(): |
| waterfall_image_opt = gr.Image(label="Validation Plot 1: Waterfall Plot (Check if variable contributions are positive as expected)", type="filepath", interactive=False, height=350) |
| decision_image_opt = gr.Image(label="Validation Plot 2: Decision Plot (Check prediction path and baseline offset)", type="filepath", interactive=False, height=350) |
| |
| optimize_btn.click( |
| fn=get_local_plots_combined, |
| inputs=optimal_inputs_list, |
| outputs=[result_text_opt, waterfall_image_opt, decision_image_opt] |
| ) |
|
|
|
|
| |
| with gr.Blocks() as local_block: |
| |
| |
| gr.Markdown( |
| """ |
| ## 局部解釋(多圖表) |
| **目的:** 此分頁作為測試區,用於生成各種局部(單一實例)SHAP 圖,以理解特定預測的生成原因。它允許單獨控制每種圖表的生成。 |
| |
| ### 功能函數 (Functionality): |
| * **預測與瀑布圖 (`get_local_explanation`):** 預測目標值並生成顯示特徵貢獻的瀑布圖。 |
| * **決策圖 (`get_decision_plot`):** 顯示從基線(Base Value)開始的預測路徑。 |
| * **力圖 (`get_static_force_plot`):** 提供正面/負面貢獻的高層次可視化。 |
| |
| ### 輸出解讀 (Interpretation of Output): |
| * **瀑布圖:** 將輸出推高的特徵顯示為**紅色**。推低的顯示為**藍色**。所有貢獻的總和等於預測值與基線值之間的差異。 |
| * **決策圖:** 水平路徑(X 軸)代表預測值。較粗的線顯示特徵值如何將預測值從平均值(基線)轉移到最終預測值。 |
| * **力圖:** 紅色特徵將預測值推**高**(向右);藍色特徵將其推**低**(向左)。 |
| """ |
| ) |
| gr.Markdown("---") |
| |
| local_model_selector = gr.Radio(model_choices, value=DEFAULT_MODEL, label="Select Model for SHAP Analysis") |
| with gr.Row(): |
| local_input_MW01 = gr.Number(label="MW01", minimum=0, maximum=1000, value=800) |
| local_input_MW02 = gr.Number(label="MW02", minimum=0, maximum=1000, value=800) |
| local_input_MW03 = gr.Number(label="MW03", minimum=0, maximum=1000, value=300) |
| local_input_flux = gr.Slider(label="flux", minimum=0.5, maximum=2, step=0.5, value=2) |
| local_input_MWA = gr.Slider(label="MWA", minimum=5, maximum=20, step=5, value=5) |
|
|
| local_inputs_list = [local_model_selector, local_input_MW01, local_input_MW02, local_input_MW03, local_input_flux, local_input_MWA] |
| |
| result_text = gr.Textbox(label="Model Prediction Result") |
| |
| with gr.Row(): |
| waterfall_btn = gr.Button("1. Generate Waterfall Plot") |
| decision_btn = gr.Button("2. Generate Decision Plot") |
| force_btn = gr.Button("3. Generate Static Force Plot") |
| |
| gr.Markdown("---") |
| |
| with gr.Row(): |
| waterfall_image = gr.Image(label="1. SHAP Waterfall Plot", type="filepath", interactive=False, height=300) |
| |
| with gr.Row(): |
| decision_image = gr.Image(label="2. SHAP Decision Plot", type="filepath", interactive=False, height=400) |
| |
| with gr.Row(): |
| force_static_image = gr.Image(label="3. SHAP Force Plot (Static)", type="filepath", interactive=False, height=200) |
|
|
| waterfall_btn.click(fn=get_local_explanation, inputs=local_inputs_list, outputs=[result_text, waterfall_image]) |
| decision_btn.click(fn=get_decision_plot, inputs=local_inputs_list, outputs=decision_image) |
| force_btn.click(fn=get_static_force_plot, inputs=local_inputs_list, outputs=force_static_image) |
|
|
|
|
| |
| with gr.Blocks() as global_summary_block: |
| |
| |
| gr.Markdown( |
| """ |
| ## 全域分析:特徵重要性與決策路徑 |
| **目的:** 此分頁提供特徵重要性的全域概覽,以及特徵如何影響整個數據集的目標變數,有助於確立特徵的優先順序和理想的數值範圍。 |
| |
| ### 功能函數 (Functionality): |
| * **摘要圖 (`get_summary_plot`):** 計算數據集中所有實例的 SHAP 數值平均幅度。 |
| * **全域決策圖 (`get_global_decision_plot`):** 顯示大量樣本數據的預測路徑,揭示整體趨勢和異常值。 |
| |
| ### 輸出解讀 (Interpretation of Output): |
| * **摘要圖:** |
| * **Y 軸:** 特徵依全域重要性排序(頂部最重要)。 |
| * **X 軸:** SHAP 數值(對模型輸出的影響)。 |
| * **顏色(紅/藍):** 特徵值(紅色 = 高,藍色 = 低)。 |
| * **關鍵見解:** 尋找高數值(紅色)一致地導致高正向 SHAP 數值的特徵,反之亦然,以確定最大化的理想數值範圍。 |
| * **全域決策圖:** 這是許多預測路徑的集合。水平寬度表示變異性,路徑聚集的位置顯示了常見的預測值。 |
| """ |
| ) |
| gr.Markdown("---") |
| |
| global_model_selector_sum = gr.Radio(model_choices, value=DEFAULT_MODEL, label="Select Model for Global Analysis") |
| global_model_selector_sum |
| gr.Markdown("### 1. Feature Importance Summary") |
| summary_btn = gr.Button("Generate Summary Plot") |
| summary_image = gr.Image(label="SHAP Summary Plot: Global Importance", type="filepath", interactive=False, height=500) |
| summary_btn.click(fn=get_summary_plot, inputs=[global_model_selector_sum], outputs=summary_image) |
| gr.Markdown("### 2. Global Decision Path (Sampled)") |
| global_decision_btn = gr.Button(f"Generate Global Decision Plot (Sampling {SAMPLE_SIZE_GLOBAL_DECISION} points)") |
| global_decision_image = gr.Image(label="SHAP Decision Plot: Paths of Sampled Data", type="filepath", interactive=False, height=600) |
| global_decision_btn.click(fn=get_global_decision_plot, inputs=[global_model_selector_sum], outputs=global_decision_image) |
|
|
|
|
| |
| with gr.Blocks() as dependence_interface: |
| |
| |
| gr.Markdown( |
| """ |
| ## 特徵依賴圖 |
| **目的:** 此圖表有助於分析單一特徵對預測的邊際效應,以及此效應如何被第二個「互動」特徵所修改。 |
| |
| ### 功能函數 (Functionality): |
| * **依賴圖 (`get_dependence_plot`):** 繪製特徵的實際數值(X 軸)與其 SHAP 數值(Y 軸)之間的關係,並可選用一個顏色特徵來顯示互動效應。 |
| |
| ### 輸出解讀 (Interpretation of Output): |
| * **主要特徵(X 軸):** 顯示**趨勢**。如果隨著 X 軸數值的增加,點向上移動,則該特徵對輸出有正向的邊際影響。 |
| * **顏色(互動特徵):** 如果著色特徵呈現出明顯的模式(例如,所有高顏色點聚集在更高或更低的位置),則表示這兩個特徵之間存在強烈的互動關係。 |
| """ |
| ) |
| gr.Markdown("---") |
| |
| gr.Interface( |
| fn=get_dependence_plot, |
| inputs=[ |
| gr.Radio(model_choices, value=DEFAULT_MODEL, label="Select Model for SHAP Analysis"), |
| gr.Dropdown(label="Select Feature for X-axis", choices=FEATURES, value=FEATURES[0]), |
| gr.Dropdown(label="Select Interaction Feature (Color)", choices=['None'] + FEATURES, value='None') |
| ], |
| outputs=gr.Image(label="SHAP Dependence Plot: Feature Dependence", type="filepath", interactive=False, height=350), |
| flagging_mode="never", |
| title="Feature Dependence on SHAP Value" |
| ) |
|
|
| |
| full_interface = gr.TabbedInterface( |
| [optimal_block, local_block, global_summary_block, dependence_interface], |
| ["Optimal Prediction Finder (Integrated)", "Local Explanation (Multi-Plot)", "Global Analysis (Summary & Decision)", "Feature Dependence Plot"] |
| ) |
|
|
| if __name__ == "__main__": |
| full_interface.launch() |