ping98k commited on
Commit
faf573e
·
unverified ·
2 Parent(s): bc6765e cca2d14

Merge pull request #7 from ping98k/codex/add-option-boxes-for-score-and-pairwise-model

Browse files
README.md CHANGED
@@ -14,6 +14,11 @@ This project provides a small interface for running "tournaments" between langua
14
  - `GENERATE_MODEL`
15
  - `SCORE_MODEL`
16
  - `PAIRWISE_MODEL`
 
 
 
 
 
17
  - `ENABLE_SCORE_FILTER`
18
  - `ENABLE_PAIRWISE_FILTER`
19
  2. Install dependencies (example with `pip`):
 
14
  - `GENERATE_MODEL`
15
  - `SCORE_MODEL`
16
  - `PAIRWISE_MODEL`
17
+ - `GENERATE_TEMPERATURE`
18
+ - `SCORE_TEMPERATURE`
19
+ - `PAIRWISE_TEMPERATURE`
20
+ - `PASS_INSTRUCTION_TO_SCORE`
21
+ - `PASS_INSTRUCTION_TO_PAIRWISE`
22
  - `ENABLE_SCORE_FILTER`
23
  - `ENABLE_PAIRWISE_FILTER`
24
  2. Install dependencies (example with `pip`):
main.py CHANGED
@@ -40,6 +40,11 @@ PAIRWISE_FILTER_DEFAULT = os.getenv("ENABLE_PAIRWISE_FILTER", "true").lower() ==
40
  GENERATE_MODEL_DEFAULT = os.getenv("GENERATE_MODEL", "gpt-4o-mini")
41
  SCORE_MODEL_DEFAULT = os.getenv("SCORE_MODEL", "gpt-4o-mini")
42
  PAIRWISE_MODEL_DEFAULT = os.getenv("PAIRWISE_MODEL", "gpt-4o-mini")
 
 
 
 
 
43
  CRITERIA_DEFAULT = "Factuality,Instruction Following,Precision"
44
  def _clean_json(txt):
45
  txt = re.sub(r"^```.*?\n|```$", "", txt, flags=re.DOTALL).strip()
@@ -54,6 +59,9 @@ def run_tournament(
54
  generate_model,
55
  score_model,
56
  pairwise_model,
 
 
 
57
  instruction_input,
58
  criteria_input,
59
  n_gen,
@@ -62,6 +70,8 @@ def run_tournament(
62
  max_workers,
63
  enable_score_filter,
64
  enable_pairwise_filter,
 
 
65
  ):
66
  instruction = instruction_input.strip()
67
  criteria_list = [c.strip() for c in criteria_input.split(",") if c.strip()] or ["Factuality", "Instruction Following", "Precision"]
@@ -69,6 +79,12 @@ def run_tournament(
69
  num_top_picks = int(num_top_picks)
70
  pool_size = int(pool_size)
71
  max_workers = int(max_workers)
 
 
 
 
 
 
72
  if not api_base:
73
  api_base = API_BASE_DEFAULT
74
  if not api_token:
@@ -81,6 +97,10 @@ def run_tournament(
81
  pairwise_model = PAIRWISE_MODEL_DEFAULT
82
  enable_score_filter = bool(enable_score_filter)
83
  enable_pairwise_filter = bool(enable_pairwise_filter)
 
 
 
 
84
  process_log = []
85
  hist_fig = None
86
  top_picks_str = ""
@@ -127,6 +147,7 @@ def run_tournament(
127
  model=generate_model,
128
  api_base=api_base,
129
  api_key=api_token,
 
130
  return_usage=True,
131
  )
132
  add_usage(usage)
@@ -146,6 +167,8 @@ def run_tournament(
146
  model=score_model,
147
  api_base=api_base,
148
  api_key=api_token,
 
 
149
  return_usage=True,
150
  )
151
  add_usage(usage)
@@ -182,6 +205,8 @@ def run_tournament(
182
  model=pairwise_model,
183
  api_base=api_base,
184
  api_key=api_token,
 
 
185
  return_usage=True,
186
  )
187
  add_usage(usage)
@@ -259,6 +284,9 @@ demo = gr.Interface(
259
  gr.Textbox(value=GENERATE_MODEL_DEFAULT, label="Generation Model"),
260
  gr.Textbox(value=SCORE_MODEL_DEFAULT, label="Score Model"),
261
  gr.Textbox(value=PAIRWISE_MODEL_DEFAULT, label="Pairwise Model"),
 
 
 
262
  gr.Textbox(lines=10, label="Instruction"),
263
  gr.Textbox(value=CRITERIA_DEFAULT, lines=5, label="Criteria (comma separated)"),
264
  gr.Number(value=NUM_GENERATIONS_DEFAULT, label="Number of Generations"),
@@ -267,6 +295,8 @@ demo = gr.Interface(
267
  gr.Number(value=MAX_WORKERS_DEFAULT, label="Max Workers"),
268
  gr.Checkbox(value=SCORE_FILTER_DEFAULT, label="Enable Score Filter"),
269
  gr.Checkbox(value=PAIRWISE_FILTER_DEFAULT, label="Enable Pairwise Filter"),
 
 
270
  ],
271
  outputs=[
272
  gr.Textbox(lines=10, label="Process"),
 
40
  GENERATE_MODEL_DEFAULT = os.getenv("GENERATE_MODEL", "gpt-4o-mini")
41
  SCORE_MODEL_DEFAULT = os.getenv("SCORE_MODEL", "gpt-4o-mini")
42
  PAIRWISE_MODEL_DEFAULT = os.getenv("PAIRWISE_MODEL", "gpt-4o-mini")
43
+ GENERATE_TEMPERATURE_DEFAULT = float(os.getenv("GENERATE_TEMPERATURE", "1.0"))
44
+ SCORE_TEMPERATURE_DEFAULT = float(os.getenv("SCORE_TEMPERATURE", "1.0"))
45
+ PAIRWISE_TEMPERATURE_DEFAULT = float(os.getenv("PAIRWISE_TEMPERATURE", "1.0"))
46
+ SCORE_WITH_INSTRUCTION_DEFAULT = os.getenv("PASS_INSTRUCTION_TO_SCORE", "true").lower() == "true"
47
+ PAIRWISE_WITH_INSTRUCTION_DEFAULT = os.getenv("PASS_INSTRUCTION_TO_PAIRWISE", "true").lower() == "true"
48
  CRITERIA_DEFAULT = "Factuality,Instruction Following,Precision"
49
  def _clean_json(txt):
50
  txt = re.sub(r"^```.*?\n|```$", "", txt, flags=re.DOTALL).strip()
 
59
  generate_model,
60
  score_model,
61
  pairwise_model,
62
+ generate_temperature,
63
+ score_temperature,
64
+ pairwise_temperature,
65
  instruction_input,
66
  criteria_input,
67
  n_gen,
 
70
  max_workers,
71
  enable_score_filter,
72
  enable_pairwise_filter,
73
+ score_with_instruction,
74
+ pairwise_with_instruction,
75
  ):
76
  instruction = instruction_input.strip()
77
  criteria_list = [c.strip() for c in criteria_input.split(",") if c.strip()] or ["Factuality", "Instruction Following", "Precision"]
 
79
  num_top_picks = int(num_top_picks)
80
  pool_size = int(pool_size)
81
  max_workers = int(max_workers)
82
+ if generate_temperature is None:
83
+ generate_temperature = GENERATE_TEMPERATURE_DEFAULT
84
+ if score_temperature is None:
85
+ score_temperature = SCORE_TEMPERATURE_DEFAULT
86
+ if pairwise_temperature is None:
87
+ pairwise_temperature = PAIRWISE_TEMPERATURE_DEFAULT
88
  if not api_base:
89
  api_base = API_BASE_DEFAULT
90
  if not api_token:
 
97
  pairwise_model = PAIRWISE_MODEL_DEFAULT
98
  enable_score_filter = bool(enable_score_filter)
99
  enable_pairwise_filter = bool(enable_pairwise_filter)
100
+ if score_with_instruction is None:
101
+ score_with_instruction = SCORE_WITH_INSTRUCTION_DEFAULT
102
+ if pairwise_with_instruction is None:
103
+ pairwise_with_instruction = PAIRWISE_WITH_INSTRUCTION_DEFAULT
104
  process_log = []
105
  hist_fig = None
106
  top_picks_str = ""
 
147
  model=generate_model,
148
  api_base=api_base,
149
  api_key=api_token,
150
+ temperature=generate_temperature,
151
  return_usage=True,
152
  )
153
  add_usage(usage)
 
167
  model=score_model,
168
  api_base=api_base,
169
  api_key=api_token,
170
+ temperature=score_temperature,
171
+ include_instruction=score_with_instruction,
172
  return_usage=True,
173
  )
174
  add_usage(usage)
 
205
  model=pairwise_model,
206
  api_base=api_base,
207
  api_key=api_token,
208
+ temperature=pairwise_temperature,
209
+ include_instruction=pairwise_with_instruction,
210
  return_usage=True,
211
  )
212
  add_usage(usage)
 
284
  gr.Textbox(value=GENERATE_MODEL_DEFAULT, label="Generation Model"),
285
  gr.Textbox(value=SCORE_MODEL_DEFAULT, label="Score Model"),
286
  gr.Textbox(value=PAIRWISE_MODEL_DEFAULT, label="Pairwise Model"),
287
+ gr.Number(value=GENERATE_TEMPERATURE_DEFAULT, label="Generation Temperature"),
288
+ gr.Number(value=SCORE_TEMPERATURE_DEFAULT, label="Score Temperature"),
289
+ gr.Number(value=PAIRWISE_TEMPERATURE_DEFAULT, label="Pairwise Temperature"),
290
  gr.Textbox(lines=10, label="Instruction"),
291
  gr.Textbox(value=CRITERIA_DEFAULT, lines=5, label="Criteria (comma separated)"),
292
  gr.Number(value=NUM_GENERATIONS_DEFAULT, label="Number of Generations"),
 
295
  gr.Number(value=MAX_WORKERS_DEFAULT, label="Max Workers"),
296
  gr.Checkbox(value=SCORE_FILTER_DEFAULT, label="Enable Score Filter"),
297
  gr.Checkbox(value=PAIRWISE_FILTER_DEFAULT, label="Enable Pairwise Filter"),
298
+ gr.Checkbox(value=SCORE_WITH_INSTRUCTION_DEFAULT, label="Pass Instruction to Score Model"),
299
+ gr.Checkbox(value=PAIRWISE_WITH_INSTRUCTION_DEFAULT, label="Pass Instruction to Pairwise Model"),
300
  ],
