ping98k commited on
Commit
3404ee0
·
1 Parent(s): 348dd93

Add configurable thinking budget

Browse files
README.md CHANGED
@@ -21,6 +21,15 @@ This project provides a small interface for running "tournaments" between langua
21
  - `PASS_INSTRUCTION_TO_PAIRWISE`
22
  - `ENABLE_SCORE_FILTER`
23
  - `ENABLE_PAIRWISE_FILTER`
 
 
 
 
 
 
 
 
 
24
  2. Install dependencies (example with `pip`):
25
  ```bash
26
  pip install gradio litellm python-dotenv tqdm matplotlib
 
21
  - `PASS_INSTRUCTION_TO_PAIRWISE`
22
  - `ENABLE_SCORE_FILTER`
23
  - `ENABLE_PAIRWISE_FILTER`
24
+ - `ENABLE_GENERATE_THINKING`
25
+ - `ENABLE_SCORE_THINKING`
26
+ - `ENABLE_PAIRWISE_THINKING`
27
+ - `THINKING_BUDGET_TOKENS`
28
+
29
+ When any of the thinking flags are enabled, the app sends
30
+ `thinking={"type": "enabled", "budget_tokens": $THINKING_BUDGET_TOKENS}` with each
31
+ `litellm.completion` call for that model. Otherwise it sends
32
+ `thinking={"type": "disabled", "budget_tokens": 0}`.
33
  2. Install dependencies (example with `pip`):
34
  ```bash
35
  pip install gradio litellm python-dotenv tqdm matplotlib
main.py CHANGED
@@ -45,6 +45,10 @@ SCORE_TEMPERATURE_DEFAULT = float(os.getenv("SCORE_TEMPERATURE", "0.6"))
45
  PAIRWISE_TEMPERATURE_DEFAULT = float(os.getenv("PAIRWISE_TEMPERATURE", "0.6"))
46
  SCORE_WITH_INSTRUCTION_DEFAULT = os.getenv("PASS_INSTRUCTION_TO_SCORE", "true").lower() == "true"
47
  PAIRWISE_WITH_INSTRUCTION_DEFAULT = os.getenv("PASS_INSTRUCTION_TO_PAIRWISE", "true").lower() == "true"
 
 
 
 
48
  CRITERIA_DEFAULT = "Factuality,Instruction Following,Precision"
49
  def _clean_json(txt):
50
  txt = re.sub(r"^```.*?\n|```$", "", txt, flags=re.DOTALL).strip()
@@ -72,6 +76,9 @@ def run_tournament(
72
  enable_pairwise_filter,
73
  score_with_instruction,
74
  pairwise_with_instruction,
 
 
 
75
  ):
76
  instruction = instruction_input.strip()
77
  criteria_list = [c.strip() for c in criteria_input.split(",") if c.strip()] or ["Factuality", "Instruction Following", "Precision"]
