copper-mind / app /inference.py
ifieryarrows's picture
Sync from GitHub (tests passed)
aa191f1 verified
"""
Inference module: Live prediction and snapshot generation.
Handles:
- Loading trained model
- Running inference on current data
- Generating analysis report
- Saving snapshots for caching
"""
import json
import logging
import re
# Suppress httpx request logging to prevent API keys in URLs from appearing in logs
logging.getLogger("httpx").setLevel(logging.WARNING)
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Optional
import numpy as np
import pandas as pd
import xgboost as xgb
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.db import SessionLocal
from app.models import (
PriceBar,
DailySentiment,
DailySentimentV2,
AnalysisSnapshot,
NewsArticle,
NewsSentiment,
NewsRaw,
NewsProcessed,
NewsSentimentV2,
)
from app.settings import get_settings
from app.features import (
load_price_data,
load_sentiment_data,
generate_symbol_features,
align_to_target_calendar,
get_feature_descriptions,
)
from app.ai_engine import load_model, load_model_metadata
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# =============================================================================
# Feature Alignment Helpers (Train/Inference compatibility)
# =============================================================================
def _sanitize_symbol(sym: str) -> str:
"""Convert symbol to safe column prefix (HG=F -> HG_F)."""
return re.sub(r"[^A-Za-z0-9]+", "_", sym).strip("_")
def _rename_sanitized_to_raw(df: pd.DataFrame, symbols: list[str]) -> pd.DataFrame:
"""
Rename sanitized column prefixes back to raw symbol names.
Example: HG_F_ret1 -> HG=F_ret1
"""
rename_map = {}
cols = list(df.columns)
for sym in symbols:
sanitized = _sanitize_symbol(sym)
if sanitized == sym:
continue # No change needed
sanitized_prefix = sanitized + "_"
raw_prefix = sym + "_"
for col in cols:
if col.startswith(sanitized_prefix):
new_name = raw_prefix + col[len(sanitized_prefix):]
rename_map[col] = new_name
if rename_map:
logger.debug(f"Renaming {len(rename_map)} columns from sanitized to raw")
return df.rename(columns=rename_map)
return df
def _align_features_to_model(df: pd.DataFrame, expected_features: list[str]) -> pd.DataFrame:
"""
Align DataFrame columns to match model's expected feature names.
- Missing features are filled with 0.0
- Extra features are dropped
- Column order matches expected_features
"""
if not expected_features:
logger.warning("No expected features provided; skipping alignment")
return df
present = set(df.columns)
expected = set(expected_features)
missing = expected - present
extra = present - expected
if missing or extra:
logger.info(
f"Feature alignment: expected={len(expected_features)} present={len(df.columns)} "
f"missing={len(missing)} extra={len(extra)}"
)
if missing:
logger.debug(f"Missing features (first 10): {list(missing)[:10]}")
if extra:
logger.debug(f"Extra features (first 10): {list(extra)[:10]}")
return df.reindex(columns=expected_features, fill_value=0.0)
def get_current_price(session: Session, symbol: str) -> Optional[float]:
"""
Get the current price for a symbol.
Priority:
1. Twelve Data API (most reliable, no rate limit issues)
2. yfinance live data (15-min delayed)
3. Database fallback
"""
import httpx
import yfinance as yf
from app.settings import get_settings
settings = get_settings()
# Try Twelve Data first (for XCU/USD copper)
if settings.twelvedata_api_key:
try:
with httpx.Client(timeout=10.0) as client:
response = client.get(
"https://api.twelvedata.com/price",
params={
"symbol": "XCU/USD",
"apikey": settings.twelvedata_api_key,
}
)
if response.status_code == 200:
data = response.json()
price = data.get("price")
if price:
logger.info(f"Using Twelve Data price for copper: ${float(price):.4f}")
return float(price)
except Exception as e:
from app.settings import mask_api_key
logger.debug(f"Twelve Data price fetch failed: {mask_api_key(str(e))}")
# Try yfinance as fallback
try:
ticker = yf.Ticker(symbol)
info = ticker.info
live_price = info.get('regularMarketPrice') or info.get('currentPrice')
if live_price is not None:
logger.info(f"Using yfinance price for {symbol}: ${live_price:.4f}")
return float(live_price)
except Exception as e:
logger.debug(f"yfinance price fetch failed for {symbol}: {e}")
# Final fallback to database
latest = session.query(PriceBar).filter(
PriceBar.symbol == symbol
).order_by(PriceBar.date.desc()).first()
if latest:
logger.info(f"Using DB price for {symbol}: ${latest.close:.4f}")
return latest.close
return None
def get_current_sentiment(session: Session) -> Optional[float]:
"""Get the most recent daily sentiment index."""
settings = get_settings()
source = str(getattr(settings, "scoring_source", "news_articles")).strip().lower()
if source == "news_processed":
latest_v2 = session.query(DailySentimentV2).order_by(
DailySentimentV2.date.desc()
).first()
if latest_v2 is not None:
return latest_v2.sentiment_index
logger.warning("No rows in daily_sentiments_v2; falling back to legacy daily_sentiments")
latest = session.query(DailySentiment).order_by(
DailySentiment.date.desc()
).first()
return latest.sentiment_index if latest else None
def get_data_quality_stats(
session: Session,
symbol: str,
days: int = 7
) -> dict:
"""Get data quality statistics for the report."""
settings = get_settings()
source = str(getattr(settings, "scoring_source", "news_articles")).strip().lower()
cutoff = datetime.now(timezone.utc) - timedelta(days=days)
if source == "news_processed":
horizon_days = max(1, int(getattr(settings, "sentiment_horizon_days", 5)))
news_count = session.query(func.count(NewsProcessed.id)).join(
NewsRaw, NewsProcessed.raw_id == NewsRaw.id
).filter(
NewsRaw.published_at >= cutoff
).scalar()
scored_count = session.query(func.count(NewsSentimentV2.id)).join(
NewsProcessed,
NewsSentimentV2.news_processed_id == NewsProcessed.id
).join(
NewsRaw,
NewsProcessed.raw_id == NewsRaw.id
).filter(
NewsRaw.published_at >= cutoff,
NewsSentimentV2.horizon_days == horizon_days,
).scalar()
else:
# Legacy article-level stats
news_count = session.query(func.count(NewsArticle.id)).filter(
NewsArticle.published_at >= cutoff
).scalar()
scored_count = session.query(func.count(NewsSentiment.id)).join(
NewsArticle,
NewsSentiment.news_article_id == NewsArticle.id
).filter(NewsArticle.published_at >= cutoff).scalar()
# Price bar coverage
expected_days = days
actual_bars = session.query(func.count(PriceBar.id)).filter(
PriceBar.symbol == symbol,
PriceBar.date >= cutoff
).scalar()
# Account for weekends (roughly 5/7 of days should have bars)
expected_trading_days = int(expected_days * 5 / 7)
coverage_pct = min(100, int(actual_bars / max(1, expected_trading_days) * 100))
# Missing days calculation
missing_days = max(0, expected_trading_days - actual_bars)
return {
"news_count_7d": news_count,
"scored_count_7d": scored_count,
"missing_days": missing_days,
"coverage_pct": coverage_pct,
}
def calculate_confidence_band(
session: Session,
symbol: str,
predicted_price: float,
lookback_days: int = 30
) -> tuple[float, float]:
"""
Calculate confidence band based on historical prediction errors.
Simple approach: use historical return volatility.
Returns:
Tuple of (lower_bound, upper_bound)
"""
# Load recent prices
cutoff = datetime.now(timezone.utc) - timedelta(days=lookback_days)
prices = session.query(PriceBar.close).filter(
PriceBar.symbol == symbol,
PriceBar.date >= cutoff
).order_by(PriceBar.date.asc()).all()
if len(prices) < 10:
# Not enough data, use 5% band
return predicted_price * 0.95, predicted_price * 1.05
closes = [p[0] for p in prices]
returns = pd.Series(closes).pct_change().dropna()
# 1 standard deviation of daily returns
std_ret = returns.std()
# Confidence band: ±1 std
lower = predicted_price * (1 - std_ret)
upper = predicted_price * (1 + std_ret)
return lower, upper
def get_sentiment_label(sentiment_index: float) -> str:
"""Convert sentiment index to label."""
if sentiment_index > 0.1:
return "Bullish"
elif sentiment_index < -0.1:
return "Bearish"
else:
return "Neutral"
def _sign(value: float) -> int:
"""Return numeric sign (-1, 0, 1)."""
if value > 0:
return 1
if value < 0:
return -1
return 0
def _clamp(value: float, lower: float, upper: float) -> float:
"""Clamp value to [lower, upper]."""
return max(lower, min(upper, value))
def _apply_sentiment_adjustment(
raw_predicted_return: float,
sentiment_index: float,
news_count_7d: int,
) -> tuple[float, float, bool, bool]:
"""
Apply aggressive-but-capped sentiment multiplier to raw predicted return.
Returns:
(adjusted_return, multiplier, adjustment_applied, capped)
"""
settings = get_settings()
news_ref = max(1.0, float(settings.inference_sentiment_news_ref))
power_ref = max(1e-6, float(settings.inference_sentiment_power_ref))
news_floor = max(1, int(round(news_ref * 0.4))) # default: 12 when ref is 30
news_intensity = min(1.0, max(0.0, float(news_count_7d) / news_ref))
sentiment_power = float(np.tanh(abs(float(sentiment_index)) / power_ref))
raw_sign = _sign(float(raw_predicted_return))
sentiment_sign = _sign(float(sentiment_index))
direction = 1.0 if raw_sign == 0 or raw_sign == sentiment_sign else -1.0
multiplier = 1.0 + (direction * sentiment_power * news_intensity)
multiplier = _clamp(
multiplier,
float(settings.inference_sentiment_multiplier_min),
float(settings.inference_sentiment_multiplier_max),
)
use_tiny_floor = (
abs(float(raw_predicted_return)) < float(settings.inference_tiny_signal_threshold)
and abs(float(sentiment_index)) >= power_ref
and int(news_count_7d) >= news_floor
)
if use_tiny_floor:
adjusted_return = float(sentiment_sign) * float(settings.inference_tiny_signal_floor)
else:
adjusted_return = float(raw_predicted_return) * multiplier
cap = abs(float(settings.inference_return_cap))
capped = False
if adjusted_return > cap:
adjusted_return = cap
capped = True
elif adjusted_return < -cap:
adjusted_return = -cap
capped = True
adjustment_applied = use_tiny_floor or capped or abs(multiplier - 1.0) > 1e-9
return adjusted_return, multiplier, adjustment_applied, capped
def build_features_for_prediction(
session: Session,
target_symbol: str,
feature_names: list[str]
) -> Optional[pd.DataFrame]:
"""
Build feature vector for live prediction.
Uses the most recent available data.
MUST use training_symbols to match the model's training data.
Includes robust alignment to handle:
- Sanitized vs raw symbol name differences (HG_F vs HG=F)
- Missing/extra features between training and inference
"""
settings = get_settings()
# Use training_symbols (not symbols_list) to match model training
symbols = settings.training_symbols
# Load recent data (need enough for feature calculation)
end_date = datetime.now(timezone.utc)
start_date = end_date - timedelta(days=60) # Need history for indicators
# Load target
target_df = load_price_data(session, target_symbol, start_date, end_date)
if target_df.empty:
logger.error(f"No price data for {target_symbol}")
return None
# Load other symbols
other_dfs = {}
for symbol in symbols:
if symbol != target_symbol:
df = load_price_data(session, symbol, start_date, end_date)
if not df.empty:
other_dfs[symbol] = df
# Align
aligned = align_to_target_calendar(target_df, other_dfs, max_ffill=3)
# Generate features
all_features = generate_symbol_features(target_df, target_symbol)
for symbol, df in aligned.items():
if not df.empty:
symbol_features = generate_symbol_features(df, symbol)
all_features = all_features.join(symbol_features, how="left")
# Add sentiment (use concat to avoid fragmentation warning)
sentiment_df = load_sentiment_data(session, start_date, end_date)
sentiment_parts = []
if not sentiment_df.empty:
sentiment_aligned = sentiment_df.reindex(target_df.index).ffill(limit=3)
sentiment_parts.append(
sentiment_aligned["sentiment_index"].fillna(settings.sentiment_missing_fill).rename("sentiment__index")
)
sentiment_parts.append(
sentiment_aligned["news_count"].fillna(0).rename("sentiment__news_count")
)
else:
sentiment_parts.append(
pd.Series(settings.sentiment_missing_fill, index=all_features.index, name="sentiment__index")
)
sentiment_parts.append(
pd.Series(0, index=all_features.index, name="sentiment__news_count")
)
all_features = pd.concat([all_features] + sentiment_parts, axis=1)
# Get latest row
latest = all_features.iloc[[-1]].copy()
# STEP 1: Rename sanitized prefixes to raw symbol names if needed
# This handles cases where feature generation used sanitized names (HG_F)
# but model was trained with raw names (HG=F)
all_symbols = [target_symbol] + list(symbols)
latest = _rename_sanitized_to_raw(latest, all_symbols)
# STEP 2: Align to model's expected features
# - Missing features get 0.0 (same as missing data handling in training)
# - Extra features are dropped
# - Column order matches expected feature_names
latest = _align_features_to_model(latest, feature_names)
# Ensure float dtype for XGBoost
latest = latest.astype(float)
return latest
def generate_analysis_report(
session: Session,
target_symbol: str = "HG=F"
) -> Optional[dict]:
"""
Generate a full analysis report.
Returns:
Dict with analysis data matching the API schema
"""
settings = get_settings()
# Load model
model = load_model(target_symbol)
if model is None:
logger.error(f"No model found for {target_symbol}")
return None
# Load metadata
metadata = load_model_metadata(target_symbol)
features = metadata.get("features", [])
importance = metadata.get("importance", [])
metrics = metadata.get("metrics", {})
if not features:
logger.error("No feature list found for model")
return None
# CRITICAL: Verify target_type is explicitly set
# Do NOT guess - wrong interpretation inverts prediction meaning
target_type = metrics.get("target_type")
if target_type not in ("simple_return", "log_return", "price"):
logger.error(f"Invalid or missing target_type in model metadata: {target_type}")
logger.error("Model must be retrained with explicit target_type. Cannot generate forecast.")
return None
# Get current price (for display - may be live yfinance or DB fallback)
current_price = get_current_price(session, target_symbol)
price_source = "yfinance_live" # Default assumption
if current_price is None:
logger.error(f"No price data for {target_symbol}")
return None
# Get latest DB close price for prediction base (baseline_price)
# Model predicts based on historical closes, not intraday prices
latest_bar = session.query(PriceBar).filter(
PriceBar.symbol == target_symbol
).order_by(PriceBar.date.desc()).first()
if latest_bar:
baseline_price = latest_bar.close
baseline_price_date = latest_bar.date.strftime("%Y-%m-%d") if latest_bar.date else None
price_source = "yfinance_db_close"
else:
baseline_price = current_price
baseline_price_date = None
price_source = "yfinance_live_fallback"
# Get current sentiment
current_sentiment = get_current_sentiment(session)
if current_sentiment is None:
current_sentiment = 0.0
# Build features for prediction
X = build_features_for_prediction(session, target_symbol, features)
if X is None or X.empty:
logger.error("Could not build features for prediction")
return None
# Make prediction
dmatrix = xgb.DMatrix(X, feature_names=features)
model_output = float(model.predict(dmatrix)[0])
logger.info(f"Model prediction: raw_output={model_output:.6f}, target_type={target_type}")
# Compute raw predicted return based on target_type
if target_type == "simple_return":
raw_predicted_return = model_output
elif target_type == "log_return":
import math
raw_predicted_return = math.exp(model_output) - 1
elif target_type == "price":
raw_predicted_return = (model_output / baseline_price) - 1 if baseline_price > 0 else 0
else:
raw_predicted_return = 0.0
# Data quality feeds sentiment multiplier intensity.
data_quality = get_data_quality_stats(session, target_symbol)
news_count_7d = int(data_quality.get("news_count_7d") or 0)
predicted_return, sentiment_multiplier, adjustment_applied, predicted_return_capped = (
_apply_sentiment_adjustment(
raw_predicted_return=float(raw_predicted_return),
sentiment_index=float(current_sentiment),
news_count_7d=news_count_7d,
)
)
logger.info(
"Sentiment adjustment: raw=%.6f adjusted=%.6f multiplier=%.4f applied=%s capped=%s news_count_7d=%s sentiment=%.4f",
raw_predicted_return,
predicted_return,
sentiment_multiplier,
adjustment_applied,
predicted_return_capped,
news_count_7d,
current_sentiment,
)
predicted_price = baseline_price * (1 + predicted_return)
# Validate prediction after sentiment adjustment/cap.
prediction_invalid = False
if predicted_return < -1.0:
logger.error(f"Invalid prediction: return {predicted_return:.4f} < -100%")
prediction_invalid = True
if predicted_price <= 0:
logger.error(f"Invalid prediction: price {predicted_price:.4f} <= 0")
prediction_invalid = True
if prediction_invalid:
return None
# Calculate confidence band
conf_lower, conf_upper = calculate_confidence_band(
session, target_symbol, predicted_price
)
# Build influencer descriptions
descriptions = get_feature_descriptions()
top_influencers = []
for item in importance[:10]:
feat = item["feature"]
# Try to find description
desc = None
for key, value in descriptions.items():
if key in feat:
desc = value
break
if desc is None:
# Build from feature name
desc = feat.replace("_", " ").replace(" ", " ").title()
top_influencers.append({
"feature": feat,
"importance": item["importance"],
"description": desc,
})
# Build report with explicit baseline_price and target_type
report = {
"symbol": target_symbol,
"current_price": round(current_price, 4),
"baseline_price": round(baseline_price, 4),
"baseline_price_date": baseline_price_date,
"predicted_return": round(predicted_return, 6),
"raw_predicted_return": round(raw_predicted_return, 6),
"sentiment_multiplier": round(sentiment_multiplier, 4),
"sentiment_adjustment_applied": bool(adjustment_applied),
"predicted_return_capped": bool(predicted_return_capped),
"predicted_return_pct": round(predicted_return * 100, 2),
"predicted_price": round(predicted_price, 4),
"target_type": target_type,
"price_source": price_source,
"confidence_lower": round(conf_lower, 4),
"confidence_upper": round(conf_upper, 4),
"sentiment_index": round(current_sentiment, 4),
"sentiment_label": get_sentiment_label(current_sentiment),
"top_influencers": top_influencers,
"data_quality": data_quality,
"training_symbols_hash": settings.training_symbols_hash,
"generated_at": datetime.now(timezone.utc).isoformat(),
}
return report
def save_analysis_snapshot(
session: Session,
report: dict,
symbol: str
) -> AnalysisSnapshot:
"""Save analysis report as a snapshot."""
now = datetime.now(timezone.utc)
# Check for existing snapshot today
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
existing = session.query(AnalysisSnapshot).filter(
AnalysisSnapshot.symbol == symbol,
AnalysisSnapshot.as_of_date >= today_start
).first()
if existing:
# Update existing
existing.report_json = report
existing.generated_at = now
snapshot = existing
else:
# Create new
snapshot = AnalysisSnapshot(
symbol=symbol,
as_of_date=now,
report_json=report,
generated_at=now,
)
session.add(snapshot)
session.commit()
logger.info(f"Snapshot saved for {symbol}")
return snapshot
def get_latest_snapshot(
session: Session,
symbol: str,
max_age_minutes: int = 30
) -> Optional[dict]:
"""
Get the latest snapshot if it's fresh enough.
Returns:
Report dict if fresh snapshot exists, None otherwise
"""
cutoff = datetime.now(timezone.utc) - timedelta(minutes=max_age_minutes)
snapshot = session.query(AnalysisSnapshot).filter(
AnalysisSnapshot.symbol == symbol,
AnalysisSnapshot.generated_at >= cutoff
).order_by(AnalysisSnapshot.generated_at.desc()).first()
if snapshot:
return snapshot.report_json
return None
def get_any_snapshot(
session: Session,
symbol: str
) -> Optional[dict]:
"""Get the most recent snapshot regardless of age."""
snapshot = session.query(AnalysisSnapshot).filter(
AnalysisSnapshot.symbol == symbol
).order_by(AnalysisSnapshot.generated_at.desc()).first()
if snapshot:
return snapshot.report_json
return None