Spaces:
Build error
Build error
| import os | |
| import asyncio | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import List, Optional, Dict | |
| from dotenv import load_dotenv | |
| import logging | |
| # Load environment variables | |
| load_dotenv() | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Import the existing RAG agent functionality | |
| from agent import RAGAgent | |
| # Create FastAPI app | |
| app = FastAPI( | |
| title="RAG Agent API", | |
| description="API for RAG Agent with document retrieval and question answering", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware for development | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, replace with specific origins | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Pydantic models | |
| class QueryRequest(BaseModel): | |
| query: str | |
| class ChatRequest(BaseModel): | |
| query: str | |
| message: str | |
| session_id: str | |
| selected_text: Optional[str] = None | |
| query_type: str = "global" | |
| top_k: int = 5 | |
| class MatchedChunk(BaseModel): | |
| content: str | |
| url: str | |
| position: int | |
| similarity_score: float | |
| class QueryResponse(BaseModel): | |
| answer: str | |
| sources: List[str] | |
| matched_chunks: List[MatchedChunk] | |
| error: Optional[str] = None | |
| status: str # "success", "error", "empty" | |
| query_time_ms: Optional[float] = None | |
| confidence: Optional[str] = None | |
| class ChatResponse(BaseModel): | |
| response: str | |
| citations: List[Dict[str, str]] | |
| session_id: str | |
| query_type: str | |
| timestamp: str | |
| class HealthResponse(BaseModel): | |
| status: str | |
| message: str | |
| # Global RAG agent instance | |
| rag_agent = None | |
| async def startup_event(): | |
| """Initialize the RAG agent on startup""" | |
| global rag_agent | |
| logger.info("Initializing RAG Agent...") | |
| try: | |
| rag_agent = RAGAgent() | |
| logger.info("RAG Agent initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize RAG Agent: {e}") | |
| raise | |
| async def ask_rag(request: QueryRequest): | |
| """ | |
| Process a user query through the RAG agent and return the response | |
| """ | |
| logger.info(f"Processing query: {request.query[:50]}...") | |
| try: | |
| # Validate input | |
| if not request.query or len(request.query.strip()) == 0: | |
| raise HTTPException(status_code=400, detail="Query cannot be empty") | |
| if len(request.query) > 2000: | |
| raise HTTPException(status_code=400, detail="Query too long, maximum 2000 characters") | |
| # Process query through RAG agent | |
| response = rag_agent.query_agent(request.query) | |
| # Format response | |
| formatted_response = QueryResponse( | |
| answer=response.get("answer", ""), | |
| sources=response.get("sources", []), | |
| matched_chunks=[ | |
| MatchedChunk( | |
| content=chunk.get("content", ""), | |
| url=chunk.get("url", ""), | |
| position=chunk.get("position", 0), | |
| similarity_score=chunk.get("similarity_score", 0.0) | |
| ) | |
| for chunk in response.get("matched_chunks", []) | |
| ], | |
| error=response.get("error"), | |
| status="error" if response.get("error") else "success", | |
| query_time_ms=response.get("query_time_ms"), | |
| confidence=response.get("confidence") | |
| ) | |
| logger.info(f"Query processed successfully in {response.get('query_time_ms', 0):.2f}ms") | |
| return formatted_response | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error processing query: {e}") | |
| return QueryResponse( | |
| answer="", | |
| sources=[], | |
| matched_chunks=[], | |
| error=str(e), | |
| status="error" | |
| ) | |
| async def chat_endpoint(request: ChatRequest): | |
| """ | |
| Main chat endpoint that handles conversation with RAG capabilities | |
| """ | |
| logger.info(f"Processing chat query: {request.query[:50]}...") | |
| try: | |
| # Validate input | |
| if not request.query or len(request.query.strip()) == 0: | |
| raise HTTPException(status_code=400, detail="Query cannot be empty") | |
| if not request.session_id or len(request.session_id.strip()) == 0: | |
| raise HTTPException(status_code=400, detail="Session ID cannot be empty") | |
| if len(request.query) > 2000: | |
| raise HTTPException(status_code=400, detail="Query too long, maximum 2000 characters") | |
| # Process query through RAG agent | |
| response = rag_agent.query_agent(request.query) | |
| # Format response to match expected structure | |
| from datetime import datetime | |
| timestamp = datetime.utcnow().isoformat() | |
| # Convert matched chunks to citations format | |
| citations = [] | |
| for chunk in response.get("matched_chunks", []): | |
| citation = { | |
| "document_id": "", | |
| "title": chunk.get("url", ""), | |
| "chapter": "", | |
| "section": "", | |
| "page_reference": "" | |
| } | |
| citations.append(citation) | |
| formatted_response = ChatResponse( | |
| response=response.get("answer", ""), | |
| citations=citations, | |
| session_id=request.session_id, | |
| query_type=request.query_type, | |
| timestamp=timestamp | |
| ) | |
| logger.info(f"Chat query processed successfully") | |
| return formatted_response | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error processing chat query: {e}") | |
| from datetime import datetime | |
| return ChatResponse( | |
| response="", | |
| citations=[], | |
| session_id=request.session_id, | |
| query_type=request.query_type, | |
| timestamp=datetime.utcnow().isoformat() | |
| ) | |
| async def health_check(): | |
| """ | |
| Health check endpoint | |
| """ | |
| return HealthResponse( | |
| status="healthy", | |
| message="RAG Agent API is running" | |
| ) | |
| # For running with uvicorn | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |