File size: 15,211 Bytes
c29b692
e4a181a
4ccee12
5831b86
 
66f49ec
249284d
a01815f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2584782
a959b3c
41ec98d
1d9f28c
a959b3c
cbe88a5
 
 
 
249284d
 
 
e4a181a
 
 
02aebba
 
3404ee0
 
 
0d1b4d4
295a884
 
 
 
 
 
 
4ccee12
295a884
 
 
 
4ccee12
295a884
 
 
 
 
 
4ccee12
cbe88a5
87cbfc7
 
249284d
 
 
02aebba
 
 
cbe88a5
 
 
 
87cbfc7
cbe88a5
 
 
02aebba
 
3404ee0
 
 
b936324
 
cbe88a5
2584782
c0bf2b8
753833c
 
 
 
02aebba
 
 
 
 
 
249284d
e4a181a
249284d
e4a181a
249284d
 
 
 
 
 
cbe88a5
 
02aebba
 
 
 
3404ee0
 
 
 
 
 
fa5956e
 
 
 
 
c0bf2b8
66f49ec
e6b4ba0
66f49ec
37e55ed
 
50d3606
c2b0ecd
50d3606
b936324
37e55ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50d3606
b241674
50d3606
fa5956e
 
b241674
 
50d3606
c0bf2b8
 
4ccee12
e6b4ba0
fa5956e
37e55ed
249284d
 
 
 
 
02aebba
3404ee0
37e55ed
249284d
37e55ed
c0bf2b8
37e55ed
b241674
2584782
 
c94e158
cbe88a5
b241674
 
 
 
37e55ed
 
 
 
 
 
 
 
02aebba
 
3404ee0
fa5956e
37e55ed
c595827
37e55ed
b241674
295a884
c2b0ecd
c595827
c2b0ecd
 
 
 
 
 
 
 
 
 
cbe88a5
1d9f28c
cbe88a5
a01815f
 
c2b0ecd
 
 
 
a01815f
cbe88a5
 
 
 
 
b241674
 
cbe88a5
 
 
 
b936324
 
 
37e55ed
 
 
 
 
 
 
 
02aebba
 
3404ee0
fa5956e
37e55ed
 
 
50d3606
295a884
b936324
 
 
cbe88a5
b936324
 
 
 
 
 
 
 
cbe88a5
b936324
 
a01815f
cbe88a5
 
 
b936324
 
 
 
 
 
 
 
 
a01815f
b936324
cbe88a5
1d9f28c
cbe88a5
b936324
e6b4ba0
 
 
eaad301
e6b4ba0
50d3606
 
b936324
c2b0ecd
b936324
cbe88a5
 
b936324
e6b4ba0
2584782
 
 
 
ca4a97e
 
249284d
 
 
cca2d14
 
02aebba
57d022a
1633c5a
753833c
cbe88a5
 
 
 
 
02aebba
 
e4a181a
 
 
 
 
2584782
c0bf2b8
 
66f49ec
e6b4ba0
cbe88a5
37e55ed
cbe88a5
 
2584782
3d7317c
5831b86
2584782
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
from dotenv import load_dotenv
# load_dotenv("./local.env",override=True)
import os, json, re, ast, gradio as gr
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import matplotlib.pyplot as plt
from tournament_utils import generate_players, prompt_score, prompt_pairwise
import time


class SimpleProgress:
    """Minimal progress helper to compute ETA."""

    def __init__(self, total: int, prefix: str = "Progress"):
        self.total = total
        self.prefix = prefix
        self.start = time.time()
        self.count = 0

    def step(self) -> str:
        self.count += 1
        elapsed = time.time() - self.start
        remaining = (elapsed / self.count) * (self.total - self.count) if self.count else 0
        h, rem = divmod(int(remaining), 3600)
        m, s = divmod(rem, 60)
        if h:
            eta = f"{h:d}:{m:02d}:{s:02d}"
        else:
            eta = f"{m:02d}:{s:02d}"
        return f"{self.prefix} {self.count}/{self.total} - ETA {eta}"

