File size: 1,695 Bytes
ba848b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import torch
from torch.utils.data import Dataset
from datasets import load_dataset
class ChessDataset(Dataset):
def __init__(self, data, tokenizer, block_size):
self.data = data
self.tokenizer = tokenizer
self.block_size = block_size
def __len__(self): return len(self.data)
def __getitem__(self, idx):
text = self.data[idx]["text"]
tokens =self.tokenizer(text, max_length=self.block_size)["input_ids"]
input_ids= torch.tensor(tokens, dtype=torch.long)
attention_mask= torch.ones_like(input_ids)
return {"input_ids": input_ids, "attention_mask": attention_mask}
class ChessDataCollator:
def __init__(self, tokenizer=None, max_length=None): pass
def __call__(self, features):
input_ids=torch.nn.utils.rnn.pad_sequence([f["input_ids"] for f in features], batch_first=True, padding_value=0)
mask =torch.nn.utils.rnn.pad_sequence([f["attention_mask"] for f in features], batch_first=True, padding_value=0)
labels = input_ids.clone()
labels[mask == 0] = -100
return {"input_ids": input_ids, "attention_mask": mask, "labels": labels}
def create_train_val_datasets(dataset_name, tokenizer, val_samples=1000, **kwargs):
max_train=kwargs.get('train_samples', kwargs.get('max_train_samples', 50000))
block_size= kwargs.get('n_ctx', kwargs.get('max_length', 256))
ds =load_dataset(dataset_name, split="train")
if len(ds)>max_train + val_samples: ds = ds.select(range(max_train + val_samples))
split=ds.train_test_split(test_size=val_samples)
return ChessDataset(split["train"], tokenizer, block_size), ChessDataset(split["test"], tokenizer, block_size)
|