import os import asyncio from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import List, Optional, Dict from dotenv import load_dotenv import logging # Load environment variables load_dotenv() # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Import the existing RAG agent functionality from agent import RAGAgent # Create FastAPI app app = FastAPI( title="RAG Agent API", description="API for RAG Agent with document retrieval and question answering", version="1.0.0" ) # Add CORS middleware for development app.add_middleware( CORSMiddleware, allow_origins=["*"], # In production, replace with specific origins allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Pydantic models class QueryRequest(BaseModel): query: str class ChatRequest(BaseModel): query: str message: str session_id: str selected_text: Optional[str] = None query_type: str = "global" top_k: int = 5 class MatchedChunk(BaseModel): content: str url: str position: int similarity_score: float class QueryResponse(BaseModel): answer: str sources: List[str] matched_chunks: List[MatchedChunk] error: Optional[str] = None status: str # "success", "error", "empty" query_time_ms: Optional[float] = None confidence: Optional[str] = None class ChatResponse(BaseModel): response: str citations: List[Dict[str, str]] session_id: str query_type: str timestamp: str class HealthResponse(BaseModel): status: str message: str # Global RAG agent instance rag_agent = None @app.on_event("startup") async def startup_event(): """Initialize the RAG agent on startup""" global rag_agent logger.info("Initializing RAG Agent...") try: rag_agent = RAGAgent() logger.info("RAG Agent initialized successfully") except Exception as e: logger.error(f"Failed to initialize RAG Agent: {e}") raise @app.post("/ask", response_model=QueryResponse) async def ask_rag(request: QueryRequest): """ Process a user query through the RAG agent and return the response """ logger.info(f"Processing query: {request.query[:50]}...") try: # Validate input if not request.query or len(request.query.strip()) == 0: raise HTTPException(status_code=400, detail="Query cannot be empty") if len(request.query) > 2000: raise HTTPException(status_code=400, detail="Query too long, maximum 2000 characters") # Process query through RAG agent response = rag_agent.query_agent(request.query) # Format response formatted_response = QueryResponse( answer=response.get("answer", ""), sources=response.get("sources", []), matched_chunks=[ MatchedChunk( content=chunk.get("content", ""), url=chunk.get("url", ""), position=chunk.get("position", 0), similarity_score=chunk.get("similarity_score", 0.0) ) for chunk in response.get("matched_chunks", []) ], error=response.get("error"), status="error" if response.get("error") else "success", query_time_ms=response.get("query_time_ms"), confidence=response.get("confidence") ) logger.info(f"Query processed successfully in {response.get('query_time_ms', 0):.2f}ms") return formatted_response except HTTPException: raise except Exception as e: logger.error(f"Error processing query: {e}") return QueryResponse( answer="", sources=[], matched_chunks=[], error=str(e), status="error" ) @app.post("/api", response_model=ChatResponse) async def chat_endpoint(request: ChatRequest): """ Main chat endpoint that handles conversation with RAG capabilities """ logger.info(f"Processing chat query: {request.query[:50]}...") try: # Validate input if not request.query or len(request.query.strip()) == 0: raise HTTPException(status_code=400, detail="Query cannot be empty") if not request.session_id or len(request.session_id.strip()) == 0: raise HTTPException(status_code=400, detail="Session ID cannot be empty") if len(request.query) > 2000: raise HTTPException(status_code=400, detail="Query too long, maximum 2000 characters") # Process query through RAG agent response = rag_agent.query_agent(request.query) # Format response to match expected structure from datetime import datetime timestamp = datetime.utcnow().isoformat() # Convert matched chunks to citations format citations = [] for chunk in response.get("matched_chunks", []): citation = { "document_id": "", "title": chunk.get("url", ""), "chapter": "", "section": "", "page_reference": "" } citations.append(citation) formatted_response = ChatResponse( response=response.get("answer", ""), citations=citations, session_id=request.session_id, query_type=request.query_type, timestamp=timestamp ) logger.info(f"Chat query processed successfully") return formatted_response except HTTPException: raise except Exception as e: logger.error(f"Error processing chat query: {e}") from datetime import datetime return ChatResponse( response="", citations=[], session_id=request.session_id, query_type=request.query_type, timestamp=datetime.utcnow().isoformat() ) @app.get("/health", response_model=HealthResponse) async def health_check(): """ Health check endpoint """ return HealthResponse( status="healthy", message="RAG Agent API is running" ) # For running with uvicorn if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)