| from transformers.utils import ModelOutput |
| import torch |
| from torch import nn |
| from typing import Dict, List, Tuple, Optional |
| from dataclasses import dataclass |
| from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast |
|
|
| |
| POSSIBLE_PREFIX_CLASSES = [ ['לכש', 'כש', 'מש', 'בש', 'לש'], ['מ'], ['ש'], ['ה'], ['ו'], ['כ'], ['ל'], ['ב'] ] |
| |
| PREFIXES_TO_CLASS = {w:i for i,l in enumerate(POSSIBLE_PREFIX_CLASSES) for w in l} |
| |
| |
| ALL_PREFIX_ITEMS = list(sorted(PREFIXES_TO_CLASS.keys(), key=len, reverse=True)) |
| TOTAL_POSSIBLE_PREFIX_CLASSES = len(POSSIBLE_PREFIX_CLASSES) |
|
|
| def get_prefixes_from_str(s, greedy=False): |
| |
| while len(s) > 0 and s[0] in PREFIXES_TO_CLASS: |
| |
| next_pre = next((pre for pre in ALL_PREFIX_ITEMS if s.startswith(pre)), None) |
| if next_pre is None: |
| return |
| yield next_pre |
| |
| |
| |
| |
| if not greedy and len(next_pre) > 1: |
| yield next_pre[0] |
| s = s[len(next_pre):] |
|
|
| def get_prefix_classes_from_str(s, greedy=False): |
| for pre in get_prefixes_from_str(s, greedy): |
| yield PREFIXES_TO_CLASS[pre] |
|
|
| @dataclass |
| class PrefixesClassifiersOutput(ModelOutput): |
| loss: Optional[torch.FloatTensor] = None |
| logits: Optional[torch.FloatTensor] = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
| class BertPrefixMarkingHead(nn.Module): |
| def __init__(self, config) -> None: |
| super().__init__() |
| self.config = config |
|
|
| |
| |
| |
| prefix_class_embed = config.hidden_size // TOTAL_POSSIBLE_PREFIX_CLASSES |
| self.prefix_class_embeddings = nn.Embedding(TOTAL_POSSIBLE_PREFIX_CLASSES + 1, prefix_class_embed) |
| |
| |
| self.transform = nn.Linear(config.hidden_size + prefix_class_embed * TOTAL_POSSIBLE_PREFIX_CLASSES, config.hidden_size) |
| self.activation = nn.Tanh() |
| self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, 2) for _ in range(TOTAL_POSSIBLE_PREFIX_CLASSES)]) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| prefix_class_id_options: torch.Tensor, |
| labels: Optional[torch.Tensor] = None) -> Tuple[torch.FloatTensor, torch.FloatTensor]: |
| |
| |
| |
| |
| |
| |
| possible_class_embed = self.prefix_class_embeddings(prefix_class_id_options) |
| |
| possible_class_embed = possible_class_embed.reshape(possible_class_embed.shape[:-2] + (-1,)) |
|
|
| |
| pre_transform_output = torch.cat((hidden_states, possible_class_embed), dim=-1) |
| pre_logits_output = self.activation(self.transform(pre_transform_output)) |
|
|
| |
| logits = torch.cat([cls(pre_logits_output).unsqueeze(-2) for cls in self.classifiers], dim=-2) |
|
|
| loss = None |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, 2), labels.view(-1)) |
| |
| return (loss, logits) |
| |
|
|
|
|
| class BertForPrefixMarking(BertPreTrainedModel): |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| self.bert = BertModel(config, add_pooling_layer=False) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| self.prefix = BertPrefixMarkingHead(config) |
|
|
| |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| token_type_ids: Optional[torch.Tensor] = None, |
| prefix_class_id_options: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| head_mask: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ): |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| bert_outputs = self.bert( |
| input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| hidden_states = bert_outputs[0] |
| hidden_states = self.dropout(hidden_states) |
|
|
| loss, logits = self.prefix.forward(hidden_states, prefix_class_id_options, labels) |
| if not return_dict: |
| return (loss,logits,) + bert_outputs[2:] |
|
|
| return PrefixesClassifiersOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=bert_outputs.hidden_states, |
| attentions=bert_outputs.attentions, |
| ) |
| |
| def predict(self, sentences: List[str], tokenizer: BertTokenizerFast, padding='longest'): |
| |
| inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, sentences, padding) |
| inputs = {k:v.to(self.device) for k,v in inputs.items()} |
|
|
| |
| logits = self.forward(**inputs, return_dict=True).logits |
| return parse_logits(inputs, sentences, tokenizer, logits) |
|
|
| def parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.FloatTensor): |
| |
| logit_preds = torch.argmax(logits, axis=3) |
|
|
| ret = [] |
|
|
| for sent_idx,sent_ids in enumerate(inputs['input_ids']): |
| tokens = tokenizer.convert_ids_to_tokens(sent_ids) |
| ret.append([]) |
| for tok_idx,token in enumerate(tokens): |
| |
| if token == tokenizer.pad_token: continue |
| if token.startswith('##'): continue |
|
|
| |
| next_tok_idx = tok_idx + 1 |
| while next_tok_idx < len(tokens) and tokens[next_tok_idx].startswith('##'): |
| token += tokens[next_tok_idx][2:] |
| next_tok_idx += 1 |
|
|
| prefix_len = get_predicted_prefix_len_from_logits(token, logit_preds[sent_idx, tok_idx]) |
| |
| if not prefix_len: |
| ret[-1].append([token]) |
| else: |
| ret[-1].append([token[:prefix_len], token[prefix_len:]]) |
| return ret |
|
|
| def encode_sentences_for_bert_for_prefix_marking(tokenizer: BertTokenizerFast, sentences: List[str], padding='longest', truncation=True): |
| inputs = tokenizer(sentences, padding=padding, truncation=truncation, return_offsets_mapping=True, return_tensors='pt') |
| |
| |
| prefix_id_options = torch.full(inputs['input_ids'].shape + (TOTAL_POSSIBLE_PREFIX_CLASSES,), TOTAL_POSSIBLE_PREFIX_CLASSES, dtype=torch.long) |
|
|
| |
| for sent_idx, sent_ids in enumerate(inputs['input_ids']): |
| tokens = tokenizer.convert_ids_to_tokens(sent_ids) |
| for tok_idx, token in enumerate(tokens): |
| |
| if len(token) < 2 or not token[0] in PREFIXES_TO_CLASS: continue |
|
|
| |
| next_tok_idx = tok_idx + 1 |
| while next_tok_idx < len(tokens) and tokens[next_tok_idx].startswith('##'): |
| token += tokens[next_tok_idx][2:] |
| next_tok_idx += 1 |
|
|
| |
| for pre_class in get_prefix_classes_from_str(token): |
| prefix_id_options[sent_idx, tok_idx, pre_class] = pre_class |
| |
| inputs['prefix_class_id_options'] = prefix_id_options |
| return inputs |
|
|
| def get_predicted_prefix_len_from_logits(token, token_logits): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| cur_len, skip_next, last_check, seen_prefixes = 0, False, False, set() |
| for prefix in get_prefixes_from_str(token): |
| |
| if skip_next: |
| skip_next = False |
| continue |
| |
| |
| if prefix in seen_prefixes: break |
| seen_prefixes.add(prefix) |
|
|
| |
| if token_logits[PREFIXES_TO_CLASS[prefix]].item(): |
| cur_len += len(prefix) |
| if last_check: break |
| skip_next = len(prefix) > 1 |
| |
| |
| |
| |
| elif len(prefix) > 1: |
| last_check = True |
| else: |
| break |
|
|
| return cur_len |
|
|