Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import re | |
| from statistics import median | |
| from datetime import date, datetime, timedelta, timezone | |
| from zoneinfo import ZoneInfo | |
| from sqlalchemy import func, select | |
| from sqlalchemy.exc import IntegrityError | |
| from agent.tools import ( | |
| classify_direction_tool, | |
| execute_trade, | |
| get_portfolio_status_tool, | |
| get_sentiment_score_tool, | |
| get_technical_signals_tool, | |
| predict_price_tool, | |
| ) | |
| from api.utils import build_holdings_view, build_portfolio_out, get_portfolio_tickers, snapshot_portfolio | |
| from core.config import settings | |
| from core.database import SessionLocal | |
| from core.models import AgentRun, Portfolio | |
| from ml.evaluator import record_prediction | |
| def _build_reasoning( | |
| ticker: str, | |
| prediction: dict, | |
| direction: dict, | |
| technical: dict, | |
| sentiment: dict | float, | |
| action: str, | |
| ) -> str: | |
| sentiment_score = sentiment.get("score", 0.0) if isinstance(sentiment, dict) else sentiment | |
| return ( | |
| f"{ticker}: Predicted {prediction.get('predicted_price')} in {prediction.get('horizon_days')}d " | |
| f"(conf {prediction.get('confidence')}). Direction={direction.get('direction')} " | |
| f"(p={direction.get('probability')}). RSI={technical.get('rsi')}, " | |
| f"MACD={technical.get('macd')}, sentiment={sentiment_score}. Action={action}." | |
| ) | |
| def _compact_tool_error(exc: Exception) -> str: | |
| if isinstance(exc, asyncio.TimeoutError): | |
| return "tool timed out" | |
| message = str(exc).replace("**", "") | |
| if "No real OHLCV returned for" in message: | |
| return "market data unavailable from providers" | |
| return message | |
| def _market_today() -> date: | |
| return datetime.now(ZoneInfo(settings.MARKET_TIMEZONE)).date() | |
| def _confidence_based_amount(total_value: float, confidence: float) -> float: | |
| if confidence >= 0.7: | |
| pct = 0.045 | |
| elif confidence >= 0.55: | |
| pct = 0.025 | |
| else: | |
| pct = 0.015 | |
| return max(0.0, total_value * pct) | |
| def _rule_decision( | |
| *, | |
| cash: float, | |
| total_value: float, | |
| confidence: float, | |
| owned_shares: float, | |
| direction: dict, | |
| technical: dict, | |
| sentiment: float, | |
| ) -> dict: | |
| action = "HOLD" | |
| amount_usd: float | None = None | |
| sell_shares: float | None = None | |
| if ( | |
| direction.get("direction") == "UP" | |
| and float(direction.get("probability", 0)) >= 0.55 | |
| and float(technical.get("rsi", 50)) < 72 | |
| and sentiment >= -0.1 | |
| and cash >= 50 | |
| ): | |
| action = "BUY" | |
| confidence_amount = _confidence_based_amount(total_value, confidence) | |
| amount_usd = max(50.0, min(confidence_amount, cash * 0.10)) | |
| elif ( | |
| owned_shares > 0 | |
| and ( | |
| direction.get("direction") == "DOWN" | |
| or sentiment < -0.25 | |
| or float(technical.get("rsi", 50)) > 78 | |
| ) | |
| ): | |
| action = "SELL" | |
| sell_shares = round(max(owned_shares * 0.25, 0.000001), 6) | |
| return { | |
| "action": action, | |
| "amount_usd": amount_usd, | |
| "sell_shares": sell_shares, | |
| "preferred_minutes": None, | |
| "source": "rules", | |
| } | |
| def _extract_json_object(text: str) -> dict | None: | |
| match = re.search(r"\{[\s\S]*\}", text) | |
| if not match: | |
| return None | |
| try: | |
| return json.loads(match.group(0)) | |
| except json.JSONDecodeError: | |
| return None | |
| async def _llm_hybrid_decision( | |
| *, | |
| ticker: str, | |
| cash: float, | |
| total_value: float, | |
| owned_shares: float, | |
| prediction: dict, | |
| direction: dict, | |
| technical: dict, | |
| sentiment: dict | float, | |
| baseline: dict, | |
| ) -> dict | None: | |
| mode = settings.AGENT_DECISION_MODE.lower().strip() | |
| if mode == "rules" or not settings.GEMINI_API_KEY: | |
| return None | |
| try: | |
| import google.generativeai as genai # type: ignore | |
| except Exception: | |
| return None | |
| now_time = datetime.now(ZoneInfo(settings.MARKET_TIMEZONE)).strftime("%Y-%m-%d %I:%M %p %Z") | |
| prompt = f""" | |
| You are a risk-aware paper-trading assistant. | |
| Decide one action for a single ticker using both quant signals and baseline rule suggestion. | |
| Return STRICT JSON only with keys: | |
| action (BUY|SELL|HOLD), amount_usd (number|null), sell_fraction (number|null), preferred_minutes (int|null), rationale (string). | |
| Current Market Time: {now_time} | |
| Ticker: {ticker} | |
| Cash available: {cash} | |
| Current portfolio total value: {total_value} | |
| Owned shares: {owned_shares} | |
| Prediction: {json.dumps(prediction)} | |
| Direction: {json.dumps(direction)} | |
| Technical: {json.dumps(technical)} | |
| Sentiment: {json.dumps(sentiment) if isinstance(sentiment, dict) else sentiment} | |
| Baseline rule decision: {json.dumps(baseline)} | |
| Position sizing guidance: | |
| - High confidence (>70%) = allocate 4-5% of total portfolio value. | |
| - Medium confidence (55-70%) = 2-3%. | |
| - Low confidence (<55%) = 1-2%. | |
| - Never put more than 10% in one position. | |
| Hard limits: | |
| - BUY only if cash >= 50. | |
| - SELL only if owned_shares > 0. | |
| - amount_usd should follow the confidence-based sizing guidance above and stay within available cash. | |
| - sell_fraction between 0.1 and 1.0 when SELL. | |
| - preferred_minutes between 20 and 240. | |
| """.strip() | |
| try: | |
| genai.configure(api_key=settings.GEMINI_API_KEY) | |
| model = genai.GenerativeModel("gemini-1.5-flash") | |
| # 30-second hard timeout — prevents stalling the whole agent loop on slow/unavailable LLM | |
| response = await asyncio.wait_for( | |
| asyncio.to_thread(model.generate_content, prompt), | |
| timeout=30.0, | |
| ) | |
| text = getattr(response, "text", "") or "" | |
| payload = _extract_json_object(text) | |
| if not isinstance(payload, dict): | |
| return None | |
| action = str(payload.get("action", "")).upper().strip() | |
| if action not in {"BUY", "SELL", "HOLD"}: | |
| return None | |
| amount_raw = payload.get("amount_usd") | |
| sell_fraction_raw = payload.get("sell_fraction") | |
| minutes_raw = payload.get("preferred_minutes") | |
| rationale = str(payload.get("rationale", "")).strip() | |
| amount_usd = float(amount_raw) if isinstance(amount_raw, (int, float)) else None | |
| sell_fraction = ( | |
| float(sell_fraction_raw) if isinstance(sell_fraction_raw, (int, float)) else None | |
| ) | |
| preferred_minutes = ( | |
| int(minutes_raw) if isinstance(minutes_raw, (int, float)) else None | |
| ) | |
| return { | |
| "action": action, | |
| "amount_usd": amount_usd, | |
| "sell_fraction": sell_fraction, | |
| "preferred_minutes": preferred_minutes, | |
| "source": "llm", | |
| "rationale": rationale, | |
| } | |
| except Exception: | |
| return None | |
| def _merge_hybrid_decision( | |
| *, | |
| baseline: dict, | |
| llm: dict | None, | |
| cash: float, | |
| total_value: float, | |
| owned_shares: float, | |
| ) -> dict: | |
| final = dict(baseline) | |
| if not llm: | |
| return final | |
| action = str(llm.get("action", baseline["action"])).upper().strip() | |
| if action not in {"BUY", "SELL", "HOLD"}: | |
| action = baseline["action"] | |
| amount_usd: float | None = None | |
| sell_shares: float | None = None | |
| if action == "BUY": | |
| if cash < 50: | |
| action = "HOLD" | |
| else: | |
| llm_amount = llm.get("amount_usd") | |
| baseline_amount = baseline.get("amount_usd") | |
| amount_val = ( | |
| float(llm_amount) | |
| if isinstance(llm_amount, (int, float)) | |
| else float(baseline_amount or 50.0) | |
| ) | |
| amount_usd = max(50.0, min(amount_val, cash * 0.10, total_value * 0.10)) | |
| elif action == "SELL": | |
| if owned_shares <= 0: | |
| action = "HOLD" | |
| else: | |
| sell_fraction = llm.get("sell_fraction") | |
| frac = ( | |
| float(sell_fraction) | |
| if isinstance(sell_fraction, (int, float)) | |
| else 0.25 | |
| ) | |
| frac = max(0.1, min(frac, 1.0)) | |
| sell_shares = round(max(owned_shares * frac, 0.000001), 6) | |
| preferred_minutes = llm.get("preferred_minutes") | |
| final.update( | |
| { | |
| "action": action, | |
| "amount_usd": amount_usd, | |
| "sell_shares": sell_shares, | |
| "preferred_minutes": ( | |
| max(20, min(int(preferred_minutes), 240)) | |
| if isinstance(preferred_minutes, (int, float)) | |
| else baseline.get("preferred_minutes") | |
| ), | |
| "source": llm.get("source", "rules"), | |
| "llm_rationale": llm.get("rationale"), | |
| } | |
| ) | |
| return final | |
| async def _with_retries( | |
| coro, | |
| *args, | |
| retries: int = 2, | |
| base_delay_seconds: float = 0.6, | |
| timeout_seconds: float | None = None, | |
| **kwargs, | |
| ): | |
| last_error: Exception | None = None | |
| for attempt in range(retries + 1): | |
| try: | |
| if timeout_seconds: | |
| return await asyncio.wait_for( | |
| coro(*args, **kwargs), | |
| timeout=timeout_seconds, | |
| ) | |
| return await coro(*args, **kwargs) | |
| except Exception as exc: # noqa: PERF203 | |
| last_error = exc | |
| if attempt >= retries: | |
| break | |
| await asyncio.sleep(base_delay_seconds * (attempt + 1)) | |
| if last_error: | |
| raise last_error | |
| raise RuntimeError("Retry helper reached an unexpected state") | |
| async def run_agent( | |
| portfolio_id: str, | |
| session: str = "manual", | |
| run_type: str = "manual", | |
| run_id: str | None = None, | |
| ) -> dict: | |
| async with SessionLocal() as db: | |
| lock_key = abs(hash(str(portfolio_id))) % (2**31) | |
| got_lock = bool(await db.scalar(select(func.pg_try_advisory_lock(lock_key)))) | |
| if not got_lock: | |
| # If a run record was pre-created (manual trigger path), ensure it doesn't | |
| # stay "running" forever. | |
| if run_id: | |
| run = await db.get(AgentRun, run_id) | |
| if run and run.status == "running": | |
| run.status = "skipped" | |
| run.trades_made = run.trades_made or 0 | |
| run.summary = (run.summary or "").strip() or "Skipped: run already in progress." | |
| run.completed_at = datetime.now(timezone.utc) | |
| try: | |
| await db.commit() | |
| except IntegrityError: | |
| # Older databases may still enforce a CHECK constraint without | |
| # the 'skipped' status; fall back to a terminal status that | |
| # always exists. | |
| await db.rollback() | |
| run.status = "done" | |
| await db.commit() | |
| return {"status": "skipped", "reason": "run already in progress"} | |
| run: AgentRun | None = None | |
| try: | |
| if run_id: | |
| run = await db.get(AgentRun, run_id) | |
| if not run: | |
| run = AgentRun( | |
| portfolio_id=portfolio_id, | |
| run_type=run_type, | |
| session=session, | |
| status="running", | |
| started_at=datetime.now(timezone.utc), | |
| ) | |
| db.add(run) | |
| await db.commit() | |
| await db.refresh(run) | |
| portfolio = await db.get(Portfolio, portfolio_id) | |
| if not portfolio: | |
| run.status = "failed" | |
| run.summary = "Run failed: Portfolio not found" | |
| run.completed_at = datetime.now(timezone.utc) | |
| await db.commit() | |
| return {"status": "failed", "error": "Portfolio not found"} | |
| tickers = await get_portfolio_tickers(db, portfolio.id) | |
| holdings, _ = await build_holdings_view(db, portfolio.id) | |
| portfolio_out = await build_portfolio_out(db, portfolio) | |
| if not tickers: | |
| run.summary = "No tickers configured for this portfolio." | |
| run.status = "done" | |
| run.trades_made = 0 | |
| run.total_pl = portfolio_out.profit_loss | |
| run.completed_at = datetime.now(timezone.utc) | |
| await db.commit() | |
| return {"status": "done", "run_id": str(run.id)} | |
| trades_made = 0 | |
| summary_lines: list[str] = [] | |
| per_ticker_decisions: list[dict] = [] | |
| confidence_values: list[float] = [] | |
| bearish_signals = 0 | |
| bullish_signals = 0 | |
| tool_errors = 0 | |
| preferred_minutes_votes: list[int] = [] | |
| for ticker in tickers: | |
| try: | |
| prediction = await _with_retries( | |
| predict_price_tool, | |
| ticker, | |
| horizon_days=3, | |
| timeout_seconds=25.0, | |
| ) | |
| direction = await _with_retries( | |
| classify_direction_tool, | |
| ticker, | |
| timeout_seconds=20.0, | |
| ) | |
| technical = await _with_retries( | |
| get_technical_signals_tool, | |
| ticker, | |
| timeout_seconds=20.0, | |
| ) | |
| sentiment_data = await _with_retries( | |
| get_sentiment_score_tool, | |
| ticker, | |
| timeout_seconds=20.0, | |
| ) | |
| status = await _with_retries( | |
| get_portfolio_status_tool, | |
| db, | |
| portfolio_id, | |
| timeout_seconds=10.0, | |
| ) | |
| except Exception as exc: | |
| compact = _compact_tool_error(exc) | |
| summary_lines.append(f"{ticker}: skipped due to tool error ({compact}).") | |
| per_ticker_decisions.append( | |
| { | |
| "ticker": ticker, | |
| "action": "HOLD", | |
| "rationale": f"Skipped due to tool error: {compact}", | |
| "tools_called": {}, | |
| } | |
| ) | |
| tool_errors += 1 | |
| continue | |
| cash = float(status["current_cash"]) | |
| matching_holding = next( | |
| (item for item in status["holdings"] if item["ticker"] == ticker), | |
| None, | |
| ) | |
| owned_shares = float((matching_holding or {}).get("shares", 0.0)) | |
| sentiment_score_float = float(sentiment_data.get("score", 0.0)) | |
| baseline = _rule_decision( | |
| cash=cash, | |
| total_value=float(portfolio_out.total_value), | |
| confidence=float(prediction.get("confidence", 0.5)), | |
| owned_shares=owned_shares, | |
| direction=direction, | |
| technical=technical, | |
| sentiment=sentiment_score_float, | |
| ) | |
| llm_pick = await _llm_hybrid_decision( | |
| ticker=ticker, | |
| cash=cash, | |
| total_value=float(portfolio_out.total_value), | |
| owned_shares=owned_shares, | |
| prediction=prediction, | |
| direction=direction, | |
| technical=technical, | |
| sentiment=sentiment_data, | |
| baseline=baseline, | |
| ) | |
| decision = _merge_hybrid_decision( | |
| baseline=baseline, | |
| llm=llm_pick, | |
| cash=cash, | |
| total_value=float(portfolio_out.total_value), | |
| owned_shares=owned_shares, | |
| ) | |
| action = decision["action"] | |
| amount_usd = decision.get("amount_usd") | |
| sell_shares = decision.get("sell_shares") | |
| pref = decision.get("preferred_minutes") | |
| if isinstance(pref, int): | |
| preferred_minutes_votes.append(pref) | |
| conf = float(prediction.get("confidence", 0.5)) | |
| confidence_values.append(conf) | |
| if direction.get("direction") == "UP": | |
| bullish_signals += 1 | |
| elif direction.get("direction") == "DOWN": | |
| bearish_signals += 1 | |
| reasoning = _build_reasoning( | |
| ticker=ticker, | |
| prediction=prediction, | |
| direction=direction, | |
| technical=technical, | |
| sentiment=sentiment_data, | |
| action=action, | |
| ) | |
| if decision.get("source") == "llm": | |
| llm_rationale = str(decision.get("llm_rationale") or "").strip() | |
| if llm_rationale: | |
| reasoning = f"{reasoning} LLM rationale: {llm_rationale}" | |
| tools_called = { | |
| "lstm_prediction": prediction, | |
| "direction": direction, | |
| "technical_signals": technical, | |
| "sentiment_score": sentiment_data, | |
| "decision_source": decision.get("source", "rules"), | |
| "decision_mode": settings.AGENT_DECISION_MODE, | |
| } | |
| per_ticker_decisions.append( | |
| { | |
| "ticker": ticker, | |
| "action": action, | |
| "rationale": reasoning, | |
| "tools_called": tools_called, | |
| } | |
| ) | |
| horizon_days = max(1, int(prediction.get("horizon_days", 3))) | |
| prediction_date = _market_today() | |
| await record_prediction( | |
| db=db, | |
| ticker=ticker, | |
| model_type="lstm", | |
| predicted_price=float(prediction.get("predicted_price", 0.0)), | |
| actual_price=None, | |
| prediction_date=prediction_date, | |
| prediction_for_date=prediction_date + timedelta(days=horizon_days), | |
| ) | |
| if action in {"BUY", "SELL"}: | |
| try: | |
| await _with_retries( | |
| execute_trade, | |
| db=db, | |
| portfolio_id=portfolio_id, | |
| ticker=ticker, | |
| action=action, | |
| amount_usd=amount_usd, | |
| shares=sell_shares, | |
| run_id=str(run.id), | |
| llm_reasoning=reasoning, | |
| tools_called=tools_called, | |
| timeout_seconds=20.0, | |
| ) | |
| except Exception as exc: | |
| summary_lines.append(f"{ticker}: trade execution failed ({exc}).") | |
| continue | |
| trades_made += 1 | |
| summary_lines.append(reasoning) | |
| await snapshot_portfolio(db, portfolio.id) | |
| refreshed_portfolio = await build_portfolio_out(db, portfolio) | |
| run.status = "done" | |
| run.trades_made = trades_made | |
| run.total_pl = refreshed_portfolio.profit_loss | |
| run.summary = "\n".join(summary_lines[-10:]) | |
| run.per_ticker_decisions = per_ticker_decisions | |
| run.completed_at = datetime.now(timezone.utc) | |
| await db.commit() | |
| if portfolio.is_active: | |
| from scheduler.jobs import compute_next_market_run, schedule_portfolio_run | |
| avg_conf = sum(confidence_values) / len(confidence_values) if confidence_values else 0.5 | |
| minutes = 180 | |
| if trades_made > 0: | |
| minutes = 60 | |
| if avg_conf >= 0.72 or bullish_signals >= max(2, len(tickers) // 2): | |
| minutes = min(minutes, 45) | |
| if bearish_signals >= max(2, len(tickers) // 2): | |
| minutes = min(minutes, 60) | |
| if tool_errors > 0: | |
| minutes = max(minutes, 120) | |
| if preferred_minutes_votes: | |
| minutes = int((minutes + median(preferred_minutes_votes)) / 2) | |
| minutes = max(20, min(minutes, 240)) | |
| next_run_at = compute_next_market_run(preferred_minutes=minutes) | |
| schedule_portfolio_run( | |
| portfolio_id=str(portfolio.id), | |
| run_at_utc=next_run_at, | |
| session="adaptive", | |
| ) | |
| return { | |
| "status": "done", | |
| "run_id": str(run.id), | |
| "trades_made": trades_made, | |
| } | |
| except Exception as exc: | |
| if run is not None: | |
| run.status = "failed" | |
| run.summary = f"Run failed: {exc}" | |
| run.per_ticker_decisions = run.per_ticker_decisions or [] | |
| run.completed_at = datetime.now(timezone.utc) | |
| await db.commit() | |
| return { | |
| "status": "failed", | |
| "run_id": str(run.id), | |
| "error": str(exc), | |
| } | |
| return {"status": "failed", "error": str(exc)} | |
| finally: | |
| # Advisory locks are connection-scoped; ensure we release on the same session. | |
| await db.execute(select(func.pg_advisory_unlock(lock_key))) | |