ping98k commited on
Commit
02aebba
·
1 Parent(s): 1633c5a

Add judge instruction options and temperature settings

Browse files
README.md CHANGED
@@ -14,6 +14,11 @@ This project provides a small interface for running "tournaments" between langua
14
  - `GENERATE_MODEL`
15
  - `SCORE_MODEL`
16
  - `PAIRWISE_MODEL`
 
 
 
 
 
17
  - `ENABLE_SCORE_FILTER`
18
  - `ENABLE_PAIRWISE_FILTER`
19
  2. Install dependencies (example with `pip`):
 
14
  - `GENERATE_MODEL`
15
  - `SCORE_MODEL`
16
  - `PAIRWISE_MODEL`
17
+ - `GENERATE_TEMPERATURE`
18
+ - `SCORE_TEMPERATURE`
19
+ - `PAIRWISE_TEMPERATURE`
20
+ - `PASS_INSTRUCTION_TO_SCORE`
21
+ - `PASS_INSTRUCTION_TO_PAIRWISE`
22
  - `ENABLE_SCORE_FILTER`
23
  - `ENABLE_PAIRWISE_FILTER`
24
  2. Install dependencies (example with `pip`):
main.py CHANGED
@@ -17,6 +17,11 @@ PAIRWISE_FILTER_DEFAULT = os.getenv("ENABLE_PAIRWISE_FILTER", "true").lower() ==
17
  GENERATE_MODEL_DEFAULT = os.getenv("GENERATE_MODEL", "gpt-4o-mini")
18
  SCORE_MODEL_DEFAULT = os.getenv("SCORE_MODEL", "gpt-4o-mini")
19
  PAIRWISE_MODEL_DEFAULT = os.getenv("PAIRWISE_MODEL", "gpt-4o-mini")
 
 
 
 
 
20
  CRITERIA_DEFAULT = "Factuality,Instruction Following,Precision"
21
  def _clean_json(txt):
22
  txt = re.sub(r"^```.*?\n|```$", "", txt, flags=re.DOTALL).strip()
