File size: 8,932 Bytes
65600d1
 
 
 
087ec71
 
65600d1
 
b21e8c5
d42fccb
 
349e3bf
b21e8c5
 
 
65600d1
 
d42fccb
 
 
 
 
 
 
65600d1
 
d42fccb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8efe5a4
65600d1
8efe5a4
d42fccb
86d12ed
b44206b
 
 
8efe5a4
 
86d12ed
8efe5a4
86d12ed
 
d42fccb
d3b19b7
 
 
 
 
86d12ed
 
d42fccb
6bd8e43
9461a66
 
 
308a36f
d3b19b7
 
 
 
 
 
86d12ed
 
6bd8e43
65600d1
 
5609e0c
 
 
 
 
 
afef5d4
 
 
 
5609e0c
 
65600d1
 
 
86d12ed
c8edbbe
d42fccb
 
 
 
c8edbbe
 
 
d42fccb
86d12ed
 
d42fccb
86d12ed
b303c51
d42fccb
 
b303c51
d42fccb
 
b303c51
86d12ed
b303c51
 
8efe5a4
86d12ed
d42fccb
b44206b
 
8efe5a4
b44206b
 
72c0ffe
8efe5a4
afef5d4
 
 
8efe5a4
86d12ed
d3b19b7
5609e0c
d42fccb
86d12ed
 
 
 
 
d3b19b7
b21e8c5
 
d3b19b7
 
b21e8c5
afef5d4
 
 
d3b19b7
b21e8c5
afef5d4
 
 
d3b19b7
86d12ed
d3b19b7
b21e8c5
afef5d4
 
 
d3b19b7
b21e8c5
afef5d4
 
 
86d12ed
 
d42fccb
 
 
 
 
 
 
 
 
 
349e3bf
86d12ed
c8edbbe
 
86d12ed
 
c8edbbe
86d12ed
 
 
 
d42fccb
65600d1
087ec71
 
 
 
5b712d9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
"""
Gradio Space for batch_outputs demo.
Loads data from Hugging Face dataset AE-W/batch_outputs.
"""
import os

import gradio as gr

from dataset_loader import (
    list_samples_bin,
    list_samples_clap,
    list_samples_dasheng,
    get_nn_demo_paths,
    get_results_demo_paths,
)


def _method_from_view(view: str) -> str:
    """Return 'bin' | 'clap' | 'dasheng' from view label."""
    if "Bin" in view:
        return "bin"
    if "Dasheng" in view:
        return "dasheng"
    return "clap"


def _sample_choices(method: str) -> list[str]:
    if method == "bin":
        return list_samples_bin()
    if method == "dasheng":
        return list_samples_dasheng()
    return list_samples_clap()


# Default: first sample of first method
DEFAULT_METHOD = "clap"
DEFAULT_SAMPLE_IDS = _sample_choices(DEFAULT_METHOD)
TOP1_ID = DEFAULT_SAMPLE_IDS[0] if DEFAULT_SAMPLE_IDS else None


def build_nn_view(sample_id: str | None, method: str):
    """NN view: NN1-NN10 from baseline. Each: prompt + spec on top, BG/FG/Mix audios below."""
    if not sample_id:
        return (None,) * (10 * 5)
    data = get_nn_demo_paths(sample_id, top_k=10, method=method)
    out = []
    for i, nn in enumerate(data.get("nn_list", [])[:10]):
        prompt = nn.get("prompt", "") or ""
        out.append(f"**NN{i+1}:** {prompt}" if prompt else "")
        out.extend([nn.get("spec"), nn.get("bg_wav"), nn.get("fg_wav"), nn.get("m_wav")])
    while len(out) < 50:
        out.append(None)
    return tuple(out[:50])


