multi-classifier / openai_dataset_maker.py
veryfansome's picture
feat: UD is back, LlaMA play
0cdb887
raw
history blame
13.4 kB
from asyncio import Task
from datasets import Dataset, load_dataset
from datetime import datetime
from openai import AsyncOpenAI
from traceback import format_exc
from typing import Union
import asyncio
import json
import logging
from utils import default_logging_config
client = AsyncOpenAI()
logger = logging.getLogger(__name__)
features = {
"adj": {"JJ": "adjective",
"JJR": "comparative adjective",
"JJS": "superlative adjective",
"O": "out-of-scope"},
"adv": {"RB": "adverb",
"RBR": "comparative adverb",
"RBS": "superlative adverb",
"O": "out-of-scope"},
"det": {"DT": "articles, demonstratives, and other determiners",
"EX": "existential 'there'",
"PDT": "predeterminer before a determiner to modify a noun phrase",
"O": "out-of-scope"},
"enc": {"BRACKET": "in or contains bracket wrapped text",
"QUOTE": "in or contains quote wrapped text",
"TICK": "in or contains backtick wrapped text",
"O": "out-of-scope"},
"func": {"CC": "coordinating conjunction",
"IN": "preposition or subordinating conjunction",
"RP": "particle",
"TO": "to",
"UH": "interjection",
"O": "out-of-scope"},
"misc": {"$": "currency",
"ADD": "address, URLs, usernames, or other non-lexical representations of places or entities",
"CD": "cardinal numbers",
"EMOJI": "emoji",
"TIME": "date or time",
"O": "out-of-scope"},
"ner1": {"B-GPE": "beginning of geopolitical entities",
"I-GPE": "inside of geopolitical entities",
"B-ORG": "beginning of organization",
"I-ORG": "inside of organization",
"B-PER": "beginning of person",
"I-PER": "inside of person",
"O": "out-of-scope"},
"ner2": {"B-EVENT": "beginning of event",
"I-EVENT": "inside of event",
"B-LOC": "beginning of location",
"I-LOC": "inside of location",
"O": "out-of-scope"},
"noun": {"NN": "common noun singular",
"NNS": "common noun plural",
"NNP": "proper noun singular",
"NNPS": "proper noun plural",
"O": "out-of-scope" },
"pronoun": {"POS": "possessive ending like the 's",
"PRP$": "possessive pronoun",
"PRP": "personal pronoun",
"O": "out-of-scope"},
"punct": {"COLON": "colon or semicolon",
"COMMA": "comma",
"EXCLAIM": "exclamation mark",
"HYPH": "dash or hyphen",
"LS": "list item marker",
"PERIOD": "period",
"QUESTION": "question mark",
"SEP": "any section separator",
"O": "out-of-scope"},
"verb": {"MD": "modal verb",
"VB": "verb base form",
"VBD": "verb past tense",
"VBG": "present participle, gerund",
"VBN": "past participle",
"VBP": "non-3rd person singular present",
"VBZ": "3rd person singular present",
"O": "out-of-scope"},
"wh": {"WDT": "Wh-determiner",
"WP$": "Wh-possessive pronoun",
"WP": "Wh-pronoun",
"WRB": "Wh-adverb",
"O": "out-of-scope"},
}
prompts = {
"adj": f"its semantic role",
"adv": f"its semantic role",
"det": f"its semantic role",
"enc": f"its sentence chunk classification",
"func": f"its semantic role",
"misc": f"its semantic role",
"ner1": f"its NER classification",
"ner2": f"its NER classification",
"noun": f"its semantic role",
"pronoun": f"its semantic role",
"punct": f"the punctuation classes it contains",
"verb": f"its semantic role",
"wh": f"its semantic role",
}
async def classify_tokens(args, prompt: str, labels: dict[str, str], tokens: list[str],
model="gpt-4o"):
tok_len = len(tokens)
example = "[" + (", ".join([f'"{tok}"' for tok in tokens])) + "]"
try:
response = await client.chat.completions.create(
model=args.openai_model, timeout=30,
**({"reasoning_effort": "low"} if model.startswith("o") else {"presence_penalty": 0, "temperature": 0}),
messages=[
{
"role": "system",
"content": (
"Analyze the user provided sequence. Consider each string's semantic role "
f"in the given sequence, then return a list of {tok_len} label strings. "
f"Generate no more than {tok_len} labels. "
"When typos or out-of-order words are provided, infer the intended meaning."
),
},
{
"role": "system",
"content": f"Labels: {labels}",
},
{
"role": "user",
"content": example,
},
{
"role": "user",
"content": (f"Replace each of the {tok_len} given strings with one of the following labels "
f"that best describes {prompt}: {sorted(labels.keys())}"),
},
],
response_format={
"type": "json_schema",
"json_schema": {
"name": "labels",
"strict": True,
"schema": {
"type": "object",
"properties": {
"labels": {
"type": "array",
"description": f"List of {tok_len} labels, one for each string from the user's sequence.",
"items": {
"type": "string",
}
}
},
"additionalProperties": False,
"required": ["labels"]
}
}
},
)
except Exception as e:
logger.error(f"openai call failed: {format_exc()}")
raise
try:
return json.loads(response.choices[0].message.content)["labels"]
except Exception as e:
logger.error(f"response: {response.choices[0].message} {format_exc()}")
raise
async def classify_with_retry(args, prompt, labels, tokens, retry=10):
for i in range(retry):
try:
return await classify_tokens(args, prompt, labels, tokens)
except Exception as e:
logger.error(f"attempt {i} failed {tokens} {prompt} {format_exc()}")
await asyncio.sleep(i)
async def generate_token_labels(args, case):
tokens = case.split()
sorted_cols = list(sorted(features.keys()))
example = {}
for idx, labels in enumerate(list(await asyncio.gather(
*[classify_with_retry(args, prompts[col], features[col], tokens) for col in sorted_cols]))):
example[sorted_cols[idx]] = labels
return example
async def main(args, cases):
ds_dict = {k: [] for k in features.keys()}
ts_run = datetime.now().strftime("%Y%m%d")
drain_completed = False
max_concurrent_tasks = 15
tasks: list[Union[Task, None]] = []
async def checkpoint_task():
tick_cnt = 0
while not drain_completed:
tick_cnt += 1
if tick_cnt % 600 == 0:
# checkpoint
_ds = Dataset.from_dict(ds_dict)
_ds.save_to_disk(f"{args.save_path}/{args.openai_model}_{ts_run}_checkpoint")
logger.info(f"\n{_ds}")
await asyncio.sleep(1)
future_checkpoint_task_completion = asyncio.create_task(checkpoint_task())
async def drain_tasks():
while not drain_completed:
for idx, task in enumerate(tasks):
if task is None:
continue
try:
logger.info(f"attempting Example {idx}")
example = await asyncio.wait_for(task, timeout=180)
for col, labels in example.items():
logger.info(f" {col}: {labels}")
ds_dict[col].append(example[col])
tasks[idx] = None
except asyncio.exceptions.TimeoutError:
logger.warning(f"attempt to wait_for Example {idx} timed out after 10 seconds.")
except Exception as e:
logger.error(f"attempt to generate Example {idx} failed.\n{format_exc()}")
tasks[idx] = None
raise
await asyncio.sleep(1)
future_drain_completion = asyncio.create_task(drain_tasks())
for case in cases:
while len([t for t in tasks if t is not None]) >= max_concurrent_tasks:
await asyncio.sleep(1)
logger.info(f"scheduling case {case}")
tasks.append(asyncio.create_task(generate_token_labels(args, case)))
# Block until done
while len([t for t in tasks if t is not None]) > 0:
logger.info(f"waiting on {len([t for t in tasks if t is not None])} tasks")
await asyncio.sleep(1)
drain_completed = True
await future_drain_completion
await future_checkpoint_task_completion
ds = Dataset.from_dict(ds_dict)
ds.save_to_disk(f"{args.save_path}/{args.openai_model}_{ts_run}")
logger.info(f"\n{ds}")
if __name__ == "__main__":
import argparse
import logging.config
logging.config.dictConfig(default_logging_config)
arg_parser = argparse.ArgumentParser(description="Train multi-task model.")
arg_parser.add_argument("--openai-model", help="OpenAI model.",
action="store", default="o3-mini", choices=["gpt-4o", "o3-mini", "o1"])
arg_parser.add_argument("--save-path", help="Save final dataset to specified path.",
action="store", default="./dataset")
arg_parser.add_argument("--ud", help='Use UD datasets.',
action="store_true", default=False)
parsed_args = arg_parser.parse_args()
all_text = []
if parsed_args.ud:
ud_en_ewt_ds = load_dataset("universal_dependencies", "en_ewt")
all_text += ud_en_ewt_ds["train"]["text"]
all_text += ud_en_ewt_ds["validation"]["text"]
all_text += ud_en_ewt_ds["test"]["text"]
ud_en_gum_ds = load_dataset("universal_dependencies", "en_gum")
all_text += ud_en_gum_ds["train"]["text"]
all_text += ud_en_gum_ds["validation"]["text"]
all_text += ud_en_gum_ds["test"]["text"]
ud_en_lines_ds = load_dataset("universal_dependencies", "en_lines")
all_text += ud_en_lines_ds["train"]["text"]
all_text += ud_en_lines_ds["validation"]["text"]
all_text += ud_en_lines_ds["test"]["text"]
ud_en_partut_ds = load_dataset("universal_dependencies", "en_partut")
all_text += ud_en_partut_ds["train"]["text"]
all_text += ud_en_partut_ds["validation"]["text"]
all_text += ud_en_partut_ds["test"]["text"]
ud_en_pronouns_ds = load_dataset("universal_dependencies", "en_pronouns")
all_text += ud_en_pronouns_ds["test"]["text"]
ud_en_pud_ds = load_dataset("universal_dependencies", "en_pud")
all_text += ud_en_pud_ds["test"]["text"]
logger.info(f"{len(all_text)} UD examples")
all_text += [
"Hello world!",
"127.0.0.1 is the localhost address.",
"1/2 is equivalent to 0.5 or 50%",
"John was running so fast, you can just tell he's a runner.",
"He excels at math and competed in the Math Olympiad",
"They're only $5!",
"Where is your sense of adventure?",
"I have only 3 cents.",
"Watson was on his way to 221B Baker Street when the robbery occurred.",
"That's uncopyrightable.",
"She's full of incomprehensibilities.",
"He's a total sesquipedalian.",
"you piece of SHIT!!",
"uh........... what..",
"Steph Curry is GOAT!!",
"[click here!](http://www.google.com)",
"Dude! The stock's value grew like 10x in a year!",
"Yea, I was at the DMV - God what a shit show!",
"Send an email to help@example.com",
"@goober, take your question to #corp-help-desk",
"Example 1 : Joe Shmoe has a big toe.",
"1. Steal under-pants. 2. ... 3. Profit!",
"I expect `len(word_list) == 3`",
"Home | Shop | Contact Us",
"This is me on cake <(^.^)>",
"and then he fell right on his face 😂",
"Putin is from Russia.",
"Zelenskyy is Ukrainian",
"In 2013, the Pentagon and other agencies officially acknowledged the existence of Area-51.",
"The Freedom of Information Act gives us the right to request access to records from any federal agency.",
"His motives here are totally sus",
"Yea, he finished by doing the Dab",
"Be back i'll",
"hes got a bon to pick",
"so then he says then he says, \"you'll regret this\" lol",
]
asyncio.run(main(parsed_args, all_text))