@@ -31,6 +36,9 @@ def run_tournament(
31
  generate_model,
32
  score_model,
33
  pairwise_model,
 
 
 
34
  instruction_input,
35
  criteria_input,
36
  n_gen,
@@ -39,6 +47,8 @@ def run_tournament(
39
  max_workers,
40
  enable_score_filter,
41
  enable_pairwise_filter,
 
 
42
  ):
43
  instruction = instruction_input.strip()
44
  criteria_list = [c.strip() for c in criteria_input.split(",") if c.strip()] or ["Factuality", "Instruction Following", "Precision"]
@@ -46,6 +56,12 @@ def run_tournament(
46
  num_top_picks = int(num_top_picks)
47
  pool_size = int(pool_size)
48
  max_workers = int(max_workers)
 
 
 
 
 
 
49
  if not api_base:
50
  api_base = API_BASE_DEFAULT
51
  if not api_token:
@@ -58,6 +74,10 @@ def run_tournament(
58
  pairwise_model = PAIRWISE_MODEL_DEFAULT
59
  enable_score_filter = bool(enable_score_filter)
60
  enable_pairwise_filter = bool(enable_pairwise_filter)
 
 
 
 
61
  process_log = []
62
  hist_fig = None
63
  top_picks_str = ""
@@ -104,6 +124,7 @@ def run_tournament(
104
  model=generate_model,
105
  api_base=api_base,
106
  api_key=api_token,
 
107
  return_usage=True,
108
  )
109
  add_usage(usage)
@@ -123,6 +144,8 @@ def run_tournament(
123
  model=score_model,
124
  api_base=api_base,
125
  api_key=api_token,
 
 
126
  return_usage=True,
127
  )
128
  add_usage(usage)
@@ -161,6 +184,8 @@ def run_tournament(
161
  model=pairwise_model,
162
  api_base=api_base,
163
  api_key=api_token,
 
 
164
  return_usage=True,
165
  )
166
  add_usage(usage)
@@ -231,8 +256,11 @@ demo = gr.Interface(
231
  gr.Textbox(value=API_BASE_DEFAULT, label="API Base Path"),
232
  gr.Textbox(value="", label="API Token", type="password"),
233
  gr.Textbox(value=GENERATE_MODEL_DEFAULT, label="Generation Model"),
 
234
  gr.Textbox(value=SCORE_MODEL_DEFAULT, label="Score Model"),
 
235
  gr.Textbox(value=PAIRWISE_MODEL_DEFAULT, label="Pairwise Model"),
 
236
  gr.Textbox(lines=10, label="Instruction"),
237
  gr.Textbox(value=CRITERIA_DEFAULT, lines=5, label="Criteria (comma separated)"),
238
  gr.Number(value=NUM_GENERATIONS_DEFAULT, label="Number of Generations"),
@@ -241,6 +269,8 @@ demo = gr.Interface(
241
  gr.Number(value=MAX_WORKERS_DEFAULT, label="Max Workers"),
242
  gr.Checkbox(value=SCORE_FILTER_DEFAULT, label="Enable Score Filter"),
243
  gr.Checkbox(value=PAIRWISE_FILTER_DEFAULT, label="Enable Pairwise Filter"),
 
 
244
  ],
245
  outputs=[
246
  gr.Textbox(lines=10, label="Process"),
 
17
  GENERATE_MODEL_DEFAULT = os.getenv("GENERATE_MODEL", "gpt-4o-mini")
18
  SCORE_MODEL_DEFAULT = os.getenv("SCORE_MODEL", "gpt-4o-mini")
19
  PAIRWISE_MODEL_DEFAULT = os.getenv("PAIRWISE_MODEL", "gpt-4o-mini")
20
+ GENERATE_TEMPERATURE_DEFAULT = float(os.getenv("GENERATE_TEMPERATURE", "1.0"))
21
+ SCORE_TEMPERATURE_DEFAULT = float(os.getenv("SCORE_TEMPERATURE", "1.0"))
22
+ PAIRWISE_TEMPERATURE_DEFAULT = float(os.getenv("PAIRWISE_TEMPERATURE", "1.0"))
23
+ SCORE_WITH_INSTRUCTION_DEFAULT = os.getenv("PASS_INSTRUCTION_TO_SCORE", "true").lower() == "true"
24
+ PAIRWISE_WITH_INSTRUCTION_DEFAULT = os.getenv("PASS_INSTRUCTION_TO_PAIRWISE", "true").lower() == "true"
25
  CRITERIA_DEFAULT = "Factuality,Instruction Following,Precision"
26
  def _clean_json(txt):
27
  txt = re.sub(r"^```.*?\n|```$", "", txt, flags=re.DOTALL).strip()
 
36
  generate_model,
37
  score_model,
38
  pairwise_model,
39
+ generate_temperature,
40
+ score_temperature,
41
+ pairwise_temperature,
42
  instruction_input,
43
  criteria_input,
44
  n_gen,
 
47
  max_workers,
48
  enable_score_filter,
49
  enable_pairwise_filter,
50
+ score_with_instruction,
51
+ pairwise_with_instruction,
52
  ):
53
  instruction = instruction_input.strip()
54
  criteria_list = [c.strip() for c in criteria_input.split(",") if c.strip()] or ["Factuality", "Instruction Following", "Precision"]
 
56
  num_top_picks = int(num_top_picks)
57
  pool_size = int(pool_size)
58
  max_workers = int(max_workers)
59
+ if generate_temperature is None:
60
+ generate_temperature = GENERATE_TEMPERATURE_DEFAULT
61
+ if score_temperature is None:
62
+ score_temperature = SCORE_TEMPERATURE_DEFAULT
63
+ if pairwise_temperature is None:
64
+ pairwise_temperature = PAIRWISE_TEMPERATURE_DEFAULT
65
  if not api_base:
66
  api_base = API_BASE_DEFAULT
67
  if not api_token:
 
74
  pairwise_model = PAIRWISE_MODEL_DEFAULT
75
  enable_score_filter = bool(enable_score_filter)
76
  enable_pairwise_filter = bool(enable_pairwise_filter)
77
+ if score_with_instruction is None:
78
+ score_with_instruction = SCORE_WITH_INSTRUCTION_DEFAULT
79
+ if pairwise_with_instruction is None:
80
+ pairwise_with_instruction = PAIRWISE_WITH_INSTRUCTION_DEFAULT
81
  process_log = []
82
  hist_fig = None
83
  top_picks_str = ""
 
124
  model=generate_model,
125
  api_base=api_base,
126
  api_key=api_token,
127
+ temperature=generate_temperature,
128
  return_usage=True,
129
  )
130
  add_usage(usage)
 
144
  model=score_model,
145
  api_base=api_base,
146
  api_key=api_token,
147
+ temperature=score_temperature,
148
+ include_instruction=score_with_instruction,
149
  return_usage=True,
150
  )
151
  add_usage(usage)
 
184
  model=pairwise_model,
185
  api_base=api_base,
186
  api_key=api_token,
187
+ temperature=pairwise_temperature,
188
+ include_instruction=pairwise_with_instruction,
189
  return_usage=True,
190
  )
191
  add_usage(usage)
 
256
  gr.Textbox(value=API_BASE_DEFAULT, label="API Base Path"),
257
  gr.Textbox(value="", label="API Token", type="password"),
258
  gr.Textbox(value=GENERATE_MODEL_DEFAULT, label="Generation Model"),
259
+ gr.Number(value=GENERATE_TEMPERATURE_DEFAULT, label="Generation Temperature"),
260
  gr.Textbox(value=SCORE_MODEL_DEFAULT, label="Score Model"),
261
+ gr.Number(value=SCORE_TEMPERATURE_DEFAULT, label="Score Temperature"),
262
  gr.Textbox(value=PAIRWISE_MODEL_DEFAULT, label="Pairwise Model"),
263
+ gr.Number(value=PAIRWISE_TEMPERATURE_DEFAULT, label="Pairwise Temperature"),
264
  gr.Textbox(lines=10, label="Instruction"),
265
  gr.Textbox(value=CRITERIA_DEFAULT, lines=5, label="Criteria (comma separated)"),
266
  gr.Number(value=NUM_GENERATIONS_DEFAULT, label="Number of Generations"),
 
269
  gr.Number(value=MAX_WORKERS_DEFAULT, label="Max Workers"),
270
  gr.Checkbox(value=SCORE_FILTER_DEFAULT, label="Enable Score Filter"),
271
  gr.Checkbox(value=PAIRWISE_FILTER_DEFAULT, label="Enable Pairwise Filter"),
272
+ gr.Checkbox(value=SCORE_WITH_INSTRUCTION_DEFAULT, label="Pass Instruction to Score Model"),
273
+ gr.Checkbox(value=PAIRWISE_WITH_INSTRUCTION_DEFAULT, label="Pass Instruction to Pairwise Model"),
274
  ],
