emanuelaboros commited on
Commit
a84fd08
·
verified ·
1 Parent(s): 08dfe16

Update generic_ner.py

Browse files
Files changed (1) hide show
  1. generic_ner.py +114 -59
generic_ner.py CHANGED
@@ -1,15 +1,15 @@
1
- from transformers import Pipeline
2
  import numpy as np
3
- import torch
4
  from nltk.chunk import conlltags2tree
5
  from nltk import pos_tag
6
  from nltk.tree import Tree
7
- import string
8
- import torch.nn.functional as F
9
- import re
10
-
11
-
12
  import re, string
 
 
 
 
 
 
 
13
 
14
 
15
  def tokenize(text):
@@ -88,14 +88,20 @@ def get_entities(tokens, tags, confidences, text):
88
  entity_start_position = indices[0]
89
  entity_end_position = indices[1]
90
  if (
91
- "_".join([original_label, original_string, str(entity_start_position)])
 
 
92
  in already_done
93
  ):
94
  continue
95
  else:
96
  already_done.append(
97
  "_".join(
98
- [original_label, original_string, str(entity_start_position)]
 
 
 
 
99
  )
100
  )
101
  entities.append(
@@ -141,6 +147,37 @@ def realign(
141
  return words_list, preds_list, confidence_list
142
 
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  # List of additional "strange" punctuation marks
145
  additional_punctuation = "‘’“”„«»•–—―‣◦…§¶†‡‰′″〈〉"
146
 
@@ -164,56 +201,74 @@ class MultitaskTokenClassificationPipeline(Pipeline):
164
  }
165
  return preprocess_kwargs, {}, {}
166
 
167
- def preprocess(self, text, **kwargs):
168
- tokenized_inputs = self.tokenizer(
169
- text, padding="max_length", truncation=True, max_length=512
170
- )
171
-
172
- text_sentence = tokenize(add_spaces_around_punctuation(text))
173
- return tokenized_inputs, text_sentence, text
174
-
175
- def _forward(self, inputs):
176
- inputs, text_sentence, text = inputs
177
- input_ids = torch.tensor([inputs["input_ids"]], dtype=torch.long).to(
178
- self.model.device
179
- )
180
- attention_mask = torch.tensor([inputs["attention_mask"]], dtype=torch.long).to(
181
- self.model.device
182
- )
183
- with torch.no_grad():
184
- outputs = self.model(input_ids, attention_mask)
185
- return outputs, text_sentence, text
186
-
187
- def postprocess(self, outputs, **kwargs):
188
- """
189
- Postprocess the outputs of the model
190
- :param outputs:
191
- :param kwargs:
192
- :return:
193
- """
194
- tokens_result, text_sentence, text = outputs
195
-
196
- predictions = {}
197
- confidence_scores = {}
198
- for task, logits in tokens_result.logits.items():
199
- predictions[task] = torch.argmax(logits, dim=-1).tolist()
200
- confidence_scores[task] = F.softmax(logits, dim=-1).tolist()
201
-
202
- decoded_predictions = {}
203
- for task, preds in predictions.items():
204
- decoded_predictions[task] = [
205
- [self.id2label[task][label] for label in seq] for seq in preds
206
- ]
207
- entities = {}
208
- for task, preds in predictions.items():
209
- words_list, preds_list, confidence_list = realign(
210
- text_sentence,
211
- preds[0],
212
- confidence_scores[task][0],
213
- self.tokenizer,
214
- self.id2label[task],
215
  )
216
 
217
- entities[task] = get_entities(words_list, preds_list, confidence_list, text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
- return entities
 
 
1
  import numpy as np
 
2
  from nltk.chunk import conlltags2tree
3
  from nltk import pos_tag
4
  from nltk.tree import Tree
 
 
 
 
 
5
  import re, string
6
+ import pysbd
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from transformers import Pipeline
10
+ from langdetect import detect
11
+ from nltk.tokenize import sent_tokenize
12
+ from typing import List
13
 
14
 
15
  def tokenize(text):
 
88
  entity_start_position = indices[0]
89
  entity_end_position = indices[1]
90
  if (
91
+ "_".join(
92
+ [original_label, original_string, str(entity_start_position)]
93
+ )
94
  in already_done
95
  ):
96
  continue
97
  else:
98
  already_done.append(
99
  "_".join(
100
+ [
101
+ original_label,
102
+ original_string,
103
+ str(entity_start_position),
104
+ ]
105
  )
106
  )
107
  entities.append(
 
147
  return words_list, preds_list, confidence_list
148
 
149
 
150
+ def segment_and_trim_sentences(article, language, max_length):
151
+
152
+ try:
153
+ segmenter = pysbd.Segmenter(language=language, clean=False)
154
+ except:
155
+ segmenter = pysbd.Segmenter(language="en", clean=False)
156
+
157
+ sentences = segmenter.segment(article)
158
+
159
+ trimmed_sentences = []
160
+ for sentence in sentences:
161
+ while len(sentence) > max_length:
162
+ # Find the last space within max_length
163
+ cut_index = sentence.rfind(" ", 0, max_length)
164
+ if cut_index == -1:
165
+ # If no space found, forcibly cut at max_length
166
+ cut_index = max_length
167
+
168
+ # Cut the sentence and add the first part to trimmed sentences
169
+ trimmed_sentences.append(sentence[:cut_index])
170
+
171
+ # Update the sentence to be the remaining part
172
+ sentence = sentence[cut_index:].lstrip()
173
+
174
+ # Add the remaining part of the sentence if it's not empty
175
+ if sentence:
176
+ trimmed_sentences.append(sentence)
177
+
178
+ return trimmed_sentences
179
+
180
+
181
  # List of additional "strange" punctuation marks
182
  additional_punctuation = "‘’“”„«»•–—―‣◦…§¶†‡‰′″〈〉"
183
 
 
201
  }
202
  return preprocess_kwargs, {}, {}
203
 
204
+ class MultitaskTokenClassificationPipeline(Pipeline):
205
+
206
+ def _sanitize_parameters(self, **kwargs):
207
+ preprocess_kwargs = {}
208
+ if "text" in kwargs:
209
+ preprocess_kwargs["text"] = kwargs["text"]
210
+ self.label_map = self.model.config.label_map
211
+ self.id2label = {
212
+ task: {id_: label for label, id_ in labels.items()}
213
+ for task, labels in self.label_map.items()
214
+ }
215
+ return preprocess_kwargs, {}, {}
216
+
217
+ def preprocess(self, text, **kwargs):
218
+
219
+ language = detect(text)
220
+ sentences = segment_and_trim_sentences(text, language, 512)
221
+
222
+ tokenized_inputs = self.tokenizer(
223
+ sentences,
224
+ padding="max_length",
225
+ truncation=True,
226
+ max_length=512,
227
+ return_tensors="pt",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  )
229
 
230
+ text_sentence = [
231
+ tokenize(add_spaces_around_punctuation(sentence))
232
+ for sentence in sentences
233
+ ]
234
+ return tokenized_inputs, text_sentence, text
235
+
236
+ def _forward(self, inputs):
237
+ inputs, text_sentence, text = inputs
238
+ input_ids = inputs["input_ids"].to(self.model.device)
239
+ attention_mask = inputs["attention_mask"].to(self.model.device)
240
+
241
+ with torch.no_grad():
242
+ outputs = self.model(input_ids, attention_mask)
243
+
244
+ return outputs, text_sentence, text
245
+
246
+ def postprocess(self, outputs, **kwargs):
247
+ tokens_result, text_sentence, text = outputs
248
+
249
+ predictions = {}
250
+ confidence_scores = {}
251
+ for task, logits in tokens_result.logits.items():
252
+ predictions[task] = torch.argmax(logits, dim=-1).tolist()
253
+ confidence_scores[task] = F.softmax(logits, dim=-1).tolist()
254
+
255
+ decoded_predictions = {}
256
+ for task, preds in predictions.items():
257
+ decoded_predictions[task] = [
258
+ [self.id2label[task][label] for label in seq] for seq in preds
259
+ ]
260
+ entities = {}
261
+ for task, preds in predictions.items():
262
+ words_list, preds_list, confidence_list = realign(
263
+ text_sentence,
264
+ preds[0],
265
+ confidence_scores[task][0],
266
+ self.tokenizer,
267
+ self.id2label[task],
268
+ )
269
+
270
+ entities[task] = get_entities(
271
+ words_list, preds_list, confidence_list, text
272
+ )
273
 
274
+ return entities