def build_results_view(sample_id: str | None, method: str):
    """
    Results view: 3 blocks. Per block:
    - Row1: Gaussian | Youtube spec + their BG/FG/Mix
    - Row2: Ours | NN baseline spec + their BG/FG/Mix
    """
    if not sample_id:
        return (None,) * (3 * (1 + 4 * 4))
    data = get_results_demo_paths(sample_id, method=method)
    out = []
    for i in range(1, 4):
        block = data.get(f"block{i}", {})
        prompt = block.get("prompt", "") or ""
        out.append(f"**NN{i}:** {prompt}" if prompt else "")
        # Top row: Gaussian, Youtube
        for key in ("baseline_gaussian", "baseline_youtube"):
            b = block.get(key, {})
            out.extend([b.get("spec"), b.get("bg_wav"), b.get("fg_wav"), b.get("m_wav")])
        # Bottom row: Ours, NN baseline (Original)
        for key in ("ours", "baseline_original"):
            b = block.get(key, {})
            out.extend([b.get("spec"), b.get("bg_wav"), b.get("fg_wav"), b.get("m_wav")])
    return tuple(out)


with gr.Blocks(
    title="NearestNeighbor Audio Demo",
    css="""
    .gradio-container { max-width: 1400px; }
    /* Results view: force all 4 spec images (Gaussian, Youtube, Ours, NN baseline) to same size */
    #results-column img { width: 700px !important; height: 280px !important; object-fit: contain; }
    /* Reduce audio player row height (BG/FG/Mix) */
    .compact-audio .gr-form { min-height: 0 !important; }
    .compact-audio > div { min-height: 0 !important; max-height: 72px !important; }
    .compact-audio audio { max-height: 48px !important; }
    """,
) as app:
    gr.Markdown("# NearestNeighbor Audio Demo")
    gr.Markdown("Data from [AE-W/batch_outputs](https://huggingface.co/datasets/AE-W/batch_outputs)")

    view_radio = gr.Radio(
        choices=[
            "Nearest Neighbor (Bin)",
            "Results (Bin)",
            "Nearest Neighbor (Clap)",
            "Results (Clap)",
            "Nearest Neighbor (Dasheng)",
            "Results (Dasheng)",
        ],
        value="Nearest Neighbor (Clap)",
        label="View",
    )
    noise_dd = gr.Dropdown(choices=DEFAULT_SAMPLE_IDS, value=TOP1_ID, label="Noise (ID)")

    gr.Markdown("""
**Three prompt-search methods**: **Bin** | **Clap** | **Dasheng**. Each combines `batch_outputs_*` and `generated_noises_*` from the dataset.

**How to read the IDs**
- **Numeric IDs** (e.g. `00_000357`) come from batch_outputs (SONYC/UrbanSound8k).
- **Long prompt-like IDs** (e.g. `a_bulldozer_moving_gravel_...`) come from generated_noises.

**Audio labels**: **BG** = background noise | **FG** = generated foreground | **Mix** = BG + FG
""")

    # ---- NN View: NN1-NN10, each: spec on top, BG/FG/Mix audios below ----
    with gr.Column(visible=True) as nn_col:
        nn_section_title = gr.Markdown("### Nearest Neighbor (Clap): Baseline outputs (top 10 prompts)")
        nn_outputs = []
        for i in range(10):
            with gr.Group():
                nn_p_md = gr.Markdown(value="")
                nn_outputs.append(nn_p_md)
                nn_img = gr.Image(label=f"NN{i+1}", show_label=True, height=480)
                nn_outputs.append(nn_img)
                nn_bg = gr.Audio(label="BG", show_label=True, elem_classes=["compact-audio"])
                nn_fg = gr.Audio(label="FG", show_label=True, elem_classes=["compact-audio"])
                nn_m = gr.Audio(label="Mix", show_label=True, elem_classes=["compact-audio"])
                nn_outputs.extend([nn_bg, nn_fg, nn_m])

    # ---- Results View: 3 prompts, each with 2 rows (Gaussian|Youtube, Ours|NN baseline) ----
    with gr.Column(visible=False, elem_id="results-column") as res_col:
        res_section_title = gr.Markdown("### Results (Clap): 3 baselines + Ours (top 3 prompts)")
        res_outputs = []
        for i in range(1, 4):
            with gr.Group():
                res_p_md = gr.Markdown(value="")
                res_outputs.append(res_p_md)
                # Row 1: Gaussian | Youtube (spec + BG/FG/Mix under each)
                # Fixed height & width for consistent display
                spec_size = {"height": 280, "width": 700}
                with gr.Row():
                    with gr.Column():
                        res_outputs.append(gr.Image(label="Gaussian", **spec_size))
                        res_outputs.append(gr.Audio(label="BG", elem_classes=["compact-audio"]))
                        res_outputs.append(gr.Audio(label="FG", elem_classes=["compact-audio"]))
                        res_outputs.append(gr.Audio(label="Mix", elem_classes=["compact-audio"]))
                    with gr.Column():
                        res_outputs.append(gr.Image(label="Youtube", **spec_size))
                        res_outputs.append(gr.Audio(label="BG", elem_classes=["compact-audio"]))
                        res_outputs.append(gr.Audio(label="FG", elem_classes=["compact-audio"]))
                        res_outputs.append(gr.Audio(label="Mix", elem_classes=["compact-audio"]))
                # Row 2: Ours | NN baseline (spec + BG/FG/Mix under each)
                with gr.Row():
                    with gr.Column():
                        res_outputs.append(gr.Image(label="Ours", **spec_size))
                        res_outputs.append(gr.Audio(label="BG", elem_classes=["compact-audio"]))
                        res_outputs.append(gr.Audio(label="FG", elem_classes=["compact-audio"]))
                        res_outputs.append(gr.Audio(label="Mix", elem_classes=["compact-audio"]))
                    with gr.Column():
                        res_outputs.append(gr.Image(label="NN baseline", **spec_size))
                        res_outputs.append(gr.Audio(label="BG", elem_classes=["compact-audio"]))
                        res_outputs.append(gr.Audio(label="FG", elem_classes=["compact-audio"]))
                        res_outputs.append(gr.Audio(label="Mix", elem_classes=["compact-audio"]))

    def on_change(sid, view):
        method = _method_from_view(view)
        choices = _sample_choices(method)
        if sid not in choices and choices:
            sid = choices[0]
        is_nn = view in ("Nearest Neighbor (Bin)", "Nearest Neighbor (Clap)", "Nearest Neighbor (Dasheng)")
        is_res = view in ("Results (Bin)", "Results (Clap)", "Results (Dasheng)")
        nn_vals = build_nn_view(sid, method)
        res_vals = build_results_view(sid, method)
        nn_title = f"### Nearest Neighbor ({method.capitalize()}): Baseline outputs (top 10 prompts)"
        res_title = f"### Results ({method.capitalize()}): 3 baselines + Ours (top 3 prompts)"
        dd_update = gr.update(choices=choices, value=sid)
        return (
            [gr.update(value=nn_title)] + list(nn_vals) + [gr.update(value=res_title)] + list(res_vals) +
            [gr.update(visible=is_nn), gr.update(visible=is_res), dd_update]
        )

    all_outputs = [nn_section_title] + nn_outputs + [res_section_title] + res_outputs + [nn_col, res_col, noise_dd]

    noise_dd.change(on_change, inputs=[noise_dd, view_radio], outputs=all_outputs)
    view_radio.change(on_change, inputs=[noise_dd, view_radio], outputs=all_outputs)

    app.load(lambda: on_change(TOP1_ID, "Nearest Neighbor (Clap)"), outputs=all_outputs)

_hf_hub_cache = os.environ.get(
    "HUGGINGFACE_HUB_CACHE",
    os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub"),
)
app.launch(allowed_paths=[_hf_hub_cache])