| from datasets import load_dataset | |
| from config import CONFIG | |
| from rank_bm25 import BM25Okapi | |
| from sentence_transformers import SentenceTransformer, util | |
| class Retriever: | |
| def __init__(self): | |
| self.corpus = None | |
| self.bm25 = None | |
| self.model = None | |
| self.chunk_embeddings = None | |
| def load_and_prepare_dataset(self): | |
| dataset = load_dataset(CONFIG['DATASET']) | |
| dataset = dataset['train'].select(range(CONFIG['MAX_NUM_OF_RECORDS'])) | |
| dataset = dataset.map(lambda x: {'chunks': self.chunk_text(x['abstract'])}) | |
| self.corpus = [chunk for chunks in dataset["chunks"] for chunk in chunks] | |
| def prepare_bm25(self): | |
| tokenized_corpus = [doc.split(" ") for doc in self.corpus] | |
| self.bm25 = BM25Okapi(tokenized_corpus) | |
| def compute_embeddings(self): | |
| self.model = SentenceTransformer('all-MiniLM-L6-v2') | |
| tokenizer = self.model._first_module().tokenizer | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| self.chunk_embeddings = self.model.encode(self.corpus, convert_to_tensor=True) | |
| def chunk_text(self, text, chunk_size=CONFIG['CHUNK_SIZE']): | |
| words = text.split() | |
| return [' '.join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)] | |
| def retrieve_documents_bm25(self, query): | |
| tokenized_query = query.split(" ") | |
| scores = self.bm25.get_scores(tokenized_query) | |
| top_docs = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:CONFIG['TOP_DOCS']] | |
| return [self.corpus[i] for i in top_docs] | |
| def retrieve_documents_semantic(self, query): | |
| query_embedding = self.model.encode(query, convert_to_tensor=True) | |
| scores = util.pytorch_cos_sim(query_embedding, self.chunk_embeddings)[0] | |
| top_chunks = scores.topk(CONFIG['TOP_DOCS']).indices | |
| return [self.corpus[i] for i in top_chunks] |