File size: 4,241 Bytes
4564881
617515c
 
1a2b9e6
 
 
 
 
55943c5
 
4564881
 
 
 
 
 
617515c
 
 
 
 
 
 
 
 
4564881
 
 
 
 
 
 
 
d7434b7
4564881
 
 
1a2b9e6
 
 
 
4564881
 
1a2b9e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
617515c
1a2b9e6
 
 
 
 
 
 
 
f69a6fa
 
 
1a2b9e6
 
 
 
4564881
1a2b9e6
 
 
 
 
 
 
 
 
d7434b7
f69a6fa
 
1a2b9e6
 
 
 
 
 
 
dbb9b6d
4564881
f69a6fa
dbb9b6d
1a2b9e6
 
 
f69a6fa
 
1a2b9e6
55943c5
1a2b9e6
 
 
 
 
 
55943c5
1a2b9e6
 
 
 
 
617515c
1a2b9e6
4564881
1a2b9e6
 
617515c
1a2b9e6
55943c5
1a2b9e6
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from src.retrieval.rag_pipeline import AnimeRAGPipeline
from fastapi import FastAPI, HTTPException, status, Request, Security, Depends
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import uvicorn
from src.api.schemas import RecommendationRequest, RecommendationResponse
import time
from config import settings
import traceback
import logging
from contextlib import asynccontextmanager

logger = logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

api_key_header = APIKeyHeader(name="X-API-Key")


async def verify_api_key(api_key: str = Security(api_key_header)):
    if api_key != settings.fastapi_api_key:
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN, detail="Invalid API Key")
    return api_key


@asynccontextmanager
async def lifespan(app: FastAPI):
    """
    Handles startup and shutdown of the Anime RAG Pipeline.
    Replaces lazy global initialization with app state.
    """
    print("Initializing Anime RAG Pipeline...")
    app.state.pipeline = AnimeRAGPipeline(retriever_k=50)

    yield
    print("Shutting down... Cleaning up resources.")


app = FastAPI(title="Anime Recommendation API",
              description="RAG-powered anime recommendation system",
              version="1.0.0",
              lifespan=lifespan)


app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"]
)


@app.get("/")
async def root():
    """Healthcheck Endpoint"""
    return {
        "status": "online",
        "message": "Anime recommendation API",
        "version": "1.0.0"
    }


@app.post("/recommend", response_model=RecommendationResponse)
async def get_recommendations(request: RecommendationRequest, fastapi_req: Request, _=Depends(verify_api_key)):
    """
    Get anime recommendation based on user query

    Example request:
    ```json
    {
        "query": "Anime similar to Death Note but lighter",
        "n_results": 5,
        "min_score": 7.5,
        "genre_filter": ["Comedy", "Fantasy"],
        "anime_type": TV
    }
    ```
    """
    try:
        rag_pipeline = fastapi_req.app.state.pipeline

        rag_pipeline.recommendation_n = request.n_results

        filters = {}
        if request.min_score:
            filters["min_score"] = request.min_score

        if request.genre_filter:
            filters["genre_filter"] = request.genre_filter

        if request.anime_type:
            filters["anime_type"] = request.anime_type

        start_time = time.time()
        result = rag_pipeline.recommend(
            user_query=request.query,
            filters=filters if filters else None
        )
        end_time = time.time()

        # print(f"Retrieved anime : \n{result["retrieved_animes"][0]}")
        print(f"Retrieved anime Count : \n{result["reranked_count"]}")
        print(f"Result Recommendations: \n{result["recommendations"][:20]}")
        return RecommendationResponse(
            query=result["query"],
            recommendations=result["recommendations"],
            retrieved_count=result["reranked_count"],
            retrieved_animes=result["retrieved_animes"],
            metadata={
                "model": settings.model_name,
                "retriever_k": rag_pipeline.retriever_k,
                "Time taken for LLM + vector search": str(end_time - start_time)
            }
        )

    except Exception as e:
        traceback.print_exc()
        raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                            detail=f"Error processing request: {str(e)}")


@app.get("/stats")
async def get_stats(fastapi_req: Request, _=Depends(verify_api_key)):
    """Get system statistics"""
    rag_pipeline = fastapi_req.app.state.pipeline

    return {
        "total_anime": rag_pipeline.retriever.points_count,
        "embedding_model": "all-MiniLM-L6-v2",
        "llm_model": settings.model_name,
        "retrieval_k": rag_pipeline.retriever_k
    }

if __name__ == "__main__":
    uvicorn.run(
        "src.api.main:app",
        host="0.0.0.0",
        port=8000,
        reload=True
    )