ping98k commited on
Commit
d95d0b3
·
unverified ·
2 Parent(s): 0d1b4d4 b936324

Merge pull request #12 from ping98k/codex/optimize-pairwise-tournament-with-parallelization

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. main.py +35 -50
  3. tests/test_main.py +3 -3
README.md CHANGED
@@ -39,7 +39,7 @@ This project provides a small interface for running "tournaments" between langua
39
  ```
40
  4. Open the displayed local URL. At the top of the page you can optionally override the API base path and token (the token field is blank by default). Additional settings let you configure score and pairwise filtering.
41
 
42
- The interface will generate multiple answers, optionally filter them by score and run a pairwise tournament to select the best outputs.
43
 
44
  ## Terminology
45
 
 
39
  ```
40
  4. Open the displayed local URL. At the top of the page you can optionally override the API base path and token (the token field is blank by default). Additional settings let you configure score and pairwise filtering.
41
 
42
+ The interface will generate multiple answers, optionally filter them by score and run a pairwise tournament to select the best outputs. Results from previous pairwise comparisons are cached, so duplicate matches are skipped for faster tournaments. Pairwise results are aggregated using an Elo rating system to rank the players.
43
 
44
  ## Terminology
45
 
main.py CHANGED
@@ -78,8 +78,8 @@ def run_tournament(
78
  generate_thinking,
79
  score_thinking,
80
  pairwise_thinking,
81
- score_explain,
82
- pairwise_explain,
83
  ):
84
  instruction = instruction_input.strip()
85
  criteria_list = [c.strip() for c in criteria_input.split(",") if c.strip()] or ["Factuality", "Instruction Following", "Precision"]