275
  outputs=[
276
  gr.Textbox(lines=10, label="Process"),
tests/test_main.py CHANGED
@@ -93,6 +93,9 @@ def test_run_tournament_full_loop():
93
  generate_model='gm',
94
  score_model='sm',
95
  pairwise_model='pm',
 
 
 
96
  instruction_input='instr',
97
  criteria_input='c1,c2',
98
  n_gen=4,
@@ -101,13 +104,15 @@ def test_run_tournament_full_loop():
101
  max_workers=1,
102
  enable_score_filter=True,
103
  enable_pairwise_filter=True,
 
 
104
  ))
105
 
106
  process_log, hist_fig, top_picks, usage = results[-1]
107
  assert 'Done' in process_log
108
  assert hist_fig == 'fig'
109
  assert top_picks.strip() in {'p1', 'p2'}
110
- mock_gen.assert_called_once_with('instr', 4, model='gm', api_base='b', api_key='k', return_usage=True)
111
  assert 'Score completion' in process_log
112
  assert 'Pairwise completion' in process_log
113
  assert 'Prompt tokens' in usage
@@ -133,6 +138,9 @@ def test_run_tournament_pairwise_odd_players():
133
  generate_model='gm',
134
  score_model='sm',
135
  pairwise_model='pm',
 
 
 
