Kolesnikov Dmitry commited on
Commit
6e8ed89
·
1 Parent(s): 7c4a2f3

feat: Красивые графики

Browse files
Files changed (1) hide show
  1. src/lab3_pipeline.py +111 -1
src/lab3_pipeline.py CHANGED
@@ -67,6 +67,11 @@ try:
67
  except Exception:
68
  SCIPY_AVAILABLE = False
69
 
 
 
 
 
 
70
 
71
  # -------------------------------------------------------------------------
72
  # Metrics
@@ -520,6 +525,7 @@ def generate_report_html(out_path: str, plots: List[plt.Figure], tables: Dict[st
520
 
521
  # Закрываем рисунок чтобы освободить память
522
  plt.close(fig)
 
523
 
524
  html_parts.append("</body></html>")
525
 
@@ -530,8 +536,26 @@ def generate_report_html(out_path: str, plots: List[plt.Figure], tables: Dict[st
530
  # -------------------------------------------------------------------------
531
  # Main runner that orchestrates everything
532
  # -------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
  def run_pipeline(data_path: str, timestamp_col: str, target_col: str,
534
  out_report: str = 'lab3_report.html', freq: str = 'D'):
 
535
  """
536
  Главная точка запуска pipeline.
537
  """
@@ -673,6 +697,50 @@ def run_pipeline(data_path: str, timestamp_col: str, target_col: str,
673
  except Exception as e:
674
  print("VAR failed:", e)
675
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676
  # 3.6 Diagnostics later for top models
677
  # 3.7 Evaluate on test set
678
  eval_rows = []
@@ -810,6 +878,11 @@ def run_pipeline(data_path: str, timestamp_col: str, target_col: str,
810
  'pred': pd.Series(pred_values, index=pred_dates)
811
  })
812
 
 
 
 
 
 
813
 
814
  # -------------------------
815
  # helpers used in the pipeline but defined later
@@ -842,4 +915,41 @@ def create_forecast_index(last_train_date: pd.Timestamp, steps: int, freq: str =
842
  except Exception as e:
843
  print(f"Warning: could not create proper date index: {e}")
844
  # Fallback: числовой индекс
845
- return pd.RangeIndex(start=0, stop=steps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  except Exception:
68
  SCIPY_AVAILABLE = False
69
 
70
+ try:
71
+ from tbats import TBATS
72
+ TBATS_AVAILABLE = True
73
+ except ImportError:
74
+ TBATS_AVAILABLE = False
75
 
76
  # -------------------------------------------------------------------------
77
  # Metrics
 
525
 
526
  # Закрываем рисунок чтобы освободить память
527
  plt.close(fig)
528
+ break
529
 
530
  html_parts.append("</body></html>")
531
 
 
536
  # -------------------------------------------------------------------------
537
  # Main runner that orchestrates everything
538
  # -------------------------------------------------------------------------
539
+ def evaluate_with_cv(models_dict, X, y, cv_method='expanding', n_splits=5):
540
+ """Оценка моделей с кросс-валидацией"""
541
+ cv_results = {}
542
+
543
+ for name, model_func in models_dict.items():
544
+ if cv_method == 'expanding':
545
+ cv_scores = expanding_window_cv(X, y, model_func,
546
+ initial_train_size=len(X) // 2,
547
+ h=30, n_splits=n_splits)
548
+ else:
549
+ cv_scores = rolling_window_cv(X, y, model_func,
550
+ window=len(X) // 2,
551
+ h=30, n_splits=n_splits)
552
+ cv_results[name] = cv_scores
553
+
554
+ return cv_results
555
+
556
  def run_pipeline(data_path: str, timestamp_col: str, target_col: str,
557
  out_report: str = 'lab3_report.html', freq: str = 'D'):
558
+
559
  """
560
  Главная точка запуска pipeline.
561
  """
 
697
  except Exception as e:
698
  print("VAR failed:", e)
699
 
700
+ # TBATS модель
701
+ if TBATS_AVAILABLE:
702
+ try:
703
+ tbats_model = TBATS(seasonal_periods=[7, 30], use_arma_errors=True)
704
+ tbats_fitted = tbats_model.fit(y_train)
705
+ for h in horizons:
706
+ tbats_pred = tbats_fitted.forecast(steps=h)
707
+ pred_dates = create_forecast_index(y_train.index[-1], h, inferred_freq)
708
+ results.append({
709
+ 'model': 'TBATS',
710
+ 'h': h,
711
+ 'pred': pd.Series(tbats_pred, index=pred_dates)
712
+ })
713
+ except Exception as e:
714
+ print("TBATS failed:", e)
715
+
716
+ # Prophet модель
717
+ if PROPHET_AVAILABLE:
718
+ try:
719
+ prophet_df = y_train.reset_index()
720
+ prophet_df.columns = ['ds', 'y']
721
+ prophet_model = Prophet()
722
+ prophet_model.fit(prophet_df)
723
+ future = prophet_model.make_future_dataframe(periods=max(horizons), freq=inferred_freq)
724
+ forecast = prophet_model.predict(future)
725
+ for h in horizons:
726
+ prophet_pred = forecast.tail(h)['yhat'].values
727
+ pred_dates = create_forecast_index(y_train.index[-1], h, inferred_freq)
728
+ results.append({
729
+ 'model': 'Prophet',
730
+ 'h': h,
731
+ 'pred': pd.Series(prophet_pred, index=pred_dates)
732
+ })
733
+ except Exception as e:
734
+ print("Prophet failed:", e)
735
+
736
+ # GARCH на остатках SARIMAX
737
+ if ARCH_AVAILABLE and 'sar' in locals():
738
+ try:
739
+ garch_model = fit_garch_on_residuals(sar.resid, p=1, q=1)
740
+ # Прогноз волатильности можно добавить в анализ
741
+ except Exception as e:
742
+ print("GARCH failed:", e)
743
+
744
  # 3.6 Diagnostics later for top models
745
  # 3.7 Evaluate on test set
746
  eval_rows = []
 
878
  'pred': pd.Series(pred_values, index=pred_dates)
879
  })
880
 
881
+ cv_results = evaluate_with_cv({
882
+ 'SARIMAX': lambda X, y, h: forecast_recursive(fit_sarimax_simple(y), y, h),
883
+ 'AutoARIMA': lambda X, y, h: forecast_recursive(fit_auto_arima(y), y, h)
884
+ }, df_all.drop(columns=[target_col]), df_all[target_col])
885
+
886
 
887
  # -------------------------
888
  # helpers used in the pipeline but defined later
 
915
  except Exception as e:
916
  print(f"Warning: could not create proper date index: {e}")
917
  # Fallback: числовой индекс
918
+ return pd.RangeIndex(start=0, stop=steps)
919
+
920
+
921
+ def forecast_recursive(model, series, steps, freq='D'):
922
+ """Рекурсивная стратегия прогнозирования"""
923
+ predictions = []
924
+ current_series = series.copy()
925
+
926
+ for _ in range(steps):
927
+ if hasattr(model, 'predict'):
928
+ pred = model.predict(n_periods=1)
929
+ else:
930
+ pred = model.forecast(steps=1)
931
+ predictions.append(pred[0])
932
+ # Обновляем ряд для следующей итерации
933
+ current_series = pd.concat(
934
+ [current_series, pd.Series([pred[0]], index=[current_series.index[-1] + pd.Timedelta(days=1)])])
935
+
936
+ return np.array(predictions)
937
+
938
+
939
+ def forecast_direct(train_series, test_features, model_factory, steps):
940
+ """Прямая стратегия - отдельная модель для каждого горизонта"""
941
+ predictions = []
942
+ for h in range(1, steps + 1):
943
+ # Создаем смещенную целевую переменную
944
+ y_h = train_series.shift(-h).dropna()
945
+ X_h = train_series.iloc[:len(y_h)]
946
+
947
+ # Обучаем модель для горизонта h
948
+ model = model_factory()
949
+ model.fit(X_h.values.reshape(-1, 1), y_h.values)
950
+
951
+ # Прогноз для горизонта h
952
+ pred = model.predict(train_series.values[-1:].reshape(1, -1))
953
+ predictions.append(pred[0])
954
+
955
+ return np.array(predictions)