File size: 5,494 Bytes
795ed51
12db5fc
86cbe3c
 
9d411a7
86cbe3c
 
 
 
398a370
a0eb181
9d411a7
a0eb181
398a370
9d411a7
398a370
86cbe3c
 
 
12db5fc
86cbe3c
 
9d411a7
12db5fc
a0eb181
 
 
 
 
 
 
 
 
 
12db5fc
86cbe3c
 
 
795ed51
86cbe3c
9d411a7
 
7faf776
9d411a7
84473fd
86cbe3c
 
 
 
 
84473fd
 
a0eb181
 
86cbe3c
 
 
a0eb181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86cbe3c
 
 
 
9d411a7
 
 
86cbe3c
9d411a7
86cbe3c
 
 
 
 
9d411a7
 
 
 
 
 
 
86cbe3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
795ed51
86cbe3c
 
 
795ed51
 
84473fd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import sys
import logging
import json
from contextlib import asynccontextmanager
from fastapi import FastAPI
from pydantic import BaseModel
import uvicorn
from fastapi.responses import StreamingResponse

from langchain_core.messages import ToolMessage, AIMessage
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent

from tools import MCPClient, SchemaSearchTool, JoinPathFinderTool, QueryExecutorTool

# --- Configuration & Logging ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

MCP_URL = os.getenv("MCP_URL", "http://mcp:8000/mcp")
API_KEY = os.getenv("MCP_API_KEY", "dev-key-123")
LLM_API_KEY = os.getenv("LLM_API_KEY")

# --- System Prompt ---
SYSTEM_PROMPT = """You are a helpful assistant for querying life sciences databases.

You have access to these tools:
- schema_search: Find relevant database tables and columns based on keywords
- find_join_path: Discover how to join tables together using the knowledge graph
- execute_query: Run SQL queries against the databases

Always use schema_search first to understand the available data, then construct appropriate SQL queries.
When querying, be specific about what tables and columns you're using."""

# --- Agent Initialization ---
class GraphRAGAgent:
    """The core agent for handling GraphRAG queries using LangGraph."""

    def __init__(self):
        if not LLM_API_KEY:
            raise ValueError("LLM_API_KEY environment variable not set.")

        llm = ChatOpenAI(api_key=LLM_API_KEY, model="gpt-4o-mini", temperature=0, max_retries=1)
        
        mcp_client = MCPClient(mcp_url=MCP_URL, api_key=API_KEY)
        tools = [
            SchemaSearchTool(mcp_client=mcp_client),
            JoinPathFinderTool(mcp_client=mcp_client),
            QueryExecutorTool(mcp_client=mcp_client),
        ]
        
        # Use LangGraph's prebuilt create_react_agent for proper message handling
        self.graph = create_react_agent(llm, tools, state_modifier=SYSTEM_PROMPT)

    async def stream_query(self, question: str):
        """Processes a question and streams the intermediate steps."""
        try:
            async for event in self.graph.astream(
                {"messages": [("user", question)]},
                stream_mode="values"
            ):
                # create_react_agent uses standard message format
                messages = event.get("messages", [])
                if not messages:
                    continue
                    
                last_message = messages[-1]
                
                if isinstance(last_message, AIMessage) and last_message.tool_calls:
                    # Agent is deciding to call a tool
                    tool_call = last_message.tool_calls[0]
                    yield json.dumps({
                        "type": "thought", 
                        "content": f"🤖 Calling tool `{tool_call['name']}` with args: {tool_call['args']}"
                    }) + "\n\n"
                elif isinstance(last_message, ToolMessage):
                    # A tool has returned its result
                    yield json.dumps({
                        "type": "observation", 
                        "content": f"🛠️ Tool `{last_message.name}` returned:\n\n```\n{last_message.content}\n```"
                    }) + "\n\n"
                elif isinstance(last_message, AIMessage) and last_message.content:
                    # This is the final answer (AIMessage with content but no tool_calls)
                    yield json.dumps({
                        "type": "final_answer", 
                        "content": last_message.content
                    }) + "\n\n"
        except Exception as e:
            logger.error(f"Error in agent workflow: {e}", exc_info=True)
            yield json.dumps({
                "type": "final_answer", 
                "content": f"I encountered an error while processing your request. Please try rephrasing your question or asking something simpler."
            }) + "\n\n"

# --- FastAPI Application ---
agent = None

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Handles agent initialization on startup."""
    global agent
    logger.info("Agent server startup...")
    try:
        agent = GraphRAGAgent()
        logger.info("GraphRAGAgent initialized successfully.")
    except ValueError as e:
        logger.error(f"Agent initialization failed: {e}")
    yield
    logger.info("Agent server shutdown.")

app = FastAPI(title="GraphRAG Agent Server", lifespan=lifespan)

class QueryRequest(BaseModel):
    question: str

@app.post("/query")
async def execute_query(request: QueryRequest) -> StreamingResponse:
    """Endpoint to receive questions and stream the agent's response."""
    if not agent:
        async def error_stream():
            yield json.dumps({"error": "Agent is not initialized. Check server logs."})
        return StreamingResponse(error_stream())
    
    return StreamingResponse(agent.stream_query(request.question), media_type="application/x-ndjson")

@app.get("/health")
def health_check():
    """Health check endpoint."""
    return {"status": "ok", "agent_initialized": agent is not None}

# --- Main Execution ---
def main():
    """Main entry point to run the FastAPI server."""
    logger.info("Starting agent server...")
    uvicorn.run(app, host="0.0.0.0", port=8001)

if __name__ == "__main__":
    main()