LitBench-UI / src /retriever /retriever.py
Andreas Varvarigos
Upload 22 files
908351f verified
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
import json
import torch
from tqdm import tqdm
import os
import pandas as pd
import numpy as np
from datasets import load_dataset
from utils.utils import read_yaml_file
def generate_topic_level_embeddings(model, tokenizer, paper_list, tmp_id_2_abs):
id2topics = {
entry["paper_id"]: [entry["Level 1"], entry["Level 2"], entry["Level 3"]]
for entry in tmp_id_2_abs['train']
}
for topic_level in ['Level 1', 'Level 2', 'Level 3']:
i = 0
batch_size = 2048
candidate_emb_list = []
pbar = tqdm(total=len(paper_list))
while i < len(paper_list):
yield i / len(paper_list) / 3 if topic_level == 'Level 1' else 0.33 + i / len(paper_list) / 3 if topic_level == 'Level 2' else 0.66 + i / len(paper_list) / 3
paper_batch = paper_list[i:i+batch_size]
paper_text_batch = []
for paper_id in paper_batch:
topics = id2topics[paper_id][int(topic_level[6])-1]
topic_text = ''
for t in topics:
topic_text += t + ','
paper_text_batch.append(topic_text)
inputs = tokenizer(paper_text_batch, return_tensors='pt', padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs.to('cuda'))
candidate_embeddings = outputs.last_hidden_state[:, 0, :].cpu()
candidate_embeddings = candidate_embeddings.reshape(-1, 1024)
candidate_emb_list.append(candidate_embeddings)
i += len(candidate_embeddings)
pbar.update(len(candidate_embeddings))
all_candidate_embs = torch.cat(candidate_emb_list, 0)
df = pd.DataFrame({
"paper_id": paper_list,
"embedding": list(all_candidate_embs.numpy())
})
if not os.path.exists('datasets/topic_level_embeds'):
os.makedirs('datasets/topic_level_embeds')
df.to_parquet(f'datasets/topic_level_embeds/{topic_level}_emb.parquet', engine='pyarrow', compression='snappy')
all_candidate_embs_L1 = torch.tensor(np.array(pd.read_parquet('datasets/topic_level_embeds/Level 1_emb.parquet')['embedding'].tolist()))
all_candidate_embs_L2 = torch.tensor(np.array(pd.read_parquet('datasets/topic_level_embeds/Level 2_emb.parquet')['embedding'].tolist()))
all_candidate_embs_L3 = torch.tensor(np.array(pd.read_parquet('datasets/topic_level_embeds/Level 3_emb.parquet')['embedding'].tolist()))
all_candidate_embs = all_candidate_embs_L1 + all_candidate_embs_L2 + all_candidate_embs_L3
df = pd.DataFrame({
"paper_id": paper_list,
"embedding": list(all_candidate_embs.numpy())
})
df.to_parquet('datasets/topic_level_embeds/arxiv_papers_embeds.parquet', engine='pyarrow', compression='snappy')
def retriever(query, retrieval_nodes_path):
yield 0
config = read_yaml_file('configs/config.yaml')
# Load the model and tokenizer to generate the embeddings
embedder_name = config['retriever']['embedder']
tokenizer = AutoTokenizer.from_pretrained(embedder_name)
model = AutoModel.from_pretrained(embedder_name).to(device='cuda', dtype=torch.float16)
# Load the arXiv dataset
tmp_id_2_abs = load_dataset("AliMaatouk/arXiv_Topics", cache_dir="datasets/arxiv_topics")
paper_list = list(tmp_id_2_abs['train']['paper_id'])
# Generate the query embeddings
inputs = tokenizer([query], return_tensors='pt', padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs.to('cuda'))
query_embeddings = outputs.last_hidden_state[:, 0, :].cpu()
# Generate the candidate embeddings
# Load the embeddings from the dataset, otherwise generate the embeddings and save them
if config['retriever']['load_arxiv_embeds']:
dataset = load_dataset("AliMaatouk/arXiv-Topics-Embeddings", cache_dir="datasets/topic_level_embeds")
table = dataset["train"].data # Get PyArrow Table
all_candidate_embs = table.column("embedding").to_numpy()
else:
# If the file does not exist, generate the embeddings, otherwise, load the embeddings
if not os.path.exists('datasets/topic_level_embeds/arxiv_papers_embeds.parquet'):
yield from generate_topic_level_embeddings(model, tokenizer, paper_list, tmp_id_2_abs)
all_candidate_embs = torch.tensor(np.array(pd.read_parquet('datasets/topic_level_embeds/arxiv_papers_embeds.parquet')['embedding'].tolist()))
all_candidate_embs = all_candidate_embs.cpu().numpy()
all_candidate_embs = np.stack(all_candidate_embs)
# Calculate the cosine similarity between the query and all candidate embeddings
query_embeddings = np.array(query_embeddings)
similarity_scores = cosine_similarity(query_embeddings, all_candidate_embs)[0]
# Sort the papers by similarity scores and select the top K papers
id_score_list = []
for i in range(len(paper_list)):
id_score_list.append([paper_list[i], similarity_scores[i]])
sorted_scores = sorted(id_score_list, key=lambda i: i[-1], reverse = True)
top_K_paper = [sample[0] for sample in sorted_scores[:config['retriever']['num_retrievals']]]
papers_results = {
paper: True
for paper in top_K_paper
}
with open(retrieval_nodes_path, 'w') as f:
json.dump(papers_results, f)
yield 1.0