multi-classifier / dataset_splitter.py
veryfansome's picture
feat: UD is back, LlaMA play
0cdb887
raw
history blame
1.52 kB
from datasets import DatasetDict, load_from_disk
import argparse
from openai_dataset_maker import features
def has_all_valid_labels(exp):
for col, labels in exp.items():
if col in {"text", "tokens"}:
continue
for label in labels:
if label not in features[col]:
return False
return True
def is_evenly_shaped(exp):
cnt_set = set()
for col, labels in exp.items():
if col == "text":
continue
cnt_set.add(len(labels))
return len(cnt_set) == 1
if __name__ == '__main__':
arg_parser = argparse.ArgumentParser(description="Train multi-task model.")
arg_parser.add_argument("data_path", help="Load dataset from specified path.",
action="store")
arg_parser.add_argument("--save-path", help="Save final dataset to specified path.",
action="store", default="./training_data")
args = arg_parser.parse_args()
loaded_dataset = load_from_disk(args.data_path)
loaded_dataset = loaded_dataset.filter(is_evenly_shaped)
loaded_dataset = loaded_dataset.filter(has_all_valid_labels)
first_split = loaded_dataset.train_test_split(shuffle=True, seed=42, test_size=0.09)
second_split = first_split["train"].train_test_split(test_size=0.1)
new_ds = DatasetDict()
new_ds["test"] = first_split["test"]
new_ds["train"] = second_split["train"]
new_ds["validation"] = second_split["test"]
new_ds.save_to_disk(args.save_path)