Spaces:
Running
Running
File size: 4,241 Bytes
4564881 617515c 1a2b9e6 55943c5 4564881 617515c 4564881 d7434b7 4564881 1a2b9e6 4564881 1a2b9e6 617515c 1a2b9e6 f69a6fa 1a2b9e6 4564881 1a2b9e6 d7434b7 f69a6fa 1a2b9e6 dbb9b6d 4564881 f69a6fa dbb9b6d 1a2b9e6 f69a6fa 1a2b9e6 55943c5 1a2b9e6 55943c5 1a2b9e6 617515c 1a2b9e6 4564881 1a2b9e6 617515c 1a2b9e6 55943c5 1a2b9e6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | from src.retrieval.rag_pipeline import AnimeRAGPipeline
from fastapi import FastAPI, HTTPException, status, Request, Security, Depends
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import uvicorn
from src.api.schemas import RecommendationRequest, RecommendationResponse
import time
from config import settings
import traceback
import logging
from contextlib import asynccontextmanager
logger = logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
api_key_header = APIKeyHeader(name="X-API-Key")
async def verify_api_key(api_key: str = Security(api_key_header)):
if api_key != settings.fastapi_api_key:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Invalid API Key")
return api_key
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Handles startup and shutdown of the Anime RAG Pipeline.
Replaces lazy global initialization with app state.
"""
print("Initializing Anime RAG Pipeline...")
app.state.pipeline = AnimeRAGPipeline(retriever_k=50)
yield
print("Shutting down... Cleaning up resources.")
app = FastAPI(title="Anime Recommendation API",
description="RAG-powered anime recommendation system",
version="1.0.0",
lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
@app.get("/")
async def root():
"""Healthcheck Endpoint"""
return {
"status": "online",
"message": "Anime recommendation API",
"version": "1.0.0"
}
@app.post("/recommend", response_model=RecommendationResponse)
async def get_recommendations(request: RecommendationRequest, fastapi_req: Request, _=Depends(verify_api_key)):
"""
Get anime recommendation based on user query
Example request:
```json
{
"query": "Anime similar to Death Note but lighter",
"n_results": 5,
"min_score": 7.5,
"genre_filter": ["Comedy", "Fantasy"],
"anime_type": TV
}
```
"""
try:
rag_pipeline = fastapi_req.app.state.pipeline
rag_pipeline.recommendation_n = request.n_results
filters = {}
if request.min_score:
filters["min_score"] = request.min_score
if request.genre_filter:
filters["genre_filter"] = request.genre_filter
if request.anime_type:
filters["anime_type"] = request.anime_type
start_time = time.time()
result = rag_pipeline.recommend(
user_query=request.query,
filters=filters if filters else None
)
end_time = time.time()
# print(f"Retrieved anime : \n{result["retrieved_animes"][0]}")
print(f"Retrieved anime Count : \n{result["reranked_count"]}")
print(f"Result Recommendations: \n{result["recommendations"][:20]}")
return RecommendationResponse(
query=result["query"],
recommendations=result["recommendations"],
retrieved_count=result["reranked_count"],
retrieved_animes=result["retrieved_animes"],
metadata={
"model": settings.model_name,
"retriever_k": rag_pipeline.retriever_k,
"Time taken for LLM + vector search": str(end_time - start_time)
}
)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error processing request: {str(e)}")
@app.get("/stats")
async def get_stats(fastapi_req: Request, _=Depends(verify_api_key)):
"""Get system statistics"""
rag_pipeline = fastapi_req.app.state.pipeline
return {
"total_anime": rag_pipeline.retriever.points_count,
"embedding_model": "all-MiniLM-L6-v2",
"llm_model": settings.model_name,
"retrieval_k": rag_pipeline.retriever_k
}
if __name__ == "__main__":
uvicorn.run(
"src.api.main:app",
host="0.0.0.0",
port=8000,
reload=True
)
|