Spaces:
Sleeping
Sleeping
ping98k
commited on
Commit
·
e6b4ba0
1
Parent(s):
d95d0b3
Add Elo bar plot
Browse files- main.py +9 -3
- tests/test_main.py +9 -4
main.py
CHANGED
|
@@ -122,6 +122,7 @@ def run_tournament(
|
|
| 122 |
|
| 123 |
process_log = []
|
| 124 |
hist_fig = None
|
|
|
|
| 125 |
top_picks_str = ""
|
| 126 |
prompt_tokens = 0
|
| 127 |
completion_tokens = 0
|
|
@@ -161,7 +162,7 @@ def run_tournament(
|
|
| 161 |
def log(msg):
|
| 162 |
process_log.append(msg)
|
| 163 |
tqdm.write(msg)
|
| 164 |
-
yield "\n".join(process_log), hist_fig, top_picks_str, usage_str()
|
| 165 |
yield from log("Generating answers …")
|
| 166 |
all_players, usage = generate_players(
|
| 167 |
instruction,
|
|
@@ -279,7 +280,11 @@ def run_tournament(
|
|
| 279 |
yield from log("Pairwise generating")
|
| 280 |
with ThreadPoolExecutor(max_workers=max_workers) as ex:
|
| 281 |
rating = yield from rate(top_players, ex)
|
| 282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
for i, txt in enumerate(pairwise_outputs, 1):
|
| 284 |
yield from log_completion(f"Pairwise completion {i}: ", txt)
|
| 285 |
top_picks_str = "\n\n\n=====================================================\n\n\n".join(
|
|
@@ -288,7 +293,7 @@ def run_tournament(
|
|
| 288 |
else:
|
| 289 |
top_k = top_players[:num_top_picks]
|
| 290 |
top_picks_str = "\n\n\n=====================================================\n\n\n".join(top_k)
|
| 291 |
-
yield "\n".join(process_log + ["Done"]), hist_fig, top_picks_str, usage_str()
|
| 292 |
|
| 293 |
demo = gr.Interface(
|
| 294 |
fn=run_tournament,
|
|
@@ -320,6 +325,7 @@ demo = gr.Interface(
|
|
| 320 |
outputs=[
|
| 321 |
gr.Textbox(lines=10, label="Process"),
|
| 322 |
gr.Plot(label="Score Distribution"),
|
|
|
|
| 323 |
gr.Textbox(lines=50, label="Top picks"),
|
| 324 |
gr.Textbox(lines=5, label="Token Usage"),
|
| 325 |
],
|
|
|
|
| 122 |
|
| 123 |
process_log = []
|
| 124 |
hist_fig = None
|
| 125 |
+
elo_fig = None
|
| 126 |
top_picks_str = ""
|
| 127 |
prompt_tokens = 0
|
| 128 |
completion_tokens = 0
|
|
|
|
| 162 |
def log(msg):
|
| 163 |
process_log.append(msg)
|
| 164 |
tqdm.write(msg)
|
| 165 |
+
yield "\n".join(process_log), hist_fig, elo_fig, top_picks_str, usage_str()
|
| 166 |
yield from log("Generating answers …")
|
| 167 |
all_players, usage = generate_players(
|
| 168 |
instruction,
|
|
|
|
| 280 |
yield from log("Pairwise generating")
|
| 281 |
with ThreadPoolExecutor(max_workers=max_workers) as ex:
|
| 282 |
rating = yield from rate(top_players, ex)
|
| 283 |
+
elo_fig = plt.figure()
|
| 284 |
+
players_sorted = sorted(rating, key=rating.get, reverse=True)
|
| 285 |
+
plt.bar(range(len(players_sorted)), [rating[p] for p in players_sorted])
|
| 286 |
+
plt.xticks(range(len(players_sorted)), [str(i + 1) for i in range(len(players_sorted))])
|
| 287 |
+
top_k = players_sorted[:num_top_picks]
|
| 288 |
for i, txt in enumerate(pairwise_outputs, 1):
|
| 289 |
yield from log_completion(f"Pairwise completion {i}: ", txt)
|
| 290 |
top_picks_str = "\n\n\n=====================================================\n\n\n".join(
|
|
|
|
| 293 |
else:
|
| 294 |
top_k = top_players[:num_top_picks]
|
| 295 |
top_picks_str = "\n\n\n=====================================================\n\n\n".join(top_k)
|
| 296 |
+
yield "\n".join(process_log + ["Done"]), hist_fig, elo_fig, top_picks_str, usage_str()
|
| 297 |
|
| 298 |
demo = gr.Interface(
|
| 299 |
fn=run_tournament,
|
|
|
|
| 325 |
outputs=[
|
| 326 |
gr.Textbox(lines=10, label="Process"),
|
| 327 |
gr.Plot(label="Score Distribution"),
|
| 328 |
+
gr.Plot(label="Elo Ratings"),
|
| 329 |
gr.Textbox(lines=50, label="Top picks"),
|
| 330 |
gr.Textbox(lines=5, label="Token Usage"),
|
| 331 |
],
|
tests/test_main.py
CHANGED
|
@@ -39,6 +39,8 @@ sys.modules.setdefault('tqdm', fake_tqdm_mod)
|
|
| 39 |
fake_plt = types.ModuleType('matplotlib.pyplot')
|
| 40 |
fake_plt.figure = MagicMock(return_value='fig')
|
| 41 |
fake_plt.hist = MagicMock()
|
|
|
|
|
|
|
| 42 |
fake_matplotlib = types.ModuleType('matplotlib')
|
| 43 |
fake_matplotlib.pyplot = fake_plt
|
| 44 |
sys.modules.setdefault('matplotlib', fake_matplotlib)
|
|
@@ -81,7 +83,8 @@ def test_run_tournament_full_loop():
|
|
| 81 |
patch('main.as_completed', new=lambda futs: futs), \
|
| 82 |
patch('main.tqdm', new=dummy_tqdm), \
|
| 83 |
patch('main.plt.figure', return_value='fig'), \
|
| 84 |
-
patch('main.plt.hist')
|
|
|
|
| 85 |
mock_gen.return_value = (['p1', 'p2', 'p3', 'p4'], {'prompt_tokens':1,'completion_tokens':1})
|
| 86 |
scores = {'p1':3, 'p2':2, 'p3':1, 'p4':0}
|
| 87 |
mock_score.side_effect = lambda instr, cl, block, player, **kw: (json.dumps({'score': scores[player]}), {'prompt_tokens':1,'completion_tokens':1})
|
|
@@ -111,9 +114,10 @@ def test_run_tournament_full_loop():
|
|
| 111 |
pairwise_thinking=True,
|
| 112 |
))
|
| 113 |
|
| 114 |
-
process_log, hist_fig, top_picks, usage = results[-1]
|
| 115 |
assert 'Done' in process_log
|
| 116 |
assert hist_fig == 'fig'
|
|
|
|
| 117 |
assert any(p in top_picks for p in {'p1', 'p2'})
|
| 118 |
mock_gen.assert_called_once_with('instr', 4, model='gm', api_base='b', api_key='k', temperature=1, thinking=True, return_usage=True)
|
| 119 |
assert 'Score completion' in process_log
|
|
@@ -131,7 +135,8 @@ def test_run_tournament_pairwise_odd_players():
|
|
| 131 |
patch('main.as_completed', new=lambda futs: futs), \
|
| 132 |
patch('main.tqdm', new=dummy_tqdm), \
|
| 133 |
patch('main.plt.figure', return_value='fig'), \
|
| 134 |
-
patch('main.plt.hist')
|
|
|
|
| 135 |
mock_gen.return_value = (['p1', 'p2', 'p3'], {'prompt_tokens':1,'completion_tokens':1})
|
| 136 |
mock_pair.side_effect = lambda instr, block, a, b, **kw: (json.dumps({'winner':'A'}), {'prompt_tokens':1,'completion_tokens':1})
|
| 137 |
|
|
@@ -159,7 +164,7 @@ def test_run_tournament_pairwise_odd_players():
|
|
| 159 |
pairwise_thinking=True,
|
| 160 |
))
|
| 161 |
|
| 162 |
-
process_log,
|
| 163 |
assert 'Done' in process_log
|
| 164 |
assert any(p in top_picks for p in {'p1', 'p2', 'p3'})
|
| 165 |
assert mock_pair.call_count == 3
|
|
|
|
| 39 |
fake_plt = types.ModuleType('matplotlib.pyplot')
|
| 40 |
fake_plt.figure = MagicMock(return_value='fig')
|
| 41 |
fake_plt.hist = MagicMock()
|
| 42 |
+
fake_plt.bar = MagicMock()
|
| 43 |
+
fake_plt.xticks = MagicMock()
|
| 44 |
fake_matplotlib = types.ModuleType('matplotlib')
|
| 45 |
fake_matplotlib.pyplot = fake_plt
|
| 46 |
sys.modules.setdefault('matplotlib', fake_matplotlib)
|
|
|
|
| 83 |
patch('main.as_completed', new=lambda futs: futs), \
|
| 84 |
patch('main.tqdm', new=dummy_tqdm), \
|
| 85 |
patch('main.plt.figure', return_value='fig'), \
|
| 86 |
+
patch('main.plt.hist'), \
|
| 87 |
+
patch('main.plt.bar'):
|
| 88 |
mock_gen.return_value = (['p1', 'p2', 'p3', 'p4'], {'prompt_tokens':1,'completion_tokens':1})
|
| 89 |
scores = {'p1':3, 'p2':2, 'p3':1, 'p4':0}
|
| 90 |
mock_score.side_effect = lambda instr, cl, block, player, **kw: (json.dumps({'score': scores[player]}), {'prompt_tokens':1,'completion_tokens':1})
|
|
|
|
| 114 |
pairwise_thinking=True,
|
| 115 |
))
|
| 116 |
|
| 117 |
+
process_log, hist_fig, elo_fig, top_picks, usage = results[-1]
|
| 118 |
assert 'Done' in process_log
|
| 119 |
assert hist_fig == 'fig'
|
| 120 |
+
assert elo_fig == 'fig'
|
| 121 |
assert any(p in top_picks for p in {'p1', 'p2'})
|
| 122 |
mock_gen.assert_called_once_with('instr', 4, model='gm', api_base='b', api_key='k', temperature=1, thinking=True, return_usage=True)
|
| 123 |
assert 'Score completion' in process_log
|
|
|
|
| 135 |
patch('main.as_completed', new=lambda futs: futs), \
|
| 136 |
patch('main.tqdm', new=dummy_tqdm), \
|
| 137 |
patch('main.plt.figure', return_value='fig'), \
|
| 138 |
+
patch('main.plt.hist'), \
|
| 139 |
+
patch('main.plt.bar'):
|
| 140 |
mock_gen.return_value = (['p1', 'p2', 'p3'], {'prompt_tokens':1,'completion_tokens':1})
|
| 141 |
mock_pair.side_effect = lambda instr, block, a, b, **kw: (json.dumps({'winner':'A'}), {'prompt_tokens':1,'completion_tokens':1})
|
| 142 |
|
|
|
|
| 164 |
pairwise_thinking=True,
|
| 165 |
))
|
| 166 |
|
| 167 |
+
process_log, hist_fig, elo_fig, top_picks, usage = results[-1]
|
| 168 |
assert 'Done' in process_log
|
| 169 |
assert any(p in top_picks for p in {'p1', 'p2', 'p3'})
|
| 170 |
assert mock_pair.call_count == 3
|