NUM_TOP_PICKS_DEFAULT = int(os.getenv("NUM_TOP_PICKS", 3))
POOL_SIZE_DEFAULT = int(os.getenv("POOL_SIZE", 6))
MAX_WORKERS_DEFAULT = int(os.getenv("MAX_WORKERS", 100))
NUM_GENERATIONS_DEFAULT = int(os.getenv("NUM_GENERATIONS", 10))
API_BASE_DEFAULT = os.getenv("OPENAI_API_BASE", "")
API_TOKEN_DEFAULT = os.getenv("OPENAI_API_KEY", "")
SCORE_FILTER_DEFAULT = os.getenv("ENABLE_SCORE_FILTER", "true").lower() == "true"
PAIRWISE_FILTER_DEFAULT = os.getenv("ENABLE_PAIRWISE_FILTER", "true").lower() == "true"
GENERATE_MODEL_DEFAULT = os.getenv("GENERATE_MODEL", "gpt-4o-mini")
SCORE_MODEL_DEFAULT = os.getenv("SCORE_MODEL", "gpt-4o-mini")
PAIRWISE_MODEL_DEFAULT = os.getenv("PAIRWISE_MODEL", "gpt-4o-mini")
GENERATE_TEMPERATURE_DEFAULT = float(os.getenv("GENERATE_TEMPERATURE", "1.2"))
SCORE_TEMPERATURE_DEFAULT = float(os.getenv("SCORE_TEMPERATURE", "0.1"))
PAIRWISE_TEMPERATURE_DEFAULT = float(os.getenv("PAIRWISE_TEMPERATURE", "0.1"))
SCORE_WITH_INSTRUCTION_DEFAULT = os.getenv("PASS_INSTRUCTION_TO_SCORE", "true").lower() == "true"
PAIRWISE_WITH_INSTRUCTION_DEFAULT = os.getenv("PASS_INSTRUCTION_TO_PAIRWISE", "true").lower() == "true"
GENERATE_THINKING_DEFAULT = os.getenv("ENABLE_GENERATE_THINKING", "false").lower() == "true"
SCORE_THINKING_DEFAULT = os.getenv("ENABLE_SCORE_THINKING", "false").lower() == "true"
PAIRWISE_THINKING_DEFAULT = os.getenv("ENABLE_PAIRWISE_THINKING", "false").lower() == "true"
CRITERIA_DEFAULT = "Factuality,Concise,Precision"

# Regex used to capture the final verdict from judge output
FINAL_VERDICT_RE = re.compile(r"(?im)^final verdict:\s*(.*)$")


def _parse_verdict(txt: str) -> dict:
    """Extract verdict information from judge output."""
    txt = re.sub(r"^```.*?\n|```$", "", txt, flags=re.DOTALL).strip()
    match = FINAL_VERDICT_RE.search(txt)
    if not match:
        return {}
    verdict = match.group(1).strip()
    try:
        verdict_val = ast.literal_eval(verdict)
    except Exception:
        verdict_val = verdict
    if isinstance(verdict_val, list):
        return {"scores": verdict_val}
    return {"winner": str(verdict_val)}

def run_tournament(
    api_base,
    api_token,
    generate_model,
    score_model,
    pairwise_model,
    generate_temperature,
    score_temperature,
    pairwise_temperature,
    instruction_input,
    criteria_input,
    n_gen,
    pool_size,
    num_top_picks,
    max_workers,
    enable_score_filter,
    enable_pairwise_filter,
    score_with_instruction,
    pairwise_with_instruction,
    generate_thinking,
    score_thinking,
    pairwise_thinking,
    score_explain=None,
    pairwise_explain=None,
):
    instruction = instruction_input.strip()
    criteria_list = [c.strip() for c in criteria_input.split(",") if c.strip()] or ["Factuality", "Instruction Following", "Precision"]
    n_gen = int(n_gen)
    num_top_picks = int(num_top_picks)
    pool_size = int(pool_size)
    max_workers = int(max_workers)
    if generate_temperature is None:
        generate_temperature = GENERATE_TEMPERATURE_DEFAULT
    if score_temperature is None:
        score_temperature = SCORE_TEMPERATURE_DEFAULT
    if pairwise_temperature is None:
        pairwise_temperature = PAIRWISE_TEMPERATURE_DEFAULT
    if not api_base:
        api_base = ""
    if not api_token:
        api_token = ""
    if not generate_model:
        generate_model = GENERATE_MODEL_DEFAULT
    if not score_model:
        score_model = SCORE_MODEL_DEFAULT
    if not pairwise_model:
        pairwise_model = PAIRWISE_MODEL_DEFAULT
    enable_score_filter = bool(enable_score_filter)
    enable_pairwise_filter = bool(enable_pairwise_filter)
    if score_with_instruction is None:
        score_with_instruction = SCORE_WITH_INSTRUCTION_DEFAULT
    if pairwise_with_instruction is None:
        pairwise_with_instruction = PAIRWISE_WITH_INSTRUCTION_DEFAULT
    if generate_thinking is None:
        generate_thinking = GENERATE_THINKING_DEFAULT
    if score_thinking is None:
        score_thinking = SCORE_THINKING_DEFAULT
    if pairwise_thinking is None:
        pairwise_thinking = PAIRWISE_THINKING_DEFAULT
    if score_explain is None:
        score_explain = False
    if pairwise_explain is None:
        pairwise_explain = False

    process_log = []
    hist_fig = None
    elo_fig = None
    top_picks_str = ""
    prompt_tokens = 0
    completion_tokens = 0
    score_outputs: list[str] = []
    raw_scores: dict[str, list] = {}
    pairwise_outputs: list[str] = []
    match_cache: dict[tuple[str, str], str] = {}

    def add_usage(usage):
        nonlocal prompt_tokens, completion_tokens
        if not usage:
            return
        pt = getattr(usage, "prompt_tokens", None)
        if pt is None and isinstance(usage, dict):
            pt = usage.get("prompt_tokens")
        ct = getattr(usage, "completion_tokens", None)
        if ct is None and isinstance(usage, dict):
            ct = usage.get("completion_tokens")
        if pt:
            prompt_tokens += pt
        if ct:
            completion_tokens += ct

    def usage_str():
        return (
            f"Prompt tokens: {prompt_tokens}\n"
            f"Completion tokens: {completion_tokens}\n"
            f"Total tokens: {prompt_tokens + completion_tokens}"
        )

    def log_completion(prefix: str, text: str, player_id: int | None = None):
        disp = text.replace("\n", " ")
        if len(disp) > 1000:
            disp = disp[:1000] + "…"
        if player_id is not None:
            prefix = f"{prefix}(ID {player_id}) "
        return log(f"{prefix}{disp}")
    def log(msg):
        process_log.append(msg)
        tqdm.write(msg)
        yield "\n".join(process_log), hist_fig, elo_fig, top_picks_str, usage_str()
    yield from log("Generating answers …")
    all_players, usage = generate_players(
        instruction,
        n_gen,
        model=generate_model,
        api_base=api_base,
        api_key=api_token,
        temperature=generate_temperature,
        thinking=generate_thinking,
        return_usage=True,
    )
    add_usage(usage)
    yield from log(f"{len(all_players)} players generated")
    for i, p in enumerate(all_players, 1):
        yield from log_completion(f"Completion {i}: ", p, i)
    def criteria_block():
        return "\n".join(f"{i + 1}) {c}" for i, c in enumerate(criteria_list))

    if enable_score_filter:
        players_with_ids = list(enumerate(all_players, 1))

        def score(item):
            idx, player = item
            text, usage = prompt_score(
                instruction,
                criteria_list,
                criteria_block(),
                player,
                model=score_model,
                api_base=api_base,
                api_key=api_token,
                temperature=score_temperature,
                include_instruction=score_with_instruction,
                thinking=score_thinking,
                explain=score_explain,
                return_usage=True,
            )
            add_usage(usage)
            score_outputs.append((idx, text))
            data = _parse_verdict(text)
            raw_vals = None
            if "scores" in data and isinstance(data["scores"], list):
                raw_vals = data["scores"]
                avg = sum(raw_vals) / len(raw_vals) if raw_vals else 0.0
            else:
                try:
                    avg = float(data.get("score", 0))
                    raw_vals = [avg]
                except Exception:
                    avg = 0.0
                    raw_vals = None
            return avg, raw_vals

        yield from log("Histogram generating")
        with ThreadPoolExecutor(max_workers=max_workers) as ex:
            prog = SimpleProgress(len(all_players), "Scoring")
            scores = {}
            for (idx, p), (s_val, raw_val) in zip(players_with_ids, ex.map(score, players_with_ids)):
                scores[p] = s_val
                if raw_val is not None:
                    raw_scores[p] = raw_val
                yield from log(prog.step())
        hist_fig = plt.figure()
        plt.hist(list(scores.values()), bins=10)
        yield from log("Histogram generated")
        top_players = sorted(all_players, key=scores.get, reverse=True)[:pool_size]
        yield from log(f"Filtered to {len(top_players)} players with best scores")
        for i, (idx, txt) in enumerate(score_outputs, 1):
            yield from log_completion(f"Score completion {i}: ", txt, idx)
    else:
        top_players = all_players
    if enable_pairwise_filter:
        def play(a, b):
            key = tuple(sorted((a, b)))
            if key in match_cache:
                return match_cache[key]
            text, usage = prompt_pairwise(
                instruction,
                criteria_block(),
                a,
                b,
                model=pairwise_model,
                api_base=api_base,
                api_key=api_token,
                temperature=pairwise_temperature,
                include_instruction=pairwise_with_instruction,
                thinking=pairwise_thinking,
                explain=pairwise_explain,
                return_usage=True,
            )
            add_usage(usage)
            pairwise_outputs.append(text)
            winner_label = _parse_verdict(text).get("winner", "A")
            winner = a if winner_label == "A" else b
            match_cache[key] = winner
            return winner

        def all_pairs(players):
            for i in range(len(players)):
                for j in range(i + 1, len(players)):
                    yield players[i], players[j]

        def rate(players, executor):
            rating = {p: 1000.0 for p in players}
            pairs = list(all_pairs(players))
            futures = {executor.submit(play, a, b): (a, b) for a, b in pairs}
            prog = SimpleProgress(len(futures), "Elo matches")
            K = 32
            for fut in as_completed(futures):
                a, b = futures[fut]
                winner = fut.result()
                loser = b if winner == a else a
                ra, rb = rating[a], rating[b]
                ea = 1 / (1 + 10 ** ((rb - ra) / 400))
                eb = 1 - ea
                if winner == a:
                    rating[a] = ra + K * (1 - ea)
                    rating[b] = rb + K * (0 - eb)
                else:
                    rating[a] = ra + K * (0 - ea)
                    rating[b] = rb + K * (1 - eb)
                yield from log(prog.step())
            return rating

        yield from log("Pairwise generating")
        with ThreadPoolExecutor(max_workers=max_workers) as ex:
            rating = yield from rate(top_players, ex)
        elo_fig = plt.figure()
        players_sorted = sorted(rating, key=rating.get, reverse=True)
        plt.bar(range(len(players_sorted)), [rating[p] for p in players_sorted])
        plt.xticks(range(len(players_sorted)), [str(i + 1) for i in range(len(players_sorted))])
        top_k = players_sorted[:num_top_picks]
        for i, txt in enumerate(pairwise_outputs, 1):
            yield from log_completion(f"Pairwise completion {i}: ", txt)
        top_picks_str = "\n\n\n=====================================================\n\n\n".join(
            f"{p}\nElo: {rating[p]:.1f}" + (f"\nScore: {raw_scores.get(p)}" if p in raw_scores else "") for p in top_k
        )
    else:
        top_k = top_players[:num_top_picks]
        top_picks_str = "\n\n\n=====================================================\n\n\n".join(top_k)
    yield "\n".join(process_log + ["Done"]), hist_fig, elo_fig, top_picks_str, usage_str()

