Spaces:
Sleeping
Sleeping
| import json | |
| from src.retrieval.vector_search import AnimeRetriever | |
| from src.llm.anime_reranker import AnimeReranker | |
| from src.llm.groq_client import GroqLLM | |
| from src.llm.prompts import create_recommendation_prompt, create_system_prompt, ANIME_SEARCH_TOOL, ROUTER_SYSTEM_PROMPT | |
| import logging | |
| from config import settings | |
| logger = logging.getLogger(__name__) | |
| class AnimeRAGPipeline: | |
| def __init__( | |
| self, | |
| retriever: AnimeRetriever | None = None, | |
| reranker: AnimeReranker | None = None, | |
| llm: GroqLLM | None = None, | |
| retriever_k: int = 50, | |
| recommendation_n: int | None = 5 | |
| ): | |
| """ | |
| Initialize RAG Pipeline | |
| Args: | |
| retriever: AnimeRetriever instance (created if None) | |
| llm: GroqLLM instance (created if None) | |
| retrieval_k: How many anime to retrieve from vector search | |
| recommendation_n: How many to recommend in final output | |
| """ | |
| self.retriever = retriever or AnimeRetriever() | |
| self.reranker = reranker or AnimeReranker() | |
| self.llm = llm or GroqLLM(model=settings.model_name) | |
| self.retriever_k = retriever_k | |
| self.recommendation_n = recommendation_n | |
| logger.info("RAG Pipeline initialized") | |
| logger.info(f" - Retrieve top {retriever_k} anime from vector search") | |
| logger.info(f" - Rerank and filter using Cross-Encoder & Bayesian Math") | |
| logger.info( | |
| f" - LLM reasons and recommends top {recommendation_n} anime") | |
| def recommend( | |
| self, | |
| user_query: str, | |
| filters: dict | None = None | |
| ) -> dict: | |
| """ | |
| Get anime recommendations from user query | |
| Args: | |
| user_query: User's request (e.g., "Anime like death note for lighter") | |
| filters: Optional filters (min_score, genre_filter, anime_type) | |
| Returns: | |
| Dict with: | |
| - query: original query | |
| - retrieved_count: how many retrieved | |
| - recommendations: LLM Response | |
| - retrieved_anime: raw retrieval results(for debugging) | |
| """ | |
| logger.info(f"\n----Processing query: {user_query}-----\n") | |
| filters = filters or {} | |
| # [STEP 1] The Agentic Decision Call | |
| logger.info("[1/5] Asking LLM if it needs to search...") | |
| initial_response = self.llm.chat_with_tools( | |
| messages=[{"role": "user", "content": user_query}], | |
| tools=ANIME_SEARCH_TOOL, | |
| system_prompt=ROUTER_SYSTEM_PROMPT | |
| ) | |
| # [STEP 2] Check if the LLM decided to call the tool | |
| if not initial_response: | |
| logger.error("Groq API failed completely.") | |
| return { | |
| "query": user_query, | |
| "recommendations": "Sorry, I'm having trouble processing your query. Can you be more clear and try again?", | |
| "retrieved_count_from_DB": 0, | |
| "reranked_count": 0, | |
| "retrieved_animes": [] | |
| } | |
| if not initial_response.tool_calls: | |
| logger.info( | |
| "[2/5] No search needed. Returning conversational response.") | |
| return { | |
| "query": user_query, | |
| "retrieved_count_from_DB": 0, | |
| "reranked_count": 0, | |
| "recommendations": initial_response.content, | |
| "retrieved_animes": [] | |
| } | |
| # [STEP 3] The LLM wants to search. Extract its optimized parameters. | |
| logger.info("[2/5] Tool called! Executing vector search...") | |
| tool_call = initial_response.tool_calls[0] | |
| tool_args = json.loads(tool_call.function.arguments) | |
| logger.info(f"Tool called: [{tool_call}] with args: [{tool_args}]\n") | |
| optimized_query = tool_args.get("optimized_query", user_query) | |
| # 3A: Fetch Top >=50 Anime from Chromadb | |
| retrieved_animes = self.retriever.search( | |
| query=optimized_query, | |
| n_results=self.retriever_k, | |
| **filters | |
| ) | |
| # 3B: Rerank the Top 50 Using Pytorch + Math | |
| logger.info("[3/5] Reranking results with Cross-Encoder") | |
| reranked_df = self.reranker.process( | |
| user_query=optimized_query, | |
| retrieved_anime=retrieved_animes, | |
| top_k=10 | |
| ) | |
| top_animes_list = reranked_df.to_dict( | |
| orient="records") if not reranked_df.empty else [] | |
| print( | |
| f"After reranking, fetched {len(top_animes_list)} top animes....") | |
| # [STEP 4] The Final Recommendation Call | |
| logger.info("[4/5] Creating prompt with retrieved content...") | |
| prompt = create_recommendation_prompt( | |
| user_query=user_query, | |
| retrieved_animes=top_animes_list, | |
| n_recommendations=self.recommendation_n | |
| ) | |
| logger.info("[5/5] LLM generating final response...") | |
| system_prompt = create_system_prompt() | |
| recommendations = self.llm.generate( | |
| prompt=prompt, | |
| system_prompt=system_prompt, | |
| temperature=0.5, | |
| max_tokens=1500 | |
| ) | |
| logger.info("\n---Generated Recommendations---") | |
| return { | |
| "query": user_query, | |
| "retrieved_count_from_DB": len(retrieved_animes), | |
| "reranked_count": len(top_animes_list), | |
| "recommendations": recommendations, | |
| "retrieved_animes": top_animes_list | |
| } | |
| def recommend_streaming(self, user_query: str, filters: dict | None = None): | |
| """ | |
| Streaming version for real-time display | |
| """ | |
| # Use it in FastAPI later | |
| return self.recommend(user_query, filters) | |
| if __name__ == "__main__": | |
| import json | |
| import os | |
| # Ensure the data directory exists | |
| os.makedirs("data", exist_ok=True) | |
| pipeline = AnimeRAGPipeline( | |
| retriever_k=50, | |
| recommendation_n=5 | |
| ) | |
| test_queries = [ | |
| "Anime similar to Death Note but lighter in tone", | |
| # Test the hidden gem / passion rate | |
| "A really obscure and weird sci-fi mecha from the 90s", | |
| "A generic isekai with an overpowered main character" # Test the mainstream math | |
| ] | |
| for query in test_queries: | |
| print(f"\n" + "="*80) | |
| print(f"TESTING QUERY: '{query}'") | |
| print("="*80) | |
| # 1. Run the pipeline (This executes the LLM routing, retrieval, reranking, and generation) | |
| result = pipeline.recommend(user_query=query) | |
| # 2. Extract the data for the diagnostic report | |
| # We need to reach into the pipeline's retriever to see what the raw ChromaDB output was, | |
| # since the pipeline only returns the *reranked* list. | |
| # Note: If the LLM chose NOT to search, result["retrieved_count"] will be 0. | |
| diagnostic_data = { | |
| "1_initial_user_query": result["query"], | |
| "2_llm_routing_decision": "Searched" if result["retrieved_count"] > 0 else "Chatted", | |
| "3_total_retrieved_from_chroma": result["retrieved_count"], | |
| "4_total_survived_reranking": result.get("reranked_count", 0), | |
| "5_final_llm_recommendation_text": result["recommendations"], | |
| # The "Before" state: We want to see what ChromaDB found just using embeddings | |
| "6_raw_chroma_results_before_reranking": [], | |
| # The "After" state: We want to see the exact math scores for the top survivors | |
| "7_reranked_results_with_math": [] | |
| } | |
| # If the LLM actually performed a search, let's build the detailed lists | |
| if result["retrieved_count"] > 0: | |
| # We need to get the optimized query that the LLM generated for the tool call. | |
| # (In a real app, you might want to return `optimized_query` in the `result` dict from `recommend()`) | |
| # For this test, we will just assume it's the original query if we can't easily grab it here. | |
| # Let's format the top 15 Reranked items with all their math exposed | |
| for rank, anime in enumerate(result["retrieved_animes"]): | |
| diagnostic_data["7_reranked_results_with_math"].append({ | |
| "rank": rank + 1, | |
| "title": anime["title"], | |
| "hybrid_score_final": round(anime.get("final_hybrid_score", 0), 4), | |
| "semantic_score_raw": round(anime.get("semantic_score", 0), 4), | |
| "quality_score_raw": round(anime.get("raw_quality_score", 0), 4), | |
| "bayesian_average": round(anime.get("bayesian_score", 0), 4), | |
| "passion_rate": round(anime.get("passion_rate", 0), 5), | |
| # Showing original vector distance | |
| "chroma_distance": round(1 - anime.get("relevance_score", 1), 4) | |
| }) | |
| # 3. Save the diagnostic report | |
| safe_filename = f"diagnostic_{query[:20].replace(' ', '_').lower()}.json" | |
| filepath = os.path.join("data", safe_filename) | |
| with open(filepath, "w") as f: | |
| json.dump(diagnostic_data, f, indent=2) | |
| print(f"\n✅ Diagnostic saved to {filepath}") | |
| print(f"Generated Output:\n{result['recommendations']}") | |
| # Pause to respect Groq API rate limits | |
| import time | |
| time.sleep(3) | |