emanuelaboros commited on
Commit
71e5c2e
·
1 Parent(s): 79fdcae
.DS_Store ADDED
Binary file (6.15 kB). View file
 
generic_ner.py CHANGED
@@ -16,21 +16,21 @@ import re, string
16
  stop_words = set(nltk.corpus.stopwords.words("english"))
17
  DEBUG = False
18
  punctuation = (
19
- string.punctuation
20
- + "«»—…“”"
21
- + "—."
22
- + "–"
23
- + "’"
24
- + "‘"
25
- + "´"
26
- + "•"
27
- + "°"
28
- + "»"
29
- + "“"
30
- + "”"
31
- + "–"
32
- + "—"
33
- + "‘’“”„«»•–—―‣◦…§¶†‡‰′″〈〉"
34
  )
35
 
36
  # List of additional "strange" punctuation marks
@@ -87,53 +87,6 @@ WHITESPACE_RULES = {
87
  }
88
 
89
 
90
- # def tokenize(text: str, language: str = "other") -> list[str]:
91
- # """Apply whitespace rules to the given text and language, separating it into tokens.
92
- #
93
- # Args:
94
- # text (str): The input text to separate into a list of tokens.
95
- # language (str): Language of the text.
96
- #
97
- # Returns:
98
- # list[str]: List of tokens with punctuation as separate tokens.
99
- # """
100
- # # text = add_spaces_around_punctuation(text)
101
- # if not text:
102
- # return []
103
- #
104
- # if language not in WHITESPACE_RULES:
105
- # # Default behavior for languages without specific rules:
106
- # # tokenize using standard whitespace splitting
107
- # language = "other"
108
- #
109
- # wsrules = WHITESPACE_RULES[language]
110
- # tokenized_text = []
111
- # current_token = ""
112
- #
113
- # for char in text:
114
- # if char in wsrules["pct_no_ws_before_after"]:
115
- # if current_token:
116
- # tokenized_text.append(current_token)
117
- # tokenized_text.append(char)
118
- # current_token = ""
119
- # elif char in wsrules["pct_no_ws_before"] or char in wsrules["pct_no_ws_after"]:
120
- # if current_token:
121
- # tokenized_text.append(current_token)
122
- # tokenized_text.append(char)
123
- # current_token = ""
124
- # elif char.isspace():
125
- # if current_token:
126
- # tokenized_text.append(current_token)
127
- # current_token = ""
128
- # else:
129
- # current_token += char
130
- #
131
- # if current_token:
132
- # tokenized_text.append(current_token)
133
- #
134
- # return tokenized_text
135
-
136
-
137
  def normalize_text(text):
138
  # Remove spaces and tabs for the search but keep newline characters
139
  return re.sub(r"[ \t]+", "", text)
@@ -183,7 +136,6 @@ def find_entity_indices(article_text, search_text):
183
 
184
 
185
  def get_entities(tokens, tags, confidences, text):
186
-
187
  tags = [tag.replace("S-", "B-").replace("E-", "I-") for tag in tags]
188
  pos_tags = [pos for token, pos in pos_tag(tokens)]
189
 
@@ -208,10 +160,10 @@ def get_entities(tokens, tags, confidences, text):
208
  entity_start_position = indices[0]
209
  entity_end_position = indices[1]
210
  if (
211
- "_".join(
212
- [original_label, original_string, str(entity_start_position)]
213
- )
214
- in already_done
215
  ):
216
  continue
217
  else:
@@ -225,24 +177,24 @@ def get_entities(tokens, tags, confidences, text):
225
  )
226
  )
227
  if len(text[entity_start_position:entity_end_position].strip()) < len(
228
- text[entity_start_position:entity_end_position]
229
  ):
230
  entity_start_position = (
231
- entity_start_position
232
- + len(text[entity_start_position:entity_end_position])
233
- - len(text[entity_start_position:entity_end_position].strip())
234
  )
235
 
236
  entities.append(
237
  {
238
  "type": original_label,
239
  "confidence_ner": round(
240
- np.average(confidences[idx : idx + len(subtree)]), 2
241
  ),
242
  "index": (idx, idx + len(subtree)),
243
  "surface": text[
244
- entity_start_position:entity_end_position
245
- ], # original_string,
246
  "lOffset": entity_start_position,
247
  "rOffset": entity_end_position,
248
  }
@@ -282,6 +234,7 @@ def realign(word_ids, tokens, out_label_preds, softmax_scores, tokenizer, revert
282
 
283
  return words_list, preds_list, confidence_list
284
 
 
285
  def add_spaces_around_punctuation(text):
286
  # Add a space before and after all punctuation
287
  all_punctuation = string.punctuation + punctuation
@@ -312,8 +265,8 @@ def attach_comp_to_closest(entities):
312
 
313
  # Ensure the entity type is valid and check for minimal distance
314
  if (
315
- distance < min_distance
316
- and other_entity["type"].split(".")[0] in valid_entity_types
317
  ):
318
  min_distance = distance
319
  closest_entity = other_entity
@@ -363,8 +316,8 @@ def extract_name_from_text(text, partial_name):
363
  # Find the position of the partial name in the word list
364
  for i, word in enumerate(words):
365
  if DEBUG:
366
- print(words, "---", words[i : i + len(partial_words)])
367
- if words[i : i + len(partial_words)] == partial_words:
368
  # Initialize full name with the partial name
369
  full_name = partial_words[:]
370
 
@@ -443,8 +396,8 @@ def postprocess_entities(entities):
443
 
444
  # If the entity text is new, or this entity has more dots, update the map
445
  if (
446
- entity_text not in entity_map
447
- or entity_map[entity_text]["type"].count(".") < num_dots
448
  ):
449
  entity_map[entity_text] = entity
450
 
@@ -480,9 +433,9 @@ def remove_included_entities(entities):
480
  is_included = True
481
  break
482
  elif (
483
- entity["type"].split(".")[0] in other_entity["type"].split(".")[0]
484
- or other_entity["type"].split(".")[0]
485
- in entity["type"].split(".")[0]
486
  ):
487
  if entity["surface"] in other_entity["surface"]:
488
  is_included = True
@@ -547,12 +500,12 @@ def remove_trailing_stopwords(entities):
547
  if len(entity_text.split()) < 1:
548
  continue
549
  while entity_text and (
550
- entity_text.split()[0].lower() in stop_words
551
- or entity_text[0] in punctuation
552
  ):
553
  if entity_text.split()[0].lower() in stop_words:
554
  stopword_len = (
555
- len(entity_text.split()[0]) + 1
556
  ) # Adjust length for stopword and following space
557
  entity_text = entity_text[stopword_len:] # Remove leading stopword
558
  lOffset += stopword_len # Adjust the left offset
@@ -571,11 +524,11 @@ def remove_trailing_stopwords(entities):
571
  # Remove stopwords and punctuation from the end
572
  if len(entity_text.strip()) > 1:
573
  while (
574
- entity_text.strip().split()
575
- and (
576
- entity_text.strip().split()[-1].lower() in stop_words
577
- or entity_text[-1] in punctuation
578
- )
579
  ):
580
  if entity_text.strip().split() and entity_text.strip().split()[-1].lower() in stop_words:
581
  stopword_len = len(entity_text.strip().split()[-1]) + 1 # account for space
@@ -613,7 +566,7 @@ def remove_trailing_stopwords(entities):
613
  continue
614
  # Check if the entire entity is made up of stopwords characters
615
  if all(
616
- [char.lower() in stop_words for char in entity_text if char.isalpha()]
617
  ):
618
  if DEBUG:
619
  print(
@@ -630,11 +583,11 @@ def remove_trailing_stopwords(entities):
630
  # entities.remove(entity)
631
  continue
632
  if all(
633
- [
634
- char.lower() in string.punctuation
635
- for char in entity_text
636
- if char.isalpha()
637
- ]
638
  ):
639
  if DEBUG:
640
  print(
@@ -676,7 +629,7 @@ def remove_trailing_stopwords(entities):
676
  if DEBUG:
677
  print(f"Remained entities in remove_trailing_stopwords: {len(new_entities)}")
678
  return new_entities
679
-
680
 
681
  class MultitaskTokenClassificationPipeline(Pipeline):
682
 
@@ -723,8 +676,8 @@ class MultitaskTokenClassificationPipeline(Pipeline):
723
  def is_within(self, entity1, entity2):
724
  """Check if entity1 is fully within the bounds of entity2."""
725
  return (
726
- entity1["lOffset"] >= entity2["lOffset"]
727
- and entity1["rOffset"] <= entity2["rOffset"]
728
  )
729
 
730
  def postprocess(self, outputs, **kwargs):
 
16
  stop_words = set(nltk.corpus.stopwords.words("english"))
17
  DEBUG = False
18
  punctuation = (
19
+ string.punctuation
20
+ + "«»—…“”"
21
+ + "—."
22
+ + "–"
23
+ + "’"
24
+ + "‘"
25
+ + "´"
26
+ + "•"
27
+ + "°"
28
+ + "»"
29
+ + "“"
30
+ + "”"
31
+ + "–"
32
+ + "—"
33
+ + "‘’“”„«»•–—―‣◦…§¶†‡‰′″〈〉"
34
  )
35
 
36
  # List of additional "strange" punctuation marks
 
87
  }
88
 
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def normalize_text(text):
91
  # Remove spaces and tabs for the search but keep newline characters
92
  return re.sub(r"[ \t]+", "", text)
 
136
 
137
 
138
  def get_entities(tokens, tags, confidences, text):
 
139
  tags = [tag.replace("S-", "B-").replace("E-", "I-") for tag in tags]
140
  pos_tags = [pos for token, pos in pos_tag(tokens)]
141
 
 
160
  entity_start_position = indices[0]
161
  entity_end_position = indices[1]
162
  if (
163
+ "_".join(
164
+ [original_label, original_string, str(entity_start_position)]
165
+ )
166
+ in already_done
167
  ):
168
  continue
169
  else:
 
177
  )
178
  )
179
  if len(text[entity_start_position:entity_end_position].strip()) < len(
180
+ text[entity_start_position:entity_end_position]
181
  ):
182
  entity_start_position = (
183
+ entity_start_position
184
+ + len(text[entity_start_position:entity_end_position])
185
+ - len(text[entity_start_position:entity_end_position].strip())
186
  )
187
 
188
  entities.append(
189
  {
190
  "type": original_label,
191
  "confidence_ner": round(
192
+ np.average(confidences[idx: idx + len(subtree)]), 2
193
  ),
194
  "index": (idx, idx + len(subtree)),
195
  "surface": text[
196
+ entity_start_position:entity_end_position
197
+ ], # original_string,
198
  "lOffset": entity_start_position,
199
  "rOffset": entity_end_position,
200
  }
 
234
 
235
  return words_list, preds_list, confidence_list
236
 
237
+
238
  def add_spaces_around_punctuation(text):
239
  # Add a space before and after all punctuation
240
  all_punctuation = string.punctuation + punctuation
 
265
 
266
  # Ensure the entity type is valid and check for minimal distance
267
  if (
268
+ distance < min_distance
269
+ and other_entity["type"].split(".")[0] in valid_entity_types
270
  ):
271
  min_distance = distance
272
  closest_entity = other_entity
 
316
  # Find the position of the partial name in the word list
317
  for i, word in enumerate(words):
318
  if DEBUG:
319
+ print(words, "---", words[i: i + len(partial_words)])
320
+ if words[i: i + len(partial_words)] == partial_words:
321
  # Initialize full name with the partial name
322
  full_name = partial_words[:]
323
 
 
396
 
397
  # If the entity text is new, or this entity has more dots, update the map
398
  if (
399
+ entity_text not in entity_map
400
+ or entity_map[entity_text]["type"].count(".") < num_dots
401
  ):
402
  entity_map[entity_text] = entity
403
 
 
433
  is_included = True
434
  break
435
  elif (
436
+ entity["type"].split(".")[0] in other_entity["type"].split(".")[0]
437
+ or other_entity["type"].split(".")[0]
438
+ in entity["type"].split(".")[0]
439
  ):
440
  if entity["surface"] in other_entity["surface"]:
441
  is_included = True
 
500
  if len(entity_text.split()) < 1:
501
  continue
502
  while entity_text and (
503
+ entity_text.split()[0].lower() in stop_words
504
+ or entity_text[0] in punctuation
505
  ):
506
  if entity_text.split()[0].lower() in stop_words:
507
  stopword_len = (
508
+ len(entity_text.split()[0]) + 1
509
  ) # Adjust length for stopword and following space
510
  entity_text = entity_text[stopword_len:] # Remove leading stopword
511
  lOffset += stopword_len # Adjust the left offset
 
524
  # Remove stopwords and punctuation from the end
525
  if len(entity_text.strip()) > 1:
526
  while (
527
+ entity_text.strip().split()
528
+ and (
529
+ entity_text.strip().split()[-1].lower() in stop_words
530
+ or entity_text[-1] in punctuation
531
+ )
532
  ):
533
  if entity_text.strip().split() and entity_text.strip().split()[-1].lower() in stop_words:
534
  stopword_len = len(entity_text.strip().split()[-1]) + 1 # account for space
 
566
  continue
567
  # Check if the entire entity is made up of stopwords characters
568
  if all(
569
+ [char.lower() in stop_words for char in entity_text if char.isalpha()]
570
  ):
571
  if DEBUG:
572
  print(
 
583
  # entities.remove(entity)
584
  continue
585
  if all(
586
+ [
587
+ char.lower() in string.punctuation
588
+ for char in entity_text
589
+ if char.isalpha()
590
+ ]
591
  ):
592
  if DEBUG:
593
  print(
 
629
  if DEBUG:
630
  print(f"Remained entities in remove_trailing_stopwords: {len(new_entities)}")
631
  return new_entities
632
+
633
 
634
  class MultitaskTokenClassificationPipeline(Pipeline):
635
 
 
676
  def is_within(self, entity1, entity2):
677
  """Check if entity1 is fully within the bounds of entity2."""
678
  return (
679
+ entity1["lOffset"] >= entity2["lOffset"]
680
+ and entity1["rOffset"] <= entity2["rOffset"]
681
  )
682
 
683
  def postprocess(self, outputs, **kwargs):
modeling_stacked.py CHANGED
@@ -16,29 +16,26 @@ def get_info(label_map):
16
  return num_token_labels_dict
17
 
18
 
19
- class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
20
-
21
  config_class = ImpressoConfig
22
  _keys_to_ignore_on_load_missing = [r"position_ids"]
23
 
24
- def __init__(self, config):
25
  super().__init__(config)
26
- self.num_token_labels_dict = get_info(config.label_map)
27
  self.config = config
28
-
29
- self.bert = AutoModel.from_pretrained(
30
- config.pretrained_config["_name_or_path"], config=config.pretrained_config
31
- )
32
- if "classifier_dropout" not in config.__dict__:
33
- classifier_dropout = 0.1
34
- else:
35
- classifier_dropout = (
36
- config.classifier_dropout
37
- if config.classifier_dropout is not None
38
- else config.hidden_dropout_prob
39
- )
40
  self.dropout = nn.Dropout(classifier_dropout)
41
 
 
 
 
42
  # Additional transformer layers
43
  self.transformer_encoder = nn.TransformerEncoder(
44
  nn.TransformerEncoderLayer(
@@ -46,71 +43,72 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
46
  ),
47
  num_layers=2,
48
  )
 
 
 
 
49
 
50
- # For token classification, create a classifier for each task
51
- self.token_classifiers = nn.ModuleDict(
52
- {
53
- task: nn.Linear(config.hidden_size, num_labels)
54
- for task, num_labels in self.num_token_labels_dict.items()
55
- }
56
- )
57
-
58
- # Initialize weights and apply final processing
59
  self.post_init()
60
 
61
  def forward(
62
- self,
63
- input_ids: Optional[torch.Tensor] = None,
64
- attention_mask: Optional[torch.Tensor] = None,
65
- token_type_ids: Optional[torch.Tensor] = None,
66
- position_ids: Optional[torch.Tensor] = None,
67
- head_mask: Optional[torch.Tensor] = None,
68
- inputs_embeds: Optional[torch.Tensor] = None,
69
- labels: Optional[torch.Tensor] = None,
70
- token_labels: Optional[dict] = None,
71
- output_attentions: Optional[bool] = None,
72
- output_hidden_states: Optional[bool] = None,
73
- return_dict: Optional[bool] = None,
 
 
 
 
74
  ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
75
- r"""
76
- token_labels (`dict` of `torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*):
77
- Labels for computing the token classification loss. Keys should match the tasks.
78
- """
79
- return_dict = (
80
- return_dict if return_dict is not None else self.config.use_return_dict
81
- )
 
 
 
82
 
83
  bert_kwargs = {
84
- "input_ids": input_ids,
 
85
  "attention_mask": attention_mask,
86
  "token_type_ids": token_type_ids,
87
  "position_ids": position_ids,
88
  "head_mask": head_mask,
89
- "inputs_embeds": inputs_embeds,
90
  "output_attentions": output_attentions,
91
  "output_hidden_states": output_hidden_states,
92
  "return_dict": return_dict,
93
  }
94
 
95
- if any(
96
- keyword in self.config.name_or_path.lower()
97
- for keyword in ["llama", "deberta"]
98
- ):
99
- bert_kwargs.pop("token_type_ids")
100
- bert_kwargs.pop("head_mask")
101
 
102
- outputs = self.bert(**bert_kwargs)
103
-
104
- # For token classification
105
- token_output = outputs[0]
106
- token_output = self.dropout(token_output)
107
 
108
  # Pass through additional transformer layers
109
  token_output = self.transformer_encoder(token_output.transpose(0, 1)).transpose(
110
  0, 1
111
  )
 
 
 
 
 
112
 
113
- # Collect the logits and compute the loss for each task
114
  task_logits = {}
115
  total_loss = 0
116
  for task, classifier in self.token_classifiers.items():
@@ -131,6 +129,115 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
131
  return TokenClassifierOutput(
132
  loss=total_loss,
133
  logits=task_logits,
134
- hidden_states=outputs.hidden_states,
135
- attentions=outputs.attentions,
136
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  return num_token_labels_dict
17
 
18
 
19
+ class ExtendedMultitaskTimeModelForTokenClassification(PreTrainedModel):
 
20
  config_class = ImpressoConfig
21
  _keys_to_ignore_on_load_missing = [r"position_ids"]
22
 
23
+ def __init__(self, config, num_token_labels_dict, temporal_fusion_strategy="baseline", num_years=327):
24
  super().__init__(config)
 
25
  self.config = config
26
+ self.num_token_labels_dict = num_token_labels_dict
27
+ self.temporal_fusion_strategy = temporal_fusion_strategy
28
+ self.model = AutoModel.from_pretrained(config.name_or_path, config=config)
29
+ self.model.config.use_cache = False
30
+ self.model.config.pretraining_tp = 1
31
+ self.num_years = num_years
32
+
33
+ classifier_dropout = getattr(config, "classifier_dropout", 0.1) or config.hidden_dropout_prob
 
 
 
 
34
  self.dropout = nn.Dropout(classifier_dropout)
35
 
36
+ self.temporal_fusion = TemporalFusion(config.hidden_size, strategy=self.temporal_fusion_strategy,
37
+ num_years=num_years)
38
+
39
  # Additional transformer layers
40
  self.transformer_encoder = nn.TransformerEncoder(
41
  nn.TransformerEncoderLayer(
 
43
  ),
44
  num_layers=2,
45
  )
46
+ self.token_classifiers = nn.ModuleDict({
47
+ task: nn.Linear(config.hidden_size, num_labels)
48
+ for task, num_labels in num_token_labels_dict.items()
49
+ })
50
 
 
 
 
 
 
 
 
 
 
51
  self.post_init()
52
 
53
  def forward(
54
+ self,
55
+ input_ids: Optional[torch.Tensor] = None,
56
+ attention_mask: Optional[torch.Tensor] = None,
57
+ token_type_ids: Optional[torch.Tensor] = None,
58
+ position_ids: Optional[torch.Tensor] = None,
59
+ head_mask: Optional[torch.Tensor] = None,
60
+ labels: Optional[torch.Tensor] = None,
61
+ inputs_embeds: Optional[torch.Tensor] = None,
62
+ token_labels: Optional[dict] = None,
63
+ date_indices: Optional[torch.Tensor] = None,
64
+ year_index: Optional[torch.Tensor] = None,
65
+ decade_index: Optional[torch.Tensor] = None,
66
+ century_index: Optional[torch.Tensor] = None,
67
+ output_attentions: Optional[bool] = None,
68
+ output_hidden_states: Optional[bool] = None,
69
+ return_dict: Optional[bool] = None,
70
  ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
71
+
72
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
73
+
74
+ if inputs_embeds is None:
75
+ inputs_embeds = self.model.embeddings(input_ids)
76
+
77
+ # Early cross-attention fusion
78
+ if self.temporal_fusion_strategy == "early-cross-attention":
79
+ year_emb = self.temporal_fusion.compute_time_embedding(year_index) # (B, H)
80
+ inputs_embeds = self.temporal_fusion.cross_attn(inputs_embeds, year_emb)
81
 
82
  bert_kwargs = {
83
+ "inputs_embeds": inputs_embeds if self.temporal_fusion_strategy == "early-cross-attention" else None,
84
+ "input_ids": input_ids if self.temporal_fusion_strategy != "early-cross-attention" else None,
85
  "attention_mask": attention_mask,
86
  "token_type_ids": token_type_ids,
87
  "position_ids": position_ids,
88
  "head_mask": head_mask,
 
89
  "output_attentions": output_attentions,
90
  "output_hidden_states": output_hidden_states,
91
  "return_dict": return_dict,
92
  }
93
 
94
+ if any(keyword in self.config.name_or_path.lower() for keyword in ["llama", "deberta"]):
95
+ bert_kwargs.pop("token_type_ids", None)
96
+ bert_kwargs.pop("head_mask", None)
 
 
 
97
 
98
+ outputs = self.model(**bert_kwargs)
99
+ token_output = self.dropout(outputs[0]) # (B, T, H)
100
+ hidden_states = list(outputs.hidden_states) if output_hidden_states else None
 
 
101
 
102
  # Pass through additional transformer layers
103
  token_output = self.transformer_encoder(token_output.transpose(0, 1)).transpose(
104
  0, 1
105
  )
106
+ # Apply fusion after transformer if needed
107
+ if self.temporal_fusion_strategy not in ["baseline", "early-cross-attention"]:
108
+ token_output = self.temporal_fusion(token_output, year_index)
109
+ if output_hidden_states:
110
+ hidden_states.append(token_output) # add the final fused state
111
 
 
112
  task_logits = {}
113
  total_loss = 0
114
  for task, classifier in self.token_classifiers.items():
 
129
  return TokenClassifierOutput(
130
  loss=total_loss,
131
  logits=task_logits,
132
+ hidden_states=tuple(hidden_states) if hidden_states is not None else None,
133
+ attentions=outputs.attentions if output_attentions else None,
134
  )
135
+
136
+
137
+ class TemporalFusion(nn.Module):
138
+ def __init__(self, hidden_size, strategy="add", num_years=327, min_year=1700):
139
+ super().__init__()
140
+ self.strategy = strategy
141
+ self.hidden_size = hidden_size
142
+ self.min_year = min_year
143
+ self.max_year = min_year + num_years - 1
144
+
145
+ self.year_emb = nn.Embedding(num_years, hidden_size)
146
+
147
+ if strategy == "concat":
148
+ self.concat_proj = nn.Linear(hidden_size * 2, hidden_size)
149
+ elif strategy == "film":
150
+ self.film_gamma = nn.Linear(hidden_size, hidden_size)
151
+ self.film_beta = nn.Linear(hidden_size, hidden_size)
152
+ elif strategy == "adapter":
153
+ self.adapter = nn.Sequential(
154
+ nn.Linear(hidden_size, hidden_size),
155
+ nn.ReLU(),
156
+ nn.Linear(hidden_size, hidden_size),
157
+ )
158
+ elif strategy == "relative":
159
+ self.relative_encoder = nn.Sequential(
160
+ nn.Linear(hidden_size, hidden_size),
161
+ nn.SiLU(),
162
+ nn.LayerNorm(hidden_size),
163
+ )
164
+ self.film_gamma = nn.Linear(hidden_size, hidden_size)
165
+ self.film_beta = nn.Linear(hidden_size, hidden_size)
166
+ elif strategy == "multiscale":
167
+ self.decade_emb = nn.Embedding(1000, hidden_size)
168
+ self.century_emb = nn.Embedding(100, hidden_size)
169
+ elif strategy in ["early-cross-attention", "late-cross-attention"]:
170
+ self.year_encoder = nn.Sequential(
171
+ nn.Linear(hidden_size, hidden_size),
172
+ nn.SiLU()
173
+ )
174
+ self.cross_attn = TemporalCrossAttention(hidden_size)
175
+
176
+ def compute_time_embedding(self, year_index):
177
+ if self.strategy in ["early-cross-attention", "late-cross-attention"]:
178
+ return self.year_encoder(self.year_emb(year_index))
179
+ elif self.strategy == "multiscale":
180
+ year_index = year_index.long()
181
+ year = year_index + self.min_year
182
+ decade = (year // 10).long()
183
+ century = (year // 100).long()
184
+ return (
185
+ self.year_emb(year_index) +
186
+ self.decade_emb(decade) +
187
+ self.century_emb(century)
188
+ )
189
+ else:
190
+ return self.year_emb(year_index)
191
+
192
+ def forward(self, token_output, year_index):
193
+ B, T, H = token_output.size()
194
+
195
+ if self.strategy == "baseline":
196
+ return token_output
197
+
198
+ year_emb = self.compute_time_embedding(year_index)
199
+
200
+ if self.strategy == "concat":
201
+ expanded_year = year_emb.unsqueeze(1).repeat(1, T, 1)
202
+ fused = torch.cat([token_output, expanded_year], dim=-1)
203
+ return self.concat_proj(fused)
204
+
205
+ elif self.strategy == "film":
206
+ gamma = self.film_gamma(year_emb).unsqueeze(1)
207
+ beta = self.film_beta(year_emb).unsqueeze(1)
208
+ return gamma * token_output + beta
209
+
210
+ elif self.strategy == "adapter":
211
+ return token_output + self.adapter(year_emb).unsqueeze(1)
212
+
213
+ elif self.strategy == "add":
214
+ expanded_year = year_emb.unsqueeze(1).repeat(1, T, 1)
215
+ return token_output + expanded_year
216
+
217
+ elif self.strategy == "relative":
218
+ encoded = self.relative_encoder(year_emb)
219
+ gamma = self.film_gamma(encoded).unsqueeze(1)
220
+ beta = self.film_beta(encoded).unsqueeze(1)
221
+ return gamma * token_output + beta
222
+
223
+ elif self.strategy == "multiscale":
224
+ expanded_year = year_emb.unsqueeze(1).expand(-1, T, -1)
225
+ return token_output + expanded_year
226
+
227
+ elif self.strategy == "late-cross-attention":
228
+ return self.cross_attn(token_output, year_emb)
229
+
230
+ else:
231
+ raise ValueError(f"Unknown fusion strategy: {self.strategy}")
232
+
233
+
234
+ class TemporalCrossAttention(nn.Module):
235
+ def __init__(self, hidden_size, num_heads=4):
236
+ super().__init__()
237
+ self.attn = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=num_heads, batch_first=True)
238
+
239
+ def forward(self, token_output, time_embedding):
240
+ # token_output: (B, T, H), time_embedding: (B, H)
241
+ time_as_seq = time_embedding.unsqueeze(1) # (B, 1, H)
242
+ attn_output, _ = self.attn(token_output, time_as_seq, time_as_seq)
243
+ return token_output + attn_output
old/config.json ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "experiments_final/model_dbmdz_bert_medium_historic_multilingual_cased_max_sequence_length_512_epochs_5_run_extended_suffix_baseline/checkpoint-450",
3
+ "architectures": [
4
+ "ExtendedMultitaskModelForTokenClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_stacked.ImpressoConfig",
9
+ "AutoModelForTokenClassification": "modeling_stacked.ExtendedMultitaskModelForTokenClassification"
10
+ },
11
+ "classifier_dropout": null,
12
+ "custom_pipelines": {
13
+ "generic-ner": {
14
+ "impl": "generic_ner.MultitaskTokenClassificationPipeline",
15
+ "pt": "AutoModelForTokenClassification"
16
+ }
17
+ },
18
+ "hidden_act": "gelu",
19
+ "hidden_dropout_prob": 0.1,
20
+ "hidden_size": 512,
21
+ "initializer_range": 0.02,
22
+ "intermediate_size": 2048,
23
+ "label_map": {
24
+ "NE-COARSE-LIT": {
25
+ "B-loc": 8,
26
+ "B-org": 0,
27
+ "B-pers": 7,
28
+ "B-prod": 4,
29
+ "B-time": 5,
30
+ "I-loc": 1,
31
+ "I-org": 2,
32
+ "I-pers": 9,
33
+ "I-prod": 10,
34
+ "I-time": 6,
35
+ "O": 3
36
+ },
37
+ "NE-COARSE-METO": {
38
+ "B-loc": 3,
39
+ "B-org": 0,
40
+ "B-time": 5,
41
+ "I-loc": 4,
42
+ "I-org": 2,
43
+ "O": 1
44
+ },
45
+ "NE-FINE-COMP": {
46
+ "B-comp.demonym": 8,
47
+ "B-comp.function": 5,
48
+ "B-comp.name": 1,
49
+ "B-comp.qualifier": 9,
50
+ "B-comp.title": 2,
51
+ "I-comp.demonym": 7,
52
+ "I-comp.function": 3,
53
+ "I-comp.name": 0,
54
+ "I-comp.qualifier": 10,
55
+ "I-comp.title": 4,
56
+ "O": 6
57
+ },
58
+ "NE-FINE-LIT": {
59
+ "B-loc.add.elec": 32,
60
+ "B-loc.add.phys": 5,
61
+ "B-loc.adm.nat": 34,
62
+ "B-loc.adm.reg": 39,
63
+ "B-loc.adm.sup": 12,
64
+ "B-loc.adm.town": 33,
65
+ "B-loc.fac": 36,
66
+ "B-loc.oro": 19,
67
+ "B-loc.phys.geo": 13,
68
+ "B-loc.phys.hydro": 28,
69
+ "B-loc.unk": 4,
70
+ "B-org.adm": 3,
71
+ "B-org.ent": 24,
72
+ "B-org.ent.pressagency": 37,
73
+ "B-pers.coll": 9,
74
+ "B-pers.ind": 0,
75
+ "B-pers.ind.articleauthor": 20,
76
+ "B-prod.doctr": 2,
77
+ "B-prod.media": 10,
78
+ "B-time.date.abs": 23,
79
+ "I-loc.add.elec": 22,
80
+ "I-loc.add.phys": 6,
81
+ "I-loc.adm.nat": 11,
82
+ "I-loc.adm.reg": 35,
83
+ "I-loc.adm.sup": 15,
84
+ "I-loc.adm.town": 8,
85
+ "I-loc.fac": 27,
86
+ "I-loc.oro": 21,
87
+ "I-loc.phys.geo": 25,
88
+ "I-loc.phys.hydro": 17,
89
+ "I-loc.unk": 40,
90
+ "I-org.adm": 29,
91
+ "I-org.ent": 1,
92
+ "I-org.ent.pressagency": 14,
93
+ "I-pers.coll": 26,
94
+ "I-pers.ind": 16,
95
+ "I-pers.ind.articleauthor": 31,
96
+ "I-prod.doctr": 30,
97
+ "I-prod.media": 38,
98
+ "I-time.date.abs": 7,
99
+ "O": 18
100
+ },
101
+ "NE-FINE-METO": {
102
+ "B-loc.adm.town": 6,
103
+ "B-loc.fac": 3,
104
+ "B-loc.oro": 5,
105
+ "B-org.adm": 1,
106
+ "B-org.ent": 7,
107
+ "B-time.date.abs": 9,
108
+ "I-loc.fac": 8,
109
+ "I-org.adm": 2,
110
+ "I-org.ent": 0,
111
+ "O": 4
112
+ },
113
+ "NE-NESTED": {
114
+ "B-loc.adm.nat": 13,
115
+ "B-loc.adm.reg": 15,
116
+ "B-loc.adm.sup": 10,
117
+ "B-loc.adm.town": 9,
118
+ "B-loc.fac": 18,
119
+ "B-loc.oro": 17,
120
+ "B-loc.phys.geo": 11,
121
+ "B-loc.phys.hydro": 1,
122
+ "B-org.adm": 4,
123
+ "B-org.ent": 20,
124
+ "B-pers.coll": 7,
125
+ "B-pers.ind": 2,
126
+ "B-prod.media": 23,
127
+ "I-loc.adm.nat": 8,
128
+ "I-loc.adm.reg": 14,
129
+ "I-loc.adm.town": 6,
130
+ "I-loc.fac": 0,
131
+ "I-loc.oro": 19,
132
+ "I-loc.phys.geo": 21,
133
+ "I-loc.phys.hydro": 22,
134
+ "I-org.adm": 5,
135
+ "I-org.ent": 3,
136
+ "I-pers.ind": 12,
137
+ "I-prod.media": 24,
138
+ "O": 16
139
+ }
140
+ },
141
+ "layer_norm_eps": 1e-12,
142
+ "max_position_embeddings": 512,
143
+ "model_type": "stacked_bert",
144
+ "num_attention_heads": 8,
145
+ "num_hidden_layers": 8,
146
+ "pad_token_id": 0,
147
+ "position_embedding_type": "absolute",
148
+ "pretrained_config": {
149
+ "_name_or_path": "dbmdz/bert-medium-historic-multilingual-cased",
150
+ "add_cross_attention": false,
151
+ "architectures": [
152
+ "BertForMaskedLM"
153
+ ],
154
+ "attention_probs_dropout_prob": 0.1,
155
+ "bad_words_ids": null,
156
+ "begin_suppress_tokens": null,
157
+ "bos_token_id": null,
158
+ "chunk_size_feed_forward": 0,
159
+ "classifier_dropout": null,
160
+ "cross_attention_hidden_size": null,
161
+ "decoder_start_token_id": null,
162
+ "diversity_penalty": 0.0,
163
+ "do_sample": false,
164
+ "early_stopping": false,
165
+ "encoder_no_repeat_ngram_size": 0,
166
+ "eos_token_id": null,
167
+ "exponential_decay_length_penalty": null,
168
+ "finetuning_task": null,
169
+ "forced_bos_token_id": null,
170
+ "forced_eos_token_id": null,
171
+ "hidden_act": "gelu",
172
+ "hidden_dropout_prob": 0.1,
173
+ "hidden_size": 512,
174
+ "id2label": {
175
+ "0": "LABEL_0",
176
+ "1": "LABEL_1"
177
+ },
178
+ "initializer_range": 0.02,
179
+ "intermediate_size": 2048,
180
+ "is_decoder": false,
181
+ "is_encoder_decoder": false,
182
+ "label2id": {
183
+ "LABEL_0": 0,
184
+ "LABEL_1": 1
185
+ },
186
+ "layer_norm_eps": 1e-12,
187
+ "length_penalty": 1.0,
188
+ "max_length": 20,
189
+ "max_position_embeddings": 512,
190
+ "min_length": 0,
191
+ "model_type": "bert",
192
+ "no_repeat_ngram_size": 0,
193
+ "num_attention_heads": 8,
194
+ "num_beam_groups": 1,
195
+ "num_beams": 1,
196
+ "num_hidden_layers": 8,
197
+ "num_return_sequences": 1,
198
+ "output_attentions": false,
199
+ "output_hidden_states": false,
200
+ "output_scores": false,
201
+ "pad_token_id": 0,
202
+ "position_embedding_type": "absolute",
203
+ "prefix": null,
204
+ "problem_type": null,
205
+ "pruned_heads": {},
206
+ "remove_invalid_values": false,
207
+ "repetition_penalty": 1.0,
208
+ "return_dict": true,
209
+ "return_dict_in_generate": false,
210
+ "sep_token_id": null,
211
+ "suppress_tokens": null,
212
+ "task_specific_params": null,
213
+ "temperature": 1.0,
214
+ "tf_legacy_loss": false,
215
+ "tie_encoder_decoder": false,
216
+ "tie_word_embeddings": true,
217
+ "tokenizer_class": null,
218
+ "top_k": 50,
219
+ "top_p": 1.0,
220
+ "torch_dtype": null,
221
+ "torchscript": false,
222
+ "type_vocab_size": 2,
223
+ "typical_p": 1.0,
224
+ "use_bfloat16": false,
225
+ "use_cache": true,
226
+ "vocab_size": 32000
227
+ },
228
+ "torch_dtype": "float32",
229
+ "transformers_version": "4.40.0.dev0",
230
+ "type_vocab_size": 2,
231
+ "use_cache": true,
232
+ "vocab_size": 32000
233
+ }
old/configuration_stacked.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ import torch
3
+
4
+
5
+ class ImpressoConfig(PretrainedConfig):
6
+ model_type = "stacked_bert"
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size=30522,
11
+ hidden_size=768,
12
+ num_hidden_layers=12,
13
+ num_attention_heads=12,
14
+ intermediate_size=3072,
15
+ hidden_act="gelu",
16
+ hidden_dropout_prob=0.1,
17
+ attention_probs_dropout_prob=0.1,
18
+ max_position_embeddings=512,
19
+ type_vocab_size=2,
20
+ initializer_range=0.02,
21
+ layer_norm_eps=1e-12,
22
+ pad_token_id=0,
23
+ position_embedding_type="absolute",
24
+ use_cache=True,
25
+ classifier_dropout=None,
26
+ pretrained_config=None,
27
+ values_override=None,
28
+ label_map=None,
29
+ **kwargs,
30
+ ):
31
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
32
+
33
+ self.vocab_size = vocab_size
34
+ self.hidden_size = hidden_size
35
+ self.num_hidden_layers = num_hidden_layers
36
+ self.num_attention_heads = num_attention_heads
37
+ self.hidden_act = hidden_act
38
+ self.intermediate_size = intermediate_size
39
+ self.hidden_dropout_prob = hidden_dropout_prob
40
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
41
+ self.max_position_embeddings = max_position_embeddings
42
+ self.type_vocab_size = type_vocab_size
43
+ self.initializer_range = initializer_range
44
+ self.layer_norm_eps = layer_norm_eps
45
+ self.position_embedding_type = position_embedding_type
46
+ self.use_cache = use_cache
47
+ self.classifier_dropout = classifier_dropout
48
+ self.pretrained_config = pretrained_config
49
+ self.label_map = label_map
50
+
51
+ self.values_override = values_override or {}
52
+ self.outputs = {
53
+ "logits": {"shape": [None, None, self.hidden_size], "dtype": "float32"}
54
+ }
55
+
56
+ @classmethod
57
+ def is_torch_support_available(cls):
58
+ """
59
+ Indicate whether Torch support is available for this configuration.
60
+ Required for compatibility with certain parts of the Transformers library.
61
+ """
62
+ return True
63
+
64
+ @classmethod
65
+ def patch_ops(self):
66
+ """
67
+ A method required by some Hugging Face utilities to modify operator mappings.
68
+ Currently, it performs no operation and is included for compatibility.
69
+ Args:
70
+ ops: A dictionary of operations to potentially patch.
71
+ Returns:
72
+ The (unmodified) ops dictionary.
73
+ """
74
+ return None
75
+
76
+ def generate_dummy_inputs(self, tokenizer, batch_size=1, seq_length=8, framework="pt"):
77
+ """
78
+ Generate dummy inputs for testing or export.
79
+ Args:
80
+ tokenizer: The tokenizer used to tokenize inputs.
81
+ batch_size: Number of input samples in the batch.
82
+ seq_length: Length of each sequence.
83
+ framework: Framework ("pt" for PyTorch, "tf" for TensorFlow).
84
+ Returns:
85
+ Dummy inputs as a dictionary.
86
+ """
87
+ if framework == "pt":
88
+ input_ids = torch.randint(
89
+ low=0,
90
+ high=self.vocab_size,
91
+ size=(batch_size, seq_length),
92
+ dtype=torch.long
93
+ )
94
+ attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
95
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
96
+ else:
97
+ raise ValueError("Framework '{}' not supported.".format(framework))
98
+
99
+
100
+ # Register the configuration with the transformers library
101
+ ImpressoConfig.register_for_auto_class()
old/generic_ner.py ADDED
@@ -0,0 +1,789 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from transformers import Pipeline
3
+ import numpy as np
4
+ import torch
5
+ import nltk
6
+
7
+ nltk.download("averaged_perceptron_tagger")
8
+ nltk.download("averaged_perceptron_tagger_eng")
9
+ nltk.download("stopwords")
10
+ from nltk.chunk import conlltags2tree
11
+ from nltk import pos_tag
12
+ from nltk.tree import Tree
13
+ import torch.nn.functional as F
14
+ import re, string
15
+
16
+ stop_words = set(nltk.corpus.stopwords.words("english"))
17
+ DEBUG = False
18
+ punctuation = (
19
+ string.punctuation
20
+ + "«»—…“”"
21
+ + "—."
22
+ + "–"
23
+ + "’"
24
+ + "‘"
25
+ + "´"
26
+ + "•"
27
+ + "°"
28
+ + "»"
29
+ + "“"
30
+ + "”"
31
+ + "–"
32
+ + "—"
33
+ + "‘’“”„«»•–—―‣◦…§¶†‡‰′″〈〉"
34
+ )
35
+
36
+ # List of additional "strange" punctuation marks
37
+ # additional_punctuation = "‘’“”„«»•–—―‣◦…§¶†‡‰′″〈〉"
38
+
39
+
40
+ WHITESPACE_RULES = {
41
+ "fr": {
42
+ "pct_no_ws_before": [".", ",", ")", "]", "}", "°", "...", ".-", "%"],
43
+ "pct_no_ws_after": ["(", "[", "{"],
44
+ "pct_no_ws_before_after": ["'", "-"],
45
+ "pct_number": [".", ","],
46
+ },
47
+ "de": {
48
+ "pct_no_ws_before": [
49
+ ".",
50
+ ",",
51
+ ")",
52
+ "]",
53
+ "}",
54
+ "°",
55
+ "...",
56
+ "?",
57
+ "!",
58
+ ":",
59
+ ";",
60
+ ".-",
61
+ "%",
62
+ ],
63
+ "pct_no_ws_after": ["(", "[", "{"],
64
+ "pct_no_ws_before_after": ["'", "-"],
65
+ "pct_number": [".", ","],
66
+ },
67
+ "other": {
68
+ "pct_no_ws_before": [
69
+ ".",
70
+ ",",
71
+ ")",
72
+ "]",
73
+ "}",
74
+ "°",
75
+ "...",
76
+ "?",
77
+ "!",
78
+ ":",
79
+ ";",
80
+ ".-",
81
+ "%",
82
+ ],
83
+ "pct_no_ws_after": ["(", "[", "{"],
84
+ "pct_no_ws_before_after": ["'", "-"],
85
+ "pct_number": [".", ","],
86
+ },
87
+ }
88
+
89
+
90
+ # def tokenize(text: str, language: str = "other") -> list[str]:
91
+ # """Apply whitespace rules to the given text and language, separating it into tokens.
92
+ #
93
+ # Args:
94
+ # text (str): The input text to separate into a list of tokens.
95
+ # language (str): Language of the text.
96
+ #
97
+ # Returns:
98
+ # list[str]: List of tokens with punctuation as separate tokens.
99
+ # """
100
+ # # text = add_spaces_around_punctuation(text)
101
+ # if not text:
102
+ # return []
103
+ #
104
+ # if language not in WHITESPACE_RULES:
105
+ # # Default behavior for languages without specific rules:
106
+ # # tokenize using standard whitespace splitting
107
+ # language = "other"
108
+ #
109
+ # wsrules = WHITESPACE_RULES[language]
110
+ # tokenized_text = []
111
+ # current_token = ""
112
+ #
113
+ # for char in text:
114
+ # if char in wsrules["pct_no_ws_before_after"]:
115
+ # if current_token:
116
+ # tokenized_text.append(current_token)
117
+ # tokenized_text.append(char)
118
+ # current_token = ""
119
+ # elif char in wsrules["pct_no_ws_before"] or char in wsrules["pct_no_ws_after"]:
120
+ # if current_token:
121
+ # tokenized_text.append(current_token)
122
+ # tokenized_text.append(char)
123
+ # current_token = ""
124
+ # elif char.isspace():
125
+ # if current_token:
126
+ # tokenized_text.append(current_token)
127
+ # current_token = ""
128
+ # else:
129
+ # current_token += char
130
+ #
131
+ # if current_token:
132
+ # tokenized_text.append(current_token)
133
+ #
134
+ # return tokenized_text
135
+
136
+
137
+ def normalize_text(text):
138
+ # Remove spaces and tabs for the search but keep newline characters
139
+ return re.sub(r"[ \t]+", "", text)
140
+
141
+
142
+ def find_entity_indices(article_text, search_text):
143
+ # Normalize texts by removing spaces and tabs
144
+ normalized_article = normalize_text(article_text)
145
+ normalized_search = normalize_text(search_text)
146
+
147
+ # Initialize a list to hold all start and end indices
148
+ indices = []
149
+
150
+ # Find all occurrences of the search text in the normalized article text
151
+ start_index = 0
152
+ while True:
153
+ start_index = normalized_article.find(normalized_search, start_index)
154
+ if start_index == -1:
155
+ break
156
+
157
+ # Calculate the actual start and end indices in the original article text
158
+ original_chars = 0
159
+ original_start_index = 0
160
+ for i in range(start_index):
161
+ while article_text[original_start_index] in (" ", "\t"):
162
+ original_start_index += 1
163
+ if article_text[original_start_index] not in (" ", "\t", "\n"):
164
+ original_chars += 1
165
+ original_start_index += 1
166
+
167
+ original_end_index = original_start_index
168
+ search_chars = 0
169
+ while search_chars < len(normalized_search):
170
+ if article_text[original_end_index] not in (" ", "\t", "\n"):
171
+ search_chars += 1
172
+ original_end_index += 1 # Increment to include the last character
173
+
174
+ # Append the found indices to the list
175
+ if article_text[original_start_index] == " ":
176
+ original_start_index += 1
177
+ indices.append((original_start_index, original_end_index))
178
+
179
+ # Move start_index to the next position to continue searching
180
+ start_index += 1
181
+
182
+ return indices
183
+
184
+
185
+ def get_entities(tokens, tags, confidences, text):
186
+
187
+ tags = [tag.replace("S-", "B-").replace("E-", "I-") for tag in tags]
188
+ pos_tags = [pos for token, pos in pos_tag(tokens)]
189
+
190
+ for i in range(1, len(tags)):
191
+ # If a 'B-' tag is followed by another 'B-' without an 'O' in between, change the second to 'I-'
192
+ if tags[i].startswith("B-") and tags[i - 1].startswith("I-"):
193
+ tags[i] = "I-" + tags[i][2:] # Change 'B-' to 'I-' for the same entity type
194
+
195
+ conlltags = [(token, pos, tg) for token, pos, tg in zip(tokens, pos_tags, tags)]
196
+ ne_tree = conlltags2tree(conlltags)
197
+
198
+ entities = []
199
+ idx: int = 0
200
+ already_done = []
201
+ for subtree in ne_tree:
202
+ # skipping 'O' tags
203
+ if isinstance(subtree, Tree):
204
+ original_label = subtree.label()
205
+ original_string = " ".join([token for token, pos in subtree.leaves()])
206
+
207
+ for indices in find_entity_indices(text, original_string):
208
+ entity_start_position = indices[0]
209
+ entity_end_position = indices[1]
210
+ if (
211
+ "_".join(
212
+ [original_label, original_string, str(entity_start_position)]
213
+ )
214
+ in already_done
215
+ ):
216
+ continue
217
+ else:
218
+ already_done.append(
219
+ "_".join(
220
+ [
221
+ original_label,
222
+ original_string,
223
+ str(entity_start_position),
224
+ ]
225
+ )
226
+ )
227
+ if len(text[entity_start_position:entity_end_position].strip()) < len(
228
+ text[entity_start_position:entity_end_position]
229
+ ):
230
+ entity_start_position = (
231
+ entity_start_position
232
+ + len(text[entity_start_position:entity_end_position])
233
+ - len(text[entity_start_position:entity_end_position].strip())
234
+ )
235
+
236
+ entities.append(
237
+ {
238
+ "type": original_label,
239
+ "confidence_ner": round(
240
+ np.average(confidences[idx : idx + len(subtree)]), 2
241
+ ),
242
+ "index": (idx, idx + len(subtree)),
243
+ "surface": text[
244
+ entity_start_position:entity_end_position
245
+ ], # original_string,
246
+ "lOffset": entity_start_position,
247
+ "rOffset": entity_end_position,
248
+ }
249
+ )
250
+
251
+ idx += len(subtree)
252
+
253
+ # Update the current character position
254
+ # We add the length of the original string + 1 (for the space)
255
+ else:
256
+ token, pos = subtree
257
+ # If it's not a named entity, we still need to update the character
258
+ # position
259
+ idx += 1
260
+
261
+ return entities
262
+
263
+
264
+ def realign(word_ids, tokens, out_label_preds, softmax_scores, tokenizer, reverted_label_map):
265
+ preds_list, words_list, confidence_list = [], [], []
266
+
267
+ seen_word_ids = set()
268
+ for i, word_id in enumerate(word_ids):
269
+ if word_id is None or word_id in seen_word_ids:
270
+ continue # skip special tokens or repeated subwords
271
+
272
+ seen_word_ids.add(word_id)
273
+
274
+ try:
275
+ preds_list.append(reverted_label_map[out_label_preds[i]])
276
+ confidence_list.append(max(softmax_scores[i]))
277
+ except Exception:
278
+ preds_list.append("O")
279
+ confidence_list.append(0.0)
280
+
281
+ words_list.append(tokens[word_id]) # original word list index
282
+
283
+ return words_list, preds_list, confidence_list
284
+
285
+ def add_spaces_around_punctuation(text):
286
+ # Add a space before and after all punctuation
287
+ all_punctuation = string.punctuation + punctuation
288
+ return re.sub(r"([{}])".format(re.escape(all_punctuation)), r" \1 ", text)
289
+
290
+
291
+ def attach_comp_to_closest(entities):
292
+ # Define valid entity types that can receive a "comp.function" or "comp.name" attachment
293
+ valid_entity_types = {"org", "pers", "org.ent", "pers.ind"}
294
+
295
+ # Separate "comp.function" and "comp.name" entities from other entities
296
+ comp_entities = [ent for ent in entities if ent["type"].startswith("comp")]
297
+ other_entities = [ent for ent in entities if not ent["type"].startswith("comp")]
298
+
299
+ for comp_entity in comp_entities:
300
+ closest_entity = None
301
+ min_distance = float("inf")
302
+
303
+ # Find the closest non-"comp" entity that is valid for attaching
304
+ for other_entity in other_entities:
305
+ # Calculate distance between the comp entity and the other entity
306
+ if comp_entity["lOffset"] > other_entity["rOffset"]:
307
+ distance = comp_entity["lOffset"] - other_entity["rOffset"]
308
+ elif comp_entity["rOffset"] < other_entity["lOffset"]:
309
+ distance = other_entity["lOffset"] - comp_entity["rOffset"]
310
+ else:
311
+ distance = 0 # They overlap or touch
312
+
313
+ # Ensure the entity type is valid and check for minimal distance
314
+ if (
315
+ distance < min_distance
316
+ and other_entity["type"].split(".")[0] in valid_entity_types
317
+ ):
318
+ min_distance = distance
319
+ closest_entity = other_entity
320
+
321
+ # Attach the "comp.function" or "comp.name" if a valid entity is found
322
+ if closest_entity:
323
+ suffix = comp_entity["type"].split(".")[
324
+ -1
325
+ ] # Extract the suffix (e.g., 'name', 'function')
326
+ closest_entity[suffix] = comp_entity["surface"] # Attach the text
327
+
328
+ return other_entities
329
+
330
+
331
+ def conflicting_context(comp_entity, target_entity):
332
+ """
333
+ Determines if there is a conflict between the comp_entity and the target entity.
334
+ Prevents incorrect name and function attachments by using a rule-based approach.
335
+ """
336
+ # Case 1: Check for correct function attachment to person or organization entities
337
+ if comp_entity["type"].startswith("comp.function"):
338
+ if not ("pers" in target_entity["type"] or "org" in target_entity["type"]):
339
+ return True # Conflict: Function should only attach to persons or organizations
340
+
341
+ # Case 2: Avoid attaching comp.* entities to non-person, non-organization types (like locations)
342
+ if "loc" in target_entity["type"]:
343
+ return True # Conflict: comp.* entities should not attach to locations or similar types
344
+
345
+ return False # No conflict
346
+
347
+
348
+ def extract_name_from_text(text, partial_name):
349
+ """
350
+ Extracts the full name from the entity's text based on the partial name.
351
+ This function assumes that the full name starts with capitalized letters and does not
352
+ include any words that come after the partial name.
353
+ """
354
+ # Split the text and partial name into words
355
+ words = text.split()
356
+ partial_words = partial_name.split()
357
+
358
+ if DEBUG:
359
+ print("text:", text)
360
+ if DEBUG:
361
+ print("partial_name:", partial_name)
362
+
363
+ # Find the position of the partial name in the word list
364
+ for i, word in enumerate(words):
365
+ if DEBUG:
366
+ print(words, "---", words[i : i + len(partial_words)])
367
+ if words[i : i + len(partial_words)] == partial_words:
368
+ # Initialize full name with the partial name
369
+ full_name = partial_words[:]
370
+
371
+ if DEBUG:
372
+ print("full_name:", full_name)
373
+
374
+ # Check previous words and only add capitalized words (skip lowercase words)
375
+ j = i - 1
376
+ while j >= 0 and words[j][0].isupper():
377
+ full_name.insert(0, words[j])
378
+ j -= 1
379
+ if DEBUG:
380
+ print("full_name:", full_name)
381
+
382
+ # Return only the full name up to the partial name (ignore words after the name)
383
+ return " ".join(full_name).strip() # Join the words to form the full name
384
+
385
+ # If not found, return the original text (as a fallback)
386
+ return text.strip()
387
+
388
+
389
+ def repair_names_in_entities(entities):
390
+ """
391
+ This function repairs the names in the entities by extracting the full name
392
+ from the text of the entity if a partial name (e.g., 'Washington') is incorrectly attached.
393
+ """
394
+ for entity in entities:
395
+ if "name" in entity and "pers" in entity["type"]:
396
+ name = entity["name"]
397
+ text = entity["surface"]
398
+
399
+ # Check if the attached name is part of the entity's text
400
+ if name in text:
401
+ # Extract the full name from the text by splitting around the attached name
402
+ full_name = extract_name_from_text(entity["surface"], name)
403
+ entity["name"] = (
404
+ full_name # Replace the partial name with the full name
405
+ )
406
+ # if "name" not in entity:
407
+ # entity["name"] = entity["surface"]
408
+
409
+ return entities
410
+
411
+
412
+ def clean_coarse_entities(entities):
413
+ """
414
+ This function removes entities that are not useful for the NEL process.
415
+ """
416
+ # Define a set of entity types that are considered useful for NEL
417
+ useful_types = {
418
+ "pers", # Person
419
+ "loc", # Location
420
+ "org", # Organization
421
+ "date", # Product
422
+ "time", # Time
423
+ }
424
+
425
+ # Filter out entities that are not in the useful_types set unless they are comp.* entities
426
+ cleaned_entities = [
427
+ entity
428
+ for entity in entities
429
+ if entity["type"] in useful_types or "comp" in entity["type"]
430
+ ]
431
+
432
+ return cleaned_entities
433
+
434
+
435
+ def postprocess_entities(entities):
436
+ # Step 1: Filter entities with the same text, keeping the one with the most dots in the 'entity' field
437
+ entity_map = {}
438
+
439
+ # Loop over the entities and prioritize the one with the most dots
440
+ for entity in entities:
441
+ entity_text = entity["surface"]
442
+ num_dots = entity["type"].count(".")
443
+
444
+ # If the entity text is new, or this entity has more dots, update the map
445
+ if (
446
+ entity_text not in entity_map
447
+ or entity_map[entity_text]["type"].count(".") < num_dots
448
+ ):
449
+ entity_map[entity_text] = entity
450
+
451
+ # Collect the filtered entities from the map
452
+ filtered_entities = list(entity_map.values())
453
+
454
+ # Step 2: Attach "comp.function" entities to the closest other entities
455
+ filtered_entities = attach_comp_to_closest(filtered_entities)
456
+ if DEBUG:
457
+ print("After attach_comp_to_closest:", filtered_entities, "\n")
458
+ filtered_entities = repair_names_in_entities(filtered_entities)
459
+ if DEBUG:
460
+ print("After repair_names_in_entities:", filtered_entities, "\n")
461
+
462
+ # Step 3: Remove entities that are not useful for NEL
463
+ # filtered_entities = clean_coarse_entities(filtered_entities)
464
+
465
+ # filtered_entities = remove_blacklisted_entities(filtered_entities)
466
+
467
+ return filtered_entities
468
+
469
+
470
+ def remove_included_entities(entities):
471
+ # Loop through entities and remove those whose text is included in another with the same label
472
+ final_entities = []
473
+ for i, entity in enumerate(entities):
474
+ is_included = False
475
+ for other_entity in entities:
476
+ if entity["surface"] != other_entity["surface"]:
477
+ if "comp" in other_entity["type"]:
478
+ # Check if entity's text is a substring of another entity's text
479
+ if entity["surface"] in other_entity["surface"]:
480
+ is_included = True
481
+ break
482
+ elif (
483
+ entity["type"].split(".")[0] in other_entity["type"].split(".")[0]
484
+ or other_entity["type"].split(".")[0]
485
+ in entity["type"].split(".")[0]
486
+ ):
487
+ if entity["surface"] in other_entity["surface"]:
488
+ is_included = True
489
+ if not is_included:
490
+ final_entities.append(entity)
491
+ return final_entities
492
+
493
+
494
+ def refine_entities_with_coarse(all_entities, coarse_entities):
495
+ """
496
+ Looks through all entities and refines them based on the coarse entities.
497
+ If a surface match is found in the coarse entities and the types match,
498
+ the entity's confidence_ner and type are updated based on the coarse entity.
499
+ """
500
+ # Create a dictionary for coarse entities based on surface and type for quick lookup
501
+ coarse_lookup = {}
502
+ for coarse_entity in coarse_entities:
503
+ key = (coarse_entity["surface"], coarse_entity["type"].split(".")[0])
504
+ coarse_lookup[key] = coarse_entity
505
+
506
+ # Iterate through all entities and compare with the coarse entities
507
+ for entity in all_entities:
508
+ key = (
509
+ entity["surface"],
510
+ entity["type"].split(".")[0],
511
+ ) # Use the coarse type for comparison
512
+
513
+ if key in coarse_lookup:
514
+ coarse_entity = coarse_lookup[key]
515
+ # If a match is found, update the confidence_ner and type in the entity
516
+ if entity["confidence_ner"] < coarse_entity["confidence_ner"]:
517
+ entity["confidence_ner"] = coarse_entity["confidence_ner"]
518
+ entity["type"] = coarse_entity[
519
+ "type"
520
+ ] # Update the type if the confidence is higher
521
+
522
+ # No need to append to refined_entities, we're modifying in place
523
+ for entity in all_entities:
524
+ entity["type"] = entity["type"].split(".")[0]
525
+ return all_entities
526
+
527
+
528
+ def remove_trailing_stopwords(entities):
529
+ """
530
+ This function removes stopwords and punctuation from both the beginning and end of each entity's text
531
+ and repairs the lOffset and rOffset accordingly.
532
+ """
533
+ if DEBUG:
534
+ print(f"Initial entities in remove_trailing_stopwords: {len(entities)}")
535
+ new_entities = []
536
+ for entity in entities:
537
+ if "comp" not in entity["type"]:
538
+ entity_text = entity["surface"]
539
+ original_len = len(entity_text)
540
+
541
+ # Initial offsets
542
+ lOffset = entity.get("lOffset", 0)
543
+ rOffset = entity.get("rOffset", original_len)
544
+
545
+ # Remove stopwords and punctuation from the beginning
546
+ # print('----', entity_text)
547
+ if len(entity_text.split()) < 1:
548
+ continue
549
+ while entity_text and (
550
+ entity_text.split()[0].lower() in stop_words
551
+ or entity_text[0] in punctuation
552
+ ):
553
+ if entity_text.split()[0].lower() in stop_words:
554
+ stopword_len = (
555
+ len(entity_text.split()[0]) + 1
556
+ ) # Adjust length for stopword and following space
557
+ entity_text = entity_text[stopword_len:] # Remove leading stopword
558
+ lOffset += stopword_len # Adjust the left offset
559
+ if DEBUG:
560
+ print(
561
+ f"Removed leading stopword from entity: {entity['surface']} --> {entity_text} ({entity['type']}"
562
+ )
563
+ elif entity_text[0] in punctuation:
564
+ entity_text = entity_text[1:] # Remove leading punctuation
565
+ lOffset += 1 # Adjust the left offset
566
+ if DEBUG:
567
+ print(
568
+ f"Removed leading punctuation from entity: {entity['surface']} --> {entity_text} ({entity['type']}"
569
+ )
570
+
571
+ # Remove stopwords and punctuation from the end
572
+ if len(entity_text.strip()) > 1:
573
+ while (
574
+ entity_text.strip().split()
575
+ and (
576
+ entity_text.strip().split()[-1].lower() in stop_words
577
+ or entity_text[-1] in punctuation
578
+ )
579
+ ):
580
+ if entity_text.strip().split() and entity_text.strip().split()[-1].lower() in stop_words:
581
+ stopword_len = len(entity_text.strip().split()[-1]) + 1 # account for space
582
+ entity_text = entity_text[:-stopword_len]
583
+ rOffset -= stopword_len
584
+ if DEBUG:
585
+ print(
586
+ f"Removed trailing stopword from entity: {entity['surface']} --> {entity_text} ({entity['type']})"
587
+ )
588
+ if entity_text and entity_text[-1] in punctuation:
589
+ entity_text = entity_text[:-1]
590
+ rOffset -= 1
591
+ if DEBUG:
592
+ print(
593
+ f"Removed trailing punctuation from entity: {entity['surface']} --> {entity_text} ({entity['type']})"
594
+ )
595
+
596
+ # Skip certain entities based on rules
597
+ if entity_text in string.punctuation:
598
+ if DEBUG:
599
+ print(f"Skipping entity: {entity_text}")
600
+ # entities.remove(entity)
601
+ continue
602
+ # check now if its in stopwords
603
+ if entity_text.lower() in stop_words:
604
+ if DEBUG:
605
+ print(f"Skipping entity: {entity_text}")
606
+ # entities.remove(entity)
607
+ continue
608
+ # check now if the entire entity is a list of stopwords:
609
+ if all([word.lower() in stop_words for word in entity_text.split()]):
610
+ if DEBUG:
611
+ print(f"Skipping entity: {entity_text}")
612
+ # entities.remove(entity)
613
+ continue
614
+ # Check if the entire entity is made up of stopwords characters
615
+ if all(
616
+ [char.lower() in stop_words for char in entity_text if char.isalpha()]
617
+ ):
618
+ if DEBUG:
619
+ print(
620
+ f"Skipping entity: {entity_text} (all characters are stopwords)"
621
+ )
622
+ # entities.remove(entity)
623
+ continue
624
+ # check now if all entity is in a list of punctuation
625
+ if all([word in string.punctuation for word in entity_text.split()]):
626
+ if DEBUG:
627
+ print(
628
+ f"Skipping entity: {entity_text} (all characters are punctuation)"
629
+ )
630
+ # entities.remove(entity)
631
+ continue
632
+ if all(
633
+ [
634
+ char.lower() in string.punctuation
635
+ for char in entity_text
636
+ if char.isalpha()
637
+ ]
638
+ ):
639
+ if DEBUG:
640
+ print(
641
+ f"Skipping entity: {entity_text} (all characters are punctuation)"
642
+ )
643
+ # entities.remove(entity)
644
+ continue
645
+
646
+ # if it's a number and "time" no in it, then continue
647
+ if entity_text.isdigit() and "time" not in entity["type"]:
648
+ if DEBUG:
649
+ print(f"Skipping entity: {entity_text}")
650
+ # entities.remove(entity)
651
+ continue
652
+
653
+ if entity_text.startswith(" "):
654
+ entity_text = entity_text[1:]
655
+ # update lOffset, rOffset
656
+ lOffset += 1
657
+ if entity_text.endswith(" "):
658
+ entity_text = entity_text[:-1]
659
+ # update lOffset, rOffset
660
+ rOffset -= 1
661
+
662
+ # Update the entity surface and offsets
663
+ entity["surface"] = entity_text
664
+ entity["lOffset"] = lOffset
665
+ entity["rOffset"] = rOffset
666
+
667
+ # Remove the entity if the surface is empty after cleaning
668
+ if len(entity["surface"].strip()) == 0:
669
+ if DEBUG:
670
+ print(f"Deleted entity: {entity['surface']}")
671
+ # entities.remove(entity)
672
+ else:
673
+ new_entities.append(entity)
674
+ else:
675
+ new_entities.append(entity)
676
+ if DEBUG:
677
+ print(f"Remained entities in remove_trailing_stopwords: {len(new_entities)}")
678
+ return new_entities
679
+
680
+
681
+ class MultitaskTokenClassificationPipeline(Pipeline):
682
+
683
+ def _sanitize_parameters(self, **kwargs):
684
+ preprocess_kwargs = {}
685
+ if "text" in kwargs:
686
+ preprocess_kwargs["text"] = kwargs["text"]
687
+ if "tokens" in kwargs:
688
+ preprocess_kwargs["tokens"] = kwargs["tokens"]
689
+ self.label_map = self.model.config.label_map
690
+ self.id2label = {
691
+ task: {id_: label for label, id_ in labels.items()}
692
+ for task, labels in self.label_map.items()
693
+ }
694
+ return preprocess_kwargs, {}, {}
695
+
696
+ def preprocess(self, text, **kwargs):
697
+
698
+ tokens = kwargs["tokens"]
699
+ tokenized_inputs = self.tokenizer(
700
+ tokens, # a list of strings
701
+ is_split_into_words=True,
702
+ padding="max_length",
703
+ truncation=True,
704
+ max_length=512,
705
+ )
706
+ word_ids = tokenized_inputs.word_ids()
707
+
708
+ return tokenized_inputs, word_ids, text, tokens
709
+
710
+ def _forward(self, inputs):
711
+ inputs, word_ids, text, tokens = inputs
712
+
713
+ input_ids = torch.tensor([inputs["input_ids"]], dtype=torch.long).to(
714
+ self.model.device
715
+ )
716
+ attention_mask = torch.tensor([inputs["attention_mask"]], dtype=torch.long).to(
717
+ self.model.device
718
+ )
719
+ with torch.no_grad():
720
+ outputs = self.model(input_ids, attention_mask)
721
+ return outputs, word_ids, text, tokens
722
+
723
+ def is_within(self, entity1, entity2):
724
+ """Check if entity1 is fully within the bounds of entity2."""
725
+ return (
726
+ entity1["lOffset"] >= entity2["lOffset"]
727
+ and entity1["rOffset"] <= entity2["rOffset"]
728
+ )
729
+
730
+ def postprocess(self, outputs, **kwargs):
731
+ """
732
+ Postprocess the outputs of the model
733
+ :param outputs:
734
+ :param kwargs:
735
+ :return:
736
+ """
737
+ tokens_result, word_ids, text, tokens = outputs
738
+
739
+ predictions = {}
740
+ confidence_scores = {}
741
+ for task, logits in tokens_result.logits.items():
742
+ predictions[task] = torch.argmax(logits, dim=-1).tolist()[0]
743
+ confidence_scores[task] = F.softmax(logits, dim=-1).tolist()[0]
744
+
745
+ entities = {}
746
+ for task in predictions.keys():
747
+ words_list, preds_list, confidence_list = realign(
748
+ word_ids,
749
+ tokens,
750
+ predictions[task],
751
+ confidence_scores[task],
752
+ self.tokenizer,
753
+ self.id2label[task],
754
+ )
755
+
756
+ entities[task] = get_entities(words_list, preds_list, confidence_list, text)
757
+
758
+ # add titles to comp entities
759
+ # from pprint import pprint
760
+
761
+ # print("Before:")
762
+ # pprint(entities)
763
+
764
+ all_entities = []
765
+ coarse_entities = []
766
+ for key in entities:
767
+ if key in ["NE-COARSE-LIT"]:
768
+ coarse_entities = entities[key]
769
+ all_entities.extend(entities[key])
770
+
771
+ if DEBUG:
772
+ print(all_entities)
773
+ # print("After remove_included_entities:")
774
+ all_entities = remove_included_entities(all_entities)
775
+ if DEBUG:
776
+ print("After remove_included_entities:", all_entities)
777
+ all_entities = remove_trailing_stopwords(all_entities)
778
+ if DEBUG:
779
+ print("After remove_trailing_stopwords:", all_entities)
780
+ all_entities = postprocess_entities(all_entities)
781
+ if DEBUG:
782
+ print("After postprocess_entities:", all_entities)
783
+ all_entities = refine_entities_with_coarse(all_entities, coarse_entities)
784
+ if DEBUG:
785
+ print("After refine_entities_with_coarse:", all_entities)
786
+ # print("After attach_comp_to_closest:")
787
+ # pprint(all_entities)
788
+ # print("\n")
789
+ return all_entities
old/label_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"NE-COARSE-LIT": {"B-org": 0, "I-loc": 1, "I-org": 2, "O": 3, "B-prod": 4, "B-time": 5, "I-time": 6, "B-pers": 7, "B-loc": 8, "I-pers": 9, "I-prod": 10}, "NE-COARSE-METO": {"B-org": 0, "O": 1, "I-org": 2, "B-loc": 3, "I-loc": 4, "B-time": 5}, "NE-FINE-LIT": {"B-pers.ind": 0, "I-org.ent": 1, "B-prod.doctr": 2, "B-org.adm": 3, "B-loc.unk": 4, "B-loc.add.phys": 5, "I-loc.add.phys": 6, "I-time.date.abs": 7, "I-loc.adm.town": 8, "B-pers.coll": 9, "B-prod.media": 10, "I-loc.adm.nat": 11, "B-loc.adm.sup": 12, "B-loc.phys.geo": 13, "I-org.ent.pressagency": 14, "I-loc.adm.sup": 15, "I-pers.ind": 16, "I-loc.phys.hydro": 17, "O": 18, "B-loc.oro": 19, "B-pers.ind.articleauthor": 20, "I-loc.oro": 21, "I-loc.add.elec": 22, "B-time.date.abs": 23, "B-org.ent": 24, "I-loc.phys.geo": 25, "I-pers.coll": 26, "I-loc.fac": 27, "B-loc.phys.hydro": 28, "I-org.adm": 29, "I-prod.doctr": 30, "I-pers.ind.articleauthor": 31, "B-loc.add.elec": 32, "B-loc.adm.town": 33, "B-loc.adm.nat": 34, "I-loc.adm.reg": 35, "B-loc.fac": 36, "B-org.ent.pressagency": 37, "I-prod.media": 38, "B-loc.adm.reg": 39, "I-loc.unk": 40}, "NE-FINE-METO": {"I-org.ent": 0, "B-org.adm": 1, "I-org.adm": 2, "B-loc.fac": 3, "O": 4, "B-loc.oro": 5, "B-loc.adm.town": 6, "B-org.ent": 7, "I-loc.fac": 8, "B-time.date.abs": 9}, "NE-FINE-COMP": {"I-comp.name": 0, "B-comp.name": 1, "B-comp.title": 2, "I-comp.function": 3, "I-comp.title": 4, "B-comp.function": 5, "O": 6, "I-comp.demonym": 7, "B-comp.demonym": 8, "B-comp.qualifier": 9, "I-comp.qualifier": 10}, "NE-NESTED": {"I-loc.fac": 0, "B-loc.phys.hydro": 1, "B-pers.ind": 2, "I-org.ent": 3, "B-org.adm": 4, "I-org.adm": 5, "I-loc.adm.town": 6, "B-pers.coll": 7, "I-loc.adm.nat": 8, "B-loc.adm.town": 9, "B-loc.adm.sup": 10, "B-loc.phys.geo": 11, "I-pers.ind": 12, "B-loc.adm.nat": 13, "I-loc.adm.reg": 14, "B-loc.adm.reg": 15, "O": 16, "B-loc.oro": 17, "B-loc.fac": 18, "I-loc.oro": 19, "B-org.ent": 20, "I-loc.phys.geo": 21, "I-loc.phys.hydro": 22, "B-prod.media": 23, "I-prod.media": 24}}
old/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03a807b124debff782406c816eacb7ced1f2e25b9a5198b27e1616a41faa0662
3
+ size 193971960
old/modeling_stacked.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.modeling_outputs import TokenClassifierOutput
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig
5
+ from torch.nn import CrossEntropyLoss
6
+ from typing import Optional, Tuple, Union
7
+ import logging, json, os
8
+
9
+ from .configuration_stacked import ImpressoConfig
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def get_info(label_map):
15
+ num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()}
16
+ return num_token_labels_dict
17
+
18
+
19
+ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
20
+
21
+ config_class = ImpressoConfig
22
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
23
+
24
+ def __init__(self, config):
25
+ super().__init__(config)
26
+ self.num_token_labels_dict = get_info(config.label_map)
27
+ self.config = config
28
+
29
+ self.bert = AutoModel.from_pretrained(
30
+ config.pretrained_config["_name_or_path"], config=config.pretrained_config
31
+ )
32
+ if "classifier_dropout" not in config.__dict__:
33
+ classifier_dropout = 0.1
34
+ else:
35
+ classifier_dropout = (
36
+ config.classifier_dropout
37
+ if config.classifier_dropout is not None
38
+ else config.hidden_dropout_prob
39
+ )
40
+ self.dropout = nn.Dropout(classifier_dropout)
41
+
42
+ # Additional transformer layers
43
+ self.transformer_encoder = nn.TransformerEncoder(
44
+ nn.TransformerEncoderLayer(
45
+ d_model=config.hidden_size, nhead=config.num_attention_heads
46
+ ),
47
+ num_layers=2,
48
+ )
49
+
50
+ # For token classification, create a classifier for each task
51
+ self.token_classifiers = nn.ModuleDict(
52
+ {
53
+ task: nn.Linear(config.hidden_size, num_labels)
54
+ for task, num_labels in self.num_token_labels_dict.items()
55
+ }
56
+ )
57
+
58
+ # Initialize weights and apply final processing
59
+ self.post_init()
60
+
61
+ def forward(
62
+ self,
63
+ input_ids: Optional[torch.Tensor] = None,
64
+ attention_mask: Optional[torch.Tensor] = None,
65
+ token_type_ids: Optional[torch.Tensor] = None,
66
+ position_ids: Optional[torch.Tensor] = None,
67
+ head_mask: Optional[torch.Tensor] = None,
68
+ inputs_embeds: Optional[torch.Tensor] = None,
69
+ labels: Optional[torch.Tensor] = None,
70
+ token_labels: Optional[dict] = None,
71
+ output_attentions: Optional[bool] = None,
72
+ output_hidden_states: Optional[bool] = None,
73
+ return_dict: Optional[bool] = None,
74
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
75
+ r"""
76
+ token_labels (`dict` of `torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*):
77
+ Labels for computing the token classification loss. Keys should match the tasks.
78
+ """
79
+ return_dict = (
80
+ return_dict if return_dict is not None else self.config.use_return_dict
81
+ )
82
+
83
+ bert_kwargs = {
84
+ "input_ids": input_ids,
85
+ "attention_mask": attention_mask,
86
+ "token_type_ids": token_type_ids,
87
+ "position_ids": position_ids,
88
+ "head_mask": head_mask,
89
+ "inputs_embeds": inputs_embeds,
90
+ "output_attentions": output_attentions,
91
+ "output_hidden_states": output_hidden_states,
92
+ "return_dict": return_dict,
93
+ }
94
+
95
+ if any(
96
+ keyword in self.config.name_or_path.lower()
97
+ for keyword in ["llama", "deberta"]
98
+ ):
99
+ bert_kwargs.pop("token_type_ids")
100
+ bert_kwargs.pop("head_mask")
101
+
102
+ outputs = self.bert(**bert_kwargs)
103
+
104
+ # For token classification
105
+ token_output = outputs[0]
106
+ token_output = self.dropout(token_output)
107
+
108
+ # Pass through additional transformer layers
109
+ token_output = self.transformer_encoder(token_output.transpose(0, 1)).transpose(
110
+ 0, 1
111
+ )
112
+
113
+ # Collect the logits and compute the loss for each task
114
+ task_logits = {}
115
+ total_loss = 0
116
+ for task, classifier in self.token_classifiers.items():
117
+ logits = classifier(token_output)
118
+ task_logits[task] = logits
119
+ if token_labels and task in token_labels:
120
+ loss_fct = CrossEntropyLoss()
121
+ loss = loss_fct(
122
+ logits.view(-1, self.num_token_labels_dict[task]),
123
+ token_labels[task].view(-1),
124
+ )
125
+ total_loss += loss
126
+
127
+ if not return_dict:
128
+ output = (task_logits,) + outputs[2:]
129
+ return ((total_loss,) + output) if total_loss != 0 else output
130
+
131
+ return TokenClassifierOutput(
132
+ loss=total_loss,
133
+ logits=task_logits,
134
+ hidden_states=outputs.hidden_states,
135
+ attentions=outputs.attentions,
136
+ )
old/special_tokens_map.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "mask_token": {
10
+ "content": "[MASK]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "sep_token": {
24
+ "content": "[SEP]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "[UNK]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
+ }
old/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
old/tokenizer_config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_basic_tokenize": true,
47
+ "do_lower_case": false,
48
+ "mask_token": "[MASK]",
49
+ "max_len": 512,
50
+ "model_max_length": 512,
51
+ "never_split": null,
52
+ "pad_token": "[PAD]",
53
+ "sep_token": "[SEP]",
54
+ "strip_accents": false,
55
+ "tokenize_chinese_chars": true,
56
+ "tokenizer_class": "BertTokenizer",
57
+ "unk_token": "[UNK]"
58
+ }
old/vocab.txt ADDED
The diff for this file is too large to render. See raw diff