""" Predictor — loads saved XGBoost + LightGBM models and generates forecasts at inference time. Runs entirely on CPU. Usage: python model/predictor.py --symbol ZW=F python model/predictor.py --all """ import argparse import json import logging import pickle import sys from datetime import date, datetime, timedelta from pathlib import Path import numpy as np import pandas as pd sys.path.insert(0, str(Path(__file__).parent.parent)) from model.feature_builder import build_prediction_features from data.db import get_conn log = logging.getLogger(__name__) MODELS_DIR = Path(__file__).parent.parent / "models" SYMBOL_NAMES: dict[str, str] = { "CL=F": "Crude Oil", "NG=F": "Natural Gas", "GC=F": "Gold", "ZW=F": "Wheat", "ZC=F": "Corn", "ZS=F": "Soybeans", "CT=F": "Cotton", "SB=F": "Sugar", "USDINR=X":"USD/INR", "HG=F": "Copper", } # Human-readable labels for SHAP feature display FEATURE_LABELS: dict[str, str] = { "rsi_14": "RSI (14-day)", "macd_signal": "MACD crossover", "bb_position": "Bollinger Band position", "atr_14": "Average True Range", "atr_pct": "Volatility %", "sma_20_50_cross": "SMA 20/50 crossover", "return_1d": "1-day return %", "return_7d": "7-day return %", "return_30d": "30-day return %", "momentum_score": "Momentum score", "month_sin": "Seasonal cycle (sin)", "month_cos": "Seasonal cycle (cos)", "harvest_season_flag": "Harvest season", "days_to_opec_meeting":"Days to OPEC meeting", "oil_gold_ratio": "Oil/Gold ratio", "dxy_proxy": "USD strength proxy", "sentiment_score_1d": "News sentiment (24h)", "sentiment_3d": "News sentiment (3-day)", "sentiment_7d": "News sentiment (7-day)", "article_count_7d": "Article volume (7-day)", "positive_ratio_7d": "Positive news ratio", "bullish_events_7d": "Bullish events (7-day)", "bearish_events_7d": "Bearish events (7-day)", "max_severity_7d": "Max event severity", "direction_score_7d": "Net event direction", "supply_shock_flag": "Supply shock detected", "policy_change_flag": "Policy change detected", "risk_score_7d": "Geopolitical risk (7-day)", "risk_score_30d": "Geopolitical risk (30-day)", "drought_index": "Drought index", "heat_stress_days": "Heat stress days", "precip_anomaly_pct": "Precipitation anomaly %", } # Expected return by predicted direction (base, adjusted per-commodity) DIRECTION_EXPECTED_RETURN: dict[str, float] = { "UP": 3.0, "STABLE": 0.0, "DOWN": -3.0, } # ── model cache (loaded once per process) ───────────────────────────────────── _model_cache: dict[str, dict] = {} def _load_models(symbol: str, horizon: str = "7d") -> dict | None: """ Load XGBoost, LightGBM, scaler, and feature names for a symbol. Caches in memory for the process lifetime. Returns None if models not found (not trained yet). """ cache_key = f"{symbol}_{horizon}" if cache_key in _model_cache: return _model_cache[cache_key] xgb_path = MODELS_DIR / f"xgb_{symbol}_{horizon}.pkl" lgbm_path = MODELS_DIR / f"lgbm_{symbol}_{horizon}.pkl" scaler_path = MODELS_DIR / f"scaler_{symbol}_{horizon}.pkl" feat_path = MODELS_DIR / f"feature_names_{symbol}_{horizon}.json" if not all(p.exists() for p in [xgb_path, lgbm_path, scaler_path, feat_path]): log.warning("Models not found for %s %s — run model/trainer.py first", symbol, horizon) return None with open(xgb_path, "rb") as f: xgb_model = pickle.load(f) with open(lgbm_path, "rb") as f: lgbm_model = pickle.load(f) with open(scaler_path, "rb") as f: scaler = pickle.load(f) with open(feat_path) as f: feature_names = json.load(f) bundle = { "xgb": xgb_model, "lgbm": lgbm_model, "scaler": scaler, "features": feature_names, } _model_cache[cache_key] = bundle return bundle def _get_shap_top5(xgb_model, X_row: np.ndarray, feature_names: list[str], pred_class: int) -> list[dict]: """ Compute SHAP values for XGBoost and return top 5 features by |shap_value| for the predicted class. """ try: import shap explainer = shap.TreeExplainer(xgb_model) shap_vals = explainer.shap_values(X_row) # shape: (n_classes, n_features) or (1, n_classes, n_features) # shap_values shape varies by XGBoost version if isinstance(shap_vals, list): vals = shap_vals[pred_class][0] # for predicted class else: vals = shap_vals[0, :, pred_class] if shap_vals.ndim == 3 else shap_vals[0] top_idx = np.argsort(np.abs(vals))[::-1][:5] result = [] for i in top_idx: fname = feature_names[i] if i < len(feature_names) else f"feature_{i}" fval = float(X_row[0][i]) shap_v = float(vals[i]) result.append({ "feature": fname, "label": FEATURE_LABELS.get(fname, fname), "value": round(fval, 4), "impact": "BULLISH" if shap_v > 0 else "BEARISH", "weight": round(abs(shap_v), 4), }) return result except Exception as exc: log.debug("SHAP error: %s", exc) return [] def _get_current_price(symbol: str) -> tuple[float, float]: """Return (current_close, atr_pct) from latest DB row.""" conn = get_conn() rows = conn.execute( "SELECT close FROM prices WHERE symbol = ? ORDER BY date DESC LIMIT 2", [symbol], ).fetchall() conn.close() if not rows: return 0.0, 0.02 close = float(rows[0][0]) # Rough ATR proxy: |today - yesterday| / today atr_pct = abs(float(rows[0][0]) - float(rows[1][0])) / close if len(rows) > 1 and close > 0 else 0.02 return close, atr_pct # ── public API ───────────────────────────────────────────────────────────────── def predict(symbol: str, as_of_date: str = None) -> dict: """ Generate a forecast for a single commodity. Args: symbol: Commodity ticker, e.g. "ZW=F" as_of_date: ISO date string. Defaults to today. Returns: Forecast dict with symbol, current price, 7d + 30d forecasts, top_signals, and confidence levels. Returns error dict if models are not trained. """ as_of = as_of_date or date.today().isoformat() bundle_7d = _load_models(symbol, "7d") bundle_30d = _load_models(symbol, "30d") if bundle_7d is None: return {"symbol": symbol, "error": "models_not_trained", "as_of_date": as_of} # Build feature vector features_series = build_prediction_features(symbol, as_of) if features_series.empty: return {"symbol": symbol, "error": "no_features", "as_of_date": as_of} # Align to trained feature names feat_names_7d = bundle_7d["features"] X_raw = features_series.reindex(feat_names_7d, fill_value=0).values.reshape(1, -1) X_scaled_7d = bundle_7d["scaler"].transform(pd.DataFrame(X_raw, columns=feat_names_7d)) # Ensemble prediction — 7d X_df_7d = pd.DataFrame(X_scaled_7d, columns=feat_names_7d) xgb_proba_7d = bundle_7d["xgb"].predict_proba(X_df_7d)[0] lgbm_proba_7d = bundle_7d["lgbm"].predict_proba(X_df_7d)[0] ensemble_proba_7d = (xgb_proba_7d + lgbm_proba_7d) / 2 pred_class_7d = int(ensemble_proba_7d.argmax()) # Map encoded class back: 0=DOWN, 1=STABLE, 2=UP direction_map = {0: "DOWN", 1: "STABLE", 2: "UP"} direction_7d = direction_map[pred_class_7d] prob_7d = float(ensemble_proba_7d[pred_class_7d]) # Ensemble prediction — 30d (may not be trained) direction_30d, prob_30d = "STABLE", 0.5 if bundle_30d: feat_names_30d = bundle_30d["features"] X_raw_30d = features_series.reindex(feat_names_30d, fill_value=0).values.reshape(1, -1) X_scaled_30d = bundle_30d["scaler"].transform(pd.DataFrame(X_raw_30d, columns=feat_names_30d)) X_df_30d = pd.DataFrame(X_scaled_30d, columns=feat_names_30d) xgb_proba_30d = bundle_30d["xgb"].predict_proba(X_df_30d)[0] lgbm_proba_30d = bundle_30d["lgbm"].predict_proba(X_df_30d)[0] ensemble_proba_30d = (xgb_proba_30d + lgbm_proba_30d) / 2 pred_class_30d = int(ensemble_proba_30d.argmax()) direction_30d = direction_map[pred_class_30d] prob_30d = float(ensemble_proba_30d[pred_class_30d]) # Confidence tier — tuned thresholds (validated via 3.5yr walk-forward backtest) # HIGH fires ~9% of time with ~74% accuracy vs 45% overall def _confidence(prob: float) -> str: if prob >= 0.55: return "HIGH" if prob >= 0.45: return "MEDIUM" return "LOW" # Signal confirmation: require 2+ independent signals to issue HIGH confidence. def _confirmed_confidence(prob: float, direction: str, feat: pd.Series) -> str: base = _confidence(prob) if base == "LOW": return "LOW" confirming = 0 # Signal 1: price momentum agrees mom = float(feat.get("momentum_score", 0) or 0) ret7 = float(feat.get("return_7d", 0) or 0) if direction == "UP" and (mom > 0 or ret7 > 0): confirming += 1 if direction == "DOWN" and (mom < 0 or ret7 < 0): confirming += 1 # Signal 2: COT commercial positioning agrees (commercials = smart money) cot_net = float(feat.get("cot_commercial_net_pct", 0) or 0) cot_chg = float(feat.get("cot_commercial_chg_1w", 0) or 0) if direction == "UP" and (cot_net > 0.05 or cot_chg > 0): confirming += 1 if direction == "DOWN" and (cot_net < -0.05 or cot_chg < 0): confirming += 1 # Signal 3: EIA supply signal agrees (for CL=F and NG=F) eia_draw = float(feat.get("eia_crude_draw", 0) or feat.get("eia_natgas_draw", 0) or 0) eia_vs5yr = float(feat.get("eia_crude_vs_5yr", 0) or feat.get("eia_natgas_vs_5yr", 0) or 0) if direction == "UP" and (eia_draw > 0 or eia_vs5yr < -0.5): confirming += 1 if direction == "DOWN" and eia_vs5yr > 0.5: confirming += 1 # Signal 4: USDA crop condition trend agrees (for grain/ag symbols) crop_chg = float(feat.get("usda_crop_good_exc_chg", 0) or 0) if direction == "DOWN" and crop_chg < -2: confirming += 1 if direction == "UP" and crop_chg > 2: confirming += 1 # 2+ signals → HIGH; 0 signals → LOW; 1 signal → MEDIUM if confirming >= 2: return "HIGH" if confirming == 0: return "LOW" return "MEDIUM" # Price range using ATR current_price, atr_pct = _get_current_price(symbol) exp_ret = DIRECTION_EXPECTED_RETURN.get(direction_7d, 0.0) / 100 price_range_low = round(current_price * (1 + exp_ret - 1.5 * atr_pct), 2) price_range_high = round(current_price * (1 + exp_ret + 1.5 * atr_pct), 2) # SHAP top signals top_signals = _get_shap_top5(bundle_7d["xgb"], X_scaled_7d, feat_names_7d, pred_class_7d) conf_7d = _confirmed_confidence(prob_7d, direction_7d, features_series) conf_30d = _confirmed_confidence(prob_30d, direction_30d, features_series) # Symbols where 7d model has known accuracy issues — surface a warning. UNRELIABLE_7D = {"ZC=F", "HG=F"} model_warning = ( "7d model accuracy is low for this symbol — use 30d forecast instead" if symbol in UNRELIABLE_7D else None ) return { "symbol": symbol, "commodity_name": SYMBOL_NAMES.get(symbol, symbol), "as_of_date": as_of, "current_price": current_price, "forecast_7d": { "direction": direction_7d, "probability": round(prob_7d, 4), "price_range_low": price_range_low, "price_range_high": price_range_high, "confidence": conf_7d, "model_warning": model_warning, }, "forecast_30d": { "direction": direction_30d, "probability": round(prob_30d, 4), "confidence": conf_30d, }, "top_signals": top_signals, } def predict_all(as_of_date: str = None) -> dict[str, dict]: """ Generate forecasts for all 10 commodities and save to DuckDB. Returns: Dict mapping symbol → forecast dict. """ from signals.price_features import ALL_SYMBOLS results = {} for symbol in ALL_SYMBOLS: try: fc = predict(symbol, as_of_date) results[symbol] = fc if "error" not in fc: _save_forecast(fc) except Exception as exc: log.error("predict %s failed: %s", symbol, exc) results[symbol] = {"symbol": symbol, "error": str(exc)} return results def _save_forecast(fc: dict) -> None: """Persist a forecast to DuckDB for accuracy tracking.""" conn = get_conn() try: conn.execute( """ INSERT OR REPLACE INTO accuracy_log (date, symbol, forecast_direction, actual_direction, was_correct, confidence) VALUES (?, ?, ?, NULL, NULL, ?) """, [ fc["as_of_date"], fc["symbol"], fc["forecast_7d"]["direction"], fc["forecast_7d"]["confidence"], ], ) except Exception as exc: log.debug("Forecast save error: %s", exc) finally: conn.close() if __name__ == "__main__": parser = argparse.ArgumentParser(description="CommodiSense predictor") parser.add_argument("--symbol", default=None, help="Single symbol to predict") parser.add_argument("--all", action="store_true", help="Predict all symbols") parser.add_argument("--date", default=None, help="As-of date YYYY-MM-DD") args = parser.parse_args() logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") if args.all: results = predict_all(args.date) for sym, fc in results.items(): if "error" not in fc: d7 = fc["forecast_7d"] print(f"{sym:<12} {d7['direction']:<7} {d7['probability']:.0%} [{d7['confidence']}]") elif args.symbol: fc = predict(args.symbol, args.date) print(json.dumps(fc, indent=2, default=str)) else: parser.print_help()