MarketLens / be /agent_tools.py
royzhou01's picture
Initial deploy to Hugging Face Spaces
332f271
"""
Tool definitions and executor for the ReAct agent.
Maps Gemini function declarations to existing backend services.
Includes hybrid caching: frontend context -> server cache -> API call.
"""
import time
import logging
from datetime import datetime, timedelta
from polygon_api import PolygonAPI
from rag_pipeline import ContextRetriever
from sentiment_service import get_sentiment_service
from forecast_service import get_forecast_service
logger = logging.getLogger(__name__)
# -- Tool Schemas (Gemini function declarations) --
TOOL_DECLARATIONS = [
{
"name": "get_stock_quote",
"description": "Get the most recent closing price, open, high, low, and volume for a stock ticker. Use this when the user asks about current price, today's price, or recent trading data.",
"parameters": {
"type": "object",
"properties": {
"ticker": {
"type": "string",
"description": "Stock ticker symbol, e.g. AAPL"
}
},
"required": ["ticker"]
}
},
{
"name": "get_company_info",
"description": "Get detailed company information including name, description, market cap, sector, industry, and exchange. Use this when the user asks about what a company does, its sector, or general company details.",
"parameters": {
"type": "object",
"properties": {
"ticker": {
"type": "string",
"description": "Stock ticker symbol, e.g. AAPL"
}
},
"required": ["ticker"]
}
},
{
"name": "get_financials",
"description": "Get recent financial statements including revenue, net income, gross profit, total assets, and liabilities. Returns the last 4 filing periods. Use this for questions about earnings, revenue, profitability, or balance sheet.",
"parameters": {
"type": "object",
"properties": {
"ticker": {
"type": "string",
"description": "Stock ticker symbol, e.g. AAPL"
}
},
"required": ["ticker"]
}
},
{
"name": "get_news",
"description": "Get recent news articles about a stock. Returns headlines, sources, dates, and descriptions. Use this when the user asks about recent news, headlines, or events.",
"parameters": {
"type": "object",
"properties": {
"ticker": {
"type": "string",
"description": "Stock ticker symbol, e.g. AAPL"
},
"limit": {
"type": "integer",
"description": "Number of articles to return (default 10, max 20)"
}
},
"required": ["ticker"]
}
},
{
"name": "search_knowledge_base",
"description": "Semantic search over previously indexed news articles and research. Use this when the user asks about a specific topic and you need in-depth article content beyond headlines.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Natural language search query"
},
"ticker": {
"type": "string",
"description": "Stock ticker symbol to filter results"
}
},
"required": ["query", "ticker"]
}
},
{
"name": "analyze_sentiment",
"description": "Analyze social media sentiment for a stock by scraping StockTwits, Reddit, and Twitter posts and running FinBERT analysis. This operation takes 10-30 seconds. Use when the user asks about sentiment, social media buzz, or what people think about a stock.",
"parameters": {
"type": "object",
"properties": {
"ticker": {
"type": "string",
"description": "Stock ticker symbol, e.g. AAPL"
}
},
"required": ["ticker"]
}
},
{
"name": "get_price_forecast",
"description": "Get an LSTM neural network price forecast for the next 30 trading days. May take 30-60 seconds if the model needs training. Use when the user asks about price predictions or forecasts.",
"parameters": {
"type": "object",
"properties": {
"ticker": {
"type": "string",
"description": "Stock ticker symbol, e.g. AAPL"
}
},
"required": ["ticker"]
}
},
{
"name": "get_dividends",
"description": "Get dividend payment history including ex-dividend dates, pay dates, and amounts. Use when the user asks about dividends, yield, or dividend history.",
"parameters": {
"type": "object",
"properties": {
"ticker": {
"type": "string",
"description": "Stock ticker symbol, e.g. AAPL"
},
"limit": {
"type": "integer",
"description": "Number of dividend records to return (default 10)"
}
},
"required": ["ticker"]
}
},
{
"name": "get_stock_splits",
"description": "Get stock split history including execution dates and split ratios. Use when the user asks about stock splits.",
"parameters": {
"type": "object",
"properties": {
"ticker": {
"type": "string",
"description": "Stock ticker symbol, e.g. AAPL"
}
},
"required": ["ticker"]
}
},
{
"name": "get_price_history",
"description": "Get historical OHLCV price data for a date range. Use when the user asks about price trends, historical performance, or needs to compare prices between dates.",
"parameters": {
"type": "object",
"properties": {
"ticker": {
"type": "string",
"description": "Stock ticker symbol, e.g. AAPL"
},
"from_date": {
"type": "string",
"description": "Start date in YYYY-MM-DD format"
},
"to_date": {
"type": "string",
"description": "End date in YYYY-MM-DD format"
},
"timespan": {
"type": "string",
"description": "Time interval for each bar: day, week, or month (default: day)"
}
},
"required": ["ticker", "from_date", "to_date"]
}
}
]
class ToolCache:
"""Server-side TTL cache for tool results (Layer 2)."""
def __init__(self, ttl_seconds=300):
self._cache = {}
self._ttl = ttl_seconds
def get(self, key):
entry = self._cache.get(key)
if entry and time.time() - entry["ts"] < self._ttl:
return entry["data"]
return None
def set(self, key, data):
self._cache[key] = {"data": data, "ts": time.time()}
class ToolExecutor:
"""
Executes tool calls using a 3-layer cache strategy:
Layer 1: Frontend context (passed per-request via set_context)
Layer 2: Server-side TTL cache (ToolCache, 5-min TTL)
Layer 3: Live API call (last resort)
"""
def __init__(self, polygon_api, context_retriever, vector_store):
self.polygon = polygon_api
self.context_retriever = context_retriever
self.vector_store = vector_store
self.sentiment_service = get_sentiment_service(vector_store)
self.forecast_service = get_forecast_service()
self.server_cache = ToolCache(ttl_seconds=300)
# Frontend context for the current request (Layer 1)
self._frontend_context = {}
self._context_ticker = None
self._handlers = {
"get_stock_quote": self._get_stock_quote,
"get_company_info": self._get_company_info,
"get_financials": self._get_financials,
"get_news": self._get_news,
"search_knowledge_base": self._search_knowledge_base,
"analyze_sentiment": self._analyze_sentiment,
"get_price_forecast": self._get_price_forecast,
"get_dividends": self._get_dividends,
"get_stock_splits": self._get_stock_splits,
"get_price_history": self._get_price_history,
}
def set_context(self, frontend_context, ticker):
"""Prime Layer 1 cache with frontend-provided data for this request."""
self._frontend_context = frontend_context or {}
self._context_ticker = ticker.upper() if ticker else None
def execute(self, tool_name, args):
handler = self._handlers.get(tool_name)
if not handler:
return {"error": f"Unknown tool: {tool_name}"}
try:
return handler(**args)
except Exception as e:
logger.error(f"Tool {tool_name} failed: {e}")
return {"error": f"Tool execution failed: {str(e)}"}
def _check_frontend_context(self, tool_name, ticker):
"""Layer 1: Check if frontend already sent this data."""
if ticker.upper() != self._context_ticker:
return None
mapping = {
"get_stock_quote": "previousClose",
"get_company_info": "details",
"get_financials": "financials",
"get_news": "news",
"get_dividends": "dividends",
"get_stock_splits": "splits",
"analyze_sentiment": "sentiment",
}
context_key = mapping.get(tool_name)
if not context_key:
return None
overview = self._frontend_context.get("overview", {})
# Some keys are nested under overview
if context_key in ("previousClose", "details"):
data = overview.get(context_key)
else:
data = self._frontend_context.get(context_key)
return data if data else None
def _check_server_cache(self, tool_name, ticker):
"""Layer 2: Check server-side TTL cache."""
return self.server_cache.get(f"{tool_name}:{ticker}")
def _cache_result(self, tool_name, ticker, result):
"""Store result in server-side cache."""
self.server_cache.set(f"{tool_name}:{ticker}", result)
# -- Tool Handlers --
def _get_stock_quote(self, ticker):
ticker = ticker.upper()
# Layer 1: Frontend context
fe_data = self._check_frontend_context("get_stock_quote", ticker)
if fe_data:
results = fe_data.get("results", [])
if results:
r = results[0]
return {
"ticker": ticker,
"close": r.get("c"),
"open": r.get("o"),
"high": r.get("h"),
"low": r.get("l"),
"volume": r.get("v"),
"vwap": r.get("vw"),
"source": "cached"
}
# Layer 2: Server cache
cached = self._check_server_cache("get_stock_quote", ticker)
if cached:
return cached
# Layer 3: API call
data = self.polygon.get_previous_close(ticker)
results = data.get("results", [])
if not results:
return {"error": "No quote data available"}
r = results[0]
result = {
"ticker": ticker,
"close": r.get("c"),
"open": r.get("o"),
"high": r.get("h"),
"low": r.get("l"),
"volume": r.get("v"),
"vwap": r.get("vw"),
}
self._cache_result("get_stock_quote", ticker, result)
return result
def _get_company_info(self, ticker):
ticker = ticker.upper()
fe_data = self._check_frontend_context("get_company_info", ticker)
if fe_data:
r = fe_data.get("results", fe_data)
return {
"ticker": ticker,
"name": r.get("name"),
"description": r.get("description", "")[:500],
"market_cap": r.get("market_cap"),
"sector": r.get("sic_description"),
"homepage_url": r.get("homepage_url"),
"total_employees": r.get("total_employees"),
"source": "cached"
}
cached = self._check_server_cache("get_company_info", ticker)
if cached:
return cached
data = self.polygon.get_ticker_details(ticker)
r = data.get("results", {})
if not r:
return {"error": "No company data available"}
result = {
"ticker": ticker,
"name": r.get("name"),
"description": r.get("description", "")[:500],
"market_cap": r.get("market_cap"),
"sector": r.get("sic_description"),
"homepage_url": r.get("homepage_url"),
"total_employees": r.get("total_employees"),
}
self._cache_result("get_company_info", ticker, result)
return result
def _get_financials(self, ticker):
ticker = ticker.upper()
fe_data = self._check_frontend_context("get_financials", ticker)
if fe_data:
return self._format_financials(ticker, fe_data)
cached = self._check_server_cache("get_financials", ticker)
if cached:
return cached
data = self.polygon.get_financials(ticker)
result = self._format_financials(ticker, data)
self._cache_result("get_financials", ticker, result)
return result
def _format_financials(self, ticker, data):
results = data.get("results", [])
if not results:
return {"error": "No financial data available"}
periods = []
for r in results[:4]:
financials = r.get("financials", {})
income = financials.get("income_statement", {})
balance = financials.get("balance_sheet", {})
periods.append({
"period": f"{r.get('fiscal_period', '')} {r.get('fiscal_year', '')}",
"revenue": income.get("revenues", {}).get("value"),
"net_income": income.get("net_income_loss", {}).get("value"),
"gross_profit": income.get("gross_profit", {}).get("value"),
"total_assets": balance.get("assets", {}).get("value"),
"total_liabilities": balance.get("liabilities", {}).get("value"),
})
return {"ticker": ticker, "periods": periods}
def _get_news(self, ticker, limit=10):
ticker = ticker.upper()
limit = min(limit or 10, 20)
fe_data = self._check_frontend_context("get_news", ticker)
if fe_data:
return self._format_news(ticker, fe_data)
cached = self._check_server_cache("get_news", ticker)
if cached:
return cached
data = self.polygon.get_ticker_news(ticker, limit=limit)
result = self._format_news(ticker, data)
self._cache_result("get_news", ticker, result)
return result
def _format_news(self, ticker, data):
articles = data.get("results", data if isinstance(data, list) else [])
if not articles:
return {"error": "No news articles available"}
formatted = []
for a in articles[:10]:
formatted.append({
"title": a.get("title", ""),
"source": a.get("publisher", {}).get("name", "Unknown") if isinstance(a.get("publisher"), dict) else a.get("publisher", "Unknown"),
"published": a.get("published_utc", "")[:10],
"description": a.get("description", "")[:200],
"url": a.get("article_url", ""),
})
return {"ticker": ticker, "articles": formatted}
def _search_knowledge_base(self, query, ticker):
"""Search FAISS vector store — no caching, always live search."""
ticker = ticker.upper()
contexts = self.context_retriever.retrieve_context(query, ticker)
if not contexts:
return {"message": "No relevant articles found in knowledge base. Try using get_news for recent headlines."}
results = []
for ctx in contexts[:5]:
meta = ctx["metadata"]
results.append({
"title": meta.get("title", "Untitled"),
"source": meta.get("source", "Unknown"),
"date": meta.get("published_date", "")[:10],
"content": meta.get("full_content", meta.get("content_preview", ""))[:500],
"relevance_score": round(ctx["score"], 3),
})
return {"ticker": ticker, "results": results}
def _analyze_sentiment(self, ticker):
ticker = ticker.upper()
# Layer 1: Frontend context (already-analyzed sentiment)
fe_data = self._check_frontend_context("analyze_sentiment", ticker)
if fe_data:
aggregate = fe_data.get("aggregate", fe_data)
posts = fe_data.get("posts", [])
return {
"ticker": ticker,
"overall_sentiment": aggregate.get("label"),
"score": aggregate.get("score"),
"confidence": aggregate.get("confidence"),
"post_count": aggregate.get("post_count"),
"sources": aggregate.get("sources", {}),
"top_posts": [
{
"platform": p.get("platform"),
"content": p.get("content", "")[:200],
"sentiment": p.get("sentiment", {}).get("label", p.get("sentiment_label", "")),
}
for p in posts[:5]
],
"source": "cached"
}
cached = self._check_server_cache("analyze_sentiment", ticker)
if cached:
return cached
# Layer 3: Live scrape + analysis (slow, 10-30s)
data = self.sentiment_service.analyze_ticker(ticker)
aggregate = data.get("aggregate", {})
posts = data.get("posts", [])
result = {
"ticker": ticker,
"overall_sentiment": aggregate.get("label"),
"score": aggregate.get("score"),
"confidence": aggregate.get("confidence"),
"post_count": aggregate.get("post_count"),
"sources": aggregate.get("sources", {}),
"top_posts": [
{
"platform": p.get("platform"),
"content": p.get("content", "")[:200],
"sentiment": p.get("sentiment", {}).get("label", ""),
}
for p in posts[:5]
],
}
self._cache_result("analyze_sentiment", ticker, result)
return result
def _get_price_forecast(self, ticker):
ticker = ticker.upper()
cached = self._check_server_cache("get_price_forecast", ticker)
if cached:
return cached
data = self.forecast_service.get_forecast(ticker)
if "error" in data:
return {"error": data["error"]}
forecast = data.get("forecast", [])
result = {
"ticker": ticker,
"predictions": [
{
"date": f.get("date"),
"predicted_close": f.get("predicted_close"),
"upper_bound": f.get("upper_bound"),
"lower_bound": f.get("lower_bound"),
}
for f in forecast[:10] # First 10 days to keep context manageable
],
"model_info": data.get("model_info", {}),
}
self._cache_result("get_price_forecast", ticker, result)
return result
def _get_dividends(self, ticker, limit=10):
ticker = ticker.upper()
fe_data = self._check_frontend_context("get_dividends", ticker)
if fe_data:
return self._format_dividends(ticker, fe_data)
cached = self._check_server_cache("get_dividends", ticker)
if cached:
return cached
data = self.polygon.get_dividends(ticker, limit=limit or 10)
result = self._format_dividends(ticker, data)
self._cache_result("get_dividends", ticker, result)
return result
def _format_dividends(self, ticker, data):
results = data.get("results", data if isinstance(data, list) else [])
if not results:
return {"message": "No dividend data available for this ticker."}
formatted = []
for d in results[:10]:
formatted.append({
"ex_date": d.get("ex_dividend_date", ""),
"pay_date": d.get("pay_date", ""),
"amount": d.get("cash_amount"),
"frequency": d.get("frequency"),
})
return {"ticker": ticker, "dividends": formatted}
def _get_stock_splits(self, ticker):
ticker = ticker.upper()
fe_data = self._check_frontend_context("get_stock_splits", ticker)
if fe_data:
return self._format_splits(ticker, fe_data)
cached = self._check_server_cache("get_stock_splits", ticker)
if cached:
return cached
data = self.polygon.get_splits(ticker)
result = self._format_splits(ticker, data)
self._cache_result("get_stock_splits", ticker, result)
return result
def _format_splits(self, ticker, data):
results = data.get("results", data if isinstance(data, list) else [])
if not results:
return {"message": "No stock split history found for this ticker."}
formatted = []
for s in results[:10]:
formatted.append({
"execution_date": s.get("execution_date", ""),
"split_from": s.get("split_from"),
"split_to": s.get("split_to"),
"ratio": f"{s.get('split_to', 1)}-for-{s.get('split_from', 1)}",
})
return {"ticker": ticker, "splits": formatted}
def _get_price_history(self, ticker, from_date, to_date, timespan=None):
ticker = ticker.upper()
timespan = timespan or "day"
cache_key = f"get_price_history:{ticker}:{from_date}:{to_date}:{timespan}"
cached = self.server_cache.get(cache_key)
if cached:
return cached
data = self.polygon.get_aggregates(ticker, timespan=timespan, from_date=from_date, to_date=to_date)
results = data.get("results", [])
if not results:
return {"error": "No price history available for the given date range"}
formatted = []
for bar in results:
formatted.append({
"date": datetime.fromtimestamp(bar["t"] / 1000).strftime("%Y-%m-%d"),
"open": bar.get("o"),
"high": bar.get("h"),
"low": bar.get("l"),
"close": bar.get("c"),
"volume": bar.get("v"),
})
result = {
"ticker": ticker,
"timespan": timespan,
"from": from_date,
"to": to_date,
"bars": formatted,
"count": len(formatted),
}
self.server_cache.set(cache_key, result)
return result