leonsimon23 commited on
Commit
abb49ca
·
verified ·
1 Parent(s): d324713

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -33
app.py CHANGED
@@ -16,82 +16,127 @@ from sklearn.metrics import mean_absolute_error, mean_squared_error
16
  import seaborn as sns
17
  from joblib import Parallel, delayed
18
 
 
19
  warnings.filterwarnings('ignore')
20
 
21
  # --- 在Hugging Face上配置中文字体 ---
22
  # 确保你已经上传了 SimHei.ttf 字体文件
23
-
24
- #font_path = 'SimHei.ttf'
25
- #try:
26
- # matplotlib.font_manager.fontManager.addfont(font_path)
27
- # plt.rcParams['font.sans-serif'] = ['SimHei']
28
- # plt.rcParams['axes.unicode_minus'] = False
29
- #except Exception as e:
30
- # print(f"加载中文字体失败: {e}")
31
- # 如果失败,Gradio界面中的图表中文可能显示为方块
32
-
33
- # 替换为下面这行,更简单稳定:
34
  plt.rcParams['font.sans-serif'] = ['WenQuanYi Zen Hei']
35
- plt.rcParams['axes.unicode_minus'] = False
36
 
37
  # =============================================================================
38
  # 将所有分析步骤封装在一个主函数中
39
  # =============================================================================
40
  def run_analysis(file_obj):
 
 
 
 
 
 
 
 
 
 
 
 
41
  if file_obj is None:
42
  raise gr.Error("请先上传一个Excel文件!")
43
 
 
 
44
  # ----- 1. 数据预处理 -----
 
 
45
  df = pd.read_excel(file_obj.name)
 
46
  df['Date'] = pd.to_datetime(df['Date'])
 
47
  df.sort_values('Date', inplace=True)
 
48
  df.set_index('Date', inplace=True)
 
49
  df['Value'] = df['Value'].replace(0, np.nan)
 
50
  df['Value'].interpolate(method='linear', inplace=True)
 
51
 
52
  # ----- 2. 平稳性与白噪声检验 (仅输出结果) -----
 
 
53
  ts_stationary, d_order = make_stationary_silent(df['Value'])
 
54
  lb_p_value = acorr_ljungbox(ts_stationary, lags=[10], return_df=True)['lb_pvalue'].iloc[0]
 
55
 
56
  # ----- 3. 季节性检验 -----
57
- D_order = nsdiffs(df['Value'], m=7, test='ch')
 
 
 
58
 
59
  # STL分解图
 
60
  stl = STL(df['Value'], period=7, robust=True)
61
  res = stl.fit()
62
  fig_stl = res.plot()
63
  fig_stl.suptitle("STL季节性分解 (周期=7)", y=1.02)
 
64
 
65
  # ----- 4. 窗口优化 (简化版,避免超时) -----
 
66
  # 在线上环境中,完整的窗口优化太慢,这里使用一个固定的值或一个快速的估计
67
- # 为了演示,我们仍然运行一个简化版的
68
  ts_values = df['Value'].values
69
- window_sizes = range(100, 181, 20) # 减少评估点
 
70
  results = Parallel(n_jobs=-1)(delayed(evaluate_window)(ws, ts_values) for ws in window_sizes)
71
  results_df = pd.DataFrame(results)
 
72
  OPTIMAL_WINDOW = int(results_df.loc[results_df['mae'].idxmin()]['window_size']) if not results_df.empty else 120
 
73
 
74
  # ----- 5. 模型预测 -----
75
- final_train_end_date = df.index.max() - timedelta(days=28)
 
 
76
  final_test_start_date = final_train_end_date + timedelta(days=1)
 
 
77
  train_sarima = df[final_train_end_date - timedelta(days=OPTIMAL_WINDOW - 1) : final_train_end_date]
 
78
  train_full = df[:final_train_end_date]
 
79
  test_data = df[final_test_start_date:]
 
80
 
 
81
  sarima_pred = sarima_model(train_sarima, d_order, D_order, h=28)
 
 
 
82
  prophet_pred = prophet_model(train_full, h=28)
 
 
 
83
  weighted_pred, sw, pw = weighted_average_model(train_full, d_order, D_order, h=28)
 
