emanuelaboros commited on
Commit
a681df1
·
1 Parent(s): e7392bf

move to pregenerated tokens - some bug with word ids -- move to the inital ones

Browse files
Files changed (1) hide show
  1. generic_ner.py +14 -15
generic_ner.py CHANGED
@@ -261,27 +261,26 @@ def get_entities(tokens, tags, confidences, text):
261
  return entities
262
 
263
 
264
- def realign(
265
- word_ids, tokens, out_label_preds, softmax_scores, tokenizer, reverted_label_map
266
- ):
267
  preds_list, words_list, confidence_list = [], [], []
268
- # word_ids = tokenizer(tokens, is_split_into_words=True).word_ids()
269
- print('--'*20)
270
- print("word_ids", word_ids)
271
- print("tokens", tokens)
272
- print('--'*20)
273
- for idx, word in enumerate(tokens):
274
- beginning_index = word_ids.index(idx)
 
275
  try:
276
- preds_list.append(reverted_label_map[out_label_preds[beginning_index]])
277
- confidence_list.append(max(softmax_scores[beginning_index]))
278
- except Exception as ex: # the sentence was longer then max_length
279
  preds_list.append("O")
280
  confidence_list.append(0.0)
281
- words_list.append(word)
282
 
283
- return words_list, preds_list, confidence_list
284
 
 
285
 
286
  def add_spaces_around_punctuation(text):
287
  # Add a space before and after all punctuation
 
261
  return entities
262
 
263
 
264
+ def realign(word_ids, tokens, out_label_preds, softmax_scores, tokenizer, reverted_label_map):
 
 
265
  preds_list, words_list, confidence_list = [], [], []
266
+
267
+ seen_word_ids = set()
268
+ for i, word_id in enumerate(word_ids):
269
+ if word_id is None or word_id in seen_word_ids:
270
+ continue # skip special tokens or repeated subwords
271
+
272
+ seen_word_ids.add(word_id)
273
+
274
  try:
275
+ preds_list.append(reverted_label_map[out_label_preds[i]])
276
+ confidence_list.append(max(softmax_scores[i]))
277
+ except Exception:
278
  preds_list.append("O")
279
  confidence_list.append(0.0)
 
280
 
281
+ words_list.append(tokens[word_id]) # original word list index
282
 
283
+ return words_list, preds_list, confidence_list
284
 
285
  def add_spaces_around_punctuation(text):
286
  # Add a space before and after all punctuation