ohmygaugh commited on
Commit
86cbe3c
·
1 Parent(s): 398a370

1st pass at merging the code bases(does not run yet)

Browse files
.gitignore CHANGED
@@ -3,3 +3,6 @@
3
  /neo4j/data
4
  /postgres/data
5
  /neo4j/logs
 
 
 
 
3
  /neo4j/data
4
  /postgres/data
5
  /neo4j/logs
6
+ .env
7
+ mcp/__pycache__
8
+ mcp/core/__pycache__
Makefile CHANGED
@@ -1,4 +1,4 @@
1
- .PHONY: up down logs seed clean health test demo
2
 
3
  # Start all services
4
  up:
@@ -19,12 +19,19 @@ logs:
19
  seed:
20
  docker-compose exec mcp python /app/ops/scripts/seed.py
21
 
 
 
 
 
 
 
 
 
22
  # Clean everything (including volumes)
23
  clean:
24
  docker-compose down -v
25
  docker system prune -f
26
  @if [ -d "neo4j/data" ]; then rm -rf neo4j/data; fi
27
- @if [ -d "postgres/data" ]; then rm -rf postgres/data; fi
28
  @if [ -d "frontend/.next" ]; then rm -rf frontend/.next; fi
29
  @if [ -d "frontend/node_modules" ]; then rm -rf frontend/node_modules; fi
30
 
@@ -32,7 +39,6 @@ clean:
32
  health:
33
  @echo "Checking service health..."
34
  @docker-compose exec neo4j cypher-shell -u neo4j -p password "MATCH (n) RETURN count(n) LIMIT 1" > /dev/null 2>&1 && echo "✅ Neo4j: Healthy" || echo "❌ Neo4j: Unhealthy"
35
- @docker-compose exec postgres pg_isready -U postgres > /dev/null 2>&1 && echo "✅ PostgreSQL: Healthy" || echo "❌ PostgreSQL: Unhealthy"
36
  @curl -s http://localhost:8000/health > /dev/null && echo "✅ MCP Server: Healthy" || echo "❌ MCP Server: Unhealthy"
37
  @curl -s http://localhost:3000 > /dev/null && echo "✅ Frontend: Healthy" || echo "❌ Frontend: Unhealthy"
38
  @curl -s http://localhost:8501 > /dev/null && echo "✅ Streamlit: Healthy" || echo "❌ Streamlit: Unhealthy"
 
1
+ .PHONY: up down logs seed seed-db ingest clean health test demo
2
 
3
  # Start all services
4
  up:
 
19
  seed:
20
  docker-compose exec mcp python /app/ops/scripts/seed.py
21
 
22
+ # Seed the SQLite databases from scratch
23
+ seed-db:
24
+ python ops/scripts/generate_sample_databases.py
25
+
26
+ # Ingest SQLite database schemas into Neo4j
27
+ ingest:
28
+ docker-compose run --rm mcp python /app/ops/scripts/ingest.py
29
+
30
  # Clean everything (including volumes)
31
  clean:
32
  docker-compose down -v
33
  docker system prune -f
34
  @if [ -d "neo4j/data" ]; then rm -rf neo4j/data; fi
 
35
  @if [ -d "frontend/.next" ]; then rm -rf frontend/.next; fi
36
  @if [ -d "frontend/node_modules" ]; then rm -rf frontend/node_modules; fi
37
 
 
39
  health:
40
  @echo "Checking service health..."
41
  @docker-compose exec neo4j cypher-shell -u neo4j -p password "MATCH (n) RETURN count(n) LIMIT 1" > /dev/null 2>&1 && echo "✅ Neo4j: Healthy" || echo "❌ Neo4j: Unhealthy"
 
42
  @curl -s http://localhost:8000/health > /dev/null && echo "✅ MCP Server: Healthy" || echo "❌ MCP Server: Unhealthy"
43
  @curl -s http://localhost:3000 > /dev/null && echo "✅ Frontend: Healthy" || echo "❌ Frontend: Unhealthy"
44
  @curl -s http://localhost:8501 > /dev/null && echo "✅ Streamlit: Healthy" || echo "❌ Streamlit: Unhealthy"
README.md CHANGED
@@ -1,74 +1,63 @@
1
- # Graph-Driven Agentic System MVP
2
- "Keep your data where it is but we will treat it like a graph for you and solve these problems for you"
3
 
4
  ## Overview
5
- An intelligent agent system that reads instructions from Neo4j, queries PostgreSQL databases, pauses for human review, and maintains a complete audit trail. The system demonstrates agentic workflow orchestration with human-in-the-loop controls.
6
 
7
  ## Key Features
8
 
9
- 🤖 **Autonomous Agent**: Processes instructions sequentially with configurable pause durations
10
- 📊 **Graph Database**: All workflow metadata stored in Neo4j for complete traceability
11
- 🔍 **Natural Language SQL**: Converts questions to SQL using LLM integration
12
- ⏸️ **Human-in-the-Loop**: 5-minute pauses allow instruction editing during execution
13
- 🎯 **Single API Gateway**: All Neo4j operations routed through MCP server
14
- 📈 **Real-time Visualization**: Live workflow progress in browser interface
15
- 🔄 **Complete Audit Trail**: Every action logged with timestamps and relationships
16
 
17
  ## Architecture
18
 
19
  ```
20
- ┌─────────────┐ ┌─────────────┐ ┌─────────────┐
21
- Frontend │────│ MCP Server────Neo4j
22
- (Next.js) │ (FastAPI) │ (Graph) │
23
- └─────────────┘ └─────────────┘ └─────────────┘
24
-
25
- ┌─────────────┐ ┌─────────────┐
26
- Agent ──── PostgreSQL │
27
- │ (Python) │ │ (Data) │
28
- └─────────────┘ └─────────────┘
29
- ```
30
-
31
- ###### Any MCP applicaiton/API/Agent can be both a client and a server. Clients and Servers are a logical seperation only, not a physical one. There is an natural idea of chaining/composability between clients and servers. Like a fire bucket chain of context slosh. Use a pydantic graph here as the engine for the orchestrator? I think the point is to create a co-pilot for the analyst that is using graphRAG to inform itself, given the users request, to think in graphRAG before determining how to navigate the MCP tools
32
-
33
-
34
- actually you aren't immediately writeing data to neo4j from the relational DB instead it's about doing graphRAG to curate the proper SQL statements and tool call to make... there maybe tools to do this.
35
-
36
- Use MCP inspector and also have an MCP server that automatically checks the logs in the inspector: https://modelcontextprotocol.io/docs/tools/inspector
37
 
38
-
39
- pydantic is key here and now worrying about frontend until demo time
40
 
41
  ### Components
42
 
43
- - **Neo4j**: Graph database storing workflows, instructions, and execution metadata
44
- - **MCP Server**: Single gateway for all Neo4j operations with parameter fixing
45
- - **Agent**: Executes instructions with configurable pause periods for human review
46
- - **PostgreSQL**: Sample data source for testing SQL generation
47
- - **Frontend**: React/Next.js chat interface with real-time workflow visualization
48
 
49
  ## Quick Start
50
 
51
  ### Prerequisites
52
  - Docker & Docker Compose
53
- - OpenAI or Anthropic API key (for LLM integration)
54
 
55
  ### Setup
56
  1. **Clone and configure**:
57
  ```bash
58
  git clone <repository-url>
59
  cd <repository-name>
60
- cp .env.example .env
61
  ```
62
 
63
- 2. **Add your LLM API key** to `.env`:
64
- ```bash
65
- # For OpenAI
66
- LLM_API_KEY=sk-your-openai-key-here
67
- LLM_MODEL=gpt-4
68
-
69
- # For Anthropic
70
- LLM_API_KEY=your-anthropic-key-here
71
- LLM_MODEL=claude-3-sonnet-20240229
72
  ```
73
 
74
  3. **Start the system**:
@@ -76,235 +65,63 @@ pydantic is key here and now worrying about frontend until demo time
76
  make up
77
  ```
78
 
79
- 4. **Seed demo data**:
80
  ```bash
81
- make seed
 
82
  ```
83
 
84
  5. **Open the interface**:
85
- - Frontend: http://localhost:3000
86
  - Neo4j Browser: http://localhost:7474 (neo4j/password)
87
 
88
  ## Usage
89
-
90
- ### Basic Workflow
91
- 1. **Ask a question** in the chat interface:
92
- - "How many customers do we have?"
93
- - "Show me all customers who have placed orders"
94
- - "What's the total revenue?"
95
-
96
- 2. **Watch the agent process**:
97
- - Creates workflow with 3 instructions
98
- - Discovers database schema
99
- - Generates SQL from your question
100
- - Reviews and formats results
101
-
102
- 3. **Human intervention** (during 5-minute pauses):
103
- - Edit instructions in Neo4j Browser
104
- - Change parameters or questions
105
- - Stop workflows if needed
106
-
107
- ### Editing Instructions During Pause
108
-
109
- When the agent pauses, you can modify instructions in Neo4j Browser:
110
-
111
- ```cypher
112
- // Change the question being asked
113
- MATCH (i:Instruction {status: 'pending'})
114
- SET i.parameters = '{"question": "Show me customers from the last month"}'
115
-
116
- // Stop the entire workflow
117
- MATCH (w:Workflow {status: 'active'})
118
- SET w.status = 'stopped'
119
- ```
120
-
121
- ### Monitoring
122
-
123
- Check system health:
124
- ```bash
125
- make health
126
- ```
127
-
128
- View real-time logs:
129
- ```bash
130
- make logs
131
- ```
132
-
133
- Check specific service:
134
- ```bash
135
- make debug-agent # Agent logs
136
- make debug-mcp # MCP server logs
137
- make debug-frontend # Frontend logs
138
- ```
139
-
140
- ## Commands Reference
141
-
142
- | Command | Description |
143
- |---------|-------------|
144
- | `make up` | Start all services |
145
- | `make down` | Stop all services |
146
- | `make clean` | Remove all data and containers |
147
- | `make health` | Check service health |
148
- | `make seed` | Create demo data |
149
- | `make logs` | View all logs |
150
- | `make demo` | Full clean + setup + seed |
151
- | `make test` | Run integration test |
152
-
153
- ## Configuration
154
-
155
- ### Environment Variables
156
-
157
- All configuration is in `.env` file:
158
-
159
- - **Neo4j**: Database connection and auth
160
- - **PostgreSQL**: Sample data source
161
- - **MCP**: API keys and server settings
162
- - **Agent**: Polling interval and pause duration
163
- - **LLM**: API key and model selection
164
-
165
- ### Pause Duration
166
-
167
- Default: 5 minutes (300 seconds)
168
- Configurable via `PAUSE_DURATION` in `.env`
169
-
170
- ### Polling Interval
171
-
172
- Default: 30 seconds
173
- Configurable via `AGENT_POLL_INTERVAL` in `.env`
174
 
175
  ## Development
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  ### File Structure
178
  ```
179
- ├── agent/ # Python agent that executes instructions
180
- ├── frontend/ # Next.js chat interface
181
- ├── mcp/ # FastAPI server for Neo4j operations
182
- ├── neo4j/ # Neo4j configuration and data
183
- ├── postgres/ # PostgreSQL setup and sample data
184
- ├── ops/ # Operational scripts (seeding, etc.)
185
  ├── docker-compose.yml
186
  ├── Makefile
187
  └── README.md
188
- ```
189
-
190
- ### Adding New Instruction Types
191
-
192
- 1. **Define handler in agent**:
193
- ```python
194
- def handle_new_instruction_type(instruction):
195
- # Implementation
196
- return {"status": "success", "result": "..."}
197
- ```
198
-
199
- 2. **Add to agent main loop**:
200
- ```python
201
- elif instruction['type'] == 'new_instruction_type':
202
- exec_result = handle_new_instruction_type(instruction)
203
- ```
204
-
205
- 3. **Update frontend** to create new instruction types in workflows.
206
-
207
- ### Database Schema
208
-
209
- The system uses two databases:
210
-
211
- **Neo4j** (Workflow metadata):
212
- - `Workflow` nodes with status and metadata
213
- - `Instruction` nodes with type, sequence, parameters
214
- - `Execution` nodes with results and timestamps
215
- - Relationships: `HAS_INSTRUCTION`, `EXECUTED_AS`, `NEXT_INSTRUCTION`
216
-
217
- **PostgreSQL** (Sample data):
218
- - `customers` table
219
- - `orders` table
220
- - Sample data for testing SQL generation
221
-
222
- ## Troubleshooting
223
-
224
- ### Common Issues
225
-
226
- **Services not starting**:
227
- ```bash
228
- make down
229
- make clean
230
- make up
231
- ```
232
-
233
- **Agent not processing**:
234
- ```bash
235
- make restart-agent
236
- make debug-agent
237
- ```
238
-
239
- **Frontend not loading**:
240
- ```bash
241
- make restart-frontend
242
- make debug-frontend
243
- ```
244
-
245
- **Database connection issues**:
246
- ```bash
247
- make health
248
- # Check .env configuration
249
- ```
250
-
251
- ### Debug Mode
252
-
253
- For detailed logging, check individual service logs:
254
- ```bash
255
- docker-compose logs -f agent
256
- docker-compose logs -f mcp
257
- docker-compose logs -f frontend
258
- ```
259
-
260
- ### Reset Everything
261
-
262
- Complete clean slate:
263
- ```bash
264
- make clean
265
- cp .env.example .env
266
- # Edit .env with your API key
267
- make demo
268
- ```
269
-
270
- ## Contributing
271
-
272
- 1. Fork the repository
273
- 2. Create a feature branch
274
- 3. Test with `make demo`
275
- 4. Submit a pull request
276
-
277
- ## License
278
-
279
- MIT License - see LICENSE file for details.
280
-
281
- ---
282
-
283
- ## Quick Demo
284
-
285
- Want to see it in action immediately?
286
-
287
- ```bash
288
- # 1. Clone repo
289
- git clone <repo-url> && cd <repo-name>
290
-
291
- # 2. Add your API key
292
- cp .env.example .env
293
- # Edit .env: LLM_API_KEY=your-key-here
294
-
295
- # 3. Start everything
296
- make demo
297
-
298
- # 4. Open http://localhost:3000
299
- # 5. Ask: "How many records are in the database?"
300
- # 6. Watch the magic happen! ✨
301
- ```
302
-
303
- The system will:
304
- - Create a workflow with 3 instructions
305
- - Pause for 5 minutes before each step (editable in Neo4j Browser)
306
- - Generate SQL from your natural language question
307
- - Execute the query and return formatted results
308
- - Show the entire process in a visual graph
309
-
310
- 🎉 **Welcome to the future of human-AI collaboration!**
 
1
+ # GraphRAG Agentic System
 
2
 
3
  ## Overview
4
+ This project implements an intelligent, multi-step GraphRAG-powered agent that uses LangChain to orchestrate complex queries against a federated life sciences dataset. The agent leverages a Neo4j graph database to understand the relationships between disparate SQLite databases, constructs SQL queries, and returns unified results through a conversational UI.
5
 
6
  ## Key Features
7
 
8
+ 🤖 **LangChain Agent**: Orchestrates tools for schema discovery, pathfinding, and query execution.
9
+ 🕸️ **GraphRAG Enabled**: Uses a Neo4j knowledge graph of database schemas for intelligent query planning.
10
+ 🔬 **Life Sciences Dataset**: Comes with a rich dataset across clinical trials, drug discovery, and lab results.
11
+ conversational **Conversational UI**: A Streamlit-based chat interface for interacting with the agent.
12
+ 🔌 **RESTful MCP Server**: All core logic is exposed via a secure and scalable FastAPI server.
 
 
13
 
14
  ## Architecture
15
 
16
  ```
