PredictDrugNew / app.py
leonsimon23's picture
Create app.py
4b86963 verified
# app.py (Final Corrected Version for Parallel Processing)
# ======================== 步骤 1: 导入所有库 ========================
import gradio as gr
import pandas as pd
import numpy as np
from scipy import stats
import warnings
import os
import logging
from datetime import timedelta
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import FuncFormatter
import shutil
import zipfile
# 时间序列核心库
from statsmodels.tsa.stattools import adfuller
from statsmodels.stats.diagnostic import acorr_ljungbox
from statsmodels.tsa.seasonal import STL
import pmdarima as pm
from prophet import Prophet
# 机器学习和评估
from sklearn.metrics import mean_absolute_error, mean_squared_error
from joblib import Parallel, delayed
# --- 全局设置 ---
warnings.filterwarnings("ignore")
logging.getLogger('prophet').setLevel(logging.ERROR)
logging.getLogger('cmdstanpy').setLevel(logging.ERROR)
plt.rcParams['axes.unicode_minus'] = False # 确保负号可以正常显示
# --- 输出文件夹设置 ---
OUTPUT_DIR = 'analysis_output'
# ======================== 步骤 2: 定义辅助函数 ========================
def calculate_metrics(actual, predicted):
"""计算MAE, RMSE, MAPE, sMAPE"""
metrics_df = pd.DataFrame({'actual': actual, 'predicted': predicted}).dropna()
if metrics_df.empty:
return {'MAE': np.nan, 'RMSE': np.nan, 'MAPE': np.nan, 'sMAPE': np.nan}
clean_actual, clean_predicted = metrics_df['actual'], metrics_df['predicted']
mae = mean_absolute_error(clean_actual, clean_predicted)
rmse = np.sqrt(mean_squared_error(clean_actual, clean_predicted))
actual_safe = np.where(clean_actual == 0, 1e-6, clean_actual)
mape = np.mean(np.abs((clean_actual - clean_predicted) / actual_safe)) * 100
denominator = np.abs(clean_actual) + np.abs(clean_predicted)
denominator_safe = np.where(denominator == 0, 1e-6, denominator)
smape = 200 * np.mean(np.abs(clean_actual - clean_predicted) / denominator_safe)
return {'MAE': mae, 'RMSE': rmse, 'MAPE': mape, 'sMAPE': smape}
def dynamic_split(data, current_date, window_size, expanding=False):
if expanding:
train_start_date = data.index.min()
else:
train_start_date = current_date - timedelta(days=window_size - 1)
train_data = data[(data.index >= train_start_date) & (data.index <= current_date)]
test_start_date = current_date + timedelta(days=1)
test_end_date = current_date + timedelta(weeks=4)
test_data = data[(data.index >= test_start_date) & (data.index <= test_end_date)]
return train_data, test_data
def create_zip_archive():
"""将输出文件夹打包成 zip 文件"""
zip_path = f"{OUTPUT_DIR}.zip"
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
for root, _, files in os.walk(OUTPUT_DIR):
for file in files:
zipf.write(os.path.join(root, file),
os.path.relpath(os.path.join(root, file), OUTPUT_DIR))
return zip_path
# **修改点 1**: 改变函数签名,直接接收3个参数
def _parallel_evaluate_window(window_size, ts_values, d_order):
"""一个独立的、可被 joblib 并行调用的函数。"""
n = len(ts_values)
if window_size >= n:
return np.inf
errors = []
eval_len = min(n, 365)
for i in range(eval_len - window_size):
train = ts_values[i:(i + window_size)]
test = ts_values[i + window_size]
try:
model = pm.auto_arima(train, d=d_order, seasonal=False, stepwise=True, suppress_warnings=True, error_action='ignore')
errors.append(test - model.predict(n_periods=1)[0])
except Exception:
continue
return np.mean(np.abs(errors)) if errors else np.inf
# ======================== 步骤 3: 核心分析函数 ========================
def run_analysis(uploaded_file, progress=gr.Progress(track_tqdm=True)):
# --- 初始化 ---
progress(0, desc="开始分析流程...")
if os.path.exists(OUTPUT_DIR): shutil.rmtree(OUTPUT_DIR)
os.makedirs(OUTPUT_DIR)
def update_ui(status, stl_plot=None, window_plot=None, weight_plot=None, error_comp_plot=None, radar_plot=None, box_plot=None, final_plot=None, summary_df=None, download_file=None):
return {
status_display: gr.update(value=status), plot_stl: gr.update(value=stl_plot),
plot_window: gr.update(value=window_plot), plot_weights: gr.update(value=weight_plot),
plot_error_comp: gr.update(value=error_comp_plot), plot_radar: gr.update(value=radar_plot),
plot_boxplot: gr.update(value=box_plot), plot_final_forecast: gr.update(value=final_plot),
summary_table: gr.update(value=summary_df), download_results: gr.update(value=download_file, visible=download_file is not None)
}
yield update_ui("分析已开始,正在准备环境...")
# --- 1. 数据清洗 ---
status_message = "### 步骤 1: 数据清洗\n"
try:
progress(0.05, desc="正在读取和清洗数据...")
yield update_ui(status_message + "▶️ 正在读取上传的 Excel 文件...")
df = pd.read_excel(uploaded_file.name)
df['Date'] = pd.to_datetime(df['Date'])
df = df.drop_duplicates(subset=['Date']).sort_values('Date').reset_index(drop=True)
status_message += f"✔️ 数据读取成功,共 {len(df)} 行,时间范围: {df['Date'].min().date()}{df['Date'].max().date()}.\n"
yield update_ui(status_message)
except Exception as e:
gr.Error(f"文件读取失败: {e}. 请确保上传的 XLSX 文件包含 'Date' 和 'Value' 列。"); return
df['Value'] = df['Value'].replace(0, np.nan).interpolate(method='linear', limit_direction='both')
status_message += "✔️ 零值替换与线性插值完成。\n"
z_scores = np.abs(stats.zscore(df['Value'])); outliers = df[z_scores > 3]
status_message += f"✔️ Z-score 异常值检测完成,发现 {len(outliers)} 个潜在异常值。\n"
yield update_ui(status_message)
df.set_index('Date', inplace=True)
ts_data = df['Value'].asfreq('D').interpolate(method='linear')
# --- 2. 平稳性检验与差分 ---
status_message += "\n### 步骤 2: 平稳性检验\n"
progress(0.1, desc="进行平稳性检验...")
def make_stationary(data):
current_data, log = data.copy(), ""
for d in range(4):
p_value = adfuller(current_data.dropna())[1]
log += f"差分阶数 d={d}, ADF检验 p-value: {p_value:.4f}\n"
if p_value < 0.05:
log += f"✔️ 序列在 d={d} 阶差分后达到平稳."; return current_data.dropna(), d, log
current_data = current_data.diff()
log += f"⚠️ 警告:在最大差分阶数内未达到平稳."; return current_data.dropna(), 3, log
ts_stationary, d_order, stationarity_log = make_stationary(ts_data)
status_message += stationarity_log; yield update_ui(status_message)
# --- 3. 白噪声检验 ---
status_message += "\n\n### 步骤 3: 白噪声检验\n"; progress(0.15, desc="进行白噪声检验...")
lags = min(10, len(ts_stationary) // 5)
lb_p_value = acorr_ljungbox(ts_stationary, lags=[lags], return_df=True)['lb_pvalue'].iloc[0]
status_message += f"✔️ 通过白噪声检验 (p-value = {lb_p_value:.4f}),序列为非白噪声。" if lb_p_value <= 0.05 else f"⚠️ 序列为白噪声 (p-value = {lb_p_value:.4f}),可能不适合复杂建模。"
yield update_ui(status_message)
# --- 4. 季节性检验与处理 ---
status_message += "\n\n### 步骤 4: 季节性分析 (初步分析)\n"; progress(0.2, desc="进行季节性分解...")
D_weekly = pm.arima.nsdiffs(ts_data, m=7, test='ch')
status_message += f"✔️ 每周季节性差分阶数 D(m=7): {D_weekly} (通过 Canova-Hansen 测试确定)。\n"
plt.figure(figsize=(14, 10)); stl_year = STL(ts_data, period=365).fit(); stl_year.plot()
plt.suptitle('STL Decomposition (Annual Seasonality)', y=1.02); plt.savefig(os.path.join(OUTPUT_DIR, '4_stl_decomposition.png'))
stl_fig_for_ui = plt.gcf(); plt.close()
status_message += "✔️ STL分解图已生成."; yield update_ui(status_message, stl_plot=stl_fig_for_ui)
# --- 5. 窗口大小优化 (使用并行计算) ---
status_message += "\n\n### 步骤 5: 窗口大小优化\n"
status_message += "▶️ **正在启动多进程并行计算**以评估不同窗口大小的MAE。\n▶️ 下方的进度条将显示总体进度。计算完成后会显示详细结果。\n"
yield update_ui(status_message, stl_plot=stl_fig_for_ui)
ts_values = ts_data.values; window_sizes = range(70, 211, 14)
tasks = [(ws, ts_values, d_order) for ws in window_sizes]
# **修改点 2**: 使用 *task 来解包参数
results_maes = Parallel(n_jobs=-1)(
delayed(_parallel_evaluate_window)(*task) for task in progress.tqdm(tasks, desc="优化窗口大小")
)
results_df = pd.DataFrame({'window_size': window_sizes, 'mae': results_maes})
status_message += "---\n✔️ 并行计算完成,结果如下:\n"
for _, row in results_df.iterrows():
status_message += f"窗口大小: {int(row['window_size']):<4} -> MAE: {row['mae']:.4f}\n"
status_message += "---\n"
if results_df.empty or results_df['mae'].isna().all():
gr.Error("窗口优化失败,所有窗口大小均未产出有效MAE。"); return
best_mae_window = int(results_df.loc[results_df['mae'].idxmin()]['window_size'])
plt.figure(figsize=(10, 6)); sns.lineplot(data=results_df, x='window_size', y='mae', marker='o')
plt.title('Effect of Window Size on Prediction Error (MAE)'); plt.xlabel('Training Window Days'); plt.ylabel('Mean Absolute Error (MAE)')
plt.grid(True); plt.savefig(os.path.join(OUTPUT_DIR, '5_window_size_optimization.png'))
window_fig_for_ui = plt.gcf(); plt.close()
status_message += f"✔️ 窗口大小优化完成。最优滑动窗口大小: **{best_mae_window}** 天。"
yield update_ui(status_message, stl_plot=stl_fig_for_ui, window_plot=window_fig_for_ui)
# --- 6. 模型定义 ---
status_message += "\n\n### 步骤 6: 定义所有模型\n✔️ SARIMA, Prophet, 加权平均模型已定义。"
yield update_ui(status_message, stl_plot=stl_fig_for_ui, window_plot=window_fig_for_ui)
def sarima_model(train_data, h=28):
return pm.auto_arima(train_data['Value'], d=d_order, D=D_weekly, m=7, seasonal=True, stepwise=True, suppress_warnings=True, error_action='ignore', max_p=3, max_q=3).predict(n_periods=h)
def prophet_model_default(train_data, h=28):
df_prophet = train_data.reset_index().rename(columns={'Date': 'ds', 'Value': 'y'}); model = Prophet(yearly_seasonality=True, weekly_seasonality=True).fit(df_prophet); future = model.make_future_dataframe(periods=h, freq='D'); return model.predict(future)['yhat'].tail(h)
def prophet_model_tuned(train_data, h=28):
df_prophet = train_data.reset_index().rename(columns={'Date': 'ds', 'Value': 'y'}); param_grid = {'changepoint_prior_scale': [0.01, 0.1, 0.5], 'seasonality_prior_scale': [1.0, 5.0, 10.0]}; best_mae, best_params = float('inf'), {}; train_len = len(df_prophet); split_point = train_len - 28
if split_point < 10: return prophet_model_default(train_data, h)
train_sub, val_sub = df_prophet.iloc[:split_point], df_prophet.iloc[split_point:]
for cps in param_grid['changepoint_prior_scale']:
for sps in param_grid['seasonality_prior_scale']:
m = Prophet(changepoint_prior_scale=cps, seasonality_prior_scale=sps).fit(train_sub); fc_val = m.predict(m.make_future_dataframe(periods=len(val_sub))).tail(len(val_sub)); mae = mean_absolute_error(val_sub['y'], fc_val['yhat'])
if mae < best_mae: best_mae, best_params = mae, {'changepoint_prior_scale': cps, 'seasonality_prior_scale': sps}
final_model = Prophet(**best_params).fit(df_prophet); return final_model.predict(final_model.make_future_dataframe(periods=h))['yhat'].tail(h)
def weighted_average_model(sarima_pred, prophet_pred, train_data):
val_end_date = train_data.index.max(); val_start_date = val_end_date - timedelta(weeks=4) + timedelta(days=1); validation_data = train_data[train_data.index >= val_start_date]
if len(validation_data) < 28: return (sarima_pred + prophet_pred) / 2, 0.5, 0.5
train_for_val = train_data[train_data.index < val_start_date]
if len(train_for_val) < 20: return (sarima_pred + prophet_pred) / 2, 0.5, 0.5
sarima_val_pred, prophet_val_pred = sarima_model(train_for_val, h=28), prophet_model_tuned(train_for_val, h=28); sarima_mae, prophet_mae = mean_absolute_error(validation_data['Value'], sarima_val_pred), mean_absolute_error(validation_data['Value'], prophet_val_pred)
if sarima_mae + prophet_mae == 0: return (sarima_pred + prophet_pred) / 2, 0.5, 0.5
inv_err_s, inv_err_p = (1 / sarima_mae if sarima_mae > 1e-9 else 1e9), (1 / prophet_mae if prophet_mae > 1e-9 else 1e9); total_inv_err = inv_err_s + inv_err_p; w_s, w_p = inv_err_s / total_inv_err, inv_err_p / total_inv_err
return w_s * sarima_pred + w_p * prophet_pred, w_s, w_p
# --- 7. 模型性能评估 (滚动预测) ---
status_message += "\n\n### 步骤 7: 模型性能评估 (滚动预测)\n"
start_date = ts_data.index.min() + timedelta(days=best_mae_window); end_date = ts_data.index.max() - timedelta(weeks=4); current_dates = pd.date_range(start=start_date, end=end_date, freq='W')
status_message += f"▶️ 将进行 {len(current_dates)} 次滚动预测,这会是耗时最长的步骤...\n"; yield update_ui(status_message, stl_plot=stl_fig_for_ui, window_plot=window_fig_for_ui)
all_metrics, weight_history = [], []
for i, date in enumerate(progress.tqdm(current_dates, desc="滚动预测评估")):
train_sliding, test_data = dynamic_split(ts_data.to_frame('Value'), date, best_mae_window, False); train_expanding, _ = dynamic_split(ts_data.to_frame('Value'), date, best_mae_window, True); actual_values = test_data['Value']; h = len(actual_values)
if h == 0: continue
sarima_pred_sliding, prophet_pred_tuned = sarima_model(train_sliding, h), prophet_model_tuned(train_expanding, h); weighted_pred, w_s, w_p = weighted_average_model(sarima_pred_sliding, prophet_pred_tuned, train_sliding); weight_history.append({'Date': date, 'SARIMA_Weight': w_s, 'Prophet_Weight': w_p}); models = {'SARIMA (Sliding)': sarima_pred_sliding, 'Prophet (Tuned)': prophet_pred_tuned, 'Weighted Average': weighted_pred}
for name, pred in models.items():
metrics = calculate_metrics(actual_values, pred); metrics['Model'], metrics['Date'] = name, date; all_metrics.append(metrics)
metrics_df, weight_history_df = pd.DataFrame(all_metrics), pd.DataFrame(weight_history)
status_message += "✔️ 滚动预测评估完成."; yield update_ui(status_message, stl_plot=stl_fig_for_ui, window_plot=window_fig_for_ui)
# --- 8. 模型对比可视化 ---
status_message += "\n\n### 步骤 8: 模型对比可视化\n▶️ 正在生成对比图表...\n"; progress(0.85, desc="生成对比图表...")
plt.figure(figsize=(12, 6)); sns.lineplot(data=weight_history_df, x='Date', y='SARIMA_Weight', label='SARIMA Weight'); sns.lineplot(data=weight_history_df, x='Date', y='Prophet_Weight', label='Prophet Weight'); plt.title('Weight Change of Sub-models'); plt.savefig(os.path.join(OUTPUT_DIR, '8_1_weight_change.png')); weight_fig_for_ui = plt.gcf(); plt.close()
metrics_long = metrics_df.melt(id_vars=['Date', 'Model'], value_vars=['MAE', 'RMSE', 'MAPE', 'sMAPE'], var_name='Metric', value_name='Value'); g = sns.FacetGrid(metrics_long, col='Metric', hue='Model', col_wrap=2, height=4, aspect=1.5, sharey=False); g.map(sns.lineplot, 'Date', 'Value', marker='o', markersize=4).add_legend(title="Model"); g.fig.suptitle('Comparison of Prediction Error Metrics', y=1.03); g.set_titles("{col_name}"); g.fig.savefig(os.path.join(OUTPUT_DIR, '8_2_error_comparison.png')); error_comp_fig_for_ui = g.fig; plt.close()
avg_metrics = metrics_df.groupby('Model')[['MAE', 'RMSE', 'MAPE', 'sMAPE']].mean(); avg_metrics_normalized = avg_metrics.apply(lambda x: (x - x.min()) / (x.max() - x.min())).reset_index(); labels = avg_metrics_normalized.columns[1:]; num_vars = len(labels); angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist() + [0]; fig_radar, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True));
for i, row in avg_metrics_normalized.iterrows(): values = row.drop('Model').tolist() + [row.drop('Model').tolist()[0]]; ax.plot(angles, values, label=row['Model'], linewidth=2); ax.fill(angles, values, alpha=0.2)
ax.set_xticks(angles[:-1]); ax.set_xticklabels(labels); plt.title('Average Model Performance Radar Chart (Normalized)'); plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1)); fig_radar.savefig(os.path.join(OUTPUT_DIR, '8_3_radar_chart.png')); radar_fig_for_ui = fig_radar; plt.close()
fig_box, axes = plt.subplots(2, 2, figsize=(16, 10))
for i, metric in enumerate(['MAE', 'RMSE', 'MAPE', 'sMAPE']): ax = axes[i//2, i%2]; sns.boxplot(data=metrics_df, x='Model', y=metric, hue='Model', ax=ax, palette='viridis', legend=False); ax.set_title(f'{metric} Distribution'); ax.tick_params(axis='x', rotation=30)
plt.tight_layout(); fig_box.savefig(os.path.join(OUTPUT_DIR, '8_4_error_boxplot.png')); boxplot_fig_for_ui = fig_box; plt.close()
last_date = current_dates[-1]; train_sliding, test_data = dynamic_split(ts_data.to_frame('Value'), last_date, best_mae_window, False); h = len(test_data); preds = {'SARIMA (Sliding)': sarima_model(train_sliding, h), 'Prophet (Tuned)': prophet_model_tuned(train_sliding, h)}; preds['Weighted Average'], _, _ = weighted_average_model(preds['SARIMA (Sliding)'], preds['Prophet (Tuned)'], train_sliding); fig_final, ax_final = plt.subplots(figsize=(16, 8)); ax_final.plot(test_data.index, test_data['Value'], label='Actual Value', color='black', linewidth=2.5, marker='o')
for name, pred in preds.items(): ax_final.plot(test_data.index, pred.values, label=name, linestyle='--')
ax_final.set_title(f'Final 4-Week Forecast Comparison (Start Date: {last_date.date()})'); ax_final.legend(); fig_final.savefig(os.path.join(OUTPUT_DIR, '8_5_final_forecast.png')); final_fig_for_ui = fig_final; plt.close()
status_message += "✔️ 所有对比图表已生成."
yield update_ui(status_message, stl_plot=stl_fig_for_ui, window_plot=window_fig_for_ui, weight_plot=weight_fig_for_ui, error_comp_plot=error_comp_fig_for_ui, radar_plot=radar_fig_for_ui, box_plot=boxplot_fig_for_ui, final_plot=final_fig_for_ui)
# --- 9. 最终性能总结 ---
status_message += "\n\n### 步骤 9: 最终性能总结\n"; progress(0.95, desc="生成最终报告...")
avg_perf = metrics_df.groupby('Model')[['MAE', 'RMSE', 'MAPE', 'sMAPE']].mean().sort_values('MAE'); status_message += "✔️ 平均性能计算完成。\n"
summary_path = os.path.join(OUTPUT_DIR, 'performance_summary.txt')
with open(summary_path, 'w', encoding='utf-8') as f: f.write("Time Series Forecasting Model Comparison Report\n" + "="*50 + "\n\n"); f.write(f"1. Optimal Sliding Window Size: {best_mae_window} days\n\n"); f.write(f"2. Stationarity: Stationary after {d_order} order(s) of differencing.\n\n"); f.write(f"3. Seasonality: Weekly seasonal differencing D(m=7) = {D_weekly}.\n\n"); f.write("4. Average Model Performance (sorted by MAE):\n" + avg_perf.to_string(float_format="%.4f"))
status_message += f"✔️ 性能总结报告已保存到 {summary_path}\n"
zip_archive_path = create_zip_archive(); status_message += f"✔️ 所有结果已打包到 {zip_archive_path}。"
progress(1, desc="分析完成!"); status_message += "\n\n## ✅ 分析流程执行完毕!"
yield update_ui(status_message, stl_plot=stl_fig_for_ui, window_plot=window_fig_for_ui, weight_plot=weight_fig_for_ui, error_comp_plot=error_comp_fig_for_ui, radar_plot=radar_fig_for_ui, box_plot=boxplot_fig_for_ui, final_plot=final_fig_for_ui, summary_df=avg_perf.reset_index(), download_file=zip_archive_path)
# ======================== 步骤 4: 构建 Gradio 界面 ========================
with gr.Blocks(theme=gr.themes.Soft(), title="时间序列预测分析平台") as app:
gr.Markdown("# 📈 时间序列预测分析平台"); gr.Markdown("上传您的时间序列数据 (XLSX格式,包含 'Date' 和 'Value' 两列),然后点击“开始分析”按钮。系统将自动完成数据清洗、检验、模型训练、评估和对比的全过程。")
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(label="上传 XLSX 数据文件", file_types=['.xlsx'], value="gmqrkl.xlsx"); run_button = gr.Button("🚀 开始分析", variant="primary"); gr.Markdown("---"); gr.Markdown("### 📝 分析日志与进度"); status_display = gr.Markdown("请上传文件并点击开始...")
with gr.Column(scale=3):
gr.Markdown("### 📊 分析结果可视化")
with gr.Tabs():
with gr.TabItem("初步分析"): plot_stl = gr.Plot(label="STL 季节性分解")
with gr.TabItem("优化窗口"): plot_window = gr.Plot(label="最优窗口大小分析")
with gr.TabItem("模型性能对比"): plot_error_comp = gr.Plot(label="各模型误差指标对比"); plot_boxplot = gr.Plot(label="误差分布箱线图")
with gr.TabItem("综合评估"): plot_radar = gr.Plot(label="模型平均性能雷达图"); plot_weights = gr.Plot(label="组合模型权重变化")
with gr.TabItem("最终预测"): plot_final_forecast = gr.Plot(label="最后四周预测对比图")
gr.Markdown("---"); gr.Markdown("### 🏆 最终性能总结"); summary_table = gr.DataFrame(label="各模型平均性能对比 (按 MAE 升序)"); download_results = gr.File(label="下载全部分析结果 (ZIP)", visible=False)
outputs_list = [status_display, plot_stl, plot_window, plot_weights, plot_error_comp, plot_radar, plot_boxplot, plot_final_forecast, summary_table, download_results]
run_button.click(fn=run_analysis, inputs=[file_input], outputs=outputs_list)
if __name__ == "__main__":
app.launch()