add index from tokenizer
Browse files- BertForJointParsing.py +10 -47
BertForJointParsing.py
CHANGED
|
@@ -186,7 +186,7 @@ class BertForJointParsing(BertPreTrainedModel):
|
|
| 186 |
morph_logits=morph_logits
|
| 187 |
)
|
| 188 |
|
| 189 |
-
def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, padding='longest', truncation=True, compute_syntax_mst=True, per_token_ner=False,
|
| 190 |
is_single_sentence = isinstance(sentences, str)
|
| 191 |
if is_single_sentence:
|
| 192 |
sentences = [sentences]
|
|
@@ -234,66 +234,32 @@ class BertForJointParsing(BertPreTrainedModel):
|
|
| 234 |
for sent_idx,parsed in enumerate(ner_parse_logits(inputs, sentences, tokenizer, output.ner_logits, self.config.id2label, offset_mapping)):
|
| 235 |
if per_token_ner:
|
| 236 |
merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
|
| 237 |
-
final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(parsed)
|
| 238 |
-
|
| 239 |
if output_style in ['ud', 'iahlt_ud']:
|
| 240 |
final_output = convert_output_to_ud(final_output, style='htb' if output_style == 'ud' else 'iahlt')
|
| 241 |
|
| 242 |
if is_single_sentence:
|
| 243 |
final_output = final_output[0]
|
| 244 |
-
|
| 245 |
-
words_index = parse_index(inputs['input_ids'], tokenizer)[0]
|
| 246 |
-
for idx, w in zip(words_index, final_output[0]['tokens']):
|
| 247 |
-
w['idx'] = idx
|
| 248 |
-
|
| 249 |
return final_output
|
| 250 |
|
| 251 |
-
def parse_index(input_ids: torch.Tensor, tokenizer: BertTokenizerFast):
|
| 252 |
-
# Create input_indices for each input_id, handling word-pieces
|
| 253 |
-
input_indices = []
|
| 254 |
-
for batch_idx, ids in enumerate(input_ids):
|
| 255 |
-
sentence_indices = []
|
| 256 |
-
current_word_indices = []
|
| 257 |
-
for idx, id_value in enumerate(ids):
|
| 258 |
-
# Skip special tokens
|
| 259 |
-
if id_value in [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]:
|
| 260 |
-
continue
|
| 261 |
-
|
| 262 |
-
token_id = input_ids[batch_idx, idx]
|
| 263 |
-
token = tokenizer._convert_id_to_token(token_id)
|
| 264 |
-
|
| 265 |
-
# If the token is a continuation of a previous word (word-piece), append the index
|
| 266 |
-
if token.startswith('##'):
|
| 267 |
-
current_word_indices.append(idx)
|
| 268 |
-
else:
|
| 269 |
-
# If there's a current word, add it to sentence indices
|
| 270 |
-
if current_word_indices:
|
| 271 |
-
sentence_indices.append(current_word_indices)
|
| 272 |
-
current_word_indices = [idx]
|
| 273 |
-
|
| 274 |
-
# Add the last word to sentence indices if not empty
|
| 275 |
-
if current_word_indices:
|
| 276 |
-
sentence_indices.append(current_word_indices)
|
| 277 |
-
input_indices.append(sentence_indices)
|
| 278 |
-
return input_indices
|
| 279 |
|
| 280 |
|
| 281 |
def aggregate_ner_tokens(predictions):
|
| 282 |
entities = []
|
| 283 |
prev = None
|
| 284 |
-
for word, pred, start, end
|
| 285 |
# O does nothing
|
| 286 |
if pred == 'O': prev = None
|
| 287 |
# B- || I-entity != prev (different entity or none)
|
| 288 |
elif pred.startswith('B-') or pred[2:] != prev:
|
| 289 |
prev = pred[2:]
|
| 290 |
-
entities.append([[word], prev, start, end
|
| 291 |
else:
|
| 292 |
entities[-1][0].append(word)
|
| 293 |
entities[-1][3] = end
|
| 294 |
-
entities[-1][4].extend(idx)
|
| 295 |
|
| 296 |
-
return [dict(
|
| 297 |
|
| 298 |
def merge_token_list(src, update, key):
|
| 299 |
for token_src, token_update in zip(src, update):
|
|
@@ -310,6 +276,7 @@ def combine_token_wordpieces(input_ids: torch.Tensor, tokenizer: BertTokenizerFa
|
|
| 310 |
|
| 311 |
def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str], offset_mapping):
|
| 312 |
input_ids = inputs['input_ids']
|
|
|
|
| 313 |
predictions = torch.argmax(logits, dim=-1)
|
| 314 |
batch_ret = []
|
| 315 |
for batch_idx in range(len(sentences)):
|
|
@@ -328,15 +295,11 @@ def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], toke
|
|
| 328 |
# we modify the last token in ret
|
| 329 |
# by discarding the original end position and replacing it with the new token's end position
|
| 330 |
if token.startswith('##'):
|
| 331 |
-
ret[-1] =
|
| 332 |
continue
|
| 333 |
# for each token, we append a tuple containing: token, label, start position, end position
|
| 334 |
-
ret.append(
|
| 335 |
-
|
| 336 |
-
words_index = parse_index(inputs['input_ids'], tokenizer)[0]
|
| 337 |
-
for idx, w in zip(words_index, batch_ret[0]):
|
| 338 |
-
w.append(idx)
|
| 339 |
-
|
| 340 |
return batch_ret
|
| 341 |
|
| 342 |
def lex_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor):
|
|
|
|
| 186 |
morph_logits=morph_logits
|
| 187 |
)
|
| 188 |
|
| 189 |
+
def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, padding='longest', truncation=True, compute_syntax_mst=True, per_token_ner=False, output_style: Literal['json', 'ud', 'iahlt_ud'] = 'json'):
|
| 190 |
is_single_sentence = isinstance(sentences, str)
|
| 191 |
if is_single_sentence:
|
| 192 |
sentences = [sentences]
|
|
|
|
| 234 |
for sent_idx,parsed in enumerate(ner_parse_logits(inputs, sentences, tokenizer, output.ner_logits, self.config.id2label, offset_mapping)):
|
| 235 |
if per_token_ner:
|
| 236 |
merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
|
| 237 |
+
final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(parsed)
|
| 238 |
+
|
| 239 |
if output_style in ['ud', 'iahlt_ud']:
|
| 240 |
final_output = convert_output_to_ud(final_output, style='htb' if output_style == 'ud' else 'iahlt')
|
| 241 |
|
| 242 |
if is_single_sentence:
|
| 243 |
final_output = final_output[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
return final_output
|
| 245 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
|
| 248 |
def aggregate_ner_tokens(predictions):
|
| 249 |
entities = []
|
| 250 |
prev = None
|
| 251 |
+
for word, pred, start, end in predictions:
|
| 252 |
# O does nothing
|
| 253 |
if pred == 'O': prev = None
|
| 254 |
# B- || I-entity != prev (different entity or none)
|
| 255 |
elif pred.startswith('B-') or pred[2:] != prev:
|
| 256 |
prev = pred[2:]
|
| 257 |
+
entities.append([[word], prev, start, end])
|
| 258 |
else:
|
| 259 |
entities[-1][0].append(word)
|
| 260 |
entities[-1][3] = end
|
|
|
|
| 261 |
|
| 262 |
+
return [dict(phrase=' '.join(words), label=label, start=start, end=end) for words, label, start, end in entities]
|
| 263 |
|
| 264 |
def merge_token_list(src, update, key):
|
| 265 |
for token_src, token_update in zip(src, update):
|
|
|
|
| 276 |
|
| 277 |
def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str], offset_mapping):
|
| 278 |
input_ids = inputs['input_ids']
|
| 279 |
+
|
| 280 |
predictions = torch.argmax(logits, dim=-1)
|
| 281 |
batch_ret = []
|
| 282 |
for batch_idx in range(len(sentences)):
|
|
|
|
| 295 |
# we modify the last token in ret
|
| 296 |
# by discarding the original end position and replacing it with the new token's end position
|
| 297 |
if token.startswith('##'):
|
| 298 |
+
ret[-1] = (ret[-1][0] + token[2:], ret[-1][1], ret[-1][2], end_pos.item())
|
| 299 |
continue
|
| 300 |
# for each token, we append a tuple containing: token, label, start position, end position
|
| 301 |
+
ret.append((token, id2label[predictions[batch_idx, tok_idx].item()], start_pos.item(), end_pos.item()))
|
| 302 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
return batch_ret
|
| 304 |
|
| 305 |
def lex_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor):
|