File size: 4,968 Bytes
406d54a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from datasets import load_dataset, DatasetDict
import argparse
import logging

from utils import default_logging_config, get_uniq_training_labels, show_examples

logger = logging.getLogger(__name__)


allowed_pos = {'``', '$', "''", ',', '-LRB-', '-RRB-', '.', ':', 'ADD', 'CC', 'CD', 'DT', 'EX',
               'FW', 'HYPH', 'IN', 'JJ', 'JJR', 'JJS', 'LS', 'MD', 'NFP', 'NN', 'NNP', 'NNPS', 'NNS', 'PDT', 'POS',
               'PRP', 'PRP$', 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ',
               'WDT', 'WP', 'WP$', 'WRB'}

allowed_ner = {'O', 'B-PERSON', 'I-PERSON', 'B-NORP', 'I-NORP', 'B-FAC', 'I-FAC', 'B-ORG', 'I-ORG', 'B-GPE', 'I-GPE',
               'B-LOC', 'I-LOC', 'B-PRODUCT', 'I-PRODUCT', 'B-DATE', 'I-DATE', 'B-TIME', 'I-TIME', 'B-PERCENT',
               'I-PERCENT', 'B-MONEY', 'I-MONEY', 'B-QUANTITY', 'I-QUANTITY', 'B-ORDINAL', 'I-ORDINAL',
               'B-CARDINAL', 'I-CARDINAL', 'B-EVENT', 'I-EVENT', 'B-WORK_OF_ART', 'I-WORK_OF_ART', 'B-LAW', 'I-LAW',
               'B-LANGUAGE', 'I-LANGUAGE'}


def is_valid_example(exp):
    """
    Simple filter that checks if all pos_tags are in allowed_pos
    and all ner_tags are in allowed_ner. If you do not want any
    filtering, simply return True.
    """
    # You can skip filtering by just returning True:
    # return True

    # If your dataset has multiple tokens with possibly different tags,
    # check them all:
    for pos_tag in exp["pos_tags"]:
        if pos_tag not in allowed_pos:
            return False

    for ner_tag in exp["ner_tags"]:
        if ner_tag not in allowed_ner:
            return False

    return True


def transform_and_filter_dataset(onto_ds):
    """
    onto_ds is a DatasetDict with splits: 'train', 'validation', 'test', etc.
    Return a new DatasetDict with the same splits but:
      - Filter out unwanted examples
      - Possibly rename or remove columns
      - Possibly introduce new columns
    """
    pos_tag_int2str = onto_ds["train"].features["sentences"][0]["pos_tags"].feature.names
    ner_tag_int2str = onto_ds["train"].features["sentences"][0]["named_entities"].feature.names

    def flatten_ontonotes(batch):
        out = {
            "tokens": [],
            "ner_tags": [],
            "pos_tags": [],
            "verb_predicate": [],
        }
        for doc_id, sents in zip(batch["document_id"], batch["sentences"]):
            for sent_info in sents:
                out["tokens"].append(sent_info["words"])
                out["ner_tags"].append([ner_tag_int2str[i] for i in sent_info["named_entities"]])
                out["pos_tags"].append([pos_tag_int2str[i] for i in sent_info["pos_tags"]])
                out["verb_predicate"].append([("Yes" if s else "O") for s in sent_info["predicate_lemmas"]])
        return out

    new_splits = {}
    for split_name, split_ds in onto_ds.items():
        # Flatten
        flattened_ds = split_ds.map(
            flatten_ontonotes,
            batched=True,
            remove_columns=["sentences", "document_id"],  # remove old columns
        )

        # Filter out invalid examples
        filtered_split = flattened_ds.filter(is_valid_example)
        new_splits[split_name] = filtered_split

    return DatasetDict(new_splits)


# ------------------------------------------------------------------------------
# 6) Main Script
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    import logging.config

    arg_parser = argparse.ArgumentParser(description="Process OntoNotes CoNLL-2012 (English).")
    arg_parser.add_argument("--log-level", help="Log level.", action="store",
                            default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"])
    arg_parser.add_argument("--save", help="Save final dataset to disk.", action="store_true", default=False)
    arg_parser.add_argument("--save-path", help="Where to save final dataset.", default="./conll2012_en12_training_data")
    arg_parser.add_argument("--show", help="Show examples: <split>/<col>/<label>/<count>", default=None)
    args = arg_parser.parse_args()

    logging.config.dictConfig(default_logging_config)
    logger.setLevel(args.log_level)

    # 6a) Load OntoNotes (English) from the 'conll2012_ontonotesv5' script
    #     This usually yields "train", "validation", "test" splits.
    ontonotes_ds = load_dataset("conll2012_ontonotesv5", "english_v12")
    logger.info(f"Splits loaded: {ontonotes_ds}")

    # 6b) Transform & Filter
    final_dataset = transform_and_filter_dataset(ontonotes_ds)

    # 6d) Show examples if user requested
    show_examples(final_dataset, args.show)

    # 6e) Log unique training labels (POS/NER) if you like
    get_uniq_training_labels(final_dataset)

    # 6f) Save to disk if requested
    if args.save:
        final_dataset.save_to_disk(args.save_path)
        logger.info("Saved dataset to %s", args.save_path)