ifieryarrows commited on
Commit
782299f
·
verified ·
1 Parent(s): 67a99b9

Sync from GitHub (tests passed)

Browse files
Files changed (2) hide show
  1. app/ai_engine.py +3 -0
  2. app/inference.py +55 -7
app/ai_engine.py CHANGED
@@ -688,8 +688,11 @@ def train_xgboost_model(
688
  model.save_model(str(latest_path))
689
 
690
  # Save metrics (including training symbols audit)
 
 
691
  metrics = {
692
  "target_symbol": target_symbol,
 
693
  "trained_at": datetime.now(timezone.utc).isoformat(),
694
  "train_samples": len(X_train),
695
  "val_samples": len(X_val),
 
688
  model.save_model(str(latest_path))
689
 
690
  # Save metrics (including training symbols audit)
691
+ # TARGET_TYPE: "simple_return" means model predicts next-day return, not price
692
+ # This MUST be read by inference to correctly compute predicted_price
693
  metrics = {
694
  "target_symbol": target_symbol,
695
+ "target_type": "simple_return", # Model predicts: close(t+1)/close(t) - 1
696
  "trained_at": datetime.now(timezone.utc).isoformat(),
697
  "train_samples": len(X_train),
698
  "val_samples": len(X_val),
app/inference.py CHANGED
@@ -277,6 +277,8 @@ def generate_analysis_report(
277
  Returns:
278
  Dict with analysis data matching the API schema
279
  """
 
 
280
  # Load model
281
  model = load_model(target_symbol)
282
  if model is None:
@@ -287,23 +289,42 @@ def generate_analysis_report(
287
  metadata = load_model_metadata(target_symbol)
288
  features = metadata.get("features", [])
289
  importance = metadata.get("importance", [])
 
290
 
291
  if not features:
292
  logger.error("No feature list found for model")
293
  return None
294
 
 
 
 
 
 
 
 
 
295
  # Get current price (for display - may be live yfinance or DB fallback)
296
  current_price = get_current_price(session, target_symbol)
 
 
297
  if current_price is None:
298
  logger.error(f"No price data for {target_symbol}")
299
  return None
300
 
301
- # Get latest DB close price for prediction base
302
  # Model predicts based on historical closes, not intraday prices
303
  latest_bar = session.query(PriceBar).filter(
304
  PriceBar.symbol == target_symbol
305
  ).order_by(PriceBar.date.desc()).first()
306
- prediction_base = latest_bar.close if latest_bar else current_price
 
 
 
 
 
 
 
 
307
 
308
  # Get current sentiment
309
  current_sentiment = get_current_sentiment(session)
@@ -318,10 +339,31 @@ def generate_analysis_report(
318
 
319
  # Make prediction
320
  dmatrix = xgb.DMatrix(X, feature_names=features)
321
- predicted_return = float(model.predict(dmatrix)[0])
322
-
323
- # Calculate predicted price using DB close as base (model trains on closes)
324
- predicted_price = prediction_base * (1 + predicted_return)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
  # Calculate confidence band
327
  conf_lower, conf_upper = calculate_confidence_band(
@@ -354,18 +396,24 @@ def generate_analysis_report(
354
  "description": desc,
355
  })
356
 
357
- # Build report
358
  report = {
359
  "symbol": target_symbol,
360
  "current_price": round(current_price, 4),
 
 
361
  "predicted_return": round(predicted_return, 6),
 
362
  "predicted_price": round(predicted_price, 4),
 
 
363
  "confidence_lower": round(conf_lower, 4),
364
  "confidence_upper": round(conf_upper, 4),
365
  "sentiment_index": round(current_sentiment, 4),
366
  "sentiment_label": get_sentiment_label(current_sentiment),
367
  "top_influencers": top_influencers,
368
  "data_quality": data_quality,
 
369
  "generated_at": datetime.now(timezone.utc).isoformat(),
370
  }
371
 
 
277
  Returns:
278
  Dict with analysis data matching the API schema
279
  """
280
+ settings = get_settings()
281
+
282
  # Load model
283
  model = load_model(target_symbol)
284
  if model is None:
 
289
  metadata = load_model_metadata(target_symbol)
290
  features = metadata.get("features", [])
291
  importance = metadata.get("importance", [])
292
+ metrics = metadata.get("metrics", {})
293
 
294
  if not features:
295
  logger.error("No feature list found for model")
296
  return None
297
 
298
+ # CRITICAL: Verify target_type is explicitly set
299
+ # Do NOT guess - wrong interpretation inverts prediction meaning
300
+ target_type = metrics.get("target_type")
301
+ if target_type not in ("simple_return", "log_return", "price"):
302
+ logger.error(f"Invalid or missing target_type in model metadata: {target_type}")
303
+ logger.error("Model must be retrained with explicit target_type. Cannot generate forecast.")
304
+ return None
305
+
306
  # Get current price (for display - may be live yfinance or DB fallback)
307
  current_price = get_current_price(session, target_symbol)
308
+ price_source = "yfinance_live" # Default assumption
309
+
310
  if current_price is None:
311
  logger.error(f"No price data for {target_symbol}")
312
  return None
313
 
314
+ # Get latest DB close price for prediction base (baseline_price)
315
  # Model predicts based on historical closes, not intraday prices
316
  latest_bar = session.query(PriceBar).filter(
317
  PriceBar.symbol == target_symbol
318
  ).order_by(PriceBar.date.desc()).first()
319
+
320
+ if latest_bar:
321
+ baseline_price = latest_bar.close
322
+ baseline_price_date = latest_bar.date.strftime("%Y-%m-%d") if latest_bar.date else None
323
+ price_source = "yfinance_db_close"
324
+ else:
325
+ baseline_price = current_price
326
+ baseline_price_date = None
327
+ price_source = "yfinance_live_fallback"
328
 
329
  # Get current sentiment
330
  current_sentiment = get_current_sentiment(session)
 
339
 
340
  # Make prediction
341
  dmatrix = xgb.DMatrix(X, feature_names=features)
342
+ model_output = float(model.predict(dmatrix)[0])
343
+
344
+ # Compute predicted_return and predicted_price based on target_type
345
+ if target_type == "simple_return":
346
+ predicted_return = model_output
347
+ predicted_price = baseline_price * (1 + predicted_return)
348
+ elif target_type == "log_return":
349
+ import math
350
+ predicted_return = math.exp(model_output) - 1
351
+ predicted_price = baseline_price * math.exp(model_output)
352
+ elif target_type == "price":
353
+ predicted_price = model_output
354
+ predicted_return = (predicted_price / baseline_price) - 1 if baseline_price > 0 else 0
355
+
356
+ # Validate prediction (do not clamp by default - expose issues)
357
+ prediction_invalid = False
358
+ if predicted_return < -1.0:
359
+ logger.error(f"Invalid prediction: return {predicted_return:.4f} < -100%")
360
+ prediction_invalid = True
361
+ if predicted_price <= 0:
362
+ logger.error(f"Invalid prediction: price {predicted_price:.4f} <= 0")
363
+ prediction_invalid = True
364
+
365
+ if prediction_invalid:
366
+ return None
367
 
368
  # Calculate confidence band
369
  conf_lower, conf_upper = calculate_confidence_band(
 
396
  "description": desc,
397
  })
398
 
399
+ # Build report with explicit baseline_price and target_type
400
  report = {
401
  "symbol": target_symbol,
402
  "current_price": round(current_price, 4),
403
+ "baseline_price": round(baseline_price, 4),
404
+ "baseline_price_date": baseline_price_date,
405
  "predicted_return": round(predicted_return, 6),
406
+ "predicted_return_pct": round(predicted_return * 100, 2),
407
  "predicted_price": round(predicted_price, 4),
408
+ "target_type": target_type,
409
+ "price_source": price_source,
410
  "confidence_lower": round(conf_lower, 4),
411
  "confidence_upper": round(conf_upper, 4),
412
  "sentiment_index": round(current_sentiment, 4),
413
  "sentiment_label": get_sentiment_label(current_sentiment),
414
  "top_influencers": top_influencers,
415
  "data_quality": data_quality,
416
+ "training_symbols_hash": settings.training_symbols_hash,
417
  "generated_at": datetime.now(timezone.utc).isoformat(),
418
  }
419