Spaces:
Sleeping
Sleeping
| # app.py (Final Corrected Version for Parallel Processing) | |
| # ======================== 步骤 1: 导入所有库 ======================== | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| from scipy import stats | |
| import warnings | |
| import os | |
| import logging | |
| from datetime import timedelta | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from matplotlib.ticker import FuncFormatter | |
| import shutil | |
| import zipfile | |
| # 时间序列核心库 | |
| from statsmodels.tsa.stattools import adfuller | |
| from statsmodels.stats.diagnostic import acorr_ljungbox | |
| from statsmodels.tsa.seasonal import STL | |
| import pmdarima as pm | |
| from prophet import Prophet | |
| # 机器学习和评估 | |
| from sklearn.metrics import mean_absolute_error, mean_squared_error | |
| from joblib import Parallel, delayed | |
| # --- 全局设置 --- | |
| warnings.filterwarnings("ignore") | |
| logging.getLogger('prophet').setLevel(logging.ERROR) | |
| logging.getLogger('cmdstanpy').setLevel(logging.ERROR) | |
| plt.rcParams['axes.unicode_minus'] = False # 确保负号可以正常显示 | |
| # --- 输出文件夹设置 --- | |
| OUTPUT_DIR = 'analysis_output' | |
| # ======================== 步骤 2: 定义辅助函数 ======================== | |
| def calculate_metrics(actual, predicted): | |
| """计算MAE, RMSE, MAPE, sMAPE""" | |
| metrics_df = pd.DataFrame({'actual': actual, 'predicted': predicted}).dropna() | |
| if metrics_df.empty: | |
| return {'MAE': np.nan, 'RMSE': np.nan, 'MAPE': np.nan, 'sMAPE': np.nan} | |
| clean_actual, clean_predicted = metrics_df['actual'], metrics_df['predicted'] | |
| mae = mean_absolute_error(clean_actual, clean_predicted) | |
| rmse = np.sqrt(mean_squared_error(clean_actual, clean_predicted)) | |
| actual_safe = np.where(clean_actual == 0, 1e-6, clean_actual) | |
| mape = np.mean(np.abs((clean_actual - clean_predicted) / actual_safe)) * 100 | |
| denominator = np.abs(clean_actual) + np.abs(clean_predicted) | |
| denominator_safe = np.where(denominator == 0, 1e-6, denominator) | |
| smape = 200 * np.mean(np.abs(clean_actual - clean_predicted) / denominator_safe) | |
| return {'MAE': mae, 'RMSE': rmse, 'MAPE': mape, 'sMAPE': smape} | |
| def dynamic_split(data, current_date, window_size, expanding=False): | |
| if expanding: | |
| train_start_date = data.index.min() | |
| else: | |
| train_start_date = current_date - timedelta(days=window_size - 1) | |
| train_data = data[(data.index >= train_start_date) & (data.index <= current_date)] | |
| test_start_date = current_date + timedelta(days=1) | |
| test_end_date = current_date + timedelta(weeks=4) | |
| test_data = data[(data.index >= test_start_date) & (data.index <= test_end_date)] | |
| return train_data, test_data | |
| def create_zip_archive(): | |
| """将输出文件夹打包成 zip 文件""" | |
| zip_path = f"{OUTPUT_DIR}.zip" | |
| with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: | |
| for root, _, files in os.walk(OUTPUT_DIR): | |
| for file in files: | |
| zipf.write(os.path.join(root, file), | |
| os.path.relpath(os.path.join(root, file), OUTPUT_DIR)) | |
| return zip_path | |
| # **修改点 1**: 改变函数签名,直接接收3个参数 | |
| def _parallel_evaluate_window(window_size, ts_values, d_order): | |
| """一个独立的、可被 joblib 并行调用的函数。""" | |
| n = len(ts_values) | |
| if window_size >= n: | |
| return np.inf | |
| errors = [] | |
| eval_len = min(n, 365) | |
| for i in range(eval_len - window_size): | |
| train = ts_values[i:(i + window_size)] | |
| test = ts_values[i + window_size] | |
| try: | |
| model = pm.auto_arima(train, d=d_order, seasonal=False, stepwise=True, suppress_warnings=True, error_action='ignore') | |
| errors.append(test - model.predict(n_periods=1)[0]) | |
| except Exception: | |
| continue | |
| return np.mean(np.abs(errors)) if errors else np.inf | |
| # ======================== 步骤 3: 核心分析函数 ======================== | |
| def run_analysis(uploaded_file, progress=gr.Progress(track_tqdm=True)): | |
| # --- 初始化 --- | |
| progress(0, desc="开始分析流程...") | |
| if os.path.exists(OUTPUT_DIR): shutil.rmtree(OUTPUT_DIR) | |
| os.makedirs(OUTPUT_DIR) | |
| def update_ui(status, stl_plot=None, window_plot=None, weight_plot=None, error_comp_plot=None, radar_plot=None, box_plot=None, final_plot=None, summary_df=None, download_file=None): | |
| return { | |
| status_display: gr.update(value=status), plot_stl: gr.update(value=stl_plot), | |
| plot_window: gr.update(value=window_plot), plot_weights: gr.update(value=weight_plot), | |
| plot_error_comp: gr.update(value=error_comp_plot), plot_radar: gr.update(value=radar_plot), | |
| plot_boxplot: gr.update(value=box_plot), plot_final_forecast: gr.update(value=final_plot), | |
| summary_table: gr.update(value=summary_df), download_results: gr.update(value=download_file, visible=download_file is not None) | |
| } | |
| yield update_ui("分析已开始,正在准备环境...") | |
| # --- 1. 数据清洗 --- | |
| status_message = "### 步骤 1: 数据清洗\n" | |
| try: | |
| progress(0.05, desc="正在读取和清洗数据...") | |
| yield update_ui(status_message + "▶️ 正在读取上传的 Excel 文件...") | |
| df = pd.read_excel(uploaded_file.name) | |
| df['Date'] = pd.to_datetime(df['Date']) | |
| df = df.drop_duplicates(subset=['Date']).sort_values('Date').reset_index(drop=True) | |
| status_message += f"✔️ 数据读取成功,共 {len(df)} 行,时间范围: {df['Date'].min().date()} 到 {df['Date'].max().date()}.\n" | |
| yield update_ui(status_message) | |
| except Exception as e: | |
| gr.Error(f"文件读取失败: {e}. 请确保上传的 XLSX 文件包含 'Date' 和 'Value' 列。"); return | |
| df['Value'] = df['Value'].replace(0, np.nan).interpolate(method='linear', limit_direction='both') | |
| status_message += "✔️ 零值替换与线性插值完成。\n" | |
| z_scores = np.abs(stats.zscore(df['Value'])); outliers = df[z_scores > 3] | |
| status_message += f"✔️ Z-score 异常值检测完成,发现 {len(outliers)} 个潜在异常值。\n" | |
| yield update_ui(status_message) | |
| df.set_index('Date', inplace=True) | |
| ts_data = df['Value'].asfreq('D').interpolate(method='linear') | |
| # --- 2. 平稳性检验与差分 --- | |
| status_message += "\n### 步骤 2: 平稳性检验\n" | |
| progress(0.1, desc="进行平稳性检验...") | |
| def make_stationary(data): | |
| current_data, log = data.copy(), "" | |
| for d in range(4): | |
| p_value = adfuller(current_data.dropna())[1] | |
| log += f"差分阶数 d={d}, ADF检验 p-value: {p_value:.4f}\n" | |
| if p_value < 0.05: | |
| log += f"✔️ 序列在 d={d} 阶差分后达到平稳."; return current_data.dropna(), d, log | |
| current_data = current_data.diff() | |
| log += f"⚠️ 警告:在最大差分阶数内未达到平稳."; return current_data.dropna(), 3, log | |
| ts_stationary, d_order, stationarity_log = make_stationary(ts_data) | |
| status_message += stationarity_log; yield update_ui(status_message) | |
| # --- 3. 白噪声检验 --- | |
| status_message += "\n\n### 步骤 3: 白噪声检验\n"; progress(0.15, desc="进行白噪声检验...") | |
| lags = min(10, len(ts_stationary) // 5) | |
| lb_p_value = acorr_ljungbox(ts_stationary, lags=[lags], return_df=True)['lb_pvalue'].iloc[0] | |
| status_message += f"✔️ 通过白噪声检验 (p-value = {lb_p_value:.4f}),序列为非白噪声。" if lb_p_value <= 0.05 else f"⚠️ 序列为白噪声 (p-value = {lb_p_value:.4f}),可能不适合复杂建模。" | |
| yield update_ui(status_message) | |
| # --- 4. 季节性检验与处理 --- | |
| status_message += "\n\n### 步骤 4: 季节性分析 (初步分析)\n"; progress(0.2, desc="进行季节性分解...") | |
| D_weekly = pm.arima.nsdiffs(ts_data, m=7, test='ch') | |
| status_message += f"✔️ 每周季节性差分阶数 D(m=7): {D_weekly} (通过 Canova-Hansen 测试确定)。\n" | |
| plt.figure(figsize=(14, 10)); stl_year = STL(ts_data, period=365).fit(); stl_year.plot() | |
| plt.suptitle('STL Decomposition (Annual Seasonality)', y=1.02); plt.savefig(os.path.join(OUTPUT_DIR, '4_stl_decomposition.png')) | |
| stl_fig_for_ui = plt.gcf(); plt.close() | |
| status_message += "✔️ STL分解图已生成."; yield update_ui(status_message, stl_plot=stl_fig_for_ui) | |
| # --- 5. 窗口大小优化 (使用并行计算) --- | |
| status_message += "\n\n### 步骤 5: 窗口大小优化\n" | |
| status_message += "▶️ **正在启动多进程并行计算**以评估不同窗口大小的MAE。\n▶️ 下方的进度条将显示总体进度。计算完成后会显示详细结果。\n" | |
| yield update_ui(status_message, stl_plot=stl_fig_for_ui) | |
| ts_values = ts_data.values; window_sizes = range(70, 211, 14) | |
| tasks = [(ws, ts_values, d_order) for ws in window_sizes] | |
| # **修改点 2**: 使用 *task 来解包参数 | |
| results_maes = Parallel(n_jobs=-1)( | |
| delayed(_parallel_evaluate_window)(*task) for task in progress.tqdm(tasks, desc="优化窗口大小") | |
| ) | |
| results_df = pd.DataFrame({'window_size': window_sizes, 'mae': results_maes}) | |
| status_message += "---\n✔️ 并行计算完成,结果如下:\n" | |
| for _, row in results_df.iterrows(): | |
| status_message += f"窗口大小: {int(row['window_size']):<4} -> MAE: {row['mae']:.4f}\n" | |
| status_message += "---\n" | |
| if results_df.empty or results_df['mae'].isna().all(): | |
| gr.Error("窗口优化失败,所有窗口大小均未产出有效MAE。"); return | |
| best_mae_window = int(results_df.loc[results_df['mae'].idxmin()]['window_size']) | |
| plt.figure(figsize=(10, 6)); sns.lineplot(data=results_df, x='window_size', y='mae', marker='o') | |
| plt.title('Effect of Window Size on Prediction Error (MAE)'); plt.xlabel('Training Window Days'); plt.ylabel('Mean Absolute Error (MAE)') | |
| plt.grid(True); plt.savefig(os.path.join(OUTPUT_DIR, '5_window_size_optimization.png')) | |
| window_fig_for_ui = plt.gcf(); plt.close() | |
| status_message += f"✔️ 窗口大小优化完成。最优滑动窗口大小: **{best_mae_window}** 天。" | |
| yield update_ui(status_message, stl_plot=stl_fig_for_ui, window_plot=window_fig_for_ui) | |
| # --- 6. 模型定义 --- | |
| status_message += "\n\n### 步骤 6: 定义所有模型\n✔️ SARIMA, Prophet, 加权平均模型已定义。" | |
| yield update_ui(status_message, stl_plot=stl_fig_for_ui, window_plot=window_fig_for_ui) | |
| def sarima_model(train_data, h=28): | |
| return pm.auto_arima(train_data['Value'], d=d_order, D=D_weekly, m=7, seasonal=True, stepwise=True, suppress_warnings=True, error_action='ignore', max_p=3, max_q=3).predict(n_periods=h) | |
| def prophet_model_default(train_data, h=28): | |
| df_prophet = train_data.reset_index().rename(columns={'Date': 'ds', 'Value': 'y'}); model = Prophet(yearly_seasonality=True, weekly_seasonality=True).fit(df_prophet); future = model.make_future_dataframe(periods=h, freq='D'); return model.predict(future)['yhat'].tail(h) | |
| def prophet_model_tuned(train_data, h=28): | |
| df_prophet = train_data.reset_index().rename(columns={'Date': 'ds', 'Value': 'y'}); param_grid = {'changepoint_prior_scale': [0.01, 0.1, 0.5], 'seasonality_prior_scale': [1.0, 5.0, 10.0]}; best_mae, best_params = float('inf'), {}; train_len = len(df_prophet); split_point = train_len - 28 | |
| if split_point < 10: return prophet_model_default(train_data, h) | |
| train_sub, val_sub = df_prophet.iloc[:split_point], df_prophet.iloc[split_point:] | |
| for cps in param_grid['changepoint_prior_scale']: | |
| for sps in param_grid['seasonality_prior_scale']: | |
| m = Prophet(changepoint_prior_scale=cps, seasonality_prior_scale=sps).fit(train_sub); fc_val = m.predict(m.make_future_dataframe(periods=len(val_sub))).tail(len(val_sub)); mae = mean_absolute_error(val_sub['y'], fc_val['yhat']) | |
| if mae < best_mae: best_mae, best_params = mae, {'changepoint_prior_scale': cps, 'seasonality_prior_scale': sps} | |
| final_model = Prophet(**best_params).fit(df_prophet); return final_model.predict(final_model.make_future_dataframe(periods=h))['yhat'].tail(h) | |
| def weighted_average_model(sarima_pred, prophet_pred, train_data): | |
| val_end_date = train_data.index.max(); val_start_date = val_end_date - timedelta(weeks=4) + timedelta(days=1); validation_data = train_data[train_data.index >= val_start_date] | |
| if len(validation_data) < 28: return (sarima_pred + prophet_pred) / 2, 0.5, 0.5 | |
| train_for_val = train_data[train_data.index < val_start_date] | |
| if len(train_for_val) < 20: return (sarima_pred + prophet_pred) / 2, 0.5, 0.5 | |
| sarima_val_pred, prophet_val_pred = sarima_model(train_for_val, h=28), prophet_model_tuned(train_for_val, h=28); sarima_mae, prophet_mae = mean_absolute_error(validation_data['Value'], sarima_val_pred), mean_absolute_error(validation_data['Value'], prophet_val_pred) | |
| if sarima_mae + prophet_mae == 0: return (sarima_pred + prophet_pred) / 2, 0.5, 0.5 | |
| inv_err_s, inv_err_p = (1 / sarima_mae if sarima_mae > 1e-9 else 1e9), (1 / prophet_mae if prophet_mae > 1e-9 else 1e9); total_inv_err = inv_err_s + inv_err_p; w_s, w_p = inv_err_s / total_inv_err, inv_err_p / total_inv_err | |
| return w_s * sarima_pred + w_p * prophet_pred, w_s, w_p | |
| # --- 7. 模型性能评估 (滚动预测) --- | |
| status_message += "\n\n### 步骤 7: 模型性能评估 (滚动预测)\n" | |
| start_date = ts_data.index.min() + timedelta(days=best_mae_window); end_date = ts_data.index.max() - timedelta(weeks=4); current_dates = pd.date_range(start=start_date, end=end_date, freq='W') | |
| status_message += f"▶️ 将进行 {len(current_dates)} 次滚动预测,这会是耗时最长的步骤...\n"; yield update_ui(status_message, stl_plot=stl_fig_for_ui, window_plot=window_fig_for_ui) | |
| all_metrics, weight_history = [], [] | |
| for i, date in enumerate(progress.tqdm(current_dates, desc="滚动预测评估")): | |
| train_sliding, test_data = dynamic_split(ts_data.to_frame('Value'), date, best_mae_window, False); train_expanding, _ = dynamic_split(ts_data.to_frame('Value'), date, best_mae_window, True); actual_values = test_data['Value']; h = len(actual_values) | |
| if h == 0: continue | |
| sarima_pred_sliding, prophet_pred_tuned = sarima_model(train_sliding, h), prophet_model_tuned(train_expanding, h); weighted_pred, w_s, w_p = weighted_average_model(sarima_pred_sliding, prophet_pred_tuned, train_sliding); weight_history.append({'Date': date, 'SARIMA_Weight': w_s, 'Prophet_Weight': w_p}); models = {'SARIMA (Sliding)': sarima_pred_sliding, 'Prophet (Tuned)': prophet_pred_tuned, 'Weighted Average': weighted_pred} | |
| for name, pred in models.items(): | |
| metrics = calculate_metrics(actual_values, pred); metrics['Model'], metrics['Date'] = name, date; all_metrics.append(metrics) | |
| metrics_df, weight_history_df = pd.DataFrame(all_metrics), pd.DataFrame(weight_history) | |
| status_message += "✔️ 滚动预测评估完成."; yield update_ui(status_message, stl_plot=stl_fig_for_ui, window_plot=window_fig_for_ui) | |
| # --- 8. 模型对比可视化 --- | |
| status_message += "\n\n### 步骤 8: 模型对比可视化\n▶️ 正在生成对比图表...\n"; progress(0.85, desc="生成对比图表...") | |
| plt.figure(figsize=(12, 6)); sns.lineplot(data=weight_history_df, x='Date', y='SARIMA_Weight', label='SARIMA Weight'); sns.lineplot(data=weight_history_df, x='Date', y='Prophet_Weight', label='Prophet Weight'); plt.title('Weight Change of Sub-models'); plt.savefig(os.path.join(OUTPUT_DIR, '8_1_weight_change.png')); weight_fig_for_ui = plt.gcf(); plt.close() | |
| metrics_long = metrics_df.melt(id_vars=['Date', 'Model'], value_vars=['MAE', 'RMSE', 'MAPE', 'sMAPE'], var_name='Metric', value_name='Value'); g = sns.FacetGrid(metrics_long, col='Metric', hue='Model', col_wrap=2, height=4, aspect=1.5, sharey=False); g.map(sns.lineplot, 'Date', 'Value', marker='o', markersize=4).add_legend(title="Model"); g.fig.suptitle('Comparison of Prediction Error Metrics', y=1.03); g.set_titles("{col_name}"); g.fig.savefig(os.path.join(OUTPUT_DIR, '8_2_error_comparison.png')); error_comp_fig_for_ui = g.fig; plt.close() | |
| avg_metrics = metrics_df.groupby('Model')[['MAE', 'RMSE', 'MAPE', 'sMAPE']].mean(); avg_metrics_normalized = avg_metrics.apply(lambda x: (x - x.min()) / (x.max() - x.min())).reset_index(); labels = avg_metrics_normalized.columns[1:]; num_vars = len(labels); angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist() + [0]; fig_radar, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True)); | |
| for i, row in avg_metrics_normalized.iterrows(): values = row.drop('Model').tolist() + [row.drop('Model').tolist()[0]]; ax.plot(angles, values, label=row['Model'], linewidth=2); ax.fill(angles, values, alpha=0.2) | |
| ax.set_xticks(angles[:-1]); ax.set_xticklabels(labels); plt.title('Average Model Performance Radar Chart (Normalized)'); plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1)); fig_radar.savefig(os.path.join(OUTPUT_DIR, '8_3_radar_chart.png')); radar_fig_for_ui = fig_radar; plt.close() | |
| fig_box, axes = plt.subplots(2, 2, figsize=(16, 10)) | |
| for i, metric in enumerate(['MAE', 'RMSE', 'MAPE', 'sMAPE']): ax = axes[i//2, i%2]; sns.boxplot(data=metrics_df, x='Model', y=metric, hue='Model', ax=ax, palette='viridis', legend=False); ax.set_title(f'{metric} Distribution'); ax.tick_params(axis='x', rotation=30) | |
| plt.tight_layout(); fig_box.savefig(os.path.join(OUTPUT_DIR, '8_4_error_boxplot.png')); boxplot_fig_for_ui = fig_box; plt.close() | |
| last_date = current_dates[-1]; train_sliding, test_data = dynamic_split(ts_data.to_frame('Value'), last_date, best_mae_window, False); h = len(test_data); preds = {'SARIMA (Sliding)': sarima_model(train_sliding, h), 'Prophet (Tuned)': prophet_model_tuned(train_sliding, h)}; preds['Weighted Average'], _, _ = weighted_average_model(preds['SARIMA (Sliding)'], preds['Prophet (Tuned)'], train_sliding); fig_final, ax_final = plt.subplots(figsize=(16, 8)); ax_final.plot(test_data.index, test_data['Value'], label='Actual Value', color='black', linewidth=2.5, marker='o') | |
| for name, pred in preds.items(): ax_final.plot(test_data.index, pred.values, label=name, linestyle='--') | |
| ax_final.set_title(f'Final 4-Week Forecast Comparison (Start Date: {last_date.date()})'); ax_final.legend(); fig_final.savefig(os.path.join(OUTPUT_DIR, '8_5_final_forecast.png')); final_fig_for_ui = fig_final; plt.close() | |
| status_message += "✔️ 所有对比图表已生成." | |
| yield update_ui(status_message, stl_plot=stl_fig_for_ui, window_plot=window_fig_for_ui, weight_plot=weight_fig_for_ui, error_comp_plot=error_comp_fig_for_ui, radar_plot=radar_fig_for_ui, box_plot=boxplot_fig_for_ui, final_plot=final_fig_for_ui) | |
| # --- 9. 最终性能总结 --- | |
| status_message += "\n\n### 步骤 9: 最终性能总结\n"; progress(0.95, desc="生成最终报告...") | |
| avg_perf = metrics_df.groupby('Model')[['MAE', 'RMSE', 'MAPE', 'sMAPE']].mean().sort_values('MAE'); status_message += "✔️ 平均性能计算完成。\n" | |
| summary_path = os.path.join(OUTPUT_DIR, 'performance_summary.txt') | |
| with open(summary_path, 'w', encoding='utf-8') as f: f.write("Time Series Forecasting Model Comparison Report\n" + "="*50 + "\n\n"); f.write(f"1. Optimal Sliding Window Size: {best_mae_window} days\n\n"); f.write(f"2. Stationarity: Stationary after {d_order} order(s) of differencing.\n\n"); f.write(f"3. Seasonality: Weekly seasonal differencing D(m=7) = {D_weekly}.\n\n"); f.write("4. Average Model Performance (sorted by MAE):\n" + avg_perf.to_string(float_format="%.4f")) | |
| status_message += f"✔️ 性能总结报告已保存到 {summary_path}\n" | |
| zip_archive_path = create_zip_archive(); status_message += f"✔️ 所有结果已打包到 {zip_archive_path}。" | |
| progress(1, desc="分析完成!"); status_message += "\n\n## ✅ 分析流程执行完毕!" | |
| yield update_ui(status_message, stl_plot=stl_fig_for_ui, window_plot=window_fig_for_ui, weight_plot=weight_fig_for_ui, error_comp_plot=error_comp_fig_for_ui, radar_plot=radar_fig_for_ui, box_plot=boxplot_fig_for_ui, final_plot=final_fig_for_ui, summary_df=avg_perf.reset_index(), download_file=zip_archive_path) | |
| # ======================== 步骤 4: 构建 Gradio 界面 ======================== | |
| with gr.Blocks(theme=gr.themes.Soft(), title="时间序列预测分析平台") as app: | |
| gr.Markdown("# 📈 时间序列预测分析平台"); gr.Markdown("上传您的时间序列数据 (XLSX格式,包含 'Date' 和 'Value' 两列),然后点击“开始分析”按钮。系统将自动完成数据清洗、检验、模型训练、评估和对比的全过程。") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| file_input = gr.File(label="上传 XLSX 数据文件", file_types=['.xlsx'], value="gmqrkl.xlsx"); run_button = gr.Button("🚀 开始分析", variant="primary"); gr.Markdown("---"); gr.Markdown("### 📝 分析日志与进度"); status_display = gr.Markdown("请上传文件并点击开始...") | |
| with gr.Column(scale=3): | |
| gr.Markdown("### 📊 分析结果可视化") | |
| with gr.Tabs(): | |
| with gr.TabItem("初步分析"): plot_stl = gr.Plot(label="STL 季节性分解") | |
| with gr.TabItem("优化窗口"): plot_window = gr.Plot(label="最优窗口大小分析") | |
| with gr.TabItem("模型性能对比"): plot_error_comp = gr.Plot(label="各模型误差指标对比"); plot_boxplot = gr.Plot(label="误差分布箱线图") | |
| with gr.TabItem("综合评估"): plot_radar = gr.Plot(label="模型平均性能雷达图"); plot_weights = gr.Plot(label="组合模型权重变化") | |
| with gr.TabItem("最终预测"): plot_final_forecast = gr.Plot(label="最后四周预测对比图") | |
| gr.Markdown("---"); gr.Markdown("### 🏆 最终性能总结"); summary_table = gr.DataFrame(label="各模型平均性能对比 (按 MAE 升序)"); download_results = gr.File(label="下载全部分析结果 (ZIP)", visible=False) | |
| outputs_list = [status_display, plot_stl, plot_window, plot_weights, plot_error_comp, plot_radar, plot_boxplot, plot_final_forecast, summary_table, download_results] | |
| run_button.click(fn=run_analysis, inputs=[file_input], outputs=outputs_list) | |
| if __name__ == "__main__": | |
| app.launch() |