GreymanT commited on
Commit
8bf4d58
·
verified ·
1 Parent(s): ca9a435

Upload 80 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. api/__init__.py +2 -0
  3. api/__pycache__/__init__.cpython-311.pyc +0 -0
  4. api/__pycache__/main.cpython-311.pyc +0 -0
  5. api/__pycache__/routes.cpython-311.pyc +0 -0
  6. api/main.py +66 -0
  7. api/routes.py +126 -0
  8. data/chroma_db/12c6a58a-a370-4695-a9d6-a858314de1c1/data_level0.bin +3 -0
  9. data/chroma_db/12c6a58a-a370-4695-a9d6-a858314de1c1/header.bin +3 -0
  10. data/chroma_db/12c6a58a-a370-4695-a9d6-a858314de1c1/length.bin +3 -0
  11. data/chroma_db/12c6a58a-a370-4695-a9d6-a858314de1c1/link_lists.bin +3 -0
  12. data/chroma_db/chroma.sqlite3 +3 -0
  13. requirements.txt +43 -3
  14. scripts/__pycache__/add_documents.cpython-311.pyc +0 -0
  15. scripts/add_documents.py +214 -0
  16. scripts/add_sample_documents.py +67 -0
  17. scripts/start_api.sh +12 -0
  18. scripts/start_ui.sh +12 -0
  19. src/__init__.py +4 -0
  20. src/__pycache__/__init__.cpython-311.pyc +0 -0
  21. src/agents/__init__.py +2 -0
  22. src/agents/__pycache__/__init__.cpython-311.pyc +0 -0
  23. src/agents/__pycache__/aggregator_agent.cpython-311.pyc +0 -0
  24. src/agents/__pycache__/base_agent.cpython-311.pyc +0 -0
  25. src/agents/__pycache__/cloud_agent.cpython-311.pyc +0 -0
  26. src/agents/__pycache__/local_data_agent.cpython-311.pyc +0 -0
  27. src/agents/__pycache__/search_agent.cpython-311.pyc +0 -0
  28. src/agents/__pycache__/snowflake_agent.cpython-311.pyc +0 -0
  29. src/agents/aggregator_agent.py +266 -0
  30. src/agents/base_agent.py +305 -0
  31. src/agents/cloud_agent.py +162 -0
  32. src/agents/local_data_agent.py +86 -0
  33. src/agents/search_agent.py +101 -0
  34. src/agents/snowflake_agent.py +245 -0
  35. src/core/__init__.py +2 -0
  36. src/core/__pycache__/__init__.cpython-311.pyc +0 -0
  37. src/core/__pycache__/config.cpython-311.pyc +0 -0
  38. src/core/__pycache__/orchestrator.cpython-311.pyc +0 -0
  39. src/core/config.py +220 -0
  40. src/core/orchestrator.py +332 -0
  41. src/mcp/__init__.py +2 -0
  42. src/mcp/__pycache__/__init__.cpython-311.pyc +0 -0
  43. src/mcp/__pycache__/mcp_server.cpython-311.pyc +0 -0
  44. src/mcp/__pycache__/snowflake_server.cpython-311.pyc +0 -0
  45. src/mcp/cloud_server.py +156 -0
  46. src/mcp/local_server.py +122 -0
  47. src/mcp/mcp_server.py +78 -0
  48. src/mcp/search_server.py +62 -0
  49. src/mcp/snowflake_server.py +185 -0
  50. src/memory/__init__.py +2 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/chroma_db/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
api/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ """API layer."""
2
+
api/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (182 Bytes). View file
 
api/__pycache__/main.cpython-311.pyc ADDED
Binary file (2.79 kB). View file
 
api/__pycache__/routes.cpython-311.pyc ADDED
Binary file (6.59 kB). View file
 
