emanuelaboros commited on
Commit
471ce47
·
verified ·
1 Parent(s): 095fd51

Update generic_ner.py

Browse files
Files changed (1) hide show
  1. generic_ner.py +32 -46
generic_ner.py CHANGED
@@ -200,72 +200,58 @@ class MultitaskTokenClassificationPipeline(Pipeline):
200
  }
201
  return preprocess_kwargs, {}, {}
202
 
203
- def chunk_text_exact(self, text, tokenizer, max_subtokens):
204
- """
205
- Splits text into exact subtoken chunks based on the tokenizer's max length.
206
- """
207
- subtokens = tokenizer.encode(text, add_special_tokens=False)
208
- for i in range(0, len(subtokens), max_subtokens):
209
- chunk = subtokens[i : i + max_subtokens]
210
- yield tokenizer.decode(chunk, clean_up_tokenization_spaces=False)
211
-
212
  def preprocess(self, text, **kwargs):
213
- # Get the model's max input length
214
- max_input_length = self.tokenizer.model_max_length - 2 # Reserve space for [CLS] and [SEP]
215
 
216
- # Split the text into subtoken chunks
217
- text_chunks = list(self.chunk_text_exact(text, self.tokenizer, max_input_length))
 
218
 
219
- # Tokenize and add special tokens for each chunk
220
- tokenized_chunks = [
221
- self.tokenizer(
222
- chunk, padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length
223
- )
224
- for chunk in text_chunks
225
- ]
226
-
227
- return tokenized_chunks, text_chunks, text
228
 
229
  def _forward(self, inputs):
230
- tokenized_chunks, text_chunks, text = inputs
231
- outputs = []
 
 
 
 
 
232
  with torch.no_grad():
233
- for tokenized_input in tokenized_chunks:
234
- input_ids = torch.tensor([tokenized_input["input_ids"]], dtype=torch.long).to(self.model.device)
235
- attention_mask = torch.tensor([tokenized_input["attention_mask"]], dtype=torch.long).to(self.model.device)
236
- outputs.append(self.model(input_ids, attention_mask))
237
- return outputs, text_chunks, text
238
 
239
- def postprocess(self, outputs, **kwargs):
240
- tokens_result, text_chunks, text = outputs
241
 
242
- # Initialize variables for collecting results across chunks
243
- predictions = {task: [] for task in self.label_map.keys()}
244
- confidence_scores = {task: [] for task in self.label_map.keys()}
 
 
 
 
 
245
 
246
- # Collect predictions from each chunk
247
- for chunk_result in tokens_result:
248
- for task, logits in chunk_result.logits.items():
249
- predictions[task].extend(torch.argmax(logits, dim=-1).tolist())
250
- confidence_scores[task].extend(F.softmax(logits, dim=-1).tolist())
251
 
252
- # Decode and process the predictions
253
  decoded_predictions = {}
254
  for task, preds in predictions.items():
255
  decoded_predictions[task] = [
256
  [self.id2label[task][label] for label in seq] for seq in preds
257
  ]
258
-
259
- # Extract entities from the combined predictions
260
  entities = {}
261
  for task, preds in predictions.items():
262
  words_list, preds_list, confidence_list = realign(
263
- text_chunks,
264
- preds,
265
- confidence_scores[task],
266
  self.tokenizer,
267
  self.id2label[task],
268
  )
 
269
  entities[task] = get_entities(words_list, preds_list, confidence_list, text)
270
- print(entities[task])
271
  return entities
 
200
  }
201
  return preprocess_kwargs, {}, {}
202
 
 
 
 
 
 
 
 
 
 
203
  def preprocess(self, text, **kwargs):
 
 
204
 
205
+ tokenized_inputs = self.tokenizer(
206
+ text, padding="max_length", truncation=True, max_length=512
207
+ )
208
 
209
+ text_sentence = tokenize(add_spaces_around_punctuation(text))
210
+ return tokenized_inputs, text_sentence, text
 
 
 
 
 
 
 
211
 
212
  def _forward(self, inputs):
213
+ inputs, text_sentences, text = inputs
214
+ input_ids = torch.tensor([inputs["input_ids"]], dtype=torch.long).to(
215
+ self.model.device
216
+ )
217
+ attention_mask = torch.tensor([inputs["attention_mask"]], dtype=torch.long).to(
218
+ self.model.device
219
+ )
220
  with torch.no_grad():
221
+ outputs = self.model(input_ids, attention_mask)
222
+ return outputs, text_sentences, text
 
 
 
223
 
 
 
224
 
225
+ def postprocess(self, outputs, **kwargs):
226
+ """
227
+ Postprocess the outputs of the model
228
+ :param outputs:
229
+ :param kwargs:
230
+ :return:
231
+ """
232
+ tokens_result, text_sentence, text = outputs
233
 
234
+ predictions = {}
235
+ confidence_scores = {}
236
+ for task, logits in tokens_result.logits.items():
237
+ predictions[task] = torch.argmax(logits, dim=-1).tolist()
238
+ confidence_scores[task] = F.softmax(logits, dim=-1).tolist()
239
 
 
240
  decoded_predictions = {}
241
  for task, preds in predictions.items():
242
  decoded_predictions[task] = [
243
  [self.id2label[task][label] for label in seq] for seq in preds
244
  ]
 
 
245
  entities = {}
246
  for task, preds in predictions.items():
247
  words_list, preds_list, confidence_list = realign(
248
+ text_sentence,
249
+ preds[0],
250
+ confidence_scores[task][0],
251
  self.tokenizer,
252
  self.id2label[task],
253
  )
254
+
255
  entities[task] = get_entities(words_list, preds_list, confidence_list, text)
256
+
257
  return entities