@@ -101,6 +108,12 @@ def run_tournament(
101
  score_with_instruction = SCORE_WITH_INSTRUCTION_DEFAULT
102
  if pairwise_with_instruction is None:
103
  pairwise_with_instruction = PAIRWISE_WITH_INSTRUCTION_DEFAULT
 
 
 
 
 
 
104
  process_log = []
105
  hist_fig = None
106
  top_picks_str = ""
@@ -150,6 +163,8 @@ def run_tournament(
150
  api_base=api_base,
151
  api_key=api_token,
152
  temperature=generate_temperature,
 
 
153
  return_usage=True,
154
  )
155
  add_usage(usage)
@@ -174,6 +189,8 @@ def run_tournament(
174
  api_key=api_token,
175
  temperature=score_temperature,
176
  include_instruction=score_with_instruction,
 
 
177
  return_usage=True,
178
  )
179
  add_usage(usage)
@@ -212,6 +229,8 @@ def run_tournament(
212
  api_key=api_token,
213
  temperature=pairwise_temperature,
214
  include_instruction=pairwise_with_instruction,
 
 
215
  return_usage=True,
216
  )
217
  add_usage(usage)
@@ -302,6 +321,9 @@ demo = gr.Interface(
302
  gr.Checkbox(value=PAIRWISE_FILTER_DEFAULT, label="Enable Pairwise Filter"),
303
  gr.Checkbox(value=SCORE_WITH_INSTRUCTION_DEFAULT, label="Pass Instruction to Score Model"),
304
  gr.Checkbox(value=PAIRWISE_WITH_INSTRUCTION_DEFAULT, label="Pass Instruction to Pairwise Model"),
 
 
 
305
  ],
306
  outputs=[
307
  gr.Textbox(lines=10, label="Process"),
 
45
  PAIRWISE_TEMPERATURE_DEFAULT = float(os.getenv("PAIRWISE_TEMPERATURE", "0.6"))
46
  SCORE_WITH_INSTRUCTION_DEFAULT = os.getenv("PASS_INSTRUCTION_TO_SCORE", "true").lower() == "true"
47
  PAIRWISE_WITH_INSTRUCTION_DEFAULT = os.getenv("PASS_INSTRUCTION_TO_PAIRWISE", "true").lower() == "true"
48
+ GENERATE_THINKING_DEFAULT = os.getenv("ENABLE_GENERATE_THINKING", "false").lower() == "true"
49
+ SCORE_THINKING_DEFAULT = os.getenv("ENABLE_SCORE_THINKING", "false").lower() == "true"
50
+ PAIRWISE_THINKING_DEFAULT = os.getenv("ENABLE_PAIRWISE_THINKING", "false").lower() == "true"
51
+ THINKING_BUDGET_TOKENS_DEFAULT = int(os.getenv("THINKING_BUDGET_TOKENS", "1024"))
52
  CRITERIA_DEFAULT = "Factuality,Instruction Following,Precision"
53
  def _clean_json(txt):
54
  txt = re.sub(r"^```.*?\n|```$", "", txt, flags=re.DOTALL).strip()
 
76
  enable_pairwise_filter,
77
  score_with_instruction,
78
  pairwise_with_instruction,
79
+ generate_thinking,
80
+ score_thinking,
81
+ pairwise_thinking,
82
  ):
83
  instruction = instruction_input.strip()
84
  criteria_list = [c.strip() for c in criteria_input.split(",") if c.strip()] or ["Factuality", "Instruction Following", "Precision"]
 
108
  score_with_instruction = SCORE_WITH_INSTRUCTION_DEFAULT
109
  if pairwise_with_instruction is None:
110
  pairwise_with_instruction = PAIRWISE_WITH_INSTRUCTION_DEFAULT
111
+ if generate_thinking is None:
112
+ generate_thinking = GENERATE_THINKING_DEFAULT
113
+ if score_thinking is None:
114
+ score_thinking = SCORE_THINKING_DEFAULT
115
+ if pairwise_thinking is None:
116
+ pairwise_thinking = PAIRWISE_THINKING_DEFAULT
117
  process_log = []
118
  hist_fig = None
119
  top_picks_str = ""
 
163
  api_base=api_base,
164
  api_key=api_token,
165
  temperature=generate_temperature,
166
+ thinking=generate_thinking,
167
+ budget_tokens=THINKING_BUDGET_TOKENS_DEFAULT,
168
  return_usage=True,
169
  )
170
  add_usage(usage)
 
189
  api_key=api_token,
190
  temperature=score_temperature,
191
  include_instruction=score_with_instruction,
192
+ thinking=score_thinking,
193
+ budget_tokens=THINKING_BUDGET_TOKENS_DEFAULT,
194
  return_usage=True,
195
  )
196
  add_usage(usage)
 
229
  api_key=api_token,
230
  temperature=pairwise_temperature,
231
  include_instruction=pairwise_with_instruction,
232
+ thinking=pairwise_thinking,
233
+ budget_tokens=THINKING_BUDGET_TOKENS_DEFAULT,
234
  return_usage=True,
235
  )
236
  add_usage(usage)
 
321
  gr.Checkbox(value=PAIRWISE_FILTER_DEFAULT, label="Enable Pairwise Filter"),
