Update predictor_up.py
Browse files- predictor_up.py +15 -4
predictor_up.py
CHANGED
|
@@ -250,7 +250,7 @@ def predict_srl_single(model, tokenizer, words, predicate_word_idx, id2label, de
|
|
| 250 |
return tags, logits.squeeze(0).cpu()
|
| 251 |
|
| 252 |
|
| 253 |
-
def _encode_sentence_once(words, tokenizer, max_length=500):
|
| 254 |
enc = tokenizer(
|
| 255 |
words,
|
| 256 |
is_split_into_words=True,
|
|
@@ -264,13 +264,24 @@ def _encode_sentence_once(words, tokenizer, max_length=500):
|
|
| 264 |
sent_wp_ids = enc["input_ids"]
|
| 265 |
if isinstance(sent_wp_ids[0], list):
|
| 266 |
sent_wp_ids = sent_wp_ids[0]
|
|
|
|
| 267 |
wid = enc.word_ids()
|
|
|
|
| 268 |
first_pos = {}
|
|
|
|
| 269 |
for pos, w in enumerate(wid):
|
| 270 |
if w is not None and w not in first_pos:
|
| 271 |
-
first_pos[w] = pos + 1
|
| 272 |
-
|
| 273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
return sent_wp_ids, word_first, n_words
|
| 275 |
|
| 276 |
@torch.no_grad()
|
|
|
|
| 250 |
return tags, logits.squeeze(0).cpu()
|
| 251 |
|
| 252 |
|
| 253 |
+
def _encode_sentence_once(words, tokenizer, max_length=500):
|
| 254 |
enc = tokenizer(
|
| 255 |
words,
|
| 256 |
is_split_into_words=True,
|
|
|
|
| 264 |
sent_wp_ids = enc["input_ids"]
|
| 265 |
if isinstance(sent_wp_ids[0], list):
|
| 266 |
sent_wp_ids = sent_wp_ids[0]
|
| 267 |
+
|
| 268 |
wid = enc.word_ids()
|
| 269 |
+
|
| 270 |
first_pos = {}
|
| 271 |
+
kept_word_ids = []
|
| 272 |
for pos, w in enumerate(wid):
|
| 273 |
if w is not None and w not in first_pos:
|
| 274 |
+
first_pos[w] = pos + 1 # +1 for [CLS]
|
| 275 |
+
kept_word_ids.append(w)
|
| 276 |
+
|
| 277 |
+
kept_word_ids = sorted(kept_word_ids)
|
| 278 |
+
n_words = len(kept_word_ids)
|
| 279 |
+
|
| 280 |
+
word_first = torch.tensor(
|
| 281 |
+
[first_pos[w] for w in kept_word_ids],
|
| 282 |
+
dtype=torch.long
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
return sent_wp_ids, word_first, n_words
|
| 286 |
|
| 287 |
@torch.no_grad()
|