|
|
from datasets import load_dataset, DatasetDict |
|
|
import argparse |
|
|
import logging |
|
|
|
|
|
from utils import default_logging_config, get_uniq_training_labels, show_examples |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
allowed_pos = {'``', '$', "''", ',', '-LRB-', '-RRB-', '.', ':', 'ADD', 'CC', 'CD', 'DT', 'EX', |
|
|
'FW', 'HYPH', 'IN', 'JJ', 'JJR', 'JJS', 'LS', 'MD', 'NFP', 'NN', 'NNP', 'NNPS', 'NNS', 'PDT', 'POS', |
|
|
'PRP', 'PRP$', 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', |
|
|
'WDT', 'WP', 'WP$', 'WRB'} |
|
|
|
|
|
allowed_ner = {'O', 'B-PERSON', 'I-PERSON', 'B-NORP', 'I-NORP', 'B-FAC', 'I-FAC', 'B-ORG', 'I-ORG', 'B-GPE', 'I-GPE', |
|
|
'B-LOC', 'I-LOC', 'B-PRODUCT', 'I-PRODUCT', 'B-DATE', 'I-DATE', 'B-TIME', 'I-TIME', 'B-PERCENT', |
|
|
'I-PERCENT', 'B-MONEY', 'I-MONEY', 'B-QUANTITY', 'I-QUANTITY', 'B-ORDINAL', 'I-ORDINAL', |
|
|
'B-CARDINAL', 'I-CARDINAL', 'B-EVENT', 'I-EVENT', 'B-WORK_OF_ART', 'I-WORK_OF_ART', 'B-LAW', 'I-LAW', |
|
|
'B-LANGUAGE', 'I-LANGUAGE'} |
|
|
|
|
|
|
|
|
def is_valid_example(exp): |
|
|
""" |
|
|
Simple filter that checks if all pos_tags are in allowed_pos |
|
|
and all ner_tags are in allowed_ner. If you do not want any |
|
|
filtering, simply return True. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for pos_tag in exp["pos_tags"]: |
|
|
if pos_tag not in allowed_pos: |
|
|
return False |
|
|
|
|
|
for ner_tag in exp["ner_tags"]: |
|
|
if ner_tag not in allowed_ner: |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
def transform_and_filter_dataset(onto_ds): |
|
|
""" |
|
|
onto_ds is a DatasetDict with splits: 'train', 'validation', 'test', etc. |
|
|
Return a new DatasetDict with the same splits but: |
|
|
- Filter out unwanted examples |
|
|
- Possibly rename or remove columns |
|
|
- Possibly introduce new columns |
|
|
""" |
|
|
pos_tag_int2str = onto_ds["train"].features["sentences"][0]["pos_tags"].feature.names |
|
|
ner_tag_int2str = onto_ds["train"].features["sentences"][0]["named_entities"].feature.names |
|
|
|
|
|
def flatten_ontonotes(batch): |
|
|
out = { |
|
|
"tokens": [], |
|
|
"ner_tags": [], |
|
|
"pos_tags": [], |
|
|
"verb_predicate": [], |
|
|
} |
|
|
for doc_id, sents in zip(batch["document_id"], batch["sentences"]): |
|
|
for sent_info in sents: |
|
|
out["tokens"].append(sent_info["words"]) |
|
|
out["ner_tags"].append([ner_tag_int2str[i] for i in sent_info["named_entities"]]) |
|
|
out["pos_tags"].append([pos_tag_int2str[i] for i in sent_info["pos_tags"]]) |
|
|
out["verb_predicate"].append([("Yes" if s else "O") for s in sent_info["predicate_lemmas"]]) |
|
|
return out |
|
|
|
|
|
new_splits = {} |
|
|
for split_name, split_ds in onto_ds.items(): |
|
|
|
|
|
flattened_ds = split_ds.map( |
|
|
flatten_ontonotes, |
|
|
batched=True, |
|
|
remove_columns=["sentences", "document_id"], |
|
|
) |
|
|
|
|
|
|
|
|
filtered_split = flattened_ds.filter(is_valid_example) |
|
|
new_splits[split_name] = filtered_split |
|
|
|
|
|
return DatasetDict(new_splits) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import logging.config |
|
|
|
|
|
arg_parser = argparse.ArgumentParser(description="Process OntoNotes CoNLL-2012 (English).") |
|
|
arg_parser.add_argument("--log-level", help="Log level.", action="store", |
|
|
default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]) |
|
|
arg_parser.add_argument("--save", help="Save final dataset to disk.", action="store_true", default=False) |
|
|
arg_parser.add_argument("--save-path", help="Where to save final dataset.", default="./conll2012_en12_training_data") |
|
|
arg_parser.add_argument("--show", help="Show examples: <split>/<col>/<label>/<count>", default=None) |
|
|
args = arg_parser.parse_args() |
|
|
|
|
|
logging.config.dictConfig(default_logging_config) |
|
|
logger.setLevel(args.log_level) |
|
|
|
|
|
|
|
|
|
|
|
ontonotes_ds = load_dataset("conll2012_ontonotesv5", "english_v12") |
|
|
logger.info(f"Splits loaded: {ontonotes_ds}") |
|
|
|
|
|
|
|
|
final_dataset = transform_and_filter_dataset(ontonotes_ds) |
|
|
|
|
|
|
|
|
show_examples(final_dataset, args.show) |
|
|
|
|
|
|
|
|
get_uniq_training_labels(final_dataset) |
|
|
|
|
|
|
|
|
if args.save: |
|
|
final_dataset.save_to_disk(args.save_path) |
|
|
logger.info("Saved dataset to %s", args.save_path) |
|
|
|