301
  outputs=[
302
  gr.Textbox(lines=10, label="Process"),
tests/test_main.py CHANGED
@@ -93,6 +93,9 @@ def test_run_tournament_full_loop():
93
  generate_model='gm',
94
  score_model='sm',
95
  pairwise_model='pm',
 
 
 
96
  instruction_input='instr',
97
  criteria_input='c1,c2',
98
  n_gen=4,
@@ -101,13 +104,15 @@ def test_run_tournament_full_loop():
101
  max_workers=1,
102
  enable_score_filter=True,
103
  enable_pairwise_filter=True,
 
 
104
  ))
105
 
106
  process_log, hist_fig, top_picks, usage = results[-1]
107
  assert 'Done' in process_log
108
  assert hist_fig == 'fig'
109
  assert top_picks.strip() in {'p1', 'p2'}
110
- mock_gen.assert_called_once_with('instr', 4, model='gm', api_base='b', api_key='k', return_usage=True)
111
  assert 'Score completion' in process_log
112
  assert 'Pairwise completion' in process_log
113
  assert 'Prompt tokens' in usage
@@ -133,6 +138,9 @@ def test_run_tournament_pairwise_odd_players():
133
  generate_model='gm',
134
  score_model='sm',
135
  pairwise_model='pm',
 
 
 
