Spaces:
Running
Running
File size: 4,975 Bytes
f316449 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import logging
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from Nested.data.transforms import (
BertSeqTransform,
NestedTagsTransform
)
logger = logging.getLogger(__name__)
class Token:
def __init__(self, text=None, pred_tag=None, gold_tag=None):
"""
Token object to hold token attributes
:param text: str
:param pred_tag: str
:param gold_tag: str
"""
self.text = text
self.gold_tag = gold_tag
self.pred_tag = pred_tag
self.subwords = None
@property
def subwords(self):
return self._subwords
@subwords.setter
def subwords(self, value):
self._subwords = value
def __str__(self):
"""
Token text representation
:return: str
"""
gold_tags = "|".join(self.gold_tag)
if self.pred_tag:
pred_tags = "|".join([pred_tag["tag"] for pred_tag in self.pred_tag])
else:
pred_tags = ""
if self.gold_tag:
r = f"{self.text}\t{gold_tags}\t{pred_tags}"
else:
r = f"{self.text}\t{pred_tags}"
return r
class DefaultDataset(Dataset):
def __init__(
self,
examples=None,
vocab=None,
bert_model="aubmindlab/bert-base-arabertv2",
max_seq_len=512,
):
"""
The dataset that used to transform the segments into training data
:param examples: list[[tuple]] - [[(token, tag), (token, tag), ...], [(token, tag), ...]]
You can get generate examples from -- Nested.data.dataset.parse_conll_files
:param vocab: vocab object containing indexed tags and tokens
:param bert_model: str - BERT model
:param: int - maximum sequence length
"""
self.transform = BertSeqTransform(bert_model, vocab, max_seq_len=max_seq_len)
self.examples = examples
self.vocab = vocab
def __len__(self):
return len(self.examples)
def __getitem__(self, item):
subwords, tags, tokens, valid_len = self.transform(self.examples[item])
return subwords, tags, tokens, valid_len
def collate_fn(self, batch):
"""
Collate function that is called when the batch is called by the trainer
:param batch: Dataloader batch
:return: Same output as the __getitem__ function
"""
subwords, tags, tokens, valid_len = zip(*batch)
# Pad sequences in this batch
# subwords and tokens are padded with zeros
# tags are padding with the index of the O tag
subwords = pad_sequence(subwords, batch_first=True, padding_value=0)
tags = pad_sequence(
tags, batch_first=True, padding_value=self.vocab.tags[0].get_stoi()["O"]
)
return subwords, tags, tokens, valid_len
class NestedTagsDataset(Dataset):
def __init__(
self,
examples=None,
vocab=None,
bert_model="aubmindlab/bert-base-arabertv2",
max_seq_len=512,
):
"""
The dataset that used to transform the segments into training data
:param examples: list[[tuple]] - [[(token, tag), (token, tag), ...], [(token, tag), ...]]
You can get generate examples from -- Nested.data.dataset.parse_conll_files
:param vocab: vocab object containing indexed tags and tokens
:param bert_model: str - BERT model
:param: int - maximum sequence length
"""
self.transform = NestedTagsTransform(
bert_model, vocab, max_seq_len=max_seq_len
)
self.examples = examples
self.vocab = vocab
def __len__(self):
return len(self.examples)
def __getitem__(self, item):
subwords, tags, tokens, masks, valid_len = self.transform(self.examples[item])
return subwords, tags, tokens, masks, valid_len
def collate_fn(self, batch):
"""
Collate function that is called when the batch is called by the trainer
:param batch: Dataloader batch
:return: Same output as the __getitem__ function
"""
subwords, tags, tokens, masks, valid_len = zip(*batch)
# Pad sequences in this batch
# subwords and tokens are padded with zeros
# tags are padding with the index of the O tag
subwords = pad_sequence(subwords, batch_first=True, padding_value=0)
masks = [torch.nn.ConstantPad1d((0, subwords.shape[-1] - tag.shape[-1]), 0)(mask)
for tag, mask in zip(tags, masks)]
masks = torch.cat(masks)
# Pad the tags, do the padding for each tag type
tags = [torch.nn.ConstantPad1d((0, subwords.shape[-1] - tag.shape[-1]), vocab.get_stoi()["O"])(tag)
for tag, vocab in zip(tags, self.vocab.tags[1:])]
tags = torch.cat(tags)
return subwords, tags, tokens, masks, valid_len
|