Spaces:
Sleeping
Sleeping
ping98k
commited on
Commit
·
cbe88a5
1
Parent(s):
f94af77
Refine interface layout and labels
Browse files
README.md
CHANGED
|
@@ -9,6 +9,10 @@ This project provides a small interface for running "tournaments" between langua
|
|
| 9 |
- `POOL_SIZE`
|
| 10 |
- `MAX_WORKERS`
|
| 11 |
- `NUM_GENERATIONS`
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
2. Install dependencies (example with `pip`):
|
| 13 |
```bash
|
| 14 |
pip install gradio litellm python-dotenv tqdm matplotlib
|
|
@@ -17,7 +21,7 @@ This project provides a small interface for running "tournaments" between langua
|
|
| 17 |
```bash
|
| 18 |
python main.py
|
| 19 |
```
|
| 20 |
-
4. Open the displayed local URL
|
| 21 |
|
| 22 |
-
The interface will generate multiple answers,
|
| 23 |
|
|
|
|
| 9 |
- `POOL_SIZE`
|
| 10 |
- `MAX_WORKERS`
|
| 11 |
- `NUM_GENERATIONS`
|
| 12 |
+
- `OPENAI_API_BASE`
|
| 13 |
+
- `OPENAI_API_KEY`
|
| 14 |
+
- `ENABLE_SCORE_FILTER`
|
| 15 |
+
- `ENABLE_PAIRWISE_FILTER`
|
| 16 |
2. Install dependencies (example with `pip`):
|
| 17 |
```bash
|
| 18 |
pip install gradio litellm python-dotenv tqdm matplotlib
|
|
|
|
| 21 |
```bash
|
| 22 |
python main.py
|
| 23 |
```
|
| 24 |
+
4. Open the displayed local URL. At the top of the page you can optionally override the API base path and token (the token field is blank by default). Additional settings let you configure score and pairwise filtering.
|
| 25 |
|
| 26 |
+
The interface will generate multiple answers, optionally filter them by score and run a pairwise tournament to select the best outputs.
|
| 27 |
|
main.py
CHANGED
|
@@ -10,6 +10,10 @@ NUM_TOP_PICKS_DEFAULT = int(os.getenv("NUM_TOP_PICKS", 5))
|
|
| 10 |
POOL_SIZE_DEFAULT = int(os.getenv("POOL_SIZE", 10))
|
| 11 |
MAX_WORKERS_DEFAULT = int(os.getenv("MAX_WORKERS", 10))
|
| 12 |
NUM_GENERATIONS_DEFAULT = int(os.getenv("NUM_GENERATIONS", 20))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
def _clean_json(txt):
|
| 15 |
txt = re.sub(r"^```.*?\n|```$", "", txt, flags=re.DOTALL).strip()
|
|
@@ -18,13 +22,30 @@ def _clean_json(txt):
|
|
| 18 |
except json.JSONDecodeError:
|
| 19 |
return ast.literal_eval(txt)
|
| 20 |
|
| 21 |
-
def run_tournament(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
instruction = instruction_input.strip()
|
| 23 |
criteria_list = [c.strip() for c in criteria_input.split(",") if c.strip()] or ["Factuality", "Instruction Following", "Precision"]
|
| 24 |
n_gen = int(n_gen)
|
| 25 |
num_top_picks = int(num_top_picks)
|
| 26 |
pool_size = int(pool_size)
|
| 27 |
max_workers = int(max_workers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
process_log = []
|
| 29 |
hist_fig = None
|
| 30 |
top_picks_str = ""
|
|
@@ -38,81 +59,109 @@ def run_tournament(instruction_input, criteria_input, n_gen, num_top_picks, pool
|
|
| 38 |
def criteria_block():
|
| 39 |
return "\n".join(f"{i + 1}) {c}" for i, c in enumerate(criteria_list))
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
for
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
top_picks_str = "\n\n\n=====================================================\n\n\n".join(top_k)
|
| 99 |
yield "\n".join(process_log + ["Done"]), hist_fig, top_picks_str
|
| 100 |
|
| 101 |
demo = gr.Interface(
|
| 102 |
fn=run_tournament,
|
| 103 |
inputs=[
|
|
|
|
|
|
|
| 104 |
gr.Textbox(lines=10, label="Instruction"),
|
| 105 |
gr.Textbox(lines=5, label="Criteria (comma separated)"),
|
| 106 |
gr.Number(value=NUM_GENERATIONS_DEFAULT, label="Number of Generations"),
|
| 107 |
-
gr.Number(value=
|
| 108 |
-
gr.Number(value=
|
| 109 |
-
gr.Number(value=MAX_WORKERS_DEFAULT, label="Max Workers")
|
|
|
|
|
|
|
| 110 |
],
|
| 111 |
outputs=[
|
| 112 |
gr.Textbox(lines=10, label="Process"),
|
| 113 |
gr.Plot(label="Score Distribution"),
|
| 114 |
-
gr.Textbox(lines=50, label="Top picks")
|
| 115 |
-
]
|
|
|
|
| 116 |
)
|
| 117 |
|
| 118 |
if __name__ == "__main__":
|
|
|
|
| 10 |
POOL_SIZE_DEFAULT = int(os.getenv("POOL_SIZE", 10))
|
| 11 |
MAX_WORKERS_DEFAULT = int(os.getenv("MAX_WORKERS", 10))
|
| 12 |
NUM_GENERATIONS_DEFAULT = int(os.getenv("NUM_GENERATIONS", 20))
|
| 13 |
+
API_BASE_DEFAULT = os.getenv("OPENAI_API_BASE", "")
|
| 14 |
+
API_TOKEN_DEFAULT = os.getenv("OPENAI_API_KEY", "")
|
| 15 |
+
SCORE_FILTER_DEFAULT = os.getenv("ENABLE_SCORE_FILTER", "true").lower() == "true"
|
| 16 |
+
PAIRWISE_FILTER_DEFAULT = os.getenv("ENABLE_PAIRWISE_FILTER", "true").lower() == "true"
|
| 17 |
|
| 18 |
def _clean_json(txt):
|
| 19 |
txt = re.sub(r"^```.*?\n|```$", "", txt, flags=re.DOTALL).strip()
|
|
|
|
| 22 |
except json.JSONDecodeError:
|
| 23 |
return ast.literal_eval(txt)
|
| 24 |
|
| 25 |
+
def run_tournament(
|
| 26 |
+
instruction_input,
|
| 27 |
+
criteria_input,
|
| 28 |
+
n_gen,
|
| 29 |
+
num_top_picks,
|
| 30 |
+
pool_size,
|
| 31 |
+
max_workers,
|
| 32 |
+
api_base,
|
| 33 |
+
api_token,
|
| 34 |
+
enable_score_filter,
|
| 35 |
+
enable_pairwise_filter,
|
| 36 |
+
):
|
| 37 |
instruction = instruction_input.strip()
|
| 38 |
criteria_list = [c.strip() for c in criteria_input.split(",") if c.strip()] or ["Factuality", "Instruction Following", "Precision"]
|
| 39 |
n_gen = int(n_gen)
|
| 40 |
num_top_picks = int(num_top_picks)
|
| 41 |
pool_size = int(pool_size)
|
| 42 |
max_workers = int(max_workers)
|
| 43 |
+
if api_base:
|
| 44 |
+
os.environ["OPENAI_API_BASE"] = api_base
|
| 45 |
+
if api_token:
|
| 46 |
+
os.environ["OPENAI_API_KEY"] = api_token
|
| 47 |
+
enable_score_filter = bool(enable_score_filter)
|
| 48 |
+
enable_pairwise_filter = bool(enable_pairwise_filter)
|
| 49 |
process_log = []
|
| 50 |
hist_fig = None
|
| 51 |
top_picks_str = ""
|
|
|
|
| 59 |
def criteria_block():
|
| 60 |
return "\n".join(f"{i + 1}) {c}" for i, c in enumerate(criteria_list))
|
| 61 |
|
| 62 |
+
if enable_score_filter:
|
| 63 |
+
def score(player):
|
| 64 |
+
data = _clean_json(prompt_score(instruction, criteria_block(), player))
|
| 65 |
+
lst = data.get("score", data.get("scores", []))
|
| 66 |
+
return sum(lst) / len(lst) if lst else 0.0
|
| 67 |
+
|
| 68 |
+
yield from log("Scoring players …")
|
| 69 |
+
with ThreadPoolExecutor(max_workers=max_workers) as ex:
|
| 70 |
+
scores = {
|
| 71 |
+
p: s
|
| 72 |
+
for p, s in zip(
|
| 73 |
+
all_players,
|
| 74 |
+
list(tqdm(ex.map(score, all_players), total=len(all_players))),
|
| 75 |
+
)
|
| 76 |
+
}
|
| 77 |
+
hist_fig = plt.figure()
|
| 78 |
+
plt.hist(list(scores.values()), bins=10)
|
| 79 |
+
yield from log("Histogram generated")
|
| 80 |
+
top_players = sorted(all_players, key=scores.get, reverse=True)[:pool_size]
|
| 81 |
+
yield from log(f"Filtered to {len(top_players)} players with best scores")
|
| 82 |
+
else:
|
| 83 |
+
top_players = all_players
|
| 84 |
+
if enable_pairwise_filter:
|
| 85 |
+
def play(a, b):
|
| 86 |
+
winner_label = _clean_json(
|
| 87 |
+
prompt_play(instruction, criteria_block(), a, b)
|
| 88 |
+
).get("winner", "A")
|
| 89 |
+
return a if winner_label == "A" else b
|
| 90 |
+
|
| 91 |
+
def tournament_round(pairs, executor):
|
| 92 |
+
futures = {executor.submit(play, a, b): (a, b) for a, b in pairs}
|
| 93 |
+
results = []
|
| 94 |
+
for fut in tqdm(as_completed(futures), total=len(futures)):
|
| 95 |
+
a, b = futures[fut]
|
| 96 |
+
winner = fut.result()
|
| 97 |
+
loser = b if winner == a else a
|
| 98 |
+
results.append((winner, loser))
|
| 99 |
+
return results
|
| 100 |
+
|
| 101 |
+
def tournament(players, executor):
|
| 102 |
+
lost_to = {}
|
| 103 |
+
current = players[:]
|
| 104 |
+
while len(current) > 1:
|
| 105 |
+
pairs = [(current[i], current[i + 1]) for i in range(0, len(current) - 1, 2)]
|
| 106 |
+
round_results = tournament_round(pairs, executor)
|
| 107 |
+
for w, l in round_results:
|
| 108 |
+
lost_to[l] = w
|
| 109 |
+
current = [w for w, _ in round_results]
|
| 110 |
+
if len(players) % 2 == 1 and players[-1] not in current:
|
| 111 |
+
current.append(players[-1])
|
| 112 |
+
return current[0], lost_to
|
| 113 |
+
|
| 114 |
+
def get_candidates(champion, lost_to):
|
| 115 |
+
return [p for p, o in lost_to.items() if o == champion] + [champion]
|
| 116 |
+
|
| 117 |
+
def playoff(candidates, executor):
|
| 118 |
+
wins = {p: 0 for p in candidates}
|
| 119 |
+
pairs = [
|
| 120 |
+
(candidates[i], candidates[j])
|
| 121 |
+
for i in range(len(candidates))
|
| 122 |
+
for j in range(i + 1, len(candidates))
|
| 123 |
+
]
|
| 124 |
+
futures = {executor.submit(play, a, b): (a, b) for a, b in pairs}
|
| 125 |
+
for fut in tqdm(as_completed(futures), total=len(futures)):
|
| 126 |
+
wins[fut.result()] += 1
|
| 127 |
+
return sorted(candidates, key=lambda p: wins[p], reverse=True)
|
| 128 |
+
|
| 129 |
+
def get_top(players, executor):
|
| 130 |
+
champion, lost_to = tournament(players, executor)
|
| 131 |
+
runner_up = lost_to.get(champion)
|
| 132 |
+
finalists = [champion] + ([runner_up] if runner_up else [])
|
| 133 |
+
semifinalists = [p for p, o in lost_to.items() if o in finalists and p not in finalists]
|
| 134 |
+
candidates = list(set(finalists + semifinalists + get_candidates(champion, lost_to)))
|
| 135 |
+
return playoff(candidates, executor)[:num_top_picks]
|
| 136 |
+
|
| 137 |
+
yield from log("Running tournament …")
|
| 138 |
+
with ThreadPoolExecutor(max_workers=max_workers) as ex:
|
| 139 |
+
top_k = get_top(top_players, ex)
|
| 140 |
+
else:
|
| 141 |
+
top_k = top_players[:num_top_picks]
|
| 142 |
top_picks_str = "\n\n\n=====================================================\n\n\n".join(top_k)
|
| 143 |
yield "\n".join(process_log + ["Done"]), hist_fig, top_picks_str
|
| 144 |
|
| 145 |
demo = gr.Interface(
|
| 146 |
fn=run_tournament,
|
| 147 |
inputs=[
|
| 148 |
+
gr.Textbox(value=API_BASE_DEFAULT, label="API Base Path"),
|
| 149 |
+
gr.Textbox(value="", label="API Token", type="password"),
|
| 150 |
gr.Textbox(lines=10, label="Instruction"),
|
| 151 |
gr.Textbox(lines=5, label="Criteria (comma separated)"),
|
| 152 |
gr.Number(value=NUM_GENERATIONS_DEFAULT, label="Number of Generations"),
|
| 153 |
+
gr.Number(value=POOL_SIZE_DEFAULT, label="Top Picks Score Filter"),
|
| 154 |
+
gr.Number(value=NUM_TOP_PICKS_DEFAULT, label="Top Picks Pairwise"),
|
| 155 |
+
gr.Number(value=MAX_WORKERS_DEFAULT, label="Max Workers"),
|
| 156 |
+
gr.Checkbox(value=SCORE_FILTER_DEFAULT, label="Enable Score Filter"),
|
| 157 |
+
gr.Checkbox(value=PAIRWISE_FILTER_DEFAULT, label="Enable Pairwise Filter"),
|
| 158 |
],
|
| 159 |
outputs=[
|
| 160 |
gr.Textbox(lines=10, label="Process"),
|
| 161 |
gr.Plot(label="Score Distribution"),
|
| 162 |
+
gr.Textbox(lines=50, label="Top picks"),
|
| 163 |
+
],
|
| 164 |
+
description="Generate multiple completions and use score and pairwise filters to find the best answers.",
|
| 165 |
)
|
| 166 |
|
| 167 |
if __name__ == "__main__":
|