from datasets import Dataset, DatasetDict, load_from_disk VALIDATION_SPLIT = 0.15 TEST_SPLIT = 0.15 def get_full_data() -> Dataset: full_dataset = load_from_disk("preprocessed_dataset") if isinstance(full_dataset, DatasetDict): print("Warning, found a 'DatasetDict' while expected to find a Dataset!") return full_dataset[0] return full_dataset.shuffle(seed=42) # Set for reproducibility def make_train_data(): data = get_full_data() # _ is the test data (which is not used during training but only after) train_val_data, _ = data.train_test_split(test_size=TEST_SPLIT, seed=42).values() train_data, val_data = train_val_data.train_test_split( test_size=VALIDATION_SPLIT, seed=42 ).values() return train_data, val_data def make_test_data(): data = get_full_data() _, test_data = data.train_test_split(test_size=TEST_SPLIT, seed=42).values() return test_data