Spaces:
Running
Running
| # predict.py | |
| # Generates next-trading-day ETF signal from saved DQN weights. | |
| # Usage: | |
| # python predict.py --tsl 10 --z 1.1 | |
| import argparse | |
| import json | |
| import os | |
| import shutil | |
| from datetime import datetime, date, timedelta, timezone | |
| import numpy as np | |
| import pandas as pd | |
| import config | |
| from data_download import load_local | |
| from features import build_features | |
| from agent import DQNAgent | |
| WEIGHTS_PATH = os.path.join(config.MODELS_DIR, "dqn_best.pt") | |
| SUMMARY_PATH = os.path.join(config.MODELS_DIR, "training_summary.json") | |
| PRED_PATH = "latest_prediction.json" | |
| def next_trading_day(from_date=None) -> date: | |
| """Returns next NYSE trading day using pandas_market_calendars β no hardcoded holidays.""" | |
| try: | |
| import pandas_market_calendars as mcal | |
| nyse = mcal.get_calendar("NYSE") | |
| start = from_date or date.today() | |
| sched = nyse.schedule( | |
| start_date=start.strftime("%Y-%m-%d"), | |
| end_date=(start + timedelta(days=14)).strftime("%Y-%m-%d"), | |
| ) | |
| trading_dates = [d.date() for d in mcal.date_range(sched, frequency="1D")] | |
| for d in trading_dates: | |
| if d > start: | |
| return d | |
| except Exception: | |
| pass | |
| # Fallback: weekend skip only | |
| d = (from_date or date.today()) + timedelta(days=1) | |
| while d.weekday() >= 5: | |
| d += timedelta(days=1) | |
| return d | |
| def _q_zscore(q_vals: np.ndarray) -> np.ndarray: | |
| mu = q_vals.mean() | |
| std = q_vals.std() + 1e-9 | |
| return (q_vals - mu) / std | |
| def download_from_hf(): | |
| """Pull weights + data from HF Dataset if not present locally.""" | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| token = config.HF_TOKEN or None | |
| os.makedirs(config.DATA_DIR, exist_ok=True) | |
| os.makedirs(config.MODELS_DIR, exist_ok=True) | |
| for f in ["etf_prices", "macro"]: | |
| try: | |
| dl = hf_hub_download(repo_id=config.HF_DATASET_REPO, | |
| filename=f"data/{f}.parquet", | |
| repo_type="dataset", token=token) | |
| shutil.copy(dl, os.path.join(config.DATA_DIR, f"{f}.parquet")) | |
| except Exception as e: | |
| print(f" data/{f}: {e}") | |
| for f in ["dqn_best.pt", "training_summary.json"]: | |
| try: | |
| dl = hf_hub_download(repo_id=config.HF_DATASET_REPO, | |
| filename=f"models/{f}", | |
| repo_type="dataset", token=token) | |
| shutil.copy(dl, os.path.join(config.MODELS_DIR, f)) | |
| print(f" β models/{f}") | |
| except Exception as e: | |
| print(f" models/{f}: {e}") | |
| except Exception as e: | |
| print(f" HF download failed: {e}") | |
| def run_predict(tsl_pct: float = config.DEFAULT_TSL_PCT, | |
| z_reentry: float = config.DEFAULT_Z_REENTRY) -> dict: | |
| print(f"\n{'='*60}") | |
| print(f" Predict β {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") | |
| print(f"{'='*60}") | |
| # ββ Ensure data + weights βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| data = load_local() | |
| if not data: | |
| print(" No local data β downloading from HF...") | |
| download_from_hf() | |
| data = load_local() | |
| if not data: | |
| print(" ERROR: No data available.") | |
| return {} | |
| if not os.path.exists(WEIGHTS_PATH): | |
| print(" No local weights β downloading from HF...") | |
| download_from_hf() | |
| if not os.path.exists(WEIGHTS_PATH): | |
| print(" ERROR: No weights available.") | |
| return {} | |
| # ββ Load training metadata ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| trained_from_year = None | |
| trained_at = None | |
| lookback = config.LOOKBACK_WINDOW | |
| if os.path.exists(SUMMARY_PATH): | |
| with open(SUMMARY_PATH) as f: | |
| summary = json.load(f) | |
| trained_from_year = summary.get("start_year") | |
| trained_at = summary.get("trained_at") | |
| lookback = summary.get("lookback", config.LOOKBACK_WINDOW) | |
| # ββ Build features ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| etf_prices = data["etf_prices"] | |
| macro = data["macro"] | |
| feat_df = build_features(etf_prices, macro) | |
| # ββ Load agent ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # FIX: state_size must match env.py β flattened window + one-hot position (n_actions) | |
| state_size = feat_df.shape[1] * lookback + config.N_ACTIONS | |
| agent = DQNAgent(state_size=state_size) | |
| agent.load(WEIGHTS_PATH) | |
| # ββ Build current state (last lookback rows) ββββββββββββββββββββββββββββββ | |
| window = feat_df.iloc[-lookback:].values.astype(np.float32) | |
| if len(window) < lookback: | |
| pad = np.zeros((lookback - len(window), feat_df.shape[1]), dtype=np.float32) | |
| window = np.vstack([pad, window]) | |
| # FIX: append one-hot position β assume CASH at inference start (index 0) | |
| position = np.zeros(config.N_ACTIONS, dtype=np.float32) | |
| position[0] = 1.0 # CASH | |
| state = np.concatenate([window.flatten(), position]) | |
| # ββ Inference βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| q_values = agent.q_values(state) | |
| z_scores = _q_zscore(q_values) | |
| best_idx = int(q_values.argmax()) | |
| best_z = float(z_scores[best_idx]) | |
| # T-bill rate | |
| tbill_rate = 3.6 | |
| if "macro_TBILL_3M" in feat_df.columns: | |
| val = feat_df["macro_TBILL_3M"].iloc[-1] | |
| if not np.isnan(val): | |
| tbill_rate = float(val) | |
| # TSL / re-entry check using last 2 days | |
| tsl_triggered = False | |
| in_cash = False | |
| two_day_ret = 0.0 | |
| if best_idx != 0: | |
| etf = config.ACTIONS[best_idx] | |
| if etf in etf_prices.columns: | |
| last2 = etf_prices[etf].iloc[-3:] | |
| two_day_ret = float((last2.iloc[-1] / last2.iloc[0]) - 1) * 100 | |
| if two_day_ret <= -tsl_pct: | |
| tsl_triggered = True | |
| if best_z < z_reentry: | |
| in_cash = True | |
| final_signal = "CASH" if in_cash else config.ACTIONS[best_idx] | |
| # Signal date β use NYSE calendar | |
| now_est = datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(hours=5) | |
| today = now_est.date() | |
| try: | |
| import pandas_market_calendars as mcal | |
| nyse = mcal.get_calendar("NYSE") | |
| sched = nyse.schedule( | |
| start_date=today.strftime("%Y-%m-%d"), | |
| end_date=today.strftime("%Y-%m-%d"), | |
| ) | |
| is_trading_day = not sched.empty | |
| except Exception: | |
| is_trading_day = today.weekday() < 5 | |
| if is_trading_day and now_est.hour < 16: | |
| signal_date = today | |
| else: | |
| signal_date = next_trading_day(today) | |
| # Per-action probabilities (softmax of Q-values for display) | |
| q_shifted = q_values - q_values.max() | |
| exp_q = np.exp(q_shifted / 0.1) | |
| probs = exp_q / exp_q.sum() | |
| prob_dict = {config.ACTIONS[i]: round(float(probs[i]), 4) | |
| for i in range(config.N_ACTIONS)} | |
| output = dict( | |
| as_of_date = str(signal_date), | |
| final_signal = final_signal, | |
| final_confidence = round(float(probs[best_idx]), 4), | |
| z_score = round(best_z, 3), | |
| q_values = {config.ACTIONS[i]: round(float(q_values[i]), 4) | |
| for i in range(config.N_ACTIONS)}, | |
| probabilities = prob_dict, | |
| tbill_rate = round(tbill_rate, 3), | |
| tsl_status = dict( | |
| two_day_cumul_pct = round(two_day_ret, 2), | |
| tsl_triggered = tsl_triggered, | |
| in_cash = in_cash, | |
| z_reentry = z_reentry, | |
| tsl_pct = tsl_pct, | |
| ), | |
| trained_from_year = trained_from_year, | |
| trained_at = trained_at, | |
| ) | |
| with open(PRED_PATH, "w") as f: | |
| json.dump(output, f, indent=2, default=str) | |
| print(f"\n Signal date : {signal_date}") | |
| print(f" Final signal : {final_signal}") | |
| print(f" Z-score : {best_z:.2f}Ο") | |
| print(f" Confidence : {float(probs[best_idx]):.1%}") | |
| for act, p in prob_dict.items(): | |
| print(f" {act:<8} Q={q_values[config.ACTIONS.index(act)]:.3f} p={p:.3f}") | |
| print(f"\n Saved β {PRED_PATH}") | |
| return output | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--tsl", type=float, default=config.DEFAULT_TSL_PCT) | |
| parser.add_argument("--z", type=float, default=config.DEFAULT_Z_REENTRY) | |
| args = parser.parse_args() | |
| run_predict(tsl_pct=args.tsl, z_reentry=args.z) | |