leonsimon23 commited on
Commit
4b86963
·
verified ·
1 Parent(s): 09c81a4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +281 -0
app.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py (Final Corrected Version for Parallel Processing)
2
+
3
+ # ======================== 步骤 1: 导入所有库 ========================
4
+ import gradio as gr
5
+ import pandas as pd
6
+ import numpy as np
7
+ from scipy import stats
8
+ import warnings
9
+ import os
10
+ import logging
11
+ from datetime import timedelta
12
+ import matplotlib.pyplot as plt
13
+ import seaborn as sns
14
+ from matplotlib.ticker import FuncFormatter
15
+ import shutil
16
+ import zipfile
17
+
18
+ # 时间序列核心库
19
+ from statsmodels.tsa.stattools import adfuller
20
+ from statsmodels.stats.diagnostic import acorr_ljungbox
21
+ from statsmodels.tsa.seasonal import STL
22
+ import pmdarima as pm
23
+ from prophet import Prophet
24
+
25
+ # 机器学习和评估
26
+ from sklearn.metrics import mean_absolute_error, mean_squared_error
27
+ from joblib import Parallel, delayed
28
+
29
+ # --- 全局设置 ---
30
+ warnings.filterwarnings("ignore")
31
+ logging.getLogger('prophet').setLevel(logging.ERROR)
32
+ logging.getLogger('cmdstanpy').setLevel(logging.ERROR)
33
+ plt.rcParams['axes.unicode_minus'] = False # 确保负号可以正常显示
34
+
35
+ # --- 输出文件夹设置 ---
36
+ OUTPUT_DIR = 'analysis_output'
37
+
38
+ # ======================== 步骤 2: 定义辅助函数 ========================
39
+
40
+ def calculate_metrics(actual, predicted):
41
+ """计算MAE, RMSE, MAPE, sMAPE"""
42
+ metrics_df = pd.DataFrame({'actual': actual, 'predicted': predicted}).dropna()
43
+ if metrics_df.empty:
44
+ return {'MAE': np.nan, 'RMSE': np.nan, 'MAPE': np.nan, 'sMAPE': np.nan}
45
+
46
+ clean_actual, clean_predicted = metrics_df['actual'], metrics_df['predicted']
47
+ mae = mean_absolute_error(clean_actual, clean_predicted)
48
+ rmse = np.sqrt(mean_squared_error(clean_actual, clean_predicted))
49
+ actual_safe = np.where(clean_actual == 0, 1e-6, clean_actual)
50
+ mape = np.mean(np.abs((clean_actual - clean_predicted) / actual_safe)) * 100
51
+ denominator = np.abs(clean_actual) + np.abs(clean_predicted)
52
+ denominator_safe = np.where(denominator == 0, 1e-6, denominator)
53
+ smape = 200 * np.mean(np.abs(clean_actual - clean_predicted) / denominator_safe)
54
+
55
+ return {'MAE': mae, 'RMSE': rmse, 'MAPE': mape, 'sMAPE': smape}
56
+
57
+ def dynamic_split(data, current_date, window_size, expanding=False):
58
+ if expanding:
59
+ train_start_date = data.index.min()
60
+ else:
61
+ train_start_date = current_date - timedelta(days=window_size - 1)
62
+ train_data = data[(data.index >= train_start_date) & (data.index <= current_date)]
63
+ test_start_date = current_date + timedelta(days=1)
64
+ test_end_date = current_date + timedelta(weeks=4)
65
+ test_data = data[(data.index >= test_start_date) & (data.index <= test_end_date)]
66
+ return train_data, test_data
67
+
68
+ def create_zip_archive():
69
+ """将输出文件夹打包成 zip 文件"""
70
+ zip_path = f"{OUTPUT_DIR}.zip"
71
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
72
+ for root, _, files in os.walk(OUTPUT_DIR):
73
+ for file in files:
74
+ zipf.write(os.path.join(root, file),
75
+ os.path.relpath(os.path.join(root, file), OUTPUT_DIR))
76
+ return zip_path
77
+
78
+ # **修改点 1**: 改变函数签名,直接接收3个参数
79
+ def _parallel_evaluate_window(window_size, ts_values, d_order):
80
+ """一个独立的、可被 joblib 并行调用的函数。"""
81
+ n = len(ts_values)
82
+ if window_size >= n:
83
+ return np.inf
84
+ errors = []
85
+ eval_len = min(n, 365)
86
+ for i in range(eval_len - window_size):
87
+ train = ts_values[i:(i + window_size)]
88
+ test = ts_values[i + window_size]
89
+ try:
90
+ model = pm.auto_arima(train, d=d_order, seasonal=False, stepwise=True, suppress_warnings=True, error_action='ignore')
91
+ errors.append(test - model.predict(n_periods=1)[0])
92
+ except Exception:
93
+ continue
94
+ return np.mean(np.abs(errors)) if errors else np.inf
95
+
96
+ # ======================== 步骤 3: 核心分析函数 ========================
97
+ def run_analysis(uploaded_file, progress=gr.Progress(track_tqdm=True)):
98
+
99
+ # --- 初始化 ---
100
+ progress(0, desc="开始分析流程...")
101
+ if os.path.exists(OUTPUT_DIR): shutil.rmtree(OUTPUT_DIR)
102
+ os.makedirs(OUTPUT_DIR)
103
+
104
+ def update_ui(status, stl_plot=None, window_plot=None, weight_plot=None, error_comp_plot=None, radar_plot=None, box_plot=None, final_plot=None, summary_df=None, download_file=None):
105
+ return {
106
+ status_display: gr.update(value=status), plot_stl: gr.update(value=stl_plot),
107
+ plot_window: gr.update(value=window_plot), plot_weights: gr.update(value=weight_plot),
108
+ plot_error_comp: gr.update(value=error_comp_plot), plot_radar: gr.update(value=radar_plot),
109
+ plot_boxplot: gr.update(value=box_plot), plot_final_forecast: gr.update(value=final_plot),
110
+ summary_table: gr.update(value=summary_df), download_results: gr.update(value=download_file, visible=download_file is not None)
111
+ }
112
+
113
+ yield update_ui("分析已开始,正在准备环境...")
114
+
115
+ # --- 1. 数据清洗 ---
116
+ status_message = "### 步骤 1: 数据清洗\n"
117
+ try:
118
+ progress(0.05, desc="正在读取和清洗数据...")
119
+ yield update_ui(status_message + "▶️ 正在读取上传的 Excel 文件...")
120
+ df = pd.read_excel(uploaded_file.name)
121
+ df['Date'] = pd.to_datetime(df['Date'])
122
+ df = df.drop_duplicates(subset=['Date']).sort_values('Date').reset_index(drop=True)
123
+ status_message += f"✔️ 数据读取成功,共 {len(df)} 行,时间范围: {df['Date'].min().date()} 到 {df['Date'].max().date()}.\n"
124
+ yield update_ui(status_message)
125
+ except Exception as e:
126
+ gr.Error(f"文件读取失败: {e}. 请确保上传的 XLSX 文件包含 'Date' 和 'Value' 列。"); return
127
+ df['Value'] = df['Value'].replace(0, np.nan).interpolate(method='linear', limit_direction='both')
128
+ status_message += "✔️ 零值替换与线性插值完成。\n"
129
+ z_scores = np.abs(stats.zscore(df['Value'])); outliers = df[z_scores > 3]
130
+ status_message += f"✔️ Z-score 异常值检测完成,发现 {len(outliers)} 个潜在异常值。\n"
131
+ yield update_ui(status_message)
132
+ df.set_index('Date', inplace=True)
133
+ ts_data = df['Value'].asfreq('D').interpolate(method='linear')
134
+
135
+ # --- 2. 平稳性检验与差分 ---
136
+ status_message += "\n### 步骤 2: 平稳性检验\n"
137
+ progress(0.1, desc="进行平稳性检验...")
138
+ def make_stationary(data):
139
+ current_data, log = data.copy(), ""
140
+ for d in range(4):
141
+ p_value = adfuller(current_data.dropna())[1]
142
+ log += f"差分阶数 d={d}, ADF检验 p-value: {p_value:.4f}\n"
143
+ if p_value < 0.05:
144
+ log += f"✔️ 序列在 d={d} 阶差分后达到平稳."; return current_data.dropna(), d, log
145
+ current_data = current_data.diff()
146
+ log += f"⚠️ 警告:在最大差分阶数内未达到平稳."; return current_data.dropna(), 3, log
147
+ ts_stationary, d_order, stationarity_log = make_stationary(ts_data)
148
+ status_message += stationarity_log; yield update_ui(status_message)
149
+
150
+ # --- 3. 白噪声检验 ---
151
+ status_message += "\n\n### 步骤 3: 白噪声检验\n"; progress(0.15, desc="进行白噪声检验...")
152
+ lags = min(10, len(ts_stationary) // 5)
153
+ lb_p_value = acorr_ljungbox(ts_stationary, lags=[lags], return_df=True)['lb_pvalue'].iloc[0]
154
+ status_message += f"✔️ 通过白噪声检验 (p-value = {lb_p_value:.4f}),序列为非白噪声。" if lb_p_value <= 0.05 else f"⚠️ 序列为白噪声 (p-value = {lb_p_value:.4f}),可能不适合复杂建模。"
155
+ yield update_ui(status_message)
156
+
157
+ # --- 4. 季节性检验与处理 ---
158
+ status_message += "\n\n### 步骤 4: 季节性分析 (初步分析)\n"; progress(0.2, desc="进行季节性分解...")
159
+ D_weekly = pm.arima.nsdiffs(ts_data, m=7, test='ch')
160
+ status_message += f"✔️ 每周季节性差分阶数 D(m=7): {D_weekly} (通过 Canova-Hansen 测试确定)。\n"
161
+ plt.figure(figsize=(14, 10)); stl_year = STL(ts_data, period=365).fit(); stl_year.plot()
162
+ plt.suptitle('STL Decomposition (Annual Seasonality)', y=1.02); plt.savefig(os.path.join(OUTPUT_DIR, '4_stl_decomposition.png'))
163
+ stl_fig_for_ui = plt.gcf(); plt.close()
164
+ status_message += "✔️ STL分解图已生成."; yield update_ui(status_message, stl_plot=stl_fig_for_ui)
165
+
166
+ # --- 5. 窗口大小优化 (使用并行计算) ---
167
+ status_message += "\n\n### 步骤 5: 窗口大小优化\n"
168
+ status_message += "▶️ **正在启动多进程并行计算**以评估不同窗口大小的MAE。\n▶️ 下方的进度条将显示总体进度。计算完成后会显示详细结果。\n"
169
+ yield update_ui(status_message, stl_plot=stl_fig_for_ui)
170
+
171
+ ts_values = ts_data.values; window_sizes = range(70, 211, 14)
172
+ tasks = [(ws, ts_values, d_order) for ws in window_sizes]
173
+
174
+ # **修改点 2**: 使用 *task 来解包参数
175
+ results_maes = Parallel(n_jobs=-1)(
176
+ delayed(_parallel_evaluate_window)(*task) for task in progress.tqdm(tasks, desc="优化窗口大小")
177
+ )
178
+
179
+ results_df = pd.DataFrame({'window_size': window_sizes, 'mae': results_maes})
180
+ status_message += "---\n✔️ 并行计算完成,结果如下:\n"
181
+ for _, row in results_df.iterrows():
182
+ status_message += f"窗口大小: {int(row['window_size']):<4} -> MAE: {row['mae']:.4f}\n"
183
+ status_message += "---\n"
184
+ if results_df.empty or results_df['mae'].isna().all():
185
+ gr.Error("窗口优化失败,所有窗口大小均未产出有效MAE。"); return
186
+ best_mae_window = int(results_df.loc[results_df['mae'].idxmin()]['window_size'])
187
+ plt.figure(figsize=(10, 6)); sns.lineplot(data=results_df, x='window_size', y='mae', marker='o')
188
+ plt.title('Effect of Window Size on Prediction Error (MAE)'); plt.xlabel('Training Window Days'); plt.ylabel('Mean Absolute Error (MAE)')
189
+ plt.grid(True); plt.savefig(os.path.join(OUTPUT_DIR, '5_window_size_optimization.png'))
190
+ window_fig_for_ui = plt.gcf(); plt.close()
191
+ status_message += f"✔️ 窗口大小优化完成���最优滑动窗口大小: **{best_mae_window}** 天。"
192
+ yield update_ui(status_message, stl_plot=stl_fig_for_ui, window_plot=window_fig_for_ui)
193
+
194
+ # --- 6. 模型定义 ---
195
+ status_message += "\n\n### 步骤 6: 定义所有模型\n✔️ SARIMA, Prophet, 加权平均模型已定义。"
196
+ yield update_ui(status_message, stl_plot=stl_fig_for_ui, window_plot=window_fig_for_ui)
197
+
198
+ def sarima_model(train_data, h=28):
199
+ return pm.auto_arima(train_data['Value'], d=d_order, D=D_weekly, m=7, seasonal=True, stepwise=True, suppress_warnings=True, error_action='ignore', max_p=3, max_q=3).predict(n_periods=h)
200
+ def prophet_model_default(train_data, h=28):
201
+ df_prophet = train_data.reset_index().rename(columns={'Date': 'ds', 'Value': 'y'}); model = Prophet(yearly_seasonality=True, weekly_seasonality=True).fit(df_prophet); future = model.make_future_dataframe(periods=h, freq='D'); return model.predict(future)['yhat'].tail(h)
202
+ def prophet_model_tuned(train_data, h=28):
203
+ df_prophet = train_data.reset_index().rename(columns={'Date': 'ds', 'Value': 'y'}); param_grid = {'changepoint_prior_scale': [0.01, 0.1, 0.5], 'seasonality_prior_scale': [1.0, 5.0, 10.0]}; best_mae, best_params = float('inf'), {}; train_len = len(df_prophet); split_point = train_len - 28
204
+ if split_point < 10: return prophet_model_default(train_data, h)
205
+ train_sub, val_sub = df_prophet.iloc[:split_point], df_prophet.iloc[split_point:]
206
+ for cps in param_grid['changepoint_prior_scale']:
207
+ for sps in param_grid['seasonality_prior_scale']:
208
+ m = Prophet(changepoint_prior_scale=cps, seasonality_prior_scale=sps).fit(train_sub); fc_val = m.predict(m.make_future_dataframe(periods=len(val_sub))).tail(len(val_sub)); mae = mean_absolute_error(val_sub['y'], fc_val['yhat'])
209
+ if mae < best_mae: best_mae, best_params = mae, {'changepoint_prior_scale': cps, 'seasonality_prior_scale': sps}
210
+ final_model = Prophet(**best_params).fit(df_prophet); return final_model.predict(final_model.make_future_dataframe(periods=h))['yhat'].tail(h)
211
+ def weighted_average_model(sarima_pred, prophet_pred, train_data):
212
+ val_end_date = train_data.index.max(); val_start_date = val_end_date - timedelta(weeks=4) + timedelta(days=1); validation_data = train_data[train_data.index >= val_start_date]
213
+ if len(validation_data) < 28: return (sarima_pred + prophet_pred) / 2, 0.5, 0.5
214
+ train_for_val = train_data[train_data.index < val_start_date]
215
+ if len(train_for_val) < 20: return (sarima_pred + prophet_pred) / 2, 0.5, 0.5
216
+ sarima_val_pred, prophet_val_pred = sarima_model(train_for_val, h=28), prophet_model_tuned(train_for_val, h=28); sarima_mae, prophet_mae = mean_absolute_error(validation_data['Value'], sarima_val_pred), mean_absolute_error(validation_data['Value'], prophet_val_pred)
217
+ if sarima_mae + prophet_mae == 0: return (sarima_pred + prophet_pred) / 2, 0.5, 0.5
218
+ inv_err_s, inv_err_p = (1 / sarima_mae if sarima_mae > 1e-9 else 1e9), (1 / prophet_mae if prophet_mae > 1e-9 else 1e9); total_inv_err = inv_err_s + inv_err_p; w_s, w_p = inv_err_s / total_inv_err, inv_err_p / total_inv_err
219
+ return w_s * sarima_pred + w_p * prophet_pred, w_s, w_p
220
+
221
+ # --- 7. 模型性能评估 (滚动预测) ---
222
+ status_message += "\n\n### 步骤 7: 模型性能评估 (滚动预测)\n"
223
+ start_date = ts_data.index.min() + timedelta(days=best_mae_window); end_date = ts_data.index.max() - timedelta(weeks=4); current_dates = pd.date_range(start=start_date, end=end_date, freq='W')
224
+ status_message += f"▶️ 将进行 {len(current_dates)} 次滚动预测,这会是耗时最长的步骤...\n"; yield update_ui(status_message, stl_plot=stl_fig_for_ui, window_plot=window_fig_for_ui)
225
+ all_metrics, weight_history = [], []
226
+ for i, date in enumerate(progress.tqdm(current_dates, desc="滚动预测评估")):
227
+ train_sliding, test_data = dynamic_split(ts_data.to_frame('Value'), date, best_mae_window, False); train_expanding, _ = dynamic_split(ts_data.to_frame('Value'), date, best_mae_window, True); actual_values = test_data['Value']; h = len(actual_values)
228
+ if h == 0: continue
229
+ sarima_pred_sliding, prophet_pred_tuned = sarima_model(train_sliding, h), prophet_model_tuned(train_expanding, h); weighted_pred, w_s, w_p = weighted_average_model(sarima_pred_sliding, prophet_pred_tuned, train_sliding); weight_history.append({'Date': date, 'SARIMA_Weight': w_s, 'Prophet_Weight': w_p}); models = {'SARIMA (Sliding)': sarima_pred_sliding, 'Prophet (Tuned)': prophet_pred_tuned, 'Weighted Average': weighted_pred}
230
+ for name, pred in models.items():
231
+ metrics = calculate_metrics(actual_values, pred); metrics['Model'], metrics['Date'] = name, date; all_metrics.append(metrics)
232
+ metrics_df, weight_history_df = pd.DataFrame(all_metrics), pd.DataFrame(weight_history)
233
+ status_message += "✔️ 滚动预测评估完成."; yield update_ui(status_message, stl_plot=stl_fig_for_ui, window_plot=window_fig_for_ui)
234
+
235
+ # --- 8. 模型对比可视化 ---
236
+ status_message += "\n\n### 步骤 8: 模型对比可视化\n▶️ 正在生成对比图表...\n"; progress(0.85, desc="生成对比图表...")
237
+ plt.figure(figsize=(12, 6)); sns.lineplot(data=weight_history_df, x='Date', y='SARIMA_Weight', label='SARIMA Weight'); sns.lineplot(data=weight_history_df, x='Date', y='Prophet_Weight', label='Prophet Weight'); plt.title('Weight Change of Sub-models'); plt.savefig(os.path.join(OUTPUT_DIR, '8_1_weight_change.png')); weight_fig_for_ui = plt.gcf(); plt.close()
238
+ metrics_long = metrics_df.melt(id_vars=['Date', 'Model'], value_vars=['MAE', 'RMSE', 'MAPE', 'sMAPE'], var_name='Metric', value_name='Value'); g = sns.FacetGrid(metrics_long, col='Metric', hue='Model', col_wrap=2, height=4, aspect=1.5, sharey=False); g.map(sns.lineplot, 'Date', 'Value', marker='o', markersize=4).add_legend(title="Model"); g.fig.suptitle('Comparison of Prediction Error Metrics', y=1.03); g.set_titles("{col_name}"); g.fig.savefig(os.path.join(OUTPUT_DIR, '8_2_error_comparison.png')); error_comp_fig_for_ui = g.fig; plt.close()
239
+ avg_metrics = metrics_df.groupby('Model')[['MAE', 'RMSE', 'MAPE', 'sMAPE']].mean(); avg_metrics_normalized = avg_metrics.apply(lambda x: (x - x.min()) / (x.max() - x.min())).reset_index(); labels = avg_metrics_normalized.columns[1:]; num_vars = len(labels); angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist() + [0]; fig_radar, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True));
240
+ for i, row in avg_metrics_normalized.iterrows(): values = row.drop('Model').tolist() + [row.drop('Model').tolist()[0]]; ax.plot(angles, values, label=row['Model'], linewidth=2); ax.fill(angles, values, alpha=0.2)
241
+ ax.set_xticks(angles[:-1]); ax.set_xticklabels(labels); plt.title('Average Model Performance Radar Chart (Normalized)'); plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1)); fig_radar.savefig(os.path.join(OUTPUT_DIR, '8_3_radar_chart.png')); radar_fig_for_ui = fig_radar; plt.close()
242
+ fig_box, axes = plt.subplots(2, 2, figsize=(16, 10))
243
+ for i, metric in enumerate(['MAE', 'RMSE', 'MAPE', 'sMAPE']): ax = axes[i//2, i%2]; sns.boxplot(data=metrics_df, x='Model', y=metric, hue='Model', ax=ax, palette='viridis', legend=False); ax.set_title(f'{metric} Distribution'); ax.tick_params(axis='x', rotation=30)
244
+ plt.tight_layout(); fig_box.savefig(os.path.join(OUTPUT_DIR, '8_4_error_boxplot.png')); boxplot_fig_for_ui = fig_box; plt.close()
245
+ last_date = current_dates[-1]; train_sliding, test_data = dynamic_split(ts_data.to_frame('Value'), last_date, best_mae_window, False); h = len(test_data); preds = {'SARIMA (Sliding)': sarima_model(train_sliding, h), 'Prophet (Tuned)': prophet_model_tuned(train_sliding, h)}; preds['Weighted Average'], _, _ = weighted_average_model(preds['SARIMA (Sliding)'], preds['Prophet (Tuned)'], train_sliding); fig_final, ax_final = plt.subplots(figsize=(16, 8)); ax_final.plot(test_data.index, test_data['Value'], label='Actual Value', color='black', linewidth=2.5, marker='o')
246
+ for name, pred in preds.items(): ax_final.plot(test_data.index, pred.values, label=name, linestyle='--')
247
+ ax_final.set_title(f'Final 4-Week Forecast Comparison (Start Date: {last_date.date()})'); ax_final.legend(); fig_final.savefig(os.path.join(OUTPUT_DIR, '8_5_final_forecast.png')); final_fig_for_ui = fig_final; plt.close()
248
+ status_message += "✔️ 所有对比图表已生成."
249
+ yield update_ui(status_message, stl_plot=stl_fig_for_ui, window_plot=window_fig_for_ui, weight_plot=weight_fig_for_ui, error_comp_plot=error_comp_fig_for_ui, radar_plot=radar_fig_for_ui, box_plot=boxplot_fig_for_ui, final_plot=final_fig_for_ui)
250
+
251
+ # --- 9. 最终性能总结 ---
252
+ status_message += "\n\n### 步骤 9: 最终性能总结\n"; progress(0.95, desc="生成最终报告...")
253
+ avg_perf = metrics_df.groupby('Model')[['MAE', 'RMSE', 'MAPE', 'sMAPE']].mean().sort_values('MAE'); status_message += "✔️ 平均性能计算完成。\n"
254
+ summary_path = os.path.join(OUTPUT_DIR, 'performance_summary.txt')
255
+ with open(summary_path, 'w', encoding='utf-8') as f: f.write("Time Series Forecasting Model Comparison Report\n" + "="*50 + "\n\n"); f.write(f"1. Optimal Sliding Window Size: {best_mae_window} days\n\n"); f.write(f"2. Stationarity: Stationary after {d_order} order(s) of differencing.\n\n"); f.write(f"3. Seasonality: Weekly seasonal differencing D(m=7) = {D_weekly}.\n\n"); f.write("4. Average Model Performance (sorted by MAE):\n" + avg_perf.to_string(float_format="%.4f"))
256
+ status_message += f"✔️ 性能总结报告已保存到 {summary_path}\n"
257
+ zip_archive_path = create_zip_archive(); status_message += f"✔️ 所有结果已打包到 {zip_archive_path}。"
258
+ progress(1, desc="分析完成!"); status_message += "\n\n## ✅ 分析流程执行完毕!"
259
+ yield update_ui(status_message, stl_plot=stl_fig_for_ui, window_plot=window_fig_for_ui, weight_plot=weight_fig_for_ui, error_comp_plot=error_comp_fig_for_ui, radar_plot=radar_fig_for_ui, box_plot=boxplot_fig_for_ui, final_plot=final_fig_for_ui, summary_df=avg_perf.reset_index(), download_file=zip_archive_path)
260
+
261
+ # ======================== 步骤 4: 构建 Gradio 界面 ========================
262
+ with gr.Blocks(theme=gr.themes.Soft(), title="时间序列预测分析平台") as app:
263
+ gr.Markdown("# 📈 时间序列预测分析平台"); gr.Markdown("上传您的时间序列数据 (XLSX格式,包含 'Date' 和 'Value' 两列),然后点击“开始分析”按钮。系统将自动完成数据清洗、检验、模型训练、评估和对比的全过程。")
264
+ with gr.Row():
265
+ with gr.Column(scale=1):
266
+ file_input = gr.File(label="上传 XLSX 数据文件", file_types=['.xlsx'], value="gmqrkl.xlsx"); run_button = gr.Button("🚀 开始分析", variant="primary"); gr.Markdown("---"); gr.Markdown("### 📝 分析日志与进度"); status_display = gr.Markdown("请上传文件并点击开始...")
267
+ with gr.Column(scale=3):
268
+ gr.Markdown("### 📊 分析结果可视化")
269
+ with gr.Tabs():
270
+ with gr.TabItem("初步分析"): plot_stl = gr.Plot(label="STL 季节性分解")
271
+ with gr.TabItem("优化窗口"): plot_window = gr.Plot(label="最优窗口大小分析")
272
+ with gr.TabItem("模型性能对比"): plot_error_comp = gr.Plot(label="各模型误差指标对比"); plot_boxplot = gr.Plot(label="误差分布箱线图")
273
+ with gr.TabItem("综合评估"): plot_radar = gr.Plot(label="模型平均性能雷达图"); plot_weights = gr.Plot(label="组合模型权重变化")
274
+ with gr.TabItem("最终预测"): plot_final_forecast = gr.Plot(label="最后四周预测对比图")
275
+ gr.Markdown("---"); gr.Markdown("### 🏆 最终性能总结"); summary_table = gr.DataFrame(label="各模型平均性能对比 (按 MAE 升序)"); download_results = gr.File(label="下载全部分析结果 (ZIP)", visible=False)
276
+
277
+ outputs_list = [status_display, plot_stl, plot_window, plot_weights, plot_error_comp, plot_radar, plot_boxplot, plot_final_forecast, summary_table, download_results]
278
+ run_button.click(fn=run_analysis, inputs=[file_input], outputs=outputs_list)
279
+
280
+ if __name__ == "__main__":
281
+ app.launch()