yeomtong commited on
Commit
7dcc257
·
verified ·
1 Parent(s): d29dbea

Update predictor_up.py

Browse files
Files changed (1) hide show
  1. predictor_up.py +15 -4
predictor_up.py CHANGED
@@ -250,7 +250,7 @@ def predict_srl_single(model, tokenizer, words, predicate_word_idx, id2label, de
250
  return tags, logits.squeeze(0).cpu()
251
 
252
 
253
- def _encode_sentence_once(words, tokenizer, max_length=500):
254
  enc = tokenizer(
255
  words,
256
  is_split_into_words=True,
@@ -264,13 +264,24 @@ def _encode_sentence_once(words, tokenizer, max_length=500):
264
  sent_wp_ids = enc["input_ids"]
265
  if isinstance(sent_wp_ids[0], list):
266
  sent_wp_ids = sent_wp_ids[0]
 
267
  wid = enc.word_ids()
 
268
  first_pos = {}
 
269
  for pos, w in enumerate(wid):
270
  if w is not None and w not in first_pos:
271
- first_pos[w] = pos + 1 # +1 for [CLS]
272
- n_words = len(words)
273
- word_first = torch.tensor([first_pos[i] for i in range(n_words)], dtype=torch.long)
 
 
 
 
 
 
 
 
274
  return sent_wp_ids, word_first, n_words
275
 
276
  @torch.no_grad()
 
250
  return tags, logits.squeeze(0).cpu()
251
 
252
 
253
+ def _encode_sentence_once(words, tokenizer, max_length=500):
254
  enc = tokenizer(
255
  words,
256
  is_split_into_words=True,
 
264
  sent_wp_ids = enc["input_ids"]
265
  if isinstance(sent_wp_ids[0], list):
266
  sent_wp_ids = sent_wp_ids[0]
267
+
268
  wid = enc.word_ids()
269
+
270
  first_pos = {}
271
+ kept_word_ids = []
272
  for pos, w in enumerate(wid):
273
  if w is not None and w not in first_pos:
274
+ first_pos[w] = pos + 1 # +1 for [CLS]
275
+ kept_word_ids.append(w)
276
+
277
+ kept_word_ids = sorted(kept_word_ids)
278
+ n_words = len(kept_word_ids)
279
+
280
+ word_first = torch.tensor(
281
+ [first_pos[w] for w in kept_word_ids],
282
+ dtype=torch.long
283
+ )
284
+
285
  return sent_wp_ids, word_first, n_words
286
 
287
  @torch.no_grad()