aaljabari commited on
Commit
367883f
·
verified ·
1 Parent(s): 1f4a044

Create process.py

Browse files
Files changed (1) hide show
  1. Nested/bin/process.py +140 -0
Nested/bin/process.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import csv
4
+ import logging
5
+ import numpy as np
6
+ from Nested.utils.helpers import logging_config
7
+ from Nested.utils.data import conll_to_segments
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def to_conll_format(input_files, output_path, multi_label=False):
13
+ """
14
+ Parse data files and convert them into CoNLL format
15
+ :param input_files: List[str] - list of filenames
16
+ :param output_path: str - output path
17
+ :param multi_label: boolean - True to process data with mutli-class/multi-label
18
+ :return:
19
+ """
20
+ for input_file in input_files:
21
+ tokens = list()
22
+ prev_sent_id = None
23
+
24
+ with open(input_file, "r") as fh:
25
+ r = csv.reader(fh, delimiter="\t", quotechar=" ")
26
+ next(r)
27
+
28
+ for row in r:
29
+ sent_id, token, labels = row[1], row[3], row[4].split()
30
+ valid_labels = sum([1 for l in labels if "-" in l or l == "O"]) == len(labels)
31
+
32
+ if not valid_labels:
33
+ logging.warning("Invalid labels found %s", str(row))
34
+ continue
35
+ if not labels:
36
+ logging.warning("Token %s has no label", str(row))
37
+ continue
38
+ if not token:
39
+ logging.warning("Token %s is missing", str(row))
40
+ continue
41
+ if len(token.split()) > 1:
42
+ logging.warning("Token %s has multiple tokens", str(row))
43
+ continue
44
+
45
+ if prev_sent_id is not None and sent_id != prev_sent_id:
46
+ tokens.append([])
47
+
48
+ if multi_label:
49
+ tokens.append([token] + labels)
50
+ else:
51
+ tokens.append([token, labels[0]])
52
+
53
+ prev_sent_id = sent_id
54
+
55
+ num_segments = sum([1 for token in tokens if not token])
56
+ logging.info("Found %d segments and %d tokens in %s", num_segments + 1, len(tokens) - num_segments, input_file)
57
+
58
+ filename = os.path.basename(input_file)
59
+ output_file = os.path.join(output_path, filename)
60
+
61
+ with open(output_file, "w") as fh:
62
+ fh.write("\n".join(" ".join(token) for token in tokens))
63
+ logging.info("Output file %s", output_file)
64
+
65
+
66
+ def train_dev_test_split(input_files, output_path, train_ratio, dev_ratio):
67
+ segments = list()
68
+ filenames = ["train.txt", "val.txt", "test.txt"]
69
+
70
+ for input_file in input_files:
71
+ segments += conll_to_segments(input_file)
72
+
73
+ n = len(segments)
74
+ np.random.shuffle(segments)
75
+ datasets = np.split(segments, [int(train_ratio*n), int((train_ratio+dev_ratio)*n)])
76
+
77
+ # write data to files
78
+ for i in range(len(datasets)):
79
+ filename = os.path.join(output_path, filenames[i])
80
+
81
+ with open(filename, "w") as fh:
82
+ text = "\n\n".join(["\n".join([f"{token.text} {' '.join(token.gold_tag)}" for token in segment]) for segment in datasets[i]])
83
+ fh.write(text)
84
+ logging.info("Output file %s", filename)
85
+
86
+
87
+ def main(args):
88
+ if args.task == "to_conll_format":
89
+ to_conll_format(args.input_files, args.output_path, multi_label=args.multi_label)
90
+ if args.task == "train_dev_test_split":
91
+ train_dev_test_split(args.input_files, args.output_path, args.train_ratio, args.dev_ratio)
92
+
93
+
94
+ if __name__ == "__main__":
95
+ parser = argparse.ArgumentParser(
96
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
97
+ )
98
+
99
+ parser.add_argument(
100
+ "--input_files",
101
+ type=str,
102
+ nargs="+",
103
+ required=True,
104
+ help="List of input files",
105
+ )
106
+
107
+ parser.add_argument(
108
+ "--output_path",
109
+ type=str,
110
+ required=True,
111
+ help="Output path",
112
+ )
113
+
114
+ parser.add_argument(
115
+ "--train_ratio",
116
+ type=float,
117
+ required=False,
118
+ help="Training data ratio (percent of segments). Required with the task train_dev_test_split. "
119
+ "Files must in ConLL format",
120
+ )
121
+
122
+ parser.add_argument(
123
+ "--dev_ratio",
124
+ type=float,
125
+ required=False,
126
+ help="Dev/val data ratio (percent of segments). Required with the task train_dev_test_split. "
127
+ "Files must in ConLL format",
128
+ )
129
+
130
+ parser.add_argument(
131
+ "--task", required=True, choices=["to_conll_format", "train_dev_test_split"]
132
+ )
133
+
134
+ parser.add_argument(
135
+ "--multi_label", action='store_true'
136
+ )
137
+
138
+ args = parser.parse_args()
139
+ logging_config(os.path.join(args.output_path, "process.log"))
140
+ main(args)