Override _encode_sequences and _predict_logits to use HuggingFace tokenizer interface
Browse files- adapter.py +22 -1
adapter.py
CHANGED
|
@@ -191,12 +191,13 @@ class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScor
|
|
| 191 |
self.AbLang.train()
|
| 192 |
|
| 193 |
def _encode_sequences(self, seqs):
|
| 194 |
-
#
|
| 195 |
tokens = self.tokenizer(seqs, padding=True, return_tensors='pt')
|
| 196 |
tokens = extract_input_ids(tokens, self.used_device)
|
| 197 |
return self.AbRep(tokens).last_hidden_states.detach()
|
| 198 |
|
| 199 |
def _predict_logits(self, seqs):
|
|
|
|
| 200 |
tokens = self.tokenizer(seqs, padding=True, return_tensors='pt')
|
| 201 |
tokens = extract_input_ids(tokens, self.used_device)
|
| 202 |
output = self.AbLang(tokens)
|
|
@@ -204,6 +205,26 @@ class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScor
|
|
| 204 |
return output.last_hidden_state.detach()
|
| 205 |
return output.detach()
|
| 206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
def _preprocess_labels(self, labels):
|
| 208 |
labels = extract_input_ids(labels, self.used_device)
|
| 209 |
return labels
|
|
|
|
| 191 |
self.AbLang.train()
|
| 192 |
|
| 193 |
def _encode_sequences(self, seqs):
|
| 194 |
+
# Override to use HuggingFace tokenizer interface
|
| 195 |
tokens = self.tokenizer(seqs, padding=True, return_tensors='pt')
|
| 196 |
tokens = extract_input_ids(tokens, self.used_device)
|
| 197 |
return self.AbRep(tokens).last_hidden_states.detach()
|
| 198 |
|
| 199 |
def _predict_logits(self, seqs):
|
| 200 |
+
# Override to use HuggingFace tokenizer interface
|
| 201 |
tokens = self.tokenizer(seqs, padding=True, return_tensors='pt')
|
| 202 |
tokens = extract_input_ids(tokens, self.used_device)
|
| 203 |
output = self.AbLang(tokens)
|
|
|
|
| 205 |
return output.last_hidden_state.detach()
|
| 206 |
return output.detach()
|
| 207 |
|
| 208 |
+
def _predict_logits_with_step_masking(self, seqs):
|
| 209 |
+
# Override the stepwise masking method to use HuggingFace tokenizer
|
| 210 |
+
tokens = self.tokenizer(seqs, padding=True, return_tensors='pt')
|
| 211 |
+
tokens = extract_input_ids(tokens, self.used_device)
|
| 212 |
+
|
| 213 |
+
logits = []
|
| 214 |
+
for single_seq_tokens in tokens:
|
| 215 |
+
tkn_len = len(single_seq_tokens)
|
| 216 |
+
masked_tokens = single_seq_tokens.repeat(tkn_len, 1)
|
| 217 |
+
for num in range(tkn_len):
|
| 218 |
+
masked_tokens[num, num] = self.tokenizer.mask_token_id
|
| 219 |
+
|
| 220 |
+
with torch.no_grad():
|
| 221 |
+
logits_tmp = self.AbLang(masked_tokens)
|
| 222 |
+
|
| 223 |
+
logits_tmp = torch.stack([logits_tmp[num, num] for num in range(tkn_len)])
|
| 224 |
+
logits.append(logits_tmp)
|
| 225 |
+
|
| 226 |
+
return torch.stack(logits, dim=0)
|
| 227 |
+
|
| 228 |
def _preprocess_labels(self, labels):
|
| 229 |
labels = extract_input_ids(labels, self.used_device)
|
| 230 |
return labels
|