| """ | |
| Randomly split a file into train, dev, and test sections | |
| Specifically used in the case of building a tagger from the initial | |
| POS tagging provided by Isra, but obviously can be used to split any | |
| conllu file | |
| """ | |
| import argparse | |
| import os | |
| import random | |
| from stanza.models.common.doc import Document | |
| from stanza.utils.conll import CoNLL | |
| from stanza.utils.default_paths import get_default_paths | |
| def random_split(doc, weights, remove_xpos=False, remove_feats=False): | |
| """ | |
| weights: a tuple / list of (train, dev, test) weights | |
| """ | |
| train_doc = ([], []) | |
| dev_doc = ([], []) | |
| test_doc = ([], []) | |
| splits = [train_doc, dev_doc, test_doc] | |
| for sentence in doc.sentences: | |
| sentence_dict = sentence.to_dict() | |
| if remove_xpos: | |
| for x in sentence_dict: | |
| x.pop('xpos', None) | |
| if remove_feats: | |
| for x in sentence_dict: | |
| x.pop('feats', None) | |
| split = random.choices(splits, weights)[0] | |
| split[0].append(sentence_dict) | |
| split[1].append(sentence.comments) | |
| splits = [Document(split[0], comments=split[1]) for split in splits] | |
| return splits | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--filename', default='extern_data/sindhi/upos/sindhi_upos.conllu', help='Which file to split') | |
| parser.add_argument('--train', type=float, default=0.8, help='Fraction of the data to use for train') | |
| parser.add_argument('--dev', type=float, default=0.1, help='Fraction of the data to use for dev') | |
| parser.add_argument('--test', type=float, default=0.1, help='Fraction of the data to use for test') | |
| parser.add_argument('--seed', default='1234', help='Random seed to use') | |
| parser.add_argument('--short_name', default='sd_isra', help='Dataset name to use when writing output files') | |
| parser.add_argument('--no_remove_xpos', default=True, action='store_false', dest='remove_xpos', help='By default, we remove the xpos from the dataset') | |
| parser.add_argument('--no_remove_feats', default=True, action='store_false', dest='remove_feats', help='By default, we remove the feats from the dataset') | |
| parser.add_argument('--output_directory', default=get_default_paths()["POS_DATA_DIR"], help="Where to put the split conllu") | |
| args = parser.parse_args() | |
| weights = (args.train, args.dev, args.test) | |
| doc = CoNLL.conll2doc(args.filename) | |
| random.seed(args.seed) | |
| splits = random_split(doc, weights, args.remove_xpos, args.remove_feats) | |
| for split_doc, split_name in zip(splits, ("train", "dev", "test")): | |
| filename = os.path.join(args.output_directory, "%s.%s.in.conllu" % (args.short_name, split_name)) | |
| print("Outputting %d sentences to %s" % (len(split_doc.sentences), filename)) | |
| CoNLL.write_doc2conll(split_doc, filename) | |
| if __name__ == '__main__': | |
| main() | |