ping98k commited on
Commit
cbe88a5
·
1 Parent(s): f94af77

Refine interface layout and labels

Browse files
Files changed (2) hide show
  1. README.md +6 -2
  2. main.py +112 -63
README.md CHANGED
@@ -9,6 +9,10 @@ This project provides a small interface for running "tournaments" between langua
9
  - `POOL_SIZE`
10
  - `MAX_WORKERS`
11
  - `NUM_GENERATIONS`
 
 
 
 
12
  2. Install dependencies (example with `pip`):
13
  ```bash
14
  pip install gradio litellm python-dotenv tqdm matplotlib
@@ -17,7 +21,7 @@ This project provides a small interface for running "tournaments" between langua
17
  ```bash
18
  python main.py
19
  ```
20
- 4. Open the displayed local URL to provide an instruction and evaluation criteria.
21
 
22
- The interface will generate multiple answers, score them, and run a head-to-head tournament to find the best outputs.
23
 
 
9
  - `POOL_SIZE`
10
  - `MAX_WORKERS`
11
  - `NUM_GENERATIONS`
12
+ - `OPENAI_API_BASE`
13
+ - `OPENAI_API_KEY`
14
+ - `ENABLE_SCORE_FILTER`
15
+ - `ENABLE_PAIRWISE_FILTER`
16
  2. Install dependencies (example with `pip`):
17
  ```bash
18
  pip install gradio litellm python-dotenv tqdm matplotlib
 
21
  ```bash
22
  python main.py
23
  ```
24
+ 4. Open the displayed local URL. At the top of the page you can optionally override the API base path and token (the token field is blank by default). Additional settings let you configure score and pairwise filtering.
25
 
26
+ The interface will generate multiple answers, optionally filter them by score and run a pairwise tournament to select the best outputs.
27
 
main.py CHANGED
@@ -10,6 +10,10 @@ NUM_TOP_PICKS_DEFAULT = int(os.getenv("NUM_TOP_PICKS", 5))
10
  POOL_SIZE_DEFAULT = int(os.getenv("POOL_SIZE", 10))
11
  MAX_WORKERS_DEFAULT = int(os.getenv("MAX_WORKERS", 10))
12
  NUM_GENERATIONS_DEFAULT = int(os.getenv("NUM_GENERATIONS", 20))
 
 
 
 
13
 
14
  def _clean_json(txt):
15
  txt = re.sub(r"^```.*?\n|```$", "", txt, flags=re.DOTALL).strip()
@@ -18,13 +22,30 @@ def _clean_json(txt):
18
  except json.JSONDecodeError:
19
  return ast.literal_eval(txt)
20
 
21
- def run_tournament(instruction_input, criteria_input, n_gen, num_top_picks, pool_size, max_workers):
 
 
 
 
 
 
 
 
 
 
 
22
  instruction = instruction_input.strip()
23
  criteria_list = [c.strip() for c in criteria_input.split(",") if c.strip()] or ["Factuality", "Instruction Following", "Precision"]
24
  n_gen = int(n_gen)
25
  num_top_picks = int(num_top_picks)
26
  pool_size = int(pool_size)
27
  max_workers = int(max_workers)
 
 
 
 
 
 
28
  process_log = []
29
  hist_fig = None
30
  top_picks_str = ""
