File size: 1,695 Bytes
ba848b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

import torch
from torch.utils.data import Dataset
from datasets import load_dataset

class ChessDataset(Dataset):
    def __init__(self, data, tokenizer, block_size):
        self.data = data
        self.tokenizer = tokenizer
        self.block_size = block_size
    def __len__(self): return len(self.data)
    def __getitem__(self, idx):
        text = self.data[idx]["text"]
        tokens =self.tokenizer(text, max_length=self.block_size)["input_ids"]
        input_ids= torch.tensor(tokens, dtype=torch.long)
        attention_mask= torch.ones_like(input_ids)
        return {"input_ids": input_ids, "attention_mask": attention_mask}

class ChessDataCollator:
    def __init__(self, tokenizer=None, max_length=None): pass
    def __call__(self, features):
        input_ids=torch.nn.utils.rnn.pad_sequence([f["input_ids"] for f in features], batch_first=True, padding_value=0)
        mask =torch.nn.utils.rnn.pad_sequence([f["attention_mask"] for f in features], batch_first=True, padding_value=0)
        labels = input_ids.clone()
        labels[mask == 0] = -100
        return {"input_ids": input_ids, "attention_mask": mask, "labels": labels}

def create_train_val_datasets(dataset_name, tokenizer, val_samples=1000, **kwargs):
    max_train=kwargs.get('train_samples', kwargs.get('max_train_samples', 50000))
    block_size= kwargs.get('n_ctx', kwargs.get('max_length', 256))
    ds =load_dataset(dataset_name, split="train")
    if len(ds)>max_train + val_samples: ds = ds.select(range(max_train + val_samples))
    split=ds.train_test_split(test_size=val_samples)
    return ChessDataset(split["train"], tokenizer, block_size), ChessDataset(split["test"], tokenizer, block_size)