136
  instruction_input='instr',
137
  criteria_input='c1,c2',
138
  n_gen=3,
@@ -141,6 +149,8 @@ def test_run_tournament_pairwise_odd_players():
141
  max_workers=1,
142
  enable_score_filter=False,
143
  enable_pairwise_filter=True,
 
 
144
  ))
145
 
146
  process_log, fig, top_picks, usage = results[-1]
 
93
  generate_model='gm',
94
  score_model='sm',
95
  pairwise_model='pm',
96
+ generate_temperature=1,
97
+ score_temperature=1,
98
+ pairwise_temperature=1,
99
  instruction_input='instr',
100
  criteria_input='c1,c2',
101
  n_gen=4,
 
104
  max_workers=1,
105
  enable_score_filter=True,
106
  enable_pairwise_filter=True,
107
+ score_with_instruction=True,
108
+ pairwise_with_instruction=True,
109
  ))
110
 
111
  process_log, hist_fig, top_picks, usage = results[-1]
112
  assert 'Done' in process_log
113
  assert hist_fig == 'fig'
114
  assert top_picks.strip() in {'p1', 'p2'}
115
+ mock_gen.assert_called_once_with('instr', 4, model='gm', api_base='b', api_key='k', temperature=1, return_usage=True)
116
  assert 'Score completion' in process_log
117
  assert 'Pairwise completion' in process_log
118
  assert 'Prompt tokens' in usage
 
138
  generate_model='gm',
139
  score_model='sm',
140
  pairwise_model='pm',
141
+ generate_temperature=1,
142
+ score_temperature=1,
143
+ pairwise_temperature=1,
144
  instruction_input='instr',
145
  criteria_input='c1,c2',
146
  n_gen=3,
 
149
  max_workers=1,
150
  enable_score_filter=False,
151
  enable_pairwise_filter=True,
152
+ score_with_instruction=True,
153
+ pairwise_with_instruction=True,
154
  ))
155
 
156
  process_log, fig, top_picks, usage = results[-1]
tests/test_tournament_utils.py CHANGED
@@ -25,26 +25,28 @@ def make_response(contents):
25
  def test_generate_players():
26
  resp = make_response([" player1 ", "player2\n"])
27
  with patch('tournament_utils.completion', return_value=resp) as mock_comp:
28
- players = tu.generate_players('instr', 2, model='m', api_base='b', api_key='k')
29
- mock_comp.assert_called_once_with(model='m', messages=[{'role': 'user', 'content': 'instr'}], n=2, api_base='b', api_key='k')
30
  assert players == ['player1', 'player2']
31
 
32
 
33
  def test_prompt_score():
34
  resp = make_response([" {\"score\": [5]} "])
35
  with patch('tournament_utils.completion', return_value=resp) as mock_comp:
36
- result = tu.prompt_score('instr', ['c1'], 'block', 'pl', model='m', api_base='b', api_key='k')
37
  mock_comp.assert_called_once()
38
  assert mock_comp.call_args.kwargs['api_base'] == 'b'
39
  assert mock_comp.call_args.kwargs['api_key'] == 'k'
 
40
  assert result == '{"score": [5]}'
41
 
42
 
43
  def test_prompt_pairwise():
44
  resp = make_response([" {\"winner\": \"A\"} "])
45
  with patch('tournament_utils.completion', return_value=resp) as mock_comp:
46
- result = tu.prompt_pairwise('instr', 'block', 'A text', 'B text', model='m', api_base='b', api_key='k')
47
  mock_comp.assert_called_once()
48
  assert mock_comp.call_args.kwargs['api_base'] == 'b'