136
  instruction_input='instr',
137
  criteria_input='c1,c2',
138
  n_gen=3,
@@ -141,6 +149,8 @@ def test_run_tournament_pairwise_odd_players():
141
  max_workers=1,
142
  enable_score_filter=False,
143
  enable_pairwise_filter=True,
 
 
144
  ))
145
 
146
  process_log, fig, top_picks, usage = results[-1]
 
93
  generate_model='gm',
94
  score_model='sm',
95
  pairwise_model='pm',
96
+ generate_temperature=1,
97
+ score_temperature=1,
98
+ pairwise_temperature=1,
99
  instruction_input='instr',
100
  criteria_input='c1,c2',
101
  n_gen=4,
 
104
  max_workers=1,
105
  enable_score_filter=True,
106
  enable_pairwise_filter=True,
107
+ score_with_instruction=True,
108
+ pairwise_with_instruction=True,
109
  ))
110
 
111
  process_log, hist_fig, top_picks, usage = results[-1]
112
  assert 'Done' in process_log
113
  assert hist_fig == 'fig'
114
  assert top_picks.strip() in {'p1', 'p2'}
115
+ mock_gen.assert_called_once_with('instr', 4, model='gm', api_base='b', api_key='k', temperature=1, return_usage=True)
116
  assert 'Score completion' in process_log
117
  assert 'Pairwise completion' in process_log
118
  assert 'Prompt tokens' in usage
 
138
  generate_model='gm',
139
  score_model='sm',
140
  pairwise_model='pm',
141
+ generate_temperature=1,
142
+ score_temperature=1,
143
+ pairwise_temperature=1,
144
  instruction_input='instr',
145
  criteria_input='c1,c2',
146
  n_gen=3,
 
149
  max_workers=1,
150
  enable_score_filter=False,
151
  enable_pairwise_filter=True,
152
+ score_with_instruction=True,
153
+ pairwise_with_instruction=True,
154
  ))
155
 
156
  process_log, fig, top_picks, usage = results[-1]
tests/test_tournament_utils.py CHANGED
@@ -25,26 +25,28 @@ def make_response(contents):
25
  def test_generate_players():
26
  resp = make_response([" player1 ", "player2\n"])
27
  with patch('tournament_utils.completion', return_value=resp) as mock_comp:
28
- players = tu.generate_players('instr', 2, model='m', api_base='b', api_key='k')
29
- mock_comp.assert_called_once_with(model='m', messages=[{'role': 'user', 'content': 'instr'}], n=2, api_base='b', api_key='k')
30
  assert players == ['player1', 'player2']
31
 
32
 
33
  def test_prompt_score():
34
  resp = make_response([" {\"score\": [5]} "])
35
  with patch('tournament_utils.completion', return_value=resp) as mock_comp:
36
- result = tu.prompt_score('instr', ['c1'], 'block', 'pl', model='m', api_base='b', api_key='k')
37
  mock_comp.assert_called_once()
38
  assert mock_comp.call_args.kwargs['api_base'] == 'b'
39
  assert mock_comp.call_args.kwargs['api_key'] == 'k'
 
40
  assert result == '{"score": [5]}'
41
 
42
 
43
  def test_prompt_pairwise():
44
  resp = make_response([" {\"winner\": \"A\"} "])
45
  with patch('tournament_utils.completion', return_value=resp) as mock_comp:
46
- result = tu.prompt_pairwise('instr', 'block', 'A text', 'B text', model='m', api_base='b', api_key='k')
47
  mock_comp.assert_called_once()
48
  assert mock_comp.call_args.kwargs['api_base'] == 'b'
49
  assert mock_comp.call_args.kwargs['api_key'] == 'k'
 
