ping98k commited on
Commit
2b53c20
·
1 Parent(s): 3bd9ad6

Fix pairwise tournament loop

Browse files
Files changed (2) hide show
  1. main.py +3 -2
  2. tests/test_main.py +34 -0
main.py CHANGED
@@ -181,13 +181,14 @@ def run_tournament(
181
  lost_to = {}
182
  current = players[:]
183
  while len(current) > 1:
 
184
  pairs = [(current[i], current[i + 1]) for i in range(0, len(current) - 1, 2)]
185
  round_results = tournament_round(pairs, executor)
186
  for w, l in round_results:
187
  lost_to[l] = w
188
  current = [w for w, _ in round_results]
189
- if len(players) % 2 == 1 and players[-1] not in current:
190
- current.append(players[-1])
191
  return current[0], lost_to
192
 
193
  def get_candidates(champion, lost_to):
 
181
  lost_to = {}
182
  current = players[:]
183
  while len(current) > 1:
184
+ leftover = current[-1] if len(current) % 2 == 1 else None
185
  pairs = [(current[i], current[i + 1]) for i in range(0, len(current) - 1, 2)]
186
  round_results = tournament_round(pairs, executor)
187
  for w, l in round_results:
188
  lost_to[l] = w
189
  current = [w for w, _ in round_results]
190
+ if leftover:
191
+ current.append(leftover)
192
  return current[0], lost_to
193
 
194
  def get_candidates(champion, lost_to):
tests/test_main.py CHANGED
@@ -113,3 +113,37 @@ def test_run_tournament_full_loop():
113
  assert 'Prompt tokens' in usage
114
  assert mock_score.call_count == 4
115
  assert mock_pair.called
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  assert 'Prompt tokens' in usage
114
  assert mock_score.call_count == 4
115
  assert mock_pair.called
116
+
117
+
118
+ def test_run_tournament_pairwise_odd_players():
119
+ dummy_tqdm = DummyTqdm()
120
+ with patch('main.generate_players') as mock_gen, \
121
+ patch('main.prompt_pairwise') as mock_pair, \
122
+ patch('main.ThreadPoolExecutor', return_value=DummyExecutor()) as MockEx, \
123
+ patch('main.as_completed', new=lambda futs: futs), \
124
+ patch('main.tqdm', new=dummy_tqdm), \
125
+ patch('main.plt.figure', return_value='fig'), \
126
+ patch('main.plt.hist'):
127
+ mock_gen.return_value = (['p1', 'p2', 'p3'], {'prompt_tokens':1,'completion_tokens':1})
128
+ mock_pair.side_effect = lambda instr, block, a, b, **kw: (json.dumps({'winner':'A'}), {'prompt_tokens':1,'completion_tokens':1})
129
+
130
+ results = list(main.run_tournament(
131
+ api_base='b',
132
+ api_token='k',
133
+ generate_model='gm',
134
+ score_model='sm',
135
+ pairwise_model='pm',
136
+ instruction_input='instr',
137
+ criteria_input='c1,c2',
138
+ n_gen=3,
139
+ pool_size=3,
140
+ num_top_picks=1,
141
+ max_workers=1,
142
+ enable_score_filter=False,
143
+ enable_pairwise_filter=True,
144
+ ))
145
+
146
+ process_log, fig, top_picks, usage = results[-1]
147
+ assert 'Done' in process_log
148
+ assert top_picks.strip() in {'p1', 'p2', 'p3'}
149
+ assert mock_pair.call_count == 5