| from asyncio import Task |
| from datasets import Dataset, load_dataset |
| from datetime import datetime |
| from openai import AsyncOpenAI |
| from traceback import format_exc |
| from typing import Union |
| import asyncio |
| import json |
| import logging |
|
|
| from utils import default_logging_config |
|
|
| client = AsyncOpenAI() |
| logger = logging.getLogger(__name__) |
|
|
| features = { |
| "adj": {"JJ": "adjective", |
| "JJR": "comparative adjective", |
| "JJS": "superlative adjective", |
| "O": "out-of-scope"}, |
| "adv": {"RB": "adverb", |
| "RBR": "comparative adverb", |
| "RBS": "superlative adverb", |
| "O": "out-of-scope"}, |
| "det": {"DT": "articles, demonstratives, and other determiners", |
| "EX": "existential 'there'", |
| "PDT": "predeterminer before a determiner to modify a noun phrase", |
| "O": "out-of-scope"}, |
| "enc": {"BRACKET": "in or contains bracket wrapped text", |
| "QUOTE": "in or contains quote wrapped text", |
| "TICK": "in or contains backtick wrapped text", |
| "O": "out-of-scope"}, |
| "func": {"CC": "coordinating conjunction", |
| "IN": "preposition or subordinating conjunction", |
| "RP": "particle", |
| "TO": "to", |
| "UH": "interjection", |
| "O": "out-of-scope"}, |
| "misc": {"$": "currency", |
| "ADD": "address, URLs, usernames, or other non-lexical representations of places or entities", |
| "CD": "cardinal numbers", |
| "EMOJI": "emoji", |
| "TIME": "date or time", |
| "O": "out-of-scope"}, |
| "ner1": {"B-GPE": "beginning of geopolitical entities", |
| "I-GPE": "inside of geopolitical entities", |
| "B-ORG": "beginning of organization", |
| "I-ORG": "inside of organization", |
| "B-PER": "beginning of person", |
| "I-PER": "inside of person", |
| "O": "out-of-scope"}, |
| "ner2": {"B-EVENT": "beginning of event", |
| "I-EVENT": "inside of event", |
| "B-LOC": "beginning of location", |
| "I-LOC": "inside of location", |
| "O": "out-of-scope"}, |
| "noun": {"NN": "common noun singular", |
| "NNS": "common noun plural", |
| "NNP": "proper noun singular", |
| "NNPS": "proper noun plural", |
| "O": "out-of-scope" }, |
| "pronoun": {"POS": "possessive ending like the 's", |
| "PRP$": "possessive pronoun", |
| "PRP": "personal pronoun", |
| "O": "out-of-scope"}, |
| "punct": {"COLON": "colon or semicolon", |
| "COMMA": "comma", |
| "EXCLAIM": "exclamation mark", |
| "HYPH": "dash or hyphen", |
| "LS": "list item marker", |
| "PERIOD": "period", |
| "QUESTION": "question mark", |
| "SEP": "any section separator", |
| "O": "out-of-scope"}, |
| "verb": {"MD": "modal verb", |
| "VB": "verb base form", |
| "VBD": "verb past tense", |
| "VBG": "present participle, gerund", |
| "VBN": "past participle", |
| "VBP": "non-3rd person singular present", |
| "VBZ": "3rd person singular present", |
| "O": "out-of-scope"}, |
| "wh": {"WDT": "Wh-determiner", |
| "WP$": "Wh-possessive pronoun", |
| "WP": "Wh-pronoun", |
| "WRB": "Wh-adverb", |
| "O": "out-of-scope"}, |
| } |
|
|
| prompts = { |
| "adj": f"its semantic role", |
| "adv": f"its semantic role", |
| "det": f"its semantic role", |
| "enc": f"its sentence chunk classification", |
| "func": f"its semantic role", |
| "misc": f"its semantic role", |
| "ner1": f"its NER classification", |
| "ner2": f"its NER classification", |
| "noun": f"its semantic role", |
| "pronoun": f"its semantic role", |
| "punct": f"the punctuation classes it contains", |
| "verb": f"its semantic role", |
| "wh": f"its semantic role", |
| } |
|
|
| async def classify_tokens(args, prompt: str, labels: dict[str, str], tokens: list[str], |
| model="gpt-4o"): |
| tok_len = len(tokens) |
| example = "[" + (", ".join([f'"{tok}"' for tok in tokens])) + "]" |
| try: |
| response = await client.chat.completions.create( |
| model=args.openai_model, timeout=30, |
| **({"reasoning_effort": "low"} if model.startswith("o") else {"presence_penalty": 0, "temperature": 0}), |
| messages=[ |
| { |
| "role": "system", |
| "content": ( |
| "Analyze the user provided sequence. Consider each string's semantic role " |
| f"in the given sequence, then return a list of {tok_len} label strings. " |
| f"Generate no more than {tok_len} labels. " |
| "When typos or out-of-order words are provided, infer the intended meaning." |
| ), |
| }, |
| { |
| "role": "system", |
| "content": f"Labels: {labels}", |
| }, |
| { |
| "role": "user", |
| "content": example, |
| }, |
| { |
| "role": "user", |
| "content": (f"Replace each of the {tok_len} given strings with one of the following labels " |
| f"that best describes {prompt}: {sorted(labels.keys())}"), |
| }, |
| ], |
| response_format={ |
| "type": "json_schema", |
| "json_schema": { |
| "name": "labels", |
| "strict": True, |
| "schema": { |
| "type": "object", |
| "properties": { |
| "labels": { |
| "type": "array", |
| "description": f"List of {tok_len} labels, one for each string from the user's sequence.", |
| "items": { |
| "type": "string", |
| } |
| } |
| }, |
| "additionalProperties": False, |
| "required": ["labels"] |
| } |
| } |
| }, |
| ) |
| except Exception as e: |
| logger.error(f"openai call failed: {format_exc()}") |
| raise |
| try: |
| return json.loads(response.choices[0].message.content)["labels"] |
| except Exception as e: |
| logger.error(f"response: {response.choices[0].message} {format_exc()}") |
| raise |
|
|
|
|
| async def classify_with_retry(args, prompt, labels, tokens, retry=10): |
| for i in range(retry): |
| try: |
| return await classify_tokens(args, prompt, labels, tokens) |
| except Exception as e: |
| logger.error(f"attempt {i} failed {tokens} {prompt} {format_exc()}") |
| await asyncio.sleep(i) |
|
|
|
|
| async def generate_token_labels(args, case): |
| tokens = case.split() |
| sorted_cols = list(sorted(features.keys())) |
| example = {} |
| for idx, labels in enumerate(list(await asyncio.gather( |
| *[classify_with_retry(args, prompts[col], features[col], tokens) for col in sorted_cols]))): |
| example[sorted_cols[idx]] = labels |
| return example |
|
|
|
|
| async def main(args, cases): |
| ds_dict = {k: [] for k in features.keys()} |
|
|
| ts_run = datetime.now().strftime("%Y%m%d") |
| drain_completed = False |
| max_concurrent_tasks = 15 |
| tasks: list[Union[Task, None]] = [] |
|
|
| async def checkpoint_task(): |
| tick_cnt = 0 |
| while not drain_completed: |
| tick_cnt += 1 |
| if tick_cnt % 600 == 0: |
| |
| _ds = Dataset.from_dict(ds_dict) |
| _ds.save_to_disk(f"{args.save_path}/{args.openai_model}_{ts_run}_checkpoint") |
| logger.info(f"\n{_ds}") |
| await asyncio.sleep(1) |
| future_checkpoint_task_completion = asyncio.create_task(checkpoint_task()) |
|
|
| async def drain_tasks(): |
| while not drain_completed: |
| for idx, task in enumerate(tasks): |
| if task is None: |
| continue |
| try: |
| logger.info(f"attempting Example {idx}") |
| example = await asyncio.wait_for(task, timeout=180) |
| for col, labels in example.items(): |
| logger.info(f" {col}: {labels}") |
| ds_dict[col].append(example[col]) |
| tasks[idx] = None |
| except asyncio.exceptions.TimeoutError: |
| logger.warning(f"attempt to wait_for Example {idx} timed out after 10 seconds.") |
| except Exception as e: |
| logger.error(f"attempt to generate Example {idx} failed.\n{format_exc()}") |
| tasks[idx] = None |
| raise |
| await asyncio.sleep(1) |
| future_drain_completion = asyncio.create_task(drain_tasks()) |
|
|
| for case in cases: |
| while len([t for t in tasks if t is not None]) >= max_concurrent_tasks: |
| await asyncio.sleep(1) |
| logger.info(f"scheduling case {case}") |
| tasks.append(asyncio.create_task(generate_token_labels(args, case))) |
|
|
| |
| while len([t for t in tasks if t is not None]) > 0: |
| logger.info(f"waiting on {len([t for t in tasks if t is not None])} tasks") |
| await asyncio.sleep(1) |
|
|
| drain_completed = True |
| await future_drain_completion |
| await future_checkpoint_task_completion |
|
|
| ds = Dataset.from_dict(ds_dict) |
| ds.save_to_disk(f"{args.save_path}/{args.openai_model}_{ts_run}") |
| logger.info(f"\n{ds}") |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| import logging.config |
|
|
| logging.config.dictConfig(default_logging_config) |
|
|
| arg_parser = argparse.ArgumentParser(description="Train multi-task model.") |
| arg_parser.add_argument("--openai-model", help="OpenAI model.", |
| action="store", default="o3-mini", choices=["gpt-4o", "o3-mini", "o1"]) |
| arg_parser.add_argument("--save-path", help="Save final dataset to specified path.", |
| action="store", default="./dataset") |
| arg_parser.add_argument("--ud", help='Use UD datasets.', |
| action="store_true", default=False) |
| parsed_args = arg_parser.parse_args() |
|
|
| all_text = [] |
|
|
| if parsed_args.ud: |
| ud_en_ewt_ds = load_dataset("universal_dependencies", "en_ewt") |
| all_text += ud_en_ewt_ds["train"]["text"] |
| all_text += ud_en_ewt_ds["validation"]["text"] |
| all_text += ud_en_ewt_ds["test"]["text"] |
|
|
| ud_en_gum_ds = load_dataset("universal_dependencies", "en_gum") |
| all_text += ud_en_gum_ds["train"]["text"] |
| all_text += ud_en_gum_ds["validation"]["text"] |
| all_text += ud_en_gum_ds["test"]["text"] |
|
|
| ud_en_lines_ds = load_dataset("universal_dependencies", "en_lines") |
| all_text += ud_en_lines_ds["train"]["text"] |
| all_text += ud_en_lines_ds["validation"]["text"] |
| all_text += ud_en_lines_ds["test"]["text"] |
|
|
| ud_en_partut_ds = load_dataset("universal_dependencies", "en_partut") |
| all_text += ud_en_partut_ds["train"]["text"] |
| all_text += ud_en_partut_ds["validation"]["text"] |
| all_text += ud_en_partut_ds["test"]["text"] |
|
|
| ud_en_pronouns_ds = load_dataset("universal_dependencies", "en_pronouns") |
| all_text += ud_en_pronouns_ds["test"]["text"] |
|
|
| ud_en_pud_ds = load_dataset("universal_dependencies", "en_pud") |
| all_text += ud_en_pud_ds["test"]["text"] |
|
|
| logger.info(f"{len(all_text)} UD examples") |
|
|
| all_text += [ |
| "Hello world!", |
| "127.0.0.1 is the localhost address.", |
| "1/2 is equivalent to 0.5 or 50%", |
| "John was running so fast, you can just tell he's a runner.", |
| "He excels at math and competed in the Math Olympiad", |
| "They're only $5!", |
| "Where is your sense of adventure?", |
| "I have only 3 cents.", |
| "Watson was on his way to 221B Baker Street when the robbery occurred.", |
| "That's uncopyrightable.", |
| "She's full of incomprehensibilities.", |
| "He's a total sesquipedalian.", |
| "you piece of SHIT!!", |
| "uh........... what..", |
| "Steph Curry is GOAT!!", |
| "[click here!](http://www.google.com)", |
| "Dude! The stock's value grew like 10x in a year!", |
| "Yea, I was at the DMV - God what a shit show!", |
| "Send an email to help@example.com", |
| "@goober, take your question to #corp-help-desk", |
| "Example 1 : Joe Shmoe has a big toe.", |
| "1. Steal under-pants. 2. ... 3. Profit!", |
| "I expect `len(word_list) == 3`", |
| "Home | Shop | Contact Us", |
| "This is me on cake <(^.^)>", |
| "and then he fell right on his face 😂", |
| "Putin is from Russia.", |
| "Zelenskyy is Ukrainian", |
| "In 2013, the Pentagon and other agencies officially acknowledged the existence of Area-51.", |
| "The Freedom of Information Act gives us the right to request access to records from any federal agency.", |
| "His motives here are totally sus", |
| "Yea, he finished by doing the Dab", |
| "Be back i'll", |
| "hes got a bon to pick", |
| "so then he says then he says, \"you'll regret this\" lol", |
| ] |
| asyncio.run(main(parsed_args, all_text)) |
|
|
|
|