| | from datasets import load_dataset, DatasetDict |
| | import argparse |
| | import logging |
| |
|
| | from utils import default_logging_config, get_uniq_training_labels, show_examples |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | allowed_pos = {'``', '$', "''", ',', '-LRB-', '-RRB-', '.', ':', 'ADD', 'CC', 'CD', 'DT', 'EX', |
| | 'FW', 'HYPH', 'IN', 'JJ', 'JJR', 'JJS', 'LS', 'MD', 'NFP', 'NN', 'NNP', 'NNPS', 'NNS', 'PDT', 'POS', |
| | 'PRP', 'PRP$', 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', |
| | 'WDT', 'WP', 'WP$', 'WRB'} |
| |
|
| | allowed_ner = {'O', 'B-PERSON', 'I-PERSON', 'B-NORP', 'I-NORP', 'B-FAC', 'I-FAC', 'B-ORG', 'I-ORG', 'B-GPE', 'I-GPE', |
| | 'B-LOC', 'I-LOC', 'B-PRODUCT', 'I-PRODUCT', 'B-DATE', 'I-DATE', 'B-TIME', 'I-TIME', 'B-PERCENT', |
| | 'I-PERCENT', 'B-MONEY', 'I-MONEY', 'B-QUANTITY', 'I-QUANTITY', 'B-ORDINAL', 'I-ORDINAL', |
| | 'B-CARDINAL', 'I-CARDINAL', 'B-EVENT', 'I-EVENT', 'B-WORK_OF_ART', 'I-WORK_OF_ART', 'B-LAW', 'I-LAW', |
| | 'B-LANGUAGE', 'I-LANGUAGE'} |
| |
|
| |
|
| | def is_valid_example(exp): |
| | """ |
| | Simple filter that checks if all pos_tags are in allowed_pos |
| | and all ner_tags are in allowed_ner. If you do not want any |
| | filtering, simply return True. |
| | """ |
| | |
| | |
| |
|
| | |
| | |
| | for pos_tag in exp["pos_tags"]: |
| | if pos_tag not in allowed_pos: |
| | return False |
| |
|
| | for ner_tag in exp["ner_tags"]: |
| | if ner_tag not in allowed_ner: |
| | return False |
| |
|
| | return True |
| |
|
| |
|
| | def transform_and_filter_dataset(onto_ds): |
| | """ |
| | onto_ds is a DatasetDict with splits: 'train', 'validation', 'test', etc. |
| | Return a new DatasetDict with the same splits but: |
| | - Filter out unwanted examples |
| | - Possibly rename or remove columns |
| | - Possibly introduce new columns |
| | """ |
| | pos_tag_int2str = onto_ds["train"].features["sentences"][0]["pos_tags"].feature.names |
| | ner_tag_int2str = onto_ds["train"].features["sentences"][0]["named_entities"].feature.names |
| |
|
| | def flatten_ontonotes(batch): |
| | out = { |
| | "tokens": [], |
| | "ner_tags": [], |
| | "pos_tags": [], |
| | "verb_predicate": [], |
| | } |
| | for doc_id, sents in zip(batch["document_id"], batch["sentences"]): |
| | for sent_info in sents: |
| | out["tokens"].append(sent_info["words"]) |
| | out["ner_tags"].append([ner_tag_int2str[i] for i in sent_info["named_entities"]]) |
| | out["pos_tags"].append([pos_tag_int2str[i] for i in sent_info["pos_tags"]]) |
| | out["verb_predicate"].append([("Yes" if s else "O") for s in sent_info["predicate_lemmas"]]) |
| | return out |
| |
|
| | new_splits = {} |
| | for split_name, split_ds in onto_ds.items(): |
| | |
| | flattened_ds = split_ds.map( |
| | flatten_ontonotes, |
| | batched=True, |
| | remove_columns=["sentences", "document_id"], |
| | ) |
| |
|
| | |
| | filtered_split = flattened_ds.filter(is_valid_example) |
| | new_splits[split_name] = filtered_split |
| |
|
| | return DatasetDict(new_splits) |
| |
|
| |
|
| | |
| | |
| | |
| | if __name__ == "__main__": |
| | import logging.config |
| |
|
| | arg_parser = argparse.ArgumentParser(description="Process OntoNotes CoNLL-2012 (English).") |
| | arg_parser.add_argument("--log-level", help="Log level.", action="store", |
| | default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]) |
| | arg_parser.add_argument("--save", help="Save final dataset to disk.", action="store_true", default=False) |
| | arg_parser.add_argument("--save-path", help="Where to save final dataset.", default="./conll2012_en12_training_data") |
| | arg_parser.add_argument("--show", help="Show examples: <split>/<col>/<label>/<count>", default=None) |
| | args = arg_parser.parse_args() |
| |
|
| | logging.config.dictConfig(default_logging_config) |
| | logger.setLevel(args.log_level) |
| |
|
| | |
| | |
| | ontonotes_ds = load_dataset("conll2012_ontonotesv5", "english_v12") |
| | logger.info(f"Splits loaded: {ontonotes_ds}") |
| |
|
| | |
| | final_dataset = transform_and_filter_dataset(ontonotes_ds) |
| |
|
| | |
| | show_examples(final_dataset, args.show) |
| |
|
| | |
| | get_uniq_training_labels(final_dataset) |
| |
|
| | |
| | if args.save: |
| | final_dataset.save_to_disk(args.save_path) |
| | logger.info("Saved dataset to %s", args.save_path) |
| |
|