Spaces:
Sleeping
Sleeping
| import logging | |
| from typing import List, Dict, Any | |
| import pickle | |
| import nltk | |
| from nltk.tokenize import word_tokenize | |
| from rank_bm25 import BM25Okapi | |
| import chromadb | |
| from chromadb.config import Settings | |
| from openai import OpenAI | |
| import pandas as pd | |
| from tqdm import tqdm | |
| from dotenv import load_dotenv | |
| import os | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s' | |
| ) | |
| class VectorStoreCreator: | |
| """Class to create and manage vector stores for dog food product search.""" | |
| def __init__(self, data_path: str): | |
| """ | |
| Initialize the VectorStoreCreator. | |
| Args: | |
| data_path: Path to the pickle file containing the product data | |
| """ | |
| # Load environment variables | |
| #load_dotenv() | |
| # Obtener las claves de los secrets de Hugging Face | |
| #openai.api_key = st.secrets["OPENAI_API_KEY"].strip() | |
| #os.environ["LANGCHAIN_API_KEY"] = st.secrets["LANGCHAIN_API_KEY"] | |
| #os.environ["LANGCHAIN_TRACING_V2"] = st.secrets["LANGCHAIN_TRACING_V2"] | |
| # Initialize OpenAI client | |
| self.client = OpenAI() | |
| # Download NLTK resources | |
| nltk.download('punkt', quiet=True) | |
| # Load data | |
| self.df = pd.read_pickle(data_path) | |
| # Initialize stores | |
| self.bm25_model = None | |
| self.chroma_collection = None | |
| self.chunks = [] | |
| self.metadata = [] | |
| def prepare_data(self) -> None: | |
| """Prepare data for BM25 and embeddings.""" | |
| logging.info("Preparing data for vector stores...") | |
| # Log initial dataframe info | |
| total_rows = len(self.df) | |
| logging.info(f"Total rows in DataFrame: {total_rows}") | |
| for _, row in self.df.iterrows(): | |
| # Combine English and Spanish descriptions | |
| combined_text = f"{row['description_en']} {row['description_es']}" | |
| self.chunks.append(combined_text) | |
| # Create metadata | |
| metadata = { | |
| "product_name": row["product_name"], | |
| "brand": row["brand"], | |
| "dog_type": row["dog_type"], | |
| "food_type": row["food_type"], | |
| "weight": float(row["weight"]), | |
| "price": float(row["price"]), | |
| "reviews": float(row["reviews"]) if pd.notna(row["reviews"]) else 0.0 | |
| } | |
| self.metadata.append(metadata) | |
| # Log final chunks info | |
| logging.info(f"Total chunks created: {len(self.chunks)}") | |
| if len(self.chunks) != total_rows: | |
| logging.warning(f"Mismatch between DataFrame rows ({total_rows}) and chunks created ({len(self.chunks)})") | |
| # Log sample of first chunk | |
| if self.chunks: | |
| logging.info(f"Sample of first chunk: {self.chunks[0][:200]}...") | |
| def create_bm25_index(self, save_path: str = "bm25_index.pkl") -> None: | |
| """ | |
| Create and save BM25 index. | |
| Args: | |
| save_path: Path to save the BM25 index | |
| """ | |
| logging.info("Creating BM25 index...") | |
| # Tokenize chunks | |
| tokenized_chunks = [word_tokenize(chunk.lower()) for chunk in self.chunks] | |
| # Create BM25 model | |
| self.bm25_model = BM25Okapi(tokenized_chunks) | |
| # Save the model and related data | |
| with open(save_path, 'wb') as f: | |
| pickle.dump({ | |
| 'model': self.bm25_model, | |
| 'chunks': self.chunks, | |
| 'metadata': self.metadata | |
| }, f) | |
| logging.info(f"BM25 index saved to {save_path}") | |
| def create_chroma_db(self, db_path: str = "chroma_db") -> None: | |
| """ | |
| Create ChromaDB database. | |
| Args: | |
| db_path: Path to save the ChromaDB | |
| """ | |
| logging.info("Creating ChromaDB database...") | |
| # Initialize ChromaDB with new client syntax | |
| client = chromadb.PersistentClient(path=db_path) | |
| # Create or get collection | |
| self.chroma_collection = client.get_or_create_collection( | |
| name="dog_food_descriptions" | |
| ) | |
| # Add documents in batches | |
| batch_size = 10 | |
| for i in tqdm(range(0, len(self.chunks), batch_size)): | |
| batch_chunks = self.chunks[i:i + batch_size] | |
| batch_metadata = self.metadata[i:i + batch_size] | |
| batch_ids = [str(idx) for idx in range(i, min(i + batch_size, len(self.chunks)))] | |
| # Get embeddings for batch | |
| embeddings = [] | |
| for chunk in batch_chunks: | |
| response = self.client.embeddings.create( | |
| model="text-embedding-ada-002", | |
| input=chunk | |
| ) | |
| embeddings.append(response.data[0].embedding) | |
| # Add to collection | |
| self.chroma_collection.add( | |
| embeddings=embeddings, | |
| metadatas=batch_metadata, | |
| documents=batch_chunks, | |
| ids=batch_ids | |
| ) | |
| logging.info(f"ChromaDB saved to {db_path}") | |
| def main(): | |
| """Main execution function.""" | |
| try: | |
| # Initialize creator | |
| creator = VectorStoreCreator("3rd_clean_comida_dogs_enriched_multilingual_2.pkl") | |
| # Prepare data | |
| creator.prepare_data() | |
| # Create indices | |
| creator.create_bm25_index() | |
| creator.create_chroma_db() | |
| logging.info("Vector stores created successfully!") | |
| except Exception as e: | |
| logging.error(f"An error occurred: {e}") | |
| raise | |
| if __name__ == "__main__": | |
| main() | |