Spaces:
No application file
No application file
| from omegaconf import DictConfig | |
| from .parenthesis import BracketDataset | |
| from .text import get_text_dataset, setup_tokeniser_from_dataset, TEXT_DATASETS | |
| from typing import Optional | |
| from transformers import AutoTokenizer | |
| from dataclasses import dataclass | |
| from torch.utils.data import DataLoader | |
| class DatasetBundle: | |
| train_loader: DataLoader | |
| val_loader: Optional[DataLoader] = None | |
| tokeniser: Optional[AutoTokenizer] = None | |
| def setup_data_and_update_config(config: DictConfig) -> DatasetBundle: | |
| """ | |
| Get the dataset and update the config with token information for text datasets. | |
| """ | |
| tokeniser = None | |
| if config.dataset in TEXT_DATASETS: | |
| tokeniser = setup_tokeniser_from_dataset(config.dataset) | |
| train_set = get_text_dataset( | |
| config.dataset, | |
| split="train", | |
| max_length=config.interpolant.max_length, | |
| filter_max_length=config.training.filter_max_length, | |
| ) | |
| val_set = get_text_dataset( | |
| config.dataset, | |
| split="validation", | |
| max_length=config.interpolant.max_length, | |
| filter_max_length=config.training.filter_max_length, | |
| ) | |
| config.interpolant.tokens = len(tokeniser) | |
| config.interpolant.pad_token = tokeniser.pad_token_id | |
| config.interpolant.mask_token = tokeniser.mask_token_id | |
| if config.dataset == "bracket": | |
| train_set = BracketDataset(2048, {4: 0.1, 16: 0.4, 32: 0.4, 64: 0.1}) | |
| val_set = BracketDataset(300, {4: 0.1, 16: 0.4, 32: 0.4, 64: 0.1}) | |
| train_loader = DataLoader( | |
| train_set, | |
| batch_size=config.training.per_gpu_batch_size, | |
| shuffle=True, | |
| drop_last=True, | |
| num_workers=config.training.cpus, | |
| pin_memory=True, | |
| persistent_workers=True, | |
| ) | |
| val_loader = DataLoader( | |
| val_set, | |
| batch_size=config.training.per_gpu_batch_size, | |
| shuffle=False, | |
| drop_last=True, | |
| num_workers=config.training.cpus, | |
| pin_memory=True, | |
| persistent_workers=True, | |
| ) | |
| return DatasetBundle( | |
| train_loader=train_loader, val_loader=val_loader, tokeniser=tokeniser | |
| ) | |