import torch from torch.utils.data import Dataset from datasets import load_dataset class ChessDataset(Dataset): def __init__(self, data, tokenizer, block_size): self.data = data self.tokenizer = tokenizer self.block_size = block_size def __len__(self): return len(self.data) def __getitem__(self, idx): text = self.data[idx]["text"] tokens =self.tokenizer(text, max_length=self.block_size)["input_ids"] input_ids= torch.tensor(tokens, dtype=torch.long) attention_mask= torch.ones_like(input_ids) return {"input_ids": input_ids, "attention_mask": attention_mask} class ChessDataCollator: def __init__(self, tokenizer=None, max_length=None): pass def __call__(self, features): input_ids=torch.nn.utils.rnn.pad_sequence([f["input_ids"] for f in features], batch_first=True, padding_value=0) mask =torch.nn.utils.rnn.pad_sequence([f["attention_mask"] for f in features], batch_first=True, padding_value=0) labels = input_ids.clone() labels[mask == 0] = -100 return {"input_ids": input_ids, "attention_mask": mask, "labels": labels} def create_train_val_datasets(dataset_name, tokenizer, val_samples=1000, **kwargs): max_train=kwargs.get('train_samples', kwargs.get('max_train_samples', 50000)) block_size= kwargs.get('n_ctx', kwargs.get('max_length', 256)) ds =load_dataset(dataset_name, split="train") if len(ds)>max_train + val_samples: ds = ds.select(range(max_train + val_samples)) split=ds.train_test_split(test_size=val_samples) return ChessDataset(split["train"], tokenizer, block_size), ChessDataset(split["test"], tokenizer, block_size)