File size: 4,536 Bytes
f316449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import argparse
import csv
import logging
import numpy as np
from Nested.utils.helpers import logging_config
from Nested.utils.data import conll_to_segments

logger = logging.getLogger(__name__)


def to_conll_format(input_files, output_path, multi_label=False):
    """
    Parse data files and convert them into CoNLL format
    :param input_files: List[str] - list of filenames
    :param output_path: str - output path
    :param multi_label: boolean - True to process data with mutli-class/multi-label
    :return:
    """
    for input_file in input_files:
        tokens = list()
        prev_sent_id = None

        with open(input_file, "r") as fh:
            r = csv.reader(fh, delimiter="\t", quotechar=" ")
            next(r)

            for row in r:
                sent_id, token, labels = row[1], row[3], row[4].split()
                valid_labels = sum([1 for l in labels if "-" in l or l == "O"]) == len(labels)

                if not valid_labels:
                    logging.warning("Invalid labels found %s", str(row))
                    continue
                if not labels:
                    logging.warning("Token %s has no label", str(row))
                    continue
                if not token:
                    logging.warning("Token %s is missing", str(row))
                    continue
                if len(token.split()) > 1:
                    logging.warning("Token %s has multiple tokens", str(row))
                    continue

                if prev_sent_id is not None and sent_id != prev_sent_id:
                    tokens.append([])

                if multi_label:
                    tokens.append([token] + labels)
                else:
                    tokens.append([token, labels[0]])

                prev_sent_id = sent_id

        num_segments = sum([1 for token in tokens if not token])
        logging.info("Found %d segments and %d tokens in %s", num_segments + 1, len(tokens) - num_segments, input_file)

        filename = os.path.basename(input_file)
        output_file = os.path.join(output_path, filename)

        with open(output_file, "w") as fh:
            fh.write("\n".join(" ".join(token) for token in tokens))
            logging.info("Output file %s", output_file)


def train_dev_test_split(input_files, output_path, train_ratio, dev_ratio):
    segments = list()
    filenames = ["train.txt", "val.txt", "test.txt"]

    for input_file in input_files:
        segments += conll_to_segments(input_file)

    n = len(segments)
    np.random.shuffle(segments)
    datasets = np.split(segments, [int(train_ratio*n), int((train_ratio+dev_ratio)*n)])

    # write data to files
    for i in range(len(datasets)):
        filename = os.path.join(output_path, filenames[i])

        with open(filename, "w") as fh:
            text = "\n\n".join(["\n".join([f"{token.text} {' '.join(token.gold_tag)}" for token in segment]) for segment in datasets[i]])
            fh.write(text)
            logging.info("Output file %s", filename)


def main(args):
    if args.task == "to_conll_format":
        to_conll_format(args.input_files, args.output_path, multi_label=args.multi_label)
    if args.task == "train_dev_test_split":
        train_dev_test_split(args.input_files, args.output_path, args.train_ratio, args.dev_ratio)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    parser.add_argument(
        "--input_files",
        type=str,
        nargs="+",
        required=True,
        help="List of input files",
    )

    parser.add_argument(
        "--output_path",
        type=str,
        required=True,
        help="Output path",
    )

    parser.add_argument(
        "--train_ratio",
        type=float,
        required=False,
        help="Training data ratio (percent of segments). Required with the task train_dev_test_split. "
             "Files must in ConLL format",
    )

    parser.add_argument(
        "--dev_ratio",
        type=float,
        required=False,
        help="Dev/val data ratio (percent of segments). Required with the task train_dev_test_split. "
             "Files must in ConLL format",
    )

    parser.add_argument(
        "--task", required=True, choices=["to_conll_format", "train_dev_test_split"]
    )

    parser.add_argument(
        "--multi_label", action='store_true'
    )

    args = parser.parse_args()
    logging_config(os.path.join(args.output_path, "process.log"))
    main(args)