from datasets import load_dataset, DatasetDict import argparse import logging from utils import default_logging_config, get_uniq_training_labels, show_examples logger = logging.getLogger(__name__) allowed_pos = {'``', '$', "''", ',', '-LRB-', '-RRB-', '.', ':', 'ADD', 'CC', 'CD', 'DT', 'EX', 'FW', 'HYPH', 'IN', 'JJ', 'JJR', 'JJS', 'LS', 'MD', 'NFP', 'NN', 'NNP', 'NNPS', 'NNS', 'PDT', 'POS', 'PRP', 'PRP$', 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP$', 'WRB'} allowed_ner = {'O', 'B-PERSON', 'I-PERSON', 'B-NORP', 'I-NORP', 'B-FAC', 'I-FAC', 'B-ORG', 'I-ORG', 'B-GPE', 'I-GPE', 'B-LOC', 'I-LOC', 'B-PRODUCT', 'I-PRODUCT', 'B-DATE', 'I-DATE', 'B-TIME', 'I-TIME', 'B-PERCENT', 'I-PERCENT', 'B-MONEY', 'I-MONEY', 'B-QUANTITY', 'I-QUANTITY', 'B-ORDINAL', 'I-ORDINAL', 'B-CARDINAL', 'I-CARDINAL', 'B-EVENT', 'I-EVENT', 'B-WORK_OF_ART', 'I-WORK_OF_ART', 'B-LAW', 'I-LAW', 'B-LANGUAGE', 'I-LANGUAGE'} def is_valid_example(exp): """ Simple filter that checks if all pos_tags are in allowed_pos and all ner_tags are in allowed_ner. If you do not want any filtering, simply return True. """ # You can skip filtering by just returning True: # return True # If your dataset has multiple tokens with possibly different tags, # check them all: for pos_tag in exp["pos_tags"]: if pos_tag not in allowed_pos: return False for ner_tag in exp["ner_tags"]: if ner_tag not in allowed_ner: return False return True def transform_and_filter_dataset(onto_ds): """ onto_ds is a DatasetDict with splits: 'train', 'validation', 'test', etc. Return a new DatasetDict with the same splits but: - Filter out unwanted examples - Possibly rename or remove columns - Possibly introduce new columns """ pos_tag_int2str = onto_ds["train"].features["sentences"][0]["pos_tags"].feature.names ner_tag_int2str = onto_ds["train"].features["sentences"][0]["named_entities"].feature.names def flatten_ontonotes(batch): out = { "tokens": [], "ner_tags": [], "pos_tags": [], "verb_predicate": [], } for doc_id, sents in zip(batch["document_id"], batch["sentences"]): for sent_info in sents: out["tokens"].append(sent_info["words"]) out["ner_tags"].append([ner_tag_int2str[i] for i in sent_info["named_entities"]]) out["pos_tags"].append([pos_tag_int2str[i] for i in sent_info["pos_tags"]]) out["verb_predicate"].append([("Yes" if s else "O") for s in sent_info["predicate_lemmas"]]) return out new_splits = {} for split_name, split_ds in onto_ds.items(): # Flatten flattened_ds = split_ds.map( flatten_ontonotes, batched=True, remove_columns=["sentences", "document_id"], # remove old columns ) # Filter out invalid examples filtered_split = flattened_ds.filter(is_valid_example) new_splits[split_name] = filtered_split return DatasetDict(new_splits) # ------------------------------------------------------------------------------ # 6) Main Script # ------------------------------------------------------------------------------ if __name__ == "__main__": import logging.config arg_parser = argparse.ArgumentParser(description="Process OntoNotes CoNLL-2012 (English).") arg_parser.add_argument("--log-level", help="Log level.", action="store", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]) arg_parser.add_argument("--save", help="Save final dataset to disk.", action="store_true", default=False) arg_parser.add_argument("--save-path", help="Where to save final dataset.", default="./conll2012_en12_training_data") arg_parser.add_argument("--show", help="Show examples: //