Spaces:
Sleeping
Sleeping
| import os | |
| import asyncio | |
| import logging | |
| from typing import TypedDict, List | |
| from pathlib import Path | |
| from dotenv import load_dotenv | |
| from langchain_mcp_adapters.client import MultiServerMCPClient | |
| from langchain_groq import ChatGroq | |
| from langchain_core.messages import HumanMessage, BaseMessage | |
| from langgraph.graph import StateGraph, START, END | |
| from langgraph.prebuilt import create_react_agent | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from prompt import SYSTEM_PROMPT | |
| from tools import TOML_TOOLS | |
| logger = logging.getLogger(__name__) | |
| def _build_system_prompt(mcp_tool_count: int) -> str: | |
| if mcp_tool_count == 0: | |
| warning = ( | |
| "\n\n" | |
| "CRITICAL SYSTEM CONSTRAINT — READ CAREFULLY:\n" | |
| "The database backend (DBHub MCP) failed to connect at startup. " | |
| "As a result, the tool 'execute_sql' is NOT registered and does NOT exist in your tool list. " | |
| "\n\nYOU MUST NOT call 'execute_sql' or any database query tool under ANY circumstances. " | |
| "Doing so will cause an immediate API error and crash the request. " | |
| "There are NO exceptions to this rule — even if the user asks you to query a database, you CANNOT do it right now. " | |
| "\n\nInstead, you MUST tell the user: 'The database connection is currently unavailable. " | |
| "Please restart the server or check the logs, then try again.' " | |
| "\n\nYou may ONLY use these tools: view_databases, add_database, edit_database, delete_database." | |
| ) | |
| return SYSTEM_PROMPT + warning | |
| return SYSTEM_PROMPT | |
| load_dotenv() | |
| _raw_keys = os.getenv("GROQ_API_KEYS", "") or os.getenv("GROQ_API_KEY", "") | |
| GROQ_API_KEYS: list[str] = [k.strip() for k in _raw_keys.split(",") if k.strip()] | |
| _key_index = 0 | |
| if not GROQ_API_KEYS: | |
| raise RuntimeError("No Groq API keys found. Set GROQ_API_KEYS in .env") | |
| logger.info(f"[Agent] Loaded {len(GROQ_API_KEYS)} Groq API key(s).") | |
| TOML_PATH = Path("/app/dbhub.toml").resolve() | |
| MCP_SERVERS = { | |
| "dbhub": { | |
| "transport": "stdio", | |
| "command": "npx", | |
| "args": ["-y", "@bytebase/dbhub@latest", "--transport", "stdio", "--config", TOML_PATH], | |
| } | |
| } | |
| MEMORY = MemorySaver() | |
| logger.info("[Agent] Memory: MemorySaver (in-memory)") | |
| workflow = None | |
| mcp_client = None | |
| tool_count = 0 | |
| _reinitializing = False | |
| _cached_tools = [] | |
| def _current_api_key() -> str: | |
| return GROQ_API_KEYS[_key_index % len(GROQ_API_KEYS)] | |
| def _rotate_api_key() -> str: | |
| global _key_index | |
| _key_index = (_key_index + 1) % len(GROQ_API_KEYS) | |
| logger.warning(f"[Agent] Rotated to Groq API key #{_key_index + 1}/{len(GROQ_API_KEYS)}") | |
| return GROQ_API_KEYS[_key_index] | |
| def _create_llm(api_key: str = None) -> ChatGroq: | |
| return ChatGroq( | |
| api_key=api_key or _current_api_key(), | |
| model="llama-3.1-8b-instant", | |
| temperature=0.1, | |
| max_retries=0 | |
| ) | |
| async def _get_mcp_tools(): | |
| global mcp_client | |
| toml_expected = 0 | |
| toml_config = {} | |
| try: | |
| import toml as toml_lib | |
| with open(TOML_PATH, "r") as f: | |
| toml_config = toml_lib.load(f) | |
| toml_expected = len(toml_config.get("tools", [])) | |
| except Exception: | |
| pass | |
| tools: list = [] | |
| for attempt in range(2): | |
| if mcp_client: | |
| try: | |
| await mcp_client.__aexit__(None, None, None) | |
| except Exception: | |
| pass | |
| mcp_client = None | |
| try: | |
| mcp_client = MultiServerMCPClient(MCP_SERVERS) | |
| tools = await asyncio.wait_for(mcp_client.get_tools(), timeout=90.0) | |
| except (asyncio.TimeoutError, Exception) as conn_err: | |
| if isinstance(conn_err, asyncio.TimeoutError): | |
| print(f"[Agent] MCP timed out after 30s (attempt {attempt + 1}/2).") | |
| else: | |
| print(f"[Agent] MCP connection error (attempt {attempt + 1}/2): {conn_err}") | |
| if attempt < 1: | |
| print("[Agent] Retrying in 3s...") | |
| await asyncio.sleep(3) | |
| else: | |
| print("[Agent] All MCP attempts failed. Starting with 0 MCP tools.") | |
| continue | |
| loaded = len(tools) | |
| loaded_names = {t.name for t in tools} | |
| label = f"{loaded}/{toml_expected}" if toml_expected else str(loaded) | |
| logger.info(f"[Agent] MCP tools loaded: {label} (attempt {attempt + 1}/2)") | |
| if toml_expected == 0 or loaded >= toml_expected: | |
| break | |
| all_sources = [s.get("id", "") for s in toml_config.get("sources", [])] | |
| missing = [ | |
| src for src in all_sources | |
| if not any(src.replace("-", "_") in n or src in n for n in loaded_names) | |
| ] | |
| if missing: | |
| logger.warning(f"[Agent] Missing DB connections: {missing}") | |
| if attempt < 1: | |
| wait = 6 | |
| logger.info(f"[Agent] Retrying tool load in {wait}s...") | |
| await asyncio.sleep(wait) | |
| return tools | |
| def _build_graph(all_tools: list, mcp_tool_count: int = 0): | |
| llm = _create_llm() | |
| system_prompt = _build_system_prompt(mcp_tool_count) | |
| import inspect | |
| _sig = inspect.signature(create_react_agent) | |
| if "prompt" in _sig.parameters: | |
| return create_react_agent( | |
| model=llm, | |
| tools=all_tools, | |
| prompt=system_prompt, | |
| checkpointer=MEMORY, | |
| ) | |
| else: | |
| return create_react_agent( | |
| model=llm, | |
| tools=all_tools, | |
| state_modifier=system_prompt, | |
| checkpointer=MEMORY, | |
| ) | |
| async def initialize_agent(): | |
| global workflow, tool_count, MEMORY, _cached_tools | |
| if workflow is not None: | |
| print("[Agent] Already initialized, skipping...") | |
| return | |
| logger.info("[Agent] Initializing...") | |
| mcp_tools = await _get_mcp_tools() | |
| tool_count = len(mcp_tools) | |
| logger.info(f"[Agent] Loaded {tool_count} MCP tool(s) from DBHub:") | |
| for t in mcp_tools: | |
| logger.info(f" • {t.name}") | |
| logger.info(f"[Agent] Loaded {len(TOML_TOOLS)} TOML management tool(s):") | |
| for t in TOML_TOOLS: | |
| logger.info(f" • {t.name}") | |
| all_tools = mcp_tools + TOML_TOOLS | |
| _cached_tools = all_tools | |
| workflow = _build_graph(all_tools, mcp_tool_count=tool_count) | |
| logger.info("[Agent] ✔ Ready!\n") | |
| async def reinitialize_agent(): | |
| global workflow, tool_count, _reinitializing, _cached_tools | |
| if _reinitializing: | |
| print("[Agent] Reinit already in progress, skipping duplicate request.") | |
| return | |
| _reinitializing = True | |
| logger.info("[Agent] Reinitializing after TOML change (hot-swap)...") | |
| try: | |
| await asyncio.sleep(3) | |
| mcp_tools = await _get_mcp_tools() | |
| new_tool_count = len(mcp_tools) | |
| logger.info(f"[Agent] Tool count: {tool_count} → {new_tool_count}") | |
| for t in mcp_tools: | |
| logger.info(f" • {t.name}") | |
| all_tools = mcp_tools + TOML_TOOLS | |
| _cached_tools = all_tools | |
| new_workflow = _build_graph(all_tools, mcp_tool_count=new_tool_count) | |
| workflow = new_workflow | |
| tool_count = new_tool_count | |
| logger.info("[Agent] ✔ Reinitialized successfully!\n") | |
| except Exception as e: | |
| logger.error(f"[Agent] Reinit failed: {e}. Keeping existing workflow.") | |
| finally: | |
| _reinitializing = False | |
| # async def chat(query: str, thread_id: str) -> str: | |
| # global workflow, tool_count | |
| # MAX_WAIT_SECONDS = 90 | |
| # POLL_INTERVAL = 1 | |
| # waited = 0 | |
| # while workflow is None: | |
| # if waited >= MAX_WAIT_SECONDS: | |
| # raise RuntimeError( | |
| # "Agent initialization timed out after 90 seconds. " | |
| # "Please check the server logs and try again." | |
| # ) | |
| # print(f"[Chat] Agent reinitializing... waiting ({waited}s)") | |
| # await asyncio.sleep(POLL_INTERVAL) | |
| # waited += POLL_INTERVAL | |
| # last_exc = None | |
| # for attempt in range(len(GROQ_API_KEYS)): | |
| # try: | |
| # result = await workflow.ainvoke( | |
| # {"messages": [HumanMessage(content=query)]}, | |
| # config={"configurable": {"thread_id": thread_id}, "recursion_limit": 50}, | |
| # ) | |
| # break | |
| # except Exception as e: | |
| # err_str = str(e) | |
| # if "tool_use_failed" in err_str or "missing properties" in err_str: | |
| # logger.warning(f"[Chat] tool_use_failed — query was too vague: {query!r}") | |
| # return ( | |
| # "⚠️ Your request was a bit too vague for me to generate a query. " | |
| # "Could you be more specific? For example:\n" | |
| # "- **\"Show me all tables in analytics\"**\n" | |
| # "- **\"How many rows are in the users table in analytics?\"**\n" | |
| # "- **\"Preview data from analytics\"**" | |
| # ) | |
| # if any(code in err_str for code in ["429", "401", "rate_limit", "unauthorized", "invalid_api_key"]): | |
| # if attempt < len(GROQ_API_KEYS) - 1: | |
| # new_key = _rotate_api_key() | |
| # reason = "Rate limited" if "429" in err_str else "Invalid key" | |
| # logger.warning(f"[Chat] {reason} — rotating key and rebuilding agent...") | |
| # workflow = _build_graph(_cached_tools, mcp_tool_count=tool_count) | |
| # last_exc = e | |
| # continue | |
| # else: | |
| # logger.error("[Chat] All Groq API keys are exhausted or invalid.") | |
| # return "Temporary API error. All keys are either limited or incorrect. Please check your HF Secrets." | |
| # raise | |
| # else: | |
| # raise last_exc | |
| async def chat(query: str, thread_id: str) -> str: | |
| global workflow, tool_count, _key_index | |
| waited = 0 | |
| while workflow is None: | |
| if waited >= 90: | |
| raise RuntimeError("Agent initialization timed out.") | |
| await asyncio.sleep(1) | |
| waited += 1 | |
| last_exc = None | |
| # TRY EVERY SINGLE KEY | |
| for attempt in range(len(GROQ_API_KEYS)): | |
| current_num = (_key_index % len(GROQ_API_KEYS)) + 1 | |
| try: | |
| print(f"[Chat] Attempt {attempt+1}/{len(GROQ_API_KEYS)} using Key #{current_num}") | |
| result = await workflow.ainvoke( | |
| {"messages": [HumanMessage(content=query)]}, | |
| config={"configurable": {"thread_id": thread_id}, "recursion_limit": 50}, | |
| ) | |
| return self._extract_content(result) # I'll add the extract helper below | |
| except Exception as e: | |
| err_str = str(e).lower() | |
| last_exc = e | |
| # Check for Rate Limit (429) or Invalid (401) | |
| is_fail = any(c in err_str for c in ["429", "401", "rate_limit", "unauthorized", "invalid_api_key"]) | |
| if is_fail: | |
| if attempt < len(GROQ_API_KEYS) - 1: | |
| reason = "RATE LIMITED" if "429" in err_str else "INVALID (401)" | |
| print(f"[Chat] Key #{current_num} FAILED ({reason}). Rotating...") | |
| _rotate_api_key() | |
| # Rebuild instantly | |
| workflow = _build_graph(_cached_tools, mcp_tool_count=tool_count) | |
| continue | |
| else: | |
| print(f"[Chat] ALL {len(GROQ_API_KEYS)} KEYS FAILED.") | |
| return f" All {len(GROQ_API_KEYS)} keys failed. Last error: {err_str}" | |
| print(f"[Chat] Unexpected error: {err_str}") | |
| raise e | |
| last_message = result["messages"][-1] | |
| content = last_message.content | |
| if isinstance(content, list): | |
| parts = [ | |
| block.get("text", "") if isinstance(block, dict) else str(block) | |
| for block in content | |
| ] | |
| return "\n".join(parts).strip() | |
| logger.debug(f"[Chat] CONTENT: {str(content)}") | |
| return str(content) | |
| def get_agent_status() -> dict: | |
| return { | |
| "agent_ready": workflow is not None, | |
| "tools_loaded": tool_count, | |
| "reinitializing": _reinitializing, | |
| } | |
| def _extract_content(self, result: dict) -> str: | |
| last_message = result["messages"][-1] | |
| content = last_message.content | |
| if isinstance(content, list): | |
| return "\n".join([b.get("text", "") if isinstance(b, dict) else str(b) for b in content]).strip() | |
| return str(content) | |