U2INVEST / agent_service.py
DasbootU9607
Add TradingAgents timeout fallback
6c768de
import os
import re
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple
from langchain_core.messages import HumanMessage
from market_demo import MARKET_META
from runtime_config import DATA_DIR
TICKER_STOPWORDS = {
"A",
"AI",
"AM",
"API",
"ARE",
"ETF",
"ETFS",
"FOR",
"I",
"IS",
"JSON",
"NOW",
"OF",
"ON",
"PE",
"PM",
"RSI",
"THE",
"TO",
"U2",
"USD",
}
TRADINGAGENTS_TRIGGER_KEYWORDS = {
"analyze",
"analysis",
"deep analysis",
"full analysis",
"multi-agent",
"tradingagents",
"bull case",
"bear case",
"bullish",
"bearish",
"buy",
"sell",
"hold",
"rating",
"outlook",
"thesis",
"price target",
"recommend",
"recommendation",
"risk",
"portfolio",
}
KNOWN_SYMBOLS = sorted(MARKET_META.keys(), key=len, reverse=True)
KNOWN_NAMES = {
meta["name"].lower(): symbol
for symbol, meta in MARKET_META.items()
}
TICKER_PATTERN = re.compile(r"\$?([A-Z]{1,5}(?:-[A-Z]{2,5}|(?:\.[A-Z]{1,4}))?)\b")
DATE_PATTERN = re.compile(r"\b(20\d{2})[-/](\d{2})[-/](\d{2})\b")
def get_agent_backend_mode() -> str:
return os.getenv("AGENT_BACKEND", "auto").strip().lower() or "auto"
def get_tradingagents_provider() -> str:
return os.getenv("TRADINGAGENTS_PROVIDER", "openai").strip().lower() or "openai"
def agent_is_configured() -> bool:
backend = get_agent_backend_mode()
if backend == "legacy":
return bool(os.getenv("DEEPSEEK_API_KEY"))
if backend == "tradingagents":
return _tradingagents_is_configured()
if backend == "auto":
return _tradingagents_is_configured() or bool(os.getenv("DEEPSEEK_API_KEY"))
return bool(os.getenv("DEEPSEEK_API_KEY"))
def run_agent_message(user_message: str, session_id: str) -> Tuple[str, List[Dict[str, Any]], str]:
backend = get_agent_backend_mode()
legacy_configured = bool(os.getenv("DEEPSEEK_API_KEY"))
if backend in {"tradingagents", "auto"}:
trading_request = resolve_tradingagents_request(user_message)
if backend == "tradingagents" or trading_request:
try:
response, tools_used = run_tradingagents_message_with_timeout(
user_message,
trading_request=trading_request,
)
return response, tools_used, "tradingagents"
except Exception as error:
if backend == "tradingagents" and not legacy_configured:
raise RuntimeError(f"TradingAgents backend failed: {error}") from error
print(f"TradingAgents fallback triggered: {error}")
if legacy_configured:
response, tool_results = run_legacy_message(user_message, session_id)
fallback_note = (
"TradingAgents timed out on this deployment, so I used the fast fallback agent.\n\n"
if isinstance(error, TimeoutError)
else "TradingAgents was unavailable on this deployment, so I used the fast fallback agent.\n\n"
)
tool_results = [
{
"tool": "tradingagents_fallback",
"args": {
"reason": str(error),
},
},
*tool_results,
]
return fallback_note + response, tool_results, "legacy-fallback"
response, tools_used = run_legacy_message(user_message, session_id)
return response, tools_used, "legacy"
def run_legacy_message(user_message: str, session_id: str) -> Tuple[str, List[Dict[str, Any]]]:
if not os.getenv("DEEPSEEK_API_KEY"):
raise RuntimeError(
"Legacy agent is not configured. Set DEEPSEEK_API_KEY in your environment."
)
from agent_graph import stock_agent_app
config = {"configurable": {"thread_id": session_id}}
initial_state = {"messages": [HumanMessage(content=user_message)]}
response_content = ""
tool_results: List[Dict[str, Any]] = []
for event in stock_agent_app.stream(initial_state, config):
for output in event.values():
if "messages" not in output:
continue
last_msg = output["messages"][-1]
if hasattr(last_msg, "content") and last_msg.content:
response_content = last_msg.content
if hasattr(last_msg, "tool_calls") and last_msg.tool_calls:
for tool_call in last_msg.tool_calls:
tool_results.append(
{
"tool": tool_call.get("name", "unknown"),
"args": tool_call.get("args", {}),
}
)
return response_content or "Please rephrase your question.", tool_results
def run_tradingagents_message(
user_message: str,
trading_request: Optional[Dict[str, str]] = None,
) -> Tuple[str, List[Dict[str, Any]]]:
trading_request = trading_request or resolve_tradingagents_request(user_message)
if not trading_request:
raise RuntimeError(
"TradingAgents needs a stock ticker or company name in the message."
)
_prime_tradingagents_env()
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.graph.trading_graph import TradingAgentsGraph
config = DEFAULT_CONFIG.copy()
config["llm_provider"] = get_tradingagents_provider()
config["deep_think_llm"] = _resolve_tradingagents_model(
override_name="TRADINGAGENTS_DEEP_MODEL",
default_model=config.get("deep_think_llm", "gpt-5.4"),
)
config["quick_think_llm"] = _resolve_tradingagents_model(
override_name="TRADINGAGENTS_QUICK_MODEL",
default_model=config.get("quick_think_llm", "gpt-5.4-mini"),
)
backend_url = _resolve_tradingagents_backend_url(
default_backend_url=config.get("backend_url", "https://api.openai.com/v1"),
)
if backend_url:
config["backend_url"] = backend_url
config["max_debate_rounds"] = _get_positive_int_env("TRADINGAGENTS_MAX_DEBATE_ROUNDS", 1)
config["max_risk_discuss_rounds"] = _get_positive_int_env(
"TRADINGAGENTS_MAX_RISK_ROUNDS",
1,
)
config["output_language"] = os.getenv("TRADINGAGENTS_OUTPUT_LANGUAGE", "English")
config["results_dir"] = str(DATA_DIR / "tradingagents-logs")
config["data_cache_dir"] = str(DATA_DIR / "tradingagents-cache")
data_vendor = os.getenv("TRADINGAGENTS_DATA_VENDOR", "yfinance").strip().lower() or "yfinance"
config["data_vendors"] = {
"core_stock_apis": data_vendor,
"technical_indicators": data_vendor,
"fundamental_data": data_vendor,
"news_data": data_vendor,
}
selected_analysts = [
analyst.strip()
for analyst in os.getenv(
"TRADINGAGENTS_SELECTED_ANALYSTS",
"market,fundamentals",
).split(",")
if analyst.strip()
]
trading_graph = TradingAgentsGraph(
selected_analysts=selected_analysts,
debug=False,
config=config,
)
full_state, decision = trading_graph.propagate(
trading_request["symbol"],
trading_request["trade_date"],
)
response = build_tradingagents_response(
symbol=trading_request["symbol"],
trade_date=trading_request["trade_date"],
decision=decision,
full_state=full_state,
)
tools_used = [
{
"tool": "tradingagents",
"args": {
"symbol": trading_request["symbol"],
"trade_date": trading_request["trade_date"],
"llm_provider": config["llm_provider"],
"data_vendor": data_vendor,
},
}
]
return response, tools_used
def run_tradingagents_message_with_timeout(
user_message: str,
trading_request: Optional[Dict[str, str]] = None,
) -> Tuple[str, List[Dict[str, Any]]]:
timeout_seconds = _get_positive_int_env("TRADINGAGENTS_TIMEOUT_SECONDS", 25)
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_tradingagents_message, user_message, trading_request)
try:
return future.result(timeout=timeout_seconds)
except FuturesTimeoutError as error:
future.cancel()
raise TimeoutError(
f"TradingAgents exceeded {timeout_seconds}s timeout"
) from error
def resolve_tradingagents_request(user_message: str) -> Optional[Dict[str, str]]:
normalized_message = user_message.strip()
if not normalized_message:
return None
symbol = (
extract_focus_stock(normalized_message)
or extract_known_symbol(normalized_message)
or extract_known_company(normalized_message)
or extract_generic_ticker(normalized_message)
)
if not symbol:
return None
lower_message = normalized_message.lower()
force_for_stocks = os.getenv("TRADINGAGENTS_FORCE_FOR_STOCKS", "false").strip().lower() == "true"
has_trigger = any(keyword in lower_message for keyword in TRADINGAGENTS_TRIGGER_KEYWORDS)
if not force_for_stocks and not has_trigger:
return None
return {
"symbol": symbol,
"trade_date": extract_trade_date(normalized_message),
}
def extract_focus_stock(message: str) -> Optional[str]:
match = re.search(r"focus stocks:\s*(.+)$", message, flags=re.IGNORECASE | re.DOTALL)
if not match:
return None
stock_list = [
item.strip().upper()
for item in re.split(r"[,/\n]", match.group(1))
if item.strip()
]
for item in stock_list:
if item in MARKET_META:
return item
return None
def extract_known_symbol(message: str) -> Optional[str]:
upper_message = message.upper()
for symbol in KNOWN_SYMBOLS:
if re.search(rf"(?<![A-Z0-9]){re.escape(symbol)}(?![A-Z0-9])", upper_message):
return symbol
return None
def extract_known_company(message: str) -> Optional[str]:
lower_message = message.lower()
for company_name, symbol in KNOWN_NAMES.items():
if re.search(rf"\b{re.escape(company_name)}\b", lower_message):
return symbol
return None
def extract_generic_ticker(message: str) -> Optional[str]:
for match in TICKER_PATTERN.finditer(message.upper()):
candidate = match.group(1).strip("$")
if candidate in TICKER_STOPWORDS:
continue
return candidate
return None
def extract_trade_date(message: str) -> str:
explicit_date = DATE_PATTERN.search(message)
if explicit_date:
return f"{explicit_date.group(1)}-{explicit_date.group(2)}-{explicit_date.group(3)}"
lower_message = message.lower()
today = datetime.utcnow().date()
if "yesterday" in lower_message:
return (today - timedelta(days=1)).isoformat()
return today.isoformat()
def build_tradingagents_response(
symbol: str,
trade_date: str,
decision: str,
full_state: Dict[str, Any],
) -> str:
sections = [
"### TradingAgents Decision",
f"- Symbol: {symbol}",
f"- Analysis date: {trade_date}",
f"- Final rating: {decision}",
"",
"### Portfolio Manager",
_truncate_text(full_state.get("final_trade_decision")),
"",
"### Investment Plan",
_truncate_text(full_state.get("investment_plan")),
"",
"### Analyst Highlights",
f"- Market: {_summarize_text(full_state.get('market_report'))}",
f"- Sentiment: {_summarize_text(full_state.get('sentiment_report'))}",
f"- News: {_summarize_text(full_state.get('news_report'))}",
f"- Fundamentals: {_summarize_text(full_state.get('fundamentals_report'))}",
]
return "\n".join(line for line in sections if line is not None and line != "")
def _truncate_text(text: Any, limit: int = 1800) -> str:
cleaned = _clean_text(text)
if not cleaned:
return "No detailed portfolio-manager report was returned."
if len(cleaned) <= limit:
return cleaned
return cleaned[: limit - 3].rstrip() + "..."
def _summarize_text(text: Any, limit: int = 240) -> str:
cleaned = _clean_text(text)
if not cleaned:
return "No analyst report returned."
if len(cleaned) <= limit:
return cleaned
return cleaned[: limit - 3].rstrip() + "..."
def _clean_text(text: Any) -> str:
if text is None:
return ""
return re.sub(r"\s+", " ", str(text)).strip()
def _prime_tradingagents_env() -> None:
provider = get_tradingagents_provider()
if provider == "openai":
if not os.getenv("OPENAI_API_KEY") and os.getenv("DEEPSEEK_API_KEY"):
os.environ["OPENAI_API_KEY"] = os.getenv("DEEPSEEK_API_KEY", "")
if not os.getenv("OPENAI_API_KEY"):
raise RuntimeError(
"TradingAgents openai-compatible provider needs OPENAI_API_KEY or DEEPSEEK_API_KEY."
)
def _tradingagents_is_configured() -> bool:
provider = get_tradingagents_provider()
if provider == "openai":
return bool(os.getenv("OPENAI_API_KEY") or os.getenv("DEEPSEEK_API_KEY"))
provider_key_map = {
"google": "GOOGLE_API_KEY",
"anthropic": "ANTHROPIC_API_KEY",
"xai": "XAI_API_KEY",
"openrouter": "OPENROUTER_API_KEY",
"ollama": "OLLAMA_HOST",
}
required_key = provider_key_map.get(provider)
return bool(required_key and os.getenv(required_key))
def _get_positive_int_env(name: str, default: int) -> int:
raw_value = os.getenv(name, "").strip()
if not raw_value:
return default
try:
parsed = int(raw_value)
except ValueError:
return default
return parsed if parsed > 0 else default
def _resolve_tradingagents_model(override_name: str, default_model: str) -> str:
if os.getenv(override_name):
return os.getenv(override_name, "").strip()
if os.getenv("DEEPSEEK_API_KEY"):
return os.getenv("DEEPSEEK_MODEL", "deepseek-chat").strip()
return default_model
def _resolve_tradingagents_backend_url(default_backend_url: str) -> str:
if os.getenv("TRADINGAGENTS_BACKEND_URL"):
return os.getenv("TRADINGAGENTS_BACKEND_URL", "").strip()
if os.getenv("DEEPSEEK_API_KEY"):
return os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com").strip()
return default_backend_url