File size: 943 Bytes
137457a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | from datasets import Dataset, DatasetDict, load_from_disk
VALIDATION_SPLIT = 0.15
TEST_SPLIT = 0.15
def get_full_data() -> Dataset:
full_dataset = load_from_disk("preprocessed_dataset")
if isinstance(full_dataset, DatasetDict):
print("Warning, found a 'DatasetDict' while expected to find a Dataset!")
return full_dataset[0]
return full_dataset.shuffle(seed=42) # Set for reproducibility
def make_train_data():
data = get_full_data()
# _ is the test data (which is not used during training but only after)
train_val_data, _ = data.train_test_split(test_size=TEST_SPLIT, seed=42).values()
train_data, val_data = train_val_data.train_test_split(
test_size=VALIDATION_SPLIT, seed=42
).values()
return train_data, val_data
def make_test_data():
data = get_full_data()
_, test_data = data.train_test_split(test_size=TEST_SPLIT, seed=42).values()
return test_data
|