322
  gr.Checkbox(value=SCORE_WITH_INSTRUCTION_DEFAULT, label="Pass Instruction to Score Model"),
323
  gr.Checkbox(value=PAIRWISE_WITH_INSTRUCTION_DEFAULT, label="Pass Instruction to Pairwise Model"),
324
+ gr.Checkbox(value=GENERATE_THINKING_DEFAULT, label="Enable Thinking (Generate)"),
325
+ gr.Checkbox(value=SCORE_THINKING_DEFAULT, label="Enable Thinking (Score)"),
326
+ gr.Checkbox(value=PAIRWISE_THINKING_DEFAULT, label="Enable Thinking (Pairwise)"),
327
  ],
328
  outputs=[
329
  gr.Textbox(lines=10, label="Process"),
tests/test_main.py CHANGED
@@ -106,13 +106,16 @@ def test_run_tournament_full_loop():
106
  enable_pairwise_filter=True,
107
  score_with_instruction=True,
108
  pairwise_with_instruction=True,
 
 
 
109
  ))
110
 
111
  process_log, hist_fig, top_picks, usage = results[-1]
112
  assert 'Done' in process_log
113
  assert hist_fig == 'fig'
114
  assert top_picks.strip() in {'p1', 'p2'}
115
- mock_gen.assert_called_once_with('instr', 4, model='gm', api_base='b', api_key='k', temperature=1, return_usage=True)
116
  assert 'Score completion' in process_log
117
  assert 'Pairwise completion' in process_log
118
  assert 'Prompt tokens' in usage
@@ -151,6 +154,9 @@ def test_run_tournament_pairwise_odd_players():
151
  enable_pairwise_filter=True,
152
  score_with_instruction=True,
153
  pairwise_with_instruction=True,
 
 
 
154
  ))
155
 
156
  process_log, fig, top_picks, usage = results[-1]
 
106
  enable_pairwise_filter=True,
107
  score_with_instruction=True,
108
  pairwise_with_instruction=True,
109
+ generate_thinking=True,
110
+ score_thinking=True,
111
+ pairwise_thinking=True,
112
  ))
113
 
114
  process_log, hist_fig, top_picks, usage = results[-1]
115
  assert 'Done' in process_log
116
  assert hist_fig == 'fig'
117
  assert top_picks.strip() in {'p1', 'p2'}
118
+ mock_gen.assert_called_once_with('instr', 4, model='gm', api_base='b', api_key='k', temperature=1, thinking=True, budget_tokens=1024, return_usage=True)
119
  assert 'Score completion' in process_log
120
  assert 'Pairwise completion' in process_log
121
  assert 'Prompt tokens' in usage
 
154
  enable_pairwise_filter=True,
155
  score_with_instruction=True,
156
  pairwise_with_instruction=True,
157
+ generate_thinking=True,
158
+ score_thinking=True,
159
+ pairwise_thinking=True,
160
  ))
161
 
162
  process_log, fig, top_picks, usage = results[-1]
tests/test_tournament_utils.py CHANGED
@@ -26,7 +26,7 @@ def test_generate_players():
26
  resp = make_response([" player1 ", "player2\n"])
27
  with patch('tournament_utils.completion', return_value=resp) as mock_comp:
28
  players = tu.generate_players('instr', 2, model='m', api_base='b', api_key='k', temperature=0.5)
29
- mock_comp.assert_called_once_with(model='m', messages=[{'role': 'user', 'content': 'instr'}], n=2, api_base='b', api_key='k', temperature=0.5)
30
  assert players == ['player1', 'player2']
31
 
32
 
@@ -50,3 +50,25 @@ def test_prompt_pairwise():
50
  assert mock_comp.call_args.kwargs['api_key'] == 'k'
51
  assert mock_comp.call_args.kwargs['temperature'] == 0.3
52
  assert result == '{"winner": "A"}'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  resp = make_response([" player1 ", "player2\n"])
27
  with patch('tournament_utils.completion', return_value=resp) as mock_comp:
