Spaces:
Running
Running
Sync from GitHub (tests passed)
Browse files- app/ai_engine.py +3 -0
- app/inference.py +55 -7
app/ai_engine.py
CHANGED
|
@@ -688,8 +688,11 @@ def train_xgboost_model(
|
|
| 688 |
model.save_model(str(latest_path))
|
| 689 |
|
| 690 |
# Save metrics (including training symbols audit)
|
|
|
|
|
|
|
| 691 |
metrics = {
|
| 692 |
"target_symbol": target_symbol,
|
|
|
|
| 693 |
"trained_at": datetime.now(timezone.utc).isoformat(),
|
| 694 |
"train_samples": len(X_train),
|
| 695 |
"val_samples": len(X_val),
|
|
|
|
| 688 |
model.save_model(str(latest_path))
|
| 689 |
|
| 690 |
# Save metrics (including training symbols audit)
|
| 691 |
+
# TARGET_TYPE: "simple_return" means model predicts next-day return, not price
|
| 692 |
+
# This MUST be read by inference to correctly compute predicted_price
|
| 693 |
metrics = {
|
| 694 |
"target_symbol": target_symbol,
|
| 695 |
+
"target_type": "simple_return", # Model predicts: close(t+1)/close(t) - 1
|
| 696 |
"trained_at": datetime.now(timezone.utc).isoformat(),
|
| 697 |
"train_samples": len(X_train),
|
| 698 |
"val_samples": len(X_val),
|
app/inference.py
CHANGED
|
@@ -277,6 +277,8 @@ def generate_analysis_report(
|
|
| 277 |
Returns:
|
| 278 |
Dict with analysis data matching the API schema
|
| 279 |
"""
|
|
|
|
|
|
|
| 280 |
# Load model
|
| 281 |
model = load_model(target_symbol)
|
| 282 |
if model is None:
|
|
@@ -287,23 +289,42 @@ def generate_analysis_report(
|
|
| 287 |
metadata = load_model_metadata(target_symbol)
|
| 288 |
features = metadata.get("features", [])
|
| 289 |
importance = metadata.get("importance", [])
|
|
|
|
| 290 |
|
| 291 |
if not features:
|
| 292 |
logger.error("No feature list found for model")
|
| 293 |
return None
|
| 294 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
# Get current price (for display - may be live yfinance or DB fallback)
|
| 296 |
current_price = get_current_price(session, target_symbol)
|
|
|
|
|
|
|
| 297 |
if current_price is None:
|
| 298 |
logger.error(f"No price data for {target_symbol}")
|
| 299 |
return None
|
| 300 |
|
| 301 |
-
# Get latest DB close price for prediction base
|
| 302 |
# Model predicts based on historical closes, not intraday prices
|
| 303 |
latest_bar = session.query(PriceBar).filter(
|
| 304 |
PriceBar.symbol == target_symbol
|
| 305 |
).order_by(PriceBar.date.desc()).first()
|
| 306 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
|
| 308 |
# Get current sentiment
|
| 309 |
current_sentiment = get_current_sentiment(session)
|
|
@@ -318,10 +339,31 @@ def generate_analysis_report(
|
|
| 318 |
|
| 319 |
# Make prediction
|
| 320 |
dmatrix = xgb.DMatrix(X, feature_names=features)
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
#
|
| 324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
|
| 326 |
# Calculate confidence band
|
| 327 |
conf_lower, conf_upper = calculate_confidence_band(
|
|
@@ -354,18 +396,24 @@ def generate_analysis_report(
|
|
| 354 |
"description": desc,
|
| 355 |
})
|
| 356 |
|
| 357 |
-
# Build report
|
| 358 |
report = {
|
| 359 |
"symbol": target_symbol,
|
| 360 |
"current_price": round(current_price, 4),
|
|
|
|
|
|
|
| 361 |
"predicted_return": round(predicted_return, 6),
|
|
|
|
| 362 |
"predicted_price": round(predicted_price, 4),
|
|
|
|
|
|
|
| 363 |
"confidence_lower": round(conf_lower, 4),
|
| 364 |
"confidence_upper": round(conf_upper, 4),
|
| 365 |
"sentiment_index": round(current_sentiment, 4),
|
| 366 |
"sentiment_label": get_sentiment_label(current_sentiment),
|
| 367 |
"top_influencers": top_influencers,
|
| 368 |
"data_quality": data_quality,
|
|
|
|
| 369 |
"generated_at": datetime.now(timezone.utc).isoformat(),
|
| 370 |
}
|
| 371 |
|
|
|
|
| 277 |
Returns:
|
| 278 |
Dict with analysis data matching the API schema
|
| 279 |
"""
|
| 280 |
+
settings = get_settings()
|
| 281 |
+
|
| 282 |
# Load model
|
| 283 |
model = load_model(target_symbol)
|
| 284 |
if model is None:
|
|
|
|
| 289 |
metadata = load_model_metadata(target_symbol)
|
| 290 |
features = metadata.get("features", [])
|
| 291 |
importance = metadata.get("importance", [])
|
| 292 |
+
metrics = metadata.get("metrics", {})
|
| 293 |
|
| 294 |
if not features:
|
| 295 |
logger.error("No feature list found for model")
|
| 296 |
return None
|
| 297 |
|
| 298 |
+
# CRITICAL: Verify target_type is explicitly set
|
| 299 |
+
# Do NOT guess - wrong interpretation inverts prediction meaning
|
| 300 |
+
target_type = metrics.get("target_type")
|
| 301 |
+
if target_type not in ("simple_return", "log_return", "price"):
|
| 302 |
+
logger.error(f"Invalid or missing target_type in model metadata: {target_type}")
|
| 303 |
+
logger.error("Model must be retrained with explicit target_type. Cannot generate forecast.")
|
| 304 |
+
return None
|
| 305 |
+
|
| 306 |
# Get current price (for display - may be live yfinance or DB fallback)
|
| 307 |
current_price = get_current_price(session, target_symbol)
|
| 308 |
+
price_source = "yfinance_live" # Default assumption
|
| 309 |
+
|
| 310 |
if current_price is None:
|
| 311 |
logger.error(f"No price data for {target_symbol}")
|
| 312 |
return None
|
| 313 |
|
| 314 |
+
# Get latest DB close price for prediction base (baseline_price)
|
| 315 |
# Model predicts based on historical closes, not intraday prices
|
| 316 |
latest_bar = session.query(PriceBar).filter(
|
| 317 |
PriceBar.symbol == target_symbol
|
| 318 |
).order_by(PriceBar.date.desc()).first()
|
| 319 |
+
|
| 320 |
+
if latest_bar:
|
| 321 |
+
baseline_price = latest_bar.close
|
| 322 |
+
baseline_price_date = latest_bar.date.strftime("%Y-%m-%d") if latest_bar.date else None
|
| 323 |
+
price_source = "yfinance_db_close"
|
| 324 |
+
else:
|
| 325 |
+
baseline_price = current_price
|
| 326 |
+
baseline_price_date = None
|
| 327 |
+
price_source = "yfinance_live_fallback"
|
| 328 |
|
| 329 |
# Get current sentiment
|
| 330 |
current_sentiment = get_current_sentiment(session)
|
|
|
|
| 339 |
|
| 340 |
# Make prediction
|
| 341 |
dmatrix = xgb.DMatrix(X, feature_names=features)
|
| 342 |
+
model_output = float(model.predict(dmatrix)[0])
|
| 343 |
+
|
| 344 |
+
# Compute predicted_return and predicted_price based on target_type
|
| 345 |
+
if target_type == "simple_return":
|
| 346 |
+
predicted_return = model_output
|
| 347 |
+
predicted_price = baseline_price * (1 + predicted_return)
|
| 348 |
+
elif target_type == "log_return":
|
| 349 |
+
import math
|
| 350 |
+
predicted_return = math.exp(model_output) - 1
|
| 351 |
+
predicted_price = baseline_price * math.exp(model_output)
|
| 352 |
+
elif target_type == "price":
|
| 353 |
+
predicted_price = model_output
|
| 354 |
+
predicted_return = (predicted_price / baseline_price) - 1 if baseline_price > 0 else 0
|
| 355 |
+
|
| 356 |
+
# Validate prediction (do not clamp by default - expose issues)
|
| 357 |
+
prediction_invalid = False
|
| 358 |
+
if predicted_return < -1.0:
|
| 359 |
+
logger.error(f"Invalid prediction: return {predicted_return:.4f} < -100%")
|
| 360 |
+
prediction_invalid = True
|
| 361 |
+
if predicted_price <= 0:
|
| 362 |
+
logger.error(f"Invalid prediction: price {predicted_price:.4f} <= 0")
|
| 363 |
+
prediction_invalid = True
|
| 364 |
+
|
| 365 |
+
if prediction_invalid:
|
| 366 |
+
return None
|
| 367 |
|
| 368 |
# Calculate confidence band
|
| 369 |
conf_lower, conf_upper = calculate_confidence_band(
|
|
|
|
| 396 |
"description": desc,
|
| 397 |
})
|
| 398 |
|
| 399 |
+
# Build report with explicit baseline_price and target_type
|
| 400 |
report = {
|
| 401 |
"symbol": target_symbol,
|
| 402 |
"current_price": round(current_price, 4),
|
| 403 |
+
"baseline_price": round(baseline_price, 4),
|
| 404 |
+
"baseline_price_date": baseline_price_date,
|
| 405 |
"predicted_return": round(predicted_return, 6),
|
| 406 |
+
"predicted_return_pct": round(predicted_return * 100, 2),
|
| 407 |
"predicted_price": round(predicted_price, 4),
|
| 408 |
+
"target_type": target_type,
|
| 409 |
+
"price_source": price_source,
|
| 410 |
"confidence_lower": round(conf_lower, 4),
|
| 411 |
"confidence_upper": round(conf_upper, 4),
|
| 412 |
"sentiment_index": round(current_sentiment, 4),
|
| 413 |
"sentiment_label": get_sentiment_label(current_sentiment),
|
| 414 |
"top_influencers": top_influencers,
|
| 415 |
"data_quality": data_quality,
|
| 416 |
+
"training_symbols_hash": settings.training_symbols_hash,
|
| 417 |
"generated_at": datetime.now(timezone.utc).isoformat(),
|
| 418 |
}
|
| 419 |
|