| from datasets import Dataset, DatasetDict, load_from_disk | |
| VALIDATION_SPLIT = 0.15 | |
| TEST_SPLIT = 0.15 | |
| def get_full_data() -> Dataset: | |
| full_dataset = load_from_disk("preprocessed_dataset") | |
| if isinstance(full_dataset, DatasetDict): | |
| print("Warning, found a 'DatasetDict' while expected to find a Dataset!") | |
| return full_dataset[0] | |
| return full_dataset.shuffle(seed=42) # Set for reproducibility | |
| def make_train_data(): | |
| data = get_full_data() | |
| # _ is the test data (which is not used during training but only after) | |
| train_val_data, _ = data.train_test_split(test_size=TEST_SPLIT, seed=42).values() | |
| train_data, val_data = train_val_data.train_test_split( | |
| test_size=VALIDATION_SPLIT, seed=42 | |
| ).values() | |
| return train_data, val_data | |
| def make_test_data(): | |
| data = get_full_data() | |
| _, test_data = data.train_test_split(test_size=TEST_SPLIT, seed=42).values() | |
| return test_data | |