17
+ ┌───────────────── ┌─────────────── ┌─────────────────
18
+ Streamlit Chat │──────Agent MCP Server
19
+ (UI) │ (LangChain) │ (FastAPI) │
20
+ └───────────────── └─────────────── └─────────────────
21
+
22
+ ┌──────────────────────────────────────────────
23
+
24
+ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐
25
+ │ Neo4j │ │ clinical_ │ │ laboratory │
26
+ │ (Schema KG) │ │ trials.db │ │ .db │
27
+ └─────────────┘ └─────────────┘ └─────────────┘
28
+
29
+ ┌─────────────┐
30
+ │ drug_ │
31
+ discovery.db│
32
+ └─────────────┘
 
33
 
34
+ ```
 
35
 
36
  ### Components
37
 
38
+ - **Streamlit**: Provides a conversational chat interface for users to ask questions.
39
+ - **Agent**: A LangChain-powered orchestrator that uses custom tools to query the MCP server.
40
+ - **MCP Server**: A FastAPI application that exposes core logic for schema discovery, graph pathfinding, and federated query execution.
41
+ - **Neo4j**: Stores a knowledge graph of the schemas of all connected SQLite databases.
42
+ - **SQLite Databases**: A set of life sciences databases (`clinical_trials.db`, `drug_discovery.db`, `laboratory.db`) that serve as the federated data sources.
43
 
44
  ## Quick Start
45
 
46
  ### Prerequisites
47
  - Docker & Docker Compose
48
+ - OpenAI API key
49
 
50
  ### Setup
51
  1. **Clone and configure**:
52
  ```bash
53
  git clone <repository-url>
54
  cd <repository-name>
55
+ touch .env
56
  ```
57
 
58
+ 2. **Add your OpenAI API key** to the `.env` file. This is the only secret you need to provide.
59
+ ```
60
+ OPENAI_API_KEY="sk-your-openai-key-here"
 
 
 
 
 
 
61
  ```
62
 
63
  3. **Start the system**:
 
65
  make up
66
  ```
67
 
68
+ 4. **Seed the databases and ingest schema**:
69
  ```bash
70
+ make seed-db
71
+ make ingest
72
  ```
73
 
74
  5. **Open the interface**:
75
+ - Streamlit UI: http://localhost:8501
76
  - Neo4j Browser: http://localhost:7474 (neo4j/password)
77
 
78
  ## Usage
79
+ Once the system is running, open the Streamlit UI and ask a question about the life sciences data, for example:
80
+ - "What are the names of the trials and their primary purpose for studies on 'Cancer'?"
81
+ - "Find all drugs with 'Aspirin' in their name."
82
+ - "Show me lab results for patient '123'."
83
+
84
+ The agent will then:
85
+ 1. Use the `SchemaSearchTool` to find relevant tables.
86
+ 2. Use the `JoinPathFinderTool` to determine how to join them.
87
+ 3. Construct a SQL query.
88
+ 4. Execute the query using the `QueryExecutorTool`.
89
+ 5. Return the final answer to the UI.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  ## Development
92
 
93
+ ### Running the Agent Manually
94
+ To test the agent's logic directly without the full Docker stack, you can run it from your terminal.
95
+
96
+ 1. **Set up the environment**:
97
+ Make sure the MCP and Neo4j services are running (`make up`).
98
+ Create a Python virtual environment and install dependencies:
99
+ ```bash
100
+ python -m venv venv
101
+ source venv/bin/activate
102
+ pip install -r agent/requirements.txt
103
+ ```
104
+
105
+ 2. **Set your API key**:
106
+ ```bash
107
+ export OPENAI_API_KEY="sk-your-openai-key-here"
108
+ ```
109
+
110
+ 3. **Run the agent**:
111
+ ```bash
112
+ python agent/main.py
113
+ ```
114
+ The agent will run with the hardcoded example question and print the execution trace and final answer to your console.
115
+
116
  ### File Structure
117
  ```
118
+ ├── agent/ # The LangChain agent and its tools
119
+ ├── streamlit/ # The Streamlit conversational UI
120
+ ├── mcp/ # FastAPI server with core logic
121
+ ├── neo4j/ # Neo4j configuration and data
122
+ ├── data/ # SQLite databases
123
+ ├── ops/ # Operational scripts (seeding, ingestion, etc.)
124
  ├── docker-compose.yml
125
  ├── Makefile
126
  └── README.md
127
+ ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/main.py CHANGED
@@ -1,435 +1,134 @@
1
  import os
2
- import time
3
- import json
4
- import requests
5
- import signal
6
  import sys
7
- from datetime import datetime
8
- import openai
9
- from anthropic import Anthropic
10
-
11
- MCP_URL = os.getenv("MCP_URL", "http://mcp:8000/mcp")
12
- API_KEY = os.getenv("MCP_API_KEY", "dev-key-123")
13
- POLL_INTERVAL = int(os.getenv("AGENT_POLL_INTERVAL", "30"))
14
-
15
- # Configure LLM
16
- LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4")
17
- LLM_API_KEY = os.getenv("LLM_API_KEY")
18
-
19
- if "gpt" in LLM_MODEL:
20
- openai.api_key = LLM_API_KEY
21
- llm_client = None
22
- else:
23
- llm_client = Anthropic(api_key=LLM_API_KEY)
24
-
25
- # Defining the agents
26
- ## Data Procurement Agent
27
- ## Graph Analysis Agent
28
- ## x Agent etc...
29
 
30
- # Create Orchestrato with all the agents and their tasks per orchestrator.py file
31
- orchestrator = Orchestrator(
32
- llm_factory=EastridgeAugmentedLLM,
33
- available_agents=[
34
- DataProcurementAgent(),
35
- GraphAnalysisAgent(),
36
- xAgent()
37
- ],
38
- plan_type="full",
39
- plan_output_path=Path("output/execution_plan.md"),
40
- )
41
 
 
42
 
43
- # Global flag for interrupt handling
44
- interrupted = False
 
45
 
46
- def signal_handler(sig, frame):
47
- global interrupted
48
- interrupted = True
49
- print(f"\n[{datetime.now()}] Interrupt received, will stop after current instruction")
50
 
51
- signal.signal(signal.SIGINT, signal_handler)
52
- signal.signal(signal.SIGTERM, signal_handler)
 
53
 
54
- def call_mcp(tool, params=None):
55
- response = requests.post(
56
- MCP_URL,
57
- headers={"X-API-Key": API_KEY, "Content-Type": "application/json"},
58
- json={"tool": tool, "params": params or {}}
59
- )
60
- return response.json()
61
 
62
- def get_llm_response(prompt):
63
- """Get response from configured LLM"""
64
- if "gpt" in LLM_MODEL:
65
- response = openai.ChatCompletion.create(
66
- model=LLM_MODEL,
67
- messages=[
68
- {"role": "system", "content": "You are a SQL expert. Generate only valid PostgreSQL queries."},
69
- {"role": "user", "content": prompt}
70
- ],
71
- temperature=0,
72
- max_tokens=500
73
- )
74
- return response.choices[0].message.content
75
- else:
76
- response = llm_client.messages.create(
77
- model=LLM_MODEL,
78
- max_tokens=500,
79
- temperature=0,
80
- messages=[{"role": "user", "content": prompt}]
81
- )
82
- return response.content[0].text
83
 
84
- def handle_discover_schema(instruction):
85
- """Discover PostgreSQL schema and store in Neo4j"""
86
- print(f"[{datetime.now()}] Discovering PostgreSQL schema...")
87
-
88
- # Call MCP to discover schema
89
- schema_result = call_mcp("discover_postgres_schema")
90
-
91
- if "error" in schema_result:
92
- return {"status": "failed", "error": schema_result["error"]}
93
-
94
- schema = schema_result["schema"]
95
-
96
- # Create SourceSystem node
97
- call_mcp("write_graph", {
98
- "action": "create_node",
99
- "label": "SourceSystem",
100
- "properties": {
101
- "id": "postgres-main",
102
- "name": "Main PostgreSQL Database",
103
- "type": "postgresql",
104
- "discovered_at": datetime.now().isoformat()
105
- }
106
- })
107
-
108
- # For each table, create nodes
109
- for table_name, columns in schema.items():
110
- # Create Table node
111
- table_result = call_mcp("write_graph", {
112
- "action": "create_node",
113
- "label": "Table",
114
- "properties": {
115
- "name": table_name,
116
- "schema": "public",
117
- "column_count": len(columns)
118
- }
119
- })
120
 
121
- # Link Table to SourceSystem
122
- call_mcp("query_graph", {
123
- "query": """
124
- MATCH (s:SourceSystem {id: 'postgres-main'}),
125
- (t:Table {name: $table_name})
126
- MERGE (s)-[:HAS_TABLE]->(t)
127
- """,
128
- "parameters": {"table_name": table_name}
129
- })
130
-
131
- # Create Column nodes
132
- for col in columns:
133
- col_result = call_mcp("write_graph", {
134
- "action": "create_node",
135
- "label": "Column",
136
- "properties": {
137
- "name": col['column_name'],
138
- "data_type": col['data_type'],
139
- "nullable": col['is_nullable'] == 'YES',
140
- "table_name": table_name
141
- }
142
- })
143
-
144
- # Link Column to Table
145
- call_mcp("query_graph", {
146
- "query": """
147
- MATCH (t:Table {name: $table_name}),
148
- (c:Column {name: $col_name, table_name: $table_name})
149
- MERGE (t)-[:HAS_COLUMN]->(c)
150
- """,
151
- "parameters": {
152
- "table_name": table_name,
153
- "col_name": col['column_name']
154
- }
155
- })
156
-
157
- # Generate sample queries
158
- for table_name in schema.keys():
159
- sample_queries = [
160
- f"SELECT * FROM {table_name} LIMIT 10",
161
- f"SELECT COUNT(*) FROM {table_name}",
162
- f"SELECT * FROM {table_name} WHERE id = 1"
163
  ]
164
 
165
- for idx, query in enumerate(sample_queries):
166
- call_mcp("write_graph", {
167
- "action": "create_node",
168
- "label": "QueryTemplate",
169
- "properties": {
170
- "id": f"template-{table_name}-{idx}",
171
- "table_name": table_name,
172
- "query": query,
173
- "description": f"Sample query {idx+1} for {table_name}"
174
- }
175
- })
176
-
177
- return {
178
- "status": "success",
179
- "tables_discovered": len(schema),
180
- "columns_discovered": sum(len(cols) for cols in schema.values())
181
- }
182
-
183
- def handle_generate_sql(instruction):
184
- """Generate SQL from natural language using LLM"""
185
- print(f"[{datetime.now()}] Generating SQL from natural language...")
186
-
187
- # Get the user question from instruction parameters
188
- params = json.loads(instruction.get('parameters', '{}'))
189
- user_question = params.get('question', 'Show all data')
190
-
191
- # Fetch schema from Neo4j
192
- schema_result = call_mcp("query_graph", {
193
- "query": """
194
- MATCH (t:Table)-[:HAS_COLUMN]->(c:Column)
195
- RETURN t.name as table_name,
196
- collect({
197
- name: c.name,
198
- type: c.data_type,
199
- nullable: c.nullable
200
- }) as columns
201
- """
202
- })
203
-
204
- # Format schema for LLM
205
- schema_text = "PostgreSQL Schema:\n"
206
- for record in schema_result['data']:
207
- table = record['table_name']
208
- columns = record['columns']
209
- schema_text += f"\nTable: {table}\n"
210
- for col in columns:
211
- nullable = "NULL" if col['nullable'] else "NOT NULL"
212
- schema_text += f" - {col['name']}: {col['type']} {nullable}\n"
213
-
214
- # Create prompt
215
- prompt = f"""Given this PostgreSQL schema:
216
-
217
- {schema_text}
218
-
219
- Generate a SQL query for this question: {user_question}
220
-
221
- Return ONLY the SQL query, no explanations or markdown."""
222
-
223
- try:
224
- # Get SQL from LLM
225
- generated_sql = get_llm_response(prompt)
226
-
227
- # Clean up the SQL (remove markdown if present)
228
- generated_sql = generated_sql.strip()
229
- if generated_sql.startswith("```"):
230
- generated_sql = generated_sql.split("```")[1]
231
- if generated_sql.startswith("sql"):
232
- generated_sql = generated_sql[3:]
233
- generated_sql = generated_sql.strip()
234
-
235
- print(f"[{datetime.now()}] Generated SQL: {generated_sql}")
236
 
