|
|
""" |
|
|
SPARKNET RAG API Routes |
|
|
Endpoints for RAG queries, search, and indexing management. |
|
|
""" |
|
|
|
|
|
from fastapi import APIRouter, HTTPException, Query, Depends |
|
|
from fastapi.responses import StreamingResponse |
|
|
from typing import List, Optional |
|
|
from pathlib import Path |
|
|
from datetime import datetime |
|
|
import time |
|
|
import json |
|
|
import sys |
|
|
import asyncio |
|
|
|
|
|
|
|
|
PROJECT_ROOT = Path(__file__).parent.parent.parent |
|
|
sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
|
|
|
from api.schemas import ( |
|
|
QueryRequest, RAGResponse, Citation, QueryPlan, QueryIntentType, |
|
|
SearchRequest, SearchResponse, SearchResult, |
|
|
StoreStatus, CollectionInfo |
|
|
) |
|
|
from loguru import logger |
|
|
|
|
|
router = APIRouter() |
|
|
|
|
|
|
|
|
_query_cache = {} |
|
|
CACHE_TTL_SECONDS = 3600 |
|
|
|
|
|
|
|
|
def get_cache_key(query: str, doc_ids: Optional[List[str]]) -> str: |
|
|
"""Generate cache key for query.""" |
|
|
import hashlib |
|
|
doc_str = ",".join(sorted(doc_ids)) if doc_ids else "all" |
|
|
content = f"{query}:{doc_str}" |
|
|
return hashlib.md5(content.encode()).hexdigest() |
|
|
|
|
|
|
|
|
def get_cached_response(cache_key: str) -> Optional[RAGResponse]: |
|
|
"""Get cached response if valid.""" |
|
|
if cache_key in _query_cache: |
|
|
cached = _query_cache[cache_key] |
|
|
if time.time() - cached["timestamp"] < CACHE_TTL_SECONDS: |
|
|
response = cached["response"] |
|
|
response.from_cache = True |
|
|
return response |
|
|
else: |
|
|
del _query_cache[cache_key] |
|
|
return None |
|
|
|
|
|
|
|
|
def cache_response(cache_key: str, response: RAGResponse): |
|
|
"""Cache a query response.""" |
|
|
_query_cache[cache_key] = { |
|
|
"response": response, |
|
|
"timestamp": time.time() |
|
|
} |
|
|
|
|
|
if len(_query_cache) > 1000: |
|
|
oldest_key = min(_query_cache, key=lambda k: _query_cache[k]["timestamp"]) |
|
|
del _query_cache[oldest_key] |
|
|
|
|
|
|
|
|
def _get_rag_system(): |
|
|
"""Get or initialize the RAG system.""" |
|
|
try: |
|
|
from src.rag.agentic.orchestrator import AgenticRAG, RAGConfig |
|
|
|
|
|
config = RAGConfig( |
|
|
model_name="llama3.2:latest", |
|
|
max_revision_attempts=2, |
|
|
retrieval_top_k=10, |
|
|
final_top_k=5, |
|
|
min_confidence=0.5, |
|
|
) |
|
|
return AgenticRAG(config) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize RAG system: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
@router.post("/query", response_model=RAGResponse) |
|
|
async def query_documents(request: QueryRequest): |
|
|
""" |
|
|
Execute a RAG query across indexed documents. |
|
|
|
|
|
The query goes through the 5-agent pipeline: |
|
|
1. QueryPlanner - Intent classification and query decomposition |
|
|
2. Retriever - Hybrid dense+sparse search |
|
|
3. Reranker - Cross-encoder reranking with MMR |
|
|
4. Synthesizer - Answer generation with citations |
|
|
5. Critic - Hallucination detection and validation |
|
|
""" |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
if request.use_cache: |
|
|
cache_key = get_cache_key(request.query, request.doc_ids) |
|
|
cached = get_cached_response(cache_key) |
|
|
if cached: |
|
|
cached.latency_ms = (time.time() - start_time) * 1000 |
|
|
return cached |
|
|
|
|
|
try: |
|
|
|
|
|
rag = _get_rag_system() |
|
|
if not rag: |
|
|
raise HTTPException(status_code=503, detail="RAG system not available") |
|
|
|
|
|
|
|
|
filters = {} |
|
|
if request.doc_ids: |
|
|
filters["document_id"] = {"$in": request.doc_ids} |
|
|
|
|
|
|
|
|
logger.info(f"Executing RAG query: {request.query[:50]}...") |
|
|
|
|
|
result = rag.query( |
|
|
query=request.query, |
|
|
filters=filters if filters else None, |
|
|
top_k=request.top_k, |
|
|
) |
|
|
|
|
|
|
|
|
citations = [] |
|
|
for i, source in enumerate(result.get("sources", [])): |
|
|
citations.append(Citation( |
|
|
citation_id=i + 1, |
|
|
doc_id=source.get("document_id", "unknown"), |
|
|
document_name=source.get("filename", source.get("document_id", "unknown")), |
|
|
chunk_id=source.get("chunk_id", f"chunk_{i}"), |
|
|
chunk_text=source.get("text", "")[:300], |
|
|
page_num=source.get("page_num"), |
|
|
relevance_score=source.get("relevance_score", source.get("score", 0.0)), |
|
|
bbox=source.get("bbox"), |
|
|
)) |
|
|
|
|
|
|
|
|
query_plan = None |
|
|
if "plan" in result: |
|
|
plan = result["plan"] |
|
|
query_plan = QueryPlan( |
|
|
intent=QueryIntentType(plan.get("intent", "factoid").lower()), |
|
|
sub_queries=plan.get("sub_queries", []), |
|
|
keywords=plan.get("keywords", []), |
|
|
strategy=plan.get("strategy", "hybrid"), |
|
|
) |
|
|
|
|
|
response = RAGResponse( |
|
|
query=request.query, |
|
|
answer=result.get("answer", "I could not find an answer to your question."), |
|
|
confidence=result.get("confidence", 0.0), |
|
|
citations=citations, |
|
|
source_count=len(citations), |
|
|
query_plan=query_plan, |
|
|
from_cache=False, |
|
|
validation=result.get("validation"), |
|
|
latency_ms=(time.time() - start_time) * 1000, |
|
|
revision_count=result.get("revision_count", 0), |
|
|
) |
|
|
|
|
|
|
|
|
if request.use_cache and response.confidence >= request.min_confidence: |
|
|
cache_key = get_cache_key(request.query, request.doc_ids) |
|
|
cache_response(cache_key, response) |
|
|
|
|
|
return response |
|
|
|
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"RAG query failed: {e}") |
|
|
raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}") |
|
|
|
|
|
|
|
|
@router.post("/query/stream") |
|
|
async def query_documents_stream(request: QueryRequest): |
|
|
""" |
|
|
Stream RAG response for real-time updates. |
|
|
|
|
|
Returns Server-Sent Events (SSE) with partial responses. |
|
|
""" |
|
|
async def generate(): |
|
|
try: |
|
|
|
|
|
rag = _get_rag_system() |
|
|
if not rag: |
|
|
yield f"data: {json.dumps({'error': 'RAG system not available'})}\n\n" |
|
|
return |
|
|
|
|
|
|
|
|
yield f"data: {json.dumps({'stage': 'planning', 'message': 'Analyzing query...'})}\n\n" |
|
|
await asyncio.sleep(0.1) |
|
|
|
|
|
|
|
|
filters = {} |
|
|
if request.doc_ids: |
|
|
filters["document_id"] = {"$in": request.doc_ids} |
|
|
|
|
|
|
|
|
yield f"data: {json.dumps({'stage': 'retrieving', 'message': 'Searching documents...'})}\n\n" |
|
|
|
|
|
|
|
|
result = rag.query( |
|
|
query=request.query, |
|
|
filters=filters if filters else None, |
|
|
top_k=request.top_k, |
|
|
) |
|
|
|
|
|
|
|
|
yield f"data: {json.dumps({'stage': 'sources', 'count': len(result.get('sources', []))})}\n\n" |
|
|
|
|
|
|
|
|
yield f"data: {json.dumps({'stage': 'synthesizing', 'message': 'Generating answer...'})}\n\n" |
|
|
|
|
|
|
|
|
answer = result.get("answer", "") |
|
|
chunk_size = 50 |
|
|
for i in range(0, len(answer), chunk_size): |
|
|
chunk = answer[i:i+chunk_size] |
|
|
yield f"data: {json.dumps({'stage': 'answer', 'chunk': chunk})}\n\n" |
|
|
await asyncio.sleep(0.02) |
|
|
|
|
|
|
|
|
citations = [] |
|
|
for i, source in enumerate(result.get("sources", [])): |
|
|
citations.append({ |
|
|
"citation_id": i + 1, |
|
|
"doc_id": source.get("document_id", "unknown"), |
|
|
"chunk_text": source.get("text", "")[:200], |
|
|
"relevance_score": source.get("score", 0.0), |
|
|
}) |
|
|
|
|
|
final = { |
|
|
"stage": "complete", |
|
|
"confidence": result.get("confidence", 0.0), |
|
|
"citations": citations, |
|
|
"validation": result.get("validation"), |
|
|
} |
|
|
yield f"data: {json.dumps(final)}\n\n" |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Streaming query failed: {e}") |
|
|
yield f"data: {json.dumps({'error': str(e)})}\n\n" |
|
|
|
|
|
return StreamingResponse( |
|
|
generate(), |
|
|
media_type="text/event-stream", |
|
|
headers={ |
|
|
"Cache-Control": "no-cache", |
|
|
"Connection": "keep-alive", |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
@router.post("/search", response_model=SearchResponse) |
|
|
async def search_documents(request: SearchRequest): |
|
|
""" |
|
|
Semantic search across indexed documents. |
|
|
|
|
|
Returns matching chunks without answer synthesis. |
|
|
""" |
|
|
start_time = time.time() |
|
|
|
|
|
try: |
|
|
from src.rag.store import get_vector_store |
|
|
from src.rag.embeddings import get_embedding_model |
|
|
|
|
|
store = get_vector_store() |
|
|
embeddings = get_embedding_model() |
|
|
|
|
|
|
|
|
query_embedding = embeddings.embed_query(request.query) |
|
|
|
|
|
|
|
|
where_filter = None |
|
|
if request.doc_ids: |
|
|
where_filter = {"document_id": {"$in": request.doc_ids}} |
|
|
|
|
|
|
|
|
results = store.similarity_search_with_score( |
|
|
query_embedding=query_embedding, |
|
|
k=request.top_k, |
|
|
where=where_filter, |
|
|
) |
|
|
|
|
|
|
|
|
search_results = [] |
|
|
for doc, score in results: |
|
|
if score >= request.min_score: |
|
|
search_results.append(SearchResult( |
|
|
chunk_id=doc.metadata.get("chunk_id", "unknown"), |
|
|
doc_id=doc.metadata.get("document_id", "unknown"), |
|
|
document_name=doc.metadata.get("filename", "unknown"), |
|
|
text=doc.page_content, |
|
|
score=score, |
|
|
page_num=doc.metadata.get("page_num"), |
|
|
chunk_type=doc.metadata.get("chunk_type", "text"), |
|
|
)) |
|
|
|
|
|
return SearchResponse( |
|
|
query=request.query, |
|
|
total_results=len(search_results), |
|
|
results=search_results, |
|
|
latency_ms=(time.time() - start_time) * 1000, |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Search failed: {e}") |
|
|
|
|
|
return SearchResponse( |
|
|
query=request.query, |
|
|
total_results=0, |
|
|
results=[], |
|
|
latency_ms=(time.time() - start_time) * 1000, |
|
|
) |
|
|
|
|
|
|
|
|
@router.get("/store/status", response_model=StoreStatus) |
|
|
async def get_store_status(): |
|
|
"""Get vector store status and statistics.""" |
|
|
try: |
|
|
from src.rag.store import get_vector_store |
|
|
|
|
|
store = get_vector_store() |
|
|
|
|
|
|
|
|
collection = store._collection |
|
|
count = collection.count() |
|
|
|
|
|
|
|
|
all_metadata = collection.get(include=["metadatas"]) |
|
|
doc_ids = set() |
|
|
for meta in all_metadata.get("metadatas", []): |
|
|
if meta and "document_id" in meta: |
|
|
doc_ids.add(meta["document_id"]) |
|
|
|
|
|
collections = [CollectionInfo( |
|
|
name=store.collection_name, |
|
|
document_count=len(doc_ids), |
|
|
chunk_count=count, |
|
|
embedding_dimension=store.embedding_dimension if hasattr(store, 'embedding_dimension') else 1024, |
|
|
)] |
|
|
|
|
|
return StoreStatus( |
|
|
status="healthy", |
|
|
collections=collections, |
|
|
total_documents=len(doc_ids), |
|
|
total_chunks=count, |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Store status check failed: {e}") |
|
|
return StoreStatus( |
|
|
status="error", |
|
|
collections=[], |
|
|
total_documents=0, |
|
|
total_chunks=0, |
|
|
) |
|
|
|
|
|
|
|
|
@router.delete("/store/collection/{collection_name}") |
|
|
async def clear_collection(collection_name: str, confirm: bool = Query(False)): |
|
|
"""Clear a vector store collection (dangerous operation).""" |
|
|
if not confirm: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail="This operation will delete all data. Set confirm=true to proceed." |
|
|
) |
|
|
|
|
|
try: |
|
|
from src.rag.store import get_vector_store |
|
|
|
|
|
store = get_vector_store() |
|
|
if store.collection_name != collection_name: |
|
|
raise HTTPException(status_code=404, detail=f"Collection not found: {collection_name}") |
|
|
|
|
|
|
|
|
store._collection.delete(where={}) |
|
|
|
|
|
return {"status": "cleared", "collection": collection_name, "message": "Collection cleared successfully"} |
|
|
|
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"Collection clear failed: {e}") |
|
|
raise HTTPException(status_code=500, detail=f"Clear failed: {str(e)}") |
|
|
|
|
|
|
|
|
@router.get("/cache/stats") |
|
|
async def get_cache_stats(): |
|
|
"""Get query cache statistics.""" |
|
|
current_time = time.time() |
|
|
valid_entries = sum( |
|
|
1 for v in _query_cache.values() |
|
|
if current_time - v["timestamp"] < CACHE_TTL_SECONDS |
|
|
) |
|
|
|
|
|
return { |
|
|
"total_entries": len(_query_cache), |
|
|
"valid_entries": valid_entries, |
|
|
"expired_entries": len(_query_cache) - valid_entries, |
|
|
"ttl_seconds": CACHE_TTL_SECONDS, |
|
|
} |
|
|
|
|
|
|
|
|
@router.delete("/cache") |
|
|
async def clear_cache(): |
|
|
"""Clear the query cache.""" |
|
|
count = len(_query_cache) |
|
|
_query_cache.clear() |
|
|
return {"status": "cleared", "entries_removed": count} |
|
|
|