Spaces:
No application file
No application file
File size: 5,494 Bytes
795ed51 12db5fc 86cbe3c 9d411a7 86cbe3c 398a370 a0eb181 9d411a7 a0eb181 398a370 9d411a7 398a370 86cbe3c 12db5fc 86cbe3c 9d411a7 12db5fc a0eb181 12db5fc 86cbe3c 795ed51 86cbe3c 9d411a7 7faf776 9d411a7 84473fd 86cbe3c 84473fd a0eb181 86cbe3c a0eb181 86cbe3c 9d411a7 86cbe3c 9d411a7 86cbe3c 9d411a7 86cbe3c 795ed51 86cbe3c 795ed51 84473fd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | import os
import sys
import logging
import json
from contextlib import asynccontextmanager
from fastapi import FastAPI
from pydantic import BaseModel
import uvicorn
from fastapi.responses import StreamingResponse
from langchain_core.messages import ToolMessage, AIMessage
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent
from tools import MCPClient, SchemaSearchTool, JoinPathFinderTool, QueryExecutorTool
# --- Configuration & Logging ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
MCP_URL = os.getenv("MCP_URL", "http://mcp:8000/mcp")
API_KEY = os.getenv("MCP_API_KEY", "dev-key-123")
LLM_API_KEY = os.getenv("LLM_API_KEY")
# --- System Prompt ---
SYSTEM_PROMPT = """You are a helpful assistant for querying life sciences databases.
You have access to these tools:
- schema_search: Find relevant database tables and columns based on keywords
- find_join_path: Discover how to join tables together using the knowledge graph
- execute_query: Run SQL queries against the databases
Always use schema_search first to understand the available data, then construct appropriate SQL queries.
When querying, be specific about what tables and columns you're using."""
# --- Agent Initialization ---
class GraphRAGAgent:
"""The core agent for handling GraphRAG queries using LangGraph."""
def __init__(self):
if not LLM_API_KEY:
raise ValueError("LLM_API_KEY environment variable not set.")
llm = ChatOpenAI(api_key=LLM_API_KEY, model="gpt-4o-mini", temperature=0, max_retries=1)
mcp_client = MCPClient(mcp_url=MCP_URL, api_key=API_KEY)
tools = [
SchemaSearchTool(mcp_client=mcp_client),
JoinPathFinderTool(mcp_client=mcp_client),
QueryExecutorTool(mcp_client=mcp_client),
]
# Use LangGraph's prebuilt create_react_agent for proper message handling
self.graph = create_react_agent(llm, tools, state_modifier=SYSTEM_PROMPT)
async def stream_query(self, question: str):
"""Processes a question and streams the intermediate steps."""
try:
async for event in self.graph.astream(
{"messages": [("user", question)]},
stream_mode="values"
):
# create_react_agent uses standard message format
messages = event.get("messages", [])
if not messages:
continue
last_message = messages[-1]
if isinstance(last_message, AIMessage) and last_message.tool_calls:
# Agent is deciding to call a tool
tool_call = last_message.tool_calls[0]
yield json.dumps({
"type": "thought",
"content": f"🤖 Calling tool `{tool_call['name']}` with args: {tool_call['args']}"
}) + "\n\n"
elif isinstance(last_message, ToolMessage):
# A tool has returned its result
yield json.dumps({
"type": "observation",
"content": f"🛠️ Tool `{last_message.name}` returned:\n\n```\n{last_message.content}\n```"
}) + "\n\n"
elif isinstance(last_message, AIMessage) and last_message.content:
# This is the final answer (AIMessage with content but no tool_calls)
yield json.dumps({
"type": "final_answer",
"content": last_message.content
}) + "\n\n"
except Exception as e:
logger.error(f"Error in agent workflow: {e}", exc_info=True)
yield json.dumps({
"type": "final_answer",
"content": f"I encountered an error while processing your request. Please try rephrasing your question or asking something simpler."
}) + "\n\n"
# --- FastAPI Application ---
agent = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Handles agent initialization on startup."""
global agent
logger.info("Agent server startup...")
try:
agent = GraphRAGAgent()
logger.info("GraphRAGAgent initialized successfully.")
except ValueError as e:
logger.error(f"Agent initialization failed: {e}")
yield
logger.info("Agent server shutdown.")
app = FastAPI(title="GraphRAG Agent Server", lifespan=lifespan)
class QueryRequest(BaseModel):
question: str
@app.post("/query")
async def execute_query(request: QueryRequest) -> StreamingResponse:
"""Endpoint to receive questions and stream the agent's response."""
if not agent:
async def error_stream():
yield json.dumps({"error": "Agent is not initialized. Check server logs."})
return StreamingResponse(error_stream())
return StreamingResponse(agent.stream_query(request.question), media_type="application/x-ndjson")
@app.get("/health")
def health_check():
"""Health check endpoint."""
return {"status": "ok", "agent_initialized": agent is not None}
# --- Main Execution ---
def main():
"""Main entry point to run the FastAPI server."""
logger.info("Starting agent server...")
uvicorn.run(app, host="0.0.0.0", port=8001)
if __name__ == "__main__":
main() |