Komorebi660 commited on
Commit ·
0e016e5
1
Parent(s): c2e3b84
upload
Browse files- dpr_biencoder.38.602 +3 -0
- dump.py +114 -0
dpr_biencoder.38.602
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:69c370bcdd5630126ab222d4576678143a6c0c57ebb0b90e1ef5a48f2e0a0267
|
| 3 |
+
size 2627976303
|
dump.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
|
| 3 |
+
from torch.serialization import default_restore_location
|
| 4 |
+
|
| 5 |
+
from transformers import BertTokenizer, BertModel
|
| 6 |
+
import torch
|
| 7 |
+
from torch.utils.data import Dataset, DataLoader
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pickle
|
| 11 |
+
import argparse
|
| 12 |
+
import csv
|
| 13 |
+
|
| 14 |
+
nq_temp = {}
|
| 15 |
+
|
| 16 |
+
CheckpointState = collections.namedtuple("CheckpointState",
|
| 17 |
+
['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset', 'epoch',
|
| 18 |
+
'encoder_params'])
|
| 19 |
+
|
| 20 |
+
def load_states_from_checkpoint(model_file: str) -> CheckpointState:
|
| 21 |
+
state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, 'cpu'))
|
| 22 |
+
return CheckpointState(**state_dict)
|
| 23 |
+
|
| 24 |
+
class DocPool(Dataset):
|
| 25 |
+
def __init__(self, path):
|
| 26 |
+
doc = []
|
| 27 |
+
with open(path, "r", encoding="utf8") as f:
|
| 28 |
+
lines = csv.reader(f, delimiter='\t')
|
| 29 |
+
for _id, _text in lines:
|
| 30 |
+
doc.append(_text)
|
| 31 |
+
self.doc = doc
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def __len__(self):
|
| 35 |
+
return len(self.doc)
|
| 36 |
+
|
| 37 |
+
def __getitem__(self, index):
|
| 38 |
+
doc = self.doc[index]
|
| 39 |
+
return index, doc
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def my_collate(batch):
|
| 43 |
+
batch = list(zip(*batch))
|
| 44 |
+
res = {'id': batch[0], 'doc': batch[1]}
|
| 45 |
+
del batch
|
| 46 |
+
return res
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def extract_feature(args):
|
| 50 |
+
torch.manual_seed(2024)
|
| 51 |
+
torch.cuda.manual_seed(2024)
|
| 52 |
+
np.random.seed(2024)
|
| 53 |
+
if(args.doc_or_query == 'doc'):
|
| 54 |
+
_path = './nq_doc.tsv'
|
| 55 |
+
_out_path = './doc_embedding.pickle'
|
| 56 |
+
_prefix = 'ctx_model.'
|
| 57 |
+
else:
|
| 58 |
+
_path = './nq_query.tsv'
|
| 59 |
+
_out_path = './query_embedding.pickle'
|
| 60 |
+
_prefix = 'question_model.'
|
| 61 |
+
with torch.no_grad():
|
| 62 |
+
doc_dataset = DocPool(_path)
|
| 63 |
+
print(len(doc_dataset))
|
| 64 |
+
doc_dataloader = DataLoader(dataset=doc_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=my_collate)
|
| 65 |
+
tokenizer = BertTokenizer.from_pretrained(args.model_name)
|
| 66 |
+
model = BertModel.from_pretrained(args.model_name, return_dict=True)
|
| 67 |
+
saved_state = load_states_from_checkpoint(args.model_file)
|
| 68 |
+
prefix_len = len(_prefix)
|
| 69 |
+
ctx_state = {key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if
|
| 70 |
+
key.startswith(_prefix)}
|
| 71 |
+
model.load_state_dict(ctx_state, strict=False)
|
| 72 |
+
model = torch.nn.DataParallel(model)
|
| 73 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 74 |
+
|
| 75 |
+
model = model.to(device)
|
| 76 |
+
model.eval()
|
| 77 |
+
|
| 78 |
+
ids = []
|
| 79 |
+
idx = 0
|
| 80 |
+
doc_feature = np.zeros((len(doc_dataset), 768))
|
| 81 |
+
|
| 82 |
+
for batch_data in tqdm(doc_dataloader):
|
| 83 |
+
doc_id = batch_data['id']
|
| 84 |
+
doc_body = batch_data['doc']
|
| 85 |
+
inputs = tokenizer(doc_body, padding=True, truncation=True, return_tensors="pt",
|
| 86 |
+
add_special_tokens=True).to(device)
|
| 87 |
+
outputs = model(**inputs)
|
| 88 |
+
pooler_output = outputs.last_hidden_state[:, 0]
|
| 89 |
+
|
| 90 |
+
ids.extend(doc_id)
|
| 91 |
+
doc_feature[idx: idx + pooler_output.shape[0]] = pooler_output.cpu().numpy()
|
| 92 |
+
idx += pooler_output.shape[0]
|
| 93 |
+
|
| 94 |
+
feature_dic = {}
|
| 95 |
+
for i, id_i in enumerate(ids):
|
| 96 |
+
feature_dic[id_i] = doc_feature[i]
|
| 97 |
+
with open(_out_path, 'wb') as f:
|
| 98 |
+
pickle.dump(feature_dic, f)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
parser = argparse.ArgumentParser()
|
| 103 |
+
|
| 104 |
+
parser.add_argument('--batch_size', type=int, default=512,
|
| 105 |
+
help='minibatch size')
|
| 106 |
+
parser.add_argument('--model_name', type=str, default='Luyu/co-condenser-wiki',
|
| 107 |
+
help='model name')
|
| 108 |
+
parser.add_argument('--model_file', type=str, default='./dpr_biencoder.38.602',
|
| 109 |
+
help='model name')
|
| 110 |
+
parser.add_argument('--doc_or_query', type=str, default='query',
|
| 111 |
+
help='transfer documents or queries')
|
| 112 |
+
|
| 113 |
+
args = parser.parse_args()
|
| 114 |
+
extract_feature(args)
|