Spaces:
Sleeping
Sleeping
Merge pull request #12 from ping98k/codex/optimize-pairwise-tournament-with-parallelization
Browse files- README.md +1 -1
- main.py +35 -50
- tests/test_main.py +3 -3
README.md
CHANGED
|
@@ -39,7 +39,7 @@ This project provides a small interface for running "tournaments" between langua
|
|
| 39 |
```
|
| 40 |
4. Open the displayed local URL. At the top of the page you can optionally override the API base path and token (the token field is blank by default). Additional settings let you configure score and pairwise filtering.
|
| 41 |
|
| 42 |
-
The interface will generate multiple answers, optionally filter them by score and run a pairwise tournament to select the best outputs.
|
| 43 |
|
| 44 |
## Terminology
|
| 45 |
|
|
|
|
| 39 |
```
|
| 40 |
4. Open the displayed local URL. At the top of the page you can optionally override the API base path and token (the token field is blank by default). Additional settings let you configure score and pairwise filtering.
|
| 41 |
|
| 42 |
+
The interface will generate multiple answers, optionally filter them by score and run a pairwise tournament to select the best outputs. Results from previous pairwise comparisons are cached, so duplicate matches are skipped for faster tournaments. Pairwise results are aggregated using an Elo rating system to rank the players.
|
| 43 |
|
| 44 |
## Terminology
|
| 45 |
|
main.py
CHANGED
|
@@ -78,8 +78,8 @@ def run_tournament(
|
|
| 78 |
generate_thinking,
|
| 79 |
score_thinking,
|
| 80 |
pairwise_thinking,
|
| 81 |
-
score_explain,
|
| 82 |
-
pairwise_explain,
|
| 83 |
):
|
| 84 |
instruction = instruction_input.strip()
|
| 85 |
criteria_list = [c.strip() for c in criteria_input.split(",") if c.strip()] or ["Factuality", "Instruction Following", "Precision"]
|
|
@@ -127,6 +127,7 @@ def run_tournament(
|
|
| 127 |
completion_tokens = 0
|
| 128 |
score_outputs: list[str] = []
|
| 129 |
pairwise_outputs: list[str] = []
|
|
|
|
| 130 |
|
| 131 |
def add_usage(usage):
|
| 132 |
nonlocal prompt_tokens, completion_tokens
|
|
@@ -224,6 +225,9 @@ def run_tournament(
|
|
| 224 |
top_players = all_players
|
| 225 |
if enable_pairwise_filter:
|
| 226 |
def play(a, b):
|
|
|
|
|
|
|
|
|
|
| 227 |
text, usage = prompt_pairwise(
|
| 228 |
instruction,
|
| 229 |
criteria_block(),
|
|
@@ -241,68 +245,49 @@ def run_tournament(
|
|
| 241 |
add_usage(usage)
|
| 242 |
pairwise_outputs.append(text)
|
| 243 |
winner_label = _clean_json(text).get("winner", "A")
|
| 244 |
-
|
|
|
|
|
|
|
| 245 |
|
| 246 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
futures = {executor.submit(play, a, b): (a, b) for a, b in pairs}
|
| 248 |
-
|
|
|
|
| 249 |
for fut in as_completed(futures):
|
| 250 |
a, b = futures[fut]
|
| 251 |
winner = fut.result()
|
| 252 |
loser = b if winner == a else a
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
leftover = current[-1] if len(current) % 2 == 1 else None
|
| 263 |
-
pairs = [(current[i], current[i + 1]) for i in range(0, len(current) - 1, 2)]
|
| 264 |
-
round_results = yield from tournament_round(pairs, executor, progress)
|
| 265 |
-
for w, l in round_results:
|
| 266 |
-
lost_to[l] = w
|
| 267 |
-
current = [w for w, _ in round_results]
|
| 268 |
-
if leftover:
|
| 269 |
-
current.append(leftover)
|
| 270 |
-
return current[0], lost_to
|
| 271 |
-
|
| 272 |
-
def get_candidates(champion, lost_to):
|
| 273 |
-
return [p for p, o in lost_to.items() if o == champion] + [champion]
|
| 274 |
-
|
| 275 |
-
def playoff(candidates, executor):
|
| 276 |
-
wins = {p: 0 for p in candidates}
|
| 277 |
-
pairs = [
|
| 278 |
-
(candidates[i], candidates[j])
|
| 279 |
-
for i in range(len(candidates))
|
| 280 |
-
for j in range(i + 1, len(candidates))
|
| 281 |
-
]
|
| 282 |
-
futures = {executor.submit(play, a, b): (a, b) for a, b in pairs}
|
| 283 |
-
prog = SimpleProgress(len(futures), "Playoff")
|
| 284 |
-
for fut in as_completed(futures):
|
| 285 |
-
wins[fut.result()] += 1
|
| 286 |
yield from log(prog.step())
|
| 287 |
-
return
|
| 288 |
-
|
| 289 |
-
def get_top(players, executor):
|
| 290 |
-
champion, lost_to = yield from tournament(players, executor)
|
| 291 |
-
runner_up = lost_to.get(champion)
|
| 292 |
-
finalists = [champion] + ([runner_up] if runner_up else [])
|
| 293 |
-
semifinalists = [p for p, o in lost_to.items() if o in finalists and p not in finalists]
|
| 294 |
-
candidates = list(set(finalists + semifinalists + get_candidates(champion, lost_to)))
|
| 295 |
-
result = yield from playoff(candidates, executor)
|
| 296 |
-
return result[:num_top_picks]
|
| 297 |
|
| 298 |
yield from log("Pairwise generating")
|
| 299 |
with ThreadPoolExecutor(max_workers=max_workers) as ex:
|
| 300 |
-
|
|
|
|
| 301 |
for i, txt in enumerate(pairwise_outputs, 1):
|
| 302 |
yield from log_completion(f"Pairwise completion {i}: ", txt)
|
|
|
|
|
|
|
|
|
|
| 303 |
else:
|
| 304 |
top_k = top_players[:num_top_picks]
|
| 305 |
-
|
| 306 |
yield "\n".join(process_log + ["Done"]), hist_fig, top_picks_str, usage_str()
|
| 307 |
|
| 308 |
demo = gr.Interface(
|
|
|
|
| 78 |
generate_thinking,
|
| 79 |
score_thinking,
|
| 80 |
pairwise_thinking,
|
| 81 |
+
score_explain=None,
|
| 82 |
+
pairwise_explain=None,
|
| 83 |
):
|
| 84 |
instruction = instruction_input.strip()
|
| 85 |
criteria_list = [c.strip() for c in criteria_input.split(",") if c.strip()] or ["Factuality", "Instruction Following", "Precision"]
|
|
|
|
| 127 |
completion_tokens = 0
|
| 128 |
score_outputs: list[str] = []
|
| 129 |
pairwise_outputs: list[str] = []
|
| 130 |
+
match_cache: dict[tuple[str, str], str] = {}
|
| 131 |
|
| 132 |
def add_usage(usage):
|
| 133 |
nonlocal prompt_tokens, completion_tokens
|
|
|
|
| 225 |
top_players = all_players
|
| 226 |
if enable_pairwise_filter:
|
| 227 |
def play(a, b):
|
| 228 |
+
key = tuple(sorted((a, b)))
|
| 229 |
+
if key in match_cache:
|
| 230 |
+
return match_cache[key]
|
| 231 |
text, usage = prompt_pairwise(
|
| 232 |
instruction,
|
| 233 |
criteria_block(),
|
|
|
|
| 245 |
add_usage(usage)
|
| 246 |
pairwise_outputs.append(text)
|
| 247 |
winner_label = _clean_json(text).get("winner", "A")
|
| 248 |
+
winner = a if winner_label == "A" else b
|
| 249 |
+
match_cache[key] = winner
|
| 250 |
+
return winner
|
| 251 |
|
| 252 |
+
def all_pairs(players):
|
| 253 |
+
for i in range(len(players)):
|
| 254 |
+
for j in range(i + 1, len(players)):
|
| 255 |
+
yield players[i], players[j]
|
| 256 |
+
|
| 257 |
+
def rate(players, executor):
|
| 258 |
+
rating = {p: 1000.0 for p in players}
|
| 259 |
+
pairs = list(all_pairs(players))
|
| 260 |
futures = {executor.submit(play, a, b): (a, b) for a, b in pairs}
|
| 261 |
+
prog = SimpleProgress(len(futures), "Elo matches")
|
| 262 |
+
K = 32
|
| 263 |
for fut in as_completed(futures):
|
| 264 |
a, b = futures[fut]
|
| 265 |
winner = fut.result()
|
| 266 |
loser = b if winner == a else a
|
| 267 |
+
ra, rb = rating[a], rating[b]
|
| 268 |
+
ea = 1 / (1 + 10 ** ((rb - ra) / 400))
|
| 269 |
+
eb = 1 - ea
|
| 270 |
+
if winner == a:
|
| 271 |
+
rating[a] = ra + K * (1 - ea)
|
| 272 |
+
rating[b] = rb + K * (0 - eb)
|
| 273 |
+
else:
|
| 274 |
+
rating[a] = ra + K * (0 - ea)
|
| 275 |
+
rating[b] = rb + K * (1 - eb)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
yield from log(prog.step())
|
| 277 |
+
return rating
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
|
| 279 |
yield from log("Pairwise generating")
|
| 280 |
with ThreadPoolExecutor(max_workers=max_workers) as ex:
|
| 281 |
+
rating = yield from rate(top_players, ex)
|
| 282 |
+
top_k = sorted(top_players, key=rating.get, reverse=True)[:num_top_picks]
|
| 283 |
for i, txt in enumerate(pairwise_outputs, 1):
|
| 284 |
yield from log_completion(f"Pairwise completion {i}: ", txt)
|
| 285 |
+
top_picks_str = "\n\n\n=====================================================\n\n\n".join(
|
| 286 |
+
f"{p}\nElo: {rating[p]:.1f}" for p in top_k
|
| 287 |
+
)
|
| 288 |
else:
|
| 289 |
top_k = top_players[:num_top_picks]
|
| 290 |
+
top_picks_str = "\n\n\n=====================================================\n\n\n".join(top_k)
|
| 291 |
yield "\n".join(process_log + ["Done"]), hist_fig, top_picks_str, usage_str()
|
| 292 |
|
| 293 |
demo = gr.Interface(
|
tests/test_main.py
CHANGED
|
@@ -114,7 +114,7 @@ def test_run_tournament_full_loop():
|
|
| 114 |
process_log, hist_fig, top_picks, usage = results[-1]
|
| 115 |
assert 'Done' in process_log
|
| 116 |
assert hist_fig == 'fig'
|
| 117 |
-
assert
|
| 118 |
mock_gen.assert_called_once_with('instr', 4, model='gm', api_base='b', api_key='k', temperature=1, thinking=True, return_usage=True)
|
| 119 |
assert 'Score completion' in process_log
|
| 120 |
assert 'Pairwise completion' in process_log
|
|
@@ -161,5 +161,5 @@ def test_run_tournament_pairwise_odd_players():
|
|
| 161 |
|
| 162 |
process_log, fig, top_picks, usage = results[-1]
|
| 163 |
assert 'Done' in process_log
|
| 164 |
-
assert
|
| 165 |
-
assert mock_pair.call_count ==
|
|
|
|
| 114 |
process_log, hist_fig, top_picks, usage = results[-1]
|
| 115 |
assert 'Done' in process_log
|
| 116 |
assert hist_fig == 'fig'
|
| 117 |
+
assert any(p in top_picks for p in {'p1', 'p2'})
|
| 118 |
mock_gen.assert_called_once_with('instr', 4, model='gm', api_base='b', api_key='k', temperature=1, thinking=True, return_usage=True)
|
| 119 |
assert 'Score completion' in process_log
|
| 120 |
assert 'Pairwise completion' in process_log
|
|
|
|
| 161 |
|
| 162 |
process_log, fig, top_picks, usage = results[-1]
|
| 163 |
assert 'Done' in process_log
|
| 164 |
+
assert any(p in top_picks for p in {'p1', 'p2', 'p3'})
|
| 165 |
+
assert mock_pair.call_count == 3
|