File size: 5,577 Bytes
908351f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
import json
import torch
from tqdm import tqdm
import os
import pandas as pd
import numpy as np
from datasets import load_dataset
from utils.utils import read_yaml_file


def generate_topic_level_embeddings(model, tokenizer, paper_list, tmp_id_2_abs):
    id2topics = {
        entry["paper_id"]: [entry["Level 1"], entry["Level 2"], entry["Level 3"]]
        for entry in tmp_id_2_abs['train']
    }

    for topic_level in ['Level 1', 'Level 2', 'Level 3']:
        i = 0
        batch_size = 2048
        candidate_emb_list = []
        pbar = tqdm(total=len(paper_list))
        while i < len(paper_list):
            yield i / len(paper_list) / 3 if topic_level == 'Level 1' else 0.33 + i / len(paper_list) / 3 if topic_level == 'Level 2' else 0.66 + i / len(paper_list) / 3
            paper_batch = paper_list[i:i+batch_size]
            paper_text_batch = []
            for paper_id in paper_batch:
                topics = id2topics[paper_id][int(topic_level[6])-1]
                topic_text = ''
                for t in topics:
                    topic_text += t + ','
                paper_text_batch.append(topic_text)
            inputs = tokenizer(paper_text_batch, return_tensors='pt', padding=True, truncation=True)
            with torch.no_grad():
                outputs = model(**inputs.to('cuda'))
                candidate_embeddings = outputs.last_hidden_state[:, 0, :].cpu()
                candidate_embeddings = candidate_embeddings.reshape(-1, 1024)
                candidate_emb_list.append(candidate_embeddings)

                i += len(candidate_embeddings)
                pbar.update(len(candidate_embeddings))

        all_candidate_embs = torch.cat(candidate_emb_list, 0)
        
        df = pd.DataFrame({
            "paper_id": paper_list,
            "embedding": list(all_candidate_embs.numpy())
        })
        
        if not os.path.exists('datasets/topic_level_embeds'):
            os.makedirs('datasets/topic_level_embeds')

        df.to_parquet(f'datasets/topic_level_embeds/{topic_level}_emb.parquet', engine='pyarrow', compression='snappy')
        
    all_candidate_embs_L1 = torch.tensor(np.array(pd.read_parquet('datasets/topic_level_embeds/Level 1_emb.parquet')['embedding'].tolist()))
    all_candidate_embs_L2 = torch.tensor(np.array(pd.read_parquet('datasets/topic_level_embeds/Level 2_emb.parquet')['embedding'].tolist()))
    all_candidate_embs_L3 = torch.tensor(np.array(pd.read_parquet('datasets/topic_level_embeds/Level 3_emb.parquet')['embedding'].tolist()))
    all_candidate_embs = all_candidate_embs_L1 + all_candidate_embs_L2 + all_candidate_embs_L3
    
    df = pd.DataFrame({
        "paper_id": paper_list,
        "embedding": list(all_candidate_embs.numpy())
    })
    
    df.to_parquet('datasets/topic_level_embeds/arxiv_papers_embeds.parquet', engine='pyarrow', compression='snappy')



def retriever(query, retrieval_nodes_path):    
    yield 0
    config = read_yaml_file('configs/config.yaml')

    # Load the model and tokenizer to generate the embeddings
    embedder_name = config['retriever']['embedder']
    tokenizer = AutoTokenizer.from_pretrained(embedder_name)
    model = AutoModel.from_pretrained(embedder_name).to(device='cuda', dtype=torch.float16)


    # Load the arXiv dataset
    tmp_id_2_abs = load_dataset("AliMaatouk/arXiv_Topics", cache_dir="datasets/arxiv_topics")
    paper_list = list(tmp_id_2_abs['train']['paper_id'])


    # Generate the query embeddings
    inputs = tokenizer([query], return_tensors='pt', padding=True, truncation=True)
    with torch.no_grad():
        outputs = model(**inputs.to('cuda'))
        query_embeddings = outputs.last_hidden_state[:, 0, :].cpu()

    # Generate the candidate embeddings
    # Load the embeddings from the dataset, otherwise generate the embeddings and save them
    if config['retriever']['load_arxiv_embeds']:
        dataset = load_dataset("AliMaatouk/arXiv-Topics-Embeddings", cache_dir="datasets/topic_level_embeds")
        table = dataset["train"].data  # Get PyArrow Table
        all_candidate_embs = table.column("embedding").to_numpy()
    else:        
        # If the file does not exist, generate the embeddings, otherwise, load the embeddings
        if not os.path.exists('datasets/topic_level_embeds/arxiv_papers_embeds.parquet'):
            yield from generate_topic_level_embeddings(model, tokenizer, paper_list, tmp_id_2_abs)
        
        all_candidate_embs = torch.tensor(np.array(pd.read_parquet('datasets/topic_level_embeds/arxiv_papers_embeds.parquet')['embedding'].tolist()))
        all_candidate_embs = all_candidate_embs.cpu().numpy()

    all_candidate_embs = np.stack(all_candidate_embs)


    # Calculate the cosine similarity between the query and all candidate embeddings
    query_embeddings = np.array(query_embeddings)
    similarity_scores = cosine_similarity(query_embeddings, all_candidate_embs)[0]


    # Sort the papers by similarity scores and select the top K papers
    id_score_list = []
    for i in range(len(paper_list)):
        id_score_list.append([paper_list[i], similarity_scores[i]])

    sorted_scores = sorted(id_score_list, key=lambda i: i[-1], reverse = True)
    top_K_paper = [sample[0] for sample in sorted_scores[:config['retriever']['num_retrievals']]]

    papers_results = {
        paper: True
        for paper in top_K_paper
    }

    with open(retrieval_nodes_path, 'w') as f:
        json.dump(papers_results, f)

    yield 1.0