ping98k commited on
Commit
249284d
·
1 Parent(s): dc010c0

Add full loop test for run_tournament

Browse files
README.md CHANGED
@@ -11,6 +11,9 @@ This project provides a small interface for running "tournaments" between langua
11
  - `NUM_GENERATIONS`
12
  - `OPENAI_API_BASE`
13
  - `OPENAI_API_KEY`
 
 
 
14
  - `ENABLE_SCORE_FILTER`
15
  - `ENABLE_PAIRWISE_FILTER`
16
  2. Install dependencies (example with `pip`):
 
11
  - `NUM_GENERATIONS`
12
  - `OPENAI_API_BASE`
13
  - `OPENAI_API_KEY`
14
+ - `GENERATE_MODEL`
15
+ - `SCORE_MODEL`
16
+ - `PAIRWISE_MODEL`
17
  - `ENABLE_SCORE_FILTER`
18
  - `ENABLE_PAIRWISE_FILTER`
19
  2. Install dependencies (example with `pip`):
main.py CHANGED
@@ -4,7 +4,7 @@ import os, json, re, ast, gradio as gr
4
  from concurrent.futures import ThreadPoolExecutor, as_completed
5
  from tqdm import tqdm
6
  import matplotlib.pyplot as plt
7
- from tournament_utils import generate_players, prompt_score, prompt_play
8
 
9
  NUM_TOP_PICKS_DEFAULT = int(os.getenv("NUM_TOP_PICKS", 3))
10
  POOL_SIZE_DEFAULT = int(os.getenv("POOL_SIZE", 5))
@@ -14,6 +14,9 @@ API_BASE_DEFAULT = os.getenv("OPENAI_API_BASE", "")
14
  API_TOKEN_DEFAULT = os.getenv("OPENAI_API_KEY", "")
15
  SCORE_FILTER_DEFAULT = os.getenv("ENABLE_SCORE_FILTER", "true").lower() == "true"
16
  PAIRWISE_FILTER_DEFAULT = os.getenv("ENABLE_PAIRWISE_FILTER", "true").lower() == "true"
 
 
 
17
 
18
  def _clean_json(txt):
19
  txt = re.sub(r"^```.*?\n|```$", "", txt, flags=re.DOTALL).strip()
