Spaces:
Sleeping
Sleeping
Merge branch 'main' of https://huggingface.co/spaces/supertskone/prompt-search-engine
aad4def
unverified
| import os | |
| import logging | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from datasets import load_dataset | |
| from pinecone import Pinecone, ServerlessSpec | |
| # Disable parallelism for tokenizers | |
| os.environ['TOKENIZERS_PARALLELISM'] = 'false' | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class Vectorizer: | |
| def __init__(self, model_name='all-mpnet-base-v2', batch_size=64, init_pinecone=True): | |
| logger.info(f"Initializing Vectorizer with model {model_name} and batch size {batch_size}") | |
| self.model = SentenceTransformer(model_name) | |
| self.prompts = [] | |
| self.batch_size = batch_size | |
| self.pinecone_index_name = "hfs-search-prompts-index" | |
| self._init_pinecone = init_pinecone | |
| self._setup_pinecone() | |
| self._load_prompts() | |
| def _setup_pinecone(self): | |
| logger.info("Setting up Pinecone") | |
| # Initialize Pinecone | |
| pinecone = Pinecone(api_key='b514eb66-8626-4697-8a1c-4c411c06c090') | |
| # Check if the Pinecone index exists, if not create it | |
| existing_indexes = pinecone.list_indexes() | |
| if self.pinecone_index_name not in existing_indexes: | |
| logger.info(f"Creating Pinecone index: {self.pinecone_index_name}") | |
| if self._init_pinecone: | |
| # pinecone.delete_index(self.pinecone_index_name) | |
| pinecone.create_index( | |
| name=self.pinecone_index_name, | |
| dimension=768, | |
| metric='cosine', | |
| spec=ServerlessSpec( | |
| cloud="aws", | |
| region="us-east-1" | |
| ) | |
| ) | |
| else: | |
| logger.info(f"Pinecone index {self.pinecone_index_name} already exists") | |
| pinecone.delete_index(self.pinecone_index_name) | |
| self.index = pinecone.Index(self.pinecone_index_name) | |
| def _load_prompts(self): | |
| logger.info("Loading prompts from Pinecone") | |
| self.prompts = [] | |
| # Fetch vectors from the Pinecone index | |
| index_stats = self.index.describe_index_stats() | |
| logger.info(f"Index stats: {index_stats}") | |
| namespaces = index_stats['namespaces'] | |
| for namespace, stats in namespaces.items(): | |
| vector_count = stats['vector_count'] | |
| ids = [str(i) for i in range(vector_count)] | |
| for i in range(0, vector_count, self.batch_size): | |
| batch_ids = ids[i:i + self.batch_size] | |
| response = self.index.fetch(ids=batch_ids) | |
| for vector in response.vectors.values(): | |
| metadata = vector.get('metadata') | |
| if metadata and 'text' in metadata: | |
| self.prompts.append(metadata['text']) | |
| logger.info(f"Loaded {len(self.prompts)} prompts from Pinecone") | |
| def _store_prompts(self, dataset): | |
| logger.info("Storing prompts in Pinecone") | |
| for i in range(0, len(dataset), self.batch_size): | |
| batch = dataset[i:i + self.batch_size] | |
| vectors = self.model.encode(batch) | |
| # Prepare data for Pinecone | |
| pinecone_data = [{'id': str(i + j), 'values': vector.tolist(), 'metadata': {'text': batch[j]}} for j, vector | |
| in enumerate(vectors)] | |
| self.index.upsert(vectors=pinecone_data) | |
| logger.info(f"Upserted batch {i // self.batch_size + 1}/{len(dataset) // self.batch_size + 1} to Pinecone") | |
| def transform(self, prompts): | |
| return np.array(self.model.encode(prompts)) | |
| def store_from_dataset(self, store_data=False): | |
| if store_data: | |
| logger.info("Loading dataset") | |
| dataset = load_dataset('fantasyfish/laion-art', split='train') | |
| logger.info(f"Loaded {len(dataset)} items from dataset") | |
| logger.info("Please wait for storing. This may take up to five minutes. ") | |
| self._store_prompts([item['text'] for item in dataset]) | |
| logger.info("Items from dataset are stored.") | |
| # Ensure prompts are loaded after storing | |
| self._load_prompts() | |
| logger.info("Items from dataset are loaded.") | |