aaljabari commited on
Commit
d53598d
·
verified ·
1 Parent(s): bb62cd8

Create datasets.py

Browse files
Files changed (1) hide show
  1. Nested/data/datasets.py +150 -0
Nested/data/datasets.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch
3
+ from torch.utils.data import Dataset
4
+ from torch.nn.utils.rnn import pad_sequence
5
+ from Nested.data.transforms import (
6
+ BertSeqTransform,
7
+ NestedTagsTransform
8
+ )
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class Token:
14
+ def __init__(self, text=None, pred_tag=None, gold_tag=None):
15
+ """
16
+ Token object to hold token attributes
17
+ :param text: str
18
+ :param pred_tag: str
19
+ :param gold_tag: str
20
+ """
21
+ self.text = text
22
+ self.gold_tag = gold_tag
23
+ self.pred_tag = pred_tag
24
+ self.subwords = None
25
+
26
+ @property
27
+ def subwords(self):
28
+ return self._subwords
29
+
30
+ @subwords.setter
31
+ def subwords(self, value):
32
+ self._subwords = value
33
+
34
+ def __str__(self):
35
+ """
36
+ Token text representation
37
+ :return: str
38
+ """
39
+ gold_tags = "|".join(self.gold_tag)
40
+
41
+ if self.pred_tag:
42
+ pred_tags = "|".join([pred_tag["tag"] for pred_tag in self.pred_tag])
43
+ else:
44
+ pred_tags = ""
45
+
46
+ if self.gold_tag:
47
+ r = f"{self.text}\t{gold_tags}\t{pred_tags}"
48
+ else:
49
+ r = f"{self.text}\t{pred_tags}"
50
+
51
+ return r
52
+
53
+
54
+ class DefaultDataset(Dataset):
55
+ def __init__(
56
+ self,
57
+ examples=None,
58
+ vocab=None,
59
+ bert_model="aubmindlab/bert-base-arabertv2",
60
+ max_seq_len=512,
61
+ ):
62
+ """
63
+ The dataset that used to transform the segments into training data
64
+ :param examples: list[[tuple]] - [[(token, tag), (token, tag), ...], [(token, tag), ...]]
65
+ You can get generate examples from -- Nested.data.dataset.parse_conll_files
66
+ :param vocab: vocab object containing indexed tags and tokens
67
+ :param bert_model: str - BERT model
68
+ :param: int - maximum sequence length
69
+ """
70
+ self.transform = BertSeqTransform(bert_model, vocab, max_seq_len=max_seq_len)
71
+ self.examples = examples
72
+ self.vocab = vocab
73
+
74
+ def __len__(self):
75
+ return len(self.examples)
76
+
77
+ def __getitem__(self, item):
78
+ subwords, tags, tokens, valid_len = self.transform(self.examples[item])
79
+ return subwords, tags, tokens, valid_len
80
+
81
+ def collate_fn(self, batch):
82
+ """
83
+ Collate function that is called when the batch is called by the trainer
84
+ :param batch: Dataloader batch
85
+ :return: Same output as the __getitem__ function
86
+ """
87
+ subwords, tags, tokens, valid_len = zip(*batch)
88
+
89
+ # Pad sequences in this batch
90
+ # subwords and tokens are padded with zeros
91
+ # tags are padding with the index of the O tag
92
+ subwords = pad_sequence(subwords, batch_first=True, padding_value=0)
93
+ tags = pad_sequence(
94
+ tags, batch_first=True, padding_value=self.vocab.tags[0].get_stoi()["O"]
95
+ )
96
+ return subwords, tags, tokens, valid_len
97
+
98
+
99
+ class NestedTagsDataset(Dataset):
100
+ def __init__(
101
+ self,
102
+ examples=None,
103
+ vocab=None,
104
+ bert_model="aubmindlab/bert-base-arabertv2",
105
+ max_seq_len=512,
106
+ ):
107
+ """
108
+ The dataset that used to transform the segments into training data
109
+ :param examples: list[[tuple]] - [[(token, tag), (token, tag), ...], [(token, tag), ...]]
110
+ You can get generate examples from -- Nested.data.dataset.parse_conll_files
111
+ :param vocab: vocab object containing indexed tags and tokens
112
+ :param bert_model: str - BERT model
113
+ :param: int - maximum sequence length
114
+ """
115
+ self.transform = NestedTagsTransform(
116
+ bert_model, vocab, max_seq_len=max_seq_len
117
+ )
118
+ self.examples = examples
119
+ self.vocab = vocab
120
+
121
+ def __len__(self):
122
+ return len(self.examples)
123
+
124
+ def __getitem__(self, item):
125
+ subwords, tags, tokens, masks, valid_len = self.transform(self.examples[item])
126
+ return subwords, tags, tokens, masks, valid_len
127
+
128
+ def collate_fn(self, batch):
129
+ """
130
+ Collate function that is called when the batch is called by the trainer
131
+ :param batch: Dataloader batch
132
+ :return: Same output as the __getitem__ function
133
+ """
134
+ subwords, tags, tokens, masks, valid_len = zip(*batch)
135
+
136
+ # Pad sequences in this batch
137
+ # subwords and tokens are padded with zeros
138
+ # tags are padding with the index of the O tag
139
+ subwords = pad_sequence(subwords, batch_first=True, padding_value=0)
140
+
141
+ masks = [torch.nn.ConstantPad1d((0, subwords.shape[-1] - tag.shape[-1]), 0)(mask)
142
+ for tag, mask in zip(tags, masks)]
143
+ masks = torch.cat(masks)
144
+
145
+ # Pad the tags, do the padding for each tag type
146
+ tags = [torch.nn.ConstantPad1d((0, subwords.shape[-1] - tag.shape[-1]), vocab.get_stoi()["O"])(tag)
147
+ for tag, vocab in zip(tags, self.vocab.tags[1:])]
148
+ tags = torch.cat(tags)
149
+
150
+ return subwords, tags, tokens, masks, valid_len