| from datasets import load_dataset, DatasetDict, concatenate_datasets |
| from openai import OpenAI |
| from traceback import format_exc |
| import argparse |
| import ast |
| import json |
| import logging.config |
| import random |
|
|
| from goemotions_predict import GoEmotionsPredictor |
| from utils.typos import generate_typo |
| from utils import default_logging_config, get_uniq_training_labels, show_examples |
|
|
| logger = logging.getLogger(__name__) |
|
|
| goemotions_predictor = GoEmotionsPredictor( |
| "veryfansome/deberta-goemotions", subfolder="pos_weight_best") |
|
|
| allowed_xpos = [ |
| "''", |
| '$', |
| ',', |
| '-LRB-', |
| '-RRB-', |
| '.', |
| ':', |
| 'ADD', |
| 'CC', |
| 'CD', |
| 'DT', |
| 'EX', |
| 'FW', |
| 'HYPH', |
| 'IN', |
| 'JJ', |
| 'JJR', |
| 'JJS', |
| 'LS', |
| 'MD', |
| 'NFP', |
| 'NN', |
| 'NNP', |
| 'NNPS', |
| 'NNS', |
| 'PDT', |
| 'POS', |
| 'PRP$', |
| 'PRP', |
| 'RB', |
| 'RBR', |
| 'RBS', |
| 'RP', |
| 'SYM', |
| 'TO', |
| 'UH', |
| 'VB', |
| 'VBD', |
| 'VBG', |
| 'VBN', |
| 'VBP', |
| 'VBZ', |
| 'WDT', |
| 'WP$', |
| 'WP', |
| 'WRB', |
| '``', |
| ] |
|
|
| allowed_deprel = [ |
| 'acl', |
| 'acl:relcl', |
| 'advcl', |
| 'advmod', |
| 'amod', |
| 'appos', |
| 'aux', |
| 'aux:pass', |
| 'case', |
| 'cc', |
| 'cc:preconj', |
| 'ccomp', |
| 'compound', |
| 'compound:prt', |
| 'conj', |
| 'cop', |
| 'csubj', |
| 'csubj:pass', |
| 'dep', |
| 'det', |
| 'det:predet', |
| 'discourse', |
| 'dislocated', |
| 'expl', |
| 'fixed', |
| 'flat', |
| 'flat:foreign', |
| 'goeswith', |
| 'iobj', |
| 'list', |
| 'mark', |
| 'nmod', |
| 'nmod:npmod', |
| 'nmod:poss', |
| 'nmod:tmod', |
| 'nsubj', |
| 'nsubj:pass', |
| 'nummod', |
| 'obj', |
| 'obl', |
| 'obl:npmod', |
| 'obl:tmod', |
| 'orphan', |
| 'parataxis', |
| 'punct', |
| 'reparandum', |
| 'root', |
| 'vocative', |
| 'xcomp', |
| ] |
|
|
| non_target_feats = { |
| "Abbr": [], |
| "Foreign": [], |
| "Polarity": [], |
| "Voice": [], |
| } |
|
|
| openai_classification_params = { |
| "model": "gpt-4o", |
| "temperature": 0.0, |
|
|
| |
| |
|
|
| "top_p": 1.0, |
| "presence_penalty": 0.0, |
| "frequency_penalty": 0.0, |
| "timeout": 30, |
| } |
|
|
| target_feats = [ |
| "Case", "Definite", "Degree", "Gender", "Mood", "NumType", "Number", |
| "Person", "Poss", "PronType", "Reflex", "Tense", "Typo", "VerbForm" |
| ] |
|
|
| word_lists_limiting_adjectives = [ |
| "any", |
| "certain", |
| "each", |
| "every", |
| "other", |
| "some", |
|
|
| |
| "that", |
| "these", |
| "this", |
| "those", |
| ] |
| word_lists_difference_adjectives = [ |
| "contrasting", |
| "different", |
| "disparate", |
| "dissimilar", |
| "distinct", |
| "divergent", |
| "diverse", |
| "heterogeneous", |
| "varied", |
| "various", |
| ] |
|
|
| word_lists_similarity_adjectives = [ |
| "alike", |
| "analogous", |
| "comparable", |
| "equal", |
| "equivalent", |
| "homogeneous", |
| "identical", |
| "interchangeable", |
| "same", |
| "similar", |
| ] |
|
|
| word_lists_states_of_being_verbs = [ |
| "am", "are", "be", "been", "being", "is", "was", "were", |
| ] |
|
|
|
|
| def add_target_feat_columns(exp): |
| """ |
| Convert example["feats"] (list of feats) into separate columns |
| for each target_feat. Always return a dict with the same structure. |
| """ |
| |
| feats_list = exp["feats"] |
|
|
| |
| parsed_feats = [parse_morphological_feats(f, target_feats) for f in feats_list] |
|
|
| |
| for feat in target_feats: |
| exp[feat] = [pf[feat] for pf in parsed_feats] |
|
|
| return exp |
|
|
|
|
| def extract_label_groups(exp, feat, target_labels=None): |
| """ |
| For example, given a list of labels (e.g. ["O", "O", "NN", "NN", "O", "O", "NNS", "O"]), |
| this function will extract the index positions of the labels: NN, NNS, NNP, NNPS. |
| |
| It returns a list of consecutive index groupings for those noun labels. |
| For example: |
| ["O", "O", "NN", "NN", "O", "O", "NNS", "O"] |
| would return: |
| [[2, 3], [6]] |
| |
| Args: |
| exp: Example |
| feat: feature |
| target_labels (set of str): The set of tags to target. |
| |
| Returns: |
| list of lists of int: A list where each sub-list contains consecutive indices |
| of labels that match NN, NNS, NNP, NNPS. |
| """ |
| groups = [] |
| current_group = [] |
|
|
| for idx, label in enumerate(exp[feat]): |
| if (label in target_labels) if target_labels is not None else label != "O": |
| |
| |
| if current_group and idx == current_group[-1] + 1: |
| current_group.append(idx) |
| else: |
| if current_group: |
| groups.append(current_group) |
| current_group = [idx] |
| else: |
| if current_group: |
| groups.append(current_group) |
| current_group = [] |
|
|
| |
| if current_group: |
| groups.append(current_group) |
|
|
| return groups |
|
|
|
|
| def introduce_emotion(exp): |
| exp["Emotion"] = ["X" for _ in exp["tokens"]] |
| labels = [l.upper() for l in goemotions_predictor.predict([exp["text"]], use_per_label=True)[0]["emotions"] if l != "neutral"] |
| labels.append("O") |
| labels_len = len(labels) |
| label_blob = ", ".join([(f"or {l}" if (labels_len > 1 and i == labels_len - 1) else l) for i, l in enumerate(labels)]) |
| logger.info(f"label_blob: {label_blob}") |
| if label_blob != "O": |
| for capture_group in extract_label_groups(exp, "xpos", { |
| "JJ", "JJR", "JJS", |
| "NN", "NNS", "NNP", "NNPS", |
| "RB", "RBR", "RBS", |
| "VB", "VBD", "VBG", "VBN", "VBP", "VBZ", |
| }): |
| for token_idx in capture_group: |
| token = exp["tokens"][token_idx] |
| if token in word_lists_states_of_being_verbs: |
| exp["Emotion"][token_idx] = "O" |
| else: |
| with OpenAI() as client: |
| while exp["Emotion"][token_idx] == "X": |
| try: |
| completion = client.chat.completions.create( |
| messages=[ |
| { |
| "role": "system", |
| "content": f""" |
| Classify '{token}' at token index position {token_idx} by choosing the best fitting emotion label or O if out of scope. |
| Pay close attention to semantic context but don't over-generalize if there is not enough context in the provided text. |
| Return only the label value, nothing else. |
| """.replace("\n", "").strip() |
| }, |
| { |
| "role": "user", |
| "content": exp["text"] |
| }, |
| { |
| "role": "user", |
| "content": str(exp["tokens"]) |
| }, |
| { |
| "role": "user", |
| "content": f"The word '{token}' at token index position {token_idx} above evokes {label_blob}?" |
| }, |
| ], |
| **openai_classification_params, |
| response_format={ |
| "type": "json_schema", |
| "json_schema": { |
| "name": "label", |
| "strict": True, |
| "schema": { |
| "type": "object", |
| "properties": { |
| "label": { |
| "type": "string", |
| "enum": labels |
| } |
| }, |
| "additionalProperties": False, |
| "required": ["label"] |
| } |
| } |
| }, |
| ) |
| |
| new_label = json.loads(completion.choices[0].message.content)['label'] |
| logger.info(f"{token_idx}:{token} {new_label}") |
| if new_label in labels: |
| exp["Emotion"][token_idx] = new_label |
| except Exception as e: |
| logger.error(f"failed to get label, trying again:\n{format_exc()}") |
| exp["Emotion"] = [("O" if l == "X" else l) for l in exp["Emotion"]] |
| logger.info("\n" + "\n".join([f"{k}\t{v}" for k, v in exp.items() if k in {"tokens", "Emotion"}])) |
| return exp |
|
|
|
|
| def introduce_adj_type(exp): |
| exp["AdjType"] = ["O" for _ in exp["tokens"]] |
| labels = ["Quantity", "Quality", "Size", "Age", "Shape", "Color", "Origin", "Material", "Purpose"] |
| labels_len = len(labels) |
| label_blob = ", ".join([(f"or {l}" if i == labels_len - 1 else l) for i, l in enumerate(labels)]) |
| if "JJ" in exp["xpos"] or "JJR" in exp["xpos"] or "JJS" in exp["xpos"]: |
| for jj_group in extract_label_groups(exp, "xpos", {"JJ", "JJR", "JJS"}): |
| for jj_idx in jj_group: |
| jj_token = exp["tokens"][jj_idx] |
| if jj_token in word_lists_difference_adjectives: |
| exp["AdjType"][jj_idx] = "Difference" |
| elif jj_token in word_lists_limiting_adjectives: |
| exp["AdjType"][jj_idx] = "Limit" |
| elif jj_token in word_lists_similarity_adjectives: |
| exp["AdjType"][jj_idx] = "Similarity" |
| else: |
| with OpenAI() as client: |
| while exp["AdjType"][jj_idx] == "O": |
| try: |
| completion = client.chat.completions.create( |
| messages=[ |
| { |
| "role": "system", |
| "content": f""" |
| Classify '{jj_token}' at token index position {jj_idx} by choosing the best fitting adjective label. Return only the |
| label value, nothing else. |
| """.replace("\n", "").strip() |
| }, |
| { |
| "role": "user", |
| "content": exp["text"] |
| }, |
| { |
| "role": "user", |
| "content": str(exp["tokens"]) |
| }, |
| { |
| "role": "user", |
| "content": f"The adjective '{jj_token}' at token index position {jj_idx} above describes a {label_blob}?" |
| }, |
| ], |
| **openai_classification_params, |
| response_format={ |
| "type": "json_schema", |
| "json_schema": { |
| "name": "label", |
| "strict": True, |
| "schema": { |
| "type": "object", |
| "properties": { |
| "label": { |
| "type": "string", |
| "enum": labels |
| } |
| }, |
| "additionalProperties": False, |
| "required": ["label"] |
| } |
| } |
| }, |
| ) |
| |
| new_label = json.loads(completion.choices[0].message.content)['label'] |
| logger.info(f"{jj_idx}:{jj_token} {new_label}") |
| if new_label in labels: |
| exp["AdjType"][jj_idx] = new_label |
| except Exception as e: |
| logger.error(f"failed to get label, trying again:\n{format_exc()}") |
| logger.info("\n" + "\n".join([f"{k}\t{v}" for k, v in exp.items() if k in {"tokens", "AdjType"}])) |
| return exp |
|
|
|
|
| def introduce_ner_feature(exp, class_name: str, class_desc: str): |
| class_name_capital = class_name.capitalize() |
| class_name_upper = class_name.upper() |
| class_feature_name = f"Ner{class_name_capital}" |
| exp[class_feature_name] = ["X" for _ in exp["tokens"]] |
|
|
| labels = [f"B-{class_name_upper}", f"I-{class_name_upper}", "O"] |
| labels_len = len(labels) |
| label_blob = ", ".join([(f"or {l}" if i == labels_len - 1 else l) for i, l in enumerate(labels)]) |
| for capital_idx in [i for i, t in enumerate(exp["tokens"]) if len(t) > 0 |
| and t[0].isupper() |
| and exp["xpos"][i] in { |
| "JJ", "JJR", "JJS", |
| "NN", "NNS", "NNP", "NNPS" |
| }]: |
| capital_token = exp["tokens"][capital_idx] |
| with OpenAI() as client: |
| while exp[class_feature_name][capital_idx] == "X": |
| try: |
| completion = client.chat.completions.create( |
| messages=[ |
| { |
| "role": "system", |
| "content": "You are an expert in recognizing all kinds of names.", |
| }, |
| { |
| "role": "user", |
| "content": f""" |
| Classify '{capital_token}' at token index position {capital_idx} by choosing the best fitting BIO named entity label. |
| Pay close attention to semantic context and neighboring tokens but don't over-generalize if there is not enough context |
| in the provided text. Classify '{capital_token}' as a {class_name_upper} if it is being used as a part of a |
| {class_desc}. Use the B-{class_name_upper} label if the token begins a {class_name_upper} name entity and the |
| I-{class_name_upper} label if '{capital_token}' continues a {class_name_upper} name entity. Return only the label |
| value, nothing else. |
| """.replace("\n", "").strip() |
| }, |
| { |
| "role": "user", |
| "content": exp["text"] |
| }, |
| { |
| "role": "user", |
| "content": str(exp["tokens"]) |
| }, |
| { |
| "role": "user", |
| "content": (f"The token '{capital_token}' at index position {capital_idx} above " |
| f"is used as a {label_blob} in the text?") |
| }, |
| ], |
| **openai_classification_params, |
| response_format={ |
| "type": "json_schema", |
| "json_schema": { |
| "name": "label", |
| "strict": True, |
| "schema": { |
| "type": "object", |
| "properties": { |
| "label": { |
| "type": "string", |
| "enum": labels |
| } |
| }, |
| "additionalProperties": False, |
| "required": ["label"] |
| } |
| } |
| }, |
| ) |
| |
| new_label = json.loads(completion.choices[0].message.content)['label'] |
| logger.info(f"{capital_idx}:{capital_token} {new_label}") |
| if new_label in labels: |
| exp[class_feature_name][capital_idx] = new_label |
| except Exception as e: |
| logger.error(f"failed to get {class_feature_name} label for {capital_token} at idx {capital_idx} " |
| f"in \"{exp['text']}\", trying again:\n{format_exc()}") |
| exp[class_feature_name] = [("O" if l == "X" else l) for l in exp[class_feature_name]] |
| logger.info("\n" + "\n".join([f"{k}\t{v}" for k, v in exp.items() if k in {"tokens", class_feature_name}])) |
| return exp |
|
|
|
|
| def introduce_typos(exp, typo_probability=0.03): |
| """ |
| Randomly introduce typos in some % of tokens. |
| Update the `tokens` and the `Typo` columns in-place. |
| """ |
| |
| mutated_tokens = [] |
| mutated_typo_col = [] |
|
|
| |
| for token, old_typo_label in zip(exp["tokens"], exp["Typo"]): |
| |
| if random.random() < typo_probability: |
| mutated_token = generate_typo(token) |
| mutated_tokens.append(mutated_token) |
| mutated_typo_col.append("Yes") |
| else: |
| mutated_tokens.append(token) |
| mutated_typo_col.append(old_typo_label) |
|
|
| exp["tokens"] = mutated_tokens |
| exp["Typo"] = mutated_typo_col |
| return exp |
|
|
|
|
| def is_evenly_shaped(exp): |
| |
| feats = ["xpos", "deprel", *target_feats] |
| n_tokens = len(exp["tokens"]) |
| for feat_name in feats: |
| if len(exp[feat_name]) != n_tokens: |
| return False |
| return True |
|
|
|
|
| def is_valid_example(exp, dataset_name="ewt"): |
| """Return True if all xpos & deprel labels are in the common sets, else False.""" |
| uniq_tokens = list(set(exp["tokens"])) |
| if len(uniq_tokens) == 1: |
| if uniq_tokens[0] == "_": |
| return False |
| for x in exp["xpos"]: |
| |
| if x not in allowed_xpos: |
| |
| if x is None: |
| return False |
| elif x == "_": |
| return False |
| elif x == "-LSB-": |
| return False |
| elif x == "-RSB-": |
| return False |
| elif x == "AFX": |
| return False |
| elif x == "GW": |
| return False |
| elif x == "XX": |
| return False |
| logger.info(f"[{dataset_name}] Filtering example with: xpos={x}\n{exp['tokens']}\n{exp['xpos']}") |
| return False |
| for d in exp["deprel"]: |
| if d not in allowed_deprel: |
| if d is None: |
| return False |
| elif d == "_": |
| return False |
| logger.info(f"[{dataset_name}] Filtering example with: deprel={d}\n{exp['tokens']}\n{exp['deprel']}") |
| return False |
| return True |
|
|
|
|
| def parse_morphological_feats(feats_in, targeted_feats): |
| """ |
| Return a dict {feat_name: feat_value} for each target_feat. |
| If a feature is absent or doesn't apply, use "O". |
| If feats_in is a dict, read from it. |
| If feats_in is a string, parse it. |
| If feats_in is None/'_'/'' => no features => all "O". |
| """ |
| |
| out = {feat: "O" for feat in targeted_feats} |
|
|
| |
| if not feats_in or feats_in == "_" or feats_in == "None": |
| return out |
|
|
| pristine_feats_in = feats_in |
|
|
| |
| if isinstance(feats_in, str): |
| feats_in = ast.literal_eval(feats_in) |
|
|
| |
| if isinstance(feats_in, dict): |
| for k, v in feats_in.items(): |
| if k in targeted_feats: |
| out[k] = v |
| else: |
| if k in non_target_feats: |
| non_target_feats[k].append(v) |
| else: |
| logger.info(f"Unhandled non-target feat '{k}={v}'") |
| return out |
|
|
| |
| logger.warning(f"Unknown feats type {type(pristine_feats_in)} => {pristine_feats_in}") |
| return out |
|
|
|
|
| def replace_bracket_label(exp): |
| label_map = {"(": "-LRB-", ")": "-RRB-"} |
| exp["xpos"] = [ label_map[tok] if tok in {"(", ")"} else tok for tok in exp["xpos"] ] |
| return exp |
|
|
|
|
| def transform_and_filter_dataset(ud_dataset, dataset_name="ewt"): |
| """ |
| ud_dataset is a DatasetDict with splits: 'train', 'validation', 'test' etc. |
| Return a new DatasetDict with the same splits but transformed/filtered. |
| """ |
| new_splits = {} |
| for _split_name, _split_ds in ud_dataset.items(): |
| if dataset_name == "pud": |
| _split_ds = _split_ds.map(replace_bracket_label) |
| filtered_split = _split_ds.filter(lambda ex: is_valid_example(ex, dataset_name=dataset_name)) |
|
|
| transformed_split = filtered_split.map( |
| add_target_feat_columns, |
| batched=False |
| ) |
| |
| |
| |
| |
| transformed_split = transformed_split.map(introduce_emotion, batched=False) |
| transformed_split = transformed_split.map(introduce_adj_type, batched=False) |
| transformed_split = transformed_split.map( |
| lambda exp: introduce_ner_feature( |
| exp, "location", |
| "location's name"), |
| batched=False) |
| transformed_split = transformed_split.map( |
| lambda exp: introduce_ner_feature( |
| exp, "organization", |
| "organization's name"), |
| batched=False) |
| transformed_split = transformed_split.map( |
| lambda exp: introduce_ner_feature( |
| exp, "person", |
| "person's name"), |
| batched=False) |
|
|
| new_splits[_split_name] = transformed_split |
| transformed_split = transformed_split.remove_columns(["deps", "feats", "head", "idx", "lemmas", "misc", "upos"]) |
| new_splits[_split_name] = transformed_split.filter(is_evenly_shaped) |
| return DatasetDict(new_splits) |
|
|
|
|
| if __name__ == "__main__": |
| arg_parser = argparse.ArgumentParser(description="Make training dataset.") |
| arg_parser.add_argument("--augment-typos", help='Augment final merged training data with typos.', |
| action="store_true", default=False) |
| arg_parser.add_argument("--log-level", help='Log level.', |
| action="store", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]) |
| arg_parser.add_argument("--save", help='Save dataset to disk.', |
| action="store_true", default=False) |
| arg_parser.add_argument("--save-path", help="Save final model to specified path.", |
| action="store", default="./ud_training_data") |
| arg_parser.add_argument("--show", help="Show examples: <split>/<col>/<label>/<count>", |
| action="store", default=None) |
| args = arg_parser.parse_args() |
| logging.config.dictConfig(default_logging_config) |
|
|
| |
| ud_en_ewt_ds = load_dataset("universal_dependencies", "en_ewt") |
| ud_en_gum_ds = load_dataset("universal_dependencies", "en_gum") |
| ud_en_pud_ds = load_dataset("universal_dependencies", "en_pud") |
|
|
| for loaded_ds_name, loaded_ds in { |
| "ud_en_ewt_ds": ud_en_ewt_ds, |
| "ud_en_gum_ds": ud_en_gum_ds, |
| "ud_en_pud_ds": ud_en_pud_ds |
| }.items(): |
| t_cnt = len(loaded_ds['test']) if 'test' in loaded_ds else 0 |
| tr_cnt = len(loaded_ds['train']) if 'train' in loaded_ds else 0 |
| v_cnt = len(loaded_ds['validation']) if 'train' in loaded_ds else 0 |
| logger.info(f"Loaded {loaded_ds_name}: t:{t_cnt}, tr:{tr_cnt}, v:{v_cnt}") |
|
|
| |
| en_ewt_processed = transform_and_filter_dataset(ud_en_ewt_ds, "ewt") |
| en_gum_processed = transform_and_filter_dataset(ud_en_gum_ds, "gum") |
| en_pud_processed = transform_and_filter_dataset(ud_en_pud_ds, "pud") |
|
|
| |
| final_dataset = DatasetDict() |
| final_dataset["test"] = concatenate_datasets( |
| [ |
| en_ewt_processed["test"], |
| en_gum_processed["test"], |
| en_pud_processed["test"], |
| ] |
| ) |
|
|
| final_dataset["train"] = concatenate_datasets( |
| [ |
| en_ewt_processed["train"], |
| en_gum_processed["train"], |
| ] |
| ) |
| if args.augment_typos: |
| final_dataset["train"] = final_dataset["train"].map(introduce_typos, batched=False) |
|
|
| final_dataset["validation"] = concatenate_datasets( |
| [ |
| en_ewt_processed["validation"], |
| en_gum_processed["validation"], |
| ] |
| ) |
| show_examples(final_dataset, args.show) |
| get_uniq_training_labels(final_dataset) |
| if args.save: |
| final_dataset.save_to_disk(args.save_path) |
|
|