ner-explorer / dataset_maker.py
veryfansome's picture
feat: functional CLI editor
051eb53
from datasets import Dataset, load_dataset
from typing import Optional
import curses
import logging
import os
import random
import uuid
import wcwidth
from examples import custom_examples
from util import naive_sentence_end_pattern, naive_tokenize
logger = logging.getLogger(__name__)
DATASET_PATH = "dataset"
FEATURES = {
"<>^": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [
"LEFT", "RIGHT", "UP", "WRAP",
]] for item in sublist],
"{}": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [
"WRAP",
]] for item in sublist],
"()": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [
"WRAP",
]] for item in sublist],
"[]": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [
"WRAP",
]] for item in sublist],
"''": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [
"WRAP",
]] for item in sublist],
'""': [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [
"WRAP",
]] for item in sublist],
"``": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [
"WRAP",
]] for item in sublist],
"act": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [
"DANCE", # dance
"GAME", # game
"PROJECT", # project
]] for item in sublist],
"addr": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [
"CHAN", # radio frequency, TV channel or station name, e.g. 107.7 "The Bone", CBS, CNN, PBS, etc.
"DOOR", # apt, door, or suite number
"EMAIL", # email address
"FAC", # facility address or specific physical building name
"FILE", # file name and path
"GEO", # geo-coordinates
"IP", # IP address or CIDR notation
"MAIL", # physical mailbox or p.o. box
"PHONE", # telephone or fax
"SITE", # DNS domain name or website name
"URL", # URL parts not EMAIL, FILE, IP, or SITE
]] for item in sublist],
"concept": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [
"ART", # art, music, or literary concept
"BIO", # biology or medical concept
"BIZ", # business or marketing concept
"CHEM", # chemistry or bio-chem concept
"CLIM", # climate or ocean science concept
"ECON", # economic concept
"EDU", # education concept
"ENG", # engineering concept
"FIN", # finance or investment concept
"FORMAT", # formatting concept, e.g. list, outline, paragraph, table, figure, etc.
"GEOG", # geography concept
"GEOL", # geology concept
"INFO", # computing, data, or info sciences concept
"LANG", # linguistics concept
"LAW", # legal concept
"MATH", # math concept
"ORG", # organizational concept
"PHIL", # ethical or philosophical concept
"PHYS", # physics concept
"POLI", # sociological or political concept
"PROG", # computer programming concept
"PSY", # psychological concept
"RELI", # religious concept
"SOC", # sociology concept
"SPORTS", # sports concept
"WAR", # military concept
]] for item in sublist],
"coord": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [
"AND",
"OR", # or, nor is negatives connected by AND
"NEG", # Negative
]] for item in sublist],
"error": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [
"OMIT", # omitted or missing values due to formating or redactions
"ORDER", # word order problem
"SPELL", # spelling error
]] for item in sublist],
"foreign": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [
"ES", # Spanish
"FR", # French
"HANS", # Chinese simplified
"HANT", # Chinese traditional
"JA", # Japanese
"LA", # Latin
"LANG", # marker indicating language of subsequent foreign token
"LOAN", # loadword, English word based on foreign sound
"PHONE", # phonetic, formal (e.g. Hepburn romanization) or otherwise
"TRANS", # marker indicating translation
]] for item in sublist],
"media": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [
"AUD", # music and audio recordings
"IMG", # photos, paintings, and other images
"SOFT", # software
"TXT", # articles, books, papers, etc.
"VID", # film and other videos
]] for item in sublist],
"nature": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [
"FAUNA", # animal life
"FLORA", # plant life
"PHENOM", # phenomena
]] for item in sublist],
"num": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [
"AGE", # age
"COUNT", # count
"DIST", # distance
"FRAC", # faction
"MASS", # mass
"MONEY", # currency
"ORD", # ordinal
"PCT", # percent
"PCTILE", # percentile
"RANGE", # numeric range
"SPEED", # speed
"WEIGHT", # weight, force due to gravity
]] for item in sublist],
"org": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [
"ORG", # organization
"TITLE", # title or role
]] for item in sublist],
"other": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [
"DIV", # / or ÷
"EXP", # exponent, e.g. ^
"GT", # >
"LT", # <
"MATH", # non-arithmatic math notation
"MINUS", # -
"MULT", # x, X, or *
"OUT", # computer program output, e.g. stderr/out, logs, etc.
"PLUS", # +
"PROG", # computer programming notation
"SCI", # scientific notation outside math and programming
]] for item in sublist],
"people": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [
"GPE", # geopolitical entity e.g. countries, cities, states, or regions
"LANG", # language
"NORP", # nationalities, religious, or political groups. e.g. "American", "Muslim", or "Communist"
]] for item in sublist],
# person or personified being
"person": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [
"ALIAS", # nickname or alternative name
"HONOR", # honorific
"NAME", # person name
"PROF", # profession or professional designation e.g. CFA, CPA, MD
"USER", # username
]] for item in sublist],
"place": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [
"BYTE", # digital location
"FIC", # fictional locations
"LOC", # physical locations
"UI", # location on a user interface
"VIRT", # virtual location
"WEB", # web-connected location
]] for item in sublist],
"thing": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [
"AWARD", # named accolade or honorary award
"DEVICE", # device, tool, or toy
"FOOD", # food
]] for item in sublist],
"time": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [
"TIME", # years, dates, time values
"EVENT", # event in time
]] for item in sublist],
}
# TODO: might be multi-label
FEATURES["zz_prime"] = ["_all", "_ambiguous", *FEATURES.keys()] # primary feature, zz_ so it's labeled last
UUID5_NS = uuid.UUID("246a5463-afae-4571-a6e0-f319d74147d3") # Changes sentences signatures
def get_uniq_training_labels(ds: Dataset, columns_to_exclude: set[str] = None):
columns_to_train_on = [k for k in ds.features.keys() if k not in (
{"text", "tokens", "sig"} if columns_to_exclude is None else columns_to_exclude)]
# Create a dictionary of sets, keyed by each column name
label_counters = {col: dict() for col in columns_to_train_on}
unique_label_values = {col: set() for col in columns_to_train_on}
for example in ds:
# Each of these columns is a list (one entry per token),
# so we update our set with each token-level value
for col in columns_to_train_on:
unique_label_values[col].update(example[col])
for label_val in example[col]:
if label_val not in label_counters[col]:
label_counters[col][label_val] = 0 # Inits with 0
label_counters[col][label_val] += 1
logger.info(f"Columns:")
for col in columns_to_train_on:
logger.info(f" {col}:")
# Convert to a sorted list just to have a nice, stable ordering
vals = sorted(unique_label_values[col])
logger.info(f" {len(vals)} labels: {[f'{v}:{label_counters[col][v]}' for v in vals]}")
return unique_label_values
def main(stdscr, args):
wikipedia_dataset_name = "20231101.en"
wikipedia_dataset = load_dataset("wikimedia/wikipedia", wikipedia_dataset_name)
total_page_cnt = len(wikipedia_dataset["train"])
stdscr.clear()
stdscr.addstr(f"Loaded {wikipedia_dataset_name} containing {total_page_cnt} pages.")
new_dataset_dict = {k: [] for k in ["text", "tokens", *FEATURES.keys(), "sig"]}
signature_cache = set()
target_sig, target_col, target_idx, new_label = None, None, None, None
if args.replace:
args_replace_tokens = args.replace.split("/")
target_sig, target_col, target_idx, new_label = args_replace_tokens
if os.path.exists(DATASET_PATH):
# Load previous examples
for i, exp in enumerate(Dataset.load_from_disk(DATASET_PATH)):
sig = str(uuid.uuid5(UUID5_NS, exp["text"]))
if sig in signature_cache or sig == args.redo:
continue
signature_cache.add(sig)
if sig == target_sig:
for k, v in exp.items():
if k == target_col:
v[int(target_idx)] = new_label
new_dataset_dict[k].append(v)
else:
for k, v in exp.items():
new_dataset_dict[k].append(v)
esc_pressed = False
while not esc_pressed:
# Select random Wikipedia page
page = wikipedia_dataset["train"][random.randint(0, total_page_cnt)]
# If all custom examples are labeled, move on to Wikipedia
for page_chunk in (custom_examples + page["text"].split("\n\n")):
page_chunk = page_chunk.strip()
if not page_chunk:
continue
page_chunk_lines = page_chunk.split("\n")
for chunk_line in page_chunk_lines:
chunk_line = chunk_line.strip()
if not chunk_line:
continue
while not esc_pressed and chunk_line:
sentence_end_match = naive_sentence_end_pattern.search(chunk_line)
if sentence_end_match:
sentence_blob = chunk_line[:sentence_end_match.end()]
chunk_line = chunk_line[sentence_end_match.end():].strip()
else:
sentence_blob = chunk_line
chunk_line = ""
sig = str(uuid.uuid5(UUID5_NS, sentence_blob))
if sig in signature_cache:
continue
signature_cache.add(sig)
# TODO: sentence context
# - prefix each text with a context blob that gets tokenized with the text
# - label context blobs as B-CTXT and I-CTXT
# - this way, contextual information from outside the direct text can be injected
# - this allows injecting contexts from what we've already processed on the page
# - use a unique signal sequences to signal contexts, e.g.:
# - {{[[((prev:a,b;last:c,d))]]}}>>>
exp_idx = len(new_dataset_dict["text"])
stdscr.addstr(f"""\n\n>>>{sentence_blob}<<<
Press 'y' to accept or anything else to reject.
Press Esc to exit.
""")
ch = stdscr.getch()
stdscr.clear()
if ch == 27: # Esc
esc_pressed = True
elif ch == ord("y"):
naive_tokens = naive_tokenize(sentence_blob)
tokens_len = len(naive_tokens)
last_idx = tokens_len - 1
new_exp = {
"text": sentence_blob,
"tokens": naive_tokens,
}
for feat_name, feat_labels in FEATURES.items():
feat_labels_len = len(feat_labels)
labels_accepted = False
while not esc_pressed and not labels_accepted:
labels = []
skip_to_idx = None
skip_label = None
for token_idx, token in enumerate(naive_tokens):
if skip_to_idx is not None and skip_to_idx >= token_idx:
labels.append(skip_label if skip_label is not None else "O")
continue
skip_to_idx = None
skip_label = None
padding_len = (
1 + wcwidth.wcswidth(", ".join([f"'{t}'" for t in naive_tokens[:token_idx]]))
+ wcwidth.wcswidth(token)) + (0 if token_idx == 0 else 2)
enter_pressed = False
idx_blob = ""
stdscr.clear()
stdscr.addstr(f"""Example {exp_idx}:
{"\n".join([f"{pad_to_desired_len(k)}{v}" for k, v in new_exp.items()])}
{pad_to_desired_len(feat_name)}{labels}
Labels: {", ".join([f"{i}:{l}" for i, l in enumerate(feat_labels)])}
{naive_tokens}
{" " * padding_len}^
{" " * padding_len}{token_idx}
: """)
while not esc_pressed and not enter_pressed:
ch = stdscr.getch()
if ch in {8, 127, curses.KEY_BACKSPACE}: # Delete
idx_blob = idx_blob[:-1]
y, x = stdscr.getyx()
next_x = x - 1
if next_x > 1:
stdscr.move(y, x - 1)
stdscr.clrtoeol()
stdscr.refresh()
elif ch == 27: # Esc
esc_pressed = True
elif ch in {10, curses.KEY_ENTER}: # Enter
enter_pressed = True
else:
# Otherwise, add the character to the string
ch_chr = chr(ch)
stdscr.addstr(ch_chr)
idx_blob += ch_chr
if not idx_blob:
label_blob = idx_blob if idx_blob else "O"
labels.append(label_blob)
elif ">" in idx_blob:
try:
idx_blob, skip_distance = idx_blob.split(">")
if idx_blob:
label_idx = int(idx_blob)
if 0 <= label_idx < feat_labels_len:
label_blob = feat_labels[label_idx]
labels.append(label_blob)
skip_label = label_blob
else:
labels.append("O")
if skip_distance:
skip_to_idx = token_idx + int(skip_distance)
if skip_to_idx > last_idx:
skip_to_idx = last_idx
else:
skip_to_idx = last_idx
except ValueError:
stdscr.addstr(f"Could not convert {idx_blob} to an integer idx value.")
else:
try:
label_idx = int(idx_blob)
if 0 <= label_idx < feat_labels_len:
label_blob = feat_labels[label_idx]
labels.append(label_blob)
except ValueError:
stdscr.addstr(f"Could not convert {idx_blob} to an integer idx value.")
stdscr.clear()
stdscr.addstr(f"""Example {exp_idx}:
{"\n".join([f"{pad_to_desired_len(k)}{v}" for k, v in new_exp.items()])}
{pad_to_desired_len(feat_name)}{labels}
Press 'y' to accept or anything else to reject.
Press Esc to exit.""")
ch = stdscr.getch()
stdscr.clear()
if ch == 27: # Esc
esc_pressed = True
elif ch == ord("y"):
new_exp[feat_name] = labels
labels_accepted = True
if esc_pressed:
break
# Add if complete
new_exp["sig"] = sig
if sorted(new_exp.keys()) == sorted(new_dataset_dict.keys()):
for k, v in new_exp.items():
new_dataset_dict[k].append(v)
# Exiting
stdscr.clear()
return Dataset.from_dict(new_dataset_dict)
def pad_to_desired_len(blob: str, desired: int = 15):
blob_len = len(blob)
if blob_len < desired:
return f"{blob}{' ' * (desired - blob_len)}"
return blob
def show_examples(ds: Dataset, show_expr: Optional[str]):
if not show_expr:
ds_len = len(ds)
count_to_show = ds_len if ds_len < 25 else 25
examples_to_show = ds.shuffle()[:count_to_show]
else:
args_show_tokens = show_expr.split("/")
col_to_show, label_to_show, count_to_show = args_show_tokens
count_to_show = int(count_to_show)
examples_to_show = ds.filter(
lambda exp: label_to_show in exp[col_to_show]).shuffle(seed=42)[:count_to_show]
for i in range(count_to_show):
logger.info(f"Example {i}:")
for feature in examples_to_show.keys():
logger.info(f" {feature}: {examples_to_show[feature][i]}")
if __name__ == "__main__":
import argparse
import logging.config
arg_parser = argparse.ArgumentParser(description="Train multi-task model.")
arg_parser.add_argument("--redo",
help="Redo example based on signature",
action="store", default=None)
arg_parser.add_argument("--replace",
help="Replace a label using a sig, col, idx, and new label",
action="store", default=None)
arg_parser.add_argument("--show",
help="Show examples: <col>/<label>/<count>",
action="store", default=None)
parsed_args = arg_parser.parse_args()
logging.config.dictConfig({
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
},
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"formatter": "default",
},
},
"loggers": {
"": {
"level": "INFO",
"handlers": ["console"],
},
},
})
new_ds = curses.wrapper(main, parsed_args)
logger.info(f"Writing dataset to disk...\n{new_ds}")
show_examples(new_ds, parsed_args.show)
get_uniq_training_labels(new_ds)
new_ds.save_to_disk(DATASET_PATH)