emanuelaboros commited on
Commit
8c3be27
·
verified ·
1 Parent(s): 1a59d5a

Update generic_ner.py

Browse files
Files changed (1) hide show
  1. generic_ner.py +60 -37
generic_ner.py CHANGED
@@ -202,54 +202,77 @@ class MultitaskTokenClassificationPipeline(Pipeline):
202
  }
203
  return preprocess_kwargs, {}, {}
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  def preprocess(self, text, **kwargs):
206
 
207
- language = detect(text)
208
- sentences = segment_and_trim_sentences(text, language, 512)
209
 
210
  tokenized_inputs = self.tokenizer(
211
- sentences,
212
- padding="max_length",
213
- truncation=True,
214
- max_length=512,
215
- return_tensors="pt",
216
  )
217
 
218
- text_sentences = [
219
- tokenize(add_spaces_around_punctuation(sentence)) for sentence in sentences
220
- ]
221
- return tokenized_inputs, text_sentences, text
222
 
223
  def _forward(self, inputs):
224
  inputs, text_sentences, text = inputs
225
- all_logits = {}
226
-
227
- for i in range(len(text_sentences)):
228
- print(inputs["input_ids"][i].shape)
229
- input_ids = torch.tensor([inputs["input_ids"][i]], dtype=torch.long).to(
230
- self.model.device
231
- )
232
- attention_mask = torch.tensor(
233
- [inputs["attention_mask"][i]], dtype=torch.long
234
- ).to(self.model.device)
235
-
236
- with torch.no_grad():
237
- outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
238
-
239
- # Accumulate logits for each task
240
- if not all_logits:
241
- all_logits = {task: logits for task, logits in outputs.logits.items()}
242
- else:
243
- for task in all_logits:
244
- all_logits[task] = torch.cat(
245
- (all_logits[task], outputs.logits[task]), dim=1
246
- )
247
-
248
- # Replace outputs.logits with accumulated logits
249
- outputs.logits = all_logits
250
-
251
  return outputs, text_sentences, text
252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  def postprocess(self, outputs, **kwargs):
254
  """
255
  Postprocess the outputs of the model
 
202
  }
203
  return preprocess_kwargs, {}, {}
204
 
205
+ # def preprocess(self, text, **kwargs):
206
+ #
207
+ # language = detect(text)
208
+ # sentences = segment_and_trim_sentences(text, language, 512)
209
+ #
210
+ # tokenized_inputs = self.tokenizer(
211
+ # text,
212
+ # padding="max_length",
213
+ # truncation=True,
214
+ # max_length=512,
215
+ # return_tensors="pt",
216
+ # )
217
+ #
218
+ # text_sentences = [
219
+ # tokenize(add_spaces_around_punctuation(sentence)) for sentence in sentences
220
+ # ]
221
+ # return tokenized_inputs, text_sentences, text
222
  def preprocess(self, text, **kwargs):
223
 
224
+ # sentences = segment_and_trim_sentences(text, language, 512)
 
225
 
226
  tokenized_inputs = self.tokenizer(
227
+ text, padding="max_length", truncation=True, max_length=512
 
 
 
 
228
  )
229
 
230
+ text_sentence = tokenize(add_spaces_around_punctuation(text))
231
+ return tokenized_inputs, text_sentence, text
 
 
232
 
233
  def _forward(self, inputs):
234
  inputs, text_sentences, text = inputs
235
+ input_ids = torch.tensor([inputs["input_ids"]], dtype=torch.long).to(
236
+ self.model.device
237
+ )
238
+ print(input_ids.shape)
239
+ attention_mask = torch.tensor([inputs["attention_mask"]], dtype=torch.long).to(
240
+ self.model.device
241
+ )
242
+ with torch.no_grad():
243
+ outputs = self.model(input_ids, attention_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  return outputs, text_sentences, text
245
 
246
+ # def _forward(self, inputs):
247
+ # inputs, text_sentences, text = inputs
248
+ # all_logits = {}
249
+ #
250
+ # for i in range(len(text_sentences)):
251
+ # print(inputs["input_ids"][i].shape)
252
+ # input_ids = torch.tensor([inputs["input_ids"][i]], dtype=torch.long).to(
253
+ # self.model.device
254
+ # )
255
+ # attention_mask = torch.tensor(
256
+ # [inputs["attention_mask"][i]], dtype=torch.long
257
+ # ).to(self.model.device)
258
+ #
259
+ # with torch.no_grad():
260
+ # outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
261
+ #
262
+ # # Accumulate logits for each task
263
+ # if not all_logits:
264
+ # all_logits = {task: logits for task, logits in outputs.logits.items()}
265
+ # else:
266
+ # for task in all_logits:
267
+ # all_logits[task] = torch.cat(
268
+ # (all_logits[task], outputs.logits[task]), dim=1
269
+ # )
270
+ #
271
+ # # Replace outputs.logits with accumulated logits
272
+ # outputs.logits = all_logits
273
+ #
274
+ # return outputs, text_sentences, text
275
+
276
  def postprocess(self, outputs, **kwargs):
277
  """
278
  Postprocess the outputs of the model