Spaces:
Running
Running
| import os | |
| import argparse | |
| import csv | |
| import logging | |
| import numpy as np | |
| from Nested.utils.helpers import logging_config | |
| from Nested.utils.data import conll_to_segments | |
| logger = logging.getLogger(__name__) | |
| def to_conll_format(input_files, output_path, multi_label=False): | |
| """ | |
| Parse data files and convert them into CoNLL format | |
| :param input_files: List[str] - list of filenames | |
| :param output_path: str - output path | |
| :param multi_label: boolean - True to process data with mutli-class/multi-label | |
| :return: | |
| """ | |
| for input_file in input_files: | |
| tokens = list() | |
| prev_sent_id = None | |
| with open(input_file, "r") as fh: | |
| r = csv.reader(fh, delimiter="\t", quotechar=" ") | |
| next(r) | |
| for row in r: | |
| sent_id, token, labels = row[1], row[3], row[4].split() | |
| valid_labels = sum([1 for l in labels if "-" in l or l == "O"]) == len(labels) | |
| if not valid_labels: | |
| logging.warning("Invalid labels found %s", str(row)) | |
| continue | |
| if not labels: | |
| logging.warning("Token %s has no label", str(row)) | |
| continue | |
| if not token: | |
| logging.warning("Token %s is missing", str(row)) | |
| continue | |
| if len(token.split()) > 1: | |
| logging.warning("Token %s has multiple tokens", str(row)) | |
| continue | |
| if prev_sent_id is not None and sent_id != prev_sent_id: | |
| tokens.append([]) | |
| if multi_label: | |
| tokens.append([token] + labels) | |
| else: | |
| tokens.append([token, labels[0]]) | |
| prev_sent_id = sent_id | |
| num_segments = sum([1 for token in tokens if not token]) | |
| logging.info("Found %d segments and %d tokens in %s", num_segments + 1, len(tokens) - num_segments, input_file) | |
| filename = os.path.basename(input_file) | |
| output_file = os.path.join(output_path, filename) | |
| with open(output_file, "w") as fh: | |
| fh.write("\n".join(" ".join(token) for token in tokens)) | |
| logging.info("Output file %s", output_file) | |
| def train_dev_test_split(input_files, output_path, train_ratio, dev_ratio): | |
| segments = list() | |
| filenames = ["train.txt", "val.txt", "test.txt"] | |
| for input_file in input_files: | |
| segments += conll_to_segments(input_file) | |
| n = len(segments) | |
| np.random.shuffle(segments) | |
| datasets = np.split(segments, [int(train_ratio*n), int((train_ratio+dev_ratio)*n)]) | |
| # write data to files | |
| for i in range(len(datasets)): | |
| filename = os.path.join(output_path, filenames[i]) | |
| with open(filename, "w") as fh: | |
| text = "\n\n".join(["\n".join([f"{token.text} {' '.join(token.gold_tag)}" for token in segment]) for segment in datasets[i]]) | |
| fh.write(text) | |
| logging.info("Output file %s", filename) | |
| def main(args): | |
| if args.task == "to_conll_format": | |
| to_conll_format(args.input_files, args.output_path, multi_label=args.multi_label) | |
| if args.task == "train_dev_test_split": | |
| train_dev_test_split(args.input_files, args.output_path, args.train_ratio, args.dev_ratio) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
| ) | |
| parser.add_argument( | |
| "--input_files", | |
| type=str, | |
| nargs="+", | |
| required=True, | |
| help="List of input files", | |
| ) | |
| parser.add_argument( | |
| "--output_path", | |
| type=str, | |
| required=True, | |
| help="Output path", | |
| ) | |
| parser.add_argument( | |
| "--train_ratio", | |
| type=float, | |
| required=False, | |
| help="Training data ratio (percent of segments). Required with the task train_dev_test_split. " | |
| "Files must in ConLL format", | |
| ) | |
| parser.add_argument( | |
| "--dev_ratio", | |
| type=float, | |
| required=False, | |
| help="Dev/val data ratio (percent of segments). Required with the task train_dev_test_split. " | |
| "Files must in ConLL format", | |
| ) | |
| parser.add_argument( | |
| "--task", required=True, choices=["to_conll_format", "train_dev_test_split"] | |
| ) | |
| parser.add_argument( | |
| "--multi_label", action='store_true' | |
| ) | |
| args = parser.parse_args() | |
| logging_config(os.path.join(args.output_path, "process.log")) | |
| main(args) | |