| | import os |
| | import sys |
| | import logging |
| | import json |
| | from contextlib import asynccontextmanager |
| | from fastapi import FastAPI |
| | from pydantic import BaseModel |
| | import uvicorn |
| | from fastapi.responses import StreamingResponse |
| |
|
| | from langchain_core.messages import ToolMessage, AIMessage |
| | from langchain_openai import ChatOpenAI |
| | from langgraph.prebuilt import create_react_agent |
| |
|
| | from tools import MCPClient, SchemaSearchTool, JoinPathFinderTool, QueryExecutorTool |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| | logger = logging.getLogger(__name__) |
| |
|
| | MCP_URL = os.getenv("MCP_URL", "http://mcp:8000/mcp") |
| | API_KEY = os.getenv("MCP_API_KEY", "dev-key-123") |
| | LLM_API_KEY = os.getenv("LLM_API_KEY") |
| |
|
| | |
| | SYSTEM_PROMPT = """You are a helpful assistant for querying life sciences databases. |
| | |
| | You have access to these tools: |
| | - schema_search: Find relevant database tables and columns based on keywords |
| | - find_join_path: Discover how to join tables together using the knowledge graph |
| | - execute_query: Run SQL queries against the databases |
| | |
| | Always use schema_search first to understand the available data, then construct appropriate SQL queries. |
| | When querying, be specific about what tables and columns you're using.""" |
| |
|
| | |
| | class GraphRAGAgent: |
| | """The core agent for handling GraphRAG queries using LangGraph.""" |
| |
|
| | def __init__(self): |
| | if not LLM_API_KEY: |
| | raise ValueError("LLM_API_KEY environment variable not set.") |
| |
|
| | llm = ChatOpenAI(api_key=LLM_API_KEY, model="gpt-4o-mini", temperature=0, max_retries=1) |
| | |
| | mcp_client = MCPClient(mcp_url=MCP_URL, api_key=API_KEY) |
| | tools = [ |
| | SchemaSearchTool(mcp_client=mcp_client), |
| | JoinPathFinderTool(mcp_client=mcp_client), |
| | QueryExecutorTool(mcp_client=mcp_client), |
| | ] |
| | |
| | |
| | self.graph = create_react_agent(llm, tools, state_modifier=SYSTEM_PROMPT) |
| |
|
| | async def stream_query(self, question: str): |
| | """Processes a question and streams the intermediate steps.""" |
| | try: |
| | async for event in self.graph.astream( |
| | {"messages": [("user", question)]}, |
| | stream_mode="values" |
| | ): |
| | |
| | messages = event.get("messages", []) |
| | if not messages: |
| | continue |
| | |
| | last_message = messages[-1] |
| | |
| | if isinstance(last_message, AIMessage) and last_message.tool_calls: |
| | |
| | tool_call = last_message.tool_calls[0] |
| | yield json.dumps({ |
| | "type": "thought", |
| | "content": f"🤖 Calling tool `{tool_call['name']}` with args: {tool_call['args']}" |
| | }) + "\n\n" |
| | elif isinstance(last_message, ToolMessage): |
| | |
| | yield json.dumps({ |
| | "type": "observation", |
| | "content": f"🛠️ Tool `{last_message.name}` returned:\n\n```\n{last_message.content}\n```" |
| | }) + "\n\n" |
| | elif isinstance(last_message, AIMessage) and last_message.content: |
| | |
| | yield json.dumps({ |
| | "type": "final_answer", |
| | "content": last_message.content |
| | }) + "\n\n" |
| | except Exception as e: |
| | logger.error(f"Error in agent workflow: {e}", exc_info=True) |
| | yield json.dumps({ |
| | "type": "final_answer", |
| | "content": f"I encountered an error while processing your request. Please try rephrasing your question or asking something simpler." |
| | }) + "\n\n" |
| |
|
| | |
| | agent = None |
| |
|
| | @asynccontextmanager |
| | async def lifespan(app: FastAPI): |
| | """Handles agent initialization on startup.""" |
| | global agent |
| | logger.info("Agent server startup...") |
| | try: |
| | agent = GraphRAGAgent() |
| | logger.info("GraphRAGAgent initialized successfully.") |
| | except ValueError as e: |
| | logger.error(f"Agent initialization failed: {e}") |
| | yield |
| | logger.info("Agent server shutdown.") |
| |
|
| | app = FastAPI(title="GraphRAG Agent Server", lifespan=lifespan) |
| |
|
| | class QueryRequest(BaseModel): |
| | question: str |
| |
|
| | @app.post("/query") |
| | async def execute_query(request: QueryRequest) -> StreamingResponse: |
| | """Endpoint to receive questions and stream the agent's response.""" |
| | if not agent: |
| | async def error_stream(): |
| | yield json.dumps({"error": "Agent is not initialized. Check server logs."}) |
| | return StreamingResponse(error_stream()) |
| | |
| | return StreamingResponse(agent.stream_query(request.question), media_type="application/x-ndjson") |
| |
|
| | @app.get("/health") |
| | def health_check(): |
| | """Health check endpoint.""" |
| | return {"status": "ok", "agent_initialized": agent is not None} |
| |
|
| | |
| | def main(): |
| | """Main entry point to run the FastAPI server.""" |
| | logger.info("Starting agent server...") |
| | uvicorn.run(app, host="0.0.0.0", port=8001) |
| |
|
| | if __name__ == "__main__": |
| | main() |