49
  assert mock_comp.call_args.kwargs['api_key'] == 'k'
 
50
  assert result == '{"winner": "A"}'
 
25
  def test_generate_players():
26
  resp = make_response([" player1 ", "player2\n"])
27
  with patch('tournament_utils.completion', return_value=resp) as mock_comp:
28
+ players = tu.generate_players('instr', 2, model='m', api_base='b', api_key='k', temperature=0.5)
29
+ mock_comp.assert_called_once_with(model='m', messages=[{'role': 'user', 'content': 'instr'}], n=2, api_base='b', api_key='k', temperature=0.5)
30
  assert players == ['player1', 'player2']
31
 
32
 
33
  def test_prompt_score():
34
  resp = make_response([" {\"score\": [5]} "])
35
  with patch('tournament_utils.completion', return_value=resp) as mock_comp:
36
+ result = tu.prompt_score('instr', ['c1'], 'block', 'pl', model='m', api_base='b', api_key='k', temperature=0.2, include_instruction=False)
37
  mock_comp.assert_called_once()
38
  assert mock_comp.call_args.kwargs['api_base'] == 'b'
39
  assert mock_comp.call_args.kwargs['api_key'] == 'k'
40
+ assert mock_comp.call_args.kwargs['temperature'] == 0.2
41
  assert result == '{"score": [5]}'
42
 
43
 
44
  def test_prompt_pairwise():
45
  resp = make_response([" {\"winner\": \"A\"} "])
46
  with patch('tournament_utils.completion', return_value=resp) as mock_comp:
47
+ result = tu.prompt_pairwise('instr', 'block', 'A text', 'B text', model='m', api_base='b', api_key='k', temperature=0.3, include_instruction=False)
48
  mock_comp.assert_called_once()
49
  assert mock_comp.call_args.kwargs['api_base'] == 'b'
50
  assert mock_comp.call_args.kwargs['api_key'] == 'k'
51
+ assert mock_comp.call_args.kwargs['temperature'] == 0.3
52
  assert result == '{"winner": "A"}'
tournament_utils.py CHANGED
@@ -1,13 +1,19 @@
1
  from litellm import completion
2
 
3
 
4
- def _completion_kwargs(api_base: str | None, api_key: str | None) -> dict:
 
 
 
 
5
  """Build kwargs for litellm.completion from api settings."""
6
  kwargs: dict = {}
7
  if api_base:
8
  kwargs["api_base"] = api_base
9
  if api_key:
10
  kwargs["api_key"] = api_key
 
 
11
  return kwargs
12
 
13
 
@@ -18,6 +24,7 @@ def generate_players(
18
  *,
19
  api_base: str | None = None,
20
  api_key: str | None = None,
 
21
  return_usage: bool = False,
22
  ) -> list[str] | tuple[list[str], object]:
23
  """Request ``n`` completions for the instruction using the given model.
@@ -29,7 +36,7 @@ def generate_players(
29
  model=model,
30
  messages=[{"role": "user", "content": instruction}],
31
  n=n,
32
- **_completion_kwargs(api_base, api_key),
33
  )
34
  players = [c.message.content.strip() for c in response.choices]
35
  if return_usage:
@@ -46,6 +53,8 @@ def prompt_score(
46
  *,
47
  api_base: str | None = None,
48
  api_key: str | None = None,
 
 
49
  return_usage: bool = False,
50
  ) -> str | tuple[str, object]:
51
  """Return a JSON score string evaluating `player` on the criteria."""
@@ -53,17 +62,14 @@ def prompt_score(
53
  prompt = f"""Evaluate the output below on the following criteria:
54
  {criteria_block}
55
 
