wojood-api / Nested /data /datasets.py
naghamghanim's picture
Upload 37 files
f316449 verified
import logging
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from Nested.data.transforms import (
BertSeqTransform,
NestedTagsTransform
)
logger = logging.getLogger(__name__)
class Token:
def __init__(self, text=None, pred_tag=None, gold_tag=None):
"""
Token object to hold token attributes
:param text: str
:param pred_tag: str
:param gold_tag: str
"""
self.text = text
self.gold_tag = gold_tag
self.pred_tag = pred_tag
self.subwords = None
@property
def subwords(self):
return self._subwords
@subwords.setter
def subwords(self, value):
self._subwords = value
def __str__(self):
"""
Token text representation
:return: str
"""
gold_tags = "|".join(self.gold_tag)
if self.pred_tag:
pred_tags = "|".join([pred_tag["tag"] for pred_tag in self.pred_tag])
else:
pred_tags = ""
if self.gold_tag:
r = f"{self.text}\t{gold_tags}\t{pred_tags}"
else:
r = f"{self.text}\t{pred_tags}"
return r
class DefaultDataset(Dataset):
def __init__(
self,
examples=None,
vocab=None,
bert_model="aubmindlab/bert-base-arabertv2",
max_seq_len=512,
):
"""
The dataset that used to transform the segments into training data
:param examples: list[[tuple]] - [[(token, tag), (token, tag), ...], [(token, tag), ...]]
You can get generate examples from -- Nested.data.dataset.parse_conll_files
:param vocab: vocab object containing indexed tags and tokens
:param bert_model: str - BERT model
:param: int - maximum sequence length
"""
self.transform = BertSeqTransform(bert_model, vocab, max_seq_len=max_seq_len)
self.examples = examples
self.vocab = vocab
def __len__(self):
return len(self.examples)
def __getitem__(self, item):
subwords, tags, tokens, valid_len = self.transform(self.examples[item])
return subwords, tags, tokens, valid_len
def collate_fn(self, batch):
"""
Collate function that is called when the batch is called by the trainer
:param batch: Dataloader batch
:return: Same output as the __getitem__ function
"""
subwords, tags, tokens, valid_len = zip(*batch)
# Pad sequences in this batch
# subwords and tokens are padded with zeros
# tags are padding with the index of the O tag
subwords = pad_sequence(subwords, batch_first=True, padding_value=0)
tags = pad_sequence(
tags, batch_first=True, padding_value=self.vocab.tags[0].get_stoi()["O"]
)
return subwords, tags, tokens, valid_len
class NestedTagsDataset(Dataset):
def __init__(
self,
examples=None,
vocab=None,
bert_model="aubmindlab/bert-base-arabertv2",
max_seq_len=512,
):
"""
The dataset that used to transform the segments into training data
:param examples: list[[tuple]] - [[(token, tag), (token, tag), ...], [(token, tag), ...]]
You can get generate examples from -- Nested.data.dataset.parse_conll_files
:param vocab: vocab object containing indexed tags and tokens
:param bert_model: str - BERT model
:param: int - maximum sequence length
"""
self.transform = NestedTagsTransform(
bert_model, vocab, max_seq_len=max_seq_len
)
self.examples = examples
self.vocab = vocab
def __len__(self):
return len(self.examples)
def __getitem__(self, item):
subwords, tags, tokens, masks, valid_len = self.transform(self.examples[item])
return subwords, tags, tokens, masks, valid_len
def collate_fn(self, batch):
"""
Collate function that is called when the batch is called by the trainer
:param batch: Dataloader batch
:return: Same output as the __getitem__ function
"""
subwords, tags, tokens, masks, valid_len = zip(*batch)
# Pad sequences in this batch
# subwords and tokens are padded with zeros
# tags are padding with the index of the O tag
subwords = pad_sequence(subwords, batch_first=True, padding_value=0)
masks = [torch.nn.ConstantPad1d((0, subwords.shape[-1] - tag.shape[-1]), 0)(mask)
for tag, mask in zip(tags, masks)]
masks = torch.cat(masks)
# Pad the tags, do the padding for each tag type
tags = [torch.nn.ConstantPad1d((0, subwords.shape[-1] - tag.shape[-1]), vocab.get_stoi()["O"])(tag)
for tag, vocab in zip(tags, self.vocab.tags[1:])]
tags = torch.cat(tags)
return subwords, tags, tokens, masks, valid_len