|
|
|
|
|
import argparse |
|
|
from enum import Enum |
|
|
import glob |
|
|
import logging |
|
|
import os |
|
|
import re |
|
|
import subprocess |
|
|
import sys |
|
|
import unicodedata |
|
|
|
|
|
from stanza.models.common.short_name_to_treebank import canonical_treebank_name |
|
|
import stanza.utils.datasets.prepare_tokenizer_data as prepare_tokenizer_data |
|
|
import stanza.utils.datasets.conllu_to_text as conllu_to_text |
|
|
import stanza.utils.default_paths as default_paths |
|
|
|
|
|
logger = logging.getLogger('stanza') |
|
|
|
|
|
|
|
|
MWT_RE = re.compile("^[0-9]+[-][0-9]+") |
|
|
|
|
|
|
|
|
MWT_OR_COPY_RE = re.compile("^[0-9]+[-.][0-9]+") |
|
|
|
|
|
|
|
|
INT_RE = re.compile("^[0-9]+$") |
|
|
|
|
|
class ModelType(Enum): |
|
|
TOKENIZER = 1 |
|
|
MWT = 2 |
|
|
POS = 3 |
|
|
LEMMA = 4 |
|
|
DEPPARSE = 5 |
|
|
|
|
|
class UnknownDatasetError(ValueError): |
|
|
def __init__(self, dataset, text): |
|
|
super().__init__(text) |
|
|
self.dataset = dataset |
|
|
|
|
|
def convert_conllu_to_txt(tokenizer_dir, short_name, shards=("train", "dev", "test")): |
|
|
""" |
|
|
Convert the conllu documents for this dataset to a .txt format |
|
|
|
|
|
This follows the old conllu_to_text.pl script, except we never |
|
|
used the ZH option anyway, so we didn't reimplement it here |
|
|
""" |
|
|
for dataset in shards: |
|
|
output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu" |
|
|
output_txt = f"{tokenizer_dir}/{short_name}.{dataset}.txt" |
|
|
|
|
|
if not os.path.exists(output_conllu): |
|
|
raise FileNotFoundError("Cannot convert %s as the file cannot be found" % output_conllu) |
|
|
conllu_to_text.main([output_conllu, output_txt]) |
|
|
|
|
|
def strip_accents(word): |
|
|
""" |
|
|
Remove diacritics from words such as in the UD GRC datasets |
|
|
""" |
|
|
converted = ''.join(c for c in unicodedata.normalize('NFD', word) |
|
|
if unicodedata.category(c) not in ('Mn')) |
|
|
if len(converted) == 0: |
|
|
return word |
|
|
return converted |
|
|
|
|
|
def mwt_name(base_dir, short_name, dataset): |
|
|
return os.path.join(base_dir, f"{short_name}-ud-{dataset}-mwt.json") |
|
|
|
|
|
def tokenizer_conllu_name(base_dir, short_name, dataset): |
|
|
return os.path.join(base_dir, f"{short_name}.{dataset}.gold.conllu") |
|
|
|
|
|
def prepare_tokenizer_dataset_labels(input_txt, input_conllu, tokenizer_dir, short_name, dataset): |
|
|
labels_filename = f"{tokenizer_dir}/{short_name}-ud-{dataset}.toklabels" |
|
|
mwt_filename = mwt_name(tokenizer_dir, short_name, dataset) |
|
|
prepare_tokenizer_data.main([input_txt, |
|
|
input_conllu, |
|
|
"-o", labels_filename, |
|
|
"-m", mwt_filename]) |
|
|
|
|
|
def prepare_tokenizer_treebank_labels(tokenizer_dir, short_name): |
|
|
""" |
|
|
Given the txt and gold.conllu files, prepare mwt and label files for train/dev/test |
|
|
""" |
|
|
for dataset in ("train", "dev", "test"): |
|
|
output_txt = f"{tokenizer_dir}/{short_name}.{dataset}.txt" |
|
|
output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu" |
|
|
try: |
|
|
prepare_tokenizer_dataset_labels(output_txt, output_conllu, tokenizer_dir, short_name, dataset) |
|
|
except (KeyboardInterrupt, SystemExit): |
|
|
raise |
|
|
except: |
|
|
print("Failed to convert %s to %s" % (output_txt, output_conllu)) |
|
|
raise |
|
|
|
|
|
def read_sentences_from_conllu(filename): |
|
|
""" |
|
|
Reads a conllu file as a list of list of strings |
|
|
|
|
|
Finding a blank line separates the lists |
|
|
""" |
|
|
sents = [] |
|
|
cache = [] |
|
|
with open(filename, encoding="utf-8") as infile: |
|
|
for line in infile: |
|
|
line = line.strip() |
|
|
if len(line) == 0: |
|
|
if len(cache) > 0: |
|
|
sents.append(cache) |
|
|
cache = [] |
|
|
continue |
|
|
cache.append(line) |
|
|
if len(cache) > 0: |
|
|
sents.append(cache) |
|
|
return sents |
|
|
|
|
|
def maybe_add_fake_dependencies(lines): |
|
|
""" |
|
|
Possibly add fake dependencies in columns 6 and 7 (counting from 0) |
|
|
|
|
|
The conllu scripts need the dependencies column filled out, so in |
|
|
the case of models we build without dependency data, we need to |
|
|
add those fake dependencies in order to use the eval script etc |
|
|
|
|
|
lines: a list of strings with 10 tab separated columns |
|
|
comments are allowed (they will be skipped) |
|
|
|
|
|
returns: the same strings, but with fake dependencies added |
|
|
if columns 6 and 7 were empty |
|
|
""" |
|
|
new_lines = [] |
|
|
root_idx = None |
|
|
first_idx = None |
|
|
for line_idx, line in enumerate(lines): |
|
|
if line.startswith("#"): |
|
|
new_lines.append(line) |
|
|
continue |
|
|
|
|
|
pieces = line.split("\t") |
|
|
if MWT_OR_COPY_RE.match(pieces[0]): |
|
|
new_lines.append(line) |
|
|
continue |
|
|
|
|
|
token_idx = int(pieces[0]) |
|
|
if pieces[6] != '_': |
|
|
if pieces[6] == '0': |
|
|
root_idx = token_idx |
|
|
new_lines.append(line) |
|
|
elif token_idx == 1: |
|
|
|
|
|
|
|
|
|
|
|
first_idx = line_idx |
|
|
new_lines.append(pieces) |
|
|
else: |
|
|
pieces[6] = "1" |
|
|
pieces[7] = "dep" |
|
|
new_lines.append("\t".join(pieces)) |
|
|
if first_idx is not None: |
|
|
if root_idx is None: |
|
|
new_lines[first_idx][6] = "0" |
|
|
new_lines[first_idx][7] = "root" |
|
|
else: |
|
|
new_lines[first_idx][6] = str(root_idx) |
|
|
new_lines[first_idx][7] = "dep" |
|
|
new_lines[first_idx] = "\t".join(new_lines[first_idx]) |
|
|
return new_lines |
|
|
|
|
|
def write_sentences_to_file(outfile, sents): |
|
|
for lines in sents: |
|
|
lines = maybe_add_fake_dependencies(lines) |
|
|
for line in lines: |
|
|
print(line, file=outfile) |
|
|
print("", file=outfile) |
|
|
|
|
|
def write_sentences_to_conllu(filename, sents): |
|
|
with open(filename, 'w', encoding="utf-8") as outfile: |
|
|
write_sentences_to_file(outfile, sents) |
|
|
|
|
|
def find_treebank_dataset_file(treebank, udbase_dir, dataset, extension, fail=False, env_var="UDBASE"): |
|
|
""" |
|
|
For a given treebank, dataset, extension, look for the exact filename to use. |
|
|
|
|
|
Sometimes the short name we use is different from the short name |
|
|
used by UD. For example, Norwegian or Chinese. Hence the reason |
|
|
to not hardcode it based on treebank |
|
|
|
|
|
set fail=True to fail if the file is not found |
|
|
""" |
|
|
if treebank.startswith("UD_Korean") and treebank.endswith("_seg"): |
|
|
treebank = treebank[:-4] |
|
|
if treebank.startswith("UD_Ancient_Greek-") and (treebank.endswith("-Diacritics") or treebank.endswith("-diacritics")): |
|
|
treebank = treebank[:-11] |
|
|
filename = os.path.join(udbase_dir, treebank, f"*-ud-{dataset}.{extension}") |
|
|
files = glob.glob(filename) |
|
|
if len(files) == 0: |
|
|
if fail: |
|
|
raise FileNotFoundError("Could not find any treebank files which matched {}\nIf you have the data elsewhere, you can change the base directory for the search by changing the {} environment variable".format(filename, env_var)) |
|
|
else: |
|
|
return None |
|
|
elif len(files) == 1: |
|
|
return files[0] |
|
|
else: |
|
|
raise RuntimeError(f"Unexpected number of files matched '{udbase_dir}/{treebank}/*-ud-{dataset}.{extension}'") |
|
|
|
|
|
def mostly_underscores(filename): |
|
|
""" |
|
|
Certain treebanks have proprietary data, so the text is hidden |
|
|
|
|
|
For example: |
|
|
UD_Arabic-NYUAD |
|
|
UD_English-ESL |
|
|
UD_English-GUMReddit |
|
|
UD_Hindi_English-HIENCS |
|
|
UD_Japanese-BCCWJ |
|
|
""" |
|
|
underscore_count = 0 |
|
|
total_count = 0 |
|
|
for line in open(filename).readlines(): |
|
|
line = line.strip() |
|
|
if not line: |
|
|
continue |
|
|
if line.startswith("#"): |
|
|
continue |
|
|
total_count = total_count + 1 |
|
|
pieces = line.split("\t") |
|
|
if pieces[1] in ("_", "-"): |
|
|
underscore_count = underscore_count + 1 |
|
|
return underscore_count / total_count > 0.5 |
|
|
|
|
|
def num_words_in_file(conllu_file): |
|
|
""" |
|
|
Count the number of non-blank lines in a conllu file |
|
|
""" |
|
|
count = 0 |
|
|
with open(conllu_file) as fin: |
|
|
for line in fin: |
|
|
line = line.strip() |
|
|
if not line: |
|
|
continue |
|
|
if line.startswith("#"): |
|
|
continue |
|
|
count = count + 1 |
|
|
return count |
|
|
|
|
|
|
|
|
def get_ud_treebanks(udbase_dir, filtered=True): |
|
|
""" |
|
|
Looks in udbase_dir for all the treebanks which have both train, dev, and test |
|
|
""" |
|
|
treebanks = sorted(glob.glob(udbase_dir + "/UD_*")) |
|
|
|
|
|
treebanks = [os.path.split(t)[1] for t in treebanks] |
|
|
treebanks = [t for t in treebanks if t != "UD_English-GUMReddit"] |
|
|
if filtered: |
|
|
treebanks = [t for t in treebanks |
|
|
if (find_treebank_dataset_file(t, udbase_dir, "train", "conllu") and |
|
|
|
|
|
|
|
|
find_treebank_dataset_file(t, udbase_dir, "test", "conllu"))] |
|
|
treebanks = [t for t in treebanks |
|
|
if not mostly_underscores(find_treebank_dataset_file(t, udbase_dir, "train", "conllu"))] |
|
|
|
|
|
|
|
|
treebanks = [t for t in treebanks |
|
|
if (find_treebank_dataset_file(t, udbase_dir, "dev", "conllu") or |
|
|
num_words_in_file(find_treebank_dataset_file(t, udbase_dir, "train", "conllu")) > 1000 or |
|
|
num_words_in_file(find_treebank_dataset_file(t, udbase_dir, "test", "conllu")) > 5000)] |
|
|
return treebanks |
|
|
|
|
|
def build_argparse(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('treebanks', type=str, nargs='+', help='Which treebanks to run on. Use all_ud or ud_all for all UD treebanks') |
|
|
|
|
|
return parser |
|
|
|
|
|
|
|
|
def main(process_treebank, model_type, add_specific_args=None): |
|
|
logger.info("Datasets program called with:\n" + " ".join(sys.argv)) |
|
|
|
|
|
parser = build_argparse() |
|
|
if add_specific_args is not None: |
|
|
add_specific_args(parser) |
|
|
args = parser.parse_args() |
|
|
|
|
|
paths = default_paths.get_default_paths() |
|
|
|
|
|
treebanks = [] |
|
|
for treebank in args.treebanks: |
|
|
if treebank.lower() in ('ud_all', 'all_ud'): |
|
|
ud_treebanks = get_ud_treebanks(paths["UDBASE"]) |
|
|
treebanks.extend(ud_treebanks) |
|
|
else: |
|
|
|
|
|
treebank = canonical_treebank_name(treebank) |
|
|
treebanks.append(treebank) |
|
|
|
|
|
for treebank in treebanks: |
|
|
process_treebank(treebank, model_type, paths, args) |
|
|
|