Spaces:
Running
Running
File size: 4,536 Bytes
f316449 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import os
import argparse
import csv
import logging
import numpy as np
from Nested.utils.helpers import logging_config
from Nested.utils.data import conll_to_segments
logger = logging.getLogger(__name__)
def to_conll_format(input_files, output_path, multi_label=False):
"""
Parse data files and convert them into CoNLL format
:param input_files: List[str] - list of filenames
:param output_path: str - output path
:param multi_label: boolean - True to process data with mutli-class/multi-label
:return:
"""
for input_file in input_files:
tokens = list()
prev_sent_id = None
with open(input_file, "r") as fh:
r = csv.reader(fh, delimiter="\t", quotechar=" ")
next(r)
for row in r:
sent_id, token, labels = row[1], row[3], row[4].split()
valid_labels = sum([1 for l in labels if "-" in l or l == "O"]) == len(labels)
if not valid_labels:
logging.warning("Invalid labels found %s", str(row))
continue
if not labels:
logging.warning("Token %s has no label", str(row))
continue
if not token:
logging.warning("Token %s is missing", str(row))
continue
if len(token.split()) > 1:
logging.warning("Token %s has multiple tokens", str(row))
continue
if prev_sent_id is not None and sent_id != prev_sent_id:
tokens.append([])
if multi_label:
tokens.append([token] + labels)
else:
tokens.append([token, labels[0]])
prev_sent_id = sent_id
num_segments = sum([1 for token in tokens if not token])
logging.info("Found %d segments and %d tokens in %s", num_segments + 1, len(tokens) - num_segments, input_file)
filename = os.path.basename(input_file)
output_file = os.path.join(output_path, filename)
with open(output_file, "w") as fh:
fh.write("\n".join(" ".join(token) for token in tokens))
logging.info("Output file %s", output_file)
def train_dev_test_split(input_files, output_path, train_ratio, dev_ratio):
segments = list()
filenames = ["train.txt", "val.txt", "test.txt"]
for input_file in input_files:
segments += conll_to_segments(input_file)
n = len(segments)
np.random.shuffle(segments)
datasets = np.split(segments, [int(train_ratio*n), int((train_ratio+dev_ratio)*n)])
# write data to files
for i in range(len(datasets)):
filename = os.path.join(output_path, filenames[i])
with open(filename, "w") as fh:
text = "\n\n".join(["\n".join([f"{token.text} {' '.join(token.gold_tag)}" for token in segment]) for segment in datasets[i]])
fh.write(text)
logging.info("Output file %s", filename)
def main(args):
if args.task == "to_conll_format":
to_conll_format(args.input_files, args.output_path, multi_label=args.multi_label)
if args.task == "train_dev_test_split":
train_dev_test_split(args.input_files, args.output_path, args.train_ratio, args.dev_ratio)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--input_files",
type=str,
nargs="+",
required=True,
help="List of input files",
)
parser.add_argument(
"--output_path",
type=str,
required=True,
help="Output path",
)
parser.add_argument(
"--train_ratio",
type=float,
required=False,
help="Training data ratio (percent of segments). Required with the task train_dev_test_split. "
"Files must in ConLL format",
)
parser.add_argument(
"--dev_ratio",
type=float,
required=False,
help="Dev/val data ratio (percent of segments). Required with the task train_dev_test_split. "
"Files must in ConLL format",
)
parser.add_argument(
"--task", required=True, choices=["to_conll_format", "train_dev_test_split"]
)
parser.add_argument(
"--multi_label", action='store_true'
)
args = parser.parse_args()
logging_config(os.path.join(args.output_path, "process.log"))
main(args)
|