# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # # # # This file was created by: Alberto Palomo Alonso # # Universidad de Alcalá - Escuela Politécnica Superior # # # # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # # Import statements: import logging from torch.utils.data import Dataset, DataLoader from datasets import Dataset as HfDataset from datasets import load_from_disk from .tokenizer import SegmentationTokenizer, SentenceSegmenter # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # # # # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # class SegmentationDataset(Dataset): def __init__( self, huggingface_dataset: str | HfDataset, tokenizer: SegmentationTokenizer, segmenter: SentenceSegmenter, logger: logging.Logger = None, percentage: float = 1.0, return_type: type = dict ): """ A segmentation dataset takes a huggingface dataset or a path to a dataset on disk with the wikipedia-segmentation format. It loads the dataset and prepares it for training. Wikipedia-segmentation format: - The dataset is expected to be a huggingface dataset or a path to a dataset on disk. - The dataset should contain the following fields: >>> sample = { >>> 'text': ['Article 1', 'Article 2', ...], >>> 'titles': ['Title 1', 'Title 2', ...], >>> 'id': str, >>> 'words': int >>> 'paragraphs': int >>> 'sentences': int >>> } - The dataset should be a list of dictionaries, where each dictionary contains the fields above. Parameters ---------- huggingface_dataset : str | HfDataset A huggingface dataset or a path to a dataset on disk with the wikipedia-segmentation format. tokenizer : callable A tokenizer function that takes a string and returns a list of tokens. logger : logging.Logger, optional Logger instance. If not provided, a null logger will be used. percentage : float Percentage of the dataset to use. Default is 1.0 (100%). return_type : type The return type of __getitem__, either dict or tuple. Default is dict. Raises ------ ValueError If the huggingface_dataset is not a string or a HfDataset. ValueError If the tokenizer is not a callable function or class. ValueError If the sentence_tokenizer is not a callable function or class. ValueError If the dtype is not a type. """ # Null logging: if not isinstance(logger, logging.Logger): self.logger = logging.getLogger("null") self.logger.addHandler(logging.NullHandler()) else: self.logger = logger # Loading: if isinstance(huggingface_dataset, HfDataset): self.huggingface_dataset = huggingface_dataset elif isinstance(huggingface_dataset, str): self.huggingface_dataset = load_from_disk(huggingface_dataset) else: self.logger.error(f'[SegmentationDataset] huggingface_dataset must be either a string or a HfDataset.') raise ValueError(f'[SegmentationDataset] huggingface_dataset must be either a string or a HfDataset.') self.logger.info(f'[SegmentationDataset] Loaded dataset: {self.huggingface_dataset}') self.logger.info(f'[SegmentationDataset] Loaded dataset length: {self.huggingface_dataset.num_rows}') # Tokenizer: if callable(tokenizer): self.tokenizer = tokenizer else: self.logger.error(f'[SegmentationDataset] Tokenizer must be a callable function.') raise ValueError(f'[SegmentationDataset] Tokenizer must be a callable function.') # Segmenter: if not isinstance(segmenter, SentenceSegmenter): self.logger.error(f'[SegmentationDataset] Segmenter must be a SentenceSegmenter instance.') raise ValueError(f'[SegmentationDataset] Segmenter must be a SentenceSegmenter instance.') else: self.segmenter = segmenter # Percentage: if not (0.0 < percentage <= 1.0): self.logger.error(f'[SegmentationDataset] Percentage must be between 0.0 and 1.0.') raise ValueError(f'[SegmentationDataset] Percentage must be between 0.0 and 1.0.') else: self.percentage = percentage # Return type: if not isinstance(return_type, type): self.logger.error(f'[SegmentationDataset] return_type must be a type.') raise ValueError(f'[SegmentationDataset] return_type must be a type.') elif return_type not in [dict, tuple]: self.logger.error(f'[SegmentationDataset] return_type must be either dict or tuple.') raise ValueError(f'[SegmentationDataset] return_type must be either dict or tuple.') else: self.return_type = return_type def get_loader(self, batch_size=8, shuffle=True, num_workers=0, **kwargs) -> DataLoader: """ Returns a PyTorch DataLoader for this dataset. Parameters ---------- batch_size : int Number of samples per batch. shuffle : bool Whether to shuffle the dataset. num_workers : int Number of worker processes. **kwargs Additional arguments for DataLoader. Returns ------- [torch.utils.data.DataLoader Configured DataLoader. """ # Size handling: return DataLoader(self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, **kwargs) def __len__(self) -> int: """ Returns the number of samples in the dataset. Returns ------- int Total number of samples. """ return int(self.huggingface_dataset.num_rows * self.percentage) def __getitem__(self, idx) -> dict | tuple: """ Retrieves a single sample and generates segmentation labels. Parameters ---------- idx : int Index of the sample. Returns ------- tuple A tuple or dict (x_i, y_i, mask_x) with noisy input and corresponding target. """ sample = self.huggingface_dataset[idx]['text'] sentences = self.segmenter(sample) tokenized = self.tokenizer(sentences['sentences']) if self.return_type == tuple: return ( tokenized['input_ids'], # x sentences['sentence_boundaries'], # y tokenized['attention_mask'], # x_mask sentences['sentence_mask'], # y_mask sentences['sentence_candidates'], # y_prime_mask ) elif self.return_type == dict: return_value = { 'input': tokenized['input_ids'], 'input_mask': tokenized['attention_mask'], 'labels': sentences['sentence_boundaries'], 'output_mask': sentences['sentence_mask'], 'candidate_mask': sentences['sentence_candidates'] } else: raise ValueError(f'[SegmentationDataset] return_type must be either dict or tuple.') return return_value # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # # END OF FILE # # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #