from datasets import load_dataset, DatasetDict SEED = 21 print("Downloading dataset from hugging face...") # Download the dataset from huggingface ds = load_dataset("DScomp380/plant_village") # extract dataset from DatasetDict ds = ds['train'] print("Splitting dataset into train/test/validation...") # First extract the training set temp = ds.train_test_split(train_size=0.70, shuffle=True, seed=SEED) # then split remaining dataset for test/validation test_valid_ds = temp['test'].train_test_split(train_size=0.5, shuffle=True, seed=SEED) # assign the sub datasets train_ds = temp['train'] validation_ds = test_valid_ds['train'] test_ds = test_valid_ds['test'] # combine into one DatasetDict ds_dict = DatasetDict({ "train": train_ds, "test": test_ds, "validation": validation_ds }) ds_dict.save_to_disk("data/processed_plant_village")