28
  players = tu.generate_players('instr', 2, model='m', api_base='b', api_key='k', temperature=0.5)
29
+ mock_comp.assert_called_once_with(model='m', messages=[{'role': 'user', 'content': 'instr'}], n=2, api_base='b', api_key='k', temperature=0.5, thinking={'type': 'disabled', 'budget_tokens': 0})
30
  assert players == ['player1', 'player2']
31
 
32
 
 
50
  assert mock_comp.call_args.kwargs['api_key'] == 'k'
51
  assert mock_comp.call_args.kwargs['temperature'] == 0.3
52
  assert result == '{"winner": "A"}'
53
+
54
+
55
+ def test_thinking_passed_to_completion():
56
+ resp = make_response(["ok"])
57
+ with patch('tournament_utils.completion', return_value=resp) as mock_comp:
58
+ tu.generate_players('i', 1, thinking=True)
59
+ tu.prompt_score('i', ['c'], 'block', 'p', thinking=True)
60
+ tu.prompt_pairwise('i', 'block', 'a', 'b', thinking=True)
61
+ assert mock_comp.call_count == 3
62
+ for call in mock_comp.call_args_list:
63
+ assert call.kwargs['thinking'] == {'type': 'enabled', 'budget_tokens': 1024}
64
+
65
+
66
+ def test_thinking_disabled_by_default():
67
+ resp = make_response(["ok"])
68
+ with patch('tournament_utils.completion', return_value=resp) as mock_comp:
69
+ tu.generate_players('i', 1)
70
+ tu.prompt_score('i', ['c'], 'block', 'p')
71
+ tu.prompt_pairwise('i', 'block', 'a', 'b')
72
+ assert mock_comp.call_count == 3
73
+ for call in mock_comp.call_args_list:
74
+ assert call.kwargs['thinking'] == {'type': 'disabled', 'budget_tokens': 0}
tournament_utils.py CHANGED
@@ -1,5 +1,8 @@
 
1
  from litellm import completion
2
 
 
 
3
 
