Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import asyncio | |
| import logging | |
| from fastapi import APIRouter, Depends, HTTPException, Query, Request, status | |
| from sqlalchemy import select | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from api.deps import get_current_user, get_portfolio_or_404 | |
| from api.utils import as_float | |
| from core.config import settings | |
| from core.database import SessionLocal, get_db | |
| from core.models import ModelRegistry, Portfolio, PortfolioTicker, User | |
| from core.schemas import ( | |
| MessageResponse, | |
| ModelAccuracyOut, | |
| ModelCoverageTickerOut, | |
| ModelRetrainAllOut, | |
| ModelOut, | |
| ModelOverviewOut, | |
| ModelOverviewSummaryOut, | |
| ModelPortfolioReferenceOut, | |
| ) | |
| from ml.evaluator import get_accuracy_series, get_latest_forward_accuracy | |
| from ml.trainer import get_training_errors, train_many_tickers, train_ticker_models | |
| router = APIRouter(tags=["models"]) | |
| MODEL_TYPES = ("lstm", "xgboost") | |
| _background_train_semaphore = asyncio.Semaphore(1) | |
| async def _train_ticker_background(ticker: str) -> None: | |
| async with _background_train_semaphore: | |
| async with SessionLocal() as db: | |
| try: | |
| await train_ticker_models(db, ticker) | |
| except Exception as exc: | |
| logging.exception("Model training failed for %s: %s", ticker, exc) | |
| async def _train_many_background(tickers: list[str]) -> None: | |
| if not tickers: | |
| return | |
| async with _background_train_semaphore: | |
| async with SessionLocal() as db: | |
| try: | |
| await train_many_tickers(db, tickers) | |
| except Exception as exc: | |
| logging.exception("Bulk model retrain failed: %s", exc) | |
| def _normalize_model_type(value: str) -> str | None: | |
| normalized = value.lower().strip() | |
| if normalized in MODEL_TYPES: | |
| return normalized | |
| return None | |
| async def _get_scoped_tickers( | |
| db: AsyncSession, | |
| request: Request, | |
| portfolio_id: str | None = None, | |
| ) -> list[str]: | |
| stmt = select(PortfolioTicker.ticker).distinct() | |
| if portfolio_id: | |
| portfolio = await get_portfolio_or_404(portfolio_id, request, db) | |
| stmt = stmt.where(PortfolioTicker.portfolio_id == portfolio.id) | |
| tickers = [str(item).upper() for item in (await db.scalars(stmt)).all()] | |
| return sorted(set(tickers)) | |
| async def list_models( | |
| request: Request, | |
| portfolio_id: str | None = Query(default=None), | |
| _: User = Depends(get_current_user), | |
| db: AsyncSession = Depends(get_db), | |
| ): | |
| stmt = select(ModelRegistry) | |
| if portfolio_id: | |
| tickers = await _get_scoped_tickers(db, request, portfolio_id) | |
| if not tickers: | |
| return [] | |
| stmt = stmt.where(ModelRegistry.ticker.in_(tickers)) | |
| stmt = stmt.order_by( | |
| ModelRegistry.ticker.asc(), | |
| ModelRegistry.model_type.asc(), | |
| ) | |
| rows = (await db.scalars(stmt)).all() | |
| output: list[ModelOut] = [] | |
| for row in rows: | |
| model_type = row.model_type.lower().strip() | |
| if model_type not in {"lstm", "xgboost"}: | |
| continue | |
| accuracy = as_float(row.accuracy) | |
| forward_accuracy = None | |
| if model_type == "lstm": | |
| forward_accuracy = await get_latest_forward_accuracy( | |
| db, | |
| ticker=row.ticker, | |
| model_type=model_type, | |
| ) | |
| output.append( | |
| ModelOut( | |
| id=row.id, | |
| ticker=row.ticker, | |
| model_type=model_type, | |
| accuracy=forward_accuracy if forward_accuracy is not None else accuracy, | |
| training_rows=row.training_rows or 0, | |
| trained_at=row.trained_at, | |
| is_active=row.is_active, | |
| ) | |
| ) | |
| return output | |
| async def get_models_overview( | |
| request: Request, | |
| portfolio_id: str | None = Query(default=None), | |
| _: User = Depends(get_current_user), | |
| db: AsyncSession = Depends(get_db), | |
| ): | |
| ticker_stmt = ( | |
| select( | |
| PortfolioTicker.ticker, | |
| Portfolio.id, | |
| Portfolio.name, | |
| Portfolio.is_active, | |
| ) | |
| .join(Portfolio, Portfolio.id == PortfolioTicker.portfolio_id) | |
| .order_by(PortfolioTicker.ticker.asc(), Portfolio.name.asc()) | |
| ) | |
| if portfolio_id: | |
| portfolio = await get_portfolio_or_404(portfolio_id, request, db) | |
| ticker_stmt = ticker_stmt.where(PortfolioTicker.portfolio_id == portfolio.id) | |
| ticker_rows = (await db.execute(ticker_stmt)).all() | |
| if not ticker_rows: | |
| return ModelOverviewOut( | |
| summary=ModelOverviewSummaryOut( | |
| tracked_tickers=0, | |
| referenced_portfolios=0, | |
| trained_model_count=0, | |
| fully_trained_tickers=0, | |
| missing_model_count=0, | |
| ), | |
| available_model_types=list(MODEL_TYPES), | |
| coverage=[], | |
| ) | |
| scoped_tickers: set[str] = set() | |
| portfolio_ids: set[str] = set() | |
| coverage_map: dict[str, list[ModelPortfolioReferenceOut]] = {} | |
| for ticker, ref_id, name, is_active in ticker_rows: | |
| symbol = str(ticker).upper() | |
| scoped_tickers.add(symbol) | |
| portfolio_ids.add(str(ref_id)) | |
| coverage_map.setdefault(symbol, []).append( | |
| ModelPortfolioReferenceOut( | |
| id=ref_id, | |
| name=name, | |
| is_active=bool(is_active), | |
| ) | |
| ) | |
| model_stmt = ( | |
| select(ModelRegistry) | |
| .where(ModelRegistry.ticker.in_(sorted(scoped_tickers))) | |
| .order_by(ModelRegistry.ticker.asc(), ModelRegistry.model_type.asc()) | |
| ) | |
| model_rows = (await db.scalars(model_stmt)).all() | |
| models_by_ticker: dict[str, list[ModelRegistry]] = {} | |
| for row in model_rows: | |
| model_type = _normalize_model_type(row.model_type) | |
| if not model_type or not row.is_active: | |
| continue | |
| models_by_ticker.setdefault(row.ticker.upper(), []).append(row) | |
| coverage: list[ModelCoverageTickerOut] = [] | |
| trained_model_count = 0 | |
| fully_trained_tickers = 0 | |
| missing_model_count = 0 | |
| training_errors = get_training_errors() | |
| for ticker in sorted(scoped_tickers): | |
| rows = models_by_ticker.get(ticker, []) | |
| trained_types = [ | |
| model_type | |
| for model_type in MODEL_TYPES | |
| if any(_normalize_model_type(row.model_type) == model_type for row in rows) | |
| ] | |
| missing_types = [ | |
| model_type for model_type in MODEL_TYPES if model_type not in trained_types | |
| ] | |
| last_trained_at = max((row.trained_at for row in rows), default=None) | |
| trained_model_count += len(trained_types) | |
| missing_model_count += len(missing_types) | |
| if not missing_types: | |
| fully_trained_tickers += 1 | |
| coverage.append( | |
| ModelCoverageTickerOut( | |
| ticker=ticker, | |
| portfolios=coverage_map.get(ticker, []), | |
| trained_model_types=trained_types, | |
| missing_model_types=missing_types, | |
| coverage_pct=round((len(trained_types) / len(MODEL_TYPES)) * 100, 2), | |
| is_fully_trained=not missing_types, | |
| last_trained_at=last_trained_at, | |
| last_training_error=training_errors.get(ticker), | |
| ) | |
| ) | |
| return ModelOverviewOut( | |
| summary=ModelOverviewSummaryOut( | |
| tracked_tickers=len(scoped_tickers), | |
| referenced_portfolios=len(portfolio_ids), | |
| trained_model_count=trained_model_count, | |
| fully_trained_tickers=fully_trained_tickers, | |
| missing_model_count=missing_model_count, | |
| ), | |
| available_model_types=list(MODEL_TYPES), | |
| coverage=coverage, | |
| ) | |
| async def train_model( | |
| ticker: str, | |
| request: Request, | |
| _: User = Depends(get_current_user), | |
| db: AsyncSession = Depends(get_db), | |
| ): | |
| symbol = ticker.upper().strip() | |
| if not symbol: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Ticker is required", | |
| ) | |
| if settings.is_production: | |
| # Run training in the background so the request does not time out on hosted deployments. | |
| asyncio.create_task(_train_ticker_background(symbol)) | |
| return MessageResponse(message=f"Training started for {symbol}") | |
| try: | |
| await train_ticker_models(db, symbol) | |
| except ValueError as exc: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=str(exc), | |
| ) from exc | |
| return MessageResponse(message=f"Training completed for {symbol}") | |
| async def retrain_all_models( | |
| request: Request, | |
| portfolio_id: str | None = Query(default=None), | |
| _: User = Depends(get_current_user), | |
| db: AsyncSession = Depends(get_db), | |
| ): | |
| tickers = await _get_scoped_tickers(db, request, portfolio_id) | |
| if settings.is_production: | |
| # Schedule long-running retraining in the background to avoid request timeouts. | |
| asyncio.create_task(_train_many_background(tickers)) | |
| scope = "portfolio" if portfolio_id else "tracked" | |
| message = ( | |
| f"Retraining queued for {len(tickers)} {scope} ticker(s). " | |
| "Check back shortly for updated models." | |
| ) | |
| return ModelRetrainAllOut( | |
| message=message, | |
| total_tickers=len(tickers), | |
| trained_count=0, | |
| failed_count=0, | |
| failed=[], | |
| ) | |
| trained_count = 0 | |
| failed: list[dict] = [] | |
| if tickers: | |
| result = await train_many_tickers(db, tickers) | |
| trained_count = len(result.get("trained", [])) | |
| failed = result.get("failed", []) | |
| scope = "portfolio" if portfolio_id else "tracked" | |
| failed_count = len(failed) | |
| message = ( | |
| f"Retraining finished for {len(tickers)} {scope} ticker(s): " | |
| f"{trained_count} succeeded, {failed_count} failed" | |
| ) | |
| return ModelRetrainAllOut( | |
| message=message, | |
| total_tickers=len(tickers), | |
| trained_count=trained_count, | |
| failed_count=failed_count, | |
| failed=failed, | |
| ) | |
| async def get_model_accuracy( | |
| ticker: str, | |
| request: Request, | |
| model_type: str | None = Query(default=None), | |
| _: User = Depends(get_current_user), | |
| db: AsyncSession = Depends(get_db), | |
| ): | |
| normalized = model_type.lower().strip() if model_type else None | |
| if normalized and normalized not in MODEL_TYPES: | |
| normalized = None | |
| payload = await get_accuracy_series( | |
| db=db, | |
| ticker=ticker.upper(), | |
| model_type=normalized, | |
| limit=180, | |
| ) | |
| return ModelAccuracyOut(**payload) | |