Commit
·
4fd1faf
1
Parent(s):
8d73145
Initial commit of the trained NER model with code
Browse files- config.json +7 -0
- generic_ner.py +173 -0
config.json
CHANGED
|
@@ -5,6 +5,13 @@
|
|
| 5 |
],
|
| 6 |
"attention_probs_dropout_prob": 0.1,
|
| 7 |
"classifier_dropout": null,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
"hidden_act": "gelu",
|
| 9 |
"hidden_dropout_prob": 0.1,
|
| 10 |
"hidden_size": 512,
|
|
|
|
| 5 |
],
|
| 6 |
"attention_probs_dropout_prob": 0.1,
|
| 7 |
"classifier_dropout": null,
|
| 8 |
+
"custom_pipelines": {
|
| 9 |
+
"generic-ner": {
|
| 10 |
+
"impl": "generic_ner.MultitaskTokenClassificationPipeline",
|
| 11 |
+
"pt": "models.ExtendedMultitaskModelForTokenClassification",
|
| 12 |
+
"tf": []
|
| 13 |
+
}
|
| 14 |
+
},
|
| 15 |
"hidden_act": "gelu",
|
| 16 |
"hidden_dropout_prob": 0.1,
|
| 17 |
"hidden_size": 512,
|
generic_ner.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import Pipeline
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from nltk.chunk import conlltags2tree
|
| 5 |
+
from nltk import pos_tag
|
| 6 |
+
from nltk.tree import Tree
|
| 7 |
+
import string
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import re
|
| 10 |
+
from models import ExtendedMultitaskModelForTokenClassification
|
| 11 |
+
|
| 12 |
+
# Register the custom pipeline
|
| 13 |
+
from transformers import pipeline
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def tokenize(text):
|
| 17 |
+
# print(text)
|
| 18 |
+
for punctuation in string.punctuation:
|
| 19 |
+
text = text.replace(punctuation, " " + punctuation + " ")
|
| 20 |
+
return text.split()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def find_entity_indices(article, entity):
|
| 24 |
+
"""
|
| 25 |
+
Find all occurrences of an entity in the article and return their indices.
|
| 26 |
+
|
| 27 |
+
:param article: The complete article text.
|
| 28 |
+
:param entity: The entity to search for.
|
| 29 |
+
:return: A list of tuples (lArticleOffset, rArticleOffset) for each occurrence.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
# normalized_target = normalize_text(entity)
|
| 33 |
+
# normalized_document = normalize_text(article)
|
| 34 |
+
|
| 35 |
+
entity_indices = []
|
| 36 |
+
for match in re.finditer(re.escape(entity), article):
|
| 37 |
+
start_idx = match.start()
|
| 38 |
+
end_idx = match.end()
|
| 39 |
+
entity_indices.append((start_idx, end_idx))
|
| 40 |
+
|
| 41 |
+
return entity_indices
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_entities(tokens, tags, confidences, text):
|
| 45 |
+
|
| 46 |
+
tags = [tag.replace("S-", "B-").replace("E-", "I-") for tag in tags]
|
| 47 |
+
pos_tags = [pos for token, pos in pos_tag(tokens)]
|
| 48 |
+
|
| 49 |
+
conlltags = [(token, pos, tg) for token, pos, tg in zip(tokens, pos_tags, tags)]
|
| 50 |
+
ne_tree = conlltags2tree(conlltags)
|
| 51 |
+
|
| 52 |
+
entities = []
|
| 53 |
+
idx: int = 0
|
| 54 |
+
|
| 55 |
+
for subtree in ne_tree:
|
| 56 |
+
# skipping 'O' tags
|
| 57 |
+
if isinstance(subtree, Tree):
|
| 58 |
+
original_label = subtree.label()
|
| 59 |
+
original_string = " ".join([token for token, pos in subtree.leaves()])
|
| 60 |
+
|
| 61 |
+
for indices in find_entity_indices(text, original_string):
|
| 62 |
+
entity_start_position = indices[0]
|
| 63 |
+
entity_end_position = indices[1]
|
| 64 |
+
entities.append(
|
| 65 |
+
{
|
| 66 |
+
"entity": original_label,
|
| 67 |
+
"score": np.average(confidences[idx : idx + len(subtree)]),
|
| 68 |
+
"index": idx,
|
| 69 |
+
"word": original_string,
|
| 70 |
+
"start": entity_start_position,
|
| 71 |
+
"end": entity_end_position,
|
| 72 |
+
}
|
| 73 |
+
)
|
| 74 |
+
assert (
|
| 75 |
+
text[entity_start_position:entity_end_position] == original_string
|
| 76 |
+
)
|
| 77 |
+
idx += len(subtree)
|
| 78 |
+
|
| 79 |
+
# Update the current character position
|
| 80 |
+
# We add the length of the original string + 1 (for the space)
|
| 81 |
+
else:
|
| 82 |
+
token, pos = subtree
|
| 83 |
+
# If it's not a named entity, we still need to update the character
|
| 84 |
+
# position
|
| 85 |
+
idx += 1
|
| 86 |
+
|
| 87 |
+
return entities
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def realign(
|
| 91 |
+
text_sentence, out_label_preds, softmax_scores, tokenizer, reverted_label_map
|
| 92 |
+
):
|
| 93 |
+
preds_list, words_list, confidence_list = [], [], []
|
| 94 |
+
word_ids = tokenizer(text_sentence, is_split_into_words=True).word_ids()
|
| 95 |
+
for idx, word in enumerate(text_sentence):
|
| 96 |
+
beginning_index = word_ids.index(idx)
|
| 97 |
+
try:
|
| 98 |
+
preds_list.append(reverted_label_map[out_label_preds[beginning_index]])
|
| 99 |
+
confidence_list.append(max(softmax_scores[beginning_index]))
|
| 100 |
+
except Exception as ex: # the sentence was longer then max_length
|
| 101 |
+
preds_list.append("O")
|
| 102 |
+
confidence_list.append(0.0)
|
| 103 |
+
words_list.append(word)
|
| 104 |
+
|
| 105 |
+
return words_list, preds_list, confidence_list
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class MultitaskTokenClassificationPipeline(Pipeline):
|
| 109 |
+
def __init__(self, model, tokenizer, label_map, **kwargs):
|
| 110 |
+
super().__init__(model=model, tokenizer=tokenizer, **kwargs)
|
| 111 |
+
self.label_map = label_map
|
| 112 |
+
self.id2label = {
|
| 113 |
+
task: {id_: label for label, id_ in labels.items()}
|
| 114 |
+
for task, labels in label_map.items()
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
def _sanitize_parameters(self, **kwargs):
|
| 118 |
+
# Add any additional parameter handling if necessary
|
| 119 |
+
return kwargs, {}, {}
|
| 120 |
+
|
| 121 |
+
def preprocess(self, text, **kwargs):
|
| 122 |
+
tokenized_inputs = self.tokenizer(
|
| 123 |
+
text, padding="max_length", truncation=True, max_length=512
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
text_sentence = tokenize(text)
|
| 127 |
+
return tokenized_inputs, text_sentence, text
|
| 128 |
+
|
| 129 |
+
def _forward(self, inputs):
|
| 130 |
+
inputs, text_sentence, text = inputs
|
| 131 |
+
input_ids = torch.tensor([inputs["input_ids"]], dtype=torch.long).to(
|
| 132 |
+
self.model.device
|
| 133 |
+
)
|
| 134 |
+
attention_mask = torch.tensor([inputs["attention_mask"]], dtype=torch.long).to(
|
| 135 |
+
self.model.device
|
| 136 |
+
)
|
| 137 |
+
with torch.no_grad():
|
| 138 |
+
outputs = self.model(input_ids, attention_mask)
|
| 139 |
+
return outputs, text_sentence, text
|
| 140 |
+
|
| 141 |
+
def postprocess(self, outputs, **kwargs):
|
| 142 |
+
"""
|
| 143 |
+
Postprocess the outputs of the model
|
| 144 |
+
:param outputs:
|
| 145 |
+
:param kwargs:
|
| 146 |
+
:return:
|
| 147 |
+
"""
|
| 148 |
+
tokens_result, text_sentence, text = outputs
|
| 149 |
+
|
| 150 |
+
predictions = {}
|
| 151 |
+
confidence_scores = {}
|
| 152 |
+
for task, logits in tokens_result.logits.items():
|
| 153 |
+
predictions[task] = torch.argmax(logits, dim=-1).tolist()
|
| 154 |
+
confidence_scores[task] = F.softmax(logits, dim=-1).tolist()
|
| 155 |
+
|
| 156 |
+
decoded_predictions = {}
|
| 157 |
+
for task, preds in predictions.items():
|
| 158 |
+
decoded_predictions[task] = [
|
| 159 |
+
[self.id2label[task][label] for label in seq] for seq in preds
|
| 160 |
+
]
|
| 161 |
+
entities = {}
|
| 162 |
+
for task, preds in predictions.items():
|
| 163 |
+
words_list, preds_list, confidence_list = realign(
|
| 164 |
+
text_sentence,
|
| 165 |
+
preds[0],
|
| 166 |
+
confidence_scores[task][0],
|
| 167 |
+
self.tokenizer,
|
| 168 |
+
self.id2label[task],
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
entities[task] = get_entities(words_list, preds_list, confidence_list, text)
|
| 172 |
+
|
| 173 |
+
return entities
|