Spaces:
Sleeping
Sleeping
| # backend.py | |
| """ | |
| Backend module for MCP Agent | |
| Handles all the MCP server connections, LLM setup, and agent logic | |
| """ | |
| import sys | |
| import os | |
| import re | |
| import asyncio | |
| from dotenv import load_dotenv | |
| from typing import Optional, Dict, List, Any | |
| from pathlib import Path # already imported | |
| here = Path(__file__).parent.resolve() | |
| ################### --- Auth setup --- ################### | |
| ########################################################## | |
| HF = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| if not HF: | |
| print("WARNING: HF_TOKEN not set. The app will start, but model calls may fail.") | |
| else: | |
| os.environ["HF_TOKEN"] = HF | |
| os.environ["HUGGINGFACEHUB_API_TOKEN"] = HF | |
| try: | |
| from huggingface_hub import login | |
| login(token=HF) | |
| except Exception: | |
| pass | |
| # --- LangChain / MCP --- | |
| from langgraph.prebuilt import create_react_agent | |
| from langchain_mcp_adapters.client import MultiServerMCPClient | |
| from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace | |
| from langchain_core.messages import HumanMessage | |
| # First choice, then free/tiny fallbacks | |
| CANDIDATE_MODELS = [ | |
| os.getenv("HF_MODEL_ID", "Qwen/Qwen2.5-7B-Instruct"), | |
| "HuggingFaceTB/SmolLM3-3B-Instruct", | |
| "Qwen/Qwen2.5-1.5B-Instruct", | |
| "microsoft/Phi-3-mini-4k-instruct", | |
| ] | |
| SYSTEM_PROMPT = ( | |
| "You are an AI assistant with tools.\n" | |
| "- Use the arithmetic tools (`add`, `minus`, `multiply`, `divide`) for arithmetic or multi-step calculations.\n" | |
| "- Use the stock tools (`get_stock_price`, `get_market_summary`, `get_company_news`) for financial/market queries.\n" | |
| "- Otherwise, answer directly with your own knowledge.\n" | |
| "Be concise and accurate. Only call tools when they clearly help." | |
| ) | |
| ################### --- ROUTER helpers --- ################### | |
| ############################################################## | |
| # Detect stock/financial intent | |
| def is_stock_query(q: str) -> bool: | |
| """Check if query is about stocks, markets, or financial data.""" | |
| stock_patterns = [ | |
| r"\b(stock|share|price|ticker|market|nasdaq|dow|s&p|spy|qqq)\b", | |
| r"\b(AAPL|GOOGL|MSFT|TSLA|AMZN|META|NVDA|AMD)\b", # Common tickers | |
| r"\$[A-Z]{1,5}\b", # $SYMBOL format | |
| r"\b(trading|invest|portfolio|earnings|dividend)\b", | |
| r"\b(bull|bear|rally|crash|volatility)\b", | |
| ] | |
| return any(re.search(pattern, q, re.I) for pattern in stock_patterns) | |
| # Extract ticker symbol from query | |
| def extract_ticker(q: str) -> str: | |
| """Extract stock ticker from query.""" | |
| # Check for $SYMBOL format first | |
| dollar_match = re.search(r"\$([A-Z]{1,5})\b", q, re.I) | |
| if dollar_match: | |
| return dollar_match.group(1).upper() | |
| # Check for common patterns like "price of AAPL" or "AAPL stock" | |
| patterns = [ | |
| r"(?:price of|stock price of|quote for)\s+([A-Z]{1,5})\b", | |
| r"\b([A-Z]{1,5})\s+(?:stock|share|price|quote)", | |
| r"(?:what is|what's|get)\s+([A-Z]{1,5})\b", | |
| ] | |
| for pattern in patterns: | |
| match = re.search(pattern, q, re.I) | |
| if match: | |
| return match.group(1).upper() | |
| # Look for standalone uppercase tickers | |
| words = q.split() | |
| for word in words: | |
| clean_word = word.strip(".,!?") | |
| if 2 <= len(clean_word) <= 5 and clean_word.isupper(): | |
| return clean_word | |
| return None | |
| # Check if asking for market summary | |
| def wants_market_summary(q: str) -> bool: | |
| """Check if user wants overall market summary.""" | |
| patterns = [ | |
| r"\bmarket\s+(?:summary|overview|today|status)\b", | |
| r"\bhow(?:'s| is) the market\b", | |
| r"\b(?:dow|nasdaq|s&p)\s+(?:today|now)\b", | |
| r"\bmarket indices\b", | |
| ] | |
| return any(re.search(pattern, q, re.I) for pattern in patterns) | |
| # Check if asking for news | |
| def wants_news(q: str) -> bool: | |
| """Check if user wants company news.""" | |
| return bool(re.search(r"\b(news|headline|announcement|update)\b", q, re.I)) | |
| def build_tool_map(tools): | |
| mp = {t.name: t for t in tools} | |
| return mp | |
| def find_tool(name: str, tool_map: dict): | |
| name = name.lower() | |
| for k, t in tool_map.items(): | |
| kl = k.lower() | |
| if kl == name or kl.endswith("/" + name): | |
| return t | |
| return None | |
| async def build_chat_llm_with_fallback(): | |
| """ | |
| Try each candidate model. For each: | |
| - create HuggingFaceEndpoint + ChatHuggingFace | |
| - do a tiny 'ping' with a proper LC message to trigger routing | |
| On 402/Payment Required (or other errors), fall through to next. | |
| """ | |
| last_err = None | |
| for mid in CANDIDATE_MODELS: | |
| try: | |
| llm = HuggingFaceEndpoint( | |
| repo_id=mid, | |
| huggingfacehub_api_token=HF, | |
| temperature=0.1, | |
| max_new_tokens=256, # Increased for better responses | |
| ) | |
| model = ChatHuggingFace(llm=llm) | |
| # PROBE with a valid message type | |
| _ = await model.ainvoke([HumanMessage(content="ping")]) | |
| print(f"[LLM] Using: {mid}") | |
| return model | |
| except Exception as e: | |
| msg = str(e) | |
| if "402" in msg or "Payment Required" in msg: | |
| print(f"[LLM] {mid} requires payment; trying next...") | |
| last_err = e | |
| continue | |
| print(f"[LLM] {mid} error: {e}; trying next...") | |
| last_err = e | |
| continue | |
| raise RuntimeError(f"Could not initialize any candidate model. Last error: {last_err}") | |
| # NEW: Class to manage the MCP Agent (moved from main function) | |
| class MCPAgent: | |
| def __init__(self): | |
| self.client = None | |
| self.agent = None | |
| self.tool_map = None | |
| self.tools = None | |
| self.model = None | |
| self.initialized = False | |
| async def initialize(self): | |
| """Initialize the MCP client and agent""" | |
| if self.initialized: | |
| return | |
| # Start the Stock server separately first: `python stockserver.py` | |
| self.client = MultiServerMCPClient({ | |
| "arithmetic": { | |
| "command": sys.executable, | |
| "args": [str(here / "arithmetic_server.py")], | |
| "transport": "stdio", | |
| }, | |
| "stocks": { | |
| "command": sys.executable, | |
| "args": [str(here / "stock_server.py")], | |
| "transport": "stdio", | |
| }, | |
| } | |
| ) | |
| # 1. MCP client + tools | |
| self.tools = await self.client.get_tools() | |
| self.tool_map = build_tool_map(self.tools) | |
| # 2. Build LLM with auto-fallback | |
| self.model = await build_chat_llm_with_fallback() | |
| # Build the ReAct agent with MCP tools | |
| self.agent = create_react_agent(self.model, self.tools) | |
| self.initialized = True | |
| return list(self.tool_map.keys()) # Return available tools | |
| async def process_message(self, user_text: str, history: List[Dict]) -> str: | |
| """Process a single message with the agent""" | |
| if not self.initialized: | |
| await self.initialize() | |
| # Try direct stock tool routing first | |
| if is_stock_query(user_text): | |
| if wants_market_summary(user_text): | |
| market_tool = find_tool("get_market_summary", self.tool_map) | |
| if market_tool: | |
| return await market_tool.ainvoke({}) | |
| elif wants_news(user_text): | |
| ticker = extract_ticker(user_text) | |
| if ticker: | |
| news_tool = find_tool("get_company_news", self.tool_map) | |
| if news_tool: | |
| return await news_tool.ainvoke({"symbol": ticker, "limit": 3}) | |
| else: | |
| ticker = extract_ticker(user_text) | |
| if ticker: | |
| price_tool = find_tool("get_stock_price", self.tool_map) | |
| if price_tool: | |
| return await price_tool.ainvoke({"symbol": ticker}) | |
| # Fall back to agent for everything else | |
| messages = [{"role": "system", "content": SYSTEM_PROMPT}] + history + [ | |
| {"role": "user", "content": user_text} | |
| ] | |
| result = await self.agent.ainvoke({"messages": messages}) | |
| return result["messages"][-1].content | |
| async def cleanup(self): | |
| """Clean up resources""" | |
| if self.client: | |
| close = getattr(self.client, "close", None) | |
| if callable(close): | |
| res = close() | |
| if asyncio.iscoroutine(res): | |
| await res | |
| # NEW: Singleton instance for the agent | |
| _agent_instance = None | |
| def get_agent() -> MCPAgent: | |
| """Get or create the singleton agent instance""" | |
| global _agent_instance | |
| if _agent_instance is None: | |
| _agent_instance = MCPAgent() | |
| return _agent_instance | |