Plant_Disease_Detection_App / process_dataset.py
JAMM032's picture
Upload github repo files
97fcc90 verified
raw
history blame contribute delete
887 Bytes
from datasets import load_dataset, DatasetDict
SEED = 21
print("Downloading dataset from hugging face...")
# Download the dataset from huggingface
ds = load_dataset("DScomp380/plant_village")
# extract dataset from DatasetDict
ds = ds['train']
print("Splitting dataset into train/test/validation...")
# First extract the training set
temp = ds.train_test_split(train_size=0.70, shuffle=True, seed=SEED)
# then split remaining dataset for test/validation
test_valid_ds = temp['test'].train_test_split(train_size=0.5, shuffle=True, seed=SEED)
# assign the sub datasets
train_ds = temp['train']
validation_ds = test_valid_ds['train']
test_ds = test_valid_ds['test']
# combine into one DatasetDict
ds_dict = DatasetDict({
"train": train_ds,
"test": test_ds,
"validation": validation_ds
})
ds_dict.save_to_disk("data/processed_plant_village")