api/main.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI application main file."""
2
+
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ # Add parent directory to path to allow imports from src
7
+ parent_dir = Path(__file__).parent.parent
8
+ if str(parent_dir) not in sys.path:
9
+ sys.path.insert(0, str(parent_dir))
10
+
11
+ import logging
12
+ from fastapi import FastAPI
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ from src.core.config import get_settings
15
+ from api.routes import router
16
+
17
+ # Configure logging
18
+ logging.basicConfig(
19
+ level=logging.INFO,
20
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
21
+ )
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Initialize FastAPI app
25
+ app = FastAPI(
26
+ title="Agentic RAG System API",
27
+ description="Production-ready Agentic RAG system with multiple agents and MCP servers",
28
+ version="1.0.0",
29
+ )
30
+
31
+ # Configure CORS
32
+ settings = get_settings()
33
+ app.add_middleware(
34
+ CORSMiddleware,
35
+ allow_origins=["*"], # In production, specify allowed origins
36
+ allow_credentials=True,
37
+ allow_methods=["*"],
38
+ allow_headers=["*"],
39
+ )
40
+
41
+ # Include routes
42
+ app.include_router(router)
43
+
44
+
45
+ @app.on_event("startup")
46
+ async def startup_event():
47
+ """Initialize components on startup."""
48
+ logger.info("Starting Agentic RAG System API")
49
+ logger.info(f"API running on {settings.api_host}:{settings.api_port}")
50
+
51
+
52
+ @app.on_event("shutdown")
53
+ async def shutdown_event():
54
+ """Cleanup on shutdown."""
55
+ logger.info("Shutting down Agentic RAG System API")
56
+
57
+
58
+ if __name__ == "__main__":
59
+ import uvicorn
60
+ uvicorn.run(
61
+ "main:app",
62
+ host=settings.api_host,
63
+ port=settings.api_port,
64
+ reload=settings.api_debug,
65
+ )
66
+
api/routes.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """API routes."""
2
+
3
+ import logging
4
+ from typing import Optional
5
+ from fastapi import APIRouter, HTTPException
6
+ from pydantic import BaseModel
7
+ from src.core.orchestrator import get_orchestrator
8
+ from src.memory.long_term_memory import LongTermMemory
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ router = APIRouter()
13
+
14
+
15
+ # Request/Response models
16
+ class QueryRequest(BaseModel):
17
+ """Query request model."""
18
+ query: str
19
+ tier: str = "basic" # "basic", "agent", or "advanced"
20
+ session_id: Optional[str] = None
21
+
22
+
23
+ class QueryResponse(BaseModel):
24
+ """Query response model."""
25
+ success: bool
26
+ answer: Optional[str] = None
27
+ tier: str
28
+ error: Optional[str] = None
29
+ sources: Optional[list] = None
30
+ model: Optional[str] = None
31
+ agent: Optional[str] = None
32
+
33
+
34
+ class HealthResponse(BaseModel):
35
+ """Health check response."""
36
+ status: str
37
+ version: str
38
+
39
+
40
+ @router.get("/health", response_model=HealthResponse)
41
+ async def health_check():
42
+ """Health check endpoint."""
43
+ return {
44
+ "status": "healthy",
45
+ "version": "1.0.0",
46
+ }
47
+
48
+
49
+ @router.post("/query", response_model=QueryResponse)
50
+ async def query(request: QueryRequest):
51
+ """
52
+ Main query endpoint supporting all tiers.
53
+
54
+ - **basic**: Simple RAG (retrieval + generation)
55
+ - **agent**: Agent with tools (calculator, web search, database)
56
+ - **advanced**: Multi-agent system with MCP servers
57
+ """
58
+ try:
59
+ orchestrator = get_orchestrator()
60
+ response = await orchestrator.process_query(
61
+ query=request.query,
62
+ tier=request.tier,
63
+ session_id=request.session_id,
64
+ )
65
+
66
+ return QueryResponse(**response)
67
+
68
+ except Exception as e:
69
+ logger.error(f"Error processing query: {e}")
70
+ raise HTTPException(status_code=500, detail=str(e))
71
+
72
+
73
+ @router.get("/agents")
74
+ async def get_agents():
75
+ """Get status of all agents."""
76
+ try:
77
+ orchestrator = get_orchestrator()
78
+ status = orchestrator.get_agent_status()
79
+ return status
80
+ except Exception as e:
81
+ logger.error(f"Error getting agent status: {e}")
82
+ raise HTTPException(status_code=500, detail=str(e))
83
+
84
+
85
+ @router.get("/system")
86
+ async def get_system_info():
87
+ """Get system information."""
88
+ try:
89
+ orchestrator = get_orchestrator()
90
+ info = orchestrator.get_system_info()
91
+ return info
92
+ except Exception as e:
93
+ logger.error(f"Error getting system info: {e}")
94
+ raise HTTPException(status_code=500, detail=str(e))
95
+
96
+
97
+ @router.get("/memory/{session_id}")
98
+ async def get_memory(session_id: str):
99
+ """Get memory for a session."""
100
+ try:
101
+ long_term_memory = LongTermMemory()
102
+ memories = long_term_memory.get_session_memories(session_id, limit=50)
103
+ return {
104
+ "session_id": session_id,
105
+ "memories": memories,
106
+ "count": len(memories),
107
+ }
108
+ except Exception as e:
109
+ logger.error(f"Error getting memory: {e}")
110
+ raise HTTPException(status_code=500, detail=str(e))
111
+
112
+
113
+ @router.delete("/memory/{session_id}")
114
+ async def delete_memory(session_id: str):
115
+ """Delete memory for a session."""
116
+ try:
117
+ long_term_memory = LongTermMemory()
118
+ deleted_count = long_term_memory.delete_session_memories(session_id)
119
+ return {
120
+ "session_id": session_id,
121
+ "deleted": deleted_count,
122
+ }
123
+ except Exception as e:
124
+ logger.error(f"Error deleting memory: {e}")
125
+ raise HTTPException(status_code=500, detail=str(e))
126
+
data/chroma_db/12c6a58a-a370-4695-a9d6-a858314de1c1/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:065a5aa61390e7ff9c4d37dbb028fd9a866fd618df83adeb7b41c957a09d4dc0
3
+ size 628400
data/chroma_db/12c6a58a-a370-4695-a9d6-a858314de1c1/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b081be2c2276a57e995075c7de2f3cb25e903798aac36d98042045533ab28f7d
3
+ size 100
data/chroma_db/12c6a58a-a370-4695-a9d6-a858314de1c1/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5
3
+ size 400
data/chroma_db/12c6a58a-a370-4695-a9d6-a858314de1c1/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
3
+ size 0
data/chroma_db/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1235fee08e11e0ecfb47ccd075b737c7eec7d2c316a571f5512adc721b2110d
3
+ size 1687552
requirements.txt CHANGED
@@ -1,3 +1,43 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ openai>=1.12.0
3
+ chromadb>=0.4.22
4
+ pydantic>=2.5.0
5
+ pydantic-settings>=2.1.0
6
+ python-dotenv>=1.0.0
7
+
8
+ # MCP SDK
9
+ mcp>=0.9.0
10
+
11
+ # API framework
12
+ fastapi>=0.109.0
13
+ uvicorn[standard]>=0.27.0
14
+ httpx>=0.26.0
15
+
16
+ # UI framework
17
+ streamlit>=1.31.0
18
+
19
+ # Utilities
20
+ tiktoken>=0.5.2
21
+ numpy>=1.26.0
22
+ aiofiles>=23.2.1
23
+ nest-asyncio>=1.6.0 # For async handling in Streamlit
24
+
25
+ # Testing
26
+ pytest>=7.4.4
27
+ pytest-asyncio>=0.23.3
28
+ pytest-mock>=3.12.0
29
+
30
+ # Optional: Web search providers
31
+ tavily-python>=0.3.0
32
+
33
+ # Optional: Database support
34
+ sqlalchemy>=2.0.25
35
+
36
+ # Optional: Cloud storage
37
+ boto3>=1.34.0 # AWS S3
38
+ google-cloud-storage>=2.14.0 # GCS
39
+
40
+ # Optional: Snowflake
41
+ snowflake-connector-python>=3.7.0
42
+ pandas>=2.0.0 # For Snowflake data operations
43
+
scripts/__pycache__/add_documents.cpython-311.pyc ADDED
Binary file (10.6 kB). View file
 
scripts/add_documents.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Script to add documents to the vector store from files or text."""
2
+
3
+ import sys
4
+ import os
5
+ from pathlib import Path
6
+ from typing import List, Dict, Optional
7
+
8
+ # Add parent directory to path
9
+ try:
10
+ parent_dir = Path(__file__).parent.parent
11
+ sys.path.insert(0, str(parent_dir))
12
+ except (NameError, AttributeError):
13
+ # Handle case where __file__ is not available (e.g., when executed directly)
14
+ import os
15
+ parent_dir = Path(os.getcwd())
16
+ if str(parent_dir) not in sys.path:
17
+ sys.path.insert(0, str(parent_dir))
18
+
19
+ # Lazy import - only import when functions are actually called
20
+ # This prevents import errors when the module is scanned but not used
21
+ _vector_store = None
22
+ _vector_store_error = None
23
+
24
+ def _get_vector_store():
25
+ """Lazy import of vector store."""
26
+ global _vector_store, _vector_store_error
27
+ if _vector_store_error is not None:
28
+ raise _vector_store_error
29
+ if _vector_store is None:
30
+ try:
31
+ from src.retrieval.vector_store import get_vector_store
32
+ _vector_store = get_vector_store()
33
+ except ImportError as e:
34
+ _vector_store_error = ImportError(
35
+ f"Failed to import vector store. Make sure all dependencies (including chromadb) are installed. "
36
+ f"Run: pip install -r requirements.txt\n"
37
+ f"Original error: {e}"
38
+ )
39
+ raise _vector_store_error
40
+ return _vector_store
41
+
42
+
43
+ def add_text_documents(texts: List[str], metadatas: Optional[List[Dict]] = None):
44
+ """
45
+ Add text documents to the vector store.
46
+
47
+ Args:
48
+ texts: List of document texts
49
+ metadatas: Optional list of metadata dictionaries
50
+ """
51
+ vector_store = _get_vector_store()
52
+
53
+ if metadatas is None:
54
+ metadatas = [{}] * len(texts)
55
+
56
+ ids = vector_store.add_documents(texts, metadatas)
57
+ print(f"✅ Added {len(ids)} documents to vector store")
58
+ return ids
59
+
60
+
61
+ def add_file_documents(file_paths: List[str], chunk_size: int = 1000):
62
+ """
63
+ Add documents from text files to the vector store.
64
+
65
+ Args:
66
+ file_paths: List of file paths to read
67
+ chunk_size: Size of text chunks (characters) for splitting large documents
68
+ """
69
+ all_documents = []
70
+ all_metadatas = []
71
+
72
+ for file_path in file_paths:
73
+ file_path = Path(file_path)
74
+ if not file_path.exists():
75
+ print(f"⚠️ Warning: File not found: {file_path}")
76
+ continue
77
+
78
+ try:
79
+ with open(file_path, 'r', encoding='utf-8') as f:
80
+ content = f.read()
81
+
82
+ # Split large documents into chunks
83
+ if len(content) > chunk_size:
84
+ chunks = [content[i:i+chunk_size] for i in range(0, len(content), chunk_size)]
85
+ for i, chunk in enumerate(chunks):
86
+ all_documents.append(chunk)
87
+ all_metadatas.append({
88
+ "source": str(file_path.name),
89
+ "chunk": i + 1,
90
+ "type": "file"
91
+ })
92
+ else:
93
+ all_documents.append(content)
94
+ all_metadatas.append({
95
+ "source": str(file_path.name),
96
+ "type": "file"
97
+ })
98
+
99
+ print(f"✅ Loaded: {file_path.name}")
100
+ except Exception as e:
101
+ print(f"❌ Error reading {file_path}: {e}")
102
+
103
+ if all_documents:
104
+ ids = add_text_documents(all_documents, all_metadatas)
105
+ return ids
106
+ else:
107
+ print("⚠️ No documents to add")
108
+ return []
109
+
110
+
111
+ def add_from_directory(directory: str, extensions: List[str] = None):
112
+ """
113
+ Add all text files from a directory.
114
+
115
+ Args:
116
+ directory: Directory path
117
+ extensions: List of file extensions to include (default: ['.txt', '.md', '.py'])
118
+ """
119
+ if extensions is None:
120
+ extensions = ['.txt', '.md', '.py', '.json']
121
+
122
+ directory = Path(directory)
123
+ if not directory.exists():
124
+ print(f"❌ Directory not found: {directory}")
125
+ return []
126
+
127
+ file_paths = []
128
+ for ext in extensions:
129
+ file_paths.extend(directory.glob(f"**/*{ext}"))
130
+
131
+ if not file_paths:
132
+ print(f"⚠️ No files found with extensions {extensions} in {directory}")
133
+ return []
134
+
135
+ print(f"📁 Found {len(file_paths)} files in {directory}")
136
+ return add_file_documents([str(f) for f in file_paths])
137
+
138
+
139
+ if __name__ == "__main__":
140
+ import argparse
141
+
142
+ parser = argparse.ArgumentParser(description="Add documents to the vector store")
143
+ parser.add_argument("--text", nargs="+", help="Add text documents directly")
144
+ parser.add_argument("--file", nargs="+", help="Add documents from files")
145
+ parser.add_argument("--directory", help="Add all documents from a directory")
146
+ parser.add_argument("--sample-docs", action="store_true", help="Add sample documents")
147
+
148
+ args = parser.parse_args()
149
+
150
+ if args.sample_docs:
151
+ # Add sample documents
152
+ sample_docs = [
153
+ {
154
+ "text": """
155
+ Oracle Exadata is a database machine that combines hardware and software
156
+ to provide high-performance database solutions. When migrating Exadata
157
+ workloads to the cloud, it's important to consider compatibility,
158
+ performance, and feature parity.
159
+ """,
160
+ "metadata": {"source": "exadata_migration_guide", "type": "documentation"},
161
+ },
162
+ {
163
+ "text": """
164
+ Cloud migration strategies for Oracle Exadata include:
165
+ 1. Lift and shift - moving workloads with minimal changes
166
+ 2. Replatforming - adapting to cloud-native services
167
+ 3. Refactoring - redesigning for cloud architecture
168
+
169
+ Each approach has different trade-offs in terms of effort, cost, and feature availability.
170
+ """,
171
+ "metadata": {"source": "migration_strategies", "type": "guide"},
172
+ },
173
+ {
174
+ "text": """
175
+ Oracle Cloud Infrastructure (OCI) provides Exadata Cloud Service which
176
+ maintains full feature compatibility with on-premises Exadata. This
177
+ service offers the same architecture and capabilities, making it ideal
178
+ for migrations requiring minimal changes.
179
+ """,
180
+ "metadata": {"source": "oci_exadata", "type": "cloud_service"},
181
+ },
182
+ {
183
+ "text": """
184
+ Oracle AI Database services on AWS provide customers with a simplified path
185
+ to migrate Oracle Exadata workloads. These services run on AWS infrastructure
186
+ and offer managed database solutions that maintain Oracle compatibility while
187
+ leveraging AWS cloud capabilities. The services include automated migration tools,
188
+ performance optimization, and seamless integration with AWS services.
189
+ """,
190
+ "metadata": {"source": "oracle_aws_services", "type": "cloud_service"},
191
+ },
192
+ ]
193
+
194
+ documents = [doc["text"] for doc in sample_docs]
195
+ metadatas = [doc["metadata"] for doc in sample_docs]
196
+ add_text_documents(documents, metadatas)
197
+
198
+ elif args.text:
199
+ add_text_documents(args.text)
200
+
201
+ elif args.file:
202
+ add_file_documents(args.file)
203
+
204
+ elif args.directory:
205
+ add_from_directory(args.directory)
206
+
207
+ else:
208
+ print("Please specify --text, --file, --directory, or --sample-docs")
209
+ print("\nExamples:")
210
+ print(" python scripts/add_documents.py --sample-docs")
211
+ print(" python scripts/add_documents.py --file doc1.txt doc2.txt")
212
+ print(" python scripts/add_documents.py --directory data/sample_documents")
213
+ print(" python scripts/add_documents.py --text 'Your document text here'")
214
+
scripts/add_sample_documents.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Script to add sample documents to the vector store."""
2
+
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ # Add parent directory to path
7
+ parent_dir = Path(__file__).parent.parent
8
+ sys.path.insert(0, str(parent_dir))
9
+
10
+ # Lazy import to avoid issues when module is scanned but not used
11
+ def _get_vector_store():
12
+ """Lazy import of vector store."""
13
+ try:
14
+ from src.retrieval.vector_store import get_vector_store
15
+ return get_vector_store()
16
+ except ImportError as e:
17
+ raise ImportError(
18
+ f"Failed to import vector store. Make sure all dependencies are installed. "
19
+ f"Original error: {e}"
20
+ )
21
+
22
+ def add_sample_documents():
23
+ """Add sample documents to the vector store."""
24
+ vector_store = _get_vector_store()
25
+
26
+ sample_docs = [
27
+ {
28
+ "text": """
29
+ Oracle Exadata is a database machine that combines hardware and software
30
+ to provide high-performance database solutions. When migrating Exadata
31
+ workloads to the cloud, it's important to consider compatibility,
32
+ performance, and feature parity.
33
+ """,
34
+ "metadata": {"source": "exadata_migration_guide", "type": "documentation"},
35
+ },
36
+ {
37
+ "text": """
38
+ Cloud migration strategies for Oracle Exadata include:
39
+ 1. Lift and shift - moving workloads with minimal changes
40
+ 2. Replatforming - adapting to cloud-native services
41
+ 3. Refactoring - redesigning for cloud architecture
42
+
43
+ Each approach has different trade-offs in terms of effort, cost, and feature availability.
44
+ """,
45
+ "metadata": {"source": "migration_strategies", "type": "guide"},
46
+ },
47
+ {
48
+ "text": """
49
+ Oracle Cloud Infrastructure (OCI) provides Exadata Cloud Service which
50
+ maintains full feature compatibility with on-premises Exadata. This
51
+ service offers the same architecture and capabilities, making it ideal
52
+ for migrations requiring minimal changes.
53
+ """,
54
+ "metadata": {"source": "oci_exadata", "type": "cloud_service"},
55
+ },
56
+ ]
57
+
58
+ documents = [doc["text"] for doc in sample_docs]
59
+ metadatas = [doc["metadata"] for doc in sample_docs]
60
+
61
+ ids = vector_store.add_documents(documents, metadatas)
62
+ print(f"Added {len(ids)} sample documents to vector store")
63
+ print(f"Document IDs: {ids}")
64
+
65
+ if __name__ == "__main__":
66
+ add_sample_documents()
67
+
scripts/start_api.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Script to start the API server
3
+
4
+ cd "$(dirname "$0")/.."
5
+
6
+ echo "Starting Agentic RAG API server..."
7
+ echo "API will be available at http://localhost:8000"
8
+ echo "Press Ctrl+C to stop the server"
9
+ echo ""
10
+
11
+ uvicorn api.main:app --reload --host 0.0.0.0 --port 8000
12
+
scripts/start_ui.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Script to start the Streamlit UI
3
+
4
+ cd "$(dirname "$0")/.."
5
+
6
+ echo "Starting Agentic RAG Streamlit UI..."
7
+ echo "UI will be available at http://localhost:8501"
8
+ echo "Press Ctrl+C to stop the server"
9
+ echo ""
10
+
11
+ streamlit run ui/streamlit_app.py
12
+
src/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """Agentic RAG System - Main package."""
2
+
3
+ __version__ = "1.0.0"
4
+
src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (233 Bytes). View file
 
src/agents/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ """Agent implementations."""
2
+
src/agents/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (201 Bytes). View file
 
src/agents/__pycache__/aggregator_agent.cpython-311.pyc ADDED
Binary file (11.9 kB). View file
 
src/agents/__pycache__/base_agent.cpython-311.pyc ADDED
Binary file (12.7 kB). View file
 
src/agents/__pycache__/cloud_agent.cpython-311.pyc ADDED
Binary file (8.68 kB). View file
 
src/agents/__pycache__/local_data_agent.cpython-311.pyc ADDED
Binary file (4.22 kB). View file
 
src/agents/__pycache__/search_agent.cpython-311.pyc ADDED
Binary file (5.08 kB). View file
 
src/agents/__pycache__/snowflake_agent.cpython-311.pyc ADDED
Binary file (12.3 kB). View file
 
src/agents/aggregator_agent.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Aggregator agent that coordinates multiple specialized agents."""
2
+
3
+ import logging
4
+ from typing import List, Dict, Any, Optional
5
+ from openai import OpenAI
6
+ from src.agents.base_agent import BaseAgent
7
+ from src.agents.local_data_agent import LocalDataAgent
8
+ from src.agents.search_agent import SearchAgent
9
+ from src.agents.cloud_agent import CloudAgent
10
+ from src.core.config import get_settings
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class AggregatorAgent(BaseAgent):
16
+ """Agent that coordinates multiple specialized agents and aggregates responses."""
17
+
18
+ def __init__(self, use_planning: bool = True):
19
+ """Initialize aggregator agent."""
20
+ super().__init__(
21
+ name="aggregator_agent",
22
+ description=(
23
+ "You are an aggregator agent that coordinates multiple specialized agents "
24
+ "to answer complex questions. You route queries to appropriate agents and "
25
+ "synthesize their responses into a comprehensive answer."
26
+ ),
27
+ use_memory=True,
28
+ use_planning=use_planning,
29
+ planning_type="cot",
30
+ )
31
+
32
+ # Initialize specialized agents
33
+ self.local_agent = LocalDataAgent(use_planning=False)
34
+ self.search_agent = SearchAgent(use_planning=True)
35
+ self.cloud_agent = CloudAgent(use_planning=False)
36
+
37
+ # Initialize Snowflake agent if configured
38
+ self.snowflake_agent = None
39
+ from src.core.config import get_settings
40
+ settings = get_settings()
41
+ if settings.has_snowflake():
42
+ from src.agents.snowflake_agent import SnowflakeAgent
43
+ snowflake_config = settings.get_snowflake_config()
44
+ self.snowflake_agent = SnowflakeAgent(
45
+ snowflake_config=snowflake_config,
46
+ use_planning=False
47
+ )
48
+
49
+ self.agents = {
50
+ "local": self.local_agent,
51
+ "search": self.search_agent,
52
+ "cloud": self.cloud_agent,
53
+ }
54
+
55
+ if self.snowflake_agent:
56
+ self.agents["snowflake"] = self.snowflake_agent
57
+
58
+ async def retrieve_context(self, query: str) -> str:
59
+ """
60
+ Retrieve context by querying relevant agents.
61
+
62
+ Args:
63
+ query: User query
64
+
65
+ Returns:
66
+ Aggregated context string
67
+ """
68
+ # Determine which agents to query based on query content
69
+ agents_to_query = self._select_agents(query)
70
+
71
+ # Query selected agents in parallel
72
+ results = {}
73
+ for agent_name, agent in agents_to_query.items():
74
+ try:
75
+ context = await agent.retrieve_context(query)
76
+ results[agent_name] = context
77
+ except Exception as e:
78
+ logger.error(f"Error querying {agent_name} agent: {e}")
79
+ results[agent_name] = f"Error: {str(e)}"
80
+
81
+ # Combine results
82
+ context_parts = ["Context from specialized agents:"]
83
+ for agent_name, context in results.items():
84
+ context_parts.append(f"\n--- {agent_name.upper()} AGENT ---")
85
+ context_parts.append(context)
86
+
87
+ return "\n".join(context_parts)
88
+
89
+ def _select_agents(self, query: str) -> Dict[str, BaseAgent]:
90
+ """
91
+ Select which agents to query based on the query content.
92
+
93
+ Args:
94
+ query: User query
95
+
96
+ Returns:
97
+ Dictionary of agent names to agents
98
+ """
99
+ query_lower = query.lower()
100
+ selected = {}
101
+
102
+ # Always include local agent for document queries
103
+ if any(keyword in query_lower for keyword in ["document", "file", "local", "data"]):
104
+ selected["local"] = self.local_agent
105
+
106
+ # Include search agent for current information or web queries
107
+ if any(keyword in query_lower for keyword in [
108
+ "current", "latest", "recent", "news", "web", "internet", "online", "search"
109
+ ]):
110
+ selected["search"] = self.search_agent
111
+
112
+ # Include cloud agent for cloud-related queries
113
+ if any(keyword in query_lower for keyword in ["cloud", "s3", "gcs", "storage", "remote"]):
114
+ selected["cloud"] = self.cloud_agent
115
+
116
+ # Include Snowflake agent for database/data warehouse queries
117
+ if self.snowflake_agent and any(keyword in query_lower for keyword in [
118
+ "snowflake", "data warehouse", "sql", "database", "query", "table", "schema"
119
+ ]):
120
+ selected["snowflake"] = self.snowflake_agent
121
+
122
+ # If no specific match, use local and search by default
123
+ if not selected:
124
+ selected["local"] = self.local_agent
125
+ selected["search"] = self.search_agent
126
+
127
+ return selected
128
+
129
+ async def process(
130
+ self,
131
+ query: str,
132
+ session_id: Optional[str] = None,
133
+ context: Optional[str] = None,
134
+ ) -> dict:
135
+ """
136
+ Process query by coordinating multiple agents.
137
+
138
+ Args:
139
+ query: User query
140
+ session_id: Optional session ID
141
+ context: Optional additional context
142
+
143
+ Returns:
144
+ Aggregated response dictionary
145
+ """
146
+ # Select agents to query
147
+ agents_to_query = self._select_agents(query)
148
+
149
+ # Get responses from selected agents
150
+ agent_responses = {}
151
+ for agent_name, agent in agents_to_query.items():
152
+ try:
153
+ response = await agent.process(query, session_id, context)
154
+ agent_responses[agent_name] = response
155
+ except Exception as e:
156
+ logger.error(f"Error processing with {agent_name} agent: {e}")
157
+ agent_responses[agent_name] = {
158
+ "success": False,
159
+ "error": str(e),
160
+ }
161
+
162
+ # Synthesize responses
163
+ synthesized_response = await self._synthesize_responses(
164
+ query=query,
165
+ agent_responses=agent_responses,
166
+ session_id=session_id,
167
+ )
168
+
169
+ return synthesized_response
170
+
171
+ async def _synthesize_responses(
172
+ self,
173
+ query: str,
174
+ agent_responses: Dict[str, dict],
175
+ session_id: Optional[str],
176
+ ) -> dict:
177
+ """
178
+ Synthesize responses from multiple agents.
179
+
180
+ Args:
181
+ query: Original query
182
+ agent_responses: Dictionary of agent responses
183
+ session_id: Optional session ID
184
+
185
+ Returns:
186
+ Synthesized response
187
+ """
188
+ # Collect successful responses
189
+ successful_responses = {
190
+ name: resp for name, resp in agent_responses.items()
191
+ if resp.get("success", False)
192
+ }
193
+
194
+ if not successful_responses:
195
+ # If no successful responses, try to return the first response with error details
196
+ error_messages = []
197
+ for name, resp in agent_responses.items():
198
+ error_msg = resp.get("error", "Unknown error")
199
+ error_messages.append(f"{name}: {error_msg}")
200
+
201
+ return {
202
+ "success": False,
203
+ "error": f"No agents provided successful responses. Errors: {'; '.join(error_messages)}",
204
+ "agent_responses": agent_responses,
205
+ }
206
+
207
+ # If only one agent responded, return its response
208
+ if len(successful_responses) == 1:
209
+ response = list(successful_responses.values())[0]
210
+ response["aggregated_by"] = "single_agent"
211
+ return response
212
+
213
+ # Multiple responses - synthesize using LLM
214
+ try:
215
+ # Build synthesis prompt
216
+ synthesis_parts = [
217
+ "You are synthesizing responses from multiple specialized agents.",
218
+ f"Original question: {query}",
219
+ "",
220
+ "Agent responses:",
221
+ ]
222
+
223
+ for agent_name, response in successful_responses.items():
224
+ answer = response.get("answer", "No answer provided")
225
+ synthesis_parts.append(f"\n{agent_name.upper()} Agent:")
226
+ synthesis_parts.append(answer)
227
+
228
+ synthesis_parts.extend([
229
+ "",
230
+ "Synthesize these responses into a comprehensive, coherent answer.",
231
+ "If there are conflicts, note them. If information is complementary, combine it.",
232
+ ])
233
+
234
+ synthesis_prompt = "\n".join(synthesis_parts)
235
+
236
+ # Call LLM for synthesis
237
+ messages = [
238
+ {"role": "system", "content": self.description},
239
+ {"role": "user", "content": synthesis_prompt},
240
+ ]
241
+
242
+ response = self.client.chat.completions.create(
243
+ model=self.model,
244
+ messages=messages,
245
+ temperature=0.7,
246
+ )
247
+
248
+ synthesized_answer = response.choices[0].message.content
249
+
250
+ return {
251
+ "success": True,
252
+ "answer": synthesized_answer,
253
+ "agent": self.name,
254
+ "aggregated_by": "multiple_agents",
255
+ "source_agents": list(successful_responses.keys()),
256
+ "agent_responses": successful_responses,
257
+ "model": self.model,
258
+ }
259
+
260
+ except Exception as e:
261
+ logger.error(f"Error synthesizing responses: {e}")
262
+ # Fallback: return first successful response
263
+ first_response = list(successful_responses.values())[0]
264
+ first_response["aggregated_by"] = "fallback"
265
+ return first_response
266
+
src/agents/base_agent.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base agent class with common functionality."""
2
+
3
+ import logging
4
+ from abc import ABC, abstractmethod
5
+ from typing import List, Dict, Any, Optional, Callable
6
+ from openai import OpenAI
7
+ from src.core.config import get_settings
8
+ from src.memory.short_term_memory import ShortTermMemory
9
+ from src.memory.long_term_memory import LongTermMemory
10
+ from src.planning.react_planner import ReActPlanner
11
+ from src.planning.cot_planner import CoTPlanner
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class BaseAgent(ABC):
17
+ """Base class for all agents."""
18
+
19
+ def __init__(
20
+ self,
21
+ name: str,
22
+ description: str,
23
+ tools: Optional[List[Dict[str, Any]]] = None,
24
+ use_memory: bool = True,
25
+ use_planning: bool = False,
26
+ planning_type: str = "react", # "react" or "cot"
27
+ ):
28
+ """
29
+ Initialize base agent.
30
+
31
+ Args:
32
+ name: Agent name
33
+ description: Agent description
34
+ tools: List of available tools
35
+ use_memory: Whether to use memory
36
+ use_planning: Whether to use planning
37
+ planning_type: Type of planning ("react" or "cot")
38
+ """
39
+ self.name = name
40
+ self.description = description
41
+ self.settings = get_settings()
42
+
43
+ # Initialize OpenAI client
44
+ self.client = OpenAI(**self.settings.get_openai_client_kwargs())
45
+ self.model = self.settings.openai_model
46
+
47
+ # Initialize memory
48
+ self.use_memory = use_memory
49
+ self.short_term_memory: Optional[ShortTermMemory] = None
50
+ self.long_term_memory: Optional[LongTermMemory] = None
51
+ if use_memory:
52
+ self.short_term_memory = ShortTermMemory()
53
+ self.long_term_memory = LongTermMemory()
54
+
55
+ # Initialize planning
56
+ self.use_planning = use_planning
57
+ self.planning_type = planning_type
58
+ self.planner: Optional[ReActPlanner | CoTPlanner] = None
59
+ if use_planning:
60
+ if planning_type == "react":
61
+ self.planner = ReActPlanner(tools=tools or [])
62
+ elif planning_type == "cot":
63
+ self.planner = CoTPlanner()
64
+ else:
65
+ logger.warning(f"Unknown planning type: {planning_type}")
66
+
67
+ # Tools
68
+ self.tools = tools or []
69
+ self.tool_functions: Dict[str, Callable] = {}
70
+
71
+ def add_tool(self, tool: Dict[str, Any], tool_function: Callable) -> None:
72
+ """
73
+ Add a tool to the agent.
74
+
75
+ Args:
76
+ tool: Tool schema
77
+ tool_function: Function to execute the tool
78
+ """
79
+ self.tools.append(tool)
80
+ self.tool_functions[tool["name"]] = tool_function
81
+ if self.planner and isinstance(self.planner, ReActPlanner):
82
+ self.planner.add_tool(tool)
83
+
84
+ async def process(
85
+ self,
86
+ query: str,
87
+ session_id: Optional[str] = None,
88
+ context: Optional[str] = None,
89
+ ) -> Dict[str, Any]:
90
+ """
91
+ Process a query using the agent.
92
+
93
+ Args:
94
+ query: User query
95
+ session_id: Optional session ID for memory
96
+ context: Optional additional context
97
+
98
+ Returns:
99
+ Response dictionary
100
+ """
101
+ try:
102
+ # Add user message to memory
103
+ if self.short_term_memory:
104
+ self.short_term_memory.add_message("user", query)
105
+
106
+ # Load long-term memory if available
107
+ long_term_context = ""
108
+ if self.long_term_memory and session_id:
109
+ memories = self.long_term_memory.search_memories(query, session_id, n_results=3)
110
+ if memories:
111
+ long_term_context = "\n".join([
112
+ m["content"] for m in memories
113
+ ])
114
+
115
+ # Combine contexts
116
+ full_context = self._build_context(context, long_term_context)
117
+
118
+ # Use planning if enabled
119
+ if self.use_planning and self.planner:
120
+ response = await self._process_with_planning(query, full_context, session_id)
121
+ else:
122
+ response = await self._process_direct(query, full_context, session_id)
123
+
124
+ # Add assistant response to memory
125
+ if self.short_term_memory and "answer" in response:
126
+ self.short_term_memory.add_message("assistant", response["answer"])
127
+
128
+ # Store in long-term memory
129
+ if self.long_term_memory and session_id:
130
+ messages = self.short_term_memory.get_messages() if self.short_term_memory else []
131
+ self.long_term_memory.store_conversation(session_id, messages)
132
+
133
+ return response
134
+
135
+ except Exception as e:
136
+ logger.error(f"Error processing query in {self.name}: {e}")
137
+ return {
138
+ "success": False,
139
+ "error": str(e),
140
+ "agent": self.name,
141
+ }
142
+
143
+ async def _process_direct(
144
+ self,
145
+ query: str,
146
+ context: str,
147
+ session_id: Optional[str],
148
+ ) -> Dict[str, Any]:
149
+ """Process query directly without planning."""
150
+ # Build messages
151
+ messages = []
152
+ if context:
153
+ messages.append({
154
+ "role": "system",
155
+ "content": f"{self.description}\n\nContext: {context}",
156
+ })
157
+ else:
158
+ messages.append({
159
+ "role": "system",
160
+ "content": self.description,
161
+ })
162
+
163
+ # Add conversation history
164
+ if self.short_term_memory:
165
+ history = self.short_term_memory.get_messages(format_for_llm=True)
166
+ messages.extend(history[-5:]) # Last 5 messages
167
+ else:
168
+ messages.append({
169
+ "role": "user",
170
+ "content": query,
171
+ })
172
+
173
+ # Call LLM
174
+ try:
175
+ response = self.client.chat.completions.create(
176
+ model=self.model,
177
+ messages=messages,
178
+ temperature=0.7,
179
+ )
180
+
181
+ answer = response.choices[0].message.content
182
+
183
+ return {
184
+ "success": True,
185
+ "answer": answer,
186
+ "agent": self.name,
187
+ "model": self.model,
188
+ }
189
+ except Exception as e:
190
+ error_msg = str(e)
191
+ if "quota" in error_msg.lower() or "429" in error_msg:
192
+ logger.error(f"OpenAI API quota exceeded: {e}")
193
+ raise Exception("OpenAI API quota exceeded. Please check your billing and plan details.")
194
+ elif "api key" in error_msg.lower() or "401" in error_msg:
195
+ logger.error(f"Invalid OpenAI API key: {e}")
196
+ raise Exception("Invalid OpenAI API key. Please check your .env file.")
197
+ else:
198
+ logger.error(f"Error calling LLM: {e}")
199
+ raise
200
+
201
+ async def _process_with_planning(
202
+ self,
203
+ query: str,
204
+ context: str,
205
+ session_id: Optional[str],
206
+ ) -> Dict[str, Any]:
207
+ """Process query using planning."""
208
+ if not self.planner:
209
+ return await self._process_direct(query, context, session_id)
210
+
211
+ # Create sync LLM call function (planner expects sync)
212
+ def llm_call(prompt: str) -> str:
213
+ messages = [
214
+ {"role": "system", "content": self.description},
215
+ {"role": "user", "content": prompt},
216
+ ]
217
+ response = self.client.chat.completions.create(
218
+ model=self.model,
219
+ messages=messages,
220
+ temperature=0.7,
221
+ )
222
+ return response.choices[0].message.content
223
+
224
+ # Generate plan (planner methods are sync)
225
+ if isinstance(self.planner, ReActPlanner):
226
+ plan = self.planner.plan(
227
+ query=query,
228
+ context=context,
229
+ llm_call=llm_call,
230
+ )
231
+ else: # CoT planner
232
+ plan = self.planner.plan(
233
+ query=query,
234
+ context=context,
235
+ llm_call=llm_call,
236
+ )
237
+
238
+ # Extract final answer
239
+ if isinstance(self.planner, ReActPlanner):
240
+ answer = plan.get("final_answer", "I couldn't find a complete answer.")
241
+ else:
242
+ answer = plan.get("conclusion", "I couldn't find a complete answer.")
243
+
244
+ return {
245
+ "success": True,
246
+ "answer": answer,
247
+ "agent": self.name,
248
+ "plan": plan,
249
+ "model": self.model,
250
+ }
251
+
252
+ def _build_context(
253
+ self,
254
+ additional_context: Optional[str],
255
+ long_term_context: str,
256
+ ) -> str:
257
+ """Build full context string."""
258
+ parts = []
259
+ if long_term_context:
260
+ parts.append(f"Relevant past conversations:\n{long_term_context}")
261
+ if additional_context:
262
+ parts.append(f"Additional context:\n{additional_context}")
263
+ return "\n\n".join(parts)
264
+
265
+ async def _execute_tool(
266
+ self,
267
+ tool_name: str,
268
+ **kwargs,
269
+ ) -> Any:
270
+ """Execute a tool (supports both sync and async tools)."""
271
+ if tool_name not in self.tool_functions:
272
+ raise ValueError(f"Tool '{tool_name}' not found")
273
+
274
+ tool_func = self.tool_functions[tool_name]
275
+ # Check if tool is async
276
+ import asyncio
277
+ if asyncio.iscoroutinefunction(tool_func):
278
+ return await tool_func(**kwargs)
279
+ else:
280
+ return tool_func(**kwargs)
281
+
282
+ @abstractmethod
283
+ async def retrieve_context(self, query: str) -> str:
284
+ """
285
+ Retrieve relevant context for the query.
286
+
287
+ Args:
288
+ query: User query
289
+
290
+ Returns:
291
+ Context string
292
+ """
293
+ pass
294
+
295
+ def get_status(self) -> Dict[str, Any]:
296
+ """Get agent status."""
297
+ return {
298
+ "name": self.name,
299
+ "description": self.description,
300
+ "tools": [t["name"] for t in self.tools],
301
+ "memory_enabled": self.use_memory,
302
+ "planning_enabled": self.use_planning,
303
+ "planning_type": self.planning_type if self.use_planning else None,
304
+ }
305
+
src/agents/cloud_agent.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Cloud storage agent for remote data access."""
2
+
3
+ import logging
4
+ import os
5
+ from typing import Optional
6
+ from src.agents.base_agent import BaseAgent
7
+ from src.core.config import get_settings
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class CloudAgent(BaseAgent):
13
+ """Agent specialized in accessing cloud storage and remote data."""
14
+
15
+ def __init__(self, use_planning: bool = False):
16
+ """Initialize cloud agent."""
17
+ super().__init__(
18
+ name="cloud_agent",
19
+ description=(
20
+ "You are a specialized agent for accessing cloud storage and remote data. "
21
+ "You can retrieve documents and information from cloud storage services "
22
+ "like AWS S3 or Google Cloud Storage."
23
+ ),
24
+ use_memory=True,
25
+ use_planning=use_planning,
26
+ )
27
+ self.settings = get_settings()
28
+ self._init_cloud_client()
29
+
30
+ def _init_cloud_client(self):
31
+ """Initialize cloud storage client based on configuration."""
32
+ self.cloud_type = None
33
+ self.client = None
34
+
35
+ # Check for AWS S3
36
+ if self.settings.aws_access_key_id and self.settings.aws_s3_bucket:
37
+ try:
38
+ import boto3
39
+ self.client = boto3.client(
40
+ "s3",
41
+ aws_access_key_id=self.settings.aws_access_key_id,
42
+ aws_secret_access_key=self.settings.aws_secret_access_key,
43
+ region_name=self.settings.aws_region,
44
+ )
45
+ self.cloud_type = "s3"
46
+ self.bucket_name = self.settings.aws_s3_bucket
47
+ logger.info("Initialized AWS S3 client")
48
+ except ImportError:
49
+ logger.warning("boto3 not installed, AWS S3 unavailable")
50
+ except Exception as e:
51
+ logger.error(f"Error initializing S3 client: {e}")
52
+
53
+ # Check for GCS
54
+ elif self.settings.google_application_credentials and self.settings.gcs_bucket_name:
55
+ try:
56
+ from google.cloud import storage
57
+ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.settings.google_application_credentials
58
+ self.client = storage.Client()
59
+ self.cloud_type = "gcs"
60
+ self.bucket_name = self.settings.gcs_bucket_name
61
+ logger.info("Initialized Google Cloud Storage client")
62
+ except ImportError:
63
+ logger.warning("google-cloud-storage not installed, GCS unavailable")
64
+ except Exception as e:
65
+ logger.error(f"Error initializing GCS client: {e}")
66
+
67
+ if not self.client:
68
+ logger.warning("No cloud storage configured")
69
+
70
+ async def retrieve_context(self, query: str) -> str:
71
+ """
72
+ Retrieve relevant context from cloud storage.
73
+
74
+ Args:
75
+ query: User query
76
+
77
+ Returns:
78
+ Context string from cloud documents
79
+ """
80
+ if not self.client:
81
+ return "Cloud storage is not configured."
82
+
83
+ try:
84
+ if self.cloud_type == "s3":
85
+ return await self._retrieve_from_s3(query)
86
+ elif self.cloud_type == "gcs":
87
+ return await self._retrieve_from_gcs(query)
88
+ else:
89
+ return "Unknown cloud storage type."
90
+ except Exception as e:
91
+ logger.error(f"Error retrieving cloud context: {e}")
92
+ return f"Error retrieving from cloud storage: {str(e)}"
93
+
94
+ async def _retrieve_from_s3(self, query: str) -> str:
95
+ """Retrieve documents from S3."""
96
+ try:
97
+ # List objects in bucket (simplified - in production, use vector search)
98
+ response = self.client.list_objects_v2(
99
+ Bucket=self.bucket_name,
100
+ MaxKeys=10,
101
+ )
102
+
103
+ if "Contents" not in response:
104
+ return "No documents found in S3 bucket."
105
+
106
+ context_parts = [f"Documents in S3 bucket '{self.bucket_name}':"]
107
+ for obj in response["Contents"][:5]: # Limit to 5
108
+ key = obj["Key"]
109
+ size = obj["Size"]
110
+ context_parts.append(f"- {key} ({size} bytes)")
111
+
112
+ return "\n".join(context_parts)
113
+ except Exception as e:
114
+ logger.error(f"Error listing S3 objects: {e}")
115
+ return f"Error accessing S3: {str(e)}"
116
+
117
+ async def _retrieve_from_gcs(self, query: str) -> str:
118
+ """Retrieve documents from GCS."""
119
+ try:
120
+ bucket = self.client.bucket(self.bucket_name)
121
+ blobs = list(bucket.list_blobs(max_results=10))
122
+
123
+ if not blobs:
124
+ return "No documents found in GCS bucket."
125
+
126
+ context_parts = [f"Documents in GCS bucket '{self.bucket_name}':"]
127
+ for blob in blobs[:5]: # Limit to 5
128
+ context_parts.append(f"- {blob.name} ({blob.size} bytes)")
129
+
130
+ return "\n".join(context_parts)
131
+ except Exception as e:
132
+ logger.error(f"Error listing GCS objects: {e}")
133
+ return f"Error accessing GCS: {str(e)}"
134
+
135
+ async def process(
136
+ self,
137
+ query: str,
138
+ session_id: Optional[str] = None,
139
+ context: Optional[str] = None,
140
+ ) -> dict:
141
+ """
142
+ Process query with cloud storage access.
143
+
144
+ Args:
145
+ query: User query
146
+ session_id: Optional session ID
147
+ context: Optional additional context
148
+
149
+ Returns:
150
+ Response dictionary
151
+ """
152
+ # Retrieve cloud context
153
+ cloud_context = await self.retrieve_context(query)
154
+
155
+ # Combine with provided context
156
+ full_context = cloud_context
157
+ if context:
158
+ full_context = f"{context}\n\n{cloud_context}"
159
+
160
+ # Process using base agent
161
+ return await super().process(query, session_id, full_context)
162
+
src/agents/local_data_agent.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Local data agent for document queries."""
2
+
3
+ import logging
4
+ from typing import Optional
5
+ from src.agents.base_agent import BaseAgent
6
+ from src.retrieval.vector_store import get_vector_store
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class LocalDataAgent(BaseAgent):
12
+ """Agent specialized in querying local documents and data."""
13
+
14
+ def __init__(self, use_planning: bool = False):
15
+ """Initialize local data agent."""
16
+ super().__init__(
17
+ name="local_data_agent",
18
+ description=(
19
+ "You are a specialized agent for querying local documents and data. "
20
+ "You have access to a vector store of local documents and can retrieve "
21
+ "relevant information to answer questions."
22
+ ),
23
+ use_memory=True,
24
+ use_planning=use_planning,
25
+ )
26
+ self.vector_store = get_vector_store()
27
+
28
+ async def retrieve_context(self, query: str) -> str:
29
+ """
30
+ Retrieve relevant context from local documents.
31
+
32
+ Args:
33
+ query: User query
34
+
35
+ Returns:
36
+ Context string from retrieved documents
37
+ """
38
+ try:
39
+ # Search vector store
40
+ results = self.vector_store.search(query=query, n_results=5)
41
+
42
+ if not results["documents"]:
43
+ return "No relevant documents found in local data."
44
+
45
+ # Format results
46
+ context_parts = ["Relevant documents from local data:"]
47
+ for i, (doc, metadata) in enumerate(
48
+ zip(results["documents"], results["metadatas"]), 1
49
+ ):
50
+ source = metadata.get("source", "Unknown")
51
+ context_parts.append(f"\n[{i}] Source: {source}")
52
+ context_parts.append(f"Content: {doc[:500]}...") # Truncate long docs
53
+
54
+ return "\n".join(context_parts)
55
+ except Exception as e:
56
+ logger.error(f"Error retrieving local context: {e}")
57
+ return f"Error retrieving local documents: {str(e)}"
58
+
59
+ async def process(
60
+ self,
61
+ query: str,
62
+ session_id: Optional[str] = None,
63
+ context: Optional[str] = None,
64
+ ) -> dict:
65
+ """
66
+ Process query with local document retrieval.
67
+
68
+ Args:
69
+ query: User query
70
+ session_id: Optional session ID
71
+ context: Optional additional context
72
+
73
+ Returns:
74
+ Response dictionary
75
+ """
76
+ # Retrieve local context
77
+ local_context = await self.retrieve_context(query)
78
+
79
+ # Combine with provided context
80
+ full_context = local_context
81
+ if context:
82
+ full_context = f"{context}\n\n{local_context}"
83
+
84
+ # Process using base agent
85
+ return await super().process(query, session_id, full_context)
86
+
src/agents/search_agent.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Web search agent for online information."""
2
+
3
+ import logging
4
+ from typing import Optional
5
+ from src.agents.base_agent import BaseAgent
6
+ from src.tools.web_search import get_web_search
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class SearchAgent(BaseAgent):
12
+ """Agent specialized in web search and online information."""
13
+
14
+ def __init__(self, use_planning: bool = True):
15
+ """Initialize search agent."""
16
+ web_search = get_web_search()
17
+ tools = [web_search.get_tool_schema()]
18
+
19
+ super().__init__(
20
+ name="search_agent",
21
+ description=(
22
+ "You are a specialized agent for searching the web and finding "
23
+ "online information. You can search the internet to answer questions "
24
+ "that require current or external information."
25
+ ),
26
+ tools=tools,
27
+ use_memory=True,
28
+ use_planning=use_planning,
29
+ planning_type="react",
30
+ )
31
+
32
+ # Register tool function (async wrapper)
33
+ async def web_search_tool(query: str, max_results: int = 5):
34
+ return await web_search.search(query, max_results)
35
+
36
+ self.add_tool(
37
+ tool=web_search.get_tool_schema(),
38
+ tool_function=web_search_tool,
39
+ )
40
+ self.web_search = web_search
41
+
42
+ async def retrieve_context(self, query: str) -> str:
43
+ """
44
+ Retrieve relevant context from web search.
45
+
46
+ Args:
47
+ query: User query
48
+
49
+ Returns:
50
+ Context string from web search results
51
+ """
52
+ try:
53
+ # Perform web search
54
+ search_results = await self.web_search.search(query, max_results=5)
55
+
56
+ if not search_results.get("success") or not search_results.get("results"):
57
+ return "No relevant information found from web search."
58
+
59
+ # Format results
60
+ context_parts = ["Web search results:"]
61
+ for i, result in enumerate(search_results["results"], 1):
62
+ title = result.get("title", "No title")
63
+ url = result.get("url", "")
64
+ content = result.get("content", "")[:300] # Truncate
65
+ context_parts.append(f"\n[{i}] {title}")
66
+ context_parts.append(f"URL: {url}")
67
+ context_parts.append(f"Content: {content}...")
68
+
69
+ return "\n".join(context_parts)
70
+ except Exception as e:
71
+ logger.error(f"Error retrieving web context: {e}")
72
+ return f"Error performing web search: {str(e)}"
73
+
74
+ async def process(
75
+ self,
76
+ query: str,
77
+ session_id: Optional[str] = None,
78
+ context: Optional[str] = None,
79
+ ) -> dict:
80
+ """
81
+ Process query with web search.
82
+
83
+ Args:
84
+ query: User query
85
+ session_id: Optional session ID
86
+ context: Optional additional context
87
+
88
+ Returns:
89
+ Response dictionary
90
+ """
91
+ # Retrieve web context
92
+ web_context = await self.retrieve_context(query)
93
+
94
+ # Combine with provided context
95
+ full_context = web_context
96
+ if context:
97
+ full_context = f"{context}\n\n{web_context}"
98
+
99
+ # Process using base agent (which will use planning if enabled)
100
+ return await super().process(query, session_id, full_context)
101
+
src/agents/snowflake_agent.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Snowflake data warehouse agent."""
2
+
3
+ import logging
4
+ from typing import Dict, List, Optional
5
+ import json
6
+ from src.agents.base_agent import BaseAgent
7
+ from src.mcp.snowflake_server import SnowflakeMCPServer
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class SnowflakeAgent(BaseAgent):
13
+ """Agent specialized in querying Snowflake data warehouse."""
14
+
15
+ def __init__(self, snowflake_config: Optional[Dict] = None, use_planning: bool = False):
16
+ """Initialize Snowflake agent."""
17
+ super().__init__(
18
+ name="snowflake_agent",
19
+ description=(
20
+ "You are a specialized agent for querying Snowflake data warehouse. "
21
+ "You can convert natural language queries to SQL and execute them "
22
+ "on Snowflake databases."
23
+ ),
24
+ use_memory=True,
25
+ use_planning=use_planning,
26
+ )
27
+
28
+ # Initialize Snowflake MCP server
29
+ self.snowflake_server = SnowflakeMCPServer(config=snowflake_config)
30
+ self.tables_cache: Optional[List[str]] = None
31
+
32
+ def get_available_tables(self) -> List[str]:
33
+ """Cache and return available tables."""
34
+ if not self.tables_cache:
35
+ try:
36
+ self.tables_cache = self.snowflake_server.get_tables()
37
+ except Exception as e:
38
+ logger.error(f"Error getting tables: {e}")
39
+ self.tables_cache = []
40
+ return self.tables_cache
41
+
42
+ def get_context_for_query(self, user_query: str) -> str:
43
+ """Build context about available tables and schemas."""
44
+ try:
45
+ tables = self.get_available_tables()
46
+
47
+ if not tables:
48
+ return "No tables available in Snowflake database."
49
+
50
+ # Get schema for relevant tables (limit to avoid token overflow)
51
+ context = "Available Snowflake tables:\n\n"
52
+ for table in tables[:10]: # Limit to first 10 tables
53
+ try:
54
+ schema = self.snowflake_server.get_table_schema(table)
55
+ context += f"Table: {table}\n"
56
+ if schema:
57
+ context += "Columns: " + ", ".join([
58
+ f"{col.get('COLUMN_NAME', 'unknown')} ({col.get('DATA_TYPE', 'unknown')})"
59
+ for col in schema[:5] # First 5 columns
60
+ ]) + "\n\n"
61
+ else:
62
+ context += "Columns: (schema not available)\n\n"
63
+ except Exception as e:
64
+ logger.warning(f"Error getting schema for {table}: {e}")
65
+ context += f"Table: {table}\nColumns: (error retrieving schema)\n\n"
66
+
67
+ return context
68
+ except Exception as e:
69
+ logger.error(f"Error building context: {e}")
70
+ return f"Error building context: {str(e)}"
71
+
72
+ def natural_language_to_sql(self, user_query: str) -> str:
73
+ """Convert natural language query to SQL using LLM."""
74
+ context = self.get_context_for_query(user_query)
75
+
76
+ prompt = f"""You are a Snowflake SQL expert. Convert this natural language query to SQL.
77
+
78
+ Database context:
79
+ {context}
80
+
81
+ User query: {user_query}
82
+
83
+ Requirements:
84
+ 1. Generate ONLY valid Snowflake SQL
85
+ 2. Use proper table and column names from the context
86
+ 3. Include appropriate filters and aggregations
87
+ 4. Limit results to 100 rows for safety
88
+ 5. Return ONLY the SQL query, no explanation
89
+
90
+ SQL Query:"""
91
+
92
+ try:
93
+ messages = [
94
+ {
95
+ "role": "system",
96
+ "content": "You are a Snowflake SQL expert. Generate only valid SQL queries.",
97
+ },
98
+ {
99
+ "role": "user",
100
+ "content": prompt,
101
+ },
102
+ ]
103
+
104
+ response = self.client.chat.completions.create(
105
+ model=self.model,
106
+ messages=messages,
107
+ temperature=0.3, # Lower temperature for more deterministic SQL
108
+ )
109
+
110
+ sql = response.choices[0].message.content.strip()
111
+
112
+ # Clean up any markdown code blocks
113
+ sql = sql.replace("```sql", "").replace("```", "").strip()
114
+
115
+ return sql
116
+ except Exception as e:
117
+ logger.error(f"Error generating SQL: {e}")
118
+ raise
119
+
120
+ async def retrieve_context(self, query: str) -> str:
121
+ """
122
+ Retrieve relevant context from Snowflake.
123
+
124
+ Args:
125
+ query: User query
126
+
127
+ Returns:
128
+ Context string from Snowflake
129
+ """
130
+ try:
131
+ # Get available tables context
132
+ context = self.get_context_for_query(query)
133
+
134
+ # If query seems to be asking for data, try to generate and execute SQL
135
+ if any(keyword in query.lower() for keyword in ['show', 'list', 'get', 'find', 'select']):
136
+ try:
137
+ sql = self.natural_language_to_sql(query)
138
+ results = self.snowflake_server.query(sql)
139
+
140
+ if results and not any('error' in str(r).lower() for r in results):
141
+ # Format results for context
142
+ context += f"\n\nQuery Results:\n"
143
+ context += json.dumps(results[:5], indent=2) # First 5 rows
144
+ except Exception as e:
145
+ logger.warning(f"Could not execute query for context: {e}")
146
+
147
+ return context
148
+ except Exception as e:
149
+ logger.error(f"Error retrieving Snowflake context: {e}")
150
+ return f"Error retrieving Snowflake context: {str(e)}"
151
+
152
+ async def process(
153
+ self,
154
+ query: str,
155
+ session_id: Optional[str] = None,
156
+ context: Optional[str] = None,
157
+ ) -> dict:
158
+ """
159
+ Process query with Snowflake data warehouse.
160
+
161
+ Args:
162
+ query: User query
163
+ session_id: Optional session ID
164
+ context: Optional additional context
165
+
166
+ Returns:
167
+ Response dictionary
168
+ """
169
+ try:
170
+ # Convert natural language to SQL
171
+ sql_query = self.natural_language_to_sql(query)
172
+
173
+ logger.info(f"Generated SQL: {sql_query}")
174
+
175
+ # Execute query
176
+ results = self.snowflake_server.query(sql_query)
177
+
178
+ # Check for errors
179
+ if results and isinstance(results, list) and len(results) > 0:
180
+ if isinstance(results[0], dict) and 'error' in results[0]:
181
+ return {
182
+ "success": False,
183
+ "error": results[0].get('error', 'Unknown error'),
184
+ "sql_query": sql_query,
185
+ "agent": self.name,
186
+ }
187
+
188
+ # Format results for LLM
189
+ summary = await self._summarize_results(query, results)
190
+
191
+ # Build full context with results
192
+ snowflake_context = f"SQL Query: {sql_query}\n\nResults ({len(results)} rows):\n{json.dumps(results[:10], indent=2)}"
193
+ full_context = f"{context}\n\n{snowflake_context}" if context else snowflake_context
194
+
195
+ # Process using base agent
196
+ return await super().process(query, session_id, full_context)
197
+
198
+ except Exception as e:
199
+ logger.error(f"Error processing Snowflake query: {e}")
200
+ return {
201
+ "success": False,
202
+ "error": str(e),
203
+ "agent": self.name,
204
+ }
205
+
206
+ async def _summarize_results(self, query: str, results: List[Dict]) -> str:
207
+ """Use LLM to summarize query results."""
208
+ if not results:
209
+ return "No results found."
210
+
211
+ # Convert results to readable format
212
+ results_text = json.dumps(results[:10], indent=2)
213
+
214
+ prompt = f"""Summarize these Snowflake query results for the user.
215
+
216
+ Original question: {query}
217
+ Number of results: {len(results)}
218
+
219
+ Sample data:
220
+ {results_text}
221
+
222
+ Provide a clear, concise summary of the findings."""
223
+
224
+ try:
225
+ messages = [
226
+ {
227
+ "role": "system",
228
+ "content": "You are a helpful assistant that summarizes database query results.",
229
+ },
230
+ {
231
+ "role": "user",
232
+ "content": prompt,
233
+ },
234
+ ]
235
+
236
+ response = self.client.chat.completions.create(
237
+ model=self.model,
238
+ messages=messages,
239
+ temperature=0.7,
240
+ )
241
+
242
+ return response.choices[0].message.content
243
+ except Exception as e:
244
+ logger.error(f"Error summarizing results: {e}")
245
+ return f"Found {len(results)} results. (Summary generation failed: {str(e)})"
src/core/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ """Core orchestration and configuration."""
2
+
src/core/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (214 Bytes). View file
 
src/core/__pycache__/config.cpython-311.pyc ADDED
Binary file (8.94 kB). View file
 
src/core/__pycache__/orchestrator.cpython-311.pyc ADDED
Binary file (14.9 kB). View file
 
src/core/config.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration management using pydantic-settings."""
2
+
3
+ import os
4
+ from typing import Optional, Dict, Any
5
+ from pydantic import Field
6
+ from pydantic_settings import BaseSettings, SettingsConfigDict
7
+
8
+
9
+ class Settings(BaseSettings):
10
+ """Application settings loaded from environment variables."""
11
+
12
+ model_config = SettingsConfigDict(
13
+ env_file=".env",
14
+ env_file_encoding="utf-8",
15
+ case_sensitive=False,
16
+ extra="ignore",
17
+ )
18
+
19
+ # OpenAI/OpenRouter Configuration
20
+ openai_api_key: str = Field(default="", description="OpenAI or OpenRouter API key")
21
+ openai_base_url: Optional[str] = Field(
22
+ default=None, description="OpenAI/OpenRouter base URL (e.g., https://openrouter.ai/api/v1)"
23
+ )
24
+ openai_model: str = Field(
25
+ default="gpt-4-turbo-preview", description="Model to use (OpenAI or OpenRouter model name)"
26
+ )
27
+ openai_embedding_model: str = Field(
28
+ default="text-embedding-3-small", description="Embedding model to use"
29
+ )
30
+ # OpenRouter specific headers (optional)
31
+ openrouter_http_referer: Optional[str] = Field(
32
+ default=None, description="HTTP-Referer header for OpenRouter (optional)"
33
+ )
34
+ openrouter_title: Optional[str] = Field(
35
+ default=None, description="X-Title header for OpenRouter (optional)"
36
+ )
37
+
38
+ # ChromaDB Configuration
39
+ chroma_db_path: str = Field(
40
+ default="./data/chroma_db", description="Path to ChromaDB database"
41
+ )
42
+ chroma_collection_name: str = Field(
43
+ default="documents", description="ChromaDB collection name"
44
+ )
45
+
46
+ # MCP Server Configuration
47
+ mcp_server_host: str = Field(
48
+ default="localhost", description="MCP server host"
49
+ )
50
+ mcp_server_port: int = Field(
51
+ default=8001, description="MCP server port"
52
+ )
53
+
54
+ # Memory Configuration
55
+ short_term_memory_size: int = Field(
56
+ default=10, description="Number of recent messages to keep in short-term memory"
57
+ )
58
+ long_term_memory_enabled: bool = Field(
59
+ default=True, description="Enable long-term memory"
60
+ )
61
+ max_context_tokens: int = Field(
62
+ default=4000, description="Maximum context tokens for LLM"
63
+ )
64
+
65
+ # API Configuration
66
+ api_host: str = Field(
67
+ default="0.0.0.0", description="API server host"
68
+ )
69
+ api_port: int = Field(
70
+ default=8000, description="API server port"
71
+ )
72
+ api_debug: bool = Field(
73
+ default=False, description="Enable API debug mode"
74
+ )
75
+
76
+ # Web Search Configuration (Optional)
77
+ tavily_api_key: Optional[str] = Field(
78
+ default=None, description="Tavily API key for web search"
79
+ )
80
+ serper_api_key: Optional[str] = Field(
81
+ default=None, description="Serper API key for web search"
82
+ )
83
+
84
+ # Database Configuration (Optional)
85
+ database_url: Optional[str] = Field(
86
+ default="sqlite:///./data/app.db", description="Database connection URL"
87
+ )
88
+
89
+ # AWS Configuration (Optional)
90
+ aws_access_key_id: Optional[str] = Field(
91
+ default=None, description="AWS access key ID"
92
+ )
93
+ aws_secret_access_key: Optional[str] = Field(
94
+ default=None, description="AWS secret access key"
95
+ )
96
+ aws_region: str = Field(
97
+ default="us-east-1", description="AWS region"
98
+ )
99
+ aws_s3_bucket: Optional[str] = Field(
100
+ default=None, description="AWS S3 bucket name"
101
+ )
102
+
103
+ # GCS Configuration (Optional)
104
+ google_application_credentials: Optional[str] = Field(
105
+ default=None, description="Path to GCS service account JSON"
106
+ )
107
+ gcs_bucket_name: Optional[str] = Field(
108
+ default=None, description="GCS bucket name"
109
+ )
110
+
111
+ # Snowflake Configuration (Optional)
112
+ snowflake_account: Optional[str] = Field(
113
+ default=None, description="Snowflake account identifier"
114
+ )
115
+ snowflake_user: Optional[str] = Field(
116
+ default=None, description="Snowflake username"
117
+ )
118
+ snowflake_password: Optional[str] = Field(
119
+ default=None, description="Snowflake password"
120
+ )
121
+ snowflake_warehouse: Optional[str] = Field(
122
+ default=None, description="Snowflake warehouse name"
123
+ )
124
+ snowflake_database: Optional[str] = Field(
125
+ default=None, description="Snowflake database name"
126
+ )
127
+ snowflake_schema: Optional[str] = Field(
128
+ default="PUBLIC", description="Snowflake schema name"
129
+ )
130
+ snowflake_role: Optional[str] = Field(
131
+ default="ACCOUNTADMIN", description="Snowflake role"
132
+ )
133
+
134
+ # Logging
135
+ log_level: str = Field(
136
+ default="INFO", description="Logging level"
137
+ )
138
+
139
+ def get_openai_client_kwargs(self) -> dict:
140
+ """Get kwargs for OpenAI client initialization (supports OpenRouter)."""
141
+ kwargs = {
142
+ "api_key": self.openai_api_key,
143
+ }
144
+
145
+ # If base_url is provided, use it (for OpenRouter or custom endpoints)
146
+ if self.openai_base_url:
147
+ kwargs["base_url"] = self.openai_base_url
148
+
149
+ # Add OpenRouter headers if configured
150
+ headers = {}
151
+ if self.openrouter_http_referer:
152
+ headers["HTTP-Referer"] = self.openrouter_http_referer
153
+ if self.openrouter_title:
154
+ headers["X-Title"] = self.openrouter_title
155
+
156
+ if headers:
157
+ kwargs["default_headers"] = headers
158
+
159
+ return kwargs
160
+
161
+ def get_chroma_client_kwargs(self) -> dict:
162
+ """Get kwargs for ChromaDB client initialization."""
163
+ return {
164
+ "path": self.chroma_db_path,
165
+ }
166
+
167
+ def has_web_search(self) -> bool:
168
+ """Check if web search is configured."""
169
+ return bool(self.tavily_api_key or self.serper_api_key)
170
+
171
+ def has_cloud_storage(self) -> bool:
172
+ """Check if cloud storage is configured."""
173
+ return bool(
174
+ (self.aws_access_key_id and self.aws_s3_bucket)
175
+ or (self.google_application_credentials and self.gcs_bucket_name)
176
+ )
177
+
178
+ def has_snowflake(self) -> bool:
179
+ """Check if Snowflake is configured."""
180
+ return bool(
181
+ self.snowflake_account
182
+ and self.snowflake_user
183
+ and self.snowflake_password
184
+ and self.snowflake_warehouse
185
+ and self.snowflake_database
186
+ )
187
+
188
+ def get_snowflake_config(self) -> Optional[Dict[str, Any]]:
189
+ """Get Snowflake configuration dictionary."""
190
+ if not self.has_snowflake():
191
+ return None
192
+
193
+ return {
194
+ "account": self.snowflake_account,
195
+ "user": self.snowflake_user,
196
+ "password": self.snowflake_password,
197
+ "warehouse": self.snowflake_warehouse,
198
+ "database": self.snowflake_database,
199
+ "schema": self.snowflake_schema,
200
+ "role": self.snowflake_role,
201
+ }
202
+
203
+
204
+ # Global settings instance
205
+ _settings: Optional[Settings] = None
206
+
207
+
208
+ def get_settings() -> Settings:
209
+ """Get or create the global settings instance."""
210
+ global _settings
211
+ if _settings is None:
212
+ _settings = Settings()
213
+ return _settings
214
+
215
+
216
+ def reset_settings() -> None:
217
+ """Reset the global settings instance (useful for testing)."""
218
+ global _settings
219
+ _settings = None
220
+
src/core/orchestrator.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Main orchestrator for coordinating all components."""
2
+
3
+ import logging
4
+ from typing import Dict, Any, Optional
5
+ from enum import Enum
6
+ from src.core.config import get_settings
7
+ from src.retrieval.vector_store import get_vector_store
8
+ from src.agents.local_data_agent import LocalDataAgent
9
+ from src.agents.search_agent import SearchAgent
10
+ from src.agents.cloud_agent import CloudAgent
11
+ from src.agents.aggregator_agent import AggregatorAgent
12
+ from src.agents.snowflake_agent import SnowflakeAgent
13
+ from src.tools.calculator import get_calculator
14
+ from src.tools.web_search import get_web_search
15
+ from src.tools.database_query import get_database_query
16
+ from openai import OpenAI
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class Tier(Enum):
22
+ """System tiers."""
23
+ BASIC_RAG = "basic"
24
+ AGENT_WITH_TOOLS = "agent"
25
+ ADVANCED_AGENTIC = "advanced"
26
+
27
+
28
+ class Orchestrator:
29
+ """Main orchestrator for the RAG system."""
30
+
31
+ def __init__(self):
32
+ """Initialize orchestrator."""
33
+ self.settings = get_settings()
34
+ self.client = OpenAI(**self.settings.get_openai_client_kwargs())
35
+ self.model = self.settings.openai_model
36
+
37
+ # Initialize components
38
+ self.vector_store = get_vector_store()
39
+
40
+ # Initialize agents (lazy loading)
41
+ self._local_agent: Optional[LocalDataAgent] = None
42
+ self._search_agent: Optional[SearchAgent] = None
43
+ self._cloud_agent: Optional[CloudAgent] = None
44
+ self._snowflake_agent: Optional[SnowflakeAgent] = None
45
+ self._aggregator_agent: Optional[AggregatorAgent] = None
46
+
47
+ # Initialize tools
48
+ self.calculator = get_calculator()
49
+ self.web_search = get_web_search()
50
+ self.database_query = get_database_query()
51
+
52
+ async def process_query(
53
+ self,
54
+ query: str,
55
+ tier: str = "basic",
56
+ session_id: Optional[str] = None,
57
+ ) -> Dict[str, Any]:
58
+ """
59
+ Process a query using the specified tier.
60
+
61
+ Args:
62
+ query: User query
63
+ tier: System tier ("basic", "agent", or "advanced")
64
+ session_id: Optional session ID for memory
65
+
66
+ Returns:
67
+ Response dictionary
68
+ """
69
+ try:
70
+ tier_enum = Tier(tier.lower())
71
+
72
+ if tier_enum == Tier.BASIC_RAG:
73
+ return await self._process_basic_rag(query, session_id)
74
+ elif tier_enum == Tier.AGENT_WITH_TOOLS:
75
+ return await self._process_agent_with_tools(query, session_id)
76
+ elif tier_enum == Tier.ADVANCED_AGENTIC:
77
+ return await self._process_advanced_agentic(query, session_id)
78
+ else:
79
+ raise ValueError(f"Unknown tier: {tier}")
80
+
81
+ except ValueError as e:
82
+ logger.error(f"Invalid tier: {e}")
83
+ return {
84
+ "success": False,
85
+ "error": f"Invalid tier: {tier}",
86
+ }
87
+ except Exception as e:
88
+ logger.error(f"Error processing query: {e}")
89
+ return {
90
+ "success": False,
91
+ "error": str(e),
92
+ }
93
+
94
+ async def _process_basic_rag(
95
+ self,
96
+ query: str,
97
+ session_id: Optional[str],
98
+ ) -> Dict[str, Any]:
99
+ """Process query using basic RAG (retrieval + generation)."""
100
+ try:
101
+ # Check if OpenAI API key is configured
102
+ if not self.settings.openai_api_key:
103
+ return {
104
+ "success": False,
105
+ "error": "OpenAI API key not configured. Please set OPENAI_API_KEY in your .env file.",
106
+ "tier": "basic",
107
+ }
108
+
109
+ # Retrieve relevant documents
110
+ results = self.vector_store.search(query=query, n_results=5)
111
+
112
+ # Build context - use retrieved documents if available, otherwise use empty context
113
+ if results["documents"]:
114
+ context_parts = ["Retrieved documents:"]
115
+ for i, (doc, metadata) in enumerate(
116
+ zip(results["documents"], results["metadatas"]), 1
117
+ ):
118
+ source = metadata.get("source", "Unknown")
119
+ context_parts.append(f"\n[{i}] Source: {source}")
120
+ # Ensure doc is a string
121
+ doc_str = str(doc) if doc else ""
122
+ context_parts.append(f"Content: {doc_str[:500]}...")
123
+ context = "\n".join(context_parts)
124
+ sources = [
125
+ {"id": id, "metadata": meta}
126
+ for id, meta in zip(results["ids"], results["metadatas"])
127
+ ]
128
+ else:
129
+ context = "No relevant documents found in the knowledge base."
130
+ sources = []
131
+
132
+ # Generate response using LLM
133
+ messages = [
134
+ {
135
+ "role": "system",
136
+ "content": "You are a helpful assistant that answers questions based on the provided context.",
137
+ },
138
+ {
139
+ "role": "user",
140
+ "content": f"Context:\n{context}\n\nQuestion: {query}",
141
+ },
142
+ ]
143
+
144
+ try:
145
+ response = self.client.chat.completions.create(
146
+ model=self.model,
147
+ messages=messages,
148
+ temperature=0.7,
149
+ )
150
+
151
+ answer = response.choices[0].message.content
152
+ except Exception as api_error:
153
+ error_msg = str(api_error)
154
+ if "quota" in error_msg.lower() or "429" in error_msg:
155
+ raise Exception("OpenAI API quota exceeded. Please check your billing and plan details.")
156
+ elif "api key" in error_msg.lower() or "401" in error_msg:
157
+ raise Exception("Invalid OpenAI API key. Please check your .env file.")
158
+ else:
159
+ raise Exception(f"OpenAI API error: {error_msg}")
160
+
161
+ return {
162
+ "success": True,
163
+ "answer": answer,
164
+ "tier": "basic",
165
+ "sources": sources,
166
+ "model": self.model,
167
+ }
168
+
169
+ except Exception as e:
170
+ logger.error(f"Error in basic RAG: {e}", exc_info=True)
171
+ return {
172
+ "success": False,
173
+ "error": f"Error processing query: {str(e)}",
174
+ "tier": "basic",
175
+ }
176
+
177
+ async def _process_agent_with_tools(
178
+ self,
179
+ query: str,
180
+ session_id: Optional[str],
181
+ ) -> Dict[str, Any]:
182
+ """Process query using agent with tools."""
183
+ try:
184
+ # Check if OpenAI API key is configured
185
+ if not self.settings.openai_api_key:
186
+ return {
187
+ "success": False,
188
+ "error": "OpenAI API key not configured. Please set OPENAI_API_KEY in your .env file.",
189
+ "tier": "agent",
190
+ }
191
+
192
+ # Use local agent with tools enabled
193
+ if not self._local_agent:
194
+ self._local_agent = LocalDataAgent(use_planning=True)
195
+
196
+ # Add tools to agent
197
+ self._local_agent.add_tool(
198
+ tool=self.calculator.get_tool_schema(),
199
+ tool_function=lambda expression: self.calculator.calculate(expression),
200
+ )
201
+
202
+ if self.settings.has_web_search():
203
+ async def web_search_tool(query: str, max_results: int = 5):
204
+ return await self.web_search.search(query, max_results)
205
+
206
+ self._local_agent.add_tool(
207
+ tool=self.web_search.get_tool_schema(),
208
+ tool_function=web_search_tool,
209
+ )
210
+
211
+ if self.settings.database_url:
212
+ def db_query_tool(sql: str, limit: int = 100):
213
+ return self.database_query.query(sql, limit)
214
+
215
+ self._local_agent.add_tool(
216
+ tool=self.database_query.get_tool_schema(),
217
+ tool_function=db_query_tool,
218
+ )
219
+
220
+ # Process query
221
+ response = await self._local_agent.process(query, session_id)
222
+
223
+ return {
224
+ **response,
225
+ "tier": "agent",
226
+ }
227
+
228
+ except Exception as e:
229
+ logger.error(f"Error in agent with tools: {e}", exc_info=True)
230
+ return {
231
+ "success": False,
232
+ "error": f"Error processing query: {str(e)}",
233
+ "tier": "agent",
234
+ }
235
+
236
+ async def _process_advanced_agentic(
237
+ self,
238
+ query: str,
239
+ session_id: Optional[str],
240
+ ) -> Dict[str, Any]:
241
+ """Process query using advanced agentic RAG with multiple agents."""
242
+ try:
243
+ # Check if OpenAI API key is configured
244
+ if not self.settings.openai_api_key:
245
+ return {
246
+ "success": False,
247
+ "error": "OpenAI API key not configured. Please set OPENAI_API_KEY in your .env file.",
248
+ "tier": "advanced",
249
+ }
250
+
251
+ # Use aggregator agent
252
+ if not self._aggregator_agent:
253
+ self._aggregator_agent = AggregatorAgent(use_planning=True)
254
+
255
+ # Add Snowflake agent if configured
256
+ if self.settings.has_snowflake() and not self._snowflake_agent:
257
+ snowflake_config = self.settings.get_snowflake_config()
258
+ self._snowflake_agent = SnowflakeAgent(
259
+ snowflake_config=snowflake_config,
260
+ use_planning=False
261
+ )
262
+ # Note: AggregatorAgent will automatically discover SnowflakeAgent
263
+ # through its agent selection logic
264
+
265
+ # Process query
266
+ response = await self._aggregator_agent.process(query, session_id)
267
+
268
+ return {
269
+ **response,
270
+ "tier": "advanced",
271
+ }
272
+
273
+ except Exception as e:
274
+ logger.error(f"Error in advanced agentic: {e}", exc_info=True)
275
+ return {
276
+ "success": False,
277
+ "error": f"Error processing query: {str(e)}",
278
+ "tier": "advanced",
279
+ }
280
+
281
+ def get_agent_status(self) -> Dict[str, Any]:
282
+ """Get status of all agents."""
283
+ status = {
284
+ "tiers_available": ["basic", "agent", "advanced"],
285
+ "agents": {},
286
+ }
287
+
288
+ if self._local_agent:
289
+ status["agents"]["local"] = self._local_agent.get_status()
290
+ if self._search_agent:
291
+ status["agents"]["search"] = self._search_agent.get_status()
292
+ if self._cloud_agent:
293
+ status["agents"]["cloud"] = self._cloud_agent.get_status()
294
+ if self._snowflake_agent:
295
+ status["agents"]["snowflake"] = self._snowflake_agent.get_status()
296
+ if self._aggregator_agent:
297
+ status["agents"]["aggregator"] = self._aggregator_agent.get_status()
298
+
299
+ return status
300
+
301
+ def get_system_info(self) -> Dict[str, Any]:
302
+ """Get system information."""
303
+ return {
304
+ "vector_store": {
305
+ "document_count": self.vector_store.count(),
306
+ "collection_name": self.settings.chroma_collection_name,
307
+ },
308
+ "tools": {
309
+ "calculator": True,
310
+ "web_search": self.settings.has_web_search(),
311
+ "database": bool(self.settings.database_url),
312
+ "snowflake": self.settings.has_snowflake(),
313
+ },
314
+ "memory": {
315
+ "short_term_enabled": True,
316
+ "long_term_enabled": self.settings.long_term_memory_enabled,
317
+ },
318
+ "model": self.model,
319
+ }
320
+
321
+
322
+ # Global instance
323
+ _orchestrator: Optional[Orchestrator] = None
324
+
325
+
326
+ def get_orchestrator() -> Orchestrator:
327
+ """Get or create the global orchestrator instance."""
328
+ global _orchestrator
329
+ if _orchestrator is None:
330
+ _orchestrator = Orchestrator()
331
+ return _orchestrator
332
+
src/mcp/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ """MCP server implementations."""
2
+
src/mcp/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (203 Bytes). View file
 
src/mcp/__pycache__/mcp_server.cpython-311.pyc ADDED
Binary file (5.44 kB). View file
 
src/mcp/__pycache__/snowflake_server.cpython-311.pyc ADDED
Binary file (10 kB). View file
 
src/mcp/cloud_server.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Cloud storage MCP server."""
2
+
3
+ import logging
4
+ from typing import Any, Dict
5
+
6
+ try:
7
+ from mcp.types import Tool
8
+ MCP_AVAILABLE = True
9
+ except ImportError:
10
+ MCP_AVAILABLE = False
11
+ class Tool:
12
+ def __init__(self, **kwargs):
13
+ pass
14
+
15
+ from src.mcp.mcp_server import BaseMCPServer
16
+ from src.core.config import get_settings
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class CloudMCPServer(BaseMCPServer):
22
+ """MCP server for cloud storage operations."""
23
+
24
+ def __init__(self):
25
+ """Initialize cloud MCP server."""
26
+ super().__init__("cloud_storage_server")
27
+ self.settings = get_settings()
28
+ self._init_cloud_client()
29
+ self._register_tools()
30
+
31
+ def _init_cloud_client(self):
32
+ """Initialize cloud storage client."""
33
+ self.cloud_type = None
34
+ self.client = None
35
+
36
+ # Check for AWS S3
37
+ if self.settings.aws_access_key_id and self.settings.aws_s3_bucket:
38
+ try:
39
+ import boto3
40
+ self.client = boto3.client(
41
+ "s3",
42
+ aws_access_key_id=self.settings.aws_access_key_id,
43
+ aws_secret_access_key=self.settings.aws_secret_access_key,
44
+ region_name=self.settings.aws_region,
45
+ )
46
+ self.cloud_type = "s3"
47
+ self.bucket_name = self.settings.aws_s3_bucket
48
+ except Exception as e:
49
+ logger.error(f"Error initializing S3: {e}")
50
+
51
+ # Check for GCS
52
+ elif self.settings.google_application_credentials and self.settings.gcs_bucket_name:
53
+ try:
54
+ from google.cloud import storage
55
+ import os
56
+ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.settings.google_application_credentials
57
+ self.client = storage.Client()
58
+ self.cloud_type = "gcs"
59
+ self.bucket_name = self.settings.gcs_bucket_name
60
+ except Exception as e:
61
+ logger.error(f"Error initializing GCS: {e}")
62
+
63
+ def _register_tools(self):
64
+ """Register cloud storage tools."""
65
+ if not self.client:
66
+ logger.warning("No cloud storage configured, skipping tool registration")
67
+ return
68
+
69
+ # List objects tool
70
+ list_tool = Tool(
71
+ name="list_cloud_objects",
72
+ description="List objects in cloud storage",
73
+ inputSchema={
74
+ "type": "object",
75
+ "properties": {
76
+ "prefix": {
77
+ "type": "string",
78
+ "description": "Object key prefix to filter",
79
+ },
80
+ "max_keys": {
81
+ "type": "integer",
82
+ "description": "Maximum number of objects to return",
83
+ "default": 10,
84
+ },
85
+ },
86
+ },
87
+ )
88
+ self.register_tool(list_tool)
89
+
90
+ # Get object tool
91
+ get_tool = Tool(
92
+ name="get_cloud_object",
93
+ description="Get an object from cloud storage",
94
+ inputSchema={
95
+ "type": "object",
96
+ "properties": {
97
+ "key": {
98
+ "type": "string",
99
+ "description": "Object key",
100
+ },
101
+ },
102
+ "required": ["key"],
103
+ },
104
+ )
105
+ self.register_tool(get_tool)
106
+
107
+ async def _execute_tool(self, name: str, arguments: Dict[str, Any]) -> Any:
108
+ """Execute a cloud storage tool."""
109
+ if not self.client:
110
+ return {"error": "Cloud storage not configured"}
111
+
112
+ if name == "list_cloud_objects":
113
+ prefix = arguments.get("prefix", "")
114
+ max_keys = arguments.get("max_keys", 10)
115
+
116
+ if self.cloud_type == "s3":
117
+ response = self.client.list_objects_v2(
118
+ Bucket=self.bucket_name,
119
+ Prefix=prefix,
120
+ MaxKeys=max_keys,
121
+ )
122
+ objects = [
123
+ {"key": obj["Key"], "size": obj["Size"]}
124
+ for obj in response.get("Contents", [])
125
+ ]
126
+ return {"objects": objects, "count": len(objects)}
127
+
128
+ elif self.cloud_type == "gcs":
129
+ bucket = self.client.bucket(self.bucket_name)
130
+ blobs = list(bucket.list_blobs(prefix=prefix, max_results=max_keys))
131
+ objects = [{"key": blob.name, "size": blob.size} for blob in blobs]
132
+ return {"objects": objects, "count": len(objects)}
133
+
134
+ elif name == "get_cloud_object":
135
+ key = arguments.get("key")
136
+
137
+ if self.cloud_type == "s3":
138
+ try:
139
+ response = self.client.get_object(Bucket=self.bucket_name, Key=key)
140
+ content = response["Body"].read().decode("utf-8")
141
+ return {"key": key, "content": content}
142
+ except Exception as e:
143
+ return {"error": str(e)}
144
+
145
+ elif self.cloud_type == "gcs":
146
+ try:
147
+ bucket = self.client.bucket(self.bucket_name)
148
+ blob = bucket.blob(key)
149
+ content = blob.download_as_text()
150
+ return {"key": key, "content": content}
151
+ except Exception as e:
152
+ return {"error": str(e)}
153
+
154
+ else:
155
+ raise ValueError(f"Unknown tool: {name}")
156
+
src/mcp/local_server.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Local data MCP server."""
2
+
3
+ import logging
4
+ from typing import Any, Dict
5
+
6
+ try:
7
+ from mcp.types import Tool
8
+ MCP_AVAILABLE = True
9
+ except ImportError:
10
+ MCP_AVAILABLE = False
11
+ # Create a mock Tool class for type hints
12
+ class Tool:
13
+ def __init__(self, **kwargs):
14
+ pass
15
+
16
+ from src.mcp.mcp_server import BaseMCPServer
17
+ from src.retrieval.vector_store import get_vector_store
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class LocalMCPServer(BaseMCPServer):
23
+ """MCP server for local document operations."""
24
+
25
+ def __init__(self):
26
+ """Initialize local MCP server."""
27
+ super().__init__("local_data_server")
28
+ self.vector_store = get_vector_store()
29
+ self._register_tools()
30
+
31
+ def _register_tools(self):
32
+ """Register local data tools."""
33
+ # Search documents tool
34
+ search_tool = Tool(
35
+ name="search_local_documents",
36
+ description="Search local documents in the vector store",
37
+ inputSchema={
38
+ "type": "object",
39
+ "properties": {
40
+ "query": {
41
+ "type": "string",
42
+ "description": "Search query",
43
+ },
44
+ "n_results": {
45
+ "type": "integer",
46
+ "description": "Number of results to return",
47
+ "default": 5,
48
+ },
49
+ },
50
+ "required": ["query"],
51
+ },
52
+ )
53
+ self.register_tool(search_tool)
54
+
55
+ # Get document by ID tool
56
+ get_doc_tool = Tool(
57
+ name="get_local_document",
58
+ description="Get a document by its ID",
59
+ inputSchema={
60
+ "type": "object",
61
+ "properties": {
62
+ "document_id": {
63
+ "type": "string",
64
+ "description": "Document ID",
65
+ },
66
+ },
67
+ "required": ["document_id"],
68
+ },
69
+ )
70
+ self.register_tool(get_doc_tool)
71
+
72
+ # List documents tool
73
+ list_docs_tool = Tool(
74
+ name="list_local_documents",
75
+ description="List all documents in the vector store",
76
+ inputSchema={
77
+ "type": "object",
78
+ "properties": {
79
+ "limit": {
80
+ "type": "integer",
81
+ "description": "Maximum number of documents to return",
82
+ "default": 10,
83
+ },
84
+ },
85
+ },
86
+ )
87
+ self.register_tool(list_docs_tool)
88
+
89
+ async def _execute_tool(self, name: str, arguments: Dict[str, Any]) -> Any:
90
+ """Execute a local data tool."""
91
+ if name == "search_local_documents":
92
+ query = arguments.get("query", "")
93
+ n_results = arguments.get("n_results", 5)
94
+ results = self.vector_store.search(query=query, n_results=n_results)
95
+ return {
96
+ "documents": results["documents"],
97
+ "ids": results["ids"],
98
+ "metadatas": results["metadatas"],
99
+ }
100
+
101
+ elif name == "get_local_document":
102
+ document_id = arguments.get("document_id")
103
+ results = self.vector_store.get_by_ids([document_id])
104
+ if results["documents"]:
105
+ return {
106
+ "document": results["documents"][0],
107
+ "metadata": results["metadatas"][0] if results["metadatas"] else {},
108
+ }
109
+ else:
110
+ return {"error": "Document not found"}
111
+
112
+ elif name == "list_local_documents":
113
+ limit = arguments.get("limit", 10)
114
+ count = self.vector_store.count()
115
+ return {
116
+ "total_documents": count,
117
+ "limit": limit,
118
+ }
119
+
120
+ else:
121
+ raise ValueError(f"Unknown tool: {name}")
122
+
src/mcp/mcp_server.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base MCP server implementation."""
2
+
3
+ import logging
4
+ from typing import Any, Dict, List, Optional
5
+ import asyncio
6
+
7
+ # Try to import MCP SDK - adjust imports based on actual SDK version
8
+ try:
9
+ from mcp.server import Server
10
+ from mcp.server.stdio import stdio_server
11
+ from mcp.types import Tool, TextContent
12
+ MCP_AVAILABLE = True
13
+ except ImportError:
14
+ # Fallback if MCP SDK structure is different
15
+ MCP_AVAILABLE = False
16
+ logger.warning("MCP SDK not available - MCP servers will not function")
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class BaseMCPServer:
22
+ """Base MCP server with common functionality."""
23
+
24
+ def __init__(self, name: str):
25
+ """Initialize base MCP server."""
26
+ self.name = name
27
+ if not MCP_AVAILABLE:
28
+ logger.warning(f"MCP SDK not available - {name} server cannot be initialized")
29
+ self.server = None
30
+ self.tools: List[Any] = []
31
+ return
32
+
33
+ self.server = Server(name)
34
+ self.tools: List[Any] = []
35
+ self._setup_handlers()
36
+
37
+ def _setup_handlers(self):
38
+ """Setup MCP server handlers."""
39
+ if not self.server:
40
+ return
41
+
42
+ @self.server.list_tools()
43
+ async def list_tools() -> List[Any]:
44
+ """List available tools."""
45
+ return self.tools
46
+
47
+ @self.server.call_tool()
48
+ async def call_tool(name: str, arguments: Dict[str, Any]) -> List[Any]:
49
+ """Call a tool by name."""
50
+ try:
51
+ result = await self._execute_tool(name, arguments)
52
+ return [TextContent(type="text", text=str(result))]
53
+ except Exception as e:
54
+ logger.error(f"Error executing tool {name}: {e}")
55
+ return [TextContent(type="text", text=f"Error: {str(e)}")]
56
+
57
+ async def _execute_tool(self, name: str, arguments: Dict[str, Any]) -> Any:
58
+ """Execute a tool - to be overridden by subclasses."""
59
+ raise NotImplementedError("Subclasses must implement _execute_tool")
60
+
61
+ def register_tool(self, tool: Any):
62
+ """Register a tool with the server."""
63
+ self.tools.append(tool)
64
+ logger.info(f"Registered tool: {tool.name if hasattr(tool, 'name') else 'unknown'}")
65
+
66
+ async def run(self):
67
+ """Run the MCP server."""
68
+ if not self.server or not MCP_AVAILABLE:
69
+ logger.error("Cannot run MCP server - SDK not available")
70
+ return
71
+
72
+ async with stdio_server() as (read_stream, write_stream):
73
+ await self.server.run(
74
+ read_stream,
75
+ write_stream,
76
+ self.server.create_initialization_options(),
77
+ )
78
+
src/mcp/search_server.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Web search MCP server."""
2
+
3
+ import logging
4
+ from typing import Any, Dict
5
+
6
+ try:
7
+ from mcp.types import Tool
8
+ MCP_AVAILABLE = True
9
+ except ImportError:
10
+ MCP_AVAILABLE = False
11
+ class Tool:
12
+ def __init__(self, **kwargs):
13
+ pass
14
+
15
+ from src.mcp.mcp_server import BaseMCPServer
16
+ from src.tools.web_search import get_web_search
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class SearchMCPServer(BaseMCPServer):
22
+ """MCP server for web search operations."""
23
+
24
+ def __init__(self):
25
+ """Initialize search MCP server."""
26
+ super().__init__("web_search_server")
27
+ self.web_search = get_web_search()
28
+ self._register_tools()
29
+
30
+ def _register_tools(self):
31
+ """Register web search tools."""
32
+ search_tool = Tool(
33
+ name="web_search",
34
+ description="Search the web for information",
35
+ inputSchema={
36
+ "type": "object",
37
+ "properties": {
38
+ "query": {
39
+ "type": "string",
40
+ "description": "Search query",
41
+ },
42
+ "max_results": {
43
+ "type": "integer",
44
+ "description": "Maximum number of results",
45
+ "default": 5,
46
+ },
47
+ },
48
+ "required": ["query"],
49
+ },
50
+ )
51
+ self.register_tool(search_tool)
52
+
53
+ async def _execute_tool(self, name: str, arguments: Dict[str, Any]) -> Any:
54
+ """Execute a web search tool."""
55
+ if name == "web_search":
56
+ query = arguments.get("query", "")
57
+ max_results = arguments.get("max_results", 5)
58
+ results = await self.web_search.search(query, max_results)
59
+ return results
60
+ else:
61
+ raise ValueError(f"Unknown tool: {name}")
62
+
src/mcp/snowflake_server.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MCP Server for Snowflake data warehouse."""
2
+
3
+ import logging
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ try:
7
+ from mcp.types import Tool
8
+ MCP_AVAILABLE = True
9
+ except ImportError:
10
+ MCP_AVAILABLE = False
11
+ class Tool:
12
+ def __init__(self, **kwargs):
13
+ pass
14
+
15
+ from src.mcp.mcp_server import BaseMCPServer
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ try:
20
+ import snowflake.connector
21
+ import pandas as pd
22
+ SNOWFLAKE_AVAILABLE = True
23
+ except ImportError:
24
+ SNOWFLAKE_AVAILABLE = False
25
+ logger.warning("snowflake-connector-python not installed")
26
+
27
+
28
+ class SnowflakeMCPServer(BaseMCPServer):
29
+ """MCP Server for Snowflake data warehouse operations."""
30
+
31
+ def __init__(self, config: Optional[Dict] = None):
32
+ """Initialize Snowflake MCP server."""
33
+ super().__init__("snowflake_server")
34
+ self.config = config or {}
35
+ self.connection = None
36
+ self.cursor = None
37
+ if SNOWFLAKE_AVAILABLE:
38
+ self._register_tools()
39
+
40
+ def _register_tools(self):
41
+ """Register Snowflake tools with MCP server."""
42
+ if not SNOWFLAKE_AVAILABLE:
43
+ logger.warning("Snowflake connector not available, skipping tool registration")
44
+ return
45
+
46
+ # Query tool
47
+ query_tool = Tool(
48
+ name="snowflake_query",
49
+ description="Execute SQL query on Snowflake data warehouse",
50
+ inputSchema={
51
+ "type": "object",
52
+ "properties": {
53
+ "sql": {
54
+ "type": "string",
55
+ "description": "SQL query to execute",
56
+ },
57
+ },
58
+ "required": ["sql"],
59
+ },
60
+ )
61
+ self.register_tool(query_tool)
62
+
63
+ # List tables tool
64
+ list_tables_tool = Tool(
65
+ name="snowflake_list_tables",
66
+ description="List all tables in the current schema",
67
+ inputSchema={"type": "object", "properties": {}},
68
+ )
69
+ self.register_tool(list_tables_tool)
70
+
71
+ # Get table schema tool
72
+ schema_tool = Tool(
73
+ name="snowflake_get_schema",
74
+ description="Get schema information for a table",
75
+ inputSchema={
76
+ "type": "object",
77
+ "properties": {
78
+ "table_name": {
79
+ "type": "string",
80
+ "description": "Name of the table",
81
+ },
82
+ },
83
+ "required": ["table_name"],
84
+ },
85
+ )
86
+ self.register_tool(schema_tool)
87
+
88
+ def connect(self):
89
+ """Establish connection to Snowflake."""
90
+ if not SNOWFLAKE_AVAILABLE:
91
+ return False
92
+
93
+ try:
94
+ self.connection = snowflake.connector.connect(
95
+ account=self.config.get('account'),
96
+ user=self.config.get('user'),
97
+ password=self.config.get('password'),
98
+ warehouse=self.config.get('warehouse'),
99
+ database=self.config.get('database'),
100
+ schema=self.config.get('schema'),
101
+ role=self.config.get('role', 'ACCOUNTADMIN'),
102
+ )
103
+ self.cursor = self.connection.cursor()
104
+ logger.info(f"Connected to Snowflake account: {self.config.get('account')}")
105
+ return True
106
+ except Exception as e:
107
+ logger.error(f"Snowflake connection failed: {e}")
108
+ return False
109
+
110
+ def query(self, sql_query: str) -> List[Dict]:
111
+ """Execute SQL query on Snowflake."""
112
+ if not SNOWFLAKE_AVAILABLE:
113
+ return [{"error": "Snowflake connector not available"}]
114
+
115
+ if not self.connection:
116
+ if not self.connect():
117
+ return [{"error": "Failed to connect to Snowflake"}]
118
+
119
+ try:
120
+ self.cursor.execute(sql_query)
121
+ columns = [desc[0] for desc in self.cursor.description]
122
+ results = self.cursor.fetchall()
123
+ return [dict(zip(columns, row)) for row in results]
124
+ except Exception as e:
125
+ logger.error(f"Query error: {e}")
126
+ return [{"error": str(e), "query": sql_query}]
127
+
128
+ def get_tables(self) -> List[str]:
129
+ """List all tables in the current schema."""
130
+ if not self.config.get('database') or not self.config.get('schema'):
131
+ return []
132
+
133
+ query = f"""
134
+ SELECT TABLE_NAME
135
+ FROM {self.config['database']}.INFORMATION_SCHEMA.TABLES
136
+ WHERE TABLE_SCHEMA = '{self.config['schema']}'
137
+ """
138
+ results = self.query(query)
139
+ return [row['TABLE_NAME'] for row in results if 'TABLE_NAME' in row]
140
+
141
+ def get_table_schema(self, table_name: str) -> List[Dict]:
142
+ """Get schema information for a table."""
143
+ if not self.config.get('database') or not self.config.get('schema'):
144
+ return []
145
+
146
+ query = f"""
147
+ SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE
148
+ FROM {self.config['database']}.INFORMATION_SCHEMA.COLUMNS
149
+ WHERE TABLE_SCHEMA = '{self.config['schema']}'
150
+ AND TABLE_NAME = '{table_name}'
151
+ """
152
+ return self.query(query)
153
+
154
+ async def _execute_tool(self, name: str, arguments: Dict[str, Any]) -> Any:
155
+ """Execute a Snowflake tool."""
156
+ if not self.config:
157
+ return {"error": "Snowflake configuration not provided"}
158
+
159
+ if name == "snowflake_query":
160
+ sql = arguments.get("sql", "")
161
+ return {"results": self.query(sql)}
162
+
163
+ elif name == "snowflake_list_tables":
164
+ return {"tables": self.get_tables()}
165
+
166
+ elif name == "snowflake_get_schema":
167
+ table_name = arguments.get("table_name")
168
+ if not table_name:
169
+ return {"error": "table_name is required"}
170
+ return {"schema": self.get_table_schema(table_name)}
171
+
172
+ else:
173
+ raise ValueError(f"Unknown tool: {name}")
174
+
175
+ def close(self):
176
+ """Close Snowflake connection."""
177
+ if self.cursor:
178
+ self.cursor.close()
179
+ if self.connection:
180
+ self.connection.close()
181
+
182
+ def __del__(self):
183
+ """Cleanup on deletion."""
184
+ self.close()
185
+
src/memory/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ """Memory management system."""
2
+