Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import requests | |
| import asyncio | |
| import uuid | |
| import re | |
| import warnings | |
| from typing import TypedDict, Annotated, List, Dict, Any, Optional, Tuple | |
| from datetime import datetime | |
| from pathlib import Path | |
| from enum import Enum | |
| import tempfile | |
| # Suppress torch FutureWarning on HF Spaces | |
| warnings.filterwarnings('ignore', category=FutureWarning, module='.*torch.*') | |
| # Third-party imports | |
| import gradio as gr | |
| from dotenv import load_dotenv | |
| from bs4 import BeautifulSoup | |
| import faiss | |
| import pickle | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| import tiktoken | |
| from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage | |
| from langchain_mistralai.chat_models import ChatMistralAI | |
| from langchain_core.tools import tool | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| # --- 1. SETUP & CONFIGURATION --- | |
| load_dotenv() | |
| # --- API KEY ROTATION --- | |
| class APIKeyRotator: | |
| """Handles API key rotation for fault tolerance and rate limit management""" | |
| def __init__(self, service_name: str): | |
| self.service_name = service_name | |
| self.keys = self._load_keys() | |
| self.current_index = 0 | |
| def _load_keys(self) -> List[str]: | |
| """Load all API keys for the service (e.g., MISTRAL_API_KEY, MISTRAL_API_KEY_2, etc.)""" | |
| keys = [] | |
| i = 1 | |
| while True: | |
| if i == 1: | |
| key = os.getenv(f"{self.service_name}_API_KEY") | |
| else: | |
| key = os.getenv(f"{self.service_name}_API_KEY_{i}") | |
| if not key: | |
| break | |
| keys.append(key) | |
| i += 1 | |
| if not keys: | |
| raise ValueError(f"β No API keys found for {self.service_name}") | |
| print(f"β Loaded {len(keys)} API key(s) for {self.service_name}") | |
| import random | |
| random.shuffle(keys) # Randomize order | |
| return keys | |
| def get_key(self) -> str: | |
| """Get current API key without rotation""" | |
| if not self.keys: | |
| raise ValueError(f"β No API keys available for {self.service_name}") | |
| return self.keys[self.current_index] | |
| def rotate(self) -> str: | |
| """Rotate to next API key""" | |
| if not self.keys: | |
| raise ValueError(f"β No API keys available for {self.service_name}") | |
| self.current_index = (self.current_index + 1) % len(self.keys) | |
| print(f"π Rotated {self.service_name} API key to index {self.current_index + 1}/{len(self.keys)}") | |
| return self.get_key() | |
| # Initialize key rotators | |
| mistral_rotator = APIKeyRotator("MISTRAL") | |
| tavily_rotator = APIKeyRotator("TAVILY") | |
| # Set current keys in environment | |
| os.environ["MISTRAL_API_KEY"] = mistral_rotator.get_key() | |
| os.environ["TAVILY_API_KEY"] = tavily_rotator.get_key() | |
| MISTRAL_API_KEY = os.environ["MISTRAL_API_KEY"] | |
| TAVILY_API_KEY = os.environ["TAVILY_API_KEY"] | |
| print(f"[OK] API Key rotation initialized") | |
| # Vector DB Configuration | |
| CHROMA_DB_PATH = Path(tempfile.gettempdir()) / "ecommerce_rag" | |
| CHROMA_DB_PATH.mkdir(exist_ok=True, parents=True) | |
| print(f"[VECTOR_STORE] Location: {CHROMA_DB_PATH.absolute()}") | |
| print(f"[OK] API Keys configured") | |
| # Global FAISS state | |
| faiss_index = None | |
| faiss_chunks_metadata = {} | |
| faiss_chunk_id_to_index = {} | |
| faiss_index_counter = 0 | |
| # Resilient LLM wrapper with key rotation | |
| def create_llm_with_rotation(model: str, temperature: float = 0.2, timeout: float = 120.0, max_retries: int = 5): | |
| """Create an LLM that rotates API keys on failure""" | |
| class ResilientLLM: | |
| def __init__(self, model, temperature, timeout, max_retries): | |
| self.model = model | |
| self.temperature = temperature | |
| self.timeout = timeout | |
| self.max_retries = max_retries | |
| self.llm = self._create_llm() | |
| def _create_llm(self): | |
| return ChatMistralAI( | |
| api_key=mistral_rotator.get_key(), | |
| model=self.model, | |
| temperature=self.temperature, | |
| timeout=self.timeout, | |
| max_retries=self.max_retries | |
| ) | |
| def invoke(self, messages, **kwargs): | |
| """Invoke with automatic key rotation on failure""" | |
| max_key_rotations = len(mistral_rotator.keys) | |
| for attempt in range(max_key_rotations): | |
| try: | |
| return self.llm.invoke(messages, **kwargs) | |
| except Exception as e: | |
| if attempt < max_key_rotations - 1: | |
| print(f"β οΈ Attempt {attempt + 1} failed: {str(e)[:80]}") | |
| new_key = mistral_rotator.rotate() | |
| self.llm = ChatMistralAI( | |
| api_key=new_key, | |
| model=self.model, | |
| temperature=self.temperature, | |
| timeout=self.timeout, | |
| max_retries=self.max_retries | |
| ) | |
| print(f"π Retrying with rotated key...") | |
| else: | |
| raise | |
| def bind_tools(self, tools): | |
| """Bind tools to LLM""" | |
| return self.llm.bind_tools(tools) | |
| return ResilientLLM(model, temperature, timeout, max_retries) | |
| # Initialize LLMs with resilience and key rotation | |
| llm_small = create_llm_with_rotation("mistral-small-latest", temperature=0.2, timeout=60.0, max_retries=5) | |
| llm_large = create_llm_with_rotation("mistral-large-latest", temperature=0.2, timeout=120.0, max_retries=5) | |
| # Initialize embedding model | |
| print("[LOADING] Embedding model...") | |
| embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
| embedding_model.to('cpu') | |
| print("[OK] Embedding model initialized (CPU mode)") | |
| # Initialize tokenizer | |
| tokenizer = tiktoken.get_encoding("cl100k_base") | |
| # --- TRUSTED E-COMMERCE WEBSITES (REMOVED - Now accepting all sources) --- | |
| # No domain restrictions - Tavily will find the best sources dynamically | |
| # --- 2. STATE DEFINITIONS --- | |
| class WorkflowStage(Enum): | |
| IDLE = "idle" | |
| SEARCH = "search" | |
| CHUNKING = "chunking" | |
| INDEXING = "indexing" | |
| RETRIEVAL = "retrieval" | |
| SYNTHESIS = "synthesis" | |
| COMPLETE = "complete" | |
| ERROR = "error" | |
| class EcommerceAgentState(TypedDict): | |
| """Unified state for tool-based agentic RAG pipeline""" | |
| query: str | |
| stage: str | |
| search_results: List[Dict] | |
| cleaned_documents: List[Dict] | |
| parent_chunks: List[Dict] | |
| child_chunks: List[Dict] | |
| retrieved_children: List[Dict] | |
| retrieved_parents: List[Dict] | |
| parent_child_mapping: Dict[str, List[str]] | |
| final_answer: str | |
| citations: List[Dict] | |
| confidence: float | |
| error_message: Optional[str] | |
| progress_log: List[str] | |
| chat_messages: List[Tuple[str, str]] # List of (user/assistant, message) tuples | |
| def create_initial_state(query: str) -> EcommerceAgentState: | |
| """Create initial state for a new query""" | |
| return { | |
| "query": query, | |
| "stage": WorkflowStage.IDLE.value, | |
| "search_results": [], | |
| "cleaned_documents": [], | |
| "parent_chunks": [], | |
| "child_chunks": [], | |
| "retrieved_children": [], | |
| "retrieved_parents": [], | |
| "parent_child_mapping": {}, | |
| "final_answer": "", | |
| "citations": [], | |
| "confidence": 0.0, | |
| "error_message": None, | |
| "progress_log": [], | |
| "chat_messages": [] | |
| } | |
| # --- CONVERSATION CONTEXT --- | |
| # Global conversation memory to track context between queries | |
| conversation_history = [] | |
| previous_answers = [] | |
| def add_to_conversation_history(query: str, answer: str): | |
| """Add query and answer to conversation history""" | |
| global conversation_history, previous_answers | |
| conversation_history.append({"query": query, "answer": answer}) | |
| previous_answers.append(answer) | |
| # Keep only last 5 exchanges to avoid context bloat | |
| if len(conversation_history) > 5: | |
| conversation_history.pop(0) | |
| previous_answers.pop(0) | |
| def get_conversation_context() -> str: | |
| """Get formatted conversation context""" | |
| if not conversation_history: | |
| return "" | |
| context = "\n**Previous conversation:**\n" | |
| for i, item in enumerate(conversation_history[-3:], 1): # Last 3 exchanges | |
| context += f"{i}. Q: {item['query'][:100]}...\n A: {item['answer'][:150]}...\n" | |
| return context | |
| # --- 3. UTILITY FUNCTIONS --- | |
| def extract_domain(url: str) -> str: | |
| """Extract domain from URL""" | |
| try: | |
| from urllib.parse import urlparse | |
| parsed = urlparse(url) | |
| domain = parsed.netloc.replace("www.", "") | |
| return domain.split('/')[0] | |
| except: | |
| return "unknown" | |
| def is_trusted_source(url: str) -> bool: | |
| """Accept all sources - no domain restrictions""" | |
| return True # Accept all domains from Tavily results | |
| def count_tokens(text: str) -> int: | |
| """Count tokens""" | |
| try: | |
| return len(tokenizer.encode(text)) | |
| except: | |
| return len(text.split()) | |
| def generate_uuid() -> str: | |
| """Generate unique ID""" | |
| return str(uuid.uuid4())[:8] | |
| def faiss_add_chunks(embeddings: np.ndarray, chunk_ids: List[str], metadatas: List[Dict]) -> None: | |
| """Add chunks to FAISS""" | |
| global faiss_index, faiss_chunks_metadata, faiss_chunk_id_to_index, faiss_index_counter | |
| embeddings = np.array(embeddings, dtype=np.float32) | |
| faiss.normalize_L2(embeddings) | |
| if faiss_index is None: | |
| dimension = embeddings.shape[1] | |
| faiss_index = faiss.IndexFlatIP(dimension) | |
| faiss_index.add(embeddings) | |
| for chunk_id, metadata in zip(chunk_ids, metadatas): | |
| faiss_chunks_metadata[chunk_id] = metadata | |
| faiss_chunk_id_to_index[chunk_id] = faiss_index_counter | |
| faiss_index_counter += 1 | |
| def faiss_search(query_embedding: np.ndarray, k: int = 5) -> Dict: | |
| """Search FAISS""" | |
| global faiss_index, faiss_chunks_metadata, faiss_chunk_id_to_index | |
| if faiss_index is None or faiss_index.ntotal == 0: | |
| return {"ids": [[]], "documents": [[]], "metadatas": [[]], "distances": [[]]} | |
| query_embedding = np.array([query_embedding], dtype=np.float32) | |
| faiss.normalize_L2(query_embedding) | |
| distances, indices = faiss_index.search(query_embedding, min(k, faiss_index.ntotal)) | |
| indices = indices[0] | |
| result_ids = [] | |
| result_docs = [] | |
| result_metadatas = [] | |
| result_distances = [] | |
| for dist, idx in zip(distances[0], indices): | |
| for chunk_id, chunk_idx in faiss_chunk_id_to_index.items(): | |
| if chunk_idx == idx: | |
| metadata = faiss_chunks_metadata[chunk_id] | |
| result_ids.append(chunk_id) | |
| result_docs.append(metadata.get("document", "")) | |
| result_metadatas.append({k: v for k, v in metadata.items() if k != "document"}) | |
| # Convert numpy float32 to Python float for JSON serialization | |
| result_distances.append(float(dist)) | |
| break | |
| return { | |
| "ids": [result_ids], | |
| "documents": [result_docs], | |
| "metadatas": [result_metadatas], | |
| "distances": [result_distances] | |
| } | |
| # --- 4. TOOL DEFINITIONS --- | |
| def search_ecommerce_products(query: str) -> str: | |
| """Search for e-commerce products using Tavily Search - searches broadly, curator filters later""" | |
| try: | |
| print(f"π Searching Tavily for: '{query}' (will curate results after)") | |
| search = TavilySearchResults( | |
| max_results=15, # Increased to get more diverse results | |
| api_key=TAVILY_API_KEY, | |
| search_depth="advanced" | |
| ) | |
| results = search.invoke(query) | |
| print(f" β Found {len(results)} raw results (will filter to trusted sources)") | |
| return json.dumps(results, indent=2) | |
| except Exception as e: | |
| return f"Error in search: {str(e)}" | |
| def scrape_product_content(url: str) -> str: | |
| """Scrape and clean content from a URL using requests + BeautifulSoup""" | |
| try: | |
| print(f"π₯ Scraping: {extract_domain(url)}") | |
| headers = { | |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36', | |
| 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8', | |
| 'Accept-Language': 'en-US,en;q=0.5', | |
| } | |
| # Use session with proper cleanup to avoid file descriptor issues on HF Spaces | |
| session = requests.Session() | |
| try: | |
| response = session.get(url, timeout=15, headers=headers) | |
| response.raise_for_status() | |
| response.encoding = 'utf-8' | |
| # Parse with BeautifulSoup | |
| soup = BeautifulSoup(response.content, 'html.parser') | |
| # Remove unwanted elements | |
| for element in soup(["script", "style", "meta", "link", "nav", "footer", "header", "noscript", "iframe"]): | |
| element.decompose() | |
| # Extract text with newline separator to preserve structure | |
| text = soup.get_text(separator='\n', strip=True) | |
| # Clean up whitespace | |
| lines = (line.strip() for line in text.splitlines()) | |
| chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) | |
| text = '\n'.join(chunk for chunk in chunks if chunk) | |
| # Limit content | |
| text = text[:8000] | |
| if len(text) > 200: | |
| print(f" β Extracted {len(text)} characters") | |
| return text | |
| else: | |
| return "Insufficient content" | |
| except requests.Timeout: | |
| print(f" [ERROR] Timeout after 15s") | |
| return "Error: Request timeout" | |
| except requests.RequestException as e: | |
| print(f" [ERROR] Request error: {str(e)}") | |
| return f"Error: {str(e)}" | |
| finally: | |
| session.close() | |
| except Exception as e: | |
| print(f" [ERROR] Unexpected error: {str(e)}") | |
| return f"Error scraping: {str(e)}" | |
| def curator_filter_sites(urls_json: str) -> str: | |
| """Curator Agent - Process and deduplicate URLs (no domain filtering) | |
| Returns list of unique URLs from search results | |
| """ | |
| try: | |
| urls = json.loads(urls_json) if isinstance(urls_json, str) else urls_json | |
| if isinstance(urls, dict): | |
| urls = [urls.get("url", "")] if urls.get("url") else [] | |
| elif not isinstance(urls, list): | |
| urls = [str(urls)] | |
| print(f"π― Curator Agent: Processing {len(urls)} URLs...") | |
| curated_urls = [] | |
| seen_urls = set() | |
| for url in urls: | |
| if isinstance(url, dict): | |
| url = url.get("url", "") | |
| if not url or url in seen_urls: | |
| continue | |
| seen_urls.add(url) | |
| curated_urls.append({ | |
| "url": url, | |
| "domain": extract_domain(url) | |
| }) | |
| print(f" β Added: {extract_domain(url)}") | |
| result = { | |
| "total_input": len(urls), | |
| "curated_count": len(curated_urls), | |
| "urls": curated_urls | |
| } | |
| print(f" π Result: {len(curated_urls)} unique URLs") | |
| return json.dumps(result, indent=2) | |
| except Exception as e: | |
| return f"Error in curator filter: {str(e)}" | |
| def chunk_content(content: str, url: str) -> str: | |
| """Split content into parent and child chunks""" | |
| try: | |
| parent_id = generate_uuid() | |
| parent_chunk = { | |
| "id": parent_id, | |
| "document": content[:4000], | |
| "source": url, | |
| "chunk_type": "parent", | |
| "created_at": datetime.now().isoformat() | |
| } | |
| # Create child chunks | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
| child_texts = splitter.split_text(content) | |
| child_chunks = [] | |
| for i, child_text in enumerate(child_texts[:5]): # Limit to 5 children | |
| child_id = generate_uuid() | |
| child_chunks.append({ | |
| "id": child_id, | |
| "document": child_text, | |
| "source": url, | |
| "chunk_type": "child", | |
| "parent_id": parent_id, | |
| "sequence": i, | |
| "created_at": datetime.now().isoformat() | |
| }) | |
| result = { | |
| "parent": parent_chunk, | |
| "children": child_chunks, | |
| "count": len(child_chunks) | |
| } | |
| return json.dumps(result, indent=2) | |
| except Exception as e: | |
| return f"Error chunking: {str(e)}" | |
| def generate_embeddings_and_index(chunks_json: str) -> str: | |
| """Generate embeddings and add to FAISS index""" | |
| try: | |
| chunks_data = json.loads(chunks_json) | |
| parent = chunks_data.get("parent", {}) | |
| children = chunks_data.get("children", []) | |
| # Only include valid chunks (not empty) | |
| all_chunks = [] | |
| if parent and parent.get("document"): | |
| all_chunks.append(parent) | |
| all_chunks.extend([c for c in children if c.get("document")]) | |
| if not all_chunks: | |
| print("β οΈ No valid chunks to index") | |
| return json.dumps({"success": True, "chunks_indexed": 0, "total_index_size": 0}) | |
| # Extract texts | |
| texts = [chunk.get("document", "") for chunk in all_chunks] | |
| chunk_ids = [chunk.get("id", "") for chunk in all_chunks] | |
| print(f"β³ Generating embeddings for {len(all_chunks)} chunks...") | |
| # Generate embeddings with minimal batching - optimized for CPU | |
| embeddings_list = [] | |
| batch_size = 2 # Ultra-minimal for CPU (was 4) | |
| for i in range(0, len(texts), batch_size): | |
| batch_texts = texts[i:i+batch_size] | |
| print(f" Processing batch {i//batch_size + 1}/{(len(texts)-1)//batch_size + 1}...") | |
| try: | |
| # Encode without progress bar to avoid hang | |
| batch_embeddings = embedding_model.encode(batch_texts, convert_to_numpy=True, show_progress_bar=False) | |
| embeddings_list.extend(batch_embeddings) | |
| print(f" β Encoded {len(batch_texts)} texts") | |
| except Exception as be: | |
| print(f" β οΈ Batch error: {str(be)}, using random embeddings as fallback") | |
| # Fallback: use random embeddings | |
| embeddings_list.extend(np.random.randn(len(batch_texts), 384).astype(np.float32)) | |
| # Add to FAISS | |
| embeddings = np.array(embeddings_list, dtype=np.float32) | |
| print(f" Adding {len(embeddings)} embeddings to FAISS...") | |
| faiss_add_chunks(embeddings, chunk_ids, all_chunks) | |
| print(f"β Successfully indexed {len(all_chunks)} chunks") | |
| return json.dumps({ | |
| "success": True, | |
| "chunks_indexed": len(all_chunks), | |
| "total_index_size": sum(1 for _ in faiss_chunk_id_to_index.keys()) | |
| }) | |
| except Exception as e: | |
| print(f"β Indexing error: {str(e)}") | |
| return json.dumps({"success": False, "error": str(e)}) | |
| def retrieve_similar_chunks(query_text: str, k: int = 5) -> str: | |
| """Retrieve similar chunks from FAISS""" | |
| try: | |
| print(f"π Retrieving similar chunks for: '{query_text[:50]}...'") | |
| try: | |
| # Try to encode query with timeout fallback | |
| query_embedding = embedding_model.encode(query_text, convert_to_numpy=True, show_progress_bar=False) | |
| print(f" β Query encoded") | |
| except Exception as e: | |
| print(f" β οΈ Encoding timeout/error: {str(e)[:50]}, using random embedding") | |
| # Fallback: use random embedding | |
| query_embedding = np.random.randn(384).astype(np.float32) | |
| results = faiss_search(query_embedding, k=k) | |
| if not results["ids"][0]: | |
| print(f" β οΈ No chunks found in FAISS (index size: {faiss_index.ntotal if faiss_index else 0})") | |
| else: | |
| print(f" β Found {len(results['ids'][0])} similar chunks") | |
| return json.dumps(results, indent=2) | |
| except Exception as e: | |
| print(f" β Retrieval error: {str(e)}") | |
| return json.dumps({"ids": [[]], "documents": [[]], "metadatas": [[]], "distances": [[]]}) | |
| def generate_answer(query: str, context_chunks: str, chat_history: str = "") -> str: | |
| """Generate answer using LLM with context and chat history - with retry logic""" | |
| max_attempts = 3 | |
| attempt = 0 | |
| while attempt < max_attempts: | |
| try: | |
| print(f"π‘ Generating answer for: '{query[:50]}...' (Attempt {attempt + 1}/{max_attempts})") | |
| system_msg = SystemMessage(content="""You are an e-commerce assistant. | |
| Answer the user's query based on the provided context AND previous conversation. | |
| Be concise, accurate, and cite sources when possible. Remember that default region is INDIA and current year in 2026. | |
| If answering a follow-up question, use insights from previous messages to provide better recommendations.""") | |
| # Build user message with chat history context | |
| user_content = f""" | |
| {chat_history} | |
| Current Query: {query} | |
| Context chunks: | |
| {context_chunks} | |
| Please provide a helpful answer with confidence level (0-100%).""" | |
| user_msg = HumanMessage(content=user_content) | |
| response = llm_large.invoke([system_msg, user_msg]) | |
| print(f"β Answer generated successfully") | |
| return response.content | |
| except Exception as e: | |
| attempt += 1 | |
| error_msg = str(e) | |
| print(f"β οΈ Attempt {attempt} failed: {error_msg[:80]}") | |
| if attempt < max_attempts: | |
| wait_time = 5 * attempt # 5s, 10s, 15s | |
| print(f" β³ Retrying in {wait_time}s...") | |
| import time | |
| time.sleep(wait_time) | |
| else: | |
| print(f"β Failed after {max_attempts} attempts") | |
| return f"Error generating answer: {error_msg[:100]}" | |
| # --- 5. AGENT FUNCTIONS WITH TOOLS --- | |
| def run_decision_agent(query: str, chat_history: List = None, progress_callback=None) -> Tuple[bool, str]: | |
| """ | |
| Decision Agent - Decides if web search is needed | |
| Returns: (needs_search: bool, reasoning: str) | |
| """ | |
| print(f"\n{'='*70}\nπ€ DECISION AGENT\n{'='*70}") | |
| if chat_history is None: | |
| chat_history = [] | |
| # Check if we have previous answers | |
| if not chat_history or len(chat_history) <= 1: | |
| print("β First question - will perform web search") | |
| return True, "First query - need fresh data" | |
| # Build conversation context from chat history | |
| context = "Previous Conversation:\n" | |
| for msg in chat_history[-6:]: # Last 3 exchanges | |
| role = msg.get("role", "").upper() | |
| content = msg.get("content", "")[:200] # First 200 chars | |
| context += f"{role}: {content}\n" | |
| # Decision prompt | |
| system_msg = SystemMessage(content="""You are a decision agent for e-commerce queries. | |
| Analyze if the current query can be answered from PREVIOUS conversation context or if it needs NEW web search. | |
| Rules: | |
| 1. If query asks about SPECIFIC products/prices/comparisons from before β NO search needed, use previous context | |
| 2. If query asks for DIFFERENT products or NEW information β YES search needed | |
| 3. If query is a follow-up question about something already discussed β NO search needed | |
| 4. If unclear or asking for LATEST data β YES search needed | |
| Respond with ONLY: "SEARCH" or "CONTEXT" | |
| """) | |
| user_msg = HumanMessage(content=f"""{context} | |
| Current query: {query} | |
| Decision: (respond with ONLY "SEARCH" or "CONTEXT")""") | |
| try: | |
| response = llm_small.invoke([system_msg, user_msg]) | |
| decision = response.content.strip().upper() | |
| needs_search = "SEARCH" in decision | |
| reasoning = f"Decision: {'Will search web' if needs_search else 'Will use previous context'}" | |
| print(f"β {reasoning}") | |
| return needs_search, reasoning | |
| except Exception as e: | |
| print(f"β οΈ Decision error: {str(e)[:50]}, defaulting to search") | |
| return True, f"Error in decision (default to search): {str(e)[:30]}" | |
| def run_search_agent(state: EcommerceAgentState, progress_callback=None) -> EcommerceAgentState: | |
| """Search Agent - Find products using Tavily + scrape content""" | |
| try: | |
| print(f"\n{'='*70}\nπ SEARCH AGENT\n{'='*70}") | |
| if progress_callback: | |
| progress_callback("π Searching for products...") | |
| query = state["query"] | |
| # Step 1: Generate search queries (with timeout fallback) | |
| system_msg = SystemMessage(content="""Generate 3-4 search variations for an e-commerce query. | |
| Focus on product names, prices, availability. Remember that the default region is INDIA and current year is 2026. | |
| Return ONLY a JSON array of strings.""") | |
| user_msg = HumanMessage(content=f"Query: {query}\n\nReturn search variations as JSON array:") | |
| try: | |
| response = llm_small.invoke([system_msg, user_msg]) | |
| search_queries = [query] | |
| try: | |
| json_match = re.search(r'\[.*\]', response.content, re.DOTALL) | |
| if json_match: | |
| search_queries.extend(json.loads(json_match.group())) | |
| except: | |
| search_queries.extend([f"{query} price", f"{query} buy online"]) | |
| except Exception as e: | |
| print(f"β οΈ LLM timeout/error, using default queries: {str(e)[:50]}") | |
| search_queries = [query, f"{query} price", f"{query} buy online"] | |
| search_queries = search_queries[:5] | |
| state["progress_log"].append(f"β Generated {len(search_queries)} search queries") | |
| print(f"Search queries: {search_queries}") | |
| # Step 2: Broad Tavily search (will filter with Curator Agent after) | |
| print("π₯ Searching broadly with Tavily...") | |
| tavily_search = TavilySearchResults( | |
| max_results=15, # More results for better diversity | |
| api_key=TAVILY_API_KEY, | |
| search_depth="advanced" | |
| ) | |
| all_results = [] | |
| for sq in search_queries[:3]: # Limit to 3 to avoid too many API calls | |
| try: | |
| print(f" π Query: '{sq[:40]}...'") | |
| results = tavily_search.invoke(sq) | |
| all_results.extend(results) | |
| print(f" β Found {len(results)} results") | |
| except Exception as e: | |
| print(f" β οΈ Error: {str(e)[:50]}") | |
| state["progress_log"].append(f"β οΈ Search error: {str(e)[:50]}") | |
| # Step 3: Deduplicate results (no domain filtering) | |
| print(f"π― Processing {len(all_results)} results...") | |
| seen_urls = set() | |
| seen_domains = {} | |
| unique_results = [] | |
| for result in all_results: | |
| # Ensure result is a dictionary | |
| if not isinstance(result, dict): | |
| try: | |
| result = json.loads(result) if isinstance(result, str) else {"url": str(result)} | |
| except: | |
| continue | |
| url = result.get("url", "") | |
| if url and url not in seen_urls: | |
| seen_urls.add(url) | |
| domain = extract_domain(url) | |
| unique_results.append(result) | |
| seen_domains[domain] = seen_domains.get(domain, 0) + 1 | |
| print(f" β Added: {domain}") | |
| print(f"\nπ Results: {len(unique_results)} URLs from {len(seen_domains)} domains") | |
| print(f" Domains found: {list(seen_domains.keys())}") | |
| unique_results = unique_results[:12] # Keep top 12 for diversity | |
| state["search_results"] = unique_results | |
| state["stage"] = WorkflowStage.CHUNKING.value | |
| state["progress_log"].append(f"β Search complete: {len(unique_results)} URLs found") | |
| print(f"β Search complete: {len(unique_results)} URLs") | |
| return state | |
| except Exception as e: | |
| print(f"β Search agent error: {str(e)}") | |
| state["error_message"] = f"Search agent failed: {str(e)[:100]}" | |
| state["stage"] = WorkflowStage.ERROR.value | |
| return state | |
| def run_chunking_agent(state: EcommerceAgentState, progress_callback=None) -> EcommerceAgentState: | |
| """Chunking Agent - Scrape and chunk documents""" | |
| try: | |
| print(f"\n{'='*70}\nπ CHUNKING AGENT\n{'='*70}") | |
| if progress_callback: | |
| progress_callback("π Chunking documents...") | |
| # Scrape URLs | |
| cleaned_documents = [] | |
| for i, result in enumerate(state.get("search_results", []), 1): | |
| url = result.get("url", "") | |
| if not url: | |
| continue | |
| domain = extract_domain(url) | |
| print(f"π₯ Scraping [{i}]: {domain}") | |
| try: | |
| content = scrape_product_content.invoke({"url": url}) | |
| if content and "Error" not in content and len(content) > 200: | |
| cleaned_documents.append({ | |
| "url": url, | |
| "content": content, | |
| "source": domain, | |
| "scraped_at": datetime.now().isoformat() | |
| }) | |
| print(f" β Extracted {len(content)} chars") | |
| state["progress_log"].append(f"β Scraped: {domain}") | |
| except Exception as e: | |
| print(f" β οΈ Error: {str(e)[:50]}") | |
| state["progress_log"].append(f"β οΈ Scrape error: {domain}") | |
| # Chunk documents | |
| parent_chunks = [] | |
| child_chunks = [] | |
| parent_child_mapping = {} | |
| for doc in cleaned_documents: | |
| try: | |
| chunks_result = chunk_content.invoke({ | |
| "content": doc["content"], | |
| "url": doc["url"] | |
| }) | |
| chunks_data = json.loads(chunks_result) | |
| parent = chunks_data.get("parent", {}) | |
| children = chunks_data.get("children", []) | |
| if parent: | |
| parent_chunks.append(parent) | |
| parent_child_mapping[parent["id"]] = [c["id"] for c in children] | |
| child_chunks.extend(children) | |
| except Exception as e: | |
| print(f" β Chunking error: {str(e)[:50]}") | |
| state["cleaned_documents"] = cleaned_documents | |
| state["parent_chunks"] = parent_chunks | |
| state["child_chunks"] = child_chunks | |
| state["parent_child_mapping"] = parent_child_mapping | |
| state["stage"] = WorkflowStage.INDEXING.value | |
| state["progress_log"].append(f"β Chunking: {len(parent_chunks)} parents, {len(child_chunks)} children") | |
| print(f"β Chunking complete: {len(parent_chunks)} parents, {len(child_chunks)} children") | |
| return state | |
| except Exception as e: | |
| state["error_message"] = f"Chunking agent failed: {str(e)}" | |
| state["stage"] = WorkflowStage.ERROR.value | |
| return state | |
| def run_indexing_agent(state: EcommerceAgentState, progress_callback=None) -> EcommerceAgentState: | |
| """Indexing Agent - Index chunks in FAISS""" | |
| try: | |
| print(f"\n{'='*70}\nποΈ INDEXING AGENT\n{'='*70}") | |
| if progress_callback: | |
| progress_callback("ποΈ Indexing chunks...") | |
| # Combine all chunks | |
| all_chunks = state.get("parent_chunks", []) + state.get("child_chunks", []) | |
| if not all_chunks: | |
| state["progress_log"].append("β οΈ No chunks to index") | |
| state["stage"] = WorkflowStage.RETRIEVAL.value | |
| return state | |
| # Generate embeddings using tool | |
| chunks_json = json.dumps({ | |
| "parent": state["parent_chunks"][0] if state["parent_chunks"] else {}, | |
| "children": state["child_chunks"] | |
| }) | |
| result = generate_embeddings_and_index.invoke({"chunks_json": chunks_json}) | |
| index_result = json.loads(result) | |
| state["stage"] = WorkflowStage.RETRIEVAL.value | |
| state["progress_log"].append(f"β Indexed {index_result.get('chunks_indexed', 0)} chunks") | |
| print(f"β Indexing complete") | |
| return state | |
| except Exception as e: | |
| state["error_message"] = f"Indexing agent failed: {str(e)}" | |
| state["stage"] = WorkflowStage.ERROR.value | |
| return state | |
| def run_retrieval_agent(state: EcommerceAgentState, progress_callback=None) -> EcommerceAgentState: | |
| """Retrieval Agent - Find similar chunks""" | |
| try: | |
| print(f"\n{'='*70}\nπ RETRIEVAL AGENT\n{'='*70}") | |
| if progress_callback: | |
| progress_callback("π Retrieving similar content...") | |
| query = state["query"] | |
| # Check if FAISS has any chunks | |
| if not faiss_index or faiss_index.ntotal == 0: | |
| print("β οΈ FAISS index is empty - no chunks to retrieve") | |
| state["retrieved_children"] = [] | |
| state["progress_log"].append("β οΈ FAISS index empty") | |
| else: | |
| # Retrieve chunks | |
| result_json = retrieve_similar_chunks.invoke({"query_text": query, "k": 5}) | |
| if isinstance(result_json, dict): | |
| results = result_json | |
| else: | |
| results = json.loads(result_json) | |
| retrieved_children = [] | |
| for doc_id, doc_text, metadata, dist in zip( | |
| results.get("ids", [[]])[0], | |
| results.get("documents", [[]])[0], | |
| results.get("metadatas", [[]])[0], | |
| results.get("distances", [[]])[0] | |
| ): | |
| retrieved_children.append({ | |
| "id": doc_id, | |
| "document": doc_text, | |
| "metadata": metadata, | |
| "distance": float(dist) | |
| }) | |
| state["retrieved_children"] = retrieved_children | |
| state["progress_log"].append(f"β Retrieved {len(retrieved_children)} chunks") | |
| print(f"β Retrieval complete: {len(retrieved_children)} chunks") | |
| state["stage"] = WorkflowStage.SYNTHESIS.value | |
| return state | |
| except Exception as e: | |
| print(f"β Retrieval error: {str(e)}") | |
| state["error_message"] = f"Retrieval agent failed: {str(e)}" | |
| state["stage"] = WorkflowStage.ERROR.value | |
| return state | |
| def run_synthesis_agent(state: EcommerceAgentState, progress_callback=None) -> EcommerceAgentState: | |
| """Synthesis Agent - Generate answer with robust error handling""" | |
| try: | |
| print(f"\n{'='*70}\nπ‘ SYNTHESIS AGENT\n{'='*70}") | |
| if progress_callback: | |
| progress_callback("π‘ Generating answer...") | |
| query = state["query"] | |
| retrieved = state.get("retrieved_children", []) | |
| chat_messages = state.get("chat_messages", []) | |
| # Prepare context - handle empty case | |
| if retrieved: | |
| context = "\n\n".join([ | |
| f"[{i+1}] {chunk.get('metadata', {}).get('source', 'Unknown')}\n{chunk.get('document', '')}" | |
| for i, chunk in enumerate(retrieved[:5]) | |
| ]) | |
| confidence = min(0.9, len(retrieved) / 10.0) | |
| else: | |
| context = "[No relevant information found in knowledge base]" | |
| confidence = 0.1 | |
| # Build chat history context | |
| chat_context = "" | |
| if chat_messages and len(chat_messages) > 1: | |
| chat_context = "Previous Conversation:\n" | |
| for msg in chat_messages[:-1]: # All except current query | |
| role = msg.get("role", "").upper() | |
| content = msg.get("content", "") | |
| chat_context += f"{role}: {content[:300]}\n\n" # First 300 chars of each message | |
| # Generate answer with retry logic built-in | |
| answer = None | |
| max_retries = 3 | |
| for retry_attempt in range(max_retries): | |
| try: | |
| print(f" β³ Generating (attempt {retry_attempt + 1}/{max_retries})...") | |
| answer = generate_answer(query, context, chat_context) | |
| # Check if answer is an error | |
| if answer and not answer.startswith("Error"): | |
| print(f" β Answer generated successfully") | |
| break | |
| elif answer and answer.startswith("Error"): | |
| print(f" β οΈ {answer[:80]}") | |
| if retry_attempt < max_retries - 1: | |
| wait = 5 * (retry_attempt + 1) | |
| print(f" β³ Waiting {wait}s before retry...") | |
| import time | |
| time.sleep(wait) | |
| else: | |
| print(f" β Max retries reached") | |
| except Exception as gen_err: | |
| print(f" β οΈ Generation error: {str(gen_err)[:80]}") | |
| if retry_attempt < max_retries - 1: | |
| wait = 5 * (retry_attempt + 1) | |
| print(f" β³ Waiting {wait}s before retry...") | |
| import time | |
| time.sleep(wait) | |
| else: | |
| answer = f"Error generating answer: {str(gen_err)[:100]}" | |
| if not answer: | |
| answer = f"Error generating answer: Unable to generate response after {max_retries} attempts" | |
| # Extract citations from retrieved chunks | |
| citations = [ | |
| { | |
| "source": chunk.get("metadata", {}).get("source", "Unknown"), | |
| "section": "Content", | |
| "url": chunk.get("metadata", {}).get("source", "") | |
| } | |
| for chunk in retrieved[:3] | |
| ] | |
| state["final_answer"] = answer | |
| state["citations"] = citations | |
| state["confidence"] = confidence | |
| state["stage"] = WorkflowStage.COMPLETE.value | |
| state["progress_log"].append(f"β Synthesis complete: Confidence={confidence:.2f}") | |
| print(f"β Synthesis complete") | |
| return state | |
| except Exception as e: | |
| print(f"β Synthesis error: {str(e)}") | |
| state["error_message"] = f"Synthesis agent failed: {str(e)}" | |
| state["stage"] = WorkflowStage.ERROR.value | |
| return state | |
| # --- 6. WORKFLOW ORCHESTRATION --- | |
| def run_complete_workflow(query: str, chat_history: List = None, progress=gr.Progress()) -> EcommerceAgentState: | |
| """Run the complete tool-based agentic pipeline with intelligent context""" | |
| if chat_history is None: | |
| chat_history = [] | |
| state = create_initial_state(query) | |
| state["chat_messages"] = chat_history # Add chat history to state | |
| def update_progress(msg: str): | |
| progress(0.2, desc=msg) | |
| # Step 1: Decision Agent - Determine if search is needed | |
| needs_search, decision_reason = run_decision_agent(query, chat_history, update_progress) | |
| state["progress_log"].append(f"π€ {decision_reason}") | |
| # Step 2: Conditionally run search agents | |
| if needs_search: | |
| progress(0.25, desc="π Searching...") | |
| state = run_search_agent(state, update_progress) | |
| if state["stage"] == WorkflowStage.ERROR.value: | |
| return state | |
| progress(0.35, desc="π Chunking...") | |
| state = run_chunking_agent(state, update_progress) | |
| if state["stage"] == WorkflowStage.ERROR.value: | |
| return state | |
| progress(0.5, desc="ποΈ Indexing...") | |
| state = run_indexing_agent(state, update_progress) | |
| if state["stage"] == WorkflowStage.ERROR.value: | |
| return state | |
| else: | |
| # Skip search/chunking/indexing, go directly to retrieval | |
| print("π Reusing indexed data from previous query...") | |
| state["progress_log"].append("π Using existing indexed knowledge") | |
| state["stage"] = WorkflowStage.RETRIEVAL.value | |
| # Step 3: Retrieval and Synthesis (always run) | |
| progress(0.65, desc="π Retrieving...") | |
| state = run_retrieval_agent(state, update_progress) | |
| if state["stage"] == WorkflowStage.ERROR.value: | |
| return state | |
| progress(0.85, desc="π‘ Generating answer...") | |
| state = run_synthesis_agent(state, update_progress) | |
| progress(1.0, desc="β Complete") | |
| # Add to conversation history | |
| add_to_conversation_history(query, state["final_answer"]) | |
| return state | |
| # --- 7. GRADIO INTERFACE ---# | |
| def chat_interface(message: str, chat_history: List) -> Tuple[str, List]: | |
| """Chat interface for the e-commerce RAG system""" | |
| if not message or message.strip() == "": | |
| return "", chat_history | |
| # Add user message to history (new format: dict with role and content) | |
| chat_history.append({"role": "user", "content": message}) | |
| # Run workflow with full chat history | |
| state = run_complete_workflow(message, chat_history=chat_history) | |
| # Build assistant response | |
| assistant_response = f"""π€ **Answer** | |
| {state['final_answer']} | |
| **Confidence:** {state['confidence']*100:.0f}% | |
| --- | |
| π **Sources:** | |
| """ | |
| for cite in state["citations"][:3]: | |
| assistant_response += f"- {cite['source']}\n" | |
| assistant_response += f"\n---\n\nπ **Process Log:**\n" | |
| for log in state["progress_log"]: | |
| assistant_response += f"β {log}\n" | |
| if state["error_message"]: | |
| assistant_response += f"\nβ οΈ **Error:** {state['error_message']}" | |
| # Update last message with assistant response (new format) | |
| chat_history.append({"role": "assistant", "content": assistant_response}) | |
| return "", chat_history | |
| # Build Gradio interface | |
| with gr.Blocks(title="ποΈ E-commerce Chat RAG") as demo: | |
| gr.Markdown(""" | |
| # ποΈ E-commerce Chat RAG System | |
| Ask any question about e-commerce products and prices! | |
| *Agents:* π Search β π Chunk β ποΈ Index β π Retrieve β π‘ Synthesize | |
| """) | |
| chatbot = gr.Chatbot( | |
| label="π¬ Chat", | |
| height=600 | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Your Question", | |
| placeholder="e.g., What is the current price of iPhone 14?", | |
| lines=2, | |
| scale=4 | |
| ) | |
| submit_btn = gr.Button("Send π€", variant="primary", scale=1) | |
| # Chat submission | |
| submit_btn.click( | |
| fn=chat_interface, | |
| inputs=[msg, chatbot], | |
| outputs=[msg, chatbot], | |
| queue=True | |
| ) | |
| # Allow Enter key to submit | |
| msg.submit( | |
| fn=chat_interface, | |
| inputs=[msg, chatbot], | |
| outputs=[msg, chatbot], | |
| queue=True | |
| ) | |
| if __name__ == "__main__": | |
| try: | |
| # For HF Spaces: suppress asyncio event loop cleanup warnings | |
| demo.launch(share=False, show_error=True) | |
| finally: | |
| # Clean up event loop on shutdown | |
| try: | |
| loop = asyncio.get_event_loop() | |
| if loop.is_running(): | |
| loop.stop() | |
| loop.close() | |
| except Exception: | |
| pass | |