Shaltiel commited on
Commit
2f5fc2f
·
verified ·
1 Parent(s): e935839

Update BertForPrefixMarking.py

Browse files
Files changed (1) hide show
  1. BertForPrefixMarking.py +299 -296
BertForPrefixMarking.py CHANGED
@@ -1,296 +1,299 @@
1
- from transformers.utils import ModelOutput
2
- import torch
3
- from torch import nn
4
- from typing import Dict, List, Tuple, Optional, Union
5
- from dataclasses import dataclass
6
- from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast, AutoConfig, AutoModel
7
- from transformers.models.auto.auto_factory import _BaseAutoModelClass
8
- from transformers.modeling_utils import no_init_weights
9
- import inspect, os
10
-
11
- # define the classes, and the possible prefixes for each class
12
- POSSIBLE_PREFIX_CLASSES = [ ['לכש', 'כש', 'מש', 'בש', 'לש'], ['מ'], ['ש'], ['ה'], ['ו'], ['כ'], ['ל'], ['ב'] ]
13
- POSSIBLE_RABBINIC_PREFIX_CLASSES = [ ['לכש', 'כש', 'מש', 'בש', 'לש', 'לד', 'בד', 'מד', 'כד', 'לכד'], ['מ'], ['ש', 'ד'], ['ה'], ['ו'], ['כ'], ['ל'], ['ב'], ['א'], ['ק'] ]
14
-
15
- class PrefixConfig(dict):
16
- def __init__(self, possible_classes, **kwargs): # added kwargs for previous version where all features were kept as dict values
17
- super().__init__()
18
- self.possible_classes = possible_classes
19
- self.total_classes = len(possible_classes)
20
- self.prefix_c2i = {w: i for i, l in enumerate(possible_classes) for w in l}
21
- self.all_prefix_items = list(sorted(self.prefix_c2i.keys(), key=len, reverse=True))
22
-
23
- @property
24
- def possible_classes(self) -> List[List[str]]:
25
- return self.get('possible_classes')
26
-
27
- @possible_classes.setter
28
- def possible_classes(self, value: List[List[str]]):
29
- self['possible_classes'] = value
30
-
31
- DEFAULT_PREFIX_CONFIG = PrefixConfig(POSSIBLE_PREFIX_CLASSES)
32
-
33
- def get_prefixes_from_str(s, cfg: PrefixConfig, greedy=False):
34
- # keep trimming prefixes from the string
35
- while len(s) > 0 and s[0] in cfg.prefix_c2i:
36
- # find the longest string to trim
37
- next_pre = next((pre for pre in cfg.all_prefix_items if s.startswith(pre)), None)
38
- if next_pre is None:
39
- return
40
- yield next_pre
41
- # if the chosen prefix is more than one letter, there is always an option that the
42
- # prefix is actually just the first letter of the prefix - so offer that up as a valid prefix
43
- # as well. We will still jump to the length of the longer one, since if the next two/three
44
- # letters are a prefix, they have to be the longest one
45
- if not greedy and len(next_pre) > 1:
46
- yield next_pre[0]
47
- s = s[len(next_pre):]
48
-
49
- def get_prefix_classes_from_str(s, cfg: PrefixConfig, greedy=False):
50
- for pre in get_prefixes_from_str(s, cfg, greedy):
51
- yield cfg.prefix_c2i[pre]
52
-
53
- @dataclass
54
- class PrefixesClassifiersOutput(ModelOutput):
55
- loss: Optional[torch.FloatTensor] = None
56
- logits: Optional[torch.FloatTensor] = None
57
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
58
- attentions: Optional[Tuple[torch.FloatTensor]] = None
59
-
60
- class BertPrefixMarkingHead(nn.Module):
61
- def __init__(self, config) -> None:
62
- super().__init__()
63
- self.config = config
64
-
65
- if not hasattr(config, 'prefix_cfg') or config.prefix_cfg is None:
66
- setattr(config, 'prefix_cfg', DEFAULT_PREFIX_CONFIG)
67
- if isinstance(config.prefix_cfg, dict):
68
- config.prefix_cfg = PrefixConfig(config.prefix_cfg['possible_classes'])
69
-
70
- # an embedding table containing an embedding for each prefix class + 1 for NONE
71
- # we will concatenate either the embedding/NONE for each class - and we want the concatenate
72
- # size to be the hidden_size
73
- prefix_class_embed = config.hidden_size // config.prefix_cfg.total_classes
74
- self.prefix_class_embeddings = nn.Embedding(config.prefix_cfg.total_classes + 1, prefix_class_embed)
75
-
76
- # one layer for transformation, apply an activation, then another N classifiers for each prefix class
77
- self.transform = nn.Linear(config.hidden_size + prefix_class_embed * config.prefix_cfg.total_classes, config.hidden_size)
78
- self.activation = nn.Tanh()
79
- self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, 2) for _ in range(config.prefix_cfg.total_classes)])
80
-
81
- def forward(
82
- self,
83
- hidden_states: torch.Tensor,
84
- prefix_class_id_options: torch.Tensor,
85
- labels: Optional[torch.Tensor] = None) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
86
-
87
- # encode the prefix_class_id_options
88
- # If input_ids is batch x seq_len
89
- # Then sequence_output is batch x seq_len x hidden_dim
90
- # So prefix_class_id_options is batch x seq_len x total_classes
91
- # Looking up the embeddings should give us batch x seq_len x total_classes x hidden_dim / N
92
- possible_class_embed = self.prefix_class_embeddings(prefix_class_id_options)
93
- # then flatten the final dimension - now we have batch x seq_len x hidden_dim_2
94
- possible_class_embed = possible_class_embed.reshape(possible_class_embed.shape[:-2] + (-1,))
95
-
96
- # concatenate the new class embed into the sequence output before the transform
97
- pre_transform_output = torch.cat((hidden_states, possible_class_embed), dim=-1) # batch x seq_len x (hidden_dim + hidden_dim_2)
98
- pre_logits_output = self.activation(self.transform(pre_transform_output))# batch x seq_len x hidden_dim
99
-
100
- # run each of the classifiers on the transformed output
101
- logits = torch.cat([cls(pre_logits_output).unsqueeze(-2) for cls in self.classifiers], dim=-2)
102
-
103
- loss = None
104
- if labels is not None:
105
- loss_fct = nn.CrossEntropyLoss()
106
- loss = loss_fct(logits.view(-1, 2), labels.view(-1))
107
-
108
- return (loss, logits)
109
-
110
- def can_func_take_parameter(fn, param_name):
111
- signature = inspect.signature(fn)
112
- # Exclude 'self' from parameters
113
- parameters = [p.name for p in signature.parameters.values() if p.name != 'self']
114
- return 'kwargs' in parameters or param_name in parameters
115
-
116
- class BaseForPrefixMarking(BertPreTrainedModel):
117
- base_model_prefix = ""
118
- def __init__(self, config, bert_cls=BertModel):
119
- super().__init__(config)
120
- setattr(config, "hidden_dropout_prob", getattr(config, "hidden_dropout_prob", 0.1))
121
- setattr(config, "initializer_range", getattr(config, "classifier_init_range", getattr(config, 'decoder_init_range', 0.02)))
122
-
123
- self.bert = bert_cls(config, **({} if not can_func_take_parameter(bert_cls.__init__, 'add_pooling_layer') else {'add_pooling_layer': False}))
124
- self.send_token_type_ids = can_func_take_parameter(self.bert.forward, 'token_type_ids')
125
-
126
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
127
- self.prefix = BertPrefixMarkingHead(config)
128
-
129
- # Initialize weights and apply final processing
130
- self.post_init()
131
-
132
- def forward(
133
- self,
134
- input_ids: Optional[torch.Tensor] = None,
135
- attention_mask: Optional[torch.Tensor] = None,
136
- token_type_ids: Optional[torch.Tensor] = None,
137
- prefix_class_id_options: Optional[torch.Tensor] = None,
138
- position_ids: Optional[torch.Tensor] = None,
139
- labels: Optional[torch.Tensor] = None,
140
- head_mask: Optional[torch.Tensor] = None,
141
- inputs_embeds: Optional[torch.Tensor] = None,
142
- output_attentions: Optional[bool] = None,
143
- output_hidden_states: Optional[bool] = None,
144
- return_dict: Optional[bool] = None,
145
- ):
146
- r"""
147
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
148
- Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
149
- """
150
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
151
-
152
- kwargs = dict(token_type_ids=token_type_ids, head_mask=head_mask) if self.send_token_type_ids else {}
153
-
154
- bert_outputs = self.bert(
155
- input_ids,
156
- attention_mask=attention_mask,
157
- position_ids=position_ids,
158
- inputs_embeds=inputs_embeds,
159
- output_attentions=output_attentions,
160
- output_hidden_states=output_hidden_states,
161
- return_dict=return_dict,
162
- **kwargs
163
- )
164
-
165
- hidden_states = bert_outputs[0]
166
- hidden_states = self.dropout(hidden_states)
167
-
168
- loss, logits = self.prefix.forward(hidden_states, prefix_class_id_options, labels)
169
- if not return_dict:
170
- return (loss,logits,) + bert_outputs[2:]
171
-
172
- return PrefixesClassifiersOutput(
173
- loss=loss,
174
- logits=logits,
175
- hidden_states=bert_outputs.hidden_states,
176
- attentions=bert_outputs.attentions,
177
- )
178
-
179
- def predict(self, sentences: List[str], tokenizer: BertTokenizerFast, padding='longest'):
180
- # step 1: encode the sentences through using the tokenizer, and get the input tensors + prefix id tensors
181
- inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, self.config.prefix_cfg, sentences, padding)
182
- inputs.pop('offset_mapping')
183
- inputs = {k:v.to(self.device) for k,v in inputs.items()}
184
-
185
- # run through bert
186
- logits = self.forward(**inputs, return_dict=True).logits
187
- return parse_logits(inputs['input_ids'].tolist(), sentences, tokenizer, logits, self.config.prefix_cfg)
188
-
189
-
190
- class AutoForPrefixMarking(_BaseAutoModelClass):
191
- @classmethod
192
- def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *args, **kwargs):
193
- auto_cfg = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
194
- base_cls = BaseForPrefixMarking
195
- with no_init_weights():
196
- bert_cls = AutoModel.from_config(auto_cfg, *args, **{k: v for k, v in kwargs.items() if k != 'config'}).__class__
197
- if 'Prefix' in bert_cls.__name__:
198
- base_cls = bert_cls
199
-
200
- return base_cls.from_pretrained(pretrained_model_name_or_path, *args, **kwargs, bert_cls=bert_cls, key_mapping={"^model": "bert"})
201
-
202
-
203
- def parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.FloatTensor, config: PrefixConfig):
204
- # extract the predictions by argmaxing the final dimension (batch x sequence x prefixes x prediction)
205
- logit_preds = torch.argmax(logits, axis=3).tolist()
206
-
207
- ret = []
208
-
209
- for sent_idx,sent_ids in enumerate(input_ids):
210
- tokens = tokenizer.convert_ids_to_tokens(sent_ids)
211
-
212
- ret.append([])
213
- for tok_idx,token in enumerate(tokens):
214
- # If we've reached the pad token, then we are at the end
215
- if token == tokenizer.pad_token: continue
216
- if token.startswith('##'): continue
217
-
218
- # combine the next tokens in? only if it's a breakup
219
- next_tok_idx = tok_idx + 1
220
- while next_tok_idx < len(tokens) and tokens[next_tok_idx].startswith('##'):
221
- token += tokens[next_tok_idx][2:]
222
- next_tok_idx += 1
223
-
224
- if hasattr(tokenizer, 'splinter') and tokenizer.splinter:
225
- token = tokenizer.splinter.unsplinter_word(token)
226
-
227
- prefix_len = get_predicted_prefix_len_from_logits(token, logit_preds[sent_idx][tok_idx], config)
228
-
229
- if not prefix_len:
230
- ret[-1].append([token])
231
- else:
232
- ret[-1].append([token[:prefix_len], token[prefix_len:]])
233
- return ret
234
-
235
- def encode_sentences_for_bert_for_prefix_marking(tokenizer: BertTokenizerFast, config: PrefixConfig, sentences: List[str], padding='longest', truncation=True):
236
- inputs = tokenizer(sentences, padding=padding, truncation=truncation, return_offsets_mapping=True, return_tensors='pt')
237
- # create our prefix_id_options array which will be like the input ids shape but with an addtional
238
- # dimension containing for each prefix whether it can be for that word
239
- prefix_id_options = torch.full(inputs['input_ids'].shape + (config.total_classes,), config.total_classes, dtype=torch.long)
240
-
241
- # go through each token, and fill in the vector accordingly
242
- for sent_idx, sent_ids in enumerate(inputs['input_ids']):
243
- tokens = tokenizer.convert_ids_to_tokens(sent_ids)
244
- for tok_idx, token in enumerate(tokens):
245
- # if the first letter isn't a valid prefix letter, nothing to talk about
246
- if len(token) < 2 or not token[0] in config.prefix_c2i: continue
247
-
248
- # combine the next tokens in? only if it's a breakup
249
- next_tok_idx = tok_idx + 1
250
- while next_tok_idx < len(tokens) and tokens[next_tok_idx].startswith('##'):
251
- token += tokens[next_tok_idx][2:]
252
- next_tok_idx += 1
253
-
254
- # find all the possible prefixes - and mark them as 0 (and in the possible mark it as it's value for embed lookup)
255
- for pre_class in get_prefix_classes_from_str(token, config):
256
- prefix_id_options[sent_idx, tok_idx, pre_class] = pre_class
257
-
258
- inputs['prefix_class_id_options'] = prefix_id_options
259
- return inputs
260
-
261
- def get_predicted_prefix_len_from_logits(token, token_logits, config: PrefixConfig):
262
- # Go through each possible prefix, and check if the prefix is yes - and if
263
- # so increase the counter of the matched length, otherwise break out. That will solve cases
264
- # of predicting prefix combinations that don't exist on the word.
265
- # For example, if we have the word ושכשהלכתי and the model predict ו & כש, then we will only
266
- # take the vuv because in order to get the כש we need the ש as well.
267
- # Two extra items:
268
- # 1] Don't allow the same prefix multiple times
269
- # 2] Always check that the word starts with that prefix - otherwise it's bad
270
- # (except for the case of multi-letter prefix, where we force the next to be last)
271
- cur_len, skip_next, last_check, seen_prefixes = 0, False, False, set()
272
- for prefix in get_prefixes_from_str(token, config):
273
- # Are we skipping this prefix? This will be the case where we matched כש, don't allow ש
274
- if skip_next:
275
- skip_next = False
276
- continue
277
- # check for duplicate prefixes, we don't allow two of the same prefix
278
- # if it predicted two of the same, then we will break out
279
- if prefix in seen_prefixes: break
280
- seen_prefixes.add(prefix)
281
-
282
- # check if we predicted this prefix
283
- if token_logits[config.prefix_c2i[prefix]]:
284
- cur_len += len(prefix)
285
- if last_check: break
286
- skip_next = len(prefix) > 1
287
- # Otherwise, we predicted no. If we didn't, then this is the end of the prefix
288
- # and time to break out. *Except* if it's a multi letter prefix, then we allow
289
- # just the next letter - e.g., if כש doesn't match, then we allow כ, but then we know
290
- # the word continues with a ש, and if it's not כש, then it's not כ-ש- (invalid)
291
- elif len(prefix) > 1:
292
- last_check = True
293
- else:
294
- break
295
-
296
- return cur_len
 
 
 
 
1
+ from transformers.utils import ModelOutput
2
+ import torch
3
+ from torch import nn
4
+ from typing import Dict, List, Tuple, Optional, Union
5
+ from dataclasses import dataclass
6
+ from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast, AutoConfig, AutoModel
7
+ from transformers.models.auto.auto_factory import _BaseAutoModelClass
8
+ try:
9
+ from transformers.modeling_utils import no_init_weights
10
+ except ImportError:
11
+ from transformers.initialization import no_init_weights
12
+ import inspect, os
13
+
14
+ # define the classes, and the possible prefixes for each class
15
+ POSSIBLE_PREFIX_CLASSES = [ ['לכש', 'כש', 'מש', 'בש', 'לש'], ['מ'], ['ש'], ['ה'], ['ו'], ['כ'], ['ל'], ['ב'] ]
16
+ POSSIBLE_RABBINIC_PREFIX_CLASSES = [ ['לכש', 'כש', 'מש', 'בש', 'לש', 'לד', 'בד', 'מד', 'כד', 'לכד'], ['מ'], ['ש', 'ד'], ['ה'], ['ו'], ['כ'], ['ל'], ['ב'], ['א'], ['ק'] ]
17
+
18
+ class PrefixConfig(dict):
19
+ def __init__(self, possible_classes, **kwargs): # added kwargs for previous version where all features were kept as dict values
20
+ super().__init__()
21
+ self.possible_classes = possible_classes
22
+ self.total_classes = len(possible_classes)
23
+ self.prefix_c2i = {w: i for i, l in enumerate(possible_classes) for w in l}
24
+ self.all_prefix_items = list(sorted(self.prefix_c2i.keys(), key=len, reverse=True))
25
+
26
+ @property
27
+ def possible_classes(self) -> List[List[str]]:
28
+ return self.get('possible_classes')
29
+
30
+ @possible_classes.setter
31
+ def possible_classes(self, value: List[List[str]]):
32
+ self['possible_classes'] = value
33
+
34
+ DEFAULT_PREFIX_CONFIG = PrefixConfig(POSSIBLE_PREFIX_CLASSES)
35
+
36
+ def get_prefixes_from_str(s, cfg: PrefixConfig, greedy=False):
37
+ # keep trimming prefixes from the string
38
+ while len(s) > 0 and s[0] in cfg.prefix_c2i:
39
+ # find the longest string to trim
40
+ next_pre = next((pre for pre in cfg.all_prefix_items if s.startswith(pre)), None)
41
+ if next_pre is None:
42
+ return
43
+ yield next_pre
44
+ # if the chosen prefix is more than one letter, there is always an option that the
45
+ # prefix is actually just the first letter of the prefix - so offer that up as a valid prefix
46
+ # as well. We will still jump to the length of the longer one, since if the next two/three
47
+ # letters are a prefix, they have to be the longest one
48
+ if not greedy and len(next_pre) > 1:
49
+ yield next_pre[0]
50
+ s = s[len(next_pre):]
51
+
52
+ def get_prefix_classes_from_str(s, cfg: PrefixConfig, greedy=False):
53
+ for pre in get_prefixes_from_str(s, cfg, greedy):
54
+ yield cfg.prefix_c2i[pre]
55
+
56
+ @dataclass
57
+ class PrefixesClassifiersOutput(ModelOutput):
58
+ loss: Optional[torch.FloatTensor] = None
59
+ logits: Optional[torch.FloatTensor] = None
60
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
61
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
62
+
63
+ class BertPrefixMarkingHead(nn.Module):
64
+ def __init__(self, config) -> None:
65
+ super().__init__()
66
+ self.config = config
67
+
68
+ if not hasattr(config, 'prefix_cfg') or config.prefix_cfg is None:
69
+ setattr(config, 'prefix_cfg', DEFAULT_PREFIX_CONFIG)
70
+ if isinstance(config.prefix_cfg, dict):
71
+ config.prefix_cfg = PrefixConfig(config.prefix_cfg['possible_classes'])
72
+
73
+ # an embedding table containing an embedding for each prefix class + 1 for NONE
74
+ # we will concatenate either the embedding/NONE for each class - and we want the concatenate
75
+ # size to be the hidden_size
76
+ prefix_class_embed = config.hidden_size // config.prefix_cfg.total_classes
77
+ self.prefix_class_embeddings = nn.Embedding(config.prefix_cfg.total_classes + 1, prefix_class_embed)
78
+
79
+ # one layer for transformation, apply an activation, then another N classifiers for each prefix class
80
+ self.transform = nn.Linear(config.hidden_size + prefix_class_embed * config.prefix_cfg.total_classes, config.hidden_size)
81
+ self.activation = nn.Tanh()
82
+ self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, 2) for _ in range(config.prefix_cfg.total_classes)])
83
+
84
+ def forward(
85
+ self,
86
+ hidden_states: torch.Tensor,
87
+ prefix_class_id_options: torch.Tensor,
88
+ labels: Optional[torch.Tensor] = None) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
89
+
90
+ # encode the prefix_class_id_options
91
+ # If input_ids is batch x seq_len
92
+ # Then sequence_output is batch x seq_len x hidden_dim
93
+ # So prefix_class_id_options is batch x seq_len x total_classes
94
+ # Looking up the embeddings should give us batch x seq_len x total_classes x hidden_dim / N
95
+ possible_class_embed = self.prefix_class_embeddings(prefix_class_id_options)
96
+ # then flatten the final dimension - now we have batch x seq_len x hidden_dim_2
97
+ possible_class_embed = possible_class_embed.reshape(possible_class_embed.shape[:-2] + (-1,))
98
+
99
+ # concatenate the new class embed into the sequence output before the transform
100
+ pre_transform_output = torch.cat((hidden_states, possible_class_embed), dim=-1) # batch x seq_len x (hidden_dim + hidden_dim_2)
101
+ pre_logits_output = self.activation(self.transform(pre_transform_output))# batch x seq_len x hidden_dim
102
+
103
+ # run each of the classifiers on the transformed output
104
+ logits = torch.cat([cls(pre_logits_output).unsqueeze(-2) for cls in self.classifiers], dim=-2)
105
+
106
+ loss = None
107
+ if labels is not None:
108
+ loss_fct = nn.CrossEntropyLoss()
109
+ loss = loss_fct(logits.view(-1, 2), labels.view(-1))
110
+
111
+ return (loss, logits)
112
+
113
+ def can_func_take_parameter(fn, param_name):
114
+ signature = inspect.signature(fn)
115
+ # Exclude 'self' from parameters
116
+ parameters = [p.name for p in signature.parameters.values() if p.name != 'self']
117
+ return 'kwargs' in parameters or param_name in parameters
118
+
119
+ class BaseForPrefixMarking(BertPreTrainedModel):
120
+ base_model_prefix = ""
121
+ def __init__(self, config, bert_cls=BertModel):
122
+ super().__init__(config)
123
+ setattr(config, "hidden_dropout_prob", getattr(config, "hidden_dropout_prob", 0.1))
124
+ setattr(config, "initializer_range", getattr(config, "classifier_init_range", getattr(config, 'decoder_init_range', 0.02)))
125
+
126
+ self.bert = bert_cls(config, **({} if not can_func_take_parameter(bert_cls.__init__, 'add_pooling_layer') else {'add_pooling_layer': False}))
127
+ self.send_token_type_ids = can_func_take_parameter(self.bert.forward, 'token_type_ids')
128
+
129
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
130
+ self.prefix = BertPrefixMarkingHead(config)
131
+
132
+ # Initialize weights and apply final processing
133
+ self.post_init()
134
+
135
+ def forward(
136
+ self,
137
+ input_ids: Optional[torch.Tensor] = None,
138
+ attention_mask: Optional[torch.Tensor] = None,
139
+ token_type_ids: Optional[torch.Tensor] = None,
140
+ prefix_class_id_options: Optional[torch.Tensor] = None,
141
+ position_ids: Optional[torch.Tensor] = None,
142
+ labels: Optional[torch.Tensor] = None,
143
+ head_mask: Optional[torch.Tensor] = None,
144
+ inputs_embeds: Optional[torch.Tensor] = None,
145
+ output_attentions: Optional[bool] = None,
146
+ output_hidden_states: Optional[bool] = None,
147
+ return_dict: Optional[bool] = None,
148
+ ):
149
+ r"""
150
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
151
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
152
+ """
153
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
154
+
155
+ kwargs = dict(token_type_ids=token_type_ids, head_mask=head_mask) if self.send_token_type_ids else {}
156
+
157
+ bert_outputs = self.bert(
158
+ input_ids,
159
+ attention_mask=attention_mask,
160
+ position_ids=position_ids,
161
+ inputs_embeds=inputs_embeds,
162
+ output_attentions=output_attentions,
163
+ output_hidden_states=output_hidden_states,
164
+ return_dict=return_dict,
165
+ **kwargs
166
+ )
167
+
168
+ hidden_states = bert_outputs[0]
169
+ hidden_states = self.dropout(hidden_states)
170
+
171
+ loss, logits = self.prefix.forward(hidden_states, prefix_class_id_options, labels)
172
+ if not return_dict:
173
+ return (loss,logits,) + bert_outputs[2:]
174
+
175
+ return PrefixesClassifiersOutput(
176
+ loss=loss,
177
+ logits=logits,
178
+ hidden_states=bert_outputs.hidden_states,
179
+ attentions=bert_outputs.attentions,
180
+ )
181
+
182
+ def predict(self, sentences: List[str], tokenizer: BertTokenizerFast, padding='longest'):
183
+ # step 1: encode the sentences through using the tokenizer, and get the input tensors + prefix id tensors
184
+ inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, self.config.prefix_cfg, sentences, padding)
185
+ inputs.pop('offset_mapping')
186
+ inputs = {k:v.to(self.device) for k,v in inputs.items()}
187
+
188
+ # run through bert
189
+ logits = self.forward(**inputs, return_dict=True).logits
190
+ return parse_logits(inputs['input_ids'].tolist(), sentences, tokenizer, logits, self.config.prefix_cfg)
191
+
192
+
193
+ class AutoForPrefixMarking(_BaseAutoModelClass):
194
+ @classmethod
195
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *args, **kwargs):
196
+ auto_cfg = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
197
+ base_cls = BaseForPrefixMarking
198
+ with no_init_weights():
199
+ bert_cls = AutoModel.from_config(auto_cfg, *args, **{k: v for k, v in kwargs.items() if k != 'config'}).__class__
200
+ if 'Prefix' in bert_cls.__name__:
201
+ base_cls = bert_cls
202
+
203
+ return base_cls.from_pretrained(pretrained_model_name_or_path, *args, **kwargs, bert_cls=bert_cls, key_mapping={"^model": "bert"})
204
+
205
+
206
+ def parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.FloatTensor, config: PrefixConfig):
207
+ # extract the predictions by argmaxing the final dimension (batch x sequence x prefixes x prediction)
208
+ logit_preds = torch.argmax(logits, axis=3).tolist()
209
+
210
+ ret = []
211
+
212
+ for sent_idx,sent_ids in enumerate(input_ids):
213
+ tokens = tokenizer.convert_ids_to_tokens(sent_ids)
214
+
215
+ ret.append([])
216
+ for tok_idx,token in enumerate(tokens):
217
+ # If we've reached the pad token, then we are at the end
218
+ if token == tokenizer.pad_token: continue
219
+ if token.startswith('##'): continue
220
+
221
+ # combine the next tokens in? only if it's a breakup
222
+ next_tok_idx = tok_idx + 1
223
+ while next_tok_idx < len(tokens) and tokens[next_tok_idx].startswith('##'):
224
+ token += tokens[next_tok_idx][2:]
225
+ next_tok_idx += 1
226
+
227
+ if hasattr(tokenizer, 'splinter') and tokenizer.splinter:
228
+ token = tokenizer.splinter.unsplinter_word(token)
229
+
230
+ prefix_len = get_predicted_prefix_len_from_logits(token, logit_preds[sent_idx][tok_idx], config)
231
+
232
+ if not prefix_len:
233
+ ret[-1].append([token])
234
+ else:
235
+ ret[-1].append([token[:prefix_len], token[prefix_len:]])
236
+ return ret
237
+
238
+ def encode_sentences_for_bert_for_prefix_marking(tokenizer: BertTokenizerFast, config: PrefixConfig, sentences: List[str], padding='longest', truncation=True):
239
+ inputs = tokenizer(sentences, padding=padding, truncation=truncation, return_offsets_mapping=True, return_tensors='pt')
240
+ # create our prefix_id_options array which will be like the input ids shape but with an addtional
241
+ # dimension containing for each prefix whether it can be for that word
242
+ prefix_id_options = torch.full(inputs['input_ids'].shape + (config.total_classes,), config.total_classes, dtype=torch.long)
243
+
244
+ # go through each token, and fill in the vector accordingly
245
+ for sent_idx, sent_ids in enumerate(inputs['input_ids']):
246
+ tokens = tokenizer.convert_ids_to_tokens(sent_ids)
247
+ for tok_idx, token in enumerate(tokens):
248
+ # if the first letter isn't a valid prefix letter, nothing to talk about
249
+ if len(token) < 2 or not token[0] in config.prefix_c2i: continue
250
+
251
+ # combine the next tokens in? only if it's a breakup
252
+ next_tok_idx = tok_idx + 1
253
+ while next_tok_idx < len(tokens) and tokens[next_tok_idx].startswith('##'):
254
+ token += tokens[next_tok_idx][2:]
255
+ next_tok_idx += 1
256
+
257
+ # find all the possible prefixes - and mark them as 0 (and in the possible mark it as it's value for embed lookup)
258
+ for pre_class in get_prefix_classes_from_str(token, config):
259
+ prefix_id_options[sent_idx, tok_idx, pre_class] = pre_class
260
+
261
+ inputs['prefix_class_id_options'] = prefix_id_options
262
+ return inputs
263
+
264
+ def get_predicted_prefix_len_from_logits(token, token_logits, config: PrefixConfig):
265
+ # Go through each possible prefix, and check if the prefix is yes - and if
266
+ # so increase the counter of the matched length, otherwise break out. That will solve cases
267
+ # of predicting prefix combinations that don't exist on the word.
268
+ # For example, if we have the word ושכשהלכתי and the model predict ו & כש, then we will only
269
+ # take the vuv because in order to get the כש we need the ש as well.
270
+ # Two extra items:
271
+ # 1] Don't allow the same prefix multiple times
272
+ # 2] Always check that the word starts with that prefix - otherwise it's bad
273
+ # (except for the case of multi-letter prefix, where we force the next to be last)
274
+ cur_len, skip_next, last_check, seen_prefixes = 0, False, False, set()
275
+ for prefix in get_prefixes_from_str(token, config):
276
+ # Are we skipping this prefix? This will be the case where we matched כש, don't allow ש
277
+ if skip_next:
278
+ skip_next = False
279
+ continue
280
+ # check for duplicate prefixes, we don't allow two of the same prefix
281
+ # if it predicted two of the same, then we will break out
282
+ if prefix in seen_prefixes: break
283
+ seen_prefixes.add(prefix)
284
+
285
+ # check if we predicted this prefix
286
+ if token_logits[config.prefix_c2i[prefix]]:
287
+ cur_len += len(prefix)
288
+ if last_check: break
289
+ skip_next = len(prefix) > 1
290
+ # Otherwise, we predicted no. If we didn't, then this is the end of the prefix
291
+ # and time to break out. *Except* if it's a multi letter prefix, then we allow
292
+ # just the next letter - e.g., if כש doesn't match, then we allow כ, but then we know
293
+ # the word continues with a ש, and if it's not כש, then it's not כ-ש- (invalid)
294
+ elif len(prefix) > 1:
295
+ last_check = True
296
+ else:
297
+ break
298
+
299
+ return cur_len