wikipedia_segmentation / src /dataset /tokenized_dataset.py
alverciito
upload safetensors and refactor research files
dbd79bd
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# #
# This file was created by: Alberto Palomo Alonso #
# Universidad de Alcalá - Escuela Politécnica Superior #
# #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# Import statements:
import logging
import json
import os
import numpy as np
from torch.utils.data import Dataset, DataLoader
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
class TokenizedSegmentationDataset(Dataset):
def __init__(
self,
tokenized_dataset: str,
logger: logging.Logger = None,
percentage: float = 1.0,
return_type: type = dict
):
"""
A tokoenized segmentation dataset takes a huggingface dataset or a path to a dataset on disk with the
wikipedia-segmentation format. It loads the dataset and prepares it for training.
Wikipedia-segmentation format:
- The dataset is expected to be a huggingface dataset or a path to a dataset on disk.
- The dataset should contain the following fields:
>>> sample = {
>>> 'text': ['Article 1', 'Article 2', ...],
>>> 'titles': ['Title 1', 'Title 2', ...],
>>> 'id': str,
>>> 'words': int
>>> 'paragraphs': int
>>> 'sentences': int
>>> }
- The dataset should be a list of dictionaries, where each dictionary contains the fields above.
Parameters
----------
tokenized_dataset : str
A path to a tokenized dataset on disk with the wikipedia-segmentation format.
logger : logging.Logger, optional
Logger instance. If not provided, a null logger will be used.
percentage : float
Percentage of the dataset to use. Default is 1.0 (100%).
return_type : type
The return type of __getitem__, either dict or tuple. Default is dict.
Raises
------
ValueError
If the huggingface_dataset is not a string or a HfDataset.
ValueError
If the tokenizer is not a callable function or class.
ValueError
If the sentence_tokenizer is not a callable function or class.
ValueError
If the dtype is not a type.
"""
# Null logging:
if not isinstance(logger, logging.Logger):
self.logger = logging.getLogger("null")
self.logger.addHandler(logging.NullHandler())
else:
self.logger = logger
# Loading:
if isinstance(tokenized_dataset, str):
self.metadata_path = os.path.join(tokenized_dataset, 'info.json')
if not os.path.exists(self.metadata_path):
self.logger.error(f'[SegmentationDataset] Dataset metadata file not found at {self.metadata_path}.')
raise FileNotFoundError(f'[SegmentationDataset] Dataset metadata file not found at {self.metadata_path}.')
else:
with open(self.metadata_path, 'r', encoding='utf-8') as f:
self.metadata = json.load(f)
if 'fingerprint' not in self.metadata or not self.metadata['fingerprint']:
raise ValueError(f'[SegmentationDataset] Dataset metadata file is missing fingerprint information.')
else:
self.logger.error(f'[SegmentationDataset] huggingface_dataset must be either a string or a HfDataset.')
raise ValueError(f'[SegmentationDataset] huggingface_dataset must be either a string or a HfDataset.')
self.logger.info(f'[SegmentationDataset] Loaded dataset: {tokenized_dataset}')
self.logger.info(f'[SegmentationDataset] Loaded dataset length: {self.metadata["samples"]}')
# Percentage:
if not (0.0 < percentage <= 1.0):
self.logger.error(f'[SegmentationDataset] Percentage must be between 0.0 and 1.0.')
raise ValueError(f'[SegmentationDataset] Percentage must be between 0.0 and 1.0.')
else:
self.percentage = percentage
# Return type:
if not isinstance(return_type, type):
self.logger.error(f'[SegmentationDataset] return_type must be a type.')
raise ValueError(f'[SegmentationDataset] return_type must be a type.')
elif return_type not in [dict, tuple]:
self.logger.error(f'[SegmentationDataset] return_type must be either dict or tuple.')
raise ValueError(f'[SegmentationDataset] return_type must be either dict or tuple.')
else:
self.return_type = return_type
self.metadata['max_sentences'] = self.metadata['x']['element_shape'][0]
self.metadata['max_tokens'] = self.metadata['x']['element_shape'][1]
# Build maps:
read_mode = 'r'
self.x_map = np.memmap(
os.path.join(tokenized_dataset, self.metadata['x']['name'] + self.metadata['x']['extension']),
dtype=self.metadata['x']['dtype'],
mode=read_mode,
shape=(self.metadata['x']['samples'], *self.metadata['x']['element_shape'])
)
self.y_map = np.memmap(
os.path.join(tokenized_dataset, self.metadata['y']['name'] + self.metadata['y']['extension']),
dtype=self.metadata['y']['dtype'],
mode=read_mode,
shape=(self.metadata['y']['samples'], *self.metadata['y']['element_shape'])
)
self.x_mask_map = np.memmap(
os.path.join(tokenized_dataset, self.metadata['x_mask']['name'] + self.metadata['x_mask']['extension']),
dtype=self.metadata['x_mask']['dtype'],
mode=read_mode,
shape=(self.metadata['x_mask']['samples'], *self.metadata['x_mask']['element_shape'])
)
self.y_mask_map = np.memmap(
os.path.join(tokenized_dataset, self.metadata['y_mask']['name'] + self.metadata['y_mask']['extension']),
dtype=self.metadata['y_mask']['dtype'],
mode=read_mode,
shape=(self.metadata['y_mask']['samples'], *self.metadata['y_mask']['element_shape'])
)
self.y_cand_map = np.memmap(
os.path.join(tokenized_dataset, self.metadata['y_cand']['name'] + self.metadata['y_cand']['extension']),
dtype=self.metadata['y_cand']['dtype'],
mode=read_mode,
shape=(self.metadata['y_cand']['samples'], *self.metadata['y_cand']['element_shape'])
)
def get_loader(self, batch_size=8, shuffle=True, num_workers=0, **kwargs) -> DataLoader:
"""
Returns a PyTorch DataLoader for this dataset.
Parameters
----------
batch_size : int
Number of samples per batch.
shuffle : bool
Whether to shuffle the dataset.
num_workers : int
Number of worker processes.
**kwargs
Additional arguments for DataLoader.
Returns
-------
[torch.utils.data.DataLoader
Configured DataLoader.
"""
# Size handling:
return DataLoader(self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True,
**kwargs)
def __len__(self) -> int:
"""
Returns the number of samples in the dataset.
Returns
-------
int
Total number of samples.
"""
return int(self.metadata['samples'] * self.percentage)
def __getitem__(self, idx) -> dict | tuple:
"""
Retrieves a single sample and generates segmentation labels.
Parameters
----------
idx : int
Index of the sample.
Returns
-------
tuple
A tuple or dict (x_i, y_i, mask_x) with noisy input and corresponding target.
"""
if self.return_type == tuple:
return (
np.array(self.x_map[idx]), # ← copia
np.array(self.y_map[idx]),
np.array(self.x_mask_map[idx]),
np.array(self.y_mask_map[idx]),
np.array(self.y_cand_map[idx]),
)
elif self.return_type == dict:
return {
'input': np.array(self.x_map[idx]),
'input_mask': np.array(self.x_mask_map[idx]),
'labels': np.array(self.y_map[idx]),
'output_mask': np.array(self.y_mask_map[idx]),
'candidate_mask': np.array(self.y_cand_map[idx]),
}
else:
raise ValueError(f'[SegmentationDataset] return_type must be either dict or tuple.')
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# END OF FILE #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #