Spaces:
Sleeping
Sleeping
| import json | |
| import logging | |
| import re | |
| from typing import List, Dict, Any, Optional | |
| from langchain_groq import ChatGroq | |
| from langchain_core.messages import HumanMessage, SystemMessage | |
| from dotenv import load_dotenv | |
| import os | |
| from services.processor import DocumentProcessor | |
| logger = logging.getLogger(__name__) | |
| load_dotenv() | |
| # Any of these patterns alone is enough to signal a product search intent | |
| _PRODUCT_INTENT_PATTERNS = [ | |
| r'\b(phones?|laptops?|tablets?|smartphones?|computers?|tvs?|cameras?|watches?|shoes?|bags?|dress(es)?|clothes?|headphones?|speakers?)\b', | |
| r'\b(how much|price[sd]?|cost[sd]?|cheap(est)?|expensive|budget|affordable|deal[sd]?|offer[sd]?|discount)\b', | |
| r'\b(buy|purchase|order|in stock|available)\b', | |
| r'\b(show me|find me|search for|give me|list|display)\b', | |
| r'\b(recommend|suggest|compare|vs\.?|versus)\b', | |
| r'\b(products?|items?)\b', | |
| ] | |
| _PRODUCT_INTENT_RE = re.compile( | |
| '|'.join(_PRODUCT_INTENT_PATTERNS), | |
| flags=re.IGNORECASE, | |
| ) | |
| class RAGAgent: | |
| """ | |
| RAG agent combining Groq LLM with ChromaDB to answer questions | |
| grounded in indexed e-commerce site content. | |
| Supports both plain text responses and structured product card responses. | |
| """ | |
| def __init__(self): | |
| self.groq_api_key = os.getenv("GROQ_API_KEY") | |
| if not self.groq_api_key: | |
| raise ValueError("GROQ_API_KEY environment variable is required") | |
| try: | |
| self.llm = ChatGroq( | |
| model="llama-3.3-70b-versatile", | |
| temperature=0.1, | |
| groq_api_key=self.groq_api_key, | |
| ) | |
| logger.info("Groq Llama3-70b model initialized") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize Groq model: {str(e)}") | |
| raise | |
| self.processor = DocumentProcessor() | |
| # --- Text Q&A prompt --- | |
| self.text_prompt_template = """You are a specialized AI assistant for the website "{site_id}". | |
| Answer the user's question using ONLY the information provided in the context below. | |
| SITE CONTEXT: | |
| {context} | |
| CONVERSATION HISTORY: | |
| {history} | |
| RULES: | |
| 1. Answer ONLY using information from the context | |
| 2. If the context does not contain the answer, politely say so | |
| 3. Be precise, concise, and helpful | |
| 4. Use lists or short paragraphs for clarity | |
| 5. Never invent information that is not in the context | |
| 6. Respond in English by default; match the user's language if they write in another language | |
| USER QUESTION: | |
| {query} | |
| Your answer:""" | |
| # --- Product search prompt --- | |
| self.product_prompt_template = """You are a product search assistant for the website "{site_id}". | |
| The user is looking for products. Extract all matching products from the context below. | |
| SITE CONTEXT: | |
| {context} | |
| CONVERSATION HISTORY: | |
| {history} | |
| USER QUERY: | |
| {query} | |
| INSTRUCTIONS: | |
| - Extract every product that matches the query from the context. | |
| - For each product, collect: name, price, image URL (from markdown image syntax ), product URL (from markdown link syntax [...](url)), rating, and availability. | |
| - If a field is not available, use an empty string "". | |
| - Return ONLY a valid JSON object in this exact format, nothing else before or after: | |
| {{ | |
| "summary": "Brief one-sentence summary of what you found", | |
| "products": [ | |
| {{ | |
| "name": "Product name", | |
| "price": "Price as shown (e.g. GH₵ 1,252)", | |
| "image_url": "https://...", | |
| "product_url": "https://...", | |
| "rating": "e.g. 4.5/5", | |
| "availability": "In Stock or empty string" | |
| }} | |
| ] | |
| }} | |
| JSON response:""" | |
| # ------------------------------------------------------------------ | |
| # Public API | |
| # ------------------------------------------------------------------ | |
| async def generate_response( | |
| self, | |
| query: str, | |
| site_id: str, | |
| history: List[Dict[str, str]] = None, | |
| max_context_chunks: int = 5, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Generate a RAG response. | |
| Returns a dict: | |
| - {"response_type": "text", "response": str} | |
| - {"response_type": "products", "response": str, "products": list} | |
| """ | |
| try: | |
| is_product_query = bool(_PRODUCT_INTENT_RE.search(query)) | |
| logger.info(f"Query intent — product_search={is_product_query}: {query[:60]}") | |
| context_chunks = await self._retrieve_relevant_context( | |
| site_id=site_id, | |
| query=query, | |
| max_results=max_context_chunks, | |
| ) | |
| if not context_chunks: | |
| logger.warning("No relevant context found") | |
| return { | |
| "response_type": "text", | |
| "response": "Sorry, I could not find relevant information for your question on this site.", | |
| } | |
| formatted_context = self._format_context(context_chunks) | |
| formatted_history = self._format_history(history or []) | |
| if is_product_query: | |
| return await self._generate_product_response( | |
| query, site_id, formatted_context, formatted_history | |
| ) | |
| else: | |
| return await self._generate_text_response( | |
| query, site_id, formatted_context, formatted_history | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error generating response: {str(e)}") | |
| return { | |
| "response_type": "text", | |
| "response": "Sorry, a technical error occurred. Please try again.", | |
| } | |
| # ------------------------------------------------------------------ | |
| # Internal helpers | |
| # ------------------------------------------------------------------ | |
| async def _generate_text_response( | |
| self, query: str, site_id: str, context: str, history: str | |
| ) -> Dict[str, Any]: | |
| prompt = self.text_prompt_template.format( | |
| site_id=site_id, | |
| context=context, | |
| history=history, | |
| query=query, | |
| ) | |
| raw = await self._call_llm(prompt, query) | |
| return {"response_type": "text", "response": raw} | |
| async def _generate_product_response( | |
| self, query: str, site_id: str, context: str, history: str | |
| ) -> Dict[str, Any]: | |
| prompt = self.product_prompt_template.format( | |
| site_id=site_id, | |
| context=context, | |
| history=history, | |
| query=query, | |
| ) | |
| raw = await self._call_llm(prompt, query) | |
| # Try to parse the JSON the LLM returned | |
| parsed = self._parse_product_json(raw) | |
| if parsed: | |
| return { | |
| "response_type": "products", | |
| "response": parsed.get("summary", f"Found {len(parsed.get('products', []))} products"), | |
| "products": parsed.get("products", []), | |
| } | |
| # Fallback: return as plain text if JSON parsing failed | |
| logger.warning("Product JSON parsing failed — falling back to text response") | |
| return {"response_type": "text", "response": raw} | |
| def _parse_product_json(self, raw: str) -> Optional[Dict]: | |
| """Extract and parse the JSON object from the LLM response.""" | |
| try: | |
| # Direct parse first | |
| return json.loads(raw.strip()) | |
| except json.JSONDecodeError: | |
| pass | |
| # Try to find a JSON block inside the response | |
| match = re.search(r'\{[\s\S]*"products"[\s\S]*\}', raw) | |
| if match: | |
| try: | |
| return json.loads(match.group(0)) | |
| except json.JSONDecodeError: | |
| pass | |
| logger.warning(f"Could not parse product JSON from: {raw[:200]}") | |
| return None | |
| async def _retrieve_relevant_context( | |
| self, site_id: str, query: str, max_results: int = 5 | |
| ) -> List[Dict[str, Any]]: | |
| try: | |
| chunks = await self.processor.search_similar_chunks( | |
| site_id=site_id, | |
| query=query, | |
| n_results=max_results, | |
| ) | |
| # Filter low-similarity chunks, fall back to top-k if all filtered | |
| relevant = [c for c in chunks if c.get("similarity_score", 0) > 0.1] | |
| return relevant if relevant else chunks | |
| except Exception as e: | |
| logger.error(f"Error retrieving context: {str(e)}") | |
| return [] | |
| def _format_context(self, context_chunks: List[Dict[str, Any]]) -> str: | |
| if not context_chunks: | |
| return "No context available." | |
| sections = [] | |
| for i, chunk in enumerate(context_chunks, 1): | |
| content = chunk.get("content", "").strip() | |
| metadata = chunk.get("metadata", {}) | |
| source_url = metadata.get("source_url", "") | |
| page_title = metadata.get("page_title", "") | |
| score = chunk.get("similarity_score", 0) | |
| if content: | |
| header = f"--- Section {i} (relevance: {score:.2f}" | |
| if page_title: | |
| header += f", page: {page_title}" | |
| if source_url: | |
| header += f", url: {source_url}" | |
| header += ") ---" | |
| sections.append(f"{header}\n{content}") | |
| return "\n\n".join(sections) | |
| def _format_history(self, history: List[Dict[str, str]]) -> str: | |
| if not history: | |
| return "No conversation history." | |
| recent = history[-5:] | |
| exchanges = [] | |
| for h in recent: | |
| q = h.get("question", "").strip() | |
| a = h.get("answer", "").strip() | |
| if q and a: | |
| exchanges.append(f"User: {q}\nAssistant: {a}") | |
| return "\n\n".join(exchanges) if exchanges else "No conversation history." | |
| async def _call_llm(self, system_prompt: str, query: str) -> str: | |
| try: | |
| messages = [ | |
| SystemMessage(content=system_prompt), | |
| HumanMessage(content=query), | |
| ] | |
| response = await self.llm.ainvoke(messages) | |
| return response.content.strip() | |
| except Exception as e: | |
| logger.error(f"LLM call failed: {str(e)}") | |
| raise | |
| async def get_site_info(self, site_id: str) -> Optional[Dict[str, Any]]: | |
| try: | |
| collection = self.processor.get_collection(site_id) | |
| if not collection: | |
| return None | |
| results = collection.get(limit=5) | |
| if not results["documents"]: | |
| return None | |
| total_docs = collection.count() | |
| sample = results["documents"][0] if results["documents"] else "" | |
| return { | |
| "site_id": site_id, | |
| "total_documents": total_docs, | |
| "content_preview": sample[:200] + "..." if len(sample) > 200 else sample, | |
| "last_indexed": ( | |
| results["metadatas"][0].get("upload_time") | |
| if results["metadatas"] and results["metadatas"][0] | |
| else "Unknown" | |
| ), | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting site info for {site_id}: {str(e)}") | |
| return None | |