Spaces:
Running
Running
Sync from GitHub
Browse files- app/ai_engine.py +99 -1
- app/models.py +29 -0
app/ai_engine.py
CHANGED
|
@@ -449,6 +449,20 @@ def train_xgboost_model(
|
|
| 449 |
desc = descriptions.get(feat, feat)
|
| 450 |
logger.info(f" {feat}: {imp:.4f} ({desc})")
|
| 451 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
return {
|
| 453 |
"model_path": str(model_path),
|
| 454 |
"metrics": metrics,
|
|
@@ -474,8 +488,92 @@ def load_model(target_symbol: str = "HG=F") -> Optional[xgb.Booster]:
|
|
| 474 |
return model
|
| 475 |
|
| 476 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
def load_model_metadata(target_symbol: str = "HG=F") -> dict:
|
| 478 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 479 |
settings = get_settings()
|
| 480 |
model_dir = Path(settings.model_dir)
|
| 481 |
|
|
|
|
| 449 |
desc = descriptions.get(feat, feat)
|
| 450 |
logger.info(f" {feat}: {imp:.4f} ({desc})")
|
| 451 |
|
| 452 |
+
# Save metadata to database for persistence across HF Space restarts
|
| 453 |
+
try:
|
| 454 |
+
from app.db import SessionLocal
|
| 455 |
+
with SessionLocal() as session:
|
| 456 |
+
save_model_metadata_to_db(
|
| 457 |
+
session=session,
|
| 458 |
+
symbol=target_symbol,
|
| 459 |
+
importance=normalized_importance,
|
| 460 |
+
features=feature_names,
|
| 461 |
+
metrics=metrics,
|
| 462 |
+
)
|
| 463 |
+
except Exception as e:
|
| 464 |
+
logger.warning(f"Could not save model metadata to DB: {e}")
|
| 465 |
+
|
| 466 |
return {
|
| 467 |
"model_path": str(model_path),
|
| 468 |
"metrics": metrics,
|
|
|
|
| 488 |
return model
|
| 489 |
|
| 490 |
|
| 491 |
+
def save_model_metadata_to_db(
|
| 492 |
+
session,
|
| 493 |
+
symbol: str,
|
| 494 |
+
importance: list,
|
| 495 |
+
features: list,
|
| 496 |
+
metrics: dict
|
| 497 |
+
) -> None:
|
| 498 |
+
"""
|
| 499 |
+
Save model metadata to database for persistence across restarts.
|
| 500 |
+
Called after train_model=True pipeline runs.
|
| 501 |
+
"""
|
| 502 |
+
from .models import ModelMetadata
|
| 503 |
+
from datetime import datetime
|
| 504 |
+
|
| 505 |
+
# Try to find existing record
|
| 506 |
+
existing = session.query(ModelMetadata).filter(ModelMetadata.symbol == symbol).first()
|
| 507 |
+
|
| 508 |
+
if existing:
|
| 509 |
+
existing.importance_json = json.dumps(importance)
|
| 510 |
+
existing.features_json = json.dumps(features)
|
| 511 |
+
existing.metrics_json = json.dumps(metrics)
|
| 512 |
+
existing.trained_at = datetime.utcnow()
|
| 513 |
+
logger.info(f"Updated model metadata in DB for {symbol}")
|
| 514 |
+
else:
|
| 515 |
+
new_record = ModelMetadata(
|
| 516 |
+
symbol=symbol,
|
| 517 |
+
importance_json=json.dumps(importance),
|
| 518 |
+
features_json=json.dumps(features),
|
| 519 |
+
metrics_json=json.dumps(metrics),
|
| 520 |
+
)
|
| 521 |
+
session.add(new_record)
|
| 522 |
+
logger.info(f"Saved new model metadata to DB for {symbol}")
|
| 523 |
+
|
| 524 |
+
session.commit()
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
def load_model_metadata_from_db(session, symbol: str) -> dict:
|
| 528 |
+
"""
|
| 529 |
+
Load model metadata from database.
|
| 530 |
+
Returns dict with importance, features, metrics or None values if not found.
|
| 531 |
+
"""
|
| 532 |
+
from .models import ModelMetadata
|
| 533 |
+
|
| 534 |
+
metadata = {
|
| 535 |
+
"metrics": None,
|
| 536 |
+
"features": None,
|
| 537 |
+
"importance": None,
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
record = session.query(ModelMetadata).filter(ModelMetadata.symbol == symbol).first()
|
| 541 |
+
|
| 542 |
+
if record:
|
| 543 |
+
try:
|
| 544 |
+
if record.importance_json:
|
| 545 |
+
metadata["importance"] = json.loads(record.importance_json)
|
| 546 |
+
if record.features_json:
|
| 547 |
+
metadata["features"] = json.loads(record.features_json)
|
| 548 |
+
if record.metrics_json:
|
| 549 |
+
metadata["metrics"] = json.loads(record.metrics_json)
|
| 550 |
+
logger.info(f"Loaded model metadata from DB for {symbol}")
|
| 551 |
+
except json.JSONDecodeError as e:
|
| 552 |
+
logger.warning(f"Failed to parse model metadata from DB: {e}")
|
| 553 |
+
|
| 554 |
+
return metadata
|
| 555 |
+
|
| 556 |
+
|
| 557 |
def load_model_metadata(target_symbol: str = "HG=F") -> dict:
|
| 558 |
+
"""
|
| 559 |
+
Load metrics and feature info for a model.
|
| 560 |
+
|
| 561 |
+
Priority:
|
| 562 |
+
1. Database (survives HF Space restarts)
|
| 563 |
+
2. Local JSON files (fallback for development)
|
| 564 |
+
"""
|
| 565 |
+
from app.db import SessionLocal
|
| 566 |
+
|
| 567 |
+
# Try database first
|
| 568 |
+
try:
|
| 569 |
+
with SessionLocal() as session:
|
| 570 |
+
db_metadata = load_model_metadata_from_db(session, target_symbol)
|
| 571 |
+
if db_metadata.get("importance") and db_metadata.get("features"):
|
| 572 |
+
return db_metadata
|
| 573 |
+
except Exception as e:
|
| 574 |
+
logger.debug(f"Could not load metadata from DB: {e}")
|
| 575 |
+
|
| 576 |
+
# Fallback to local files
|
| 577 |
settings = get_settings()
|
| 578 |
model_dir = Path(settings.model_dir)
|
| 579 |
|
app/models.py
CHANGED
|
@@ -230,3 +230,32 @@ class AICommentary(Base):
|
|
| 230 |
|
| 231 |
def __repr__(self):
|
| 232 |
return f"<AICommentary(symbol={self.symbol}, generated_at={self.generated_at})>"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
def __repr__(self):
|
| 232 |
return f"<AICommentary(symbol={self.symbol}, generated_at={self.generated_at})>"
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class ModelMetadata(Base):
|
| 236 |
+
"""
|
| 237 |
+
Persisted XGBoost model metadata.
|
| 238 |
+
Stores feature importance, features list, and metrics in database
|
| 239 |
+
so they survive HF Space restarts.
|
| 240 |
+
One row per symbol, updated after each model training (train_model=True).
|
| 241 |
+
"""
|
| 242 |
+
__tablename__ = "model_metadata"
|
| 243 |
+
|
| 244 |
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
| 245 |
+
|
| 246 |
+
symbol = Column(String(20), nullable=False, unique=True, index=True)
|
| 247 |
+
|
| 248 |
+
# Feature importance as JSON [{feature, importance}, ...]
|
| 249 |
+
importance_json = Column(Text, nullable=True)
|
| 250 |
+
|
| 251 |
+
# Feature names list as JSON ["feature1", "feature2", ...]
|
| 252 |
+
features_json = Column(Text, nullable=True)
|
| 253 |
+
|
| 254 |
+
# Training metrics as JSON {train_mae, val_mae, etc}
|
| 255 |
+
metrics_json = Column(Text, nullable=True)
|
| 256 |
+
|
| 257 |
+
# When the model was trained
|
| 258 |
+
trained_at = Column(DateTime(timezone=True), nullable=False, default=datetime.utcnow, index=True)
|
| 259 |
+
|
| 260 |
+
def __repr__(self):
|
| 261 |
+
return f"<ModelMetadata(symbol={self.symbol}, trained_at={self.trained_at})>"
|