alverciito
upload safetensors and refactor research files
dbd79bd
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# #
# This file was created by: Alberto Palomo Alonso #
# Universidad de Alcalá - Escuela Politécnica Superior #
# #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# Import statements:
import logging
from torch.utils.data import Dataset, DataLoader
from datasets import Dataset as HfDataset
from datasets import load_from_disk
from .tokenizer import SegmentationTokenizer, SentenceSegmenter
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
class SegmentationDataset(Dataset):
def __init__(
self,
huggingface_dataset: str | HfDataset,
tokenizer: SegmentationTokenizer,
segmenter: SentenceSegmenter,
logger: logging.Logger = None,
percentage: float = 1.0,
return_type: type = dict
):
"""
A segmentation dataset takes a huggingface dataset or a path to a dataset on disk with the
wikipedia-segmentation format. It loads the dataset and prepares it for training.
Wikipedia-segmentation format:
- The dataset is expected to be a huggingface dataset or a path to a dataset on disk.
- The dataset should contain the following fields:
>>> sample = {
>>> 'text': ['Article 1', 'Article 2', ...],
>>> 'titles': ['Title 1', 'Title 2', ...],
>>> 'id': str,
>>> 'words': int
>>> 'paragraphs': int
>>> 'sentences': int
>>> }
- The dataset should be a list of dictionaries, where each dictionary contains the fields above.
Parameters
----------
huggingface_dataset : str | HfDataset
A huggingface dataset or a path to a dataset on disk with the wikipedia-segmentation format.
tokenizer : callable
A tokenizer function that takes a string and returns a list of tokens.
logger : logging.Logger, optional
Logger instance. If not provided, a null logger will be used.
percentage : float
Percentage of the dataset to use. Default is 1.0 (100%).
return_type : type
The return type of __getitem__, either dict or tuple. Default is dict.
Raises
------
ValueError
If the huggingface_dataset is not a string or a HfDataset.
ValueError
If the tokenizer is not a callable function or class.
ValueError
If the sentence_tokenizer is not a callable function or class.
ValueError
If the dtype is not a type.
"""
# Null logging:
if not isinstance(logger, logging.Logger):
self.logger = logging.getLogger("null")
self.logger.addHandler(logging.NullHandler())
else:
self.logger = logger
# Loading:
if isinstance(huggingface_dataset, HfDataset):
self.huggingface_dataset = huggingface_dataset
elif isinstance(huggingface_dataset, str):
self.huggingface_dataset = load_from_disk(huggingface_dataset)
else:
self.logger.error(f'[SegmentationDataset] huggingface_dataset must be either a string or a HfDataset.')
raise ValueError(f'[SegmentationDataset] huggingface_dataset must be either a string or a HfDataset.')
self.logger.info(f'[SegmentationDataset] Loaded dataset: {self.huggingface_dataset}')
self.logger.info(f'[SegmentationDataset] Loaded dataset length: {self.huggingface_dataset.num_rows}')
# Tokenizer:
if callable(tokenizer):
self.tokenizer = tokenizer
else:
self.logger.error(f'[SegmentationDataset] Tokenizer must be a callable function.')
raise ValueError(f'[SegmentationDataset] Tokenizer must be a callable function.')
# Segmenter:
if not isinstance(segmenter, SentenceSegmenter):
self.logger.error(f'[SegmentationDataset] Segmenter must be a SentenceSegmenter instance.')
raise ValueError(f'[SegmentationDataset] Segmenter must be a SentenceSegmenter instance.')
else:
self.segmenter = segmenter
# Percentage:
if not (0.0 < percentage <= 1.0):
self.logger.error(f'[SegmentationDataset] Percentage must be between 0.0 and 1.0.')
raise ValueError(f'[SegmentationDataset] Percentage must be between 0.0 and 1.0.')
else:
self.percentage = percentage
# Return type:
if not isinstance(return_type, type):
self.logger.error(f'[SegmentationDataset] return_type must be a type.')
raise ValueError(f'[SegmentationDataset] return_type must be a type.')
elif return_type not in [dict, tuple]:
self.logger.error(f'[SegmentationDataset] return_type must be either dict or tuple.')
raise ValueError(f'[SegmentationDataset] return_type must be either dict or tuple.')
else:
self.return_type = return_type
def get_loader(self, batch_size=8, shuffle=True, num_workers=0, **kwargs) -> DataLoader:
"""
Returns a PyTorch DataLoader for this dataset.
Parameters
----------
batch_size : int
Number of samples per batch.
shuffle : bool
Whether to shuffle the dataset.
num_workers : int
Number of worker processes.
**kwargs
Additional arguments for DataLoader.
Returns
-------
[torch.utils.data.DataLoader
Configured DataLoader.
"""
# Size handling:
return DataLoader(self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
pin_memory=True, **kwargs)
def __len__(self) -> int:
"""
Returns the number of samples in the dataset.
Returns
-------
int
Total number of samples.
"""
return int(self.huggingface_dataset.num_rows * self.percentage)
def __getitem__(self, idx) -> dict | tuple:
"""
Retrieves a single sample and generates segmentation labels.
Parameters
----------
idx : int
Index of the sample.
Returns
-------
tuple
A tuple or dict (x_i, y_i, mask_x) with noisy input and corresponding target.
"""
sample = self.huggingface_dataset[idx]['text']
sentences = self.segmenter(sample)
tokenized = self.tokenizer(sentences['sentences'])
if self.return_type == tuple:
return (
tokenized['input_ids'], # x
sentences['sentence_boundaries'], # y
tokenized['attention_mask'], # x_mask
sentences['sentence_mask'], # y_mask
sentences['sentence_candidates'], # y_prime_mask
)
elif self.return_type == dict:
return_value = {
'input': tokenized['input_ids'],
'input_mask': tokenized['attention_mask'],
'labels': sentences['sentence_boundaries'],
'output_mask': sentences['sentence_mask'],
'candidate_mask': sentences['sentence_candidates']
}
else:
raise ValueError(f'[SegmentationDataset] return_type must be either dict or tuple.')
return return_value
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# END OF FILE #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #