Spaces:
Sleeping
Sleeping
Merge pull request #2 from ping98k/codex/add-inputs-for-configuration-overrides
Browse files- README.md +6 -2
- main.py +119 -66
- tests/test_tournament_utils.py +46 -0
- tournament_utils.py +9 -2
README.md
CHANGED
|
@@ -9,6 +9,10 @@ This project provides a small interface for running "tournaments" between langua
|
|
| 9 |
- `POOL_SIZE`
|
| 10 |
- `MAX_WORKERS`
|
| 11 |
- `NUM_GENERATIONS`
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
2. Install dependencies (example with `pip`):
|
| 13 |
```bash
|
| 14 |
pip install gradio litellm python-dotenv tqdm matplotlib
|
|
@@ -17,7 +21,7 @@ This project provides a small interface for running "tournaments" between langua
|
|
| 17 |
```bash
|
| 18 |
python main.py
|
| 19 |
```
|
| 20 |
-
4. Open the displayed local URL
|
| 21 |
|
| 22 |
-
The interface will generate multiple answers,
|
| 23 |
|
|
|
|
| 9 |
- `POOL_SIZE`
|
| 10 |
- `MAX_WORKERS`
|
| 11 |
- `NUM_GENERATIONS`
|
| 12 |
+
- `OPENAI_API_BASE`
|
| 13 |
+
- `OPENAI_API_KEY`
|
| 14 |
+
- `ENABLE_SCORE_FILTER`
|
| 15 |
+
- `ENABLE_PAIRWISE_FILTER`
|
| 16 |
2. Install dependencies (example with `pip`):
|
| 17 |
```bash
|
| 18 |
pip install gradio litellm python-dotenv tqdm matplotlib
|
|
|
|
| 21 |
```bash
|
| 22 |
python main.py
|
| 23 |
```
|
| 24 |
+
4. Open the displayed local URL. At the top of the page you can optionally override the API base path and token (the token field is blank by default). Additional settings let you configure score and pairwise filtering.
|
| 25 |
|
| 26 |
+
The interface will generate multiple answers, optionally filter them by score and run a pairwise tournament to select the best outputs.
|
| 27 |
|
main.py
CHANGED
|
@@ -6,10 +6,14 @@ from tqdm import tqdm
|
|
| 6 |
import matplotlib.pyplot as plt
|
| 7 |
from tournament_utils import generate_players, prompt_score, prompt_play
|
| 8 |
|
| 9 |
-
NUM_TOP_PICKS_DEFAULT = int(os.getenv("NUM_TOP_PICKS",
|
| 10 |
-
POOL_SIZE_DEFAULT = int(os.getenv("POOL_SIZE",
|
| 11 |
MAX_WORKERS_DEFAULT = int(os.getenv("MAX_WORKERS", 10))
|
| 12 |
-
NUM_GENERATIONS_DEFAULT = int(os.getenv("NUM_GENERATIONS",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
def _clean_json(txt):
|
| 15 |
txt = re.sub(r"^```.*?\n|```$", "", txt, flags=re.DOTALL).strip()
|
|
@@ -18,13 +22,30 @@ def _clean_json(txt):
|
|
| 18 |
except json.JSONDecodeError:
|
| 19 |
return ast.literal_eval(txt)
|
| 20 |
|
| 21 |
-
def run_tournament(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
instruction = instruction_input.strip()
|
| 23 |
criteria_list = [c.strip() for c in criteria_input.split(",") if c.strip()] or ["Factuality", "Instruction Following", "Precision"]
|
| 24 |
n_gen = int(n_gen)
|
| 25 |
num_top_picks = int(num_top_picks)
|
| 26 |
pool_size = int(pool_size)
|
| 27 |
max_workers = int(max_workers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
process_log = []
|
| 29 |
hist_fig = None
|
| 30 |
top_picks_str = ""
|
|
@@ -38,81 +59,113 @@ def run_tournament(instruction_input, criteria_input, n_gen, num_top_picks, pool
|
|
| 38 |
def criteria_block():
|
| 39 |
return "\n".join(f"{i + 1}) {c}" for i, c in enumerate(criteria_list))
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
top_picks_str = "\n\n\n=====================================================\n\n\n".join(top_k)
|
| 99 |
yield "\n".join(process_log + ["Done"]), hist_fig, top_picks_str
|
| 100 |
|
| 101 |
demo = gr.Interface(
|
| 102 |
fn=run_tournament,
|
| 103 |
inputs=[
|
|
|
|
|
|
|
| 104 |
gr.Textbox(lines=10, label="Instruction"),
|
| 105 |
gr.Textbox(lines=5, label="Criteria (comma separated)"),
|
| 106 |
gr.Number(value=NUM_GENERATIONS_DEFAULT, label="Number of Generations"),
|
| 107 |
-
gr.Number(value=
|
| 108 |
-
gr.Number(value=
|
| 109 |
-
gr.Number(value=MAX_WORKERS_DEFAULT, label="Max Workers")
|
|
|
|
|
|
|
| 110 |
],
|
| 111 |
outputs=[
|
| 112 |
gr.Textbox(lines=10, label="Process"),
|
| 113 |
gr.Plot(label="Score Distribution"),
|
| 114 |
-
gr.Textbox(lines=50, label="Top picks")
|
| 115 |
-
]
|
|
|
|
| 116 |
)
|
| 117 |
|
| 118 |
if __name__ == "__main__":
|
|
|
|
| 6 |
import matplotlib.pyplot as plt
|
| 7 |
from tournament_utils import generate_players, prompt_score, prompt_play
|
| 8 |
|
| 9 |
+
NUM_TOP_PICKS_DEFAULT = int(os.getenv("NUM_TOP_PICKS", 3))
|
| 10 |
+
POOL_SIZE_DEFAULT = int(os.getenv("POOL_SIZE", 5))
|
| 11 |
MAX_WORKERS_DEFAULT = int(os.getenv("MAX_WORKERS", 10))
|
| 12 |
+
NUM_GENERATIONS_DEFAULT = int(os.getenv("NUM_GENERATIONS", 10))
|
| 13 |
+
API_BASE_DEFAULT = os.getenv("OPENAI_API_BASE", "")
|
| 14 |
+
API_TOKEN_DEFAULT = os.getenv("OPENAI_API_KEY", "")
|
| 15 |
+
SCORE_FILTER_DEFAULT = os.getenv("ENABLE_SCORE_FILTER", "true").lower() == "true"
|
| 16 |
+
PAIRWISE_FILTER_DEFAULT = os.getenv("ENABLE_PAIRWISE_FILTER", "true").lower() == "true"
|
| 17 |
|
| 18 |
def _clean_json(txt):
|
| 19 |
txt = re.sub(r"^```.*?\n|```$", "", txt, flags=re.DOTALL).strip()
|
|
|
|
| 22 |
except json.JSONDecodeError:
|
| 23 |
return ast.literal_eval(txt)
|
| 24 |
|
| 25 |
+
def run_tournament(
|
| 26 |
+
api_base,
|
| 27 |
+
api_token,
|
| 28 |
+
instruction_input,
|
| 29 |
+
criteria_input,
|
| 30 |
+
n_gen,
|
| 31 |
+
pool_size,
|
| 32 |
+
num_top_picks,
|
| 33 |
+
max_workers,
|
| 34 |
+
enable_score_filter,
|
| 35 |
+
enable_pairwise_filter,
|
| 36 |
+
):
|
| 37 |
instruction = instruction_input.strip()
|
| 38 |
criteria_list = [c.strip() for c in criteria_input.split(",") if c.strip()] or ["Factuality", "Instruction Following", "Precision"]
|
| 39 |
n_gen = int(n_gen)
|
| 40 |
num_top_picks = int(num_top_picks)
|
| 41 |
pool_size = int(pool_size)
|
| 42 |
max_workers = int(max_workers)
|
| 43 |
+
if api_base:
|
| 44 |
+
os.environ["OPENAI_API_BASE"] = api_base
|
| 45 |
+
if api_token:
|
| 46 |
+
os.environ["OPENAI_API_KEY"] = api_token
|
| 47 |
+
enable_score_filter = bool(enable_score_filter)
|
| 48 |
+
enable_pairwise_filter = bool(enable_pairwise_filter)
|
| 49 |
process_log = []
|
| 50 |
hist_fig = None
|
| 51 |
top_picks_str = ""
|
|
|
|
| 59 |
def criteria_block():
|
| 60 |
return "\n".join(f"{i + 1}) {c}" for i, c in enumerate(criteria_list))
|
| 61 |
|
| 62 |
+
if enable_score_filter:
|
| 63 |
+
def score(player):
|
| 64 |
+
data = _clean_json(
|
| 65 |
+
prompt_score(instruction, criteria_list, criteria_block(), player)
|
| 66 |
+
)
|
| 67 |
+
if "scores" in data and isinstance(data["scores"], list):
|
| 68 |
+
vals = data["scores"]
|
| 69 |
+
return sum(vals) / len(vals) if vals else 0.0
|
| 70 |
+
return float(data.get("score", 0))
|
| 71 |
+
|
| 72 |
+
yield from log("Scoring players …")
|
| 73 |
+
with ThreadPoolExecutor(max_workers=max_workers) as ex:
|
| 74 |
+
scores = {
|
| 75 |
+
p: s
|
| 76 |
+
for p, s in zip(
|
| 77 |
+
all_players,
|
| 78 |
+
list(tqdm(ex.map(score, all_players), total=len(all_players))),
|
| 79 |
+
)
|
| 80 |
+
}
|
| 81 |
+
hist_fig = plt.figure()
|
| 82 |
+
plt.hist(list(scores.values()), bins=10)
|
| 83 |
+
yield from log("Histogram generated")
|
| 84 |
+
top_players = sorted(all_players, key=scores.get, reverse=True)[:pool_size]
|
| 85 |
+
yield from log(f"Filtered to {len(top_players)} players with best scores")
|
| 86 |
+
else:
|
| 87 |
+
top_players = all_players
|
| 88 |
+
if enable_pairwise_filter:
|
| 89 |
+
def play(a, b):
|
| 90 |
+
winner_label = _clean_json(
|
| 91 |
+
prompt_play(instruction, criteria_block(), a, b)
|
| 92 |
+
).get("winner", "A")
|
| 93 |
+
return a if winner_label == "A" else b
|
| 94 |
+
|
| 95 |
+
def tournament_round(pairs, executor):
|
| 96 |
+
futures = {executor.submit(play, a, b): (a, b) for a, b in pairs}
|
| 97 |
+
results = []
|
| 98 |
+
for fut in tqdm(as_completed(futures), total=len(futures)):
|
| 99 |
+
a, b = futures[fut]
|
| 100 |
+
winner = fut.result()
|
| 101 |
+
loser = b if winner == a else a
|
| 102 |
+
results.append((winner, loser))
|
| 103 |
+
return results
|
| 104 |
+
|
| 105 |
+
def tournament(players, executor):
|
| 106 |
+
lost_to = {}
|
| 107 |
+
current = players[:]
|
| 108 |
+
while len(current) > 1:
|
| 109 |
+
pairs = [(current[i], current[i + 1]) for i in range(0, len(current) - 1, 2)]
|
| 110 |
+
round_results = tournament_round(pairs, executor)
|
| 111 |
+
for w, l in round_results:
|
| 112 |
+
lost_to[l] = w
|
| 113 |
+
current = [w for w, _ in round_results]
|
| 114 |
+
if len(players) % 2 == 1 and players[-1] not in current:
|
| 115 |
+
current.append(players[-1])
|
| 116 |
+
return current[0], lost_to
|
| 117 |
+
|
| 118 |
+
def get_candidates(champion, lost_to):
|
| 119 |
+
return [p for p, o in lost_to.items() if o == champion] + [champion]
|
| 120 |
+
|
| 121 |
+
def playoff(candidates, executor):
|
| 122 |
+
wins = {p: 0 for p in candidates}
|
| 123 |
+
pairs = [
|
| 124 |
+
(candidates[i], candidates[j])
|
| 125 |
+
for i in range(len(candidates))
|
| 126 |
+
for j in range(i + 1, len(candidates))
|
| 127 |
+
]
|
| 128 |
+
futures = {executor.submit(play, a, b): (a, b) for a, b in pairs}
|
| 129 |
+
for fut in tqdm(as_completed(futures), total=len(futures)):
|
| 130 |
+
wins[fut.result()] += 1
|
| 131 |
+
return sorted(candidates, key=lambda p: wins[p], reverse=True)
|
| 132 |
+
|
| 133 |
+
def get_top(players, executor):
|
| 134 |
+
champion, lost_to = tournament(players, executor)
|
| 135 |
+
runner_up = lost_to.get(champion)
|
| 136 |
+
finalists = [champion] + ([runner_up] if runner_up else [])
|
| 137 |
+
semifinalists = [p for p, o in lost_to.items() if o in finalists and p not in finalists]
|
| 138 |
+
candidates = list(set(finalists + semifinalists + get_candidates(champion, lost_to)))
|
| 139 |
+
return playoff(candidates, executor)[:num_top_picks]
|
| 140 |
+
|
| 141 |
+
yield from log("Running tournament …")
|
| 142 |
+
with ThreadPoolExecutor(max_workers=max_workers) as ex:
|
| 143 |
+
top_k = get_top(top_players, ex)
|
| 144 |
+
else:
|
| 145 |
+
top_k = top_players[:num_top_picks]
|
| 146 |
top_picks_str = "\n\n\n=====================================================\n\n\n".join(top_k)
|
| 147 |
yield "\n".join(process_log + ["Done"]), hist_fig, top_picks_str
|
| 148 |
|
| 149 |
demo = gr.Interface(
|
| 150 |
fn=run_tournament,
|
| 151 |
inputs=[
|
| 152 |
+
gr.Textbox(value=API_BASE_DEFAULT, label="API Base Path"),
|
| 153 |
+
gr.Textbox(value="", label="API Token", type="password"),
|
| 154 |
gr.Textbox(lines=10, label="Instruction"),
|
| 155 |
gr.Textbox(lines=5, label="Criteria (comma separated)"),
|
| 156 |
gr.Number(value=NUM_GENERATIONS_DEFAULT, label="Number of Generations"),
|
| 157 |
+
gr.Number(value=POOL_SIZE_DEFAULT, label="Top Picks Score Filter"),
|
| 158 |
+
gr.Number(value=NUM_TOP_PICKS_DEFAULT, label="Top Picks Pairwise"),
|
| 159 |
+
gr.Number(value=MAX_WORKERS_DEFAULT, label="Max Workers"),
|
| 160 |
+
gr.Checkbox(value=SCORE_FILTER_DEFAULT, label="Enable Score Filter"),
|
| 161 |
+
gr.Checkbox(value=PAIRWISE_FILTER_DEFAULT, label="Enable Pairwise Filter"),
|
| 162 |
],
|
| 163 |
outputs=[
|
| 164 |
gr.Textbox(lines=10, label="Process"),
|
| 165 |
gr.Plot(label="Score Distribution"),
|
| 166 |
+
gr.Textbox(lines=50, label="Top picks"),
|
| 167 |
+
],
|
| 168 |
+
description="Generate multiple completions and use score and pairwise filters to find the best answers.",
|
| 169 |
)
|
| 170 |
|
| 171 |
if __name__ == "__main__":
|
tests/test_tournament_utils.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys, os, types
|
| 2 |
+
from unittest.mock import patch, MagicMock
|
| 3 |
+
|
| 4 |
+
# Ensure project root in path
|
| 5 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 6 |
+
|
| 7 |
+
# Provide dummy litellm module so import succeeds
|
| 8 |
+
fake_litellm = types.ModuleType('litellm')
|
| 9 |
+
fake_litellm.completion = MagicMock()
|
| 10 |
+
sys.modules.setdefault('litellm', fake_litellm)
|
| 11 |
+
|
| 12 |
+
import tournament_utils as tu
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def make_response(contents):
|
| 16 |
+
class Message:
|
| 17 |
+
def __init__(self, content):
|
| 18 |
+
self.content = content
|
| 19 |
+
class Choice:
|
| 20 |
+
def __init__(self, content):
|
| 21 |
+
self.message = Message(content)
|
| 22 |
+
return MagicMock(choices=[Choice(c) for c in contents])
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def test_generate_players():
|
| 26 |
+
resp = make_response([" player1 ", "player2\n"])
|
| 27 |
+
with patch('tournament_utils.completion', return_value=resp) as mock_comp:
|
| 28 |
+
players = tu.generate_players('instr', 2, model='m')
|
| 29 |
+
mock_comp.assert_called_once_with(model='m', messages=[{'role': 'user', 'content': 'instr'}], n=2)
|
| 30 |
+
assert players == ['player1', 'player2']
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def test_prompt_score():
|
| 34 |
+
resp = make_response([" {\"score\": [5]} "])
|
| 35 |
+
with patch('tournament_utils.completion', return_value=resp) as mock_comp:
|
| 36 |
+
result = tu.prompt_score('instr', ['c1'], 'block', 'pl', model='m')
|
| 37 |
+
mock_comp.assert_called_once()
|
| 38 |
+
assert result == '{"score": [5]}'
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def test_prompt_play():
|
| 42 |
+
resp = make_response([" {\"winner\": \"A\"} "])
|
| 43 |
+
with patch('tournament_utils.completion', return_value=resp) as mock_comp:
|
| 44 |
+
result = tu.prompt_play('instr', 'block', 'A text', 'B text', model='m')
|
| 45 |
+
mock_comp.assert_called_once()
|
| 46 |
+
assert result == '{"winner": "A"}'
|
tournament_utils.py
CHANGED
|
@@ -11,12 +11,19 @@ def generate_players(instruction: str, n: int, model: str = "gpt-4o-mini"):
|
|
| 11 |
return [c.message.content.strip() for c in response.choices]
|
| 12 |
|
| 13 |
|
| 14 |
-
def prompt_score(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
"""Return a JSON score string evaluating `player` on the criteria."""
|
|
|
|
| 16 |
prompt = f"""Evaluate the output below on the following criteria:
|
| 17 |
{criteria_block}
|
| 18 |
|
| 19 |
-
Return JSON exactly like: {{"
|
| 20 |
|
| 21 |
Instruction:
|
| 22 |
{instruction}
|
|
|
|
| 11 |
return [c.message.content.strip() for c in response.choices]
|
| 12 |
|
| 13 |
|
| 14 |
+
def prompt_score(
|
| 15 |
+
instruction: str,
|
| 16 |
+
criteria_list: list[str],
|
| 17 |
+
criteria_block: str,
|
| 18 |
+
player: str,
|
| 19 |
+
model: str = "gpt-4o-mini",
|
| 20 |
+
) -> str:
|
| 21 |
"""Return a JSON score string evaluating `player` on the criteria."""
|
| 22 |
+
example_scores = ", ".join(["1-10"] * len(criteria_list)) or "1-10"
|
| 23 |
prompt = f"""Evaluate the output below on the following criteria:
|
| 24 |
{criteria_block}
|
| 25 |
|
| 26 |
+
Return JSON exactly like: {{"scores": [{example_scores}]}}.
|
| 27 |
|
| 28 |
Instruction:
|
| 29 |
{instruction}
|