Spaces:
No application file
No application file
Kolesnikov Dmitry
commited on
Commit
·
6e8ed89
1
Parent(s):
7c4a2f3
feat: Красивые графики
Browse files- src/lab3_pipeline.py +111 -1
src/lab3_pipeline.py
CHANGED
|
@@ -67,6 +67,11 @@ try:
|
|
| 67 |
except Exception:
|
| 68 |
SCIPY_AVAILABLE = False
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
# -------------------------------------------------------------------------
|
| 72 |
# Metrics
|
|
@@ -520,6 +525,7 @@ def generate_report_html(out_path: str, plots: List[plt.Figure], tables: Dict[st
|
|
| 520 |
|
| 521 |
# Закрываем рисунок чтобы освободить память
|
| 522 |
plt.close(fig)
|
|
|
|
| 523 |
|
| 524 |
html_parts.append("</body></html>")
|
| 525 |
|
|
@@ -530,8 +536,26 @@ def generate_report_html(out_path: str, plots: List[plt.Figure], tables: Dict[st
|
|
| 530 |
# -------------------------------------------------------------------------
|
| 531 |
# Main runner that orchestrates everything
|
| 532 |
# -------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 533 |
def run_pipeline(data_path: str, timestamp_col: str, target_col: str,
|
| 534 |
out_report: str = 'lab3_report.html', freq: str = 'D'):
|
|
|
|
| 535 |
"""
|
| 536 |
Главная точка запуска pipeline.
|
| 537 |
"""
|
|
@@ -673,6 +697,50 @@ def run_pipeline(data_path: str, timestamp_col: str, target_col: str,
|
|
| 673 |
except Exception as e:
|
| 674 |
print("VAR failed:", e)
|
| 675 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 676 |
# 3.6 Diagnostics later for top models
|
| 677 |
# 3.7 Evaluate on test set
|
| 678 |
eval_rows = []
|
|
@@ -810,6 +878,11 @@ def run_pipeline(data_path: str, timestamp_col: str, target_col: str,
|
|
| 810 |
'pred': pd.Series(pred_values, index=pred_dates)
|
| 811 |
})
|
| 812 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 813 |
|
| 814 |
# -------------------------
|
| 815 |
# helpers used in the pipeline but defined later
|
|
@@ -842,4 +915,41 @@ def create_forecast_index(last_train_date: pd.Timestamp, steps: int, freq: str =
|
|
| 842 |
except Exception as e:
|
| 843 |
print(f"Warning: could not create proper date index: {e}")
|
| 844 |
# Fallback: числовой индекс
|
| 845 |
-
return pd.RangeIndex(start=0, stop=steps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
except Exception:
|
| 68 |
SCIPY_AVAILABLE = False
|
| 69 |
|
| 70 |
+
try:
|
| 71 |
+
from tbats import TBATS
|
| 72 |
+
TBATS_AVAILABLE = True
|
| 73 |
+
except ImportError:
|
| 74 |
+
TBATS_AVAILABLE = False
|
| 75 |
|
| 76 |
# -------------------------------------------------------------------------
|
| 77 |
# Metrics
|
|
|
|
| 525 |
|
| 526 |
# Закрываем рисунок чтобы освободить память
|
| 527 |
plt.close(fig)
|
| 528 |
+
break
|
| 529 |
|
| 530 |
html_parts.append("</body></html>")
|
| 531 |
|
|
|
|
| 536 |
# -------------------------------------------------------------------------
|
| 537 |
# Main runner that orchestrates everything
|
| 538 |
# -------------------------------------------------------------------------
|
| 539 |
+
def evaluate_with_cv(models_dict, X, y, cv_method='expanding', n_splits=5):
|
| 540 |
+
"""Оценка моделей с кросс-валидацией"""
|
| 541 |
+
cv_results = {}
|
| 542 |
+
|
| 543 |
+
for name, model_func in models_dict.items():
|
| 544 |
+
if cv_method == 'expanding':
|
| 545 |
+
cv_scores = expanding_window_cv(X, y, model_func,
|
| 546 |
+
initial_train_size=len(X) // 2,
|
| 547 |
+
h=30, n_splits=n_splits)
|
| 548 |
+
else:
|
| 549 |
+
cv_scores = rolling_window_cv(X, y, model_func,
|
| 550 |
+
window=len(X) // 2,
|
| 551 |
+
h=30, n_splits=n_splits)
|
| 552 |
+
cv_results[name] = cv_scores
|
| 553 |
+
|
| 554 |
+
return cv_results
|
| 555 |
+
|
| 556 |
def run_pipeline(data_path: str, timestamp_col: str, target_col: str,
|
| 557 |
out_report: str = 'lab3_report.html', freq: str = 'D'):
|
| 558 |
+
|
| 559 |
"""
|
| 560 |
Главная точка запуска pipeline.
|
| 561 |
"""
|
|
|
|
| 697 |
except Exception as e:
|
| 698 |
print("VAR failed:", e)
|
| 699 |
|
| 700 |
+
# TBATS модель
|
| 701 |
+
if TBATS_AVAILABLE:
|
| 702 |
+
try:
|
| 703 |
+
tbats_model = TBATS(seasonal_periods=[7, 30], use_arma_errors=True)
|
| 704 |
+
tbats_fitted = tbats_model.fit(y_train)
|
| 705 |
+
for h in horizons:
|
| 706 |
+
tbats_pred = tbats_fitted.forecast(steps=h)
|
| 707 |
+
pred_dates = create_forecast_index(y_train.index[-1], h, inferred_freq)
|
| 708 |
+
results.append({
|
| 709 |
+
'model': 'TBATS',
|
| 710 |
+
'h': h,
|
| 711 |
+
'pred': pd.Series(tbats_pred, index=pred_dates)
|
| 712 |
+
})
|
| 713 |
+
except Exception as e:
|
| 714 |
+
print("TBATS failed:", e)
|
| 715 |
+
|
| 716 |
+
# Prophet модель
|
| 717 |
+
if PROPHET_AVAILABLE:
|
| 718 |
+
try:
|
| 719 |
+
prophet_df = y_train.reset_index()
|
| 720 |
+
prophet_df.columns = ['ds', 'y']
|
| 721 |
+
prophet_model = Prophet()
|
| 722 |
+
prophet_model.fit(prophet_df)
|
| 723 |
+
future = prophet_model.make_future_dataframe(periods=max(horizons), freq=inferred_freq)
|
| 724 |
+
forecast = prophet_model.predict(future)
|
| 725 |
+
for h in horizons:
|
| 726 |
+
prophet_pred = forecast.tail(h)['yhat'].values
|
| 727 |
+
pred_dates = create_forecast_index(y_train.index[-1], h, inferred_freq)
|
| 728 |
+
results.append({
|
| 729 |
+
'model': 'Prophet',
|
| 730 |
+
'h': h,
|
| 731 |
+
'pred': pd.Series(prophet_pred, index=pred_dates)
|
| 732 |
+
})
|
| 733 |
+
except Exception as e:
|
| 734 |
+
print("Prophet failed:", e)
|
| 735 |
+
|
| 736 |
+
# GARCH на остатках SARIMAX
|
| 737 |
+
if ARCH_AVAILABLE and 'sar' in locals():
|
| 738 |
+
try:
|
| 739 |
+
garch_model = fit_garch_on_residuals(sar.resid, p=1, q=1)
|
| 740 |
+
# Прогноз волатильности можно добавить в анализ
|
| 741 |
+
except Exception as e:
|
| 742 |
+
print("GARCH failed:", e)
|
| 743 |
+
|
| 744 |
# 3.6 Diagnostics later for top models
|
| 745 |
# 3.7 Evaluate on test set
|
| 746 |
eval_rows = []
|
|
|
|
| 878 |
'pred': pd.Series(pred_values, index=pred_dates)
|
| 879 |
})
|
| 880 |
|
| 881 |
+
cv_results = evaluate_with_cv({
|
| 882 |
+
'SARIMAX': lambda X, y, h: forecast_recursive(fit_sarimax_simple(y), y, h),
|
| 883 |
+
'AutoARIMA': lambda X, y, h: forecast_recursive(fit_auto_arima(y), y, h)
|
| 884 |
+
}, df_all.drop(columns=[target_col]), df_all[target_col])
|
| 885 |
+
|
| 886 |
|
| 887 |
# -------------------------
|
| 888 |
# helpers used in the pipeline but defined later
|
|
|
|
| 915 |
except Exception as e:
|
| 916 |
print(f"Warning: could not create proper date index: {e}")
|
| 917 |
# Fallback: числовой индекс
|
| 918 |
+
return pd.RangeIndex(start=0, stop=steps)
|
| 919 |
+
|
| 920 |
+
|
| 921 |
+
def forecast_recursive(model, series, steps, freq='D'):
|
| 922 |
+
"""Рекурсивная стратегия прогнозирования"""
|
| 923 |
+
predictions = []
|
| 924 |
+
current_series = series.copy()
|
| 925 |
+
|
| 926 |
+
for _ in range(steps):
|
| 927 |
+
if hasattr(model, 'predict'):
|
| 928 |
+
pred = model.predict(n_periods=1)
|
| 929 |
+
else:
|
| 930 |
+
pred = model.forecast(steps=1)
|
| 931 |
+
predictions.append(pred[0])
|
| 932 |
+
# Обновляем ряд для следующей итерации
|
| 933 |
+
current_series = pd.concat(
|
| 934 |
+
[current_series, pd.Series([pred[0]], index=[current_series.index[-1] + pd.Timedelta(days=1)])])
|
| 935 |
+
|
| 936 |
+
return np.array(predictions)
|
| 937 |
+
|
| 938 |
+
|
| 939 |
+
def forecast_direct(train_series, test_features, model_factory, steps):
|
| 940 |
+
"""Прямая стратегия - отдельная модель для каждого горизонта"""
|
| 941 |
+
predictions = []
|
| 942 |
+
for h in range(1, steps + 1):
|
| 943 |
+
# Создаем смещенную целевую переменную
|
| 944 |
+
y_h = train_series.shift(-h).dropna()
|
| 945 |
+
X_h = train_series.iloc[:len(y_h)]
|
| 946 |
+
|
| 947 |
+
# Обучаем модель для горизонта h
|
| 948 |
+
model = model_factory()
|
| 949 |
+
model.fit(X_h.values.reshape(-1, 1), y_h.values)
|
| 950 |
+
|
| 951 |
+
# Прогноз для горизонта h
|
| 952 |
+
pred = model.predict(train_series.values[-1:].reshape(1, -1))
|
| 953 |
+
predictions.append(pred[0])
|
| 954 |
+
|
| 955 |
+
return np.array(predictions)
|