File size: 2,183 Bytes
4f2b2f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from omegaconf import DictConfig
from .parenthesis import BracketDataset
from .text import get_text_dataset, setup_tokeniser_from_dataset, TEXT_DATASETS
from typing import Optional
from transformers import AutoTokenizer
from dataclasses import dataclass
from torch.utils.data import DataLoader


@dataclass
class DatasetBundle:
    train_loader: DataLoader
    val_loader: Optional[DataLoader] = None
    tokeniser: Optional[AutoTokenizer] = None


def setup_data_and_update_config(config: DictConfig) -> DatasetBundle:
    """
    Get the dataset and update the config with token information for text datasets.
    """
    tokeniser = None
    if config.dataset in TEXT_DATASETS:
        tokeniser = setup_tokeniser_from_dataset(config.dataset)
        train_set = get_text_dataset(
            config.dataset,
            split="train",
            max_length=config.interpolant.max_length,
            filter_max_length=config.training.filter_max_length,
        )
        val_set = get_text_dataset(
            config.dataset,
            split="validation",
            max_length=config.interpolant.max_length,
            filter_max_length=config.training.filter_max_length,
        )
        config.interpolant.tokens = len(tokeniser)
        config.interpolant.pad_token = tokeniser.pad_token_id
        config.interpolant.mask_token = tokeniser.mask_token_id

    if config.dataset == "bracket":
        train_set = BracketDataset(2048, {4: 0.1, 16: 0.4, 32: 0.4, 64: 0.1})
        val_set = BracketDataset(300, {4: 0.1, 16: 0.4, 32: 0.4, 64: 0.1})

    train_loader = DataLoader(
        train_set,
        batch_size=config.training.per_gpu_batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=config.training.cpus,
        pin_memory=True,
        persistent_workers=True,
    )
    val_loader = DataLoader(
        val_set,
        batch_size=config.training.per_gpu_batch_size,
        shuffle=False,
        drop_last=True,
        num_workers=config.training.cpus,
        pin_memory=True,
        persistent_workers=True,
    )

    return DatasetBundle(
        train_loader=train_loader, val_loader=val_loader, tokeniser=tokeniser
    )