| | import os |
| | from typing import Optional, Union, List |
| | from transformers import AutoModel, PreTrainedModel, AutoConfig, AutoModel, RobertaModel, BertModel |
| | from transformers.modeling_outputs import TokenClassifierOutput |
| | from torch import nn |
| | from torch.nn import CrossEntropyLoss |
| | import torch |
| | from itertools import islice |
| | from.configuration_multiheadcrf import MultiHeadCRFConfig |
| |
|
| | NUM_PER_LAYER = 16 |
| |
|
| | class RobertaMultiHeadCRFModel(PreTrainedModel): |
| | config_class = MultiHeadCRFConfig |
| | transformers_backbone_name = "roberta" |
| | transformers_backbone_class = RobertaModel |
| | _keys_to_ignore_on_load_unexpected = [r"pooler"] |
| | |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.num_labels = config.num_labels |
| | |
| | self.number_of_layer_per_head = config.number_of_layer_per_head |
| | |
| | self.heads = config.classes |
| | |
| | |
| | |
| | |
| | setattr(self, self.transformers_backbone_name, self.transformers_backbone_class(config, add_pooling_layer=False)) |
| | |
| | |
| | |
| | self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| |
|
| | print(sorted(self.heads)) |
| | for ent in self.heads: |
| | for i in range(self.number_of_layer_per_head): |
| | setattr(self, f"{ent}_dense_{i}", nn.Linear(config.hidden_size, config.hidden_size)) |
| | setattr(self, f"{ent}_dense_activation_{i}", nn.GELU(approximate='none')) |
| | setattr(self, f"{ent}_classifier", nn.Linear(config.hidden_size, config.num_labels)) |
| | setattr(self, f"{ent}_crf", CRF(num_tags=config.num_labels, batch_first=True)) |
| | setattr(self, f"{ent}_reduction", config.crf_reduction) |
| | self.reduction=config.crf_reduction |
| |
|
| | if self.config.freeze == True: |
| | self.manage_freezing() |
| | |
| | def training_mode(self): |
| |
|
| | |
| | |
| | for ent in self.heads: |
| | for i in range(self.number_of_layer_per_head): |
| | getattr(self, f"{ent}_dense_{i}").reset_parameters() |
| | getattr(self, f"{ent}_classifier").reset_parameters() |
| | getattr(self, f"{ent}_crf").reset_parameters() |
| | getattr(self, f"{ent}_crf").mask_impossible_transitions() |
| | |
| | def manage_freezing(self): |
| | for _, param in getattr(self, self.transformers_backbone_name).embeddings.named_parameters(): |
| | param.requires_grad = False |
| | |
| | num_encoders_to_freeze = self.config.num_frozen_encoder |
| | if num_encoders_to_freeze > 0: |
| | for _, param in islice(getattr(self, self.transformers_backbone_name).encoder.named_parameters(), num_encoders_to_freeze*NUM_PER_LAYER): |
| | param.requires_grad = False |
| | |
| | |
| | def forward(self, |
| | input_ids=None, |
| | attention_mask=None, |
| | token_type_ids=None, |
| | position_ids=None, |
| | head_mask=None, |
| | inputs_embeds=None, |
| | labels=None, |
| | output_attentions=None, |
| | output_hidden_states=None, |
| | return_dict=None |
| | ): |
| | |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| | |
| | outputs = getattr(self, self.transformers_backbone_name)(input_ids, |
| | attention_mask=attention_mask, |
| | token_type_ids=token_type_ids, |
| | position_ids=position_ids, |
| | head_mask=head_mask, |
| | inputs_embeds=inputs_embeds, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict) |
| |
|
| | sequence_output = outputs[0] |
| | sequence_output = self.dropout(sequence_output) |
| |
|
| | logits = {k:0 for k in self.heads} |
| | for ent in self.heads: |
| | for i in range(self.number_of_layer_per_head): |
| | dense_output = getattr(self, f"{ent}_dense_{i}")(sequence_output) |
| | dense_output = getattr(self, f"{ent}_dense_activation_{i}")(dense_output) |
| | logits[ent] = getattr(self, f"{ent}_classifier")(dense_output) |
| | |
| | loss = None |
| | if labels is not None: |
| | |
| | |
| | |
| | outputs = {k:0 for k in self.heads} |
| | for ent in self.heads: |
| | |
| | outputs[ent] = getattr(self, f"{ent}_crf")(logits[ent],labels[ent], reduction=self.reduction) |
| |
|
| | |
| | return sum(outputs.values()), logits |
| | else: |
| | |
| | |
| | outputs = {k:0 for k in self.heads} |
| | |
| | for ent in self.heads: |
| | outputs[ent] = torch.Tensor(getattr(self, f"{ent}_crf").decode(logits[ent])) |
| | return [outputs[ent] for ent in sorted(self.heads)] |
| |
|
| |
|
| | class BertMultiHeadCRFModel(RobertaMultiHeadCRFModel): |
| | config_class = MultiHeadCRFConfig |
| | transformers_backbone_name = "bert" |
| | transformers_backbone_class = BertModel |
| | _keys_to_ignore_on_load_unexpected = [r"pooler"] |
| |
|
| | |
| | LARGE_NEGATIVE_NUMBER = -1e9 |
| | class CRF(nn.Module): |
| | """Conditional random field. |
| | This module implements a conditional random field [LMP01]_. The forward computation |
| | of this class computes the log likelihood of the given sequence of tags and |
| | emission score tensor. This class also has `~CRF.decode` method which finds |
| | the best tag sequence given an emission score tensor using `Viterbi algorithm`_. |
| | Args: |
| | num_tags: Number of tags. |
| | batch_first: Whether the first dimension corresponds to the size of a minibatch. |
| | Attributes: |
| | start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size |
| | ``(num_tags,)``. |
| | end_transitions (`~torch.nn.Parameter`): End transition score tensor of size |
| | ``(num_tags,)``. |
| | transitions (`~torch.nn.Parameter`): Transition score tensor of size |
| | ``(num_tags, num_tags)``. |
| | .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001). |
| | "Conditional random fields: Probabilistic models for segmenting and |
| | labeling sequence data". *Proc. 18th International Conf. on Machine |
| | Learning*. Morgan Kaufmann. pp. 282–289. |
| | .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm |
| | """ |
| |
|
| | def __init__(self, num_tags: int, batch_first: bool = False) -> None: |
| | if num_tags <= 0: |
| | raise ValueError(f'invalid number of tags: {num_tags}') |
| | super().__init__() |
| | self.num_tags = num_tags |
| | self.batch_first = batch_first |
| | self.start_transitions = nn.Parameter(torch.empty(num_tags)) |
| | self.end_transitions = nn.Parameter(torch.empty(num_tags)) |
| | self.transitions = nn.Parameter(torch.empty(num_tags, num_tags)) |
| |
|
| | self.reset_parameters() |
| | self.mask_impossible_transitions() |
| |
|
| | def reset_parameters(self) -> None: |
| | """Initialize the transition parameters. |
| | The parameters will be initialized randomly from a uniform distribution |
| | between -0.1 and 0.1. |
| | """ |
| | nn.init.uniform_(self.start_transitions, -0.1, 0.1) |
| | nn.init.uniform_(self.end_transitions, -0.1, 0.1) |
| | nn.init.uniform_(self.transitions, -0.1, 0.1) |
| | |
| | def mask_impossible_transitions(self) -> None: |
| | """Set the value of impossible transitions to LARGE_NEGATIVE_NUMBER |
| | - start transition value of I-X |
| | - transition score of O -> I |
| | """ |
| | with torch.no_grad(): |
| | self.start_transitions[2] = LARGE_NEGATIVE_NUMBER |
| | |
| | self.transitions[0][2] = LARGE_NEGATIVE_NUMBER |
| | |
| | def __repr__(self) -> str: |
| | return f'{self.__class__.__name__}(num_tags={self.num_tags})' |
| |
|
| | def forward( |
| | self, |
| | emissions: torch.Tensor, |
| | tags: torch.LongTensor, |
| | mask: Optional[torch.ByteTensor] = None, |
| | reduction: str = 'sum', |
| | ) -> torch.Tensor: |
| | """Compute the conditional log likelihood of a sequence of tags given emission scores. |
| | Args: |
| | emissions (`~torch.Tensor`): Emission score tensor of size |
| | ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, |
| | ``(batch_size, seq_length, num_tags)`` otherwise. |
| | tags (`~torch.LongTensor`): Sequence of tags tensor of size |
| | ``(seq_length, batch_size)`` if ``batch_first`` is ``False``, |
| | ``(batch_size, seq_length)`` otherwise. |
| | mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` |
| | if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. |
| | reduction: Specifies the reduction to apply to the output: |
| | ``none|sum|mean|token_mean``. ``none``: no reduction will be applied. |
| | ``sum``: the output will be summed over batches. ``mean``: the output will be |
| | averaged over batches. ``token_mean``: the output will be averaged over tokens. |
| | Returns: |
| | `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if |
| | reduction is ``none``, ``()`` otherwise. |
| | """ |
| | |
| | self._validate(emissions, tags=tags, mask=mask) |
| | if reduction not in ('none', 'sum', 'mean', 'token_mean'): |
| | raise ValueError(f'invalid reduction: {reduction}') |
| | if mask is None: |
| | mask = torch.ones_like(tags, dtype=torch.uint8) |
| |
|
| | if self.batch_first: |
| | emissions = emissions.transpose(0, 1) |
| | tags = tags.transpose(0, 1) |
| | mask = mask.transpose(0, 1) |
| |
|
| | |
| | numerator = self._compute_score(emissions, tags, mask) |
| | |
| | denominator = self._compute_normalizer(emissions, mask) |
| | |
| | llh = numerator - denominator |
| | nllh = -llh |
| | |
| | if reduction == 'none': |
| | return nllh |
| | if reduction == 'sum': |
| | return nllh.sum() |
| | if reduction == 'mean': |
| | return nllh.mean() |
| | assert reduction == 'token_mean' |
| | return nllh.sum() / mask.type_as(emissions).sum() |
| |
|
| | def decode(self, emissions: torch.Tensor, |
| | mask: Optional[torch.ByteTensor] = None) -> List[List[int]]: |
| | """Find the most likely tag sequence using Viterbi algorithm. |
| | Args: |
| | emissions (`~torch.Tensor`): Emission score tensor of size |
| | ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, |
| | ``(batch_size, seq_length, num_tags)`` otherwise. |
| | mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` |
| | if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. |
| | Returns: |
| | List of list containing the best tag sequence for each batch. |
| | """ |
| | self._validate(emissions, mask=mask) |
| | if mask is None: |
| | mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8) |
| |
|
| | if self.batch_first: |
| | emissions = emissions.transpose(0, 1) |
| | mask = mask.transpose(0, 1) |
| |
|
| | return self._viterbi_decode(emissions, mask) |
| |
|
| | def _validate( |
| | self, |
| | emissions: torch.Tensor, |
| | tags: Optional[torch.LongTensor] = None, |
| | mask: Optional[torch.ByteTensor] = None) -> None: |
| | if emissions.dim() != 3: |
| | raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}') |
| | if emissions.size(2) != self.num_tags: |
| | raise ValueError( |
| | f'expected last dimension of emissions is {self.num_tags}, ' |
| | f'got {emissions.size(2)}') |
| |
|
| | if tags is not None: |
| | if emissions.shape[:2] != tags.shape: |
| | raise ValueError( |
| | 'the first two dimensions of emissions and tags must match, ' |
| | f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}') |
| |
|
| | if mask is not None: |
| | if emissions.shape[:2] != mask.shape: |
| | raise ValueError( |
| | 'the first two dimensions of emissions and mask must match, ' |
| | f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}') |
| | no_empty_seq = not self.batch_first and mask[0].all() |
| | no_empty_seq_bf = self.batch_first and mask[:, 0].all() |
| | if not no_empty_seq and not no_empty_seq_bf: |
| | raise ValueError('mask of the first timestep must all be on') |
| |
|
| | def _compute_score( |
| | self, emissions: torch.Tensor, tags: torch.LongTensor, |
| | mask: torch.ByteTensor) -> torch.Tensor: |
| | |
| | |
| | |
| | assert emissions.dim() == 3 and tags.dim() == 2 |
| | assert emissions.shape[:2] == tags.shape |
| | assert emissions.size(2) == self.num_tags |
| | assert mask.shape == tags.shape |
| | assert mask[0].all() |
| |
|
| | seq_length, batch_size = tags.shape |
| | mask = mask.type_as(emissions) |
| |
|
| | |
| | |
| | score = self.start_transitions[tags[0]] |
| | score += emissions[0, torch.arange(batch_size), tags[0]] |
| |
|
| | for i in range(1, seq_length): |
| | |
| | |
| | score += self.transitions[tags[i - 1], tags[i]] * mask[i] |
| |
|
| | |
| | |
| | score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i] |
| |
|
| | |
| | |
| | seq_ends = mask.long().sum(dim=0) - 1 |
| | |
| | last_tags = tags[seq_ends, torch.arange(batch_size)] |
| | |
| | score += self.end_transitions[last_tags] |
| |
|
| | return score |
| |
|
| | def _compute_normalizer( |
| | self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor: |
| | |
| | |
| | assert emissions.dim() == 3 and mask.dim() == 2 |
| | assert emissions.shape[:2] == mask.shape |
| | assert emissions.size(2) == self.num_tags |
| | assert mask[0].all() |
| |
|
| | seq_length = emissions.size(0) |
| |
|
| | |
| | |
| | |
| | |
| | score = self.start_transitions + emissions[0] |
| |
|
| | for i in range(1, seq_length): |
| | |
| | |
| | broadcast_score = score.unsqueeze(2) |
| |
|
| | |
| | |
| | broadcast_emissions = emissions[i].unsqueeze(1) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | next_score = broadcast_score + self.transitions + broadcast_emissions |
| |
|
| | |
| | |
| | |
| | |
| | next_score = torch.logsumexp(next_score, dim=1) |
| |
|
| | |
| | |
| | score = torch.where(mask[i].unsqueeze(1).bool(), next_score, score) |
| |
|
| | |
| | |
| | score += self.end_transitions |
| |
|
| | |
| | |
| | return torch.logsumexp(score, dim=1) |
| |
|
| | def _viterbi_decode(self, emissions: torch.FloatTensor, |
| | mask: torch.ByteTensor) -> List[List[int]]: |
| | |
| | |
| | assert emissions.dim() == 3 and mask.dim() == 2 |
| | assert emissions.shape[:2] == mask.shape |
| | assert emissions.size(2) == self.num_tags |
| | assert mask[0].all() |
| |
|
| | seq_length, batch_size = mask.shape |
| |
|
| | |
| | |
| | score = self.start_transitions + emissions[0] |
| | history = [] |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | for i in range(1, seq_length): |
| | |
| | |
| | broadcast_score = score.unsqueeze(2) |
| |
|
| | |
| | |
| | broadcast_emission = emissions[i].unsqueeze(1) |
| |
|
| | |
| | |
| | |
| | |
| | next_score = broadcast_score + self.transitions + broadcast_emission |
| |
|
| | |
| | |
| | next_score, indices = next_score.max(dim=1) |
| |
|
| | |
| | |
| | |
| | score = torch.where(mask[i].unsqueeze(1).bool(), next_score, score) |
| | history.append(indices) |
| |
|
| | |
| | |
| | score += self.end_transitions |
| |
|
| | |
| |
|
| | |
| | seq_ends = mask.long().sum(dim=0) - 1 |
| | best_tags_list = [] |
| |
|
| | for idx in range(batch_size): |
| | |
| | |
| | _, best_last_tag = score[idx].max(dim=0) |
| | best_tags = [best_last_tag.item()] |
| |
|
| | |
| | |
| | for hist in reversed(history[:seq_ends[idx]]): |
| | best_last_tag = hist[idx][best_tags[-1]] |
| | best_tags.append(best_last_tag.item()) |
| |
|
| | |
| | best_tags.reverse() |
| | best_tags_list.append(best_tags) |
| |
|
| | return best_tags_list |