stanza-digphil / stanza /utils /datasets /random_split_conllu.py
Albin Thörn Cleland
Clean initial commit with LFS
19b8775
"""
Randomly split a file into train, dev, and test sections
Specifically used in the case of building a tagger from the initial
POS tagging provided by Isra, but obviously can be used to split any
conllu file
"""
import argparse
import os
import random
from stanza.models.common.doc import Document
from stanza.utils.conll import CoNLL
from stanza.utils.default_paths import get_default_paths
def random_split(doc, weights, remove_xpos=False, remove_feats=False):
"""
weights: a tuple / list of (train, dev, test) weights
"""
train_doc = ([], [])
dev_doc = ([], [])
test_doc = ([], [])
splits = [train_doc, dev_doc, test_doc]
for sentence in doc.sentences:
sentence_dict = sentence.to_dict()
if remove_xpos:
for x in sentence_dict:
x.pop('xpos', None)
if remove_feats:
for x in sentence_dict:
x.pop('feats', None)
split = random.choices(splits, weights)[0]
split[0].append(sentence_dict)
split[1].append(sentence.comments)
splits = [Document(split[0], comments=split[1]) for split in splits]
return splits
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--filename', default='extern_data/sindhi/upos/sindhi_upos.conllu', help='Which file to split')
parser.add_argument('--train', type=float, default=0.8, help='Fraction of the data to use for train')
parser.add_argument('--dev', type=float, default=0.1, help='Fraction of the data to use for dev')
parser.add_argument('--test', type=float, default=0.1, help='Fraction of the data to use for test')
parser.add_argument('--seed', default='1234', help='Random seed to use')
parser.add_argument('--short_name', default='sd_isra', help='Dataset name to use when writing output files')
parser.add_argument('--no_remove_xpos', default=True, action='store_false', dest='remove_xpos', help='By default, we remove the xpos from the dataset')
parser.add_argument('--no_remove_feats', default=True, action='store_false', dest='remove_feats', help='By default, we remove the feats from the dataset')
parser.add_argument('--output_directory', default=get_default_paths()["POS_DATA_DIR"], help="Where to put the split conllu")
args = parser.parse_args()
weights = (args.train, args.dev, args.test)
doc = CoNLL.conll2doc(args.filename)
random.seed(args.seed)
splits = random_split(doc, weights, args.remove_xpos, args.remove_feats)
for split_doc, split_name in zip(splits, ("train", "dev", "test")):
filename = os.path.join(args.output_directory, "%s.%s.in.conllu" % (args.short_name, split_name))
print("Outputting %d sentences to %s" % (len(split_doc.sentences), filename))
CoNLL.write_doc2conll(split_doc, filename)
if __name__ == '__main__':
main()