DBHUB_MCP_SERVER / agent.py
Tamannathakur's picture
Update agent.py
57c4ad7 verified
import os
import asyncio
import logging
from typing import TypedDict, List
from pathlib import Path
from dotenv import load_dotenv
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_groq import ChatGroq
from langchain_core.messages import HumanMessage, BaseMessage
from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import create_react_agent
from langgraph.checkpoint.memory import MemorySaver
from prompt import SYSTEM_PROMPT
from tools import TOML_TOOLS
logger = logging.getLogger(__name__)
def _build_system_prompt(mcp_tool_count: int) -> str:
if mcp_tool_count == 0:
warning = (
"\n\n"
"CRITICAL SYSTEM CONSTRAINT — READ CAREFULLY:\n"
"The database backend (DBHub MCP) failed to connect at startup. "
"As a result, the tool 'execute_sql' is NOT registered and does NOT exist in your tool list. "
"\n\nYOU MUST NOT call 'execute_sql' or any database query tool under ANY circumstances. "
"Doing so will cause an immediate API error and crash the request. "
"There are NO exceptions to this rule — even if the user asks you to query a database, you CANNOT do it right now. "
"\n\nInstead, you MUST tell the user: 'The database connection is currently unavailable. "
"Please restart the server or check the logs, then try again.' "
"\n\nYou may ONLY use these tools: view_databases, add_database, edit_database, delete_database."
)
return SYSTEM_PROMPT + warning
return SYSTEM_PROMPT
load_dotenv()
_raw_keys = os.getenv("GROQ_API_KEYS", "") or os.getenv("GROQ_API_KEY", "")
GROQ_API_KEYS: list[str] = [k.strip() for k in _raw_keys.split(",") if k.strip()]
_key_index = 0
if not GROQ_API_KEYS:
raise RuntimeError("No Groq API keys found. Set GROQ_API_KEYS in .env")
logger.info(f"[Agent] Loaded {len(GROQ_API_KEYS)} Groq API key(s).")
TOML_PATH = Path("/app/dbhub.toml").resolve()
MCP_SERVERS = {
"dbhub": {
"transport": "stdio",
"command": "npx",
"args": ["-y", "@bytebase/dbhub@latest", "--transport", "stdio", "--config", TOML_PATH],
}
}
MEMORY = MemorySaver()
logger.info("[Agent] Memory: MemorySaver (in-memory)")
workflow = None
mcp_client = None
tool_count = 0
_reinitializing = False
_cached_tools = []
def _current_api_key() -> str:
return GROQ_API_KEYS[_key_index % len(GROQ_API_KEYS)]
def _rotate_api_key() -> str:
global _key_index
_key_index = (_key_index + 1) % len(GROQ_API_KEYS)
logger.warning(f"[Agent] Rotated to Groq API key #{_key_index + 1}/{len(GROQ_API_KEYS)}")
return GROQ_API_KEYS[_key_index]
def _create_llm(api_key: str = None) -> ChatGroq:
return ChatGroq(
api_key=api_key or _current_api_key(),
model="llama-3.1-8b-instant",
temperature=0.1,
max_retries=0
)
async def _get_mcp_tools():
global mcp_client
toml_expected = 0
toml_config = {}
try:
import toml as toml_lib
with open(TOML_PATH, "r") as f:
toml_config = toml_lib.load(f)
toml_expected = len(toml_config.get("tools", []))
except Exception:
pass
tools: list = []
for attempt in range(2):
if mcp_client:
try:
await mcp_client.__aexit__(None, None, None)
except Exception:
pass
mcp_client = None
try:
mcp_client = MultiServerMCPClient(MCP_SERVERS)
tools = await asyncio.wait_for(mcp_client.get_tools(), timeout=90.0)
except (asyncio.TimeoutError, Exception) as conn_err:
if isinstance(conn_err, asyncio.TimeoutError):
print(f"[Agent] MCP timed out after 30s (attempt {attempt + 1}/2).")
else:
print(f"[Agent] MCP connection error (attempt {attempt + 1}/2): {conn_err}")
if attempt < 1:
print("[Agent] Retrying in 3s...")
await asyncio.sleep(3)
else:
print("[Agent] All MCP attempts failed. Starting with 0 MCP tools.")
continue
loaded = len(tools)
loaded_names = {t.name for t in tools}
label = f"{loaded}/{toml_expected}" if toml_expected else str(loaded)
logger.info(f"[Agent] MCP tools loaded: {label} (attempt {attempt + 1}/2)")
if toml_expected == 0 or loaded >= toml_expected:
break
all_sources = [s.get("id", "") for s in toml_config.get("sources", [])]
missing = [
src for src in all_sources
if not any(src.replace("-", "_") in n or src in n for n in loaded_names)
]
if missing:
logger.warning(f"[Agent] Missing DB connections: {missing}")
if attempt < 1:
wait = 6
logger.info(f"[Agent] Retrying tool load in {wait}s...")
await asyncio.sleep(wait)
return tools
def _build_graph(all_tools: list, mcp_tool_count: int = 0):
llm = _create_llm()
system_prompt = _build_system_prompt(mcp_tool_count)
import inspect
_sig = inspect.signature(create_react_agent)
if "prompt" in _sig.parameters:
return create_react_agent(
model=llm,
tools=all_tools,
prompt=system_prompt,
checkpointer=MEMORY,
)
else:
return create_react_agent(
model=llm,
tools=all_tools,
state_modifier=system_prompt,
checkpointer=MEMORY,
)
async def initialize_agent():
global workflow, tool_count, MEMORY, _cached_tools
if workflow is not None:
print("[Agent] Already initialized, skipping...")
return
logger.info("[Agent] Initializing...")
mcp_tools = await _get_mcp_tools()
tool_count = len(mcp_tools)
logger.info(f"[Agent] Loaded {tool_count} MCP tool(s) from DBHub:")
for t in mcp_tools:
logger.info(f" • {t.name}")
logger.info(f"[Agent] Loaded {len(TOML_TOOLS)} TOML management tool(s):")
for t in TOML_TOOLS:
logger.info(f" • {t.name}")
all_tools = mcp_tools + TOML_TOOLS
_cached_tools = all_tools
workflow = _build_graph(all_tools, mcp_tool_count=tool_count)
logger.info("[Agent] ✔ Ready!\n")
async def reinitialize_agent():
global workflow, tool_count, _reinitializing, _cached_tools
if _reinitializing:
print("[Agent] Reinit already in progress, skipping duplicate request.")
return
_reinitializing = True
logger.info("[Agent] Reinitializing after TOML change (hot-swap)...")
try:
await asyncio.sleep(3)
mcp_tools = await _get_mcp_tools()
new_tool_count = len(mcp_tools)
logger.info(f"[Agent] Tool count: {tool_count}{new_tool_count}")
for t in mcp_tools:
logger.info(f" • {t.name}")
all_tools = mcp_tools + TOML_TOOLS
_cached_tools = all_tools
new_workflow = _build_graph(all_tools, mcp_tool_count=new_tool_count)
workflow = new_workflow
tool_count = new_tool_count
logger.info("[Agent] ✔ Reinitialized successfully!\n")
except Exception as e:
logger.error(f"[Agent] Reinit failed: {e}. Keeping existing workflow.")
finally:
_reinitializing = False
# async def chat(query: str, thread_id: str) -> str:
# global workflow, tool_count
# MAX_WAIT_SECONDS = 90
# POLL_INTERVAL = 1
# waited = 0
# while workflow is None:
# if waited >= MAX_WAIT_SECONDS:
# raise RuntimeError(
# "Agent initialization timed out after 90 seconds. "
# "Please check the server logs and try again."
# )
# print(f"[Chat] Agent reinitializing... waiting ({waited}s)")
# await asyncio.sleep(POLL_INTERVAL)
# waited += POLL_INTERVAL
# last_exc = None
# for attempt in range(len(GROQ_API_KEYS)):
# try:
# result = await workflow.ainvoke(
# {"messages": [HumanMessage(content=query)]},
# config={"configurable": {"thread_id": thread_id}, "recursion_limit": 50},
# )
# break
# except Exception as e:
# err_str = str(e)
# if "tool_use_failed" in err_str or "missing properties" in err_str:
# logger.warning(f"[Chat] tool_use_failed — query was too vague: {query!r}")
# return (
# "⚠️ Your request was a bit too vague for me to generate a query. "
# "Could you be more specific? For example:\n"
# "- **\"Show me all tables in analytics\"**\n"
# "- **\"How many rows are in the users table in analytics?\"**\n"
# "- **\"Preview data from analytics\"**"
# )
# if any(code in err_str for code in ["429", "401", "rate_limit", "unauthorized", "invalid_api_key"]):
# if attempt < len(GROQ_API_KEYS) - 1:
# new_key = _rotate_api_key()
# reason = "Rate limited" if "429" in err_str else "Invalid key"
# logger.warning(f"[Chat] {reason} — rotating key and rebuilding agent...")
# workflow = _build_graph(_cached_tools, mcp_tool_count=tool_count)
# last_exc = e
# continue
# else:
# logger.error("[Chat] All Groq API keys are exhausted or invalid.")
# return "Temporary API error. All keys are either limited or incorrect. Please check your HF Secrets."
# raise
# else:
# raise last_exc
async def chat(query: str, thread_id: str) -> str:
global workflow, tool_count, _key_index
waited = 0
while workflow is None:
if waited >= 90:
raise RuntimeError("Agent initialization timed out.")
await asyncio.sleep(1)
waited += 1
last_exc = None
# TRY EVERY SINGLE KEY
for attempt in range(len(GROQ_API_KEYS)):
current_num = (_key_index % len(GROQ_API_KEYS)) + 1
try:
print(f"[Chat] Attempt {attempt+1}/{len(GROQ_API_KEYS)} using Key #{current_num}")
result = await workflow.ainvoke(
{"messages": [HumanMessage(content=query)]},
config={"configurable": {"thread_id": thread_id}, "recursion_limit": 50},
)
return self._extract_content(result) # I'll add the extract helper below
except Exception as e:
err_str = str(e).lower()
last_exc = e
# Check for Rate Limit (429) or Invalid (401)
is_fail = any(c in err_str for c in ["429", "401", "rate_limit", "unauthorized", "invalid_api_key"])
if is_fail:
if attempt < len(GROQ_API_KEYS) - 1:
reason = "RATE LIMITED" if "429" in err_str else "INVALID (401)"
print(f"[Chat] Key #{current_num} FAILED ({reason}). Rotating...")
_rotate_api_key()
# Rebuild instantly
workflow = _build_graph(_cached_tools, mcp_tool_count=tool_count)
continue
else:
print(f"[Chat] ALL {len(GROQ_API_KEYS)} KEYS FAILED.")
return f" All {len(GROQ_API_KEYS)} keys failed. Last error: {err_str}"
print(f"[Chat] Unexpected error: {err_str}")
raise e
last_message = result["messages"][-1]
content = last_message.content
if isinstance(content, list):
parts = [
block.get("text", "") if isinstance(block, dict) else str(block)
for block in content
]
return "\n".join(parts).strip()
logger.debug(f"[Chat] CONTENT: {str(content)}")
return str(content)
def get_agent_status() -> dict:
return {
"agent_ready": workflow is not None,
"tools_loaded": tool_count,
"reinitializing": _reinitializing,
}
def _extract_content(self, result: dict) -> str:
last_message = result["messages"][-1]
content = last_message.content
if isinstance(content, list):
return "\n".join([b.get("text", "") if isinstance(b, dict) else str(b) for b in content]).strip()
return str(content)