| | from abc import ABC, abstractmethod |
| |
|
| | import pandas as pd |
| | import torch |
| | from datasets import load_from_disk |
| | from sentence_transformers import SentenceTransformer |
| |
|
| | |
| |
|
| |
|
| | class TextEmbedder(ABC): |
| | def __init__(self, model_name, paragraphs_path, device, load_existing_index=False): |
| | self.dataset = load_from_disk(paragraphs_path) |
| | self.model = self._load_model(model_name, device) |
| |
|
| | assert len(self.dataset) > 0, "The loaded dataset is empty !!" |
| |
|
| | if load_existing_index == True: |
| | self.dataset.load_faiss_index( |
| | "embeddings", f"{paragraphs_path}/index.faiss" |
| | ) |
| |
|
| | def generate_paragraphs_embedding(self): |
| | self.dataset = self.dataset.map( |
| | lambda x: {"embeddings": self._generate_embeddings(x["content"])} |
| | ) |
| |
|
| | def save_embeddings(self, output_path): |
| | self.dataset.add_faiss_index(column="embeddings") |
| | self.dataset.save_faiss_index("embeddings", f"{output_path}/index.faiss") |
| |
|
| | def retrieve_faiss(self, query: str, k_total: int, threshold: int): |
| | question_embedding = self._generate_embeddings(query) |
| | scores, samples = self.dataset.get_nearest_examples( |
| | "embeddings", question_embedding, k=k_total |
| | ) |
| | passages_df = pd.DataFrame(samples) |
| | passages_df["scores"] = scores / 100 |
| | passages_df = passages_df[passages_df["scores"] > threshold] |
| | passages_df = passages_df.sort_values(by=["scores"], ascending=False) |
| |
|
| | if len(passages_df) == 0: |
| | return [], [] |
| |
|
| | contents = passages_df["content"].tolist() |
| | meta = passages_df.drop(columns=["content"]).to_dict(orient="records") |
| | passages = [] |
| | for i in range(len(contents)): |
| | passages.append({"content": contents[i], "meta": meta[i]}) |
| | return passages, passages_df["scores"].values |
| |
|
| | def retrieve_elastic(self, query: str, k_total: int, threshold: int): |
| | raise NotImplementedError |
| |
|
| | @abstractmethod |
| | def _load_model(self, model_name: str, device: str): |
| | pass |
| |
|
| | @abstractmethod |
| | def _generate_embeddings(self, text: str): |
| | pass |
| |
|
| |
|
| | class SentenceTransformersTextEmbedder(TextEmbedder): |
| | def _load_model(self, model_name: str, device: str): |
| | model = SentenceTransformer(model_name) |
| | torch_device = torch.device(device) |
| | model.to(torch_device) |
| | return model |
| |
|
| | def _generate_embeddings(self, text: str): |
| | return self.model.encode(text) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|