Spaces:
No application file
No application file
| import os | |
| import sys | |
| import logging | |
| import json | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| import uvicorn | |
| from fastapi.responses import StreamingResponse | |
| from langchain_core.messages import ToolMessage, AIMessage | |
| from langchain_openai import ChatOpenAI | |
| from langgraph.prebuilt import create_react_agent | |
| from tools import MCPClient, SchemaSearchTool, JoinPathFinderTool, QueryExecutorTool | |
| # --- Configuration & Logging --- | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| MCP_URL = os.getenv("MCP_URL", "http://mcp:8000/mcp") | |
| API_KEY = os.getenv("MCP_API_KEY", "dev-key-123") | |
| LLM_API_KEY = os.getenv("LLM_API_KEY") | |
| # --- System Prompt --- | |
| SYSTEM_PROMPT = """You are a helpful assistant for querying life sciences databases. | |
| You have access to these tools: | |
| - schema_search: Find relevant database tables and columns based on keywords | |
| - find_join_path: Discover how to join tables together using the knowledge graph | |
| - execute_query: Run SQL queries against the databases | |
| Always use schema_search first to understand the available data, then construct appropriate SQL queries. | |
| When querying, be specific about what tables and columns you're using.""" | |
| # --- Agent Initialization --- | |
| class GraphRAGAgent: | |
| """The core agent for handling GraphRAG queries using LangGraph.""" | |
| def __init__(self): | |
| if not LLM_API_KEY: | |
| raise ValueError("LLM_API_KEY environment variable not set.") | |
| llm = ChatOpenAI(api_key=LLM_API_KEY, model="gpt-4o-mini", temperature=0, max_retries=1) | |
| mcp_client = MCPClient(mcp_url=MCP_URL, api_key=API_KEY) | |
| tools = [ | |
| SchemaSearchTool(mcp_client=mcp_client), | |
| JoinPathFinderTool(mcp_client=mcp_client), | |
| QueryExecutorTool(mcp_client=mcp_client), | |
| ] | |
| # Use LangGraph's prebuilt create_react_agent for proper message handling | |
| self.graph = create_react_agent(llm, tools, state_modifier=SYSTEM_PROMPT) | |
| async def stream_query(self, question: str): | |
| """Processes a question and streams the intermediate steps.""" | |
| try: | |
| async for event in self.graph.astream( | |
| {"messages": [("user", question)]}, | |
| stream_mode="values" | |
| ): | |
| # create_react_agent uses standard message format | |
| messages = event.get("messages", []) | |
| if not messages: | |
| continue | |
| last_message = messages[-1] | |
| if isinstance(last_message, AIMessage) and last_message.tool_calls: | |
| # Agent is deciding to call a tool | |
| tool_call = last_message.tool_calls[0] | |
| yield json.dumps({ | |
| "type": "thought", | |
| "content": f"🤖 Calling tool `{tool_call['name']}` with args: {tool_call['args']}" | |
| }) + "\n\n" | |
| elif isinstance(last_message, ToolMessage): | |
| # A tool has returned its result | |
| yield json.dumps({ | |
| "type": "observation", | |
| "content": f"🛠️ Tool `{last_message.name}` returned:\n\n```\n{last_message.content}\n```" | |
| }) + "\n\n" | |
| elif isinstance(last_message, AIMessage) and last_message.content: | |
| # This is the final answer (AIMessage with content but no tool_calls) | |
| yield json.dumps({ | |
| "type": "final_answer", | |
| "content": last_message.content | |
| }) + "\n\n" | |
| except Exception as e: | |
| logger.error(f"Error in agent workflow: {e}", exc_info=True) | |
| yield json.dumps({ | |
| "type": "final_answer", | |
| "content": f"I encountered an error while processing your request. Please try rephrasing your question or asking something simpler." | |
| }) + "\n\n" | |
| # --- FastAPI Application --- | |
| agent = None | |
| async def lifespan(app: FastAPI): | |
| """Handles agent initialization on startup.""" | |
| global agent | |
| logger.info("Agent server startup...") | |
| try: | |
| agent = GraphRAGAgent() | |
| logger.info("GraphRAGAgent initialized successfully.") | |
| except ValueError as e: | |
| logger.error(f"Agent initialization failed: {e}") | |
| yield | |
| logger.info("Agent server shutdown.") | |
| app = FastAPI(title="GraphRAG Agent Server", lifespan=lifespan) | |
| class QueryRequest(BaseModel): | |
| question: str | |
| async def execute_query(request: QueryRequest) -> StreamingResponse: | |
| """Endpoint to receive questions and stream the agent's response.""" | |
| if not agent: | |
| async def error_stream(): | |
| yield json.dumps({"error": "Agent is not initialized. Check server logs."}) | |
| return StreamingResponse(error_stream()) | |
| return StreamingResponse(agent.stream_query(request.question), media_type="application/x-ndjson") | |
| def health_check(): | |
| """Health check endpoint.""" | |
| return {"status": "ok", "agent_initialized": agent is not None} | |
| # --- Main Execution --- | |
| def main(): | |
| """Main entry point to run the FastAPI server.""" | |
| logger.info("Starting agent server...") | |
| uvicorn.run(app, host="0.0.0.0", port=8001) | |
| if __name__ == "__main__": | |
| main() |