Commit
·
5c7120a
1
Parent(s):
c5081c8
feat: unified args
Browse files- dataset_maker.py +9 -7
dataset_maker.py
CHANGED
|
@@ -103,12 +103,13 @@ prompts = {
|
|
| 103 |
"wh": f"its semantic role",
|
| 104 |
}
|
| 105 |
|
| 106 |
-
async def classify_tokens(prompt: str, labels: dict[str, str], tokens: list[str],
|
|
|
|
| 107 |
tok_len = len(tokens)
|
| 108 |
example = "[" + (", ".join([f'"{tok}"' for tok in tokens])) + "]"
|
| 109 |
try:
|
| 110 |
response = await client.chat.completions.create(
|
| 111 |
-
model=
|
| 112 |
**({"reasoning_effort": "low"} if model.startswith("o") else {"presence_penalty": 0, "temperature": 0}),
|
| 113 |
messages=[
|
| 114 |
{
|
|
@@ -166,20 +167,21 @@ async def classify_tokens(prompt: str, labels: dict[str, str], tokens: list[str]
|
|
| 166 |
raise
|
| 167 |
|
| 168 |
|
| 169 |
-
async def classify_with_retry(prompt, labels, tokens,
|
| 170 |
for i in range(retry):
|
| 171 |
try:
|
| 172 |
-
return await classify_tokens(prompt, labels, tokens
|
| 173 |
except Exception as e:
|
| 174 |
logger.error(f"attempt {i} failed {tokens} {prompt} {format_exc()}")
|
| 175 |
await asyncio.sleep(i)
|
| 176 |
|
| 177 |
-
|
|
|
|
| 178 |
tokens = sp_tokenize(case)
|
| 179 |
sorted_cols = list(sorted(features.keys()))
|
| 180 |
example = {}
|
| 181 |
for idx, labels in enumerate(list(await asyncio.gather(
|
| 182 |
-
*[classify_with_retry(prompts[col], features[col], tokens
|
| 183 |
example[sorted_cols[idx]] = labels
|
| 184 |
return example
|
| 185 |
|
|
@@ -229,7 +231,7 @@ async def main(args, cases):
|
|
| 229 |
while len([t for t in tasks if t is not None]) >= max_concurrent_tasks:
|
| 230 |
await asyncio.sleep(1)
|
| 231 |
logger.info(f"scheduling case {case}")
|
| 232 |
-
tasks.append(asyncio.create_task(generate_token_labels(
|
| 233 |
|
| 234 |
# Block until done
|
| 235 |
while len([t for t in tasks if t is not None]) > 0:
|
|
|
|
| 103 |
"wh": f"its semantic role",
|
| 104 |
}
|
| 105 |
|
| 106 |
+
async def classify_tokens(args, prompt: str, labels: dict[str, str], tokens: list[str],
|
| 107 |
+
model="gpt-4o"):
|
| 108 |
tok_len = len(tokens)
|
| 109 |
example = "[" + (", ".join([f'"{tok}"' for tok in tokens])) + "]"
|
| 110 |
try:
|
| 111 |
response = await client.chat.completions.create(
|
| 112 |
+
model=args.openai_model, timeout=30,
|
| 113 |
**({"reasoning_effort": "low"} if model.startswith("o") else {"presence_penalty": 0, "temperature": 0}),
|
| 114 |
messages=[
|
| 115 |
{
|
|
|
|
| 167 |
raise
|
| 168 |
|
| 169 |
|
| 170 |
+
async def classify_with_retry(args, prompt, labels, tokens, retry=10):
|
| 171 |
for i in range(retry):
|
| 172 |
try:
|
| 173 |
+
return await classify_tokens(args, prompt, labels, tokens)
|
| 174 |
except Exception as e:
|
| 175 |
logger.error(f"attempt {i} failed {tokens} {prompt} {format_exc()}")
|
| 176 |
await asyncio.sleep(i)
|
| 177 |
|
| 178 |
+
|
| 179 |
+
async def generate_token_labels(args, case):
|
| 180 |
tokens = sp_tokenize(case)
|
| 181 |
sorted_cols = list(sorted(features.keys()))
|
| 182 |
example = {}
|
| 183 |
for idx, labels in enumerate(list(await asyncio.gather(
|
| 184 |
+
*[classify_with_retry(args, prompts[col], features[col], tokens) for col in sorted_cols]))):
|
| 185 |
example[sorted_cols[idx]] = labels
|
| 186 |
return example
|
| 187 |
|
|
|
|
| 231 |
while len([t for t in tasks if t is not None]) >= max_concurrent_tasks:
|
| 232 |
await asyncio.sleep(1)
|
| 233 |
logger.info(f"scheduling case {case}")
|
| 234 |
+
tasks.append(asyncio.create_task(generate_token_labels(args, case)))
|
| 235 |
|
| 236 |
# Block until done
|
| 237 |
while len([t for t in tasks if t is not None]) > 0:
|