File size: 6,058 Bytes
19b8775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""
From a directory of files with VTB Trees, split into train/dev/test set
with a split of 70/15/15

The script requires two arguments
1. org_dir: the original directory obtainable from running vtb_convert.py
2. split_dir: the directory where the train/dev/test splits will be stored
"""

import os
import argparse
import random


def create_shuffle_list(org_dir):
    """
    This function creates the random order with which we use to loop through the files

    :param org_dir: original directory storing the files that store the trees
    :return: list of file names randomly shuffled
    """
    file_names = sorted(os.listdir(org_dir))
    random.shuffle(file_names)

    return file_names


def create_paths(split_dir, short_name):
    """
    This function creates the necessary paths for the train/dev/test splits

    :param split_dir: directory that stores the splits
    :return: train path, dev path, test path
    """
    if not short_name:
        short_name = ""
    elif not short_name.endswith("_"):
        short_name = short_name + "_"

    train_path = os.path.join(split_dir, '%strain.mrg' % short_name)
    dev_path = os.path.join(split_dir, '%sdev.mrg' % short_name)
    test_path = os.path.join(split_dir, '%stest.mrg' % short_name)

    return train_path, dev_path, test_path


def get_num_samples(org_dir, file_names):
    """
    Function for obtaining the number of samples

    :param org_dir: original directory storing the tree files
    :param file_names: list of file names in the directory
    :return: number of samples
    """
    count = 0
    # Loop through the files, which then loop through the trees
    for filename in file_names:
        # Skip files that are not .mrg
        if not filename.endswith('.mrg'):
            continue
        # File is .mrg. Start processing
        file_dir = os.path.join(org_dir, filename)
        with open(file_dir, 'r', encoding='utf-8') as reader:
            content = reader.readlines()
            for line in content:
                count += 1

    return count

def split_files(org_dir, split_dir, short_name=None, train_size=0.7, dev_size=0.15, rotation=None):
    os.makedirs(split_dir, exist_ok=True)

    if train_size + dev_size >= 1.0:
        print("Not making a test slice with the given ratios: train {} dev {}".format(train_size, dev_size))

    # Create a random shuffle list of the file names in the original directory
    file_names = create_shuffle_list(org_dir)

    # Create train_path, dev_path, test_path
    train_path, dev_path, test_path = create_paths(split_dir, short_name)

    # Set up the number of samples for each train/dev/test set
    # TODO: if we ever wanted to split files with <s> </s> in them,
    # this particular code would need some updating to pay attention to the ids
    num_samples = get_num_samples(org_dir, file_names)
    print("Found {} total lines in {}".format(num_samples, org_dir))

    stop_train = int(num_samples * train_size)
    if train_size + dev_size >= 1.0:
        stop_dev = num_samples
        output_limits = (stop_train, stop_dev)
        output_names = (train_path, dev_path)
        print("Splitting {} train, {} dev".format(stop_train, stop_dev - stop_train))
    elif train_size + dev_size > 0.0:
        stop_dev = int(num_samples * (train_size + dev_size))
        output_limits = (stop_train, stop_dev, num_samples)
        output_names = (train_path, dev_path, test_path)
        print("Splitting {} train, {} dev, {} test".format(stop_train, stop_dev - stop_train, num_samples - stop_dev))
    else:
        stop_dev = 0
        output_limits = (num_samples,)
        output_names = (test_path,)
        print("Copying all {} lines to test".format(num_samples))

    # Count how much stuff we've written.
    # We will switch to the next output file when we're written enough
    count = 0

    trees = []
    for filename in file_names:
        if not filename.endswith('.mrg'):
            continue
        with open(os.path.join(org_dir, filename), encoding='utf-8') as reader:
            new_trees = reader.readlines()
            new_trees = [x.strip() for x in new_trees]
            new_trees = [x for x in new_trees if x]
            trees.extend(new_trees)
    # rotate the train & dev sections, leave the test section the same
    if rotation is not None and rotation[0] > 0:
        rotation_start = len(trees) * rotation[0] // rotation[1]
        rotation_end = stop_dev
        # if there are no test trees, rotation_end: will be empty anyway
        trees = trees[rotation_start:rotation_end] + trees[:rotation_start] + trees[rotation_end:]
    tree_iter = iter(trees)
    for write_path, count_limit in zip(output_names, output_limits):
        with open(write_path, 'w', encoding='utf-8') as writer:
            # Loop through the files, which then loop through the trees and write to write_path
            while count < count_limit:
                next_tree = next(tree_iter, None)
                if next_tree is None:
                    raise RuntimeError("Ran out of trees before reading all of the expected trees")
                # Write to write_dir
                writer.write(next_tree)
                writer.write("\n")
                count += 1

def main():
    """
    Main function for the script

    Process args, loop through each tree in each file in the directory
    and write the trees to the train/dev/test split with a split of
    70/15/15
    """
    parser = argparse.ArgumentParser(
        description="Script that splits a list of files of vtb trees into train/dev/test sets",
    )
    parser.add_argument(
        'org_dir',
        help='The location of the original directory storing correctly formatted vtb trees '
    )
    parser.add_argument(
        'split_dir',
        help='The location of new directory storing the train/dev/test set'
    )

    args = parser.parse_args()

    org_dir = args.org_dir
    split_dir = args.split_dir

    random.seed(1234)

    split_files(org_dir, split_dir)

if __name__ == '__main__':
    main()