multi-classifier / ud_dataset_maker.py
veryfansome's picture
feat: updates for models/ud_ewt_gum_pud_20250611
cf60c27
from datasets import load_dataset, load_from_disk, DatasetDict, concatenate_datasets
import argparse
import ast
import logging.config
from utils import default_logging_config, get_uniq_training_labels, show_examples
logger = logging.getLogger(__name__)
allowed_xpos = [
"''",
'$',
',',
'-LRB-', # (
'-RRB-', # )
'.',
':',
'ADD', # URLs, email addresses, or other “address” forms (like Twitter handles) that do not fit elsewhere.
'AFX',
'CC',
'CD',
'DT',
'EX',
'FW',
'HYPH',
'IN',
'JJ',
'JJR',
'JJS',
'LS', # List item marker
'MD',
'NFP', # “Non-Final Punctuation” for punctuation that doesn’t fit typical labels, in unexpected or stray positions
'NN',
'NNP',
'NNPS',
'NNS',
'PDT',
'POS',
'PRP$',
'PRP',
'RB',
'RBR',
'RBS',
'RP',
'SYM',
'TO',
'UH',
'VB',
'VBD',
'VBG',
'VBN',
'VBP',
'VBZ',
'WDT',
'WP$',
'WP',
'WRB',
'``',
]
allowed_deprel = [
'acl',
'acl:relcl',
'advcl',
'advmod',
'amod',
'appos',
'aux',
'aux:pass',
'case',
'cc',
'cc:preconj',
'ccomp',
'compound',
'compound:prt',
'conj',
'cop',
'csubj',
'csubj:pass',
'det',
'det:predet',
'discourse',
'expl',
'fixed',
'flat',
'iobj',
'list',
'mark',
'nmod',
'nmod:npmod',
'nmod:poss',
'nmod:tmod',
'nsubj',
'nsubj:pass',
'nummod',
'obj',
'obl',
'obl:npmod',
'obl:tmod',
'parataxis',
'punct',
'root',
'vocative',
'xcomp',
]
non_target_feats = { # Found programmatically and added after analysis
"Abbr": [],
"Typo": [],
"Voice": [],
}
target_feats = [
"Case", "Definite", "Degree", "Foreign", "Gender", "Mood", "NumType", "Number",
"Person", "Polarity", "PronType", "Poss", "Reflex", "Tense", "VerbForm",
]
def add_target_feat_columns(exp):
"""
Convert example["feats"] (list of feats) into separate columns
for each target_feat. Always return a dict with the same structure.
"""
if "feats" in exp:
# example["feats"] is a list of length N (one per token)
feats_list = exp["feats"]
# Parse feats for each token
parsed_feats = [parse_morphological_feats(
f, target_feats, exp, i
) for i, f in enumerate(feats_list)]
# Now add new columns for each target feat
for feat in target_feats:
exp[feat] = [pf[feat] for pf in parsed_feats]
return exp
def convert_upos(exp, labels):
exp["pos"] = [labels[i] for i in exp.pop("upos")]
return exp
def extract_label_groups(exp, feat, target_labels=None):
"""
For example, given a list of labels (e.g. ["X", "X", "NN", "NN", "X", "X", "NNS", "X"]),
this function will extract the index positions of the labels: NN, NNS, NNP, NNPS.
It returns a list of consecutive index groupings for those noun labels.
For example:
["X", "X", "NN", "NN", "X", "X", "NNS", "X"]
would return:
[[2, 3], [6]]
Args:
exp: Example
feat: feature
target_labels (set of str): The set of tags to target.
Returns:
list of lists of int: A list where each sub-list contains consecutive indices
of labels that match NN, NNS, NNP, NNPS.
"""
groups = []
current_group = []
for idx, label in enumerate(exp[feat]):
if (label in target_labels) if target_labels is not None else label != "X":
# If current_group is empty or the current idx is consecutive (i.e., previous index + 1),
# append to current_group. Otherwise, start a new group.
if current_group and idx == current_group[-1] + 1:
current_group.append(idx)
else:
if current_group:
groups.append(current_group)
current_group = [idx]
else:
if current_group:
groups.append(current_group)
current_group = []
# If there's an open group at the end, add it
if current_group:
groups.append(current_group)
return groups
def is_evenly_shaped(exp):
# All your target columns
feats = ["xpos", "deprel", *target_feats]
n_tokens = len(exp["tokens"])
for feat_name in feats:
if len(exp[feat_name]) != n_tokens:
return False
return True
def is_valid_example(exp, dataset_name="ewt"):
"""Return True if all xpos & deprel labels are in the common sets, else False."""
uniq_tokens = list(set(exp["tokens"]))
if len(uniq_tokens) == 1:
if uniq_tokens[0] == "_":
return False
for x in exp["xpos"]:
# If we hit an out-of-common-set xpos, we exclude this entire example
if x not in allowed_xpos:
# From time-to-time, we run into labels that are missing - either _ or None.
if x is None:
return False
elif x == "_":
return False
elif x == "-LSB-": # [, en_gum only, not shared by other datasets
return False
elif x == "-RSB-": # ], en_gum only, not shared by other datasets
return False
elif x == "GW": # 'GW', # "Gap Word", sometimes called “additional word” or “merged/gap word”).
return False
elif x == "XX": # Unknown or “placeholder” words/tokens, 2 examples both word1/word2 with XX on the /
return False
logger.info(f"[{dataset_name}] Filtering example with: xpos={x}\n{exp['tokens']}\n{exp['xpos']}")
return False
for d in exp["deprel"]:
if d not in allowed_deprel:
if d is None:
return False
elif d == "_":
return False
elif d == "dep":
return False
elif d == "dislocated":
return False
elif d == "flat:foreign":
return False
elif d == "goeswith":
return False
elif d == "orphan":
return False
elif d == "reparandum":
return False
logger.info(f"[{dataset_name}] Filtering example with: deprel={d}\n{exp['tokens']}\n{exp['deprel']}")
return False
if "Typo" in exp:
for t in exp["Typo"]:
if t != "X":
return False
return True
def parse_morphological_feats(feats_in, targeted_feats, exp, token_idx):
"""
Return a dict {feat_name: feat_value} for each target_feat.
If a feature is absent or doesn't apply, use "X".
If feats_in is a dict, read from it.
If feats_in is a string, parse it.
If feats_in is None/'_'/'' => no features => all "X".
"""
# Default
token = exp["tokens"][token_idx]
upos = exp["pos"][token_idx]
xpos = exp["xpos"][token_idx]
out = {feat: "X" for feat in targeted_feats}
# If feats_in is None or "_" or an empty string
if not feats_in or feats_in == "_" or feats_in == "None":
feats_in = {}
pristine_feats_in = feats_in
# If feats_in is a dict string: "{'Number': 'Sing', 'Person': '3'}"
if isinstance(feats_in, str):
feats_in = ast.literal_eval(feats_in)
##
# Custom transforms
# Consistency between FW xpos tag and Foreign morphological feature
if xpos == "FW":
feats_in["Foreign"] = "Yes"
# Incorrectly labeled Polarity feature
# - Polarity indicates negation or affirmation on grammatical items.
# - In English, it pertains to only the following function words:
# - the particle not receives Polarity=Neg
# - the coordinating conjunction nor receives Polarity=Neg, as does neither when coupled with nor
# - the interjection no receives Polarity=Neg
# - the interjection yes receives Polarity=Pos
# - Lexical (as opposed to grammatical) items that trigger negative polarity, e.g. lack, doubt, hardly, do not
# receive the feature. Neither do negative prefixes (on adjectives: wise – unwise, probable – improbable), as
# the availability of such prefixes depends on the lexical stem.
# - Other function words conveying negation are pro-forms (tagged as DET, PRON, or ADV) and should therefore
# receive PronType=Neg (not Polarity).
if token in {"Yes", "yes"} and upos == "INTJ":
feats_in["Polarity"] = "Pos"
elif token in {"Non", "non", "Not", "not", "n't", "n’t"}:
feats_in["Polarity"] = "Neg"
elif token in {"Neither", "neither", "Nor", "nor"} and upos == "CCONJ":
feats_in["Polarity"] = "Neg"
elif token in {"Never", "No", "no"} and upos == "INTJ":
feats_in["Polarity"] = "Neg"
elif token in {
"Neither", "neither",
"Never", "never",
"No", "no",
"Nobody", "nobody",
"None", "none",
"Nothing", "nothing",
"Nowhere", "nowhere"
} and upos in {"ADV", "DET"}:
feats_in["Polarity"] = "X"
feats_in["PronType"] = "Neg"
else:
feats_in["Polarity"] = "X"
# feats_in is now always a dictionary (some UD data defaults to this)
if isinstance(feats_in, dict):
for k, v in feats_in.items():
if k in targeted_feats:
out[k] = v
else:
if k in non_target_feats:
non_target_feats[k].append(v)
else:
logger.info(f"Unhandled non-target feat '{k}={v}'")
return out
# Otherwise, unknown type
logger.warning(f"Unknown feats type {type(pristine_feats_in)} => {pristine_feats_in}")
return out
def replace_bracket_label(exp):
label_map = {"(": "-LRB-", ")": "-RRB-"}
exp["xpos"] = [ label_map[tok] if tok in {"(", ")"} else tok for tok in exp["xpos"] ]
return exp
def transform_and_filter_dataset(ud_dataset, dataset_name="ewt"):
"""
ud_dataset is a DatasetDict with splits: 'train', 'validation', 'test' etc.
Return a new DatasetDict with the same splits but transformed/filtered.
"""
new_splits = {}
for _split_name, _split_ds in ud_dataset.items():
if dataset_name == "pud":
_split_ds = _split_ds.map(replace_bracket_label)
transformed_split = _split_ds.filter(lambda ex: is_valid_example(ex, dataset_name=dataset_name))
if "upos" in _split_ds.features:
transformed_split = transformed_split.map(
lambda exp: convert_upos(exp, _split_ds.features["upos"].feature.names),
batched=False)
transformed_split = transformed_split.map(
add_target_feat_columns,
batched=False
)
for col_name in {"deps", "feats", "head", "idx", "lemmas", "misc", "Typo"}:
if col_name in transformed_split.features:
transformed_split = transformed_split.remove_columns([col_name])
new_splits[_split_name] = transformed_split.filter(is_evenly_shaped)
return DatasetDict(new_splits)
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser(description="Make training dataset.")
arg_parser.add_argument("--load-path", help="Load dataset from specified path.",
action="store", default=None)
arg_parser.add_argument("--log-level", help='Log level.',
action="store", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"])
arg_parser.add_argument("--save", help='Save dataset to disk.',
action="store_true", default=False)
arg_parser.add_argument("--save-path", help="Save final dataset to specified path.",
action="store", default="./ud_training_data")
arg_parser.add_argument("--show", help="Show examples: <split>/<col>/<label>/<count>",
action="store", default=None)
args = arg_parser.parse_args()
logging.config.dictConfig(default_logging_config)
if args.load_path is None:
# Load UD Datasets: EWT, GUM, PUD
ud_en_ewt_ds = load_dataset("universal_dependencies", "en_ewt")
ud_en_gum_ds = load_dataset("universal_dependencies", "en_gum")
ud_en_pud_ds = load_dataset("universal_dependencies", "en_pud")
for loaded_ds_name, loaded_ds in {
"ud_en_ewt_ds": ud_en_ewt_ds,
"ud_en_gum_ds": ud_en_gum_ds,
"ud_en_pud_ds": ud_en_pud_ds
}.items():
t_cnt = len(loaded_ds['test']) if 'test' in loaded_ds else 0
tr_cnt = len(loaded_ds['train']) if 'train' in loaded_ds else 0
v_cnt = len(loaded_ds['validation']) if 'train' in loaded_ds else 0
logger.info(f"Loaded {loaded_ds_name}: t:{t_cnt}, tr:{tr_cnt}, v:{v_cnt}")
# Apply transform + filtering to each split in each dataset
en_ewt_processed = transform_and_filter_dataset(ud_en_ewt_ds, "ewt")
en_gum_processed = transform_and_filter_dataset(ud_en_gum_ds, "gum")
en_pud_processed = transform_and_filter_dataset(ud_en_pud_ds, "pud")
# Concatenate Datasets
final_dataset = DatasetDict()
final_dataset["test"] = concatenate_datasets(
[
en_ewt_processed["test"],
en_gum_processed["test"],
en_pud_processed["test"],
]
)
final_dataset["train"] = concatenate_datasets(
[
en_ewt_processed["train"],
en_gum_processed["train"],
]
)
final_dataset["validation"] = concatenate_datasets(
[
en_ewt_processed["validation"],
en_gum_processed["validation"],
]
)
else:
final_dataset = transform_and_filter_dataset(load_from_disk(args.load_path))
show_examples(final_dataset, args.show)
get_uniq_training_labels(final_dataset)
if args.save:
final_dataset.save_to_disk(args.save_path)