|
|
from datasets import Dataset, load_dataset |
|
|
from typing import Optional |
|
|
import curses |
|
|
import logging |
|
|
import os |
|
|
import random |
|
|
import uuid |
|
|
import wcwidth |
|
|
|
|
|
from examples import custom_examples |
|
|
from util import naive_sentence_end_pattern, naive_tokenize |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
DATASET_PATH = "dataset" |
|
|
FEATURES = { |
|
|
"<>^": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ |
|
|
"LEFT", "RIGHT", "UP", "WRAP", |
|
|
]] for item in sublist], |
|
|
|
|
|
"{}": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ |
|
|
"WRAP", |
|
|
]] for item in sublist], |
|
|
|
|
|
"()": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ |
|
|
"WRAP", |
|
|
]] for item in sublist], |
|
|
|
|
|
"[]": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ |
|
|
"WRAP", |
|
|
]] for item in sublist], |
|
|
|
|
|
"''": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ |
|
|
"WRAP", |
|
|
]] for item in sublist], |
|
|
|
|
|
'""': [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ |
|
|
"WRAP", |
|
|
]] for item in sublist], |
|
|
|
|
|
"``": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ |
|
|
"WRAP", |
|
|
]] for item in sublist], |
|
|
|
|
|
"act": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ |
|
|
"DANCE", |
|
|
"GAME", |
|
|
"PROJECT", |
|
|
]] for item in sublist], |
|
|
|
|
|
"addr": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ |
|
|
"CHAN", |
|
|
"DOOR", |
|
|
"EMAIL", |
|
|
"FAC", |
|
|
"FILE", |
|
|
"GEO", |
|
|
"IP", |
|
|
"MAIL", |
|
|
"PHONE", |
|
|
"SITE", |
|
|
"URL", |
|
|
]] for item in sublist], |
|
|
|
|
|
"concept": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ |
|
|
"ART", |
|
|
"BIO", |
|
|
"BIZ", |
|
|
"CHEM", |
|
|
"CLIM", |
|
|
"ECON", |
|
|
"EDU", |
|
|
"ENG", |
|
|
"FIN", |
|
|
"FORMAT", |
|
|
"GEOG", |
|
|
"GEOL", |
|
|
"INFO", |
|
|
"LANG", |
|
|
"LAW", |
|
|
"MATH", |
|
|
"ORG", |
|
|
"PHIL", |
|
|
"PHYS", |
|
|
"POLI", |
|
|
"PROG", |
|
|
"PSY", |
|
|
"RELI", |
|
|
"SOC", |
|
|
"SPORTS", |
|
|
"WAR", |
|
|
]] for item in sublist], |
|
|
|
|
|
"coord": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ |
|
|
"AND", |
|
|
"OR", |
|
|
"NEG", |
|
|
]] for item in sublist], |
|
|
|
|
|
"error": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ |
|
|
"OMIT", |
|
|
"ORDER", |
|
|
"SPELL", |
|
|
]] for item in sublist], |
|
|
|
|
|
"foreign": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ |
|
|
"ES", |
|
|
"FR", |
|
|
"HANS", |
|
|
"HANT", |
|
|
"JA", |
|
|
"LA", |
|
|
|
|
|
"LANG", |
|
|
"LOAN", |
|
|
"PHONE", |
|
|
"TRANS", |
|
|
]] for item in sublist], |
|
|
|
|
|
"media": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ |
|
|
"AUD", |
|
|
"IMG", |
|
|
"SOFT", |
|
|
"TXT", |
|
|
"VID", |
|
|
]] for item in sublist], |
|
|
|
|
|
"nature": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ |
|
|
"FAUNA", |
|
|
"FLORA", |
|
|
"PHENOM", |
|
|
]] for item in sublist], |
|
|
|
|
|
"num": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ |
|
|
"AGE", |
|
|
"COUNT", |
|
|
"DIST", |
|
|
"FRAC", |
|
|
"MASS", |
|
|
"MONEY", |
|
|
"ORD", |
|
|
"PCT", |
|
|
"PCTILE", |
|
|
"RANGE", |
|
|
"SPEED", |
|
|
"WEIGHT", |
|
|
]] for item in sublist], |
|
|
|
|
|
"org": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ |
|
|
"ORG", |
|
|
"TITLE", |
|
|
]] for item in sublist], |
|
|
|
|
|
"other": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ |
|
|
"DIV", |
|
|
"EXP", |
|
|
"GT", |
|
|
"LT", |
|
|
"MATH", |
|
|
"MINUS", |
|
|
"MULT", |
|
|
"OUT", |
|
|
"PLUS", |
|
|
"PROG", |
|
|
"SCI", |
|
|
]] for item in sublist], |
|
|
|
|
|
"people": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ |
|
|
"GPE", |
|
|
"LANG", |
|
|
"NORP", |
|
|
]] for item in sublist], |
|
|
|
|
|
|
|
|
"person": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ |
|
|
"ALIAS", |
|
|
"HONOR", |
|
|
"NAME", |
|
|
"PROF", |
|
|
"USER", |
|
|
]] for item in sublist], |
|
|
|
|
|
"place": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ |
|
|
"BYTE", |
|
|
"FIC", |
|
|
"LOC", |
|
|
"UI", |
|
|
"VIRT", |
|
|
"WEB", |
|
|
]] for item in sublist], |
|
|
|
|
|
"thing": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ |
|
|
"AWARD", |
|
|
"DEVICE", |
|
|
"FOOD", |
|
|
]] for item in sublist], |
|
|
|
|
|
"time": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ |
|
|
"TIME", |
|
|
"EVENT", |
|
|
]] for item in sublist], |
|
|
} |
|
|
|
|
|
FEATURES["zz_prime"] = ["_all", "_ambiguous", *FEATURES.keys()] |
|
|
UUID5_NS = uuid.UUID("246a5463-afae-4571-a6e0-f319d74147d3") |
|
|
|
|
|
|
|
|
def get_uniq_training_labels(ds: Dataset, columns_to_exclude: set[str] = None): |
|
|
columns_to_train_on = [k for k in ds.features.keys() if k not in ( |
|
|
{"text", "tokens", "sig"} if columns_to_exclude is None else columns_to_exclude)] |
|
|
|
|
|
|
|
|
label_counters = {col: dict() for col in columns_to_train_on} |
|
|
unique_label_values = {col: set() for col in columns_to_train_on} |
|
|
|
|
|
for example in ds: |
|
|
|
|
|
|
|
|
for col in columns_to_train_on: |
|
|
unique_label_values[col].update(example[col]) |
|
|
for label_val in example[col]: |
|
|
if label_val not in label_counters[col]: |
|
|
label_counters[col][label_val] = 0 |
|
|
label_counters[col][label_val] += 1 |
|
|
|
|
|
logger.info(f"Columns:") |
|
|
for col in columns_to_train_on: |
|
|
logger.info(f" {col}:") |
|
|
|
|
|
vals = sorted(unique_label_values[col]) |
|
|
logger.info(f" {len(vals)} labels: {[f'{v}:{label_counters[col][v]}' for v in vals]}") |
|
|
|
|
|
return unique_label_values |
|
|
|
|
|
|
|
|
def main(stdscr, args): |
|
|
wikipedia_dataset_name = "20231101.en" |
|
|
wikipedia_dataset = load_dataset("wikimedia/wikipedia", wikipedia_dataset_name) |
|
|
total_page_cnt = len(wikipedia_dataset["train"]) |
|
|
|
|
|
stdscr.clear() |
|
|
stdscr.addstr(f"Loaded {wikipedia_dataset_name} containing {total_page_cnt} pages.") |
|
|
|
|
|
new_dataset_dict = {k: [] for k in ["text", "tokens", *FEATURES.keys(), "sig"]} |
|
|
signature_cache = set() |
|
|
|
|
|
target_sig, target_col, target_idx, new_label = None, None, None, None |
|
|
if args.replace: |
|
|
args_replace_tokens = args.replace.split("/") |
|
|
target_sig, target_col, target_idx, new_label = args_replace_tokens |
|
|
if os.path.exists(DATASET_PATH): |
|
|
|
|
|
for i, exp in enumerate(Dataset.load_from_disk(DATASET_PATH)): |
|
|
sig = str(uuid.uuid5(UUID5_NS, exp["text"])) |
|
|
if sig in signature_cache or sig == args.redo: |
|
|
continue |
|
|
signature_cache.add(sig) |
|
|
if sig == target_sig: |
|
|
for k, v in exp.items(): |
|
|
if k == target_col: |
|
|
v[int(target_idx)] = new_label |
|
|
new_dataset_dict[k].append(v) |
|
|
else: |
|
|
for k, v in exp.items(): |
|
|
new_dataset_dict[k].append(v) |
|
|
|
|
|
esc_pressed = False |
|
|
while not esc_pressed: |
|
|
|
|
|
page = wikipedia_dataset["train"][random.randint(0, total_page_cnt)] |
|
|
|
|
|
for page_chunk in (custom_examples + page["text"].split("\n\n")): |
|
|
page_chunk = page_chunk.strip() |
|
|
if not page_chunk: |
|
|
continue |
|
|
page_chunk_lines = page_chunk.split("\n") |
|
|
for chunk_line in page_chunk_lines: |
|
|
chunk_line = chunk_line.strip() |
|
|
if not chunk_line: |
|
|
continue |
|
|
while not esc_pressed and chunk_line: |
|
|
sentence_end_match = naive_sentence_end_pattern.search(chunk_line) |
|
|
if sentence_end_match: |
|
|
sentence_blob = chunk_line[:sentence_end_match.end()] |
|
|
chunk_line = chunk_line[sentence_end_match.end():].strip() |
|
|
else: |
|
|
sentence_blob = chunk_line |
|
|
chunk_line = "" |
|
|
|
|
|
sig = str(uuid.uuid5(UUID5_NS, sentence_blob)) |
|
|
if sig in signature_cache: |
|
|
continue |
|
|
signature_cache.add(sig) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
exp_idx = len(new_dataset_dict["text"]) |
|
|
stdscr.addstr(f"""\n\n>>>{sentence_blob}<<< |
|
|
|
|
|
Press 'y' to accept or anything else to reject. |
|
|
Press Esc to exit. |
|
|
""") |
|
|
ch = stdscr.getch() |
|
|
stdscr.clear() |
|
|
if ch == 27: |
|
|
esc_pressed = True |
|
|
elif ch == ord("y"): |
|
|
naive_tokens = naive_tokenize(sentence_blob) |
|
|
tokens_len = len(naive_tokens) |
|
|
last_idx = tokens_len - 1 |
|
|
new_exp = { |
|
|
"text": sentence_blob, |
|
|
"tokens": naive_tokens, |
|
|
} |
|
|
|
|
|
for feat_name, feat_labels in FEATURES.items(): |
|
|
feat_labels_len = len(feat_labels) |
|
|
labels_accepted = False |
|
|
while not esc_pressed and not labels_accepted: |
|
|
labels = [] |
|
|
skip_to_idx = None |
|
|
skip_label = None |
|
|
for token_idx, token in enumerate(naive_tokens): |
|
|
if skip_to_idx is not None and skip_to_idx >= token_idx: |
|
|
labels.append(skip_label if skip_label is not None else "O") |
|
|
continue |
|
|
skip_to_idx = None |
|
|
skip_label = None |
|
|
padding_len = ( |
|
|
1 + wcwidth.wcswidth(", ".join([f"'{t}'" for t in naive_tokens[:token_idx]])) |
|
|
+ wcwidth.wcswidth(token)) + (0 if token_idx == 0 else 2) |
|
|
|
|
|
enter_pressed = False |
|
|
idx_blob = "" |
|
|
stdscr.clear() |
|
|
stdscr.addstr(f"""Example {exp_idx}: |
|
|
|
|
|
{"\n".join([f"{pad_to_desired_len(k)}{v}" for k, v in new_exp.items()])} |
|
|
{pad_to_desired_len(feat_name)}{labels} |
|
|
|
|
|
Labels: {", ".join([f"{i}:{l}" for i, l in enumerate(feat_labels)])} |
|
|
|
|
|
{naive_tokens} |
|
|
{" " * padding_len}^ |
|
|
{" " * padding_len}{token_idx} |
|
|
|
|
|
: """) |
|
|
while not esc_pressed and not enter_pressed: |
|
|
ch = stdscr.getch() |
|
|
if ch in {8, 127, curses.KEY_BACKSPACE}: |
|
|
idx_blob = idx_blob[:-1] |
|
|
y, x = stdscr.getyx() |
|
|
next_x = x - 1 |
|
|
if next_x > 1: |
|
|
stdscr.move(y, x - 1) |
|
|
stdscr.clrtoeol() |
|
|
stdscr.refresh() |
|
|
elif ch == 27: |
|
|
esc_pressed = True |
|
|
elif ch in {10, curses.KEY_ENTER}: |
|
|
enter_pressed = True |
|
|
else: |
|
|
|
|
|
ch_chr = chr(ch) |
|
|
stdscr.addstr(ch_chr) |
|
|
idx_blob += ch_chr |
|
|
if not idx_blob: |
|
|
label_blob = idx_blob if idx_blob else "O" |
|
|
labels.append(label_blob) |
|
|
elif ">" in idx_blob: |
|
|
try: |
|
|
idx_blob, skip_distance = idx_blob.split(">") |
|
|
if idx_blob: |
|
|
label_idx = int(idx_blob) |
|
|
if 0 <= label_idx < feat_labels_len: |
|
|
label_blob = feat_labels[label_idx] |
|
|
labels.append(label_blob) |
|
|
skip_label = label_blob |
|
|
else: |
|
|
labels.append("O") |
|
|
|
|
|
if skip_distance: |
|
|
skip_to_idx = token_idx + int(skip_distance) |
|
|
if skip_to_idx > last_idx: |
|
|
skip_to_idx = last_idx |
|
|
else: |
|
|
skip_to_idx = last_idx |
|
|
except ValueError: |
|
|
stdscr.addstr(f"Could not convert {idx_blob} to an integer idx value.") |
|
|
else: |
|
|
try: |
|
|
label_idx = int(idx_blob) |
|
|
if 0 <= label_idx < feat_labels_len: |
|
|
label_blob = feat_labels[label_idx] |
|
|
labels.append(label_blob) |
|
|
except ValueError: |
|
|
stdscr.addstr(f"Could not convert {idx_blob} to an integer idx value.") |
|
|
stdscr.clear() |
|
|
stdscr.addstr(f"""Example {exp_idx}: |
|
|
|
|
|
{"\n".join([f"{pad_to_desired_len(k)}{v}" for k, v in new_exp.items()])} |
|
|
{pad_to_desired_len(feat_name)}{labels} |
|
|
|
|
|
Press 'y' to accept or anything else to reject. |
|
|
Press Esc to exit.""") |
|
|
ch = stdscr.getch() |
|
|
stdscr.clear() |
|
|
if ch == 27: |
|
|
esc_pressed = True |
|
|
elif ch == ord("y"): |
|
|
new_exp[feat_name] = labels |
|
|
labels_accepted = True |
|
|
if esc_pressed: |
|
|
break |
|
|
|
|
|
new_exp["sig"] = sig |
|
|
if sorted(new_exp.keys()) == sorted(new_dataset_dict.keys()): |
|
|
for k, v in new_exp.items(): |
|
|
new_dataset_dict[k].append(v) |
|
|
|
|
|
stdscr.clear() |
|
|
return Dataset.from_dict(new_dataset_dict) |
|
|
|
|
|
|
|
|
def pad_to_desired_len(blob: str, desired: int = 15): |
|
|
blob_len = len(blob) |
|
|
if blob_len < desired: |
|
|
return f"{blob}{' ' * (desired - blob_len)}" |
|
|
return blob |
|
|
|
|
|
|
|
|
def show_examples(ds: Dataset, show_expr: Optional[str]): |
|
|
if not show_expr: |
|
|
ds_len = len(ds) |
|
|
count_to_show = ds_len if ds_len < 25 else 25 |
|
|
examples_to_show = ds.shuffle()[:count_to_show] |
|
|
else: |
|
|
args_show_tokens = show_expr.split("/") |
|
|
col_to_show, label_to_show, count_to_show = args_show_tokens |
|
|
count_to_show = int(count_to_show) |
|
|
examples_to_show = ds.filter( |
|
|
lambda exp: label_to_show in exp[col_to_show]).shuffle(seed=42)[:count_to_show] |
|
|
for i in range(count_to_show): |
|
|
logger.info(f"Example {i}:") |
|
|
for feature in examples_to_show.keys(): |
|
|
logger.info(f" {feature}: {examples_to_show[feature][i]}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
import logging.config |
|
|
|
|
|
arg_parser = argparse.ArgumentParser(description="Train multi-task model.") |
|
|
arg_parser.add_argument("--redo", |
|
|
help="Redo example based on signature", |
|
|
action="store", default=None) |
|
|
arg_parser.add_argument("--replace", |
|
|
help="Replace a label using a sig, col, idx, and new label", |
|
|
action="store", default=None) |
|
|
arg_parser.add_argument("--show", |
|
|
help="Show examples: <col>/<label>/<count>", |
|
|
action="store", default=None) |
|
|
parsed_args = arg_parser.parse_args() |
|
|
|
|
|
logging.config.dictConfig({ |
|
|
"version": 1, |
|
|
"disable_existing_loggers": False, |
|
|
"formatters": { |
|
|
"default": { |
|
|
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
|
|
}, |
|
|
}, |
|
|
"handlers": { |
|
|
"console": { |
|
|
"class": "logging.StreamHandler", |
|
|
"formatter": "default", |
|
|
}, |
|
|
}, |
|
|
"loggers": { |
|
|
"": { |
|
|
"level": "INFO", |
|
|
"handlers": ["console"], |
|
|
}, |
|
|
}, |
|
|
}) |
|
|
|
|
|
new_ds = curses.wrapper(main, parsed_args) |
|
|
logger.info(f"Writing dataset to disk...\n{new_ds}") |
|
|
show_examples(new_ds, parsed_args.show) |
|
|
get_uniq_training_labels(new_ds) |
|
|
new_ds.save_to_disk(DATASET_PATH) |
|
|
|