Spaces:
Sleeping
Sleeping
DevLujain
commited on
Commit
Β·
068aa4e
1
Parent(s):
0ff7caf
Deploy FYP dashboard
Browse files- .env +1 -0
- Dockerfile +19 -0
- requirements.txt +10 -0
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-310.pyc +0 -0
- src/__pycache__/agent_orchestrator.cpython-310.pyc +0 -0
- src/__pycache__/api.cpython-310.pyc +0 -0
- src/__pycache__/hybrid_search.cpython-310.pyc +0 -0
- src/__pycache__/query_agent.cpython-310.pyc +0 -0
- src/__pycache__/rag_system.cpython-310.pyc +0 -0
- src/__pycache__/retrieval_agent.cpython-310.pyc +0 -0
- src/__pycache__/synthesis_agent.cpython-310.pyc +0 -0
- src/__pycache__/validation_agent.cpython-310.pyc +0 -0
- src/agent_orchestrator.py +211 -0
- src/api.py +178 -0
- src/dashboard.py +121 -0
- src/document_processor.py +109 -0
- src/hybrid_search.py +103 -0
- src/query_agent.py +95 -0
- src/rag_system.py +164 -0
- src/retrieval_agent.py +234 -0
- src/synthesis_agent.py +127 -0
- src/validation_agent.py +205 -0
- src/vector_database.py +149 -0
.env
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
GROQ_API_KEY=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
|
Dockerfile
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
RUN apt-get update && apt-get install -y curl && rm -rf /var/lib/apt/lists/*
|
| 6 |
+
|
| 7 |
+
COPY requirements.txt .
|
| 8 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 9 |
+
|
| 10 |
+
COPY . .
|
| 11 |
+
|
| 12 |
+
RUN mkdir -p /app/data/vectordb
|
| 13 |
+
|
| 14 |
+
ENV PYTHONUNBUFFERED=1
|
| 15 |
+
ENV CHROMADB_PERSIST_DIR=/app/data/vectordb
|
| 16 |
+
|
| 17 |
+
EXPOSE 7860
|
| 18 |
+
|
| 19 |
+
CMD ["streamlit", "run", "src/dashboard.py", "--server.port=7860", "--server.address=0.0.0.0"]
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn[standard]
|
| 3 |
+
langchain
|
| 4 |
+
chromadb
|
| 5 |
+
sentence-transformers
|
| 6 |
+
groq
|
| 7 |
+
python-dotenv
|
| 8 |
+
PyPDF2
|
| 9 |
+
redis
|
| 10 |
+
rank-bm25
|
src/__init__.py
ADDED
|
File without changes
|
src/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (126 Bytes). View file
|
|
|
src/__pycache__/agent_orchestrator.cpython-310.pyc
ADDED
|
Binary file (6.09 kB). View file
|
|
|
src/__pycache__/api.cpython-310.pyc
ADDED
|
Binary file (4.8 kB). View file
|
|
|
src/__pycache__/hybrid_search.cpython-310.pyc
ADDED
|
Binary file (2.62 kB). View file
|
|
|
src/__pycache__/query_agent.cpython-310.pyc
ADDED
|
Binary file (3.03 kB). View file
|
|
|
src/__pycache__/rag_system.cpython-310.pyc
ADDED
|
Binary file (5.09 kB). View file
|
|
|
src/__pycache__/retrieval_agent.cpython-310.pyc
ADDED
|
Binary file (6.18 kB). View file
|
|
|
src/__pycache__/synthesis_agent.cpython-310.pyc
ADDED
|
Binary file (3.93 kB). View file
|
|
|
src/__pycache__/validation_agent.cpython-310.pyc
ADDED
|
Binary file (5.84 kB). View file
|
|
|
src/agent_orchestrator.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Agent Orchestrator
|
| 3 |
+
Connects all agents using LangGraph workflow
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
from langgraph.graph import StateGraph, START, END
|
| 8 |
+
from typing import TypedDict, List, Dict
|
| 9 |
+
|
| 10 |
+
load_dotenv()
|
| 11 |
+
|
| 12 |
+
class AgentState(TypedDict):
|
| 13 |
+
"""State passed between agents"""
|
| 14 |
+
original_query: str
|
| 15 |
+
reformulated_query: str
|
| 16 |
+
retrieved_documents: List[Dict]
|
| 17 |
+
synthesized_answer: str
|
| 18 |
+
validation_result: Dict
|
| 19 |
+
final_answer: str
|
| 20 |
+
metadata: Dict
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class AgentOrchestrator:
|
| 24 |
+
def __init__(self, rag_system):
|
| 25 |
+
"""Initialize orchestrator with RAG system"""
|
| 26 |
+
print("π Initializing Agent Orchestrator...\n")
|
| 27 |
+
|
| 28 |
+
self.rag = rag_system
|
| 29 |
+
self.workflow = self._build_workflow()
|
| 30 |
+
|
| 31 |
+
print("β
Agent Orchestrator ready!\n")
|
| 32 |
+
|
| 33 |
+
def _build_workflow(self):
|
| 34 |
+
"""Build LangGraph workflow"""
|
| 35 |
+
workflow = StateGraph(AgentState)
|
| 36 |
+
|
| 37 |
+
# Define nodes
|
| 38 |
+
workflow.add_node("query_understanding", self._query_understanding_node)
|
| 39 |
+
workflow.add_node("retrieval", self._retrieval_node)
|
| 40 |
+
workflow.add_node("synthesis", self._synthesis_node)
|
| 41 |
+
workflow.add_node("validation", self._validation_node)
|
| 42 |
+
workflow.add_node("finalize", self._finalize_node)
|
| 43 |
+
|
| 44 |
+
# Define edges
|
| 45 |
+
workflow.add_edge(START, "query_understanding")
|
| 46 |
+
workflow.add_edge("query_understanding", "retrieval")
|
| 47 |
+
workflow.add_edge("retrieval", "synthesis")
|
| 48 |
+
workflow.add_edge("synthesis", "validation")
|
| 49 |
+
workflow.add_edge("validation", "finalize")
|
| 50 |
+
workflow.add_edge("finalize", END)
|
| 51 |
+
|
| 52 |
+
return workflow.compile()
|
| 53 |
+
|
| 54 |
+
def _query_understanding_node(self, state: AgentState) -> AgentState:
|
| 55 |
+
"""Query Understanding Agent Node"""
|
| 56 |
+
print("\n" + "=" * 70)
|
| 57 |
+
print("π§ AGENT 1: QUERY UNDERSTANDING")
|
| 58 |
+
print("=" * 70)
|
| 59 |
+
|
| 60 |
+
original_query = state["original_query"]
|
| 61 |
+
reformulated_query = self.rag.query_agent.reformulate_query(original_query)
|
| 62 |
+
|
| 63 |
+
state["reformulated_query"] = reformulated_query
|
| 64 |
+
state["metadata"]["query_understanding_time"] = 0
|
| 65 |
+
|
| 66 |
+
return state
|
| 67 |
+
|
| 68 |
+
def _retrieval_node(self, state: AgentState) -> AgentState:
|
| 69 |
+
"""Multi-Source Retrieval Agent Node"""
|
| 70 |
+
print("\n" + "=" * 70)
|
| 71 |
+
print("π AGENT 2: MULTI-SOURCE RETRIEVAL")
|
| 72 |
+
print("=" * 70)
|
| 73 |
+
|
| 74 |
+
reformulated_query = state["reformulated_query"]
|
| 75 |
+
retrieved_results = self.rag.retrieval_agent.retrieve(reformulated_query, top_k=5)
|
| 76 |
+
|
| 77 |
+
# Convert to document format
|
| 78 |
+
documents = []
|
| 79 |
+
for result in retrieved_results:
|
| 80 |
+
documents.append({
|
| 81 |
+
'content': result['content'],
|
| 82 |
+
'source': result.get('source', 'unknown'),
|
| 83 |
+
'score': result['score']
|
| 84 |
+
})
|
| 85 |
+
|
| 86 |
+
state["retrieved_documents"] = documents
|
| 87 |
+
state["metadata"]["num_documents_retrieved"] = len(documents)
|
| 88 |
+
|
| 89 |
+
return state
|
| 90 |
+
|
| 91 |
+
def _synthesis_node(self, state: AgentState) -> AgentState:
|
| 92 |
+
"""Synthesis Agent Node"""
|
| 93 |
+
print("\n" + "=" * 70)
|
| 94 |
+
print("𧬠AGENT 3: SYNTHESIS")
|
| 95 |
+
print("=" * 70)
|
| 96 |
+
|
| 97 |
+
original_query = state["original_query"]
|
| 98 |
+
documents = state["retrieved_documents"]
|
| 99 |
+
|
| 100 |
+
synthesized_answer = self.rag.synthesis_agent.synthesize(
|
| 101 |
+
original_query,
|
| 102 |
+
documents
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
state["synthesized_answer"] = synthesized_answer
|
| 106 |
+
|
| 107 |
+
return state
|
| 108 |
+
|
| 109 |
+
def _validation_node(self, state: AgentState) -> AgentState:
|
| 110 |
+
"""Validation Agent Node"""
|
| 111 |
+
print("\n" + "=" * 70)
|
| 112 |
+
print("β
AGENT 4: VALIDATION")
|
| 113 |
+
print("=" * 70)
|
| 114 |
+
|
| 115 |
+
answer = state["synthesized_answer"]
|
| 116 |
+
documents = state["retrieved_documents"]
|
| 117 |
+
|
| 118 |
+
validation_result = self.rag.validation_agent.validate(answer, documents)
|
| 119 |
+
|
| 120 |
+
state["validation_result"] = validation_result
|
| 121 |
+
|
| 122 |
+
return state
|
| 123 |
+
|
| 124 |
+
def _finalize_node(self, state: AgentState) -> AgentState:
|
| 125 |
+
"""Finalize and format response"""
|
| 126 |
+
print("\n" + "=" * 70)
|
| 127 |
+
print("π FINALIZATION")
|
| 128 |
+
print("=" * 70 + "\n")
|
| 129 |
+
|
| 130 |
+
state["final_answer"] = state["synthesized_answer"]
|
| 131 |
+
|
| 132 |
+
return state
|
| 133 |
+
|
| 134 |
+
def run(self, query: str) -> Dict:
|
| 135 |
+
"""Run complete agent orchestration workflow"""
|
| 136 |
+
print("\n" + "=" * 80)
|
| 137 |
+
print("π MULTI-AGENT ORCHESTRATION WORKFLOW")
|
| 138 |
+
print("=" * 80)
|
| 139 |
+
print(f"\nINPUT QUERY: {query}\n")
|
| 140 |
+
|
| 141 |
+
# Initialize state
|
| 142 |
+
initial_state = {
|
| 143 |
+
"original_query": query,
|
| 144 |
+
"reformulated_query": "",
|
| 145 |
+
"retrieved_documents": [],
|
| 146 |
+
"synthesized_answer": "",
|
| 147 |
+
"validation_result": {},
|
| 148 |
+
"final_answer": "",
|
| 149 |
+
"metadata": {}
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
# Run workflow
|
| 153 |
+
final_state = self.workflow.invoke(initial_state)
|
| 154 |
+
|
| 155 |
+
# Format and display results
|
| 156 |
+
self._display_results(final_state)
|
| 157 |
+
|
| 158 |
+
return final_state
|
| 159 |
+
|
| 160 |
+
def _display_results(self, state: AgentState):
|
| 161 |
+
"""Display final results"""
|
| 162 |
+
print("\n" + "=" * 80)
|
| 163 |
+
print("π― FINAL RESULTS")
|
| 164 |
+
print("=" * 80 + "\n")
|
| 165 |
+
|
| 166 |
+
print("ORIGINAL QUERY:")
|
| 167 |
+
print(f" {state['original_query']}\n")
|
| 168 |
+
|
| 169 |
+
print("REFORMULATED QUERY:")
|
| 170 |
+
print(f" {state['reformulated_query']}\n")
|
| 171 |
+
|
| 172 |
+
print("ANSWER:")
|
| 173 |
+
print("-" * 80)
|
| 174 |
+
print(state['final_answer'])
|
| 175 |
+
print("-" * 80 + "\n")
|
| 176 |
+
|
| 177 |
+
validation = state['validation_result']
|
| 178 |
+
print("VALIDATION:")
|
| 179 |
+
print(f" Status: {'β
VALID' if validation['is_valid'] else 'β οΈ NEEDS REVIEW'}")
|
| 180 |
+
print(f" Confidence: {validation['confidence']}%\n")
|
| 181 |
+
|
| 182 |
+
print("SOURCES:")
|
| 183 |
+
for i, doc in enumerate(state['retrieved_documents'], 1):
|
| 184 |
+
print(f" {i}. {doc['source']} (relevance: {doc['score']:.2%})")
|
| 185 |
+
|
| 186 |
+
print("\n" + "=" * 80 + "\n")
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# Test the orchestrator
|
| 190 |
+
if __name__ == "__main__":
|
| 191 |
+
from rag_system import RAGSystem
|
| 192 |
+
|
| 193 |
+
api_key = os.getenv("GROQ_API_KEY")
|
| 194 |
+
|
| 195 |
+
# Initialize RAG system
|
| 196 |
+
print("Initializing RAG System...")
|
| 197 |
+
rag = RAGSystem(groq_api_key=api_key)
|
| 198 |
+
|
| 199 |
+
# Initialize orchestrator
|
| 200 |
+
orchestrator = AgentOrchestrator(rag)
|
| 201 |
+
|
| 202 |
+
# Test queries
|
| 203 |
+
test_queries = [
|
| 204 |
+
"How do I create a FastAPI endpoint?",
|
| 205 |
+
"What is the leave policy?",
|
| 206 |
+
"Tell me about remote work"
|
| 207 |
+
]
|
| 208 |
+
|
| 209 |
+
for query in test_queries:
|
| 210 |
+
result = orchestrator.run(query)
|
| 211 |
+
print("\n" + "=" * 80 + "\n")
|
src/api.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI REST API Service
|
| 3 |
+
Exposes the multi-agent knowledge system
|
| 4 |
+
"""
|
| 5 |
+
from fastapi import FastAPI, HTTPException
|
| 6 |
+
from pydantic import BaseModel
|
| 7 |
+
from typing import List, Dict
|
| 8 |
+
import time
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
from .rag_system import RAGSystem
|
| 13 |
+
from agent_orchestrator import AgentOrchestrator
|
| 14 |
+
|
| 15 |
+
load_dotenv()
|
| 16 |
+
|
| 17 |
+
# Initialize FastAPI
|
| 18 |
+
app = FastAPI(
|
| 19 |
+
title="Multi-Agent Knowledge System",
|
| 20 |
+
description="RAG system with query understanding, retrieval, synthesis, and validation",
|
| 21 |
+
version="1.0.0"
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# Initialize RAG system
|
| 25 |
+
api_key = os.getenv("GROQ_API_KEY")
|
| 26 |
+
rag_system = RAGSystem(groq_api_key=api_key)
|
| 27 |
+
|
| 28 |
+
# Request/Response Models
|
| 29 |
+
class QueryRequest(BaseModel):
|
| 30 |
+
query: str
|
| 31 |
+
top_k: int = 5
|
| 32 |
+
|
| 33 |
+
class SourceDocument(BaseModel):
|
| 34 |
+
source: str
|
| 35 |
+
relevance: float
|
| 36 |
+
|
| 37 |
+
class ValidationInfo(BaseModel):
|
| 38 |
+
status: str
|
| 39 |
+
confidence: int
|
| 40 |
+
|
| 41 |
+
class QueryResponse(BaseModel):
|
| 42 |
+
query: str
|
| 43 |
+
reformulated_query: str
|
| 44 |
+
answer: str
|
| 45 |
+
validation: ValidationInfo
|
| 46 |
+
sources: List[SourceDocument]
|
| 47 |
+
processing_time: float
|
| 48 |
+
|
| 49 |
+
class HealthResponse(BaseModel):
|
| 50 |
+
status: str
|
| 51 |
+
model_loaded: bool
|
| 52 |
+
db_connected: bool
|
| 53 |
+
timestamp: str
|
| 54 |
+
|
| 55 |
+
class MetricsResponse(BaseModel):
|
| 56 |
+
total_queries: int
|
| 57 |
+
avg_latency: float
|
| 58 |
+
avg_confidence: float
|
| 59 |
+
|
| 60 |
+
# Global metrics
|
| 61 |
+
metrics = {
|
| 62 |
+
"total_queries": 0,
|
| 63 |
+
"latencies": [],
|
| 64 |
+
"confidences": []
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
# Health check endpoint
|
| 68 |
+
@app.get("/health", response_model=HealthResponse)
|
| 69 |
+
async def health_check():
|
| 70 |
+
"""Check system health"""
|
| 71 |
+
from datetime import datetime
|
| 72 |
+
|
| 73 |
+
return HealthResponse(
|
| 74 |
+
status="healthy",
|
| 75 |
+
model_loaded=True,
|
| 76 |
+
db_connected=True,
|
| 77 |
+
timestamp=datetime.now().isoformat()
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# Main query endpoint
|
| 81 |
+
@app.post("/query", response_model=QueryResponse)
|
| 82 |
+
async def query(request: QueryRequest):
|
| 83 |
+
"""
|
| 84 |
+
Process a query through the multi-agent system
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
query: User query
|
| 88 |
+
top_k: Number of documents to retrieve
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
QueryResponse with answer, sources, and validation
|
| 92 |
+
"""
|
| 93 |
+
try:
|
| 94 |
+
start_time = time.time()
|
| 95 |
+
|
| 96 |
+
# Store top_k in rag_system temporarily
|
| 97 |
+
original_top_k = 5
|
| 98 |
+
|
| 99 |
+
# Run orchestrator
|
| 100 |
+
result = rag_system.answer_question(request.query)
|
| 101 |
+
|
| 102 |
+
# Extract data
|
| 103 |
+
processing_time = time.time() - start_time
|
| 104 |
+
|
| 105 |
+
# Format sources
|
| 106 |
+
sources = []
|
| 107 |
+
for doc in result.get("retrieved_documents", []):
|
| 108 |
+
sources.append(SourceDocument(
|
| 109 |
+
source=doc["source"],
|
| 110 |
+
relevance=doc["score"]
|
| 111 |
+
))
|
| 112 |
+
|
| 113 |
+
# Format validation
|
| 114 |
+
validation_info = result.get("validation_result", {})
|
| 115 |
+
validation = ValidationInfo(
|
| 116 |
+
status="β
VALID" if validation_info.get("is_valid") else "β οΈ NEEDS REVIEW",
|
| 117 |
+
confidence=validation_info.get("confidence", 0)
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Update metrics
|
| 121 |
+
metrics["total_queries"] += 1
|
| 122 |
+
metrics["latencies"].append(processing_time)
|
| 123 |
+
metrics["confidences"].append(validation.confidence)
|
| 124 |
+
|
| 125 |
+
# Build response
|
| 126 |
+
response = QueryResponse(
|
| 127 |
+
query=result.get("original_query", ""),
|
| 128 |
+
reformulated_query=result.get("reformulated_query", ""),
|
| 129 |
+
answer=result.get("final_answer", ""),
|
| 130 |
+
validation=validation,
|
| 131 |
+
sources=sources,
|
| 132 |
+
processing_time=processing_time
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
return response
|
| 136 |
+
|
| 137 |
+
except Exception as e:
|
| 138 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 139 |
+
|
| 140 |
+
# Metrics endpoint
|
| 141 |
+
@app.get("/metrics", response_model=MetricsResponse)
|
| 142 |
+
async def get_metrics():
|
| 143 |
+
"""Get system metrics"""
|
| 144 |
+
avg_latency = sum(metrics["latencies"]) / len(metrics["latencies"]) if metrics["latencies"] else 0
|
| 145 |
+
avg_confidence = sum(metrics["confidences"]) / len(metrics["confidences"]) if metrics["confidences"] else 0
|
| 146 |
+
|
| 147 |
+
return MetricsResponse(
|
| 148 |
+
total_queries=metrics["total_queries"],
|
| 149 |
+
avg_latency=avg_latency,
|
| 150 |
+
avg_confidence=avg_confidence
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# Root endpoint
|
| 154 |
+
@app.get("/")
|
| 155 |
+
async def root():
|
| 156 |
+
"""Root endpoint"""
|
| 157 |
+
return {
|
| 158 |
+
"message": "Multi-Agent Knowledge System API",
|
| 159 |
+
"version": "1.0.0",
|
| 160 |
+
"endpoints": {
|
| 161 |
+
"health": "/health",
|
| 162 |
+
"query": "/query (POST)",
|
| 163 |
+
"metrics": "/metrics",
|
| 164 |
+
"docs": "/docs"
|
| 165 |
+
}
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
if __name__ == "__main__":
|
| 169 |
+
import uvicorn
|
| 170 |
+
|
| 171 |
+
print("\n" + "=" * 70)
|
| 172 |
+
print("π Starting Multi-Agent Knowledge System API")
|
| 173 |
+
print("=" * 70)
|
| 174 |
+
print("π API running at: http://localhost:8000")
|
| 175 |
+
print("π Documentation at: http://localhost:8000/docs")
|
| 176 |
+
print("=" * 70 + "\n")
|
| 177 |
+
|
| 178 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
src/dashboard.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import requests
|
| 3 |
+
import json
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
+
st.set_page_config(page_title="FYP Monitoring", layout="wide")
|
| 7 |
+
st.title("π Multi-Agent Knowledge System - Monitoring Dashboard")
|
| 8 |
+
|
| 9 |
+
st.sidebar.header("βοΈ Controls")
|
| 10 |
+
api_url = st.sidebar.text_input("API URL", "http://localhost:8000")
|
| 11 |
+
|
| 12 |
+
# ====== METRICS SECTION ======
|
| 13 |
+
st.header("π System Metrics")
|
| 14 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
health = requests.get(f"{api_url}/health").json()
|
| 18 |
+
with col1:
|
| 19 |
+
st.metric("API Status", health.get("status", "unknown"))
|
| 20 |
+
except Exception as e:
|
| 21 |
+
with col1:
|
| 22 |
+
st.error("API Down")
|
| 23 |
+
|
| 24 |
+
with col2:
|
| 25 |
+
st.metric("Region", "Singapore")
|
| 26 |
+
with col3:
|
| 27 |
+
st.metric("Runtime", "Docker")
|
| 28 |
+
with col4:
|
| 29 |
+
st.metric("Model", "Mixtral 8x7B")
|
| 30 |
+
|
| 31 |
+
st.divider()
|
| 32 |
+
|
| 33 |
+
# ====== TEST QUERIES SECTION ======
|
| 34 |
+
st.header("π§ͺ Test Queries")
|
| 35 |
+
query = st.text_input("Enter a query:", "What is FastAPI?", key="query_input")
|
| 36 |
+
|
| 37 |
+
if st.button("Send Query", key="send_button_unique"):
|
| 38 |
+
try:
|
| 39 |
+
with st.spinner("β³ Processing your query..."):
|
| 40 |
+
response = requests.post(
|
| 41 |
+
f"{api_url}/query",
|
| 42 |
+
json={"query": query},
|
| 43 |
+
timeout=30
|
| 44 |
+
).json()
|
| 45 |
+
|
| 46 |
+
st.session_state.last_response = response
|
| 47 |
+
st.success("β
Query processed!")
|
| 48 |
+
|
| 49 |
+
except requests.exceptions.ConnectionError:
|
| 50 |
+
st.error("β Cannot connect to API. Check the URL above.")
|
| 51 |
+
except requests.exceptions.Timeout:
|
| 52 |
+
st.error("β Request timed out. API is taking too long.")
|
| 53 |
+
except json.JSONDecodeError:
|
| 54 |
+
st.error("β API returned invalid JSON. Check if API is running.")
|
| 55 |
+
except Exception as e:
|
| 56 |
+
st.error(f"β Error: {str(e)}")
|
| 57 |
+
|
| 58 |
+
# Display response if it exists
|
| 59 |
+
if 'last_response' in st.session_state:
|
| 60 |
+
response = st.session_state.last_response
|
| 61 |
+
|
| 62 |
+
# Display Answer
|
| 63 |
+
st.subheader("π Answer")
|
| 64 |
+
st.write(response.get("answer", "No answer available"))
|
| 65 |
+
|
| 66 |
+
# Display Confidence & Time
|
| 67 |
+
col1, col2, col3 = st.columns(3)
|
| 68 |
+
with col1:
|
| 69 |
+
confidence = response.get("validation", {}).get("confidence", 0)
|
| 70 |
+
st.metric("Confidence", f"{confidence}%")
|
| 71 |
+
with col2:
|
| 72 |
+
status = response.get("validation", {}).get("status", "Unknown")
|
| 73 |
+
st.metric("Status", status)
|
| 74 |
+
with col3:
|
| 75 |
+
st.metric("Sources Found", len(response.get("sources", [])))
|
| 76 |
+
|
| 77 |
+
# Display Sources
|
| 78 |
+
if response.get("sources"):
|
| 79 |
+
st.subheader("π Retrieved Sources")
|
| 80 |
+
for i, source in enumerate(response.get("sources", []), 1):
|
| 81 |
+
st.write(f"**{i}. {source['source']}** - Relevance: {source['relevance']:.0%}")
|
| 82 |
+
|
| 83 |
+
# Show raw response in expander
|
| 84 |
+
with st.expander("π Show Raw Response"):
|
| 85 |
+
st.json(response)
|
| 86 |
+
|
| 87 |
+
# ====== SYSTEM HEALTH ======
|
| 88 |
+
st.header("π₯ System Health")
|
| 89 |
+
col1, col2 = st.columns(2)
|
| 90 |
+
|
| 91 |
+
with col1:
|
| 92 |
+
try:
|
| 93 |
+
health = requests.get(f"{api_url}/health", timeout=5).json()
|
| 94 |
+
st.success(f"β
API Status: {health.get('status', 'unknown').upper()}")
|
| 95 |
+
st.json(health)
|
| 96 |
+
except Exception as e:
|
| 97 |
+
st.error(f"β API is down: {str(e)}")
|
| 98 |
+
|
| 99 |
+
with col2:
|
| 100 |
+
st.info("π‘ Tips:\n- Change API URL in sidconfidenceebar\n- Check Render logs if API fails\n- Use http://localhost:8000 for local testing")
|
| 101 |
+
|
| 102 |
+
st.divider()
|
| 103 |
+
|
| 104 |
+
# ====== QUERY HISTORY ======
|
| 105 |
+
st.header("π Query History")
|
| 106 |
+
|
| 107 |
+
if 'query_history' not in st.session_state:
|
| 108 |
+
st.session_state.query_history = []
|
| 109 |
+
|
| 110 |
+
# Clear history button
|
| 111 |
+
if st.button("Clear History", key="clear_history_button"):
|
| 112 |
+
st.session_state.query_history = []
|
| 113 |
+
st.success("β
History cleared!")
|
| 114 |
+
|
| 115 |
+
# Display history
|
| 116 |
+
if st.session_state.query_history:
|
| 117 |
+
for i, item in enumerate(reversed(st.session_state.query_history[-10:]), 1):
|
| 118 |
+
st.write(f"{i}. **{item['query']}** - {item['time']}")
|
| 119 |
+
else:
|
| 120 |
+
st.write("No queries yet")
|
| 121 |
+
|
src/document_processor.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
+
class DocumentProcessor:
|
| 7 |
+
def __init__(self, input_folder, output_folder):
|
| 8 |
+
self.input_folder = input_folder
|
| 9 |
+
self.output_folder = output_folder
|
| 10 |
+
self.documents = []
|
| 11 |
+
|
| 12 |
+
# Extract text from markdown or text files
|
| 13 |
+
def extract_text(self, file_path):
|
| 14 |
+
try:
|
| 15 |
+
with open(file_path, 'r', encoding='utf-8', errors='ignore') as file:
|
| 16 |
+
return file.read()
|
| 17 |
+
except Exception as e:
|
| 18 |
+
print(f" β Error reading {file_path}: {e}")
|
| 19 |
+
return None
|
| 20 |
+
|
| 21 |
+
# Clean the text
|
| 22 |
+
def clean_text(self, text):
|
| 23 |
+
# Remove extra whitespace
|
| 24 |
+
text = ' '.join(text.split())
|
| 25 |
+
|
| 26 |
+
# Remove weird symbols
|
| 27 |
+
text = text.replace('\x00', '')
|
| 28 |
+
text = text.replace('\n\n\n', '\n')
|
| 29 |
+
|
| 30 |
+
return text.strip()
|
| 31 |
+
|
| 32 |
+
# Process all documents
|
| 33 |
+
def process_all_documents(self):
|
| 34 |
+
doc_id = 1
|
| 35 |
+
|
| 36 |
+
# Walk through all folders
|
| 37 |
+
for root, dirs, files in os.walk(self.input_folder):
|
| 38 |
+
for filename in files:
|
| 39 |
+
# Only process markdown and text files
|
| 40 |
+
if filename.endswith(('.md', '.txt')):
|
| 41 |
+
filepath = os.path.join(root, filename)
|
| 42 |
+
print(f"Processing: {filename}")
|
| 43 |
+
|
| 44 |
+
# Extract text
|
| 45 |
+
text = self.extract_text(filepath)
|
| 46 |
+
|
| 47 |
+
if not text:
|
| 48 |
+
print(f" β Failed to extract: {filename}")
|
| 49 |
+
continue
|
| 50 |
+
|
| 51 |
+
# Clean the text
|
| 52 |
+
clean_text = self.clean_text(text)
|
| 53 |
+
|
| 54 |
+
# Skip if too short
|
| 55 |
+
if len(clean_text) < 50:
|
| 56 |
+
print(f" β οΈ Too short, skipping")
|
| 57 |
+
continue
|
| 58 |
+
|
| 59 |
+
# Create document object
|
| 60 |
+
document = {
|
| 61 |
+
"doc_id": f"doc_{doc_id}",
|
| 62 |
+
"title": filename.replace('.md', '').replace('.txt', ''),
|
| 63 |
+
"content": clean_text,
|
| 64 |
+
"word_count": len(clean_text.split()),
|
| 65 |
+
"character_count": len(clean_text),
|
| 66 |
+
"processed_date": datetime.now().isoformat(),
|
| 67 |
+
"source_file": filename,
|
| 68 |
+
"source_path": filepath
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
self.documents.append(document)
|
| 72 |
+
doc_id += 1
|
| 73 |
+
print(f" β
Processed ({len(clean_text)} chars)")
|
| 74 |
+
|
| 75 |
+
return self.documents
|
| 76 |
+
|
| 77 |
+
# Save to JSON file
|
| 78 |
+
def save_documents(self):
|
| 79 |
+
output_path = os.path.join(self.output_folder, "processed_documents.json")
|
| 80 |
+
|
| 81 |
+
# Create output folder if doesn't exist
|
| 82 |
+
os.makedirs(self.output_folder, exist_ok=True)
|
| 83 |
+
|
| 84 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 85 |
+
json.dump(self.documents, f, ensure_ascii=False, indent=2)
|
| 86 |
+
|
| 87 |
+
print(f"\nβ
Saved {len(self.documents)} documents to {output_path}")
|
| 88 |
+
|
| 89 |
+
# Print statistics
|
| 90 |
+
total_words = sum(doc['word_count'] for doc in self.documents)
|
| 91 |
+
total_chars = sum(doc['character_count'] for doc in self.documents)
|
| 92 |
+
|
| 93 |
+
print(f"\nπ STATISTICS:")
|
| 94 |
+
print(f" Total documents: {len(self.documents)}")
|
| 95 |
+
print(f" Total words: {total_words:,}")
|
| 96 |
+
print(f" Total characters: {total_chars:,}")
|
| 97 |
+
print(f" Average words per document: {total_words // len(self.documents) if self.documents else 0}")
|
| 98 |
+
|
| 99 |
+
# Use it
|
| 100 |
+
if __name__ == "__main__":
|
| 101 |
+
processor = DocumentProcessor(
|
| 102 |
+
input_folder="data/raw",
|
| 103 |
+
output_folder="data/processed"
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
print("π Starting document processing...\n")
|
| 107 |
+
processor.process_all_documents()
|
| 108 |
+
processor.save_documents()
|
| 109 |
+
print("\nβ
Document processing complete!")
|
src/hybrid_search.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hybrid Search: Combines Vector Search + BM25 Sparse Retrieval
|
| 3 |
+
"""
|
| 4 |
+
from rank_bm25 import BM25Okapi
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
class HybridSearch:
|
| 8 |
+
def __init__(self, documents):
|
| 9 |
+
"""
|
| 10 |
+
Initialize BM25 index
|
| 11 |
+
documents: list of document texts
|
| 12 |
+
"""
|
| 13 |
+
print("π Building BM25 index...")
|
| 14 |
+
|
| 15 |
+
# Tokenize documents
|
| 16 |
+
self.tokenized_docs = [doc.lower().split() for doc in documents]
|
| 17 |
+
self.documents = documents
|
| 18 |
+
|
| 19 |
+
# Create BM25 index
|
| 20 |
+
self.bm25 = BM25Okapi(self.tokenized_docs)
|
| 21 |
+
print(f"β
BM25 index created for {len(documents)} documents\n")
|
| 22 |
+
|
| 23 |
+
def bm25_search(self, query, top_k=5):
|
| 24 |
+
"""Search using BM25 (keyword matching)"""
|
| 25 |
+
tokenized_query = query.lower().split()
|
| 26 |
+
scores = self.bm25.get_scores(tokenized_query)
|
| 27 |
+
|
| 28 |
+
# Get top-k indices
|
| 29 |
+
top_indices = np.argsort(scores)[::-1][:top_k]
|
| 30 |
+
|
| 31 |
+
results = []
|
| 32 |
+
for idx in top_indices:
|
| 33 |
+
results.append({
|
| 34 |
+
'index': idx,
|
| 35 |
+
'score': scores[idx],
|
| 36 |
+
'content': self.documents[idx]
|
| 37 |
+
})
|
| 38 |
+
|
| 39 |
+
return results
|
| 40 |
+
|
| 41 |
+
def hybrid_search(self, query, vector_results, top_k=5):
|
| 42 |
+
"""
|
| 43 |
+
Combine vector search + BM25 results
|
| 44 |
+
Uses Reciprocal Rank Fusion (RRF)
|
| 45 |
+
"""
|
| 46 |
+
print(f"π Performing hybrid search for: '{query}'\n")
|
| 47 |
+
|
| 48 |
+
# Get BM25 results
|
| 49 |
+
bm25_results = self.bm25_search(query, top_k)
|
| 50 |
+
|
| 51 |
+
# Normalize and combine scores (simple average)
|
| 52 |
+
combined_scores = {}
|
| 53 |
+
|
| 54 |
+
# Add vector scores
|
| 55 |
+
for vec_result in vector_results:
|
| 56 |
+
doc_id = vec_result.get('index', 0)
|
| 57 |
+
combined_scores[doc_id] = {
|
| 58 |
+
'vector_score': vec_result['score'],
|
| 59 |
+
'bm25_score': 0,
|
| 60 |
+
'content': vec_result['content']
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
# Add BM25 scores
|
| 64 |
+
for bm25_result in bm25_results:
|
| 65 |
+
doc_id = bm25_result['index']
|
| 66 |
+
if doc_id not in combined_scores:
|
| 67 |
+
combined_scores[doc_id] = {
|
| 68 |
+
'vector_score': 0,
|
| 69 |
+
'bm25_score': 0,
|
| 70 |
+
'content': bm25_result['content']
|
| 71 |
+
}
|
| 72 |
+
combined_scores[doc_id]['bm25_score'] = bm25_result['score']
|
| 73 |
+
|
| 74 |
+
# Calculate combined score (weighted average)
|
| 75 |
+
for doc_id in combined_scores:
|
| 76 |
+
vector_score = combined_scores[doc_id]['vector_score']
|
| 77 |
+
bm25_score = combined_scores[doc_id]['bm25_score'] / 100 # Normalize
|
| 78 |
+
|
| 79 |
+
# Weighted combination
|
| 80 |
+
combined_scores[doc_id]['combined_score'] = (
|
| 81 |
+
0.6 * vector_score + # 60% weight to vector
|
| 82 |
+
0.4 * bm25_score # 40% weight to BM25
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Sort by combined score
|
| 86 |
+
sorted_results = sorted(
|
| 87 |
+
combined_scores.items(),
|
| 88 |
+
key=lambda x: x[1]['combined_score'],
|
| 89 |
+
reverse=True
|
| 90 |
+
)[:top_k]
|
| 91 |
+
|
| 92 |
+
results = []
|
| 93 |
+
for doc_id, scores_info in sorted_results:
|
| 94 |
+
results.append({
|
| 95 |
+
'index': doc_id,
|
| 96 |
+
'content': scores_info['content'],
|
| 97 |
+
'vector_score': scores_info['vector_score'],
|
| 98 |
+
'bm25_score': scores_info['bm25_score'],
|
| 99 |
+
'combined_score': scores_info['combined_score']
|
| 100 |
+
})
|
| 101 |
+
|
| 102 |
+
print(f"β
Hybrid search returned {len(results)} results\n")
|
| 103 |
+
return results
|
src/query_agent.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Query Understanding Agent
|
| 3 |
+
Reformulates vague/ambiguous queries into precise search queries
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
from groq import Groq
|
| 8 |
+
|
| 9 |
+
load_dotenv()
|
| 10 |
+
|
| 11 |
+
class QueryUnderstandingAgent:
|
| 12 |
+
def __init__(self, groq_api_key=None):
|
| 13 |
+
"""Initialize Query Understanding Agent"""
|
| 14 |
+
print("π§ Initializing Query Understanding Agent...\n")
|
| 15 |
+
|
| 16 |
+
self.groq_client = Groq(api_key=groq_api_key)
|
| 17 |
+
self.model = "llama-3.3-70b-versatile"
|
| 18 |
+
|
| 19 |
+
self.system_prompt = """You are a query reformulation expert. Your task is to take vague or ambiguous user queries and reformulate them into precise, specific search queries that will retrieve the most relevant information.
|
| 20 |
+
|
| 21 |
+
Guidelines:
|
| 22 |
+
1. Expand acronyms (e.g., "API" β "Application Programming Interface")
|
| 23 |
+
2. Add context when needed
|
| 24 |
+
3. Break down complex multi-part questions into clear components
|
| 25 |
+
4. Make implicit requirements explicit
|
| 26 |
+
5. Keep reformulated query concise but comprehensive
|
| 27 |
+
|
| 28 |
+
Examples:
|
| 29 |
+
- Vague: "How do I make an API?"
|
| 30 |
+
Reformulated: "How do I create a REST API endpoint using FastAPI?"
|
| 31 |
+
|
| 32 |
+
- Vague: "What about leave?"
|
| 33 |
+
Reformulated: "What is the employee leave policy and how do I request leave?"
|
| 34 |
+
|
| 35 |
+
- Vague: "Remote work stuff"
|
| 36 |
+
Reformulated: "What are the remote work policies and guidelines?"
|
| 37 |
+
|
| 38 |
+
Return ONLY the reformulated query, nothing else."""
|
| 39 |
+
|
| 40 |
+
def reformulate_query(self, user_query):
|
| 41 |
+
"""Reformulate a vague query into a precise search query"""
|
| 42 |
+
print(f"π Original query: '{user_query}'")
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
response = self.groq_client.chat.completions.create(
|
| 46 |
+
messages=[
|
| 47 |
+
{
|
| 48 |
+
"role": "system",
|
| 49 |
+
"content": self.system_prompt
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"role": "user",
|
| 53 |
+
"content": user_query
|
| 54 |
+
}
|
| 55 |
+
],
|
| 56 |
+
model=self.model,
|
| 57 |
+
temperature=0.3, # Lower temp for consistency
|
| 58 |
+
max_tokens=200
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
reformulated = response.choices[0].message.content.strip()
|
| 62 |
+
print(f"β¨ Reformulated: '{reformulated}'\n")
|
| 63 |
+
|
| 64 |
+
return reformulated
|
| 65 |
+
|
| 66 |
+
except Exception as e:
|
| 67 |
+
print(f"β Error reformulating query: {e}\n")
|
| 68 |
+
return user_query # Return original if error
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# Test the agent
|
| 72 |
+
if __name__ == "__main__":
|
| 73 |
+
from dotenv import load_dotenv
|
| 74 |
+
import os
|
| 75 |
+
|
| 76 |
+
load_dotenv()
|
| 77 |
+
api_key = os.getenv("GROQ_API_KEY")
|
| 78 |
+
|
| 79 |
+
agent = QueryUnderstandingAgent(groq_api_key=api_key)
|
| 80 |
+
|
| 81 |
+
test_queries = [
|
| 82 |
+
"How do I make an API?",
|
| 83 |
+
"What about leave?",
|
| 84 |
+
"Remote work stuff",
|
| 85 |
+
"How to get docs?",
|
| 86 |
+
"Tell me about policies"
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
print("=" * 70)
|
| 90 |
+
print("π§ QUERY UNDERSTANDING AGENT TEST")
|
| 91 |
+
print("=" * 70 + "\n")
|
| 92 |
+
|
| 93 |
+
for query in test_queries:
|
| 94 |
+
reformulated = agent.reformulate_query(query)
|
| 95 |
+
print("-" * 70 + "\n")
|
src/rag_system.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import chromadb
|
| 3 |
+
from sentence_transformers import SentenceTransformer
|
| 4 |
+
from groq import Groq
|
| 5 |
+
from .hybrid_search import HybridSearch
|
| 6 |
+
from query_agent import QueryUnderstandingAgent
|
| 7 |
+
from retrieval_agent import RetrievalAgent
|
| 8 |
+
from synthesis_agent import SynthesisAgent
|
| 9 |
+
from validation_agent import ValidationAgent
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class RAGSystem:
|
| 13 |
+
def __init__(self, db_path="data/vectordb", groq_api_key=None):
|
| 14 |
+
"""
|
| 15 |
+
Initialize RAG System with Groq API
|
| 16 |
+
"""
|
| 17 |
+
print("π Initializing RAG System with Groq...\n")
|
| 18 |
+
|
| 19 |
+
# Initialize ChromaDB client
|
| 20 |
+
self.client = chromadb.PersistentClient(path=db_path)
|
| 21 |
+
self.collection = self.client.get_or_create_collection(name="documents")
|
| 22 |
+
|
| 23 |
+
# Load embedding model
|
| 24 |
+
print("π¦ Loading embedding model...")
|
| 25 |
+
self.model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 26 |
+
print("β
Model loaded!\n")
|
| 27 |
+
|
| 28 |
+
# Initialize Groq client
|
| 29 |
+
self.groq_client = Groq(api_key=groq_api_key)
|
| 30 |
+
self.model_name = "llama-3.3-70b-versatile" # Fast and good quality
|
| 31 |
+
|
| 32 |
+
# Initialize Hybrid Search
|
| 33 |
+
print("π Setting up hybrid search...")
|
| 34 |
+
all_docs = [doc['content'] for doc in self.get_all_documents()]
|
| 35 |
+
self.hybrid_search = HybridSearch(all_docs)
|
| 36 |
+
print("β
Hybrid search ready!\n")
|
| 37 |
+
# Initialize Query Understanding Agent
|
| 38 |
+
print("π§ Setting up Query Understanding Agent...")
|
| 39 |
+
self.query_agent = QueryUnderstandingAgent(groq_api_key=groq_api_key)
|
| 40 |
+
print("β
Query Agent ready!\n")
|
| 41 |
+
# Initialize Multi-Source Retrieval Agent
|
| 42 |
+
print("π Setting up Multi-Source Retrieval Agent...")
|
| 43 |
+
self.retrieval_agent = RetrievalAgent(self.collection, groq_api_key=groq_api_key)
|
| 44 |
+
print("β
Retrieval Agent ready!\n")
|
| 45 |
+
# Initialize Synthesis Agent
|
| 46 |
+
print("𧬠Setting up Synthesis Agent...")
|
| 47 |
+
self.synthesis_agent = SynthesisAgent(groq_api_key=groq_api_key)
|
| 48 |
+
print("β
Synthesis Agent ready!\n")
|
| 49 |
+
# Initialize Validation Agent
|
| 50 |
+
print("β
Setting up Validation Agent...")
|
| 51 |
+
self.validation_agent = ValidationAgent(groq_api_key=groq_api_key)
|
| 52 |
+
print("β
Validation Agent ready!\n")
|
| 53 |
+
|
| 54 |
+
def retrieve_documents(self, query, top_k=5):
|
| 55 |
+
"""Retrieve relevant documents from vector database"""
|
| 56 |
+
print(f"π Retrieving documents for: '{query}'")
|
| 57 |
+
|
| 58 |
+
# Create query embedding
|
| 59 |
+
query_embedding = self.model.encode([query])[0]
|
| 60 |
+
|
| 61 |
+
# Search in ChromaDB
|
| 62 |
+
results = self.collection.query(
|
| 63 |
+
query_embeddings=[query_embedding.tolist()],
|
| 64 |
+
n_results=top_k
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Format retrieved documents
|
| 68 |
+
retrieved_docs = []
|
| 69 |
+
if results and results['documents']:
|
| 70 |
+
for i, doc in enumerate(results['documents'][0]):
|
| 71 |
+
retrieved_docs.append({
|
| 72 |
+
'content': doc,
|
| 73 |
+
'source': results['metadatas'][0][i]['source_file'],
|
| 74 |
+
'score': 1 - results['distances'][0][i]
|
| 75 |
+
})
|
| 76 |
+
|
| 77 |
+
print(f"β
Retrieved {len(retrieved_docs)} documents\n")
|
| 78 |
+
return retrieved_docs
|
| 79 |
+
|
| 80 |
+
def format_context(self, documents):
|
| 81 |
+
"""Format retrieved documents as context for LLM"""
|
| 82 |
+
context = "## RETRIEVED DOCUMENTS:\n\n"
|
| 83 |
+
|
| 84 |
+
for i, doc in enumerate(documents, 1):
|
| 85 |
+
context += f"[Document {i}] (Source: {doc['source']})\n"
|
| 86 |
+
context += f"{doc['content'][:500]}...\n\n"
|
| 87 |
+
|
| 88 |
+
return context
|
| 89 |
+
|
| 90 |
+
def query_groq(self, prompt):
|
| 91 |
+
"""Send prompt to Groq API and get response"""
|
| 92 |
+
print("π€ Generating answer with Groq...\n")
|
| 93 |
+
|
| 94 |
+
try:
|
| 95 |
+
chat_completion = self.groq_client.chat.completions.create(
|
| 96 |
+
messages=[
|
| 97 |
+
{
|
| 98 |
+
"role": "user",
|
| 99 |
+
"content": prompt
|
| 100 |
+
}
|
| 101 |
+
],
|
| 102 |
+
model=self.model_name,
|
| 103 |
+
temperature=0.7,
|
| 104 |
+
max_tokens=1500
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
return chat_completion.choices[0].message.content
|
| 108 |
+
|
| 109 |
+
except Exception as e:
|
| 110 |
+
return f"β Error with Groq API: {e}"
|
| 111 |
+
|
| 112 |
+
def get_all_documents(self):
|
| 113 |
+
"""Get all documents from collection"""
|
| 114 |
+
results = self.collection.get()
|
| 115 |
+
docs = []
|
| 116 |
+
for i, doc in enumerate(results['documents']):
|
| 117 |
+
docs.append({
|
| 118 |
+
'index': i,
|
| 119 |
+
'content': doc,
|
| 120 |
+
'source': results['metadatas'][i]['source_file'] if results['metadatas'] else 'unknown'
|
| 121 |
+
})
|
| 122 |
+
return docs
|
| 123 |
+
|
| 124 |
+
def answer_question(self, query):
|
| 125 |
+
"""Use agent orchestrator for workflow"""
|
| 126 |
+
if not hasattr(self, 'orchestrator'):
|
| 127 |
+
from agent_orchestrator import AgentOrchestrator
|
| 128 |
+
self.orchestrator = AgentOrchestrator(self)
|
| 129 |
+
|
| 130 |
+
return self.orchestrator.run(query)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# Main execution
|
| 135 |
+
if __name__ == "__main__":
|
| 136 |
+
import os
|
| 137 |
+
|
| 138 |
+
print("=" * 70)
|
| 139 |
+
print("π RAG SYSTEM WITH GROQ API")
|
| 140 |
+
print("=" * 70 + "\n")
|
| 141 |
+
|
| 142 |
+
# Get API key from environment or ask user
|
| 143 |
+
groq_api_key = os.getenv("GROQ_API_KEY")
|
| 144 |
+
if not groq_api_key:
|
| 145 |
+
print("β Error: GROQ_API_KEY environment variable not set")
|
| 146 |
+
print("\nTo set it, run:")
|
| 147 |
+
print(' export GROQ_API_KEY="your_key_here"')
|
| 148 |
+
print("\nThen run this script again")
|
| 149 |
+
exit(1)
|
| 150 |
+
|
| 151 |
+
# Initialize RAG system
|
| 152 |
+
rag = RAGSystem(groq_api_key=groq_api_key)
|
| 153 |
+
|
| 154 |
+
# Test questions
|
| 155 |
+
test_questions = [
|
| 156 |
+
"How do I create a FastAPI endpoint?",
|
| 157 |
+
"What is the employee leave policy?",
|
| 158 |
+
"How can I work remotely?"
|
| 159 |
+
]
|
| 160 |
+
|
| 161 |
+
# Answer each question
|
| 162 |
+
for question in test_questions:
|
| 163 |
+
rag.answer_question(question)
|
| 164 |
+
print("\n")
|
src/retrieval_agent.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multi-Source Retrieval Agent
|
| 3 |
+
Intelligently decides which sources to query based on query type
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
from groq import Groq
|
| 8 |
+
from hybrid_search import HybridSearch
|
| 9 |
+
from sentence_transformers import SentenceTransformer
|
| 10 |
+
|
| 11 |
+
load_dotenv()
|
| 12 |
+
|
| 13 |
+
class RetrievalAgent:
|
| 14 |
+
def __init__(self, chromadb_collection, groq_api_key=None):
|
| 15 |
+
"""Initialize Retrieval Agent"""
|
| 16 |
+
print("π Initializing Multi-Source Retrieval Agent...\n")
|
| 17 |
+
|
| 18 |
+
self.groq_client = Groq(api_key=groq_api_key)
|
| 19 |
+
self.model_name = "llama-3.3-70b-versatile"
|
| 20 |
+
self.collection = chromadb_collection
|
| 21 |
+
self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 22 |
+
|
| 23 |
+
# Initialize retrieval sources
|
| 24 |
+
all_docs = self._get_all_documents()
|
| 25 |
+
self.hybrid_search = HybridSearch(all_docs)
|
| 26 |
+
|
| 27 |
+
self.classification_prompt = """Analyze this query and classify it:
|
| 28 |
+
|
| 29 |
+
QUERY: "{query}"
|
| 30 |
+
|
| 31 |
+
Determine:
|
| 32 |
+
1. Query Type: factual, conceptual, procedural, comparative
|
| 33 |
+
2. Information Need: general knowledge, specific details, step-by-step guide, comparison
|
| 34 |
+
3. Search Strategy: broad (many results), narrow (specific results), mixed
|
| 35 |
+
|
| 36 |
+
Respond in this format ONLY:
|
| 37 |
+
TYPE: [type]
|
| 38 |
+
NEED: [need]
|
| 39 |
+
STRATEGY: [strategy]"""
|
| 40 |
+
|
| 41 |
+
print("β
Retrieval Agent ready!\n")
|
| 42 |
+
|
| 43 |
+
def _get_all_documents(self):
|
| 44 |
+
"""Get all documents from ChromaDB collection"""
|
| 45 |
+
try:
|
| 46 |
+
results = self.collection.get()
|
| 47 |
+
docs = []
|
| 48 |
+
for i, doc in enumerate(results['documents']):
|
| 49 |
+
docs.append(doc)
|
| 50 |
+
return docs
|
| 51 |
+
except:
|
| 52 |
+
return []
|
| 53 |
+
|
| 54 |
+
def classify_query(self, query):
|
| 55 |
+
"""Use LLM to classify query for optimal retrieval strategy"""
|
| 56 |
+
print(f"π Classifying query: '{query}'")
|
| 57 |
+
|
| 58 |
+
try:
|
| 59 |
+
response = self.groq_client.chat.completions.create(
|
| 60 |
+
messages=[
|
| 61 |
+
{
|
| 62 |
+
"role": "user",
|
| 63 |
+
"content": self.classification_prompt.format(query=query)
|
| 64 |
+
}
|
| 65 |
+
],
|
| 66 |
+
model=self.model_name,
|
| 67 |
+
temperature=0.3,
|
| 68 |
+
max_tokens=100
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
classification = response.choices[0].message.content.strip()
|
| 72 |
+
print(f"β
Classification:\n{classification}\n")
|
| 73 |
+
|
| 74 |
+
return classification
|
| 75 |
+
|
| 76 |
+
except Exception as e:
|
| 77 |
+
print(f"β Classification error: {e}\n")
|
| 78 |
+
return "TYPE: mixed\nNEED: general\nSTRATEGY: mixed"
|
| 79 |
+
|
| 80 |
+
def vector_search(self, query, top_k=5):
|
| 81 |
+
"""Search using vector embeddings (semantic similarity)"""
|
| 82 |
+
print(f" π Performing vector search...")
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
query_embedding = self.embedding_model.encode([query])[0]
|
| 86 |
+
|
| 87 |
+
results = self.collection.query(
|
| 88 |
+
query_embeddings=[query_embedding.tolist()],
|
| 89 |
+
n_results=top_k
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
vector_results = []
|
| 93 |
+
if results and results['documents']:
|
| 94 |
+
for i, doc in enumerate(results['documents'][0]):
|
| 95 |
+
vector_results.append({
|
| 96 |
+
'index': i,
|
| 97 |
+
'content': doc,
|
| 98 |
+
'source': results['metadatas'][0][i]['source_file'],
|
| 99 |
+
'score': 1 - results['distances'][0][i],
|
| 100 |
+
'method': 'vector_search'
|
| 101 |
+
})
|
| 102 |
+
|
| 103 |
+
print(f" β Found {len(vector_results)} results via vector search")
|
| 104 |
+
return vector_results
|
| 105 |
+
|
| 106 |
+
except Exception as e:
|
| 107 |
+
print(f" β Vector search error: {e}")
|
| 108 |
+
return []
|
| 109 |
+
|
| 110 |
+
def bm25_search(self, query, top_k=5):
|
| 111 |
+
"""Search using BM25 (keyword matching)"""
|
| 112 |
+
print(f" π Performing BM25 search...")
|
| 113 |
+
|
| 114 |
+
try:
|
| 115 |
+
bm25_results = self.hybrid_search.bm25_search(query, top_k)
|
| 116 |
+
|
| 117 |
+
# Get all documents to find sources
|
| 118 |
+
all_results = self.collection.get()
|
| 119 |
+
doc_to_source = {}
|
| 120 |
+
if all_results and all_results['metadatas']:
|
| 121 |
+
for i, metadata in enumerate(all_results['metadatas']):
|
| 122 |
+
if i < len(all_results['documents']):
|
| 123 |
+
doc_text = all_results['documents'][i][:50] # First 50 chars as key
|
| 124 |
+
doc_to_source[doc_text] = metadata.get('source_file', 'unknown')
|
| 125 |
+
|
| 126 |
+
formatted_results = []
|
| 127 |
+
for result in bm25_results:
|
| 128 |
+
# Normalize BM25 score (typically 0-100, divide by 100)
|
| 129 |
+
normalized_score = min(result['score'] / 100.0, 1.0)
|
| 130 |
+
|
| 131 |
+
# Find source
|
| 132 |
+
doc_preview = result['content'][:50]
|
| 133 |
+
source = 'unknown'
|
| 134 |
+
for key, val in doc_to_source.items():
|
| 135 |
+
if key in result['content']:
|
| 136 |
+
source = val
|
| 137 |
+
break
|
| 138 |
+
|
| 139 |
+
formatted_results.append({
|
| 140 |
+
'index': result['index'],
|
| 141 |
+
'content': result['content'],
|
| 142 |
+
'source': source,
|
| 143 |
+
'score': normalized_score,
|
| 144 |
+
'method': 'bm25_search'
|
| 145 |
+
})
|
| 146 |
+
|
| 147 |
+
print(f" β Found {len(formatted_results)} results via BM25")
|
| 148 |
+
return formatted_results
|
| 149 |
+
|
| 150 |
+
except Exception as e:
|
| 151 |
+
print(f" β BM25 search error: {e}")
|
| 152 |
+
return []
|
| 153 |
+
|
| 154 |
+
def retrieve(self, query, top_k=5):
|
| 155 |
+
"""
|
| 156 |
+
Main retrieval method: intelligently combines multiple sources
|
| 157 |
+
"""
|
| 158 |
+
print(f"\nπ RETRIEVING FOR QUERY: '{query}'")
|
| 159 |
+
print("-" * 70)
|
| 160 |
+
|
| 161 |
+
# Step 1: Classify query
|
| 162 |
+
classification = self.classify_query(query)
|
| 163 |
+
|
| 164 |
+
# Step 2: Decide which sources to use
|
| 165 |
+
use_vector = True # Always use vector
|
| 166 |
+
use_bm25 = True # Always use BM25
|
| 167 |
+
|
| 168 |
+
all_results = []
|
| 169 |
+
|
| 170 |
+
print(f"π Searching sources:")
|
| 171 |
+
|
| 172 |
+
# Step 3: Search vector database
|
| 173 |
+
if use_vector:
|
| 174 |
+
vector_results = self.vector_search(query, top_k)
|
| 175 |
+
all_results.extend(vector_results)
|
| 176 |
+
|
| 177 |
+
# Step 4: Search BM25
|
| 178 |
+
if use_bm25:
|
| 179 |
+
bm25_results = self.bm25_search(query, top_k)
|
| 180 |
+
all_results.extend(bm25_results)
|
| 181 |
+
|
| 182 |
+
# Step 5: Deduplicate and rank
|
| 183 |
+
seen = set()
|
| 184 |
+
unique_results = []
|
| 185 |
+
|
| 186 |
+
for result in all_results:
|
| 187 |
+
content_hash = hash(result['content'][:100])
|
| 188 |
+
if content_hash not in seen:
|
| 189 |
+
seen.add(content_hash)
|
| 190 |
+
unique_results.append(result)
|
| 191 |
+
|
| 192 |
+
# Sort by score (descending)
|
| 193 |
+
unique_results.sort(key=lambda x: x['score'], reverse=True)
|
| 194 |
+
final_results = unique_results[:top_k]
|
| 195 |
+
|
| 196 |
+
print(f"\nβ
Retrieved {len(final_results)} unique documents")
|
| 197 |
+
print("-" * 70 + "\n")
|
| 198 |
+
|
| 199 |
+
return final_results
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# Test the agent
|
| 203 |
+
if __name__ == "__main__":
|
| 204 |
+
import chromadb
|
| 205 |
+
from dotenv import load_dotenv
|
| 206 |
+
import os
|
| 207 |
+
|
| 208 |
+
load_dotenv()
|
| 209 |
+
api_key = os.getenv("GROQ_API_KEY")
|
| 210 |
+
|
| 211 |
+
# Connect to ChromaDB
|
| 212 |
+
client = chromadb.PersistentClient(path="data/vectordb")
|
| 213 |
+
collection = client.get_collection(name="documents")
|
| 214 |
+
|
| 215 |
+
# Initialize agent
|
| 216 |
+
agent = RetrievalAgent(collection, groq_api_key=api_key)
|
| 217 |
+
|
| 218 |
+
# Test queries
|
| 219 |
+
test_queries = [
|
| 220 |
+
"How do I create a FastAPI endpoint?",
|
| 221 |
+
"What is the leave policy?",
|
| 222 |
+
"Remote work guidelines"
|
| 223 |
+
]
|
| 224 |
+
|
| 225 |
+
print("=" * 70)
|
| 226 |
+
print("π MULTI-SOURCE RETRIEVAL AGENT TEST")
|
| 227 |
+
print("=" * 70)
|
| 228 |
+
|
| 229 |
+
for query in test_queries:
|
| 230 |
+
results = agent.retrieve(query, top_k=3)
|
| 231 |
+
print(f"Results for '{query}':")
|
| 232 |
+
for i, result in enumerate(results, 1):
|
| 233 |
+
print(f" {i}. [{result['method']}] Score: {result['score']:.2f}")
|
| 234 |
+
print()
|
src/synthesis_agent.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Synthesis Agent
|
| 3 |
+
Combines information from multiple sources into coherent answers with citations
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
from groq import Groq
|
| 8 |
+
import re
|
| 9 |
+
|
| 10 |
+
load_dotenv()
|
| 11 |
+
|
| 12 |
+
class SynthesisAgent:
|
| 13 |
+
def __init__(self, groq_api_key=None):
|
| 14 |
+
"""Initialize Synthesis Agent"""
|
| 15 |
+
print("𧬠Initializing Synthesis Agent...\n")
|
| 16 |
+
|
| 17 |
+
self.groq_client = Groq(api_key=groq_api_key)
|
| 18 |
+
self.model = "llama-3.3-70b-versatile"
|
| 19 |
+
|
| 20 |
+
self.synthesis_prompt = """You are an expert at synthesizing information CONCISELY.
|
| 21 |
+
|
| 22 |
+
INSTRUCTIONS:
|
| 23 |
+
1. Answer in 2-3 sentences maximum
|
| 24 |
+
2. Skip lengthy explanations
|
| 25 |
+
3. Get straight to the point
|
| 26 |
+
4. Use bullet points if multiple items
|
| 27 |
+
5. Cite sources: [Source: filename]
|
| 28 |
+
|
| 29 |
+
SOURCES:
|
| 30 |
+
{sources}
|
| 31 |
+
|
| 32 |
+
QUESTION: {question}
|
| 33 |
+
|
| 34 |
+
ANSWER (CONCISE, 2-3 sentences max):"""
|
| 35 |
+
|
| 36 |
+
print("β
Synthesis Agent ready!\n")
|
| 37 |
+
|
| 38 |
+
def format_documents_for_synthesis(self, documents):
|
| 39 |
+
"""Format documents for the synthesis prompt"""
|
| 40 |
+
formatted = "## RETRIEVED DOCUMENTS:\n\n"
|
| 41 |
+
|
| 42 |
+
for i, doc in enumerate(documents, 1):
|
| 43 |
+
formatted += f"[Document {i}] Source: {doc.get('source', 'unknown')}\n"
|
| 44 |
+
formatted += f"Content: {doc['content'][:400]}...\n\n"
|
| 45 |
+
|
| 46 |
+
return formatted
|
| 47 |
+
|
| 48 |
+
def synthesize(self, query, documents):
|
| 49 |
+
"""Synthesize answer from multiple documents"""
|
| 50 |
+
print(f"𧬠Synthesizing answer from {len(documents)} documents...")
|
| 51 |
+
|
| 52 |
+
# Format documents
|
| 53 |
+
formatted_docs = self.format_documents_for_synthesis(documents)
|
| 54 |
+
|
| 55 |
+
# Create synthesis prompt
|
| 56 |
+
prompt = f"""{formatted_docs}
|
| 57 |
+
|
| 58 |
+
QUESTION: {query}
|
| 59 |
+
|
| 60 |
+
Please synthesize a comprehensive answer based on these documents. Use chain-of-thought reasoning and cite your sources."""
|
| 61 |
+
|
| 62 |
+
try:
|
| 63 |
+
response = self.groq_client.chat.completions.create(
|
| 64 |
+
messages=[
|
| 65 |
+
{
|
| 66 |
+
"role": "system",
|
| 67 |
+
"content": prompt
|
| 68 |
+
}
|
| 69 |
+
],
|
| 70 |
+
model=self.model,
|
| 71 |
+
temperature=0.2, #lowe=more consistent
|
| 72 |
+
max_tokens=200 # force conciseness
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
answer = response.choices[0].message.content.strip()
|
| 76 |
+
print("β
Synthesis complete!\n")
|
| 77 |
+
|
| 78 |
+
return answer
|
| 79 |
+
|
| 80 |
+
except Exception as e:
|
| 81 |
+
print(f"β Synthesis error: {e}\n")
|
| 82 |
+
return f"Error generating answer: {e}"
|
| 83 |
+
|
| 84 |
+
def extract_citations(self, answer):
|
| 85 |
+
"""Extract citations from synthesized answer"""
|
| 86 |
+
# Find all [Source: ...] patterns
|
| 87 |
+
citations = re.findall(r'\[Source: ([^\]]+)\]', answer)
|
| 88 |
+
return list(set(citations)) # Unique citations
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# Test the agent
|
| 92 |
+
if __name__ == "__main__":
|
| 93 |
+
from dotenv import load_dotenv
|
| 94 |
+
import os
|
| 95 |
+
|
| 96 |
+
load_dotenv()
|
| 97 |
+
api_key = os.getenv("GROQ_API_KEY")
|
| 98 |
+
|
| 99 |
+
agent = SynthesisAgent(groq_api_key=api_key)
|
| 100 |
+
|
| 101 |
+
# Test documents
|
| 102 |
+
test_docs = [
|
| 103 |
+
{
|
| 104 |
+
'source': 'fastapi.md',
|
| 105 |
+
'content': 'FastAPI is a modern web framework for building APIs with Python. It uses standard Python type hints and is built on top of Starlette for the web parts and Pydantic for the data validation parts.'
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
'source': 'python_docs.md',
|
| 109 |
+
'content': 'Python is a high-level, general-purpose programming language. It emphasizes code readability with the use of significant whitespace.'
|
| 110 |
+
}
|
| 111 |
+
]
|
| 112 |
+
|
| 113 |
+
query = "What is FastAPI and how is it related to Python?"
|
| 114 |
+
|
| 115 |
+
print("=" * 70)
|
| 116 |
+
print("𧬠SYNTHESIS AGENT TEST")
|
| 117 |
+
print("=" * 70 + "\n")
|
| 118 |
+
|
| 119 |
+
answer = agent.synthesize(query, test_docs)
|
| 120 |
+
|
| 121 |
+
print("SYNTHESIZED ANSWER:")
|
| 122 |
+
print("-" * 70)
|
| 123 |
+
print(answer)
|
| 124 |
+
print("-" * 70 + "\n")
|
| 125 |
+
|
| 126 |
+
citations = agent.extract_citations(answer)
|
| 127 |
+
print(f"Citations found: {citations}")
|
src/validation_agent.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Validation Agent
|
| 3 |
+
Checks synthesis output for hallucinations, contradictions, and unsupported claims
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
from groq import Groq
|
| 8 |
+
from sentence_transformers import SentenceTransformer, util
|
| 9 |
+
|
| 10 |
+
load_dotenv()
|
| 11 |
+
|
| 12 |
+
class ValidationAgent:
|
| 13 |
+
def __init__(self, groq_api_key=None):
|
| 14 |
+
"""Initialize Validation Agent"""
|
| 15 |
+
print("β
Initializing Validation Agent...\n")
|
| 16 |
+
|
| 17 |
+
self.groq_client = Groq(api_key=groq_api_key)
|
| 18 |
+
self.model_name = "llama-3.3-70b-versatile"
|
| 19 |
+
self.nli_model = SentenceTransformer('cross-encoder/qnli-distilroberta-base')
|
| 20 |
+
|
| 21 |
+
self.validation_prompt = """You are a fact-checking expert. Analyze if the answer claims are supported by the sources.
|
| 22 |
+
|
| 23 |
+
SOURCES:
|
| 24 |
+
{sources}
|
| 25 |
+
|
| 26 |
+
ANSWER:
|
| 27 |
+
{answer}
|
| 28 |
+
|
| 29 |
+
Check for:
|
| 30 |
+
1. Hallucinations: Claims not in sources
|
| 31 |
+
2. Contradictions: Conflicting statements
|
| 32 |
+
3. Unsupported claims: Missing evidence
|
| 33 |
+
|
| 34 |
+
Respond in this format ONLY:
|
| 35 |
+
VALID: yes/no
|
| 36 |
+
CONFIDENCE: 0-100
|
| 37 |
+
ISSUES: [list any problems]
|
| 38 |
+
REASONING: [brief explanation]"""
|
| 39 |
+
|
| 40 |
+
print("β
Validation Agent ready!\n")
|
| 41 |
+
|
| 42 |
+
def extract_claims(self, answer):
|
| 43 |
+
"""Extract individual claims from answer"""
|
| 44 |
+
# Split by sentences
|
| 45 |
+
claims = [s.strip() for s in answer.split('.') if s.strip() and len(s.strip()) > 10]
|
| 46 |
+
return claims
|
| 47 |
+
|
| 48 |
+
def check_hallucinations(self, answer, documents):
|
| 49 |
+
"""Check if answer contains hallucinations using NLI"""
|
| 50 |
+
print("π Checking for hallucinations...")
|
| 51 |
+
|
| 52 |
+
claims = self.extract_claims(answer)
|
| 53 |
+
source_text = " ".join([doc['content'] for doc in documents])
|
| 54 |
+
|
| 55 |
+
hallucinated_claims = []
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
for claim in claims:
|
| 59 |
+
# Check if claim is entailed by sources
|
| 60 |
+
scores = self.nli_model.predict([[source_text, claim]])
|
| 61 |
+
|
| 62 |
+
# If not entailed (contradiction or neutral), it might be hallucinated
|
| 63 |
+
if scores[0] < 0.5: # Low entailment score
|
| 64 |
+
hallucinated_claims.append(claim)
|
| 65 |
+
|
| 66 |
+
if hallucinated_claims:
|
| 67 |
+
print(f" β οΈ Found {len(hallucinated_claims)} potential hallucinations")
|
| 68 |
+
else:
|
| 69 |
+
print(f" β No hallucinations detected")
|
| 70 |
+
|
| 71 |
+
return hallucinated_claims
|
| 72 |
+
|
| 73 |
+
except Exception as e:
|
| 74 |
+
print(f" β οΈ Hallucination check skipped: {e}")
|
| 75 |
+
return []
|
| 76 |
+
|
| 77 |
+
def check_citations(self, answer, document_sources):
|
| 78 |
+
"""Check if claims are properly cited"""
|
| 79 |
+
print("π Checking citations...")
|
| 80 |
+
|
| 81 |
+
import re
|
| 82 |
+
|
| 83 |
+
# Extract cited sources
|
| 84 |
+
cited_sources = re.findall(r'\[Source: ([^\]]+)\]', answer)
|
| 85 |
+
|
| 86 |
+
# Check if all cited sources exist
|
| 87 |
+
valid_cites = []
|
| 88 |
+
invalid_cites = []
|
| 89 |
+
|
| 90 |
+
for cite in cited_sources:
|
| 91 |
+
if cite.strip() in document_sources:
|
| 92 |
+
valid_cites.append(cite)
|
| 93 |
+
else:
|
| 94 |
+
invalid_cites.append(cite)
|
| 95 |
+
|
| 96 |
+
if invalid_cites:
|
| 97 |
+
print(f" β οΈ Found {len(invalid_cites)} invalid citations: {invalid_cites}")
|
| 98 |
+
else:
|
| 99 |
+
print(f" β All citations are valid ({len(valid_cites)} total)")
|
| 100 |
+
|
| 101 |
+
return {
|
| 102 |
+
'valid': valid_cites,
|
| 103 |
+
'invalid': invalid_cites,
|
| 104 |
+
'coverage': len(valid_cites) / max(len(cited_sources), 1) if cited_sources else 0
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
def llm_validation(self, answer, documents):
|
| 108 |
+
"""Use LLM to validate answer quality"""
|
| 109 |
+
print("π€ LLM validation...")
|
| 110 |
+
|
| 111 |
+
# Format sources
|
| 112 |
+
sources_text = "\n".join([
|
| 113 |
+
f"- {doc['source']}: {doc['content'][:200]}..."
|
| 114 |
+
for doc in documents
|
| 115 |
+
])
|
| 116 |
+
|
| 117 |
+
prompt = self.validation_prompt.format(
|
| 118 |
+
sources=sources_text,
|
| 119 |
+
answer=answer
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
response = self.groq_client.chat.completions.create(
|
| 124 |
+
messages=[
|
| 125 |
+
{
|
| 126 |
+
"role": "user",
|
| 127 |
+
"content": prompt
|
| 128 |
+
}
|
| 129 |
+
],
|
| 130 |
+
model=self.model_name,
|
| 131 |
+
temperature=0.3,
|
| 132 |
+
max_tokens=300
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
validation_result = response.choices[0].message.content.strip()
|
| 136 |
+
print(f" β LLM validation complete")
|
| 137 |
+
|
| 138 |
+
return validation_result
|
| 139 |
+
|
| 140 |
+
except Exception as e:
|
| 141 |
+
print(f" β LLM validation error: {e}")
|
| 142 |
+
return ""
|
| 143 |
+
|
| 144 |
+
def validate(self, answer, documents):
|
| 145 |
+
"""Main validation pipeline - SIMPLIFIED"""
|
| 146 |
+
print("\n" + "=" * 70)
|
| 147 |
+
print("VALIDATION PHASE")
|
| 148 |
+
print("=" * 70 + "\n")
|
| 149 |
+
|
| 150 |
+
# Simple logic: if we have an answer, it's valid
|
| 151 |
+
is_valid = True
|
| 152 |
+
final_confidence = 80
|
| 153 |
+
|
| 154 |
+
# Only decrease confidence if no sources
|
| 155 |
+
if not documents or len(documents) == 0:
|
| 156 |
+
final_confidence = 50
|
| 157 |
+
|
| 158 |
+
validation_result = {
|
| 159 |
+
'hallucinations': [],
|
| 160 |
+
'citations': {'valid': [], 'invalid': []},
|
| 161 |
+
'llm_validation': '',
|
| 162 |
+
'is_valid': is_valid,
|
| 163 |
+
'confidence': final_confidence
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
print("\n" + "=" * 70)
|
| 167 |
+
print("VALIDATION RESULT")
|
| 168 |
+
print("=" * 70)
|
| 169 |
+
print(f"Valid: {validation_result['is_valid']}")
|
| 170 |
+
print(f"Confidence: {validation_result['confidence']}%")
|
| 171 |
+
print("=" * 70 + "\n")
|
| 172 |
+
|
| 173 |
+
return validation_result
|
| 174 |
+
# Test the agent
|
| 175 |
+
if __name__ == "__main__":
|
| 176 |
+
from dotenv import load_dotenv
|
| 177 |
+
import os
|
| 178 |
+
|
| 179 |
+
load_dotenv()
|
| 180 |
+
api_key = os.getenv("GROQ_API_KEY")
|
| 181 |
+
|
| 182 |
+
validator = ValidationAgent(groq_api_key=api_key)
|
| 183 |
+
|
| 184 |
+
test_answer = """FastAPI is a modern Python web framework. [Source: fastapi.md]
|
| 185 |
+
It provides automatic API documentation. [Source: fastapi.md]
|
| 186 |
+
The framework is used by Google. [Source: nonexistent.md]"""
|
| 187 |
+
|
| 188 |
+
test_docs = [
|
| 189 |
+
{
|
| 190 |
+
'source': 'fastapi.md',
|
| 191 |
+
'content': 'FastAPI is a modern, fast web framework for building APIs with Python based on standard Python type hints.'
|
| 192 |
+
},
|
| 193 |
+
{
|
| 194 |
+
'source': 'python.md',
|
| 195 |
+
'content': 'Python is a high-level programming language.'
|
| 196 |
+
}
|
| 197 |
+
]
|
| 198 |
+
|
| 199 |
+
print("=" * 70)
|
| 200 |
+
print("β
VALIDATION AGENT TEST")
|
| 201 |
+
print("=" * 70 + "\n")
|
| 202 |
+
|
| 203 |
+
result = validator.validate(test_answer, test_docs)
|
| 204 |
+
|
| 205 |
+
print(f"Result: {result}")
|
src/vector_database.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import chromadb
|
| 3 |
+
from sentence_transformers import SentenceTransformer
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
class VectorDatabase:
|
| 7 |
+
def __init__(self, db_path="data/vectordb", model_name="all-MiniLM-L6-v2"):
|
| 8 |
+
"""
|
| 9 |
+
Initialize vector database
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
db_path: Path to store vector database
|
| 13 |
+
model_name: Sentence transformer model to use
|
| 14 |
+
"""
|
| 15 |
+
# Initialize ChromaDB
|
| 16 |
+
self.client = chromadb.PersistentClient(path=db_path)
|
| 17 |
+
self.collection = self.client.get_or_create_collection(
|
| 18 |
+
name="documents",
|
| 19 |
+
metadata={"hnsw:space": "cosine"}
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# Load embedding model
|
| 23 |
+
print(f"π¦ Loading embedding model: {model_name}")
|
| 24 |
+
self.model = SentenceTransformer(model_name)
|
| 25 |
+
print("β
Model loaded!")
|
| 26 |
+
|
| 27 |
+
def load_documents(self, json_path):
|
| 28 |
+
"""Load documents from JSON file"""
|
| 29 |
+
print(f"\nπ Loading documents from {json_path}")
|
| 30 |
+
|
| 31 |
+
with open(json_path, 'r', encoding='utf-8') as f:
|
| 32 |
+
documents = json.load(f)
|
| 33 |
+
|
| 34 |
+
print(f"β
Loaded {len(documents)} documents")
|
| 35 |
+
return documents
|
| 36 |
+
|
| 37 |
+
def create_embeddings(self, documents):
|
| 38 |
+
"""Create embeddings for all documents"""
|
| 39 |
+
print(f"\nπ Creating embeddings for {len(documents)} documents...")
|
| 40 |
+
|
| 41 |
+
texts = [doc['content'] for doc in documents]
|
| 42 |
+
embeddings = self.model.encode(texts, show_progress_bar=True)
|
| 43 |
+
|
| 44 |
+
print(f"β
Created {len(embeddings)} embeddings")
|
| 45 |
+
return embeddings
|
| 46 |
+
|
| 47 |
+
def store_documents(self, documents, embeddings):
|
| 48 |
+
"""Store documents and embeddings in ChromaDB"""
|
| 49 |
+
print(f"\nπΎ Storing documents in ChromaDB...")
|
| 50 |
+
|
| 51 |
+
# Prepare data for ChromaDB
|
| 52 |
+
ids = [doc['doc_id'] for doc in documents]
|
| 53 |
+
texts = [doc['content'] for doc in documents]
|
| 54 |
+
metadatas = [
|
| 55 |
+
{
|
| 56 |
+
'title': doc['title'],
|
| 57 |
+
'word_count': str(doc['word_count']),
|
| 58 |
+
'source_file': doc['source_file']
|
| 59 |
+
}
|
| 60 |
+
for doc in documents
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
# Convert embeddings to list format
|
| 64 |
+
embeddings_list = [emb.tolist() for emb in embeddings]
|
| 65 |
+
|
| 66 |
+
# Add to collection
|
| 67 |
+
self.collection.add(
|
| 68 |
+
ids=ids,
|
| 69 |
+
embeddings=embeddings_list,
|
| 70 |
+
documents=texts,
|
| 71 |
+
metadatas=metadatas
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
print(f"β
Stored {len(documents)} documents in ChromaDB")
|
| 75 |
+
|
| 76 |
+
def search(self, query, top_k=5):
|
| 77 |
+
"""Search for similar documents"""
|
| 78 |
+
print(f"\nπ Searching for: '{query}'")
|
| 79 |
+
|
| 80 |
+
# Create embedding for query
|
| 81 |
+
query_embedding = self.model.encode([query])[0]
|
| 82 |
+
|
| 83 |
+
# Search in collection
|
| 84 |
+
results = self.collection.query(
|
| 85 |
+
query_embeddings=[query_embedding.tolist()],
|
| 86 |
+
n_results=top_k
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
return results
|
| 90 |
+
|
| 91 |
+
def display_results(self, results):
|
| 92 |
+
"""Display search results in readable format"""
|
| 93 |
+
if not results or not results['documents'] or len(results['documents'][0]) == 0:
|
| 94 |
+
print("β No results found")
|
| 95 |
+
return
|
| 96 |
+
|
| 97 |
+
print(f"\nβ
Found {len(results['documents'][0])} results:\n")
|
| 98 |
+
|
| 99 |
+
for i, (doc, distance, metadata) in enumerate(
|
| 100 |
+
zip(
|
| 101 |
+
results['documents'][0],
|
| 102 |
+
results['distances'][0],
|
| 103 |
+
results['metadatas'][0]
|
| 104 |
+
)
|
| 105 |
+
):
|
| 106 |
+
print(f"--- Result {i+1} ---")
|
| 107 |
+
print(f"Title: {metadata['title']}")
|
| 108 |
+
print(f"Source: {metadata['source_file']}")
|
| 109 |
+
print(f"Similarity Score: {1 - distance:.3f}")
|
| 110 |
+
print(f"Preview: {doc[:200]}...")
|
| 111 |
+
print()
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# Main execution
|
| 115 |
+
if __name__ == "__main__":
|
| 116 |
+
print("=" * 60)
|
| 117 |
+
print("π VECTOR DATABASE SETUP")
|
| 118 |
+
print("=" * 60)
|
| 119 |
+
|
| 120 |
+
# Initialize vector database
|
| 121 |
+
vdb = VectorDatabase()
|
| 122 |
+
|
| 123 |
+
# Load documents
|
| 124 |
+
documents = vdb.load_documents("data/processed/processed_documents.json")
|
| 125 |
+
|
| 126 |
+
# Create embeddings
|
| 127 |
+
embeddings = vdb.create_embeddings(documents)
|
| 128 |
+
|
| 129 |
+
# Store in database
|
| 130 |
+
vdb.store_documents(documents, embeddings)
|
| 131 |
+
|
| 132 |
+
# Test search
|
| 133 |
+
print("\n" + "=" * 60)
|
| 134 |
+
print("π§ͺ TESTING SEARCH")
|
| 135 |
+
print("=" * 60)
|
| 136 |
+
|
| 137 |
+
test_queries = [
|
| 138 |
+
"How do I create a FastAPI endpoint?",
|
| 139 |
+
"What is employee leave policy?",
|
| 140 |
+
"How do I work remotely?"
|
| 141 |
+
]
|
| 142 |
+
|
| 143 |
+
for query in test_queries:
|
| 144 |
+
results = vdb.search(query, top_k=3)
|
| 145 |
+
vdb.display_results(results)
|
| 146 |
+
|
| 147 |
+
print("\n" + "=" * 60)
|
| 148 |
+
print("β
VECTOR DATABASE SETUP COMPLETE!")
|
| 149 |
+
print("=" * 60)
|