1st pass at merging the code bases(does not run yet)
Browse files- .gitignore +3 -0
- Makefile +9 -3
- README.md +79 -262
- agent/main.py +115 -416
- agent/requirements.txt +8 -4
- agent/tools.py +150 -0
- docker-compose.yml +6 -42
- mcp/__init__.py +0 -0
- mcp/core/__init__.py +0 -0
- mcp/core/config.py +24 -0
- mcp/core/database.py +26 -0
- mcp/core/discovery.py +98 -0
- mcp/core/graph.py +151 -0
- mcp/core/intelligence.py +161 -0
- mcp/requirements.txt +2 -1
- ops/scripts/generate_sample_databases.py +160 -0
- ops/scripts/ingest.py +71 -0
- postgres/init.sql +0 -28
- streamlit/app.py +119 -441
.gitignore
CHANGED
|
@@ -3,3 +3,6 @@
|
|
| 3 |
/neo4j/data
|
| 4 |
/postgres/data
|
| 5 |
/neo4j/logs
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
/neo4j/data
|
| 4 |
/postgres/data
|
| 5 |
/neo4j/logs
|
| 6 |
+
.env
|
| 7 |
+
mcp/__pycache__
|
| 8 |
+
mcp/core/__pycache__
|
Makefile
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
.PHONY: up down logs seed clean health test demo
|
| 2 |
|
| 3 |
# Start all services
|
| 4 |
up:
|
|
@@ -19,12 +19,19 @@ logs:
|
|
| 19 |
seed:
|
| 20 |
docker-compose exec mcp python /app/ops/scripts/seed.py
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
# Clean everything (including volumes)
|
| 23 |
clean:
|
| 24 |
docker-compose down -v
|
| 25 |
docker system prune -f
|
| 26 |
@if [ -d "neo4j/data" ]; then rm -rf neo4j/data; fi
|
| 27 |
-
@if [ -d "postgres/data" ]; then rm -rf postgres/data; fi
|
| 28 |
@if [ -d "frontend/.next" ]; then rm -rf frontend/.next; fi
|
| 29 |
@if [ -d "frontend/node_modules" ]; then rm -rf frontend/node_modules; fi
|
| 30 |
|
|
@@ -32,7 +39,6 @@ clean:
|
|
| 32 |
health:
|
| 33 |
@echo "Checking service health..."
|
| 34 |
@docker-compose exec neo4j cypher-shell -u neo4j -p password "MATCH (n) RETURN count(n) LIMIT 1" > /dev/null 2>&1 && echo "β
Neo4j: Healthy" || echo "β Neo4j: Unhealthy"
|
| 35 |
-
@docker-compose exec postgres pg_isready -U postgres > /dev/null 2>&1 && echo "β
PostgreSQL: Healthy" || echo "β PostgreSQL: Unhealthy"
|
| 36 |
@curl -s http://localhost:8000/health > /dev/null && echo "β
MCP Server: Healthy" || echo "β MCP Server: Unhealthy"
|
| 37 |
@curl -s http://localhost:3000 > /dev/null && echo "β
Frontend: Healthy" || echo "β Frontend: Unhealthy"
|
| 38 |
@curl -s http://localhost:8501 > /dev/null && echo "β
Streamlit: Healthy" || echo "β Streamlit: Unhealthy"
|
|
|
|
| 1 |
+
.PHONY: up down logs seed seed-db ingest clean health test demo
|
| 2 |
|
| 3 |
# Start all services
|
| 4 |
up:
|
|
|
|
| 19 |
seed:
|
| 20 |
docker-compose exec mcp python /app/ops/scripts/seed.py
|
| 21 |
|
| 22 |
+
# Seed the SQLite databases from scratch
|
| 23 |
+
seed-db:
|
| 24 |
+
python ops/scripts/generate_sample_databases.py
|
| 25 |
+
|
| 26 |
+
# Ingest SQLite database schemas into Neo4j
|
| 27 |
+
ingest:
|
| 28 |
+
docker-compose run --rm mcp python /app/ops/scripts/ingest.py
|
| 29 |
+
|
| 30 |
# Clean everything (including volumes)
|
| 31 |
clean:
|
| 32 |
docker-compose down -v
|
| 33 |
docker system prune -f
|
| 34 |
@if [ -d "neo4j/data" ]; then rm -rf neo4j/data; fi
|
|
|
|
| 35 |
@if [ -d "frontend/.next" ]; then rm -rf frontend/.next; fi
|
| 36 |
@if [ -d "frontend/node_modules" ]; then rm -rf frontend/node_modules; fi
|
| 37 |
|
|
|
|
| 39 |
health:
|
| 40 |
@echo "Checking service health..."
|
| 41 |
@docker-compose exec neo4j cypher-shell -u neo4j -p password "MATCH (n) RETURN count(n) LIMIT 1" > /dev/null 2>&1 && echo "β
Neo4j: Healthy" || echo "β Neo4j: Unhealthy"
|
|
|
|
| 42 |
@curl -s http://localhost:8000/health > /dev/null && echo "β
MCP Server: Healthy" || echo "β MCP Server: Unhealthy"
|
| 43 |
@curl -s http://localhost:3000 > /dev/null && echo "β
Frontend: Healthy" || echo "β Frontend: Unhealthy"
|
| 44 |
@curl -s http://localhost:8501 > /dev/null && echo "β
Streamlit: Healthy" || echo "β Streamlit: Unhealthy"
|
README.md
CHANGED
|
@@ -1,74 +1,63 @@
|
|
| 1 |
-
#
|
| 2 |
-
"Keep your data where it is but we will treat it like a graph for you and solve these problems for you"
|
| 3 |
|
| 4 |
## Overview
|
| 5 |
-
|
| 6 |
|
| 7 |
## Key Features
|
| 8 |
|
| 9 |
-
π€ **
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
π **Real-time Visualization**: Live workflow progress in browser interface
|
| 15 |
-
π **Complete Audit Trail**: Every action logged with timestamps and relationships
|
| 16 |
|
| 17 |
## Architecture
|
| 18 |
|
| 19 |
```
|
| 20 |
-
|
| 21 |
-
β
|
| 22 |
-
β
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
Use MCP inspector and also have an MCP server that automatically checks the logs in the inspector: https://modelcontextprotocol.io/docs/tools/inspector
|
| 37 |
|
| 38 |
-
|
| 39 |
-
pydantic is key here and now worrying about frontend until demo time
|
| 40 |
|
| 41 |
### Components
|
| 42 |
|
| 43 |
-
- **
|
| 44 |
-
- **
|
| 45 |
-
- **
|
| 46 |
-
- **
|
| 47 |
-
- **
|
| 48 |
|
| 49 |
## Quick Start
|
| 50 |
|
| 51 |
### Prerequisites
|
| 52 |
- Docker & Docker Compose
|
| 53 |
-
- OpenAI
|
| 54 |
|
| 55 |
### Setup
|
| 56 |
1. **Clone and configure**:
|
| 57 |
```bash
|
| 58 |
git clone <repository-url>
|
| 59 |
cd <repository-name>
|
| 60 |
-
|
| 61 |
```
|
| 62 |
|
| 63 |
-
2. **Add your
|
| 64 |
-
```
|
| 65 |
-
|
| 66 |
-
LLM_API_KEY=sk-your-openai-key-here
|
| 67 |
-
LLM_MODEL=gpt-4
|
| 68 |
-
|
| 69 |
-
# For Anthropic
|
| 70 |
-
LLM_API_KEY=your-anthropic-key-here
|
| 71 |
-
LLM_MODEL=claude-3-sonnet-20240229
|
| 72 |
```
|
| 73 |
|
| 74 |
3. **Start the system**:
|
|
@@ -76,235 +65,63 @@ pydantic is key here and now worrying about frontend until demo time
|
|
| 76 |
make up
|
| 77 |
```
|
| 78 |
|
| 79 |
-
4. **Seed
|
| 80 |
```bash
|
| 81 |
-
make seed
|
|
|
|
| 82 |
```
|
| 83 |
|
| 84 |
5. **Open the interface**:
|
| 85 |
-
-
|
| 86 |
- Neo4j Browser: http://localhost:7474 (neo4j/password)
|
| 87 |
|
| 88 |
## Usage
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
2.
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
- Reviews and formats results
|
| 101 |
-
|
| 102 |
-
3. **Human intervention** (during 5-minute pauses):
|
| 103 |
-
- Edit instructions in Neo4j Browser
|
| 104 |
-
- Change parameters or questions
|
| 105 |
-
- Stop workflows if needed
|
| 106 |
-
|
| 107 |
-
### Editing Instructions During Pause
|
| 108 |
-
|
| 109 |
-
When the agent pauses, you can modify instructions in Neo4j Browser:
|
| 110 |
-
|
| 111 |
-
```cypher
|
| 112 |
-
// Change the question being asked
|
| 113 |
-
MATCH (i:Instruction {status: 'pending'})
|
| 114 |
-
SET i.parameters = '{"question": "Show me customers from the last month"}'
|
| 115 |
-
|
| 116 |
-
// Stop the entire workflow
|
| 117 |
-
MATCH (w:Workflow {status: 'active'})
|
| 118 |
-
SET w.status = 'stopped'
|
| 119 |
-
```
|
| 120 |
-
|
| 121 |
-
### Monitoring
|
| 122 |
-
|
| 123 |
-
Check system health:
|
| 124 |
-
```bash
|
| 125 |
-
make health
|
| 126 |
-
```
|
| 127 |
-
|
| 128 |
-
View real-time logs:
|
| 129 |
-
```bash
|
| 130 |
-
make logs
|
| 131 |
-
```
|
| 132 |
-
|
| 133 |
-
Check specific service:
|
| 134 |
-
```bash
|
| 135 |
-
make debug-agent # Agent logs
|
| 136 |
-
make debug-mcp # MCP server logs
|
| 137 |
-
make debug-frontend # Frontend logs
|
| 138 |
-
```
|
| 139 |
-
|
| 140 |
-
## Commands Reference
|
| 141 |
-
|
| 142 |
-
| Command | Description |
|
| 143 |
-
|---------|-------------|
|
| 144 |
-
| `make up` | Start all services |
|
| 145 |
-
| `make down` | Stop all services |
|
| 146 |
-
| `make clean` | Remove all data and containers |
|
| 147 |
-
| `make health` | Check service health |
|
| 148 |
-
| `make seed` | Create demo data |
|
| 149 |
-
| `make logs` | View all logs |
|
| 150 |
-
| `make demo` | Full clean + setup + seed |
|
| 151 |
-
| `make test` | Run integration test |
|
| 152 |
-
|
| 153 |
-
## Configuration
|
| 154 |
-
|
| 155 |
-
### Environment Variables
|
| 156 |
-
|
| 157 |
-
All configuration is in `.env` file:
|
| 158 |
-
|
| 159 |
-
- **Neo4j**: Database connection and auth
|
| 160 |
-
- **PostgreSQL**: Sample data source
|
| 161 |
-
- **MCP**: API keys and server settings
|
| 162 |
-
- **Agent**: Polling interval and pause duration
|
| 163 |
-
- **LLM**: API key and model selection
|
| 164 |
-
|
| 165 |
-
### Pause Duration
|
| 166 |
-
|
| 167 |
-
Default: 5 minutes (300 seconds)
|
| 168 |
-
Configurable via `PAUSE_DURATION` in `.env`
|
| 169 |
-
|
| 170 |
-
### Polling Interval
|
| 171 |
-
|
| 172 |
-
Default: 30 seconds
|
| 173 |
-
Configurable via `AGENT_POLL_INTERVAL` in `.env`
|
| 174 |
|
| 175 |
## Development
|
| 176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
### File Structure
|
| 178 |
```
|
| 179 |
-
βββ agent/ #
|
| 180 |
-
βββ
|
| 181 |
-
βββ mcp/
|
| 182 |
-
βββ neo4j/
|
| 183 |
-
βββ
|
| 184 |
-
βββ ops/
|
| 185 |
βββ docker-compose.yml
|
| 186 |
βββ Makefile
|
| 187 |
βββ README.md
|
| 188 |
-
```
|
| 189 |
-
|
| 190 |
-
### Adding New Instruction Types
|
| 191 |
-
|
| 192 |
-
1. **Define handler in agent**:
|
| 193 |
-
```python
|
| 194 |
-
def handle_new_instruction_type(instruction):
|
| 195 |
-
# Implementation
|
| 196 |
-
return {"status": "success", "result": "..."}
|
| 197 |
-
```
|
| 198 |
-
|
| 199 |
-
2. **Add to agent main loop**:
|
| 200 |
-
```python
|
| 201 |
-
elif instruction['type'] == 'new_instruction_type':
|
| 202 |
-
exec_result = handle_new_instruction_type(instruction)
|
| 203 |
-
```
|
| 204 |
-
|
| 205 |
-
3. **Update frontend** to create new instruction types in workflows.
|
| 206 |
-
|
| 207 |
-
### Database Schema
|
| 208 |
-
|
| 209 |
-
The system uses two databases:
|
| 210 |
-
|
| 211 |
-
**Neo4j** (Workflow metadata):
|
| 212 |
-
- `Workflow` nodes with status and metadata
|
| 213 |
-
- `Instruction` nodes with type, sequence, parameters
|
| 214 |
-
- `Execution` nodes with results and timestamps
|
| 215 |
-
- Relationships: `HAS_INSTRUCTION`, `EXECUTED_AS`, `NEXT_INSTRUCTION`
|
| 216 |
-
|
| 217 |
-
**PostgreSQL** (Sample data):
|
| 218 |
-
- `customers` table
|
| 219 |
-
- `orders` table
|
| 220 |
-
- Sample data for testing SQL generation
|
| 221 |
-
|
| 222 |
-
## Troubleshooting
|
| 223 |
-
|
| 224 |
-
### Common Issues
|
| 225 |
-
|
| 226 |
-
**Services not starting**:
|
| 227 |
-
```bash
|
| 228 |
-
make down
|
| 229 |
-
make clean
|
| 230 |
-
make up
|
| 231 |
-
```
|
| 232 |
-
|
| 233 |
-
**Agent not processing**:
|
| 234 |
-
```bash
|
| 235 |
-
make restart-agent
|
| 236 |
-
make debug-agent
|
| 237 |
-
```
|
| 238 |
-
|
| 239 |
-
**Frontend not loading**:
|
| 240 |
-
```bash
|
| 241 |
-
make restart-frontend
|
| 242 |
-
make debug-frontend
|
| 243 |
-
```
|
| 244 |
-
|
| 245 |
-
**Database connection issues**:
|
| 246 |
-
```bash
|
| 247 |
-
make health
|
| 248 |
-
# Check .env configuration
|
| 249 |
-
```
|
| 250 |
-
|
| 251 |
-
### Debug Mode
|
| 252 |
-
|
| 253 |
-
For detailed logging, check individual service logs:
|
| 254 |
-
```bash
|
| 255 |
-
docker-compose logs -f agent
|
| 256 |
-
docker-compose logs -f mcp
|
| 257 |
-
docker-compose logs -f frontend
|
| 258 |
-
```
|
| 259 |
-
|
| 260 |
-
### Reset Everything
|
| 261 |
-
|
| 262 |
-
Complete clean slate:
|
| 263 |
-
```bash
|
| 264 |
-
make clean
|
| 265 |
-
cp .env.example .env
|
| 266 |
-
# Edit .env with your API key
|
| 267 |
-
make demo
|
| 268 |
-
```
|
| 269 |
-
|
| 270 |
-
## Contributing
|
| 271 |
-
|
| 272 |
-
1. Fork the repository
|
| 273 |
-
2. Create a feature branch
|
| 274 |
-
3. Test with `make demo`
|
| 275 |
-
4. Submit a pull request
|
| 276 |
-
|
| 277 |
-
## License
|
| 278 |
-
|
| 279 |
-
MIT License - see LICENSE file for details.
|
| 280 |
-
|
| 281 |
-
---
|
| 282 |
-
|
| 283 |
-
## Quick Demo
|
| 284 |
-
|
| 285 |
-
Want to see it in action immediately?
|
| 286 |
-
|
| 287 |
-
```bash
|
| 288 |
-
# 1. Clone repo
|
| 289 |
-
git clone <repo-url> && cd <repo-name>
|
| 290 |
-
|
| 291 |
-
# 2. Add your API key
|
| 292 |
-
cp .env.example .env
|
| 293 |
-
# Edit .env: LLM_API_KEY=your-key-here
|
| 294 |
-
|
| 295 |
-
# 3. Start everything
|
| 296 |
-
make demo
|
| 297 |
-
|
| 298 |
-
# 4. Open http://localhost:3000
|
| 299 |
-
# 5. Ask: "How many records are in the database?"
|
| 300 |
-
# 6. Watch the magic happen! β¨
|
| 301 |
-
```
|
| 302 |
-
|
| 303 |
-
The system will:
|
| 304 |
-
- Create a workflow with 3 instructions
|
| 305 |
-
- Pause for 5 minutes before each step (editable in Neo4j Browser)
|
| 306 |
-
- Generate SQL from your natural language question
|
| 307 |
-
- Execute the query and return formatted results
|
| 308 |
-
- Show the entire process in a visual graph
|
| 309 |
-
|
| 310 |
-
π **Welcome to the future of human-AI collaboration!**
|
|
|
|
| 1 |
+
# GraphRAG Agentic System
|
|
|
|
| 2 |
|
| 3 |
## Overview
|
| 4 |
+
This project implements an intelligent, multi-step GraphRAG-powered agent that uses LangChain to orchestrate complex queries against a federated life sciences dataset. The agent leverages a Neo4j graph database to understand the relationships between disparate SQLite databases, constructs SQL queries, and returns unified results through a conversational UI.
|
| 5 |
|
| 6 |
## Key Features
|
| 7 |
|
| 8 |
+
π€ **LangChain Agent**: Orchestrates tools for schema discovery, pathfinding, and query execution.
|
| 9 |
+
πΈοΈ **GraphRAG Enabled**: Uses a Neo4j knowledge graph of database schemas for intelligent query planning.
|
| 10 |
+
π¬ **Life Sciences Dataset**: Comes with a rich dataset across clinical trials, drug discovery, and lab results.
|
| 11 |
+
conversational **Conversational UI**: A Streamlit-based chat interface for interacting with the agent.
|
| 12 |
+
π **RESTful MCP Server**: All core logic is exposed via a secure and scalable FastAPI server.
|
|
|
|
|
|
|
| 13 |
|
| 14 |
## Architecture
|
| 15 |
|
| 16 |
```
|
| 17 |
+
βββββββββββββββββββ βββββββββββββββββ βββββββββββββββββββ
|
| 18 |
+
β Streamlit Chat ββββββββ Agent β β MCP Server β
|
| 19 |
+
β (UI) β β (LangChain) β β (FastAPI) β
|
| 20 |
+
βββββββββββββββββββ βββββββββββββββββ βββββββββββββββββββ
|
| 21 |
+
β
|
| 22 |
+
βββββββββββββββββββββββββΌββββββββββββββββββββββββ
|
| 23 |
+
β β β
|
| 24 |
+
βββββββββββββββ βββββββββββββββ βββββββββββββββ
|
| 25 |
+
β Neo4j β β clinical_ β β laboratory β
|
| 26 |
+
β (Schema KG) β β trials.db β β .db β
|
| 27 |
+
βββββββββββββββ βββββββββββββββ βββββββββββββββ
|
| 28 |
+
β
|
| 29 |
+
βββββββββββββββ
|
| 30 |
+
β drug_ β
|
| 31 |
+
β discovery.dbβ
|
| 32 |
+
βββββββββββββββ
|
|
|
|
| 33 |
|
| 34 |
+
```
|
|
|
|
| 35 |
|
| 36 |
### Components
|
| 37 |
|
| 38 |
+
- **Streamlit**: Provides a conversational chat interface for users to ask questions.
|
| 39 |
+
- **Agent**: A LangChain-powered orchestrator that uses custom tools to query the MCP server.
|
| 40 |
+
- **MCP Server**: A FastAPI application that exposes core logic for schema discovery, graph pathfinding, and federated query execution.
|
| 41 |
+
- **Neo4j**: Stores a knowledge graph of the schemas of all connected SQLite databases.
|
| 42 |
+
- **SQLite Databases**: A set of life sciences databases (`clinical_trials.db`, `drug_discovery.db`, `laboratory.db`) that serve as the federated data sources.
|
| 43 |
|
| 44 |
## Quick Start
|
| 45 |
|
| 46 |
### Prerequisites
|
| 47 |
- Docker & Docker Compose
|
| 48 |
+
- OpenAI API key
|
| 49 |
|
| 50 |
### Setup
|
| 51 |
1. **Clone and configure**:
|
| 52 |
```bash
|
| 53 |
git clone <repository-url>
|
| 54 |
cd <repository-name>
|
| 55 |
+
touch .env
|
| 56 |
```
|
| 57 |
|
| 58 |
+
2. **Add your OpenAI API key** to the `.env` file. This is the only secret you need to provide.
|
| 59 |
+
```
|
| 60 |
+
OPENAI_API_KEY="sk-your-openai-key-here"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
```
|
| 62 |
|
| 63 |
3. **Start the system**:
|
|
|
|
| 65 |
make up
|
| 66 |
```
|
| 67 |
|
| 68 |
+
4. **Seed the databases and ingest schema**:
|
| 69 |
```bash
|
| 70 |
+
make seed-db
|
| 71 |
+
make ingest
|
| 72 |
```
|
| 73 |
|
| 74 |
5. **Open the interface**:
|
| 75 |
+
- Streamlit UI: http://localhost:8501
|
| 76 |
- Neo4j Browser: http://localhost:7474 (neo4j/password)
|
| 77 |
|
| 78 |
## Usage
|
| 79 |
+
Once the system is running, open the Streamlit UI and ask a question about the life sciences data, for example:
|
| 80 |
+
- "What are the names of the trials and their primary purpose for studies on 'Cancer'?"
|
| 81 |
+
- "Find all drugs with 'Aspirin' in their name."
|
| 82 |
+
- "Show me lab results for patient '123'."
|
| 83 |
+
|
| 84 |
+
The agent will then:
|
| 85 |
+
1. Use the `SchemaSearchTool` to find relevant tables.
|
| 86 |
+
2. Use the `JoinPathFinderTool` to determine how to join them.
|
| 87 |
+
3. Construct a SQL query.
|
| 88 |
+
4. Execute the query using the `QueryExecutorTool`.
|
| 89 |
+
5. Return the final answer to the UI.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
## Development
|
| 92 |
|
| 93 |
+
### Running the Agent Manually
|
| 94 |
+
To test the agent's logic directly without the full Docker stack, you can run it from your terminal.
|
| 95 |
+
|
| 96 |
+
1. **Set up the environment**:
|
| 97 |
+
Make sure the MCP and Neo4j services are running (`make up`).
|
| 98 |
+
Create a Python virtual environment and install dependencies:
|
| 99 |
+
```bash
|
| 100 |
+
python -m venv venv
|
| 101 |
+
source venv/bin/activate
|
| 102 |
+
pip install -r agent/requirements.txt
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
2. **Set your API key**:
|
| 106 |
+
```bash
|
| 107 |
+
export OPENAI_API_KEY="sk-your-openai-key-here"
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
3. **Run the agent**:
|
| 111 |
+
```bash
|
| 112 |
+
python agent/main.py
|
| 113 |
+
```
|
| 114 |
+
The agent will run with the hardcoded example question and print the execution trace and final answer to your console.
|
| 115 |
+
|
| 116 |
### File Structure
|
| 117 |
```
|
| 118 |
+
βββ agent/ # The LangChain agent and its tools
|
| 119 |
+
βββ streamlit/ # The Streamlit conversational UI
|
| 120 |
+
βββ mcp/ # FastAPI server with core logic
|
| 121 |
+
βββ neo4j/ # Neo4j configuration and data
|
| 122 |
+
βββ data/ # SQLite databases
|
| 123 |
+
βββ ops/ # Operational scripts (seeding, ingestion, etc.)
|
| 124 |
βββ docker-compose.yml
|
| 125 |
βββ Makefile
|
| 126 |
βββ README.md
|
| 127 |
+
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent/main.py
CHANGED
|
@@ -1,435 +1,134 @@
|
|
| 1 |
import os
|
| 2 |
-
import time
|
| 3 |
-
import json
|
| 4 |
-
import requests
|
| 5 |
-
import signal
|
| 6 |
import sys
|
| 7 |
-
|
| 8 |
-
import
|
| 9 |
-
from
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
# Configure LLM
|
| 16 |
-
LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4")
|
| 17 |
-
LLM_API_KEY = os.getenv("LLM_API_KEY")
|
| 18 |
-
|
| 19 |
-
if "gpt" in LLM_MODEL:
|
| 20 |
-
openai.api_key = LLM_API_KEY
|
| 21 |
-
llm_client = None
|
| 22 |
-
else:
|
| 23 |
-
llm_client = Anthropic(api_key=LLM_API_KEY)
|
| 24 |
-
|
| 25 |
-
# Defining the agents
|
| 26 |
-
## Data Procurement Agent
|
| 27 |
-
## Graph Analysis Agent
|
| 28 |
-
## x Agent etc...
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
DataProcurementAgent(),
|
| 35 |
-
GraphAnalysisAgent(),
|
| 36 |
-
xAgent()
|
| 37 |
-
],
|
| 38 |
-
plan_type="full",
|
| 39 |
-
plan_output_path=Path("output/execution_plan.md"),
|
| 40 |
-
)
|
| 41 |
|
|
|
|
| 42 |
|
| 43 |
-
#
|
| 44 |
-
|
|
|
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
print(f"\n[{datetime.now()}] Interrupt received, will stop after current instruction")
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
|
|
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
headers={"X-API-Key": API_KEY, "Content-Type": "application/json"},
|
| 58 |
-
json={"tool": tool, "params": params or {}}
|
| 59 |
-
)
|
| 60 |
-
return response.json()
|
| 61 |
|
| 62 |
-
def
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
response = openai.ChatCompletion.create(
|
| 66 |
-
model=LLM_MODEL,
|
| 67 |
-
messages=[
|
| 68 |
-
{"role": "system", "content": "You are a SQL expert. Generate only valid PostgreSQL queries."},
|
| 69 |
-
{"role": "user", "content": prompt}
|
| 70 |
-
],
|
| 71 |
-
temperature=0,
|
| 72 |
-
max_tokens=500
|
| 73 |
-
)
|
| 74 |
-
return response.choices[0].message.content
|
| 75 |
-
else:
|
| 76 |
-
response = llm_client.messages.create(
|
| 77 |
-
model=LLM_MODEL,
|
| 78 |
-
max_tokens=500,
|
| 79 |
-
temperature=0,
|
| 80 |
-
messages=[{"role": "user", "content": prompt}]
|
| 81 |
-
)
|
| 82 |
-
return response.content[0].text
|
| 83 |
|
| 84 |
-
|
| 85 |
-
"""Discover PostgreSQL schema and store in Neo4j"""
|
| 86 |
-
print(f"[{datetime.now()}] Discovering PostgreSQL schema...")
|
| 87 |
-
|
| 88 |
-
# Call MCP to discover schema
|
| 89 |
-
schema_result = call_mcp("discover_postgres_schema")
|
| 90 |
-
|
| 91 |
-
if "error" in schema_result:
|
| 92 |
-
return {"status": "failed", "error": schema_result["error"]}
|
| 93 |
-
|
| 94 |
-
schema = schema_result["schema"]
|
| 95 |
-
|
| 96 |
-
# Create SourceSystem node
|
| 97 |
-
call_mcp("write_graph", {
|
| 98 |
-
"action": "create_node",
|
| 99 |
-
"label": "SourceSystem",
|
| 100 |
-
"properties": {
|
| 101 |
-
"id": "postgres-main",
|
| 102 |
-
"name": "Main PostgreSQL Database",
|
| 103 |
-
"type": "postgresql",
|
| 104 |
-
"discovered_at": datetime.now().isoformat()
|
| 105 |
-
}
|
| 106 |
-
})
|
| 107 |
-
|
| 108 |
-
# For each table, create nodes
|
| 109 |
-
for table_name, columns in schema.items():
|
| 110 |
-
# Create Table node
|
| 111 |
-
table_result = call_mcp("write_graph", {
|
| 112 |
-
"action": "create_node",
|
| 113 |
-
"label": "Table",
|
| 114 |
-
"properties": {
|
| 115 |
-
"name": table_name,
|
| 116 |
-
"schema": "public",
|
| 117 |
-
"column_count": len(columns)
|
| 118 |
-
}
|
| 119 |
-
})
|
| 120 |
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
MERGE (s)-[:HAS_TABLE]->(t)
|
| 127 |
-
""",
|
| 128 |
-
"parameters": {"table_name": table_name}
|
| 129 |
-
})
|
| 130 |
-
|
| 131 |
-
# Create Column nodes
|
| 132 |
-
for col in columns:
|
| 133 |
-
col_result = call_mcp("write_graph", {
|
| 134 |
-
"action": "create_node",
|
| 135 |
-
"label": "Column",
|
| 136 |
-
"properties": {
|
| 137 |
-
"name": col['column_name'],
|
| 138 |
-
"data_type": col['data_type'],
|
| 139 |
-
"nullable": col['is_nullable'] == 'YES',
|
| 140 |
-
"table_name": table_name
|
| 141 |
-
}
|
| 142 |
-
})
|
| 143 |
-
|
| 144 |
-
# Link Column to Table
|
| 145 |
-
call_mcp("query_graph", {
|
| 146 |
-
"query": """
|
| 147 |
-
MATCH (t:Table {name: $table_name}),
|
| 148 |
-
(c:Column {name: $col_name, table_name: $table_name})
|
| 149 |
-
MERGE (t)-[:HAS_COLUMN]->(c)
|
| 150 |
-
""",
|
| 151 |
-
"parameters": {
|
| 152 |
-
"table_name": table_name,
|
| 153 |
-
"col_name": col['column_name']
|
| 154 |
-
}
|
| 155 |
-
})
|
| 156 |
-
|
| 157 |
-
# Generate sample queries
|
| 158 |
-
for table_name in schema.keys():
|
| 159 |
-
sample_queries = [
|
| 160 |
-
f"SELECT * FROM {table_name} LIMIT 10",
|
| 161 |
-
f"SELECT COUNT(*) FROM {table_name}",
|
| 162 |
-
f"SELECT * FROM {table_name} WHERE id = 1"
|
| 163 |
]
|
| 164 |
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
"action": "create_node",
|
| 168 |
-
"label": "QueryTemplate",
|
| 169 |
-
"properties": {
|
| 170 |
-
"id": f"template-{table_name}-{idx}",
|
| 171 |
-
"table_name": table_name,
|
| 172 |
-
"query": query,
|
| 173 |
-
"description": f"Sample query {idx+1} for {table_name}"
|
| 174 |
-
}
|
| 175 |
-
})
|
| 176 |
-
|
| 177 |
-
return {
|
| 178 |
-
"status": "success",
|
| 179 |
-
"tables_discovered": len(schema),
|
| 180 |
-
"columns_discovered": sum(len(cols) for cols in schema.values())
|
| 181 |
-
}
|
| 182 |
-
|
| 183 |
-
def handle_generate_sql(instruction):
|
| 184 |
-
"""Generate SQL from natural language using LLM"""
|
| 185 |
-
print(f"[{datetime.now()}] Generating SQL from natural language...")
|
| 186 |
-
|
| 187 |
-
# Get the user question from instruction parameters
|
| 188 |
-
params = json.loads(instruction.get('parameters', '{}'))
|
| 189 |
-
user_question = params.get('question', 'Show all data')
|
| 190 |
-
|
| 191 |
-
# Fetch schema from Neo4j
|
| 192 |
-
schema_result = call_mcp("query_graph", {
|
| 193 |
-
"query": """
|
| 194 |
-
MATCH (t:Table)-[:HAS_COLUMN]->(c:Column)
|
| 195 |
-
RETURN t.name as table_name,
|
| 196 |
-
collect({
|
| 197 |
-
name: c.name,
|
| 198 |
-
type: c.data_type,
|
| 199 |
-
nullable: c.nullable
|
| 200 |
-
}) as columns
|
| 201 |
-
"""
|
| 202 |
-
})
|
| 203 |
-
|
| 204 |
-
# Format schema for LLM
|
| 205 |
-
schema_text = "PostgreSQL Schema:\n"
|
| 206 |
-
for record in schema_result['data']:
|
| 207 |
-
table = record['table_name']
|
| 208 |
-
columns = record['columns']
|
| 209 |
-
schema_text += f"\nTable: {table}\n"
|
| 210 |
-
for col in columns:
|
| 211 |
-
nullable = "NULL" if col['nullable'] else "NOT NULL"
|
| 212 |
-
schema_text += f" - {col['name']}: {col['type']} {nullable}\n"
|
| 213 |
-
|
| 214 |
-
# Create prompt
|
| 215 |
-
prompt = f"""Given this PostgreSQL schema:
|
| 216 |
-
|
| 217 |
-
{schema_text}
|
| 218 |
-
|
| 219 |
-
Generate a SQL query for this question: {user_question}
|
| 220 |
-
|
| 221 |
-
Return ONLY the SQL query, no explanations or markdown."""
|
| 222 |
-
|
| 223 |
-
try:
|
| 224 |
-
# Get SQL from LLM
|
| 225 |
-
generated_sql = get_llm_response(prompt)
|
| 226 |
-
|
| 227 |
-
# Clean up the SQL (remove markdown if present)
|
| 228 |
-
generated_sql = generated_sql.strip()
|
| 229 |
-
if generated_sql.startswith("```"):
|
| 230 |
-
generated_sql = generated_sql.split("```")[1]
|
| 231 |
-
if generated_sql.startswith("sql"):
|
| 232 |
-
generated_sql = generated_sql[3:]
|
| 233 |
-
generated_sql = generated_sql.strip()
|
| 234 |
-
|
| 235 |
-
print(f"[{datetime.now()}] Generated SQL: {generated_sql}")
|
| 236 |
|
| 237 |
-
#
|
| 238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
"
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
if workflow_id and check_workflow_stop(workflow_id):
|
| 310 |
-
print(f"[{datetime.now()}] Workflow stopped during pause")
|
| 311 |
-
return False
|
| 312 |
-
|
| 313 |
-
# Check for global interrupt
|
| 314 |
-
if interrupted:
|
| 315 |
-
print(f"[{datetime.now()}] Interrupted during pause")
|
| 316 |
-
return False
|
| 317 |
-
|
| 318 |
-
# Show progress
|
| 319 |
-
remaining = duration - elapsed
|
| 320 |
-
if remaining > 0 and elapsed % 30 == 0: # Update every 30 seconds
|
| 321 |
-
print(f"[{datetime.now()}] Pause remaining: {remaining} seconds")
|
| 322 |
-
|
| 323 |
-
# Log pause end
|
| 324 |
-
call_mcp("write_graph", {
|
| 325 |
-
"action": "create_node",
|
| 326 |
-
"label": "Log",
|
| 327 |
-
"properties": {
|
| 328 |
-
"type": "pause_completed",
|
| 329 |
-
"instruction_id": instruction_id,
|
| 330 |
-
"timestamp": datetime.now().isoformat()
|
| 331 |
-
}
|
| 332 |
-
})
|
| 333 |
-
|
| 334 |
-
print(f"[{datetime.now()}] Pause complete, continuing execution")
|
| 335 |
-
return True
|
| 336 |
-
|
| 337 |
def main():
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
while not interrupted:
|
| 342 |
-
try:
|
| 343 |
-
result = call_mcp("get_next_instruction")
|
| 344 |
-
instruction = result.get("instruction")
|
| 345 |
-
|
| 346 |
-
if instruction:
|
| 347 |
-
print(f"[{datetime.now()}] Found instruction: {instruction['id']}, type: {instruction['type']}")
|
| 348 |
-
|
| 349 |
-
# Get workflow ID
|
| 350 |
-
workflow_result = call_mcp("query_graph", {
|
| 351 |
-
"query": """
|
| 352 |
-
MATCH (w:Workflow)-[:HAS_INSTRUCTION]->(i:Instruction {id: $id})
|
| 353 |
-
RETURN w.id as workflow_id
|
| 354 |
-
""",
|
| 355 |
-
"parameters": {"id": instruction['id']}
|
| 356 |
-
})
|
| 357 |
-
workflow_id = workflow_result['data'][0]['workflow_id'] if workflow_result['data'] else None
|
| 358 |
-
|
| 359 |
-
# PAUSE BEFORE EXECUTION
|
| 360 |
-
pause_duration = instruction.get('pause_duration', 300)
|
| 361 |
-
if pause_duration > 0:
|
| 362 |
-
if not pause_with_interrupt(pause_duration, instruction['id'], workflow_id):
|
| 363 |
-
print(f"[{datetime.now()}] Execution cancelled during pause")
|
| 364 |
-
continue
|
| 365 |
-
|
| 366 |
-
# Re-fetch instruction to get any edits made during pause
|
| 367 |
-
refetch_result = call_mcp("query_graph", {
|
| 368 |
-
"query": "MATCH (i:Instruction {id: $id}) RETURN i",
|
| 369 |
-
"parameters": {"id": instruction['id']}
|
| 370 |
-
})
|
| 371 |
-
|
| 372 |
-
if refetch_result['data']:
|
| 373 |
-
instruction = refetch_result['data'][0]['i']
|
| 374 |
-
print(f"[{datetime.now()}] Re-fetched instruction after pause, parameters: {instruction.get('parameters')}")
|
| 375 |
-
|
| 376 |
-
# Update status to executing
|
| 377 |
-
call_mcp("query_graph", {
|
| 378 |
-
"query": "MATCH (i:Instruction {id: $id}) SET i.status = 'executing'",
|
| 379 |
-
"parameters": {"id": instruction['id']}
|
| 380 |
-
})
|
| 381 |
-
|
| 382 |
-
# Execute based on type
|
| 383 |
-
if instruction['type'] == 'discover_schema':
|
| 384 |
-
exec_result = handle_discover_schema(instruction)
|
| 385 |
-
elif instruction['type'] == 'generate_sql':
|
| 386 |
-
exec_result = handle_generate_sql(instruction)
|
| 387 |
-
else:
|
| 388 |
-
exec_result = {"status": "success", "result": "Reviewed"}
|
| 389 |
-
|
| 390 |
-
# Store execution result
|
| 391 |
-
exec_node = call_mcp("write_graph", {
|
| 392 |
-
"action": "create_node",
|
| 393 |
-
"label": "Execution",
|
| 394 |
-
"properties": {
|
| 395 |
-
"id": f"exec-{instruction['id']}-{int(time.time())}",
|
| 396 |
-
"started_at": datetime.now().isoformat(),
|
| 397 |
-
"completed_at": datetime.now().isoformat(),
|
| 398 |
-
"result": json.dumps(exec_result)
|
| 399 |
-
}
|
| 400 |
-
})
|
| 401 |
-
|
| 402 |
-
# Link execution
|
| 403 |
-
call_mcp("query_graph", {
|
| 404 |
-
"query": """
|
| 405 |
-
MATCH (i:Instruction {id: $iid}), (e:Execution {id: $eid})
|
| 406 |
-
CREATE (i)-[:EXECUTED_AS]->(e)
|
| 407 |
-
""",
|
| 408 |
-
"parameters": {
|
| 409 |
-
"iid": instruction['id'],
|
| 410 |
-
"eid": exec_node['created']['id']
|
| 411 |
-
}
|
| 412 |
-
})
|
| 413 |
-
|
| 414 |
-
# Update status
|
| 415 |
-
final_status = 'complete' if exec_result.get('status') == 'success' else 'failed'
|
| 416 |
-
call_mcp("query_graph", {
|
| 417 |
-
"query": "MATCH (i:Instruction {id: $id}) SET i.status = $status",
|
| 418 |
-
"parameters": {"id": instruction['id'], "status": final_status}
|
| 419 |
-
})
|
| 420 |
-
|
| 421 |
-
print(f"[{datetime.now()}] Completed instruction: {instruction['id']}")
|
| 422 |
-
|
| 423 |
-
else:
|
| 424 |
-
print(f"[{datetime.now()}] No pending instructions")
|
| 425 |
-
|
| 426 |
-
time.sleep(POLL_INTERVAL)
|
| 427 |
-
|
| 428 |
-
except Exception as e:
|
| 429 |
-
print(f"[{datetime.now()}] Error: {e}")
|
| 430 |
-
time.sleep(POLL_INTERVAL)
|
| 431 |
-
|
| 432 |
-
print(f"[{datetime.now()}] Agent shutting down")
|
| 433 |
|
| 434 |
if __name__ == "__main__":
|
| 435 |
main()
|
|
|
|
| 1 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import sys
|
| 3 |
+
import logging
|
| 4 |
+
import json
|
| 5 |
+
from typing import Annotated, List, TypedDict
|
| 6 |
+
from fastapi import FastAPI
|
| 7 |
+
from pydantic import BaseModel
|
| 8 |
+
import uvicorn
|
| 9 |
+
from fastapi.responses import StreamingResponse
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
from langchain_core.messages import BaseMessage, ToolMessage, AIMessage
|
| 12 |
+
from langchain_openai import OpenAI
|
| 13 |
+
from langgraph.graph import StateGraph, START, END
|
| 14 |
+
from langgraph.prebuilt import ToolNode
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
from agent.tools import MCPClient, SchemaSearchTool, JoinPathFinderTool, QueryExecutorTool
|
| 17 |
|
| 18 |
+
# --- Configuration & Logging ---
|
| 19 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
|
| 22 |
+
MCP_URL = os.getenv("MCP_URL", "http://mcp:8000/mcp")
|
| 23 |
+
API_KEY = os.getenv("MCP_API_KEY", "dev-key-123")
|
| 24 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
|
|
|
| 25 |
|
| 26 |
+
# --- Agent State Definition ---
|
| 27 |
+
class AgentState(TypedDict):
|
| 28 |
+
messages: List[BaseMessage]
|
| 29 |
|
| 30 |
+
# --- Agent Initialization ---
|
| 31 |
+
class GraphRAGAgent:
|
| 32 |
+
"""The core agent for handling GraphRAG queries using LangGraph."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
+
def __init__(self):
|
| 35 |
+
if not OPENAI_API_KEY:
|
| 36 |
+
raise ValueError("OPENAI_API_KEY environment variable not set.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
llm = OpenAI(api_key=OPENAI_API_KEY, temperature=0, max_retries=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
+
mcp_client = MCPClient(mcp_url=MCP_URL, api_key=API_KEY)
|
| 41 |
+
tools = [
|
| 42 |
+
SchemaSearchTool(mcp_client=mcp_client),
|
| 43 |
+
JoinPathFinderTool(mcp_client=mcp_client),
|
| 44 |
+
QueryExecutorTool(mcp_client=mcp_client),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
]
|
| 46 |
|
| 47 |
+
self.llm_with_tools = llm.bind_tools(tools)
|
| 48 |
+
self.tool_node = ToolNode(tools)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
+
# Define the agent graph
|
| 51 |
+
graph = StateGraph(AgentState)
|
| 52 |
+
graph.add_node("llm", self.call_llm)
|
| 53 |
+
graph.add_node("tools", self.tool_node)
|
| 54 |
+
|
| 55 |
+
graph.add_edge(START, "llm")
|
| 56 |
+
graph.add_conditional_edges("llm", self.should_call_tools)
|
| 57 |
+
graph.add_edge("tools", "llm")
|
| 58 |
|
| 59 |
+
self.graph = graph.compile()
|
| 60 |
+
|
| 61 |
+
def should_call_tools(self, state: AgentState) -> str:
|
| 62 |
+
"""Determines whether to call tools or end the execution."""
|
| 63 |
+
last_message = state["messages"][-1]
|
| 64 |
+
if not last_message.tool_calls:
|
| 65 |
+
return END
|
| 66 |
+
return "tools"
|
| 67 |
+
|
| 68 |
+
def call_llm(self, state: AgentState) -> dict:
|
| 69 |
+
"""Calls the LLM with the current state to decide the next action."""
|
| 70 |
+
response = self.llm_with_tools.invoke(state["messages"])
|
| 71 |
+
return {"messages": [response]}
|
| 72 |
+
|
| 73 |
+
async def stream_query(self, question: str):
|
| 74 |
+
"""Processes a question and streams the intermediate steps."""
|
| 75 |
+
inputs = {"messages": [("user", question)]}
|
| 76 |
+
async for event in self.graph.astream(inputs, stream_mode="values"):
|
| 77 |
+
last_message = event["messages"][-1]
|
| 78 |
+
if isinstance(last_message, AIMessage) and last_message.tool_calls:
|
| 79 |
+
# Agent is thinking and calling a tool
|
| 80 |
+
tool_call = last_message.tool_calls[0]
|
| 81 |
+
yield json.dumps({
|
| 82 |
+
"type": "thought",
|
| 83 |
+
"content": f"π€ Calling tool `{tool_call['name']}` with args: {tool_call['args']}"
|
| 84 |
+
}) + "\\n\\n"
|
| 85 |
+
elif isinstance(last_message, ToolMessage):
|
| 86 |
+
# A tool has returned its result
|
| 87 |
+
yield json.dumps({
|
| 88 |
+
"type": "observation",
|
| 89 |
+
"content": f"π οΈ Tool `{last_message.name}` returned:\n\n```\n{last_message.content}\n```"
|
| 90 |
+
}) + "\\n\\n"
|
| 91 |
+
elif isinstance(last_message, AIMessage):
|
| 92 |
+
# This is the final answer
|
| 93 |
+
yield json.dumps({"type": "final_answer", "content": last_message.content}) + "\\n\\n"
|
| 94 |
+
|
| 95 |
+
# --- FastAPI Application ---
|
| 96 |
+
app = FastAPI(title="GraphRAG Agent Server")
|
| 97 |
+
agent = None
|
| 98 |
+
|
| 99 |
+
class QueryRequest(BaseModel):
|
| 100 |
+
question: str
|
| 101 |
+
|
| 102 |
+
@app.on_event("startup")
|
| 103 |
+
def startup_event():
|
| 104 |
+
"""Initialize the agent on server startup."""
|
| 105 |
+
global agent
|
| 106 |
+
try:
|
| 107 |
+
agent = GraphRAGAgent()
|
| 108 |
+
logger.info("GraphRAGAgent initialized successfully.")
|
| 109 |
+
except ValueError as e:
|
| 110 |
+
logger.error(f"Agent initialization failed: {e}")
|
| 111 |
+
|
| 112 |
+
@app.post("/query")
|
| 113 |
+
async def execute_query(request: QueryRequest) -> StreamingResponse:
|
| 114 |
+
"""Endpoint to receive questions and stream the agent's response."""
|
| 115 |
+
if not agent:
|
| 116 |
+
async def error_stream():
|
| 117 |
+
yield json.dumps({"error": "Agent is not initialized. Check server logs."})
|
| 118 |
+
return StreamingResponse(error_stream())
|
| 119 |
+
|
| 120 |
+
return StreamingResponse(agent.stream_query(request.question), media_type="application/x-ndjson")
|
| 121 |
+
|
| 122 |
+
@app.get("/health")
|
| 123 |
+
def health_check():
|
| 124 |
+
"""Health check endpoint."""
|
| 125 |
+
return {"status": "ok", "agent_initialized": agent is not None}
|
| 126 |
+
|
| 127 |
+
# --- Main Execution ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
def main():
|
| 129 |
+
"""Main entry point to run the FastAPI server."""
|
| 130 |
+
logger.info("Starting agent server...")
|
| 131 |
+
uvicorn.run(app, host="0.0.0.0", port=8001)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
if __name__ == "__main__":
|
| 134 |
main()
|
agent/requirements.txt
CHANGED
|
@@ -1,4 +1,8 @@
|
|
| 1 |
-
requests
|
| 2 |
-
python-dotenv
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
requests
|
| 2 |
+
python-dotenv
|
| 3 |
+
langchain
|
| 4 |
+
langchain-openai
|
| 5 |
+
pydantic
|
| 6 |
+
fastapi
|
| 7 |
+
uvicorn[standard]
|
| 8 |
+
langgraph
|
agent/tools.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import requests
|
| 3 |
+
import json
|
| 4 |
+
from typing import Dict, Any, List, Optional
|
| 5 |
+
from langchain.tools import BaseTool
|
| 6 |
+
from pydantic import Field
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
class MCPClient:
|
| 12 |
+
"""Client for making authenticated REST API calls to the MCP server."""
|
| 13 |
+
|
| 14 |
+
def __init__(self, mcp_url: str, api_key: str):
|
| 15 |
+
self.mcp_url = mcp_url
|
| 16 |
+
self.headers = {
|
| 17 |
+
"Authorization": f"Bearer {api_key}",
|
| 18 |
+
"Content-Type": "application/json"
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
def post(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 22 |
+
"""Send a POST request to a given MCP endpoint."""
|
| 23 |
+
try:
|
| 24 |
+
url = f"{self.mcp_url}/{endpoint}"
|
| 25 |
+
response = requests.post(url, headers=self.headers, data=json.dumps(data))
|
| 26 |
+
response.raise_for_status()
|
| 27 |
+
return response.json()
|
| 28 |
+
except requests.exceptions.HTTPError as http_err:
|
| 29 |
+
logger.error(f"HTTP error occurred: {http_err} - {response.text}")
|
| 30 |
+
return {"status": "error", "message": f"HTTP error: {response.status_code} {response.reason}"}
|
| 31 |
+
except requests.exceptions.RequestException as req_err:
|
| 32 |
+
logger.error(f"Request error occurred: {req_err}")
|
| 33 |
+
return {"status": "error", "message": f"Request failed: {req_err}"}
|
| 34 |
+
except json.JSONDecodeError:
|
| 35 |
+
logger.error("Failed to decode JSON response.")
|
| 36 |
+
return {"status": "error", "message": "Invalid JSON response from server."}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class SchemaSearchTool(BaseTool):
|
| 40 |
+
"""LangChain tool for searching database schemas."""
|
| 41 |
+
|
| 42 |
+
name: str = "schema_search"
|
| 43 |
+
description: str = """
|
| 44 |
+
Search for relevant database schemas based on a natural language query.
|
| 45 |
+
Use this when you need to find which tables/columns are relevant to a user's question.
|
| 46 |
+
Input should be a descriptive query like 'patient information' or 'drug trials'.
|
| 47 |
+
"""
|
| 48 |
+
mcp_client: MCPClient
|
| 49 |
+
|
| 50 |
+
def _run(self, query: str) -> str:
|
| 51 |
+
"""Execute schema search."""
|
| 52 |
+
response = self.mcp_client.post("discovery/get_relevant_schemas", {"query": query})
|
| 53 |
+
|
| 54 |
+
if response.get("status") == "success":
|
| 55 |
+
schemas = response.get("schemas", [])
|
| 56 |
+
if schemas:
|
| 57 |
+
schema_text = "Found relevant schemas:\\n"
|
| 58 |
+
for schema in schemas:
|
| 59 |
+
schema_text += f"- {schema.get('database', 'Unknown')}.{schema.get('table', 'Unknown')}.{schema.get('name', 'Unknown')} ({schema.get('type', ['Unknown'])[0]})\\n"
|
| 60 |
+
return schema_text
|
| 61 |
+
else:
|
| 62 |
+
return "No relevant schemas found."
|
| 63 |
+
else:
|
| 64 |
+
return f"Error searching schemas: {response.get('message', 'Unknown error')}"
|
| 65 |
+
|
| 66 |
+
async def _arun(self, query: str) -> str:
|
| 67 |
+
raise NotImplementedError("SchemaSearchTool does not support async")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class JoinPathFinderTool(BaseTool):
|
| 71 |
+
"""LangChain tool for finding join paths between tables."""
|
| 72 |
+
|
| 73 |
+
name: str = "find_join_path"
|
| 74 |
+
description: str = """
|
| 75 |
+
Find how to join two tables together using foreign key relationships.
|
| 76 |
+
Use this when you need to query across multiple tables.
|
| 77 |
+
Input should be two table names separated by a comma, like 'patients,studies'.
|
| 78 |
+
"""
|
| 79 |
+
mcp_client: MCPClient
|
| 80 |
+
|
| 81 |
+
def _run(self, table_names: str) -> str:
|
| 82 |
+
"""Find join path."""
|
| 83 |
+
try:
|
| 84 |
+
tables = [t.strip() for t in table_names.split(',')]
|
| 85 |
+
if len(tables) != 2:
|
| 86 |
+
return "Please provide exactly two table names separated by a comma."
|
| 87 |
+
|
| 88 |
+
response = self.mcp_client.post(
|
| 89 |
+
"graph/find_join_path",
|
| 90 |
+
{"table1": tables[0], "table2": tables[1]}
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
if response.get("status") == "success":
|
| 94 |
+
path = response.get("path", "No path found")
|
| 95 |
+
return f"Join path: {path}"
|
| 96 |
+
else:
|
| 97 |
+
return f"Error finding join path: {response.get('message', 'Unknown error')}"
|
| 98 |
+
except Exception as e:
|
| 99 |
+
return f"Failed to find join path: {str(e)}"
|
| 100 |
+
|
| 101 |
+
async def _arun(self, table_names: str) -> str:
|
| 102 |
+
raise NotImplementedError("JoinPathFinderTool does not support async")
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class QueryExecutorTool(BaseTool):
|
| 106 |
+
"""LangChain tool for executing SQL queries."""
|
| 107 |
+
|
| 108 |
+
name: str = "execute_query"
|
| 109 |
+
description: str = """
|
| 110 |
+
Execute a SQL query against the databases and return results.
|
| 111 |
+
Use this after you have a valid SQL query.
|
| 112 |
+
Input should be a valid SQL query string.
|
| 113 |
+
"""
|
| 114 |
+
mcp_client: MCPClient
|
| 115 |
+
|
| 116 |
+
def _run(self, sql: str) -> str:
|
| 117 |
+
"""Execute query."""
|
| 118 |
+
try:
|
| 119 |
+
response = self.mcp_client.post(
|
| 120 |
+
"intelligence/execute_query",
|
| 121 |
+
{"sql": sql}
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
if response.get("status") == "success":
|
| 125 |
+
results = response.get("results", [])
|
| 126 |
+
|
| 127 |
+
if results:
|
| 128 |
+
# Format results as a readable table
|
| 129 |
+
result_text = f"Query returned {len(results)} rows:\\n"
|
| 130 |
+
headers = list(results[0].keys())
|
| 131 |
+
result_text += " | ".join(headers) + "\\n"
|
| 132 |
+
result_text += "-" * (len(" | ".join(headers))) + "\\n"
|
| 133 |
+
|
| 134 |
+
for row in results[:10]: # Limit display to first 10 rows
|
| 135 |
+
values = [str(row.get(h, "")) for h in headers]
|
| 136 |
+
result_text += " | ".join(values) + "\\n"
|
| 137 |
+
|
| 138 |
+
if len(results) > 10:
|
| 139 |
+
result_text += f"... and {len(results) - 10} more rows\\n"
|
| 140 |
+
|
| 141 |
+
return result_text
|
| 142 |
+
else:
|
| 143 |
+
return "Query executed successfully but returned no results."
|
| 144 |
+
else:
|
| 145 |
+
return f"Error executing query: {response.get('message', 'Unknown error')}"
|
| 146 |
+
except Exception as e:
|
| 147 |
+
return f"Failed to execute query: {str(e)}"
|
| 148 |
+
|
| 149 |
+
async def _arun(self, sql: str) -> str:
|
| 150 |
+
raise NotImplementedError("QueryExecutorTool does not support async")
|
docker-compose.yml
CHANGED
|
@@ -19,25 +19,6 @@ services:
|
|
| 19 |
networks:
|
| 20 |
- agent-network
|
| 21 |
|
| 22 |
-
postgres:
|
| 23 |
-
image: postgres:15
|
| 24 |
-
environment:
|
| 25 |
-
- POSTGRES_DB=testdb
|
| 26 |
-
- POSTGRES_USER=postgres
|
| 27 |
-
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
|
| 28 |
-
ports:
|
| 29 |
-
- "5432:5432"
|
| 30 |
-
volumes:
|
| 31 |
-
- ./postgres/data:/var/lib/postgresql/data
|
| 32 |
-
- ./postgres/init.sql:/docker-entrypoint-initdb.d/init.sql
|
| 33 |
-
healthcheck:
|
| 34 |
-
test: ["CMD-SHELL", "pg_isready -U postgres"]
|
| 35 |
-
interval: 10s
|
| 36 |
-
timeout: 5s
|
| 37 |
-
retries: 5
|
| 38 |
-
networks:
|
| 39 |
-
- agent-network
|
| 40 |
-
|
| 41 |
mcp:
|
| 42 |
build: ./mcp
|
| 43 |
ports:
|
|
@@ -45,17 +26,15 @@ services:
|
|
| 45 |
environment:
|
| 46 |
- NEO4J_BOLT_URL=${NEO4J_BOLT_URL}
|
| 47 |
- NEO4J_AUTH=${NEO4J_AUTH}
|
| 48 |
-
- POSTGRES_CONNECTION=${POSTGRES_CONNECTION}
|
| 49 |
- MCP_API_KEYS=${MCP_API_KEYS}
|
| 50 |
- MCP_PORT=${MCP_PORT}
|
| 51 |
depends_on:
|
| 52 |
neo4j:
|
| 53 |
condition: service_healthy
|
| 54 |
-
postgres:
|
| 55 |
-
condition: service_healthy
|
| 56 |
volumes:
|
| 57 |
- ./mcp:/app
|
| 58 |
- ./ops/scripts:/app/ops/scripts
|
|
|
|
| 59 |
command: uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
| 60 |
healthcheck:
|
| 61 |
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
|
@@ -71,33 +50,17 @@ services:
|
|
| 71 |
- MCP_URL=http://mcp:8000/mcp
|
| 72 |
- MCP_API_KEY=dev-key-123
|
| 73 |
- AGENT_POLL_INTERVAL=${AGENT_POLL_INTERVAL}
|
| 74 |
-
-
|
| 75 |
-
- LLM_API_KEY=${LLM_API_KEY}
|
| 76 |
-
- LLM_MODEL=${LLM_MODEL}
|
| 77 |
-
- POSTGRES_CONNECTION=${POSTGRES_CONNECTION}
|
| 78 |
depends_on:
|
| 79 |
mcp:
|
| 80 |
condition: service_healthy
|
| 81 |
volumes:
|
| 82 |
- ./agent:/app
|
|
|
|
| 83 |
command: python -u main.py
|
| 84 |
-
restart:
|
| 85 |
-
networks:
|
| 86 |
-
- agent-network
|
| 87 |
-
|
| 88 |
-
frontend:
|
| 89 |
-
build: ./frontend
|
| 90 |
ports:
|
| 91 |
-
- "
|
| 92 |
-
environment:
|
| 93 |
-
- NEXT_PUBLIC_MCP_URL=http://localhost:8000
|
| 94 |
-
depends_on:
|
| 95 |
-
- mcp
|
| 96 |
-
volumes:
|
| 97 |
-
- ./frontend:/app
|
| 98 |
-
- /app/node_modules
|
| 99 |
-
- /app/.next
|
| 100 |
-
command: npm run dev
|
| 101 |
networks:
|
| 102 |
- agent-network
|
| 103 |
|
|
@@ -106,6 +69,7 @@ services:
|
|
| 106 |
ports:
|
| 107 |
- "8501:8501"
|
| 108 |
environment:
|
|
|
|
| 109 |
- MCP_URL=http://mcp:8000/mcp
|
| 110 |
- MCP_API_KEY=dev-key-123
|
| 111 |
depends_on:
|
|
|
|
| 19 |
networks:
|
| 20 |
- agent-network
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
mcp:
|
| 23 |
build: ./mcp
|
| 24 |
ports:
|
|
|
|
| 26 |
environment:
|
| 27 |
- NEO4J_BOLT_URL=${NEO4J_BOLT_URL}
|
| 28 |
- NEO4J_AUTH=${NEO4J_AUTH}
|
|
|
|
| 29 |
- MCP_API_KEYS=${MCP_API_KEYS}
|
| 30 |
- MCP_PORT=${MCP_PORT}
|
| 31 |
depends_on:
|
| 32 |
neo4j:
|
| 33 |
condition: service_healthy
|
|
|
|
|
|
|
| 34 |
volumes:
|
| 35 |
- ./mcp:/app
|
| 36 |
- ./ops/scripts:/app/ops/scripts
|
| 37 |
+
- ./data:/app/data
|
| 38 |
command: uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
| 39 |
healthcheck:
|
| 40 |
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
|
|
|
| 50 |
- MCP_URL=http://mcp:8000/mcp
|
| 51 |
- MCP_API_KEY=dev-key-123
|
| 52 |
- AGENT_POLL_INTERVAL=${AGENT_POLL_INTERVAL}
|
| 53 |
+
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
|
|
|
|
|
|
|
|
|
| 54 |
depends_on:
|
| 55 |
mcp:
|
| 56 |
condition: service_healthy
|
| 57 |
volumes:
|
| 58 |
- ./agent:/app
|
| 59 |
+
- ./data:/app/data
|
| 60 |
command: python -u main.py
|
| 61 |
+
restart: on-failure
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
ports:
|
| 63 |
+
- "8001:8001"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
networks:
|
| 65 |
- agent-network
|
| 66 |
|
|
|
|
| 69 |
ports:
|
| 70 |
- "8501:8501"
|
| 71 |
environment:
|
| 72 |
+
- AGENT_URL=http://agent:8001/query
|
| 73 |
- MCP_URL=http://mcp:8000/mcp
|
| 74 |
- MCP_API_KEY=dev-key-123
|
| 75 |
depends_on:
|
mcp/__init__.py
ADDED
|
File without changes
|
mcp/core/__init__.py
ADDED
|
File without changes
|
mcp/core/config.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
# --- Neo4j Configuration ---
|
| 4 |
+
NEO4J_URI = os.getenv("NEO4J_BOLT_URL", "bolt://neo4j:7687")
|
| 5 |
+
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
|
| 6 |
+
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "password")
|
| 7 |
+
|
| 8 |
+
# --- SQLite Configuration ---
|
| 9 |
+
SQLITE_DATA_DIR = os.getenv("SQLITE_DATA_DIR", "/app/data")
|
| 10 |
+
|
| 11 |
+
def get_sqlite_connection_string(db_name: str) -> str:
|
| 12 |
+
"""
|
| 13 |
+
Generates the SQLAlchemy connection string for a given SQLite database file.
|
| 14 |
+
Assumes the database file is located in the SQLITE_DATA_DIR.
|
| 15 |
+
Example: get_sqlite_connection_string("clinical_trials.db")
|
| 16 |
+
-> "sqlite:////app/data/clinical_trials.db"
|
| 17 |
+
"""
|
| 18 |
+
db_path = os.path.join(SQLITE_DATA_DIR, db_name)
|
| 19 |
+
return f"sqlite:///{db_path}"
|
| 20 |
+
|
| 21 |
+
# --- Application Settings ---
|
| 22 |
+
# You can add other application-wide settings here
|
| 23 |
+
# For example, API keys, logging levels, etc.
|
| 24 |
+
# These would typically be loaded from environment variables as well.
|
mcp/core/database.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sqlalchemy import create_engine
|
| 2 |
+
from sqlalchemy.engine import Engine
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
logging.basicConfig(level=logging.INFO)
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
def get_db_engine(connection_string: str) -> Engine | None:
|
| 9 |
+
"""
|
| 10 |
+
Creates a SQLAlchemy engine for a given database connection string.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
connection_string: The database connection string.
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
A SQLAlchemy Engine instance, or None if connection fails.
|
| 17 |
+
"""
|
| 18 |
+
try:
|
| 19 |
+
engine = create_engine(connection_string)
|
| 20 |
+
# Test the connection
|
| 21 |
+
with engine.connect() as connection:
|
| 22 |
+
logger.info(f"Successfully connected to {engine.url.database}")
|
| 23 |
+
return engine
|
| 24 |
+
except Exception as e:
|
| 25 |
+
logger.error(f"Failed to connect to database: {e}")
|
| 26 |
+
return None
|
mcp/core/discovery.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sqlalchemy import inspect, text
|
| 2 |
+
from sqlalchemy.engine import Engine
|
| 3 |
+
from typing import Dict, Any, List
|
| 4 |
+
import logging
|
| 5 |
+
import json
|
| 6 |
+
from concurrent.futures import TimeoutError, ThreadPoolExecutor
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
def get_table_schema(inspector, table_name: str) -> Dict[str, Any]:
|
| 11 |
+
"""Extracts schema for a single table."""
|
| 12 |
+
columns = inspector.get_columns(table_name)
|
| 13 |
+
primary_keys = inspector.get_pk_constraint(table_name)['constrained_columns']
|
| 14 |
+
foreign_keys = inspector.get_foreign_keys(table_name)
|
| 15 |
+
|
| 16 |
+
table_schema = {
|
| 17 |
+
"name": table_name,
|
| 18 |
+
"columns": [],
|
| 19 |
+
"primary_keys": primary_keys,
|
| 20 |
+
"foreign_keys": foreign_keys
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
for col in columns:
|
| 24 |
+
table_schema["columns"].append({
|
| 25 |
+
"name": col['name'],
|
| 26 |
+
"type": str(col['type']),
|
| 27 |
+
"nullable": col['nullable'],
|
| 28 |
+
"default": col.get('default'),
|
| 29 |
+
})
|
| 30 |
+
return table_schema
|
| 31 |
+
|
| 32 |
+
def get_sample_data(engine: Engine, table_name: str, sample_size: int = 5) -> Dict[str, Any]:
|
| 33 |
+
"""Fetches sample data and distinct values for each column."""
|
| 34 |
+
sample_data = {}
|
| 35 |
+
with engine.connect() as connection:
|
| 36 |
+
# Get row count
|
| 37 |
+
try:
|
| 38 |
+
result = connection.execute(text(f'SELECT COUNT(*) FROM "{table_name}"'))
|
| 39 |
+
sample_data['row_count'] = result.scalar_one()
|
| 40 |
+
except Exception as e:
|
| 41 |
+
logger.warning(f"Could not get row count for table {table_name}: {e}")
|
| 42 |
+
sample_data['row_count'] = -1 # Indicate error or unknown
|
| 43 |
+
|
| 44 |
+
# Get sample rows
|
| 45 |
+
try:
|
| 46 |
+
result = connection.execute(text(f'SELECT * FROM "{table_name}" LIMIT {sample_size}'))
|
| 47 |
+
rows = [dict(row._mapping) for row in result.fetchall()]
|
| 48 |
+
# Attempt to JSON serialize to handle complex types gracefully
|
| 49 |
+
sample_data['sample_rows'] = json.loads(json.dumps(rows, default=str))
|
| 50 |
+
except Exception as e:
|
| 51 |
+
logger.warning(f"Could not get sample rows for table {table_name}: {e}")
|
| 52 |
+
sample_data['sample_rows'] = []
|
| 53 |
+
|
| 54 |
+
return sample_data
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def discover_schema(engine: Engine, timeout: int = 30) -> Dict[str, Any] | None:
|
| 58 |
+
"""
|
| 59 |
+
Discovers the full schema of a database using SQLAlchemy's inspection API.
|
| 60 |
+
Includes table schemas and sample data.
|
| 61 |
+
"""
|
| 62 |
+
try:
|
| 63 |
+
with ThreadPoolExecutor() as executor:
|
| 64 |
+
future = executor.submit(_discover_schema_task, engine)
|
| 65 |
+
return future.result(timeout=timeout)
|
| 66 |
+
except TimeoutError:
|
| 67 |
+
logger.error(f"Schema discovery for {engine.url.database} timed out after {timeout} seconds.")
|
| 68 |
+
return None
|
| 69 |
+
except Exception as e:
|
| 70 |
+
logger.error(f"An unexpected error occurred during schema discovery for {engine.url.database}: {e}")
|
| 71 |
+
return None
|
| 72 |
+
|
| 73 |
+
def _discover_schema_task(engine: Engine) -> Dict[str, Any]:
|
| 74 |
+
"""The actual schema discovery logic to be run with a timeout."""
|
| 75 |
+
inspector = inspect(engine)
|
| 76 |
+
db_schema = {
|
| 77 |
+
"database_name": engine.url.database,
|
| 78 |
+
"dialect": engine.dialect.name,
|
| 79 |
+
"tables": []
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
table_names = inspector.get_table_names()
|
| 83 |
+
|
| 84 |
+
for table_name in table_names:
|
| 85 |
+
try:
|
| 86 |
+
logger.info(f"Discovering schema for table: {table_name}")
|
| 87 |
+
table_schema = get_table_schema(inspector, table_name)
|
| 88 |
+
|
| 89 |
+
logger.info(f"Collecting sample data for table: {table_name}")
|
| 90 |
+
sample_info = get_sample_data(engine, table_name)
|
| 91 |
+
table_schema.update(sample_info)
|
| 92 |
+
|
| 93 |
+
db_schema["tables"].append(table_schema)
|
| 94 |
+
except Exception as e:
|
| 95 |
+
logger.error(f"Could not inspect table '{table_name}': {e}")
|
| 96 |
+
continue
|
| 97 |
+
|
| 98 |
+
return db_schema
|
mcp/core/graph.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from neo4j import GraphDatabase
|
| 2 |
+
import logging
|
| 3 |
+
import json
|
| 4 |
+
from typing import List, Dict, Any
|
| 5 |
+
from . import config
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
class GraphStore:
|
| 10 |
+
def __init__(self):
|
| 11 |
+
self._driver = GraphDatabase.driver(config.NEO4J_URI, auth=(config.NEO4J_USER, config.NEO4J_PASSWORD))
|
| 12 |
+
self.ensure_constraints()
|
| 13 |
+
|
| 14 |
+
def close(self):
|
| 15 |
+
self._driver.close()
|
| 16 |
+
|
| 17 |
+
def ensure_constraints(self):
|
| 18 |
+
"""Ensure uniqueness constraints are set up in Neo4j."""
|
| 19 |
+
with self._driver.session() as session:
|
| 20 |
+
session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (d:Database) REQUIRE d.name IS UNIQUE")
|
| 21 |
+
session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (t:Table) REQUIRE t.unique_name IS UNIQUE")
|
| 22 |
+
session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (c:Column) REQUIRE c.unique_name IS UNIQUE")
|
| 23 |
+
logger.info("Neo4j constraints ensured.")
|
| 24 |
+
|
| 25 |
+
def import_schema(self, schema_data: dict):
|
| 26 |
+
"""
|
| 27 |
+
Imports a discovered database schema into the Neo4j graph.
|
| 28 |
+
"""
|
| 29 |
+
db_name = schema_data['database_name']
|
| 30 |
+
|
| 31 |
+
with self._driver.session() as session:
|
| 32 |
+
# Create Database node
|
| 33 |
+
session.run("MERGE (d:Database {name: $db_name})", db_name=db_name)
|
| 34 |
+
|
| 35 |
+
for table in schema_data['tables']:
|
| 36 |
+
table_unique_name = f"{db_name}.{table['name']}"
|
| 37 |
+
table_properties = {
|
| 38 |
+
"name": table['name'],
|
| 39 |
+
"unique_name": table_unique_name,
|
| 40 |
+
"row_count": table.get('row_count', -1),
|
| 41 |
+
"sample_rows": json.dumps(table.get('sample_rows', []))
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
# Create Table node and HAS_TABLE relationship
|
| 45 |
+
session.run(
|
| 46 |
+
"""
|
| 47 |
+
MATCH (d:Database {name: $db_name})
|
| 48 |
+
MERGE (t:Table {unique_name: $unique_name})
|
| 49 |
+
ON CREATE SET t += $props
|
| 50 |
+
ON MATCH SET t += $props
|
| 51 |
+
MERGE (d)-[:HAS_TABLE]->(t)
|
| 52 |
+
""",
|
| 53 |
+
db_name=db_name,
|
| 54 |
+
unique_name=table_unique_name,
|
| 55 |
+
props=table_properties
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
for column in table['columns']:
|
| 59 |
+
column_unique_name = f"{table_unique_name}.{column['name']}"
|
| 60 |
+
column_properties = {
|
| 61 |
+
"name": column['name'],
|
| 62 |
+
"unique_name": column_unique_name,
|
| 63 |
+
"type": column['type'],
|
| 64 |
+
"nullable": column['nullable'],
|
| 65 |
+
"default": str(column.get('default')) # Ensure default is string
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
# Create Column node and HAS_COLUMN relationship
|
| 69 |
+
session.run(
|
| 70 |
+
"""
|
| 71 |
+
MATCH (t:Table {unique_name: $table_unique_name})
|
| 72 |
+
MERGE (c:Column {unique_name: $column_unique_name})
|
| 73 |
+
ON CREATE SET c += $props
|
| 74 |
+
ON MATCH SET c += $props
|
| 75 |
+
MERGE (t)-[:HAS_COLUMN]->(c)
|
| 76 |
+
""",
|
| 77 |
+
table_unique_name=table_unique_name,
|
| 78 |
+
column_unique_name=column_unique_name,
|
| 79 |
+
props=column_properties
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# After all tables and columns are created, create foreign key relationships
|
| 83 |
+
for table in schema_data['tables']:
|
| 84 |
+
table_unique_name = f"{db_name}.{table['name']}"
|
| 85 |
+
if table.get('foreign_keys'):
|
| 86 |
+
for fk in table['foreign_keys']:
|
| 87 |
+
constrained_columns = fk['constrained_columns']
|
| 88 |
+
referred_table = fk['referred_table']
|
| 89 |
+
referred_columns = fk['referred_columns']
|
| 90 |
+
|
| 91 |
+
referred_table_unique_name = f"{db_name}.{referred_table}"
|
| 92 |
+
|
| 93 |
+
for i, col_name in enumerate(constrained_columns):
|
| 94 |
+
from_col_unique_name = f"{table_unique_name}.{col_name}"
|
| 95 |
+
to_col_unique_name = f"{referred_table_unique_name}.{referred_columns[i]}"
|
| 96 |
+
|
| 97 |
+
session.run(
|
| 98 |
+
"""
|
| 99 |
+
MATCH (from_col:Column {unique_name: $from_col})
|
| 100 |
+
MATCH (to_col:Column {unique_name: $to_col})
|
| 101 |
+
MERGE (from_col)-[:REFERENCES]->(to_col)
|
| 102 |
+
""",
|
| 103 |
+
from_col=from_col_unique_name,
|
| 104 |
+
to_col=to_col_unique_name
|
| 105 |
+
)
|
| 106 |
+
logger.info(f"Successfully imported schema for database: {db_name}")
|
| 107 |
+
|
| 108 |
+
def find_shortest_path(self, start_node_name: str, end_node_name: str) -> List[Dict[str, Any]]:
|
| 109 |
+
"""
|
| 110 |
+
Finds the shortest path between two nodes (Tables or Columns) in the graph.
|
| 111 |
+
This is a generic pathfinder.
|
| 112 |
+
"""
|
| 113 |
+
query = """
|
| 114 |
+
MATCH (start {unique_name: $start_name}), (end {unique_name: $end_name})
|
| 115 |
+
CALL apoc.path.shortestPath(start, end, 'REFERENCES|HAS_COLUMN|HAS_TABLE', {maxLevel: 10}) YIELD path
|
| 116 |
+
RETURN path
|
| 117 |
+
"""
|
| 118 |
+
with self._driver.session() as session:
|
| 119 |
+
result = session.run(query, start_name=start_node_name, end_name=end_node_name)
|
| 120 |
+
# The result is complex, we need to parse it into a user-friendly format.
|
| 121 |
+
# For now, returning the raw path objects.
|
| 122 |
+
return [record["path"] for record in result]
|
| 123 |
+
|
| 124 |
+
def keyword_search(self, keyword: str) -> List[Dict[str, Any]]:
|
| 125 |
+
"""
|
| 126 |
+
Searches for tables and columns matching a keyword.
|
| 127 |
+
Returns a list of matching nodes with their database and table context.
|
| 128 |
+
"""
|
| 129 |
+
query = """
|
| 130 |
+
MATCH (n)
|
| 131 |
+
WHERE (n:Table OR n:Column) AND n.name CONTAINS $keyword
|
| 132 |
+
OPTIONAL MATCH (d:Database)-[:HAS_TABLE]->(t:Table)-[:HAS_COLUMN]->(n) WHERE n:Column
|
| 133 |
+
OPTIONAL MATCH (d2:Database)-[:HAS_TABLE]->(n) WHERE n:Table
|
| 134 |
+
WITH COALESCE(d, d2) AS db, COALESCE(t, n) AS tbl, n AS item
|
| 135 |
+
RETURN db.name AS database, tbl.name AS table, item.name AS name, labels(item) AS type
|
| 136 |
+
LIMIT 25
|
| 137 |
+
"""
|
| 138 |
+
with self._driver.session() as session:
|
| 139 |
+
result = session.run(query, keyword=keyword)
|
| 140 |
+
return [record.data() for record in result]
|
| 141 |
+
|
| 142 |
+
def get_table_row_count(self, table_unique_name: str) -> int:
|
| 143 |
+
"""Retrieves the stored row count for a given table."""
|
| 144 |
+
query = """
|
| 145 |
+
MATCH (t:Table {unique_name: $unique_name})
|
| 146 |
+
RETURN t.row_count AS row_count
|
| 147 |
+
"""
|
| 148 |
+
with self._driver.session() as session:
|
| 149 |
+
result = session.run(query, unique_name=table_unique_name)
|
| 150 |
+
record = result.single()
|
| 151 |
+
return record['row_count'] if record else -1
|
mcp/core/intelligence.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlparse
|
| 2 |
+
import logging
|
| 3 |
+
from typing import List, Dict, Any
|
| 4 |
+
|
| 5 |
+
from .graph import GraphStore
|
| 6 |
+
from .database import get_db_engine
|
| 7 |
+
from . import config
|
| 8 |
+
from sqlalchemy import text
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
# Constants for query cost estimation
|
| 13 |
+
ROW_EXECUTION_THRESHOLD = 100 # Execute queries expected to return fewer rows
|
| 14 |
+
JOIN_CARDINALITY_ESTIMATE = 1000 # A simplistic estimate for joins
|
| 15 |
+
|
| 16 |
+
class QueryIntelligence:
|
| 17 |
+
"""
|
| 18 |
+
Provides intelligence for handling SQL queries. It estimates query cost
|
| 19 |
+
and decides on an execution strategy.
|
| 20 |
+
"""
|
| 21 |
+
def __init__(self, graph_store: GraphStore):
|
| 22 |
+
self.graph_store = graph_store
|
| 23 |
+
self.db_engines = {}
|
| 24 |
+
|
| 25 |
+
def _get_engine_for_db(self, db_name: str):
|
| 26 |
+
"""Helper to get or create an engine for a specific database."""
|
| 27 |
+
if db_name not in self.db_engines:
|
| 28 |
+
# Assuming db_name includes the .db extension
|
| 29 |
+
connection_string = config.get_sqlite_connection_string(db_name)
|
| 30 |
+
self.db_engines[db_name] = get_db_engine(connection_string)
|
| 31 |
+
return self.db_engines.get(db_name)
|
| 32 |
+
|
| 33 |
+
async def get_relevant_schemas(self, query: str) -> List[Dict[str, Any]]:
|
| 34 |
+
"""Finds schemas relevant to a natural language query."""
|
| 35 |
+
# This is a simplistic keyword search. A real implementation would use
|
| 36 |
+
# embedding-based search or an LLM to extract entities.
|
| 37 |
+
keywords = query.split()
|
| 38 |
+
all_results = []
|
| 39 |
+
for keyword in keywords:
|
| 40 |
+
if len(keyword) > 2: # Avoid very short keywords
|
| 41 |
+
results = self.graph_store.keyword_search(keyword)
|
| 42 |
+
all_results.extend(results)
|
| 43 |
+
# Deduplicate results
|
| 44 |
+
return [dict(t) for t in {tuple(d.items()) for d in all_results}]
|
| 45 |
+
|
| 46 |
+
async def find_join_path(self, table1_name: str, table2_name: str) -> str:
|
| 47 |
+
"""Finds a join path between two tables using the graph."""
|
| 48 |
+
# This is a simplification. It requires table names to be unique or requires
|
| 49 |
+
# the user to provide fully qualified names (db.table).
|
| 50 |
+
t1_nodes = self.graph_store.keyword_search(table1_name)
|
| 51 |
+
t2_nodes = self.graph_store.keyword_search(table2_name)
|
| 52 |
+
|
| 53 |
+
if not t1_nodes or not t2_nodes:
|
| 54 |
+
return "Could not find one or both tables."
|
| 55 |
+
|
| 56 |
+
# Assume the first result is correct for simplicity
|
| 57 |
+
t1_unique_name = f"{t1_nodes[0]['database']}.{t1_nodes[0]['table']}"
|
| 58 |
+
t2_unique_name = f"{t2_nodes[0]['database']}.{t2_nodes[0]['table']}"
|
| 59 |
+
|
| 60 |
+
path_result = self.graph_store.find_shortest_path(t1_unique_name, t2_unique_name)
|
| 61 |
+
|
| 62 |
+
if not path_result:
|
| 63 |
+
return f"No path found between {table1_name} and {table2_name}."
|
| 64 |
+
|
| 65 |
+
# Format the path for display
|
| 66 |
+
# This is a complex task. The raw path from Neo4j needs careful parsing.
|
| 67 |
+
# This is a placeholder for that logic.
|
| 68 |
+
return f"Path found (details require parsing): {path_result}"
|
| 69 |
+
|
| 70 |
+
async def execute_query(self, sql: str, limit: int) -> List[Dict[str, Any]]:
|
| 71 |
+
"""
|
| 72 |
+
Executes a SQL query against the appropriate database if the estimated
|
| 73 |
+
cost is below the threshold.
|
| 74 |
+
"""
|
| 75 |
+
cost_estimate = self.estimate_query_cost(sql)
|
| 76 |
+
|
| 77 |
+
if cost_estimate['decision'] != 'execute':
|
| 78 |
+
raise PermissionError(f"Query execution denied. Estimated cost is too high ({cost_estimate['estimated_rows']} rows).")
|
| 79 |
+
|
| 80 |
+
# This is a major simplification. Determining which database to run the query
|
| 81 |
+
# against is a hard problem (especially for federated queries).
|
| 82 |
+
# We assume the first table found belongs to the correct database.
|
| 83 |
+
parsed_sql = self._parse_sql(sql)
|
| 84 |
+
if not parsed_sql['tables']:
|
| 85 |
+
raise ValueError("No tables found in SQL query.")
|
| 86 |
+
|
| 87 |
+
first_table = parsed_sql['tables'][0]
|
| 88 |
+
search_results = self.graph_store.keyword_search(first_table)
|
| 89 |
+
if not search_results:
|
| 90 |
+
raise ValueError(f"Table '{first_table}' not found in any known database.")
|
| 91 |
+
|
| 92 |
+
db_name = search_results[0]['database']
|
| 93 |
+
engine = self._get_engine_for_db(db_name)
|
| 94 |
+
|
| 95 |
+
if not engine:
|
| 96 |
+
raise ConnectionError(f"Could not connect to database: {db_name}")
|
| 97 |
+
|
| 98 |
+
with engine.connect() as connection:
|
| 99 |
+
# Append limit to the query
|
| 100 |
+
safe_sql = f"{sql.strip().rstrip(';')} LIMIT {int(limit)}"
|
| 101 |
+
result = connection.execute(text(safe_sql))
|
| 102 |
+
return [dict(row._mapping) for row in result.fetchall()]
|
| 103 |
+
|
| 104 |
+
def _parse_sql(self, sql: str) -> Dict[str, Any]:
|
| 105 |
+
"""Parses the SQL to identify tables and columns."""
|
| 106 |
+
parsed = sqlparse.parse(sql)[0]
|
| 107 |
+
# This is a simplistic parser. A real implementation would need
|
| 108 |
+
# a much more robust SQL parsing library to handle complex queries, CTEs, etc.
|
| 109 |
+
tables = set()
|
| 110 |
+
for token in parsed.tokens:
|
| 111 |
+
if isinstance(token, sqlparse.sql.Identifier):
|
| 112 |
+
tables.add(token.get_real_name())
|
| 113 |
+
elif token.is_group:
|
| 114 |
+
# Look for identifiers within subgroups (e.g., in FROM or JOIN clauses)
|
| 115 |
+
for sub_token in token.tokens:
|
| 116 |
+
if isinstance(sub_token, sqlparse.sql.Identifier):
|
| 117 |
+
tables.add(sub_token.get_real_name())
|
| 118 |
+
|
| 119 |
+
return {"tables": list(tables)}
|
| 120 |
+
|
| 121 |
+
def estimate_query_cost(self, sql: str) -> Dict[str, Any]:
|
| 122 |
+
"""
|
| 123 |
+
Estimates the cost of a query based on row counts from the graph.
|
| 124 |
+
"""
|
| 125 |
+
try:
|
| 126 |
+
parsed_sql = self._parse_sql(sql)
|
| 127 |
+
tables_in_query = parsed_sql['tables']
|
| 128 |
+
|
| 129 |
+
if not tables_in_query:
|
| 130 |
+
return {"estimated_rows": 0, "decision": "execute", "message": "No tables found in query."}
|
| 131 |
+
|
| 132 |
+
# For simplicity, we'll take the max row count of any table in the query.
|
| 133 |
+
# A real system would analyze JOINs and WHERE clauses.
|
| 134 |
+
max_rows = 0
|
| 135 |
+
for table_name in tables_in_query:
|
| 136 |
+
# Need to find the unique name. This assumes table names are unique across DBs for now.
|
| 137 |
+
# A real implementation needs context of which DB is being queried.
|
| 138 |
+
search_result = self.graph_store.keyword_search(table_name)
|
| 139 |
+
if search_result:
|
| 140 |
+
table_unique_name = f"{search_result[0]['database']}.{search_result[0]['table']}"
|
| 141 |
+
row_count = self.graph_store.get_table_row_count(table_unique_name)
|
| 142 |
+
if row_count > max_rows:
|
| 143 |
+
max_rows = row_count
|
| 144 |
+
|
| 145 |
+
estimated_rows = max_rows
|
| 146 |
+
# Crude adjustment for joins
|
| 147 |
+
if len(tables_in_query) > 1:
|
| 148 |
+
# A better estimate would involve graph traversal and statistical models
|
| 149 |
+
estimated_rows *= JOIN_CARDINALITY_ESTIMATE * (len(tables_in_query) - 1)
|
| 150 |
+
|
| 151 |
+
decision = "execute" if estimated_rows < ROW_EXECUTION_THRESHOLD else "return_sql"
|
| 152 |
+
|
| 153 |
+
return {
|
| 154 |
+
"estimated_rows": estimated_rows,
|
| 155 |
+
"decision": decision,
|
| 156 |
+
"tables_found": tables_in_query
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
except Exception as e:
|
| 160 |
+
logger.error(f"Error estimating query cost: {e}")
|
| 161 |
+
return {"estimated_rows": -1, "decision": "error", "message": str(e)}
|
mcp/requirements.txt
CHANGED
|
@@ -3,4 +3,5 @@ uvicorn==0.24.0
|
|
| 3 |
neo4j==5.14.0
|
| 4 |
pydantic==2.4.0
|
| 5 |
requests==2.31.0
|
| 6 |
-
|
|
|
|
|
|
| 3 |
neo4j==5.14.0
|
| 4 |
pydantic==2.4.0
|
| 5 |
requests==2.31.0
|
| 6 |
+
SQLAlchemy==2.0.29
|
| 7 |
+
sqlparse==0.5.0
|
ops/scripts/generate_sample_databases.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from datetime import datetime, timedelta
|
| 4 |
+
import random
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
# Ensure the data directory exists
|
| 8 |
+
DATA_DIR = 'data'
|
| 9 |
+
os.makedirs(DATA_DIR, exist_ok=True)
|
| 10 |
+
|
| 11 |
+
def create_clinical_trials_db():
|
| 12 |
+
"""Creates the clinical_trials.db database."""
|
| 13 |
+
conn = sqlite3.connect(os.path.join(DATA_DIR, 'clinical_trials.db'))
|
| 14 |
+
|
| 15 |
+
# Studies table
|
| 16 |
+
studies_data = {
|
| 17 |
+
'study_id': ['ONCO-2023-001', 'CARDIO-2023-047', 'NEURO-2024-012', 'DIAB-2023-089', 'RARE-2024-003'],
|
| 18 |
+
'study_name': ['Phase III Immunotherapy Trial', 'Beta Blocker Efficacy Study', 'Alzheimer Prevention Trial', 'Insulin Resistance Study', 'Rare Disease Natural History'],
|
| 19 |
+
'phase': ['Phase 3', 'Phase 2', 'Phase 2', 'Phase 3', 'Observational'],
|
| 20 |
+
'status': ['RECRUITING', 'ACTIVE', 'PLANNING', 'COMPLETED', 'RECRUITING'],
|
| 21 |
+
'sponsor': ['OncoPharm Inc', 'CardioHealth', 'NeuroGen', 'DiabetesCare', 'NIH'],
|
| 22 |
+
'target_enrollment': [500, 200, 150, 800, 50],
|
| 23 |
+
'current_enrollment': [237, 178, 0, 800, 12],
|
| 24 |
+
'start_date': ['2023-03-15', '2023-06-01', '2024-01-15', '2023-01-10', '2024-02-01']
|
| 25 |
+
}
|
| 26 |
+
pd.DataFrame(studies_data).to_sql('studies', conn, index=False, if_exists='replace')
|
| 27 |
+
|
| 28 |
+
# Patients table
|
| 29 |
+
patients_data = {
|
| 30 |
+
'patient_id': [f'PT{str(i).zfill(6)}' for i in range(1, 101)],
|
| 31 |
+
'study_id': [random.choice(studies_data['study_id']) for _ in range(100)],
|
| 32 |
+
'enrollment_date': [(datetime.now() - timedelta(days=random.randint(1, 365))).strftime('%Y-%m-%d') for _ in range(100)],
|
| 33 |
+
'age': [random.randint(18, 85) for _ in range(100)],
|
| 34 |
+
'gender': [random.choice(['M', 'F']) for _ in range(100)],
|
| 35 |
+
'status': [random.choice(['ENROLLED', 'COMPLETED', 'WITHDRAWN', 'SCREENING']) for _ in range(100)]
|
| 36 |
+
}
|
| 37 |
+
pd.DataFrame(patients_data).to_sql('patients', conn, index=False, if_exists='replace')
|
| 38 |
+
|
| 39 |
+
# Adverse Events table
|
| 40 |
+
adverse_events_data = {
|
| 41 |
+
'event_id': list(range(1, 51)),
|
| 42 |
+
'patient_id': [random.choice(patients_data['patient_id'][:50]) for _ in range(50)],
|
| 43 |
+
'event_date': [(datetime.now() - timedelta(days=random.randint(1, 180))).strftime('%Y-%m-%d') for _ in range(50)],
|
| 44 |
+
'event_type': [random.choice(['NAUSEA', 'HEADACHE', 'FATIGUE', 'RASH', 'FEVER']) for _ in range(50)],
|
| 45 |
+
'severity': [random.choice(['MILD', 'MODERATE', 'SEVERE']) for _ in range(50)],
|
| 46 |
+
'related_to_treatment': [random.choice(['YES', 'NO', 'UNKNOWN']) for _ in range(50)]
|
| 47 |
+
}
|
| 48 |
+
pd.DataFrame(adverse_events_data).to_sql('adverse_events', conn, index=False, if_exists='replace')
|
| 49 |
+
|
| 50 |
+
# Add foreign keys (SQLite doesn't enforce, but documents relationships)
|
| 51 |
+
conn.execute("CREATE INDEX idx_patients_study ON patients(study_id)")
|
| 52 |
+
conn.execute("CREATE INDEX idx_events_patient ON adverse_events(patient_id)")
|
| 53 |
+
conn.commit()
|
| 54 |
+
conn.close()
|
| 55 |
+
print("β
Clinical Trials database created successfully!")
|
| 56 |
+
|
| 57 |
+
def create_laboratory_db():
|
| 58 |
+
"""Creates the laboratory.db database."""
|
| 59 |
+
conn = sqlite3.connect(os.path.join(DATA_DIR, 'laboratory.db'))
|
| 60 |
+
|
| 61 |
+
# Lab Tests table
|
| 62 |
+
lab_tests_data = {
|
| 63 |
+
'test_id': [f'LAB{str(i).zfill(8)}' for i in range(1, 201)],
|
| 64 |
+
'patient_id': [f'PT{str(random.randint(1, 100)).zfill(6)}' for _ in range(200)],
|
| 65 |
+
'test_date': [(datetime.now() - timedelta(days=random.randint(1, 365))).strftime('%Y-%m-%d') for _ in range(200)],
|
| 66 |
+
'test_type': [random.choice(['CBC', 'METABOLIC_PANEL', 'LIVER_FUNCTION', 'LIPID_PANEL', 'HBA1C']) for _ in range(200)],
|
| 67 |
+
'ordered_by': [f'DR{str(random.randint(1, 20)).zfill(4)}' for _ in range(200)],
|
| 68 |
+
'priority': [random.choice(['ROUTINE', 'URGENT', 'STAT']) for _ in range(200)]
|
| 69 |
+
}
|
| 70 |
+
pd.DataFrame(lab_tests_data).to_sql('lab_tests', conn, index=False, if_exists='replace')
|
| 71 |
+
|
| 72 |
+
# Test Results table
|
| 73 |
+
results_data = {
|
| 74 |
+
'result_id': list(range(1, 601)),
|
| 75 |
+
'test_id': [random.choice(lab_tests_data['test_id']) for _ in range(600)],
|
| 76 |
+
'analyte': [random.choice(['GLUCOSE', 'WBC', 'RBC', 'PLATELETS', 'CREATININE', 'ALT', 'AST', 'CHOLESTEROL']) for _ in range(600)],
|
| 77 |
+
'value': [round(random.uniform(1, 200), 2) for _ in range(600)],
|
| 78 |
+
'unit': [random.choice(['mg/dL', 'K/uL', 'M/uL', 'g/dL', 'mmol/L']) for _ in range(600)],
|
| 79 |
+
'reference_low': [round(random.uniform(1, 50), 2) for _ in range(600)],
|
| 80 |
+
'reference_high': [round(random.uniform(100, 200), 2) for _ in range(600)],
|
| 81 |
+
'flag': [random.choice(['NORMAL', 'HIGH', 'LOW', 'CRITICAL']) for _ in range(600)]
|
| 82 |
+
}
|
| 83 |
+
pd.DataFrame(results_data).to_sql('test_results', conn, index=False, if_exists='replace')
|
| 84 |
+
|
| 85 |
+
# Biomarkers table (for research)
|
| 86 |
+
biomarkers_data = {
|
| 87 |
+
'biomarker_id': list(range(1, 31)),
|
| 88 |
+
'patient_id': [f'PT{str(random.randint(1, 100)).zfill(6)}' for _ in range(30)],
|
| 89 |
+
'biomarker_name': [random.choice(['PD-L1', 'BRCA1', 'EGFR', 'KRAS', 'HER2']) for _ in range(30)],
|
| 90 |
+
'expression_level': [random.choice(['HIGH', 'MEDIUM', 'LOW', 'NEGATIVE']) for _ in range(30)],
|
| 91 |
+
'test_method': [random.choice(['IHC', 'PCR', 'NGS', 'FLOW_CYTOMETRY']) for _ in range(30)],
|
| 92 |
+
'collection_date': [(datetime.now() - timedelta(days=random.randint(1, 365))).strftime('%Y-%m-%d') for _ in range(30)]
|
| 93 |
+
}
|
| 94 |
+
pd.DataFrame(biomarkers_data).to_sql('biomarkers', conn, index=False, if_exists='replace')
|
| 95 |
+
|
| 96 |
+
conn.execute("CREATE INDEX idx_results_test ON test_results(test_id)")
|
| 97 |
+
conn.execute("CREATE INDEX idx_tests_patient ON lab_tests(patient_id)")
|
| 98 |
+
conn.commit()
|
| 99 |
+
conn.close()
|
| 100 |
+
print("β
Laboratory database created successfully!")
|
| 101 |
+
|
| 102 |
+
def create_drug_discovery_db():
|
| 103 |
+
"""Creates the drug_discovery.db database."""
|
| 104 |
+
conn = sqlite3.connect(os.path.join(DATA_DIR, 'drug_discovery.db'))
|
| 105 |
+
|
| 106 |
+
# Compounds table
|
| 107 |
+
compounds_data = {
|
| 108 |
+
'compound_id': [f'CMP-{str(i).zfill(6)}' for i in range(1, 51)],
|
| 109 |
+
'compound_name': [f'Compound-{chr(65+i//10)}{i%10}' for i in range(50)],
|
| 110 |
+
'molecular_weight': [round(random.uniform(200, 800), 2) for _ in range(50)],
|
| 111 |
+
'formula': [f'C{random.randint(10,30)}H{random.randint(10,40)}N{random.randint(0,5)}O{random.randint(1,10)}' for _ in range(50)],
|
| 112 |
+
'development_stage': [random.choice(['DISCOVERY', 'LEAD_OPT', 'PRECLINICAL', 'CLINICAL', 'DISCONTINUED']) for _ in range(50)],
|
| 113 |
+
'target_class': [random.choice(['KINASE', 'GPCR', 'ION_CHANNEL', 'PROTEASE', 'ANTIBODY']) for _ in range(50)]
|
| 114 |
+
}
|
| 115 |
+
pd.DataFrame(compounds_data).to_sql('compounds', conn, index=False, if_exists='replace')
|
| 116 |
+
|
| 117 |
+
# Assay Results table
|
| 118 |
+
assays_data = {
|
| 119 |
+
'assay_id': list(range(1, 201)),
|
| 120 |
+
'compound_id': [random.choice(compounds_data['compound_id']) for _ in range(200)],
|
| 121 |
+
'assay_type': [random.choice(['BINDING', 'CELL_VIABILITY', 'ENZYME_INHIBITION', 'TOXICITY']) for _ in range(200)],
|
| 122 |
+
'ic50_nm': [round(random.uniform(0.1, 10000), 2) for _ in range(200)],
|
| 123 |
+
'efficacy_percent': [round(random.uniform(0, 100), 1) for _ in range(200)],
|
| 124 |
+
'assay_date': [(datetime.now() - timedelta(days=random.randint(1, 365))).strftime('%Y-%m-%d') for _ in range(200)],
|
| 125 |
+
'scientist': [f'SCI-{random.randint(1, 10)}' for _ in range(200)]
|
| 126 |
+
}
|
| 127 |
+
pd.DataFrame(assays_data).to_sql('assay_results', conn, index=False, if_exists='replace')
|
| 128 |
+
|
| 129 |
+
# Drug Targets table
|
| 130 |
+
targets_data = {
|
| 131 |
+
'target_id': [f'TGT-{str(i).zfill(4)}' for i in range(1, 21)],
|
| 132 |
+
'target_name': [f'Protein-{i}' for i in range(1, 21)],
|
| 133 |
+
'gene_symbol': [f'GENE{i}' for i in range(1, 21)],
|
| 134 |
+
'pathway': [random.choice(['MAPK', 'PI3K/AKT', 'WNT', 'NOTCH', 'HEDGEHOG']) for _ in range(20)],
|
| 135 |
+
'disease_area': [random.choice(['ONCOLOGY', 'CARDIOLOGY', 'NEUROLOGY', 'IMMUNOLOGY']) for _ in range(20)]
|
| 136 |
+
}
|
| 137 |
+
pd.DataFrame(targets_data).to_sql('drug_targets', conn, index=False, if_exists='replace')
|
| 138 |
+
|
| 139 |
+
# Compound-Target Associations
|
| 140 |
+
associations_data = {
|
| 141 |
+
'association_id': list(range(1, 76)),
|
| 142 |
+
'compound_id': [random.choice(compounds_data['compound_id']) for _ in range(75)],
|
| 143 |
+
'target_id': [random.choice(targets_data['target_id']) for _ in range(75)],
|
| 144 |
+
'affinity_nm': [round(random.uniform(0.1, 1000), 2) for _ in range(75)],
|
| 145 |
+
'selectivity_fold': [round(random.uniform(1, 100), 1) for _ in range(75)]
|
| 146 |
+
}
|
| 147 |
+
pd.DataFrame(associations_data).to_sql('compound_targets', conn, index=False, if_exists='replace')
|
| 148 |
+
|
| 149 |
+
conn.execute("CREATE INDEX idx_assays_compound ON assay_results(compound_id)")
|
| 150 |
+
conn.execute("CREATE INDEX idx_associations_compound ON compound_targets(compound_id)")
|
| 151 |
+
conn.execute("CREATE INDEX idx_associations_target ON compound_targets(target_id)")
|
| 152 |
+
conn.commit()
|
| 153 |
+
conn.close()
|
| 154 |
+
print("β
Drug Discovery database created successfully!")
|
| 155 |
+
|
| 156 |
+
if __name__ == "__main__":
|
| 157 |
+
create_clinical_trials_db()
|
| 158 |
+
create_laboratory_db()
|
| 159 |
+
create_drug_discovery_db()
|
| 160 |
+
print("\nπ All three Life Sciences databases created successfully in the 'data' directory!")
|
ops/scripts/ingest.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import logging
|
| 4 |
+
from sqlalchemy import create_engine
|
| 5 |
+
|
| 6 |
+
# Add project root to path to allow imports from mcp
|
| 7 |
+
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 8 |
+
sys.path.append(project_root)
|
| 9 |
+
|
| 10 |
+
from core.discovery import discover_schema
|
| 11 |
+
from core.graph import GraphStore
|
| 12 |
+
from core.config import SQLITE_DATA_DIR
|
| 13 |
+
|
| 14 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
def ingest_sqlite_database(db_file: str, graph_store: GraphStore):
|
| 18 |
+
"""Discovers schema from a SQLite DB and ingests it into Neo4j."""
|
| 19 |
+
db_path = os.path.join(SQLITE_DATA_DIR, db_file)
|
| 20 |
+
logger.info(f"Processing database: {db_path}")
|
| 21 |
+
|
| 22 |
+
if not os.path.exists(db_path):
|
| 23 |
+
logger.error(f"Database file not found: {db_path}")
|
| 24 |
+
return
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
engine = create_engine(f"sqlite:///{db_path}")
|
| 28 |
+
schema_data = discover_schema(engine)
|
| 29 |
+
|
| 30 |
+
if schema_data:
|
| 31 |
+
logger.info(f"Discovered schema for {db_file}, ingesting into Neo4j...")
|
| 32 |
+
graph_store.import_schema(schema_data)
|
| 33 |
+
logger.info(f"Successfully ingested schema for {db_file}")
|
| 34 |
+
else:
|
| 35 |
+
logger.warning(f"Could not discover schema for {db_file}. Skipping.")
|
| 36 |
+
|
| 37 |
+
except Exception as e:
|
| 38 |
+
logger.error(f"An error occurred while processing {db_file}: {e}")
|
| 39 |
+
|
| 40 |
+
def main():
|
| 41 |
+
"""
|
| 42 |
+
Main function to run the ingestion process for all SQLite databases
|
| 43 |
+
found in the data directory.
|
| 44 |
+
"""
|
| 45 |
+
logger.info("Starting schema ingestion process...")
|
| 46 |
+
|
| 47 |
+
if not os.path.exists(SQLITE_DATA_DIR) or not os.path.isdir(SQLITE_DATA_DIR):
|
| 48 |
+
logger.error(f"Data directory not found: {SQLITE_DATA_DIR}")
|
| 49 |
+
return
|
| 50 |
+
|
| 51 |
+
db_files = [f for f in os.listdir(SQLITE_DATA_DIR) if f.endswith(".db")]
|
| 52 |
+
|
| 53 |
+
if not db_files:
|
| 54 |
+
logger.warning(f"No SQLite database files (.db) found in {SQLITE_DATA_DIR}.")
|
| 55 |
+
return
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
graph_store = GraphStore()
|
| 59 |
+
logger.info("Successfully connected to Neo4j.")
|
| 60 |
+
except Exception as e:
|
| 61 |
+
logger.error(f"Failed to connect to Neo4j. Aborting ingestion. Error: {e}")
|
| 62 |
+
return
|
| 63 |
+
|
| 64 |
+
for db_file in db_files:
|
| 65 |
+
ingest_sqlite_database(db_file, graph_store)
|
| 66 |
+
|
| 67 |
+
graph_store.close()
|
| 68 |
+
logger.info("Schema ingestion process completed.")
|
| 69 |
+
|
| 70 |
+
if __name__ == "__main__":
|
| 71 |
+
main()
|
postgres/init.sql
DELETED
|
@@ -1,28 +0,0 @@
|
|
| 1 |
-
-- Create sample tables for testing
|
| 2 |
-
CREATE TABLE customers (
|
| 3 |
-
id SERIAL PRIMARY KEY,
|
| 4 |
-
email VARCHAR(255) UNIQUE NOT NULL,
|
| 5 |
-
first_name VARCHAR(100),
|
| 6 |
-
last_name VARCHAR(100),
|
| 7 |
-
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 8 |
-
);
|
| 9 |
-
|
| 10 |
-
CREATE TABLE orders (
|
| 11 |
-
id SERIAL PRIMARY KEY,
|
| 12 |
-
customer_id INTEGER REFERENCES customers(id),
|
| 13 |
-
order_date DATE NOT NULL,
|
| 14 |
-
total_amount DECIMAL(10,2),
|
| 15 |
-
status VARCHAR(50)
|
| 16 |
-
);
|
| 17 |
-
|
| 18 |
-
-- Insert sample data
|
| 19 |
-
INSERT INTO customers (email, first_name, last_name) VALUES
|
| 20 |
-
('john.doe@email.com', 'John', 'Doe'),
|
| 21 |
-
('jane.smith@email.com', 'Jane', 'Smith'),
|
| 22 |
-
('bob.johnson@email.com', 'Bob', 'Johnson');
|
| 23 |
-
|
| 24 |
-
INSERT INTO orders (customer_id, order_date, total_amount, status) VALUES
|
| 25 |
-
(1, '2024-01-15', 99.99, 'completed'),
|
| 26 |
-
(1, '2024-02-01', 149.99, 'completed'),
|
| 27 |
-
(2, '2024-01-20', 79.99, 'pending'),
|
| 28 |
-
(3, '2024-02-10', 199.99, 'completed');
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
streamlit/app.py
CHANGED
|
@@ -7,478 +7,156 @@ All database access MUST go through MCP server - no direct connections allowed.
|
|
| 7 |
|
| 8 |
import streamlit as st
|
| 9 |
import requests
|
| 10 |
-
import
|
| 11 |
import json
|
| 12 |
import pandas as pd
|
| 13 |
-
from
|
| 14 |
-
import os
|
| 15 |
-
from typing import Dict, Any, Optional, Tuple
|
| 16 |
|
| 17 |
-
# Configuration
|
|
|
|
|
|
|
| 18 |
MCP_URL = os.getenv("MCP_URL", "http://mcp:8000/mcp")
|
| 19 |
MCP_API_KEY = os.getenv("MCP_API_KEY", "dev-key-123")
|
| 20 |
|
| 21 |
-
# Page configuration
|
| 22 |
st.set_page_config(
|
| 23 |
-
page_title="
|
| 24 |
-
page_icon="
|
| 25 |
-
layout="wide"
|
| 26 |
-
initial_sidebar_state="expanded"
|
| 27 |
)
|
| 28 |
|
| 29 |
-
#
|
| 30 |
-
if '
|
| 31 |
-
st.session_state.
|
| 32 |
-
if '
|
| 33 |
-
st.session_state.
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
def
|
| 38 |
-
"""
|
| 39 |
-
Call MCP server - the ONLY way to access databases.
|
| 40 |
-
Returns (response_data, response_time_ms)
|
| 41 |
-
"""
|
| 42 |
-
start_time = time.time()
|
| 43 |
-
|
| 44 |
try:
|
| 45 |
response = requests.post(
|
| 46 |
-
MCP_URL,
|
| 47 |
-
headers={
|
| 48 |
-
|
| 49 |
-
"Content-Type": "application/json"
|
| 50 |
-
},
|
| 51 |
-
json={"tool": tool, "params": params or {}},
|
| 52 |
-
timeout=10
|
| 53 |
)
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
"tool": tool,
|
| 61 |
-
"params": params,
|
| 62 |
-
"status_code": response.status_code,
|
| 63 |
-
"response_time_ms": response_time,
|
| 64 |
-
"success": response.status_code == 200
|
| 65 |
-
}
|
| 66 |
-
st.session_state.debug_log.append(debug_entry)
|
| 67 |
-
|
| 68 |
-
# Keep only last 5 entries
|
| 69 |
-
if len(st.session_state.debug_log) > 5:
|
| 70 |
-
st.session_state.debug_log = st.session_state.debug_log[-5:]
|
| 71 |
-
|
| 72 |
-
if response.status_code == 200:
|
| 73 |
-
return response.json(), response_time
|
| 74 |
-
else:
|
| 75 |
-
return {"error": f"HTTP {response.status_code}: {response.text}"}, response_time
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
"
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
return {"error": error_msg}, response_time
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
result, response_time = call_mcp("get_schema")
|
| 98 |
-
if "error" in result:
|
| 99 |
-
return False, response_time, result["error"]
|
| 100 |
-
return True, response_time, "Connected"
|
| 101 |
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
return False, response_time, result["error"]
|
| 107 |
-
return True, response_time, "Connected"
|
| 108 |
|
| 109 |
-
def
|
| 110 |
-
"""
|
| 111 |
try:
|
| 112 |
-
|
| 113 |
-
response
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
-
if
|
| 117 |
-
|
| 118 |
-
else:
|
| 119 |
-
return False, response_time, f"HTTP {response.status_code}"
|
| 120 |
-
except Exception as e:
|
| 121 |
-
return False, 0, str(e)
|
| 122 |
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
result, _ = call_mcp("query_graph", {
|
| 126 |
-
"query": "MATCH (l:Log) WHERE l.timestamp > datetime() - duration('PT1H') RETURN count(l) as count"
|
| 127 |
-
})
|
| 128 |
-
|
| 129 |
-
if "error" in result:
|
| 130 |
-
return {"error": result["error"]}
|
| 131 |
-
|
| 132 |
-
return result.get("data", [{}])[0] if result.get("data") else {}
|
| 133 |
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
workflow_result, _ = call_mcp("write_graph", {
|
| 140 |
-
"action": "create_node",
|
| 141 |
-
"label": "Workflow",
|
| 142 |
-
"properties": {
|
| 143 |
-
"id": workflow_id,
|
| 144 |
-
"name": f"Streamlit Query: {question[:50]}...",
|
| 145 |
-
"description": f"Query from Streamlit: {question}",
|
| 146 |
-
"status": "active",
|
| 147 |
-
"created_at": datetime.now().isoformat(),
|
| 148 |
-
"source": "streamlit"
|
| 149 |
-
}
|
| 150 |
-
})
|
| 151 |
-
|
| 152 |
-
if "error" in workflow_result:
|
| 153 |
-
st.error(f"Failed to create workflow: {workflow_result['error']}")
|
| 154 |
-
return None
|
| 155 |
-
|
| 156 |
-
# Create instruction sequence
|
| 157 |
-
instructions = [
|
| 158 |
-
{
|
| 159 |
-
"id": f"{workflow_id}-inst-1",
|
| 160 |
-
"type": "discover_schema",
|
| 161 |
-
"sequence": 1,
|
| 162 |
-
"description": "Discover database schema",
|
| 163 |
-
"status": "pending",
|
| 164 |
-
"pause_duration": 5, # Short pause for testing
|
| 165 |
-
"parameters": "{}"
|
| 166 |
-
},
|
| 167 |
-
{
|
| 168 |
-
"id": f"{workflow_id}-inst-2",
|
| 169 |
-
"type": "generate_sql",
|
| 170 |
-
"sequence": 2,
|
| 171 |
-
"description": f"Generate SQL for: {question}",
|
| 172 |
-
"status": "pending",
|
| 173 |
-
"pause_duration": 5,
|
| 174 |
-
"parameters": json.dumps({"question": question})
|
| 175 |
-
},
|
| 176 |
-
{
|
| 177 |
-
"id": f"{workflow_id}-inst-3",
|
| 178 |
-
"type": "review_results",
|
| 179 |
-
"sequence": 3,
|
| 180 |
-
"description": "Review and format results",
|
| 181 |
-
"status": "pending",
|
| 182 |
-
"pause_duration": 0,
|
| 183 |
-
"parameters": "{}"
|
| 184 |
-
}
|
| 185 |
-
]
|
| 186 |
-
|
| 187 |
-
# Create instruction nodes
|
| 188 |
-
for inst in instructions:
|
| 189 |
-
inst_result, _ = call_mcp("write_graph", {
|
| 190 |
-
"action": "create_node",
|
| 191 |
-
"label": "Instruction",
|
| 192 |
-
"properties": inst
|
| 193 |
-
})
|
| 194 |
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
return None
|
| 198 |
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
CREATE (w)-[:HAS_INSTRUCTION]->(i)
|
| 204 |
-
""",
|
| 205 |
-
"parameters": {"workflow_id": workflow_id, "inst_id": inst["id"]}
|
| 206 |
-
})
|
| 207 |
-
|
| 208 |
-
# Create instruction chain
|
| 209 |
-
for i in range(len(instructions) - 1):
|
| 210 |
-
chain_result, _ = call_mcp("query_graph", {
|
| 211 |
-
"query": """
|
| 212 |
-
MATCH (i1:Instruction {id: $id1}), (i2:Instruction {id: $id2})
|
| 213 |
-
CREATE (i1)-[:NEXT_INSTRUCTION]->(i2)
|
| 214 |
-
""",
|
| 215 |
-
"parameters": {"id1": instructions[i]["id"], "id2": instructions[i + 1]["id"]}
|
| 216 |
-
})
|
| 217 |
-
|
| 218 |
-
return workflow_id
|
| 219 |
|
| 220 |
-
def
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
MATCH (w:Workflow {id: $id})-[:HAS_INSTRUCTION]->(i:Instruction)
|
| 225 |
-
RETURN w.status as workflow_status,
|
| 226 |
-
collect(i.status) as instruction_statuses,
|
| 227 |
-
collect(i.type) as instruction_types,
|
| 228 |
-
collect(i.sequence) as sequences
|
| 229 |
-
""",
|
| 230 |
-
"parameters": {"id": workflow_id}
|
| 231 |
-
})
|
| 232 |
-
|
| 233 |
-
if "error" in result or not result.get("data"):
|
| 234 |
-
return {"error": "Workflow not found"}
|
| 235 |
-
|
| 236 |
-
return result["data"][0]
|
| 237 |
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
MATCH (w:Workflow {id: $id})-[:HAS_INSTRUCTION]->(i:Instruction)-[:EXECUTED_AS]->(e:Execution)
|
| 243 |
-
RETURN i.sequence as sequence,
|
| 244 |
-
i.type as type,
|
| 245 |
-
i.description as description,
|
| 246 |
-
e.result as result,
|
| 247 |
-
e.started_at as started_at,
|
| 248 |
-
e.completed_at as completed_at
|
| 249 |
-
ORDER BY i.sequence
|
| 250 |
-
""",
|
| 251 |
-
"parameters": {"id": workflow_id}
|
| 252 |
-
})
|
| 253 |
-
|
| 254 |
-
if "error" in result:
|
| 255 |
-
return {"error": result["error"]}
|
| 256 |
-
|
| 257 |
-
return {"executions": result.get("data", [])}
|
| 258 |
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
MATCH (t:Table)-[:HAS_COLUMN]->(c:Column)
|
| 264 |
-
RETURN t.name as table_name,
|
| 265 |
-
collect({name: c.name, type: c.data_type, nullable: c.nullable}) as columns
|
| 266 |
-
ORDER BY t.name
|
| 267 |
-
"""
|
| 268 |
-
})
|
| 269 |
-
|
| 270 |
-
if "error" in result:
|
| 271 |
-
return f"Error fetching schema: {result['error']}"
|
| 272 |
-
|
| 273 |
-
schema_text = "Database Schema:\n"
|
| 274 |
-
for record in result.get("data", []):
|
| 275 |
-
table_name = record["table_name"]
|
| 276 |
-
columns = record["columns"]
|
| 277 |
-
schema_text += f"\nTable: {table_name}\n"
|
| 278 |
-
for col in columns:
|
| 279 |
-
nullable = "NULL" if col["nullable"] else "NOT NULL"
|
| 280 |
-
schema_text += f" - {col['name']}: {col['type']} {nullable}\n"
|
| 281 |
-
|
| 282 |
-
return schema_text
|
| 283 |
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
# Sidebar
|
| 289 |
-
with st.sidebar:
|
| 290 |
-
st.header("π§ Configuration")
|
| 291 |
-
st.code(f"MCP URL: {MCP_URL}")
|
| 292 |
-
st.code(f"API Key: {MCP_API_KEY[:10]}...")
|
| 293 |
-
|
| 294 |
-
if st.button("π Refresh All", type="primary"):
|
| 295 |
-
st.rerun()
|
| 296 |
-
|
| 297 |
-
st.header("π Quick Stats")
|
| 298 |
-
stats = get_performance_stats()
|
| 299 |
-
if "error" not in stats:
|
| 300 |
-
st.metric("Logs (1h)", stats.get("count", 0))
|
| 301 |
-
else:
|
| 302 |
-
st.error(f"Stats error: {stats['error']}")
|
| 303 |
-
|
| 304 |
-
# Main tabs
|
| 305 |
-
tab1, tab2 = st.tabs(["π Connection Status", "π€ Query Tester"])
|
| 306 |
-
|
| 307 |
-
with tab1:
|
| 308 |
-
st.header("Connection Status Monitor")
|
| 309 |
-
st.caption("All database access goes through MCP server - no direct connections allowed")
|
| 310 |
-
|
| 311 |
-
# Connection status in columns
|
| 312 |
-
col1, col2, col3 = st.columns(3)
|
| 313 |
-
|
| 314 |
-
with col1:
|
| 315 |
-
st.subheader("Neo4j (via MCP)")
|
| 316 |
-
neo4j_ok, neo4j_time, neo4j_msg = test_neo4j_connection()
|
| 317 |
-
st.metric(
|
| 318 |
-
label="Status",
|
| 319 |
-
value="Online" if neo4j_ok else "Offline",
|
| 320 |
-
delta=f"{neo4j_time}ms"
|
| 321 |
-
)
|
| 322 |
-
if neo4j_ok:
|
| 323 |
-
st.success(neo4j_msg)
|
| 324 |
-
else:
|
| 325 |
-
st.error(neo4j_msg)
|
| 326 |
-
|
| 327 |
-
with col2:
|
| 328 |
-
st.subheader("PostgreSQL (via MCP)")
|
| 329 |
-
postgres_ok, postgres_time, postgres_msg = test_postgres_connection()
|
| 330 |
-
st.metric(
|
| 331 |
-
label="Status",
|
| 332 |
-
value="Online" if postgres_ok else "Offline",
|
| 333 |
-
delta=f"{postgres_time}ms"
|
| 334 |
-
)
|
| 335 |
-
if postgres_ok:
|
| 336 |
-
st.success(postgres_msg)
|
| 337 |
-
else:
|
| 338 |
-
st.error(postgres_msg)
|
| 339 |
-
|
| 340 |
-
with col3:
|
| 341 |
-
st.subheader("MCP Server")
|
| 342 |
-
mcp_ok, mcp_time, mcp_msg = test_mcp_server()
|
| 343 |
-
st.metric(
|
| 344 |
-
label="Status",
|
| 345 |
-
value="Online" if mcp_ok else "Offline",
|
| 346 |
-
delta=f"{mcp_time}ms"
|
| 347 |
-
)
|
| 348 |
-
if mcp_ok:
|
| 349 |
-
st.success(mcp_msg)
|
| 350 |
-
else:
|
| 351 |
-
st.error(mcp_msg)
|
| 352 |
-
|
| 353 |
-
# Performance stats
|
| 354 |
-
st.subheader("Performance Statistics")
|
| 355 |
-
stats = get_performance_stats()
|
| 356 |
-
if "error" not in stats:
|
| 357 |
-
st.info(f"Operations in last hour: {stats.get('count', 0)}")
|
| 358 |
-
else:
|
| 359 |
-
st.error(f"Cannot fetch stats: {stats['error']}")
|
| 360 |
-
|
| 361 |
-
# Auto-refresh info
|
| 362 |
-
st.session_state.last_refresh = datetime.now()
|
| 363 |
-
st.caption(f"Last checked: {st.session_state.last_refresh.strftime('%H:%M:%S')}")
|
| 364 |
-
|
| 365 |
-
# Auto-refresh every 5 seconds
|
| 366 |
-
time.sleep(5)
|
| 367 |
-
st.rerun()
|
| 368 |
-
|
| 369 |
-
with tab2:
|
| 370 |
-
st.header("Query Tester")
|
| 371 |
-
st.caption("Test natural language queries through the agentic engine")
|
| 372 |
-
|
| 373 |
-
# Query input
|
| 374 |
-
question = st.text_area(
|
| 375 |
-
"Enter your question:",
|
| 376 |
-
height=100,
|
| 377 |
-
placeholder="e.g., 'How many customers do we have?' or 'Show me all orders from last month'"
|
| 378 |
-
)
|
| 379 |
-
|
| 380 |
-
col1, col2 = st.columns([1, 1])
|
| 381 |
-
|
| 382 |
-
with col1:
|
| 383 |
-
if st.button("π Execute Query", type="primary", disabled=not question.strip()):
|
| 384 |
-
if question.strip():
|
| 385 |
-
with st.spinner("Creating workflow..."):
|
| 386 |
-
workflow_id = create_workflow(question.strip())
|
| 387 |
-
if workflow_id:
|
| 388 |
-
st.session_state.workflow_id = workflow_id
|
| 389 |
-
st.success(f"Workflow created: {workflow_id}")
|
| 390 |
-
else:
|
| 391 |
-
st.error("Failed to create workflow")
|
| 392 |
-
|
| 393 |
-
with col2:
|
| 394 |
-
if st.button("ποΈ Clear Results"):
|
| 395 |
-
st.session_state.workflow_id = None
|
| 396 |
-
st.rerun()
|
| 397 |
-
|
| 398 |
-
# Workflow execution monitoring
|
| 399 |
-
if st.session_state.workflow_id:
|
| 400 |
-
st.subheader("Execution Progress")
|
| 401 |
-
|
| 402 |
-
# Get workflow status
|
| 403 |
-
status = get_workflow_status(st.session_state.workflow_id)
|
| 404 |
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
instruction_types = status.get("instruction_types", [])
|
| 411 |
-
|
| 412 |
-
# Progress bar
|
| 413 |
-
completed = sum(1 for s in instruction_statuses if s == "complete")
|
| 414 |
-
total = len(instruction_statuses)
|
| 415 |
-
progress = completed / total if total > 0 else 0
|
| 416 |
|
| 417 |
-
|
| 418 |
-
st.caption(f"Progress: {completed}/{total} instructions completed")
|
| 419 |
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
elif inst_status == "failed":
|
| 429 |
-
st.error(f"β {inst_type}")
|
| 430 |
-
else:
|
| 431 |
-
st.info(f"β³ {inst_type}")
|
| 432 |
-
|
| 433 |
-
# Get and display results
|
| 434 |
-
if completed > 0:
|
| 435 |
-
results = get_workflow_results(st.session_state.workflow_id)
|
| 436 |
-
|
| 437 |
-
if "error" not in results:
|
| 438 |
-
st.subheader("Execution Results")
|
| 439 |
-
|
| 440 |
-
for execution in results.get("executions", []):
|
| 441 |
-
with st.expander(f"Step {execution['sequence']}: {execution['type']}"):
|
| 442 |
-
st.write(f"**Description:** {execution['description']}")
|
| 443 |
-
|
| 444 |
-
if execution['started_at'] and execution['completed_at']:
|
| 445 |
-
start = datetime.fromisoformat(execution['started_at'].replace('Z', '+00:00'))
|
| 446 |
-
end = datetime.fromisoformat(execution['completed_at'].replace('Z', '+00:00'))
|
| 447 |
-
duration = (end - start).total_seconds()
|
| 448 |
-
st.write(f"**Duration:** {duration:.2f} seconds")
|
| 449 |
-
|
| 450 |
-
if execution['result']:
|
| 451 |
-
try:
|
| 452 |
-
result_data = json.loads(execution['result']) if isinstance(execution['result'], str) else execution['result']
|
| 453 |
-
|
| 454 |
-
if execution['type'] == 'generate_sql' and 'generated_sql' in result_data:
|
| 455 |
-
st.write("**Generated SQL:**")
|
| 456 |
-
st.code(result_data['generated_sql'], language='sql')
|
| 457 |
-
|
| 458 |
-
if 'data' in result_data and result_data['data']:
|
| 459 |
-
st.write("**Query Results:**")
|
| 460 |
-
df = pd.DataFrame(result_data['data'])
|
| 461 |
-
st.dataframe(df)
|
| 462 |
-
|
| 463 |
-
if 'error' in result_data:
|
| 464 |
-
st.error(f"Error: {result_data['error']}")
|
| 465 |
-
|
| 466 |
-
except Exception as e:
|
| 467 |
-
st.write("**Raw Result:**")
|
| 468 |
-
st.code(str(execution['result']))
|
| 469 |
-
else:
|
| 470 |
-
st.error(f"Results error: {results['error']}")
|
| 471 |
|
| 472 |
-
|
| 473 |
-
with st.expander("π§ Debug Information"):
|
| 474 |
-
st.write("**Last 5 MCP Requests:**")
|
| 475 |
-
for entry in st.session_state.debug_log:
|
| 476 |
-
status_icon = "β
" if entry["success"] else "β"
|
| 477 |
-
st.write(f"{status_icon} {entry['timestamp']} - {entry['tool']} ({entry['response_time_ms']}ms)")
|
| 478 |
-
if not entry["success"] and "error" in entry:
|
| 479 |
-
st.error(f"Error: {entry['error']}")
|
| 480 |
-
|
| 481 |
-
st.write("**Important:** All database operations go through MCP server. Direct database access is not permitted.")
|
| 482 |
|
| 483 |
if __name__ == "__main__":
|
| 484 |
main()
|
|
|
|
| 7 |
|
| 8 |
import streamlit as st
|
| 9 |
import requests
|
| 10 |
+
import os
|
| 11 |
import json
|
| 12 |
import pandas as pd
|
| 13 |
+
from typing import Dict, Any
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
# --- Configuration ---
|
| 16 |
+
AGENT_URL = os.getenv("AGENT_URL", "http://agent:8001/query")
|
| 17 |
+
NEO4J_URL = os.getenv("NEO4J_URL", "http://neo4j:7474")
|
| 18 |
MCP_URL = os.getenv("MCP_URL", "http://mcp:8000/mcp")
|
| 19 |
MCP_API_KEY = os.getenv("MCP_API_KEY", "dev-key-123")
|
| 20 |
|
|
|
|
| 21 |
st.set_page_config(
|
| 22 |
+
page_title="GraphRAG Chat",
|
| 23 |
+
page_icon="π¬",
|
| 24 |
+
layout="wide"
|
|
|
|
| 25 |
)
|
| 26 |
|
| 27 |
+
# --- Session State ---
|
| 28 |
+
if 'messages' not in st.session_state:
|
| 29 |
+
st.session_state.messages = []
|
| 30 |
+
if 'schema_info' not in st.session_state:
|
| 31 |
+
st.session_state.schema_info = ""
|
| 32 |
+
|
| 33 |
+
# --- Helper Functions ---
|
| 34 |
+
def stream_agent_response(question: str):
|
| 35 |
+
"""Streams the agent's response, yielding JSON objects."""
|
| 36 |
+
try:
|
| 37 |
+
with requests.post(AGENT_URL, json={"question": question}, stream=True, timeout=300) as r:
|
| 38 |
+
r.raise_for_status()
|
| 39 |
+
for chunk in r.iter_content(chunk_size=None):
|
| 40 |
+
if chunk:
|
| 41 |
+
try:
|
| 42 |
+
yield json.loads(chunk)
|
| 43 |
+
except json.JSONDecodeError:
|
| 44 |
+
# Handle potential parsing errors if chunks are not perfect JSON
|
| 45 |
+
logger.warning(f"Could not decode JSON chunk: {chunk}")
|
| 46 |
+
continue
|
| 47 |
+
except requests.exceptions.RequestException as e:
|
| 48 |
+
yield {"error": f"Failed to connect to agent: {e}"}
|
| 49 |
|
| 50 |
+
def fetch_schema_info() -> str:
|
| 51 |
+
"""Fetches the database schema from the MCP server for display."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
try:
|
| 53 |
response = requests.post(
|
| 54 |
+
f"{MCP_URL}/discovery/get_relevant_schemas",
|
| 55 |
+
headers={"Authorization": f"Bearer {MCP_API_KEY}", "Content-Type": "application/json"},
|
| 56 |
+
json={"query": ""}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
)
|
| 58 |
+
response.raise_for_status()
|
| 59 |
+
data = response.json()
|
| 60 |
+
|
| 61 |
+
if data.get("status") == "success":
|
| 62 |
+
schemas = data.get("schemas", [])
|
| 63 |
+
if not schemas: return "No schema information found."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
+
# Group columns by table
|
| 66 |
+
tables = {}
|
| 67 |
+
for s in schemas:
|
| 68 |
+
table_key = f"{s.get('database', '')}.{s.get('table', '')}"
|
| 69 |
+
if table_key not in tables:
|
| 70 |
+
tables[table_key] = []
|
| 71 |
+
tables[table_key].append(f"{s.get('name', '')} ({s.get('type', [''])[0]})")
|
| 72 |
+
|
| 73 |
+
schema_text = ""
|
| 74 |
+
for table, columns in tables.items():
|
| 75 |
+
schema_text += f"**{table}**:\n"
|
| 76 |
+
for col in columns:
|
| 77 |
+
schema_text += f"- {col}\n"
|
| 78 |
+
return schema_text
|
| 79 |
+
else:
|
| 80 |
+
return f"Error from MCP: {data.get('message', 'Unknown error')}"
|
|
|
|
| 81 |
|
| 82 |
+
except requests.exceptions.RequestException as e:
|
| 83 |
+
return f"Could not fetch schema: {e}"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
+
@st.cache_data(ttl=600)
|
| 86 |
+
def get_cached_schema():
|
| 87 |
+
"""Cache the schema info to avoid repeated calls."""
|
| 88 |
+
return fetch_schema_info()
|
|
|
|
|
|
|
| 89 |
|
| 90 |
+
def check_service_health(service_name: str, url: str) -> bool:
|
| 91 |
+
"""Checks if a service is reachable."""
|
| 92 |
try:
|
| 93 |
+
response = requests.get(url, timeout=5)
|
| 94 |
+
return response.status_code in [200, 401]
|
| 95 |
+
except requests.exceptions.RequestException:
|
| 96 |
+
return False
|
| 97 |
+
|
| 98 |
+
# --- UI Components ---
|
| 99 |
+
def display_sidebar():
|
| 100 |
+
with st.sidebar:
|
| 101 |
+
st.title("ποΈ Database Schema")
|
| 102 |
|
| 103 |
+
if st.button("π Refresh Schema"):
|
| 104 |
+
st.cache_data.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
+
st.session_state.schema_info = get_cached_schema()
|
| 107 |
+
st.markdown(st.session_state.schema_info)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
+
st.markdown("---")
|
| 110 |
+
st.title("π Service Status")
|
| 111 |
+
|
| 112 |
+
neo4j_status = "β
Online" if check_service_health("Neo4j", NEO4J_URL) else "β Offline"
|
| 113 |
+
mcp_status = "β
Online" if check_service_health("MCP", MCP_URL.replace("/mcp", "/health")) else "β Offline"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
+
st.markdown(f"**Neo4j:** {neo4j_status}")
|
| 116 |
+
st.markdown(f"**MCP Server:** {mcp_status}")
|
|
|
|
| 117 |
|
| 118 |
+
st.markdown("---")
|
| 119 |
+
if st.button("ποΈ Clear Chat History"):
|
| 120 |
+
st.session_state.messages = []
|
| 121 |
+
st.rerun()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
+
def main():
|
| 124 |
+
display_sidebar()
|
| 125 |
+
st.title("π¬ GraphRAG Conversational Agent")
|
| 126 |
+
st.markdown("Ask questions about the life sciences dataset. The agent's thought process will be shown below.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
+
# Display chat history
|
| 129 |
+
for message in st.session_state.messages:
|
| 130 |
+
with st.chat_message(message["role"]):
|
| 131 |
+
st.markdown(message["content"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
+
if prompt := st.chat_input("Ask your question here..."):
|
| 134 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
| 135 |
+
with st.chat_message("user"):
|
| 136 |
+
st.markdown(prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
+
with st.chat_message("assistant"):
|
| 139 |
+
full_response = ""
|
| 140 |
+
response_box = st.empty()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
+
for chunk in stream_agent_response(prompt):
|
| 143 |
+
if "error" in chunk:
|
| 144 |
+
full_response = chunk["error"]
|
| 145 |
+
response_box.error(full_response)
|
| 146 |
+
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
+
content = chunk.get("content", "")
|
|
|
|
| 149 |
|
| 150 |
+
if chunk.get("type") == "thought":
|
| 151 |
+
full_response += f"π€ *{content}*\n\n"
|
| 152 |
+
elif chunk.get("type") == "observation":
|
| 153 |
+
full_response += f"{content}\n\n"
|
| 154 |
+
elif chunk.get("type") == "final_answer":
|
| 155 |
+
full_response += f"**Final Answer:**\n{content}"
|
| 156 |
+
|
| 157 |
+
response_box.markdown(full_response)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
+
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
if __name__ == "__main__":
|
| 162 |
main()
|