|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from datasets import Dataset as HfDataset |
|
|
from datasets import load_from_disk |
|
|
from .tokenizer import SegmentationTokenizer, SentenceSegmenter |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SegmentationDataset(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
huggingface_dataset: str | HfDataset, |
|
|
tokenizer: SegmentationTokenizer, |
|
|
segmenter: SentenceSegmenter, |
|
|
logger: logging.Logger = None, |
|
|
percentage: float = 1.0, |
|
|
return_type: type = dict |
|
|
): |
|
|
""" |
|
|
A segmentation dataset takes a huggingface dataset or a path to a dataset on disk with the |
|
|
wikipedia-segmentation format. It loads the dataset and prepares it for training. |
|
|
|
|
|
Wikipedia-segmentation format: |
|
|
- The dataset is expected to be a huggingface dataset or a path to a dataset on disk. |
|
|
- The dataset should contain the following fields: |
|
|
>>> sample = { |
|
|
>>> 'text': ['Article 1', 'Article 2', ...], |
|
|
>>> 'titles': ['Title 1', 'Title 2', ...], |
|
|
>>> 'id': str, |
|
|
>>> 'words': int |
|
|
>>> 'paragraphs': int |
|
|
>>> 'sentences': int |
|
|
>>> } |
|
|
- The dataset should be a list of dictionaries, where each dictionary contains the fields above. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
huggingface_dataset : str | HfDataset |
|
|
A huggingface dataset or a path to a dataset on disk with the wikipedia-segmentation format. |
|
|
|
|
|
tokenizer : callable |
|
|
A tokenizer function that takes a string and returns a list of tokens. |
|
|
|
|
|
logger : logging.Logger, optional |
|
|
Logger instance. If not provided, a null logger will be used. |
|
|
|
|
|
percentage : float |
|
|
Percentage of the dataset to use. Default is 1.0 (100%). |
|
|
|
|
|
return_type : type |
|
|
The return type of __getitem__, either dict or tuple. Default is dict. |
|
|
|
|
|
Raises |
|
|
------ |
|
|
ValueError |
|
|
If the huggingface_dataset is not a string or a HfDataset. |
|
|
ValueError |
|
|
If the tokenizer is not a callable function or class. |
|
|
ValueError |
|
|
If the sentence_tokenizer is not a callable function or class. |
|
|
ValueError |
|
|
If the dtype is not a type. |
|
|
|
|
|
""" |
|
|
|
|
|
if not isinstance(logger, logging.Logger): |
|
|
self.logger = logging.getLogger("null") |
|
|
self.logger.addHandler(logging.NullHandler()) |
|
|
else: |
|
|
self.logger = logger |
|
|
|
|
|
|
|
|
if isinstance(huggingface_dataset, HfDataset): |
|
|
self.huggingface_dataset = huggingface_dataset |
|
|
elif isinstance(huggingface_dataset, str): |
|
|
self.huggingface_dataset = load_from_disk(huggingface_dataset) |
|
|
else: |
|
|
self.logger.error(f'[SegmentationDataset] huggingface_dataset must be either a string or a HfDataset.') |
|
|
raise ValueError(f'[SegmentationDataset] huggingface_dataset must be either a string or a HfDataset.') |
|
|
self.logger.info(f'[SegmentationDataset] Loaded dataset: {self.huggingface_dataset}') |
|
|
self.logger.info(f'[SegmentationDataset] Loaded dataset length: {self.huggingface_dataset.num_rows}') |
|
|
|
|
|
|
|
|
if callable(tokenizer): |
|
|
self.tokenizer = tokenizer |
|
|
else: |
|
|
self.logger.error(f'[SegmentationDataset] Tokenizer must be a callable function.') |
|
|
raise ValueError(f'[SegmentationDataset] Tokenizer must be a callable function.') |
|
|
|
|
|
|
|
|
if not isinstance(segmenter, SentenceSegmenter): |
|
|
self.logger.error(f'[SegmentationDataset] Segmenter must be a SentenceSegmenter instance.') |
|
|
raise ValueError(f'[SegmentationDataset] Segmenter must be a SentenceSegmenter instance.') |
|
|
else: |
|
|
self.segmenter = segmenter |
|
|
|
|
|
|
|
|
if not (0.0 < percentage <= 1.0): |
|
|
self.logger.error(f'[SegmentationDataset] Percentage must be between 0.0 and 1.0.') |
|
|
raise ValueError(f'[SegmentationDataset] Percentage must be between 0.0 and 1.0.') |
|
|
else: |
|
|
self.percentage = percentage |
|
|
|
|
|
|
|
|
if not isinstance(return_type, type): |
|
|
self.logger.error(f'[SegmentationDataset] return_type must be a type.') |
|
|
raise ValueError(f'[SegmentationDataset] return_type must be a type.') |
|
|
elif return_type not in [dict, tuple]: |
|
|
self.logger.error(f'[SegmentationDataset] return_type must be either dict or tuple.') |
|
|
raise ValueError(f'[SegmentationDataset] return_type must be either dict or tuple.') |
|
|
else: |
|
|
self.return_type = return_type |
|
|
|
|
|
def get_loader(self, batch_size=8, shuffle=True, num_workers=0, **kwargs) -> DataLoader: |
|
|
""" |
|
|
Returns a PyTorch DataLoader for this dataset. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
batch_size : int |
|
|
Number of samples per batch. |
|
|
shuffle : bool |
|
|
Whether to shuffle the dataset. |
|
|
num_workers : int |
|
|
Number of worker processes. |
|
|
**kwargs |
|
|
Additional arguments for DataLoader. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
[torch.utils.data.DataLoader |
|
|
Configured DataLoader. |
|
|
""" |
|
|
|
|
|
return DataLoader(self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, |
|
|
pin_memory=True, **kwargs) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
""" |
|
|
Returns the number of samples in the dataset. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
int |
|
|
Total number of samples. |
|
|
""" |
|
|
return int(self.huggingface_dataset.num_rows * self.percentage) |
|
|
|
|
|
def __getitem__(self, idx) -> dict | tuple: |
|
|
""" |
|
|
Retrieves a single sample and generates segmentation labels. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
idx : int |
|
|
Index of the sample. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
tuple |
|
|
A tuple or dict (x_i, y_i, mask_x) with noisy input and corresponding target. |
|
|
""" |
|
|
sample = self.huggingface_dataset[idx]['text'] |
|
|
sentences = self.segmenter(sample) |
|
|
tokenized = self.tokenizer(sentences['sentences']) |
|
|
|
|
|
if self.return_type == tuple: |
|
|
return ( |
|
|
tokenized['input_ids'], |
|
|
sentences['sentence_boundaries'], |
|
|
tokenized['attention_mask'], |
|
|
sentences['sentence_mask'], |
|
|
sentences['sentence_candidates'], |
|
|
) |
|
|
elif self.return_type == dict: |
|
|
return_value = { |
|
|
'input': tokenized['input_ids'], |
|
|
'input_mask': tokenized['attention_mask'], |
|
|
'labels': sentences['sentence_boundaries'], |
|
|
'output_mask': sentences['sentence_mask'], |
|
|
'candidate_mask': sentences['sentence_candidates'] |
|
|
} |
|
|
else: |
|
|
raise ValueError(f'[SegmentationDataset] return_type must be either dict or tuple.') |
|
|
return return_value |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|