Spaces:
Running
Running
| from src.retrieval.rag_pipeline import AnimeRAGPipeline | |
| from fastapi import FastAPI, HTTPException, status, Request, Security, Depends | |
| from fastapi.security import APIKeyHeader | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| import uvicorn | |
| from src.api.schemas import RecommendationRequest, RecommendationResponse | |
| import time | |
| from config import settings | |
| import traceback | |
| import logging | |
| from contextlib import asynccontextmanager | |
| logger = logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| api_key_header = APIKeyHeader(name="X-API-Key") | |
| async def verify_api_key(api_key: str = Security(api_key_header)): | |
| if api_key != settings.fastapi_api_key: | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, detail="Invalid API Key") | |
| return api_key | |
| async def lifespan(app: FastAPI): | |
| """ | |
| Handles startup and shutdown of the Anime RAG Pipeline. | |
| Replaces lazy global initialization with app state. | |
| """ | |
| print("Initializing Anime RAG Pipeline...") | |
| app.state.pipeline = AnimeRAGPipeline(retriever_k=50) | |
| yield | |
| print("Shutting down... Cleaning up resources.") | |
| app = FastAPI(title="Anime Recommendation API", | |
| description="RAG-powered anime recommendation system", | |
| version="1.0.0", | |
| lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"] | |
| ) | |
| async def root(): | |
| """Healthcheck Endpoint""" | |
| return { | |
| "status": "online", | |
| "message": "Anime recommendation API", | |
| "version": "1.0.0" | |
| } | |
| async def get_recommendations(request: RecommendationRequest, fastapi_req: Request, _=Depends(verify_api_key)): | |
| """ | |
| Get anime recommendation based on user query | |
| Example request: | |
| ```json | |
| { | |
| "query": "Anime similar to Death Note but lighter", | |
| "n_results": 5, | |
| "min_score": 7.5, | |
| "genre_filter": ["Comedy", "Fantasy"], | |
| "anime_type": TV | |
| } | |
| ``` | |
| """ | |
| try: | |
| rag_pipeline = fastapi_req.app.state.pipeline | |
| rag_pipeline.recommendation_n = request.n_results | |
| filters = {} | |
| if request.min_score: | |
| filters["min_score"] = request.min_score | |
| if request.genre_filter: | |
| filters["genre_filter"] = request.genre_filter | |
| if request.anime_type: | |
| filters["anime_type"] = request.anime_type | |
| start_time = time.time() | |
| result = rag_pipeline.recommend( | |
| user_query=request.query, | |
| filters=filters if filters else None | |
| ) | |
| end_time = time.time() | |
| # print(f"Retrieved anime : \n{result["retrieved_animes"][0]}") | |
| print(f"Retrieved anime Count : \n{result["reranked_count"]}") | |
| print(f"Result Recommendations: \n{result["recommendations"][:20]}") | |
| return RecommendationResponse( | |
| query=result["query"], | |
| recommendations=result["recommendations"], | |
| retrieved_count=result["reranked_count"], | |
| retrieved_animes=result["retrieved_animes"], | |
| metadata={ | |
| "model": settings.model_name, | |
| "retriever_k": rag_pipeline.retriever_k, | |
| "Time taken for LLM + vector search": str(end_time - start_time) | |
| } | |
| ) | |
| except Exception as e: | |
| traceback.print_exc() | |
| raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Error processing request: {str(e)}") | |
| async def get_stats(fastapi_req: Request, _=Depends(verify_api_key)): | |
| """Get system statistics""" | |
| rag_pipeline = fastapi_req.app.state.pipeline | |
| return { | |
| "total_anime": rag_pipeline.retriever.points_count, | |
| "embedding_model": "all-MiniLM-L6-v2", | |
| "llm_model": settings.model_name, | |
| "retrieval_k": rag_pipeline.retriever_k | |
| } | |
| if __name__ == "__main__": | |
| uvicorn.run( | |
| "src.api.main:app", | |
| host="0.0.0.0", | |
| port=8000, | |
| reload=True | |
| ) | |