Komorebi660 commited on
Commit
0e016e5
·
1 Parent(s): c2e3b84
Files changed (2) hide show
  1. dpr_biencoder.38.602 +3 -0
  2. dump.py +114 -0
dpr_biencoder.38.602 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69c370bcdd5630126ab222d4576678143a6c0c57ebb0b90e1ef5a48f2e0a0267
3
+ size 2627976303
dump.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+
3
+ from torch.serialization import default_restore_location
4
+
5
+ from transformers import BertTokenizer, BertModel
6
+ import torch
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+ import pickle
11
+ import argparse
12
+ import csv
13
+
14
+ nq_temp = {}
15
+
16
+ CheckpointState = collections.namedtuple("CheckpointState",
17
+ ['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset', 'epoch',
18
+ 'encoder_params'])
19
+
20
+ def load_states_from_checkpoint(model_file: str) -> CheckpointState:
21
+ state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, 'cpu'))
22
+ return CheckpointState(**state_dict)
23
+
24
+ class DocPool(Dataset):
25
+ def __init__(self, path):
26
+ doc = []
27
+ with open(path, "r", encoding="utf8") as f:
28
+ lines = csv.reader(f, delimiter='\t')
29
+ for _id, _text in lines:
30
+ doc.append(_text)
31
+ self.doc = doc
32
+
33
+
34
+ def __len__(self):
35
+ return len(self.doc)
36
+
37
+ def __getitem__(self, index):
38
+ doc = self.doc[index]
39
+ return index, doc
40
+
41
+
42
+ def my_collate(batch):
43
+ batch = list(zip(*batch))
44
+ res = {'id': batch[0], 'doc': batch[1]}
45
+ del batch
46
+ return res
47
+
48
+
49
+ def extract_feature(args):
50
+ torch.manual_seed(2024)
51
+ torch.cuda.manual_seed(2024)
52
+ np.random.seed(2024)
53
+ if(args.doc_or_query == 'doc'):
54
+ _path = './nq_doc.tsv'
55
+ _out_path = './doc_embedding.pickle'
56
+ _prefix = 'ctx_model.'
57
+ else:
58
+ _path = './nq_query.tsv'
59
+ _out_path = './query_embedding.pickle'
60
+ _prefix = 'question_model.'
61
+ with torch.no_grad():
62
+ doc_dataset = DocPool(_path)
63
+ print(len(doc_dataset))
64
+ doc_dataloader = DataLoader(dataset=doc_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=my_collate)
65
+ tokenizer = BertTokenizer.from_pretrained(args.model_name)
66
+ model = BertModel.from_pretrained(args.model_name, return_dict=True)
67
+ saved_state = load_states_from_checkpoint(args.model_file)
68
+ prefix_len = len(_prefix)
69
+ ctx_state = {key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if
70
+ key.startswith(_prefix)}
71
+ model.load_state_dict(ctx_state, strict=False)
72
+ model = torch.nn.DataParallel(model)
73
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
74
+
75
+ model = model.to(device)
76
+ model.eval()
77
+
78
+ ids = []
79
+ idx = 0
80
+ doc_feature = np.zeros((len(doc_dataset), 768))
81
+
82
+ for batch_data in tqdm(doc_dataloader):
83
+ doc_id = batch_data['id']
84
+ doc_body = batch_data['doc']
85
+ inputs = tokenizer(doc_body, padding=True, truncation=True, return_tensors="pt",
86
+ add_special_tokens=True).to(device)
87
+ outputs = model(**inputs)
88
+ pooler_output = outputs.last_hidden_state[:, 0]
89
+
90
+ ids.extend(doc_id)
91
+ doc_feature[idx: idx + pooler_output.shape[0]] = pooler_output.cpu().numpy()
92
+ idx += pooler_output.shape[0]
93
+
94
+ feature_dic = {}
95
+ for i, id_i in enumerate(ids):
96
+ feature_dic[id_i] = doc_feature[i]
97
+ with open(_out_path, 'wb') as f:
98
+ pickle.dump(feature_dic, f)
99
+
100
+
101
+ if __name__ == "__main__":
102
+ parser = argparse.ArgumentParser()
103
+
104
+ parser.add_argument('--batch_size', type=int, default=512,
105
+ help='minibatch size')
106
+ parser.add_argument('--model_name', type=str, default='Luyu/co-condenser-wiki',
107
+ help='model name')
108
+ parser.add_argument('--model_file', type=str, default='./dpr_biencoder.38.602',
109
+ help='model name')
110
+ parser.add_argument('--doc_or_query', type=str, default='query',
111
+ help='transfer documents or queries')
112
+
113
+ args = parser.parse_args()
114
+ extract_feature(args)