Safetensors
daviddongdong commited on
Commit
349d756
·
verified ·
1 Parent(s): 0f0a5a9

Create text_wrapper.py

Browse files
Files changed (1) hide show
  1. text_wrapper.py +305 -0
text_wrapper.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+
5
+ class Sent_Retriever:
6
+ def __init__(self, bs=256, use_gpu=True):
7
+ self.bs = bs
8
+ self.device = torch.device("cuda" if (torch.cuda.is_available() and use_gpu) else "cpu")
9
+
10
+ def embed_passages(self, passages, prefix=""):
11
+ if prefix != "":
12
+ passages = [prefix + item for item in passages]
13
+ embeddings = []
14
+ with torch.no_grad():
15
+ for i in tqdm(range(0, len(passages), self.bs)):
16
+ batch_passage = passages[i:(i + self.bs)]
17
+ emb = self.model.encode(batch_passage, normalize_embeddings=True)
18
+ embeddings.extend(emb)
19
+ return embeddings
20
+
21
+ def score(self, queries, quotes):
22
+ query_emb = np.asarray(self.embed_queries(queries))
23
+ quote_emb = np.asarray(self.embed_quotes(quotes))
24
+ return (query_emb @ quote_emb.T).tolist()
25
+
26
+ def get_tok_len(self, text_input):
27
+ return self.model._first_module().tokenizer(
28
+ text=[text_input],
29
+ truncation=False, max_length=False, return_tensors="pt"
30
+ )["input_ids"].size()[-1]
31
+
32
+
33
+ class BGE(Sent_Retriever):
34
+ def __init__(self, bs=256, use_gpu=True, model_path="checkpoint/bge-large-en-v1.5"):
35
+ from sentence_transformers import SentenceTransformer
36
+ super().__init__(bs=bs, use_gpu=use_gpu)
37
+ self.model_path = model_path
38
+ self.model = SentenceTransformer(self.model_path)
39
+ print("[text_wrapper.py - init] Setting up BGE...")
40
+ print("[text_wrapper.py - init] BGE is loaded from '{}'...".format( self.model_path ))
41
+ self.model.eval()
42
+ self.model = self.model.to(self.device)
43
+
44
+ def embed_queries(self, queries):
45
+ prefix = "Represent this sentence for searching relevant passages:"
46
+ if isinstance(queries, str): queries = [queries]
47
+ return self.embed_passages(queries, prefix)
48
+
49
+ def embed_quotes(self, quotes):
50
+ if isinstance(quotes, str): quotes = [quotes]
51
+ return self.embed_passages(quotes)
52
+
53
+
54
+ class E5(Sent_Retriever):
55
+ def __init__(self, bs=256, use_gpu=True, model_path="checkpoint/e5-large-v2"):
56
+ from sentence_transformers import SentenceTransformer
57
+ super().__init__(bs=bs, use_gpu=use_gpu)
58
+ self.model_path = model_path
59
+ self.model = SentenceTransformer(self.model_path)
60
+ print("[text_wrapper.py - init] Setting up E5...")
61
+ print("[text_wrapper.py - init] E5 is loaded from '{}'...".format( self.model_path ))
62
+ self.model.eval()
63
+ self.model = self.model.to(self.device)
64
+
65
+ def embed_queries(self, queries):
66
+ prefix = "query:"
67
+ if isinstance(queries, str): queries = [queries]
68
+ return self.embed_passages(queries, prefix)
69
+
70
+ def embed_quotes(self, quotes):
71
+ prefix = "passage: "
72
+ if isinstance(quotes, str): quotes = [quotes]
73
+ return self.embed_passages(quotes, prefix)
74
+
75
+
76
+ class GTE(Sent_Retriever):
77
+ def __init__(self, bs=256, use_gpu=True, model_path="checkpoint/gte-large"):
78
+ from sentence_transformers import SentenceTransformer
79
+ super().__init__(bs=bs, use_gpu=use_gpu)
80
+ self.model_path = model_path
81
+ self.model = SentenceTransformer(self.model_path)
82
+ print("[text_wrapper.py - init] Setting up GTE...")
83
+ print("[text_wrapper.py - init] GTE is loaded from '{}'...".format( self.model_path ))
84
+ self.model.eval()
85
+ self.model = self.model.to(self.device)
86
+
87
+ def embed_queries(self, queries):
88
+ if isinstance(queries, str): queries = [queries]
89
+ return self.embed_passages(queries)
90
+
91
+ def embed_quotes(self, quotes):
92
+ if isinstance(quotes, str): quotes = [quotes]
93
+ return self.embed_passages(quotes)
94
+
95
+
96
+
97
+ class Contriever():
98
+ def __init__(self, bs = 256, use_gpu= True):
99
+ from transformers import AutoTokenizer, AutoModel
100
+ self.model_path = 'checkpoint/contriever-msmarco'
101
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
102
+ self.model = AutoModel.from_pretrained(self.model_path)
103
+ self.bs = bs
104
+ self.device = torch.device("cuda" if (torch.cuda.is_available() and use_gpu) else "cpu")
105
+ print("[text_wrapper.py - init] Setting up Contriever...")
106
+ print("[text_wrapper.py - init] Contriever is loaded from '{}'...".format( self.model_path ))
107
+ self.model.eval()
108
+ self.model = self.model.to(self.device)
109
+
110
+ def mean_pooling(self, token_embeddings, mask):
111
+ token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
112
+ sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
113
+ return sentence_embeddings
114
+
115
+ def embed_queries(self, query):
116
+ return self.embed_passages(query)
117
+
118
+ def embed_quotes(self, quotes):
119
+ return self.embed_passages(quotes)
120
+
121
+ def embed_passages(self, quotes):
122
+ if isinstance(quotes, str): quotes = [quotes]
123
+ quote_embeddings = []
124
+ with torch.no_grad():
125
+ for i in tqdm(range(0, len(quotes), self.bs)):
126
+ batch_quotes = quotes[i:(i + self.bs)]
127
+ encoded_quotes = self.tokenizer.batch_encode_plus(
128
+ batch_quotes, return_tensors = "pt",
129
+ max_length = 512, padding = True, truncation = True)
130
+ encoded_data = {k: v.to(self.device) for k, v in encoded_quotes.items()}
131
+ batched_outputs = self.model(**encoded_data)
132
+ batched_quote_embs = self.mean_pooling(batched_outputs[0], encoded_data['attention_mask'])
133
+ quote_embeddings.extend([q.cpu().detach().numpy() for q in batched_quote_embs])
134
+ return quote_embeddings
135
+
136
+ def score(self, query, quotes):
137
+ query_emb = np.asarray(self.embed_queries(query))
138
+ quote_emb = np.asarray(self.embed_quotes(quotes))
139
+ scores = (query_emb @ quote_emb.T).tolist()
140
+ return scores
141
+
142
+
143
+ class DPR():
144
+ def __init__(self, bs = 256, use_gpu= True):
145
+ from transformers import DPRContextEncoder, DPRContextEncoderTokenizer, DPRQuestionEncoder, DPRQuestionEncoderTokenizer
146
+ self.model_path = "checkpoint/"
147
+ self.query_tok = DPRQuestionEncoderTokenizer.from_pretrained(self.model_path +"dpr-question_encoder-multiset-base")
148
+ self.query_enc = DPRQuestionEncoder.from_pretrained(self.model_path +"dpr-question_encoder-multiset-base")
149
+ self.ctx_tok = DPRContextEncoderTokenizer.from_pretrained(self.model_path +"dpr-ctx_encoder-multiset-base")
150
+ self.ctx_enc = DPRContextEncoder.from_pretrained(self.model_path +"dpr-ctx_encoder-multiset-base")
151
+ self.bs = bs
152
+ print("[text_wrapper.py - init] Setting up DPR...")
153
+ print("[text_wrapper.py - init] DPR is loaded from '{}'...".format( self.model_path ))
154
+ self.device = torch.device("cuda" if (torch.cuda.is_available() and use_gpu) else "cpu")
155
+ self.query_enc.eval()
156
+ self.query_enc = self.query_enc.to(self.device)
157
+ self.ctx_enc.eval()
158
+ self.ctx_enc = self.ctx_enc.to(self.device)
159
+
160
+ def embed_queries(self, queries):
161
+ if isinstance(queries, str): queries = [queries]
162
+ query_embeddings = []
163
+ with torch.no_grad():
164
+ for i in tqdm(range(0, len(queries), self.bs)):
165
+ batch_queries = queries[i:(i + self.bs)]
166
+ encoded_query = self.query_tok.batch_encode_plus(
167
+ batch_queries, truncation=True, padding=True,
168
+ return_tensors='pt', max_length=512)
169
+ encoded_data = {k : v.cuda() for k, v in encoded_query.items()}
170
+ query_emb = self.query_enc(**encoded_data).pooler_output
171
+ query_emb = [q.cpu().detach().numpy() for q in query_emb]
172
+ query_embeddings.extend(query_emb)
173
+ return query_embeddings
174
+
175
+ def embed_quotes(self, quotes):
176
+ if isinstance(quotes, str): quotes = [quotes]
177
+ quote_embeddings = []
178
+ with torch.no_grad():
179
+ for i in tqdm(range(0, len(quotes), self.bs)):
180
+ batch_quotes = quotes[i:(i + self.bs)]
181
+ encoded_ctx = self.ctx_tok.batch_encode_plus(
182
+ batch_quotes, truncation=True, padding=True,
183
+ return_tensors='pt', max_length=512)
184
+ encoded_data = {k: v.cuda() for k, v in encoded_ctx.items()}
185
+ quote_emb = self.ctx_enc(**encoded_data).pooler_output
186
+ quote_emb = [q.cpu().detach().numpy() for q in quote_emb]
187
+ quote_embeddings.extend(quote_emb)
188
+ return quote_embeddings
189
+
190
+ def score(self, query, quotes):
191
+ query_emb = np.asarray(self.embed_queries(query))
192
+ quote_emb = np.asarray(self.embed_quotes(quotes))
193
+ scores = (query_emb @ quote_emb.T).tolist()
194
+ return scores
195
+
196
+
197
+ class ColBERTReranker:
198
+ def __init__(self, bs = 256, use_gpu= True):
199
+ from colbert.modeling.colbert import ColBERT
200
+ from colbert.infra import ColBERTConfig
201
+ from transformers import AutoTokenizer
202
+ self.model_path = "checkpoint/colbertv2.0"
203
+ self.bs = bs
204
+ config = ColBERTConfig(bsize=bs, root='./', query_token_id='[Q]', doc_token_id='[D]')
205
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
206
+ self.model = ColBERT(name=self.model_path, colbert_config=config)
207
+ self.doc_token_id = self.tokenizer.convert_tokens_to_ids(config.doc_token_id)
208
+ self.query_token_id = self.tokenizer.convert_tokens_to_ids(config.query_token_id)
209
+ self.add_special_tokens = True
210
+ self.device = torch.device("cuda" if (torch.cuda.is_available() and use_gpu) else "cpu")
211
+ print("[text_wrapper.py - init] Setting up ColBERT Reranker...")
212
+ print("[text_wrapper.py - init] ColBERT is loaded from '{}'...".format( self.model_path ))
213
+ self.model.eval()
214
+ self.model = self.model.to(self.device)
215
+
216
+ def embed_queries(self, queries):
217
+ if isinstance(queries, str): queries = [queries]
218
+ query_embeddings = []
219
+ query = ['. ' + item for item in queries] # placeholder for query emb
220
+ with torch.no_grad():
221
+ for i in tqdm(range(0, len(queries), self.bs)):
222
+ batch_queries = queries[i:(i + self.bs)]
223
+ encoded_query = self.tokenizer.batch_encode_plus(
224
+ batch_queries, max_length = 512, padding=True, truncation=True,
225
+ add_special_tokens=self.add_special_tokens, return_tensors='pt')
226
+ encoded_data = {k: v.to(self.device) for k, v in encoded_query.items()}
227
+ encoded_data['input_ids'][:, 1] = self.query_token_id
228
+ batch_query_emb = self.model.query(encoded_data['input_ids'], encoded_data['attention_mask'])
229
+
230
+ for emb, mask in zip(batch_query_emb, encoded_data['attention_mask']):
231
+ length = mask.sum().item() # Number of true tokens in this sequence
232
+ np_emb = emb[:length].cpu().numpy() # Shape: [L, H]
233
+ query_embeddings.append(np_emb) # `L` varies per example
234
+
235
+ # torch.cuda.empty_cache()
236
+ return query_embeddings
237
+
238
+ @staticmethod
239
+ def pad_tok_len(quote_embeddings, pad_value=0):
240
+ lengths = [e.shape[0] for e in quote_embeddings]
241
+ max_len = max(lengths)
242
+ N, H = len(quote_embeddings), quote_embeddings[0].shape[1]
243
+ padded_embeddings = np.full((N, max_len, H), pad_value, dtype=quote_embeddings[0].dtype)
244
+ padded_masks = np.zeros((N, max_len), dtype=np.int64)
245
+ for i, (emb, length) in enumerate(zip(quote_embeddings, lengths)):
246
+ padded_embeddings[i, :length, :] = emb
247
+ padded_masks[i, :length] = 1
248
+ return padded_embeddings, padded_masks
249
+
250
+ def embed_quotes(self, quotes, pad_token_len = False):
251
+ quote_embeddings = []
252
+ quote_masks = []
253
+ quotes = ['. ' + quote for quote in quotes]
254
+ with torch.no_grad():
255
+ # placeholder for query emb
256
+ for i in tqdm(range(0, len(quotes), self.bs)):
257
+ batch_quotes = quotes[i:(i + self.bs)]
258
+ encoded_quotes = self.tokenizer.batch_encode_plus(
259
+ batch_quotes, return_tensors = "pt",
260
+ max_length = 512, padding = True, truncation = True)
261
+ encoded_data = {k: v.to(self.device) for k, v in encoded_quotes.items()}
262
+ encoded_data['input_ids'][:, 1] = self.doc_token_id
263
+ # bz x # max num_token in batch x 128
264
+ batched_quote_embs = self.model.doc(encoded_data['input_ids'], encoded_data['attention_mask'])
265
+
266
+ for emb, mask in zip(batched_quote_embs, encoded_data['attention_mask']):
267
+ length = mask.sum().item() # Number of true tokens in this sequence
268
+ np_emb = emb[:length].cpu().numpy() # Shape: [L, H]
269
+ quote_embeddings.append(np_emb) # `L` varies per example
270
+
271
+ # max length of quotes could differ between different batches
272
+ if pad_token_len:
273
+ quote_embeddings, quote_masks = self.pad_tok_len(quote_embeddings)
274
+ return quote_embeddings, quote_masks
275
+ return quote_embeddings
276
+
277
+
278
+ @staticmethod
279
+ def colbert_score(query_embed, quote_embeddings, quote_masks):
280
+ Q, H = query_embed.shape # [Q, H]
281
+ N, L, _ = quote_embeddings.shape # [N, L, H]
282
+ # 1. Compute [Q, N, L] (similarity btw every query token to every quote token)
283
+ # Expand query to [Q, 1, 1, H], quote_embeddings to [1, N, L, H]
284
+ query_expanded = query_embed[:, np.newaxis, np.newaxis, :] # [Q, 1, 1, H]
285
+ quote_expanded = quote_embeddings[np.newaxis, :, :, :] # [1, N, L, H]
286
+ sim = np.matmul(query_expanded, np.transpose(quote_expanded, (0 ,1 ,3 ,2))) # (Q, N, 1, L)
287
+ # But let's use broadcasting for dot product:
288
+ # sim[q, n, l] = np.dot(query_embed[q], quote_embeddings[n,l])
289
+ sim = np.einsum('qh,nlh->qnl', query_embed, quote_embeddings) # [Q, N, L]
290
+ # 2. Mask invalid tokens
291
+ sim = np.where(quote_masks[np.newaxis, :, : ]==1, sim, -1e9) # [Q, N, L]
292
+ # 3. MaxSim: For each query token, take max over quote tokens (L dimension)
293
+ maxsim = sim.max(-1) # [Q, N]
294
+ # 4. Aggregate (sum over query tokens)
295
+ scores = maxsim.sum(axis=0) # [N]
296
+ return scores
297
+
298
+ def score(self, query, quotes):
299
+ query_embeddings = self.embed_queries(query)
300
+ quote_embeddings, quote_masks = self.embed_quotes(quotes, pad_token_len=True)
301
+ scores_list = []
302
+ for query_embed in query_embeddings:
303
+ scores = self.colbert_score(query_embed, quote_embeddings, quote_masks)
304
+ scores_list.append(scores.tolist())
305
+ return scores_list