from datasets import Dataset, load_dataset from typing import Optional import curses import logging import os import random import uuid import wcwidth from examples import custom_examples from util import naive_sentence_end_pattern, naive_tokenize logger = logging.getLogger(__name__) DATASET_PATH = "dataset" FEATURES = { "<>^": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ "LEFT", "RIGHT", "UP", "WRAP", ]] for item in sublist], "{}": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ "WRAP", ]] for item in sublist], "()": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ "WRAP", ]] for item in sublist], "[]": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ "WRAP", ]] for item in sublist], "''": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ "WRAP", ]] for item in sublist], '""': [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ "WRAP", ]] for item in sublist], "``": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ "WRAP", ]] for item in sublist], "act": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ "DANCE", # dance "GAME", # game "PROJECT", # project ]] for item in sublist], "addr": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ "CHAN", # radio frequency, TV channel or station name, e.g. 107.7 "The Bone", CBS, CNN, PBS, etc. "DOOR", # apt, door, or suite number "EMAIL", # email address "FAC", # facility address or specific physical building name "FILE", # file name and path "GEO", # geo-coordinates "IP", # IP address or CIDR notation "MAIL", # physical mailbox or p.o. box "PHONE", # telephone or fax "SITE", # DNS domain name or website name "URL", # URL parts not EMAIL, FILE, IP, or SITE ]] for item in sublist], "concept": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ "ART", # art, music, or literary concept "BIO", # biology or medical concept "BIZ", # business or marketing concept "CHEM", # chemistry or bio-chem concept "CLIM", # climate or ocean science concept "ECON", # economic concept "EDU", # education concept "ENG", # engineering concept "FIN", # finance or investment concept "FORMAT", # formatting concept, e.g. list, outline, paragraph, table, figure, etc. "GEOG", # geography concept "GEOL", # geology concept "INFO", # computing, data, or info sciences concept "LANG", # linguistics concept "LAW", # legal concept "MATH", # math concept "ORG", # organizational concept "PHIL", # ethical or philosophical concept "PHYS", # physics concept "POLI", # sociological or political concept "PROG", # computer programming concept "PSY", # psychological concept "RELI", # religious concept "SOC", # sociology concept "SPORTS", # sports concept "WAR", # military concept ]] for item in sublist], "coord": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ "AND", "OR", # or, nor is negatives connected by AND "NEG", # Negative ]] for item in sublist], "error": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ "OMIT", # omitted or missing values due to formating or redactions "ORDER", # word order problem "SPELL", # spelling error ]] for item in sublist], "foreign": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ "ES", # Spanish "FR", # French "HANS", # Chinese simplified "HANT", # Chinese traditional "JA", # Japanese "LA", # Latin "LANG", # marker indicating language of subsequent foreign token "LOAN", # loadword, English word based on foreign sound "PHONE", # phonetic, formal (e.g. Hepburn romanization) or otherwise "TRANS", # marker indicating translation ]] for item in sublist], "media": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ "AUD", # music and audio recordings "IMG", # photos, paintings, and other images "SOFT", # software "TXT", # articles, books, papers, etc. "VID", # film and other videos ]] for item in sublist], "nature": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ "FAUNA", # animal life "FLORA", # plant life "PHENOM", # phenomena ]] for item in sublist], "num": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ "AGE", # age "COUNT", # count "DIST", # distance "FRAC", # faction "MASS", # mass "MONEY", # currency "ORD", # ordinal "PCT", # percent "PCTILE", # percentile "RANGE", # numeric range "SPEED", # speed "WEIGHT", # weight, force due to gravity ]] for item in sublist], "org": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ "ORG", # organization "TITLE", # title or role ]] for item in sublist], "other": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ "DIV", # / or รท "EXP", # exponent, e.g. ^ "GT", # > "LT", # < "MATH", # non-arithmatic math notation "MINUS", # - "MULT", # x, X, or * "OUT", # computer program output, e.g. stderr/out, logs, etc. "PLUS", # + "PROG", # computer programming notation "SCI", # scientific notation outside math and programming ]] for item in sublist], "people": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ "GPE", # geopolitical entity e.g. countries, cities, states, or regions "LANG", # language "NORP", # nationalities, religious, or political groups. e.g. "American", "Muslim", or "Communist" ]] for item in sublist], # person or personified being "person": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ "ALIAS", # nickname or alternative name "HONOR", # honorific "NAME", # person name "PROF", # profession or professional designation e.g. CFA, CPA, MD "USER", # username ]] for item in sublist], "place": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ "BYTE", # digital location "FIC", # fictional locations "LOC", # physical locations "UI", # location on a user interface "VIRT", # virtual location "WEB", # web-connected location ]] for item in sublist], "thing": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ "AWARD", # named accolade or honorary award "DEVICE", # device, tool, or toy "FOOD", # food ]] for item in sublist], "time": [item for sublist in [[f"B-{l}", f"I-{l}"] for l in [ "TIME", # years, dates, time values "EVENT", # event in time ]] for item in sublist], } # TODO: might be multi-label FEATURES["zz_prime"] = ["_all", "_ambiguous", *FEATURES.keys()] # primary feature, zz_ so it's labeled last UUID5_NS = uuid.UUID("246a5463-afae-4571-a6e0-f319d74147d3") # Changes sentences signatures def get_uniq_training_labels(ds: Dataset, columns_to_exclude: set[str] = None): columns_to_train_on = [k for k in ds.features.keys() if k not in ( {"text", "tokens", "sig"} if columns_to_exclude is None else columns_to_exclude)] # Create a dictionary of sets, keyed by each column name label_counters = {col: dict() for col in columns_to_train_on} unique_label_values = {col: set() for col in columns_to_train_on} for example in ds: # Each of these columns is a list (one entry per token), # so we update our set with each token-level value for col in columns_to_train_on: unique_label_values[col].update(example[col]) for label_val in example[col]: if label_val not in label_counters[col]: label_counters[col][label_val] = 0 # Inits with 0 label_counters[col][label_val] += 1 logger.info(f"Columns:") for col in columns_to_train_on: logger.info(f" {col}:") # Convert to a sorted list just to have a nice, stable ordering vals = sorted(unique_label_values[col]) logger.info(f" {len(vals)} labels: {[f'{v}:{label_counters[col][v]}' for v in vals]}") return unique_label_values def main(stdscr, args): wikipedia_dataset_name = "20231101.en" wikipedia_dataset = load_dataset("wikimedia/wikipedia", wikipedia_dataset_name) total_page_cnt = len(wikipedia_dataset["train"]) stdscr.clear() stdscr.addstr(f"Loaded {wikipedia_dataset_name} containing {total_page_cnt} pages.") new_dataset_dict = {k: [] for k in ["text", "tokens", *FEATURES.keys(), "sig"]} signature_cache = set() target_sig, target_col, target_idx, new_label = None, None, None, None if args.replace: args_replace_tokens = args.replace.split("/") target_sig, target_col, target_idx, new_label = args_replace_tokens if os.path.exists(DATASET_PATH): # Load previous examples for i, exp in enumerate(Dataset.load_from_disk(DATASET_PATH)): sig = str(uuid.uuid5(UUID5_NS, exp["text"])) if sig in signature_cache or sig == args.redo: continue signature_cache.add(sig) if sig == target_sig: for k, v in exp.items(): if k == target_col: v[int(target_idx)] = new_label new_dataset_dict[k].append(v) else: for k, v in exp.items(): new_dataset_dict[k].append(v) esc_pressed = False while not esc_pressed: # Select random Wikipedia page page = wikipedia_dataset["train"][random.randint(0, total_page_cnt)] # If all custom examples are labeled, move on to Wikipedia for page_chunk in (custom_examples + page["text"].split("\n\n")): page_chunk = page_chunk.strip() if not page_chunk: continue page_chunk_lines = page_chunk.split("\n") for chunk_line in page_chunk_lines: chunk_line = chunk_line.strip() if not chunk_line: continue while not esc_pressed and chunk_line: sentence_end_match = naive_sentence_end_pattern.search(chunk_line) if sentence_end_match: sentence_blob = chunk_line[:sentence_end_match.end()] chunk_line = chunk_line[sentence_end_match.end():].strip() else: sentence_blob = chunk_line chunk_line = "" sig = str(uuid.uuid5(UUID5_NS, sentence_blob)) if sig in signature_cache: continue signature_cache.add(sig) # TODO: sentence context # - prefix each text with a context blob that gets tokenized with the text # - label context blobs as B-CTXT and I-CTXT # - this way, contextual information from outside the direct text can be injected # - this allows injecting contexts from what we've already processed on the page # - use a unique signal sequences to signal contexts, e.g.: # - {{[[((prev:a,b;last:c,d))]]}}>>> exp_idx = len(new_dataset_dict["text"]) stdscr.addstr(f"""\n\n>>>{sentence_blob}<<< Press 'y' to accept or anything else to reject. Press Esc to exit. """) ch = stdscr.getch() stdscr.clear() if ch == 27: # Esc esc_pressed = True elif ch == ord("y"): naive_tokens = naive_tokenize(sentence_blob) tokens_len = len(naive_tokens) last_idx = tokens_len - 1 new_exp = { "text": sentence_blob, "tokens": naive_tokens, } for feat_name, feat_labels in FEATURES.items(): feat_labels_len = len(feat_labels) labels_accepted = False while not esc_pressed and not labels_accepted: labels = [] skip_to_idx = None skip_label = None for token_idx, token in enumerate(naive_tokens): if skip_to_idx is not None and skip_to_idx >= token_idx: labels.append(skip_label if skip_label is not None else "O") continue skip_to_idx = None skip_label = None padding_len = ( 1 + wcwidth.wcswidth(", ".join([f"'{t}'" for t in naive_tokens[:token_idx]])) + wcwidth.wcswidth(token)) + (0 if token_idx == 0 else 2) enter_pressed = False idx_blob = "" stdscr.clear() stdscr.addstr(f"""Example {exp_idx}: {"\n".join([f"{pad_to_desired_len(k)}{v}" for k, v in new_exp.items()])} {pad_to_desired_len(feat_name)}{labels} Labels: {", ".join([f"{i}:{l}" for i, l in enumerate(feat_labels)])} {naive_tokens} {" " * padding_len}^ {" " * padding_len}{token_idx} : """) while not esc_pressed and not enter_pressed: ch = stdscr.getch() if ch in {8, 127, curses.KEY_BACKSPACE}: # Delete idx_blob = idx_blob[:-1] y, x = stdscr.getyx() next_x = x - 1 if next_x > 1: stdscr.move(y, x - 1) stdscr.clrtoeol() stdscr.refresh() elif ch == 27: # Esc esc_pressed = True elif ch in {10, curses.KEY_ENTER}: # Enter enter_pressed = True else: # Otherwise, add the character to the string ch_chr = chr(ch) stdscr.addstr(ch_chr) idx_blob += ch_chr if not idx_blob: label_blob = idx_blob if idx_blob else "O" labels.append(label_blob) elif ">" in idx_blob: try: idx_blob, skip_distance = idx_blob.split(">") if idx_blob: label_idx = int(idx_blob) if 0 <= label_idx < feat_labels_len: label_blob = feat_labels[label_idx] labels.append(label_blob) skip_label = label_blob else: labels.append("O") if skip_distance: skip_to_idx = token_idx + int(skip_distance) if skip_to_idx > last_idx: skip_to_idx = last_idx else: skip_to_idx = last_idx except ValueError: stdscr.addstr(f"Could not convert {idx_blob} to an integer idx value.") else: try: label_idx = int(idx_blob) if 0 <= label_idx < feat_labels_len: label_blob = feat_labels[label_idx] labels.append(label_blob) except ValueError: stdscr.addstr(f"Could not convert {idx_blob} to an integer idx value.") stdscr.clear() stdscr.addstr(f"""Example {exp_idx}: {"\n".join([f"{pad_to_desired_len(k)}{v}" for k, v in new_exp.items()])} {pad_to_desired_len(feat_name)}{labels} Press 'y' to accept or anything else to reject. Press Esc to exit.""") ch = stdscr.getch() stdscr.clear() if ch == 27: # Esc esc_pressed = True elif ch == ord("y"): new_exp[feat_name] = labels labels_accepted = True if esc_pressed: break # Add if complete new_exp["sig"] = sig if sorted(new_exp.keys()) == sorted(new_dataset_dict.keys()): for k, v in new_exp.items(): new_dataset_dict[k].append(v) # Exiting stdscr.clear() return Dataset.from_dict(new_dataset_dict) def pad_to_desired_len(blob: str, desired: int = 15): blob_len = len(blob) if blob_len < desired: return f"{blob}{' ' * (desired - blob_len)}" return blob def show_examples(ds: Dataset, show_expr: Optional[str]): if not show_expr: ds_len = len(ds) count_to_show = ds_len if ds_len < 25 else 25 examples_to_show = ds.shuffle()[:count_to_show] else: args_show_tokens = show_expr.split("/") col_to_show, label_to_show, count_to_show = args_show_tokens count_to_show = int(count_to_show) examples_to_show = ds.filter( lambda exp: label_to_show in exp[col_to_show]).shuffle(seed=42)[:count_to_show] for i in range(count_to_show): logger.info(f"Example {i}:") for feature in examples_to_show.keys(): logger.info(f" {feature}: {examples_to_show[feature][i]}") if __name__ == "__main__": import argparse import logging.config arg_parser = argparse.ArgumentParser(description="Train multi-task model.") arg_parser.add_argument("--redo", help="Redo example based on signature", action="store", default=None) arg_parser.add_argument("--replace", help="Replace a label using a sig, col, idx, and new label", action="store", default=None) arg_parser.add_argument("--show", help="Show examples: