import os import json import requests import asyncio import uuid import re import warnings from typing import TypedDict, Annotated, List, Dict, Any, Optional, Tuple from datetime import datetime from pathlib import Path from enum import Enum import tempfile # Suppress torch FutureWarning on HF Spaces warnings.filterwarnings('ignore', category=FutureWarning, module='.*torch.*') # Third-party imports import gradio as gr from dotenv import load_dotenv from bs4 import BeautifulSoup import faiss import pickle import numpy as np from sentence_transformers import SentenceTransformer import tiktoken from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage from langchain_mistralai.chat_models import ChatMistralAI from langchain_core.tools import tool from langchain_community.tools.tavily_search import TavilySearchResults from langchain_text_splitters import RecursiveCharacterTextSplitter # --- 1. SETUP & CONFIGURATION --- load_dotenv() # --- API KEY ROTATION --- class APIKeyRotator: """Handles API key rotation for fault tolerance and rate limit management""" def __init__(self, service_name: str): self.service_name = service_name self.keys = self._load_keys() self.current_index = 0 def _load_keys(self) -> List[str]: """Load all API keys for the service (e.g., MISTRAL_API_KEY, MISTRAL_API_KEY_2, etc.)""" keys = [] i = 1 while True: if i == 1: key = os.getenv(f"{self.service_name}_API_KEY") else: key = os.getenv(f"{self.service_name}_API_KEY_{i}") if not key: break keys.append(key) i += 1 if not keys: raise ValueError(f"āŒ No API keys found for {self.service_name}") print(f"āœ… Loaded {len(keys)} API key(s) for {self.service_name}") import random random.shuffle(keys) # Randomize order return keys def get_key(self) -> str: """Get current API key without rotation""" if not self.keys: raise ValueError(f"āŒ No API keys available for {self.service_name}") return self.keys[self.current_index] def rotate(self) -> str: """Rotate to next API key""" if not self.keys: raise ValueError(f"āŒ No API keys available for {self.service_name}") self.current_index = (self.current_index + 1) % len(self.keys) print(f"šŸ”„ Rotated {self.service_name} API key to index {self.current_index + 1}/{len(self.keys)}") return self.get_key() # Initialize key rotators mistral_rotator = APIKeyRotator("MISTRAL") tavily_rotator = APIKeyRotator("TAVILY") # Set current keys in environment os.environ["MISTRAL_API_KEY"] = mistral_rotator.get_key() os.environ["TAVILY_API_KEY"] = tavily_rotator.get_key() MISTRAL_API_KEY = os.environ["MISTRAL_API_KEY"] TAVILY_API_KEY = os.environ["TAVILY_API_KEY"] print(f"[OK] API Key rotation initialized") # Vector DB Configuration CHROMA_DB_PATH = Path(tempfile.gettempdir()) / "ecommerce_rag" CHROMA_DB_PATH.mkdir(exist_ok=True, parents=True) print(f"[VECTOR_STORE] Location: {CHROMA_DB_PATH.absolute()}") print(f"[OK] API Keys configured") # Global FAISS state faiss_index = None faiss_chunks_metadata = {} faiss_chunk_id_to_index = {} faiss_index_counter = 0 # Resilient LLM wrapper with key rotation def create_llm_with_rotation(model: str, temperature: float = 0.2, timeout: float = 120.0, max_retries: int = 5): """Create an LLM that rotates API keys on failure""" class ResilientLLM: def __init__(self, model, temperature, timeout, max_retries): self.model = model self.temperature = temperature self.timeout = timeout self.max_retries = max_retries self.llm = self._create_llm() def _create_llm(self): return ChatMistralAI( api_key=mistral_rotator.get_key(), model=self.model, temperature=self.temperature, timeout=self.timeout, max_retries=self.max_retries ) def invoke(self, messages, **kwargs): """Invoke with automatic key rotation on failure""" max_key_rotations = len(mistral_rotator.keys) for attempt in range(max_key_rotations): try: return self.llm.invoke(messages, **kwargs) except Exception as e: if attempt < max_key_rotations - 1: print(f"āš ļø Attempt {attempt + 1} failed: {str(e)[:80]}") new_key = mistral_rotator.rotate() self.llm = ChatMistralAI( api_key=new_key, model=self.model, temperature=self.temperature, timeout=self.timeout, max_retries=self.max_retries ) print(f"šŸ”„ Retrying with rotated key...") else: raise def bind_tools(self, tools): """Bind tools to LLM""" return self.llm.bind_tools(tools) return ResilientLLM(model, temperature, timeout, max_retries) # Initialize LLMs with resilience and key rotation llm_small = create_llm_with_rotation("mistral-small-latest", temperature=0.2, timeout=60.0, max_retries=5) llm_large = create_llm_with_rotation("mistral-large-latest", temperature=0.2, timeout=120.0, max_retries=5) # Initialize embedding model print("[LOADING] Embedding model...") embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') embedding_model.to('cpu') print("[OK] Embedding model initialized (CPU mode)") # Initialize tokenizer tokenizer = tiktoken.get_encoding("cl100k_base") # --- TRUSTED E-COMMERCE WEBSITES (REMOVED - Now accepting all sources) --- # No domain restrictions - Tavily will find the best sources dynamically # --- 2. STATE DEFINITIONS --- class WorkflowStage(Enum): IDLE = "idle" SEARCH = "search" CHUNKING = "chunking" INDEXING = "indexing" RETRIEVAL = "retrieval" SYNTHESIS = "synthesis" COMPLETE = "complete" ERROR = "error" class EcommerceAgentState(TypedDict): """Unified state for tool-based agentic RAG pipeline""" query: str stage: str search_results: List[Dict] cleaned_documents: List[Dict] parent_chunks: List[Dict] child_chunks: List[Dict] retrieved_children: List[Dict] retrieved_parents: List[Dict] parent_child_mapping: Dict[str, List[str]] final_answer: str citations: List[Dict] confidence: float error_message: Optional[str] progress_log: List[str] chat_messages: List[Tuple[str, str]] # List of (user/assistant, message) tuples def create_initial_state(query: str) -> EcommerceAgentState: """Create initial state for a new query""" return { "query": query, "stage": WorkflowStage.IDLE.value, "search_results": [], "cleaned_documents": [], "parent_chunks": [], "child_chunks": [], "retrieved_children": [], "retrieved_parents": [], "parent_child_mapping": {}, "final_answer": "", "citations": [], "confidence": 0.0, "error_message": None, "progress_log": [], "chat_messages": [] } # --- CONVERSATION CONTEXT --- # Global conversation memory to track context between queries conversation_history = [] previous_answers = [] def add_to_conversation_history(query: str, answer: str): """Add query and answer to conversation history""" global conversation_history, previous_answers conversation_history.append({"query": query, "answer": answer}) previous_answers.append(answer) # Keep only last 5 exchanges to avoid context bloat if len(conversation_history) > 5: conversation_history.pop(0) previous_answers.pop(0) def get_conversation_context() -> str: """Get formatted conversation context""" if not conversation_history: return "" context = "\n**Previous conversation:**\n" for i, item in enumerate(conversation_history[-3:], 1): # Last 3 exchanges context += f"{i}. Q: {item['query'][:100]}...\n A: {item['answer'][:150]}...\n" return context # --- 3. UTILITY FUNCTIONS --- def extract_domain(url: str) -> str: """Extract domain from URL""" try: from urllib.parse import urlparse parsed = urlparse(url) domain = parsed.netloc.replace("www.", "") return domain.split('/')[0] except: return "unknown" def is_trusted_source(url: str) -> bool: """Accept all sources - no domain restrictions""" return True # Accept all domains from Tavily results def count_tokens(text: str) -> int: """Count tokens""" try: return len(tokenizer.encode(text)) except: return len(text.split()) def generate_uuid() -> str: """Generate unique ID""" return str(uuid.uuid4())[:8] def faiss_add_chunks(embeddings: np.ndarray, chunk_ids: List[str], metadatas: List[Dict]) -> None: """Add chunks to FAISS""" global faiss_index, faiss_chunks_metadata, faiss_chunk_id_to_index, faiss_index_counter embeddings = np.array(embeddings, dtype=np.float32) faiss.normalize_L2(embeddings) if faiss_index is None: dimension = embeddings.shape[1] faiss_index = faiss.IndexFlatIP(dimension) faiss_index.add(embeddings) for chunk_id, metadata in zip(chunk_ids, metadatas): faiss_chunks_metadata[chunk_id] = metadata faiss_chunk_id_to_index[chunk_id] = faiss_index_counter faiss_index_counter += 1 def faiss_search(query_embedding: np.ndarray, k: int = 5) -> Dict: """Search FAISS""" global faiss_index, faiss_chunks_metadata, faiss_chunk_id_to_index if faiss_index is None or faiss_index.ntotal == 0: return {"ids": [[]], "documents": [[]], "metadatas": [[]], "distances": [[]]} query_embedding = np.array([query_embedding], dtype=np.float32) faiss.normalize_L2(query_embedding) distances, indices = faiss_index.search(query_embedding, min(k, faiss_index.ntotal)) indices = indices[0] result_ids = [] result_docs = [] result_metadatas = [] result_distances = [] for dist, idx in zip(distances[0], indices): for chunk_id, chunk_idx in faiss_chunk_id_to_index.items(): if chunk_idx == idx: metadata = faiss_chunks_metadata[chunk_id] result_ids.append(chunk_id) result_docs.append(metadata.get("document", "")) result_metadatas.append({k: v for k, v in metadata.items() if k != "document"}) # Convert numpy float32 to Python float for JSON serialization result_distances.append(float(dist)) break return { "ids": [result_ids], "documents": [result_docs], "metadatas": [result_metadatas], "distances": [result_distances] } # --- 4. TOOL DEFINITIONS --- @tool def search_ecommerce_products(query: str) -> str: """Search for e-commerce products using Tavily Search - searches broadly, curator filters later""" try: print(f"šŸ” Searching Tavily for: '{query}' (will curate results after)") search = TavilySearchResults( max_results=15, # Increased to get more diverse results api_key=TAVILY_API_KEY, search_depth="advanced" ) results = search.invoke(query) print(f" āœ… Found {len(results)} raw results (will filter to trusted sources)") return json.dumps(results, indent=2) except Exception as e: return f"Error in search: {str(e)}" @tool def scrape_product_content(url: str) -> str: """Scrape and clean content from a URL using requests + BeautifulSoup""" try: print(f"šŸ“„ Scraping: {extract_domain(url)}") headers = { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36', 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8', 'Accept-Language': 'en-US,en;q=0.5', } # Use session with proper cleanup to avoid file descriptor issues on HF Spaces session = requests.Session() try: response = session.get(url, timeout=15, headers=headers) response.raise_for_status() response.encoding = 'utf-8' # Parse with BeautifulSoup soup = BeautifulSoup(response.content, 'html.parser') # Remove unwanted elements for element in soup(["script", "style", "meta", "link", "nav", "footer", "header", "noscript", "iframe"]): element.decompose() # Extract text with newline separator to preserve structure text = soup.get_text(separator='\n', strip=True) # Clean up whitespace lines = (line.strip() for line in text.splitlines()) chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) text = '\n'.join(chunk for chunk in chunks if chunk) # Limit content text = text[:8000] if len(text) > 200: print(f" āœ… Extracted {len(text)} characters") return text else: return "Insufficient content" except requests.Timeout: print(f" [ERROR] Timeout after 15s") return "Error: Request timeout" except requests.RequestException as e: print(f" [ERROR] Request error: {str(e)}") return f"Error: {str(e)}" finally: session.close() except Exception as e: print(f" [ERROR] Unexpected error: {str(e)}") return f"Error scraping: {str(e)}" @tool def curator_filter_sites(urls_json: str) -> str: """Curator Agent - Process and deduplicate URLs (no domain filtering) Returns list of unique URLs from search results """ try: urls = json.loads(urls_json) if isinstance(urls_json, str) else urls_json if isinstance(urls, dict): urls = [urls.get("url", "")] if urls.get("url") else [] elif not isinstance(urls, list): urls = [str(urls)] print(f"šŸŽÆ Curator Agent: Processing {len(urls)} URLs...") curated_urls = [] seen_urls = set() for url in urls: if isinstance(url, dict): url = url.get("url", "") if not url or url in seen_urls: continue seen_urls.add(url) curated_urls.append({ "url": url, "domain": extract_domain(url) }) print(f" āœ… Added: {extract_domain(url)}") result = { "total_input": len(urls), "curated_count": len(curated_urls), "urls": curated_urls } print(f" šŸ“Š Result: {len(curated_urls)} unique URLs") return json.dumps(result, indent=2) except Exception as e: return f"Error in curator filter: {str(e)}" @tool def chunk_content(content: str, url: str) -> str: """Split content into parent and child chunks""" try: parent_id = generate_uuid() parent_chunk = { "id": parent_id, "document": content[:4000], "source": url, "chunk_type": "parent", "created_at": datetime.now().isoformat() } # Create child chunks splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) child_texts = splitter.split_text(content) child_chunks = [] for i, child_text in enumerate(child_texts[:5]): # Limit to 5 children child_id = generate_uuid() child_chunks.append({ "id": child_id, "document": child_text, "source": url, "chunk_type": "child", "parent_id": parent_id, "sequence": i, "created_at": datetime.now().isoformat() }) result = { "parent": parent_chunk, "children": child_chunks, "count": len(child_chunks) } return json.dumps(result, indent=2) except Exception as e: return f"Error chunking: {str(e)}" @tool def generate_embeddings_and_index(chunks_json: str) -> str: """Generate embeddings and add to FAISS index""" try: chunks_data = json.loads(chunks_json) parent = chunks_data.get("parent", {}) children = chunks_data.get("children", []) # Only include valid chunks (not empty) all_chunks = [] if parent and parent.get("document"): all_chunks.append(parent) all_chunks.extend([c for c in children if c.get("document")]) if not all_chunks: print("āš ļø No valid chunks to index") return json.dumps({"success": True, "chunks_indexed": 0, "total_index_size": 0}) # Extract texts texts = [chunk.get("document", "") for chunk in all_chunks] chunk_ids = [chunk.get("id", "") for chunk in all_chunks] print(f"ā³ Generating embeddings for {len(all_chunks)} chunks...") # Generate embeddings with minimal batching - optimized for CPU embeddings_list = [] batch_size = 2 # Ultra-minimal for CPU (was 4) for i in range(0, len(texts), batch_size): batch_texts = texts[i:i+batch_size] print(f" Processing batch {i//batch_size + 1}/{(len(texts)-1)//batch_size + 1}...") try: # Encode without progress bar to avoid hang batch_embeddings = embedding_model.encode(batch_texts, convert_to_numpy=True, show_progress_bar=False) embeddings_list.extend(batch_embeddings) print(f" āœ… Encoded {len(batch_texts)} texts") except Exception as be: print(f" āš ļø Batch error: {str(be)}, using random embeddings as fallback") # Fallback: use random embeddings embeddings_list.extend(np.random.randn(len(batch_texts), 384).astype(np.float32)) # Add to FAISS embeddings = np.array(embeddings_list, dtype=np.float32) print(f" Adding {len(embeddings)} embeddings to FAISS...") faiss_add_chunks(embeddings, chunk_ids, all_chunks) print(f"āœ… Successfully indexed {len(all_chunks)} chunks") return json.dumps({ "success": True, "chunks_indexed": len(all_chunks), "total_index_size": sum(1 for _ in faiss_chunk_id_to_index.keys()) }) except Exception as e: print(f"āŒ Indexing error: {str(e)}") return json.dumps({"success": False, "error": str(e)}) @tool def retrieve_similar_chunks(query_text: str, k: int = 5) -> str: """Retrieve similar chunks from FAISS""" try: print(f"šŸ”Ž Retrieving similar chunks for: '{query_text[:50]}...'") try: # Try to encode query with timeout fallback query_embedding = embedding_model.encode(query_text, convert_to_numpy=True, show_progress_bar=False) print(f" āœ… Query encoded") except Exception as e: print(f" āš ļø Encoding timeout/error: {str(e)[:50]}, using random embedding") # Fallback: use random embedding query_embedding = np.random.randn(384).astype(np.float32) results = faiss_search(query_embedding, k=k) if not results["ids"][0]: print(f" āš ļø No chunks found in FAISS (index size: {faiss_index.ntotal if faiss_index else 0})") else: print(f" āœ… Found {len(results['ids'][0])} similar chunks") return json.dumps(results, indent=2) except Exception as e: print(f" āŒ Retrieval error: {str(e)}") return json.dumps({"ids": [[]], "documents": [[]], "metadatas": [[]], "distances": [[]]}) def generate_answer(query: str, context_chunks: str, chat_history: str = "") -> str: """Generate answer using LLM with context and chat history - with retry logic""" max_attempts = 3 attempt = 0 while attempt < max_attempts: try: print(f"šŸ’” Generating answer for: '{query[:50]}...' (Attempt {attempt + 1}/{max_attempts})") system_msg = SystemMessage(content="""You are an e-commerce assistant. Answer the user's query based on the provided context AND previous conversation. Be concise, accurate, and cite sources when possible. Remember that default region is INDIA and current year in 2026. If answering a follow-up question, use insights from previous messages to provide better recommendations.""") # Build user message with chat history context user_content = f""" {chat_history} Current Query: {query} Context chunks: {context_chunks} Please provide a helpful answer with confidence level (0-100%).""" user_msg = HumanMessage(content=user_content) response = llm_large.invoke([system_msg, user_msg]) print(f"āœ… Answer generated successfully") return response.content except Exception as e: attempt += 1 error_msg = str(e) print(f"āš ļø Attempt {attempt} failed: {error_msg[:80]}") if attempt < max_attempts: wait_time = 5 * attempt # 5s, 10s, 15s print(f" ā³ Retrying in {wait_time}s...") import time time.sleep(wait_time) else: print(f"āŒ Failed after {max_attempts} attempts") return f"Error generating answer: {error_msg[:100]}" # --- 5. AGENT FUNCTIONS WITH TOOLS --- def run_decision_agent(query: str, chat_history: List = None, progress_callback=None) -> Tuple[bool, str]: """ Decision Agent - Decides if web search is needed Returns: (needs_search: bool, reasoning: str) """ print(f"\n{'='*70}\nšŸ¤” DECISION AGENT\n{'='*70}") if chat_history is None: chat_history = [] # Check if we have previous answers if not chat_history or len(chat_history) <= 1: print("āœ… First question - will perform web search") return True, "First query - need fresh data" # Build conversation context from chat history context = "Previous Conversation:\n" for msg in chat_history[-6:]: # Last 3 exchanges role = msg.get("role", "").upper() content = msg.get("content", "")[:200] # First 200 chars context += f"{role}: {content}\n" # Decision prompt system_msg = SystemMessage(content="""You are a decision agent for e-commerce queries. Analyze if the current query can be answered from PREVIOUS conversation context or if it needs NEW web search. Rules: 1. If query asks about SPECIFIC products/prices/comparisons from before → NO search needed, use previous context 2. If query asks for DIFFERENT products or NEW information → YES search needed 3. If query is a follow-up question about something already discussed → NO search needed 4. If unclear or asking for LATEST data → YES search needed Respond with ONLY: "SEARCH" or "CONTEXT" """) user_msg = HumanMessage(content=f"""{context} Current query: {query} Decision: (respond with ONLY "SEARCH" or "CONTEXT")""") try: response = llm_small.invoke([system_msg, user_msg]) decision = response.content.strip().upper() needs_search = "SEARCH" in decision reasoning = f"Decision: {'Will search web' if needs_search else 'Will use previous context'}" print(f"āœ… {reasoning}") return needs_search, reasoning except Exception as e: print(f"āš ļø Decision error: {str(e)[:50]}, defaulting to search") return True, f"Error in decision (default to search): {str(e)[:30]}" def run_search_agent(state: EcommerceAgentState, progress_callback=None) -> EcommerceAgentState: """Search Agent - Find products using Tavily + scrape content""" try: print(f"\n{'='*70}\nšŸ” SEARCH AGENT\n{'='*70}") if progress_callback: progress_callback("šŸ” Searching for products...") query = state["query"] # Step 1: Generate search queries (with timeout fallback) system_msg = SystemMessage(content="""Generate 3-4 search variations for an e-commerce query. Focus on product names, prices, availability. Remember that the default region is INDIA and current year is 2026. Return ONLY a JSON array of strings.""") user_msg = HumanMessage(content=f"Query: {query}\n\nReturn search variations as JSON array:") try: response = llm_small.invoke([system_msg, user_msg]) search_queries = [query] try: json_match = re.search(r'\[.*\]', response.content, re.DOTALL) if json_match: search_queries.extend(json.loads(json_match.group())) except: search_queries.extend([f"{query} price", f"{query} buy online"]) except Exception as e: print(f"āš ļø LLM timeout/error, using default queries: {str(e)[:50]}") search_queries = [query, f"{query} price", f"{query} buy online"] search_queries = search_queries[:5] state["progress_log"].append(f"āœ… Generated {len(search_queries)} search queries") print(f"Search queries: {search_queries}") # Step 2: Broad Tavily search (will filter with Curator Agent after) print("šŸ“„ Searching broadly with Tavily...") tavily_search = TavilySearchResults( max_results=15, # More results for better diversity api_key=TAVILY_API_KEY, search_depth="advanced" ) all_results = [] for sq in search_queries[:3]: # Limit to 3 to avoid too many API calls try: print(f" šŸ” Query: '{sq[:40]}...'") results = tavily_search.invoke(sq) all_results.extend(results) print(f" āœ… Found {len(results)} results") except Exception as e: print(f" āš ļø Error: {str(e)[:50]}") state["progress_log"].append(f"āš ļø Search error: {str(e)[:50]}") # Step 3: Deduplicate results (no domain filtering) print(f"šŸŽÆ Processing {len(all_results)} results...") seen_urls = set() seen_domains = {} unique_results = [] for result in all_results: # Ensure result is a dictionary if not isinstance(result, dict): try: result = json.loads(result) if isinstance(result, str) else {"url": str(result)} except: continue url = result.get("url", "") if url and url not in seen_urls: seen_urls.add(url) domain = extract_domain(url) unique_results.append(result) seen_domains[domain] = seen_domains.get(domain, 0) + 1 print(f" āœ… Added: {domain}") print(f"\nšŸ“Š Results: {len(unique_results)} URLs from {len(seen_domains)} domains") print(f" Domains found: {list(seen_domains.keys())}") unique_results = unique_results[:12] # Keep top 12 for diversity state["search_results"] = unique_results state["stage"] = WorkflowStage.CHUNKING.value state["progress_log"].append(f"āœ… Search complete: {len(unique_results)} URLs found") print(f"āœ… Search complete: {len(unique_results)} URLs") return state except Exception as e: print(f"āŒ Search agent error: {str(e)}") state["error_message"] = f"Search agent failed: {str(e)[:100]}" state["stage"] = WorkflowStage.ERROR.value return state def run_chunking_agent(state: EcommerceAgentState, progress_callback=None) -> EcommerceAgentState: """Chunking Agent - Scrape and chunk documents""" try: print(f"\n{'='*70}\nšŸ“„ CHUNKING AGENT\n{'='*70}") if progress_callback: progress_callback("šŸ“„ Chunking documents...") # Scrape URLs cleaned_documents = [] for i, result in enumerate(state.get("search_results", []), 1): url = result.get("url", "") if not url: continue domain = extract_domain(url) print(f"šŸ“„ Scraping [{i}]: {domain}") try: content = scrape_product_content.invoke({"url": url}) if content and "Error" not in content and len(content) > 200: cleaned_documents.append({ "url": url, "content": content, "source": domain, "scraped_at": datetime.now().isoformat() }) print(f" āœ… Extracted {len(content)} chars") state["progress_log"].append(f"āœ… Scraped: {domain}") except Exception as e: print(f" āš ļø Error: {str(e)[:50]}") state["progress_log"].append(f"āš ļø Scrape error: {domain}") # Chunk documents parent_chunks = [] child_chunks = [] parent_child_mapping = {} for doc in cleaned_documents: try: chunks_result = chunk_content.invoke({ "content": doc["content"], "url": doc["url"] }) chunks_data = json.loads(chunks_result) parent = chunks_data.get("parent", {}) children = chunks_data.get("children", []) if parent: parent_chunks.append(parent) parent_child_mapping[parent["id"]] = [c["id"] for c in children] child_chunks.extend(children) except Exception as e: print(f" āŒ Chunking error: {str(e)[:50]}") state["cleaned_documents"] = cleaned_documents state["parent_chunks"] = parent_chunks state["child_chunks"] = child_chunks state["parent_child_mapping"] = parent_child_mapping state["stage"] = WorkflowStage.INDEXING.value state["progress_log"].append(f"āœ… Chunking: {len(parent_chunks)} parents, {len(child_chunks)} children") print(f"āœ… Chunking complete: {len(parent_chunks)} parents, {len(child_chunks)} children") return state except Exception as e: state["error_message"] = f"Chunking agent failed: {str(e)}" state["stage"] = WorkflowStage.ERROR.value return state def run_indexing_agent(state: EcommerceAgentState, progress_callback=None) -> EcommerceAgentState: """Indexing Agent - Index chunks in FAISS""" try: print(f"\n{'='*70}\nšŸ—‚ļø INDEXING AGENT\n{'='*70}") if progress_callback: progress_callback("šŸ—‚ļø Indexing chunks...") # Combine all chunks all_chunks = state.get("parent_chunks", []) + state.get("child_chunks", []) if not all_chunks: state["progress_log"].append("āš ļø No chunks to index") state["stage"] = WorkflowStage.RETRIEVAL.value return state # Generate embeddings using tool chunks_json = json.dumps({ "parent": state["parent_chunks"][0] if state["parent_chunks"] else {}, "children": state["child_chunks"] }) result = generate_embeddings_and_index.invoke({"chunks_json": chunks_json}) index_result = json.loads(result) state["stage"] = WorkflowStage.RETRIEVAL.value state["progress_log"].append(f"āœ… Indexed {index_result.get('chunks_indexed', 0)} chunks") print(f"āœ… Indexing complete") return state except Exception as e: state["error_message"] = f"Indexing agent failed: {str(e)}" state["stage"] = WorkflowStage.ERROR.value return state def run_retrieval_agent(state: EcommerceAgentState, progress_callback=None) -> EcommerceAgentState: """Retrieval Agent - Find similar chunks""" try: print(f"\n{'='*70}\nšŸ”Ž RETRIEVAL AGENT\n{'='*70}") if progress_callback: progress_callback("šŸ”Ž Retrieving similar content...") query = state["query"] # Check if FAISS has any chunks if not faiss_index or faiss_index.ntotal == 0: print("āš ļø FAISS index is empty - no chunks to retrieve") state["retrieved_children"] = [] state["progress_log"].append("āš ļø FAISS index empty") else: # Retrieve chunks result_json = retrieve_similar_chunks.invoke({"query_text": query, "k": 5}) if isinstance(result_json, dict): results = result_json else: results = json.loads(result_json) retrieved_children = [] for doc_id, doc_text, metadata, dist in zip( results.get("ids", [[]])[0], results.get("documents", [[]])[0], results.get("metadatas", [[]])[0], results.get("distances", [[]])[0] ): retrieved_children.append({ "id": doc_id, "document": doc_text, "metadata": metadata, "distance": float(dist) }) state["retrieved_children"] = retrieved_children state["progress_log"].append(f"āœ… Retrieved {len(retrieved_children)} chunks") print(f"āœ… Retrieval complete: {len(retrieved_children)} chunks") state["stage"] = WorkflowStage.SYNTHESIS.value return state except Exception as e: print(f"āŒ Retrieval error: {str(e)}") state["error_message"] = f"Retrieval agent failed: {str(e)}" state["stage"] = WorkflowStage.ERROR.value return state def run_synthesis_agent(state: EcommerceAgentState, progress_callback=None) -> EcommerceAgentState: """Synthesis Agent - Generate answer with robust error handling""" try: print(f"\n{'='*70}\nšŸ’” SYNTHESIS AGENT\n{'='*70}") if progress_callback: progress_callback("šŸ’” Generating answer...") query = state["query"] retrieved = state.get("retrieved_children", []) chat_messages = state.get("chat_messages", []) # Prepare context - handle empty case if retrieved: context = "\n\n".join([ f"[{i+1}] {chunk.get('metadata', {}).get('source', 'Unknown')}\n{chunk.get('document', '')}" for i, chunk in enumerate(retrieved[:5]) ]) confidence = min(0.9, len(retrieved) / 10.0) else: context = "[No relevant information found in knowledge base]" confidence = 0.1 # Build chat history context chat_context = "" if chat_messages and len(chat_messages) > 1: chat_context = "Previous Conversation:\n" for msg in chat_messages[:-1]: # All except current query role = msg.get("role", "").upper() content = msg.get("content", "") chat_context += f"{role}: {content[:300]}\n\n" # First 300 chars of each message # Generate answer with retry logic built-in answer = None max_retries = 3 for retry_attempt in range(max_retries): try: print(f" ā³ Generating (attempt {retry_attempt + 1}/{max_retries})...") answer = generate_answer(query, context, chat_context) # Check if answer is an error if answer and not answer.startswith("Error"): print(f" āœ… Answer generated successfully") break elif answer and answer.startswith("Error"): print(f" āš ļø {answer[:80]}") if retry_attempt < max_retries - 1: wait = 5 * (retry_attempt + 1) print(f" ā³ Waiting {wait}s before retry...") import time time.sleep(wait) else: print(f" āŒ Max retries reached") except Exception as gen_err: print(f" āš ļø Generation error: {str(gen_err)[:80]}") if retry_attempt < max_retries - 1: wait = 5 * (retry_attempt + 1) print(f" ā³ Waiting {wait}s before retry...") import time time.sleep(wait) else: answer = f"Error generating answer: {str(gen_err)[:100]}" if not answer: answer = f"Error generating answer: Unable to generate response after {max_retries} attempts" # Extract citations from retrieved chunks citations = [ { "source": chunk.get("metadata", {}).get("source", "Unknown"), "section": "Content", "url": chunk.get("metadata", {}).get("source", "") } for chunk in retrieved[:3] ] state["final_answer"] = answer state["citations"] = citations state["confidence"] = confidence state["stage"] = WorkflowStage.COMPLETE.value state["progress_log"].append(f"āœ… Synthesis complete: Confidence={confidence:.2f}") print(f"āœ… Synthesis complete") return state except Exception as e: print(f"āŒ Synthesis error: {str(e)}") state["error_message"] = f"Synthesis agent failed: {str(e)}" state["stage"] = WorkflowStage.ERROR.value return state # --- 6. WORKFLOW ORCHESTRATION --- def run_complete_workflow(query: str, chat_history: List = None, progress=gr.Progress()) -> EcommerceAgentState: """Run the complete tool-based agentic pipeline with intelligent context""" if chat_history is None: chat_history = [] state = create_initial_state(query) state["chat_messages"] = chat_history # Add chat history to state def update_progress(msg: str): progress(0.2, desc=msg) # Step 1: Decision Agent - Determine if search is needed needs_search, decision_reason = run_decision_agent(query, chat_history, update_progress) state["progress_log"].append(f"šŸ¤” {decision_reason}") # Step 2: Conditionally run search agents if needs_search: progress(0.25, desc="šŸ” Searching...") state = run_search_agent(state, update_progress) if state["stage"] == WorkflowStage.ERROR.value: return state progress(0.35, desc="šŸ“„ Chunking...") state = run_chunking_agent(state, update_progress) if state["stage"] == WorkflowStage.ERROR.value: return state progress(0.5, desc="šŸ—‚ļø Indexing...") state = run_indexing_agent(state, update_progress) if state["stage"] == WorkflowStage.ERROR.value: return state else: # Skip search/chunking/indexing, go directly to retrieval print("šŸ“š Reusing indexed data from previous query...") state["progress_log"].append("šŸ“š Using existing indexed knowledge") state["stage"] = WorkflowStage.RETRIEVAL.value # Step 3: Retrieval and Synthesis (always run) progress(0.65, desc="šŸ”Ž Retrieving...") state = run_retrieval_agent(state, update_progress) if state["stage"] == WorkflowStage.ERROR.value: return state progress(0.85, desc="šŸ’” Generating answer...") state = run_synthesis_agent(state, update_progress) progress(1.0, desc="āœ… Complete") # Add to conversation history add_to_conversation_history(query, state["final_answer"]) return state # --- 7. GRADIO INTERFACE ---# def chat_interface(message: str, chat_history: List) -> Tuple[str, List]: """Chat interface for the e-commerce RAG system""" if not message or message.strip() == "": return "", chat_history # Add user message to history (new format: dict with role and content) chat_history.append({"role": "user", "content": message}) # Run workflow with full chat history state = run_complete_workflow(message, chat_history=chat_history) # Build assistant response assistant_response = f"""šŸ¤– **Answer** {state['final_answer']} **Confidence:** {state['confidence']*100:.0f}% --- šŸ“š **Sources:** """ for cite in state["citations"][:3]: assistant_response += f"- {cite['source']}\n" assistant_response += f"\n---\n\nšŸ“‹ **Process Log:**\n" for log in state["progress_log"]: assistant_response += f"āœ“ {log}\n" if state["error_message"]: assistant_response += f"\nāš ļø **Error:** {state['error_message']}" # Update last message with assistant response (new format) chat_history.append({"role": "assistant", "content": assistant_response}) return "", chat_history # Build Gradio interface with gr.Blocks(title="šŸ›ļø E-commerce Chat RAG") as demo: gr.Markdown(""" # šŸ›ļø E-commerce Chat RAG System Ask any question about e-commerce products and prices! *Agents:* šŸ” Search → šŸ“„ Chunk → šŸ—‚ļø Index → šŸ”Ž Retrieve → šŸ’” Synthesize """) chatbot = gr.Chatbot( label="šŸ’¬ Chat", height=600 ) with gr.Row(): msg = gr.Textbox( label="Your Question", placeholder="e.g., What is the current price of iPhone 14?", lines=2, scale=4 ) submit_btn = gr.Button("Send šŸ“¤", variant="primary", scale=1) # Chat submission submit_btn.click( fn=chat_interface, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=True ) # Allow Enter key to submit msg.submit( fn=chat_interface, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=True ) if __name__ == "__main__": try: # For HF Spaces: suppress asyncio event loop cleanup warnings demo.launch(share=False, show_error=True) finally: # Clean up event loop on shutdown try: loop = asyncio.get_event_loop() if loop.is_running(): loop.stop() loop.close() except Exception: pass