ping98k commited on
Commit
348dd93
·
unverified ·
2 Parent(s): f801763 41ec98d

Merge pull request #10 from ping98k/codex/add-id-to-each-player

Browse files
Files changed (1) hide show
  1. main.py +17 -12
main.py CHANGED
@@ -1,5 +1,5 @@
1
  from dotenv import load_dotenv
2
- load_dotenv("./local.env")
3
  import os, json, re, ast, gradio as gr
4
  from concurrent.futures import ThreadPoolExecutor, as_completed
5
  from tqdm import tqdm
@@ -30,7 +30,7 @@ class SimpleProgress:
30
  return f"{self.prefix} {self.count}/{self.total} - ETA {eta}"
31
 
32
  NUM_TOP_PICKS_DEFAULT = int(os.getenv("NUM_TOP_PICKS", 3))
33
- POOL_SIZE_DEFAULT = int(os.getenv("POOL_SIZE", 5))
34
  MAX_WORKERS_DEFAULT = int(os.getenv("MAX_WORKERS", 100))
35
  NUM_GENERATIONS_DEFAULT = int(os.getenv("NUM_GENERATIONS", 10))
36
  API_BASE_DEFAULT = os.getenv("OPENAI_API_BASE", "")
@@ -40,9 +40,9 @@ PAIRWISE_FILTER_DEFAULT = os.getenv("ENABLE_PAIRWISE_FILTER", "true").lower() ==
40
  GENERATE_MODEL_DEFAULT = os.getenv("GENERATE_MODEL", "gpt-4o-mini")
41
  SCORE_MODEL_DEFAULT = os.getenv("SCORE_MODEL", "gpt-4o-mini")
42
  PAIRWISE_MODEL_DEFAULT = os.getenv("PAIRWISE_MODEL", "gpt-4o-mini")
43
- GENERATE_TEMPERATURE_DEFAULT = float(os.getenv("GENERATE_TEMPERATURE", "1.0"))
44
- SCORE_TEMPERATURE_DEFAULT = float(os.getenv("SCORE_TEMPERATURE", "0.1"))
45
- PAIRWISE_TEMPERATURE_DEFAULT = float(os.getenv("PAIRWISE_TEMPERATURE", "0.1"))
46
  SCORE_WITH_INSTRUCTION_DEFAULT = os.getenv("PASS_INSTRUCTION_TO_SCORE", "true").lower() == "true"
47
  PAIRWISE_WITH_INSTRUCTION_DEFAULT = os.getenv("PASS_INSTRUCTION_TO_PAIRWISE", "true").lower() == "true"
48
  CRITERIA_DEFAULT = "Factuality,Instruction Following,Precision"
@@ -131,10 +131,12 @@ def run_tournament(
131
  f"Total tokens: {prompt_tokens + completion_tokens}"
132
  )
133
 
134
- def log_completion(prefix: str, text: str):
135
  disp = text.replace("\n", " ")
136
  if len(disp) > 100:
137
  disp = disp[:100] + "…"
 
 
138
  return log(f"{prefix}{disp}")
139
  def log(msg):
140
  process_log.append(msg)
