Spaces:
Sleeping
Sleeping
ping98k
commited on
Commit
·
249284d
1
Parent(s):
dc010c0
Add full loop test for run_tournament
Browse files- README.md +3 -0
- main.py +45 -8
- tests/test_main.py +112 -0
- tests/test_tournament_utils.py +9 -5
- tournament_utils.py +42 -4
README.md
CHANGED
|
@@ -11,6 +11,9 @@ This project provides a small interface for running "tournaments" between langua
|
|
| 11 |
- `NUM_GENERATIONS`
|
| 12 |
- `OPENAI_API_BASE`
|
| 13 |
- `OPENAI_API_KEY`
|
|
|
|
|
|
|
|
|
|
| 14 |
- `ENABLE_SCORE_FILTER`
|
| 15 |
- `ENABLE_PAIRWISE_FILTER`
|
| 16 |
2. Install dependencies (example with `pip`):
|
|
|
|
| 11 |
- `NUM_GENERATIONS`
|
| 12 |
- `OPENAI_API_BASE`
|
| 13 |
- `OPENAI_API_KEY`
|
| 14 |
+
- `GENERATE_MODEL`
|
| 15 |
+
- `SCORE_MODEL`
|
| 16 |
+
- `PAIRWISE_MODEL`
|
| 17 |
- `ENABLE_SCORE_FILTER`
|
| 18 |
- `ENABLE_PAIRWISE_FILTER`
|
| 19 |
2. Install dependencies (example with `pip`):
|
main.py
CHANGED
|
@@ -4,7 +4,7 @@ import os, json, re, ast, gradio as gr
|
|
| 4 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 5 |
from tqdm import tqdm
|
| 6 |
import matplotlib.pyplot as plt
|
| 7 |
-
from tournament_utils import generate_players, prompt_score,
|
| 8 |
|
| 9 |
NUM_TOP_PICKS_DEFAULT = int(os.getenv("NUM_TOP_PICKS", 3))
|
| 10 |
POOL_SIZE_DEFAULT = int(os.getenv("POOL_SIZE", 5))
|
|
@@ -14,6 +14,9 @@ API_BASE_DEFAULT = os.getenv("OPENAI_API_BASE", "")
|
|
| 14 |
API_TOKEN_DEFAULT = os.getenv("OPENAI_API_KEY", "")
|
| 15 |
SCORE_FILTER_DEFAULT = os.getenv("ENABLE_SCORE_FILTER", "true").lower() == "true"
|
| 16 |
PAIRWISE_FILTER_DEFAULT = os.getenv("ENABLE_PAIRWISE_FILTER", "true").lower() == "true"
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
def _clean_json(txt):
|
| 19 |
txt = re.sub(r"^```.*?\n|```$", "", txt, flags=re.DOTALL).strip()
|
|
@@ -25,6 +28,9 @@ def _clean_json(txt):
|
|
| 25 |
def run_tournament(
|
| 26 |
api_base,
|
| 27 |
api_token,
|
|
|
|
|
|
|
|
|
|
| 28 |
instruction_input,
|
| 29 |
criteria_input,
|
| 30 |
n_gen,
|
|
@@ -40,10 +46,16 @@ def run_tournament(
|
|
| 40 |
num_top_picks = int(num_top_picks)
|
| 41 |
pool_size = int(pool_size)
|
| 42 |
max_workers = int(max_workers)
|
| 43 |
-
if api_base:
|
| 44 |
-
|
| 45 |
-
if api_token:
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
enable_score_filter = bool(enable_score_filter)
|
| 48 |
enable_pairwise_filter = bool(enable_pairwise_filter)
|
| 49 |
process_log = []
|
|
@@ -54,7 +66,13 @@ def run_tournament(
|
|
| 54 |
tqdm.write(msg)
|
| 55 |
yield "\n".join(process_log), hist_fig, top_picks_str
|
| 56 |
yield from log("Generating players …")
|
| 57 |
-
all_players = generate_players(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
yield from log(f"{len(all_players)} players generated")
|
| 59 |
def criteria_block():
|
| 60 |
return "\n".join(f"{i + 1}) {c}" for i, c in enumerate(criteria_list))
|
|
@@ -62,7 +80,15 @@ def run_tournament(
|
|
| 62 |
if enable_score_filter:
|
| 63 |
def score(player):
|
| 64 |
data = _clean_json(
|
| 65 |
-
prompt_score(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
)
|
| 67 |
if "scores" in data and isinstance(data["scores"], list):
|
| 68 |
vals = data["scores"]
|
|
@@ -88,7 +114,15 @@ def run_tournament(
|
|
| 88 |
if enable_pairwise_filter:
|
| 89 |
def play(a, b):
|
| 90 |
winner_label = _clean_json(
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
).get("winner", "A")
|
| 93 |
return a if winner_label == "A" else b
|
| 94 |
|
|
@@ -151,6 +185,9 @@ demo = gr.Interface(
|
|
| 151 |
inputs=[
|
| 152 |
gr.Textbox(value=API_BASE_DEFAULT, label="API Base Path"),
|
| 153 |
gr.Textbox(value="", label="API Token", type="password"),
|
|
|
|
|
|
|
|
|
|
| 154 |
gr.Textbox(lines=10, label="Instruction"),
|
| 155 |
gr.Textbox(lines=5, label="Criteria (comma separated)"),
|
| 156 |
gr.Number(value=NUM_GENERATIONS_DEFAULT, label="Number of Generations"),
|
|
|
|
| 4 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 5 |
from tqdm import tqdm
|
| 6 |
import matplotlib.pyplot as plt
|
| 7 |
+
from tournament_utils import generate_players, prompt_score, prompt_pairwise
|
| 8 |
|
| 9 |
NUM_TOP_PICKS_DEFAULT = int(os.getenv("NUM_TOP_PICKS", 3))
|
| 10 |
POOL_SIZE_DEFAULT = int(os.getenv("POOL_SIZE", 5))
|
|
|
|
| 14 |
API_TOKEN_DEFAULT = os.getenv("OPENAI_API_KEY", "")
|
| 15 |
SCORE_FILTER_DEFAULT = os.getenv("ENABLE_SCORE_FILTER", "true").lower() == "true"
|
| 16 |
PAIRWISE_FILTER_DEFAULT = os.getenv("ENABLE_PAIRWISE_FILTER", "true").lower() == "true"
|
| 17 |
+
GENERATE_MODEL_DEFAULT = os.getenv("GENERATE_MODEL", "gpt-4o-mini")
|
| 18 |
+
SCORE_MODEL_DEFAULT = os.getenv("SCORE_MODEL", "gpt-4o-mini")
|
| 19 |
+
PAIRWISE_MODEL_DEFAULT = os.getenv("PAIRWISE_MODEL", "gpt-4o-mini")
|
| 20 |
|
| 21 |
def _clean_json(txt):
|
| 22 |
txt = re.sub(r"^```.*?\n|```$", "", txt, flags=re.DOTALL).strip()
|
|
|
|
| 28 |
def run_tournament(
|
| 29 |
api_base,
|
| 30 |
api_token,
|
| 31 |
+
generate_model,
|
| 32 |
+
score_model,
|
| 33 |
+
pairwise_model,
|
| 34 |
instruction_input,
|
| 35 |
criteria_input,
|
| 36 |
n_gen,
|
|
|
|
| 46 |
num_top_picks = int(num_top_picks)
|
| 47 |
pool_size = int(pool_size)
|
| 48 |
max_workers = int(max_workers)
|
| 49 |
+
if not api_base:
|
| 50 |
+
api_base = API_BASE_DEFAULT
|
| 51 |
+
if not api_token:
|
| 52 |
+
api_token = API_TOKEN_DEFAULT
|
| 53 |
+
if not generate_model:
|
| 54 |
+
generate_model = GENERATE_MODEL_DEFAULT
|
| 55 |
+
if not score_model:
|
| 56 |
+
score_model = SCORE_MODEL_DEFAULT
|
| 57 |
+
if not pairwise_model:
|
| 58 |
+
pairwise_model = PAIRWISE_MODEL_DEFAULT
|
| 59 |
enable_score_filter = bool(enable_score_filter)
|
| 60 |
enable_pairwise_filter = bool(enable_pairwise_filter)
|
| 61 |
process_log = []
|
|
|
|
| 66 |
tqdm.write(msg)
|
| 67 |
yield "\n".join(process_log), hist_fig, top_picks_str
|
| 68 |
yield from log("Generating players …")
|
| 69 |
+
all_players = generate_players(
|
| 70 |
+
instruction,
|
| 71 |
+
n_gen,
|
| 72 |
+
model=generate_model,
|
| 73 |
+
api_base=api_base,
|
| 74 |
+
api_key=api_token,
|
| 75 |
+
)
|
| 76 |
yield from log(f"{len(all_players)} players generated")
|
| 77 |
def criteria_block():
|
| 78 |
return "\n".join(f"{i + 1}) {c}" for i, c in enumerate(criteria_list))
|
|
|
|
| 80 |
if enable_score_filter:
|
| 81 |
def score(player):
|
| 82 |
data = _clean_json(
|
| 83 |
+
prompt_score(
|
| 84 |
+
instruction,
|
| 85 |
+
criteria_list,
|
| 86 |
+
criteria_block(),
|
| 87 |
+
player,
|
| 88 |
+
model=score_model,
|
| 89 |
+
api_base=api_base,
|
| 90 |
+
api_key=api_token,
|
| 91 |
+
)
|
| 92 |
)
|
| 93 |
if "scores" in data and isinstance(data["scores"], list):
|
| 94 |
vals = data["scores"]
|
|
|
|
| 114 |
if enable_pairwise_filter:
|
| 115 |
def play(a, b):
|
| 116 |
winner_label = _clean_json(
|
| 117 |
+
prompt_pairwise(
|
| 118 |
+
instruction,
|
| 119 |
+
criteria_block(),
|
| 120 |
+
a,
|
| 121 |
+
b,
|
| 122 |
+
model=pairwise_model,
|
| 123 |
+
api_base=api_base,
|
| 124 |
+
api_key=api_token,
|
| 125 |
+
)
|
| 126 |
).get("winner", "A")
|
| 127 |
return a if winner_label == "A" else b
|
| 128 |
|
|
|
|
| 185 |
inputs=[
|
| 186 |
gr.Textbox(value=API_BASE_DEFAULT, label="API Base Path"),
|
| 187 |
gr.Textbox(value="", label="API Token", type="password"),
|
| 188 |
+
gr.Textbox(value=GENERATE_MODEL_DEFAULT, label="Generation Model"),
|
| 189 |
+
gr.Textbox(value=SCORE_MODEL_DEFAULT, label="Score Model"),
|
| 190 |
+
gr.Textbox(value=PAIRWISE_MODEL_DEFAULT, label="Pairwise Model"),
|
| 191 |
gr.Textbox(lines=10, label="Instruction"),
|
| 192 |
gr.Textbox(lines=5, label="Criteria (comma separated)"),
|
| 193 |
gr.Number(value=NUM_GENERATIONS_DEFAULT, label="Number of Generations"),
|
tests/test_main.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys, os, types, json
|
| 2 |
+
from unittest.mock import patch, MagicMock
|
| 3 |
+
|
| 4 |
+
# Ensure project root in path
|
| 5 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 6 |
+
|
| 7 |
+
# Provide dummy litellm module so import succeeds
|
| 8 |
+
fake_litellm = types.ModuleType('litellm')
|
| 9 |
+
fake_litellm.completion = MagicMock()
|
| 10 |
+
sys.modules.setdefault('litellm', fake_litellm)
|
| 11 |
+
|
| 12 |
+
# Provide dummy dotenv module
|
| 13 |
+
fake_dotenv = types.ModuleType('dotenv')
|
| 14 |
+
fake_dotenv.load_dotenv = MagicMock()
|
| 15 |
+
sys.modules.setdefault('dotenv', fake_dotenv)
|
| 16 |
+
|
| 17 |
+
# Dummy gradio module so import succeeds
|
| 18 |
+
fake_gradio = types.ModuleType('gradio')
|
| 19 |
+
fake_gradio.Interface = MagicMock(return_value=MagicMock(launch=MagicMock()))
|
| 20 |
+
fake_gradio.Textbox = MagicMock
|
| 21 |
+
fake_gradio.Number = MagicMock
|
| 22 |
+
fake_gradio.Checkbox = MagicMock
|
| 23 |
+
fake_gradio.Plot = MagicMock
|
| 24 |
+
sys.modules.setdefault('gradio', fake_gradio)
|
| 25 |
+
|
| 26 |
+
# Dummy tqdm module for write method
|
| 27 |
+
class FakeTqdmModule(types.ModuleType):
|
| 28 |
+
def __init__(self):
|
| 29 |
+
super().__init__('tqdm')
|
| 30 |
+
self.write = MagicMock()
|
| 31 |
+
def __call__(self, iterable=None, total=None):
|
| 32 |
+
return iterable
|
| 33 |
+
|
| 34 |
+
fake_tqdm_mod = FakeTqdmModule()
|
| 35 |
+
fake_tqdm_mod.tqdm = fake_tqdm_mod
|
| 36 |
+
sys.modules.setdefault('tqdm', fake_tqdm_mod)
|
| 37 |
+
|
| 38 |
+
# Dummy matplotlib module
|
| 39 |
+
fake_plt = types.ModuleType('matplotlib.pyplot')
|
| 40 |
+
fake_plt.figure = MagicMock(return_value='fig')
|
| 41 |
+
fake_plt.hist = MagicMock()
|
| 42 |
+
fake_matplotlib = types.ModuleType('matplotlib')
|
| 43 |
+
fake_matplotlib.pyplot = fake_plt
|
| 44 |
+
sys.modules.setdefault('matplotlib', fake_matplotlib)
|
| 45 |
+
sys.modules.setdefault('matplotlib.pyplot', fake_plt)
|
| 46 |
+
|
| 47 |
+
import main
|
| 48 |
+
|
| 49 |
+
class DummyFuture:
|
| 50 |
+
def __init__(self, func, *args):
|
| 51 |
+
self._func = func
|
| 52 |
+
self._args = args
|
| 53 |
+
def result(self):
|
| 54 |
+
return self._func(*self._args)
|
| 55 |
+
|
| 56 |
+
class DummyExecutor:
|
| 57 |
+
def __init__(self, *args, **kwargs):
|
| 58 |
+
pass
|
| 59 |
+
def __enter__(self):
|
| 60 |
+
return self
|
| 61 |
+
def __exit__(self, exc_type, exc, tb):
|
| 62 |
+
pass
|
| 63 |
+
def submit(self, func, *args):
|
| 64 |
+
return DummyFuture(func, *args)
|
| 65 |
+
def map(self, func, iterable):
|
| 66 |
+
for item in iterable:
|
| 67 |
+
yield func(item)
|
| 68 |
+
|
| 69 |
+
class DummyTqdm:
|
| 70 |
+
def __call__(self, iterable=None, total=None):
|
| 71 |
+
return iterable
|
| 72 |
+
def write(self, msg):
|
| 73 |
+
pass
|
| 74 |
+
|
| 75 |
+
def test_run_tournament_full_loop():
|
| 76 |
+
dummy_tqdm = DummyTqdm()
|
| 77 |
+
with patch('main.generate_players') as mock_gen, \
|
| 78 |
+
patch('main.prompt_score') as mock_score, \
|
| 79 |
+
patch('main.prompt_pairwise') as mock_pair, \
|
| 80 |
+
patch('main.ThreadPoolExecutor', return_value=DummyExecutor()) as MockExec, \
|
| 81 |
+
patch('main.as_completed', new=lambda futs: futs), \
|
| 82 |
+
patch('main.tqdm', new=dummy_tqdm), \
|
| 83 |
+
patch('main.plt.figure', return_value='fig'), \
|
| 84 |
+
patch('main.plt.hist'):
|
| 85 |
+
mock_gen.return_value = ['p1', 'p2', 'p3', 'p4']
|
| 86 |
+
scores = {'p1':3, 'p2':2, 'p3':1, 'p4':0}
|
| 87 |
+
mock_score.side_effect = lambda instr, cl, block, player, **kw: json.dumps({'score': scores[player]})
|
| 88 |
+
mock_pair.side_effect = lambda instr, block, a, b, **kw: json.dumps({'winner': 'A'})
|
| 89 |
+
|
| 90 |
+
results = list(main.run_tournament(
|
| 91 |
+
api_base='b',
|
| 92 |
+
api_token='k',
|
| 93 |
+
generate_model='gm',
|
| 94 |
+
score_model='sm',
|
| 95 |
+
pairwise_model='pm',
|
| 96 |
+
instruction_input='instr',
|
| 97 |
+
criteria_input='c1,c2',
|
| 98 |
+
n_gen=4,
|
| 99 |
+
pool_size=2,
|
| 100 |
+
num_top_picks=1,
|
| 101 |
+
max_workers=1,
|
| 102 |
+
enable_score_filter=True,
|
| 103 |
+
enable_pairwise_filter=True,
|
| 104 |
+
))
|
| 105 |
+
|
| 106 |
+
process_log, hist_fig, top_picks = results[-1]
|
| 107 |
+
assert 'Done' in process_log
|
| 108 |
+
assert hist_fig == 'fig'
|
| 109 |
+
assert top_picks.strip() in {'p1', 'p2'}
|
| 110 |
+
mock_gen.assert_called_once_with('instr', 4, model='gm', api_base='b', api_key='k')
|
| 111 |
+
assert mock_score.call_count == 4
|
| 112 |
+
assert mock_pair.called
|
tests/test_tournament_utils.py
CHANGED
|
@@ -25,22 +25,26 @@ def make_response(contents):
|
|
| 25 |
def test_generate_players():
|
| 26 |
resp = make_response([" player1 ", "player2\n"])
|
| 27 |
with patch('tournament_utils.completion', return_value=resp) as mock_comp:
|
| 28 |
-
players = tu.generate_players('instr', 2, model='m')
|
| 29 |
-
mock_comp.assert_called_once_with(model='m', messages=[{'role': 'user', 'content': 'instr'}], n=2)
|
| 30 |
assert players == ['player1', 'player2']
|
| 31 |
|
| 32 |
|
| 33 |
def test_prompt_score():
|
| 34 |
resp = make_response([" {\"score\": [5]} "])
|
| 35 |
with patch('tournament_utils.completion', return_value=resp) as mock_comp:
|
| 36 |
-
result = tu.prompt_score('instr', ['c1'], 'block', 'pl', model='m')
|
| 37 |
mock_comp.assert_called_once()
|
|
|
|
|
|
|
| 38 |
assert result == '{"score": [5]}'
|
| 39 |
|
| 40 |
|
| 41 |
-
def
|
| 42 |
resp = make_response([" {\"winner\": \"A\"} "])
|
| 43 |
with patch('tournament_utils.completion', return_value=resp) as mock_comp:
|
| 44 |
-
result = tu.
|
| 45 |
mock_comp.assert_called_once()
|
|
|
|
|
|
|
| 46 |
assert result == '{"winner": "A"}'
|
|
|
|
| 25 |
def test_generate_players():
|
| 26 |
resp = make_response([" player1 ", "player2\n"])
|
| 27 |
with patch('tournament_utils.completion', return_value=resp) as mock_comp:
|
| 28 |
+
players = tu.generate_players('instr', 2, model='m', api_base='b', api_key='k')
|
| 29 |
+
mock_comp.assert_called_once_with(model='m', messages=[{'role': 'user', 'content': 'instr'}], n=2, api_base='b', api_key='k')
|
| 30 |
assert players == ['player1', 'player2']
|
| 31 |
|
| 32 |
|
| 33 |
def test_prompt_score():
|
| 34 |
resp = make_response([" {\"score\": [5]} "])
|
| 35 |
with patch('tournament_utils.completion', return_value=resp) as mock_comp:
|
| 36 |
+
result = tu.prompt_score('instr', ['c1'], 'block', 'pl', model='m', api_base='b', api_key='k')
|
| 37 |
mock_comp.assert_called_once()
|
| 38 |
+
assert mock_comp.call_args.kwargs['api_base'] == 'b'
|
| 39 |
+
assert mock_comp.call_args.kwargs['api_key'] == 'k'
|
| 40 |
assert result == '{"score": [5]}'
|
| 41 |
|
| 42 |
|
| 43 |
+
def test_prompt_pairwise():
|
| 44 |
resp = make_response([" {\"winner\": \"A\"} "])
|
| 45 |
with patch('tournament_utils.completion', return_value=resp) as mock_comp:
|
| 46 |
+
result = tu.prompt_pairwise('instr', 'block', 'A text', 'B text', model='m', api_base='b', api_key='k')
|
| 47 |
mock_comp.assert_called_once()
|
| 48 |
+
assert mock_comp.call_args.kwargs['api_base'] == 'b'
|
| 49 |
+
assert mock_comp.call_args.kwargs['api_key'] == 'k'
|
| 50 |
assert result == '{"winner": "A"}'
|
tournament_utils.py
CHANGED
|
@@ -1,12 +1,30 @@
|
|
| 1 |
from litellm import completion
|
| 2 |
|
| 3 |
|
| 4 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"""Request `n` completions for the instruction using the given model."""
|
| 6 |
response = completion(
|
| 7 |
model=model,
|
| 8 |
messages=[{"role": "user", "content": instruction}],
|
| 9 |
n=n,
|
|
|
|
| 10 |
)
|
| 11 |
return [c.message.content.strip() for c in response.choices]
|
| 12 |
|
|
@@ -17,6 +35,9 @@ def prompt_score(
|
|
| 17 |
criteria_block: str,
|
| 18 |
player: str,
|
| 19 |
model: str = "gpt-4o-mini",
|
|
|
|
|
|
|
|
|
|
| 20 |
) -> str:
|
| 21 |
"""Return a JSON score string evaluating `player` on the criteria."""
|
| 22 |
example_scores = ", ".join(["1-10"] * len(criteria_list)) or "1-10"
|
|
@@ -30,11 +51,24 @@ Instruction:
|
|
| 30 |
|
| 31 |
Output:
|
| 32 |
{player}"""
|
| 33 |
-
response = completion(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
return response.choices[0].message.content.strip()
|
| 35 |
|
| 36 |
|
| 37 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
"""Return which player wins in JSON using the given criteria."""
|
| 39 |
prompt = f"""Compare the two players below using:
|
| 40 |
{criteria_block}
|
|
@@ -47,5 +81,9 @@ Instruction:
|
|
| 47 |
Players:
|
| 48 |
<A>{a}</A>
|
| 49 |
<B>{b}</B>"""
|
| 50 |
-
response = completion(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
return response.choices[0].message.content.strip()
|
|
|
|
| 1 |
from litellm import completion
|
| 2 |
|
| 3 |
|
| 4 |
+
def _completion_kwargs(api_base: str | None, api_key: str | None) -> dict:
|
| 5 |
+
"""Build kwargs for litellm.completion from api settings."""
|
| 6 |
+
kwargs: dict = {}
|
| 7 |
+
if api_base:
|
| 8 |
+
kwargs["api_base"] = api_base
|
| 9 |
+
if api_key:
|
| 10 |
+
kwargs["api_key"] = api_key
|
| 11 |
+
return kwargs
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def generate_players(
|
| 15 |
+
instruction: str,
|
| 16 |
+
n: int,
|
| 17 |
+
model: str = "gpt-4o-mini",
|
| 18 |
+
*,
|
| 19 |
+
api_base: str | None = None,
|
| 20 |
+
api_key: str | None = None,
|
| 21 |
+
):
|
| 22 |
"""Request `n` completions for the instruction using the given model."""
|
| 23 |
response = completion(
|
| 24 |
model=model,
|
| 25 |
messages=[{"role": "user", "content": instruction}],
|
| 26 |
n=n,
|
| 27 |
+
**_completion_kwargs(api_base, api_key),
|
| 28 |
)
|
| 29 |
return [c.message.content.strip() for c in response.choices]
|
| 30 |
|
|
|
|
| 35 |
criteria_block: str,
|
| 36 |
player: str,
|
| 37 |
model: str = "gpt-4o-mini",
|
| 38 |
+
*,
|
| 39 |
+
api_base: str | None = None,
|
| 40 |
+
api_key: str | None = None,
|
| 41 |
) -> str:
|
| 42 |
"""Return a JSON score string evaluating `player` on the criteria."""
|
| 43 |
example_scores = ", ".join(["1-10"] * len(criteria_list)) or "1-10"
|
|
|
|
| 51 |
|
| 52 |
Output:
|
| 53 |
{player}"""
|
| 54 |
+
response = completion(
|
| 55 |
+
model=model,
|
| 56 |
+
messages=[{"role": "system", "content": prompt}],
|
| 57 |
+
**_completion_kwargs(api_base, api_key),
|
| 58 |
+
)
|
| 59 |
return response.choices[0].message.content.strip()
|
| 60 |
|
| 61 |
|
| 62 |
+
def prompt_pairwise(
|
| 63 |
+
instruction: str,
|
| 64 |
+
criteria_block: str,
|
| 65 |
+
a: str,
|
| 66 |
+
b: str,
|
| 67 |
+
model: str = "gpt-4o-mini",
|
| 68 |
+
*,
|
| 69 |
+
api_base: str | None = None,
|
| 70 |
+
api_key: str | None = None,
|
| 71 |
+
) -> str:
|
| 72 |
"""Return which player wins in JSON using the given criteria."""
|
| 73 |
prompt = f"""Compare the two players below using:
|
| 74 |
{criteria_block}
|
|
|
|
| 81 |
Players:
|
| 82 |
<A>{a}</A>
|
| 83 |
<B>{b}</B>"""
|
| 84 |
+
response = completion(
|
| 85 |
+
model=model,
|
| 86 |
+
messages=[{"role": "system", "content": prompt}],
|
| 87 |
+
**_completion_kwargs(api_base, api_key),
|
| 88 |
+
)
|
| 89 |
return response.choices[0].message.content.strip()
|