4
  def _completion_kwargs(
5
  api_base: str | None,
@@ -25,6 +28,8 @@ def generate_players(
25
  api_base: str | None = None,
26
  api_key: str | None = None,
27
  temperature: float | None = None,
 
 
28
  return_usage: bool = False,
29
  ) -> list[str] | tuple[list[str], object]:
30
  """Request ``n`` completions for the instruction using the given model.
@@ -32,11 +37,17 @@ def generate_players(
32
  When ``return_usage`` is ``True`` the ``usage`` object from the completion
33
  response is also returned.
34
  """
 
 
 
 
 
 
35
  response = completion(
36
  model=model,
37
- messages=[{"role": "user", "content": instruction}],
38
  n=n,
39
- **_completion_kwargs(api_base, api_key, temperature),
40
  )
41
  players = [c.message.content.strip() for c in response.choices]
42
  if return_usage:
@@ -55,6 +66,8 @@ def prompt_score(
55
  api_key: str | None = None,
56
  temperature: float | None = None,
57
  include_instruction: bool = True,
 
 
58
  return_usage: bool = False,
59
  ) -> str | tuple[str, object]:
60
  """Return a JSON score string evaluating `player` on the criteria."""
@@ -66,10 +79,15 @@ Return JSON exactly like: {{"scores": [{example_scores}]}}."""
66
  if include_instruction:
67
  prompt += f"\n\nInstruction:\n{instruction}"
68
  prompt += f"\n\nOutput:\n{player}"
 
 
 
 
 
69
  response = completion(
70
  model=model,
71
  messages=[{"role": "system", "content": prompt}],
72
- **_completion_kwargs(api_base, api_key, temperature),
73
  )
74
  text = response.choices[0].message.content.strip()
75
  if return_usage:
@@ -88,6 +106,8 @@ def prompt_pairwise(
88
  api_key: str | None = None,
89
  temperature: float | None = None,
90
  include_instruction: bool = True,
 
 
91
  return_usage: bool = False,
92
  ) -> str | tuple[str, object]:
93
  """Return which player wins in JSON using the given criteria."""
@@ -98,10 +118,15 @@ Return ONLY JSON {{"winner": "A"}} or {{"winner": "B"}}."""
98
  if include_instruction:
99
  prompt += f"\n\nInstruction:\n{instruction}"
100
  prompt += f"\n\nPlayers:\n<A>{a}</A>\n<B>{b}</B>"
 
 
 
 
 
101
  response = completion(
102
  model=model,
103
  messages=[{"role": "system", "content": prompt}],
104
- **_completion_kwargs(api_base, api_key, temperature),
105
  )
106
  text = response.choices[0].message.content.strip()
107
  if return_usage:
 
1
+ import os
2
  from litellm import completion
3
 
4
+ BUDGET_TOKENS_DEFAULT = int(os.getenv("THINKING_BUDGET_TOKENS", "1024"))
5
+
6
 
7
  def _completion_kwargs(
8
  api_base: str | None,
 
28
  api_base: str | None = None,
29
  api_key: str | None = None,
30
  temperature: float | None = None,
31
+ thinking: bool = False,
32
+ budget_tokens: int = BUDGET_TOKENS_DEFAULT,
33
  return_usage: bool = False,
34
  ) -> list[str] | tuple[list[str], object]:
35
  """Request ``n`` completions for the instruction using the given model.
 
37
  When ``return_usage`` is ``True`` the ``usage`` object from the completion
38
  response is also returned.
39
  """
40
+ messages = [{"role": "user", "content": instruction}]
41
+ kwargs = _completion_kwargs(api_base, api_key, temperature)
42
+ kwargs["thinking"] = {
43
+ "type": "enabled" if thinking else "disabled",
44
+ "budget_tokens": budget_tokens if thinking else 0,
45
+ }
46
  response = completion(
47
  model=model,
48
+ messages=messages,
49
  n=n,
50
+ **kwargs,
51
  )
52
  players = [c.message.content.strip() for c in response.choices]
53
  if return_usage:
 
66
  api_key: str | None = None,
67
  temperature: float | None = None,
68
  include_instruction: bool = True,
69
+ thinking: bool = False,
70
+ budget_tokens: int = BUDGET_TOKENS_DEFAULT,
71
  return_usage: bool = False,
72
  ) -> str | tuple[str, object]:
73
  """Return a JSON score string evaluating `player` on the criteria."""
 
79
  if include_instruction:
80
  prompt += f"\n\nInstruction:\n{instruction}"
81
  prompt += f"\n\nOutput:\n{player}"
82
+ kwargs = _completion_kwargs(api_base, api_key, temperature)
83
+ kwargs["thinking"] = {
84
+ "type": "enabled" if thinking else "disabled",
85
+ "budget_tokens": budget_tokens if thinking else 0,
86
+ }
87
  response = completion(
88
  model=model,
89
  messages=[{"role": "system", "content": prompt}],
90
+ **kwargs,
91
  )
92
  text = response.choices[0].message.content.strip()
93
  if return_usage:
 
106
  api_key: str | None = None,
107
  temperature: float | None = None,
108
  include_instruction: bool = True,
109
+ thinking: bool = False,
110
+ budget_tokens: int = BUDGET_TOKENS_DEFAULT,
111
  return_usage: bool = False,
112
  ) -> str | tuple[str, object]:
113
  """Return which player wins in JSON using the given criteria."""
 
118
  if include_instruction:
119
  prompt += f"\n\nInstruction:\n{instruction}"
120
  prompt += f"\n\nPlayers:\n<A>{a}</A>\n<B>{b}</B>"
121
+ kwargs = _completion_kwargs(api_base, api_key, temperature)
122
+ kwargs["thinking"] = {
123
+ "type": "enabled" if thinking else "disabled",
124
+ "budget_tokens": budget_tokens if thinking else 0,
125
+ }
126
  response = completion(
127
  model=model,
128
  messages=[{"role": "system", "content": prompt}],
129
+ **kwargs,
130
  )
131
  text = response.choices[0].message.content.strip()
132
  if return_usage: