ping98k commited on
Commit
3bd9ad6
·
1 Parent(s): c43a3f3

Log scoring and pairwise completions

Browse files
Files changed (3) hide show
  1. main.py +70 -27
  2. tests/test_main.py +8 -5
  3. tournament_utils.py +23 -7
main.py CHANGED
@@ -61,47 +61,84 @@ def run_tournament(
61
  process_log = []
62
  hist_fig = None
63
  top_picks_str = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def log(msg):
65
  process_log.append(msg)
66
  tqdm.write(msg)
67
- yield "\n".join(process_log), hist_fig, top_picks_str
68
  yield from log("Generating players …")
69
- all_players = generate_players(
70
  instruction,
71
  n_gen,
72
  model=generate_model,
73
  api_base=api_base,
74
  api_key=api_token,
 
75
  )
 
76
  yield from log(f"{len(all_players)} players generated")
 
 
77
  def criteria_block():
78
  return "\n".join(f"{i + 1}) {c}" for i, c in enumerate(criteria_list))
79
 
80
  if enable_score_filter:
81
  def score(player):
82
- data = _clean_json(
83
- prompt_score(
84
- instruction,
85
- criteria_list,
86
- criteria_block(),
87
- player,
88
- model=score_model,
89
- api_base=api_base,
90
- api_key=api_token,
91
- )
92
  )
 
 
 
93
  if "scores" in data and isinstance(data["scores"], list):
94
  vals = data["scores"]
95
  return sum(vals) / len(vals) if vals else 0.0
96
  return float(data.get("score", 0))
97
 
98
- yield from log("Scoring players …")
99
  with ThreadPoolExecutor(max_workers=max_workers) as ex:
100
  scores = {
101
  p: s
102
  for p, s in zip(
103
  all_players,
104
- list(tqdm(ex.map(score, all_players), total=len(all_players))),
105
  )
106
  }
107
  hist_fig = plt.figure()
@@ -109,21 +146,25 @@ def run_tournament(
109
  yield from log("Histogram generated")
110
  top_players = sorted(all_players, key=scores.get, reverse=True)[:pool_size]
111
  yield from log(f"Filtered to {len(top_players)} players with best scores")
 
 
112
  else:
113
  top_players = all_players
114
  if enable_pairwise_filter:
115
  def play(a, b):
116
- winner_label = _clean_json(
117
- prompt_pairwise(
118
- instruction,
119
- criteria_block(),
120
- a,
121
- b,
122
- model=pairwise_model,
123
- api_base=api_base,
124
- api_key=api_token,
125
- )
126
- ).get("winner", "A")
 
 
127
  return a if winner_label == "A" else b
128
 
129
  def tournament_round(pairs, executor):
@@ -172,13 +213,14 @@ def run_tournament(
172
  candidates = list(set(finalists + semifinalists + get_candidates(champion, lost_to)))
173
  return playoff(candidates, executor)[:num_top_picks]
174
 
175
- yield from log("Running tournament …")
176
  with ThreadPoolExecutor(max_workers=max_workers) as ex:
177
  top_k = get_top(top_players, ex)
 
 
178
  else:
179
  top_k = top_players[:num_top_picks]
180
  top_picks_str = "\n\n\n=====================================================\n\n\n".join(top_k)
181
- yield "\n".join(process_log + ["Done"]), hist_fig, top_picks_str
182
 
183
  demo = gr.Interface(
184
  fn=run_tournament,
@@ -201,6 +243,7 @@ demo = gr.Interface(
201
  gr.Textbox(lines=10, label="Process"),
202
  gr.Plot(label="Score Distribution"),
203
  gr.Textbox(lines=50, label="Top picks"),
 
204
  ],
205
  description="Generate multiple completions and use score and pairwise filters to find the best answers.",
206
  )
 
61
  process_log = []
62
  hist_fig = None
63
  top_picks_str = ""
64
+ prompt_tokens = 0
65
+ completion_tokens = 0
66
+ score_outputs: list[str] = []
67
+ pairwise_outputs: list[str] = []
68
+
69
+ def add_usage(usage):
70
+ nonlocal prompt_tokens, completion_tokens
71
+ if not usage:
72
+ return
73
+ pt = getattr(usage, "prompt_tokens", None)
74
+ if pt is None and isinstance(usage, dict):
75
+ pt = usage.get("prompt_tokens")
76
+ ct = getattr(usage, "completion_tokens", None)
77
+ if ct is None and isinstance(usage, dict):
78
+ ct = usage.get("completion_tokens")
79
+ if pt:
80
+ prompt_tokens += pt
81
+ if ct:
82
+ completion_tokens += ct
83
+
84
+ def usage_str():
85
+ return (
86
+ f"Prompt tokens: {prompt_tokens}\n"
87
+ f"Completion tokens: {completion_tokens}\n"
88
+ f"Total tokens: {prompt_tokens + completion_tokens}"
89
+ )
90
+
91
+ def log_completion(prefix: str, text: str):
92
+ disp = text.replace("\n", " ")
93
+ if len(disp) > 100:
94
+ disp = disp[:100] + "…"
95
+ return log(f"{prefix}{disp}")
96
  def log(msg):
97
  process_log.append(msg)
98
  tqdm.write(msg)
99
+ yield "\n".join(process_log), hist_fig, top_picks_str, usage_str()
100
  yield from log("Generating players …")
101
+ all_players, usage = generate_players(
102
  instruction,
103
  n_gen,
104
  model=generate_model,
105
  api_base=api_base,
106
  api_key=api_token,
107
+ return_usage=True,
108
  )
109
+ add_usage(usage)
110
  yield from log(f"{len(all_players)} players generated")
111
+ for i, p in enumerate(all_players, 1):
112
+ yield from log_completion(f"Completion {i}: ", p)
113
  def criteria_block():
114
  return "\n".join(f"{i + 1}) {c}" for i, c in enumerate(criteria_list))
115
 
116
  if enable_score_filter:
117
  def score(player):
118
+ text, usage = prompt_score(
119
+ instruction,
120
+ criteria_list,
121
+ criteria_block(),
122
+ player,
123
+ model=score_model,
124
+ api_base=api_base,
125
+ api_key=api_token,
126
+ return_usage=True,
 
127
  )
128
+ add_usage(usage)
129
+ score_outputs.append(text)
130
+ data = _clean_json(text)
131
  if "scores" in data and isinstance(data["scores"], list):
132
  vals = data["scores"]
133
  return sum(vals) / len(vals) if vals else 0.0
134
  return float(data.get("score", 0))
135
 
 
136
  with ThreadPoolExecutor(max_workers=max_workers) as ex:
137
  scores = {
138
  p: s
139
  for p, s in zip(
140
  all_players,
141
+ tqdm(ex.map(score, all_players), total=len(all_players)),
142
  )
143
  }
144
  hist_fig = plt.figure()
 
146
  yield from log("Histogram generated")
147
  top_players = sorted(all_players, key=scores.get, reverse=True)[:pool_size]
148
  yield from log(f"Filtered to {len(top_players)} players with best scores")
149
+ for i, txt in enumerate(score_outputs, 1):
150
+ yield from log_completion(f"Score completion {i}: ", txt)
151
  else:
152
  top_players = all_players
153
  if enable_pairwise_filter:
154
  def play(a, b):
155
+ text, usage = prompt_pairwise(
156
+ instruction,
157
+ criteria_block(),
158
+ a,
159
+ b,
160
+ model=pairwise_model,
161
+ api_base=api_base,
162
+ api_key=api_token,
163
+ return_usage=True,
164
+ )
165
+ add_usage(usage)
166
+ pairwise_outputs.append(text)
167
+ winner_label = _clean_json(text).get("winner", "A")
168
  return a if winner_label == "A" else b
169
 
170
  def tournament_round(pairs, executor):
 
213
  candidates = list(set(finalists + semifinalists + get_candidates(champion, lost_to)))
214
  return playoff(candidates, executor)[:num_top_picks]
215
 
 
216
  with ThreadPoolExecutor(max_workers=max_workers) as ex:
217
  top_k = get_top(top_players, ex)
218
+ for i, txt in enumerate(pairwise_outputs, 1):
219
+ yield from log_completion(f"Pairwise completion {i}: ", txt)
220
  else:
221
  top_k = top_players[:num_top_picks]
222
  top_picks_str = "\n\n\n=====================================================\n\n\n".join(top_k)
223
+ yield "\n".join(process_log + ["Done"]), hist_fig, top_picks_str, usage_str()
224
 
225
  demo = gr.Interface(
226
  fn=run_tournament,
 
243
  gr.Textbox(lines=10, label="Process"),
244
  gr.Plot(label="Score Distribution"),
245
  gr.Textbox(lines=50, label="Top picks"),
246
+ gr.Textbox(lines=5, label="Token Usage"),
247
  ],
248
  description="Generate multiple completions and use score and pairwise filters to find the best answers.",
249
  )
tests/test_main.py CHANGED
@@ -82,10 +82,10 @@ def test_run_tournament_full_loop():
82
  patch('main.tqdm', new=dummy_tqdm), \
83
  patch('main.plt.figure', return_value='fig'), \
84
  patch('main.plt.hist'):
85
- mock_gen.return_value = ['p1', 'p2', 'p3', 'p4']
86
  scores = {'p1':3, 'p2':2, 'p3':1, 'p4':0}
87
- mock_score.side_effect = lambda instr, cl, block, player, **kw: json.dumps({'score': scores[player]})
88
- mock_pair.side_effect = lambda instr, block, a, b, **kw: json.dumps({'winner': 'A'})
89
 
90
  results = list(main.run_tournament(
91
  api_base='b',
@@ -103,10 +103,13 @@ def test_run_tournament_full_loop():
103
  enable_pairwise_filter=True,
104
  ))
105
 
106
- process_log, hist_fig, top_picks = results[-1]
107
  assert 'Done' in process_log
108
  assert hist_fig == 'fig'
109
  assert top_picks.strip() in {'p1', 'p2'}
110
- mock_gen.assert_called_once_with('instr', 4, model='gm', api_base='b', api_key='k')
 
 
 
111
  assert mock_score.call_count == 4
112
  assert mock_pair.called
 
82
  patch('main.tqdm', new=dummy_tqdm), \
83
  patch('main.plt.figure', return_value='fig'), \
84
  patch('main.plt.hist'):
85
+ mock_gen.return_value = (['p1', 'p2', 'p3', 'p4'], {'prompt_tokens':1,'completion_tokens':1})
86
  scores = {'p1':3, 'p2':2, 'p3':1, 'p4':0}
87
+ mock_score.side_effect = lambda instr, cl, block, player, **kw: (json.dumps({'score': scores[player]}), {'prompt_tokens':1,'completion_tokens':1})
88
+ mock_pair.side_effect = lambda instr, block, a, b, **kw: (json.dumps({'winner': 'A'}), {'prompt_tokens':1,'completion_tokens':1})
89
 
90
  results = list(main.run_tournament(
91
  api_base='b',
 
103
  enable_pairwise_filter=True,
104
  ))
105
 
106
+ process_log, hist_fig, top_picks, usage = results[-1]
107
  assert 'Done' in process_log
108
  assert hist_fig == 'fig'
109
  assert top_picks.strip() in {'p1', 'p2'}
110
+ mock_gen.assert_called_once_with('instr', 4, model='gm', api_base='b', api_key='k', return_usage=True)
111
+ assert 'Score completion' in process_log
112
+ assert 'Pairwise completion' in process_log
113
+ assert 'Prompt tokens' in usage
114
  assert mock_score.call_count == 4
115
  assert mock_pair.called
tournament_utils.py CHANGED
@@ -18,15 +18,23 @@ def generate_players(
18
  *,
19
  api_base: str | None = None,
20
  api_key: str | None = None,
21
- ):
22
- """Request `n` completions for the instruction using the given model."""
 
 
 
 
 
23
  response = completion(
24
  model=model,
25
  messages=[{"role": "user", "content": instruction}],
26
  n=n,
27
  **_completion_kwargs(api_base, api_key),
28
  )
29
- return [c.message.content.strip() for c in response.choices]
 
 
 
30
 
31
 
32
  def prompt_score(
@@ -38,7 +46,8 @@ def prompt_score(
38
  *,
39
  api_base: str | None = None,
40
  api_key: str | None = None,
41
- ) -> str:
 
42
  """Return a JSON score string evaluating `player` on the criteria."""
43
  example_scores = ", ".join(["1-10"] * len(criteria_list)) or "1-10"
44
  prompt = f"""Evaluate the output below on the following criteria:
@@ -56,7 +65,10 @@ Output:
56
  messages=[{"role": "system", "content": prompt}],
57
  **_completion_kwargs(api_base, api_key),
58
  )
59
- return response.choices[0].message.content.strip()
 
 
 
60
 
61
 
62
  def prompt_pairwise(
@@ -68,7 +80,8 @@ def prompt_pairwise(
68
  *,
69
  api_base: str | None = None,
70
  api_key: str | None = None,
71
- ) -> str:
 
72
  """Return which player wins in JSON using the given criteria."""
73
  prompt = f"""Compare the two players below using:
74
  {criteria_block}
@@ -86,4 +99,7 @@ Players:
86
  messages=[{"role": "system", "content": prompt}],
87
  **_completion_kwargs(api_base, api_key),
88
  )
89
- return response.choices[0].message.content.strip()
 
 
 
 
18
  *,
19
  api_base: str | None = None,
20
  api_key: str | None = None,
21
+ return_usage: bool = False,
22
+ ) -> list[str] | tuple[list[str], object]:
23
+ """Request ``n`` completions for the instruction using the given model.
24
+
25
+ When ``return_usage`` is ``True`` the ``usage`` object from the completion
26
+ response is also returned.
27
+ """
28
  response = completion(
29
  model=model,
30
  messages=[{"role": "user", "content": instruction}],
31
  n=n,
32
  **_completion_kwargs(api_base, api_key),
33
  )
34
+ players = [c.message.content.strip() for c in response.choices]
35
+ if return_usage:
36
+ return players, getattr(response, "usage", None)
37
+ return players
38
 
39
 
40
  def prompt_score(
 
46
  *,
47
  api_base: str | None = None,
48
  api_key: str | None = None,
49
+ return_usage: bool = False,
50
+ ) -> str | tuple[str, object]:
51
  """Return a JSON score string evaluating `player` on the criteria."""
52
  example_scores = ", ".join(["1-10"] * len(criteria_list)) or "1-10"
53
  prompt = f"""Evaluate the output below on the following criteria:
 
65
  messages=[{"role": "system", "content": prompt}],
66
  **_completion_kwargs(api_base, api_key),
67
  )
68
+ text = response.choices[0].message.content.strip()
69
+ if return_usage:
70
+ return text, getattr(response, "usage", None)
71
+ return text
72
 
73
 
74
  def prompt_pairwise(
 
80
  *,
81
  api_base: str | None = None,
82
  api_key: str | None = None,
83
+ return_usage: bool = False,
84
+ ) -> str | tuple[str, object]:
85
  """Return which player wins in JSON using the given criteria."""
86
  prompt = f"""Compare the two players below using:
87
  {criteria_block}
 
99
  messages=[{"role": "system", "content": prompt}],
100
  **_completion_kwargs(api_base, api_key),
101
  )
102
+ text = response.choices[0].message.content.strip()
103
+ if return_usage:
104
+ return text, getattr(response, "usage", None)
105
+ return text