Spaces:
Sleeping
Sleeping
File size: 7,343 Bytes
1d70196 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | import os
import faiss
import pickle
import numpy as np
import json
from sentence_transformers import SentenceTransformer
from groq import Groq
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
import logging
from dotenv import load_dotenv
# Load environment variables from the .env file at project root
_project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
load_dotenv(os.path.join(_project_root, '.env'))
# Configure logging for tenacity retries
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ReviewRAGEngine:
def __init__(self, vectorstore_dir: str = 'vectorstore'):
"""
Initializes the RAG Engine.
Loads the FAISS index, embedding model, and sets up the OpenAI-compatible LLM client.
"""
self.vectorstore_dir = vectorstore_dir
print("Initializing RAG Engine...")
# Load Embedding Model
try:
self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
except Exception as e:
print(f"Failed to load sentence-transformer: {e}")
self.embedder = None
# Load FAISS Index
index_path = os.path.join(vectorstore_dir, 'reviews.index')
try:
self.index = faiss.read_index(index_path)
print(f"Loaded FAISS index with {self.index.ntotal} vectors.")
except Exception as e:
print(f"Warning: Could not load FAISS index from {index_path}. Error: {e}")
self.index = None
# Load Reviews Metadata Database
metadata_path = os.path.join(vectorstore_dir, 'reviews_metadata.pkl')
try:
with open(metadata_path, 'rb') as f:
self.metadata_df = pickle.load(f)
print(f"Loaded metadata for {len(self.metadata_df)} reviews.")
except Exception as e:
print(f"Warning: Could not load metadata from {metadata_path}. Error: {e}")
self.metadata_df = None
# Setup LLM Client via Groq
api_key = os.getenv("GROQ_API_KEY", "")
self.client = Groq(api_key=api_key)
self.llm_model = "moonshotai/kimi-k2-instruct-0905"
# Retry decorator: retries up to 5 times, waiting 2^x * 1 seconds between each retry (max 10s wait)
# This prevents the application from crashing if the Moonshot API hits rate limits.
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=2, max=10),
reraise=True
)
def _call_llm_with_retry(self, messages):
"""Calls the LLM API with exponential backoff to handle rate limits gracefully."""
try:
response = self.client.chat.completions.create(
model=self.llm_model,
messages=messages,
temperature=0.6,
max_completion_tokens=4096,
top_p=1,
stream=False,
stop=None
)
return response.choices[0].message.content
except Exception as e:
# In a real environment, you'd only catch specific RateLimitErrors here.
logger.warning(f"LLM API Call failed. Entering exponential backoff... Error: {e}")
raise e
def retrieve(self, query: str, top_k: int = 15):
"""
Embeds the query and retrieves the Top K most similar reviews from FAISS.
"""
if not self.index or not self.metadata_df is not None or not self.embedder:
return [{"text": "Systems not fully loaded", "aspects": {}}]
# Embed Query
q_embedding = self.embedder.encode([query], convert_to_numpy=True)
faiss.normalize_L2(q_embedding)
# Search FAISS
distances, indices = self.index.search(q_embedding, top_k)
# Fetch metadata
results = []
for idx in indices[0]:
if idx == -1: continue # FAISS returns -1 if there aren't enough vectors
row = self.metadata_df.iloc[idx]
text = row.get('reviewDocument', str(row.values[0])) # fallback to first column if missing
aspects_str = row.get('predicted_aspects', '{}')
try:
aspects_dict = json.loads(aspects_str)
except:
aspects_dict = {}
results.append({
"text": text,
"aspects": aspects_dict
})
return results
def answer_question(self, question: str, top_k: int = 15) -> str:
"""
Full RAG Pipeline: Retrieve relevant reviews -> Build Context -> Synthesize Answer
"""
# 1. Retrieve
retrieved_reviews = self.retrieve(question, top_k)
if not retrieved_reviews:
return "I couldn't find any relevant reviews to answer your question."
# 2. Build Context String
context_parts = []
for i, rev in enumerate(retrieved_reviews):
aspect_summaries = []
for aspect, details in rev['aspects'].items():
sentiment = details.get('sentiment', 'unknown')
aspect_summaries.append(f"{aspect}: {sentiment}")
aspects_joined = ", ".join(aspect_summaries) if aspect_summaries else "None detected"
context_parts.append(f"Review {i+1}: \"{rev['text']}\"\nDetected Aspects: [{aspects_joined}]")
context_block = "\n\n".join(context_parts)
# 3. Build Prompt
system_prompt = (
"You are an expert E-Commerce Product Analyst. "
"You help product managers understand customer feedback by analyzing reviews and aspect sentiments. "
"Always base your answers strictly on the context provided. Do not invent information. "
"Use bullet points and be concise. If possible, mention specific percentages or counts based on the context."
)
user_prompt = f"""Based on the following retrieved customer reviews and the AI-extracted aspect sentiments, answer the user's question.
<CONTEXT>
{context_block}
</CONTEXT>
Question: {question}
Answer:"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
# 4. Call LLM (with safety retries)
try:
answer = self._call_llm_with_retry(messages)
return answer
except Exception as e:
return f"Error: Failed to reach the Moonshot API after multiple retries due to rate limits or connection errors. Detail: {e}"
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Test the RAG Engine locally")
parser.add_argument('--query', type=str, default="What do people complain about the most?", help="Question to ask")
args = parser.parse_args()
rag = ReviewRAGEngine()
print(f"\nUser Query: {args.query}")
print("\nRetrieving and synthesizing...")
answer = rag.answer_question(args.query, top_k=5)
print("\n-------------------------")
print("RAG System Output:")
print("-------------------------")
print(answer)
|