multi-classifier / ud_dataset_maker.py
veryfansome's picture
wip: adj and adv features
817dcd8
raw
history blame
36.6 kB
from datasets import load_dataset, load_from_disk, DatasetDict, concatenate_datasets
from openai import OpenAI
from traceback import format_exc
import argparse
import ast
import json
import logging.config
import random
from goemotions_predict import GoEmotionsPredictor
from utils.typos import generate_typo
from utils import default_logging_config, get_uniq_training_labels, show_examples
logger = logging.getLogger(__name__)
goemotions_predictor = GoEmotionsPredictor(
"veryfansome/deberta-goemotions", subfolder="pos_weight_best")
allowed_xpos = [
"''",
'$',
',',
'-LRB-', # (
'-RRB-', # )
'.',
':',
'ADD', # URLs, email addresses, or other “address” forms (like Twitter handles) that do not fit elsewhere.
'CC',
'CD',
'DT',
'EX',
'FW',
'HYPH',
'IN',
'JJ',
'JJR',
'JJS',
'LS', # List item marker
'MD',
'NFP', # “Non-Final Punctuation” for punctuation that doesn’t fit typical labels, in unexpected or stray positions
'NN',
'NNP',
'NNPS',
'NNS',
'PDT',
'POS',
'PRP$',
'PRP',
'RB',
'RBR',
'RBS',
'RP',
'SYM',
'TO',
'UH',
'VB',
'VBD',
'VBG',
'VBN',
'VBP',
'VBZ',
'WDT',
'WP$',
'WP',
'WRB',
'``',
]
allowed_deprel = [
'acl',
'acl:relcl',
'advcl',
'advmod',
'amod',
'appos',
'aux',
'aux:pass',
'case',
'cc',
'cc:preconj',
'ccomp',
'compound',
'compound:prt',
'conj',
'cop',
'csubj',
'csubj:pass',
'dep',
'det',
'det:predet',
'discourse',
'dislocated',
'expl',
'fixed',
'flat',
'flat:foreign',
'goeswith',
'iobj',
'list',
'mark',
'nmod',
'nmod:npmod',
'nmod:poss',
'nmod:tmod',
'nsubj',
'nsubj:pass',
'nummod',
'obj',
'obl',
'obl:npmod',
'obl:tmod',
'orphan',
'parataxis',
'punct',
'reparandum',
'root',
'vocative',
'xcomp',
]
non_target_feats = { # Found programmatically and added after analysis
"Abbr": [],
"Foreign": [],
"Polarity": [],
"Voice": [],
}
openai_classification_params = {
"model": "gpt-4o",
"temperature": 0.0,
#"model": "o3-mini",
#"reasoning_effort": "high",
"top_p": 1.0,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"timeout": 30,
}
target_feats = [
"Case", "Definite", "Degree", "Gender", "Mood", "NumType", "Number",
"Person", "Poss", "PronType", "Reflex", "Tense", "Typo", "VerbForm"
]
word_lists_degree_adverbs = [
"almost",
"quite",
"rather",
"too",
"very",
"extremely",
]
word_lists_difference_adjectives = [
"contrasting",
"different",
"disparate",
"dissimilar",
"distinct",
"divergent",
"diverse",
"heterogeneous",
"varied",
"various",
]
word_lists_frequency_adverbs = [
"always",
"daily",
"monthly",
"often",
"rarely",
"seldom",
"sometimes",
"weekly",
"yearly",
]
word_lists_limiting_adjectives = [
"any",
"certain",
"each",
"every",
"other",
"some",
# Demonstrative adjectives / determiners
"that",
"these",
"this",
"those",
]
word_lists_negative_adverbs = [
"not",
]
word_lists_similarity_adjectives = [
"alike",
"analogous",
"comparable",
"equal",
"equivalent",
"homogeneous",
"identical",
"interchangeable",
"same",
"similar",
]
word_lists_states_of_being_verbs = [
"am", "are", "be", "been", "being", "is", "was", "were",
]
word_lists_time_adverbs = [
"already",
"soon",
"today",
"tomorrow",
"yesterday",
]
word_lists_uncertainty_adverbs = [
"maybe",
"perhaps",
"possibly",
]
def add_target_feat_columns(exp):
"""
Convert example["feats"] (list of feats) into separate columns
for each target_feat. Always return a dict with the same structure.
"""
if "feats" in exp:
# example["feats"] is a list of length N (one per token)
feats_list = exp["feats"]
# Parse feats for each token
parsed_feats = [parse_morphological_feats(f, target_feats) for f in feats_list]
# Now add new columns for each target feat
for feat in target_feats:
exp[feat] = [pf[feat] for pf in parsed_feats]
return exp
def convert_head_column(batch):
for feature_name, feature_attr in {
"AdjHead": ({"JJ", "JJR", "JJS"}, -4, 4),
"AdvHead": ({"RB", "RBR", "RBS"}, -3, 4),
"CdHead": ({"CD"}, -3, 3),
"ConjHead": ({"CC"}, -1, 4),
"DetHead": ({"DT", "PDT"}, -2, 4),
"InHead": ({"IN"}, -2, 5),
"ModalHead": ({"MD"}, -1, 3),
"NounHead": ({"NN", "NNS", "NNP", "NNPS"}, -5, 4),
"PronounHead": ({"PRP"}, -2, 3),
"ToHead": ({"TO"}, -1, 2),
"VerbHead": ({"VB", "VBD", "VBG", "VBN", "VBP", "VBZ"}, -5, 4),
"WhHead": ({"WDT", "WP", "WP$", "WRB"}, -2, 4),
}.items():
label_set, max_negative, max_positive = feature_attr
if feature_name not in batch:
batch[feature_name] = batch["head"].copy()
for head_idx, head_labels in enumerate(batch["head"]):
new_head_labels = []
for label_idx, label in enumerate(head_labels):
if batch["xpos"][head_idx][label_idx] in label_set:
new_label = int(label) - (label_idx + 1)
if max_negative < new_label < max_positive:
new_label = str(new_label)
elif new_label > 0:
new_label = f"{max_positive}+"
else:
new_label = f"{max_negative}+"
new_head_labels.append(new_label)
else:
new_head_labels.append("O")
batch[feature_name][head_idx] = new_head_labels
return batch
def convert_upos(exp, labels):
exp["pos"] = [labels[i] for i in exp.pop("upos")]
return exp
def extract_label_groups(exp, feat, target_labels=None):
"""
For example, given a list of labels (e.g. ["O", "O", "NN", "NN", "O", "O", "NNS", "O"]),
this function will extract the index positions of the labels: NN, NNS, NNP, NNPS.
It returns a list of consecutive index groupings for those noun labels.
For example:
["O", "O", "NN", "NN", "O", "O", "NNS", "O"]
would return:
[[2, 3], [6]]
Args:
exp: Example
feat: feature
target_labels (set of str): The set of tags to target.
Returns:
list of lists of int: A list where each sub-list contains consecutive indices
of labels that match NN, NNS, NNP, NNPS.
"""
groups = []
current_group = []
for idx, label in enumerate(exp[feat]):
if (label in target_labels) if target_labels is not None else label != "O":
# If current_group is empty or the current idx is consecutive (i.e., previous index + 1),
# append to current_group. Otherwise, start a new group.
if current_group and idx == current_group[-1] + 1:
current_group.append(idx)
else:
if current_group:
groups.append(current_group)
current_group = [idx]
else:
if current_group:
groups.append(current_group)
current_group = []
# If there's an open group at the end, add it
if current_group:
groups.append(current_group)
return groups
def introduce_adj_type(exp):
if "AdjType" not in exp:
exp["AdjType"] = ["O" for _ in exp["tokens"]]
labels = ["Quantity", "Quality", "Size", "Age", "Shape", "Color", "Origin", "Material", "Purpose"]
labels_len = len(labels)
label_blob = ", ".join([(f"or {l}" if i == labels_len - 1 else l) for i, l in enumerate(labels)])
if "JJ" in exp["xpos"] or "JJR" in exp["xpos"] or "JJS" in exp["xpos"]:
for jj_group in extract_label_groups(exp, "xpos", {"JJ", "JJR", "JJS"}):
for jj_idx in jj_group:
jj_token = exp["tokens"][jj_idx]
if jj_token in word_lists_difference_adjectives:
exp["AdjType"][jj_idx] = "Difference"
elif jj_token in word_lists_limiting_adjectives:
exp["AdjType"][jj_idx] = "Limit"
elif jj_token in word_lists_similarity_adjectives:
exp["AdjType"][jj_idx] = "Similarity"
else:
with OpenAI() as client:
while exp["AdjType"][jj_idx] == "O": # While not labeled
try:
completion = client.chat.completions.create(
messages=[
{
"role": "system",
"content": f"""
Classify '{jj_token}' at token index position {jj_idx} by choosing the best fitting adjective label. Return only the
label value, nothing else.
""".replace("\n", "").strip()
},
{
"role": "user",
"content": exp["text"]
},
{
"role": "user",
"content": str(exp["tokens"])
},
{
"role": "user",
"content": f"The adjective '{jj_token}' at token index position {jj_idx} above describes a {label_blob}?"
},
],
**openai_classification_params,
response_format={
"type": "json_schema",
"json_schema": {
"name": "adjective",
"strict": True,
"schema": {
"type": "object",
"properties": {
"label": {
"type": "string",
"enum": labels
}
},
"additionalProperties": False,
"required": ["label"]
}
}
},
)
# Set so occasional hallucinations are retried
new_label = json.loads(completion.choices[0].message.content)['label']
logger.info(f"{jj_idx}:{jj_token} {new_label}")
if new_label in labels:
exp["AdjType"][jj_idx] = new_label
except Exception as e:
logger.error(f"failed to get label, trying again:\n{format_exc()}")
logger.info("\n" + "\n".join([f"{k}\t{v}" for k, v in exp.items() if k in {"tokens", "AdjType"}]))
return exp
def introduce_adv_type(exp):
if "AdvType" not in exp:
exp["AdvType"] = ["O" for _ in exp["tokens"]]
labels = [
"Degree",
"Frequency",
"Manner",
"Negative",
"Place",
"Purpose",
"Time",
"Uncertainty",
]
labels_len = len(labels)
label_blob = ", ".join([(f"or {l}" if i == labels_len - 1 else l) for i, l in enumerate(labels)])
if "RB" in exp["xpos"] or "RBR" in exp["xpos"] or "RBS" in exp["xpos"]:
for rb_group in extract_label_groups(exp, "xpos", {"RB", "RBR", "RBS"}):
for rb_idx in rb_group:
rb_token = exp["tokens"][rb_idx]
if rb_token in word_lists_degree_adverbs:
exp["AdvType"][rb_idx] = "Degree"
elif rb_token in word_lists_frequency_adverbs:
exp["AdvType"][rb_idx] = "Frequency"
elif rb_token in word_lists_negative_adverbs:
exp["AdvType"][rb_idx] = "Negative"
elif rb_token in word_lists_time_adverbs:
exp["AdvType"][rb_idx] = "Time"
elif rb_token in word_lists_uncertainty_adverbs:
exp["AdvType"][rb_idx] = "Uncertainty"
else:
with OpenAI() as client:
while exp["AdvType"][rb_idx] == "O": # While not labeled
try:
completion = client.chat.completions.create(
messages=[
{
"role": "system",
"content": f"""
Classify '{rb_token}' at token index position {rb_idx} by choosing the best fitting adverb label. Return only the
label value, nothing else.
""".replace("\n", "").strip()
},
{
"role": "user",
"content": exp["text"]
},
{
"role": "user",
"content": str(exp["tokens"])
},
{
"role": "user",
"content": f"The adverb '{rb_token}' at token index position {rb_idx} above describes a {label_blob}?"
},
],
**openai_classification_params,
response_format={
"type": "json_schema",
"json_schema": {
"name": "adverb",
"strict": True,
"schema": {
"type": "object",
"properties": {
"label": {
"type": "string",
"enum": labels
}
},
"additionalProperties": False,
"required": ["label"]
}
}
},
)
# Set so occasional hallucinations are retried
new_label = json.loads(completion.choices[0].message.content)['label']
logger.info(f"{rb_idx}:{rb_token} {new_label}")
if new_label in labels:
exp["AdvType"][rb_idx] = new_label
except Exception as e:
logger.error(f"failed to get label, trying again:\n{format_exc()}")
logger.info("\n" + "\n".join([f"{k}\t{v}" for k, v in exp.items() if k in {"tokens", "AdvType"}]))
return exp
def introduce_emotion(exp):
if "Emotion" not in exp:
exp["Emotion"] = ["X" for _ in exp["tokens"]]
labels = [l.upper() for l in goemotions_predictor.predict([exp["text"]], use_per_label=True)[0]["emotions"] if l != "neutral"]
labels.append("O")
labels_len = len(labels)
label_blob = ", ".join([(f"or {l}" if (labels_len > 1 and i == labels_len - 1) else l) for i, l in enumerate(labels)])
logger.info(f"label_blob: {label_blob}")
if label_blob != "O":
for capture_group in extract_label_groups(exp, "xpos", {
"JJ", "JJR", "JJS",
"NN", "NNS", "NNP", "NNPS",
"RB", "RBR", "RBS",
"VB", "VBD", "VBG", "VBN", "VBP", "VBZ",
}):
for token_idx in capture_group:
token = exp["tokens"][token_idx]
if token in word_lists_states_of_being_verbs:
exp["Emotion"][token_idx] = "O"
else:
with OpenAI() as client:
while exp["Emotion"][token_idx] == "X": # While not labeled
try:
completion = client.chat.completions.create(
messages=[
{
"role": "system",
"content": f"""
Classify '{token}' at token index position {token_idx} by choosing the best fitting emotion label or O if out of scope.
Pay close attention to semantic context but don't over-generalize if there is not enough context in the provided text.
Return only the label value, nothing else.
""".replace("\n", "").strip()
},
{
"role": "user",
"content": exp["text"]
},
{
"role": "user",
"content": str(exp["tokens"])
},
{
"role": "user",
"content": f"The word '{token}' at token index position {token_idx} above evokes {label_blob}?"
},
],
**openai_classification_params,
response_format={
"type": "json_schema",
"json_schema": {
"name": "label",
"strict": True,
"schema": {
"type": "object",
"properties": {
"label": {
"type": "string",
"enum": labels
}
},
"additionalProperties": False,
"required": ["label"]
}
}
},
)
# Set so occasional hallucinations are retried
new_label = json.loads(completion.choices[0].message.content)['label']
logger.info(f"{token_idx}:{token} {new_label}")
if new_label in labels:
exp["Emotion"][token_idx] = new_label
except Exception as e:
logger.error(f"failed to get label, trying again:\n{format_exc()}")
logger.info("\n" + "\n".join([f"{k}\t{v}" for k, v in exp.items() if k in {"tokens", "Emotion"}]))
exp["Emotion"] = [("O" if l == "X" else l) for l in exp["Emotion"]]
return exp
def introduce_ner_feature(exp, class_name: str, class_desc: str):
class_name_capital = class_name.capitalize()
class_name_upper = class_name.upper()
class_feature_name = f"Ner{class_name_capital}"
if class_feature_name not in exp:
exp[class_feature_name] = ["X" for _ in exp["tokens"]]
labels = [f"B-{class_name_upper}", f"I-{class_name_upper}", "O"]
labels_len = len(labels)
label_blob = ", ".join([(f"or {l}" if i == labels_len - 1 else l) for i, l in enumerate(labels)])
for capital_idx in [i for i, t in enumerate(exp["tokens"]) if len(t) > 0
and t[0].isupper()
and exp["xpos"][i] in {
"JJ", "JJR", "JJS",
"NN", "NNS", "NNP", "NNPS"
}]:
capital_token = exp["tokens"][capital_idx]
with OpenAI() as client:
while exp[class_feature_name][capital_idx] == "X": # While not labeled
try:
completion = client.chat.completions.create(
messages=[
{
"role": "system",
"content": "You are an expert in recognizing all kinds of names.",
},
{
"role": "user",
"content": f"""
Classify '{capital_token}' at token index position {capital_idx} by choosing the best fitting BIO named entity label.
Pay close attention to semantic context and neighboring tokens but don't over-generalize if there is not enough context
in the provided text. Classify '{capital_token}' as a {class_name_upper} if it is being used as a part of a
{class_desc}. Use the B-{class_name_upper} label if the token begins a {class_name_upper} name entity and the
I-{class_name_upper} label if '{capital_token}' continues a {class_name_upper} name entity. Return only the label
value, nothing else.
""".replace("\n", "").strip()
},
{
"role": "user",
"content": exp["text"]
},
{
"role": "user",
"content": str(exp["tokens"])
},
{
"role": "user",
"content": (f"The token '{capital_token}' at index position {capital_idx} above "
f"is used as a {label_blob} in the text?")
},
],
**openai_classification_params,
response_format={
"type": "json_schema",
"json_schema": {
"name": "label",
"strict": True,
"schema": {
"type": "object",
"properties": {
"label": {
"type": "string",
"enum": labels
}
},
"additionalProperties": False,
"required": ["label"]
}
}
},
)
# Set if valid label so occasional hallucinations are retried
new_label = json.loads(completion.choices[0].message.content)['label']
logger.info(f"{capital_idx}:{capital_token} {new_label}")
if new_label in labels:
exp[class_feature_name][capital_idx] = new_label
except Exception as e:
logger.error(f"failed to get {class_feature_name} label for {capital_token} at idx {capital_idx} "
f"in \"{exp['text']}\", trying again:\n{format_exc()}")
logger.info("\n" + "\n".join([f"{k}\t{v}" for k, v in exp.items() if k in {"tokens", class_feature_name}]))
exp[class_feature_name] = [("O" if l == "X" else l) for l in exp[class_feature_name]]
return exp
def introduce_typos(exp, typo_probability=0.03):
"""
Randomly introduce typos in some % of tokens.
Update the `tokens` and the `Typo` columns in-place.
"""
# new lists for mutated tokens and new Typo labels
mutated_tokens = []
mutated_typo_col = []
# Loop over each token
for token, old_typo_label in zip(exp["tokens"], exp["Typo"]):
# Decide whether to mutate this token
if random.random() < typo_probability:
mutated_token = generate_typo(token)
mutated_tokens.append(mutated_token)
mutated_typo_col.append("Yes") # Mark as a "Yes" for the newly introduced typo
else:
mutated_tokens.append(token)
mutated_typo_col.append(old_typo_label)
exp["tokens"] = mutated_tokens
exp["Typo"] = mutated_typo_col
return exp
def is_evenly_shaped(exp):
# All your target columns
feats = ["xpos", "deprel", *target_feats]
n_tokens = len(exp["tokens"])
for feat_name in feats:
if len(exp[feat_name]) != n_tokens:
return False
return True
def is_valid_example(exp, dataset_name="ewt"):
"""Return True if all xpos & deprel labels are in the common sets, else False."""
uniq_tokens = list(set(exp["tokens"]))
if len(uniq_tokens) == 1:
if uniq_tokens[0] == "_":
return False
for x in exp["xpos"]:
# If we hit an out-of-common-set xpos, we exclude this entire example
if x not in allowed_xpos:
# From time-to-time, we run into labels that are missing - either _ or None.
if x is None:
return False
elif x == "_":
return False
elif x == "-LSB-": # [, en_gum only, not shared by other datasets
return False
elif x == "-RSB-": # ], en_gum only, not shared by other datasets
return False
elif x == "AFX": # “Affix” for bound morphemes or prefixes/suffixes that are split off from main tokens
return False
elif x == "GW": # 'GW', # "Gap Word", sometimes called “additional word” or “merged/gap word”).
return False
elif x == "XX": # Unknown or “placeholder” words/tokens, 2 examples both word1/word2 with XX on the /
return False
logger.info(f"[{dataset_name}] Filtering example with: xpos={x}\n{exp['tokens']}\n{exp['xpos']}")
return False
for d in exp["deprel"]:
if d not in allowed_deprel:
if d is None:
return False
elif d == "_":
return False
logger.info(f"[{dataset_name}] Filtering example with: deprel={d}\n{exp['tokens']}\n{exp['deprel']}")
return False
return True
def parse_morphological_feats(feats_in, targeted_feats):
"""
Return a dict {feat_name: feat_value} for each target_feat.
If a feature is absent or doesn't apply, use "O".
If feats_in is a dict, read from it.
If feats_in is a string, parse it.
If feats_in is None/'_'/'' => no features => all "O".
"""
# Default
out = {feat: "O" for feat in targeted_feats}
# Case A: feats_in is None or "_" or an empty string
if not feats_in or feats_in == "_" or feats_in == "None":
return out
pristine_feats_in = feats_in
# Case B: feats_in is a dict string: "{'Number': 'Sing', 'Person': '3'}"
if isinstance(feats_in, str):
feats_in = ast.literal_eval(feats_in)
# Case C: feats_in is a dictionary (some UD data does that)
if isinstance(feats_in, dict):
for k, v in feats_in.items():
if k in targeted_feats:
out[k] = v
else:
if k in non_target_feats:
non_target_feats[k].append(v)
else:
logger.info(f"Unhandled non-target feat '{k}={v}'")
return out
# Otherwise, unknown type
logger.warning(f"Unknown feats type {type(pristine_feats_in)} => {pristine_feats_in}")
return out
def replace_bracket_label(exp):
label_map = {"(": "-LRB-", ")": "-RRB-"}
exp["xpos"] = [ label_map[tok] if tok in {"(", ")"} else tok for tok in exp["xpos"] ]
return exp
def transform_and_filter_dataset(ud_dataset, dataset_name="ewt"):
"""
ud_dataset is a DatasetDict with splits: 'train', 'validation', 'test' etc.
Return a new DatasetDict with the same splits but transformed/filtered.
"""
new_splits = {}
for _split_name, _split_ds in ud_dataset.items():
if dataset_name == "pud":
_split_ds = _split_ds.map(replace_bracket_label)
filtered_split = _split_ds.filter(lambda ex: is_valid_example(ex, dataset_name=dataset_name))
transformed_split = filtered_split.map(lambda exp: convert_upos(exp, _split_ds.features["upos"].feature.names),
batched=False)
transformed_split = transformed_split.map(
add_target_feat_columns,
batched=False
)
transformed_split = transformed_split.map(convert_head_column, batched=True, batch_size=1000)
# TODO:
# - Get emotion classes and label adj and adv tokens based on classified emotions. This connects descriptions,
# with the kind of attribute, with the emotions evoked.
# - checkpoints after each phase to avoid costly re-dos
#transformed_split = transformed_split.map(introduce_emotion, batched=False)
#transformed_split = transformed_split.map(introduce_adj_type, batched=False)
#transformed_split = transformed_split.map(
# lambda exp: introduce_ner_feature(
# exp, "location",
# "location's name"),
# batched=False)
#transformed_split = transformed_split.map(
# lambda exp: introduce_ner_feature(
# exp, "organization",
# "organization's name"),
# batched=False)
#transformed_split = transformed_split.map(
# lambda exp: introduce_ner_feature(
# exp, "person",
# "person's name"),
# batched=False)
for col_name in {"deps", "feats", "head", "idx", "lemmas", "misc"}:
if col_name in transformed_split.features:
transformed_split = transformed_split.remove_columns([col_name])
new_splits[_split_name] = transformed_split.filter(is_evenly_shaped)
return DatasetDict(new_splits)
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser(description="Make training dataset.")
arg_parser.add_argument("--augment-typos", help='Augment final merged training data with typos.',
action="store_true", default=False)
arg_parser.add_argument("--load-path", help="Load dataset from specified path.",
action="store", default=None)
arg_parser.add_argument("--log-level", help='Log level.',
action="store", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"])
arg_parser.add_argument("--save", help='Save dataset to disk.',
action="store_true", default=False)
arg_parser.add_argument("--save-path", help="Save final dataset to specified path.",
action="store", default="./ud_training_data")
arg_parser.add_argument("--show", help="Show examples: <split>/<col>/<label>/<count>",
action="store", default=None)
args = arg_parser.parse_args()
logging.config.dictConfig(default_logging_config)
if args.load_path is None:
# Load UD Datasets: EWT, GUM, PUD
ud_en_ewt_ds = load_dataset("universal_dependencies", "en_ewt")
ud_en_gum_ds = load_dataset("universal_dependencies", "en_gum")
ud_en_pud_ds = load_dataset("universal_dependencies", "en_pud")
for loaded_ds_name, loaded_ds in {
"ud_en_ewt_ds": ud_en_ewt_ds,
"ud_en_gum_ds": ud_en_gum_ds,
"ud_en_pud_ds": ud_en_pud_ds
}.items():
t_cnt = len(loaded_ds['test']) if 'test' in loaded_ds else 0
tr_cnt = len(loaded_ds['train']) if 'train' in loaded_ds else 0
v_cnt = len(loaded_ds['validation']) if 'train' in loaded_ds else 0
logger.info(f"Loaded {loaded_ds_name}: t:{t_cnt}, tr:{tr_cnt}, v:{v_cnt}")
# Apply transform + filtering to each split in each dataset
en_ewt_processed = transform_and_filter_dataset(ud_en_ewt_ds, "ewt")
en_gum_processed = transform_and_filter_dataset(ud_en_gum_ds, "gum")
en_pud_processed = transform_and_filter_dataset(ud_en_pud_ds, "pud")
# Concatenate Datasets
final_dataset = DatasetDict()
final_dataset["test"] = concatenate_datasets(
[
en_ewt_processed["test"],
en_gum_processed["test"],
en_pud_processed["test"],
]
)
final_dataset["train"] = concatenate_datasets(
[
en_ewt_processed["train"],
en_gum_processed["train"],
]
)
if args.augment_typos:
final_dataset["train"] = final_dataset["train"].map(introduce_typos, batched=False)
final_dataset["validation"] = concatenate_datasets(
[
en_ewt_processed["validation"],
en_gum_processed["validation"],
]
)
else:
final_dataset = transform_and_filter_dataset(load_from_disk(args.load_path))
show_examples(final_dataset, args.show)
get_uniq_training_labels(final_dataset)
if args.save:
final_dataset.save_to_disk(args.save_path)