ifieryarrows commited on
Commit
e998ea8
·
verified ·
1 Parent(s): f4e8f50

Sync from GitHub

Browse files
Files changed (2) hide show
  1. app/ai_engine.py +99 -1
  2. app/models.py +29 -0
app/ai_engine.py CHANGED
@@ -449,6 +449,20 @@ def train_xgboost_model(
449
  desc = descriptions.get(feat, feat)
450
  logger.info(f" {feat}: {imp:.4f} ({desc})")
451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
  return {
453
  "model_path": str(model_path),
454
  "metrics": metrics,
@@ -474,8 +488,92 @@ def load_model(target_symbol: str = "HG=F") -> Optional[xgb.Booster]:
474
  return model
475
 
476
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  def load_model_metadata(target_symbol: str = "HG=F") -> dict:
478
- """Load metrics and feature info for a model."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
  settings = get_settings()
480
  model_dir = Path(settings.model_dir)
481
 
 
449
  desc = descriptions.get(feat, feat)
450
  logger.info(f" {feat}: {imp:.4f} ({desc})")
451
 
452
+ # Save metadata to database for persistence across HF Space restarts
453
+ try:
454
+ from app.db import SessionLocal
455
+ with SessionLocal() as session:
456
+ save_model_metadata_to_db(
457
+ session=session,
458
+ symbol=target_symbol,
459
+ importance=normalized_importance,
460
+ features=feature_names,
461
+ metrics=metrics,
462
+ )
463
+ except Exception as e:
464
+ logger.warning(f"Could not save model metadata to DB: {e}")
465
+
466
  return {
467
  "model_path": str(model_path),
468
  "metrics": metrics,
 
488
  return model
489
 
490
 
491
+ def save_model_metadata_to_db(
492
+ session,
493
+ symbol: str,
494
+ importance: list,
495
+ features: list,
496
+ metrics: dict
497
+ ) -> None:
498
+ """
499
+ Save model metadata to database for persistence across restarts.
500
+ Called after train_model=True pipeline runs.
501
+ """
502
+ from .models import ModelMetadata
503
+ from datetime import datetime
504
+
505
+ # Try to find existing record
506
+ existing = session.query(ModelMetadata).filter(ModelMetadata.symbol == symbol).first()
507
+
508
+ if existing:
509
+ existing.importance_json = json.dumps(importance)
510
+ existing.features_json = json.dumps(features)
511
+ existing.metrics_json = json.dumps(metrics)
512
+ existing.trained_at = datetime.utcnow()
513
+ logger.info(f"Updated model metadata in DB for {symbol}")
514
+ else:
515
+ new_record = ModelMetadata(
516
+ symbol=symbol,
517
+ importance_json=json.dumps(importance),
518
+ features_json=json.dumps(features),
519
+ metrics_json=json.dumps(metrics),
520
+ )
521
+ session.add(new_record)
522
+ logger.info(f"Saved new model metadata to DB for {symbol}")
523
+
524
+ session.commit()
525
+
526
+
527
+ def load_model_metadata_from_db(session, symbol: str) -> dict:
528
+ """
529
+ Load model metadata from database.
530
+ Returns dict with importance, features, metrics or None values if not found.
531
+ """
532
+ from .models import ModelMetadata
533
+
534
+ metadata = {
535
+ "metrics": None,
536
+ "features": None,
537
+ "importance": None,
538
+ }
539
+
540
+ record = session.query(ModelMetadata).filter(ModelMetadata.symbol == symbol).first()
541
+
542
+ if record:
543
+ try:
544
+ if record.importance_json:
545
+ metadata["importance"] = json.loads(record.importance_json)
546
+ if record.features_json:
547
+ metadata["features"] = json.loads(record.features_json)
548
+ if record.metrics_json:
549
+ metadata["metrics"] = json.loads(record.metrics_json)
550
+ logger.info(f"Loaded model metadata from DB for {symbol}")
551
+ except json.JSONDecodeError as e:
552
+ logger.warning(f"Failed to parse model metadata from DB: {e}")
553
+
554
+ return metadata
555
+
556
+
557
  def load_model_metadata(target_symbol: str = "HG=F") -> dict:
558
+ """
559
+ Load metrics and feature info for a model.
560
+
561
+ Priority:
562
+ 1. Database (survives HF Space restarts)
563
+ 2. Local JSON files (fallback for development)
564
+ """
565
+ from app.db import SessionLocal
566
+
567
+ # Try database first
568
+ try:
569
+ with SessionLocal() as session:
570
+ db_metadata = load_model_metadata_from_db(session, target_symbol)
571
+ if db_metadata.get("importance") and db_metadata.get("features"):
572
+ return db_metadata
573
+ except Exception as e:
574
+ logger.debug(f"Could not load metadata from DB: {e}")
575
+
576
+ # Fallback to local files
577
  settings = get_settings()
578
  model_dir = Path(settings.model_dir)
579
 
app/models.py CHANGED
@@ -230,3 +230,32 @@ class AICommentary(Base):
230
 
231
  def __repr__(self):
232
  return f"<AICommentary(symbol={self.symbol}, generated_at={self.generated_at})>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
  def __repr__(self):
232
  return f"<AICommentary(symbol={self.symbol}, generated_at={self.generated_at})>"
233
+
234
+
235
+ class ModelMetadata(Base):
236
+ """
237
+ Persisted XGBoost model metadata.
238
+ Stores feature importance, features list, and metrics in database
239
+ so they survive HF Space restarts.
240
+ One row per symbol, updated after each model training (train_model=True).
241
+ """
242
+ __tablename__ = "model_metadata"
243
+
244
+ id = Column(Integer, primary_key=True, autoincrement=True)
245
+
246
+ symbol = Column(String(20), nullable=False, unique=True, index=True)
247
+
248
+ # Feature importance as JSON [{feature, importance}, ...]
249
+ importance_json = Column(Text, nullable=True)
250
+
251
+ # Feature names list as JSON ["feature1", "feature2", ...]
252
+ features_json = Column(Text, nullable=True)
253
+
254
+ # Training metrics as JSON {train_mae, val_mae, etc}
255
+ metrics_json = Column(Text, nullable=True)
256
+
257
+ # When the model was trained
258
+ trained_at = Column(DateTime(timezone=True), nullable=False, default=datetime.utcnow, index=True)
259
+
260
+ def __repr__(self):
261
+ return f"<ModelMetadata(symbol={self.symbol}, trained_at={self.trained_at})>"