84
 
 
85
  predictions = pd.DataFrame({
86
  'Actual': test_data['Value'],
87
  'SARIMA': sarima_pred,
88
  'Prophet': prophet_pred,
89
  'Weighted': weighted_pred
90
- }).dropna()
91
 
92
  # ----- 6. 指标计算与可视化 -----
 
 
93
  metrics = {model: calculate_metrics(predictions['Actual'], predictions[model]) for model in ['SARIMA', 'Prophet', 'Weighted']}
94
  metrics_df = pd.DataFrame(metrics).T.reset_index().rename(columns={'index': 'Model'})
 
95
 
96
  # 4周预测对比图
97
  fig_forecast_4w = plt.figure(figsize=(12, 6))
@@ -101,9 +146,10 @@ def run_analysis(file_obj):
101
  plt.plot(predictions.index, predictions['Weighted'], label=f'加权平均 (S:{sw:.2f}, P:{pw:.2f})', color='green', linestyle=':')
102
  plt.title('未来4周预测对比', fontsize=16)
103
  plt.legend(); plt.grid(True)
 
104
 
105
  # 1周预测对比图和指标
106
- first_week_preds = predictions.head(7)
107
  first_week_metrics = pd.DataFrame({
108
  model: calculate_metrics(first_week_preds['Actual'], first_week_preds[model])
109
  for model in ['SARIMA', 'Prophet', 'Weighted']
@@ -115,7 +161,9 @@ def run_analysis(file_obj):
115
  plt.plot(first_week_preds.index, first_week_preds['Prophet'], label='Prophet预测', color='blue', linestyle='-.')
116
  plt.plot(first_week_preds.index, first_week_preds['Weighted'], label='加权平均预测', color='green', linestyle=':')
117
  plt.title('第一周预测结果对比'); plt.legend(); plt.grid(True)
 
118
 
 
119
  summary_text = (
120
  f"数据加载成功,共 {len(df)} 条记录。\n"
121
  f"最优差分阶数 d = {d_order}, 季节差分阶数 D = {D_order}\n"
@@ -123,6 +171,7 @@ def run_analysis(file_obj):
123
  f"计算得到的最优滑动窗口为: {OPTIMAL_WINDOW} 天\n"
124
  f"模型权重: SARIMA={sw:.2f}, Prophet={pw:.2f}"
125
  )
 
126
 
127
  # 返回所有结果
128
  return summary_text, metrics_df, fig_forecast_4w, first_week_metrics, fig_forecast_1w, fig_stl
@@ -131,54 +180,158 @@ def run_analysis(file_obj):
131
  # 辅助函数 (从主脚本中提取)
132
  # =============================================================================
133
  def make_stationary_silent(data_series, max_diff=2):
 
 
 
 
 
 
 
 
134
  diff_order = 0
135
  current_series = data_series.dropna()
136
  for d in range(max_diff + 1):
137
  if d > 0:
138
  current_series = current_series.diff().dropna()
139
- if adfuller(current_series)[1] < 0.05:
 
140
  return current_series, d
141
  return current_series, max_diff
142
 
143
  def evaluate_window(window_size, ts_values):
 
 
 
 
 
 
 
 
144
  n = len(ts_values)
145
- if window_size >= n: return {'mae': float('inf'), 'rmse': float('inf')}
146
- errors = [ts_values[i + window_size] - auto_arima(ts_values[i:(i + window_size)], d=0, stepwise=True, suppress_warnings=True, error_action='ignore').predict(n_periods=1)[0] for i in range(n - window_size - 1)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  return {'window_size': window_size, 'mae': np.mean(np.abs(errors)), 'rmse': np.sqrt(np.mean(np.square(errors)))}
148
 
149
  def sarima_model(train_data, d, D, h):
 
 
 
 
 
 
 
 
 
 
 
150
  model = auto_arima(train_data['Value'], d=d, D=D, m=7, seasonal=True, stepwise=True, suppress_warnings=True, error_action='ignore')
151
  return model.predict(n_periods=h)
152
 
153
  def prophet_model(train_data, h):
 
 
 
 
 
 
 
 
 
154
  df_prophet = train_data.reset_index().rename(columns={'Date': 'ds', 'Value': 'y'})
155
- model = Prophet(yearly_seasonality=True, weekly_seasonality=True).fit(df_prophet)
156
- return model.predict(model.make_future_dataframe(periods=h, freq='D'))['yhat'].tail(h)
 
 
 
 
 
157
 
158
  def weighted_average_model(train_data, d, D, h):
159
- validation_data = train_data.tail(28)
160
- train_for_val = train_data.iloc[:-28]
161
- if len(train_for_val) < 56: # 确保训练集足够大
162
- return prophet_model(train_data, h), 0.0, 1.0 # 如果数据太少,直接用prophet
163
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  sarima_fc_val = sarima_model(train_for_val, d, D, h=28)
165
  prophet_fc_val = prophet_model(train_for_val, h=28)
 
 
166
  sarima_mae = mean_absolute_error(validation_data['Value'], sarima_fc_val)
167
  prophet_mae = mean_absolute_error(validation_data['Value'], prophet_fc_val)
168
 
169
- if sarima_mae + prophet_mae == 0: sarima_weight, prophet_weight = 0.5, 0.5
 
 
170
  else:
171
  inv_sum = (1/sarima_mae) + (1/prophet_mae)
172
  sarima_weight = (1/sarima_mae) / inv_sum
173
  prophet_weight = (1/prophet_mae) / inv_sum
174
 
 
175
  sarima_pred = sarima_model(train_data, d, D, h=h)
176
  prophet_pred = prophet_model(train_data, h=h)
 
 
177
  weighted_avg = sarima_weight * sarima_pred.values + prophet_weight * prophet_pred.values
178
  return pd.Series(weighted_avg, index=sarima_pred.index), sarima_weight, prophet_weight
179
 
180
  def calculate_metrics(actual, predicted):
181
- return {'MAE': mean_absolute_error(actual, predicted), 'RMSE': np.sqrt(mean_squared_error(actual, predicted)), 'MAPE': np.mean(np.abs((actual - predicted) / actual)) * 100}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  # =============================================================================
184
  # 创建Gradio界面
@@ -205,7 +358,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="药品销量时序预测") as app:
205
 
206
  with gr.Row():
207
  plot_stl_output = gr.Plot(label="STL季节性分解图")
208
- gr.Markdown("") # 占位
209
 
210
  # 设置按钮点击事件
211
  run_button.click(
@@ -214,12 +367,13 @@ with gr.Blocks(theme=gr.themes.Soft(), title="药品销量时序预测") as app:
214
  outputs=[summary_output, metrics_4w_output, plot_4w_output, metrics_1w_output, plot_1w_output, plot_stl_output]
215
  )
216
 
 
217
  gr.Examples(
218
- [["gmqrkl.xlsx"]],
219
  inputs=[file_input],
220
  fn=run_analysis,
221
  outputs=[summary_output, metrics_4w_output, plot_4w_output, metrics_1w_output, plot_1w_output, plot_stl_output]
222
  )
223
 
 
224
  app.launch(server_name="0.0.0.0", server_port=7860)
225
- #app.launch()
 
16
  import seaborn as sns
17
  from joblib import Parallel, delayed
18
 
19
+ # 忽略所有警告,但这可能不会抑制所有来自底层库的FutureWarning
20
  warnings.filterwarnings('ignore')
21
 
22
  # --- 在Hugging Face上配置中文字体 ---
23
  # 确保你已经上传了 SimHei.ttf 字体文件
24
+ # 或者在环境中安装了 WenQuanYi Zen Hei 字体
 
 
 
 
 
 
 
 
 
 
25
  plt.rcParams['font.sans-serif'] = ['WenQuanYi Zen Hei']
26
+ plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
27
 
28
  # =============================================================================
29
  # 将所有分析步骤封装在一个主函数中
30
  # =============================================================================
31
  def run_analysis(file_obj):
32
+ """
33
+ 运行完整的时序预测分析流程。
34
+ 包括数据预处理、平稳性/白噪声/季节性检验、窗口优化、
35
+ SARIMA和Prophet模型预测、加权平均模型以及结果可视化。
36
+
37
+ 参数:
38
+ file_obj (gr.File): Gradio文件对象,包含上传的Excel文件。
39
+
40
+ 返回:
41
+ tuple: 包含分析摘要文本、四周预测指标DataFrame、四周预测图、
42
+ 一周预测指标DataFrame、一周预测图和STL分解图。
43
+ """
44
  if file_obj is None:
45
  raise gr.Error("请先上传一个Excel文件!")
46
 
47
+ print("----- 开始分析 -----")
48
+
49
  # ----- 1. 数据预处理 -----
50
+ print("----- 1. 数据预处理 -----")
51
+ # 读取Excel文件
52
  df = pd.read_excel(file_obj.name)
53
+ # 将'Date'列转换为日期时间格式
54
  df['Date'] = pd.to_datetime(df['Date'])
55
+ # 按日期排序
56
  df.sort_values('Date', inplace=True)
57
+ # 将'Date'列设置为索引
58
  df.set_index('Date', inplace=True)
59
+ # 将'Value'列中的0替换为NaN
60
  df['Value'] = df['Value'].replace(0, np.nan)
61
+ # 使用线性插值填充NaN值
62
  df['Value'].interpolate(method='linear', inplace=True)
63
+ print(f"数据加载成功,共 {len(df)} 条记录。")
64
 
65
  # ----- 2. 平稳性与白噪声检验 (仅输出结果) -----
66
+ print("----- 2. 平稳性与白噪声检验 -----")
67
+ # 对时间序列进行差分以使其平稳
68
  ts_stationary, d_order = make_stationary_silent(df['Value'])
69
+ # 进行Ljung-Box白噪声检验
70
  lb_p_value = acorr_ljungbox(ts_stationary, lags=[10], return_df=True)['lb_pvalue'].iloc[0]
71
+ print(f"最优差分阶数 d = {d_order}, 白噪声检验 P-value = {lb_p_value:.4f}")
72
 
73
  # ----- 3. 季节性检验 -----
74
+ print("----- 3. 季节性检验 -----")
75
+ # 使用Canova-Hansen检验确定季节差分阶数
76
+ D_order = nsdiffs(df['Value'], m=7, test='ch') # 假设周期为7天
77
+ print(f"季节差分阶数 D = {D_order}")
78
 
79
  # STL分解图
80
+ # 对时间序列进行STL分解,提取趋势、季节性和残差
81
  stl = STL(df['Value'], period=7, robust=True)
82
  res = stl.fit()
83
  fig_stl = res.plot()
84
  fig_stl.suptitle("STL季节性分解 (周期=7)", y=1.02)
85
+ print("STL季节性分解图已生成。")
86
 
87
  # ----- 4. 窗口优化 (简化版,避免超时) -----
88
+ print("----- 4. 窗口优化 (简化版) -----")
89
  # 在线上环境中,完整的窗口优化太慢,这里使用一个固定的值或一个快速的估计
90
+ # 为了演示,我们仍然运行一个简化版的窗口评估
91
  ts_values = df['Value'].values
92
+ window_sizes = range(100, 181, 20) # 减少评估点,评估100到180天窗口,步长20
93
+ # 并行评估不同窗口大小下的模型性能
94
  results = Parallel(n_jobs=-1)(delayed(evaluate_window)(ws, ts_values) for ws in window_sizes)
95
  results_df = pd.DataFrame(results)
96
+ # 确定最优窗口大小,如果结果为空则默认为120
97
  OPTIMAL_WINDOW = int(results_df.loc[results_df['mae'].idxmin()]['window_size']) if not results_df.empty else 120
98
+ print(f"计算得到的最优滑动窗口为: {OPTIMAL_WINDOW} 天")
99
 
100
  # ----- 5. 模型预测 -----
101
+ print("----- 5. 模型预测 -----")
102
+ # 定义训练集和测试集的日期范围
103
+ final_train_end_date = df.index.max() - timedelta(days=28) # 留出最后28天作为测试集
104
  final_test_start_date = final_train_end_date + timedelta(days=1)
105
+
106
+ # SARIMA模型训练数据:基于最优窗口
107
  train_sarima = df[final_train_end_date - timedelta(days=OPTIMAL_WINDOW - 1) : final_train_end_date]
108
+ # Prophet模型训练数据:使用所有可用数据直到训练结束日期
109
  train_full = df[:final_train_end_date]
110
+ # 测试数据
111
  test_data = df[final_test_start_date:]
112
+ print(f"训练数据截止日期: {final_train_end_date.strftime('%Y-%m-%d')}, 测试数据开始日期: {final_test_start_date.strftime('%Y-%m-%d')}")
113
 
114
+ print("正在生成 SARIMA 模型预测...")
115
  sarima_pred = sarima_model(train_sarima, d_order, D_order, h=28)
116
+ print("SARIMA 模型预测完成。")
117
+
118
+ print("正在生成 Prophet 模型预测...")
119
  prophet_pred = prophet_model(train_full, h=28)
120
+ print("Prophet 模型预测完成。")
121
+
122
+ print("正在生成加权平均模型预测...")
123
  weighted_pred, sw, pw = weighted_average_model(train_full, d_order, D_order, h=28)
124
+ print(f"加权平均模型预测完成。模型权重: SARIMA={sw:.2f}, Prophet={pw:.2f}")
125
 
126
+ # 将实际值和各模型预测值整合到DataFrame中
127
  predictions = pd.DataFrame({
128
  'Actual': test_data['Value'],
129
  'SARIMA': sarima_pred,
130
  'Prophet': prophet_pred,
131
  'Weighted': weighted_pred
132
+ }).dropna() # 丢弃包含NaN的行
133
 
134
  # ----- 6. 指标计算与可视化 -----
135
+ print("----- 6. 指标计算与可视化 -----")
136
+ # 计算所有模型的性能指标
137
  metrics = {model: calculate_metrics(predictions['Actual'], predictions[model]) for model in ['SARIMA', 'Prophet', 'Weighted']}
138
  metrics_df = pd.DataFrame(metrics).T.reset_index().rename(columns={'index': 'Model'})
139
+ print("性能指标计算完成。")
140
 
141
  # 4周预测对比图
142
  fig_forecast_4w = plt.figure(figsize=(12, 6))
 
146
  plt.plot(predictions.index, predictions['Weighted'], label=f'加权平均 (S:{sw:.2f}, P:{pw:.2f})', color='green', linestyle=':')
147
  plt.title('未来4周预测对比', fontsize=16)
148
  plt.legend(); plt.grid(True)
149
+ print("未来4周预测对比图已生成。")
150
 
151
  # 1周预测对比图和指标
152
+ first_week_preds = predictions.head(7) # 获取第一周的预测数据
153
  first_week_metrics = pd.DataFrame({
154
  model: calculate_metrics(first_week_preds['Actual'], first_week_preds[model])
155
  for model in ['SARIMA', 'Prophet', 'Weighted']
 
161
  plt.plot(first_week_preds.index, first_week_preds['Prophet'], label='Prophet预测', color='blue', linestyle='-.')
162
  plt.plot(first_week_preds.index, first_week_preds['Weighted'], label='加权平均预测', color='green', linestyle=':')
163
  plt.title('第一周预测结果对比'); plt.legend(); plt.grid(True)
164
+ print("第一周预测结果对比图已生成。")
165
 
166
+ # 生成分析摘要文本
167
  summary_text = (
168
  f"数据加载成功,共 {len(df)} 条记录。\n"
169
  f"最优差分阶数 d = {d_order}, 季节差分阶数 D = {D_order}\n"
 
171
  f"计算得到的最优滑动窗口为: {OPTIMAL_WINDOW} 天\n"
172
  f"模型权重: SARIMA={sw:.2f}, Prophet={pw:.2f}"
173
  )
174
+ print("----- 分析完成 -----")
175
 
176
  # 返回所有结果
177
  return summary_text, metrics_df, fig_forecast_4w, first_week_metrics, fig_forecast_1w, fig_stl
 
180
  # 辅助函数 (从主脚本中提取)
181
  # =============================================================================
182
  def make_stationary_silent(data_series, max_diff=2):
183
+ """
184
+ 通过差分使时间序列平稳。
185
+ 参数:
186
+ data_series (pd.Series): 输入时间序列。
187
+ max_diff (int): 最大差分阶数。
188
+ 返回:
189
+ tuple: 平稳后的时间序列和差分阶数。
190
+ """
191
  diff_order = 0
192
  current_series = data_series.dropna()
193
  for d in range(max_diff + 1):
194
  if d > 0:
195
  current_series = current_series.diff().dropna()
196
+ # 使用ADF检验检查平稳性
197
+ if adfuller(current_series)[1] < 0.05: # P-value小于0.05则认为平稳
198
  return current_series, d
199
  return current_series, max_diff
200
 
201
  def evaluate_window(window_size, ts_values):
202
+ """
203
+ 评估给定窗口大小下模型的性能。
204
+ 参数:
205
+ window_size (int): 滑动窗口大小。
206
+ ts_values (np.array): 时间序列值数组。
207
+ 返回:
208
+ dict: 包含窗口大小、MAE和RMSE的字典。
209
+ """
210
  n = len(ts_values)
211
+ if window_size >= n:
212
+ return {'mae': float('inf'), 'rmse': float('inf')}
213
+
214
+ # 在滑动窗口上训练auto_arima模型并进行单步预测
215
+ errors = []
216
+ for i in range(n - window_size - 1):
217
+ try:
218
+ model = auto_arima(ts_values[i:(i + window_size)], d=0, stepwise=True, suppress_warnings=True, error_action='ignore')
219
+ prediction = model.predict(n_periods=1)[0]
220
+ errors.append(ts_values[i + window_size] - prediction)
221
+ except Exception as e:
222
+ # 捕获auto_arima可能出现的错误,避免中断循环
223
+ # print(f"窗口评估中auto_arima出现错误: {e}") # 避免过多打印,只在必要时开启
224
+ errors.append(np.nan) # 标记为NaN,后续会处理
225
+
226
+ errors = np.array(errors)
227
+ errors = errors[~np.isnan(errors)] # 移除NaN错误
228
+
229
+ if len(errors) == 0:
230
+ return {'window_size': window_size, 'mae': float('inf'), 'rmse': float('inf')}
231
+
232
  return {'window_size': window_size, 'mae': np.mean(np.abs(errors)), 'rmse': np.sqrt(np.mean(np.square(errors)))}
233
 
234
  def sarima_model(train_data, d, D, h):
235
+ """
236
+ 训练SARIMA模型并进行预测。
237
+ 参数:
238
+ train_data (pd.DataFrame): 训练数据。
239
+ d (int): 非季节差分阶数。
240
+ D (int): 季节差分阶数。
241
+ h (int): 预测步长。
242
+ 返回:
243
+ pd.Series: SARIMA模型预测结果。
244
+ """
245
+ # 使用auto_arima自动选择最优SARIMA模型参数
246
  model = auto_arima(train_data['Value'], d=d, D=D, m=7, seasonal=True, stepwise=True, suppress_warnings=True, error_action='ignore')
247
  return model.predict(n_periods=h)
248
 
249
  def prophet_model(train_data, h):
250
+ """
251
+ 训练Prophet模型并进行预测。
252
+ 参数:
253
+ train_data (pd.DataFrame): 训练数据。
254
+ h (int): 预测步长。
255
+ 返回:
256
+ pd.Series: Prophet模型预测结果。
257
+ """
258
+ # 准备Prophet模型所需的数据格式
259
  df_prophet = train_data.reset_index().rename(columns={'Date': 'ds', 'Value': 'y'})
260
+ # 初始化并训练Prophet模型
261
+ model = Prophet(yearly_seasonality=True, weekly_seasonality=True)
262
+ model.fit(df_prophet)
263
+ # 创建未来预测的DataFrame
264
+ future = model.make_future_dataframe(periods=h, freq='D')
265
+ # 进行预测并返回预测值
266
+ return model.predict(future)['yhat'].tail(h)
267
 
268
  def weighted_average_model(train_data, d, D, h):
269
+ """
270
+ 训练加权平均模型(SARIMA和Prophet)并进行预测。
271
+ 权重基于验证集上的MAE反比。
272
+ 参数:
273
+ train_data (pd.DataFrame): 训练数据。
274
+ d (int): SARIMA非季节差分阶数。
275
+ D (int): SARIMA季节差分阶数。
276
+ h (int): 预测步长。
277
+ 返回:
278
+ tuple: 加权平均预测结果(pd.Series)、SARIMA权重、Prophet权重。
279
+ """
280
+ validation_data = train_data.tail(28) # 留出最后28天作为验证集
281
+ train_for_val = train_data.iloc[:-28] # 用于训练验证模型的子集
282
+
283
+ if len(train_for_val) < 56: # 确保训练集足够大,至少是验证集的两倍
284
+ # 如果数据太少,直接使用Prophet模型,并设置Prophet权重为1.0
285
+ prophet_pred_full = prophet_model(train_data, h)
286
+ return prophet_pred_full, 0.0, 1.0
287
+
288
+ # 在验证集上生成SARIMA和Prophet预测
289
  sarima_fc_val = sarima_model(train_for_val, d, D, h=28)
290
  prophet_fc_val = prophet_model(train_for_val, h=28)
291
+
292
+ # 计算各自在验证集上的MAE
293
  sarima_mae = mean_absolute_error(validation_data['Value'], sarima_fc_val)
294
  prophet_mae = mean_absolute_error(validation_data['Value'], prophet_fc_val)
295
 
296
+ # 根据MAE计算权重(MAE越小,权重越大)
297
+ if sarima_mae + prophet_mae == 0: # 避免除以零
298
+ sarima_weight, prophet_weight = 0.5, 0.5
299
  else:
300
  inv_sum = (1/sarima_mae) + (1/prophet_mae)
301
  sarima_weight = (1/sarima_mae) / inv_sum
302
  prophet_weight = (1/prophet_mae) / inv_sum
303
 
304
+ # 使用完整训练数据生成最终预测
305
  sarima_pred = sarima_model(train_data, d, D, h=h)
306
  prophet_pred = prophet_model(train_data, h=h)
307
+
308
+ # 计算加权平均预测
309
  weighted_avg = sarima_weight * sarima_pred.values + prophet_weight * prophet_pred.values
310
  return pd.Series(weighted_avg, index=sarima_pred.index), sarima_weight, prophet_weight
311
 
312
  def calculate_metrics(actual, predicted):
313
+ """
314
+ 计算预测性能指标:MAE, RMSE, MAPE。
315
+ 参数:
316
+ actual (pd.Series): 实际值。
317
+ predicted (pd.Series): 预测值。
318
+ 返回:
319
+ dict: 包含MAE, RMSE, MAPE的字典。
320
+ """
321
+ # 确保实际值和预测值长度一致
322
+ min_len = min(len(actual), len(predicted))
323
+ actual = actual.iloc[:min_len]
324
+ predicted = predicted.iloc[:min_len]
325
+
326
+ # 避免除以零或无穷大
327
+ mape = np.mean(np.abs((actual - predicted) / actual.replace(0, np.nan))) * 100
328
+ mape = np.nan_to_num(mape, nan=0.0, posinf=0.0, neginf=0.0) # 处理可能出现的NaN/Inf
329
+
330
+ return {
331
+ 'MAE': mean_absolute_error(actual, predicted),
332
+ 'RMSE': np.sqrt(mean_squared_error(actual, predicted)),
333
+ 'MAPE': mape
334
+ }
335
 
336
  # =============================================================================
337
  # 创建Gradio界面
 
358
 
359
  with gr.Row():
360
  plot_stl_output = gr.Plot(label="STL季节性分解图")
361
+ gr.Markdown("") # 占位符,保持布局
362
 
363
  # 设置按钮点击事件
364
  run_button.click(
 
367
  outputs=[summary_output, metrics_4w_output, plot_4w_output, metrics_1w_output, plot_1w_output, plot_stl_output]
368
  )
369
 
370
+ # 示例文件,方便用户测试
371
  gr.Examples(
372
+ [["gmqrkl.xlsx"]], # 请确保 'gmqrkl.xlsx' 文件在Gradio应用运行的环境中可访问
373
  inputs=[file_input],
374
  fn=run_analysis,
375
  outputs=[summary_output, metrics_4w_output, plot_4w_output, metrics_1w_output, plot_1w_output, plot_stl_output]
376
  )
377
 
378
+ # 明确指定 Gradio 监听的地址和端口,以确保在 Hugging Face Spaces 上正确启动
379
  app.launch(server_name="0.0.0.0", server_port=7860)