237
- # Execute the SQL
238
- query_result = call_mcp("query_postgres", {"query": generated_sql})
 
 
 
 
 
 
239
 
240
- if "error" in query_result:
241
- return {
242
- "status": "failed",
243
- "generated_sql": generated_sql,
244
- "error": query_result["error"]
245
- }
246
-
247
- # Store successful query as template
248
- call_mcp("write_graph", {
249
- "action": "create_node",
250
- "label": "QueryTemplate",
251
- "properties": {
252
- "id": f"generated-{int(time.time())}",
253
- "query": generated_sql,
254
- "question": user_question,
255
- "created_at": datetime.now().isoformat()
256
- }
257
- })
258
-
259
- return {
260
- "status": "success",
261
- "generated_sql": generated_sql,
262
- "row_count": query_result.get("row_count", 0),
263
- "data": query_result.get("data", [])[:10] # Limit to 10 rows for storage
264
- }
265
-
266
- except Exception as e:
267
- return {
268
- "status": "failed",
269
- "error": str(e)
270
- }
271
-
272
- def check_workflow_stop(workflow_id):
273
- """Check if workflow has been marked to stop"""
274
- result = call_mcp("query_graph", {
275
- "query": "MATCH (w:Workflow {id: $id}) RETURN w.status as status",
276
- "parameters": {"id": workflow_id}
277
- })
278
-
279
- if result['data'] and result['data'][0]['status'] == 'stopped':
280
- return True
281
- return False
282
-
283
- def pause_with_interrupt(duration, instruction_id, workflow_id=None):
284
- """Pause for duration seconds with interrupt checking every 10 seconds"""
285
- print(f"[{datetime.now()}] Pausing for {duration} seconds for human review")
286
- print(f"[{datetime.now()}] You can edit instruction in Neo4j Browser:")
287
- print(f" MATCH (i:Instruction {{id: '{instruction_id}'}}) SET i.parameters = '{{\"key\": \"value\"}}'")
288
-
289
- # Log pause start
290
- call_mcp("write_graph", {
291
- "action": "create_node",
292
- "label": "Log",
293
- "properties": {
294
- "type": "pause_started",
295
- "instruction_id": instruction_id,
296
- "duration": duration,
297
- "timestamp": datetime.now().isoformat()
298
- }
299
- })
300
-
301
- elapsed = 0
302
- while elapsed < duration:
303
- # Check every 10 seconds
304
- sleep_time = min(10, duration - elapsed)
305
- time.sleep(sleep_time)
306
- elapsed += sleep_time
307
-
308
- # Check for workflow stop
309
- if workflow_id and check_workflow_stop(workflow_id):
310
- print(f"[{datetime.now()}] Workflow stopped during pause")
311
- return False
312
-
313
- # Check for global interrupt
314
- if interrupted:
315
- print(f"[{datetime.now()}] Interrupted during pause")
316
- return False
317
-
318
- # Show progress
319
- remaining = duration - elapsed
320
- if remaining > 0 and elapsed % 30 == 0: # Update every 30 seconds
321
- print(f"[{datetime.now()}] Pause remaining: {remaining} seconds")
322
-
323
- # Log pause end
324
- call_mcp("write_graph", {
325
- "action": "create_node",
326
- "label": "Log",
327
- "properties": {
328
- "type": "pause_completed",
329
- "instruction_id": instruction_id,
330
- "timestamp": datetime.now().isoformat()
331
- }
332
- })
333
-
334
- print(f"[{datetime.now()}] Pause complete, continuing execution")
335
- return True
336
-
337
  def main():
338
- global interrupted
339
- print(f"[{datetime.now()}] Agent starting, polling every {POLL_INTERVAL}s")
340
-
341
- while not interrupted:
342
- try:
343
- result = call_mcp("get_next_instruction")
344
- instruction = result.get("instruction")
345
-
346
- if instruction:
347
- print(f"[{datetime.now()}] Found instruction: {instruction['id']}, type: {instruction['type']}")
348
-
349
- # Get workflow ID
350
- workflow_result = call_mcp("query_graph", {
351
- "query": """
352
- MATCH (w:Workflow)-[:HAS_INSTRUCTION]->(i:Instruction {id: $id})
353
- RETURN w.id as workflow_id
354
- """,
355
- "parameters": {"id": instruction['id']}
356
- })
357
- workflow_id = workflow_result['data'][0]['workflow_id'] if workflow_result['data'] else None
358
-
359
- # PAUSE BEFORE EXECUTION
360
- pause_duration = instruction.get('pause_duration', 300)
361
- if pause_duration > 0:
362
- if not pause_with_interrupt(pause_duration, instruction['id'], workflow_id):
363
- print(f"[{datetime.now()}] Execution cancelled during pause")
364
- continue
365
-
366
- # Re-fetch instruction to get any edits made during pause
367
- refetch_result = call_mcp("query_graph", {
368
- "query": "MATCH (i:Instruction {id: $id}) RETURN i",
369
- "parameters": {"id": instruction['id']}
370
- })
371
-
372
- if refetch_result['data']:
373
- instruction = refetch_result['data'][0]['i']
374
- print(f"[{datetime.now()}] Re-fetched instruction after pause, parameters: {instruction.get('parameters')}")
375
-
376
- # Update status to executing
377
- call_mcp("query_graph", {
378
- "query": "MATCH (i:Instruction {id: $id}) SET i.status = 'executing'",
379
- "parameters": {"id": instruction['id']}
380
- })
381
-
382
- # Execute based on type
383
- if instruction['type'] == 'discover_schema':
384
- exec_result = handle_discover_schema(instruction)
385
- elif instruction['type'] == 'generate_sql':
386
- exec_result = handle_generate_sql(instruction)
387
- else:
388
- exec_result = {"status": "success", "result": "Reviewed"}
389
-
390
- # Store execution result
391
- exec_node = call_mcp("write_graph", {
392
- "action": "create_node",
393
- "label": "Execution",
394
- "properties": {
395
- "id": f"exec-{instruction['id']}-{int(time.time())}",
396
- "started_at": datetime.now().isoformat(),
397
- "completed_at": datetime.now().isoformat(),
398
- "result": json.dumps(exec_result)
399
- }
400
- })
401
-
402
- # Link execution
403
- call_mcp("query_graph", {
404
- "query": """
405
- MATCH (i:Instruction {id: $iid}), (e:Execution {id: $eid})
406
- CREATE (i)-[:EXECUTED_AS]->(e)
407
- """,
408
- "parameters": {
409
- "iid": instruction['id'],
410
- "eid": exec_node['created']['id']
411
- }
412
- })
413
-
414
- # Update status
415
- final_status = 'complete' if exec_result.get('status') == 'success' else 'failed'
416
- call_mcp("query_graph", {
417
- "query": "MATCH (i:Instruction {id: $id}) SET i.status = $status",
418
- "parameters": {"id": instruction['id'], "status": final_status}
419
- })
420
-
421
- print(f"[{datetime.now()}] Completed instruction: {instruction['id']}")
422
-
423
- else:
424
- print(f"[{datetime.now()}] No pending instructions")
425
-
426
- time.sleep(POLL_INTERVAL)
427
-
428
- except Exception as e:
429
- print(f"[{datetime.now()}] Error: {e}")
430
- time.sleep(POLL_INTERVAL)
431
-
432
- print(f"[{datetime.now()}] Agent shutting down")
433
 
434
  if __name__ == "__main__":
435
  main()
 
1
  import os
 
 
 
 
2
  import sys
3
+ import logging
4
+ import json
5
+ from typing import Annotated, List, TypedDict
6
+ from fastapi import FastAPI
7
+ from pydantic import BaseModel
8
+ import uvicorn
9
+ from fastapi.responses import StreamingResponse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ from langchain_core.messages import BaseMessage, ToolMessage, AIMessage
12
+ from langchain_openai import OpenAI
13
+ from langgraph.graph import StateGraph, START, END
14
+ from langgraph.prebuilt import ToolNode
 
 
 
 
 
 
 
15
 
16
+ from agent.tools import MCPClient, SchemaSearchTool, JoinPathFinderTool, QueryExecutorTool
17
 
18
+ # --- Configuration & Logging ---
19
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
20
+ logger = logging.getLogger(__name__)
21
 
22
+ MCP_URL = os.getenv("MCP_URL", "http://mcp:8000/mcp")
23
+ API_KEY = os.getenv("MCP_API_KEY", "dev-key-123")
24
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
 
25
 
26
+ # --- Agent State Definition ---
27
+ class AgentState(TypedDict):
28
+ messages: List[BaseMessage]
29
 
30
+ # --- Agent Initialization ---
31
+ class GraphRAGAgent:
32
+ """The core agent for handling GraphRAG queries using LangGraph."""
 
 
 
 
33
 