@@ -153,12 +155,15 @@ def run_tournament(
153
  add_usage(usage)
154
  yield from log(f"{len(all_players)} players generated")
155
  for i, p in enumerate(all_players, 1):
156
- yield from log_completion(f"Completion {i}: ", p)
157
  def criteria_block():
158
  return "\n".join(f"{i + 1}) {c}" for i, c in enumerate(criteria_list))
159
 
160
  if enable_score_filter:
161
- def score(player):
 
 
 
162
  text, usage = prompt_score(
163
  instruction,
164
  criteria_list,
@@ -172,7 +177,7 @@ def run_tournament(
172
  return_usage=True,
173
  )
174
  add_usage(usage)
175
- score_outputs.append(text)
176
  data = _clean_json(text)
177
  if "scores" in data and isinstance(data["scores"], list):
178
  vals = data["scores"]
@@ -183,7 +188,7 @@ def run_tournament(
183
  with ThreadPoolExecutor(max_workers=max_workers) as ex:
184
  prog = SimpleProgress(len(all_players), "Scoring")
185
  scores = {}
186
- for p, s in zip(all_players, ex.map(score, all_players)):
187
  scores[p] = s
188
  yield from log(prog.step())
189
  hist_fig = plt.figure()
@@ -191,8 +196,8 @@ def run_tournament(
191
  yield from log("Histogram generated")
192
  top_players = sorted(all_players, key=scores.get, reverse=True)[:pool_size]
193
  yield from log(f"Filtered to {len(top_players)} players with best scores")
194
- for i, txt in enumerate(score_outputs, 1):
195
- yield from log_completion(f"Score completion {i}: ", txt)
196
  else:
197
  top_players = all_players
198
  if enable_pairwise_filter:
 
1
  from dotenv import load_dotenv
2
+ load_dotenv("./local.env",override=True)
3
  import os, json, re, ast, gradio as gr
4
  from concurrent.futures import ThreadPoolExecutor, as_completed
5
  from tqdm import tqdm
 
30
  return f"{self.prefix} {self.count}/{self.total} - ETA {eta}"
31
 
32
  NUM_TOP_PICKS_DEFAULT = int(os.getenv("NUM_TOP_PICKS", 3))
33
+ POOL_SIZE_DEFAULT = int(os.getenv("POOL_SIZE", 6))
34
  MAX_WORKERS_DEFAULT = int(os.getenv("MAX_WORKERS", 100))
35
  NUM_GENERATIONS_DEFAULT = int(os.getenv("NUM_GENERATIONS", 10))
36
  API_BASE_DEFAULT = os.getenv("OPENAI_API_BASE", "")
 
40
  GENERATE_MODEL_DEFAULT = os.getenv("GENERATE_MODEL", "gpt-4o-mini")
41
  SCORE_MODEL_DEFAULT = os.getenv("SCORE_MODEL", "gpt-4o-mini")
42
  PAIRWISE_MODEL_DEFAULT = os.getenv("PAIRWISE_MODEL", "gpt-4o-mini")
43
+ GENERATE_TEMPERATURE_DEFAULT = float(os.getenv("GENERATE_TEMPERATURE", "0.9"))
44
+ SCORE_TEMPERATURE_DEFAULT = float(os.getenv("SCORE_TEMPERATURE", "0.6"))
45
+ PAIRWISE_TEMPERATURE_DEFAULT = float(os.getenv("PAIRWISE_TEMPERATURE", "0.6"))
46
  SCORE_WITH_INSTRUCTION_DEFAULT = os.getenv("PASS_INSTRUCTION_TO_SCORE", "true").lower() == "true"
47
  PAIRWISE_WITH_INSTRUCTION_DEFAULT = os.getenv("PASS_INSTRUCTION_TO_PAIRWISE", "true").lower() == "true"
48
  CRITERIA_DEFAULT = "Factuality,Instruction Following,Precision"
 
131
  f"Total tokens: {prompt_tokens + completion_tokens}"
132
  )
133
 
134
+ def log_completion(prefix: str, text: str, player_id: int | None = None):
135
  disp = text.replace("\n", " ")
136
  if len(disp) > 100:
137
  disp = disp[:100] + "…"
138
+ if player_id is not None:
139
+ prefix = f"{prefix}(ID {player_id}) "
140
  return log(f"{prefix}{disp}")
141
  def log(msg):
142
  process_log.append(msg)
 
155
  add_usage(usage)
156
  yield from log(f"{len(all_players)} players generated")
157
  for i, p in enumerate(all_players, 1):
158
+ yield from log_completion(f"Completion {i}: ", p, i)
159
  def criteria_block():
160
  return "\n".join(f"{i + 1}) {c}" for i, c in enumerate(criteria_list))
161
 
162
  if enable_score_filter:
163
+ players_with_ids = list(enumerate(all_players, 1))
164
+
165
+ def score(item):
166
+ idx, player = item
167
  text, usage = prompt_score(
168
  instruction,
169
  criteria_list,
 
177
  return_usage=True,
178
  )
179
  add_usage(usage)
180
+ score_outputs.append((idx, text))
181
  data = _clean_json(text)
182
  if "scores" in data and isinstance(data["scores"], list):
183
  vals = data["scores"]
 
188
  with ThreadPoolExecutor(max_workers=max_workers) as ex:
189
  prog = SimpleProgress(len(all_players), "Scoring")
190
  scores = {}
191
+ for (idx, p), s in zip(players_with_ids, ex.map(score, players_with_ids)):
192
  scores[p] = s
193
  yield from log(prog.step())
194
  hist_fig = plt.figure()
 
196
  yield from log("Histogram generated")
197
  top_players = sorted(all_players, key=scores.get, reverse=True)[:pool_size]
198
  yield from log(f"Filtered to {len(top_players)} players with best scores")
199
+ for i, (idx, txt) in enumerate(score_outputs, 1):
200
+ yield from log_completion(f"Score completion {i}: ", txt, idx)
201
  else:
202
  top_players = all_players
203
  if enable_pairwise_filter: