Spaces:
Sleeping
Sleeping
| import os | |
| import faiss | |
| import pickle | |
| import numpy as np | |
| import json | |
| from sentence_transformers import SentenceTransformer | |
| from groq import Groq | |
| from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type | |
| import logging | |
| from dotenv import load_dotenv | |
| # Load environment variables from the .env file at project root | |
| _project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) | |
| load_dotenv(os.path.join(_project_root, '.env')) | |
| # Configure logging for tenacity retries | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class ReviewRAGEngine: | |
| def __init__(self, vectorstore_dir: str = 'vectorstore'): | |
| """ | |
| Initializes the RAG Engine. | |
| Loads the FAISS index, embedding model, and sets up the OpenAI-compatible LLM client. | |
| """ | |
| self.vectorstore_dir = vectorstore_dir | |
| print("Initializing RAG Engine...") | |
| # Load Embedding Model | |
| try: | |
| self.embedder = SentenceTransformer('all-MiniLM-L6-v2') | |
| except Exception as e: | |
| print(f"Failed to load sentence-transformer: {e}") | |
| self.embedder = None | |
| # Load FAISS Index | |
| index_path = os.path.join(vectorstore_dir, 'reviews.index') | |
| try: | |
| self.index = faiss.read_index(index_path) | |
| print(f"Loaded FAISS index with {self.index.ntotal} vectors.") | |
| except Exception as e: | |
| print(f"Warning: Could not load FAISS index from {index_path}. Error: {e}") | |
| self.index = None | |
| # Load Reviews Metadata Database | |
| metadata_path = os.path.join(vectorstore_dir, 'reviews_metadata.pkl') | |
| try: | |
| with open(metadata_path, 'rb') as f: | |
| self.metadata_df = pickle.load(f) | |
| print(f"Loaded metadata for {len(self.metadata_df)} reviews.") | |
| except Exception as e: | |
| print(f"Warning: Could not load metadata from {metadata_path}. Error: {e}") | |
| self.metadata_df = None | |
| # Setup LLM Client via Groq | |
| api_key = os.getenv("GROQ_API_KEY", "") | |
| self.client = Groq(api_key=api_key) | |
| self.llm_model = "moonshotai/kimi-k2-instruct-0905" | |
| # Retry decorator: retries up to 5 times, waiting 2^x * 1 seconds between each retry (max 10s wait) | |
| # This prevents the application from crashing if the Moonshot API hits rate limits. | |
| def _call_llm_with_retry(self, messages): | |
| """Calls the LLM API with exponential backoff to handle rate limits gracefully.""" | |
| try: | |
| response = self.client.chat.completions.create( | |
| model=self.llm_model, | |
| messages=messages, | |
| temperature=0.6, | |
| max_completion_tokens=4096, | |
| top_p=1, | |
| stream=False, | |
| stop=None | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| # In a real environment, you'd only catch specific RateLimitErrors here. | |
| logger.warning(f"LLM API Call failed. Entering exponential backoff... Error: {e}") | |
| raise e | |
| def retrieve(self, query: str, top_k: int = 15): | |
| """ | |
| Embeds the query and retrieves the Top K most similar reviews from FAISS. | |
| """ | |
| if not self.index or not self.metadata_df is not None or not self.embedder: | |
| return [{"text": "Systems not fully loaded", "aspects": {}}] | |
| # Embed Query | |
| q_embedding = self.embedder.encode([query], convert_to_numpy=True) | |
| faiss.normalize_L2(q_embedding) | |
| # Search FAISS | |
| distances, indices = self.index.search(q_embedding, top_k) | |
| # Fetch metadata | |
| results = [] | |
| for idx in indices[0]: | |
| if idx == -1: continue # FAISS returns -1 if there aren't enough vectors | |
| row = self.metadata_df.iloc[idx] | |
| text = row.get('reviewDocument', str(row.values[0])) # fallback to first column if missing | |
| aspects_str = row.get('predicted_aspects', '{}') | |
| try: | |
| aspects_dict = json.loads(aspects_str) | |
| except: | |
| aspects_dict = {} | |
| results.append({ | |
| "text": text, | |
| "aspects": aspects_dict | |
| }) | |
| return results | |
| def answer_question(self, question: str, top_k: int = 15) -> str: | |
| """ | |
| Full RAG Pipeline: Retrieve relevant reviews -> Build Context -> Synthesize Answer | |
| """ | |
| # 1. Retrieve | |
| retrieved_reviews = self.retrieve(question, top_k) | |
| if not retrieved_reviews: | |
| return "I couldn't find any relevant reviews to answer your question." | |
| # 2. Build Context String | |
| context_parts = [] | |
| for i, rev in enumerate(retrieved_reviews): | |
| aspect_summaries = [] | |
| for aspect, details in rev['aspects'].items(): | |
| sentiment = details.get('sentiment', 'unknown') | |
| aspect_summaries.append(f"{aspect}: {sentiment}") | |
| aspects_joined = ", ".join(aspect_summaries) if aspect_summaries else "None detected" | |
| context_parts.append(f"Review {i+1}: \"{rev['text']}\"\nDetected Aspects: [{aspects_joined}]") | |
| context_block = "\n\n".join(context_parts) | |
| # 3. Build Prompt | |
| system_prompt = ( | |
| "You are an expert E-Commerce Product Analyst. " | |
| "You help product managers understand customer feedback by analyzing reviews and aspect sentiments. " | |
| "Always base your answers strictly on the context provided. Do not invent information. " | |
| "Use bullet points and be concise. If possible, mention specific percentages or counts based on the context." | |
| ) | |
| user_prompt = f"""Based on the following retrieved customer reviews and the AI-extracted aspect sentiments, answer the user's question. | |
| <CONTEXT> | |
| {context_block} | |
| </CONTEXT> | |
| Question: {question} | |
| Answer:""" | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ] | |
| # 4. Call LLM (with safety retries) | |
| try: | |
| answer = self._call_llm_with_retry(messages) | |
| return answer | |
| except Exception as e: | |
| return f"Error: Failed to reach the Moonshot API after multiple retries due to rate limits or connection errors. Detail: {e}" | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Test the RAG Engine locally") | |
| parser.add_argument('--query', type=str, default="What do people complain about the most?", help="Question to ask") | |
| args = parser.parse_args() | |
| rag = ReviewRAGEngine() | |
| print(f"\nUser Query: {args.query}") | |
| print("\nRetrieving and synthesizing...") | |
| answer = rag.answer_question(args.query, top_k=5) | |
| print("\n-------------------------") | |
| print("RAG System Output:") | |
| print("-------------------------") | |
| print(answer) | |