multi-classifier / conll2012_dataset_maker.py
veryfansome's picture
feat: updated conll model
406d54a
raw
history blame
4.97 kB
from datasets import load_dataset, DatasetDict
import argparse
import logging
from utils import default_logging_config, get_uniq_training_labels, show_examples
logger = logging.getLogger(__name__)
allowed_pos = {'``', '$', "''", ',', '-LRB-', '-RRB-', '.', ':', 'ADD', 'CC', 'CD', 'DT', 'EX',
'FW', 'HYPH', 'IN', 'JJ', 'JJR', 'JJS', 'LS', 'MD', 'NFP', 'NN', 'NNP', 'NNPS', 'NNS', 'PDT', 'POS',
'PRP', 'PRP$', 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ',
'WDT', 'WP', 'WP$', 'WRB'}
allowed_ner = {'O', 'B-PERSON', 'I-PERSON', 'B-NORP', 'I-NORP', 'B-FAC', 'I-FAC', 'B-ORG', 'I-ORG', 'B-GPE', 'I-GPE',
'B-LOC', 'I-LOC', 'B-PRODUCT', 'I-PRODUCT', 'B-DATE', 'I-DATE', 'B-TIME', 'I-TIME', 'B-PERCENT',
'I-PERCENT', 'B-MONEY', 'I-MONEY', 'B-QUANTITY', 'I-QUANTITY', 'B-ORDINAL', 'I-ORDINAL',
'B-CARDINAL', 'I-CARDINAL', 'B-EVENT', 'I-EVENT', 'B-WORK_OF_ART', 'I-WORK_OF_ART', 'B-LAW', 'I-LAW',
'B-LANGUAGE', 'I-LANGUAGE'}
def is_valid_example(exp):
"""
Simple filter that checks if all pos_tags are in allowed_pos
and all ner_tags are in allowed_ner. If you do not want any
filtering, simply return True.
"""
# You can skip filtering by just returning True:
# return True
# If your dataset has multiple tokens with possibly different tags,
# check them all:
for pos_tag in exp["pos_tags"]:
if pos_tag not in allowed_pos:
return False
for ner_tag in exp["ner_tags"]:
if ner_tag not in allowed_ner:
return False
return True
def transform_and_filter_dataset(onto_ds):
"""
onto_ds is a DatasetDict with splits: 'train', 'validation', 'test', etc.
Return a new DatasetDict with the same splits but:
- Filter out unwanted examples
- Possibly rename or remove columns
- Possibly introduce new columns
"""
pos_tag_int2str = onto_ds["train"].features["sentences"][0]["pos_tags"].feature.names
ner_tag_int2str = onto_ds["train"].features["sentences"][0]["named_entities"].feature.names
def flatten_ontonotes(batch):
out = {
"tokens": [],
"ner_tags": [],
"pos_tags": [],
"verb_predicate": [],
}
for doc_id, sents in zip(batch["document_id"], batch["sentences"]):
for sent_info in sents:
out["tokens"].append(sent_info["words"])
out["ner_tags"].append([ner_tag_int2str[i] for i in sent_info["named_entities"]])
out["pos_tags"].append([pos_tag_int2str[i] for i in sent_info["pos_tags"]])
out["verb_predicate"].append([("Yes" if s else "O") for s in sent_info["predicate_lemmas"]])
return out
new_splits = {}
for split_name, split_ds in onto_ds.items():
# Flatten
flattened_ds = split_ds.map(
flatten_ontonotes,
batched=True,
remove_columns=["sentences", "document_id"], # remove old columns
)
# Filter out invalid examples
filtered_split = flattened_ds.filter(is_valid_example)
new_splits[split_name] = filtered_split
return DatasetDict(new_splits)
# ------------------------------------------------------------------------------
# 6) Main Script
# ------------------------------------------------------------------------------
if __name__ == "__main__":
import logging.config
arg_parser = argparse.ArgumentParser(description="Process OntoNotes CoNLL-2012 (English).")
arg_parser.add_argument("--log-level", help="Log level.", action="store",
default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"])
arg_parser.add_argument("--save", help="Save final dataset to disk.", action="store_true", default=False)
arg_parser.add_argument("--save-path", help="Where to save final dataset.", default="./conll2012_en12_training_data")
arg_parser.add_argument("--show", help="Show examples: <split>/<col>/<label>/<count>", default=None)
args = arg_parser.parse_args()
logging.config.dictConfig(default_logging_config)
logger.setLevel(args.log_level)
# 6a) Load OntoNotes (English) from the 'conll2012_ontonotesv5' script
# This usually yields "train", "validation", "test" splits.
ontonotes_ds = load_dataset("conll2012_ontonotesv5", "english_v12")
logger.info(f"Splits loaded: {ontonotes_ds}")
# 6b) Transform & Filter
final_dataset = transform_and_filter_dataset(ontonotes_ds)
# 6d) Show examples if user requested
show_examples(final_dataset, args.show)
# 6e) Log unique training labels (POS/NER) if you like
get_uniq_training_labels(final_dataset)
# 6f) Save to disk if requested
if args.save:
final_dataset.save_to_disk(args.save_path)
logger.info("Saved dataset to %s", args.save_path)