import logging import torch from torch.utils.data import Dataset from torch.nn.utils.rnn import pad_sequence from Nested.data.transforms import ( BertSeqTransform, NestedTagsTransform ) logger = logging.getLogger(__name__) class Token: def __init__(self, text=None, pred_tag=None, gold_tag=None): """ Token object to hold token attributes :param text: str :param pred_tag: str :param gold_tag: str """ self.text = text self.gold_tag = gold_tag self.pred_tag = pred_tag self.subwords = None @property def subwords(self): return self._subwords @subwords.setter def subwords(self, value): self._subwords = value def __str__(self): """ Token text representation :return: str """ gold_tags = "|".join(self.gold_tag) if self.pred_tag: pred_tags = "|".join([pred_tag["tag"] for pred_tag in self.pred_tag]) else: pred_tags = "" if self.gold_tag: r = f"{self.text}\t{gold_tags}\t{pred_tags}" else: r = f"{self.text}\t{pred_tags}" return r class DefaultDataset(Dataset): def __init__( self, examples=None, vocab=None, bert_model="aubmindlab/bert-base-arabertv2", max_seq_len=512, ): """ The dataset that used to transform the segments into training data :param examples: list[[tuple]] - [[(token, tag), (token, tag), ...], [(token, tag), ...]] You can get generate examples from -- Nested.data.dataset.parse_conll_files :param vocab: vocab object containing indexed tags and tokens :param bert_model: str - BERT model :param: int - maximum sequence length """ self.transform = BertSeqTransform(bert_model, vocab, max_seq_len=max_seq_len) self.examples = examples self.vocab = vocab def __len__(self): return len(self.examples) def __getitem__(self, item): subwords, tags, tokens, valid_len = self.transform(self.examples[item]) return subwords, tags, tokens, valid_len def collate_fn(self, batch): """ Collate function that is called when the batch is called by the trainer :param batch: Dataloader batch :return: Same output as the __getitem__ function """ subwords, tags, tokens, valid_len = zip(*batch) # Pad sequences in this batch # subwords and tokens are padded with zeros # tags are padding with the index of the O tag subwords = pad_sequence(subwords, batch_first=True, padding_value=0) tags = pad_sequence( tags, batch_first=True, padding_value=self.vocab.tags[0].get_stoi()["O"] ) return subwords, tags, tokens, valid_len class NestedTagsDataset(Dataset): def __init__( self, examples=None, vocab=None, bert_model="aubmindlab/bert-base-arabertv2", max_seq_len=512, ): """ The dataset that used to transform the segments into training data :param examples: list[[tuple]] - [[(token, tag), (token, tag), ...], [(token, tag), ...]] You can get generate examples from -- Nested.data.dataset.parse_conll_files :param vocab: vocab object containing indexed tags and tokens :param bert_model: str - BERT model :param: int - maximum sequence length """ self.transform = NestedTagsTransform( bert_model, vocab, max_seq_len=max_seq_len ) self.examples = examples self.vocab = vocab def __len__(self): return len(self.examples) def __getitem__(self, item): subwords, tags, tokens, masks, valid_len = self.transform(self.examples[item]) return subwords, tags, tokens, masks, valid_len def collate_fn(self, batch): """ Collate function that is called when the batch is called by the trainer :param batch: Dataloader batch :return: Same output as the __getitem__ function """ subwords, tags, tokens, masks, valid_len = zip(*batch) # Pad sequences in this batch # subwords and tokens are padded with zeros # tags are padding with the index of the O tag subwords = pad_sequence(subwords, batch_first=True, padding_value=0) masks = [torch.nn.ConstantPad1d((0, subwords.shape[-1] - tag.shape[-1]), 0)(mask) for tag, mask in zip(tags, masks)] masks = torch.cat(masks) # Pad the tags, do the padding for each tag type tags = [torch.nn.ConstantPad1d((0, subwords.shape[-1] - tag.shape[-1]), vocab.get_stoi()["O"])(tag) for tag, vocab in zip(tags, self.vocab.tags[1:])] tags = torch.cat(tags) return subwords, tags, tokens, masks, valid_len