ping98k commited on
Commit
37e55ed
·
1 Parent(s): c43a3f3

Add token usage logging and completion output

Browse files
Files changed (3) hide show
  1. main.py +58 -24
  2. tests/test_main.py +6 -5
  3. tournament_utils.py +23 -7
main.py CHANGED
@@ -61,35 +61,67 @@ def run_tournament(
61
  process_log = []
62
  hist_fig = None
63
  top_picks_str = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def log(msg):
65
  process_log.append(msg)
66
  tqdm.write(msg)
67
- yield "\n".join(process_log), hist_fig, top_picks_str
68
  yield from log("Generating players …")
69
- all_players = generate_players(
70
  instruction,
71
  n_gen,
72
  model=generate_model,
73
  api_base=api_base,
74
  api_key=api_token,
 
75
  )
 
76
  yield from log(f"{len(all_players)} players generated")
 
 
 
 
 
77
  def criteria_block():
78
  return "\n".join(f"{i + 1}) {c}" for i, c in enumerate(criteria_list))
79
 
80
  if enable_score_filter:
81
  def score(player):
82
- data = _clean_json(
83
- prompt_score(
84
- instruction,
85
- criteria_list,
86
- criteria_block(),
87
- player,
88
- model=score_model,
89
- api_base=api_base,
90
- api_key=api_token,
91
- )
92
  )
 
 
93
  if "scores" in data and isinstance(data["scores"], list):
94
  vals = data["scores"]
95
  return sum(vals) / len(vals) if vals else 0.0
@@ -113,17 +145,18 @@ def run_tournament(
113
  top_players = all_players
114
  if enable_pairwise_filter:
115
  def play(a, b):
116
- winner_label = _clean_json(
117
- prompt_pairwise(
118
- instruction,
119
- criteria_block(),
120
- a,
121
- b,
122
- model=pairwise_model,
123
- api_base=api_base,
124
- api_key=api_token,
125
- )
126
- ).get("winner", "A")
 
127
  return a if winner_label == "A" else b
128
 
129
  def tournament_round(pairs, executor):
@@ -178,7 +211,7 @@ def run_tournament(
178
  else:
179
  top_k = top_players[:num_top_picks]
180
  top_picks_str = "\n\n\n=====================================================\n\n\n".join(top_k)
181
- yield "\n".join(process_log + ["Done"]), hist_fig, top_picks_str
182
 
183
  demo = gr.Interface(
184
  fn=run_tournament,
@@ -201,6 +234,7 @@ demo = gr.Interface(
201
  gr.Textbox(lines=10, label="Process"),
202
  gr.Plot(label="Score Distribution"),
203
  gr.Textbox(lines=50, label="Top picks"),
 
204
  ],
205
  description="Generate multiple completions and use score and pairwise filters to find the best answers.",
206
  )
 
61
  process_log = []
62
  hist_fig = None
63
  top_picks_str = ""
64
+ prompt_tokens = 0
65
+ completion_tokens = 0
66
+
67
+ def add_usage(usage):
68
+ nonlocal prompt_tokens, completion_tokens
69
+ if not usage:
70
+ return
71
+ pt = getattr(usage, "prompt_tokens", None)
72
+ if pt is None and isinstance(usage, dict):
73
+ pt = usage.get("prompt_tokens")
74
+ ct = getattr(usage, "completion_tokens", None)
75
+ if ct is None and isinstance(usage, dict):
76
+ ct = usage.get("completion_tokens")
77
+ if pt:
78
+ prompt_tokens += pt
79
+ if ct:
80
+ completion_tokens += ct
81
+
82
+ def usage_str():
83
+ return (
84
+ f"Prompt tokens: {prompt_tokens}\n"
85
+ f"Completion tokens: {completion_tokens}\n"
86
+ f"Total tokens: {prompt_tokens + completion_tokens}"
87
+ )
88
  def log(msg):
89
  process_log.append(msg)
90
  tqdm.write(msg)
91
+ yield "\n".join(process_log), hist_fig, top_picks_str, usage_str()
92
  yield from log("Generating players …")
93
+ all_players, usage = generate_players(
94
  instruction,
95
  n_gen,
96
  model=generate_model,
97
  api_base=api_base,
98
  api_key=api_token,
99
+ return_usage=True,
100
  )
101
+ add_usage(usage)
102
  yield from log(f"{len(all_players)} players generated")
103
+ for i, p in enumerate(all_players, 1):
104
+ disp = p.replace("\n", " ")
105
+ if len(disp) > 100:
106
+ disp = disp[:100] + "…"
107
+ yield from log(f"Completion {i}: {disp}")
108
  def criteria_block():
109
  return "\n".join(f"{i + 1}) {c}" for i, c in enumerate(criteria_list))
110
 
111
  if enable_score_filter:
112
  def score(player):
113
+ text, usage = prompt_score(
114
+ instruction,
115
+ criteria_list,
116
+ criteria_block(),
117
+ player,
118
+ model=score_model,
119
+ api_base=api_base,
120
+ api_key=api_token,
121
+ return_usage=True,
 
122
  )
123
+ add_usage(usage)
124
+ data = _clean_json(text)
125
  if "scores" in data and isinstance(data["scores"], list):
126
  vals = data["scores"]
127
  return sum(vals) / len(vals) if vals else 0.0
 
145
  top_players = all_players
146
  if enable_pairwise_filter:
147
  def play(a, b):
148
+ text, usage = prompt_pairwise(
149
+ instruction,
150
+ criteria_block(),
151
+ a,
152
+ b,
153
+ model=pairwise_model,
154
+ api_base=api_base,
155
+ api_key=api_token,
156
+ return_usage=True,
157
+ )
158
+ add_usage(usage)
159
+ winner_label = _clean_json(text).get("winner", "A")
160
  return a if winner_label == "A" else b
161
 
162
  def tournament_round(pairs, executor):
 
211
  else:
212
  top_k = top_players[:num_top_picks]
213
  top_picks_str = "\n\n\n=====================================================\n\n\n".join(top_k)
214
+ yield "\n".join(process_log + ["Done"]), hist_fig, top_picks_str, usage_str()
215
 
216
  demo = gr.Interface(
217
  fn=run_tournament,
 
234
  gr.Textbox(lines=10, label="Process"),
235
  gr.Plot(label="Score Distribution"),
236
  gr.Textbox(lines=50, label="Top picks"),
237
+ gr.Textbox(lines=5, label="Token Usage"),
238
  ],
239
  description="Generate multiple completions and use score and pairwise filters to find the best answers.",
240
  )
tests/test_main.py CHANGED
@@ -82,10 +82,10 @@ def test_run_tournament_full_loop():
82
  patch('main.tqdm', new=dummy_tqdm), \
83
  patch('main.plt.figure', return_value='fig'), \
84
  patch('main.plt.hist'):
85
- mock_gen.return_value = ['p1', 'p2', 'p3', 'p4']
86
  scores = {'p1':3, 'p2':2, 'p3':1, 'p4':0}
87
- mock_score.side_effect = lambda instr, cl, block, player, **kw: json.dumps({'score': scores[player]})
88
- mock_pair.side_effect = lambda instr, block, a, b, **kw: json.dumps({'winner': 'A'})
89
 
90
  results = list(main.run_tournament(
91
  api_base='b',
@@ -103,10 +103,11 @@ def test_run_tournament_full_loop():
103
  enable_pairwise_filter=True,
104
  ))
105
 
106
- process_log, hist_fig, top_picks = results[-1]
107
  assert 'Done' in process_log
108
  assert hist_fig == 'fig'
109
  assert top_picks.strip() in {'p1', 'p2'}
110
- mock_gen.assert_called_once_with('instr', 4, model='gm', api_base='b', api_key='k')
 
111
  assert mock_score.call_count == 4
112
  assert mock_pair.called
 
82
  patch('main.tqdm', new=dummy_tqdm), \
83
  patch('main.plt.figure', return_value='fig'), \
84
  patch('main.plt.hist'):
85
+ mock_gen.return_value = (['p1', 'p2', 'p3', 'p4'], {'prompt_tokens':1,'completion_tokens':1})
86
  scores = {'p1':3, 'p2':2, 'p3':1, 'p4':0}
87
+ mock_score.side_effect = lambda instr, cl, block, player, **kw: (json.dumps({'score': scores[player]}), {'prompt_tokens':1,'completion_tokens':1})
88
+ mock_pair.side_effect = lambda instr, block, a, b, **kw: (json.dumps({'winner': 'A'}), {'prompt_tokens':1,'completion_tokens':1})
89
 
90
  results = list(main.run_tournament(
91
  api_base='b',
 
103
  enable_pairwise_filter=True,
104
  ))
105
 
106
+ process_log, hist_fig, top_picks, usage = results[-1]
107
  assert 'Done' in process_log
108
  assert hist_fig == 'fig'
109
  assert top_picks.strip() in {'p1', 'p2'}
110
+ mock_gen.assert_called_once_with('instr', 4, model='gm', api_base='b', api_key='k', return_usage=True)
111
+ assert 'Prompt tokens' in usage
112
  assert mock_score.call_count == 4
113
  assert mock_pair.called
tournament_utils.py CHANGED
@@ -18,15 +18,23 @@ def generate_players(
18
  *,
19
  api_base: str | None = None,
20
  api_key: str | None = None,
21
- ):
22
- """Request `n` completions for the instruction using the given model."""
 
 
 
 
 
23
  response = completion(
24
  model=model,
25
  messages=[{"role": "user", "content": instruction}],
26
  n=n,
27
  **_completion_kwargs(api_base, api_key),
28
  )
29
- return [c.message.content.strip() for c in response.choices]
 
 
 
30
 
31
 
32
  def prompt_score(
@@ -38,7 +46,8 @@ def prompt_score(
38
  *,
39
  api_base: str | None = None,
40
  api_key: str | None = None,
41
- ) -> str:
 
42
  """Return a JSON score string evaluating `player` on the criteria."""
43
  example_scores = ", ".join(["1-10"] * len(criteria_list)) or "1-10"
44
  prompt = f"""Evaluate the output below on the following criteria:
@@ -56,7 +65,10 @@ Output:
56
  messages=[{"role": "system", "content": prompt}],
57
  **_completion_kwargs(api_base, api_key),
58
  )
59
- return response.choices[0].message.content.strip()
 
 
 
60
 
61
 
62
  def prompt_pairwise(
@@ -68,7 +80,8 @@ def prompt_pairwise(
68
  *,
69
  api_base: str | None = None,
70
  api_key: str | None = None,
71
- ) -> str:
 
72
  """Return which player wins in JSON using the given criteria."""
73
  prompt = f"""Compare the two players below using:
74
  {criteria_block}
@@ -86,4 +99,7 @@ Players:
86
  messages=[{"role": "system", "content": prompt}],
87
  **_completion_kwargs(api_base, api_key),
88
  )
89
- return response.choices[0].message.content.strip()
 
 
 
 
18
  *,
19
  api_base: str | None = None,
20
  api_key: str | None = None,
21
+ return_usage: bool = False,
22
+ ) -> list[str] | tuple[list[str], object]:
23
+ """Request ``n`` completions for the instruction using the given model.
24
+
25
+ When ``return_usage`` is ``True`` the ``usage`` object from the completion
26
+ response is also returned.
27
+ """
28
  response = completion(
29
  model=model,
30
  messages=[{"role": "user", "content": instruction}],
31
  n=n,
32
  **_completion_kwargs(api_base, api_key),
33
  )
34
+ players = [c.message.content.strip() for c in response.choices]
35
+ if return_usage:
36
+ return players, getattr(response, "usage", None)
37
+ return players
38
 
39
 
40
  def prompt_score(
 
46
  *,
47
  api_base: str | None = None,
48
  api_key: str | None = None,
49
+ return_usage: bool = False,
50
+ ) -> str | tuple[str, object]:
51
  """Return a JSON score string evaluating `player` on the criteria."""
52
  example_scores = ", ".join(["1-10"] * len(criteria_list)) or "1-10"
53
  prompt = f"""Evaluate the output below on the following criteria:
 
65
  messages=[{"role": "system", "content": prompt}],
66
  **_completion_kwargs(api_base, api_key),
67
  )
68
+ text = response.choices[0].message.content.strip()
69
+ if return_usage:
70
+ return text, getattr(response, "usage", None)
71
+ return text
72
 
73
 
74
  def prompt_pairwise(
 
80
  *,
81
  api_base: str | None = None,
82
  api_key: str | None = None,
83
+ return_usage: bool = False,
84
+ ) -> str | tuple[str, object]:
85
  """Return which player wins in JSON using the given criteria."""
86
  prompt = f"""Compare the two players below using:
87
  {criteria_block}
 
99
  messages=[{"role": "system", "content": prompt}],
100
  **_completion_kwargs(api_base, api_key),
101
  )
102
+ text = response.choices[0].message.content.strip()
103
+ if return_usage:
104
+ return text, getattr(response, "usage", None)
105
+ return text