alpha-engine / agent /tools.py
Dharambir Agrawal
HF Space server-only
fd48bc8
from __future__ import annotations
from datetime import datetime, timezone
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from api.utils import as_float, build_holdings_view, build_portfolio_out, transaction_to_out
from core.models import Holding, Portfolio, Transaction
from data.market_data import get_history, get_price_snapshot
from data.news_fetcher import get_sentiment_score
from ml.features import get_technical_signals
from ml.predictor import classify_direction, predict_price
TOOLS = [
{
"name": "predict_price",
"description": "Uses trained LSTM model to predict price N days ahead. Returns predicted price and confidence score.",
"parameters": {
"ticker": "string - stock ticker symbol",
"horizon_days": "integer - forecast horizon (1-10)",
},
},
{
"name": "classify_direction",
"description": "Uses XGBoost to classify if price will go UP or DOWN in next 3 days. Returns direction and probability.",
"parameters": {
"ticker": "string",
},
},
{
"name": "get_technical_signals",
"description": "Returns RSI, MACD crossover signal, Bollinger Band position for a ticker.",
"parameters": {
"ticker": "string",
},
},
{
"name": "get_sentiment_score",
"description": "Fetches latest news for ticker, returns sentiment score from -1.0 to 1.0 and recent headlines.",
"parameters": {
"ticker": "string",
},
},
{
"name": "get_portfolio_status",
"description": "Returns current cash, holdings, total value, and unrealized P&L for the portfolio.",
"parameters": {},
},
{
"name": "execute_trade",
"description": "Execute a BUY or SELL order. For BUY: specify dollar amount. For SELL: specify shares or 'all'.",
"parameters": {
"ticker": "string",
"action": "string - 'BUY' or 'SELL'",
"amount_usd": "number - dollar amount for BUY (optional)",
"shares": "number or 'all' - for SELL (optional)",
},
},
{
"name": "get_price_history",
"description": "Returns OHLCV price history for a ticker.",
"parameters": {
"ticker": "string",
"days": "integer - lookback days (default 30)",
},
},
]
async def predict_price_tool(ticker: str, horizon_days: int = 3) -> dict:
return await predict_price(ticker=ticker.upper(), horizon_days=max(1, min(10, horizon_days)))
async def classify_direction_tool(ticker: str) -> dict:
return await classify_direction(ticker=ticker.upper())
async def get_technical_signals_tool(ticker: str) -> dict:
return await get_technical_signals(ticker=ticker.upper())
async def get_sentiment_score_tool(ticker: str) -> dict:
return await get_sentiment_score(ticker=ticker.upper())
async def get_price_history_tool(ticker: str, days: int = 30) -> list[dict]:
return await get_history(ticker=ticker.upper(), days=max(1, days))
async def get_portfolio_status_tool(db: AsyncSession, portfolio_id: str) -> dict:
portfolio = await db.get(Portfolio, portfolio_id)
if not portfolio:
raise ValueError("Portfolio not found")
holdings, holdings_value = await build_holdings_view(db, portfolio_id)
current_cash = as_float(portfolio.current_cash)
return {
"portfolio_id": str(portfolio.id),
"current_cash": round(current_cash, 2),
"holdings_value": round(holdings_value, 2),
"total_value": round(current_cash + holdings_value, 2),
"holdings": [item.model_dump() for item in holdings],
}
async def execute_trade(
db: AsyncSession,
portfolio_id: str,
ticker: str,
action: str,
run_id: str | None = None,
amount_usd: float | None = None,
shares: float | str | None = None,
llm_reasoning: str | None = None,
tools_called: dict | None = None,
) -> dict:
symbol = ticker.upper().strip()
action = action.upper().strip()
if action not in {"BUY", "SELL", "HOLD"}:
raise ValueError("Invalid action")
portfolio = await db.get(Portfolio, portfolio_id)
if not portfolio:
raise ValueError("Portfolio not found")
price_payload = await get_price_snapshot(symbol)
price = float(price_payload["price"])
holding_stmt = select(Holding).where(
Holding.portfolio_id == portfolio.id,
Holding.ticker == symbol,
)
holding = await db.scalar(holding_stmt)
if not holding:
holding = Holding(
portfolio_id=portfolio.id,
ticker=symbol,
shares=0,
avg_buy_price=price,
)
db.add(holding)
executed_shares = 0.0
trade_value = 0.0
idempotency_key = f"{run_id}:{symbol}:{action}" if run_id else None
if idempotency_key:
existing_tx = await db.scalar(
select(Transaction).where(Transaction.idempotency_key == idempotency_key)
)
if existing_tx:
portfolio_out = await build_portfolio_out(db, portfolio)
return {
"transaction": transaction_to_out(existing_tx).model_dump(),
"portfolio": portfolio_out.model_dump(),
}
if action == "BUY":
cash_available = as_float(portfolio.current_cash)
amount = float(amount_usd) if amount_usd is not None else min(cash_available * 0.05, cash_available)
amount = min(amount, cash_available)
if amount <= 0:
raise ValueError("Insufficient cash for BUY")
executed_shares = amount / price
existing_shares = as_float(holding.shares)
existing_avg = as_float(holding.avg_buy_price)
new_total_shares = existing_shares + executed_shares
if new_total_shares > 0:
holding.avg_buy_price = (
(existing_shares * existing_avg) + (executed_shares * price)
) / new_total_shares
holding.shares = new_total_shares
portfolio.current_cash = cash_available - amount
trade_value = amount
elif action == "SELL":
existing_shares = as_float(holding.shares)
if existing_shares <= 0:
raise ValueError(f"No shares available to sell for {symbol}")
if shares == "all" or shares is None:
executed_shares = existing_shares
else:
requested = float(shares)
executed_shares = max(0.0, min(requested, existing_shares))
if executed_shares <= 0:
raise ValueError("No valid shares specified for SELL")
trade_value = executed_shares * price
holding.shares = max(0.0, existing_shares - executed_shares)
if as_float(holding.shares) == 0:
holding.avg_buy_price = None
portfolio.current_cash = as_float(portfolio.current_cash) + trade_value
else: # HOLD
executed_shares = 0.0
trade_value = 0.0
tx = Transaction(
portfolio_id=portfolio.id,
ticker=symbol,
action=action,
shares=round(executed_shares, 6),
price_at_trade=round(price, 4),
total_value=round(trade_value, 2),
idempotency_key=idempotency_key,
llm_reasoning=llm_reasoning or "No detailed reasoning provided.",
tools_called=tools_called or {},
executed_at=datetime.now(timezone.utc),
)
db.add(tx)
portfolio.updated_at = datetime.now(timezone.utc)
await db.commit()
await db.refresh(tx)
portfolio_out = await build_portfolio_out(db, portfolio)
return {
"transaction": transaction_to_out(tx).model_dump(),
"portfolio": portfolio_out.model_dump(),
}
async def execute_tool_call(
db: AsyncSession,
portfolio_id: str,
name: str,
arguments: dict,
):
if name == "predict_price":
return await predict_price_tool(
ticker=str(arguments.get("ticker", "")),
horizon_days=int(arguments.get("horizon_days", 3)),
)
if name == "classify_direction":
return await classify_direction_tool(ticker=str(arguments.get("ticker", "")))
if name == "get_technical_signals":
return await get_technical_signals_tool(ticker=str(arguments.get("ticker", "")))
if name == "get_sentiment_score":
return await get_sentiment_score_tool(ticker=str(arguments.get("ticker", "")))
if name == "get_portfolio_status":
return await get_portfolio_status_tool(db, portfolio_id)
if name == "execute_trade":
return await execute_trade(
db=db,
portfolio_id=portfolio_id,
ticker=str(arguments.get("ticker", "")),
action=str(arguments.get("action", "HOLD")),
run_id=arguments.get("run_id"),
amount_usd=arguments.get("amount_usd"),
shares=arguments.get("shares"),
)
if name == "get_price_history":
return await get_price_history_tool(
ticker=str(arguments.get("ticker", "")),
days=int(arguments.get("days", 30)),
)
raise ValueError(f"Unknown tool call: {name}")