Spaces:
Sleeping
Sleeping
ping98k
commited on
Commit
·
2b53c20
1
Parent(s):
3bd9ad6
Fix pairwise tournament loop
Browse files- main.py +3 -2
- tests/test_main.py +34 -0
main.py
CHANGED
|
@@ -181,13 +181,14 @@ def run_tournament(
|
|
| 181 |
lost_to = {}
|
| 182 |
current = players[:]
|
| 183 |
while len(current) > 1:
|
|
|
|
| 184 |
pairs = [(current[i], current[i + 1]) for i in range(0, len(current) - 1, 2)]
|
| 185 |
round_results = tournament_round(pairs, executor)
|
| 186 |
for w, l in round_results:
|
| 187 |
lost_to[l] = w
|
| 188 |
current = [w for w, _ in round_results]
|
| 189 |
-
if
|
| 190 |
-
current.append(
|
| 191 |
return current[0], lost_to
|
| 192 |
|
| 193 |
def get_candidates(champion, lost_to):
|
|
|
|
| 181 |
lost_to = {}
|
| 182 |
current = players[:]
|
| 183 |
while len(current) > 1:
|
| 184 |
+
leftover = current[-1] if len(current) % 2 == 1 else None
|
| 185 |
pairs = [(current[i], current[i + 1]) for i in range(0, len(current) - 1, 2)]
|
| 186 |
round_results = tournament_round(pairs, executor)
|
| 187 |
for w, l in round_results:
|
| 188 |
lost_to[l] = w
|
| 189 |
current = [w for w, _ in round_results]
|
| 190 |
+
if leftover:
|
| 191 |
+
current.append(leftover)
|
| 192 |
return current[0], lost_to
|
| 193 |
|
| 194 |
def get_candidates(champion, lost_to):
|
tests/test_main.py
CHANGED
|
@@ -113,3 +113,37 @@ def test_run_tournament_full_loop():
|
|
| 113 |
assert 'Prompt tokens' in usage
|
| 114 |
assert mock_score.call_count == 4
|
| 115 |
assert mock_pair.called
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
assert 'Prompt tokens' in usage
|
| 114 |
assert mock_score.call_count == 4
|
| 115 |
assert mock_pair.called
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def test_run_tournament_pairwise_odd_players():
|
| 119 |
+
dummy_tqdm = DummyTqdm()
|
| 120 |
+
with patch('main.generate_players') as mock_gen, \
|
| 121 |
+
patch('main.prompt_pairwise') as mock_pair, \
|
| 122 |
+
patch('main.ThreadPoolExecutor', return_value=DummyExecutor()) as MockEx, \
|
| 123 |
+
patch('main.as_completed', new=lambda futs: futs), \
|
| 124 |
+
patch('main.tqdm', new=dummy_tqdm), \
|
| 125 |
+
patch('main.plt.figure', return_value='fig'), \
|
| 126 |
+
patch('main.plt.hist'):
|
| 127 |
+
mock_gen.return_value = (['p1', 'p2', 'p3'], {'prompt_tokens':1,'completion_tokens':1})
|
| 128 |
+
mock_pair.side_effect = lambda instr, block, a, b, **kw: (json.dumps({'winner':'A'}), {'prompt_tokens':1,'completion_tokens':1})
|
| 129 |
+
|
| 130 |
+
results = list(main.run_tournament(
|
| 131 |
+
api_base='b',
|
| 132 |
+
api_token='k',
|
| 133 |
+
generate_model='gm',
|
| 134 |
+
score_model='sm',
|
| 135 |
+
pairwise_model='pm',
|
| 136 |
+
instruction_input='instr',
|
| 137 |
+
criteria_input='c1,c2',
|
| 138 |
+
n_gen=3,
|
| 139 |
+
pool_size=3,
|
| 140 |
+
num_top_picks=1,
|
| 141 |
+
max_workers=1,
|
| 142 |
+
enable_score_filter=False,
|
| 143 |
+
enable_pairwise_filter=True,
|
| 144 |
+
))
|
| 145 |
+
|
| 146 |
+
process_log, fig, top_picks, usage = results[-1]
|
| 147 |
+
assert 'Done' in process_log
|
| 148 |
+
assert top_picks.strip() in {'p1', 'p2', 'p3'}
|
| 149 |
+
assert mock_pair.call_count == 5
|