|
|
import logging |
|
|
import chromadb |
|
|
from chromadb.config import Settings |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from typing import List, Dict, Any |
|
|
import os |
|
|
|
|
|
|
|
|
from document_processor import prepare_product_documents, prepare_review_documents |
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
|
|
|
|
class VectorDBManager: |
|
|
""" |
|
|
Manages the creation, embedding, and population of a ChromaDB vector database. |
|
|
""" |
|
|
def __init__(self, db_path: str = "./chroma_db", model_name: str = 'BAAI/bge-large-en-v1.5'): |
|
|
""" |
|
|
Initializes the VectorDBManager. |
|
|
|
|
|
Args: |
|
|
db_path (str): Path to the ChromaDB database directory. |
|
|
model_name (str): The name of the sentence-transformer model to use for embeddings. |
|
|
""" |
|
|
logger.info(f"Initializing VectorDBManager with db_path='{db_path}' and model='{model_name}'") |
|
|
|
|
|
|
|
|
self.client = chromadb.PersistentClient( |
|
|
path=db_path, |
|
|
settings=Settings(allow_reset=True) |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
self.model = SentenceTransformer(model_name) |
|
|
logger.info(f"Successfully loaded embedding model: {model_name}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load sentence-transformer model '{model_name}'. Error: {e}") |
|
|
raise |
|
|
|
|
|
def shutdown(self): |
|
|
""" |
|
|
Shuts down the manager and resets the ChromaDB client to release file locks. |
|
|
""" |
|
|
logger.info("Shutting down VectorDBManager and resetting client.") |
|
|
self.client.reset() |
|
|
|
|
|
def _generate_embeddings(self, texts: List[str], batch_size: int = 32) -> List[List[float]]: |
|
|
""" |
|
|
Generates embeddings for a list of texts in batches. |
|
|
|
|
|
Args: |
|
|
texts (List[str]): A list of text strings to embed. |
|
|
batch_size (int): The batch size for embedding generation. |
|
|
|
|
|
Returns: |
|
|
List[List[float]]: A list of embedding vectors. |
|
|
""" |
|
|
logger.info(f"Generating embeddings for {len(texts)} documents in batches of {batch_size}...") |
|
|
embeddings = self.model.encode(texts, batch_size=batch_size, show_progress_bar=True) |
|
|
logger.info("Finished generating embeddings.") |
|
|
return embeddings.tolist() |
|
|
|
|
|
def populate_collection(self, collection_name: str, documents: List[Dict[str, Any]]): |
|
|
""" |
|
|
Populates a ChromaDB collection with documents, generating embeddings first. |
|
|
|
|
|
Args: |
|
|
collection_name (str): The name of the collection to create/populate. |
|
|
documents (List[Dict[str, Any]]): A list of prepared documents from document_processor. |
|
|
""" |
|
|
if not documents: |
|
|
logger.warning(f"No documents provided for collection '{collection_name}'. Skipping population.") |
|
|
return |
|
|
|
|
|
logger.info(f"Populating collection: '{collection_name}'") |
|
|
|
|
|
|
|
|
try: |
|
|
if collection_name in [c.name for c in self.client.list_collections()]: |
|
|
logger.warning(f"Collection '{collection_name}' already exists. Deleting it for a fresh population.") |
|
|
self.client.delete_collection(name=collection_name) |
|
|
except Exception as e: |
|
|
logger.error(f"Error deleting collection '{collection_name}': {e}") |
|
|
|
|
|
|
|
|
|
|
|
collection = self.client.create_collection(name=collection_name) |
|
|
|
|
|
|
|
|
ids = [doc['id'] for doc in documents] |
|
|
texts_to_embed = [doc['text_for_embedding'] for doc in documents] |
|
|
metadatas = [doc['metadata'] for doc in documents] |
|
|
|
|
|
|
|
|
embeddings = self._generate_embeddings(texts_to_embed) |
|
|
|
|
|
|
|
|
batch_size = 500 |
|
|
for i in range(0, len(ids), batch_size): |
|
|
batch_ids = ids[i:i + batch_size] |
|
|
batch_embeddings = embeddings[i:i + batch_size] |
|
|
batch_metadatas = metadatas[i:i + batch_size] |
|
|
|
|
|
batch_documents_text = texts_to_embed[i:i + batch_size] |
|
|
|
|
|
try: |
|
|
logger.info(f"Adding batch {i//batch_size + 1} to '{collection_name}' collection...") |
|
|
collection.add( |
|
|
embeddings=batch_embeddings, |
|
|
documents=batch_documents_text, |
|
|
metadatas=batch_metadatas, |
|
|
ids=batch_ids |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to add batch to collection '{collection_name}'. Error: {e}") |
|
|
|
|
|
return |
|
|
|
|
|
logger.info(f"Successfully populated '{collection_name}' with {collection.count()} items.") |
|
|
|
|
|
def run_etl_pipeline(products_file: str, reviews_file: str, db_path: str, model_name: str): |
|
|
""" |
|
|
Runs the full ETL (Extract, Transform, Load) pipeline to populate the vector database. |
|
|
|
|
|
Args: |
|
|
products_file (str): Path to the products JSON file. |
|
|
reviews_file (str): Path to the product reviews JSON file. |
|
|
db_path (str): Path to store the ChromaDB database. |
|
|
model_name (str): Name of the sentence-transformer model. |
|
|
""" |
|
|
logger.info("--- Starting RAG ETL Pipeline ---") |
|
|
|
|
|
|
|
|
db_manager = VectorDBManager(db_path=db_path, model_name=model_name) |
|
|
|
|
|
|
|
|
logger.info("Step 1: Preparing product documents...") |
|
|
product_documents = prepare_product_documents(products_file) |
|
|
db_manager.populate_collection("products", product_documents) |
|
|
|
|
|
|
|
|
logger.info("Step 2: Preparing review documents...") |
|
|
review_documents = prepare_review_documents(reviews_file, products_file) |
|
|
db_manager.populate_collection("reviews", review_documents) |
|
|
|
|
|
logger.info("--- RAG ETL Pipeline Finished Successfully ---") |
|
|
return db_manager |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
EMBEDDING_MODEL = 'BAAI/bge-large-en-v1.5' |
|
|
DB_PATH = "./chroma_db" |
|
|
|
|
|
|
|
|
CWD = os.getcwd() |
|
|
PRODUCTS_JSON_PATH = os.path.join(CWD, 'products.json') |
|
|
REVIEWS_JSON_PATH = os.path.join(CWD, 'product_reviews.json') |
|
|
|
|
|
|
|
|
if not os.path.exists(PRODUCTS_JSON_PATH) or not os.path.exists(REVIEWS_JSON_PATH): |
|
|
logger.error("Error: Make sure 'products.json' and 'product_reviews.json' exist in the project root directory.") |
|
|
else: |
|
|
run_etl_pipeline( |
|
|
products_file=PRODUCTS_JSON_PATH, |
|
|
reviews_file=REVIEWS_JSON_PATH, |
|
|
db_path=DB_PATH, |
|
|
model_name=EMBEDDING_MODEL |
|
|
) |
|
|
|