alpha-engine / agent /runner.py
Dharambir Agrawal
HF Space server-only
fd48bc8
from __future__ import annotations
import asyncio
import json
import re
from statistics import median
from datetime import date, datetime, timedelta, timezone
from zoneinfo import ZoneInfo
from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError
from agent.tools import (
classify_direction_tool,
execute_trade,
get_portfolio_status_tool,
get_sentiment_score_tool,
get_technical_signals_tool,
predict_price_tool,
)
from api.utils import build_holdings_view, build_portfolio_out, get_portfolio_tickers, snapshot_portfolio
from core.config import settings
from core.database import SessionLocal
from core.models import AgentRun, Portfolio
from ml.evaluator import record_prediction
def _build_reasoning(
ticker: str,
prediction: dict,
direction: dict,
technical: dict,
sentiment: dict | float,
action: str,
) -> str:
sentiment_score = sentiment.get("score", 0.0) if isinstance(sentiment, dict) else sentiment
return (
f"{ticker}: Predicted {prediction.get('predicted_price')} in {prediction.get('horizon_days')}d "
f"(conf {prediction.get('confidence')}). Direction={direction.get('direction')} "
f"(p={direction.get('probability')}). RSI={technical.get('rsi')}, "
f"MACD={technical.get('macd')}, sentiment={sentiment_score}. Action={action}."
)
def _compact_tool_error(exc: Exception) -> str:
if isinstance(exc, asyncio.TimeoutError):
return "tool timed out"
message = str(exc).replace("**", "")
if "No real OHLCV returned for" in message:
return "market data unavailable from providers"
return message
def _market_today() -> date:
return datetime.now(ZoneInfo(settings.MARKET_TIMEZONE)).date()
def _confidence_based_amount(total_value: float, confidence: float) -> float:
if confidence >= 0.7:
pct = 0.045
elif confidence >= 0.55:
pct = 0.025
else:
pct = 0.015
return max(0.0, total_value * pct)
def _rule_decision(
*,
cash: float,
total_value: float,
confidence: float,
owned_shares: float,
direction: dict,
technical: dict,
sentiment: float,
) -> dict:
action = "HOLD"
amount_usd: float | None = None
sell_shares: float | None = None
if (
direction.get("direction") == "UP"
and float(direction.get("probability", 0)) >= 0.55
and float(technical.get("rsi", 50)) < 72
and sentiment >= -0.1
and cash >= 50
):
action = "BUY"
confidence_amount = _confidence_based_amount(total_value, confidence)
amount_usd = max(50.0, min(confidence_amount, cash * 0.10))
elif (
owned_shares > 0
and (
direction.get("direction") == "DOWN"
or sentiment < -0.25
or float(technical.get("rsi", 50)) > 78
)
):
action = "SELL"
sell_shares = round(max(owned_shares * 0.25, 0.000001), 6)
return {
"action": action,
"amount_usd": amount_usd,
"sell_shares": sell_shares,
"preferred_minutes": None,
"source": "rules",
}
def _extract_json_object(text: str) -> dict | None:
match = re.search(r"\{[\s\S]*\}", text)
if not match:
return None
try:
return json.loads(match.group(0))
except json.JSONDecodeError:
return None
async def _llm_hybrid_decision(
*,
ticker: str,
cash: float,
total_value: float,
owned_shares: float,
prediction: dict,
direction: dict,
technical: dict,
sentiment: dict | float,
baseline: dict,
) -> dict | None:
mode = settings.AGENT_DECISION_MODE.lower().strip()
if mode == "rules" or not settings.GEMINI_API_KEY:
return None
try:
import google.generativeai as genai # type: ignore
except Exception:
return None
now_time = datetime.now(ZoneInfo(settings.MARKET_TIMEZONE)).strftime("%Y-%m-%d %I:%M %p %Z")
prompt = f"""
You are a risk-aware paper-trading assistant.
Decide one action for a single ticker using both quant signals and baseline rule suggestion.
Return STRICT JSON only with keys:
action (BUY|SELL|HOLD), amount_usd (number|null), sell_fraction (number|null), preferred_minutes (int|null), rationale (string).
Current Market Time: {now_time}
Ticker: {ticker}
Cash available: {cash}
Current portfolio total value: {total_value}
Owned shares: {owned_shares}
Prediction: {json.dumps(prediction)}
Direction: {json.dumps(direction)}
Technical: {json.dumps(technical)}
Sentiment: {json.dumps(sentiment) if isinstance(sentiment, dict) else sentiment}
Baseline rule decision: {json.dumps(baseline)}
Position sizing guidance:
- High confidence (>70%) = allocate 4-5% of total portfolio value.
- Medium confidence (55-70%) = 2-3%.
- Low confidence (<55%) = 1-2%.
- Never put more than 10% in one position.
Hard limits:
- BUY only if cash >= 50.
- SELL only if owned_shares > 0.
- amount_usd should follow the confidence-based sizing guidance above and stay within available cash.
- sell_fraction between 0.1 and 1.0 when SELL.
- preferred_minutes between 20 and 240.
""".strip()
try:
genai.configure(api_key=settings.GEMINI_API_KEY)
model = genai.GenerativeModel("gemini-1.5-flash")
# 30-second hard timeout — prevents stalling the whole agent loop on slow/unavailable LLM
response = await asyncio.wait_for(
asyncio.to_thread(model.generate_content, prompt),
timeout=30.0,
)
text = getattr(response, "text", "") or ""
payload = _extract_json_object(text)
if not isinstance(payload, dict):
return None
action = str(payload.get("action", "")).upper().strip()
if action not in {"BUY", "SELL", "HOLD"}:
return None
amount_raw = payload.get("amount_usd")
sell_fraction_raw = payload.get("sell_fraction")
minutes_raw = payload.get("preferred_minutes")
rationale = str(payload.get("rationale", "")).strip()
amount_usd = float(amount_raw) if isinstance(amount_raw, (int, float)) else None
sell_fraction = (
float(sell_fraction_raw) if isinstance(sell_fraction_raw, (int, float)) else None
)
preferred_minutes = (
int(minutes_raw) if isinstance(minutes_raw, (int, float)) else None
)
return {
"action": action,
"amount_usd": amount_usd,
"sell_fraction": sell_fraction,
"preferred_minutes": preferred_minutes,
"source": "llm",
"rationale": rationale,
}
except Exception:
return None
def _merge_hybrid_decision(
*,
baseline: dict,
llm: dict | None,
cash: float,
total_value: float,
owned_shares: float,
) -> dict:
final = dict(baseline)
if not llm:
return final
action = str(llm.get("action", baseline["action"])).upper().strip()
if action not in {"BUY", "SELL", "HOLD"}:
action = baseline["action"]
amount_usd: float | None = None
sell_shares: float | None = None
if action == "BUY":
if cash < 50:
action = "HOLD"
else:
llm_amount = llm.get("amount_usd")
baseline_amount = baseline.get("amount_usd")
amount_val = (
float(llm_amount)
if isinstance(llm_amount, (int, float))
else float(baseline_amount or 50.0)
)
amount_usd = max(50.0, min(amount_val, cash * 0.10, total_value * 0.10))
elif action == "SELL":
if owned_shares <= 0:
action = "HOLD"
else:
sell_fraction = llm.get("sell_fraction")
frac = (
float(sell_fraction)
if isinstance(sell_fraction, (int, float))
else 0.25
)
frac = max(0.1, min(frac, 1.0))
sell_shares = round(max(owned_shares * frac, 0.000001), 6)
preferred_minutes = llm.get("preferred_minutes")
final.update(
{
"action": action,
"amount_usd": amount_usd,
"sell_shares": sell_shares,
"preferred_minutes": (
max(20, min(int(preferred_minutes), 240))
if isinstance(preferred_minutes, (int, float))
else baseline.get("preferred_minutes")
),
"source": llm.get("source", "rules"),
"llm_rationale": llm.get("rationale"),
}
)
return final
async def _with_retries(
coro,
*args,
retries: int = 2,
base_delay_seconds: float = 0.6,
timeout_seconds: float | None = None,
**kwargs,
):
last_error: Exception | None = None
for attempt in range(retries + 1):
try:
if timeout_seconds:
return await asyncio.wait_for(
coro(*args, **kwargs),
timeout=timeout_seconds,
)
return await coro(*args, **kwargs)
except Exception as exc: # noqa: PERF203
last_error = exc
if attempt >= retries:
break
await asyncio.sleep(base_delay_seconds * (attempt + 1))
if last_error:
raise last_error
raise RuntimeError("Retry helper reached an unexpected state")
async def run_agent(
portfolio_id: str,
session: str = "manual",
run_type: str = "manual",
run_id: str | None = None,
) -> dict:
async with SessionLocal() as db:
lock_key = abs(hash(str(portfolio_id))) % (2**31)
got_lock = bool(await db.scalar(select(func.pg_try_advisory_lock(lock_key))))
if not got_lock:
# If a run record was pre-created (manual trigger path), ensure it doesn't
# stay "running" forever.
if run_id:
run = await db.get(AgentRun, run_id)
if run and run.status == "running":
run.status = "skipped"
run.trades_made = run.trades_made or 0
run.summary = (run.summary or "").strip() or "Skipped: run already in progress."
run.completed_at = datetime.now(timezone.utc)
try:
await db.commit()
except IntegrityError:
# Older databases may still enforce a CHECK constraint without
# the 'skipped' status; fall back to a terminal status that
# always exists.
await db.rollback()
run.status = "done"
await db.commit()
return {"status": "skipped", "reason": "run already in progress"}
run: AgentRun | None = None
try:
if run_id:
run = await db.get(AgentRun, run_id)
if not run:
run = AgentRun(
portfolio_id=portfolio_id,
run_type=run_type,
session=session,
status="running",
started_at=datetime.now(timezone.utc),
)
db.add(run)
await db.commit()
await db.refresh(run)
portfolio = await db.get(Portfolio, portfolio_id)
if not portfolio:
run.status = "failed"
run.summary = "Run failed: Portfolio not found"
run.completed_at = datetime.now(timezone.utc)
await db.commit()
return {"status": "failed", "error": "Portfolio not found"}
tickers = await get_portfolio_tickers(db, portfolio.id)
holdings, _ = await build_holdings_view(db, portfolio.id)
portfolio_out = await build_portfolio_out(db, portfolio)
if not tickers:
run.summary = "No tickers configured for this portfolio."
run.status = "done"
run.trades_made = 0
run.total_pl = portfolio_out.profit_loss
run.completed_at = datetime.now(timezone.utc)
await db.commit()
return {"status": "done", "run_id": str(run.id)}
trades_made = 0
summary_lines: list[str] = []
per_ticker_decisions: list[dict] = []
confidence_values: list[float] = []
bearish_signals = 0
bullish_signals = 0
tool_errors = 0
preferred_minutes_votes: list[int] = []
for ticker in tickers:
try:
prediction = await _with_retries(
predict_price_tool,
ticker,
horizon_days=3,
timeout_seconds=25.0,
)
direction = await _with_retries(
classify_direction_tool,
ticker,
timeout_seconds=20.0,
)
technical = await _with_retries(
get_technical_signals_tool,
ticker,
timeout_seconds=20.0,
)
sentiment_data = await _with_retries(
get_sentiment_score_tool,
ticker,
timeout_seconds=20.0,
)
status = await _with_retries(
get_portfolio_status_tool,
db,
portfolio_id,
timeout_seconds=10.0,
)
except Exception as exc:
compact = _compact_tool_error(exc)
summary_lines.append(f"{ticker}: skipped due to tool error ({compact}).")
per_ticker_decisions.append(
{
"ticker": ticker,
"action": "HOLD",
"rationale": f"Skipped due to tool error: {compact}",
"tools_called": {},
}
)
tool_errors += 1
continue
cash = float(status["current_cash"])
matching_holding = next(
(item for item in status["holdings"] if item["ticker"] == ticker),
None,
)
owned_shares = float((matching_holding or {}).get("shares", 0.0))
sentiment_score_float = float(sentiment_data.get("score", 0.0))
baseline = _rule_decision(
cash=cash,
total_value=float(portfolio_out.total_value),
confidence=float(prediction.get("confidence", 0.5)),
owned_shares=owned_shares,
direction=direction,
technical=technical,
sentiment=sentiment_score_float,
)
llm_pick = await _llm_hybrid_decision(
ticker=ticker,
cash=cash,
total_value=float(portfolio_out.total_value),
owned_shares=owned_shares,
prediction=prediction,
direction=direction,
technical=technical,
sentiment=sentiment_data,
baseline=baseline,
)
decision = _merge_hybrid_decision(
baseline=baseline,
llm=llm_pick,
cash=cash,
total_value=float(portfolio_out.total_value),
owned_shares=owned_shares,
)
action = decision["action"]
amount_usd = decision.get("amount_usd")
sell_shares = decision.get("sell_shares")
pref = decision.get("preferred_minutes")
if isinstance(pref, int):
preferred_minutes_votes.append(pref)
conf = float(prediction.get("confidence", 0.5))
confidence_values.append(conf)
if direction.get("direction") == "UP":
bullish_signals += 1
elif direction.get("direction") == "DOWN":
bearish_signals += 1
reasoning = _build_reasoning(
ticker=ticker,
prediction=prediction,
direction=direction,
technical=technical,
sentiment=sentiment_data,
action=action,
)
if decision.get("source") == "llm":
llm_rationale = str(decision.get("llm_rationale") or "").strip()
if llm_rationale:
reasoning = f"{reasoning} LLM rationale: {llm_rationale}"
tools_called = {
"lstm_prediction": prediction,
"direction": direction,
"technical_signals": technical,
"sentiment_score": sentiment_data,
"decision_source": decision.get("source", "rules"),
"decision_mode": settings.AGENT_DECISION_MODE,
}
per_ticker_decisions.append(
{
"ticker": ticker,
"action": action,
"rationale": reasoning,
"tools_called": tools_called,
}
)
horizon_days = max(1, int(prediction.get("horizon_days", 3)))
prediction_date = _market_today()
await record_prediction(
db=db,
ticker=ticker,
model_type="lstm",
predicted_price=float(prediction.get("predicted_price", 0.0)),
actual_price=None,
prediction_date=prediction_date,
prediction_for_date=prediction_date + timedelta(days=horizon_days),
)
if action in {"BUY", "SELL"}:
try:
await _with_retries(
execute_trade,
db=db,
portfolio_id=portfolio_id,
ticker=ticker,
action=action,
amount_usd=amount_usd,
shares=sell_shares,
run_id=str(run.id),
llm_reasoning=reasoning,
tools_called=tools_called,
timeout_seconds=20.0,
)
except Exception as exc:
summary_lines.append(f"{ticker}: trade execution failed ({exc}).")
continue
trades_made += 1
summary_lines.append(reasoning)
await snapshot_portfolio(db, portfolio.id)
refreshed_portfolio = await build_portfolio_out(db, portfolio)
run.status = "done"
run.trades_made = trades_made
run.total_pl = refreshed_portfolio.profit_loss
run.summary = "\n".join(summary_lines[-10:])
run.per_ticker_decisions = per_ticker_decisions
run.completed_at = datetime.now(timezone.utc)
await db.commit()
if portfolio.is_active:
from scheduler.jobs import compute_next_market_run, schedule_portfolio_run
avg_conf = sum(confidence_values) / len(confidence_values) if confidence_values else 0.5
minutes = 180
if trades_made > 0:
minutes = 60
if avg_conf >= 0.72 or bullish_signals >= max(2, len(tickers) // 2):
minutes = min(minutes, 45)
if bearish_signals >= max(2, len(tickers) // 2):
minutes = min(minutes, 60)
if tool_errors > 0:
minutes = max(minutes, 120)
if preferred_minutes_votes:
minutes = int((minutes + median(preferred_minutes_votes)) / 2)
minutes = max(20, min(minutes, 240))
next_run_at = compute_next_market_run(preferred_minutes=minutes)
schedule_portfolio_run(
portfolio_id=str(portfolio.id),
run_at_utc=next_run_at,
session="adaptive",
)
return {
"status": "done",
"run_id": str(run.id),
"trades_made": trades_made,
}
except Exception as exc:
if run is not None:
run.status = "failed"
run.summary = f"Run failed: {exc}"
run.per_ticker_decisions = run.per_ticker_decisions or []
run.completed_at = datetime.now(timezone.utc)
await db.commit()
return {
"status": "failed",
"run_id": str(run.id),
"error": str(exc),
}
return {"status": "failed", "error": str(exc)}
finally:
# Advisory locks are connection-scoped; ensure we release on the same session.
await db.execute(select(func.pg_advisory_unlock(lock_key)))