veryfansome commited on
Commit
5c7120a
·
1 Parent(s): c5081c8

feat: unified args

Browse files
Files changed (1) hide show
  1. dataset_maker.py +9 -7
dataset_maker.py CHANGED
@@ -103,12 +103,13 @@ prompts = {
103
  "wh": f"its semantic role",
104
  }
105
 
106
- async def classify_tokens(prompt: str, labels: dict[str, str], tokens: list[str], model="gpt-4o"):
 
107
  tok_len = len(tokens)
108
  example = "[" + (", ".join([f'"{tok}"' for tok in tokens])) + "]"
109
  try:
110
  response = await client.chat.completions.create(
111
- model=model, timeout=30,
112
  **({"reasoning_effort": "low"} if model.startswith("o") else {"presence_penalty": 0, "temperature": 0}),
113
  messages=[
114
  {
@@ -166,20 +167,21 @@ async def classify_tokens(prompt: str, labels: dict[str, str], tokens: list[str]
166
  raise
167
 
168
 
169
- async def classify_with_retry(prompt, labels, tokens, model="gpt-4o", retry=10):
170
  for i in range(retry):
171
  try:
172
- return await classify_tokens(prompt, labels, tokens, model=model)
173
  except Exception as e:
174
  logger.error(f"attempt {i} failed {tokens} {prompt} {format_exc()}")
175
  await asyncio.sleep(i)
176
 
177
- async def generate_token_labels(case, model="gpt-4o"):
 
178
  tokens = sp_tokenize(case)
179
  sorted_cols = list(sorted(features.keys()))
180
  example = {}
181
  for idx, labels in enumerate(list(await asyncio.gather(
182
- *[classify_with_retry(prompts[col], features[col], tokens, model=model) for col in sorted_cols]))):
183
  example[sorted_cols[idx]] = labels
184
  return example
185
 
@@ -229,7 +231,7 @@ async def main(args, cases):
229
  while len([t for t in tasks if t is not None]) >= max_concurrent_tasks:
230
  await asyncio.sleep(1)
231
  logger.info(f"scheduling case {case}")
232
- tasks.append(asyncio.create_task(generate_token_labels(case, model=args.openai_model)))
233
 
234
  # Block until done
235
  while len([t for t in tasks if t is not None]) > 0:
 
103
  "wh": f"its semantic role",
104
  }
105
 
106
+ async def classify_tokens(args, prompt: str, labels: dict[str, str], tokens: list[str],
107
+ model="gpt-4o"):
108
  tok_len = len(tokens)
109
  example = "[" + (", ".join([f'"{tok}"' for tok in tokens])) + "]"
110
  try:
111
  response = await client.chat.completions.create(
112
+ model=args.openai_model, timeout=30,
113
  **({"reasoning_effort": "low"} if model.startswith("o") else {"presence_penalty": 0, "temperature": 0}),
114
  messages=[
115
  {
 
167
  raise
168
 
169
 
170
+ async def classify_with_retry(args, prompt, labels, tokens, retry=10):
171
  for i in range(retry):
172
  try:
173
+ return await classify_tokens(args, prompt, labels, tokens)
174
  except Exception as e:
175
  logger.error(f"attempt {i} failed {tokens} {prompt} {format_exc()}")
176
  await asyncio.sleep(i)
177
 
178
+
179
+ async def generate_token_labels(args, case):
180
  tokens = sp_tokenize(case)
181
  sorted_cols = list(sorted(features.keys()))
182
  example = {}
183
  for idx, labels in enumerate(list(await asyncio.gather(
184
+ *[classify_with_retry(args, prompts[col], features[col], tokens) for col in sorted_cols]))):
185
  example[sorted_cols[idx]] = labels
186
  return example
187
 
 
231
  while len([t for t in tasks if t is not None]) >= max_concurrent_tasks:
232
  await asyncio.sleep(1)
233
  logger.info(f"scheduling case {case}")
234
+ tasks.append(asyncio.create_task(generate_token_labels(args, case)))
235
 
236
  # Block until done
237
  while len([t for t in tasks if t is not None]) > 0: