hemantn commited on
Commit
88ddbb1
·
1 Parent(s): ed12887

Override _encode_sequences and _predict_logits to use HuggingFace tokenizer interface

Browse files
Files changed (1) hide show
  1. adapter.py +22 -1
adapter.py CHANGED
@@ -191,12 +191,13 @@ class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScor
191
  self.AbLang.train()
192
 
193
  def _encode_sequences(self, seqs):
194
- # Use HuggingFace-style padding and return PyTorch tensors
195
  tokens = self.tokenizer(seqs, padding=True, return_tensors='pt')
196
  tokens = extract_input_ids(tokens, self.used_device)
197
  return self.AbRep(tokens).last_hidden_states.detach()
198
 
199
  def _predict_logits(self, seqs):
 
200
  tokens = self.tokenizer(seqs, padding=True, return_tensors='pt')
201
  tokens = extract_input_ids(tokens, self.used_device)
202
  output = self.AbLang(tokens)
@@ -204,6 +205,26 @@ class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScor
204
  return output.last_hidden_state.detach()
205
  return output.detach()
206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  def _preprocess_labels(self, labels):
208
  labels = extract_input_ids(labels, self.used_device)
209
  return labels
 
191
  self.AbLang.train()
192
 
193
  def _encode_sequences(self, seqs):
194
+ # Override to use HuggingFace tokenizer interface
195
  tokens = self.tokenizer(seqs, padding=True, return_tensors='pt')
196
  tokens = extract_input_ids(tokens, self.used_device)
197
  return self.AbRep(tokens).last_hidden_states.detach()
198
 
199
  def _predict_logits(self, seqs):
200
+ # Override to use HuggingFace tokenizer interface
201
  tokens = self.tokenizer(seqs, padding=True, return_tensors='pt')
202
  tokens = extract_input_ids(tokens, self.used_device)
203
  output = self.AbLang(tokens)
 
205
  return output.last_hidden_state.detach()
206
  return output.detach()
207
 
208
+ def _predict_logits_with_step_masking(self, seqs):
209
+ # Override the stepwise masking method to use HuggingFace tokenizer
210
+ tokens = self.tokenizer(seqs, padding=True, return_tensors='pt')
211
+ tokens = extract_input_ids(tokens, self.used_device)
212
+
213
+ logits = []
214
+ for single_seq_tokens in tokens:
215
+ tkn_len = len(single_seq_tokens)
216
+ masked_tokens = single_seq_tokens.repeat(tkn_len, 1)
217
+ for num in range(tkn_len):
218
+ masked_tokens[num, num] = self.tokenizer.mask_token_id
219
+
220
+ with torch.no_grad():
221
+ logits_tmp = self.AbLang(masked_tokens)
222
+
223
+ logits_tmp = torch.stack([logits_tmp[num, num] for num in range(tkn_len)])
224
+ logits.append(logits_tmp)
225
+
226
+ return torch.stack(logits, dim=0)
227
+
228
  def _preprocess_labels(self, labels):
229
  labels = extract_input_ids(labels, self.used_device)
230
  return labels