@@ -25,6 +28,9 @@ def _clean_json(txt):
25
  def run_tournament(
26
  api_base,
27
  api_token,
 
 
 
28
  instruction_input,
29
  criteria_input,
30
  n_gen,
@@ -40,10 +46,16 @@ def run_tournament(
40
  num_top_picks = int(num_top_picks)
41
  pool_size = int(pool_size)
42
  max_workers = int(max_workers)
43
- if api_base:
44
- os.environ["OPENAI_API_BASE"] = api_base
45
- if api_token:
46
- os.environ["OPENAI_API_KEY"] = api_token
 
 
 
 
 
 
47
  enable_score_filter = bool(enable_score_filter)
48
  enable_pairwise_filter = bool(enable_pairwise_filter)
49
  process_log = []
@@ -54,7 +66,13 @@ def run_tournament(
54
  tqdm.write(msg)
55
  yield "\n".join(process_log), hist_fig, top_picks_str
56
  yield from log("Generating players …")
57
- all_players = generate_players(instruction, n_gen)
 
 
 
 
 
 
58
  yield from log(f"{len(all_players)} players generated")
59
  def criteria_block():
60
  return "\n".join(f"{i + 1}) {c}" for i, c in enumerate(criteria_list))
@@ -62,7 +80,15 @@ def run_tournament(
62
  if enable_score_filter:
63
  def score(player):
64
  data = _clean_json(
65
- prompt_score(instruction, criteria_list, criteria_block(), player)
 
 
 
 
 
 
 
 
66
  )
67
  if "scores" in data and isinstance(data["scores"], list):
68
  vals = data["scores"]
@@ -88,7 +114,15 @@ def run_tournament(
88
  if enable_pairwise_filter:
89
  def play(a, b):
90
  winner_label = _clean_json(
91
- prompt_play(instruction, criteria_block(), a, b)
 
 
 
 
 
 
 
 
92
  ).get("winner", "A")
93
  return a if winner_label == "A" else b
94
 
@@ -151,6 +185,9 @@ demo = gr.Interface(
151
  inputs=[
152
  gr.Textbox(value=API_BASE_DEFAULT, label="API Base Path"),
153
  gr.Textbox(value="", label="API Token", type="password"),
 
 
 
154
  gr.Textbox(lines=10, label="Instruction"),
155
  gr.Textbox(lines=5, label="Criteria (comma separated)"),
156
  gr.Number(value=NUM_GENERATIONS_DEFAULT, label="Number of Generations"),
 
4
  from concurrent.futures import ThreadPoolExecutor, as_completed
5
  from tqdm import tqdm
6
  import matplotlib.pyplot as plt
7
+ from tournament_utils import generate_players, prompt_score, prompt_pairwise
8
 
9
  NUM_TOP_PICKS_DEFAULT = int(os.getenv("NUM_TOP_PICKS", 3))
10
  POOL_SIZE_DEFAULT = int(os.getenv("POOL_SIZE", 5))
 
14
  API_TOKEN_DEFAULT = os.getenv("OPENAI_API_KEY", "")
15
  SCORE_FILTER_DEFAULT = os.getenv("ENABLE_SCORE_FILTER", "true").lower() == "true"
16
  PAIRWISE_FILTER_DEFAULT = os.getenv("ENABLE_PAIRWISE_FILTER", "true").lower() == "true"
17
+ GENERATE_MODEL_DEFAULT = os.getenv("GENERATE_MODEL", "gpt-4o-mini")
18
+ SCORE_MODEL_DEFAULT = os.getenv("SCORE_MODEL", "gpt-4o-mini")
19
+ PAIRWISE_MODEL_DEFAULT = os.getenv("PAIRWISE_MODEL", "gpt-4o-mini")
20
 
21
  def _clean_json(txt):
22
  txt = re.sub(r"^```.*?\n|```$", "", txt, flags=re.DOTALL).strip()
 
28
  def run_tournament(
29
  api_base,
30
  api_token,
31
+ generate_model,
32
+ score_model,
33
+ pairwise_model,
34
  instruction_input,
35
  criteria_input,
36
  n_gen,
 
46
  num_top_picks = int(num_top_picks)
47
  pool_size = int(pool_size)
48
  max_workers = int(max_workers)
49
+ if not api_base:
50
+ api_base = API_BASE_DEFAULT
51
+ if not api_token:
52
+ api_token = API_TOKEN_DEFAULT
53
+ if not generate_model:
54
+ generate_model = GENERATE_MODEL_DEFAULT
55
+ if not score_model:
56
+ score_model = SCORE_MODEL_DEFAULT
57
+ if not pairwise_model:
58
+ pairwise_model = PAIRWISE_MODEL_DEFAULT
59
  enable_score_filter = bool(enable_score_filter)
60
  enable_pairwise_filter = bool(enable_pairwise_filter)
61
  process_log = []
 
66
  tqdm.write(msg)
67
  yield "\n".join(process_log), hist_fig, top_picks_str
68
  yield from log("Generating players …")
69
+ all_players = generate_players(
70
+ instruction,
71
+ n_gen,
72
+ model=generate_model,
73
+ api_base=api_base,
74
+ api_key=api_token,
75
+ )
76
  yield from log(f"{len(all_players)} players generated")
77
  def criteria_block():
78
  return "\n".join(f"{i + 1}) {c}" for i, c in enumerate(criteria_list))
 
80
  if enable_score_filter:
81
  def score(player):
82
  data = _clean_json(
83
+ prompt_score(
84
+ instruction,
85
+ criteria_list,
86
+ criteria_block(),
87
+ player,
88
+ model=score_model,
89
+ api_base=api_base,
90
+ api_key=api_token,
91
+ )
92
  )
93
  if "scores" in data and isinstance(data["scores"], list):
94
  vals = data["scores"]
 
114
  if enable_pairwise_filter:
115
  def play(a, b):
116
  winner_label = _clean_json(
117
+ prompt_pairwise(
118
+ instruction,
119
+ criteria_block(),
120
+ a,
121
+ b,
122
+ model=pairwise_model,
123
+ api_base=api_base,
124
+ api_key=api_token,
125
+ )
126
  ).get("winner", "A")
127
  return a if winner_label == "A" else b
128
 
 
185
  inputs=[
186
  gr.Textbox(value=API_BASE_DEFAULT, label="API Base Path"),
187
  gr.Textbox(value="", label="API Token", type="password"),
188
+ gr.Textbox(value=GENERATE_MODEL_DEFAULT, label="Generation Model"),
189
+ gr.Textbox(value=SCORE_MODEL_DEFAULT, label="Score Model"),
190
+ gr.Textbox(value=PAIRWISE_MODEL_DEFAULT, label="Pairwise Model"),
191
  gr.Textbox(lines=10, label="Instruction"),
192
  gr.Textbox(lines=5, label="Criteria (comma separated)"),
193
  gr.Number(value=NUM_GENERATIONS_DEFAULT, label="Number of Generations"),
tests/test_main.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os, types, json
2
+ from unittest.mock import patch, MagicMock
3
+
4
+ # Ensure project root in path
5
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
6
+
7
+ # Provide dummy litellm module so import succeeds
8
+ fake_litellm = types.ModuleType('litellm')
9
+ fake_litellm.completion = MagicMock()
10
+ sys.modules.setdefault('litellm', fake_litellm)
11
+
12
+ # Provide dummy dotenv module
13
+ fake_dotenv = types.ModuleType('dotenv')
14
+ fake_dotenv.load_dotenv = MagicMock()
15
+ sys.modules.setdefault('dotenv', fake_dotenv)
16
+
17
+ # Dummy gradio module so import succeeds
18
+ fake_gradio = types.ModuleType('gradio')
19
+ fake_gradio.Interface = MagicMock(return_value=MagicMock(launch=MagicMock()))
20
+ fake_gradio.Textbox = MagicMock
21
+ fake_gradio.Number = MagicMock
22
+ fake_gradio.Checkbox = MagicMock
23
+ fake_gradio.Plot = MagicMock
24
+ sys.modules.setdefault('gradio', fake_gradio)
25
+
26
+ # Dummy tqdm module for write method
27
+ class FakeTqdmModule(types.ModuleType):
28
+ def __init__(self):
29
+ super().__init__('tqdm')
30
+ self.write = MagicMock()
31
+ def __call__(self, iterable=None, total=None):
32
+ return iterable
33
+
34
+ fake_tqdm_mod = FakeTqdmModule()
35
+ fake_tqdm_mod.tqdm = fake_tqdm_mod
36
+ sys.modules.setdefault('tqdm', fake_tqdm_mod)
37
+
38
+ # Dummy matplotlib module
39
+ fake_plt = types.ModuleType('matplotlib.pyplot')
40
+ fake_plt.figure = MagicMock(return_value='fig')
41
+ fake_plt.hist = MagicMock()
42
+ fake_matplotlib = types.ModuleType('matplotlib')
43
+ fake_matplotlib.pyplot = fake_plt
44
+ sys.modules.setdefault('matplotlib', fake_matplotlib)
45
+ sys.modules.setdefault('matplotlib.pyplot', fake_plt)
46
+
47
+ import main
48
+
49
+ class DummyFuture:
50
+ def __init__(self, func, *args):
51
+ self._func = func
52
+ self._args = args
53
+ def result(self):
54
+ return self._func(*self._args)
55
+
56
+ class DummyExecutor:
57
+ def __init__(self, *args, **kwargs):
58
+ pass
59
+ def __enter__(self):
60
+ return self
61
+ def __exit__(self, exc_type, exc, tb):
62
+ pass
63
+ def submit(self, func, *args):
64
+ return DummyFuture(func, *args)
65
+ def map(self, func, iterable):
66
+ for item in iterable:
67
+ yield func(item)
68
+
69
+ class DummyTqdm:
70
+ def __call__(self, iterable=None, total=None):
71
+ return iterable
72
+ def write(self, msg):
73
+ pass
74
+
75
+ def test_run_tournament_full_loop():
76
+ dummy_tqdm = DummyTqdm()
77
+ with patch('main.generate_players') as mock_gen, \
78
+ patch('main.prompt_score') as mock_score, \
79
+ patch('main.prompt_pairwise') as mock_pair, \
80
+ patch('main.ThreadPoolExecutor', return_value=DummyExecutor()) as MockExec, \
81
+ patch('main.as_completed', new=lambda futs: futs), \
82
+ patch('main.tqdm', new=dummy_tqdm), \
83
+ patch('main.plt.figure', return_value='fig'), \
84
+ patch('main.plt.hist'):
85
+ mock_gen.return_value = ['p1', 'p2', 'p3', 'p4']
86
+ scores = {'p1':3, 'p2':2, 'p3':1, 'p4':0}
87
+ mock_score.side_effect = lambda instr, cl, block, player, **kw: json.dumps({'score': scores[player]})
88
+ mock_pair.side_effect = lambda instr, block, a, b, **kw: json.dumps({'winner': 'A'})
89
+
90
+ results = list(main.run_tournament(
91
+ api_base='b',
92
+ api_token='k',
93
+ generate_model='gm',
94
+ score_model='sm',
95
+ pairwise_model='pm',
96
+ instruction_input='instr',
97
+ criteria_input='c1,c2',
98
+ n_gen=4,
99
+ pool_size=2,
100
+ num_top_picks=1,
101
+ max_workers=1,
102
+ enable_score_filter=True,
103
+ enable_pairwise_filter=True,
104
+ ))
105
+
106
+ process_log, hist_fig, top_picks = results[-1]
107
+ assert 'Done' in process_log
108
+ assert hist_fig == 'fig'
109
+ assert top_picks.strip() in {'p1', 'p2'}
110
+ mock_gen.assert_called_once_with('instr', 4, model='gm', api_base='b', api_key='k')
111
+ assert mock_score.call_count == 4
112
+ assert mock_pair.called
tests/test_tournament_utils.py CHANGED
@@ -25,22 +25,26 @@ def make_response(contents):
25
  def test_generate_players():
26
  resp = make_response([" player1 ", "player2\n"])
27
  with patch('tournament_utils.completion', return_value=resp) as mock_comp:
28
- players = tu.generate_players('instr', 2, model='m')
29
- mock_comp.assert_called_once_with(model='m', messages=[{'role': 'user', 'content': 'instr'}], n=2)
30
  assert players == ['player1', 'player2']
31
 
32
 
33
  def test_prompt_score():
34
  resp = make_response([" {\"score\": [5]} "])
35
  with patch('tournament_utils.completion', return_value=resp) as mock_comp:
36
- result = tu.prompt_score('instr', ['c1'], 'block', 'pl', model='m')
37
  mock_comp.assert_called_once()
 
 
38
  assert result == '{"score": [5]}'
39
 
40
 
41
- def test_prompt_play():
42
  resp = make_response([" {\"winner\": \"A\"} "])
43
  with patch('tournament_utils.completion', return_value=resp) as mock_comp:
44
- result = tu.prompt_play('instr', 'block', 'A text', 'B text', model='m')
45
  mock_comp.assert_called_once()
 
 
46
  assert result == '{"winner": "A"}'
 
25
  def test_generate_players():
26
  resp = make_response([" player1 ", "player2\n"])
27
  with patch('tournament_utils.completion', return_value=resp) as mock_comp:
28
+ players = tu.generate_players('instr', 2, model='m', api_base='b', api_key='k')
29
+ mock_comp.assert_called_once_with(model='m', messages=[{'role': 'user', 'content': 'instr'}], n=2, api_base='b', api_key='k')
30
  assert players == ['player1', 'player2']
31
 
32
 
33
  def test_prompt_score():
34
  resp = make_response([" {\"score\": [5]} "])
35
  with patch('tournament_utils.completion', return_value=resp) as mock_comp:
36
+ result = tu.prompt_score('instr', ['c1'], 'block', 'pl', model='m', api_base='b', api_key='k')
37
  mock_comp.assert_called_once()
38
+ assert mock_comp.call_args.kwargs['api_base'] == 'b'
39
+ assert mock_comp.call_args.kwargs['api_key'] == 'k'
40
  assert result == '{"score": [5]}'
41
 
42
 
43
+ def test_prompt_pairwise():
44
  resp = make_response([" {\"winner\": \"A\"} "])
45
  with patch('tournament_utils.completion', return_value=resp) as mock_comp:
46
+ result = tu.prompt_pairwise('instr', 'block', 'A text', 'B text', model='m', api_base='b', api_key='k')
47
  mock_comp.assert_called_once()
48
+ assert mock_comp.call_args.kwargs['api_base'] == 'b'
49
+ assert mock_comp.call_args.kwargs['api_key'] == 'k'
50
  assert result == '{"winner": "A"}'
tournament_utils.py CHANGED
@@ -1,12 +1,30 @@
1
  from litellm import completion
2
 
3
 
4
- def generate_players(instruction: str, n: int, model: str = "gpt-4o-mini"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  """Request `n` completions for the instruction using the given model."""
6
  response = completion(
7
  model=model,
8
  messages=[{"role": "user", "content": instruction}],
9
  n=n,
 
10
  )
11
  return [c.message.content.strip() for c in response.choices]
12
 
@@ -17,6 +35,9 @@ def prompt_score(
17
  criteria_block: str,
18
  player: str,
19
  model: str = "gpt-4o-mini",
 
 
 
20
  ) -> str:
21
  """Return a JSON score string evaluating `player` on the criteria."""
22
  example_scores = ", ".join(["1-10"] * len(criteria_list)) or "1-10"
@@ -30,11 +51,24 @@ Instruction:
30
 
31
  Output:
32
  {player}"""
33
- response = completion(model=model, messages=[{"role": "system", "content": prompt}])
 
 
 
 
34
  return response.choices[0].message.content.strip()
35
 
36
 
37
- def prompt_play(instruction: str, criteria_block: str, a: str, b: str, model: str = "gpt-4o-mini") -> str:
 
 
 
 
 
 
 
 
 
38
  """Return which player wins in JSON using the given criteria."""
39
  prompt = f"""Compare the two players below using:
40
  {criteria_block}
@@ -47,5 +81,9 @@ Instruction:
47
  Players:
48
  <A>{a}</A>
49
  <B>{b}</B>"""
50
- response = completion(model=model, messages=[{"role": "system", "content": prompt}])
 
 
 
 
51
  return response.choices[0].message.content.strip()
 
1
  from litellm import completion
2
 
3
 
4
+ def _completion_kwargs(api_base: str | None, api_key: str | None) -> dict:
5
+ """Build kwargs for litellm.completion from api settings."""
6
+ kwargs: dict = {}
7
+ if api_base:
8
+ kwargs["api_base"] = api_base
9
+ if api_key:
10
+ kwargs["api_key"] = api_key
11
+ return kwargs
12
+
13
+
14
+ def generate_players(
15
+ instruction: str,
16
+ n: int,
17
+ model: str = "gpt-4o-mini",
18
+ *,
19
+ api_base: str | None = None,
20
+ api_key: str | None = None,
21
+ ):
22
  """Request `n` completions for the instruction using the given model."""
23
  response = completion(
24
  model=model,
25
  messages=[{"role": "user", "content": instruction}],
26
  n=n,
27
+ **_completion_kwargs(api_base, api_key),
28
  )
29
  return [c.message.content.strip() for c in response.choices]
30
 
 
35
  criteria_block: str,
36
  player: str,
37
  model: str = "gpt-4o-mini",
38
+ *,
39
+ api_base: str | None = None,
40
+ api_key: str | None = None,
41
  ) -> str:
42
  """Return a JSON score string evaluating `player` on the criteria."""
43
  example_scores = ", ".join(["1-10"] * len(criteria_list)) or "1-10"
 
51
 
52
  Output:
53
  {player}"""
54
+ response = completion(
55
+ model=model,
56
+ messages=[{"role": "system", "content": prompt}],
57
+ **_completion_kwargs(api_base, api_key),
58
+ )
59
  return response.choices[0].message.content.strip()
60
 
61
 
62
+ def prompt_pairwise(
63
+ instruction: str,
64
+ criteria_block: str,
65
+ a: str,
66
+ b: str,
67
+ model: str = "gpt-4o-mini",
68
+ *,
69
+ api_base: str | None = None,
70
+ api_key: str | None = None,
71
+ ) -> str:
72
  """Return which player wins in JSON using the given criteria."""
73
  prompt = f"""Compare the two players below using:
74
  {criteria_block}
 
81
  Players:
82
  <A>{a}</A>
83
  <B>{b}</B>"""
84
+ response = completion(
85
+ model=model,
86
+ messages=[{"role": "system", "content": prompt}],
87
+ **_completion_kwargs(api_base, api_key),
88
+ )
89
  return response.choices[0].message.content.strip()