ping98k commited on
Commit
e6b4ba0
·
1 Parent(s): d95d0b3

Add Elo bar plot

Browse files
Files changed (2) hide show
  1. main.py +9 -3
  2. tests/test_main.py +9 -4
main.py CHANGED
@@ -122,6 +122,7 @@ def run_tournament(
122
 
123
  process_log = []
124
  hist_fig = None
 
125
  top_picks_str = ""
126
  prompt_tokens = 0
127
  completion_tokens = 0
@@ -161,7 +162,7 @@ def run_tournament(
161
  def log(msg):
162
  process_log.append(msg)
163
  tqdm.write(msg)
164
- yield "\n".join(process_log), hist_fig, top_picks_str, usage_str()
165
  yield from log("Generating answers …")
166
  all_players, usage = generate_players(
167
  instruction,
@@ -279,7 +280,11 @@ def run_tournament(
279
  yield from log("Pairwise generating")
280
  with ThreadPoolExecutor(max_workers=max_workers) as ex:
281
  rating = yield from rate(top_players, ex)
282
- top_k = sorted(top_players, key=rating.get, reverse=True)[:num_top_picks]
 
 
 
 
283
  for i, txt in enumerate(pairwise_outputs, 1):
284
  yield from log_completion(f"Pairwise completion {i}: ", txt)
285
  top_picks_str = "\n\n\n=====================================================\n\n\n".join(
@@ -288,7 +293,7 @@ def run_tournament(
288
  else:
289
  top_k = top_players[:num_top_picks]
290
  top_picks_str = "\n\n\n=====================================================\n\n\n".join(top_k)
291
- yield "\n".join(process_log + ["Done"]), hist_fig, top_picks_str, usage_str()
292
 
293
  demo = gr.Interface(
294
  fn=run_tournament,
@@ -320,6 +325,7 @@ demo = gr.Interface(
320
  outputs=[
321
  gr.Textbox(lines=10, label="Process"),
322
  gr.Plot(label="Score Distribution"),
 
323
  gr.Textbox(lines=50, label="Top picks"),
324
  gr.Textbox(lines=5, label="Token Usage"),
325
  ],
 
122
 
123
  process_log = []
124
  hist_fig = None
125
+ elo_fig = None
126
  top_picks_str = ""
127
  prompt_tokens = 0
128
  completion_tokens = 0
 
162
  def log(msg):
163
  process_log.append(msg)
164
  tqdm.write(msg)
165
+ yield "\n".join(process_log), hist_fig, elo_fig, top_picks_str, usage_str()
166
  yield from log("Generating answers …")
167
  all_players, usage = generate_players(
168
  instruction,
 
280
  yield from log("Pairwise generating")
281
  with ThreadPoolExecutor(max_workers=max_workers) as ex:
282
  rating = yield from rate(top_players, ex)
283
+ elo_fig = plt.figure()
284
+ players_sorted = sorted(rating, key=rating.get, reverse=True)
285
+ plt.bar(range(len(players_sorted)), [rating[p] for p in players_sorted])
286
+ plt.xticks(range(len(players_sorted)), [str(i + 1) for i in range(len(players_sorted))])
287
+ top_k = players_sorted[:num_top_picks]
288
  for i, txt in enumerate(pairwise_outputs, 1):
289
  yield from log_completion(f"Pairwise completion {i}: ", txt)
290
  top_picks_str = "\n\n\n=====================================================\n\n\n".join(
 
293
  else:
294
  top_k = top_players[:num_top_picks]
295
  top_picks_str = "\n\n\n=====================================================\n\n\n".join(top_k)
296
+ yield "\n".join(process_log + ["Done"]), hist_fig, elo_fig, top_picks_str, usage_str()
297
 
298
  demo = gr.Interface(
299
  fn=run_tournament,
 
325
  outputs=[
326
  gr.Textbox(lines=10, label="Process"),
327
  gr.Plot(label="Score Distribution"),
328
+ gr.Plot(label="Elo Ratings"),
329
  gr.Textbox(lines=50, label="Top picks"),
330
  gr.Textbox(lines=5, label="Token Usage"),
331
  ],
tests/test_main.py CHANGED
@@ -39,6 +39,8 @@ sys.modules.setdefault('tqdm', fake_tqdm_mod)
39
  fake_plt = types.ModuleType('matplotlib.pyplot')
40
  fake_plt.figure = MagicMock(return_value='fig')
41
  fake_plt.hist = MagicMock()
 
 
42
  fake_matplotlib = types.ModuleType('matplotlib')
43
  fake_matplotlib.pyplot = fake_plt
44
  sys.modules.setdefault('matplotlib', fake_matplotlib)
@@ -81,7 +83,8 @@ def test_run_tournament_full_loop():
81
  patch('main.as_completed', new=lambda futs: futs), \
82
  patch('main.tqdm', new=dummy_tqdm), \
83
  patch('main.plt.figure', return_value='fig'), \
84
- patch('main.plt.hist'):
 
85
  mock_gen.return_value = (['p1', 'p2', 'p3', 'p4'], {'prompt_tokens':1,'completion_tokens':1})
86
  scores = {'p1':3, 'p2':2, 'p3':1, 'p4':0}
87
  mock_score.side_effect = lambda instr, cl, block, player, **kw: (json.dumps({'score': scores[player]}), {'prompt_tokens':1,'completion_tokens':1})
@@ -111,9 +114,10 @@ def test_run_tournament_full_loop():
111
  pairwise_thinking=True,
112
  ))
113
 
114
- process_log, hist_fig, top_picks, usage = results[-1]
115
  assert 'Done' in process_log
116
  assert hist_fig == 'fig'
 
117
  assert any(p in top_picks for p in {'p1', 'p2'})
118
  mock_gen.assert_called_once_with('instr', 4, model='gm', api_base='b', api_key='k', temperature=1, thinking=True, return_usage=True)
119
  assert 'Score completion' in process_log
@@ -131,7 +135,8 @@ def test_run_tournament_pairwise_odd_players():
131
  patch('main.as_completed', new=lambda futs: futs), \
132
  patch('main.tqdm', new=dummy_tqdm), \
133
  patch('main.plt.figure', return_value='fig'), \
134
- patch('main.plt.hist'):
 
135
  mock_gen.return_value = (['p1', 'p2', 'p3'], {'prompt_tokens':1,'completion_tokens':1})
136
  mock_pair.side_effect = lambda instr, block, a, b, **kw: (json.dumps({'winner':'A'}), {'prompt_tokens':1,'completion_tokens':1})
137
 
@@ -159,7 +164,7 @@ def test_run_tournament_pairwise_odd_players():
159
  pairwise_thinking=True,
160
  ))
161
 
162
- process_log, fig, top_picks, usage = results[-1]
163
  assert 'Done' in process_log
164
  assert any(p in top_picks for p in {'p1', 'p2', 'p3'})
165
  assert mock_pair.call_count == 3
 
39
  fake_plt = types.ModuleType('matplotlib.pyplot')
40
  fake_plt.figure = MagicMock(return_value='fig')
41
  fake_plt.hist = MagicMock()
42
+ fake_plt.bar = MagicMock()
43
+ fake_plt.xticks = MagicMock()
44
  fake_matplotlib = types.ModuleType('matplotlib')
45
  fake_matplotlib.pyplot = fake_plt
46
  sys.modules.setdefault('matplotlib', fake_matplotlib)
 
83
  patch('main.as_completed', new=lambda futs: futs), \
84
  patch('main.tqdm', new=dummy_tqdm), \
85
  patch('main.plt.figure', return_value='fig'), \
86
+ patch('main.plt.hist'), \
87
+ patch('main.plt.bar'):
88
  mock_gen.return_value = (['p1', 'p2', 'p3', 'p4'], {'prompt_tokens':1,'completion_tokens':1})
89
  scores = {'p1':3, 'p2':2, 'p3':1, 'p4':0}
90
  mock_score.side_effect = lambda instr, cl, block, player, **kw: (json.dumps({'score': scores[player]}), {'prompt_tokens':1,'completion_tokens':1})
 
114
  pairwise_thinking=True,
115
  ))
116
 
117
+ process_log, hist_fig, elo_fig, top_picks, usage = results[-1]
118
  assert 'Done' in process_log
119
  assert hist_fig == 'fig'
120
+ assert elo_fig == 'fig'
121
  assert any(p in top_picks for p in {'p1', 'p2'})
122
  mock_gen.assert_called_once_with('instr', 4, model='gm', api_base='b', api_key='k', temperature=1, thinking=True, return_usage=True)
123
  assert 'Score completion' in process_log
 
135
  patch('main.as_completed', new=lambda futs: futs), \
136
  patch('main.tqdm', new=dummy_tqdm), \
137
  patch('main.plt.figure', return_value='fig'), \
138
+ patch('main.plt.hist'), \
139
+ patch('main.plt.bar'):
140
  mock_gen.return_value = (['p1', 'p2', 'p3'], {'prompt_tokens':1,'completion_tokens':1})
141
  mock_pair.side_effect = lambda instr, block, a, b, **kw: (json.dumps({'winner':'A'}), {'prompt_tokens':1,'completion_tokens':1})
142
 
 
164
  pairwise_thinking=True,
165
  ))
166
 
167
+ process_log, hist_fig, elo_fig, top_picks, usage = results[-1]
168
  assert 'Done' in process_log
169
  assert any(p in top_picks for p in {'p1', 'p2', 'p3'})
170
  assert mock_pair.call_count == 3