34
+ def __init__(self):
35
+ if not OPENAI_API_KEY:
36
+ raise ValueError("OPENAI_API_KEY environment variable not set.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ llm = OpenAI(api_key=OPENAI_API_KEY, temperature=0, max_retries=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ mcp_client = MCPClient(mcp_url=MCP_URL, api_key=API_KEY)
41
+ tools = [
42
+ SchemaSearchTool(mcp_client=mcp_client),
43
+ JoinPathFinderTool(mcp_client=mcp_client),
44
+ QueryExecutorTool(mcp_client=mcp_client),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  ]
46
 
47
+ self.llm_with_tools = llm.bind_tools(tools)
48
+ self.tool_node = ToolNode(tools)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ # Define the agent graph
51
+ graph = StateGraph(AgentState)
52
+ graph.add_node("llm", self.call_llm)
53
+ graph.add_node("tools", self.tool_node)
54
+
55
+ graph.add_edge(START, "llm")
56
+ graph.add_conditional_edges("llm", self.should_call_tools)
57
+ graph.add_edge("tools", "llm")
58
 
59
+ self.graph = graph.compile()
60
+
61
+ def should_call_tools(self, state: AgentState) -> str:
62
+ """Determines whether to call tools or end the execution."""
63
+ last_message = state["messages"][-1]
64
+ if not last_message.tool_calls:
65
+ return END
66
+ return "tools"
67
+
68
+ def call_llm(self, state: AgentState) -> dict:
69
+ """Calls the LLM with the current state to decide the next action."""
70
+ response = self.llm_with_tools.invoke(state["messages"])
71
+ return {"messages": [response]}
72
+
73
+ async def stream_query(self, question: str):
74
+ """Processes a question and streams the intermediate steps."""
75
+ inputs = {"messages": [("user", question)]}
76
+ async for event in self.graph.astream(inputs, stream_mode="values"):
77
+ last_message = event["messages"][-1]
78
+ if isinstance(last_message, AIMessage) and last_message.tool_calls:
79
+ # Agent is thinking and calling a tool
80
+ tool_call = last_message.tool_calls[0]
81
+ yield json.dumps({
82
+ "type": "thought",
83
+ "content": f"🤖 Calling tool `{tool_call['name']}` with args: {tool_call['args']}"
84
+ }) + "\\n\\n"
85
+ elif isinstance(last_message, ToolMessage):
86
+ # A tool has returned its result
87
+ yield json.dumps({
88
+ "type": "observation",
89
+ "content": f"🛠️ Tool `{last_message.name}` returned:\n\n```\n{last_message.content}\n```"
90
+ }) + "\\n\\n"
91
+ elif isinstance(last_message, AIMessage):
92
+ # This is the final answer
93
+ yield json.dumps({"type": "final_answer", "content": last_message.content}) + "\\n\\n"
94
+
95
+ # --- FastAPI Application ---
96
+ app = FastAPI(title="GraphRAG Agent Server")
97
+ agent = None
98
+
99
+ class QueryRequest(BaseModel):
100
+ question: str
101
+
102
+ @app.on_event("startup")
103
+ def startup_event():
104
+ """Initialize the agent on server startup."""
105
+ global agent
106
+ try:
107
+ agent = GraphRAGAgent()
108
+ logger.info("GraphRAGAgent initialized successfully.")
109
+ except ValueError as e:
110
+ logger.error(f"Agent initialization failed: {e}")
111
+
112
+ @app.post("/query")
113
+ async def execute_query(request: QueryRequest) -> StreamingResponse:
114
+ """Endpoint to receive questions and stream the agent's response."""
115
+ if not agent:
116
+ async def error_stream():
117
+ yield json.dumps({"error": "Agent is not initialized. Check server logs."})
118
+ return StreamingResponse(error_stream())
119
+
120
+ return StreamingResponse(agent.stream_query(request.question), media_type="application/x-ndjson")
121
+
122
+ @app.get("/health")
123
+ def health_check():
124
+ """Health check endpoint."""
125
+ return {"status": "ok", "agent_initialized": agent is not None}
126
+
127
+ # --- Main Execution ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  def main():
129
+ """Main entry point to run the FastAPI server."""
130
+ logger.info("Starting agent server...")
131
+ uvicorn.run(app, host="0.0.0.0", port=8001)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  if __name__ == "__main__":
134
  main()
agent/requirements.txt CHANGED
@@ -1,4 +1,8 @@
1
- requests==2.31.0
2
- python-dotenv==1.0.0
3
- openai==0.28.1
4
- anthropic==0.7.0
 
 
 
 
 
1
+ requests
2
+ python-dotenv
3
+ langchain
4
+ langchain-openai
5
+ pydantic
6
+ fastapi
7
+ uvicorn[standard]
8
+ langgraph
agent/tools.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import json
4
+ from typing import Dict, Any, List, Optional
5
+ from langchain.tools import BaseTool
6
+ from pydantic import Field
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class MCPClient:
12
+ """Client for making authenticated REST API calls to the MCP server."""
13
+
14
+ def __init__(self, mcp_url: str, api_key: str):
15
+ self.mcp_url = mcp_url
16
+ self.headers = {
17
+ "Authorization": f"Bearer {api_key}",
18
+ "Content-Type": "application/json"
19
+ }
20
+
21
+ def post(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]:
22
+ """Send a POST request to a given MCP endpoint."""
23
+ try:
24
+ url = f"{self.mcp_url}/{endpoint}"
25
+ response = requests.post(url, headers=self.headers, data=json.dumps(data))
26
+ response.raise_for_status()
27
+ return response.json()
28
+ except requests.exceptions.HTTPError as http_err:
29
+ logger.error(f"HTTP error occurred: {http_err} - {response.text}")
30
+ return {"status": "error", "message": f"HTTP error: {response.status_code} {response.reason}"}
31
+ except requests.exceptions.RequestException as req_err:
32
+ logger.error(f"Request error occurred: {req_err}")
33
+ return {"status": "error", "message": f"Request failed: {req_err}"}
34
+ except json.JSONDecodeError:
35
+ logger.error("Failed to decode JSON response.")
36
+ return {"status": "error", "message": "Invalid JSON response from server."}
37
+
38
+
39
+ class SchemaSearchTool(BaseTool):
40
+ """LangChain tool for searching database schemas."""
41
+
42
+ name: str = "schema_search"
43
+ description: str = """
44
+ Search for relevant database schemas based on a natural language query.
45
+ Use this when you need to find which tables/columns are relevant to a user's question.
46
+ Input should be a descriptive query like 'patient information' or 'drug trials'.
47
+ """
48
+ mcp_client: MCPClient
49
+
50
+ def _run(self, query: str) -> str:
51
+ """Execute schema search."""
52
+ response = self.mcp_client.post("discovery/get_relevant_schemas", {"query": query})
53
+
54
+ if response.get("status") == "success":
55
+ schemas = response.get("schemas", [])
56
+ if schemas:
57
+ schema_text = "Found relevant schemas:\\n"
58
+ for schema in schemas:
59
+ schema_text += f"- {schema.get('database', 'Unknown')}.{schema.get('table', 'Unknown')}.{schema.get('name', 'Unknown')} ({schema.get('type', ['Unknown'])[0]})\\n"
60
+ return schema_text
61
+ else:
62
+ return "No relevant schemas found."
63
+ else:
64
+ return f"Error searching schemas: {response.get('message', 'Unknown error')}"
65
+
66
+ async def _arun(self, query: str) -> str:
67
+ raise NotImplementedError("SchemaSearchTool does not support async")
68
+
69
+
70
+ class JoinPathFinderTool(BaseTool):
71
+ """LangChain tool for finding join paths between tables."""
72
+
73
+ name: str = "find_join_path"
74
+ description: str = """
75
+ Find how to join two tables together using foreign key relationships.
76
+ Use this when you need to query across multiple tables.
77
+ Input should be two table names separated by a comma, like 'patients,studies'.
78
+ """
79
+ mcp_client: MCPClient
80
+
81
+ def _run(self, table_names: str) -> str:
82
+ """Find join path."""
83
+ try:
84
+ tables = [t.strip() for t in table_names.split(',')]
85
+ if len(tables) != 2:
86
+ return "Please provide exactly two table names separated by a comma."
87
+
88
+ response = self.mcp_client.post(
89
+ "graph/find_join_path",
90
+ {"table1": tables[0], "table2": tables[1]}
91
+ )
92
+
93
+ if response.get("status") == "success":
94
+ path = response.get("path", "No path found")
95
+ return f"Join path: {path}"
96
+ else:
97
+ return f"Error finding join path: {response.get('message', 'Unknown error')}"
98
+ except Exception as e:
99
+ return f"Failed to find join path: {str(e)}"
100
+
101
+ async def _arun(self, table_names: str) -> str:
102
+ raise NotImplementedError("JoinPathFinderTool does not support async")
103
+
104
+
105
+ class QueryExecutorTool(BaseTool):
106
+ """LangChain tool for executing SQL queries."""
107
+
108
+ name: str = "execute_query"
109
+ description: str = """
110
+ Execute a SQL query against the databases and return results.
111
+ Use this after you have a valid SQL query.
112
+ Input should be a valid SQL query string.
113
+ """
114
+ mcp_client: MCPClient
115
+
116
+ def _run(self, sql: str) -> str:
117
+ """Execute query."""
118
+ try:
119
+ response = self.mcp_client.post(
120
+ "intelligence/execute_query",
121
+ {"sql": sql}
122
+ )
123
+
124
+ if response.get("status") == "success":
125
+ results = response.get("results", [])
126
+
127
+ if results:
128
+ # Format results as a readable table
129
+ result_text = f"Query returned {len(results)} rows:\\n"
130
+ headers = list(results[0].keys())
131
+ result_text += " | ".join(headers) + "\\n"
132
+ result_text += "-" * (len(" | ".join(headers))) + "\\n"
133
+
134
+ for row in results[:10]: # Limit display to first 10 rows
135
+ values = [str(row.get(h, "")) for h in headers]
136
+ result_text += " | ".join(values) + "\\n"
137
+
138
+ if len(results) > 10:
139
+ result_text += f"... and {len(results) - 10} more rows\\n"
140
+
141
+ return result_text
142
+ else:
143
+ return "Query executed successfully but returned no results."
144
+ else:
145
+ return f"Error executing query: {response.get('message', 'Unknown error')}"
146
+ except Exception as e:
147
+ return f"Failed to execute query: {str(e)}"
148
+
149
+ async def _arun(self, sql: str) -> str:
150
+ raise NotImplementedError("QueryExecutorTool does not support async")
docker-compose.yml CHANGED
@@ -19,25 +19,6 @@ services:
19
  networks:
20
  - agent-network
21
 
22
- postgres:
23
- image: postgres:15
24
- environment:
25
- - POSTGRES_DB=testdb
26
- - POSTGRES_USER=postgres
27
- - POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
28
- ports:
29
- - "5432:5432"
30
- volumes:
31
- - ./postgres/data:/var/lib/postgresql/data
32
- - ./postgres/init.sql:/docker-entrypoint-initdb.d/init.sql
33
- healthcheck:
34
- test: ["CMD-SHELL", "pg_isready -U postgres"]
35
- interval: 10s
36
- timeout: 5s
37
- retries: 5
38
- networks:
39
- - agent-network
40
-
41
  mcp:
42
  build: ./mcp
43
  ports:
@@ -45,17 +26,15 @@ services:
45
  environment:
46
  - NEO4J_BOLT_URL=${NEO4J_BOLT_URL}
47
  - NEO4J_AUTH=${NEO4J_AUTH}
48
- - POSTGRES_CONNECTION=${POSTGRES_CONNECTION}
49
  - MCP_API_KEYS=${MCP_API_KEYS}
50
  - MCP_PORT=${MCP_PORT}
51
  depends_on:
52
  neo4j:
53
  condition: service_healthy
54
- postgres:
55
- condition: service_healthy
56
  volumes:
57
  - ./mcp:/app
58
  - ./ops/scripts:/app/ops/scripts
 
59
  command: uvicorn main:app --host 0.0.0.0 --port 8000 --reload
60
  healthcheck:
61
  test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
@@ -71,33 +50,17 @@ services:
71
  - MCP_URL=http://mcp:8000/mcp
72
  - MCP_API_KEY=dev-key-123
73
  - AGENT_POLL_INTERVAL=${AGENT_POLL_INTERVAL}
74
- - PAUSE_DURATION=${PAUSE_DURATION}
75
- - LLM_API_KEY=${LLM_API_KEY}
76
- - LLM_MODEL=${LLM_MODEL}
77
- - POSTGRES_CONNECTION=${POSTGRES_CONNECTION}
78
  depends_on:
79
  mcp:
80
  condition: service_healthy
81
  volumes:
82
  - ./agent:/app
 
83
  command: python -u main.py
84
- restart: unless-stopped
85
- networks:
86
- - agent-network
87
-
88
- frontend:
89
- build: ./frontend
90
  ports:
91
- - "3000:3000"
92
- environment:
93
- - NEXT_PUBLIC_MCP_URL=http://localhost:8000
94
- depends_on:
95
- - mcp
96
- volumes:
97
- - ./frontend:/app
98
- - /app/node_modules
99
- - /app/.next
100
- command: npm run dev
101
  networks:
102
  - agent-network
103
 
@@ -106,6 +69,7 @@ services:
106
  ports:
107
  - "8501:8501"
108
  environment:
 
109
  - MCP_URL=http://mcp:8000/mcp
110
  - MCP_API_KEY=dev-key-123
111
  depends_on:
 
19
  networks:
20
  - agent-network
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  mcp:
23
  build: ./mcp
24
  ports:
 
26
  environment:
27
  - NEO4J_BOLT_URL=${NEO4J_BOLT_URL}
28
  - NEO4J_AUTH=${NEO4J_AUTH}
 
29
  - MCP_API_KEYS=${MCP_API_KEYS}
30
  - MCP_PORT=${MCP_PORT}
31
  depends_on:
32
  neo4j:
33
  condition: service_healthy
 
 
34
  volumes:
35
  - ./mcp:/app
36
  - ./ops/scripts:/app/ops/scripts
37
+ - ./data:/app/data
38
  command: uvicorn main:app --host 0.0.0.0 --port 8000 --reload
39
  healthcheck:
40
  test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
 
50
  - MCP_URL=http://mcp:8000/mcp
51
  - MCP_API_KEY=dev-key-123
52
  - AGENT_POLL_INTERVAL=${AGENT_POLL_INTERVAL}
53
+ - OPENAI_API_KEY=${OPENAI_API_KEY}
 
 
 
54
  depends_on:
55
  mcp:
56
  condition: service_healthy
57
  volumes:
58
  - ./agent:/app
59
+ - ./data:/app/data
60
  command: python -u main.py
61
+ restart: on-failure
 
 
 
 
 
62
  ports:
63
+ - "8001:8001"
 
 
 
 
 
 
 
 
 
64
  networks:
65
  - agent-network
66
 
 
69
  ports:
70
  - "8501:8501"
71
  environment:
72
+ - AGENT_URL=http://agent:8001/query
73
  - MCP_URL=http://mcp:8000/mcp
74
  - MCP_API_KEY=dev-key-123
75
  depends_on:
mcp/__init__.py ADDED
File without changes
mcp/core/__init__.py ADDED
File without changes
mcp/core/config.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # --- Neo4j Configuration ---
4
+ NEO4J_URI = os.getenv("NEO4J_BOLT_URL", "bolt://neo4j:7687")
5
+ NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
6
+ NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "password")
7
+
8
+ # --- SQLite Configuration ---
9
+ SQLITE_DATA_DIR = os.getenv("SQLITE_DATA_DIR", "/app/data")
10
+
11
+ def get_sqlite_connection_string(db_name: str) -> str:
12
+ """
13
+ Generates the SQLAlchemy connection string for a given SQLite database file.
14
+ Assumes the database file is located in the SQLITE_DATA_DIR.
15
+ Example: get_sqlite_connection_string("clinical_trials.db")
16
+ -> "sqlite:////app/data/clinical_trials.db"
17
+ """
18
+ db_path = os.path.join(SQLITE_DATA_DIR, db_name)
19
+ return f"sqlite:///{db_path}"
20
+
21
+ # --- Application Settings ---
22
+ # You can add other application-wide settings here
23
+ # For example, API keys, logging levels, etc.
24
+ # These would typically be loaded from environment variables as well.
mcp/core/database.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import create_engine
2
+ from sqlalchemy.engine import Engine
3
+ import logging
4
+
5
+ logging.basicConfig(level=logging.INFO)
6
+ logger = logging.getLogger(__name__)
7
+
8
+ def get_db_engine(connection_string: str) -> Engine | None:
9
+ """
10
+ Creates a SQLAlchemy engine for a given database connection string.
11
+
12
+ Args:
13
+ connection_string: The database connection string.
14
+
15
+ Returns:
16
+ A SQLAlchemy Engine instance, or None if connection fails.
17
+ """
18
+ try:
19
+ engine = create_engine(connection_string)
20
+ # Test the connection
21
+ with engine.connect() as connection:
22
+ logger.info(f"Successfully connected to {engine.url.database}")
23
+ return engine
24
+ except Exception as e:
25
+ logger.error(f"Failed to connect to database: {e}")
26
+ return None
mcp/core/discovery.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import inspect, text
2
+ from sqlalchemy.engine import Engine
3
+ from typing import Dict, Any, List
4
+ import logging
5
+ import json
6
+ from concurrent.futures import TimeoutError, ThreadPoolExecutor
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ def get_table_schema(inspector, table_name: str) -> Dict[str, Any]:
11
+ """Extracts schema for a single table."""
12
+ columns = inspector.get_columns(table_name)
13
+ primary_keys = inspector.get_pk_constraint(table_name)['constrained_columns']
14
+ foreign_keys = inspector.get_foreign_keys(table_name)
15
+
16
+ table_schema = {
17
+ "name": table_name,
18
+ "columns": [],
19
+ "primary_keys": primary_keys,
20
+ "foreign_keys": foreign_keys
21
+ }
22
+
23
+ for col in columns:
24
+ table_schema["columns"].append({
25
+ "name": col['name'],
26
+ "type": str(col['type']),
27
+ "nullable": col['nullable'],
28
+ "default": col.get('default'),
29
+ })
30
+ return table_schema
31
+
32
+ def get_sample_data(engine: Engine, table_name: str, sample_size: int = 5) -> Dict[str, Any]:
33
+ """Fetches sample data and distinct values for each column."""
34
+ sample_data = {}
35
+ with engine.connect() as connection:
36
+ # Get row count
37
+ try:
38
+ result = connection.execute(text(f'SELECT COUNT(*) FROM "{table_name}"'))
39
+ sample_data['row_count'] = result.scalar_one()
40
+ except Exception as e:
41
+ logger.warning(f"Could not get row count for table {table_name}: {e}")
42
+ sample_data['row_count'] = -1 # Indicate error or unknown
43
+
44
+ # Get sample rows
45
+ try:
46
+ result = connection.execute(text(f'SELECT * FROM "{table_name}" LIMIT {sample_size}'))
47
+ rows = [dict(row._mapping) for row in result.fetchall()]
48
+ # Attempt to JSON serialize to handle complex types gracefully
49
+ sample_data['sample_rows'] = json.loads(json.dumps(rows, default=str))
50
+ except Exception as e:
51
+ logger.warning(f"Could not get sample rows for table {table_name}: {e}")
52
+ sample_data['sample_rows'] = []
53
+
54
+ return sample_data
55
+
56
+
57
+ def discover_schema(engine: Engine, timeout: int = 30) -> Dict[str, Any] | None:
58
+ """
59
+ Discovers the full schema of a database using SQLAlchemy's inspection API.
60
+ Includes table schemas and sample data.
61
+ """
62
+ try:
63
+ with ThreadPoolExecutor() as executor:
64
+ future = executor.submit(_discover_schema_task, engine)
65
+ return future.result(timeout=timeout)
66
+ except TimeoutError:
67
+ logger.error(f"Schema discovery for {engine.url.database} timed out after {timeout} seconds.")
68
+ return None
69
+ except Exception as e:
70
+ logger.error(f"An unexpected error occurred during schema discovery for {engine.url.database}: {e}")
71
+ return None
72
+
73
+ def _discover_schema_task(engine: Engine) -> Dict[str, Any]:
74
+ """The actual schema discovery logic to be run with a timeout."""
75
+ inspector = inspect(engine)
76
+ db_schema = {
77
+ "database_name": engine.url.database,
78
+ "dialect": engine.dialect.name,
79
+ "tables": []
80
+ }
81
+
82
+ table_names = inspector.get_table_names()
83
+
84
+ for table_name in table_names:
85
+ try:
86
+ logger.info(f"Discovering schema for table: {table_name}")
87
+ table_schema = get_table_schema(inspector, table_name)
88
+
89
+ logger.info(f"Collecting sample data for table: {table_name}")
90
+ sample_info = get_sample_data(engine, table_name)
91
+ table_schema.update(sample_info)
92
+
93
+ db_schema["tables"].append(table_schema)
94
+ except Exception as e:
95
+ logger.error(f"Could not inspect table '{table_name}': {e}")
96
+ continue
97
+
98
+ return db_schema
mcp/core/graph.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from neo4j import GraphDatabase
2
+ import logging
3
+ import json
4
+ from typing import List, Dict, Any
5
+ from . import config
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ class GraphStore:
10
+ def __init__(self):
11
+ self._driver = GraphDatabase.driver(config.NEO4J_URI, auth=(config.NEO4J_USER, config.NEO4J_PASSWORD))
12
+ self.ensure_constraints()
13
+
14
+ def close(self):
15
+ self._driver.close()
16
+
17
+ def ensure_constraints(self):
18
+ """Ensure uniqueness constraints are set up in Neo4j."""
19
+ with self._driver.session() as session:
20
+ session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (d:Database) REQUIRE d.name IS UNIQUE")
21
+ session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (t:Table) REQUIRE t.unique_name IS UNIQUE")
22
+ session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (c:Column) REQUIRE c.unique_name IS UNIQUE")
23
+ logger.info("Neo4j constraints ensured.")
24
+
25
+ def import_schema(self, schema_data: dict):
26
+ """
27
+ Imports a discovered database schema into the Neo4j graph.
28
+ """
29
+ db_name = schema_data['database_name']
30
+
31
+ with self._driver.session() as session:
32
+ # Create Database node
33
+ session.run("MERGE (d:Database {name: $db_name})", db_name=db_name)
34
+
35
+ for table in schema_data['tables']:
36
+ table_unique_name = f"{db_name}.{table['name']}"
37
+ table_properties = {
38
+ "name": table['name'],
39
+ "unique_name": table_unique_name,
40
+ "row_count": table.get('row_count', -1),
41
+ "sample_rows": json.dumps(table.get('sample_rows', []))
42
+ }
43
+
44
+ # Create Table node and HAS_TABLE relationship
45
+ session.run(
46
+ """
47
+ MATCH (d:Database {name: $db_name})
48
+ MERGE (t:Table {unique_name: $unique_name})
49
+ ON CREATE SET t += $props
50
+ ON MATCH SET t += $props
51
+ MERGE (d)-[:HAS_TABLE]->(t)
52
+ """,
53
+ db_name=db_name,
54
+ unique_name=table_unique_name,
55
+ props=table_properties
56
+ )
57
+
58
+ for column in table['columns']:
59
+ column_unique_name = f"{table_unique_name}.{column['name']}"
60
+ column_properties = {
61
+ "name": column['name'],
62
+ "unique_name": column_unique_name,
63
+ "type": column['type'],
64
+ "nullable": column['nullable'],
65
+ "default": str(column.get('default')) # Ensure default is string
66
+ }
67
+
68
+ # Create Column node and HAS_COLUMN relationship
69
+ session.run(
70
+ """
71
+ MATCH (t:Table {unique_name: $table_unique_name})
72
+ MERGE (c:Column {unique_name: $column_unique_name})
73
+ ON CREATE SET c += $props
74
+ ON MATCH SET c += $props
75
+ MERGE (t)-[:HAS_COLUMN]->(c)
76
+ """,
77
+ table_unique_name=table_unique_name,
78
+ column_unique_name=column_unique_name,
79
+ props=column_properties
80
+ )
81
+
82
+ # After all tables and columns are created, create foreign key relationships
83
+ for table in schema_data['tables']:
84
+ table_unique_name = f"{db_name}.{table['name']}"
85
+ if table.get('foreign_keys'):
86
+ for fk in table['foreign_keys']:
87
+ constrained_columns = fk['constrained_columns']
88
+ referred_table = fk['referred_table']
89
+ referred_columns = fk['referred_columns']
90
+
91
+ referred_table_unique_name = f"{db_name}.{referred_table}"
92
+
93
+ for i, col_name in enumerate(constrained_columns):
94
+ from_col_unique_name = f"{table_unique_name}.{col_name}"
95
+ to_col_unique_name = f"{referred_table_unique_name}.{referred_columns[i]}"
96
+
97
+ session.run(
98
+ """
99
+ MATCH (from_col:Column {unique_name: $from_col})
100
+ MATCH (to_col:Column {unique_name: $to_col})
101
+ MERGE (from_col)-[:REFERENCES]->(to_col)
102
+ """,
103
+ from_col=from_col_unique_name,
104
+ to_col=to_col_unique_name
105
+ )
106
+ logger.info(f"Successfully imported schema for database: {db_name}")
107
+
108
+ def find_shortest_path(self, start_node_name: str, end_node_name: str) -> List[Dict[str, Any]]:
109
+ """
110
+ Finds the shortest path between two nodes (Tables or Columns) in the graph.
111
+ This is a generic pathfinder.
112
+ """
113
+ query = """
114
+ MATCH (start {unique_name: $start_name}), (end {unique_name: $end_name})
115
+ CALL apoc.path.shortestPath(start, end, 'REFERENCES|HAS_COLUMN|HAS_TABLE', {maxLevel: 10}) YIELD path
116
+ RETURN path
117
+ """
118
+ with self._driver.session() as session:
119
+ result = session.run(query, start_name=start_node_name, end_name=end_node_name)
120
+ # The result is complex, we need to parse it into a user-friendly format.
121
+ # For now, returning the raw path objects.
122
+ return [record["path"] for record in result]
123
+
124
+ def keyword_search(self, keyword: str) -> List[Dict[str, Any]]:
125
+ """
126
+ Searches for tables and columns matching a keyword.
127
+ Returns a list of matching nodes with their database and table context.
128
+ """
129
+ query = """
130
+ MATCH (n)
131
+ WHERE (n:Table OR n:Column) AND n.name CONTAINS $keyword
132
+ OPTIONAL MATCH (d:Database)-[:HAS_TABLE]->(t:Table)-[:HAS_COLUMN]->(n) WHERE n:Column
133
+ OPTIONAL MATCH (d2:Database)-[:HAS_TABLE]->(n) WHERE n:Table
134
+ WITH COALESCE(d, d2) AS db, COALESCE(t, n) AS tbl, n AS item
135
+ RETURN db.name AS database, tbl.name AS table, item.name AS name, labels(item) AS type
136
+ LIMIT 25
137
+ """
138
+ with self._driver.session() as session:
139
+ result = session.run(query, keyword=keyword)
140
+ return [record.data() for record in result]
141
+
142
+ def get_table_row_count(self, table_unique_name: str) -> int:
143
+ """Retrieves the stored row count for a given table."""
144
+ query = """
145
+ MATCH (t:Table {unique_name: $unique_name})
146
+ RETURN t.row_count AS row_count
147
+ """
148
+ with self._driver.session() as session:
149
+ result = session.run(query, unique_name=table_unique_name)
150
+ record = result.single()
151
+ return record['row_count'] if record else -1
mcp/core/intelligence.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlparse
2
+ import logging
3
+ from typing import List, Dict, Any
4
+
5
+ from .graph import GraphStore
6
+ from .database import get_db_engine
7
+ from . import config
8
+ from sqlalchemy import text
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # Constants for query cost estimation
13
+ ROW_EXECUTION_THRESHOLD = 100 # Execute queries expected to return fewer rows
14
+ JOIN_CARDINALITY_ESTIMATE = 1000 # A simplistic estimate for joins
15
+
16
+ class QueryIntelligence:
17
+ """
18
+ Provides intelligence for handling SQL queries. It estimates query cost
19
+ and decides on an execution strategy.
20
+ """
21
+ def __init__(self, graph_store: GraphStore):
22
+ self.graph_store = graph_store
23
+ self.db_engines = {}
24
+
25
+ def _get_engine_for_db(self, db_name: str):
26
+ """Helper to get or create an engine for a specific database."""
27
+ if db_name not in self.db_engines:
28
+ # Assuming db_name includes the .db extension
29
+ connection_string = config.get_sqlite_connection_string(db_name)
30
+ self.db_engines[db_name] = get_db_engine(connection_string)
31
+ return self.db_engines.get(db_name)
32
+
33
+ async def get_relevant_schemas(self, query: str) -> List[Dict[str, Any]]:
34
+ """Finds schemas relevant to a natural language query."""
35
+ # This is a simplistic keyword search. A real implementation would use
36
+ # embedding-based search or an LLM to extract entities.
37
+ keywords = query.split()
38
+ all_results = []
39
+ for keyword in keywords:
40
+ if len(keyword) > 2: # Avoid very short keywords
41
+ results = self.graph_store.keyword_search(keyword)
42
+ all_results.extend(results)
43
+ # Deduplicate results
44
+ return [dict(t) for t in {tuple(d.items()) for d in all_results}]
45
+
46
+ async def find_join_path(self, table1_name: str, table2_name: str) -> str:
47
+ """Finds a join path between two tables using the graph."""
48
+ # This is a simplification. It requires table names to be unique or requires
49
+ # the user to provide fully qualified names (db.table).
50
+ t1_nodes = self.graph_store.keyword_search(table1_name)
51
+ t2_nodes = self.graph_store.keyword_search(table2_name)
52
+
53
+ if not t1_nodes or not t2_nodes:
54
+ return "Could not find one or both tables."
55
+
56
+ # Assume the first result is correct for simplicity
57
+ t1_unique_name = f"{t1_nodes[0]['database']}.{t1_nodes[0]['table']}"
58
+ t2_unique_name = f"{t2_nodes[0]['database']}.{t2_nodes[0]['table']}"
59
+
60
+ path_result = self.graph_store.find_shortest_path(t1_unique_name, t2_unique_name)
61
+
62
+ if not path_result:
63
+ return f"No path found between {table1_name} and {table2_name}."
64
+
65
+ # Format the path for display
66
+ # This is a complex task. The raw path from Neo4j needs careful parsing.
67
+ # This is a placeholder for that logic.
68
+ return f"Path found (details require parsing): {path_result}"
69
+
70
+ async def execute_query(self, sql: str, limit: int) -> List[Dict[str, Any]]:
71
+ """
72
+ Executes a SQL query against the appropriate database if the estimated
73
+ cost is below the threshold.
74
+ """
75
+ cost_estimate = self.estimate_query_cost(sql)
76
+
77
+ if cost_estimate['decision'] != 'execute':
78
+ raise PermissionError(f"Query execution denied. Estimated cost is too high ({cost_estimate['estimated_rows']} rows).")
79
+
80
+ # This is a major simplification. Determining which database to run the query
81
+ # against is a hard problem (especially for federated queries).
82
+ # We assume the first table found belongs to the correct database.
83
+ parsed_sql = self._parse_sql(sql)
84
+ if not parsed_sql['tables']:
85
+ raise ValueError("No tables found in SQL query.")
86
+
87
+ first_table = parsed_sql['tables'][0]
88
+ search_results = self.graph_store.keyword_search(first_table)
89
+ if not search_results:
90
+ raise ValueError(f"Table '{first_table}' not found in any known database.")
91
+
92
+ db_name = search_results[0]['database']
93
+ engine = self._get_engine_for_db(db_name)
94
+
95
+ if not engine:
96
+ raise ConnectionError(f"Could not connect to database: {db_name}")
97
+
98
+ with engine.connect() as connection:
99
+ # Append limit to the query
100
+ safe_sql = f"{sql.strip().rstrip(';')} LIMIT {int(limit)}"
101
+ result = connection.execute(text(safe_sql))
102
+ return [dict(row._mapping) for row in result.fetchall()]
103
+
104
+ def _parse_sql(self, sql: str) -> Dict[str, Any]:
105
+ """Parses the SQL to identify tables and columns."""
106
+ parsed = sqlparse.parse(sql)[0]
107
+ # This is a simplistic parser. A real implementation would need
108
+ # a much more robust SQL parsing library to handle complex queries, CTEs, etc.
109
+ tables = set()
110
+ for token in parsed.tokens:
111
+ if isinstance(token, sqlparse.sql.Identifier):
112
+ tables.add(token.get_real_name())
113
+ elif token.is_group:
114
+ # Look for identifiers within subgroups (e.g., in FROM or JOIN clauses)
115
+ for sub_token in token.tokens:
116
+ if isinstance(sub_token, sqlparse.sql.Identifier):
117
+ tables.add(sub_token.get_real_name())
118
+
119
+ return {"tables": list(tables)}
120
+
121
+ def estimate_query_cost(self, sql: str) -> Dict[str, Any]:
122
+ """
123
+ Estimates the cost of a query based on row counts from the graph.
124
+ """
125
+ try:
126
+ parsed_sql = self._parse_sql(sql)
127
+ tables_in_query = parsed_sql['tables']
128
+
129
+ if not tables_in_query:
130
+ return {"estimated_rows": 0, "decision": "execute", "message": "No tables found in query."}
131
+
132
+ # For simplicity, we'll take the max row count of any table in the query.
133
+ # A real system would analyze JOINs and WHERE clauses.
134
+ max_rows = 0
135
+ for table_name in tables_in_query:
136
+ # Need to find the unique name. This assumes table names are unique across DBs for now.
137
+ # A real implementation needs context of which DB is being queried.
138
+ search_result = self.graph_store.keyword_search(table_name)
139
+ if search_result:
140
+ table_unique_name = f"{search_result[0]['database']}.{search_result[0]['table']}"
141
+ row_count = self.graph_store.get_table_row_count(table_unique_name)
142
+ if row_count > max_rows:
143
+ max_rows = row_count
144
+
145
+ estimated_rows = max_rows
146
+ # Crude adjustment for joins
147
+ if len(tables_in_query) > 1:
148
+ # A better estimate would involve graph traversal and statistical models
149
+ estimated_rows *= JOIN_CARDINALITY_ESTIMATE * (len(tables_in_query) - 1)
150
+
151
+ decision = "execute" if estimated_rows < ROW_EXECUTION_THRESHOLD else "return_sql"
152
+
153
+ return {
154
+ "estimated_rows": estimated_rows,
155
+ "decision": decision,
156
+ "tables_found": tables_in_query
157
+ }
158
+
159
+ except Exception as e:
160
+ logger.error(f"Error estimating query cost: {e}")
161
+ return {"estimated_rows": -1, "decision": "error", "message": str(e)}
mcp/requirements.txt CHANGED
@@ -3,4 +3,5 @@ uvicorn==0.24.0
3
  neo4j==5.14.0