50
  assert result == '{"winner": "A"}'
 
25
  def test_generate_players():
26
  resp = make_response([" player1 ", "player2\n"])
27
  with patch('tournament_utils.completion', return_value=resp) as mock_comp:
28
+ players = tu.generate_players('instr', 2, model='m', api_base='b', api_key='k', temperature=0.5)
29
+ mock_comp.assert_called_once_with(model='m', messages=[{'role': 'user', 'content': 'instr'}], n=2, api_base='b', api_key='k', temperature=0.5)
30
  assert players == ['player1', 'player2']
31
 
32
 
33
  def test_prompt_score():
34
  resp = make_response([" {\"score\": [5]} "])
35
  with patch('tournament_utils.completion', return_value=resp) as mock_comp:
36
+ result = tu.prompt_score('instr', ['c1'], 'block', 'pl', model='m', api_base='b', api_key='k', temperature=0.2, include_instruction=False)
37
  mock_comp.assert_called_once()
38
  assert mock_comp.call_args.kwargs['api_base'] == 'b'
39
  assert mock_comp.call_args.kwargs['api_key'] == 'k'
40
+ assert mock_comp.call_args.kwargs['temperature'] == 0.2
41
  assert result == '{"score": [5]}'
42
 
43
 
44
  def test_prompt_pairwise():
45
  resp = make_response([" {\"winner\": \"A\"} "])
46
  with patch('tournament_utils.completion', return_value=resp) as mock_comp:
47
+ result = tu.prompt_pairwise('instr', 'block', 'A text', 'B text', model='m', api_base='b', api_key='k', temperature=0.3, include_instruction=False)
48
  mock_comp.assert_called_once()
49
  assert mock_comp.call_args.kwargs['api_base'] == 'b'
50
  assert mock_comp.call_args.kwargs['api_key'] == 'k'
51
+ assert mock_comp.call_args.kwargs['temperature'] == 0.3
52
  assert result == '{"winner": "A"}'
tournament_utils.py CHANGED
@@ -1,13 +1,19 @@
1
  from litellm import completion
2
 
3
 
4
- def _completion_kwargs(api_base: str | None, api_key: str | None) -> dict:
 
 
 
 
5
  """Build kwargs for litellm.completion from api settings."""
6
  kwargs: dict = {}
7
  if api_base:
8
  kwargs["api_base"] = api_base
9
  if api_key:
10
  kwargs["api_key"] = api_key
 
 
11
  return kwargs
12
 
13
 
@@ -18,6 +24,7 @@ def generate_players(
18
  *,
19
  api_base: str | None = None,
20
  api_key: str | None = None,
 
21
  return_usage: bool = False,
22
  ) -> list[str] | tuple[list[str], object]:
23
  """Request ``n`` completions for the instruction using the given model.
@@ -29,7 +36,7 @@ def generate_players(
29
  model=model,
30
  messages=[{"role": "user", "content": instruction}],
31
  n=n,
32
- **_completion_kwargs(api_base, api_key),
33
  )
34
  players = [c.message.content.strip() for c in response.choices]
35
  if return_usage:
@@ -46,6 +53,8 @@ def prompt_score(
46
  *,
47
  api_base: str | None = None,
48
  api_key: str | None = None,
 
 
49
  return_usage: bool = False,
50
  ) -> str | tuple[str, object]:
51
  """Return a JSON score string evaluating `player` on the criteria."""
@@ -53,17 +62,14 @@ def prompt_score(
53
  prompt = f"""Evaluate the output below on the following criteria:
54
  {criteria_block}
55
 
56
- Return JSON exactly like: {{"scores": [{example_scores}]}}.
57
-
58
- Instruction:
59
- {instruction}
60
-
61
- Output:
62
- {player}"""
63
  response = completion(
64
  model=model,
65
  messages=[{"role": "system", "content": prompt}],
66
- **_completion_kwargs(api_base, api_key),
67
  )
68
  text = response.choices[0].message.content.strip()
69
  if return_usage:
@@ -80,24 +86,22 @@ def prompt_pairwise(
80
  *,
81
  api_base: str | None = None,
82
  api_key: str | None = None,
 
 
83
  return_usage: bool = False,
84
  ) -> str | tuple[str, object]:
85
  """Return which player wins in JSON using the given criteria."""
86
  prompt = f"""Compare the two players below using:
