Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import asyncio | |
| import logging | |
| import math | |
| from datetime import datetime, timedelta, timezone | |
| import numpy as np | |
| from sqlalchemy import desc, select | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from core.models import ( | |
| AgentRun, | |
| Holding, | |
| Portfolio, | |
| PortfolioSnapshot, | |
| PortfolioTicker, | |
| Transaction, | |
| ) | |
| from core.schemas import HoldingOut, PerformanceStatsOut, PortfolioOut, TransactionOut | |
| from data.exceptions import MarketDataUnavailableError | |
| from data.market_data import get_current_price, get_provider_debug | |
| # Simple in-memory price cache with TTL (5 minutes) | |
| _price_cache: dict[str, tuple[float, float]] = {} # {ticker: (price, timestamp_utc_seconds)} | |
| _PRICE_CACHE_TTL = 300.0 # 5 minutes | |
| _PRICE_CACHE_STALE_TTL = 3600.0 # 1 hour | |
| def _get_cached_price(ticker: str, *, allow_stale: bool = False) -> float | None: | |
| """Get cached price if exists and not expired (or allow stale within max age).""" | |
| symbol = ticker.strip().upper() | |
| entry = _price_cache.get(symbol) | |
| if not entry: | |
| return None | |
| price, ts = entry | |
| now = datetime.now(timezone.utc).timestamp() | |
| age = now - ts | |
| if age <= _PRICE_CACHE_TTL: | |
| return price | |
| if allow_stale and age <= _PRICE_CACHE_STALE_TTL: | |
| return price | |
| if age > _PRICE_CACHE_STALE_TTL: | |
| _price_cache.pop(symbol, None) | |
| return None | |
| def _set_cached_price(ticker: str, price: float) -> None: | |
| """Cache a price for 5 minutes.""" | |
| symbol = ticker.strip().upper() | |
| _price_cache[symbol] = (price, datetime.now(timezone.utc).timestamp()) | |
| def normalize_ticker(value: str) -> str: | |
| return value.strip().upper() | |
| def as_float(value, default: float = 0.0) -> float: | |
| if value is None: | |
| return default | |
| return float(value) | |
| async def heal_stale_agent_runs( | |
| db: AsyncSession, | |
| *, | |
| portfolio_id=None, | |
| stale_after_minutes: int = 180, | |
| ) -> int: | |
| """Mark long-running agent runs as failed so they don't block new runs. | |
| This primarily defends against crashes/lock contention bugs that leave rows in | |
| a perpetual "running" state. | |
| """ | |
| now = datetime.now(timezone.utc) | |
| cutoff = now - timedelta(minutes=stale_after_minutes) | |
| # Fetch all currently-running rows (for a portfolio if provided). | |
| running_stmt = ( | |
| select(AgentRun) | |
| .where(AgentRun.status == "running") | |
| .where(AgentRun.completed_at.is_(None)) | |
| .order_by(AgentRun.started_at.asc()) | |
| ) | |
| if portfolio_id is not None: | |
| running_stmt = running_stmt.where(AgentRun.portfolio_id == portfolio_id) | |
| running_runs = list((await db.scalars(running_stmt)).all()) | |
| if not running_runs: | |
| return 0 | |
| # If we know a later run completed, any older "running" row is impossible | |
| # (per-portfolio advisory lock enforces single-run execution). | |
| latest_completed_at: datetime | None = None | |
| if portfolio_id is not None: | |
| completed_stmt = ( | |
| select(AgentRun.completed_at) | |
| .where(AgentRun.portfolio_id == portfolio_id) | |
| .where(AgentRun.completed_at.is_not(None)) | |
| .order_by(AgentRun.completed_at.desc()) | |
| .limit(1) | |
| ) | |
| latest_completed_at = await db.scalar(completed_stmt) | |
| stale_runs: list[AgentRun] = [] | |
| for run in running_runs: | |
| impossible = bool(latest_completed_at and latest_completed_at > run.started_at) | |
| too_old = run.started_at < cutoff | |
| if impossible or too_old: | |
| stale_runs.append(run) | |
| if not stale_runs: | |
| return 0 | |
| for run in stale_runs: | |
| run.status = "failed" | |
| minutes_running = max(0, int((now - run.started_at).total_seconds() // 60)) | |
| note = ( | |
| "Skipped: run was marked stale (stuck in running state). " | |
| f"Elapsed≈{minutes_running}m; threshold={stale_after_minutes}m." | |
| ) | |
| existing = (run.summary or "").strip() | |
| run.summary = existing + ("\n\n" if existing else "") + note | |
| run.completed_at = now | |
| await db.commit() | |
| return len(stale_runs) | |
| async def get_portfolio_tickers(db: AsyncSession, portfolio_id) -> list[str]: | |
| stmt = ( | |
| select(PortfolioTicker.ticker) | |
| .where(PortfolioTicker.portfolio_id == portfolio_id) | |
| .order_by(PortfolioTicker.ticker.asc()) | |
| ) | |
| rows = (await db.scalars(stmt)).all() | |
| return [str(row).upper() for row in rows] | |
| async def _price_for_holding(ticker: str, avg_buy: float) -> tuple[float, str, str | None]: | |
| """Fetch current price with caching and timeout. | |
| Returns (price, source, error). Source indicates whether live, cache, stale_cache, | |
| or avg_buy was used. Error is set when live fetch failed. | |
| """ | |
| # Check cache first | |
| cached = _get_cached_price(ticker) | |
| if cached is not None: | |
| return cached, "cache", None | |
| stale_cached = _get_cached_price(ticker, allow_stale=True) | |
| try: | |
| # Fetch price with a bounded timeout to avoid slow requests | |
| price = await asyncio.wait_for(get_current_price(ticker), timeout=8.0) | |
| _set_cached_price(ticker, price) | |
| return price, "live", None | |
| except asyncio.TimeoutError: | |
| details = get_provider_debug(ticker) | |
| error = "market data request timed out" | |
| if details: | |
| error = f"{error} ({details})" | |
| logging.warning("Market data timeout for %s: %s", ticker, error) | |
| if stale_cached is not None: | |
| return stale_cached, "stale_cache", error | |
| return avg_buy, "avg_buy", error | |
| except MarketDataUnavailableError as exc: | |
| error = str(exc) | |
| logging.warning("Market data unavailable for %s: %s", ticker, error) | |
| if stale_cached is not None: | |
| return stale_cached, "stale_cache", error | |
| return avg_buy, "avg_buy", error | |
| except Exception as exc: | |
| error = f"market data error: {exc}" | |
| logging.warning("Market data error for %s: %s", ticker, error) | |
| if stale_cached is not None: | |
| return stale_cached, "stale_cache", error | |
| return avg_buy, "avg_buy", error | |
| async def build_holdings_view( | |
| db: AsyncSession, | |
| portfolio_id, | |
| ) -> tuple[list[HoldingOut], float]: | |
| stmt = ( | |
| select(Holding) | |
| .where(Holding.portfolio_id == portfolio_id) | |
| .order_by(Holding.ticker.asc()) | |
| ) | |
| rows = (await db.scalars(stmt)).all() | |
| eligible = [row for row in rows if as_float(row.shares) > 0] | |
| if not eligible: | |
| return [], 0.0 | |
| sem = asyncio.Semaphore(3) | |
| async def _fetch_with_sem(row): | |
| async with sem: | |
| return await _price_for_holding(row.ticker, as_float(row.avg_buy_price)) | |
| prices = await asyncio.gather(*[_fetch_with_sem(row) for row in eligible]) | |
| output: list[HoldingOut] = [] | |
| holdings_value = 0.0 | |
| for row, price_payload in zip(eligible, prices, strict=True): | |
| current_price, price_source, price_error = price_payload | |
| shares = as_float(row.shares) | |
| avg_buy = as_float(row.avg_buy_price) | |
| value = shares * current_price | |
| cost_basis = shares * avg_buy | |
| profit_loss = value - cost_basis | |
| profit_loss_pct = (profit_loss / cost_basis * 100) if cost_basis > 0 else 0.0 | |
| holdings_value += value | |
| output.append( | |
| HoldingOut( | |
| ticker=row.ticker, | |
| shares=round(shares, 6), | |
| avg_buy_price=round(avg_buy, 4), | |
| current_price=round(current_price, 4), | |
| value=round(value, 2), | |
| profit_loss=round(profit_loss, 2), | |
| profit_loss_pct=round(profit_loss_pct, 4), | |
| price_source=price_source, | |
| price_error=price_error, | |
| ) | |
| ) | |
| return output, round(holdings_value, 2) | |
| def _portfolio_out_from_values( | |
| portfolio: Portfolio, | |
| tickers: list[str], | |
| holdings_value: float, | |
| ) -> PortfolioOut: | |
| current_cash = as_float(portfolio.current_cash) | |
| starting_capital = as_float(portfolio.starting_capital) | |
| total_value = current_cash + holdings_value | |
| profit_loss = total_value - starting_capital | |
| profit_loss_pct = (profit_loss / starting_capital * 100) if starting_capital else 0.0 | |
| return PortfolioOut( | |
| id=portfolio.id, | |
| name=portfolio.name, | |
| description=portfolio.description, | |
| starting_capital=round(starting_capital, 2), | |
| current_cash=round(current_cash, 2), | |
| holdings_value=round(holdings_value, 2), | |
| total_value=round(total_value, 2), | |
| profit_loss=round(profit_loss, 2), | |
| profit_loss_pct=round(profit_loss_pct, 4), | |
| is_active=portfolio.is_active, | |
| tickers=tickers, | |
| created_at=portfolio.created_at, | |
| ) | |
| async def build_portfolio_out(db: AsyncSession, portfolio: Portfolio) -> PortfolioOut: | |
| tickers = await get_portfolio_tickers(db, portfolio.id) | |
| _, holdings_value = await build_holdings_view(db, portfolio.id) | |
| return _portfolio_out_from_values(portfolio, tickers, holdings_value) | |
| def transaction_to_out(tx: Transaction) -> TransactionOut: | |
| return TransactionOut( | |
| id=tx.id, | |
| portfolio_id=tx.portfolio_id, | |
| ticker=tx.ticker, | |
| action=tx.action, | |
| shares=as_float(tx.shares), | |
| price_at_trade=as_float(tx.price_at_trade), | |
| total_value=as_float(tx.total_value), | |
| llm_reasoning=tx.llm_reasoning or "", | |
| tools_called=tx.tools_called or {}, | |
| executed_at=tx.executed_at, | |
| ) | |
| async def snapshot_portfolio(db: AsyncSession, portfolio_id) -> PortfolioSnapshot: | |
| portfolio = await db.get(Portfolio, portfolio_id) | |
| if not portfolio: | |
| raise ValueError("Portfolio not found") | |
| _, holdings_value = await build_holdings_view(db, portfolio_id) | |
| current_cash = as_float(portfolio.current_cash) | |
| total_value = current_cash + holdings_value | |
| snapshot = PortfolioSnapshot( | |
| portfolio_id=portfolio_id, | |
| total_value=round(total_value, 2), | |
| cash_value=round(current_cash, 2), | |
| holdings_value=round(holdings_value, 2), | |
| snapshot_at=datetime.now(timezone.utc), | |
| ) | |
| db.add(snapshot) | |
| await db.commit() | |
| await db.refresh(snapshot) | |
| return snapshot | |
| async def build_performance_stats( | |
| db: AsyncSession, | |
| portfolio_id, | |
| ) -> PerformanceStatsOut: | |
| portfolio = await db.get(Portfolio, portfolio_id) | |
| if not portfolio: | |
| raise ValueError("Portfolio not found") | |
| tx_stmt = ( | |
| select(Transaction) | |
| .where(Transaction.portfolio_id == portfolio_id) | |
| .order_by(Transaction.executed_at.asc()) | |
| ) | |
| transactions = (await db.scalars(tx_stmt)).all() | |
| realized_results: list[dict] = [] | |
| positions: dict[str, dict[str, float]] = {} | |
| for tx in transactions: | |
| ticker = tx.ticker.upper() | |
| action = tx.action | |
| shares = as_float(tx.shares) | |
| price = as_float(tx.price_at_trade) | |
| if ticker not in positions: | |
| positions[ticker] = {"shares": 0.0, "avg": 0.0} | |
| pos = positions[ticker] | |
| if action == "BUY": | |
| total_cost = (pos["shares"] * pos["avg"]) + (shares * price) | |
| pos["shares"] += shares | |
| if pos["shares"] > 0: | |
| pos["avg"] = total_cost / pos["shares"] | |
| elif action == "SELL" and pos["shares"] > 0: | |
| qty = min(shares, pos["shares"]) if shares > 0 else pos["shares"] | |
| if qty <= 0: | |
| continue | |
| gain_pct = ((price - pos["avg"]) / pos["avg"] * 100) if pos["avg"] else 0.0 | |
| realized_results.append({"ticker": ticker, "gain_pct": gain_pct}) | |
| pos["shares"] -= qty | |
| sell_trades_count = len(realized_results) | |
| profitable_trades = len([item for item in realized_results if item["gain_pct"] > 0]) | |
| win_rate = profitable_trades / sell_trades_count if sell_trades_count else 0.0 | |
| if realized_results: | |
| best = max(realized_results, key=lambda item: item["gain_pct"]) | |
| worst = min(realized_results, key=lambda item: item["gain_pct"]) | |
| best_trade = { | |
| "ticker": best["ticker"], | |
| "gain_pct": round(best["gain_pct"], 4), | |
| } | |
| worst_trade = { | |
| "ticker": worst["ticker"], | |
| "loss_pct": round(worst["gain_pct"], 4), | |
| } | |
| else: | |
| best_trade = {"ticker": "N/A", "gain_pct": 0.0} | |
| worst_trade = {"ticker": "N/A", "loss_pct": 0.0} | |
| snapshots_stmt = ( | |
| select(PortfolioSnapshot) | |
| .where(PortfolioSnapshot.portfolio_id == portfolio_id) | |
| .order_by(PortfolioSnapshot.snapshot_at.asc()) | |
| ) | |
| snapshots = (await db.scalars(snapshots_stmt)).all() | |
| values = [as_float(row.total_value) for row in snapshots if row.total_value is not None] | |
| max_drawdown_pct = 0.0 | |
| if values: | |
| peak = values[0] | |
| drawdowns: list[float] = [] | |
| for value in values: | |
| peak = max(peak, value) | |
| drawdown = ((value - peak) / peak * 100) if peak else 0.0 | |
| drawdowns.append(drawdown) | |
| max_drawdown_pct = min(drawdowns) | |
| sharpe_ratio = 0.0 | |
| if len(values) > 2: | |
| arr = np.array(values, dtype=float) | |
| returns = np.diff(arr) / arr[:-1] | |
| if returns.size > 1 and float(np.std(returns)) > 0: | |
| sharpe_ratio = float(np.mean(returns) / np.std(returns) * math.sqrt(252)) | |
| return PerformanceStatsOut( | |
| total_return_pct=0.0, | |
| sharpe_ratio=round(sharpe_ratio, 4), | |
| max_drawdown_pct=round(max_drawdown_pct, 4), | |
| win_rate=round(win_rate, 4), | |
| total_trades=sell_trades_count, | |
| profitable_trades=profitable_trades, | |
| best_trade=best_trade, | |
| worst_trade=worst_trade, | |
| ) | |
| async def get_recent_transactions( | |
| db: AsyncSession, | |
| portfolio_id, | |
| limit: int = 10, | |
| ) -> list[TransactionOut]: | |
| stmt = ( | |
| select(Transaction) | |
| .where(Transaction.portfolio_id == portfolio_id) | |
| .order_by(desc(Transaction.executed_at)) | |
| .limit(limit) | |
| ) | |
| rows = (await db.scalars(stmt)).all() | |
| return [transaction_to_out(row) for row in rows] | |
| async def get_agent_runs( | |
| db: AsyncSession, | |
| portfolio_id, | |
| limit: int = 20, | |
| ) -> list[AgentRun]: | |
| await heal_stale_agent_runs(db, portfolio_id=portfolio_id) | |
| stmt = ( | |
| select(AgentRun) | |
| .where(AgentRun.portfolio_id == portfolio_id) | |
| .order_by(desc(AgentRun.started_at)) | |
| .limit(limit) | |
| ) | |
| return list((await db.scalars(stmt)).all()) | |