AnimeRAGSystem / src /retrieval /rag_pipeline.py
Pushkar02-n's picture
Final changes in local
d7434b7
import json
from src.retrieval.vector_search import AnimeRetriever
from src.llm.anime_reranker import AnimeReranker
from src.llm.groq_client import GroqLLM
from src.llm.prompts import create_recommendation_prompt, create_system_prompt, ANIME_SEARCH_TOOL, ROUTER_SYSTEM_PROMPT
import logging
from config import settings
logger = logging.getLogger(__name__)
class AnimeRAGPipeline:
def __init__(
self,
retriever: AnimeRetriever | None = None,
reranker: AnimeReranker | None = None,
llm: GroqLLM | None = None,
retriever_k: int = 50,
recommendation_n: int | None = 5
):
"""
Initialize RAG Pipeline
Args:
retriever: AnimeRetriever instance (created if None)
llm: GroqLLM instance (created if None)
retrieval_k: How many anime to retrieve from vector search
recommendation_n: How many to recommend in final output
"""
self.retriever = retriever or AnimeRetriever()
self.reranker = reranker or AnimeReranker()
self.llm = llm or GroqLLM(model=settings.model_name)
self.retriever_k = retriever_k
self.recommendation_n = recommendation_n
logger.info("RAG Pipeline initialized")
logger.info(f" - Retrieve top {retriever_k} anime from vector search")
logger.info(f" - Rerank and filter using Cross-Encoder & Bayesian Math")
logger.info(
f" - LLM reasons and recommends top {recommendation_n} anime")
def recommend(
self,
user_query: str,
filters: dict | None = None
) -> dict:
"""
Get anime recommendations from user query
Args:
user_query: User's request (e.g., "Anime like death note for lighter")
filters: Optional filters (min_score, genre_filter, anime_type)
Returns:
Dict with:
- query: original query
- retrieved_count: how many retrieved
- recommendations: LLM Response
- retrieved_anime: raw retrieval results(for debugging)
"""
logger.info(f"\n----Processing query: {user_query}-----\n")
filters = filters or {}
# [STEP 1] The Agentic Decision Call
logger.info("[1/5] Asking LLM if it needs to search...")
initial_response = self.llm.chat_with_tools(
messages=[{"role": "user", "content": user_query}],
tools=ANIME_SEARCH_TOOL,
system_prompt=ROUTER_SYSTEM_PROMPT
)
# [STEP 2] Check if the LLM decided to call the tool
if not initial_response:
logger.error("Groq API failed completely.")
return {
"query": user_query,
"recommendations": "Sorry, I'm having trouble processing your query. Can you be more clear and try again?",
"retrieved_count_from_DB": 0,
"reranked_count": 0,
"retrieved_animes": []
}
if not initial_response.tool_calls:
logger.info(
"[2/5] No search needed. Returning conversational response.")
return {
"query": user_query,
"retrieved_count_from_DB": 0,
"reranked_count": 0,
"recommendations": initial_response.content,
"retrieved_animes": []
}
# [STEP 3] The LLM wants to search. Extract its optimized parameters.
logger.info("[2/5] Tool called! Executing vector search...")
tool_call = initial_response.tool_calls[0]
tool_args = json.loads(tool_call.function.arguments)
logger.info(f"Tool called: [{tool_call}] with args: [{tool_args}]\n")
optimized_query = tool_args.get("optimized_query", user_query)
# 3A: Fetch Top >=50 Anime from Chromadb
retrieved_animes = self.retriever.search(
query=optimized_query,
n_results=self.retriever_k,
**filters
)
# 3B: Rerank the Top 50 Using Pytorch + Math
logger.info("[3/5] Reranking results with Cross-Encoder")
reranked_df = self.reranker.process(
user_query=optimized_query,
retrieved_anime=retrieved_animes,
top_k=10
)
top_animes_list = reranked_df.to_dict(
orient="records") if not reranked_df.empty else []
print(
f"After reranking, fetched {len(top_animes_list)} top animes....")
# [STEP 4] The Final Recommendation Call
logger.info("[4/5] Creating prompt with retrieved content...")
prompt = create_recommendation_prompt(
user_query=user_query,
retrieved_animes=top_animes_list,
n_recommendations=self.recommendation_n
)
logger.info("[5/5] LLM generating final response...")
system_prompt = create_system_prompt()
recommendations = self.llm.generate(
prompt=prompt,
system_prompt=system_prompt,
temperature=0.5,
max_tokens=1500
)
logger.info("\n---Generated Recommendations---")
return {
"query": user_query,
"retrieved_count_from_DB": len(retrieved_animes),
"reranked_count": len(top_animes_list),
"recommendations": recommendations,
"retrieved_animes": top_animes_list
}
def recommend_streaming(self, user_query: str, filters: dict | None = None):
"""
Streaming version for real-time display
"""
# Use it in FastAPI later
return self.recommend(user_query, filters)
if __name__ == "__main__":
import json
import os
# Ensure the data directory exists
os.makedirs("data", exist_ok=True)
pipeline = AnimeRAGPipeline(
retriever_k=50,
recommendation_n=5
)
test_queries = [
"Anime similar to Death Note but lighter in tone",
# Test the hidden gem / passion rate
"A really obscure and weird sci-fi mecha from the 90s",
"A generic isekai with an overpowered main character" # Test the mainstream math
]
for query in test_queries:
print(f"\n" + "="*80)
print(f"TESTING QUERY: '{query}'")
print("="*80)
# 1. Run the pipeline (This executes the LLM routing, retrieval, reranking, and generation)
result = pipeline.recommend(user_query=query)
# 2. Extract the data for the diagnostic report
# We need to reach into the pipeline's retriever to see what the raw ChromaDB output was,
# since the pipeline only returns the *reranked* list.
# Note: If the LLM chose NOT to search, result["retrieved_count"] will be 0.
diagnostic_data = {
"1_initial_user_query": result["query"],
"2_llm_routing_decision": "Searched" if result["retrieved_count"] > 0 else "Chatted",
"3_total_retrieved_from_chroma": result["retrieved_count"],
"4_total_survived_reranking": result.get("reranked_count", 0),
"5_final_llm_recommendation_text": result["recommendations"],
# The "Before" state: We want to see what ChromaDB found just using embeddings
"6_raw_chroma_results_before_reranking": [],
# The "After" state: We want to see the exact math scores for the top survivors
"7_reranked_results_with_math": []
}
# If the LLM actually performed a search, let's build the detailed lists
if result["retrieved_count"] > 0:
# We need to get the optimized query that the LLM generated for the tool call.
# (In a real app, you might want to return `optimized_query` in the `result` dict from `recommend()`)
# For this test, we will just assume it's the original query if we can't easily grab it here.
# Let's format the top 15 Reranked items with all their math exposed
for rank, anime in enumerate(result["retrieved_animes"]):
diagnostic_data["7_reranked_results_with_math"].append({
"rank": rank + 1,
"title": anime["title"],
"hybrid_score_final": round(anime.get("final_hybrid_score", 0), 4),
"semantic_score_raw": round(anime.get("semantic_score", 0), 4),
"quality_score_raw": round(anime.get("raw_quality_score", 0), 4),
"bayesian_average": round(anime.get("bayesian_score", 0), 4),
"passion_rate": round(anime.get("passion_rate", 0), 5),
# Showing original vector distance
"chroma_distance": round(1 - anime.get("relevance_score", 1), 4)
})
# 3. Save the diagnostic report
safe_filename = f"diagnostic_{query[:20].replace(' ', '_').lower()}.json"
filepath = os.path.join("data", safe_filename)
with open(filepath, "w") as f:
json.dump(diagnostic_data, f, indent=2)
print(f"\n✅ Diagnostic saved to {filepath}")
print(f"Generated Output:\n{result['recommendations']}")
# Pause to respect Groq API rate limits
import time
time.sleep(3)