copper-mind / deep_learning /data /future_frame.py
ifieryarrows's picture
Sync from GitHub (tests passed)
d317049 verified
"""Future decoder row construction for TFT inference."""
from __future__ import annotations
import pandas as pd
from deep_learning.config import TFTASROConfig
from deep_learning.data.feature_store import _build_calendar_features
from deep_learning.data.regime_features import REGIME_FEATURES
def _future_business_dates(last_date, horizon: int) -> pd.DatetimeIndex:
start = pd.Timestamp(last_date).normalize() + pd.offsets.BDay(1)
return pd.bdate_range(start=start, periods=horizon)
def build_future_decoder_rows(
history_df: pd.DataFrame,
horizon: int,
cfg: TFTASROConfig,
) -> pd.DataFrame:
"""
Build no-lookahead future decoder rows.
Known future values are limited to calendar features. Unknown future
covariates receive deterministic placeholders that never use future price,
news, embedding, or event information.
"""
if history_df.empty:
raise ValueError("history_df must contain at least one encoder row")
last = history_df.iloc[-1].copy()
future_index = _future_business_dates(history_df.index.max(), horizon)
future = pd.DataFrame([last.to_dict()] * horizon, index=future_index)
for col in future.columns:
if col == "group_id":
future[col] = last.get(col, "copper")
elif col == "time_idx":
start_idx = int(history_df["time_idx"].iloc[-1]) + 1 if "time_idx" in history_df else len(history_df)
future[col] = range(start_idx, start_idx + horizon)
target_cols = {
cfg.forecast.model_daily_target_col,
cfg.forecast.auxiliary_target_col,
cfg.forecast.primary_target_col,
"realized_vol_20d",
"material_move_5d",
}
for col in target_cols:
if col in future.columns:
future[col] = 0.0
calendar = _build_calendar_features(future_index)
for col in calendar.columns:
if col in future.columns:
future[col] = calendar[col].values
neutral_exact = {
"sentiment_index",
"news_count",
"material_news_count",
"after_close_news_count",
"event_shock_score",
"sentiment_x_supply_shock",
"sentiment_x_usd_pressure",
"sentiment_x_risk_on",
"event_shock_x_high_vol",
}
for col in future.columns:
lower = col.lower()
if col in neutral_exact or col.startswith("emb_pca_") or col.startswith("evt_"):
future[col] = 0.0
elif "ret" in lower or "roc" in lower or "momentum" in lower:
future[col] = 0.0
for col in REGIME_FEATURES:
if col in future.columns and col != "event_shock_score":
future[col] = float(last.get(col, 0.0))
prev_days = float(last.get("days_since_last_material_news", 999.0))
if "days_since_last_material_news" in future.columns:
future["days_since_last_material_news"] = [prev_days + i for i in range(1, horizon + 1)]
if "stale_sentiment_flag" in future.columns:
future["stale_sentiment_flag"] = (
future.get("days_since_last_material_news", pd.Series(999.0, index=future.index)) >= 3
).astype(float)
return future[history_df.columns].copy()