""" SPARKNET RAG API Routes Endpoints for RAG queries, search, and indexing management. """ from fastapi import APIRouter, HTTPException, Query, Depends from fastapi.responses import StreamingResponse from typing import List, Optional from pathlib import Path from datetime import datetime import time import json import sys import asyncio # Add project root to path PROJECT_ROOT = Path(__file__).parent.parent.parent sys.path.insert(0, str(PROJECT_ROOT)) from api.schemas import ( QueryRequest, RAGResponse, Citation, QueryPlan, QueryIntentType, SearchRequest, SearchResponse, SearchResult, StoreStatus, CollectionInfo ) from loguru import logger router = APIRouter() # Simple in-memory cache for query results _query_cache = {} CACHE_TTL_SECONDS = 3600 # 1 hour def get_cache_key(query: str, doc_ids: Optional[List[str]]) -> str: """Generate cache key for query.""" import hashlib doc_str = ",".join(sorted(doc_ids)) if doc_ids else "all" content = f"{query}:{doc_str}" return hashlib.md5(content.encode()).hexdigest() def get_cached_response(cache_key: str) -> Optional[RAGResponse]: """Get cached response if valid.""" if cache_key in _query_cache: cached = _query_cache[cache_key] if time.time() - cached["timestamp"] < CACHE_TTL_SECONDS: response = cached["response"] response.from_cache = True return response else: del _query_cache[cache_key] return None def cache_response(cache_key: str, response: RAGResponse): """Cache a query response.""" _query_cache[cache_key] = { "response": response, "timestamp": time.time() } # Limit cache size if len(_query_cache) > 1000: oldest_key = min(_query_cache, key=lambda k: _query_cache[k]["timestamp"]) del _query_cache[oldest_key] def _get_rag_system(): """Get or initialize the RAG system.""" try: from src.rag.agentic.orchestrator import AgenticRAG, RAGConfig config = RAGConfig( model_name="llama3.2:latest", max_revision_attempts=2, retrieval_top_k=10, final_top_k=5, min_confidence=0.5, ) return AgenticRAG(config) except Exception as e: logger.error(f"Failed to initialize RAG system: {e}") return None @router.post("/query", response_model=RAGResponse) async def query_documents(request: QueryRequest): """ Execute a RAG query across indexed documents. The query goes through the 5-agent pipeline: 1. QueryPlanner - Intent classification and query decomposition 2. Retriever - Hybrid dense+sparse search 3. Reranker - Cross-encoder reranking with MMR 4. Synthesizer - Answer generation with citations 5. Critic - Hallucination detection and validation """ start_time = time.time() # Check cache if enabled if request.use_cache: cache_key = get_cache_key(request.query, request.doc_ids) cached = get_cached_response(cache_key) if cached: cached.latency_ms = (time.time() - start_time) * 1000 return cached try: # Initialize RAG system rag = _get_rag_system() if not rag: raise HTTPException(status_code=503, detail="RAG system not available") # Build filters filters = {} if request.doc_ids: filters["document_id"] = {"$in": request.doc_ids} # Execute query logger.info(f"Executing RAG query: {request.query[:50]}...") result = rag.query( query=request.query, filters=filters if filters else None, top_k=request.top_k, ) # Build response citations = [] for i, source in enumerate(result.get("sources", [])): citations.append(Citation( citation_id=i + 1, doc_id=source.get("document_id", "unknown"), document_name=source.get("filename", source.get("document_id", "unknown")), chunk_id=source.get("chunk_id", f"chunk_{i}"), chunk_text=source.get("text", "")[:300], page_num=source.get("page_num"), relevance_score=source.get("relevance_score", source.get("score", 0.0)), bbox=source.get("bbox"), )) # Query plan info query_plan = None if "plan" in result: plan = result["plan"] query_plan = QueryPlan( intent=QueryIntentType(plan.get("intent", "factoid").lower()), sub_queries=plan.get("sub_queries", []), keywords=plan.get("keywords", []), strategy=plan.get("strategy", "hybrid"), ) response = RAGResponse( query=request.query, answer=result.get("answer", "I could not find an answer to your question."), confidence=result.get("confidence", 0.0), citations=citations, source_count=len(citations), query_plan=query_plan, from_cache=False, validation=result.get("validation"), latency_ms=(time.time() - start_time) * 1000, revision_count=result.get("revision_count", 0), ) # Cache successful responses if request.use_cache and response.confidence >= request.min_confidence: cache_key = get_cache_key(request.query, request.doc_ids) cache_response(cache_key, response) return response except HTTPException: raise except Exception as e: logger.error(f"RAG query failed: {e}") raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}") @router.post("/query/stream") async def query_documents_stream(request: QueryRequest): """ Stream RAG response for real-time updates. Returns Server-Sent Events (SSE) with partial responses. """ async def generate(): try: # Initialize RAG system rag = _get_rag_system() if not rag: yield f"data: {json.dumps({'error': 'RAG system not available'})}\n\n" return # Send planning stage yield f"data: {json.dumps({'stage': 'planning', 'message': 'Analyzing query...'})}\n\n" await asyncio.sleep(0.1) # Build filters filters = {} if request.doc_ids: filters["document_id"] = {"$in": request.doc_ids} # Send retrieval stage yield f"data: {json.dumps({'stage': 'retrieving', 'message': 'Searching documents...'})}\n\n" # Execute query (in chunks if streaming supported) result = rag.query( query=request.query, filters=filters if filters else None, top_k=request.top_k, ) # Send sources yield f"data: {json.dumps({'stage': 'sources', 'count': len(result.get('sources', []))})}\n\n" # Send synthesis stage yield f"data: {json.dumps({'stage': 'synthesizing', 'message': 'Generating answer...'})}\n\n" # Stream answer in chunks answer = result.get("answer", "") chunk_size = 50 for i in range(0, len(answer), chunk_size): chunk = answer[i:i+chunk_size] yield f"data: {json.dumps({'stage': 'answer', 'chunk': chunk})}\n\n" await asyncio.sleep(0.02) # Send final result citations = [] for i, source in enumerate(result.get("sources", [])): citations.append({ "citation_id": i + 1, "doc_id": source.get("document_id", "unknown"), "chunk_text": source.get("text", "")[:200], "relevance_score": source.get("score", 0.0), }) final = { "stage": "complete", "confidence": result.get("confidence", 0.0), "citations": citations, "validation": result.get("validation"), } yield f"data: {json.dumps(final)}\n\n" except Exception as e: logger.error(f"Streaming query failed: {e}") yield f"data: {json.dumps({'error': str(e)})}\n\n" return StreamingResponse( generate(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", } ) @router.post("/search", response_model=SearchResponse) async def search_documents(request: SearchRequest): """ Semantic search across indexed documents. Returns matching chunks without answer synthesis. """ start_time = time.time() try: from src.rag.store import get_vector_store from src.rag.embeddings import get_embedding_model store = get_vector_store() embeddings = get_embedding_model() # Generate query embedding query_embedding = embeddings.embed_query(request.query) # Build filter where_filter = None if request.doc_ids: where_filter = {"document_id": {"$in": request.doc_ids}} # Search results = store.similarity_search_with_score( query_embedding=query_embedding, k=request.top_k, where=where_filter, ) # Filter by minimum score search_results = [] for doc, score in results: if score >= request.min_score: search_results.append(SearchResult( chunk_id=doc.metadata.get("chunk_id", "unknown"), doc_id=doc.metadata.get("document_id", "unknown"), document_name=doc.metadata.get("filename", "unknown"), text=doc.page_content, score=score, page_num=doc.metadata.get("page_num"), chunk_type=doc.metadata.get("chunk_type", "text"), )) return SearchResponse( query=request.query, total_results=len(search_results), results=search_results, latency_ms=(time.time() - start_time) * 1000, ) except Exception as e: logger.error(f"Search failed: {e}") # Fallback: return empty results return SearchResponse( query=request.query, total_results=0, results=[], latency_ms=(time.time() - start_time) * 1000, ) @router.get("/store/status", response_model=StoreStatus) async def get_store_status(): """Get vector store status and statistics.""" try: from src.rag.store import get_vector_store store = get_vector_store() # Get collection info collection = store._collection count = collection.count() # Get unique documents all_metadata = collection.get(include=["metadatas"]) doc_ids = set() for meta in all_metadata.get("metadatas", []): if meta and "document_id" in meta: doc_ids.add(meta["document_id"]) collections = [CollectionInfo( name=store.collection_name, document_count=len(doc_ids), chunk_count=count, embedding_dimension=store.embedding_dimension if hasattr(store, 'embedding_dimension') else 1024, )] return StoreStatus( status="healthy", collections=collections, total_documents=len(doc_ids), total_chunks=count, ) except Exception as e: logger.error(f"Store status check failed: {e}") return StoreStatus( status="error", collections=[], total_documents=0, total_chunks=0, ) @router.delete("/store/collection/{collection_name}") async def clear_collection(collection_name: str, confirm: bool = Query(False)): """Clear a vector store collection (dangerous operation).""" if not confirm: raise HTTPException( status_code=400, detail="This operation will delete all data. Set confirm=true to proceed." ) try: from src.rag.store import get_vector_store store = get_vector_store() if store.collection_name != collection_name: raise HTTPException(status_code=404, detail=f"Collection not found: {collection_name}") # Clear collection store._collection.delete(where={}) return {"status": "cleared", "collection": collection_name, "message": "Collection cleared successfully"} except HTTPException: raise except Exception as e: logger.error(f"Collection clear failed: {e}") raise HTTPException(status_code=500, detail=f"Clear failed: {str(e)}") @router.get("/cache/stats") async def get_cache_stats(): """Get query cache statistics.""" current_time = time.time() valid_entries = sum( 1 for v in _query_cache.values() if current_time - v["timestamp"] < CACHE_TTL_SECONDS ) return { "total_entries": len(_query_cache), "valid_entries": valid_entries, "expired_entries": len(_query_cache) - valid_entries, "ttl_seconds": CACHE_TTL_SECONDS, } @router.delete("/cache") async def clear_cache(): """Clear the query cache.""" count = len(_query_cache) _query_cache.clear() return {"status": "cleared", "entries_removed": count}