|
|
""" |
|
|
From a directory of files with VTB Trees, split into train/dev/test set |
|
|
with a split of 70/15/15 |
|
|
|
|
|
The script requires two arguments |
|
|
1. org_dir: the original directory obtainable from running vtb_convert.py |
|
|
2. split_dir: the directory where the train/dev/test splits will be stored |
|
|
""" |
|
|
|
|
|
import os |
|
|
import argparse |
|
|
import random |
|
|
|
|
|
|
|
|
def create_shuffle_list(org_dir): |
|
|
""" |
|
|
This function creates the random order with which we use to loop through the files |
|
|
|
|
|
:param org_dir: original directory storing the files that store the trees |
|
|
:return: list of file names randomly shuffled |
|
|
""" |
|
|
file_names = sorted(os.listdir(org_dir)) |
|
|
random.shuffle(file_names) |
|
|
|
|
|
return file_names |
|
|
|
|
|
|
|
|
def create_paths(split_dir, short_name): |
|
|
""" |
|
|
This function creates the necessary paths for the train/dev/test splits |
|
|
|
|
|
:param split_dir: directory that stores the splits |
|
|
:return: train path, dev path, test path |
|
|
""" |
|
|
if not short_name: |
|
|
short_name = "" |
|
|
elif not short_name.endswith("_"): |
|
|
short_name = short_name + "_" |
|
|
|
|
|
train_path = os.path.join(split_dir, '%strain.mrg' % short_name) |
|
|
dev_path = os.path.join(split_dir, '%sdev.mrg' % short_name) |
|
|
test_path = os.path.join(split_dir, '%stest.mrg' % short_name) |
|
|
|
|
|
return train_path, dev_path, test_path |
|
|
|
|
|
|
|
|
def get_num_samples(org_dir, file_names): |
|
|
""" |
|
|
Function for obtaining the number of samples |
|
|
|
|
|
:param org_dir: original directory storing the tree files |
|
|
:param file_names: list of file names in the directory |
|
|
:return: number of samples |
|
|
""" |
|
|
count = 0 |
|
|
|
|
|
for filename in file_names: |
|
|
|
|
|
if not filename.endswith('.mrg'): |
|
|
continue |
|
|
|
|
|
file_dir = os.path.join(org_dir, filename) |
|
|
with open(file_dir, 'r', encoding='utf-8') as reader: |
|
|
content = reader.readlines() |
|
|
for line in content: |
|
|
count += 1 |
|
|
|
|
|
return count |
|
|
|
|
|
def split_files(org_dir, split_dir, short_name=None, train_size=0.7, dev_size=0.15, rotation=None): |
|
|
os.makedirs(split_dir, exist_ok=True) |
|
|
|
|
|
if train_size + dev_size >= 1.0: |
|
|
print("Not making a test slice with the given ratios: train {} dev {}".format(train_size, dev_size)) |
|
|
|
|
|
|
|
|
file_names = create_shuffle_list(org_dir) |
|
|
|
|
|
|
|
|
train_path, dev_path, test_path = create_paths(split_dir, short_name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_samples = get_num_samples(org_dir, file_names) |
|
|
print("Found {} total lines in {}".format(num_samples, org_dir)) |
|
|
|
|
|
stop_train = int(num_samples * train_size) |
|
|
if train_size + dev_size >= 1.0: |
|
|
stop_dev = num_samples |
|
|
output_limits = (stop_train, stop_dev) |
|
|
output_names = (train_path, dev_path) |
|
|
print("Splitting {} train, {} dev".format(stop_train, stop_dev - stop_train)) |
|
|
elif train_size + dev_size > 0.0: |
|
|
stop_dev = int(num_samples * (train_size + dev_size)) |
|
|
output_limits = (stop_train, stop_dev, num_samples) |
|
|
output_names = (train_path, dev_path, test_path) |
|
|
print("Splitting {} train, {} dev, {} test".format(stop_train, stop_dev - stop_train, num_samples - stop_dev)) |
|
|
else: |
|
|
stop_dev = 0 |
|
|
output_limits = (num_samples,) |
|
|
output_names = (test_path,) |
|
|
print("Copying all {} lines to test".format(num_samples)) |
|
|
|
|
|
|
|
|
|
|
|
count = 0 |
|
|
|
|
|
trees = [] |
|
|
for filename in file_names: |
|
|
if not filename.endswith('.mrg'): |
|
|
continue |
|
|
with open(os.path.join(org_dir, filename), encoding='utf-8') as reader: |
|
|
new_trees = reader.readlines() |
|
|
new_trees = [x.strip() for x in new_trees] |
|
|
new_trees = [x for x in new_trees if x] |
|
|
trees.extend(new_trees) |
|
|
|
|
|
if rotation is not None and rotation[0] > 0: |
|
|
rotation_start = len(trees) * rotation[0] // rotation[1] |
|
|
rotation_end = stop_dev |
|
|
|
|
|
trees = trees[rotation_start:rotation_end] + trees[:rotation_start] + trees[rotation_end:] |
|
|
tree_iter = iter(trees) |
|
|
for write_path, count_limit in zip(output_names, output_limits): |
|
|
with open(write_path, 'w', encoding='utf-8') as writer: |
|
|
|
|
|
while count < count_limit: |
|
|
next_tree = next(tree_iter, None) |
|
|
if next_tree is None: |
|
|
raise RuntimeError("Ran out of trees before reading all of the expected trees") |
|
|
|
|
|
writer.write(next_tree) |
|
|
writer.write("\n") |
|
|
count += 1 |
|
|
|
|
|
def main(): |
|
|
""" |
|
|
Main function for the script |
|
|
|
|
|
Process args, loop through each tree in each file in the directory |
|
|
and write the trees to the train/dev/test split with a split of |
|
|
70/15/15 |
|
|
""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Script that splits a list of files of vtb trees into train/dev/test sets", |
|
|
) |
|
|
parser.add_argument( |
|
|
'org_dir', |
|
|
help='The location of the original directory storing correctly formatted vtb trees ' |
|
|
) |
|
|
parser.add_argument( |
|
|
'split_dir', |
|
|
help='The location of new directory storing the train/dev/test set' |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
org_dir = args.org_dir |
|
|
split_dir = args.split_dir |
|
|
|
|
|
random.seed(1234) |
|
|
|
|
|
split_files(org_dir, split_dir) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|