arka7's picture
Update app.py
5e91f60 verified
import os
import json
import requests
import asyncio
import uuid
import re
import warnings
from typing import TypedDict, Annotated, List, Dict, Any, Optional, Tuple
from datetime import datetime
from pathlib import Path
from enum import Enum
import tempfile
# Suppress torch FutureWarning on HF Spaces
warnings.filterwarnings('ignore', category=FutureWarning, module='.*torch.*')
# Third-party imports
import gradio as gr
from dotenv import load_dotenv
from bs4 import BeautifulSoup
import faiss
import pickle
import numpy as np
from sentence_transformers import SentenceTransformer
import tiktoken
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage
from langchain_mistralai.chat_models import ChatMistralAI
from langchain_core.tools import tool
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_text_splitters import RecursiveCharacterTextSplitter
# --- 1. SETUP & CONFIGURATION ---
load_dotenv()
# --- API KEY ROTATION ---
class APIKeyRotator:
"""Handles API key rotation for fault tolerance and rate limit management"""
def __init__(self, service_name: str):
self.service_name = service_name
self.keys = self._load_keys()
self.current_index = 0
def _load_keys(self) -> List[str]:
"""Load all API keys for the service (e.g., MISTRAL_API_KEY, MISTRAL_API_KEY_2, etc.)"""
keys = []
i = 1
while True:
if i == 1:
key = os.getenv(f"{self.service_name}_API_KEY")
else:
key = os.getenv(f"{self.service_name}_API_KEY_{i}")
if not key:
break
keys.append(key)
i += 1
if not keys:
raise ValueError(f"❌ No API keys found for {self.service_name}")
print(f"βœ… Loaded {len(keys)} API key(s) for {self.service_name}")
import random
random.shuffle(keys) # Randomize order
return keys
def get_key(self) -> str:
"""Get current API key without rotation"""
if not self.keys:
raise ValueError(f"❌ No API keys available for {self.service_name}")
return self.keys[self.current_index]
def rotate(self) -> str:
"""Rotate to next API key"""
if not self.keys:
raise ValueError(f"❌ No API keys available for {self.service_name}")
self.current_index = (self.current_index + 1) % len(self.keys)
print(f"πŸ”„ Rotated {self.service_name} API key to index {self.current_index + 1}/{len(self.keys)}")
return self.get_key()
# Initialize key rotators
mistral_rotator = APIKeyRotator("MISTRAL")
tavily_rotator = APIKeyRotator("TAVILY")
# Set current keys in environment
os.environ["MISTRAL_API_KEY"] = mistral_rotator.get_key()
os.environ["TAVILY_API_KEY"] = tavily_rotator.get_key()
MISTRAL_API_KEY = os.environ["MISTRAL_API_KEY"]
TAVILY_API_KEY = os.environ["TAVILY_API_KEY"]
print(f"[OK] API Key rotation initialized")
# Vector DB Configuration
CHROMA_DB_PATH = Path(tempfile.gettempdir()) / "ecommerce_rag"
CHROMA_DB_PATH.mkdir(exist_ok=True, parents=True)
print(f"[VECTOR_STORE] Location: {CHROMA_DB_PATH.absolute()}")
print(f"[OK] API Keys configured")
# Global FAISS state
faiss_index = None
faiss_chunks_metadata = {}
faiss_chunk_id_to_index = {}
faiss_index_counter = 0
# Resilient LLM wrapper with key rotation
def create_llm_with_rotation(model: str, temperature: float = 0.2, timeout: float = 120.0, max_retries: int = 5):
"""Create an LLM that rotates API keys on failure"""
class ResilientLLM:
def __init__(self, model, temperature, timeout, max_retries):
self.model = model
self.temperature = temperature
self.timeout = timeout
self.max_retries = max_retries
self.llm = self._create_llm()
def _create_llm(self):
return ChatMistralAI(
api_key=mistral_rotator.get_key(),
model=self.model,
temperature=self.temperature,
timeout=self.timeout,
max_retries=self.max_retries
)
def invoke(self, messages, **kwargs):
"""Invoke with automatic key rotation on failure"""
max_key_rotations = len(mistral_rotator.keys)
for attempt in range(max_key_rotations):
try:
return self.llm.invoke(messages, **kwargs)
except Exception as e:
if attempt < max_key_rotations - 1:
print(f"⚠️ Attempt {attempt + 1} failed: {str(e)[:80]}")
new_key = mistral_rotator.rotate()
self.llm = ChatMistralAI(
api_key=new_key,
model=self.model,
temperature=self.temperature,
timeout=self.timeout,
max_retries=self.max_retries
)
print(f"πŸ”„ Retrying with rotated key...")
else:
raise
def bind_tools(self, tools):
"""Bind tools to LLM"""
return self.llm.bind_tools(tools)
return ResilientLLM(model, temperature, timeout, max_retries)
# Initialize LLMs with resilience and key rotation
llm_small = create_llm_with_rotation("mistral-small-latest", temperature=0.2, timeout=60.0, max_retries=5)
llm_large = create_llm_with_rotation("mistral-large-latest", temperature=0.2, timeout=120.0, max_retries=5)
# Initialize embedding model
print("[LOADING] Embedding model...")
embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
embedding_model.to('cpu')
print("[OK] Embedding model initialized (CPU mode)")
# Initialize tokenizer
tokenizer = tiktoken.get_encoding("cl100k_base")
# --- TRUSTED E-COMMERCE WEBSITES (REMOVED - Now accepting all sources) ---
# No domain restrictions - Tavily will find the best sources dynamically
# --- 2. STATE DEFINITIONS ---
class WorkflowStage(Enum):
IDLE = "idle"
SEARCH = "search"
CHUNKING = "chunking"
INDEXING = "indexing"
RETRIEVAL = "retrieval"
SYNTHESIS = "synthesis"
COMPLETE = "complete"
ERROR = "error"
class EcommerceAgentState(TypedDict):
"""Unified state for tool-based agentic RAG pipeline"""
query: str
stage: str
search_results: List[Dict]
cleaned_documents: List[Dict]
parent_chunks: List[Dict]
child_chunks: List[Dict]
retrieved_children: List[Dict]
retrieved_parents: List[Dict]
parent_child_mapping: Dict[str, List[str]]
final_answer: str
citations: List[Dict]
confidence: float
error_message: Optional[str]
progress_log: List[str]
chat_messages: List[Tuple[str, str]] # List of (user/assistant, message) tuples
def create_initial_state(query: str) -> EcommerceAgentState:
"""Create initial state for a new query"""
return {
"query": query,
"stage": WorkflowStage.IDLE.value,
"search_results": [],
"cleaned_documents": [],
"parent_chunks": [],
"child_chunks": [],
"retrieved_children": [],
"retrieved_parents": [],
"parent_child_mapping": {},
"final_answer": "",
"citations": [],
"confidence": 0.0,
"error_message": None,
"progress_log": [],
"chat_messages": []
}
# --- CONVERSATION CONTEXT ---
# Global conversation memory to track context between queries
conversation_history = []
previous_answers = []
def add_to_conversation_history(query: str, answer: str):
"""Add query and answer to conversation history"""
global conversation_history, previous_answers
conversation_history.append({"query": query, "answer": answer})
previous_answers.append(answer)
# Keep only last 5 exchanges to avoid context bloat
if len(conversation_history) > 5:
conversation_history.pop(0)
previous_answers.pop(0)
def get_conversation_context() -> str:
"""Get formatted conversation context"""
if not conversation_history:
return ""
context = "\n**Previous conversation:**\n"
for i, item in enumerate(conversation_history[-3:], 1): # Last 3 exchanges
context += f"{i}. Q: {item['query'][:100]}...\n A: {item['answer'][:150]}...\n"
return context
# --- 3. UTILITY FUNCTIONS ---
def extract_domain(url: str) -> str:
"""Extract domain from URL"""
try:
from urllib.parse import urlparse
parsed = urlparse(url)
domain = parsed.netloc.replace("www.", "")
return domain.split('/')[0]
except:
return "unknown"
def is_trusted_source(url: str) -> bool:
"""Accept all sources - no domain restrictions"""
return True # Accept all domains from Tavily results
def count_tokens(text: str) -> int:
"""Count tokens"""
try:
return len(tokenizer.encode(text))
except:
return len(text.split())
def generate_uuid() -> str:
"""Generate unique ID"""
return str(uuid.uuid4())[:8]
def faiss_add_chunks(embeddings: np.ndarray, chunk_ids: List[str], metadatas: List[Dict]) -> None:
"""Add chunks to FAISS"""
global faiss_index, faiss_chunks_metadata, faiss_chunk_id_to_index, faiss_index_counter
embeddings = np.array(embeddings, dtype=np.float32)
faiss.normalize_L2(embeddings)
if faiss_index is None:
dimension = embeddings.shape[1]
faiss_index = faiss.IndexFlatIP(dimension)
faiss_index.add(embeddings)
for chunk_id, metadata in zip(chunk_ids, metadatas):
faiss_chunks_metadata[chunk_id] = metadata
faiss_chunk_id_to_index[chunk_id] = faiss_index_counter
faiss_index_counter += 1
def faiss_search(query_embedding: np.ndarray, k: int = 5) -> Dict:
"""Search FAISS"""
global faiss_index, faiss_chunks_metadata, faiss_chunk_id_to_index
if faiss_index is None or faiss_index.ntotal == 0:
return {"ids": [[]], "documents": [[]], "metadatas": [[]], "distances": [[]]}
query_embedding = np.array([query_embedding], dtype=np.float32)
faiss.normalize_L2(query_embedding)
distances, indices = faiss_index.search(query_embedding, min(k, faiss_index.ntotal))
indices = indices[0]
result_ids = []
result_docs = []
result_metadatas = []
result_distances = []
for dist, idx in zip(distances[0], indices):
for chunk_id, chunk_idx in faiss_chunk_id_to_index.items():
if chunk_idx == idx:
metadata = faiss_chunks_metadata[chunk_id]
result_ids.append(chunk_id)
result_docs.append(metadata.get("document", ""))
result_metadatas.append({k: v for k, v in metadata.items() if k != "document"})
# Convert numpy float32 to Python float for JSON serialization
result_distances.append(float(dist))
break
return {
"ids": [result_ids],
"documents": [result_docs],
"metadatas": [result_metadatas],
"distances": [result_distances]
}
# --- 4. TOOL DEFINITIONS ---
@tool
def search_ecommerce_products(query: str) -> str:
"""Search for e-commerce products using Tavily Search - searches broadly, curator filters later"""
try:
print(f"πŸ” Searching Tavily for: '{query}' (will curate results after)")
search = TavilySearchResults(
max_results=15, # Increased to get more diverse results
api_key=TAVILY_API_KEY,
search_depth="advanced"
)
results = search.invoke(query)
print(f" βœ… Found {len(results)} raw results (will filter to trusted sources)")
return json.dumps(results, indent=2)
except Exception as e:
return f"Error in search: {str(e)}"
@tool
def scrape_product_content(url: str) -> str:
"""Scrape and clean content from a URL using requests + BeautifulSoup"""
try:
print(f"πŸ“₯ Scraping: {extract_domain(url)}")
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8',
'Accept-Language': 'en-US,en;q=0.5',
}
# Use session with proper cleanup to avoid file descriptor issues on HF Spaces
session = requests.Session()
try:
response = session.get(url, timeout=15, headers=headers)
response.raise_for_status()
response.encoding = 'utf-8'
# Parse with BeautifulSoup
soup = BeautifulSoup(response.content, 'html.parser')
# Remove unwanted elements
for element in soup(["script", "style", "meta", "link", "nav", "footer", "header", "noscript", "iframe"]):
element.decompose()
# Extract text with newline separator to preserve structure
text = soup.get_text(separator='\n', strip=True)
# Clean up whitespace
lines = (line.strip() for line in text.splitlines())
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
text = '\n'.join(chunk for chunk in chunks if chunk)
# Limit content
text = text[:8000]
if len(text) > 200:
print(f" βœ… Extracted {len(text)} characters")
return text
else:
return "Insufficient content"
except requests.Timeout:
print(f" [ERROR] Timeout after 15s")
return "Error: Request timeout"
except requests.RequestException as e:
print(f" [ERROR] Request error: {str(e)}")
return f"Error: {str(e)}"
finally:
session.close()
except Exception as e:
print(f" [ERROR] Unexpected error: {str(e)}")
return f"Error scraping: {str(e)}"
@tool
def curator_filter_sites(urls_json: str) -> str:
"""Curator Agent - Process and deduplicate URLs (no domain filtering)
Returns list of unique URLs from search results
"""
try:
urls = json.loads(urls_json) if isinstance(urls_json, str) else urls_json
if isinstance(urls, dict):
urls = [urls.get("url", "")] if urls.get("url") else []
elif not isinstance(urls, list):
urls = [str(urls)]
print(f"🎯 Curator Agent: Processing {len(urls)} URLs...")
curated_urls = []
seen_urls = set()
for url in urls:
if isinstance(url, dict):
url = url.get("url", "")
if not url or url in seen_urls:
continue
seen_urls.add(url)
curated_urls.append({
"url": url,
"domain": extract_domain(url)
})
print(f" βœ… Added: {extract_domain(url)}")
result = {
"total_input": len(urls),
"curated_count": len(curated_urls),
"urls": curated_urls
}
print(f" πŸ“Š Result: {len(curated_urls)} unique URLs")
return json.dumps(result, indent=2)
except Exception as e:
return f"Error in curator filter: {str(e)}"
@tool
def chunk_content(content: str, url: str) -> str:
"""Split content into parent and child chunks"""
try:
parent_id = generate_uuid()
parent_chunk = {
"id": parent_id,
"document": content[:4000],
"source": url,
"chunk_type": "parent",
"created_at": datetime.now().isoformat()
}
# Create child chunks
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
child_texts = splitter.split_text(content)
child_chunks = []
for i, child_text in enumerate(child_texts[:5]): # Limit to 5 children
child_id = generate_uuid()
child_chunks.append({
"id": child_id,
"document": child_text,
"source": url,
"chunk_type": "child",
"parent_id": parent_id,
"sequence": i,
"created_at": datetime.now().isoformat()
})
result = {
"parent": parent_chunk,
"children": child_chunks,
"count": len(child_chunks)
}
return json.dumps(result, indent=2)
except Exception as e:
return f"Error chunking: {str(e)}"
@tool
def generate_embeddings_and_index(chunks_json: str) -> str:
"""Generate embeddings and add to FAISS index"""
try:
chunks_data = json.loads(chunks_json)
parent = chunks_data.get("parent", {})
children = chunks_data.get("children", [])
# Only include valid chunks (not empty)
all_chunks = []
if parent and parent.get("document"):
all_chunks.append(parent)
all_chunks.extend([c for c in children if c.get("document")])
if not all_chunks:
print("⚠️ No valid chunks to index")
return json.dumps({"success": True, "chunks_indexed": 0, "total_index_size": 0})
# Extract texts
texts = [chunk.get("document", "") for chunk in all_chunks]
chunk_ids = [chunk.get("id", "") for chunk in all_chunks]
print(f"⏳ Generating embeddings for {len(all_chunks)} chunks...")
# Generate embeddings with minimal batching - optimized for CPU
embeddings_list = []
batch_size = 2 # Ultra-minimal for CPU (was 4)
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
print(f" Processing batch {i//batch_size + 1}/{(len(texts)-1)//batch_size + 1}...")
try:
# Encode without progress bar to avoid hang
batch_embeddings = embedding_model.encode(batch_texts, convert_to_numpy=True, show_progress_bar=False)
embeddings_list.extend(batch_embeddings)
print(f" βœ… Encoded {len(batch_texts)} texts")
except Exception as be:
print(f" ⚠️ Batch error: {str(be)}, using random embeddings as fallback")
# Fallback: use random embeddings
embeddings_list.extend(np.random.randn(len(batch_texts), 384).astype(np.float32))
# Add to FAISS
embeddings = np.array(embeddings_list, dtype=np.float32)
print(f" Adding {len(embeddings)} embeddings to FAISS...")
faiss_add_chunks(embeddings, chunk_ids, all_chunks)
print(f"βœ… Successfully indexed {len(all_chunks)} chunks")
return json.dumps({
"success": True,
"chunks_indexed": len(all_chunks),
"total_index_size": sum(1 for _ in faiss_chunk_id_to_index.keys())
})
except Exception as e:
print(f"❌ Indexing error: {str(e)}")
return json.dumps({"success": False, "error": str(e)})
@tool
def retrieve_similar_chunks(query_text: str, k: int = 5) -> str:
"""Retrieve similar chunks from FAISS"""
try:
print(f"πŸ”Ž Retrieving similar chunks for: '{query_text[:50]}...'")
try:
# Try to encode query with timeout fallback
query_embedding = embedding_model.encode(query_text, convert_to_numpy=True, show_progress_bar=False)
print(f" βœ… Query encoded")
except Exception as e:
print(f" ⚠️ Encoding timeout/error: {str(e)[:50]}, using random embedding")
# Fallback: use random embedding
query_embedding = np.random.randn(384).astype(np.float32)
results = faiss_search(query_embedding, k=k)
if not results["ids"][0]:
print(f" ⚠️ No chunks found in FAISS (index size: {faiss_index.ntotal if faiss_index else 0})")
else:
print(f" βœ… Found {len(results['ids'][0])} similar chunks")
return json.dumps(results, indent=2)
except Exception as e:
print(f" ❌ Retrieval error: {str(e)}")
return json.dumps({"ids": [[]], "documents": [[]], "metadatas": [[]], "distances": [[]]})
def generate_answer(query: str, context_chunks: str, chat_history: str = "") -> str:
"""Generate answer using LLM with context and chat history - with retry logic"""
max_attempts = 3
attempt = 0
while attempt < max_attempts:
try:
print(f"πŸ’‘ Generating answer for: '{query[:50]}...' (Attempt {attempt + 1}/{max_attempts})")
system_msg = SystemMessage(content="""You are an e-commerce assistant.
Answer the user's query based on the provided context AND previous conversation.
Be concise, accurate, and cite sources when possible. Remember that default region is INDIA and current year in 2026.
If answering a follow-up question, use insights from previous messages to provide better recommendations.""")
# Build user message with chat history context
user_content = f"""
{chat_history}
Current Query: {query}
Context chunks:
{context_chunks}
Please provide a helpful answer with confidence level (0-100%)."""
user_msg = HumanMessage(content=user_content)
response = llm_large.invoke([system_msg, user_msg])
print(f"βœ… Answer generated successfully")
return response.content
except Exception as e:
attempt += 1
error_msg = str(e)
print(f"⚠️ Attempt {attempt} failed: {error_msg[:80]}")
if attempt < max_attempts:
wait_time = 5 * attempt # 5s, 10s, 15s
print(f" ⏳ Retrying in {wait_time}s...")
import time
time.sleep(wait_time)
else:
print(f"❌ Failed after {max_attempts} attempts")
return f"Error generating answer: {error_msg[:100]}"
# --- 5. AGENT FUNCTIONS WITH TOOLS ---
def run_decision_agent(query: str, chat_history: List = None, progress_callback=None) -> Tuple[bool, str]:
"""
Decision Agent - Decides if web search is needed
Returns: (needs_search: bool, reasoning: str)
"""
print(f"\n{'='*70}\nπŸ€” DECISION AGENT\n{'='*70}")
if chat_history is None:
chat_history = []
# Check if we have previous answers
if not chat_history or len(chat_history) <= 1:
print("βœ… First question - will perform web search")
return True, "First query - need fresh data"
# Build conversation context from chat history
context = "Previous Conversation:\n"
for msg in chat_history[-6:]: # Last 3 exchanges
role = msg.get("role", "").upper()
content = msg.get("content", "")[:200] # First 200 chars
context += f"{role}: {content}\n"
# Decision prompt
system_msg = SystemMessage(content="""You are a decision agent for e-commerce queries.
Analyze if the current query can be answered from PREVIOUS conversation context or if it needs NEW web search.
Rules:
1. If query asks about SPECIFIC products/prices/comparisons from before β†’ NO search needed, use previous context
2. If query asks for DIFFERENT products or NEW information β†’ YES search needed
3. If query is a follow-up question about something already discussed β†’ NO search needed
4. If unclear or asking for LATEST data β†’ YES search needed
Respond with ONLY: "SEARCH" or "CONTEXT"
""")
user_msg = HumanMessage(content=f"""{context}
Current query: {query}
Decision: (respond with ONLY "SEARCH" or "CONTEXT")""")
try:
response = llm_small.invoke([system_msg, user_msg])
decision = response.content.strip().upper()
needs_search = "SEARCH" in decision
reasoning = f"Decision: {'Will search web' if needs_search else 'Will use previous context'}"
print(f"βœ… {reasoning}")
return needs_search, reasoning
except Exception as e:
print(f"⚠️ Decision error: {str(e)[:50]}, defaulting to search")
return True, f"Error in decision (default to search): {str(e)[:30]}"
def run_search_agent(state: EcommerceAgentState, progress_callback=None) -> EcommerceAgentState:
"""Search Agent - Find products using Tavily + scrape content"""
try:
print(f"\n{'='*70}\nπŸ” SEARCH AGENT\n{'='*70}")
if progress_callback:
progress_callback("πŸ” Searching for products...")
query = state["query"]
# Step 1: Generate search queries (with timeout fallback)
system_msg = SystemMessage(content="""Generate 3-4 search variations for an e-commerce query.
Focus on product names, prices, availability. Remember that the default region is INDIA and current year is 2026.
Return ONLY a JSON array of strings.""")
user_msg = HumanMessage(content=f"Query: {query}\n\nReturn search variations as JSON array:")
try:
response = llm_small.invoke([system_msg, user_msg])
search_queries = [query]
try:
json_match = re.search(r'\[.*\]', response.content, re.DOTALL)
if json_match:
search_queries.extend(json.loads(json_match.group()))
except:
search_queries.extend([f"{query} price", f"{query} buy online"])
except Exception as e:
print(f"⚠️ LLM timeout/error, using default queries: {str(e)[:50]}")
search_queries = [query, f"{query} price", f"{query} buy online"]
search_queries = search_queries[:5]
state["progress_log"].append(f"βœ… Generated {len(search_queries)} search queries")
print(f"Search queries: {search_queries}")
# Step 2: Broad Tavily search (will filter with Curator Agent after)
print("πŸ“₯ Searching broadly with Tavily...")
tavily_search = TavilySearchResults(
max_results=15, # More results for better diversity
api_key=TAVILY_API_KEY,
search_depth="advanced"
)
all_results = []
for sq in search_queries[:3]: # Limit to 3 to avoid too many API calls
try:
print(f" πŸ” Query: '{sq[:40]}...'")
results = tavily_search.invoke(sq)
all_results.extend(results)
print(f" βœ… Found {len(results)} results")
except Exception as e:
print(f" ⚠️ Error: {str(e)[:50]}")
state["progress_log"].append(f"⚠️ Search error: {str(e)[:50]}")
# Step 3: Deduplicate results (no domain filtering)
print(f"🎯 Processing {len(all_results)} results...")
seen_urls = set()
seen_domains = {}
unique_results = []
for result in all_results:
# Ensure result is a dictionary
if not isinstance(result, dict):
try:
result = json.loads(result) if isinstance(result, str) else {"url": str(result)}
except:
continue
url = result.get("url", "")
if url and url not in seen_urls:
seen_urls.add(url)
domain = extract_domain(url)
unique_results.append(result)
seen_domains[domain] = seen_domains.get(domain, 0) + 1
print(f" βœ… Added: {domain}")
print(f"\nπŸ“Š Results: {len(unique_results)} URLs from {len(seen_domains)} domains")
print(f" Domains found: {list(seen_domains.keys())}")
unique_results = unique_results[:12] # Keep top 12 for diversity
state["search_results"] = unique_results
state["stage"] = WorkflowStage.CHUNKING.value
state["progress_log"].append(f"βœ… Search complete: {len(unique_results)} URLs found")
print(f"βœ… Search complete: {len(unique_results)} URLs")
return state
except Exception as e:
print(f"❌ Search agent error: {str(e)}")
state["error_message"] = f"Search agent failed: {str(e)[:100]}"
state["stage"] = WorkflowStage.ERROR.value
return state
def run_chunking_agent(state: EcommerceAgentState, progress_callback=None) -> EcommerceAgentState:
"""Chunking Agent - Scrape and chunk documents"""
try:
print(f"\n{'='*70}\nπŸ“„ CHUNKING AGENT\n{'='*70}")
if progress_callback:
progress_callback("πŸ“„ Chunking documents...")
# Scrape URLs
cleaned_documents = []
for i, result in enumerate(state.get("search_results", []), 1):
url = result.get("url", "")
if not url:
continue
domain = extract_domain(url)
print(f"πŸ“₯ Scraping [{i}]: {domain}")
try:
content = scrape_product_content.invoke({"url": url})
if content and "Error" not in content and len(content) > 200:
cleaned_documents.append({
"url": url,
"content": content,
"source": domain,
"scraped_at": datetime.now().isoformat()
})
print(f" βœ… Extracted {len(content)} chars")
state["progress_log"].append(f"βœ… Scraped: {domain}")
except Exception as e:
print(f" ⚠️ Error: {str(e)[:50]}")
state["progress_log"].append(f"⚠️ Scrape error: {domain}")
# Chunk documents
parent_chunks = []
child_chunks = []
parent_child_mapping = {}
for doc in cleaned_documents:
try:
chunks_result = chunk_content.invoke({
"content": doc["content"],
"url": doc["url"]
})
chunks_data = json.loads(chunks_result)
parent = chunks_data.get("parent", {})
children = chunks_data.get("children", [])
if parent:
parent_chunks.append(parent)
parent_child_mapping[parent["id"]] = [c["id"] for c in children]
child_chunks.extend(children)
except Exception as e:
print(f" ❌ Chunking error: {str(e)[:50]}")
state["cleaned_documents"] = cleaned_documents
state["parent_chunks"] = parent_chunks
state["child_chunks"] = child_chunks
state["parent_child_mapping"] = parent_child_mapping
state["stage"] = WorkflowStage.INDEXING.value
state["progress_log"].append(f"βœ… Chunking: {len(parent_chunks)} parents, {len(child_chunks)} children")
print(f"βœ… Chunking complete: {len(parent_chunks)} parents, {len(child_chunks)} children")
return state
except Exception as e:
state["error_message"] = f"Chunking agent failed: {str(e)}"
state["stage"] = WorkflowStage.ERROR.value
return state
def run_indexing_agent(state: EcommerceAgentState, progress_callback=None) -> EcommerceAgentState:
"""Indexing Agent - Index chunks in FAISS"""
try:
print(f"\n{'='*70}\nπŸ—‚οΈ INDEXING AGENT\n{'='*70}")
if progress_callback:
progress_callback("πŸ—‚οΈ Indexing chunks...")
# Combine all chunks
all_chunks = state.get("parent_chunks", []) + state.get("child_chunks", [])
if not all_chunks:
state["progress_log"].append("⚠️ No chunks to index")
state["stage"] = WorkflowStage.RETRIEVAL.value
return state
# Generate embeddings using tool
chunks_json = json.dumps({
"parent": state["parent_chunks"][0] if state["parent_chunks"] else {},
"children": state["child_chunks"]
})
result = generate_embeddings_and_index.invoke({"chunks_json": chunks_json})
index_result = json.loads(result)
state["stage"] = WorkflowStage.RETRIEVAL.value
state["progress_log"].append(f"βœ… Indexed {index_result.get('chunks_indexed', 0)} chunks")
print(f"βœ… Indexing complete")
return state
except Exception as e:
state["error_message"] = f"Indexing agent failed: {str(e)}"
state["stage"] = WorkflowStage.ERROR.value
return state
def run_retrieval_agent(state: EcommerceAgentState, progress_callback=None) -> EcommerceAgentState:
"""Retrieval Agent - Find similar chunks"""
try:
print(f"\n{'='*70}\nπŸ”Ž RETRIEVAL AGENT\n{'='*70}")
if progress_callback:
progress_callback("πŸ”Ž Retrieving similar content...")
query = state["query"]
# Check if FAISS has any chunks
if not faiss_index or faiss_index.ntotal == 0:
print("⚠️ FAISS index is empty - no chunks to retrieve")
state["retrieved_children"] = []
state["progress_log"].append("⚠️ FAISS index empty")
else:
# Retrieve chunks
result_json = retrieve_similar_chunks.invoke({"query_text": query, "k": 5})
if isinstance(result_json, dict):
results = result_json
else:
results = json.loads(result_json)
retrieved_children = []
for doc_id, doc_text, metadata, dist in zip(
results.get("ids", [[]])[0],
results.get("documents", [[]])[0],
results.get("metadatas", [[]])[0],
results.get("distances", [[]])[0]
):
retrieved_children.append({
"id": doc_id,
"document": doc_text,
"metadata": metadata,
"distance": float(dist)
})
state["retrieved_children"] = retrieved_children
state["progress_log"].append(f"βœ… Retrieved {len(retrieved_children)} chunks")
print(f"βœ… Retrieval complete: {len(retrieved_children)} chunks")
state["stage"] = WorkflowStage.SYNTHESIS.value
return state
except Exception as e:
print(f"❌ Retrieval error: {str(e)}")
state["error_message"] = f"Retrieval agent failed: {str(e)}"
state["stage"] = WorkflowStage.ERROR.value
return state
def run_synthesis_agent(state: EcommerceAgentState, progress_callback=None) -> EcommerceAgentState:
"""Synthesis Agent - Generate answer with robust error handling"""
try:
print(f"\n{'='*70}\nπŸ’‘ SYNTHESIS AGENT\n{'='*70}")
if progress_callback:
progress_callback("πŸ’‘ Generating answer...")
query = state["query"]
retrieved = state.get("retrieved_children", [])
chat_messages = state.get("chat_messages", [])
# Prepare context - handle empty case
if retrieved:
context = "\n\n".join([
f"[{i+1}] {chunk.get('metadata', {}).get('source', 'Unknown')}\n{chunk.get('document', '')}"
for i, chunk in enumerate(retrieved[:5])
])
confidence = min(0.9, len(retrieved) / 10.0)
else:
context = "[No relevant information found in knowledge base]"
confidence = 0.1
# Build chat history context
chat_context = ""
if chat_messages and len(chat_messages) > 1:
chat_context = "Previous Conversation:\n"
for msg in chat_messages[:-1]: # All except current query
role = msg.get("role", "").upper()
content = msg.get("content", "")
chat_context += f"{role}: {content[:300]}\n\n" # First 300 chars of each message
# Generate answer with retry logic built-in
answer = None
max_retries = 3
for retry_attempt in range(max_retries):
try:
print(f" ⏳ Generating (attempt {retry_attempt + 1}/{max_retries})...")
answer = generate_answer(query, context, chat_context)
# Check if answer is an error
if answer and not answer.startswith("Error"):
print(f" βœ… Answer generated successfully")
break
elif answer and answer.startswith("Error"):
print(f" ⚠️ {answer[:80]}")
if retry_attempt < max_retries - 1:
wait = 5 * (retry_attempt + 1)
print(f" ⏳ Waiting {wait}s before retry...")
import time
time.sleep(wait)
else:
print(f" ❌ Max retries reached")
except Exception as gen_err:
print(f" ⚠️ Generation error: {str(gen_err)[:80]}")
if retry_attempt < max_retries - 1:
wait = 5 * (retry_attempt + 1)
print(f" ⏳ Waiting {wait}s before retry...")
import time
time.sleep(wait)
else:
answer = f"Error generating answer: {str(gen_err)[:100]}"
if not answer:
answer = f"Error generating answer: Unable to generate response after {max_retries} attempts"
# Extract citations from retrieved chunks
citations = [
{
"source": chunk.get("metadata", {}).get("source", "Unknown"),
"section": "Content",
"url": chunk.get("metadata", {}).get("source", "")
}
for chunk in retrieved[:3]
]
state["final_answer"] = answer
state["citations"] = citations
state["confidence"] = confidence
state["stage"] = WorkflowStage.COMPLETE.value
state["progress_log"].append(f"βœ… Synthesis complete: Confidence={confidence:.2f}")
print(f"βœ… Synthesis complete")
return state
except Exception as e:
print(f"❌ Synthesis error: {str(e)}")
state["error_message"] = f"Synthesis agent failed: {str(e)}"
state["stage"] = WorkflowStage.ERROR.value
return state
# --- 6. WORKFLOW ORCHESTRATION ---
def run_complete_workflow(query: str, chat_history: List = None, progress=gr.Progress()) -> EcommerceAgentState:
"""Run the complete tool-based agentic pipeline with intelligent context"""
if chat_history is None:
chat_history = []
state = create_initial_state(query)
state["chat_messages"] = chat_history # Add chat history to state
def update_progress(msg: str):
progress(0.2, desc=msg)
# Step 1: Decision Agent - Determine if search is needed
needs_search, decision_reason = run_decision_agent(query, chat_history, update_progress)
state["progress_log"].append(f"πŸ€” {decision_reason}")
# Step 2: Conditionally run search agents
if needs_search:
progress(0.25, desc="πŸ” Searching...")
state = run_search_agent(state, update_progress)
if state["stage"] == WorkflowStage.ERROR.value:
return state
progress(0.35, desc="πŸ“„ Chunking...")
state = run_chunking_agent(state, update_progress)
if state["stage"] == WorkflowStage.ERROR.value:
return state
progress(0.5, desc="πŸ—‚οΈ Indexing...")
state = run_indexing_agent(state, update_progress)
if state["stage"] == WorkflowStage.ERROR.value:
return state
else:
# Skip search/chunking/indexing, go directly to retrieval
print("πŸ“š Reusing indexed data from previous query...")
state["progress_log"].append("πŸ“š Using existing indexed knowledge")
state["stage"] = WorkflowStage.RETRIEVAL.value
# Step 3: Retrieval and Synthesis (always run)
progress(0.65, desc="πŸ”Ž Retrieving...")
state = run_retrieval_agent(state, update_progress)
if state["stage"] == WorkflowStage.ERROR.value:
return state
progress(0.85, desc="πŸ’‘ Generating answer...")
state = run_synthesis_agent(state, update_progress)
progress(1.0, desc="βœ… Complete")
# Add to conversation history
add_to_conversation_history(query, state["final_answer"])
return state
# --- 7. GRADIO INTERFACE ---#
def chat_interface(message: str, chat_history: List) -> Tuple[str, List]:
"""Chat interface for the e-commerce RAG system"""
if not message or message.strip() == "":
return "", chat_history
# Add user message to history (new format: dict with role and content)
chat_history.append({"role": "user", "content": message})
# Run workflow with full chat history
state = run_complete_workflow(message, chat_history=chat_history)
# Build assistant response
assistant_response = f"""πŸ€– **Answer**
{state['final_answer']}
**Confidence:** {state['confidence']*100:.0f}%
---
πŸ“š **Sources:**
"""
for cite in state["citations"][:3]:
assistant_response += f"- {cite['source']}\n"
assistant_response += f"\n---\n\nπŸ“‹ **Process Log:**\n"
for log in state["progress_log"]:
assistant_response += f"βœ“ {log}\n"
if state["error_message"]:
assistant_response += f"\n⚠️ **Error:** {state['error_message']}"
# Update last message with assistant response (new format)
chat_history.append({"role": "assistant", "content": assistant_response})
return "", chat_history
# Build Gradio interface
with gr.Blocks(title="πŸ›οΈ E-commerce Chat RAG") as demo:
gr.Markdown("""
# πŸ›οΈ E-commerce Chat RAG System
Ask any question about e-commerce products and prices!
*Agents:* πŸ” Search β†’ πŸ“„ Chunk β†’ πŸ—‚οΈ Index β†’ πŸ”Ž Retrieve β†’ πŸ’‘ Synthesize
""")
chatbot = gr.Chatbot(
label="πŸ’¬ Chat",
height=600
)
with gr.Row():
msg = gr.Textbox(
label="Your Question",
placeholder="e.g., What is the current price of iPhone 14?",
lines=2,
scale=4
)
submit_btn = gr.Button("Send πŸ“€", variant="primary", scale=1)
# Chat submission
submit_btn.click(
fn=chat_interface,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
queue=True
)
# Allow Enter key to submit
msg.submit(
fn=chat_interface,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
queue=True
)
if __name__ == "__main__":
try:
# For HF Spaces: suppress asyncio event loop cleanup warnings
demo.launch(share=False, show_error=True)
finally:
# Clean up event loop on shutdown
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.stop()
loop.close()
except Exception:
pass