@@ -38,81 +59,109 @@ def run_tournament(instruction_input, criteria_input, n_gen, num_top_picks, pool
38
  def criteria_block():
39
  return "\n".join(f"{i + 1}) {c}" for i, c in enumerate(criteria_list))
40
 
41
- def score(player):
42
- data = _clean_json(prompt_score(instruction, criteria_block(), player))
43
- lst = data.get("score", data.get("scores", []))
44
- return sum(lst) / len(lst) if lst else 0.0
45
- yield from log("Scoring players …")
46
- with ThreadPoolExecutor(max_workers=max_workers) as ex:
47
- scores = {p: s for p, s in zip(all_players, list(tqdm(ex.map(score, all_players), total=len(all_players))))}
48
- hist_fig = plt.figure()
49
- plt.hist(list(scores.values()), bins=10)
50
- yield from log("Histogram generated")
51
- top_players = sorted(all_players, key=scores.get, reverse=True)[:pool_size]
52
- yield from log(f"Filtered to {len(top_players)} players with best scores")
53
- def play(a, b):
54
- winner_label = _clean_json(
55
- prompt_play(instruction, criteria_block(), a, b)
56
- ).get("winner", "A")
57
- return a if winner_label == "A" else b
58
- def tournament_round(pairs, executor):
59
- futures = {executor.submit(play, a, b): (a, b) for a, b in pairs}
60
- results = []
61
- for fut in tqdm(as_completed(futures), total=len(futures)):
62
- a, b = futures[fut]
63
- winner = fut.result()
64
- loser = b if winner == a else a
65
- results.append((winner, loser))
66
- return results
67
- def tournament(players, executor):
68
- lost_to = {}
69
- current = players[:]
70
- while len(current) > 1:
71
- pairs = [(current[i], current[i + 1]) for i in range(0, len(current) - 1, 2)]
72
- round_results = tournament_round(pairs, executor)
73
- for w, l in round_results:
74
- lost_to[l] = w
75
- current = [w for w, _ in round_results]
76
- if len(players) % 2 == 1 and players[-1] not in current:
77
- current.append(players[-1])
78
- return current[0], lost_to
79
- def get_candidates(champion, lost_to):
80
- return [p for p, o in lost_to.items() if o == champion] + [champion]
81
- def playoff(candidates, executor):
82
- wins = {p: 0 for p in candidates}
83
- pairs = [(candidates[i], candidates[j]) for i in range(len(candidates)) for j in range(i + 1, len(candidates))]
84
- futures = {executor.submit(play, a, b): (a, b) for a, b in pairs}
85
- for fut in tqdm(as_completed(futures), total=len(futures)):
86
- wins[fut.result()] += 1
87
- return sorted(candidates, key=lambda p: wins[p], reverse=True)
88
- def get_top(players, executor):
89
- champion, lost_to = tournament(players, executor)
90
- runner_up = lost_to.get(champion)
91
- finalists = [champion] + ([runner_up] if runner_up else [])
92
- semifinalists = [p for p, o in lost_to.items() if o in finalists and p not in finalists]
93
- candidates = list(set(finalists + semifinalists + get_candidates(champion, lost_to)))
94
- return playoff(candidates, executor)[:num_top_picks]
95
- yield from log("Running tournament …")
96
- with ThreadPoolExecutor(max_workers=max_workers) as ex:
97
- top_k = get_top(top_players, ex)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  top_picks_str = "\n\n\n=====================================================\n\n\n".join(top_k)
99
  yield "\n".join(process_log + ["Done"]), hist_fig, top_picks_str
100
 
101
  demo = gr.Interface(
102
  fn=run_tournament,
103
  inputs=[
 
 
104
  gr.Textbox(lines=10, label="Instruction"),
105
  gr.Textbox(lines=5, label="Criteria (comma separated)"),
106
  gr.Number(value=NUM_GENERATIONS_DEFAULT, label="Number of Generations"),
107
- gr.Number(value=NUM_TOP_PICKS_DEFAULT, label="Top Picks (k)"),
108
- gr.Number(value=POOL_SIZE_DEFAULT, label="Filter Size"),
109
- gr.Number(value=MAX_WORKERS_DEFAULT, label="Max Workers")
 
 
110
  ],
111
  outputs=[
112
  gr.Textbox(lines=10, label="Process"),
113
  gr.Plot(label="Score Distribution"),
114
- gr.Textbox(lines=50, label="Top picks")
115
- ]
 
116
  )
117
 
118
  if __name__ == "__main__":
 
10
  POOL_SIZE_DEFAULT = int(os.getenv("POOL_SIZE", 10))
11
  MAX_WORKERS_DEFAULT = int(os.getenv("MAX_WORKERS", 10))
12
  NUM_GENERATIONS_DEFAULT = int(os.getenv("NUM_GENERATIONS", 20))
13
+ API_BASE_DEFAULT = os.getenv("OPENAI_API_BASE", "")
14
+ API_TOKEN_DEFAULT = os.getenv("OPENAI_API_KEY", "")
15
+ SCORE_FILTER_DEFAULT = os.getenv("ENABLE_SCORE_FILTER", "true").lower() == "true"
16
+ PAIRWISE_FILTER_DEFAULT = os.getenv("ENABLE_PAIRWISE_FILTER", "true").lower() == "true"
17
 
18
  def _clean_json(txt):
19
  txt = re.sub(r"^```.*?\n|```$", "", txt, flags=re.DOTALL).strip()
 
22
  except json.JSONDecodeError:
23
  return ast.literal_eval(txt)
24
 
25
+ def run_tournament(
26
+ instruction_input,
27
+ criteria_input,
28
+ n_gen,
29
+ num_top_picks,
30
+ pool_size,
31
+ max_workers,
32
+ api_base,
33
+ api_token,
34
+ enable_score_filter,
35
+ enable_pairwise_filter,
36
+ ):
37
  instruction = instruction_input.strip()
38
  criteria_list = [c.strip() for c in criteria_input.split(",") if c.strip()] or ["Factuality", "Instruction Following", "Precision"]
39
  n_gen = int(n_gen)
40
  num_top_picks = int(num_top_picks)
41
  pool_size = int(pool_size)
42
  max_workers = int(max_workers)
43
+ if api_base:
44
+ os.environ["OPENAI_API_BASE"] = api_base
45
+ if api_token:
46
+ os.environ["OPENAI_API_KEY"] = api_token
47
+ enable_score_filter = bool(enable_score_filter)
48
+ enable_pairwise_filter = bool(enable_pairwise_filter)
49
  process_log = []
50
  hist_fig = None
51
  top_picks_str = ""
 
59
  def criteria_block():
60
  return "\n".join(f"{i + 1}) {c}" for i, c in enumerate(criteria_list))
61
 
62
+ if enable_score_filter:
63
+ def score(player):
64
+ data = _clean_json(prompt_score(instruction, criteria_block(), player))
65
+ lst = data.get("score", data.get("scores", []))
66
+ return sum(lst) / len(lst) if lst else 0.0
67
+
68
+ yield from log("Scoring players …")
69
+ with ThreadPoolExecutor(max_workers=max_workers) as ex:
70
+ scores = {
71
+ p: s
72
+ for p, s in zip(
73
+ all_players,
74
+ list(tqdm(ex.map(score, all_players), total=len(all_players))),
75
+ )
76
+ }
77
+ hist_fig = plt.figure()
78
+ plt.hist(list(scores.values()), bins=10)
79
+ yield from log("Histogram generated")
80
+ top_players = sorted(all_players, key=scores.get, reverse=True)[:pool_size]
81
+ yield from log(f"Filtered to {len(top_players)} players with best scores")
82
+ else:
83
+ top_players = all_players
84
+ if enable_pairwise_filter:
85
+ def play(a, b):
86
+ winner_label = _clean_json(
87
+ prompt_play(instruction, criteria_block(), a, b)
88
+ ).get("winner", "A")
89
+ return a if winner_label == "A" else b
90
+
91
+ def tournament_round(pairs, executor):
92
+ futures = {executor.submit(play, a, b): (a, b) for a, b in pairs}
93
+ results = []
94
+ for fut in tqdm(as_completed(futures), total=len(futures)):
95
+ a, b = futures[fut]
96
+ winner = fut.result()
97
+ loser = b if winner == a else a
98
+ results.append((winner, loser))
99
+ return results
100
+
101
+ def tournament(players, executor):
102
+ lost_to = {}
103
+ current = players[:]
104
+ while len(current) > 1:
105
+ pairs = [(current[i], current[i + 1]) for i in range(0, len(current) - 1, 2)]
106
+ round_results = tournament_round(pairs, executor)
107
+ for w, l in round_results:
108
+ lost_to[l] = w
109
+ current = [w for w, _ in round_results]
110
+ if len(players) % 2 == 1 and players[-1] not in current:
111
+ current.append(players[-1])
112
+ return current[0], lost_to
113
+
114
+ def get_candidates(champion, lost_to):
115
+ return [p for p, o in lost_to.items() if o == champion] + [champion]
116
+
117
+ def playoff(candidates, executor):
118
+ wins = {p: 0 for p in candidates}
119
+ pairs = [
120
+ (candidates[i], candidates[j])
121
+ for i in range(len(candidates))
122
+ for j in range(i + 1, len(candidates))
123
+ ]
124
+ futures = {executor.submit(play, a, b): (a, b) for a, b in pairs}
125
+ for fut in tqdm(as_completed(futures), total=len(futures)):
126
+ wins[fut.result()] += 1
127
+ return sorted(candidates, key=lambda p: wins[p], reverse=True)
128
+
129
+ def get_top(players, executor):
130
+ champion, lost_to = tournament(players, executor)
131
+ runner_up = lost_to.get(champion)
132
+ finalists = [champion] + ([runner_up] if runner_up else [])
133
+ semifinalists = [p for p, o in lost_to.items() if o in finalists and p not in finalists]
134
+ candidates = list(set(finalists + semifinalists + get_candidates(champion, lost_to)))
135
+ return playoff(candidates, executor)[:num_top_picks]
136
+
137
+ yield from log("Running tournament …")
138
+ with ThreadPoolExecutor(max_workers=max_workers) as ex:
139
+ top_k = get_top(top_players, ex)
140
+ else:
141
+ top_k = top_players[:num_top_picks]
142
  top_picks_str = "\n\n\n=====================================================\n\n\n".join(top_k)
143
  yield "\n".join(process_log + ["Done"]), hist_fig, top_picks_str
144
 
145
  demo = gr.Interface(
146
  fn=run_tournament,
147
  inputs=[
148
+ gr.Textbox(value=API_BASE_DEFAULT, label="API Base Path"),
149
+ gr.Textbox(value="", label="API Token", type="password"),
150
  gr.Textbox(lines=10, label="Instruction"),
151
  gr.Textbox(lines=5, label="Criteria (comma separated)"),
152
  gr.Number(value=NUM_GENERATIONS_DEFAULT, label="Number of Generations"),
153
+ gr.Number(value=POOL_SIZE_DEFAULT, label="Top Picks Score Filter"),
154
+ gr.Number(value=NUM_TOP_PICKS_DEFAULT, label="Top Picks Pairwise"),
155
+ gr.Number(value=MAX_WORKERS_DEFAULT, label="Max Workers"),
156
+ gr.Checkbox(value=SCORE_FILTER_DEFAULT, label="Enable Score Filter"),
157
+ gr.Checkbox(value=PAIRWISE_FILTER_DEFAULT, label="Enable Pairwise Filter"),
158
  ],
159
  outputs=[
160
  gr.Textbox(lines=10, label="Process"),
161
  gr.Plot(label="Score Distribution"),
162
+ gr.Textbox(lines=50, label="Top picks"),
163
+ ],
164
+ description="Generate multiple completions and use score and pairwise filters to find the best answers.",
165
  )
166
 
167
  if __name__ == "__main__":