|
|
from datasets import load_dataset, load_from_disk, DatasetDict, concatenate_datasets |
|
|
import argparse |
|
|
import ast |
|
|
import logging.config |
|
|
|
|
|
from utils import default_logging_config, get_uniq_training_labels, show_examples |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
allowed_xpos = [ |
|
|
"''", |
|
|
'$', |
|
|
',', |
|
|
'-LRB-', |
|
|
'-RRB-', |
|
|
'.', |
|
|
':', |
|
|
'ADD', |
|
|
'AFX', |
|
|
'CC', |
|
|
'CD', |
|
|
'DT', |
|
|
'EX', |
|
|
'FW', |
|
|
'HYPH', |
|
|
'IN', |
|
|
'JJ', |
|
|
'JJR', |
|
|
'JJS', |
|
|
'LS', |
|
|
'MD', |
|
|
'NFP', |
|
|
'NN', |
|
|
'NNP', |
|
|
'NNPS', |
|
|
'NNS', |
|
|
'PDT', |
|
|
'POS', |
|
|
'PRP$', |
|
|
'PRP', |
|
|
'RB', |
|
|
'RBR', |
|
|
'RBS', |
|
|
'RP', |
|
|
'SYM', |
|
|
'TO', |
|
|
'UH', |
|
|
'VB', |
|
|
'VBD', |
|
|
'VBG', |
|
|
'VBN', |
|
|
'VBP', |
|
|
'VBZ', |
|
|
'WDT', |
|
|
'WP$', |
|
|
'WP', |
|
|
'WRB', |
|
|
'``', |
|
|
] |
|
|
|
|
|
allowed_deprel = [ |
|
|
'acl', |
|
|
'acl:relcl', |
|
|
'advcl', |
|
|
'advmod', |
|
|
'amod', |
|
|
'appos', |
|
|
'aux', |
|
|
'aux:pass', |
|
|
'case', |
|
|
'cc', |
|
|
'cc:preconj', |
|
|
'ccomp', |
|
|
'compound', |
|
|
'compound:prt', |
|
|
'conj', |
|
|
'cop', |
|
|
'csubj', |
|
|
'csubj:pass', |
|
|
'det', |
|
|
'det:predet', |
|
|
'discourse', |
|
|
'expl', |
|
|
'fixed', |
|
|
'flat', |
|
|
'iobj', |
|
|
'list', |
|
|
'mark', |
|
|
'nmod', |
|
|
'nmod:npmod', |
|
|
'nmod:poss', |
|
|
'nmod:tmod', |
|
|
'nsubj', |
|
|
'nsubj:pass', |
|
|
'nummod', |
|
|
'obj', |
|
|
'obl', |
|
|
'obl:npmod', |
|
|
'obl:tmod', |
|
|
'parataxis', |
|
|
'punct', |
|
|
'root', |
|
|
'vocative', |
|
|
'xcomp', |
|
|
] |
|
|
|
|
|
non_target_feats = { |
|
|
"Abbr": [], |
|
|
"Typo": [], |
|
|
"Voice": [], |
|
|
} |
|
|
|
|
|
target_feats = [ |
|
|
"Case", "Definite", "Degree", "Foreign", "Gender", "Mood", "NumType", "Number", |
|
|
"Person", "Polarity", "PronType", "Poss", "Reflex", "Tense", "VerbForm", |
|
|
] |
|
|
|
|
|
|
|
|
def add_target_feat_columns(exp): |
|
|
""" |
|
|
Convert example["feats"] (list of feats) into separate columns |
|
|
for each target_feat. Always return a dict with the same structure. |
|
|
""" |
|
|
if "feats" in exp: |
|
|
|
|
|
feats_list = exp["feats"] |
|
|
|
|
|
|
|
|
parsed_feats = [parse_morphological_feats( |
|
|
f, target_feats, exp, i |
|
|
) for i, f in enumerate(feats_list)] |
|
|
|
|
|
|
|
|
for feat in target_feats: |
|
|
exp[feat] = [pf[feat] for pf in parsed_feats] |
|
|
return exp |
|
|
|
|
|
|
|
|
def convert_upos(exp, labels): |
|
|
exp["pos"] = [labels[i] for i in exp.pop("upos")] |
|
|
return exp |
|
|
|
|
|
|
|
|
def extract_label_groups(exp, feat, target_labels=None): |
|
|
""" |
|
|
For example, given a list of labels (e.g. ["X", "X", "NN", "NN", "X", "X", "NNS", "X"]), |
|
|
this function will extract the index positions of the labels: NN, NNS, NNP, NNPS. |
|
|
|
|
|
It returns a list of consecutive index groupings for those noun labels. |
|
|
For example: |
|
|
["X", "X", "NN", "NN", "X", "X", "NNS", "X"] |
|
|
would return: |
|
|
[[2, 3], [6]] |
|
|
|
|
|
Args: |
|
|
exp: Example |
|
|
feat: feature |
|
|
target_labels (set of str): The set of tags to target. |
|
|
|
|
|
Returns: |
|
|
list of lists of int: A list where each sub-list contains consecutive indices |
|
|
of labels that match NN, NNS, NNP, NNPS. |
|
|
""" |
|
|
groups = [] |
|
|
current_group = [] |
|
|
|
|
|
for idx, label in enumerate(exp[feat]): |
|
|
if (label in target_labels) if target_labels is not None else label != "X": |
|
|
|
|
|
|
|
|
if current_group and idx == current_group[-1] + 1: |
|
|
current_group.append(idx) |
|
|
else: |
|
|
if current_group: |
|
|
groups.append(current_group) |
|
|
current_group = [idx] |
|
|
else: |
|
|
if current_group: |
|
|
groups.append(current_group) |
|
|
current_group = [] |
|
|
|
|
|
|
|
|
if current_group: |
|
|
groups.append(current_group) |
|
|
|
|
|
return groups |
|
|
|
|
|
|
|
|
def is_evenly_shaped(exp): |
|
|
|
|
|
feats = ["xpos", "deprel", *target_feats] |
|
|
n_tokens = len(exp["tokens"]) |
|
|
for feat_name in feats: |
|
|
if len(exp[feat_name]) != n_tokens: |
|
|
return False |
|
|
return True |
|
|
|
|
|
|
|
|
def is_valid_example(exp, dataset_name="ewt"): |
|
|
"""Return True if all xpos & deprel labels are in the common sets, else False.""" |
|
|
uniq_tokens = list(set(exp["tokens"])) |
|
|
if len(uniq_tokens) == 1: |
|
|
if uniq_tokens[0] == "_": |
|
|
return False |
|
|
for x in exp["xpos"]: |
|
|
|
|
|
if x not in allowed_xpos: |
|
|
|
|
|
if x is None: |
|
|
return False |
|
|
elif x == "_": |
|
|
return False |
|
|
elif x == "-LSB-": |
|
|
return False |
|
|
elif x == "-RSB-": |
|
|
return False |
|
|
elif x == "GW": |
|
|
return False |
|
|
elif x == "XX": |
|
|
return False |
|
|
logger.info(f"[{dataset_name}] Filtering example with: xpos={x}\n{exp['tokens']}\n{exp['xpos']}") |
|
|
return False |
|
|
for d in exp["deprel"]: |
|
|
if d not in allowed_deprel: |
|
|
if d is None: |
|
|
return False |
|
|
elif d == "_": |
|
|
return False |
|
|
elif d == "dep": |
|
|
return False |
|
|
elif d == "dislocated": |
|
|
return False |
|
|
elif d == "flat:foreign": |
|
|
return False |
|
|
elif d == "goeswith": |
|
|
return False |
|
|
elif d == "orphan": |
|
|
return False |
|
|
elif d == "reparandum": |
|
|
return False |
|
|
logger.info(f"[{dataset_name}] Filtering example with: deprel={d}\n{exp['tokens']}\n{exp['deprel']}") |
|
|
return False |
|
|
if "Typo" in exp: |
|
|
for t in exp["Typo"]: |
|
|
if t != "X": |
|
|
return False |
|
|
return True |
|
|
|
|
|
|
|
|
def parse_morphological_feats(feats_in, targeted_feats, exp, token_idx): |
|
|
""" |
|
|
Return a dict {feat_name: feat_value} for each target_feat. |
|
|
If a feature is absent or doesn't apply, use "X". |
|
|
If feats_in is a dict, read from it. |
|
|
If feats_in is a string, parse it. |
|
|
If feats_in is None/'_'/'' => no features => all "X". |
|
|
""" |
|
|
|
|
|
token = exp["tokens"][token_idx] |
|
|
upos = exp["pos"][token_idx] |
|
|
xpos = exp["xpos"][token_idx] |
|
|
out = {feat: "X" for feat in targeted_feats} |
|
|
|
|
|
|
|
|
if not feats_in or feats_in == "_" or feats_in == "None": |
|
|
feats_in = {} |
|
|
|
|
|
pristine_feats_in = feats_in |
|
|
|
|
|
|
|
|
if isinstance(feats_in, str): |
|
|
feats_in = ast.literal_eval(feats_in) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if xpos == "FW": |
|
|
feats_in["Foreign"] = "Yes" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if token in {"Yes", "yes"} and upos == "INTJ": |
|
|
feats_in["Polarity"] = "Pos" |
|
|
elif token in {"Non", "non", "Not", "not", "n't", "n’t"}: |
|
|
feats_in["Polarity"] = "Neg" |
|
|
elif token in {"Neither", "neither", "Nor", "nor"} and upos == "CCONJ": |
|
|
feats_in["Polarity"] = "Neg" |
|
|
elif token in {"Never", "No", "no"} and upos == "INTJ": |
|
|
feats_in["Polarity"] = "Neg" |
|
|
elif token in { |
|
|
"Neither", "neither", |
|
|
"Never", "never", |
|
|
"No", "no", |
|
|
"Nobody", "nobody", |
|
|
"None", "none", |
|
|
"Nothing", "nothing", |
|
|
"Nowhere", "nowhere" |
|
|
} and upos in {"ADV", "DET"}: |
|
|
feats_in["Polarity"] = "X" |
|
|
feats_in["PronType"] = "Neg" |
|
|
else: |
|
|
feats_in["Polarity"] = "X" |
|
|
|
|
|
|
|
|
if isinstance(feats_in, dict): |
|
|
for k, v in feats_in.items(): |
|
|
if k in targeted_feats: |
|
|
out[k] = v |
|
|
else: |
|
|
if k in non_target_feats: |
|
|
non_target_feats[k].append(v) |
|
|
else: |
|
|
logger.info(f"Unhandled non-target feat '{k}={v}'") |
|
|
return out |
|
|
|
|
|
|
|
|
logger.warning(f"Unknown feats type {type(pristine_feats_in)} => {pristine_feats_in}") |
|
|
return out |
|
|
|
|
|
|
|
|
def replace_bracket_label(exp): |
|
|
label_map = {"(": "-LRB-", ")": "-RRB-"} |
|
|
exp["xpos"] = [ label_map[tok] if tok in {"(", ")"} else tok for tok in exp["xpos"] ] |
|
|
return exp |
|
|
|
|
|
|
|
|
def transform_and_filter_dataset(ud_dataset, dataset_name="ewt"): |
|
|
""" |
|
|
ud_dataset is a DatasetDict with splits: 'train', 'validation', 'test' etc. |
|
|
Return a new DatasetDict with the same splits but transformed/filtered. |
|
|
""" |
|
|
new_splits = {} |
|
|
for _split_name, _split_ds in ud_dataset.items(): |
|
|
if dataset_name == "pud": |
|
|
_split_ds = _split_ds.map(replace_bracket_label) |
|
|
transformed_split = _split_ds.filter(lambda ex: is_valid_example(ex, dataset_name=dataset_name)) |
|
|
|
|
|
if "upos" in _split_ds.features: |
|
|
transformed_split = transformed_split.map( |
|
|
lambda exp: convert_upos(exp, _split_ds.features["upos"].feature.names), |
|
|
batched=False) |
|
|
transformed_split = transformed_split.map( |
|
|
add_target_feat_columns, |
|
|
batched=False |
|
|
) |
|
|
|
|
|
for col_name in {"deps", "feats", "head", "idx", "lemmas", "misc", "Typo"}: |
|
|
if col_name in transformed_split.features: |
|
|
transformed_split = transformed_split.remove_columns([col_name]) |
|
|
new_splits[_split_name] = transformed_split.filter(is_evenly_shaped) |
|
|
return DatasetDict(new_splits) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
arg_parser = argparse.ArgumentParser(description="Make training dataset.") |
|
|
arg_parser.add_argument("--load-path", help="Load dataset from specified path.", |
|
|
action="store", default=None) |
|
|
arg_parser.add_argument("--log-level", help='Log level.', |
|
|
action="store", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]) |
|
|
arg_parser.add_argument("--save", help='Save dataset to disk.', |
|
|
action="store_true", default=False) |
|
|
arg_parser.add_argument("--save-path", help="Save final dataset to specified path.", |
|
|
action="store", default="./ud_training_data") |
|
|
arg_parser.add_argument("--show", help="Show examples: <split>/<col>/<label>/<count>", |
|
|
action="store", default=None) |
|
|
args = arg_parser.parse_args() |
|
|
logging.config.dictConfig(default_logging_config) |
|
|
|
|
|
if args.load_path is None: |
|
|
|
|
|
ud_en_ewt_ds = load_dataset("universal_dependencies", "en_ewt") |
|
|
ud_en_gum_ds = load_dataset("universal_dependencies", "en_gum") |
|
|
ud_en_pud_ds = load_dataset("universal_dependencies", "en_pud") |
|
|
|
|
|
for loaded_ds_name, loaded_ds in { |
|
|
"ud_en_ewt_ds": ud_en_ewt_ds, |
|
|
"ud_en_gum_ds": ud_en_gum_ds, |
|
|
"ud_en_pud_ds": ud_en_pud_ds |
|
|
}.items(): |
|
|
t_cnt = len(loaded_ds['test']) if 'test' in loaded_ds else 0 |
|
|
tr_cnt = len(loaded_ds['train']) if 'train' in loaded_ds else 0 |
|
|
v_cnt = len(loaded_ds['validation']) if 'train' in loaded_ds else 0 |
|
|
logger.info(f"Loaded {loaded_ds_name}: t:{t_cnt}, tr:{tr_cnt}, v:{v_cnt}") |
|
|
|
|
|
|
|
|
en_ewt_processed = transform_and_filter_dataset(ud_en_ewt_ds, "ewt") |
|
|
en_gum_processed = transform_and_filter_dataset(ud_en_gum_ds, "gum") |
|
|
en_pud_processed = transform_and_filter_dataset(ud_en_pud_ds, "pud") |
|
|
|
|
|
|
|
|
final_dataset = DatasetDict() |
|
|
final_dataset["test"] = concatenate_datasets( |
|
|
[ |
|
|
en_ewt_processed["test"], |
|
|
en_gum_processed["test"], |
|
|
en_pud_processed["test"], |
|
|
] |
|
|
) |
|
|
|
|
|
final_dataset["train"] = concatenate_datasets( |
|
|
[ |
|
|
en_ewt_processed["train"], |
|
|
en_gum_processed["train"], |
|
|
] |
|
|
) |
|
|
|
|
|
final_dataset["validation"] = concatenate_datasets( |
|
|
[ |
|
|
en_ewt_processed["validation"], |
|
|
en_gum_processed["validation"], |
|
|
] |
|
|
) |
|
|
else: |
|
|
final_dataset = transform_and_filter_dataset(load_from_disk(args.load_path)) |
|
|
|
|
|
show_examples(final_dataset, args.show) |
|
|
get_uniq_training_labels(final_dataset) |
|
|
if args.save: |
|
|
final_dataset.save_to_disk(args.save_path) |
|
|
|