|
|
import logging |
|
|
import chromadb |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from typing import List, Dict, Any, Tuple |
|
|
import os |
|
|
import json |
|
|
import re |
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
|
|
|
|
class RetrievalManager: |
|
|
""" |
|
|
Manages retrieving documents from a ChromaDB vector database based on a user query. |
|
|
""" |
|
|
|
|
|
STOP_WORDS = { |
|
|
"a", "an", "the", "is", "are", "and", "or", "but", "for", "with", |
|
|
"in", "on", "at", "of", "to", "from", "by", "about", "as", "into", |
|
|
"like", "through", "after", "before", "over", "under", "above", "below", |
|
|
"up", "down", "out", "off", "then", "once", "here", "there", "when", |
|
|
"where", "why", "how", "all", "any", "both", "each", "few", "more", |
|
|
"most", "other", "some", "such", "no", "nor", "not", "only", "own", |
|
|
"same", "so", "than", "too", "very", "s", "t", "can", "will", "just", |
|
|
"don", "should", "now", "compare", "x", "y" |
|
|
} |
|
|
def __init__(self, db_path: str = "./chroma_db", model_name: str = 'BAAI/bge-large-en-v1.5'): |
|
|
""" |
|
|
Initializes the RetrievalManager. |
|
|
|
|
|
Args: |
|
|
db_path (str): Path to the ChromaDB database directory. |
|
|
model_name (str): The name of the sentence-transformer model to use for embeddings. |
|
|
""" |
|
|
logger.info(f"Initializing RetrievalManager with db_path='{db_path}' and model='{model_name}'") |
|
|
|
|
|
|
|
|
self.client = chromadb.PersistentClient(path=db_path) |
|
|
|
|
|
|
|
|
try: |
|
|
self.model = SentenceTransformer(model_name) |
|
|
logger.info(f"Successfully loaded embedding model: {model_name}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load sentence-transformer model '{model_name}'. Error: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
self.filterable_metadata = None |
|
|
try: |
|
|
with open("filterable_metadata.json", "r") as f: |
|
|
self.filterable_metadata = json.load(f) |
|
|
logger.info("Successfully loaded filterable_metadata.json") |
|
|
except FileNotFoundError: |
|
|
logger.warning("filterable_metadata.json not found. Filtering by brand/category will be disabled.") |
|
|
except json.JSONDecodeError: |
|
|
logger.error("Failed to decode filterable_metadata.json. Filtering by brand/category will be disabled.") |
|
|
|
|
|
def _generate_query_embedding(self, query: str) -> List[float]: |
|
|
""" |
|
|
Generates an embedding for a single query string. |
|
|
|
|
|
Args: |
|
|
query (str): The user query string. |
|
|
|
|
|
Returns: |
|
|
List[float]: The embedding vector for the query. |
|
|
""" |
|
|
logger.info(f"Generating embedding for query: '{query}'") |
|
|
embedding = self.model.encode(query, convert_to_tensor=False) |
|
|
logger.info("Finished generating query embedding.") |
|
|
return embedding.tolist() |
|
|
|
|
|
def _extract_filters(self, query: str) -> Tuple[Dict[str, Any], str]: |
|
|
""" |
|
|
Extracts metadata filters (price, brand, category) from the query string |
|
|
and returns a ChromaDB-compatible filter dictionary and a cleaned query. |
|
|
|
|
|
Args: |
|
|
query (str): The user's raw query. |
|
|
|
|
|
Returns: |
|
|
A tuple containing the `where` filter dictionary and the cleaned query string. |
|
|
""" |
|
|
cleaned_query = query |
|
|
filters = [] |
|
|
parts_to_remove = [] |
|
|
|
|
|
|
|
|
price_patterns = { |
|
|
"lt": re.compile(r"(?:under|less than|below|max of|\$lt)\s*\$?(\d+\.?\d*)", re.IGNORECASE), |
|
|
"gt": re.compile(r"(?:over|more than|above|min of|\$gt)\s*\$?(\d+\.?\d*)", re.IGNORECASE), |
|
|
} |
|
|
|
|
|
for op, pattern in price_patterns.items(): |
|
|
match = pattern.search(cleaned_query) |
|
|
if match: |
|
|
price = float(match.group(1)) |
|
|
filters.append({"price": {f"${op}": price}}) |
|
|
parts_to_remove.append(match.group(0)) |
|
|
|
|
|
|
|
|
if self.filterable_metadata: |
|
|
|
|
|
brands = sorted(self.filterable_metadata.get("brands", []), key=len, reverse=True) |
|
|
categories = sorted(self.filterable_metadata.get("categories", []), key=len, reverse=True) |
|
|
|
|
|
for brand in brands: |
|
|
pattern = r'\b' + re.escape(brand) + r'\b' |
|
|
match = re.search(pattern, cleaned_query, re.IGNORECASE) |
|
|
if match: |
|
|
filters.append({"brand": brand}) |
|
|
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
category_found = False |
|
|
query_lower = cleaned_query.lower() |
|
|
|
|
|
|
|
|
if any(term in query_lower for term in ["laptop", "computer", "pc"]): |
|
|
|
|
|
for category in self.filterable_metadata.get("categories", []): |
|
|
if "Computers and Laptops" in category: |
|
|
filters.append({"category": category}) |
|
|
category_found = True |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CATEGORY_SYNONYMS = { |
|
|
"Computers and Laptops": ["laptop", "laptops", "computer", "notebook", "ultrabook", \ |
|
|
"chromebook", "PC", "desktop", "workstation", "gaming laptop"], |
|
|
"Cameras and Camcorders": ["camera", "cameras", "camcorder", "photo", "video camera"], |
|
|
"Gaming Consoles and Accessories": ["console", "gaming", "games", "controller"], |
|
|
"Smartphones and Accessories": ["phone", "smartphone", "mobile", "case", "charger"], |
|
|
"Audio Equipment": ["audio", "speaker", "headphone", "earbud", "sound"], |
|
|
"Televisions and Home Theater Systems": ["tv", "television", "home theater", "soundbar", "tvs"] |
|
|
} |
|
|
|
|
|
|
|
|
priority_categories = ["Computers and Laptops", "Cameras and Camcorders", "Smartphones and Accessories"] |
|
|
|
|
|
matched_categories = set() |
|
|
query_lower = query.lower() |
|
|
|
|
|
|
|
|
for category in priority_categories: |
|
|
multi_word_synonyms = [syn for syn in CATEGORY_SYNONYMS[category] if len(syn.split()) > 1] |
|
|
for syn in multi_word_synonyms: |
|
|
if syn in query_lower: |
|
|
matched_categories.add(category) |
|
|
break |
|
|
|
|
|
|
|
|
if not matched_categories: |
|
|
for category in CATEGORY_SYNONYMS.keys(): |
|
|
single_word_synonyms = [syn for syn in CATEGORY_SYNONYMS[category] if len(syn.split()) == 1] |
|
|
if any(syn in query_lower for syn in single_word_synonyms): |
|
|
matched_categories.add(category) |
|
|
|
|
|
|
|
|
for category in matched_categories: |
|
|
filters.append({"category": category}) |
|
|
category_found = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
where_filter = {} |
|
|
if filters: |
|
|
if len(filters) > 1: |
|
|
where_filter = {"$or": filters} |
|
|
else: |
|
|
where_filter = filters[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if where_filter: |
|
|
logger.info(f"Extracted filter: {where_filter}") |
|
|
logger.info(f"Cleaned query for embedding: '{cleaned_query}'") |
|
|
|
|
|
return where_filter, cleaned_query |
|
|
|
|
|
def _route_query(self, query: str) -> List[str]: |
|
|
""" |
|
|
Determines which collection(s) to query based on keywords in the query. |
|
|
This provides a simple, fast routing mechanism before performing a semantic search. |
|
|
|
|
|
Args: |
|
|
query (str): The user query string. |
|
|
|
|
|
Returns: |
|
|
List[str]: A list of collection names to target for the search. |
|
|
""" |
|
|
query_lower = query.lower() |
|
|
|
|
|
review_keywords = ["review", "customer", "feedback", "complaints", "say about", "opinion", "experience"] |
|
|
|
|
|
product_keywords = ["specs", "features", "have", "available", "do you have", \ |
|
|
"specification", "technical details", "price", "warranty"] |
|
|
|
|
|
target_collections = [] |
|
|
|
|
|
|
|
|
if any(keyword in query_lower for keyword in review_keywords): |
|
|
target_collections.append("reviews") |
|
|
|
|
|
|
|
|
if any(keyword in query_lower for keyword in product_keywords): |
|
|
target_collections.append("products") |
|
|
|
|
|
|
|
|
if not target_collections: |
|
|
logger.info("No specific keywords found in query, defaulting to both 'products' and 'reviews' collections.") |
|
|
return ["products", "reviews"] |
|
|
|
|
|
logger.info(f"Routing query to collections: {target_collections}") |
|
|
return target_collections |
|
|
|
|
|
def search(self, query: str) -> Dict[str, List[Dict[str, Any]]]: |
|
|
""" |
|
|
Performs a semantic search across relevant collections based on the user query. |
|
|
|
|
|
This method orchestrates the query processing, routing, and retrieval steps. |
|
|
|
|
|
Args: |
|
|
query (str): The user's search query. |
|
|
|
|
|
Returns: |
|
|
Dict[str, List[Dict[str, Any]]]: A dictionary where keys are collection names |
|
|
and values are the retrieval results from ChromaDB. |
|
|
""" |
|
|
logger.info(f"--- Starting search for query: '{query}' ---") |
|
|
|
|
|
|
|
|
where_filter, cleaned_query = self._extract_filters(query) |
|
|
|
|
|
|
|
|
embedding_query = cleaned_query if cleaned_query else query |
|
|
logger.info(f"Using query for embedding: '{embedding_query}'") |
|
|
|
|
|
|
|
|
target_collections = self._route_query(query) |
|
|
|
|
|
|
|
|
query_embedding = self._generate_query_embedding(embedding_query) |
|
|
|
|
|
results = {} |
|
|
|
|
|
|
|
|
for collection_name in target_collections: |
|
|
try: |
|
|
collection = self.client.get_collection(name=collection_name) |
|
|
|
|
|
|
|
|
n_results = 5 if collection_name == "products" else 8 |
|
|
|
|
|
logger.info(f"Querying '{collection_name}' collection with n_results={n_results} and filter={where_filter}...") |
|
|
|
|
|
retrieved = collection.query( |
|
|
query_embeddings=[query_embedding], |
|
|
n_results=n_results, |
|
|
where=where_filter if where_filter else None |
|
|
) |
|
|
results[collection_name] = retrieved |
|
|
logger.info(f"Retrieved {len(retrieved.get('ids', [[]])[0])} results from '{collection_name}'.") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to query collection '{collection_name}'. Error: {e}") |
|
|
results[collection_name] = [] |
|
|
|
|
|
logger.info(f"--- Finished search for query: '{query}' ---") |
|
|
return results |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
EMBEDDING_MODEL = 'BAAI/bge-large-en-v1.5' |
|
|
DB_PATH = "./chroma_db" |
|
|
|
|
|
|
|
|
if not os.path.exists(DB_PATH): |
|
|
logger.error(f"FATAL: ChromaDB path '{DB_PATH}' not found.") |
|
|
logger.error("Please run 'vector_db_manager.py' first to create and populate the database.") |
|
|
else: |
|
|
|
|
|
logger.info("\n--- Starting Retrieval Test ---") |
|
|
try: |
|
|
retrieval_manager = RetrievalManager(db_path=DB_PATH, model_name=EMBEDDING_MODEL) |
|
|
|
|
|
|
|
|
queries = [ |
|
|
"What laptops do you have?", |
|
|
"Gaming laptops", |
|
|
"Lightweight laptop", |
|
|
"What do customers say about battery life?", |
|
|
"SmartX ProPhone camera reviews", |
|
|
"Any feedback on the AudioBliss headphones?", |
|
|
"Show me TVs under $500", |
|
|
"TechPro gaming laptops over $1000" |
|
|
] |
|
|
|
|
|
for query in queries: |
|
|
print("\n" + "="*60) |
|
|
print(f"Executing query: '{query}'") |
|
|
print("="*60) |
|
|
|
|
|
search_results = retrieval_manager.search(query) |
|
|
|
|
|
for collection_name, results in search_results.items(): |
|
|
print(f"\n--- Results from '{collection_name}' collection ---") |
|
|
if results and results.get('documents') and results['documents'][0]: |
|
|
for i, doc in enumerate(results['documents'][0]): |
|
|
distance = results['distances'][0][i] if results.get('distances') else 'N/A' |
|
|
metadata = results['metadatas'][0][i] if results.get('metadatas') else {} |
|
|
|
|
|
print(f" - Result {i+1} (Distance: {distance:.4f}):") |
|
|
print(f" ID: {results['ids'][0][i]}") |
|
|
print(f" Document: {doc[:120].strip()}...") |
|
|
print(f" Metadata: {metadata}") |
|
|
else: |
|
|
print(" No results found in this collection.") |
|
|
|
|
|
print("\n" + "="*60) |
|
|
logger.info("--- Retrieval Test Finished ---") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"An error occurred during the retrieval test: {e}", exc_info=True) |
|
|
|