import pandas as pd import numpy as np import shap import joblib import gradio as gr import matplotlib.pyplot as plt import warnings import os import time from sklearn.ensemble import RandomForestRegressor from sklearn.gaussian_process import GaussianProcessRegressor from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C from xgboost import XGBRegressor from sklearn.linear_model import LinearRegression # --- 1. Data and Model Configuration --- DATA_FILE = "ITRI-PY-1140908-training.csv" FEATURES = ['MW01', 'MW02', 'MW03', 'flux', 'MWA'] TARGET = 'yield' MODEL_CONFIG = { 'Random Forest (RF)': { 'model_file': "rf_model.joblib", 'explainer_type': shap.TreeExplainer, 'data_file': "rf_shap_data.joblib" }, 'Gaussian Process (GP)': { 'model_file': "gp_model.joblib", 'explainer_type': shap.KernelExplainer, 'data_file': "gp_shap_data.joblib" }, 'XGBoost (XGB)': { 'model_file': "xgb_model.joblib", 'explainer_type': shap.TreeExplainer, 'data_file': "xgb_shap_data.joblib" }, 'Linear Regression (LR)': { 'model_file': "lr_model.joblib", 'explainer_type': shap.LinearExplainer, 'data_file': "lr_shap_data.joblib" } } DEFAULT_MODEL = 'Random Forest (RF)' GLOBAL_DATA = {'X': None, 'Background': None, 'Quartiles': {}} SAMPLE_SIZE_GLOBAL_DECISION = 100 # --- 2. Matplotlib Configuration --- # 保持無中文字體設定以確保繪圖穩定 plt.rcParams['axes.unicode_minus'] = False # --- 3. Helper Functions: Plot Saving --- def save_plot_to_file(plt_fig, prefix): """Saves a matplotlib figure to a temporary file and returns the path.""" filename = f"/tmp/{prefix}_{time.time()}.png" plt_fig.savefig(filename, bbox_inches='tight') plt.close(plt_fig) return filename # --- 4. Data Preprocessing and Quartile Calculation --- def preprocess_data(df): """Clean up and convert the yield column, and calculate quartiles for UI hints.""" global GLOBAL_DATA if df[TARGET].dtype == 'object': print(f"--- Cleaning up '%' symbols in {TARGET} column and converting to float... ---") df[TARGET] = df[TARGET].astype(str).str.replace('%', '', regex=False).astype(float) print(f"--- {TARGET} column successfully converted to float. ---") for feature in FEATURES: GLOBAL_DATA['Quartiles'][feature] = df[feature].quantile([0.25, 0.5, 0.75]).to_dict() return df # --- 5. Model Training and SHAP Calculation (Omitted for brevity) --- def train_and_save_models(): """Load data, train ALL models, compute their SHAP values, and save.""" global GLOBAL_DATA try: df = pd.read_csv(DATA_FILE) except FileNotFoundError: print(f"Error: Data file {DATA_FILE} not found. Please ensure the file exists.") return df = preprocess_data(df) X = df[FEATURES] y = df[TARGET] print("Preparing background data for KernelExplainer...") try: background_data = shap.maskers.KMeans(X, 10).data.values except: background_data = X.sample(min(100, len(X)), random_state=42).values if background_data.ndim == 1: background_data = background_data.reshape(-1, 1) GLOBAL_DATA['X'] = X GLOBAL_DATA['Background'] = background_data for name, config in MODEL_CONFIG.items(): print(f"\n--- Processing Model: {name} ---") if name == 'Random Forest (RF)': model = RandomForestRegressor(n_estimators=100, random_state=42) elif name == 'Gaussian Process (GP)': kernel = C(1.0, (1e-3, 1e3)) * RBF(10, (1e-2, 1e2)) model = GaussianProcessRegressor(kernel=kernel, random_state=42, normalize_y=True) elif name == 'XGBoost (XGB)': model = XGBRegressor(n_estimators=100, random_state=42, use_label_encoder=False, eval_metric='rmse') elif name == 'Linear Regression (LR)': model = LinearRegression() else: continue print(f"Training {name}...") try: model.fit(X, y) joblib.dump(model, config['model_file']) except Exception as e: print(f"Error training {name}: {e}") continue print(f"Calculating SHAP values for {name}...") try: explainer_type = config['explainer_type'] if explainer_type == shap.TreeExplainer: explainer = shap.TreeExplainer(model) shap_values = explainer.shap_values(X) expected_value = explainer.expected_value elif explainer_type == shap.KernelExplainer: explainer = shap.KernelExplainer(model.predict, background_data) shap_values = explainer.shap_values(X.values) expected_value = explainer.expected_value elif explainer_type == shap.LinearExplainer: explainer = shap.LinearExplainer(model, X) shap_values = explainer.shap_values(X) expected_value = explainer.expected_value if isinstance(shap_values, list): if len(shap_values) > 0 and isinstance(shap_values[0], np.ndarray): shap_values = shap_values[0] if not isinstance(shap_values, np.ndarray): shap_values = np.array(shap_values) joblib.dump({'X': X, 'shap_values': shap_values, 'base_value': expected_value}, config['data_file']) print(f"SHAP data for {name} saved.") except Exception as e: print(f"Error calculating SHAP for {name}: {e}") import traceback traceback.print_exc() # --- 6. Helper Functions: Model Loading --- def get_model_explainer_and_data(model_name): """Loads the Model, returns a newly created Explainer, and pre-calculated SHAP data.""" if model_name not in MODEL_CONFIG: return None, None, None, None, None config = MODEL_CONFIG[model_name] try: model = joblib.load(config['model_file']) shap_data = joblib.load(config['data_file']) X = shap_data['X'] shap_values = shap_data['shap_values'] base_value = shap_data['base_value'] explainer_type = config['explainer_type'] if explainer_type == shap.TreeExplainer: explainer = shap.TreeExplainer(model) elif explainer_type == shap.KernelExplainer: if GLOBAL_DATA['Background'] is None: raise Exception("Global background data is missing. Rerun training.") explainer = shap.KernelExplainer(model.predict, GLOBAL_DATA['Background']) elif explainer_type == shap.LinearExplainer: explainer = shap.LinearExplainer(model, X) else: explainer = None return model, explainer, X, shap_values, base_value except Exception as e: print(f"Error loading or creating explainer for {model_name}: {e}") return None, None, None, None, None # --- 7. Local Explanation: Waterfall Plot --- def get_local_explanation(model_name, MW01, MW02, MW03, flux, MWA): """Predict and generate SHAP Waterfall Plot for a single input.""" if model_name is None: model_name = DEFAULT_MODEL model, explainer, _, _, _ = get_model_explainer_and_data(model_name) if explainer is None: return f"Error: Model or Explainer for {model_name} failed to load/create.", None input_data = pd.DataFrame([[MW01, MW02, MW03, flux, MWA]], columns=FEATURES) if isinstance(explainer, shap.KernelExplainer): shap_values = explainer.shap_values(input_data.values)[0] else: shap_values = explainer.shap_values(input_data)[0] prediction = model.predict(input_data)[0] base_value = explainer.expected_value plt.figure(figsize=(8, 6)) shap_object = shap.Explanation(values=shap_values, base_values=base_value, data=input_data.iloc[0].values, feature_names=FEATURES) with warnings.catch_warnings(): warnings.simplefilter("ignore"); shap.plots.waterfall(shap_object, show=False) plt.title(f"SHAP Waterfall Plot: {model_name} Explanation") plt.tight_layout() file_path = save_plot_to_file(plt.gcf(), f"{model_name.replace(' ', '_')}_waterfall") result_text = f"{model_name} Prediction for {TARGET}: {prediction:.4f}" return result_text, file_path # --- 8. Global Explanation: Summary Plot --- def get_summary_plot(model_name): """Generate SHAP Summary Plot for global feature importance.""" if model_name is None: model_name = DEFAULT_MODEL _, _, X, shap_values, _ = get_model_explainer_and_data(model_name) if shap_values is None: return None plt.figure(figsize=(10, 8)) with warnings.catch_warnings(): warnings.simplefilter("ignore"); shap.summary_plot(shap_values, X, show=False) plt.title(f"SHAP Summary Plot ({model_name}): Global Importance") plt.tight_layout() file_path = save_plot_to_file(plt.gcf(), f"{model_name.replace(' ', '_')}_summary") return file_path # --- 9. Global Explanation: Dependence Plot --- def get_dependence_plot(model_name, feature_to_plot, interaction_feature): """Generate SHAP Dependence Plot.""" if model_name is None: model_name = DEFAULT_MODEL _, _, X, shap_values, _ = get_model_explainer_and_data(model_name) if shap_values is None: return None if interaction_feature == 'None': interaction_feature = None plt.figure(figsize=(8, 6)) with warnings.catch_warnings(): warnings.simplefilter("ignore"); shap.dependence_plot(feature_to_plot, shap_values, X, interaction_index=interaction_feature, show=False) plt.title(f"SHAP Dependence Plot ({model_name}): Effect of {feature_to_plot}") plt.tight_layout() file_path = save_plot_to_file(plt.gcf(), f"{model_name.replace(' ', '_')}_dependence") return file_path # --- 10. Local Explanation: Decision Plot --- def get_decision_plot(model_name, MW01, MW02, MW03, flux, MWA): """Generate SHAP Decision Plot.""" if model_name is None: model_name = DEFAULT_MODEL _, explainer, _, _, _ = get_model_explainer_and_data(model_name) if explainer is None: return None input_data = pd.DataFrame([[MW01, MW02, MW03, flux, MWA]], columns=FEATURES) if isinstance(explainer, shap.KernelExplainer): shap_values = explainer.shap_values(input_data.values) else: shap_values = explainer.shap_values(input_data) plt.figure(figsize=(10, 8)) with warnings.catch_warnings(): warnings.simplefilter("ignore"); shap.decision_plot(explainer.expected_value, shap_values[0], input_data.iloc[0], show=False, feature_names=FEATURES, title=f"SHAP Decision Plot ({model_name}): Prediction Path") plt.tight_layout() file_path = save_plot_to_file(plt.gcf(), f"{model_name.replace(' ', '_')}_decision") return file_path # --- 11. Local Explanation: Static Force Plot --- def get_static_force_plot(model_name, MW01, MW02, MW03, flux, MWA): """Generate SHAP Static Force Plot (Matplotlib).""" if model_name is None: model_name = DEFAULT_MODEL _, explainer, _, _, _ = get_model_explainer_and_data(model_name) if explainer is None: return None input_data = pd.DataFrame([[MW01, MW02, MW03, flux, MWA]], columns=FEATURES) if isinstance(explainer, shap.KernelExplainer): shap_values = explainer.shap_values(input_data.values) else: shap_values = explainer.shap_values(input_data) plt.figure(figsize=(12, 4)) with warnings.catch_warnings(): warnings.simplefilter("ignore"); shap.force_plot(explainer.expected_value, shap_values[0], input_data.iloc[0], show=False, matplotlib=True) plt.title(f"SHAP Force Plot ({model_name}) - Static Version") plt.tight_layout() file_path = save_plot_to_file(plt.gcf(), f"{model_name.replace(' ', '_')}_force") return file_path # --- 12. Global Decision Plot (Multi-sample) --- def get_global_decision_plot(model_name): """Generate a SHAP Decision Plot showing paths for a sample of the entire dataset.""" if model_name is None: model_name = DEFAULT_MODEL model, explainer, X, shap_values_full, base_value = get_model_explainer_and_data(model_name) if explainer is None or X is None: return None n_samples = min(SAMPLE_SIZE_GLOBAL_DECISION, len(X)) sample_indices = X.sample(n_samples, random_state=42).index X_sample = X.loc[sample_indices] try: shap_values_sample = shap_values_full[sample_indices] except: if isinstance(explainer, shap.KernelExplainer): shap_values_sample = explainer.shap_values(X_sample.values) else: shap_values_sample = explainer.shap_values(X_sample) plt.figure(figsize=(10, 10)) with warnings.catch_warnings(): warnings.simplefilter("ignore"); shap.decision_plot(base_value, shap_values_sample, X_sample, show=False, feature_names=FEATURES, title=f"SHAP Global Decision Plot ({model_name}): {n_samples} Samples") plt.tight_layout() file_path = save_plot_to_file(plt.gcf(), f"{model_name.replace(' ', '_')}_global_decision") return file_path # --- 13. Local Analysis Combination for Optimal Finder --- def get_feature_label(feature, value): """Returns a label indicating if the value is Low, Mid, or High based on quartiles (English).""" if feature not in GLOBAL_DATA['Quartiles']: return f"{feature} ({value})" q = GLOBAL_DATA['Quartiles'][feature] q25, q50, q75 = q[0.25], q[0.5], q[0.75] status = "" if value < q25: status = "[Below Q1 (Low)]" elif value <= q50: status = "[Near Median]" elif value <= q75: status = "[Near Q3]" else: status = "[Above Q3 (High)]" return f"{feature} ({value}) {status}" def get_local_plots_combined(model_name, MW01, MW02, MW03, flux, MWA): """Calculates prediction, Waterfall, and Decision plots simultaneously.""" if model_name is None: model_name = DEFAULT_MODEL model, explainer, _, _, _ = get_model_explainer_and_data(model_name) if explainer is None: error_msg = f"Error: Model or Explainer for {model_name} failed to load/create." return error_msg, None, None input_values = [MW01, MW02, MW03, flux, MWA] input_data = pd.DataFrame([input_values], columns=FEATURES) if isinstance(explainer, shap.KernelExplainer): shap_values = explainer.shap_values(input_data.values) else: shap_values = explainer.shap_values(input_data) prediction = model.predict(input_data)[0] base_value = explainer.expected_value custom_feature_names = [get_feature_label(FEATURES[i], input_values[i]) for i in range(len(FEATURES))] # 3. Generate Waterfall Plot plt.figure(figsize=(8, 6)) shap_object = shap.Explanation( values=shap_values[0], base_values=base_value, data=input_data.iloc[0].values, feature_names=custom_feature_names ) with warnings.catch_warnings(): warnings.simplefilter("ignore") shap.plots.waterfall(shap_object, show=False) plt.title(f"1. Waterfall Plot: {model_name} Prediction Breakdown") plt.tight_layout() waterfall_path = save_plot_to_file(plt.gcf(), f"{model_name.replace(' ', '_')}_combined_waterfall") # 4. Generate Decision Plot plt.figure(figsize=(8, 6)) with warnings.catch_warnings(): warnings.simplefilter("ignore") shap.decision_plot( base_value, shap_values[0], input_data.iloc[0], show=False, feature_names=custom_feature_names, title=f"2. Decision Plot: {model_name} Prediction Path" ) plt.tight_layout() decision_path = save_plot_to_file(plt.gcf(), f"{model_name.replace(' ', '_')}_combined_decision") result_text = f"Predicted {TARGET}: {prediction:.4f}" return result_text, waterfall_path, decision_path # --- 14. Execute Model Training --- train_and_save_models() print("All models trained and SHAP data prepared.") # --- 15. Create Gradio Interface (Explanation Frames in Chinese) --- model_choices = list(MODEL_CONFIG.keys()) # --- D. Optimal Prediction Finder (Integrated Analysis) --- with gr.Blocks() as optimal_block: # --- Explanation Frame (Chinese) --- gr.Markdown( """ ## 最佳預測尋找器(整合分析與優化工作流程) **目的:** 此分頁引導您透過迭代的工作流程,結合「全域」和「局部」SHAP 分析,來尋找最佳的特徵設定,以最大化目標變數(`yield`,良率)。 ### 功能函數 (Functionality): 1. **全域情境 (`get_summary_plot`):** 先執行此步驟,以識別哪些特徵最重要,以及哪些數值(高/低)會將預測值推高(紅點)或推低(藍點)。 2. **迭代輸入:** 根據全域情境調整輸入參數。圖表中的特徵標籤會動態顯示您輸入值在數據集中的百分位數範圍(例如:`[Below Q1 (Low)]`, `[Near Median]` 等),以協助微調。 3. **局部驗證 (`get_local_plots_combined`):** 同時生成預測值和局部解釋圖,以驗證您的新設定是否正在正面地推動預測值。 ### 輸出解讀 (Interpretation of Output): * **瀑布圖 (Waterfall Plot):** 精確地顯示每個特徵如何對最終預測值做出貢獻,將其從 `Base Value`(平均預測值)推向 `Predicted Yield`。為了優化,目標是讓重要變數產生**正面**(紅色)的貢獻。 * **決策圖 (Decision Plot):** 將每個特徵的累積影響視覺化。路徑向右移動表示預測值更高。 """ ) gr.Markdown("---") gr.Markdown("### Step 1-2: Select Model and Get Global Context") with gr.Row(): optimal_model_selector = gr.Radio(model_choices, value=DEFAULT_MODEL, label="1. Select Optimal Model (Highest Accuracy)") summary_btn_opt = gr.Button("2. Generate Global Summary Plot (Reference: Find key variable ideal trends)") summary_image_opt = gr.Image(label="Global Feature Importance Context (Summary Plot)", type="filepath", interactive=False, height=250) summary_btn_opt.click(fn=get_summary_plot, inputs=[optimal_model_selector], outputs=summary_image_opt) gr.Markdown("---") gr.Markdown("### Step 3-4: Iterate Input, Predict, and Locally Validate") gr.Markdown("Based on the ideal trend shown in the Summary Plot (blue/red dot distribution), **fine-tune the variable values**. The plot labels will now hint at your input value's position within the dataset (Low/Mid/High).") with gr.Row(): optimal_input_MW01 = gr.Number(label="MW01", minimum=0, maximum=1000, value=800) optimal_input_MW02 = gr.Number(label="MW02", minimum=0, maximum=1000, value=800) optimal_input_MW03 = gr.Number(label="MW03", minimum=0, maximum=1000, value=300) optimal_input_flux = gr.Slider(label="flux", minimum=0.5, maximum=2, step=0.5, value=2) optimal_input_MWA = gr.Slider(label="MWA", minimum=5, maximum=20, step=5, value=5) optimal_inputs_list = [optimal_model_selector, optimal_input_MW01, optimal_input_MW02, optimal_input_MW03, optimal_input_flux, optimal_input_MWA] optimize_btn = gr.Button("4. Run Prediction & Generate Local Explanation (Validate/Fine-tune)") result_text_opt = gr.Textbox(label="Final Prediction Result") with gr.Row(): waterfall_image_opt = gr.Image(label="Validation Plot 1: Waterfall Plot (Check if variable contributions are positive as expected)", type="filepath", interactive=False, height=350) decision_image_opt = gr.Image(label="Validation Plot 2: Decision Plot (Check prediction path and baseline offset)", type="filepath", interactive=False, height=350) optimize_btn.click( fn=get_local_plots_combined, inputs=optimal_inputs_list, outputs=[result_text_opt, waterfall_image_opt, decision_image_opt] ) # --- A. Local Explanation (Multi-Plot) --- with gr.Blocks() as local_block: # --- Explanation Frame (Chinese) --- gr.Markdown( """ ## 局部解釋(多圖表) **目的:** 此分頁作為測試區,用於生成各種局部(單一實例)SHAP 圖,以理解特定預測的生成原因。它允許單獨控制每種圖表的生成。 ### 功能函數 (Functionality): * **預測與瀑布圖 (`get_local_explanation`):** 預測目標值並生成顯示特徵貢獻的瀑布圖。 * **決策圖 (`get_decision_plot`):** 顯示從基線(Base Value)開始的預測路徑。 * **力圖 (`get_static_force_plot`):** 提供正面/負面貢獻的高層次可視化。 ### 輸出解讀 (Interpretation of Output): * **瀑布圖:** 將輸出推高的特徵顯示為**紅色**。推低的顯示為**藍色**。所有貢獻的總和等於預測值與基線值之間的差異。 * **決策圖:** 水平路徑(X 軸)代表預測值。較粗的線顯示特徵值如何將預測值從平均值(基線)轉移到最終預測值。 * **力圖:** 紅色特徵將預測值推**高**(向右);藍色特徵將其推**低**(向左)。 """ ) gr.Markdown("---") local_model_selector = gr.Radio(model_choices, value=DEFAULT_MODEL, label="Select Model for SHAP Analysis") with gr.Row(): local_input_MW01 = gr.Number(label="MW01", minimum=0, maximum=1000, value=800) local_input_MW02 = gr.Number(label="MW02", minimum=0, maximum=1000, value=800) local_input_MW03 = gr.Number(label="MW03", minimum=0, maximum=1000, value=300) local_input_flux = gr.Slider(label="flux", minimum=0.5, maximum=2, step=0.5, value=2) local_input_MWA = gr.Slider(label="MWA", minimum=5, maximum=20, step=5, value=5) local_inputs_list = [local_model_selector, local_input_MW01, local_input_MW02, local_input_MW03, local_input_flux, local_input_MWA] result_text = gr.Textbox(label="Model Prediction Result") with gr.Row(): waterfall_btn = gr.Button("1. Generate Waterfall Plot") decision_btn = gr.Button("2. Generate Decision Plot") force_btn = gr.Button("3. Generate Static Force Plot") gr.Markdown("---") with gr.Row(): waterfall_image = gr.Image(label="1. SHAP Waterfall Plot", type="filepath", interactive=False, height=300) with gr.Row(): decision_image = gr.Image(label="2. SHAP Decision Plot", type="filepath", interactive=False, height=400) with gr.Row(): force_static_image = gr.Image(label="3. SHAP Force Plot (Static)", type="filepath", interactive=False, height=200) waterfall_btn.click(fn=get_local_explanation, inputs=local_inputs_list, outputs=[result_text, waterfall_image]) decision_btn.click(fn=get_decision_plot, inputs=local_inputs_list, outputs=decision_image) force_btn.click(fn=get_static_force_plot, inputs=local_inputs_list, outputs=force_static_image) # --- B. Global Analysis (Summary + Decision Plot) --- with gr.Blocks() as global_summary_block: # --- Explanation Frame (Chinese) --- gr.Markdown( """ ## 全域分析:特徵重要性與決策路徑 **目的:** 此分頁提供特徵重要性的全域概覽,以及特徵如何影響整個數據集的目標變數,有助於確立特徵的優先順序和理想的數值範圍。 ### 功能函數 (Functionality): * **摘要圖 (`get_summary_plot`):** 計算數據集中所有實例的 SHAP 數值平均幅度。 * **全域決策圖 (`get_global_decision_plot`):** 顯示大量樣本數據的預測路徑,揭示整體趨勢和異常值。 ### 輸出解讀 (Interpretation of Output): * **摘要圖:** * **Y 軸:** 特徵依全域重要性排序(頂部最重要)。 * **X 軸:** SHAP 數值(對模型輸出的影響)。 * **顏色(紅/藍):** 特徵值(紅色 = 高,藍色 = 低)。 * **關鍵見解:** 尋找高數值(紅色)一致地導致高正向 SHAP 數值的特徵,反之亦然,以確定最大化的理想數值範圍。 * **全域決策圖:** 這是許多預測路徑的集合。水平寬度表示變異性,路徑聚集的位置顯示了常見的預測值。 """ ) gr.Markdown("---") global_model_selector_sum = gr.Radio(model_choices, value=DEFAULT_MODEL, label="Select Model for Global Analysis") global_model_selector_sum gr.Markdown("### 1. Feature Importance Summary") summary_btn = gr.Button("Generate Summary Plot") summary_image = gr.Image(label="SHAP Summary Plot: Global Importance", type="filepath", interactive=False, height=500) summary_btn.click(fn=get_summary_plot, inputs=[global_model_selector_sum], outputs=summary_image) gr.Markdown("### 2. Global Decision Path (Sampled)") global_decision_btn = gr.Button(f"Generate Global Decision Plot (Sampling {SAMPLE_SIZE_GLOBAL_DECISION} points)") global_decision_image = gr.Image(label="SHAP Decision Plot: Paths of Sampled Data", type="filepath", interactive=False, height=600) global_decision_btn.click(fn=get_global_decision_plot, inputs=[global_model_selector_sum], outputs=global_decision_image) # --- C. Feature Dependence Plot Interface --- with gr.Blocks() as dependence_interface: # --- Explanation Frame (Chinese) --- gr.Markdown( """ ## 特徵依賴圖 **目的:** 此圖表有助於分析單一特徵對預測的邊際效應,以及此效應如何被第二個「互動」特徵所修改。 ### 功能函數 (Functionality): * **依賴圖 (`get_dependence_plot`):** 繪製特徵的實際數值(X 軸)與其 SHAP 數值(Y 軸)之間的關係,並可選用一個顏色特徵來顯示互動效應。 ### 輸出解讀 (Interpretation of Output): * **主要特徵(X 軸):** 顯示**趨勢**。如果隨著 X 軸數值的增加,點向上移動,則該特徵對輸出有正向的邊際影響。 * **顏色(互動特徵):** 如果著色特徵呈現出明顯的模式(例如,所有高顏色點聚集在更高或更低的位置),則表示這兩個特徵之間存在強烈的互動關係。 """ ) gr.Markdown("---") gr.Interface( fn=get_dependence_plot, inputs=[ gr.Radio(model_choices, value=DEFAULT_MODEL, label="Select Model for SHAP Analysis"), gr.Dropdown(label="Select Feature for X-axis", choices=FEATURES, value=FEATURES[0]), gr.Dropdown(label="Select Interaction Feature (Color)", choices=['None'] + FEATURES, value='None') ], outputs=gr.Image(label="SHAP Dependence Plot: Feature Dependence", type="filepath", interactive=False, height=350), flagging_mode="never", title="Feature Dependence on SHAP Value" ) # Combine all interfaces full_interface = gr.TabbedInterface( [optimal_block, local_block, global_summary_block, dependence_interface], ["Optimal Prediction Finder (Integrated)", "Local Explanation (Multi-Plot)", "Global Analysis (Summary & Decision)", "Feature Dependence Plot"] ) if __name__ == "__main__": full_interface.launch()