@@ -127,6 +127,7 @@ def run_tournament(
127
  completion_tokens = 0
128
  score_outputs: list[str] = []
129
  pairwise_outputs: list[str] = []
 
130
 
131
  def add_usage(usage):
132
  nonlocal prompt_tokens, completion_tokens
@@ -224,6 +225,9 @@ def run_tournament(
224
  top_players = all_players
225
  if enable_pairwise_filter:
226
  def play(a, b):
 
 
 
227
  text, usage = prompt_pairwise(
228
  instruction,
229
  criteria_block(),
@@ -241,68 +245,49 @@ def run_tournament(
241
  add_usage(usage)
242
  pairwise_outputs.append(text)
243
  winner_label = _clean_json(text).get("winner", "A")
244
- return a if winner_label == "A" else b
 
 
245
 
246
- def tournament_round(pairs, executor, progress):
 
 
 
 
 
 
 
247
  futures = {executor.submit(play, a, b): (a, b) for a, b in pairs}
248
- results = []
 
249
  for fut in as_completed(futures):
250
  a, b = futures[fut]
251
  winner = fut.result()
252
  loser = b if winner == a else a
253
- results.append((winner, loser))
254
- yield from log(progress.step())
255
- return results
256
-
257
- def tournament(players, executor):
258
- lost_to = {}
259
- current = players[:]
260
- progress = SimpleProgress(len(players) - 1, "Pairwise round")
261
- while len(current) > 1:
262
- leftover = current[-1] if len(current) % 2 == 1 else None
263
- pairs = [(current[i], current[i + 1]) for i in range(0, len(current) - 1, 2)]
264
- round_results = yield from tournament_round(pairs, executor, progress)
265
- for w, l in round_results:
266
- lost_to[l] = w
267
- current = [w for w, _ in round_results]
268
- if leftover:
269
- current.append(leftover)
270
- return current[0], lost_to
271
-
272
- def get_candidates(champion, lost_to):
273
- return [p for p, o in lost_to.items() if o == champion] + [champion]
274
-
275
- def playoff(candidates, executor):
276
- wins = {p: 0 for p in candidates}
277
- pairs = [
278
- (candidates[i], candidates[j])
279
- for i in range(len(candidates))
280
- for j in range(i + 1, len(candidates))
281
- ]
282
- futures = {executor.submit(play, a, b): (a, b) for a, b in pairs}
283
- prog = SimpleProgress(len(futures), "Playoff")
284
- for fut in as_completed(futures):
285
- wins[fut.result()] += 1
286
  yield from log(prog.step())
287
- return sorted(candidates, key=lambda p: wins[p], reverse=True)
288
-
289
- def get_top(players, executor):
290
- champion, lost_to = yield from tournament(players, executor)
291
- runner_up = lost_to.get(champion)
292
- finalists = [champion] + ([runner_up] if runner_up else [])
293
- semifinalists = [p for p, o in lost_to.items() if o in finalists and p not in finalists]
294
- candidates = list(set(finalists + semifinalists + get_candidates(champion, lost_to)))
295
- result = yield from playoff(candidates, executor)
296
- return result[:num_top_picks]
297
 
298
  yield from log("Pairwise generating")
299
  with ThreadPoolExecutor(max_workers=max_workers) as ex:
300
- top_k = yield from get_top(top_players, ex)
 
301
  for i, txt in enumerate(pairwise_outputs, 1):
302
  yield from log_completion(f"Pairwise completion {i}: ", txt)
 
 
 
303
  else:
304
  top_k = top_players[:num_top_picks]
305
- top_picks_str = "\n\n\n=====================================================\n\n\n".join(top_k)
306
  yield "\n".join(process_log + ["Done"]), hist_fig, top_picks_str, usage_str()
307
 
308
  demo = gr.Interface(
 
78
  generate_thinking,
79
  score_thinking,
80
  pairwise_thinking,
81
+ score_explain=None,
82
+ pairwise_explain=None,
83
  ):
84
  instruction = instruction_input.strip()
85
  criteria_list = [c.strip() for c in criteria_input.split(",") if c.strip()] or ["Factuality", "Instruction Following", "Precision"]
 
127
  completion_tokens = 0
128
  score_outputs: list[str] = []
129
  pairwise_outputs: list[str] = []
130
+ match_cache: dict[tuple[str, str], str] = {}
131
 
132
  def add_usage(usage):
133
  nonlocal prompt_tokens, completion_tokens
 
225
  top_players = all_players
226
  if enable_pairwise_filter:
227
  def play(a, b):
228
+ key = tuple(sorted((a, b)))
229
+ if key in match_cache:
230
+ return match_cache[key]
231
  text, usage = prompt_pairwise(
232
  instruction,
233
  criteria_block(),
 
245
  add_usage(usage)
246
  pairwise_outputs.append(text)
247
  winner_label = _clean_json(text).get("winner", "A")
248
+ winner = a if winner_label == "A" else b
249
+ match_cache[key] = winner
250
+ return winner
251
 
252
+ def all_pairs(players):
253
+ for i in range(len(players)):
254
+ for j in range(i + 1, len(players)):
255
+ yield players[i], players[j]
256
+
257
+ def rate(players, executor):
258
+ rating = {p: 1000.0 for p in players}
259
+ pairs = list(all_pairs(players))
260
  futures = {executor.submit(play, a, b): (a, b) for a, b in pairs}
261
+ prog = SimpleProgress(len(futures), "Elo matches")
262
+ K = 32
263
  for fut in as_completed(futures):
264
  a, b = futures[fut]
265
  winner = fut.result()
266
  loser = b if winner == a else a
267
+ ra, rb = rating[a], rating[b]
268
+ ea = 1 / (1 + 10 ** ((rb - ra) / 400))
269
+ eb = 1 - ea
270
+ if winner == a:
271
+ rating[a] = ra + K * (1 - ea)
272
+ rating[b] = rb + K * (0 - eb)
273
+ else:
274
+ rating[a] = ra + K * (0 - ea)
275
+ rating[b] = rb + K * (1 - eb)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  yield from log(prog.step())
277
+ return rating
 
 
 
 
 
 
 
 
 
278
 
279
  yield from log("Pairwise generating")
280
  with ThreadPoolExecutor(max_workers=max_workers) as ex:
281
+ rating = yield from rate(top_players, ex)
282
+ top_k = sorted(top_players, key=rating.get, reverse=True)[:num_top_picks]
283
  for i, txt in enumerate(pairwise_outputs, 1):
284
  yield from log_completion(f"Pairwise completion {i}: ", txt)
285
+ top_picks_str = "\n\n\n=====================================================\n\n\n".join(
286
+ f"{p}\nElo: {rating[p]:.1f}" for p in top_k
287
+ )
288
  else:
289
  top_k = top_players[:num_top_picks]
290
+ top_picks_str = "\n\n\n=====================================================\n\n\n".join(top_k)
291
  yield "\n".join(process_log + ["Done"]), hist_fig, top_picks_str, usage_str()
292
 
293
  demo = gr.Interface(
tests/test_main.py CHANGED
@@ -114,7 +114,7 @@ def test_run_tournament_full_loop():
114
  process_log, hist_fig, top_picks, usage = results[-1]
115
  assert 'Done' in process_log
116
  assert hist_fig == 'fig'
117
- assert top_picks.strip() in {'p1', 'p2'}
118
  mock_gen.assert_called_once_with('instr', 4, model='gm', api_base='b', api_key='k', temperature=1, thinking=True, return_usage=True)
119
  assert 'Score completion' in process_log
120
  assert 'Pairwise completion' in process_log
@@ -161,5 +161,5 @@ def test_run_tournament_pairwise_odd_players():
161
 
162
  process_log, fig, top_picks, usage = results[-1]
163
  assert 'Done' in process_log
164
- assert top_picks.strip() in {'p1', 'p2', 'p3'}
165
- assert mock_pair.call_count == 5
 
114
  process_log, hist_fig, top_picks, usage = results[-1]
115
  assert 'Done' in process_log
116
  assert hist_fig == 'fig'
117
+ assert any(p in top_picks for p in {'p1', 'p2'})
118
  mock_gen.assert_called_once_with('instr', 4, model='gm', api_base='b', api_key='k', temperature=1, thinking=True, return_usage=True)
119
  assert 'Score completion' in process_log
120
  assert 'Pairwise completion' in process_log
 
161
 
162
  process_log, fig, top_picks, usage = results[-1]
163
  assert 'Done' in process_log
164
+ assert any(p in top_picks for p in {'p1', 'p2', 'p3'})
165
+ assert mock_pair.call_count == 3