ohmygaugh's picture
demo working
a0eb181
import os
import sys
import logging
import json
from contextlib import asynccontextmanager
from fastapi import FastAPI
from pydantic import BaseModel
import uvicorn
from fastapi.responses import StreamingResponse
from langchain_core.messages import ToolMessage, AIMessage
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent
from tools import MCPClient, SchemaSearchTool, JoinPathFinderTool, QueryExecutorTool
# --- Configuration & Logging ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
MCP_URL = os.getenv("MCP_URL", "http://mcp:8000/mcp")
API_KEY = os.getenv("MCP_API_KEY", "dev-key-123")
LLM_API_KEY = os.getenv("LLM_API_KEY")
# --- System Prompt ---
SYSTEM_PROMPT = """You are a helpful assistant for querying life sciences databases.
You have access to these tools:
- schema_search: Find relevant database tables and columns based on keywords
- find_join_path: Discover how to join tables together using the knowledge graph
- execute_query: Run SQL queries against the databases
Always use schema_search first to understand the available data, then construct appropriate SQL queries.
When querying, be specific about what tables and columns you're using."""
# --- Agent Initialization ---
class GraphRAGAgent:
"""The core agent for handling GraphRAG queries using LangGraph."""
def __init__(self):
if not LLM_API_KEY:
raise ValueError("LLM_API_KEY environment variable not set.")
llm = ChatOpenAI(api_key=LLM_API_KEY, model="gpt-4o-mini", temperature=0, max_retries=1)
mcp_client = MCPClient(mcp_url=MCP_URL, api_key=API_KEY)
tools = [
SchemaSearchTool(mcp_client=mcp_client),
JoinPathFinderTool(mcp_client=mcp_client),
QueryExecutorTool(mcp_client=mcp_client),
]
# Use LangGraph's prebuilt create_react_agent for proper message handling
self.graph = create_react_agent(llm, tools, state_modifier=SYSTEM_PROMPT)
async def stream_query(self, question: str):
"""Processes a question and streams the intermediate steps."""
try:
async for event in self.graph.astream(
{"messages": [("user", question)]},
stream_mode="values"
):
# create_react_agent uses standard message format
messages = event.get("messages", [])
if not messages:
continue
last_message = messages[-1]
if isinstance(last_message, AIMessage) and last_message.tool_calls:
# Agent is deciding to call a tool
tool_call = last_message.tool_calls[0]
yield json.dumps({
"type": "thought",
"content": f"🤖 Calling tool `{tool_call['name']}` with args: {tool_call['args']}"
}) + "\n\n"
elif isinstance(last_message, ToolMessage):
# A tool has returned its result
yield json.dumps({
"type": "observation",
"content": f"🛠️ Tool `{last_message.name}` returned:\n\n```\n{last_message.content}\n```"
}) + "\n\n"
elif isinstance(last_message, AIMessage) and last_message.content:
# This is the final answer (AIMessage with content but no tool_calls)
yield json.dumps({
"type": "final_answer",
"content": last_message.content
}) + "\n\n"
except Exception as e:
logger.error(f"Error in agent workflow: {e}", exc_info=True)
yield json.dumps({
"type": "final_answer",
"content": f"I encountered an error while processing your request. Please try rephrasing your question or asking something simpler."
}) + "\n\n"
# --- FastAPI Application ---
agent = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Handles agent initialization on startup."""
global agent
logger.info("Agent server startup...")
try:
agent = GraphRAGAgent()
logger.info("GraphRAGAgent initialized successfully.")
except ValueError as e:
logger.error(f"Agent initialization failed: {e}")
yield
logger.info("Agent server shutdown.")
app = FastAPI(title="GraphRAG Agent Server", lifespan=lifespan)
class QueryRequest(BaseModel):
question: str
@app.post("/query")
async def execute_query(request: QueryRequest) -> StreamingResponse:
"""Endpoint to receive questions and stream the agent's response."""
if not agent:
async def error_stream():
yield json.dumps({"error": "Agent is not initialized. Check server logs."})
return StreamingResponse(error_stream())
return StreamingResponse(agent.stream_query(request.question), media_type="application/x-ndjson")
@app.get("/health")
def health_check():
"""Health check endpoint."""
return {"status": "ok", "agent_initialized": agent is not None}
# --- Main Execution ---
def main():
"""Main entry point to run the FastAPI server."""
logger.info("Starting agent server...")
uvicorn.run(app, host="0.0.0.0", port=8001)
if __name__ == "__main__":
main()