Spaces:
Configuration error
Configuration error
| import os | |
| import pickle | |
| from collections import defaultdict | |
| from typing import List, Tuple | |
| import numpy as np | |
| import scipy | |
| import torch | |
| import tqdm | |
| from loguru import logger | |
| from transformers import AutoModelForMaskedLM, AutoTokenizer | |
| from app.config.models.configs import Config, Document | |
| from app.utils import torch_device, split | |
| class SpladeSparseVectorDB: | |
| def __init__( | |
| self, | |
| config: Config, | |
| ) -> None: | |
| self._config = config | |
| # cuda or mps or cpu | |
| self._device = torch_device() | |
| logger.info(f"Setting device to {self._device}") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| "naver/splade-v3", device=self._device, use_fast=True | |
| ) | |
| self.model = AutoModelForMaskedLM.from_pretrained("naver/splade-v3") | |
| self.model.to(self._device) | |
| self._embeddings = None | |
| self._ids = None | |
| self._l2_norm_matrix = None | |
| self._labels_to_ind = defaultdict(list) | |
| self._chunk_size_to_ind = defaultdict(list) | |
| self.n_batch = config.embeddings.splade_config.n_batch | |
| def _get_batch_embeddings( | |
| self, docs: List[str] | |
| ) -> np.ndarray: | |
| tokens = self.tokenizer( | |
| docs, return_tensors="pt", padding=True, truncation=True | |
| ).to(self._device) | |
| output = self.model(**tokens) | |
| vecs = ( | |
| torch.max( | |
| torch.log(1 + torch.relu(output.logits)) | |
| * tokens.attention_mask.unsqueeze(-1), | |
| dim=1, | |
| )[0] | |
| .squeeze() | |
| .detach() | |
| .cpu() | |
| .numpy() | |
| ) | |
| del output | |
| del tokens | |
| return vecs | |
| def _get_embedding_fnames(self): | |
| folder_name = os.path.join(self._config.embeddings.embeddings_path, "splade") | |
| fn_embeddings = os.path.join(folder_name, "splade_embeddings.npz") | |
| fn_ids = os.path.join(folder_name, "splade_ids.pickle") | |
| fn_metadatas = os.path.join(folder_name, "splade_metadatas.pickle") | |
| return folder_name, fn_embeddings, fn_ids, fn_metadatas | |
| def load(self) -> None: | |
| _, fn_embeddings, fn_ids, fn_metadatas = self._get_embedding_fnames() | |
| try: | |
| self._embeddings = scipy.sparse.load_npz(fn_embeddings) | |
| with open(fn_ids, "rb") as fp: | |
| self._ids = np.array(pickle.load(fp)) | |
| with open(fn_metadatas, "rb") as fm: | |
| self._metadatas = np.array(pickle.load(fm)) | |
| self._l2_norm_matrix = scipy.sparse.linalg.norm(self._embeddings, axis=1) | |
| for ind, m in enumerate(self._metadatas): | |
| if m["label"]: | |
| self._labels_to_ind[m["label"]].append(ind) | |
| self._chunk_size_to_ind[m["chunk_size"]].append(ind) | |
| logger.info(f"SPLADE: Got {len(self._labels_to_ind)} labels.") | |
| except FileNotFoundError: | |
| raise FileNotFoundError( | |
| "Embeddings don't exist" | |
| ) | |
| logger.info(f"Loaded sparse embeddings from {fn_embeddings}") | |
| def generate_embeddings( | |
| self, docs: List[Document], persist: bool = True | |
| ) -> Tuple[np.ndarray, List[str], List[dict]]: | |
| chunk_size = self.n_batch | |
| ids = [d.metadata["document_id"] for d in docs] | |
| metadatas = [d.metadata for d in docs] | |
| vecs = [] | |
| for chunk in tqdm.tqdm( | |
| split(docs, chunk_size=chunk_size), total=int(len(docs) / chunk_size) | |
| ): | |
| texts = [d.page_content for d in chunk if d.page_content] | |
| vecs.append(self._get_batch_embeddings(texts)) | |
| embeddings = np.vstack(vecs) | |
| if persist: | |
| self.persist_embeddings(embeddings, metadatas, ids) | |
| return embeddings, ids, metadatas | |
| def persist_embeddings(self, embeddings, metadatas, ids): | |
| folder_name, fn_embeddings, fn_ids, fn_metadatas = self._get_embedding_fnames() | |
| csr_embeddings = scipy.sparse.csr_matrix(embeddings) | |
| if not os.path.exists(folder_name): | |
| os.makedirs(folder_name) | |
| scipy.sparse.save_npz(fn_embeddings, csr_embeddings) | |
| self.save_list(ids, fn_ids) | |
| self.save_list(metadatas, fn_metadatas) | |
| logger.info(f"Saved embeddings to {fn_embeddings}") | |
| def query( | |
| self, search: str, chunk_size: int, n: int = 50, label: str = "" | |
| ) -> Tuple[np.ndarray, np.ndarray]: | |
| if self._embeddings is None or self._ids is None: | |
| logger.info("Loading embeddings...") | |
| self.load() | |
| if ( | |
| label | |
| and label in self._labels_to_ind | |
| and self._embeddings is not None | |
| and self._ids is not None | |
| ): | |
| indices = sorted( | |
| list( | |
| set(self._labels_to_ind[label]).intersection( | |
| set(self._chunk_size_to_ind[chunk_size]) | |
| ) | |
| ) | |
| ) | |
| else: | |
| indices = sorted(list(set(self._chunk_size_to_ind[chunk_size]))) | |
| embeddings = self._embeddings[indices] | |
| ids = self._ids[indices] | |
| l2_norm_matrix = scipy.sparse.linalg.norm(embeddings, axis=1) | |
| embed_query = self._get_batch_embeddings(docs=[search]) | |
| l2_norm_query = scipy.linalg.norm(embed_query) | |
| if embeddings is not None and l2_norm_matrix is not None and ids is not None: | |
| cosine_similarity = embeddings.dot(embed_query) / ( | |
| l2_norm_matrix * l2_norm_query | |
| ) | |
| most_similar = np.argsort(cosine_similarity) | |
| top_similar_indices = most_similar[-n:][::-1] | |
| return ( | |
| ids[top_similar_indices], | |
| cosine_similarity[top_similar_indices], | |
| ) | |
| def save_list(self, list_: list, fname: str) -> None: | |
| with open(fname, "wb") as fp: | |
| pickle.dump(list_, fp) | |