demo = gr.Interface(
    fn=run_tournament,
    inputs=[
        gr.Textbox(value=API_BASE_DEFAULT, label="API Base Path", info="Clone repos and run locally To change the API base path If you leave OPENAI_API_BASE blank, LiteLLM defaults to https://api.openai.com/v1."),
        gr.Textbox(value="", label="OPENAI_API_KEY ", type="password"),
        gr.Textbox(value=GENERATE_MODEL_DEFAULT, label="Generation Model"),
        gr.Textbox(value=SCORE_MODEL_DEFAULT, label="Score Model"),
        gr.Textbox(value=PAIRWISE_MODEL_DEFAULT, label="Pairwise Model"),
        gr.Number(value=GENERATE_TEMPERATURE_DEFAULT, label="Generation Temperature"),
        gr.Number(value=SCORE_TEMPERATURE_DEFAULT, label="Score Temperature"),
        gr.Number(value=PAIRWISE_TEMPERATURE_DEFAULT, label="Pairwise Temperature"),
        gr.Textbox(lines=10, label="Instruction"),
        gr.Textbox(value=CRITERIA_DEFAULT, lines=5, label="Criteria (comma separated)"),
        gr.Number(value=NUM_GENERATIONS_DEFAULT, label="Number of Generations"),
        gr.Number(value=POOL_SIZE_DEFAULT, label="Top Picks Score Filter"),
        gr.Number(value=NUM_TOP_PICKS_DEFAULT, label="Top Picks Pairwise"),
        gr.Number(value=MAX_WORKERS_DEFAULT, label="Max Workers"),
        gr.Checkbox(value=SCORE_FILTER_DEFAULT, label="Enable Score Filter"),
        gr.Checkbox(value=PAIRWISE_FILTER_DEFAULT, label="Enable Pairwise Filter"),
        gr.Checkbox(value=SCORE_WITH_INSTRUCTION_DEFAULT, label="Pass Instruction to Score Model"),
        gr.Checkbox(value=PAIRWISE_WITH_INSTRUCTION_DEFAULT, label="Pass Instruction to Pairwise Model"),
        gr.Checkbox(value=GENERATE_THINKING_DEFAULT, label="Enable Thinking (Generate)", info="Enable Qwen3 think mode"),
        gr.Checkbox(value=SCORE_THINKING_DEFAULT, label="Enable Thinking (Score)" , info="Enable Qwen3 think mode"),
        gr.Checkbox(value=PAIRWISE_THINKING_DEFAULT, label="Enable Thinking (Pairwise)" , info="Enable Qwen3 think mode"),
        gr.Checkbox(value=True, label="Enable Explain (Score)", info="Prompt LLM to think step by step"),
        gr.Checkbox(value=True, label="Enable Explain (Pairwise)", info="Prompt LLM to think step by step"),
    ],
    outputs=[
        gr.Textbox(lines=10, label="Process"),
        gr.Plot(label="Score Distribution"),
        gr.Plot(label="Elo Ratings"),
        gr.Textbox(lines=50, label="Top picks"),
        gr.Textbox(lines=5, label="Token Usage"),
    ],
    description="Generate multiple completions and use score and pairwise filters to find the best answers.",
)

if __name__ == "__main__":
    demo.launch()