File size: 6,058 Bytes
19b8775 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
"""
From a directory of files with VTB Trees, split into train/dev/test set
with a split of 70/15/15
The script requires two arguments
1. org_dir: the original directory obtainable from running vtb_convert.py
2. split_dir: the directory where the train/dev/test splits will be stored
"""
import os
import argparse
import random
def create_shuffle_list(org_dir):
"""
This function creates the random order with which we use to loop through the files
:param org_dir: original directory storing the files that store the trees
:return: list of file names randomly shuffled
"""
file_names = sorted(os.listdir(org_dir))
random.shuffle(file_names)
return file_names
def create_paths(split_dir, short_name):
"""
This function creates the necessary paths for the train/dev/test splits
:param split_dir: directory that stores the splits
:return: train path, dev path, test path
"""
if not short_name:
short_name = ""
elif not short_name.endswith("_"):
short_name = short_name + "_"
train_path = os.path.join(split_dir, '%strain.mrg' % short_name)
dev_path = os.path.join(split_dir, '%sdev.mrg' % short_name)
test_path = os.path.join(split_dir, '%stest.mrg' % short_name)
return train_path, dev_path, test_path
def get_num_samples(org_dir, file_names):
"""
Function for obtaining the number of samples
:param org_dir: original directory storing the tree files
:param file_names: list of file names in the directory
:return: number of samples
"""
count = 0
# Loop through the files, which then loop through the trees
for filename in file_names:
# Skip files that are not .mrg
if not filename.endswith('.mrg'):
continue
# File is .mrg. Start processing
file_dir = os.path.join(org_dir, filename)
with open(file_dir, 'r', encoding='utf-8') as reader:
content = reader.readlines()
for line in content:
count += 1
return count
def split_files(org_dir, split_dir, short_name=None, train_size=0.7, dev_size=0.15, rotation=None):
os.makedirs(split_dir, exist_ok=True)
if train_size + dev_size >= 1.0:
print("Not making a test slice with the given ratios: train {} dev {}".format(train_size, dev_size))
# Create a random shuffle list of the file names in the original directory
file_names = create_shuffle_list(org_dir)
# Create train_path, dev_path, test_path
train_path, dev_path, test_path = create_paths(split_dir, short_name)
# Set up the number of samples for each train/dev/test set
# TODO: if we ever wanted to split files with <s> </s> in them,
# this particular code would need some updating to pay attention to the ids
num_samples = get_num_samples(org_dir, file_names)
print("Found {} total lines in {}".format(num_samples, org_dir))
stop_train = int(num_samples * train_size)
if train_size + dev_size >= 1.0:
stop_dev = num_samples
output_limits = (stop_train, stop_dev)
output_names = (train_path, dev_path)
print("Splitting {} train, {} dev".format(stop_train, stop_dev - stop_train))
elif train_size + dev_size > 0.0:
stop_dev = int(num_samples * (train_size + dev_size))
output_limits = (stop_train, stop_dev, num_samples)
output_names = (train_path, dev_path, test_path)
print("Splitting {} train, {} dev, {} test".format(stop_train, stop_dev - stop_train, num_samples - stop_dev))
else:
stop_dev = 0
output_limits = (num_samples,)
output_names = (test_path,)
print("Copying all {} lines to test".format(num_samples))
# Count how much stuff we've written.
# We will switch to the next output file when we're written enough
count = 0
trees = []
for filename in file_names:
if not filename.endswith('.mrg'):
continue
with open(os.path.join(org_dir, filename), encoding='utf-8') as reader:
new_trees = reader.readlines()
new_trees = [x.strip() for x in new_trees]
new_trees = [x for x in new_trees if x]
trees.extend(new_trees)
# rotate the train & dev sections, leave the test section the same
if rotation is not None and rotation[0] > 0:
rotation_start = len(trees) * rotation[0] // rotation[1]
rotation_end = stop_dev
# if there are no test trees, rotation_end: will be empty anyway
trees = trees[rotation_start:rotation_end] + trees[:rotation_start] + trees[rotation_end:]
tree_iter = iter(trees)
for write_path, count_limit in zip(output_names, output_limits):
with open(write_path, 'w', encoding='utf-8') as writer:
# Loop through the files, which then loop through the trees and write to write_path
while count < count_limit:
next_tree = next(tree_iter, None)
if next_tree is None:
raise RuntimeError("Ran out of trees before reading all of the expected trees")
# Write to write_dir
writer.write(next_tree)
writer.write("\n")
count += 1
def main():
"""
Main function for the script
Process args, loop through each tree in each file in the directory
and write the trees to the train/dev/test split with a split of
70/15/15
"""
parser = argparse.ArgumentParser(
description="Script that splits a list of files of vtb trees into train/dev/test sets",
)
parser.add_argument(
'org_dir',
help='The location of the original directory storing correctly formatted vtb trees '
)
parser.add_argument(
'split_dir',
help='The location of new directory storing the train/dev/test set'
)
args = parser.parse_args()
org_dir = args.org_dir
split_dir = args.split_dir
random.seed(1234)
split_files(org_dir, split_dir)
if __name__ == '__main__':
main()
|