Bailan-Alex's picture
Upload folder using huggingface_hub
4f2b2f4 verified
from omegaconf import DictConfig
from .parenthesis import BracketDataset
from .text import get_text_dataset, setup_tokeniser_from_dataset, TEXT_DATASETS
from typing import Optional
from transformers import AutoTokenizer
from dataclasses import dataclass
from torch.utils.data import DataLoader
@dataclass
class DatasetBundle:
train_loader: DataLoader
val_loader: Optional[DataLoader] = None
tokeniser: Optional[AutoTokenizer] = None
def setup_data_and_update_config(config: DictConfig) -> DatasetBundle:
"""
Get the dataset and update the config with token information for text datasets.
"""
tokeniser = None
if config.dataset in TEXT_DATASETS:
tokeniser = setup_tokeniser_from_dataset(config.dataset)
train_set = get_text_dataset(
config.dataset,
split="train",
max_length=config.interpolant.max_length,
filter_max_length=config.training.filter_max_length,
)
val_set = get_text_dataset(
config.dataset,
split="validation",
max_length=config.interpolant.max_length,
filter_max_length=config.training.filter_max_length,
)
config.interpolant.tokens = len(tokeniser)
config.interpolant.pad_token = tokeniser.pad_token_id
config.interpolant.mask_token = tokeniser.mask_token_id
if config.dataset == "bracket":
train_set = BracketDataset(2048, {4: 0.1, 16: 0.4, 32: 0.4, 64: 0.1})
val_set = BracketDataset(300, {4: 0.1, 16: 0.4, 32: 0.4, 64: 0.1})
train_loader = DataLoader(
train_set,
batch_size=config.training.per_gpu_batch_size,
shuffle=True,
drop_last=True,
num_workers=config.training.cpus,
pin_memory=True,
persistent_workers=True,
)
val_loader = DataLoader(
val_set,
batch_size=config.training.per_gpu_batch_size,
shuffle=False,
drop_last=True,
num_workers=config.training.cpus,
pin_memory=True,
persistent_workers=True,
)
return DatasetBundle(
train_loader=train_loader, val_loader=val_loader, tokeniser=tokeniser
)