4
  pydantic==2.4.0
5
  requests==2.31.0
6
- psycopg2-binary==2.9.9
 
 
3
  neo4j==5.14.0
4
  pydantic==2.4.0
5
  requests==2.31.0
6
+ SQLAlchemy==2.0.29
7
+ sqlparse==0.5.0
ops/scripts/generate_sample_databases.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import pandas as pd
3
+ from datetime import datetime, timedelta
4
+ import random
5
+ import os
6
+
7
+ # Ensure the data directory exists
8
+ DATA_DIR = 'data'
9
+ os.makedirs(DATA_DIR, exist_ok=True)
10
+
11
+ def create_clinical_trials_db():
12
+ """Creates the clinical_trials.db database."""
13
+ conn = sqlite3.connect(os.path.join(DATA_DIR, 'clinical_trials.db'))
14
+
15
+ # Studies table
16
+ studies_data = {
17
+ 'study_id': ['ONCO-2023-001', 'CARDIO-2023-047', 'NEURO-2024-012', 'DIAB-2023-089', 'RARE-2024-003'],
18
+ 'study_name': ['Phase III Immunotherapy Trial', 'Beta Blocker Efficacy Study', 'Alzheimer Prevention Trial', 'Insulin Resistance Study', 'Rare Disease Natural History'],
19
+ 'phase': ['Phase 3', 'Phase 2', 'Phase 2', 'Phase 3', 'Observational'],
20
+ 'status': ['RECRUITING', 'ACTIVE', 'PLANNING', 'COMPLETED', 'RECRUITING'],
21
+ 'sponsor': ['OncoPharm Inc', 'CardioHealth', 'NeuroGen', 'DiabetesCare', 'NIH'],
22
+ 'target_enrollment': [500, 200, 150, 800, 50],
23
+ 'current_enrollment': [237, 178, 0, 800, 12],
24
+ 'start_date': ['2023-03-15', '2023-06-01', '2024-01-15', '2023-01-10', '2024-02-01']
25
+ }
26
+ pd.DataFrame(studies_data).to_sql('studies', conn, index=False, if_exists='replace')
27
+
28
+ # Patients table
29
+ patients_data = {
30
+ 'patient_id': [f'PT{str(i).zfill(6)}' for i in range(1, 101)],
31
+ 'study_id': [random.choice(studies_data['study_id']) for _ in range(100)],
32
+ 'enrollment_date': [(datetime.now() - timedelta(days=random.randint(1, 365))).strftime('%Y-%m-%d') for _ in range(100)],
33
+ 'age': [random.randint(18, 85) for _ in range(100)],
34
+ 'gender': [random.choice(['M', 'F']) for _ in range(100)],
35
+ 'status': [random.choice(['ENROLLED', 'COMPLETED', 'WITHDRAWN', 'SCREENING']) for _ in range(100)]
36
+ }
37
+ pd.DataFrame(patients_data).to_sql('patients', conn, index=False, if_exists='replace')
38
+
39
+ # Adverse Events table
40
+ adverse_events_data = {
41
+ 'event_id': list(range(1, 51)),
42
+ 'patient_id': [random.choice(patients_data['patient_id'][:50]) for _ in range(50)],
43
+ 'event_date': [(datetime.now() - timedelta(days=random.randint(1, 180))).strftime('%Y-%m-%d') for _ in range(50)],
44
+ 'event_type': [random.choice(['NAUSEA', 'HEADACHE', 'FATIGUE', 'RASH', 'FEVER']) for _ in range(50)],
45
+ 'severity': [random.choice(['MILD', 'MODERATE', 'SEVERE']) for _ in range(50)],
46
+ 'related_to_treatment': [random.choice(['YES', 'NO', 'UNKNOWN']) for _ in range(50)]
47
+ }
48
+ pd.DataFrame(adverse_events_data).to_sql('adverse_events', conn, index=False, if_exists='replace')
49
+
50
+ # Add foreign keys (SQLite doesn't enforce, but documents relationships)
51
+ conn.execute("CREATE INDEX idx_patients_study ON patients(study_id)")
52
+ conn.execute("CREATE INDEX idx_events_patient ON adverse_events(patient_id)")
53
+ conn.commit()
54
+ conn.close()
55
+ print("✅ Clinical Trials database created successfully!")
56
+
57
+ def create_laboratory_db():
58
+ """Creates the laboratory.db database."""
59
+ conn = sqlite3.connect(os.path.join(DATA_DIR, 'laboratory.db'))
60
+
61
+ # Lab Tests table
62
+ lab_tests_data = {
63
+ 'test_id': [f'LAB{str(i).zfill(8)}' for i in range(1, 201)],
64
+ 'patient_id': [f'PT{str(random.randint(1, 100)).zfill(6)}' for _ in range(200)],
65
+ 'test_date': [(datetime.now() - timedelta(days=random.randint(1, 365))).strftime('%Y-%m-%d') for _ in range(200)],
66
+ 'test_type': [random.choice(['CBC', 'METABOLIC_PANEL', 'LIVER_FUNCTION', 'LIPID_PANEL', 'HBA1C']) for _ in range(200)],
67
+ 'ordered_by': [f'DR{str(random.randint(1, 20)).zfill(4)}' for _ in range(200)],
68
+ 'priority': [random.choice(['ROUTINE', 'URGENT', 'STAT']) for _ in range(200)]
69
+ }
70
+ pd.DataFrame(lab_tests_data).to_sql('lab_tests', conn, index=False, if_exists='replace')
71
+
72
+ # Test Results table
73
+ results_data = {
74
+ 'result_id': list(range(1, 601)),
75
+ 'test_id': [random.choice(lab_tests_data['test_id']) for _ in range(600)],
76
+ 'analyte': [random.choice(['GLUCOSE', 'WBC', 'RBC', 'PLATELETS', 'CREATININE', 'ALT', 'AST', 'CHOLESTEROL']) for _ in range(600)],
77
+ 'value': [round(random.uniform(1, 200), 2) for _ in range(600)],
78
+ 'unit': [random.choice(['mg/dL', 'K/uL', 'M/uL', 'g/dL', 'mmol/L']) for _ in range(600)],
79
+ 'reference_low': [round(random.uniform(1, 50), 2) for _ in range(600)],
80
+ 'reference_high': [round(random.uniform(100, 200), 2) for _ in range(600)],
81
+ 'flag': [random.choice(['NORMAL', 'HIGH', 'LOW', 'CRITICAL']) for _ in range(600)]
82
+ }
83
+ pd.DataFrame(results_data).to_sql('test_results', conn, index=False, if_exists='replace')
84
+
85
+ # Biomarkers table (for research)
86
+ biomarkers_data = {
87
+ 'biomarker_id': list(range(1, 31)),
88
+ 'patient_id': [f'PT{str(random.randint(1, 100)).zfill(6)}' for _ in range(30)],
89
+ 'biomarker_name': [random.choice(['PD-L1', 'BRCA1', 'EGFR', 'KRAS', 'HER2']) for _ in range(30)],
90
+ 'expression_level': [random.choice(['HIGH', 'MEDIUM', 'LOW', 'NEGATIVE']) for _ in range(30)],
91
+ 'test_method': [random.choice(['IHC', 'PCR', 'NGS', 'FLOW_CYTOMETRY']) for _ in range(30)],
92
+ 'collection_date': [(datetime.now() - timedelta(days=random.randint(1, 365))).strftime('%Y-%m-%d') for _ in range(30)]
93
+ }
94
+ pd.DataFrame(biomarkers_data).to_sql('biomarkers', conn, index=False, if_exists='replace')
95
+
96
+ conn.execute("CREATE INDEX idx_results_test ON test_results(test_id)")
97
+ conn.execute("CREATE INDEX idx_tests_patient ON lab_tests(patient_id)")
98
+ conn.commit()
99
+ conn.close()
100
+ print("✅ Laboratory database created successfully!")
101
+
102
+ def create_drug_discovery_db():
103
+ """Creates the drug_discovery.db database."""
104
+ conn = sqlite3.connect(os.path.join(DATA_DIR, 'drug_discovery.db'))
105
+
106
+ # Compounds table
107
+ compounds_data = {
108
+ 'compound_id': [f'CMP-{str(i).zfill(6)}' for i in range(1, 51)],
109
+ 'compound_name': [f'Compound-{chr(65+i//10)}{i%10}' for i in range(50)],
110
+ 'molecular_weight': [round(random.uniform(200, 800), 2) for _ in range(50)],
111
+ 'formula': [f'C{random.randint(10,30)}H{random.randint(10,40)}N{random.randint(0,5)}O{random.randint(1,10)}' for _ in range(50)],
112
+ 'development_stage': [random.choice(['DISCOVERY', 'LEAD_OPT', 'PRECLINICAL', 'CLINICAL', 'DISCONTINUED']) for _ in range(50)],
113
+ 'target_class': [random.choice(['KINASE', 'GPCR', 'ION_CHANNEL', 'PROTEASE', 'ANTIBODY']) for _ in range(50)]
114
+ }
115
+ pd.DataFrame(compounds_data).to_sql('compounds', conn, index=False, if_exists='replace')
116
+
117
+ # Assay Results table
118
+ assays_data = {
119
+ 'assay_id': list(range(1, 201)),
120
+ 'compound_id': [random.choice(compounds_data['compound_id']) for _ in range(200)],
121
+ 'assay_type': [random.choice(['BINDING', 'CELL_VIABILITY', 'ENZYME_INHIBITION', 'TOXICITY']) for _ in range(200)],
122
+ 'ic50_nm': [round(random.uniform(0.1, 10000), 2) for _ in range(200)],
123
+ 'efficacy_percent': [round(random.uniform(0, 100), 1) for _ in range(200)],
124
+ 'assay_date': [(datetime.now() - timedelta(days=random.randint(1, 365))).strftime('%Y-%m-%d') for _ in range(200)],
125
+ 'scientist': [f'SCI-{random.randint(1, 10)}' for _ in range(200)]
126
+ }
127
+ pd.DataFrame(assays_data).to_sql('assay_results', conn, index=False, if_exists='replace')
128
+
129
+ # Drug Targets table
130
+ targets_data = {
131
+ 'target_id': [f'TGT-{str(i).zfill(4)}' for i in range(1, 21)],
132
+ 'target_name': [f'Protein-{i}' for i in range(1, 21)],
133
+ 'gene_symbol': [f'GENE{i}' for i in range(1, 21)],
134
+ 'pathway': [random.choice(['MAPK', 'PI3K/AKT', 'WNT', 'NOTCH', 'HEDGEHOG']) for _ in range(20)],
135
+ 'disease_area': [random.choice(['ONCOLOGY', 'CARDIOLOGY', 'NEUROLOGY', 'IMMUNOLOGY']) for _ in range(20)]
136
+ }
137
+ pd.DataFrame(targets_data).to_sql('drug_targets', conn, index=False, if_exists='replace')
138
+
139
+ # Compound-Target Associations
140
+ associations_data = {
141
+ 'association_id': list(range(1, 76)),
142
+ 'compound_id': [random.choice(compounds_data['compound_id']) for _ in range(75)],
143
+ 'target_id': [random.choice(targets_data['target_id']) for _ in range(75)],
144
+ 'affinity_nm': [round(random.uniform(0.1, 1000), 2) for _ in range(75)],
145
+ 'selectivity_fold': [round(random.uniform(1, 100), 1) for _ in range(75)]
146
+ }
147
+ pd.DataFrame(associations_data).to_sql('compound_targets', conn, index=False, if_exists='replace')
148
+
149
+ conn.execute("CREATE INDEX idx_assays_compound ON assay_results(compound_id)")
150
+ conn.execute("CREATE INDEX idx_associations_compound ON compound_targets(compound_id)")
151
+ conn.execute("CREATE INDEX idx_associations_target ON compound_targets(target_id)")
152
+ conn.commit()
153
+ conn.close()
154
+ print("✅ Drug Discovery database created successfully!")
155
+
156
+ if __name__ == "__main__":
157
+ create_clinical_trials_db()
158
+ create_laboratory_db()
159
+ create_drug_discovery_db()
160
+ print("\n🎉 All three Life Sciences databases created successfully in the 'data' directory!")
ops/scripts/ingest.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import logging
4
+ from sqlalchemy import create_engine
5
+
6
+ # Add project root to path to allow imports from mcp
7
+ project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8
+ sys.path.append(project_root)
9
+
10
+ from core.discovery import discover_schema
11
+ from core.graph import GraphStore
12
+ from core.config import SQLITE_DATA_DIR
13
+
14
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
15
+ logger = logging.getLogger(__name__)
16
+
17
+ def ingest_sqlite_database(db_file: str, graph_store: GraphStore):
18
+ """Discovers schema from a SQLite DB and ingests it into Neo4j."""
19
+ db_path = os.path.join(SQLITE_DATA_DIR, db_file)
20
+ logger.info(f"Processing database: {db_path}")
21
+
22
+ if not os.path.exists(db_path):
23
+ logger.error(f"Database file not found: {db_path}")
24
+ return
25
+
26
+ try:
27
+ engine = create_engine(f"sqlite:///{db_path}")
28
+ schema_data = discover_schema(engine)
29
+
30
+ if schema_data:
31
+ logger.info(f"Discovered schema for {db_file}, ingesting into Neo4j...")
32
+ graph_store.import_schema(schema_data)
33
+ logger.info(f"Successfully ingested schema for {db_file}")
34
+ else:
35
+ logger.warning(f"Could not discover schema for {db_file}. Skipping.")
36
+
37
+ except Exception as e:
38
+ logger.error(f"An error occurred while processing {db_file}: {e}")
39
+
40
+ def main():
41
+ """
42
+ Main function to run the ingestion process for all SQLite databases
43
+ found in the data directory.
44
+ """
45
+ logger.info("Starting schema ingestion process...")
46
+
47
+ if not os.path.exists(SQLITE_DATA_DIR) or not os.path.isdir(SQLITE_DATA_DIR):
48
+ logger.error(f"Data directory not found: {SQLITE_DATA_DIR}")
49
+ return
50
+
51
+ db_files = [f for f in os.listdir(SQLITE_DATA_DIR) if f.endswith(".db")]
52
+
53
+ if not db_files:
54
+ logger.warning(f"No SQLite database files (.db) found in {SQLITE_DATA_DIR}.")
55
+ return
56
+
57
+ try:
58
+ graph_store = GraphStore()
59
+ logger.info("Successfully connected to Neo4j.")
60
+ except Exception as e:
61
+ logger.error(f"Failed to connect to Neo4j. Aborting ingestion. Error: {e}")
62
+ return
63
+
64
+ for db_file in db_files:
65
+ ingest_sqlite_database(db_file, graph_store)
66
+
67
+ graph_store.close()
68
+ logger.info("Schema ingestion process completed.")
69
+
70
+ if __name__ == "__main__":
71
+ main()
postgres/init.sql DELETED
@@ -1,28 +0,0 @@
1
- -- Create sample tables for testing
2
- CREATE TABLE customers (
3
- id SERIAL PRIMARY KEY,
4
- email VARCHAR(255) UNIQUE NOT NULL,
5
- first_name VARCHAR(100),
6
- last_name VARCHAR(100),
7
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
8
- );
9
-
10
- CREATE TABLE orders (
11
- id SERIAL PRIMARY KEY,
12
- customer_id INTEGER REFERENCES customers(id),
13
- order_date DATE NOT NULL,
14
- total_amount DECIMAL(10,2),
15
- status VARCHAR(50)
16
- );
17
-
18
- -- Insert sample data
19
- INSERT INTO customers (email, first_name, last_name) VALUES
20
- ('john.doe@email.com', 'John', 'Doe'),
21
- ('jane.smith@email.com', 'Jane', 'Smith'),
22
- ('bob.johnson@email.com', 'Bob', 'Johnson');
23
-
24
- INSERT INTO orders (customer_id, order_date, total_amount, status) VALUES
25
- (1, '2024-01-15', 99.99, 'completed'),
26
- (1, '2024-02-01', 149.99, 'completed'),
27
- (2, '2024-01-20', 79.99, 'pending'),
28
- (3, '2024-02-10', 199.99, 'completed');
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
streamlit/app.py CHANGED
@@ -7,478 +7,156 @@ All database access MUST go through MCP server - no direct connections allowed.
7
 
8
  import streamlit as st
9
  import requests
10
- import time
11
  import json
12
  import pandas as pd
13
- from datetime import datetime, timedelta
14
- import os
15
- from typing import Dict, Any, Optional, Tuple
16
 
17
- # Configuration
 
 
18
  MCP_URL = os.getenv("MCP_URL", "http://mcp:8000/mcp")
19
  MCP_API_KEY = os.getenv("MCP_API_KEY", "dev-key-123")
20
 
21
- # Page configuration
22
  st.set_page_config(
23
- page_title="MCP Monitor & Query Tester",
24
- page_icon="🤖",
25
- layout="wide",
26
- initial_sidebar_state="expanded"
27
  )
28
 
29
- # Initialize session state
30
- if 'workflow_id' not in st.session_state:
31
- st.session_state.workflow_id = None
32
- if 'debug_log' not in st.session_state:
33
- st.session_state.debug_log = []
34
- if 'last_refresh' not in st.session_state:
35
- st.session_state.last_refresh = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- def call_mcp(tool: str, params: Optional[Dict[str, Any]] = None) -> Tuple[Dict[str, Any], int]:
38
- """
39
- Call MCP server - the ONLY way to access databases.
40
- Returns (response_data, response_time_ms)
41
- """
42
- start_time = time.time()
43
-
44
  try:
45
  response = requests.post(
46
- MCP_URL,
47
- headers={
48
- "X-API-Key": MCP_API_KEY,
49
- "Content-Type": "application/json"
50
- },
51
- json={"tool": tool, "params": params or {}},
52
- timeout=10
53
  )
54
-
55
- response_time = int((time.time() - start_time) * 1000)
56
-
57
- # Log the request/response for debugging
58
- debug_entry = {
59
- "timestamp": datetime.now().isoformat(),
60
- "tool": tool,
61
- "params": params,
62
- "status_code": response.status_code,
63
- "response_time_ms": response_time,
64
- "success": response.status_code == 200
65
- }
66
- st.session_state.debug_log.append(debug_entry)
67
-
68
- # Keep only last 5 entries
69
- if len(st.session_state.debug_log) > 5:
70
- st.session_state.debug_log = st.session_state.debug_log[-5:]
71
-
72
- if response.status_code == 200:
73
- return response.json(), response_time
74
- else:
75
- return {"error": f"HTTP {response.status_code}: {response.text}"}, response_time
76
 
77
- except requests.exceptions.RequestException as e:
78
- response_time = int((time.time() - start_time) * 1000)
79
- error_msg = f"MCP Server Error: {str(e)}"
80
-
81
- # Log the error
82
- debug_entry = {
83
- "timestamp": datetime.now().isoformat(),
84
- "tool": tool,
85
- "params": params,
86
- "status_code": 0,
87
- "response_time_ms": response_time,
88
- "success": False,
89
- "error": error_msg
90
- }
91
- st.session_state.debug_log.append(debug_entry)
92
-
93
- return {"error": error_msg}, response_time
94
 
95
- def test_neo4j_connection() -> Tuple[bool, int, str]:
96
- """Test Neo4j connection through MCP server"""
97
- result, response_time = call_mcp("get_schema")
98
- if "error" in result:
99
- return False, response_time, result["error"]
100
- return True, response_time, "Connected"
101
 
102
- def test_postgres_connection() -> Tuple[bool, int, str]:
103
- """Test PostgreSQL connection through MCP server"""
104
- result, response_time = call_mcp("query_postgres", {"query": "SELECT 1 as test"})
105
- if "error" in result:
106
- return False, response_time, result["error"]
107
- return True, response_time, "Connected"
108
 
109
- def test_mcp_server() -> Tuple[bool, int, str]:
110
- """Test MCP server health"""
111
  try:
112
- start_time = time.time()
113
- response = requests.get(f"{MCP_URL.replace('/mcp', '/health')}", timeout=5)
114
- response_time = int((time.time() - start_time) * 1000)
 
 
 
 
 
 
115
 
116
- if response.status_code == 200:
117
- return True, response_time, "Healthy"
118
- else:
119
- return False, response_time, f"HTTP {response.status_code}"
120
- except Exception as e:
121
- return False, 0, str(e)
122
 
123
- def get_performance_stats() -> Dict[str, Any]:
124
- """Get performance statistics through MCP"""
125
- result, _ = call_mcp("query_graph", {
126
- "query": "MATCH (l:Log) WHERE l.timestamp > datetime() - duration('PT1H') RETURN count(l) as count"
127
- })
128
-
129
- if "error" in result:
130
- return {"error": result["error"]}
131
-
132
- return result.get("data", [{}])[0] if result.get("data") else {}
133
 
134
- def create_workflow(question: str) -> Optional[str]:
135
- """Create a new workflow for the given question"""
136
- workflow_id = f"streamlit-{int(time.time())}"
137
-
138
- # Create workflow node
139
- workflow_result, _ = call_mcp("write_graph", {
140
- "action": "create_node",
141
- "label": "Workflow",
142
- "properties": {
143
- "id": workflow_id,
144
- "name": f"Streamlit Query: {question[:50]}...",
145
- "description": f"Query from Streamlit: {question}",
146
- "status": "active",
147
- "created_at": datetime.now().isoformat(),
148
- "source": "streamlit"
149
- }
150
- })
151
-
152
- if "error" in workflow_result:
153
- st.error(f"Failed to create workflow: {workflow_result['error']}")
154
- return None
155
-
156
- # Create instruction sequence
157
- instructions = [
158
- {
159
- "id": f"{workflow_id}-inst-1",
160
- "type": "discover_schema",
161
- "sequence": 1,
162
- "description": "Discover database schema",
163
- "status": "pending",
164
- "pause_duration": 5, # Short pause for testing
165
- "parameters": "{}"
166
- },
167
- {
168
- "id": f"{workflow_id}-inst-2",
169
- "type": "generate_sql",
170
- "sequence": 2,
171
- "description": f"Generate SQL for: {question}",
172
- "status": "pending",
173
- "pause_duration": 5,
174
- "parameters": json.dumps({"question": question})
175
- },
176
- {
177
- "id": f"{workflow_id}-inst-3",
178
- "type": "review_results",
179
- "sequence": 3,
180
- "description": "Review and format results",
181
- "status": "pending",
182
- "pause_duration": 0,
183
- "parameters": "{}"
184
- }
185
- ]
186
-
187
- # Create instruction nodes
188
- for inst in instructions:
189
- inst_result, _ = call_mcp("write_graph", {
190
- "action": "create_node",
191
- "label": "Instruction",
192
- "properties": inst
193
- })
194
 
195
- if "error" in inst_result:
196
- st.error(f"Failed to create instruction: {inst_result['error']}")
197
- return None
198
 
199
- # Link instruction to workflow
200
- link_result, _ = call_mcp("query_graph", {
201
- "query": """
202
- MATCH (w:Workflow {id: $workflow_id}), (i:Instruction {id: $inst_id})
203
- CREATE (w)-[:HAS_INSTRUCTION]->(i)
204
- """,
205
- "parameters": {"workflow_id": workflow_id, "inst_id": inst["id"]}
206
- })
207
-
208
- # Create instruction chain
209
- for i in range(len(instructions) - 1):
210
- chain_result, _ = call_mcp("query_graph", {
211
- "query": """
212
- MATCH (i1:Instruction {id: $id1}), (i2:Instruction {id: $id2})
213
- CREATE (i1)-[:NEXT_INSTRUCTION]->(i2)
214
- """,
215
- "parameters": {"id1": instructions[i]["id"], "id2": instructions[i + 1]["id"]}
216
- })
217
-
218
- return workflow_id
219
 
220
- def get_workflow_status(workflow_id: str) -> Dict[str, Any]:
221
- """Get workflow execution status"""
222
- result, _ = call_mcp("query_graph", {
223
- "query": """
224
- MATCH (w:Workflow {id: $id})-[:HAS_INSTRUCTION]->(i:Instruction)
225
- RETURN w.status as workflow_status,
226
- collect(i.status) as instruction_statuses,
227
- collect(i.type) as instruction_types,
228
- collect(i.sequence) as sequences
229
- """,
230
- "parameters": {"id": workflow_id}
231
- })
232
-
233
- if "error" in result or not result.get("data"):
234
- return {"error": "Workflow not found"}
235
-
236
- return result["data"][0]
237
 
238
- def get_workflow_results(workflow_id: str) -> Dict[str, Any]:
239
- """Get workflow execution results"""
240
- result, _ = call_mcp("query_graph", {
241
- "query": """
242
- MATCH (w:Workflow {id: $id})-[:HAS_INSTRUCTION]->(i:Instruction)-[:EXECUTED_AS]->(e:Execution)
243
- RETURN i.sequence as sequence,
244
- i.type as type,
245
- i.description as description,
246
- e.result as result,
247
- e.started_at as started_at,
248
- e.completed_at as completed_at
249
- ORDER BY i.sequence
250
- """,
251
- "parameters": {"id": workflow_id}
252
- })
253
-
254
- if "error" in result:
255
- return {"error": result["error"]}
256
-
257
- return {"executions": result.get("data", [])}
258
 
259
- def get_schema_context() -> str:
260
- """Get database schema context for display"""
261
- result, _ = call_mcp("query_graph", {
262
- "query": """
263
- MATCH (t:Table)-[:HAS_COLUMN]->(c:Column)
264
- RETURN t.name as table_name,
265
- collect({name: c.name, type: c.data_type, nullable: c.nullable}) as columns
266
- ORDER BY t.name
267
- """
268
- })
269
-
270
- if "error" in result:
271
- return f"Error fetching schema: {result['error']}"
272
-
273
- schema_text = "Database Schema:\n"
274
- for record in result.get("data", []):
275
- table_name = record["table_name"]
276
- columns = record["columns"]
277
- schema_text += f"\nTable: {table_name}\n"
278
- for col in columns:
279
- nullable = "NULL" if col["nullable"] else "NOT NULL"
280
- schema_text += f" - {col['name']}: {col['type']} {nullable}\n"
281
-
282
- return schema_text
283
 
284
- def main():
285
- st.title("🤖 MCP Monitor & Query Tester")
286
- st.caption("Monitor agentic system health and test queries through MCP server")
287
-
288
- # Sidebar
289
- with st.sidebar:
290
- st.header("🔧 Configuration")
291
- st.code(f"MCP URL: {MCP_URL}")
292
- st.code(f"API Key: {MCP_API_KEY[:10]}...")
293
-
294
- if st.button("🔄 Refresh All", type="primary"):
295
- st.rerun()
296
-
297
- st.header("📊 Quick Stats")
298
- stats = get_performance_stats()
299
- if "error" not in stats:
300
- st.metric("Logs (1h)", stats.get("count", 0))
301
- else:
302
- st.error(f"Stats error: {stats['error']}")
303
-
304
- # Main tabs
305
- tab1, tab2 = st.tabs(["🔌 Connection Status", "🤖 Query Tester"])
306
-
307
- with tab1:
308
- st.header("Connection Status Monitor")
309
- st.caption("All database access goes through MCP server - no direct connections allowed")
310
-
311
- # Connection status in columns
312
- col1, col2, col3 = st.columns(3)
313
-
314
- with col1:
315
- st.subheader("Neo4j (via MCP)")
316
- neo4j_ok, neo4j_time, neo4j_msg = test_neo4j_connection()
317
- st.metric(
318
- label="Status",
319
- value="Online" if neo4j_ok else "Offline",
320
- delta=f"{neo4j_time}ms"
321
- )
322
- if neo4j_ok:
323
- st.success(neo4j_msg)
324
- else:
325
- st.error(neo4j_msg)
326
-
327
- with col2:
328
- st.subheader("PostgreSQL (via MCP)")
329
- postgres_ok, postgres_time, postgres_msg = test_postgres_connection()
330
- st.metric(
331
- label="Status",
332
- value="Online" if postgres_ok else "Offline",
333
- delta=f"{postgres_time}ms"
334
- )
335
- if postgres_ok:
336
- st.success(postgres_msg)
337
- else:
338
- st.error(postgres_msg)
339
-
340
- with col3:
341
- st.subheader("MCP Server")
342
- mcp_ok, mcp_time, mcp_msg = test_mcp_server()
343
- st.metric(
344
- label="Status",
345
- value="Online" if mcp_ok else "Offline",
346
- delta=f"{mcp_time}ms"
347
- )
348
- if mcp_ok:
349
- st.success(mcp_msg)
350
- else:
351
- st.error(mcp_msg)
352
-
353
- # Performance stats
354
- st.subheader("Performance Statistics")
355
- stats = get_performance_stats()
356
- if "error" not in stats:
357
- st.info(f"Operations in last hour: {stats.get('count', 0)}")
358
- else:
359
- st.error(f"Cannot fetch stats: {stats['error']}")
360
-
361
- # Auto-refresh info
362
- st.session_state.last_refresh = datetime.now()
363
- st.caption(f"Last checked: {st.session_state.last_refresh.strftime('%H:%M:%S')}")
364
-
365
- # Auto-refresh every 5 seconds
366
- time.sleep(5)
367
- st.rerun()
368
-
369
- with tab2:
370
- st.header("Query Tester")
371
- st.caption("Test natural language queries through the agentic engine")
372
-
373
- # Query input
374
- question = st.text_area(
375
- "Enter your question:",
376
- height=100,
377
- placeholder="e.g., 'How many customers do we have?' or 'Show me all orders from last month'"
378
- )
379
-
380
- col1, col2 = st.columns([1, 1])
381
-
382
- with col1:
383
- if st.button("🚀 Execute Query", type="primary", disabled=not question.strip()):
384
- if question.strip():
385
- with st.spinner("Creating workflow..."):
386
- workflow_id = create_workflow(question.strip())
387
- if workflow_id:
388
- st.session_state.workflow_id = workflow_id
389
- st.success(f"Workflow created: {workflow_id}")
390
- else:
391
- st.error("Failed to create workflow")
392
-
393
- with col2:
394
- if st.button("🗑️ Clear Results"):
395
- st.session_state.workflow_id = None
396
- st.rerun()
397
-
398
- # Workflow execution monitoring
399
- if st.session_state.workflow_id:
400
- st.subheader("Execution Progress")
401
-
402
- # Get workflow status
403
- status = get_workflow_status(st.session_state.workflow_id)
404
 
405
- if "error" in status:
406
- st.error(f"Status error: {status['error']}")
407
- else:
408
- workflow_status = status.get("workflow_status", "unknown")
409
- instruction_statuses = status.get("instruction_statuses", [])
410
- instruction_types = status.get("instruction_types", [])
411
-
412
- # Progress bar
413
- completed = sum(1 for s in instruction_statuses if s == "complete")
414
- total = len(instruction_statuses)
415
- progress = completed / total if total > 0 else 0
416
 
417
- st.progress(progress)
418
- st.caption(f"Progress: {completed}/{total} instructions completed")
419
 
420
- # Status display
421
- status_cols = st.columns(len(instruction_types))
422
- for i, (inst_type, inst_status) in enumerate(zip(instruction_types, instruction_statuses)):
423
- with status_cols[i]:
424
- if inst_status == "complete":
425
- st.success(f" {inst_type}")
426
- elif inst_status == "executing":
427
- st.warning(f"🔄 {inst_type}")
428
- elif inst_status == "failed":
429
- st.error(f"❌ {inst_type}")
430
- else:
431
- st.info(f"⏳ {inst_type}")
432
-
433
- # Get and display results
434
- if completed > 0:
435
- results = get_workflow_results(st.session_state.workflow_id)
436
-
437
- if "error" not in results:
438
- st.subheader("Execution Results")
439
-
440
- for execution in results.get("executions", []):
441
- with st.expander(f"Step {execution['sequence']}: {execution['type']}"):
442
- st.write(f"**Description:** {execution['description']}")
443
-
444
- if execution['started_at'] and execution['completed_at']:
445
- start = datetime.fromisoformat(execution['started_at'].replace('Z', '+00:00'))
446
- end = datetime.fromisoformat(execution['completed_at'].replace('Z', '+00:00'))
447
- duration = (end - start).total_seconds()
448
- st.write(f"**Duration:** {duration:.2f} seconds")
449
-
450
- if execution['result']:
451
- try:
452
- result_data = json.loads(execution['result']) if isinstance(execution['result'], str) else execution['result']
453
-
454
- if execution['type'] == 'generate_sql' and 'generated_sql' in result_data:
455
- st.write("**Generated SQL:**")
456
- st.code(result_data['generated_sql'], language='sql')
457
-
458
- if 'data' in result_data and result_data['data']:
459
- st.write("**Query Results:**")
460
- df = pd.DataFrame(result_data['data'])
461
- st.dataframe(df)
462
-
463
- if 'error' in result_data:
464
- st.error(f"Error: {result_data['error']}")
465
-
466
- except Exception as e:
467
- st.write("**Raw Result:**")
468
- st.code(str(execution['result']))
469
- else:
470
- st.error(f"Results error: {results['error']}")
471
 
472
- # Debug information
473
- with st.expander("🔧 Debug Information"):
474
- st.write("**Last 5 MCP Requests:**")
475
- for entry in st.session_state.debug_log:
476
- status_icon = "✅" if entry["success"] else "❌"
477
- st.write(f"{status_icon} {entry['timestamp']} - {entry['tool']} ({entry['response_time_ms']}ms)")
478
- if not entry["success"] and "error" in entry:
479
- st.error(f"Error: {entry['error']}")
480
-
481
- st.write("**Important:** All database operations go through MCP server. Direct database access is not permitted.")
482
 
483
  if __name__ == "__main__":
484
  main()
 
7
 
8
  import streamlit as st
9
  import requests
10
+ import os
11
  import json
12
  import pandas as pd
13
+ from typing import Dict, Any
 
 
14
 
15
+ # --- Configuration ---
16
+ AGENT_URL = os.getenv("AGENT_URL", "http://agent:8001/query")
17
+ NEO4J_URL = os.getenv("NEO4J_URL", "http://neo4j:7474")
18
  MCP_URL = os.getenv("MCP_URL", "http://mcp:8000/mcp")
19
  MCP_API_KEY = os.getenv("MCP_API_KEY", "dev-key-123")
20
 
 
21
  st.set_page_config(
22
+ page_title="GraphRAG Chat",
23
+ page_icon="💬",
24
+ layout="wide"
 
25
  )
26
 
27
+ # --- Session State ---
28
+ if 'messages' not in st.session_state:
29
+ st.session_state.messages = []
30
+ if 'schema_info' not in st.session_state:
31
+ st.session_state.schema_info = ""
32
+
33
+ # --- Helper Functions ---
34
+ def stream_agent_response(question: str):
35
+ """Streams the agent's response, yielding JSON objects."""
36
+ try:
37
+ with requests.post(AGENT_URL, json={"question": question}, stream=True, timeout=300) as r:
38
+ r.raise_for_status()
39
+ for chunk in r.iter_content(chunk_size=None):
40
+ if chunk:
41
+ try:
42
+ yield json.loads(chunk)
43
+ except json.JSONDecodeError:
44
+ # Handle potential parsing errors if chunks are not perfect JSON
45
+ logger.warning(f"Could not decode JSON chunk: {chunk}")
46
+ continue
47
+ except requests.exceptions.RequestException as e:
48
+ yield {"error": f"Failed to connect to agent: {e}"}
49
 
50
+ def fetch_schema_info() -> str:
51
+ """Fetches the database schema from the MCP server for display."""
 
 
 
 
 
52
  try:
53
  response = requests.post(
54
+ f"{MCP_URL}/discovery/get_relevant_schemas",
55
+ headers={"Authorization": f"Bearer {MCP_API_KEY}", "Content-Type": "application/json"},
56
+ json={"query": ""}
 
 
 
 
57
  )
58
+ response.raise_for_status()
59
+ data = response.json()
60
+
61
+ if data.get("status") == "success":
62
+ schemas = data.get("schemas", [])
63
+ if not schemas: return "No schema information found."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ # Group columns by table
66
+ tables = {}
67
+ for s in schemas:
68
+ table_key = f"{s.get('database', '')}.{s.get('table', '')}"
69
+ if table_key not in tables:
70
+ tables[table_key] = []
71
+ tables[table_key].append(f"{s.get('name', '')} ({s.get('type', [''])[0]})")
72
+
73
+ schema_text = ""
74
+ for table, columns in tables.items():
75
+ schema_text += f"**{table}**:\n"
76
+ for col in columns:
77
+ schema_text += f"- {col}\n"
78
+ return schema_text
79
+ else:
80
+ return f"Error from MCP: {data.get('message', 'Unknown error')}"
 
81
 
82
+ except requests.exceptions.RequestException as e:
83
+ return f"Could not fetch schema: {e}"
 
 
 
 
84
 
85
+ @st.cache_data(ttl=600)
86
+ def get_cached_schema():
87
+ """Cache the schema info to avoid repeated calls."""
88
+ return fetch_schema_info()
 
 
89
 
90
+ def check_service_health(service_name: str, url: str) -> bool:
91
+ """Checks if a service is reachable."""
92
  try:
93
+ response = requests.get(url, timeout=5)
94
+ return response.status_code in [200, 401]
95
+ except requests.exceptions.RequestException:
96
+ return False
97
+
98
+ # --- UI Components ---
99
+ def display_sidebar():
100
+ with st.sidebar:
101
+ st.title("🗄️ Database Schema")
102
 
103
+ if st.button("🔄 Refresh Schema"):
104
+ st.cache_data.clear()
 
 
 
 
105
 
106
+ st.session_state.schema_info = get_cached_schema()
107
+ st.markdown(st.session_state.schema_info)
 
 
 
 
 
 
 
 
108
 
109
+ st.markdown("---")
110
+ st.title("🔌 Service Status")
111
+
112
+ neo4j_status = "✅ Online" if check_service_health("Neo4j", NEO4J_URL) else "❌ Offline"
113
+ mcp_status = "✅ Online" if check_service_health("MCP", MCP_URL.replace("/mcp", "/health")) else "❌ Offline"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
+ st.markdown(f"**Neo4j:** {neo4j_status}")
116
+ st.markdown(f"**MCP Server:** {mcp_status}")
 
117
 
118
+ st.markdown("---")
119
+ if st.button("🗑️ Clear Chat History"):
120
+ st.session_state.messages = []
121
+ st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ def main():
124
+ display_sidebar()
125
+ st.title("💬 GraphRAG Conversational Agent")
126
+ st.markdown("Ask questions about the life sciences dataset. The agent's thought process will be shown below.")
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ # Display chat history
129
+ for message in st.session_state.messages:
130
+ with st.chat_message(message["role"]):
131
+ st.markdown(message["content"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ if prompt := st.chat_input("Ask your question here..."):
134
+ st.session_state.messages.append({"role": "user", "content": prompt})
135
+ with st.chat_message("user"):
136
+ st.markdown(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ with st.chat_message("assistant"):
139
+ full_response = ""
140
+ response_box = st.empty()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
+ for chunk in stream_agent_response(prompt):
143
+ if "error" in chunk:
144
+ full_response = chunk["error"]
145
+ response_box.error(full_response)
146
+ break
 
 
 
 
 
 
147
 
148
+ content = chunk.get("content", "")
 
149
 
150
+ if chunk.get("type") == "thought":
151
+ full_response += f"🤔 *{content}*\n\n"
152
+ elif chunk.get("type") == "observation":
153
+ full_response += f"{content}\n\n"
154
+ elif chunk.get("type") == "final_answer":
155
+ full_response += f"**Final Answer:**\n{content}"
156
+
157
+ response_box.markdown(full_response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
 
 
 
 
 
 
 
 
 
160
 
161
  if __name__ == "__main__":
162
  main()