import os import sys import logging import json from contextlib import asynccontextmanager from fastapi import FastAPI from pydantic import BaseModel import uvicorn from fastapi.responses import StreamingResponse from langchain_core.messages import ToolMessage, AIMessage from langchain_openai import ChatOpenAI from langgraph.prebuilt import create_react_agent from tools import MCPClient, SchemaSearchTool, JoinPathFinderTool, QueryExecutorTool # --- Configuration & Logging --- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) MCP_URL = os.getenv("MCP_URL", "http://mcp:8000/mcp") API_KEY = os.getenv("MCP_API_KEY", "dev-key-123") LLM_API_KEY = os.getenv("LLM_API_KEY") # --- System Prompt --- SYSTEM_PROMPT = """You are a helpful assistant for querying life sciences databases. You have access to these tools: - schema_search: Find relevant database tables and columns based on keywords - find_join_path: Discover how to join tables together using the knowledge graph - execute_query: Run SQL queries against the databases Always use schema_search first to understand the available data, then construct appropriate SQL queries. When querying, be specific about what tables and columns you're using.""" # --- Agent Initialization --- class GraphRAGAgent: """The core agent for handling GraphRAG queries using LangGraph.""" def __init__(self): if not LLM_API_KEY: raise ValueError("LLM_API_KEY environment variable not set.") llm = ChatOpenAI(api_key=LLM_API_KEY, model="gpt-4o-mini", temperature=0, max_retries=1) mcp_client = MCPClient(mcp_url=MCP_URL, api_key=API_KEY) tools = [ SchemaSearchTool(mcp_client=mcp_client), JoinPathFinderTool(mcp_client=mcp_client), QueryExecutorTool(mcp_client=mcp_client), ] # Use LangGraph's prebuilt create_react_agent for proper message handling self.graph = create_react_agent(llm, tools, state_modifier=SYSTEM_PROMPT) async def stream_query(self, question: str): """Processes a question and streams the intermediate steps.""" try: async for event in self.graph.astream( {"messages": [("user", question)]}, stream_mode="values" ): # create_react_agent uses standard message format messages = event.get("messages", []) if not messages: continue last_message = messages[-1] if isinstance(last_message, AIMessage) and last_message.tool_calls: # Agent is deciding to call a tool tool_call = last_message.tool_calls[0] yield json.dumps({ "type": "thought", "content": f"🤖 Calling tool `{tool_call['name']}` with args: {tool_call['args']}" }) + "\n\n" elif isinstance(last_message, ToolMessage): # A tool has returned its result yield json.dumps({ "type": "observation", "content": f"🛠️ Tool `{last_message.name}` returned:\n\n```\n{last_message.content}\n```" }) + "\n\n" elif isinstance(last_message, AIMessage) and last_message.content: # This is the final answer (AIMessage with content but no tool_calls) yield json.dumps({ "type": "final_answer", "content": last_message.content }) + "\n\n" except Exception as e: logger.error(f"Error in agent workflow: {e}", exc_info=True) yield json.dumps({ "type": "final_answer", "content": f"I encountered an error while processing your request. Please try rephrasing your question or asking something simpler." }) + "\n\n" # --- FastAPI Application --- agent = None @asynccontextmanager async def lifespan(app: FastAPI): """Handles agent initialization on startup.""" global agent logger.info("Agent server startup...") try: agent = GraphRAGAgent() logger.info("GraphRAGAgent initialized successfully.") except ValueError as e: logger.error(f"Agent initialization failed: {e}") yield logger.info("Agent server shutdown.") app = FastAPI(title="GraphRAG Agent Server", lifespan=lifespan) class QueryRequest(BaseModel): question: str @app.post("/query") async def execute_query(request: QueryRequest) -> StreamingResponse: """Endpoint to receive questions and stream the agent's response.""" if not agent: async def error_stream(): yield json.dumps({"error": "Agent is not initialized. Check server logs."}) return StreamingResponse(error_stream()) return StreamingResponse(agent.stream_query(request.question), media_type="application/x-ndjson") @app.get("/health") def health_check(): """Health check endpoint.""" return {"status": "ok", "agent_initialized": agent is not None} # --- Main Execution --- def main(): """Main entry point to run the FastAPI server.""" logger.info("Starting agent server...") uvicorn.run(app, host="0.0.0.0", port=8001) if __name__ == "__main__": main()