from __future__ import annotations import asyncio import json import re from statistics import median from datetime import date, datetime, timedelta, timezone from zoneinfo import ZoneInfo from sqlalchemy import func, select from sqlalchemy.exc import IntegrityError from agent.tools import ( classify_direction_tool, execute_trade, get_portfolio_status_tool, get_sentiment_score_tool, get_technical_signals_tool, predict_price_tool, ) from api.utils import build_holdings_view, build_portfolio_out, get_portfolio_tickers, snapshot_portfolio from core.config import settings from core.database import SessionLocal from core.models import AgentRun, Portfolio from ml.evaluator import record_prediction def _build_reasoning( ticker: str, prediction: dict, direction: dict, technical: dict, sentiment: dict | float, action: str, ) -> str: sentiment_score = sentiment.get("score", 0.0) if isinstance(sentiment, dict) else sentiment return ( f"{ticker}: Predicted {prediction.get('predicted_price')} in {prediction.get('horizon_days')}d " f"(conf {prediction.get('confidence')}). Direction={direction.get('direction')} " f"(p={direction.get('probability')}). RSI={technical.get('rsi')}, " f"MACD={technical.get('macd')}, sentiment={sentiment_score}. Action={action}." ) def _compact_tool_error(exc: Exception) -> str: if isinstance(exc, asyncio.TimeoutError): return "tool timed out" message = str(exc).replace("**", "") if "No real OHLCV returned for" in message: return "market data unavailable from providers" return message def _market_today() -> date: return datetime.now(ZoneInfo(settings.MARKET_TIMEZONE)).date() def _confidence_based_amount(total_value: float, confidence: float) -> float: if confidence >= 0.7: pct = 0.045 elif confidence >= 0.55: pct = 0.025 else: pct = 0.015 return max(0.0, total_value * pct) def _rule_decision( *, cash: float, total_value: float, confidence: float, owned_shares: float, direction: dict, technical: dict, sentiment: float, ) -> dict: action = "HOLD" amount_usd: float | None = None sell_shares: float | None = None if ( direction.get("direction") == "UP" and float(direction.get("probability", 0)) >= 0.55 and float(technical.get("rsi", 50)) < 72 and sentiment >= -0.1 and cash >= 50 ): action = "BUY" confidence_amount = _confidence_based_amount(total_value, confidence) amount_usd = max(50.0, min(confidence_amount, cash * 0.10)) elif ( owned_shares > 0 and ( direction.get("direction") == "DOWN" or sentiment < -0.25 or float(technical.get("rsi", 50)) > 78 ) ): action = "SELL" sell_shares = round(max(owned_shares * 0.25, 0.000001), 6) return { "action": action, "amount_usd": amount_usd, "sell_shares": sell_shares, "preferred_minutes": None, "source": "rules", } def _extract_json_object(text: str) -> dict | None: match = re.search(r"\{[\s\S]*\}", text) if not match: return None try: return json.loads(match.group(0)) except json.JSONDecodeError: return None async def _llm_hybrid_decision( *, ticker: str, cash: float, total_value: float, owned_shares: float, prediction: dict, direction: dict, technical: dict, sentiment: dict | float, baseline: dict, ) -> dict | None: mode = settings.AGENT_DECISION_MODE.lower().strip() if mode == "rules" or not settings.GEMINI_API_KEY: return None try: import google.generativeai as genai # type: ignore except Exception: return None now_time = datetime.now(ZoneInfo(settings.MARKET_TIMEZONE)).strftime("%Y-%m-%d %I:%M %p %Z") prompt = f""" You are a risk-aware paper-trading assistant. Decide one action for a single ticker using both quant signals and baseline rule suggestion. Return STRICT JSON only with keys: action (BUY|SELL|HOLD), amount_usd (number|null), sell_fraction (number|null), preferred_minutes (int|null), rationale (string). Current Market Time: {now_time} Ticker: {ticker} Cash available: {cash} Current portfolio total value: {total_value} Owned shares: {owned_shares} Prediction: {json.dumps(prediction)} Direction: {json.dumps(direction)} Technical: {json.dumps(technical)} Sentiment: {json.dumps(sentiment) if isinstance(sentiment, dict) else sentiment} Baseline rule decision: {json.dumps(baseline)} Position sizing guidance: - High confidence (>70%) = allocate 4-5% of total portfolio value. - Medium confidence (55-70%) = 2-3%. - Low confidence (<55%) = 1-2%. - Never put more than 10% in one position. Hard limits: - BUY only if cash >= 50. - SELL only if owned_shares > 0. - amount_usd should follow the confidence-based sizing guidance above and stay within available cash. - sell_fraction between 0.1 and 1.0 when SELL. - preferred_minutes between 20 and 240. """.strip() try: genai.configure(api_key=settings.GEMINI_API_KEY) model = genai.GenerativeModel("gemini-1.5-flash") # 30-second hard timeout — prevents stalling the whole agent loop on slow/unavailable LLM response = await asyncio.wait_for( asyncio.to_thread(model.generate_content, prompt), timeout=30.0, ) text = getattr(response, "text", "") or "" payload = _extract_json_object(text) if not isinstance(payload, dict): return None action = str(payload.get("action", "")).upper().strip() if action not in {"BUY", "SELL", "HOLD"}: return None amount_raw = payload.get("amount_usd") sell_fraction_raw = payload.get("sell_fraction") minutes_raw = payload.get("preferred_minutes") rationale = str(payload.get("rationale", "")).strip() amount_usd = float(amount_raw) if isinstance(amount_raw, (int, float)) else None sell_fraction = ( float(sell_fraction_raw) if isinstance(sell_fraction_raw, (int, float)) else None ) preferred_minutes = ( int(minutes_raw) if isinstance(minutes_raw, (int, float)) else None ) return { "action": action, "amount_usd": amount_usd, "sell_fraction": sell_fraction, "preferred_minutes": preferred_minutes, "source": "llm", "rationale": rationale, } except Exception: return None def _merge_hybrid_decision( *, baseline: dict, llm: dict | None, cash: float, total_value: float, owned_shares: float, ) -> dict: final = dict(baseline) if not llm: return final action = str(llm.get("action", baseline["action"])).upper().strip() if action not in {"BUY", "SELL", "HOLD"}: action = baseline["action"] amount_usd: float | None = None sell_shares: float | None = None if action == "BUY": if cash < 50: action = "HOLD" else: llm_amount = llm.get("amount_usd") baseline_amount = baseline.get("amount_usd") amount_val = ( float(llm_amount) if isinstance(llm_amount, (int, float)) else float(baseline_amount or 50.0) ) amount_usd = max(50.0, min(amount_val, cash * 0.10, total_value * 0.10)) elif action == "SELL": if owned_shares <= 0: action = "HOLD" else: sell_fraction = llm.get("sell_fraction") frac = ( float(sell_fraction) if isinstance(sell_fraction, (int, float)) else 0.25 ) frac = max(0.1, min(frac, 1.0)) sell_shares = round(max(owned_shares * frac, 0.000001), 6) preferred_minutes = llm.get("preferred_minutes") final.update( { "action": action, "amount_usd": amount_usd, "sell_shares": sell_shares, "preferred_minutes": ( max(20, min(int(preferred_minutes), 240)) if isinstance(preferred_minutes, (int, float)) else baseline.get("preferred_minutes") ), "source": llm.get("source", "rules"), "llm_rationale": llm.get("rationale"), } ) return final async def _with_retries( coro, *args, retries: int = 2, base_delay_seconds: float = 0.6, timeout_seconds: float | None = None, **kwargs, ): last_error: Exception | None = None for attempt in range(retries + 1): try: if timeout_seconds: return await asyncio.wait_for( coro(*args, **kwargs), timeout=timeout_seconds, ) return await coro(*args, **kwargs) except Exception as exc: # noqa: PERF203 last_error = exc if attempt >= retries: break await asyncio.sleep(base_delay_seconds * (attempt + 1)) if last_error: raise last_error raise RuntimeError("Retry helper reached an unexpected state") async def run_agent( portfolio_id: str, session: str = "manual", run_type: str = "manual", run_id: str | None = None, ) -> dict: async with SessionLocal() as db: lock_key = abs(hash(str(portfolio_id))) % (2**31) got_lock = bool(await db.scalar(select(func.pg_try_advisory_lock(lock_key)))) if not got_lock: # If a run record was pre-created (manual trigger path), ensure it doesn't # stay "running" forever. if run_id: run = await db.get(AgentRun, run_id) if run and run.status == "running": run.status = "skipped" run.trades_made = run.trades_made or 0 run.summary = (run.summary or "").strip() or "Skipped: run already in progress." run.completed_at = datetime.now(timezone.utc) try: await db.commit() except IntegrityError: # Older databases may still enforce a CHECK constraint without # the 'skipped' status; fall back to a terminal status that # always exists. await db.rollback() run.status = "done" await db.commit() return {"status": "skipped", "reason": "run already in progress"} run: AgentRun | None = None try: if run_id: run = await db.get(AgentRun, run_id) if not run: run = AgentRun( portfolio_id=portfolio_id, run_type=run_type, session=session, status="running", started_at=datetime.now(timezone.utc), ) db.add(run) await db.commit() await db.refresh(run) portfolio = await db.get(Portfolio, portfolio_id) if not portfolio: run.status = "failed" run.summary = "Run failed: Portfolio not found" run.completed_at = datetime.now(timezone.utc) await db.commit() return {"status": "failed", "error": "Portfolio not found"} tickers = await get_portfolio_tickers(db, portfolio.id) holdings, _ = await build_holdings_view(db, portfolio.id) portfolio_out = await build_portfolio_out(db, portfolio) if not tickers: run.summary = "No tickers configured for this portfolio." run.status = "done" run.trades_made = 0 run.total_pl = portfolio_out.profit_loss run.completed_at = datetime.now(timezone.utc) await db.commit() return {"status": "done", "run_id": str(run.id)} trades_made = 0 summary_lines: list[str] = [] per_ticker_decisions: list[dict] = [] confidence_values: list[float] = [] bearish_signals = 0 bullish_signals = 0 tool_errors = 0 preferred_minutes_votes: list[int] = [] for ticker in tickers: try: prediction = await _with_retries( predict_price_tool, ticker, horizon_days=3, timeout_seconds=25.0, ) direction = await _with_retries( classify_direction_tool, ticker, timeout_seconds=20.0, ) technical = await _with_retries( get_technical_signals_tool, ticker, timeout_seconds=20.0, ) sentiment_data = await _with_retries( get_sentiment_score_tool, ticker, timeout_seconds=20.0, ) status = await _with_retries( get_portfolio_status_tool, db, portfolio_id, timeout_seconds=10.0, ) except Exception as exc: compact = _compact_tool_error(exc) summary_lines.append(f"{ticker}: skipped due to tool error ({compact}).") per_ticker_decisions.append( { "ticker": ticker, "action": "HOLD", "rationale": f"Skipped due to tool error: {compact}", "tools_called": {}, } ) tool_errors += 1 continue cash = float(status["current_cash"]) matching_holding = next( (item for item in status["holdings"] if item["ticker"] == ticker), None, ) owned_shares = float((matching_holding or {}).get("shares", 0.0)) sentiment_score_float = float(sentiment_data.get("score", 0.0)) baseline = _rule_decision( cash=cash, total_value=float(portfolio_out.total_value), confidence=float(prediction.get("confidence", 0.5)), owned_shares=owned_shares, direction=direction, technical=technical, sentiment=sentiment_score_float, ) llm_pick = await _llm_hybrid_decision( ticker=ticker, cash=cash, total_value=float(portfolio_out.total_value), owned_shares=owned_shares, prediction=prediction, direction=direction, technical=technical, sentiment=sentiment_data, baseline=baseline, ) decision = _merge_hybrid_decision( baseline=baseline, llm=llm_pick, cash=cash, total_value=float(portfolio_out.total_value), owned_shares=owned_shares, ) action = decision["action"] amount_usd = decision.get("amount_usd") sell_shares = decision.get("sell_shares") pref = decision.get("preferred_minutes") if isinstance(pref, int): preferred_minutes_votes.append(pref) conf = float(prediction.get("confidence", 0.5)) confidence_values.append(conf) if direction.get("direction") == "UP": bullish_signals += 1 elif direction.get("direction") == "DOWN": bearish_signals += 1 reasoning = _build_reasoning( ticker=ticker, prediction=prediction, direction=direction, technical=technical, sentiment=sentiment_data, action=action, ) if decision.get("source") == "llm": llm_rationale = str(decision.get("llm_rationale") or "").strip() if llm_rationale: reasoning = f"{reasoning} LLM rationale: {llm_rationale}" tools_called = { "lstm_prediction": prediction, "direction": direction, "technical_signals": technical, "sentiment_score": sentiment_data, "decision_source": decision.get("source", "rules"), "decision_mode": settings.AGENT_DECISION_MODE, } per_ticker_decisions.append( { "ticker": ticker, "action": action, "rationale": reasoning, "tools_called": tools_called, } ) horizon_days = max(1, int(prediction.get("horizon_days", 3))) prediction_date = _market_today() await record_prediction( db=db, ticker=ticker, model_type="lstm", predicted_price=float(prediction.get("predicted_price", 0.0)), actual_price=None, prediction_date=prediction_date, prediction_for_date=prediction_date + timedelta(days=horizon_days), ) if action in {"BUY", "SELL"}: try: await _with_retries( execute_trade, db=db, portfolio_id=portfolio_id, ticker=ticker, action=action, amount_usd=amount_usd, shares=sell_shares, run_id=str(run.id), llm_reasoning=reasoning, tools_called=tools_called, timeout_seconds=20.0, ) except Exception as exc: summary_lines.append(f"{ticker}: trade execution failed ({exc}).") continue trades_made += 1 summary_lines.append(reasoning) await snapshot_portfolio(db, portfolio.id) refreshed_portfolio = await build_portfolio_out(db, portfolio) run.status = "done" run.trades_made = trades_made run.total_pl = refreshed_portfolio.profit_loss run.summary = "\n".join(summary_lines[-10:]) run.per_ticker_decisions = per_ticker_decisions run.completed_at = datetime.now(timezone.utc) await db.commit() if portfolio.is_active: from scheduler.jobs import compute_next_market_run, schedule_portfolio_run avg_conf = sum(confidence_values) / len(confidence_values) if confidence_values else 0.5 minutes = 180 if trades_made > 0: minutes = 60 if avg_conf >= 0.72 or bullish_signals >= max(2, len(tickers) // 2): minutes = min(minutes, 45) if bearish_signals >= max(2, len(tickers) // 2): minutes = min(minutes, 60) if tool_errors > 0: minutes = max(minutes, 120) if preferred_minutes_votes: minutes = int((minutes + median(preferred_minutes_votes)) / 2) minutes = max(20, min(minutes, 240)) next_run_at = compute_next_market_run(preferred_minutes=minutes) schedule_portfolio_run( portfolio_id=str(portfolio.id), run_at_utc=next_run_at, session="adaptive", ) return { "status": "done", "run_id": str(run.id), "trades_made": trades_made, } except Exception as exc: if run is not None: run.status = "failed" run.summary = f"Run failed: {exc}" run.per_ticker_decisions = run.per_ticker_decisions or [] run.completed_at = datetime.now(timezone.utc) await db.commit() return { "status": "failed", "run_id": str(run.id), "error": str(exc), } return {"status": "failed", "error": str(exc)} finally: # Advisory locks are connection-scoped; ensure we release on the same session. await db.execute(select(func.pg_advisory_unlock(lock_key)))