ping98k commited on
Commit
dc010c0
·
unverified ·
2 Parent(s): f94af77 a959b3c

Merge pull request #2 from ping98k/codex/add-inputs-for-configuration-overrides

Browse files
Files changed (4) hide show
  1. README.md +6 -2
  2. main.py +119 -66
  3. tests/test_tournament_utils.py +46 -0
  4. tournament_utils.py +9 -2
README.md CHANGED
@@ -9,6 +9,10 @@ This project provides a small interface for running "tournaments" between langua
9
  - `POOL_SIZE`
10
  - `MAX_WORKERS`
11
  - `NUM_GENERATIONS`
 
 
 
 
12
  2. Install dependencies (example with `pip`):
13
  ```bash
14
  pip install gradio litellm python-dotenv tqdm matplotlib
@@ -17,7 +21,7 @@ This project provides a small interface for running "tournaments" between langua
17
  ```bash
18
  python main.py
19
  ```
20
- 4. Open the displayed local URL to provide an instruction and evaluation criteria.
21
 
22
- The interface will generate multiple answers, score them, and run a head-to-head tournament to find the best outputs.
23
 
 
9
  - `POOL_SIZE`
10
  - `MAX_WORKERS`
11
  - `NUM_GENERATIONS`
12
+ - `OPENAI_API_BASE`
13
+ - `OPENAI_API_KEY`
14
+ - `ENABLE_SCORE_FILTER`
15
+ - `ENABLE_PAIRWISE_FILTER`
16
  2. Install dependencies (example with `pip`):
17
  ```bash
18
  pip install gradio litellm python-dotenv tqdm matplotlib
 
21
  ```bash
22
  python main.py
23
  ```
24
+ 4. Open the displayed local URL. At the top of the page you can optionally override the API base path and token (the token field is blank by default). Additional settings let you configure score and pairwise filtering.
25
 
26
+ The interface will generate multiple answers, optionally filter them by score and run a pairwise tournament to select the best outputs.
27
 
main.py CHANGED
@@ -6,10 +6,14 @@ from tqdm import tqdm
6
  import matplotlib.pyplot as plt
7
  from tournament_utils import generate_players, prompt_score, prompt_play
8
 
9
- NUM_TOP_PICKS_DEFAULT = int(os.getenv("NUM_TOP_PICKS", 5))
10
- POOL_SIZE_DEFAULT = int(os.getenv("POOL_SIZE", 10))
11
  MAX_WORKERS_DEFAULT = int(os.getenv("MAX_WORKERS", 10))
12
- NUM_GENERATIONS_DEFAULT = int(os.getenv("NUM_GENERATIONS", 20))
 
 
 
 
13
 
14
  def _clean_json(txt):
15
  txt = re.sub(r"^```.*?\n|```$", "", txt, flags=re.DOTALL).strip()
@@ -18,13 +22,30 @@ def _clean_json(txt):
18
  except json.JSONDecodeError:
19
  return ast.literal_eval(txt)
20
 
21
- def run_tournament(instruction_input, criteria_input, n_gen, num_top_picks, pool_size, max_workers):
 
 
 
 
 
 
 
 
 
 
 
22
  instruction = instruction_input.strip()
23
  criteria_list = [c.strip() for c in criteria_input.split(",") if c.strip()] or ["Factuality", "Instruction Following", "Precision"]
24
  n_gen = int(n_gen)
25
  num_top_picks = int(num_top_picks)
26
  pool_size = int(pool_size)
27
  max_workers = int(max_workers)
 
 
 
 
 
 
28
  process_log = []
29
  hist_fig = None
30
  top_picks_str = ""
