Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py (Final Corrected Version for Parallel Processing)
|
| 2 |
+
|
| 3 |
+
# ======================== 步骤 1: 导入所有库 ========================
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import numpy as np
|
| 7 |
+
from scipy import stats
|
| 8 |
+
import warnings
|
| 9 |
+
import os
|
| 10 |
+
import logging
|
| 11 |
+
from datetime import timedelta
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
import seaborn as sns
|
| 14 |
+
from matplotlib.ticker import FuncFormatter
|
| 15 |
+
import shutil
|
| 16 |
+
import zipfile
|
| 17 |
+
|
| 18 |
+
# 时间序列核心库
|
| 19 |
+
from statsmodels.tsa.stattools import adfuller
|
| 20 |
+
from statsmodels.stats.diagnostic import acorr_ljungbox
|
| 21 |
+
from statsmodels.tsa.seasonal import STL
|
| 22 |
+
import pmdarima as pm
|
| 23 |
+
from prophet import Prophet
|
| 24 |
+
|
| 25 |
+
# 机器学习和评估
|
| 26 |
+
from sklearn.metrics import mean_absolute_error, mean_squared_error
|
| 27 |
+
from joblib import Parallel, delayed
|
| 28 |
+
|
| 29 |
+
# --- 全局设置 ---
|
| 30 |
+
warnings.filterwarnings("ignore")
|
| 31 |
+
logging.getLogger('prophet').setLevel(logging.ERROR)
|
| 32 |
+
logging.getLogger('cmdstanpy').setLevel(logging.ERROR)
|
| 33 |
+
plt.rcParams['axes.unicode_minus'] = False # 确保负号可以正常显示
|
| 34 |
+
|
| 35 |
+
# --- 输出文件夹设置 ---
|
| 36 |
+
OUTPUT_DIR = 'analysis_output'
|
| 37 |
+
|
| 38 |
+
# ======================== 步骤 2: 定义辅助函数 ========================
|
| 39 |
+
|
| 40 |
+
def calculate_metrics(actual, predicted):
|
| 41 |
+
"""计算MAE, RMSE, MAPE, sMAPE"""
|
| 42 |
+
metrics_df = pd.DataFrame({'actual': actual, 'predicted': predicted}).dropna()
|
| 43 |
+
if metrics_df.empty:
|
| 44 |
+
return {'MAE': np.nan, 'RMSE': np.nan, 'MAPE': np.nan, 'sMAPE': np.nan}
|
| 45 |
+
|
| 46 |
+
clean_actual, clean_predicted = metrics_df['actual'], metrics_df['predicted']
|
| 47 |
+
mae = mean_absolute_error(clean_actual, clean_predicted)
|
| 48 |
+
rmse = np.sqrt(mean_squared_error(clean_actual, clean_predicted))
|
| 49 |
+
actual_safe = np.where(clean_actual == 0, 1e-6, clean_actual)
|
| 50 |
+
mape = np.mean(np.abs((clean_actual - clean_predicted) / actual_safe)) * 100
|
| 51 |
+
denominator = np.abs(clean_actual) + np.abs(clean_predicted)
|
| 52 |
+
denominator_safe = np.where(denominator == 0, 1e-6, denominator)
|
| 53 |
+
smape = 200 * np.mean(np.abs(clean_actual - clean_predicted) / denominator_safe)
|
| 54 |
+
|
| 55 |
+
return {'MAE': mae, 'RMSE': rmse, 'MAPE': mape, 'sMAPE': smape}
|
| 56 |
+
|
| 57 |
+
def dynamic_split(data, current_date, window_size, expanding=False):
|
| 58 |
+
if expanding:
|
| 59 |
+
train_start_date = data.index.min()
|
| 60 |
+
else:
|
| 61 |
+
train_start_date = current_date - timedelta(days=window_size - 1)
|
| 62 |
+
train_data = data[(data.index >= train_start_date) & (data.index <= current_date)]
|
| 63 |
+
test_start_date = current_date + timedelta(days=1)
|
| 64 |
+
test_end_date = current_date + timedelta(weeks=4)
|
| 65 |
+
test_data = data[(data.index >= test_start_date) & (data.index <= test_end_date)]
|
| 66 |
+
return train_data, test_data
|
| 67 |
+
|
| 68 |
+
def create_zip_archive():
|
| 69 |
+
"""将输出文件夹打包成 zip 文件"""
|
| 70 |
+
zip_path = f"{OUTPUT_DIR}.zip"
|
| 71 |
+
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
| 72 |
+
for root, _, files in os.walk(OUTPUT_DIR):
|
| 73 |
+
for file in files:
|
| 74 |
+
zipf.write(os.path.join(root, file),
|
| 75 |
+
os.path.relpath(os.path.join(root, file), OUTPUT_DIR))
|
| 76 |
+
return zip_path
|
| 77 |
+
|
| 78 |
+
# **修改点 1**: 改变函数签名,直接接收3个参数
|
| 79 |
+
def _parallel_evaluate_window(window_size, ts_values, d_order):
|
| 80 |
+
"""一个独立的、可被 joblib 并行调用的函数。"""
|
| 81 |
+
n = len(ts_values)
|
| 82 |
+
if window_size >= n:
|
| 83 |
+
return np.inf
|
| 84 |
+
errors = []
|
| 85 |
+
eval_len = min(n, 365)
|
| 86 |
+
for i in range(eval_len - window_size):
|
| 87 |
+
train = ts_values[i:(i + window_size)]
|
| 88 |
+
test = ts_values[i + window_size]
|
| 89 |
+
try:
|
| 90 |
+
model = pm.auto_arima(train, d=d_order, seasonal=False, stepwise=True, suppress_warnings=True, error_action='ignore')
|
| 91 |
+
errors.append(test - model.predict(n_periods=1)[0])
|
| 92 |
+
except Exception:
|
| 93 |
+
continue
|
| 94 |
+
return np.mean(np.abs(errors)) if errors else np.inf
|
| 95 |
+
|
| 96 |
+
# ======================== 步骤 3: 核心分析函数 ========================
|
| 97 |
+
def run_analysis(uploaded_file, progress=gr.Progress(track_tqdm=True)):
|
| 98 |
+
|
| 99 |
+
# --- 初始化 ---
|
| 100 |
+
progress(0, desc="开始分析流程...")
|
| 101 |
+
if os.path.exists(OUTPUT_DIR): shutil.rmtree(OUTPUT_DIR)
|
| 102 |
+
os.makedirs(OUTPUT_DIR)
|
| 103 |
+
|
| 104 |
+
def update_ui(status, stl_plot=None, window_plot=None, weight_plot=None, error_comp_plot=None, radar_plot=None, box_plot=None, final_plot=None, summary_df=None, download_file=None):
|
| 105 |
+
return {
|
| 106 |
+
status_display: gr.update(value=status), plot_stl: gr.update(value=stl_plot),
|
| 107 |
+
plot_window: gr.update(value=window_plot), plot_weights: gr.update(value=weight_plot),
|
| 108 |
+
plot_error_comp: gr.update(value=error_comp_plot), plot_radar: gr.update(value=radar_plot),
|
| 109 |
+
plot_boxplot: gr.update(value=box_plot), plot_final_forecast: gr.update(value=final_plot),
|
| 110 |
+
summary_table: gr.update(value=summary_df), download_results: gr.update(value=download_file, visible=download_file is not None)
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
yield update_ui("分析已开始,正在准备环境...")
|
| 114 |
+
|
| 115 |
+
# --- 1. 数据清洗 ---
|
| 116 |
+
status_message = "### 步骤 1: 数据清洗\n"
|
| 117 |
+
try:
|
| 118 |
+
progress(0.05, desc="正在读取和清洗数据...")
|
| 119 |
+
yield update_ui(status_message + "▶️ 正在读取上传的 Excel 文件...")
|
| 120 |
+
df = pd.read_excel(uploaded_file.name)
|
| 121 |
+
df['Date'] = pd.to_datetime(df['Date'])
|
| 122 |
+
df = df.drop_duplicates(subset=['Date']).sort_values('Date').reset_index(drop=True)
|
| 123 |
+
status_message += f"✔️ 数据读取成功,共 {len(df)} 行,时间范围: {df['Date'].min().date()} 到 {df['Date'].max().date()}.\n"
|
| 124 |
+
yield update_ui(status_message)
|
| 125 |
+
except Exception as e:
|
| 126 |
+
gr.Error(f"文件读取失败: {e}. 请确保上传的 XLSX 文件包含 'Date' 和 'Value' 列。"); return
|
| 127 |
+
df['Value'] = df['Value'].replace(0, np.nan).interpolate(method='linear', limit_direction='both')
|
| 128 |
+
status_message += "✔️ 零值替换与线性插值完成。\n"
|
| 129 |
+
z_scores = np.abs(stats.zscore(df['Value'])); outliers = df[z_scores > 3]
|
| 130 |
+
status_message += f"✔️ Z-score 异常值检测完成,发现 {len(outliers)} 个潜在异常值。\n"
|
| 131 |
+
yield update_ui(status_message)
|
| 132 |
+
df.set_index('Date', inplace=True)
|
| 133 |
+
ts_data = df['Value'].asfreq('D').interpolate(method='linear')
|
| 134 |
+
|
| 135 |
+
# --- 2. 平稳性检验与差分 ---
|
| 136 |
+
status_message += "\n### 步骤 2: 平稳性检验\n"
|
| 137 |
+
progress(0.1, desc="进行平稳性检验...")
|
| 138 |
+
def make_stationary(data):
|
| 139 |
+
current_data, log = data.copy(), ""
|
| 140 |
+
for d in range(4):
|
| 141 |
+
p_value = adfuller(current_data.dropna())[1]
|
| 142 |
+
log += f"差分阶数 d={d}, ADF检验 p-value: {p_value:.4f}\n"
|
| 143 |
+
if p_value < 0.05:
|
| 144 |
+
log += f"✔️ 序列在 d={d} 阶差分后达到平稳."; return current_data.dropna(), d, log
|
| 145 |
+
current_data = current_data.diff()
|
| 146 |
+
log += f"⚠️ 警告:在最大差分阶数内未达到平稳."; return current_data.dropna(), 3, log
|
| 147 |
+
ts_stationary, d_order, stationarity_log = make_stationary(ts_data)
|
| 148 |
+
status_message += stationarity_log; yield update_ui(status_message)
|
| 149 |
+
|
| 150 |
+
# --- 3. 白噪声检验 ---
|
| 151 |
+
status_message += "\n\n### 步骤 3: 白噪声检验\n"; progress(0.15, desc="进行白噪声检验...")
|
| 152 |
+
lags = min(10, len(ts_stationary) // 5)
|
| 153 |
+
lb_p_value = acorr_ljungbox(ts_stationary, lags=[lags], return_df=True)['lb_pvalue'].iloc[0]
|
| 154 |
+
status_message += f"✔️ 通过白噪声检验 (p-value = {lb_p_value:.4f}),序列为非白噪声。" if lb_p_value <= 0.05 else f"⚠️ 序列为白噪声 (p-value = {lb_p_value:.4f}),可能不适合复杂建模。"
|
| 155 |
+
yield update_ui(status_message)
|
| 156 |
+
|
| 157 |
+
# --- 4. 季节性检验与处理 ---
|
| 158 |
+
status_message += "\n\n### 步骤 4: 季节性分析 (初步分析)\n"; progress(0.2, desc="进行季节性分解...")
|
| 159 |
+
D_weekly = pm.arima.nsdiffs(ts_data, m=7, test='ch')
|
| 160 |
+
status_message += f"✔️ 每周季节性差分阶数 D(m=7): {D_weekly} (通过 Canova-Hansen 测试确定)。\n"
|
| 161 |
+
plt.figure(figsize=(14, 10)); stl_year = STL(ts_data, period=365).fit(); stl_year.plot()
|
| 162 |
+
plt.suptitle('STL Decomposition (Annual Seasonality)', y=1.02); plt.savefig(os.path.join(OUTPUT_DIR, '4_stl_decomposition.png'))
|
| 163 |
+
stl_fig_for_ui = plt.gcf(); plt.close()
|
| 164 |
+
status_message += "✔️ STL分解图已生成."; yield update_ui(status_message, stl_plot=stl_fig_for_ui)
|
| 165 |
+
|
| 166 |
+
# --- 5. 窗口大小优化 (使用并行计算) ---
|
| 167 |
+
status_message += "\n\n### 步骤 5: 窗口大小优化\n"
|
| 168 |
+
status_message += "▶️ **正在启动多进程并行计算**以评估不同窗口大小的MAE。\n▶️ 下方的进度条将显示总体进度。计算完成后会显示详细结果。\n"
|
| 169 |
+
yield update_ui(status_message, stl_plot=stl_fig_for_ui)
|
| 170 |
+
|
| 171 |
+
ts_values = ts_data.values; window_sizes = range(70, 211, 14)
|
| 172 |
+
tasks = [(ws, ts_values, d_order) for ws in window_sizes]
|
| 173 |
+
|
| 174 |
+
# **修改点 2**: 使用 *task 来解包参数
|
| 175 |
+
results_maes = Parallel(n_jobs=-1)(
|
| 176 |
+
delayed(_parallel_evaluate_window)(*task) for task in progress.tqdm(tasks, desc="优化窗口大小")
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
results_df = pd.DataFrame({'window_size': window_sizes, 'mae': results_maes})
|
| 180 |
+
status_message += "---\n✔️ 并行计算完成,结果如下:\n"
|
| 181 |
+
for _, row in results_df.iterrows():
|
| 182 |
+
status_message += f"窗口大小: {int(row['window_size']):<4} -> MAE: {row['mae']:.4f}\n"
|
| 183 |
+
status_message += "---\n"
|
| 184 |
+
if results_df.empty or results_df['mae'].isna().all():
|
| 185 |
+
gr.Error("窗口优化失败,所有窗口大小均未产出有效MAE。"); return
|
| 186 |
+
best_mae_window = int(results_df.loc[results_df['mae'].idxmin()]['window_size'])
|
| 187 |
+
plt.figure(figsize=(10, 6)); sns.lineplot(data=results_df, x='window_size', y='mae', marker='o')
|
| 188 |
+
plt.title('Effect of Window Size on Prediction Error (MAE)'); plt.xlabel('Training Window Days'); plt.ylabel('Mean Absolute Error (MAE)')
|
| 189 |
+
plt.grid(True); plt.savefig(os.path.join(OUTPUT_DIR, '5_window_size_optimization.png'))
|
| 190 |
+
window_fig_for_ui = plt.gcf(); plt.close()
|
| 191 |
+
status_message += f"✔️ 窗口大小优化完成���最优滑动窗口大小: **{best_mae_window}** 天。"
|
| 192 |
+
yield update_ui(status_message, stl_plot=stl_fig_for_ui, window_plot=window_fig_for_ui)
|
| 193 |
+
|
| 194 |
+
# --- 6. 模型定义 ---
|
| 195 |
+
status_message += "\n\n### 步骤 6: 定义所有模型\n✔️ SARIMA, Prophet, 加权平均模型已定义。"
|
| 196 |
+
yield update_ui(status_message, stl_plot=stl_fig_for_ui, window_plot=window_fig_for_ui)
|
| 197 |
+
|
| 198 |
+
def sarima_model(train_data, h=28):
|
| 199 |
+
return pm.auto_arima(train_data['Value'], d=d_order, D=D_weekly, m=7, seasonal=True, stepwise=True, suppress_warnings=True, error_action='ignore', max_p=3, max_q=3).predict(n_periods=h)
|
| 200 |
+
def prophet_model_default(train_data, h=28):
|
| 201 |
+
df_prophet = train_data.reset_index().rename(columns={'Date': 'ds', 'Value': 'y'}); model = Prophet(yearly_seasonality=True, weekly_seasonality=True).fit(df_prophet); future = model.make_future_dataframe(periods=h, freq='D'); return model.predict(future)['yhat'].tail(h)
|
| 202 |
+
def prophet_model_tuned(train_data, h=28):
|
| 203 |
+
df_prophet = train_data.reset_index().rename(columns={'Date': 'ds', 'Value': 'y'}); param_grid = {'changepoint_prior_scale': [0.01, 0.1, 0.5], 'seasonality_prior_scale': [1.0, 5.0, 10.0]}; best_mae, best_params = float('inf'), {}; train_len = len(df_prophet); split_point = train_len - 28
|
| 204 |
+
if split_point < 10: return prophet_model_default(train_data, h)
|
| 205 |
+
train_sub, val_sub = df_prophet.iloc[:split_point], df_prophet.iloc[split_point:]
|
| 206 |
+
for cps in param_grid['changepoint_prior_scale']:
|
| 207 |
+
for sps in param_grid['seasonality_prior_scale']:
|
| 208 |
+
m = Prophet(changepoint_prior_scale=cps, seasonality_prior_scale=sps).fit(train_sub); fc_val = m.predict(m.make_future_dataframe(periods=len(val_sub))).tail(len(val_sub)); mae = mean_absolute_error(val_sub['y'], fc_val['yhat'])
|
| 209 |
+
if mae < best_mae: best_mae, best_params = mae, {'changepoint_prior_scale': cps, 'seasonality_prior_scale': sps}
|
| 210 |
+
final_model = Prophet(**best_params).fit(df_prophet); return final_model.predict(final_model.make_future_dataframe(periods=h))['yhat'].tail(h)
|
| 211 |
+
def weighted_average_model(sarima_pred, prophet_pred, train_data):
|
| 212 |
+
val_end_date = train_data.index.max(); val_start_date = val_end_date - timedelta(weeks=4) + timedelta(days=1); validation_data = train_data[train_data.index >= val_start_date]
|
| 213 |
+
if len(validation_data) < 28: return (sarima_pred + prophet_pred) / 2, 0.5, 0.5
|
| 214 |
+
train_for_val = train_data[train_data.index < val_start_date]
|
| 215 |
+
if len(train_for_val) < 20: return (sarima_pred + prophet_pred) / 2, 0.5, 0.5
|
| 216 |
+
sarima_val_pred, prophet_val_pred = sarima_model(train_for_val, h=28), prophet_model_tuned(train_for_val, h=28); sarima_mae, prophet_mae = mean_absolute_error(validation_data['Value'], sarima_val_pred), mean_absolute_error(validation_data['Value'], prophet_val_pred)
|
| 217 |
+
if sarima_mae + prophet_mae == 0: return (sarima_pred + prophet_pred) / 2, 0.5, 0.5
|
| 218 |
+
inv_err_s, inv_err_p = (1 / sarima_mae if sarima_mae > 1e-9 else 1e9), (1 / prophet_mae if prophet_mae > 1e-9 else 1e9); total_inv_err = inv_err_s + inv_err_p; w_s, w_p = inv_err_s / total_inv_err, inv_err_p / total_inv_err
|
| 219 |
+
return w_s * sarima_pred + w_p * prophet_pred, w_s, w_p
|
| 220 |
+
|
| 221 |
+
# --- 7. 模型性能评估 (滚动预测) ---
|
| 222 |
+
status_message += "\n\n### 步骤 7: 模型性能评估 (滚动预测)\n"
|
| 223 |
+
start_date = ts_data.index.min() + timedelta(days=best_mae_window); end_date = ts_data.index.max() - timedelta(weeks=4); current_dates = pd.date_range(start=start_date, end=end_date, freq='W')
|
| 224 |
+
status_message += f"▶️ 将进行 {len(current_dates)} 次滚动预测,这会是耗时最长的步骤...\n"; yield update_ui(status_message, stl_plot=stl_fig_for_ui, window_plot=window_fig_for_ui)
|
| 225 |
+
all_metrics, weight_history = [], []
|
| 226 |
+
for i, date in enumerate(progress.tqdm(current_dates, desc="滚动预测评估")):
|
| 227 |
+
train_sliding, test_data = dynamic_split(ts_data.to_frame('Value'), date, best_mae_window, False); train_expanding, _ = dynamic_split(ts_data.to_frame('Value'), date, best_mae_window, True); actual_values = test_data['Value']; h = len(actual_values)
|
| 228 |
+
if h == 0: continue
|
| 229 |
+
sarima_pred_sliding, prophet_pred_tuned = sarima_model(train_sliding, h), prophet_model_tuned(train_expanding, h); weighted_pred, w_s, w_p = weighted_average_model(sarima_pred_sliding, prophet_pred_tuned, train_sliding); weight_history.append({'Date': date, 'SARIMA_Weight': w_s, 'Prophet_Weight': w_p}); models = {'SARIMA (Sliding)': sarima_pred_sliding, 'Prophet (Tuned)': prophet_pred_tuned, 'Weighted Average': weighted_pred}
|
| 230 |
+
for name, pred in models.items():
|
| 231 |
+
metrics = calculate_metrics(actual_values, pred); metrics['Model'], metrics['Date'] = name, date; all_metrics.append(metrics)
|
| 232 |
+
metrics_df, weight_history_df = pd.DataFrame(all_metrics), pd.DataFrame(weight_history)
|
| 233 |
+
status_message += "✔️ 滚动预测评估完成."; yield update_ui(status_message, stl_plot=stl_fig_for_ui, window_plot=window_fig_for_ui)
|
| 234 |
+
|
| 235 |
+
# --- 8. 模型对比可视化 ---
|
| 236 |
+
status_message += "\n\n### 步骤 8: 模型对比可视化\n▶️ 正在生成对比图表...\n"; progress(0.85, desc="生成对比图表...")
|
| 237 |
+
plt.figure(figsize=(12, 6)); sns.lineplot(data=weight_history_df, x='Date', y='SARIMA_Weight', label='SARIMA Weight'); sns.lineplot(data=weight_history_df, x='Date', y='Prophet_Weight', label='Prophet Weight'); plt.title('Weight Change of Sub-models'); plt.savefig(os.path.join(OUTPUT_DIR, '8_1_weight_change.png')); weight_fig_for_ui = plt.gcf(); plt.close()
|
| 238 |
+
metrics_long = metrics_df.melt(id_vars=['Date', 'Model'], value_vars=['MAE', 'RMSE', 'MAPE', 'sMAPE'], var_name='Metric', value_name='Value'); g = sns.FacetGrid(metrics_long, col='Metric', hue='Model', col_wrap=2, height=4, aspect=1.5, sharey=False); g.map(sns.lineplot, 'Date', 'Value', marker='o', markersize=4).add_legend(title="Model"); g.fig.suptitle('Comparison of Prediction Error Metrics', y=1.03); g.set_titles("{col_name}"); g.fig.savefig(os.path.join(OUTPUT_DIR, '8_2_error_comparison.png')); error_comp_fig_for_ui = g.fig; plt.close()
|
| 239 |
+
avg_metrics = metrics_df.groupby('Model')[['MAE', 'RMSE', 'MAPE', 'sMAPE']].mean(); avg_metrics_normalized = avg_metrics.apply(lambda x: (x - x.min()) / (x.max() - x.min())).reset_index(); labels = avg_metrics_normalized.columns[1:]; num_vars = len(labels); angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist() + [0]; fig_radar, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True));
|
| 240 |
+
for i, row in avg_metrics_normalized.iterrows(): values = row.drop('Model').tolist() + [row.drop('Model').tolist()[0]]; ax.plot(angles, values, label=row['Model'], linewidth=2); ax.fill(angles, values, alpha=0.2)
|
| 241 |
+
ax.set_xticks(angles[:-1]); ax.set_xticklabels(labels); plt.title('Average Model Performance Radar Chart (Normalized)'); plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1)); fig_radar.savefig(os.path.join(OUTPUT_DIR, '8_3_radar_chart.png')); radar_fig_for_ui = fig_radar; plt.close()
|
| 242 |
+
fig_box, axes = plt.subplots(2, 2, figsize=(16, 10))
|
| 243 |
+
for i, metric in enumerate(['MAE', 'RMSE', 'MAPE', 'sMAPE']): ax = axes[i//2, i%2]; sns.boxplot(data=metrics_df, x='Model', y=metric, hue='Model', ax=ax, palette='viridis', legend=False); ax.set_title(f'{metric} Distribution'); ax.tick_params(axis='x', rotation=30)
|
| 244 |
+
plt.tight_layout(); fig_box.savefig(os.path.join(OUTPUT_DIR, '8_4_error_boxplot.png')); boxplot_fig_for_ui = fig_box; plt.close()
|
| 245 |
+
last_date = current_dates[-1]; train_sliding, test_data = dynamic_split(ts_data.to_frame('Value'), last_date, best_mae_window, False); h = len(test_data); preds = {'SARIMA (Sliding)': sarima_model(train_sliding, h), 'Prophet (Tuned)': prophet_model_tuned(train_sliding, h)}; preds['Weighted Average'], _, _ = weighted_average_model(preds['SARIMA (Sliding)'], preds['Prophet (Tuned)'], train_sliding); fig_final, ax_final = plt.subplots(figsize=(16, 8)); ax_final.plot(test_data.index, test_data['Value'], label='Actual Value', color='black', linewidth=2.5, marker='o')
|
| 246 |
+
for name, pred in preds.items(): ax_final.plot(test_data.index, pred.values, label=name, linestyle='--')
|
| 247 |
+
ax_final.set_title(f'Final 4-Week Forecast Comparison (Start Date: {last_date.date()})'); ax_final.legend(); fig_final.savefig(os.path.join(OUTPUT_DIR, '8_5_final_forecast.png')); final_fig_for_ui = fig_final; plt.close()
|
| 248 |
+
status_message += "✔️ 所有对比图表已生成."
|
| 249 |
+
yield update_ui(status_message, stl_plot=stl_fig_for_ui, window_plot=window_fig_for_ui, weight_plot=weight_fig_for_ui, error_comp_plot=error_comp_fig_for_ui, radar_plot=radar_fig_for_ui, box_plot=boxplot_fig_for_ui, final_plot=final_fig_for_ui)
|
| 250 |
+
|
| 251 |
+
# --- 9. 最终性能总结 ---
|
| 252 |
+
status_message += "\n\n### 步骤 9: 最终性能总结\n"; progress(0.95, desc="生成最终报告...")
|
| 253 |
+
avg_perf = metrics_df.groupby('Model')[['MAE', 'RMSE', 'MAPE', 'sMAPE']].mean().sort_values('MAE'); status_message += "✔️ 平均性能计算完成。\n"
|
| 254 |
+
summary_path = os.path.join(OUTPUT_DIR, 'performance_summary.txt')
|
| 255 |
+
with open(summary_path, 'w', encoding='utf-8') as f: f.write("Time Series Forecasting Model Comparison Report\n" + "="*50 + "\n\n"); f.write(f"1. Optimal Sliding Window Size: {best_mae_window} days\n\n"); f.write(f"2. Stationarity: Stationary after {d_order} order(s) of differencing.\n\n"); f.write(f"3. Seasonality: Weekly seasonal differencing D(m=7) = {D_weekly}.\n\n"); f.write("4. Average Model Performance (sorted by MAE):\n" + avg_perf.to_string(float_format="%.4f"))
|
| 256 |
+
status_message += f"✔️ 性能总结报告已保存到 {summary_path}\n"
|
| 257 |
+
zip_archive_path = create_zip_archive(); status_message += f"✔️ 所有结果已打包到 {zip_archive_path}。"
|
| 258 |
+
progress(1, desc="分析完成!"); status_message += "\n\n## ✅ 分析流程执行完毕!"
|
| 259 |
+
yield update_ui(status_message, stl_plot=stl_fig_for_ui, window_plot=window_fig_for_ui, weight_plot=weight_fig_for_ui, error_comp_plot=error_comp_fig_for_ui, radar_plot=radar_fig_for_ui, box_plot=boxplot_fig_for_ui, final_plot=final_fig_for_ui, summary_df=avg_perf.reset_index(), download_file=zip_archive_path)
|
| 260 |
+
|
| 261 |
+
# ======================== 步骤 4: 构建 Gradio 界面 ========================
|
| 262 |
+
with gr.Blocks(theme=gr.themes.Soft(), title="时间序列预测分析平台") as app:
|
| 263 |
+
gr.Markdown("# 📈 时间序列预测分析平台"); gr.Markdown("上传您的时间序列数据 (XLSX格式,包含 'Date' 和 'Value' 两列),然后点击“开始分析”按钮。系统将自动完成数据清洗、检验、模型训练、评估和对比的全过程。")
|
| 264 |
+
with gr.Row():
|
| 265 |
+
with gr.Column(scale=1):
|
| 266 |
+
file_input = gr.File(label="上传 XLSX 数据文件", file_types=['.xlsx'], value="gmqrkl.xlsx"); run_button = gr.Button("🚀 开始分析", variant="primary"); gr.Markdown("---"); gr.Markdown("### 📝 分析日志与进度"); status_display = gr.Markdown("请上传文件并点击开始...")
|
| 267 |
+
with gr.Column(scale=3):
|
| 268 |
+
gr.Markdown("### 📊 分析结果可视化")
|
| 269 |
+
with gr.Tabs():
|
| 270 |
+
with gr.TabItem("初步分析"): plot_stl = gr.Plot(label="STL 季节性分解")
|
| 271 |
+
with gr.TabItem("优化窗口"): plot_window = gr.Plot(label="最优窗口大小分析")
|
| 272 |
+
with gr.TabItem("模型性能对比"): plot_error_comp = gr.Plot(label="各模型误差指标对比"); plot_boxplot = gr.Plot(label="误差分布箱线图")
|
| 273 |
+
with gr.TabItem("综合评估"): plot_radar = gr.Plot(label="模型平均性能雷达图"); plot_weights = gr.Plot(label="组合模型权重变化")
|
| 274 |
+
with gr.TabItem("最终预测"): plot_final_forecast = gr.Plot(label="最后四周预测对比图")
|
| 275 |
+
gr.Markdown("---"); gr.Markdown("### 🏆 最终性能总结"); summary_table = gr.DataFrame(label="各模型平均性能对比 (按 MAE 升序)"); download_results = gr.File(label="下载全部分析结果 (ZIP)", visible=False)
|
| 276 |
+
|
| 277 |
+
outputs_list = [status_display, plot_stl, plot_window, plot_weights, plot_error_comp, plot_radar, plot_boxplot, plot_final_forecast, summary_table, download_results]
|
| 278 |
+
run_button.click(fn=run_analysis, inputs=[file_input], outputs=outputs_list)
|
| 279 |
+
|
| 280 |
+
if __name__ == "__main__":
|
| 281 |
+
app.launch()
|