Spaces:
Sleeping
Sleeping
File size: 5,787 Bytes
c627f4d eec9d2b 857eb0d eec9d2b c627f4d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | import logging
from typing import List, Dict, Any
import pickle
import nltk
from nltk.tokenize import word_tokenize
from rank_bm25 import BM25Okapi
import chromadb
from chromadb.config import Settings
from openai import OpenAI
import pandas as pd
from tqdm import tqdm
from dotenv import load_dotenv
import os
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
class VectorStoreCreator:
"""Class to create and manage vector stores for dog food product search."""
def __init__(self, data_path: str):
"""
Initialize the VectorStoreCreator.
Args:
data_path: Path to the pickle file containing the product data
"""
# Load environment variables
#load_dotenv()
# Obtener las claves de los secrets de Hugging Face
#openai.api_key = st.secrets["OPENAI_API_KEY"].strip()
#os.environ["LANGCHAIN_API_KEY"] = st.secrets["LANGCHAIN_API_KEY"]
#os.environ["LANGCHAIN_TRACING_V2"] = st.secrets["LANGCHAIN_TRACING_V2"]
# Initialize OpenAI client
self.client = OpenAI()
# Download NLTK resources
nltk.download('punkt', quiet=True)
# Load data
self.df = pd.read_pickle(data_path)
# Initialize stores
self.bm25_model = None
self.chroma_collection = None
self.chunks = []
self.metadata = []
def prepare_data(self) -> None:
"""Prepare data for BM25 and embeddings."""
logging.info("Preparing data for vector stores...")
# Log initial dataframe info
total_rows = len(self.df)
logging.info(f"Total rows in DataFrame: {total_rows}")
for _, row in self.df.iterrows():
# Combine English and Spanish descriptions
combined_text = f"{row['description_en']} {row['description_es']}"
self.chunks.append(combined_text)
# Create metadata
metadata = {
"product_name": row["product_name"],
"brand": row["brand"],
"dog_type": row["dog_type"],
"food_type": row["food_type"],
"weight": float(row["weight"]),
"price": float(row["price"]),
"reviews": float(row["reviews"]) if pd.notna(row["reviews"]) else 0.0
}
self.metadata.append(metadata)
# Log final chunks info
logging.info(f"Total chunks created: {len(self.chunks)}")
if len(self.chunks) != total_rows:
logging.warning(f"Mismatch between DataFrame rows ({total_rows}) and chunks created ({len(self.chunks)})")
# Log sample of first chunk
if self.chunks:
logging.info(f"Sample of first chunk: {self.chunks[0][:200]}...")
def create_bm25_index(self, save_path: str = "bm25_index.pkl") -> None:
"""
Create and save BM25 index.
Args:
save_path: Path to save the BM25 index
"""
logging.info("Creating BM25 index...")
# Tokenize chunks
tokenized_chunks = [word_tokenize(chunk.lower()) for chunk in self.chunks]
# Create BM25 model
self.bm25_model = BM25Okapi(tokenized_chunks)
# Save the model and related data
with open(save_path, 'wb') as f:
pickle.dump({
'model': self.bm25_model,
'chunks': self.chunks,
'metadata': self.metadata
}, f)
logging.info(f"BM25 index saved to {save_path}")
def create_chroma_db(self, db_path: str = "chroma_db") -> None:
"""
Create ChromaDB database.
Args:
db_path: Path to save the ChromaDB
"""
logging.info("Creating ChromaDB database...")
# Initialize ChromaDB with new client syntax
client = chromadb.PersistentClient(path=db_path)
# Create or get collection
self.chroma_collection = client.get_or_create_collection(
name="dog_food_descriptions"
)
# Add documents in batches
batch_size = 10
for i in tqdm(range(0, len(self.chunks), batch_size)):
batch_chunks = self.chunks[i:i + batch_size]
batch_metadata = self.metadata[i:i + batch_size]
batch_ids = [str(idx) for idx in range(i, min(i + batch_size, len(self.chunks)))]
# Get embeddings for batch
embeddings = []
for chunk in batch_chunks:
response = self.client.embeddings.create(
model="text-embedding-ada-002",
input=chunk
)
embeddings.append(response.data[0].embedding)
# Add to collection
self.chroma_collection.add(
embeddings=embeddings,
metadatas=batch_metadata,
documents=batch_chunks,
ids=batch_ids
)
logging.info(f"ChromaDB saved to {db_path}")
def main():
"""Main execution function."""
try:
# Initialize creator
creator = VectorStoreCreator("3rd_clean_comida_dogs_enriched_multilingual_2.pkl")
# Prepare data
creator.prepare_data()
# Create indices
creator.create_bm25_index()
creator.create_chroma_db()
logging.info("Vector stores created successfully!")
except Exception as e:
logging.error(f"An error occurred: {e}")
raise
if __name__ == "__main__":
main()
|