from src.retrieval.rag_pipeline import AnimeRAGPipeline from fastapi import FastAPI, HTTPException, status, Request, Security, Depends from fastapi.security import APIKeyHeader from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field import uvicorn from src.api.schemas import RecommendationRequest, RecommendationResponse import time from config import settings import traceback import logging from contextlib import asynccontextmanager logger = logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) api_key_header = APIKeyHeader(name="X-API-Key") async def verify_api_key(api_key: str = Security(api_key_header)): if api_key != settings.fastapi_api_key: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Invalid API Key") return api_key @asynccontextmanager async def lifespan(app: FastAPI): """ Handles startup and shutdown of the Anime RAG Pipeline. Replaces lazy global initialization with app state. """ print("Initializing Anime RAG Pipeline...") app.state.pipeline = AnimeRAGPipeline(retriever_k=50) yield print("Shutting down... Cleaning up resources.") app = FastAPI(title="Anime Recommendation API", description="RAG-powered anime recommendation system", version="1.0.0", lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"] ) @app.get("/") async def root(): """Healthcheck Endpoint""" return { "status": "online", "message": "Anime recommendation API", "version": "1.0.0" } @app.post("/recommend", response_model=RecommendationResponse) async def get_recommendations(request: RecommendationRequest, fastapi_req: Request, _=Depends(verify_api_key)): """ Get anime recommendation based on user query Example request: ```json { "query": "Anime similar to Death Note but lighter", "n_results": 5, "min_score": 7.5, "genre_filter": ["Comedy", "Fantasy"], "anime_type": TV } ``` """ try: rag_pipeline = fastapi_req.app.state.pipeline rag_pipeline.recommendation_n = request.n_results filters = {} if request.min_score: filters["min_score"] = request.min_score if request.genre_filter: filters["genre_filter"] = request.genre_filter if request.anime_type: filters["anime_type"] = request.anime_type start_time = time.time() result = rag_pipeline.recommend( user_query=request.query, filters=filters if filters else None ) end_time = time.time() # print(f"Retrieved anime : \n{result["retrieved_animes"][0]}") print(f"Retrieved anime Count : \n{result["reranked_count"]}") print(f"Result Recommendations: \n{result["recommendations"][:20]}") return RecommendationResponse( query=result["query"], recommendations=result["recommendations"], retrieved_count=result["reranked_count"], retrieved_animes=result["retrieved_animes"], metadata={ "model": settings.model_name, "retriever_k": rag_pipeline.retriever_k, "Time taken for LLM + vector search": str(end_time - start_time) } ) except Exception as e: traceback.print_exc() raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error processing request: {str(e)}") @app.get("/stats") async def get_stats(fastapi_req: Request, _=Depends(verify_api_key)): """Get system statistics""" rag_pipeline = fastapi_req.app.state.pipeline return { "total_anime": rag_pipeline.retriever.points_count, "embedding_model": "all-MiniLM-L6-v2", "llm_model": settings.model_name, "retrieval_k": rag_pipeline.retriever_k } if __name__ == "__main__": uvicorn.run( "src.api.main:app", host="0.0.0.0", port=8000, reload=True )