File size: 6,536 Bytes
4b7eae4
 
 
 
 
 
 
 
 
 
 
c7c2e28
 
4b7eae4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7c2e28
4b7eae4
c7c2e28
 
4b7eae4
 
 
 
 
 
 
 
 
 
c7c2e28
4b7eae4
c7c2e28
 
 
 
4b7eae4
 
 
 
 
 
 
 
c7c2e28
 
4b7eae4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7c2e28
4b7eae4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
from __future__ import annotations

import json
import os
import random
from collections import Counter, defaultdict
from datetime import datetime, timezone
from pathlib import Path

import gradio as gr

from storage import append_vote, load_vote_rows, new_vote_id


ROOT = Path(__file__).resolve().parent
MANIFEST_PATH = ROOT / "manifest.json"
VOTES_PATH = ROOT / "votes.csv"


def load_manifest() -> list[dict]:
    if not MANIFEST_PATH.exists():
        raise FileNotFoundError("manifest.json is missing. Run scripts/tts_bakeoff/arena_assets.py before deploying.")
    return json.loads(MANIFEST_PATH.read_text())


CLIPS = load_manifest()


def sample_options() -> list[str]:
    return sorted({clip["sample_id"] for clip in CLIPS})


def choose_pair(sample_id: str = "short_dialogue") -> tuple[str, str, dict, str]:
    candidates = [clip for clip in CLIPS if clip["sample_id"] == sample_id]
    pairs = [
        (left, right)
        for left in candidates
        for right in candidates
        if left["clip_id"] < right["clip_id"]
        and left["sample_id"] == right["sample_id"]
        and left["model_id"] != right["model_id"]
    ]
    if not pairs:
        raise gr.Error("Not enough clips for this sample.")
    left, right = random.choice(pairs)
    if random.random() < 0.5:
        left, right = right, left

    state = {"left": left, "right": right}
    status = "Dialogue sample"
    return str(ROOT / left["path"]), str(ROOT / right["path"]), state, status


def vote(
    choice: str,
    state: dict | None,
    voter_id: str,
    notes: str,
    sample_id: str,
) -> tuple[str, str, dict, str, str]:
    if not state:
        raise gr.Error("Load a pair before voting.")

    left = state["left"]
    right = state["right"]
    if choice == "left":
        winner_model_id = left["model_id"]
    elif choice == "right":
        winner_model_id = right["model_id"]
    else:
        winner_model_id = ""

    now = datetime.now(timezone.utc)
    row = {
        "timestamp": now.isoformat(timespec="seconds"),
        "vote_id": new_vote_id(now),
        "voter_id": voter_id.strip(),
        "sample_id": left["sample_id"],
        "left_clip_id": left["clip_id"],
        "right_clip_id": right["clip_id"],
        "left_model_id": left["model_id"],
        "right_model_id": right["model_id"],
        "winner": choice,
        "winner_model_id": winner_model_id,
        "notes": notes.strip(),
    }
    persistence_error = append_vote(VOTES_PATH, row)
    audio_a, audio_b, next_state, pair_status = choose_pair(sample_id)
    save_status = "Vote saved to persistent log."
    if persistence_error:
        save_status = f"Vote saved locally; persistent log unavailable ({persistence_error})."
    return audio_a, audio_b, next_state, f"{save_status} {pair_status}", ""


def load_results(admin_code: str) -> tuple[list[list], str | None]:
    expected = os.getenv("ARENA_ADMIN_CODE")
    if expected and admin_code != expected:
        raise gr.Error("Wrong admin code.")
    if not expected:
        raise gr.Error("Admin results are disabled until ARENA_ADMIN_CODE is set.")
    rows, source = load_vote_rows(VOTES_PATH)
    if not rows:
        return [], None

    wins = Counter(row["winner_model_id"] for row in rows if row["winner_model_id"])
    appearances = defaultdict(int)
    labels = {clip["model_id"]: clip["label"] for clip in CLIPS}
    for row in rows:
        appearances[row["left_model_id"]] += 1
        appearances[row["right_model_id"]] += 1

    table = []
    for model_id, count in sorted(appearances.items()):
        win_count = wins[model_id]
        table.append([
            labels.get(model_id, model_id),
            model_id,
            win_count,
            count,
            round(win_count / count, 3) if count else 0,
        ])
    return table, str(VOTES_PATH) if "local" in source else None


with gr.Blocks(title="Sonic Caucus") as demo:
    gr.Markdown("# Sonic Caucus")
    gr.Markdown("Blind pairwise listening test. Vote on the better clip without model names.")

    pair_state = gr.State()
    with gr.Row():
        voter_id = gr.Textbox(label="Voter ID", placeholder="Name or initials", scale=2)
        sample_id = gr.Radio(sample_options(), value="short_dialogue", label="Sample", visible=False)
        next_pair = gr.Button("Load random pair", variant="primary", scale=1)

    status = gr.Markdown()
    with gr.Row():
        audio_a = gr.Audio(label="Clip A", type="filepath", interactive=False)
        audio_b = gr.Audio(label="Clip B", type="filepath", interactive=False)

    notes = gr.Textbox(label="Notes", lines=2, placeholder="Optional")

    with gr.Row():
        vote_a = gr.Button("A better")
        vote_b = gr.Button("B better")
        vote_tie = gr.Button("Tie")
        vote_bad = gr.Button("Both bad")

    next_pair.click(
        choose_pair,
        inputs=[sample_id],
        outputs=[audio_a, audio_b, pair_state, status],
    )
    vote_a.click(
        lambda state, voter, note, sample: vote("left", state, voter, note, sample),
        inputs=[pair_state, voter_id, notes, sample_id],
        outputs=[audio_a, audio_b, pair_state, status, notes],
    )
    vote_b.click(
        lambda state, voter, note, sample: vote("right", state, voter, note, sample),
        inputs=[pair_state, voter_id, notes, sample_id],
        outputs=[audio_a, audio_b, pair_state, status, notes],
    )
    vote_tie.click(
        lambda state, voter, note, sample: vote("tie", state, voter, note, sample),
        inputs=[pair_state, voter_id, notes, sample_id],
        outputs=[audio_a, audio_b, pair_state, status, notes],
    )
    vote_bad.click(
        lambda state, voter, note, sample: vote("both_bad", state, voter, note, sample),
        inputs=[pair_state, voter_id, notes, sample_id],
        outputs=[audio_a, audio_b, pair_state, status, notes],
    )

    with gr.Accordion("Admin results", open=False):
        admin_code = gr.Textbox(label="Admin code", type="password")
        load_results_button = gr.Button("Load results")
        results_table = gr.Dataframe(
            headers=["Model", "Model ID", "Wins", "Appearances", "Win rate"],
            datatype=["str", "str", "number", "number", "number"],
            interactive=False,
        )
        votes_file = gr.File(label="Votes CSV")
        load_results_button.click(load_results, inputs=[admin_code], outputs=[results_table, votes_file])


if __name__ == "__main__":
    demo.launch(auth=("team", os.getenv("ARENA_PASSWORD", "voice-arena")))