56
- Return JSON exactly like: {{"scores": [{example_scores}]}}.
57
-
58
- Instruction:
59
- {instruction}
60
-
61
- Output:
62
- {player}"""
63
  response = completion(
64
  model=model,
65
  messages=[{"role": "system", "content": prompt}],
66
- **_completion_kwargs(api_base, api_key),
67
  )
68
  text = response.choices[0].message.content.strip()
69
  if return_usage:
@@ -80,24 +86,22 @@ def prompt_pairwise(
80
  *,
81
  api_base: str | None = None,
82
  api_key: str | None = None,
 
 
83
  return_usage: bool = False,
84
  ) -> str | tuple[str, object]:
85
  """Return which player wins in JSON using the given criteria."""
86
  prompt = f"""Compare the two players below using:
87
  {criteria_block}
88
 
89
- Return ONLY JSON {{"winner": "A"}} or {{"winner": "B"}}.
90
-
91
- Instruction:
92
- {instruction}
93
-
94
- Players:
95
- <A>{a}</A>
96
- <B>{b}</B>"""
97
  response = completion(
98
  model=model,
99
  messages=[{"role": "system", "content": prompt}],
100
- **_completion_kwargs(api_base, api_key),
101
  )
102
  text = response.choices[0].message.content.strip()
103
  if return_usage:
 
1
  from litellm import completion
2
 
3
 
4
+ def _completion_kwargs(
5
+ api_base: str | None,
6
+ api_key: str | None,
7
+ temperature: float | None,
8
+ ) -> dict:
9
  """Build kwargs for litellm.completion from api settings."""
10
  kwargs: dict = {}
11
  if api_base:
12
  kwargs["api_base"] = api_base
13
  if api_key:
14
  kwargs["api_key"] = api_key
15
+ if temperature is not None:
16
+ kwargs["temperature"] = temperature
17
  return kwargs
18
 
19
 
 
24
  *,
25
  api_base: str | None = None,
26
  api_key: str | None = None,
27
+ temperature: float | None = None,
28
  return_usage: bool = False,
29
  ) -> list[str] | tuple[list[str], object]:
30
  """Request ``n`` completions for the instruction using the given model.
 
36
  model=model,
37
  messages=[{"role": "user", "content": instruction}],
38
  n=n,
39
+ **_completion_kwargs(api_base, api_key, temperature),
40
  )
41
  players = [c.message.content.strip() for c in response.choices]
42
  if return_usage:
 
53
  *,
54
  api_base: str | None = None,
55
  api_key: str | None = None,
56
+ temperature: float | None = None,
57
+ include_instruction: bool = True,
58
  return_usage: bool = False,
59
  ) -> str | tuple[str, object]:
60
  """Return a JSON score string evaluating `player` on the criteria."""
 
62
  prompt = f"""Evaluate the output below on the following criteria:
63
  {criteria_block}
64
 
65
+ Return JSON exactly like: {{"scores": [{example_scores}]}}."""
66
+ if include_instruction:
67
+ prompt += f"\n\nInstruction:\n{instruction}"
68
+ prompt += f"\n\nOutput:\n{player}"
 
 
 
69
  response = completion(
70
  model=model,
71
  messages=[{"role": "system", "content": prompt}],
72
+ **_completion_kwargs(api_base, api_key, temperature),
73
  )
74
  text = response.choices[0].message.content.strip()
75
  if return_usage:
 
86
  *,
87
  api_base: str | None = None,
88
  api_key: str | None = None,
89
+ temperature: float | None = None,
90
+ include_instruction: bool = True,
91
  return_usage: bool = False,
92
  ) -> str | tuple[str, object]:
93
  """Return which player wins in JSON using the given criteria."""
94
  prompt = f"""Compare the two players below using:
95
  {criteria_block}
96
 
97
+ Return ONLY JSON {{"winner": "A"}} or {{"winner": "B"}}."""
98
+ if include_instruction:
99
+ prompt += f"\n\nInstruction:\n{instruction}"
100
+ prompt += f"\n\nPlayers:\n<A>{a}</A>\n<B>{b}</B>"
 
 
 
 
101
  response = completion(
102
  model=model,
103
  messages=[{"role": "system", "content": prompt}],
104
+ **_completion_kwargs(api_base, api_key, temperature),
105
  )
106
  text = response.choices[0].message.content.strip()
107
  if return_usage: