AnimeRAGSystem / src /api /main.py
Pushkar02-n's picture
Add remaining changes to previous commit
617515c
from src.retrieval.rag_pipeline import AnimeRAGPipeline
from fastapi import FastAPI, HTTPException, status, Request, Security, Depends
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import uvicorn
from src.api.schemas import RecommendationRequest, RecommendationResponse
import time
from config import settings
import traceback
import logging
from contextlib import asynccontextmanager
logger = logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
api_key_header = APIKeyHeader(name="X-API-Key")
async def verify_api_key(api_key: str = Security(api_key_header)):
if api_key != settings.fastapi_api_key:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Invalid API Key")
return api_key
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Handles startup and shutdown of the Anime RAG Pipeline.
Replaces lazy global initialization with app state.
"""
print("Initializing Anime RAG Pipeline...")
app.state.pipeline = AnimeRAGPipeline(retriever_k=50)
yield
print("Shutting down... Cleaning up resources.")
app = FastAPI(title="Anime Recommendation API",
description="RAG-powered anime recommendation system",
version="1.0.0",
lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
@app.get("/")
async def root():
"""Healthcheck Endpoint"""
return {
"status": "online",
"message": "Anime recommendation API",
"version": "1.0.0"
}
@app.post("/recommend", response_model=RecommendationResponse)
async def get_recommendations(request: RecommendationRequest, fastapi_req: Request, _=Depends(verify_api_key)):
"""
Get anime recommendation based on user query
Example request:
```json
{
"query": "Anime similar to Death Note but lighter",
"n_results": 5,
"min_score": 7.5,
"genre_filter": ["Comedy", "Fantasy"],
"anime_type": TV
}
```
"""
try:
rag_pipeline = fastapi_req.app.state.pipeline
rag_pipeline.recommendation_n = request.n_results
filters = {}
if request.min_score:
filters["min_score"] = request.min_score
if request.genre_filter:
filters["genre_filter"] = request.genre_filter
if request.anime_type:
filters["anime_type"] = request.anime_type
start_time = time.time()
result = rag_pipeline.recommend(
user_query=request.query,
filters=filters if filters else None
)
end_time = time.time()
# print(f"Retrieved anime : \n{result["retrieved_animes"][0]}")
print(f"Retrieved anime Count : \n{result["reranked_count"]}")
print(f"Result Recommendations: \n{result["recommendations"][:20]}")
return RecommendationResponse(
query=result["query"],
recommendations=result["recommendations"],
retrieved_count=result["reranked_count"],
retrieved_animes=result["retrieved_animes"],
metadata={
"model": settings.model_name,
"retriever_k": rag_pipeline.retriever_k,
"Time taken for LLM + vector search": str(end_time - start_time)
}
)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error processing request: {str(e)}")
@app.get("/stats")
async def get_stats(fastapi_req: Request, _=Depends(verify_api_key)):
"""Get system statistics"""
rag_pipeline = fastapi_req.app.state.pipeline
return {
"total_anime": rag_pipeline.retriever.points_count,
"embedding_model": "all-MiniLM-L6-v2",
"llm_model": settings.model_name,
"retrieval_k": rag_pipeline.retriever_k
}
if __name__ == "__main__":
uvicorn.run(
"src.api.main:app",
host="0.0.0.0",
port=8000,
reload=True
)