| """Basic model. Predicts tags for every token""" |
| from typing import Dict, Optional, List, Any |
|
|
| import numpy |
| import torch |
| import torch.nn.functional as F |
| from allennlp.data import Vocabulary |
| from allennlp.models.model import Model |
| from allennlp.modules import TimeDistributed, TextFieldEmbedder |
| from allennlp.nn import InitializerApplicator, RegularizerApplicator |
| from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits |
| from allennlp.training.metrics import CategoricalAccuracy |
| from overrides import overrides |
| from torch.nn.modules.linear import Linear |
|
|
|
|
| @Model.register("seq2labels") |
| class Seq2Labels(Model): |
| """ |
| This ``Seq2Labels`` simply encodes a sequence of text with a stacked ``Seq2SeqEncoder``, then |
| predicts a tag (or couple tags) for each token in the sequence. |
| |
| Parameters |
| ---------- |
| vocab : ``Vocabulary``, required |
| A Vocabulary, required in order to compute sizes for input/output projections. |
| text_field_embedder : ``TextFieldEmbedder``, required |
| Used to embed the ``tokens`` ``TextField`` we get as input to the model. |
| encoder : ``Seq2SeqEncoder`` |
| The encoder (with its own internal stacking) that we will use in between embedding tokens |
| and predicting output tags. |
| calculate_span_f1 : ``bool``, optional (default=``None``) |
| Calculate span-level F1 metrics during training. If this is ``True``, then |
| ``label_encoding`` is required. If ``None`` and |
| label_encoding is specified, this is set to ``True``. |
| If ``None`` and label_encoding is not specified, it defaults |
| to ``False``. |
| label_encoding : ``str``, optional (default=``None``) |
| Label encoding to use when calculating span f1. |
| Valid options are "BIO", "BIOUL", "IOB1", "BMES". |
| Required if ``calculate_span_f1`` is true. |
| labels_namespace : ``str``, optional (default=``labels``) |
| This is needed to compute the SpanBasedF1Measure metric, if desired. |
| Unless you did something unusual, the default value should be what you want. |
| verbose_metrics : ``bool``, optional (default = False) |
| If true, metrics will be returned per label class in addition |
| to the overall statistics. |
| initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) |
| Used to initialize the model parameters. |
| regularizer : ``RegularizerApplicator``, optional (default=``None``) |
| If provided, will be used to calculate the regularization penalty during training. |
| """ |
|
|
| def __init__(self, vocab: Vocabulary, |
| text_field_embedder: TextFieldEmbedder, |
| predictor_dropout=0.0, |
| labels_namespace: str = "labels", |
| detect_namespace: str = "d_tags", |
| verbose_metrics: bool = False, |
| label_smoothing: float = 0.0, |
| confidence: float = 0.0, |
| del_confidence: float = 0.0, |
| initializer: InitializerApplicator = InitializerApplicator(), |
| regularizer: Optional[RegularizerApplicator] = None) -> None: |
| super(Seq2Labels, self).__init__(vocab, regularizer) |
|
|
| self.label_namespaces = [labels_namespace, |
| detect_namespace] |
| self.text_field_embedder = text_field_embedder |
| self.num_labels_classes = self.vocab.get_vocab_size(labels_namespace) |
| self.num_detect_classes = self.vocab.get_vocab_size(detect_namespace) |
| self.label_smoothing = label_smoothing |
| self.confidence = confidence |
| self.del_conf = del_confidence |
| self.incorr_index = self.vocab.get_token_index("INCORRECT", |
| namespace=detect_namespace) |
|
|
| self._verbose_metrics = verbose_metrics |
| self.predictor_dropout = TimeDistributed(torch.nn.Dropout(predictor_dropout)) |
|
|
| self.tag_labels_projection_layer = TimeDistributed( |
| Linear(text_field_embedder._token_embedders['bert'].get_output_dim(), self.num_labels_classes)) |
|
|
| self.tag_detect_projection_layer = TimeDistributed( |
| Linear(text_field_embedder._token_embedders['bert'].get_output_dim(), self.num_detect_classes)) |
|
|
| self.metrics = {"accuracy": CategoricalAccuracy()} |
|
|
| initializer(self) |
|
|
| @overrides |
| def forward(self, |
| tokens: Dict[str, torch.LongTensor], |
| labels: torch.LongTensor = None, |
| d_tags: torch.LongTensor = None, |
| metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: |
| |
| """ |
| Parameters |
| ---------- |
| tokens : Dict[str, torch.LongTensor], required |
| The output of ``TextField.as_array()``, which should typically be passed directly to a |
| ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` |
| tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": |
| Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used |
| for the ``TokenIndexers`` when you created the ``TextField`` representing your |
| sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, |
| which knows how to combine different word representations into a single vector per |
| token in your input. |
| labels : torch.LongTensor, optional (default = None) |
| A torch tensor representing the sequence of integer gold class labels of shape |
| ``(batch_size, num_tokens)``. |
| d_tags : torch.LongTensor, optional (default = None) |
| A torch tensor representing the sequence of integer gold class labels of shape |
| ``(batch_size, num_tokens)``. |
| metadata : ``List[Dict[str, Any]]``, optional, (default = None) |
| metadata containing the original words in the sentence to be tagged under a 'words' key. |
| |
| Returns |
| ------- |
| An output dictionary consisting of: |
| logits : torch.FloatTensor |
| A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing |
| unnormalised log probabilities of the tag classes. |
| class_probabilities : torch.FloatTensor |
| A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing |
| a distribution of the tag classes per word. |
| loss : torch.FloatTensor, optional |
| A scalar loss to be optimised. |
| |
| """ |
| encoded_text = self.text_field_embedder(tokens) |
| batch_size, sequence_length, _ = encoded_text.size() |
| mask = get_text_field_mask(tokens) |
| logits_labels = self.tag_labels_projection_layer(self.predictor_dropout(encoded_text)) |
| logits_d = self.tag_detect_projection_layer(encoded_text) |
|
|
| class_probabilities_labels = F.softmax(logits_labels, dim=-1).view( |
| [batch_size, sequence_length, self.num_labels_classes]) |
| class_probabilities_d = F.softmax(logits_d, dim=-1).view( |
| [batch_size, sequence_length, self.num_detect_classes]) |
| error_probs = class_probabilities_d[:, :, self.incorr_index] * mask |
| incorr_prob = torch.max(error_probs, dim=-1)[0] |
|
|
| probability_change = [self.confidence, self.del_conf] + [0] * (self.num_labels_classes - 2) |
| class_probabilities_labels += torch.FloatTensor(probability_change).repeat( |
| (batch_size, sequence_length, 1)).to(class_probabilities_labels.device) |
|
|
| output_dict = {"logits_labels": logits_labels, |
| "logits_d_tags": logits_d, |
| "class_probabilities_labels": class_probabilities_labels, |
| "class_probabilities_d_tags": class_probabilities_d, |
| "max_error_probability": incorr_prob} |
| if labels is not None and d_tags is not None: |
| loss_labels = sequence_cross_entropy_with_logits(logits_labels, labels, mask, |
| label_smoothing=self.label_smoothing) |
| loss_d = sequence_cross_entropy_with_logits(logits_d, d_tags, mask) |
| for metric in self.metrics.values(): |
| metric(logits_labels, labels, mask.float()) |
| metric(logits_d, d_tags, mask.float()) |
| output_dict["loss"] = loss_labels + loss_d |
|
|
| if metadata is not None: |
| output_dict["words"] = [x["words"] for x in metadata] |
| return output_dict |
|
|
| @overrides |
| def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| """ |
| Does a simple position-wise argmax over each token, converts indices to string labels, and |
| adds a ``"tags"`` key to the dictionary with the result. |
| """ |
| for label_namespace in self.label_namespaces: |
| all_predictions = output_dict[f'class_probabilities_{label_namespace}'] |
| all_predictions = all_predictions.cpu().data.numpy() |
| if all_predictions.ndim == 3: |
| predictions_list = [all_predictions[i] for i in range(all_predictions.shape[0])] |
| else: |
| predictions_list = [all_predictions] |
| all_tags = [] |
|
|
| for predictions in predictions_list: |
| argmax_indices = numpy.argmax(predictions, axis=-1) |
| tags = [self.vocab.get_token_from_index(x, namespace=label_namespace) |
| for x in argmax_indices] |
| all_tags.append(tags) |
| output_dict[f'{label_namespace}'] = all_tags |
| return output_dict |
|
|
| @overrides |
| def get_metrics(self, reset: bool = False) -> Dict[str, float]: |
| metrics_to_return = {metric_name: metric.get_metric(reset) for |
| metric_name, metric in self.metrics.items()} |
| return metrics_to_return |
|
|