File size: 3,140 Bytes
541c6b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import string
from dataclasses import dataclass
from pathlib import Path
from typing import List

import argbind
import gradio as gr
from audiotools import preference as pr


@argbind.bind(without_prefix=True)
@dataclass
class Config:
    folder: str = None
    save_path: str = "results.csv"
    conditions: List[str] = None
    reference: str = None
    seed: int = 0
    share: bool = False
    n_samples: int = 10


def get_text(wav_file: str):
    txt_file = Path(wav_file).with_suffix(".txt")
    if Path(txt_file).exists():
        with open(txt_file, "r") as f:
            txt = f.read()
    else:
        txt = ""
    return f"""<div style="text-align:center;font-size:large;">{txt}</div>"""


def main(config: Config):
    with gr.Blocks() as app:
        save_path = config.save_path
        samples = gr.State(pr.Samples(config.folder, n_samples=config.n_samples))

        reference = config.reference
        conditions = config.conditions

        player = pr.Player(app)
        player.create()
        if reference is not None:
            player.add("Play Reference")

        user = pr.create_tracker(app)
        ratings = []

        with gr.Row():
            txt = gr.HTML("")

        with gr.Row():
            gr.Button("Rate audio quality", interactive=False)
            with gr.Column(scale=8):
                gr.HTML(pr.slider_mushra)

        for i in range(len(conditions)):
            with gr.Row().style(equal_height=True):
                x = string.ascii_uppercase[i]
                player.add(f"Play {x}")
                with gr.Column(scale=9):
                    ratings.append(gr.Slider(value=50, interactive=True))

        def build(user, samples, *ratings):
            # Filter out samples user has done already, by looking in the CSV.
            samples.filter_completed(user, save_path)

            # Write results to CSV
            if samples.current > 0:
                start_idx = 1 if reference is not None else 0
                name = samples.names[samples.current - 1]
                result = {"sample": name, "user": user}
                for k, r in zip(samples.order[start_idx:], ratings):
                    result[k] = r
                pr.save_result(result, save_path)

            updates, done, pbar = samples.get_next_sample(reference, conditions)
            wav_file = updates[0]["value"]

            txt_update = gr.update(value=get_text(wav_file))

            return (
                updates
                + [gr.update(value=50) for _ in ratings]
                + [done, samples, pbar, txt_update]
            )

        progress = gr.HTML()
        begin = gr.Button("Submit", elem_id="start-survey")
        begin.click(
            fn=build,
            inputs=[user, samples] + ratings,
            outputs=player.to_list() + ratings + [begin, samples, progress, txt],
        ).then(None, _js=pr.reset_player)

        # Comment this back in to actually launch the script.
        app.launch(share=config.share)


if __name__ == "__main__":
    args = argbind.parse_args()
    with argbind.scope(args):
        config = Config()
        main(config)