87
  {criteria_block}
88
 
89
- Return ONLY JSON {{"winner": "A"}} or {{"winner": "B"}}.
90
-
91
- Instruction:
92
- {instruction}
93
-
94
- Players:
95
- <A>{a}</A>
96
- <B>{b}</B>"""
97
  response = completion(
98
  model=model,
99
  messages=[{"role": "system", "content": prompt}],
100
- **_completion_kwargs(api_base, api_key),
101
  )
102
  text = response.choices[0].message.content.strip()
103
  if return_usage:
 
1
  from litellm import completion
2
 
3
 
4
+ def _completion_kwargs(
5
+ api_base: str | None,
6
+ api_key: str | None,
7
+ temperature: float | None,
8
+ ) -> dict:
9
  """Build kwargs for litellm.completion from api settings."""
10
  kwargs: dict = {}
11
  if api_base:
12
  kwargs["api_base"] = api_base
13
  if api_key:
14
  kwargs["api_key"] = api_key
15
+ if temperature is not None:
16
+ kwargs["temperature"] = temperature
17
  return kwargs
18
 
19
 
 
24
  *,
25
  api_base: str | None = None,
26
  api_key: str | None = None,
27
+ temperature: float | None = None,
28
  return_usage: bool = False,
29
  ) -> list[str] | tuple[list[str], object]:
30
  """Request ``n`` completions for the instruction using the given model.
 
36
  model=model,
37
  messages=[{"role": "user", "content": instruction}],
38
  n=n,
39
+ **_completion_kwargs(api_base, api_key, temperature),
40
  )
41
  players = [c.message.content.strip() for c in response.choices]
42
  if return_usage:
 
53
  *,
54
  api_base: str | None = None,
55
  api_key: str | None = None,
56
+ temperature: float | None = None,
57
+ include_instruction: bool = True,
58
  return_usage: bool = False,
59
  ) -> str | tuple[str, object]:
60
  """Return a JSON score string evaluating `player` on the criteria."""
 
62
  prompt = f"""Evaluate the output below on the following criteria:
63
  {criteria_block}
64
 
65
+ Return JSON exactly like: {{"scores": [{example_scores}]}}."""
66
+ if include_instruction:
67
+ prompt += f"\n\nInstruction:\n{instruction}"
68
+ prompt += f"\n\nOutput:\n{player}"
 
 
 
69
  response = completion(
70
  model=model,
71
  messages=[{"role": "system", "content": prompt}],
72
+ **_completion_kwargs(api_base, api_key, temperature),
73
  )
74
  text = response.choices[0].message.content.strip()
75
  if return_usage:
 
86
  *,
87
  api_base: str | None = None,
88
  api_key: str | None = None,
89
+ temperature: float | None = None,
90
+ include_instruction: bool = True,
91
  return_usage: bool = False,
92
  ) -> str | tuple[str, object]:
93
  """Return which player wins in JSON using the given criteria."""
94
  prompt = f"""Compare the two players below using:
95
  {criteria_block}
96
 
97
+ Return ONLY JSON {{"winner": "A"}} or {{"winner": "B"}}."""
98
+ if include_instruction:
99
+ prompt += f"\n\nInstruction:\n{instruction}"
100
+ prompt += f"\n\nPlayers:\n<A>{a}</A>\n<B>{b}</B>"
 
 
 
 
101
  response = completion(
102
  model=model,
103
  messages=[{"role": "system", "content": prompt}],
104
+ **_completion_kwargs(api_base, api_key, temperature),
105
  )
106
  text = response.choices[0].message.content.strip()
107
  if return_usage: