File size: 3,973 Bytes
0e016e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import collections

from torch.serialization import default_restore_location

from transformers import BertTokenizer, BertModel
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import numpy as np
import pickle
import argparse
import csv

nq_temp = {}

CheckpointState = collections.namedtuple("CheckpointState",
                                         ['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset', 'epoch',
                                          'encoder_params'])

def load_states_from_checkpoint(model_file: str) -> CheckpointState:
    state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, 'cpu'))
    return CheckpointState(**state_dict)

class DocPool(Dataset):
    def __init__(self, path):
        doc = []
        with open(path, "r", encoding="utf8") as f:
            lines = csv.reader(f, delimiter='\t')
            for _id, _text in lines:
                doc.append(_text)
        self.doc = doc
        

    def __len__(self):
        return len(self.doc)

    def __getitem__(self, index):
        doc = self.doc[index]
        return index, doc


def my_collate(batch):
    batch = list(zip(*batch))
    res = {'id': batch[0], 'doc': batch[1]}
    del batch
    return res


def extract_feature(args):
    torch.manual_seed(2024)
    torch.cuda.manual_seed(2024)
    np.random.seed(2024)
    if(args.doc_or_query == 'doc'):
        _path = './nq_doc.tsv'
        _out_path = './doc_embedding.pickle'
        _prefix = 'ctx_model.'
    else:
        _path = './nq_query.tsv'
        _out_path = './query_embedding.pickle'
        _prefix = 'question_model.'
    with torch.no_grad():
        doc_dataset = DocPool(_path)
        print(len(doc_dataset))
        doc_dataloader = DataLoader(dataset=doc_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=my_collate)
        tokenizer = BertTokenizer.from_pretrained(args.model_name)
        model = BertModel.from_pretrained(args.model_name, return_dict=True)
        saved_state = load_states_from_checkpoint(args.model_file)
        prefix_len = len(_prefix)
        ctx_state = {key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if
                     key.startswith(_prefix)}
        model.load_state_dict(ctx_state, strict=False)
        model = torch.nn.DataParallel(model)
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

        model = model.to(device)
        model.eval()

        ids = []
        idx = 0
        doc_feature = np.zeros((len(doc_dataset), 768))

        for batch_data in tqdm(doc_dataloader):
            doc_id = batch_data['id']
            doc_body = batch_data['doc']
            inputs = tokenizer(doc_body, padding=True, truncation=True, return_tensors="pt",
                               add_special_tokens=True).to(device)
            outputs = model(**inputs)
            pooler_output = outputs.last_hidden_state[:, 0]

            ids.extend(doc_id)
            doc_feature[idx: idx + pooler_output.shape[0]] = pooler_output.cpu().numpy()
            idx += pooler_output.shape[0]

        feature_dic = {}
        for i, id_i in enumerate(ids):
            feature_dic[id_i] = doc_feature[i]
        with open(_out_path, 'wb') as f:
            pickle.dump(feature_dic, f)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument('--batch_size', type=int, default=512,
                        help='minibatch size')
    parser.add_argument('--model_name', type=str, default='Luyu/co-condenser-wiki',
                        help='model name')
    parser.add_argument('--model_file', type=str, default='./dpr_biencoder.38.602',
                        help='model name')
    parser.add_argument('--doc_or_query', type=str, default='query',
                        help='transfer documents or queries')

    args = parser.parse_args()
    extract_feature(args)