P2-ETF-DQN-ENGINE / predict.py
P2SAMAPA's picture
[auto] Deploy to HF Space from GitHub
c309d08 verified
# predict.py
# Generates next-trading-day ETF signal from saved DQN weights.
# Usage:
# python predict.py --tsl 10 --z 1.1
import argparse
import json
import os
import shutil
from datetime import datetime, date, timedelta, timezone
import numpy as np
import pandas as pd
import config
from data_download import load_local
from features import build_features
from agent import DQNAgent
WEIGHTS_PATH = os.path.join(config.MODELS_DIR, "dqn_best.pt")
SUMMARY_PATH = os.path.join(config.MODELS_DIR, "training_summary.json")
PRED_PATH = "latest_prediction.json"
def next_trading_day(from_date=None) -> date:
"""Returns next NYSE trading day using pandas_market_calendars β€” no hardcoded holidays."""
try:
import pandas_market_calendars as mcal
nyse = mcal.get_calendar("NYSE")
start = from_date or date.today()
sched = nyse.schedule(
start_date=start.strftime("%Y-%m-%d"),
end_date=(start + timedelta(days=14)).strftime("%Y-%m-%d"),
)
trading_dates = [d.date() for d in mcal.date_range(sched, frequency="1D")]
for d in trading_dates:
if d > start:
return d
except Exception:
pass
# Fallback: weekend skip only
d = (from_date or date.today()) + timedelta(days=1)
while d.weekday() >= 5:
d += timedelta(days=1)
return d
def _q_zscore(q_vals: np.ndarray) -> np.ndarray:
mu = q_vals.mean()
std = q_vals.std() + 1e-9
return (q_vals - mu) / std
def download_from_hf():
"""Pull weights + data from HF Dataset if not present locally."""
try:
from huggingface_hub import hf_hub_download
token = config.HF_TOKEN or None
os.makedirs(config.DATA_DIR, exist_ok=True)
os.makedirs(config.MODELS_DIR, exist_ok=True)
for f in ["etf_prices", "macro"]:
try:
dl = hf_hub_download(repo_id=config.HF_DATASET_REPO,
filename=f"data/{f}.parquet",
repo_type="dataset", token=token)
shutil.copy(dl, os.path.join(config.DATA_DIR, f"{f}.parquet"))
except Exception as e:
print(f" data/{f}: {e}")
for f in ["dqn_best.pt", "training_summary.json"]:
try:
dl = hf_hub_download(repo_id=config.HF_DATASET_REPO,
filename=f"models/{f}",
repo_type="dataset", token=token)
shutil.copy(dl, os.path.join(config.MODELS_DIR, f))
print(f" βœ“ models/{f}")
except Exception as e:
print(f" models/{f}: {e}")
except Exception as e:
print(f" HF download failed: {e}")
def run_predict(tsl_pct: float = config.DEFAULT_TSL_PCT,
z_reentry: float = config.DEFAULT_Z_REENTRY) -> dict:
print(f"\n{'='*60}")
print(f" Predict β€” {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"{'='*60}")
# ── Ensure data + weights ─────────────────────────────────────────────────
data = load_local()
if not data:
print(" No local data β€” downloading from HF...")
download_from_hf()
data = load_local()
if not data:
print(" ERROR: No data available.")
return {}
if not os.path.exists(WEIGHTS_PATH):
print(" No local weights β€” downloading from HF...")
download_from_hf()
if not os.path.exists(WEIGHTS_PATH):
print(" ERROR: No weights available.")
return {}
# ── Load training metadata ────────────────────────────────────────────────
trained_from_year = None
trained_at = None
lookback = config.LOOKBACK_WINDOW
if os.path.exists(SUMMARY_PATH):
with open(SUMMARY_PATH) as f:
summary = json.load(f)
trained_from_year = summary.get("start_year")
trained_at = summary.get("trained_at")
lookback = summary.get("lookback", config.LOOKBACK_WINDOW)
# ── Build features ────────────────────────────────────────────────────────
etf_prices = data["etf_prices"]
macro = data["macro"]
feat_df = build_features(etf_prices, macro)
# ── Load agent ────────────────────────────────────────────────────────────
# FIX: state_size must match env.py β€” flattened window + one-hot position (n_actions)
state_size = feat_df.shape[1] * lookback + config.N_ACTIONS
agent = DQNAgent(state_size=state_size)
agent.load(WEIGHTS_PATH)
# ── Build current state (last lookback rows) ──────────────────────────────
window = feat_df.iloc[-lookback:].values.astype(np.float32)
if len(window) < lookback:
pad = np.zeros((lookback - len(window), feat_df.shape[1]), dtype=np.float32)
window = np.vstack([pad, window])
# FIX: append one-hot position β€” assume CASH at inference start (index 0)
position = np.zeros(config.N_ACTIONS, dtype=np.float32)
position[0] = 1.0 # CASH
state = np.concatenate([window.flatten(), position])
# ── Inference ─────────────────────────────────────────────────────────────
q_values = agent.q_values(state)
z_scores = _q_zscore(q_values)
best_idx = int(q_values.argmax())
best_z = float(z_scores[best_idx])
# T-bill rate
tbill_rate = 3.6
if "macro_TBILL_3M" in feat_df.columns:
val = feat_df["macro_TBILL_3M"].iloc[-1]
if not np.isnan(val):
tbill_rate = float(val)
# TSL / re-entry check using last 2 days
tsl_triggered = False
in_cash = False
two_day_ret = 0.0
if best_idx != 0:
etf = config.ACTIONS[best_idx]
if etf in etf_prices.columns:
last2 = etf_prices[etf].iloc[-3:]
two_day_ret = float((last2.iloc[-1] / last2.iloc[0]) - 1) * 100
if two_day_ret <= -tsl_pct:
tsl_triggered = True
if best_z < z_reentry:
in_cash = True
final_signal = "CASH" if in_cash else config.ACTIONS[best_idx]
# Signal date β€” use NYSE calendar
now_est = datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(hours=5)
today = now_est.date()
try:
import pandas_market_calendars as mcal
nyse = mcal.get_calendar("NYSE")
sched = nyse.schedule(
start_date=today.strftime("%Y-%m-%d"),
end_date=today.strftime("%Y-%m-%d"),
)
is_trading_day = not sched.empty
except Exception:
is_trading_day = today.weekday() < 5
if is_trading_day and now_est.hour < 16:
signal_date = today
else:
signal_date = next_trading_day(today)
# Per-action probabilities (softmax of Q-values for display)
q_shifted = q_values - q_values.max()
exp_q = np.exp(q_shifted / 0.1)
probs = exp_q / exp_q.sum()
prob_dict = {config.ACTIONS[i]: round(float(probs[i]), 4)
for i in range(config.N_ACTIONS)}
output = dict(
as_of_date = str(signal_date),
final_signal = final_signal,
final_confidence = round(float(probs[best_idx]), 4),
z_score = round(best_z, 3),
q_values = {config.ACTIONS[i]: round(float(q_values[i]), 4)
for i in range(config.N_ACTIONS)},
probabilities = prob_dict,
tbill_rate = round(tbill_rate, 3),
tsl_status = dict(
two_day_cumul_pct = round(two_day_ret, 2),
tsl_triggered = tsl_triggered,
in_cash = in_cash,
z_reentry = z_reentry,
tsl_pct = tsl_pct,
),
trained_from_year = trained_from_year,
trained_at = trained_at,
)
with open(PRED_PATH, "w") as f:
json.dump(output, f, indent=2, default=str)
print(f"\n Signal date : {signal_date}")
print(f" Final signal : {final_signal}")
print(f" Z-score : {best_z:.2f}Οƒ")
print(f" Confidence : {float(probs[best_idx]):.1%}")
for act, p in prob_dict.items():
print(f" {act:<8} Q={q_values[config.ACTIONS.index(act)]:.3f} p={p:.3f}")
print(f"\n Saved β†’ {PRED_PATH}")
return output
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tsl", type=float, default=config.DEFAULT_TSL_PCT)
parser.add_argument("--z", type=float, default=config.DEFAULT_Z_REENTRY)
args = parser.parse_args()
run_predict(tsl_pct=args.tsl, z_reentry=args.z)