from __future__ import annotations import asyncio import logging import math from datetime import datetime, timedelta, timezone import numpy as np from sqlalchemy import desc, select from sqlalchemy.ext.asyncio import AsyncSession from core.models import ( AgentRun, Holding, Portfolio, PortfolioSnapshot, PortfolioTicker, Transaction, ) from core.schemas import HoldingOut, PerformanceStatsOut, PortfolioOut, TransactionOut from data.exceptions import MarketDataUnavailableError from data.market_data import get_current_price, get_provider_debug # Simple in-memory price cache with TTL (5 minutes) _price_cache: dict[str, tuple[float, float]] = {} # {ticker: (price, timestamp_utc_seconds)} _PRICE_CACHE_TTL = 300.0 # 5 minutes _PRICE_CACHE_STALE_TTL = 3600.0 # 1 hour def _get_cached_price(ticker: str, *, allow_stale: bool = False) -> float | None: """Get cached price if exists and not expired (or allow stale within max age).""" symbol = ticker.strip().upper() entry = _price_cache.get(symbol) if not entry: return None price, ts = entry now = datetime.now(timezone.utc).timestamp() age = now - ts if age <= _PRICE_CACHE_TTL: return price if allow_stale and age <= _PRICE_CACHE_STALE_TTL: return price if age > _PRICE_CACHE_STALE_TTL: _price_cache.pop(symbol, None) return None def _set_cached_price(ticker: str, price: float) -> None: """Cache a price for 5 minutes.""" symbol = ticker.strip().upper() _price_cache[symbol] = (price, datetime.now(timezone.utc).timestamp()) def normalize_ticker(value: str) -> str: return value.strip().upper() def as_float(value, default: float = 0.0) -> float: if value is None: return default return float(value) async def heal_stale_agent_runs( db: AsyncSession, *, portfolio_id=None, stale_after_minutes: int = 180, ) -> int: """Mark long-running agent runs as failed so they don't block new runs. This primarily defends against crashes/lock contention bugs that leave rows in a perpetual "running" state. """ now = datetime.now(timezone.utc) cutoff = now - timedelta(minutes=stale_after_minutes) # Fetch all currently-running rows (for a portfolio if provided). running_stmt = ( select(AgentRun) .where(AgentRun.status == "running") .where(AgentRun.completed_at.is_(None)) .order_by(AgentRun.started_at.asc()) ) if portfolio_id is not None: running_stmt = running_stmt.where(AgentRun.portfolio_id == portfolio_id) running_runs = list((await db.scalars(running_stmt)).all()) if not running_runs: return 0 # If we know a later run completed, any older "running" row is impossible # (per-portfolio advisory lock enforces single-run execution). latest_completed_at: datetime | None = None if portfolio_id is not None: completed_stmt = ( select(AgentRun.completed_at) .where(AgentRun.portfolio_id == portfolio_id) .where(AgentRun.completed_at.is_not(None)) .order_by(AgentRun.completed_at.desc()) .limit(1) ) latest_completed_at = await db.scalar(completed_stmt) stale_runs: list[AgentRun] = [] for run in running_runs: impossible = bool(latest_completed_at and latest_completed_at > run.started_at) too_old = run.started_at < cutoff if impossible or too_old: stale_runs.append(run) if not stale_runs: return 0 for run in stale_runs: run.status = "failed" minutes_running = max(0, int((now - run.started_at).total_seconds() // 60)) note = ( "Skipped: run was marked stale (stuck in running state). " f"Elapsed≈{minutes_running}m; threshold={stale_after_minutes}m." ) existing = (run.summary or "").strip() run.summary = existing + ("\n\n" if existing else "") + note run.completed_at = now await db.commit() return len(stale_runs) async def get_portfolio_tickers(db: AsyncSession, portfolio_id) -> list[str]: stmt = ( select(PortfolioTicker.ticker) .where(PortfolioTicker.portfolio_id == portfolio_id) .order_by(PortfolioTicker.ticker.asc()) ) rows = (await db.scalars(stmt)).all() return [str(row).upper() for row in rows] async def _price_for_holding(ticker: str, avg_buy: float) -> tuple[float, str, str | None]: """Fetch current price with caching and timeout. Returns (price, source, error). Source indicates whether live, cache, stale_cache, or avg_buy was used. Error is set when live fetch failed. """ # Check cache first cached = _get_cached_price(ticker) if cached is not None: return cached, "cache", None stale_cached = _get_cached_price(ticker, allow_stale=True) try: # Fetch price with a bounded timeout to avoid slow requests price = await asyncio.wait_for(get_current_price(ticker), timeout=8.0) _set_cached_price(ticker, price) return price, "live", None except asyncio.TimeoutError: details = get_provider_debug(ticker) error = "market data request timed out" if details: error = f"{error} ({details})" logging.warning("Market data timeout for %s: %s", ticker, error) if stale_cached is not None: return stale_cached, "stale_cache", error return avg_buy, "avg_buy", error except MarketDataUnavailableError as exc: error = str(exc) logging.warning("Market data unavailable for %s: %s", ticker, error) if stale_cached is not None: return stale_cached, "stale_cache", error return avg_buy, "avg_buy", error except Exception as exc: error = f"market data error: {exc}" logging.warning("Market data error for %s: %s", ticker, error) if stale_cached is not None: return stale_cached, "stale_cache", error return avg_buy, "avg_buy", error async def build_holdings_view( db: AsyncSession, portfolio_id, ) -> tuple[list[HoldingOut], float]: stmt = ( select(Holding) .where(Holding.portfolio_id == portfolio_id) .order_by(Holding.ticker.asc()) ) rows = (await db.scalars(stmt)).all() eligible = [row for row in rows if as_float(row.shares) > 0] if not eligible: return [], 0.0 sem = asyncio.Semaphore(3) async def _fetch_with_sem(row): async with sem: return await _price_for_holding(row.ticker, as_float(row.avg_buy_price)) prices = await asyncio.gather(*[_fetch_with_sem(row) for row in eligible]) output: list[HoldingOut] = [] holdings_value = 0.0 for row, price_payload in zip(eligible, prices, strict=True): current_price, price_source, price_error = price_payload shares = as_float(row.shares) avg_buy = as_float(row.avg_buy_price) value = shares * current_price cost_basis = shares * avg_buy profit_loss = value - cost_basis profit_loss_pct = (profit_loss / cost_basis * 100) if cost_basis > 0 else 0.0 holdings_value += value output.append( HoldingOut( ticker=row.ticker, shares=round(shares, 6), avg_buy_price=round(avg_buy, 4), current_price=round(current_price, 4), value=round(value, 2), profit_loss=round(profit_loss, 2), profit_loss_pct=round(profit_loss_pct, 4), price_source=price_source, price_error=price_error, ) ) return output, round(holdings_value, 2) def _portfolio_out_from_values( portfolio: Portfolio, tickers: list[str], holdings_value: float, ) -> PortfolioOut: current_cash = as_float(portfolio.current_cash) starting_capital = as_float(portfolio.starting_capital) total_value = current_cash + holdings_value profit_loss = total_value - starting_capital profit_loss_pct = (profit_loss / starting_capital * 100) if starting_capital else 0.0 return PortfolioOut( id=portfolio.id, name=portfolio.name, description=portfolio.description, starting_capital=round(starting_capital, 2), current_cash=round(current_cash, 2), holdings_value=round(holdings_value, 2), total_value=round(total_value, 2), profit_loss=round(profit_loss, 2), profit_loss_pct=round(profit_loss_pct, 4), is_active=portfolio.is_active, tickers=tickers, created_at=portfolio.created_at, ) async def build_portfolio_out(db: AsyncSession, portfolio: Portfolio) -> PortfolioOut: tickers = await get_portfolio_tickers(db, portfolio.id) _, holdings_value = await build_holdings_view(db, portfolio.id) return _portfolio_out_from_values(portfolio, tickers, holdings_value) def transaction_to_out(tx: Transaction) -> TransactionOut: return TransactionOut( id=tx.id, portfolio_id=tx.portfolio_id, ticker=tx.ticker, action=tx.action, shares=as_float(tx.shares), price_at_trade=as_float(tx.price_at_trade), total_value=as_float(tx.total_value), llm_reasoning=tx.llm_reasoning or "", tools_called=tx.tools_called or {}, executed_at=tx.executed_at, ) async def snapshot_portfolio(db: AsyncSession, portfolio_id) -> PortfolioSnapshot: portfolio = await db.get(Portfolio, portfolio_id) if not portfolio: raise ValueError("Portfolio not found") _, holdings_value = await build_holdings_view(db, portfolio_id) current_cash = as_float(portfolio.current_cash) total_value = current_cash + holdings_value snapshot = PortfolioSnapshot( portfolio_id=portfolio_id, total_value=round(total_value, 2), cash_value=round(current_cash, 2), holdings_value=round(holdings_value, 2), snapshot_at=datetime.now(timezone.utc), ) db.add(snapshot) await db.commit() await db.refresh(snapshot) return snapshot async def build_performance_stats( db: AsyncSession, portfolio_id, ) -> PerformanceStatsOut: portfolio = await db.get(Portfolio, portfolio_id) if not portfolio: raise ValueError("Portfolio not found") tx_stmt = ( select(Transaction) .where(Transaction.portfolio_id == portfolio_id) .order_by(Transaction.executed_at.asc()) ) transactions = (await db.scalars(tx_stmt)).all() realized_results: list[dict] = [] positions: dict[str, dict[str, float]] = {} for tx in transactions: ticker = tx.ticker.upper() action = tx.action shares = as_float(tx.shares) price = as_float(tx.price_at_trade) if ticker not in positions: positions[ticker] = {"shares": 0.0, "avg": 0.0} pos = positions[ticker] if action == "BUY": total_cost = (pos["shares"] * pos["avg"]) + (shares * price) pos["shares"] += shares if pos["shares"] > 0: pos["avg"] = total_cost / pos["shares"] elif action == "SELL" and pos["shares"] > 0: qty = min(shares, pos["shares"]) if shares > 0 else pos["shares"] if qty <= 0: continue gain_pct = ((price - pos["avg"]) / pos["avg"] * 100) if pos["avg"] else 0.0 realized_results.append({"ticker": ticker, "gain_pct": gain_pct}) pos["shares"] -= qty sell_trades_count = len(realized_results) profitable_trades = len([item for item in realized_results if item["gain_pct"] > 0]) win_rate = profitable_trades / sell_trades_count if sell_trades_count else 0.0 if realized_results: best = max(realized_results, key=lambda item: item["gain_pct"]) worst = min(realized_results, key=lambda item: item["gain_pct"]) best_trade = { "ticker": best["ticker"], "gain_pct": round(best["gain_pct"], 4), } worst_trade = { "ticker": worst["ticker"], "loss_pct": round(worst["gain_pct"], 4), } else: best_trade = {"ticker": "N/A", "gain_pct": 0.0} worst_trade = {"ticker": "N/A", "loss_pct": 0.0} snapshots_stmt = ( select(PortfolioSnapshot) .where(PortfolioSnapshot.portfolio_id == portfolio_id) .order_by(PortfolioSnapshot.snapshot_at.asc()) ) snapshots = (await db.scalars(snapshots_stmt)).all() values = [as_float(row.total_value) for row in snapshots if row.total_value is not None] max_drawdown_pct = 0.0 if values: peak = values[0] drawdowns: list[float] = [] for value in values: peak = max(peak, value) drawdown = ((value - peak) / peak * 100) if peak else 0.0 drawdowns.append(drawdown) max_drawdown_pct = min(drawdowns) sharpe_ratio = 0.0 if len(values) > 2: arr = np.array(values, dtype=float) returns = np.diff(arr) / arr[:-1] if returns.size > 1 and float(np.std(returns)) > 0: sharpe_ratio = float(np.mean(returns) / np.std(returns) * math.sqrt(252)) return PerformanceStatsOut( total_return_pct=0.0, sharpe_ratio=round(sharpe_ratio, 4), max_drawdown_pct=round(max_drawdown_pct, 4), win_rate=round(win_rate, 4), total_trades=sell_trades_count, profitable_trades=profitable_trades, best_trade=best_trade, worst_trade=worst_trade, ) async def get_recent_transactions( db: AsyncSession, portfolio_id, limit: int = 10, ) -> list[TransactionOut]: stmt = ( select(Transaction) .where(Transaction.portfolio_id == portfolio_id) .order_by(desc(Transaction.executed_at)) .limit(limit) ) rows = (await db.scalars(stmt)).all() return [transaction_to_out(row) for row in rows] async def get_agent_runs( db: AsyncSession, portfolio_id, limit: int = 20, ) -> list[AgentRun]: await heal_stale_agent_runs(db, portfolio_id=portfolio_id) stmt = ( select(AgentRun) .where(AgentRun.portfolio_id == portfolio_id) .order_by(desc(AgentRun.started_at)) .limit(limit) ) return list((await db.scalars(stmt)).all())