@@ -38,81 +59,113 @@ def run_tournament(instruction_input, criteria_input, n_gen, num_top_picks, pool
38
  def criteria_block():
39
  return "\n".join(f"{i + 1}) {c}" for i, c in enumerate(criteria_list))
40
 
41
- def score(player):
42
- data = _clean_json(prompt_score(instruction, criteria_block(), player))
43
- lst = data.get("score", data.get("scores", []))
44
- return sum(lst) / len(lst) if lst else 0.0
45
- yield from log("Scoring players …")
46
- with ThreadPoolExecutor(max_workers=max_workers) as ex:
47
- scores = {p: s for p, s in zip(all_players, list(tqdm(ex.map(score, all_players), total=len(all_players))))}
48
- hist_fig = plt.figure()
49
- plt.hist(list(scores.values()), bins=10)
50
- yield from log("Histogram generated")
51
- top_players = sorted(all_players, key=scores.get, reverse=True)[:pool_size]
52
- yield from log(f"Filtered to {len(top_players)} players with best scores")
53
- def play(a, b):
54
- winner_label = _clean_json(
55
- prompt_play(instruction, criteria_block(), a, b)
56
- ).get("winner", "A")
57
- return a if winner_label == "A" else b
58
- def tournament_round(pairs, executor):
59
- futures = {executor.submit(play, a, b): (a, b) for a, b in pairs}
60
- results = []
61
- for fut in tqdm(as_completed(futures), total=len(futures)):
62
- a, b = futures[fut]
63
- winner = fut.result()
64
- loser = b if winner == a else a
65
- results.append((winner, loser))
66
- return results
67
- def tournament(players, executor):
68
- lost_to = {}
69
- current = players[:]
70
- while len(current) > 1:
71
- pairs = [(current[i], current[i + 1]) for i in range(0, len(current) - 1, 2)]
72
- round_results = tournament_round(pairs, executor)
73
- for w, l in round_results:
74
- lost_to[l] = w
75
- current = [w for w, _ in round_results]
76
- if len(players) % 2 == 1 and players[-1] not in current:
77
- current.append(players[-1])
78
- return current[0], lost_to
79
- def get_candidates(champion, lost_to):
80
- return [p for p, o in lost_to.items() if o == champion] + [champion]
81
- def playoff(candidates, executor):
82
- wins = {p: 0 for p in candidates}
83
- pairs = [(candidates[i], candidates[j]) for i in range(len(candidates)) for j in range(i + 1, len(candidates))]
84
- futures = {executor.submit(play, a, b): (a, b) for a, b in pairs}
85
- for fut in tqdm(as_completed(futures), total=len(futures)):
86
- wins[fut.result()] += 1
87
- return sorted(candidates, key=lambda p: wins[p], reverse=True)
88
- def get_top(players, executor):
89
- champion, lost_to = tournament(players, executor)
90
- runner_up = lost_to.get(champion)
91
- finalists = [champion] + ([runner_up] if runner_up else [])
92
- semifinalists = [p for p, o in lost_to.items() if o in finalists and p not in finalists]
93
- candidates = list(set(finalists + semifinalists + get_candidates(champion, lost_to)))
94
- return playoff(candidates, executor)[:num_top_picks]
95
- yield from log("Running tournament …")
96
- with ThreadPoolExecutor(max_workers=max_workers) as ex:
97
- top_k = get_top(top_players, ex)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  top_picks_str = "\n\n\n=====================================================\n\n\n".join(top_k)
99
  yield "\n".join(process_log + ["Done"]), hist_fig, top_picks_str
100
 
101
  demo = gr.Interface(
102
  fn=run_tournament,
103
  inputs=[
 
 
104
  gr.Textbox(lines=10, label="Instruction"),
105
  gr.Textbox(lines=5, label="Criteria (comma separated)"),
106
  gr.Number(value=NUM_GENERATIONS_DEFAULT, label="Number of Generations"),
107
- gr.Number(value=NUM_TOP_PICKS_DEFAULT, label="Top Picks (k)"),
108
- gr.Number(value=POOL_SIZE_DEFAULT, label="Filter Size"),
109
- gr.Number(value=MAX_WORKERS_DEFAULT, label="Max Workers")
 
 
110
  ],
111
  outputs=[
112
  gr.Textbox(lines=10, label="Process"),
113
  gr.Plot(label="Score Distribution"),
114
- gr.Textbox(lines=50, label="Top picks")
115
- ]
 
116
  )
117
 
118
  if __name__ == "__main__":
 
6
  import matplotlib.pyplot as plt
7
  from tournament_utils import generate_players, prompt_score, prompt_play
8
 
9
+ NUM_TOP_PICKS_DEFAULT = int(os.getenv("NUM_TOP_PICKS", 3))
10
+ POOL_SIZE_DEFAULT = int(os.getenv("POOL_SIZE", 5))
11
  MAX_WORKERS_DEFAULT = int(os.getenv("MAX_WORKERS", 10))
12
+ NUM_GENERATIONS_DEFAULT = int(os.getenv("NUM_GENERATIONS", 10))
13
+ API_BASE_DEFAULT = os.getenv("OPENAI_API_BASE", "")
14
+ API_TOKEN_DEFAULT = os.getenv("OPENAI_API_KEY", "")
15
+ SCORE_FILTER_DEFAULT = os.getenv("ENABLE_SCORE_FILTER", "true").lower() == "true"
16
+ PAIRWISE_FILTER_DEFAULT = os.getenv("ENABLE_PAIRWISE_FILTER", "true").lower() == "true"
17
 
18
  def _clean_json(txt):
19
  txt = re.sub(r"^```.*?\n|```$", "", txt, flags=re.DOTALL).strip()
 
22
  except json.JSONDecodeError:
23
  return ast.literal_eval(txt)
24
 
25
+ def run_tournament(
26
+ api_base,
27
+ api_token,
28
+ instruction_input,
29
+ criteria_input,
30
+ n_gen,
31
+ pool_size,
32
+ num_top_picks,
33
+ max_workers,
34
+ enable_score_filter,
35
+ enable_pairwise_filter,
36
+ ):
37
  instruction = instruction_input.strip()
38
  criteria_list = [c.strip() for c in criteria_input.split(",") if c.strip()] or ["Factuality", "Instruction Following", "Precision"]
39
  n_gen = int(n_gen)
40
  num_top_picks = int(num_top_picks)
41
  pool_size = int(pool_size)
42
  max_workers = int(max_workers)
43
+ if api_base:
44
+ os.environ["OPENAI_API_BASE"] = api_base
45
+ if api_token:
46
+ os.environ["OPENAI_API_KEY"] = api_token
47
+ enable_score_filter = bool(enable_score_filter)
48
+ enable_pairwise_filter = bool(enable_pairwise_filter)
49
  process_log = []
50
  hist_fig = None
51
  top_picks_str = ""
 
59
  def criteria_block():
60
  return "\n".join(f"{i + 1}) {c}" for i, c in enumerate(criteria_list))
61
 
62
+ if enable_score_filter:
63
+ def score(player):
64
+ data = _clean_json(
65
+ prompt_score(instruction, criteria_list, criteria_block(), player)
66
+ )
67
+ if "scores" in data and isinstance(data["scores"], list):
68
+ vals = data["scores"]
69
+ return sum(vals) / len(vals) if vals else 0.0
70
+ return float(data.get("score", 0))
71
+
72
+ yield from log("Scoring players …")
73
+ with ThreadPoolExecutor(max_workers=max_workers) as ex:
74
+ scores = {
75
+ p: s
76
+ for p, s in zip(
77
+ all_players,
78
+ list(tqdm(ex.map(score, all_players), total=len(all_players))),
79
+ )
80
+ }
81
+ hist_fig = plt.figure()
82
+ plt.hist(list(scores.values()), bins=10)
83
+ yield from log("Histogram generated")
84
+ top_players = sorted(all_players, key=scores.get, reverse=True)[:pool_size]
85
+ yield from log(f"Filtered to {len(top_players)} players with best scores")
86
+ else:
87
+ top_players = all_players
88
+ if enable_pairwise_filter:
89
+ def play(a, b):
90
+ winner_label = _clean_json(
91
+ prompt_play(instruction, criteria_block(), a, b)
92
+ ).get("winner", "A")
93
+ return a if winner_label == "A" else b
94
+
95
+ def tournament_round(pairs, executor):
96
+ futures = {executor.submit(play, a, b): (a, b) for a, b in pairs}
97
+ results = []
98
+ for fut in tqdm(as_completed(futures), total=len(futures)):
99
+ a, b = futures[fut]
100
+ winner = fut.result()
101
+ loser = b if winner == a else a
102
+ results.append((winner, loser))
103
+ return results
104
+
105
+ def tournament(players, executor):
106
+ lost_to = {}
107
+ current = players[:]
108
+ while len(current) > 1:
109
+ pairs = [(current[i], current[i + 1]) for i in range(0, len(current) - 1, 2)]
110
+ round_results = tournament_round(pairs, executor)
111
+ for w, l in round_results:
112
+ lost_to[l] = w
113
+ current = [w for w, _ in round_results]
114
+ if len(players) % 2 == 1 and players[-1] not in current:
115
+ current.append(players[-1])
116
+ return current[0], lost_to
117
+
118
+ def get_candidates(champion, lost_to):
119
+ return [p for p, o in lost_to.items() if o == champion] + [champion]
120
+
121
+ def playoff(candidates, executor):
122
+ wins = {p: 0 for p in candidates}
123
+ pairs = [
124
+ (candidates[i], candidates[j])
125
+ for i in range(len(candidates))
126
+ for j in range(i + 1, len(candidates))
127
+ ]
128
+ futures = {executor.submit(play, a, b): (a, b) for a, b in pairs}
129
+ for fut in tqdm(as_completed(futures), total=len(futures)):
130
+ wins[fut.result()] += 1
131
+ return sorted(candidates, key=lambda p: wins[p], reverse=True)
132
+
133
+ def get_top(players, executor):
134
+ champion, lost_to = tournament(players, executor)
135
+ runner_up = lost_to.get(champion)
136
+ finalists = [champion] + ([runner_up] if runner_up else [])
137
+ semifinalists = [p for p, o in lost_to.items() if o in finalists and p not in finalists]
138
+ candidates = list(set(finalists + semifinalists + get_candidates(champion, lost_to)))
139
+ return playoff(candidates, executor)[:num_top_picks]
140
+
141
+ yield from log("Running tournament …")
142
+ with ThreadPoolExecutor(max_workers=max_workers) as ex:
143
+ top_k = get_top(top_players, ex)
144
+ else:
145
+ top_k = top_players[:num_top_picks]
146
  top_picks_str = "\n\n\n=====================================================\n\n\n".join(top_k)
147
  yield "\n".join(process_log + ["Done"]), hist_fig, top_picks_str
148
 
149
  demo = gr.Interface(
150
  fn=run_tournament,
151
  inputs=[
152
+ gr.Textbox(value=API_BASE_DEFAULT, label="API Base Path"),
153
+ gr.Textbox(value="", label="API Token", type="password"),
154
  gr.Textbox(lines=10, label="Instruction"),
155
  gr.Textbox(lines=5, label="Criteria (comma separated)"),
156
  gr.Number(value=NUM_GENERATIONS_DEFAULT, label="Number of Generations"),
157
+ gr.Number(value=POOL_SIZE_DEFAULT, label="Top Picks Score Filter"),
158
+ gr.Number(value=NUM_TOP_PICKS_DEFAULT, label="Top Picks Pairwise"),
159
+ gr.Number(value=MAX_WORKERS_DEFAULT, label="Max Workers"),
160
+ gr.Checkbox(value=SCORE_FILTER_DEFAULT, label="Enable Score Filter"),
161
+ gr.Checkbox(value=PAIRWISE_FILTER_DEFAULT, label="Enable Pairwise Filter"),
162
  ],
163
  outputs=[
164
  gr.Textbox(lines=10, label="Process"),
165
  gr.Plot(label="Score Distribution"),
166
+ gr.Textbox(lines=50, label="Top picks"),
167
+ ],
168
+ description="Generate multiple completions and use score and pairwise filters to find the best answers.",
169
  )
170
 
171
  if __name__ == "__main__":
tests/test_tournament_utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os, types
2
+ from unittest.mock import patch, MagicMock
3
+
4
+ # Ensure project root in path
5
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
6
+
7
+ # Provide dummy litellm module so import succeeds
8
+ fake_litellm = types.ModuleType('litellm')
9
+ fake_litellm.completion = MagicMock()
10
+ sys.modules.setdefault('litellm', fake_litellm)
11
+
12
+ import tournament_utils as tu
13
+
14
+
15
+ def make_response(contents):
16
+ class Message:
17
+ def __init__(self, content):
18
+ self.content = content
19
+ class Choice:
20
+ def __init__(self, content):
21
+ self.message = Message(content)
22
+ return MagicMock(choices=[Choice(c) for c in contents])
23
+
24
+
25
+ def test_generate_players():
26
+ resp = make_response([" player1 ", "player2\n"])
27
+ with patch('tournament_utils.completion', return_value=resp) as mock_comp:
28
+ players = tu.generate_players('instr', 2, model='m')
29
+ mock_comp.assert_called_once_with(model='m', messages=[{'role': 'user', 'content': 'instr'}], n=2)
30
+ assert players == ['player1', 'player2']
31
+
32
+
33
+ def test_prompt_score():
34
+ resp = make_response([" {\"score\": [5]} "])
35
+ with patch('tournament_utils.completion', return_value=resp) as mock_comp:
36
+ result = tu.prompt_score('instr', ['c1'], 'block', 'pl', model='m')
37
+ mock_comp.assert_called_once()
38
+ assert result == '{"score": [5]}'
39
+
40
+
41
+ def test_prompt_play():
42
+ resp = make_response([" {\"winner\": \"A\"} "])
43
+ with patch('tournament_utils.completion', return_value=resp) as mock_comp:
44
+ result = tu.prompt_play('instr', 'block', 'A text', 'B text', model='m')
45
+ mock_comp.assert_called_once()
46
+ assert result == '{"winner": "A"}'
tournament_utils.py CHANGED
@@ -11,12 +11,19 @@ def generate_players(instruction: str, n: int, model: str = "gpt-4o-mini"):
11
  return [c.message.content.strip() for c in response.choices]
12
 
13
 
14
- def prompt_score(instruction: str, criteria_block: str, player: str, model: str = "gpt-4o-mini") -> str:
 
 
 
 
 
 
15
  """Return a JSON score string evaluating `player` on the criteria."""
 
16
  prompt = f"""Evaluate the output below on the following criteria:
17
  {criteria_block}
18
 
19
- Return JSON exactly like: {{"score": [1-10]}}.
20
 
21
  Instruction:
22
  {instruction}
 
11
  return [c.message.content.strip() for c in response.choices]
12
 
13
 
14
+ def prompt_score(
15
+ instruction: str,
16
+ criteria_list: list[str],
17
+ criteria_block: str,
18
+ player: str,
19
+ model: str = "gpt-4o-mini",
20
+ ) -> str:
21
  """Return a JSON score string evaluating `player` on the criteria."""
22
+ example_scores = ", ".join(["1-10"] * len(criteria_list)) or "1-10"
23
  prompt = f"""Evaluate the output below on the following criteria:
24
  {criteria_block}
25
 
26
+ Return JSON exactly like: {{"scores": [{example_scores}]}}.
27
 
28
  Instruction:
29
  {instruction}