Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| from datetime import datetime, timezone | |
| from sqlalchemy import select | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from api.utils import as_float, build_holdings_view, build_portfolio_out, transaction_to_out | |
| from core.models import Holding, Portfolio, Transaction | |
| from data.market_data import get_history, get_price_snapshot | |
| from data.news_fetcher import get_sentiment_score | |
| from ml.features import get_technical_signals | |
| from ml.predictor import classify_direction, predict_price | |
| TOOLS = [ | |
| { | |
| "name": "predict_price", | |
| "description": "Uses trained LSTM model to predict price N days ahead. Returns predicted price and confidence score.", | |
| "parameters": { | |
| "ticker": "string - stock ticker symbol", | |
| "horizon_days": "integer - forecast horizon (1-10)", | |
| }, | |
| }, | |
| { | |
| "name": "classify_direction", | |
| "description": "Uses XGBoost to classify if price will go UP or DOWN in next 3 days. Returns direction and probability.", | |
| "parameters": { | |
| "ticker": "string", | |
| }, | |
| }, | |
| { | |
| "name": "get_technical_signals", | |
| "description": "Returns RSI, MACD crossover signal, Bollinger Band position for a ticker.", | |
| "parameters": { | |
| "ticker": "string", | |
| }, | |
| }, | |
| { | |
| "name": "get_sentiment_score", | |
| "description": "Fetches latest news for ticker, returns sentiment score from -1.0 to 1.0 and recent headlines.", | |
| "parameters": { | |
| "ticker": "string", | |
| }, | |
| }, | |
| { | |
| "name": "get_portfolio_status", | |
| "description": "Returns current cash, holdings, total value, and unrealized P&L for the portfolio.", | |
| "parameters": {}, | |
| }, | |
| { | |
| "name": "execute_trade", | |
| "description": "Execute a BUY or SELL order. For BUY: specify dollar amount. For SELL: specify shares or 'all'.", | |
| "parameters": { | |
| "ticker": "string", | |
| "action": "string - 'BUY' or 'SELL'", | |
| "amount_usd": "number - dollar amount for BUY (optional)", | |
| "shares": "number or 'all' - for SELL (optional)", | |
| }, | |
| }, | |
| { | |
| "name": "get_price_history", | |
| "description": "Returns OHLCV price history for a ticker.", | |
| "parameters": { | |
| "ticker": "string", | |
| "days": "integer - lookback days (default 30)", | |
| }, | |
| }, | |
| ] | |
| async def predict_price_tool(ticker: str, horizon_days: int = 3) -> dict: | |
| return await predict_price(ticker=ticker.upper(), horizon_days=max(1, min(10, horizon_days))) | |
| async def classify_direction_tool(ticker: str) -> dict: | |
| return await classify_direction(ticker=ticker.upper()) | |
| async def get_technical_signals_tool(ticker: str) -> dict: | |
| return await get_technical_signals(ticker=ticker.upper()) | |
| async def get_sentiment_score_tool(ticker: str) -> dict: | |
| return await get_sentiment_score(ticker=ticker.upper()) | |
| async def get_price_history_tool(ticker: str, days: int = 30) -> list[dict]: | |
| return await get_history(ticker=ticker.upper(), days=max(1, days)) | |
| async def get_portfolio_status_tool(db: AsyncSession, portfolio_id: str) -> dict: | |
| portfolio = await db.get(Portfolio, portfolio_id) | |
| if not portfolio: | |
| raise ValueError("Portfolio not found") | |
| holdings, holdings_value = await build_holdings_view(db, portfolio_id) | |
| current_cash = as_float(portfolio.current_cash) | |
| return { | |
| "portfolio_id": str(portfolio.id), | |
| "current_cash": round(current_cash, 2), | |
| "holdings_value": round(holdings_value, 2), | |
| "total_value": round(current_cash + holdings_value, 2), | |
| "holdings": [item.model_dump() for item in holdings], | |
| } | |
| async def execute_trade( | |
| db: AsyncSession, | |
| portfolio_id: str, | |
| ticker: str, | |
| action: str, | |
| run_id: str | None = None, | |
| amount_usd: float | None = None, | |
| shares: float | str | None = None, | |
| llm_reasoning: str | None = None, | |
| tools_called: dict | None = None, | |
| ) -> dict: | |
| symbol = ticker.upper().strip() | |
| action = action.upper().strip() | |
| if action not in {"BUY", "SELL", "HOLD"}: | |
| raise ValueError("Invalid action") | |
| portfolio = await db.get(Portfolio, portfolio_id) | |
| if not portfolio: | |
| raise ValueError("Portfolio not found") | |
| price_payload = await get_price_snapshot(symbol) | |
| price = float(price_payload["price"]) | |
| holding_stmt = select(Holding).where( | |
| Holding.portfolio_id == portfolio.id, | |
| Holding.ticker == symbol, | |
| ) | |
| holding = await db.scalar(holding_stmt) | |
| if not holding: | |
| holding = Holding( | |
| portfolio_id=portfolio.id, | |
| ticker=symbol, | |
| shares=0, | |
| avg_buy_price=price, | |
| ) | |
| db.add(holding) | |
| executed_shares = 0.0 | |
| trade_value = 0.0 | |
| idempotency_key = f"{run_id}:{symbol}:{action}" if run_id else None | |
| if idempotency_key: | |
| existing_tx = await db.scalar( | |
| select(Transaction).where(Transaction.idempotency_key == idempotency_key) | |
| ) | |
| if existing_tx: | |
| portfolio_out = await build_portfolio_out(db, portfolio) | |
| return { | |
| "transaction": transaction_to_out(existing_tx).model_dump(), | |
| "portfolio": portfolio_out.model_dump(), | |
| } | |
| if action == "BUY": | |
| cash_available = as_float(portfolio.current_cash) | |
| amount = float(amount_usd) if amount_usd is not None else min(cash_available * 0.05, cash_available) | |
| amount = min(amount, cash_available) | |
| if amount <= 0: | |
| raise ValueError("Insufficient cash for BUY") | |
| executed_shares = amount / price | |
| existing_shares = as_float(holding.shares) | |
| existing_avg = as_float(holding.avg_buy_price) | |
| new_total_shares = existing_shares + executed_shares | |
| if new_total_shares > 0: | |
| holding.avg_buy_price = ( | |
| (existing_shares * existing_avg) + (executed_shares * price) | |
| ) / new_total_shares | |
| holding.shares = new_total_shares | |
| portfolio.current_cash = cash_available - amount | |
| trade_value = amount | |
| elif action == "SELL": | |
| existing_shares = as_float(holding.shares) | |
| if existing_shares <= 0: | |
| raise ValueError(f"No shares available to sell for {symbol}") | |
| if shares == "all" or shares is None: | |
| executed_shares = existing_shares | |
| else: | |
| requested = float(shares) | |
| executed_shares = max(0.0, min(requested, existing_shares)) | |
| if executed_shares <= 0: | |
| raise ValueError("No valid shares specified for SELL") | |
| trade_value = executed_shares * price | |
| holding.shares = max(0.0, existing_shares - executed_shares) | |
| if as_float(holding.shares) == 0: | |
| holding.avg_buy_price = None | |
| portfolio.current_cash = as_float(portfolio.current_cash) + trade_value | |
| else: # HOLD | |
| executed_shares = 0.0 | |
| trade_value = 0.0 | |
| tx = Transaction( | |
| portfolio_id=portfolio.id, | |
| ticker=symbol, | |
| action=action, | |
| shares=round(executed_shares, 6), | |
| price_at_trade=round(price, 4), | |
| total_value=round(trade_value, 2), | |
| idempotency_key=idempotency_key, | |
| llm_reasoning=llm_reasoning or "No detailed reasoning provided.", | |
| tools_called=tools_called or {}, | |
| executed_at=datetime.now(timezone.utc), | |
| ) | |
| db.add(tx) | |
| portfolio.updated_at = datetime.now(timezone.utc) | |
| await db.commit() | |
| await db.refresh(tx) | |
| portfolio_out = await build_portfolio_out(db, portfolio) | |
| return { | |
| "transaction": transaction_to_out(tx).model_dump(), | |
| "portfolio": portfolio_out.model_dump(), | |
| } | |
| async def execute_tool_call( | |
| db: AsyncSession, | |
| portfolio_id: str, | |
| name: str, | |
| arguments: dict, | |
| ): | |
| if name == "predict_price": | |
| return await predict_price_tool( | |
| ticker=str(arguments.get("ticker", "")), | |
| horizon_days=int(arguments.get("horizon_days", 3)), | |
| ) | |
| if name == "classify_direction": | |
| return await classify_direction_tool(ticker=str(arguments.get("ticker", ""))) | |
| if name == "get_technical_signals": | |
| return await get_technical_signals_tool(ticker=str(arguments.get("ticker", ""))) | |
| if name == "get_sentiment_score": | |
| return await get_sentiment_score_tool(ticker=str(arguments.get("ticker", ""))) | |
| if name == "get_portfolio_status": | |
| return await get_portfolio_status_tool(db, portfolio_id) | |
| if name == "execute_trade": | |
| return await execute_trade( | |
| db=db, | |
| portfolio_id=portfolio_id, | |
| ticker=str(arguments.get("ticker", "")), | |
| action=str(arguments.get("action", "HOLD")), | |
| run_id=arguments.get("run_id"), | |
| amount_usd=arguments.get("amount_usd"), | |
| shares=arguments.get("shares"), | |
| ) | |
| if name == "get_price_history": | |
| return await get_price_history_tool( | |
| ticker=str(arguments.get("ticker", "")), | |
| days=int(arguments.get("days", 30)), | |
| ) | |
| raise ValueError(f"Unknown tool call: {name}") | |