Spaces:
Sleeping
Sleeping
| """ | |
| Gradio Space for batch_outputs demo. | |
| Loads data from Hugging Face dataset AE-W/batch_outputs. | |
| """ | |
| import os | |
| import gradio as gr | |
| from dataset_loader import ( | |
| list_samples_bin, | |
| list_samples_clap, | |
| list_samples_dasheng, | |
| get_nn_demo_paths, | |
| get_results_demo_paths, | |
| ) | |
| def _method_from_view(view: str) -> str: | |
| """Return 'bin' | 'clap' | 'dasheng' from view label.""" | |
| if "Bin" in view: | |
| return "bin" | |
| if "Dasheng" in view: | |
| return "dasheng" | |
| return "clap" | |
| def _sample_choices(method: str) -> list[str]: | |
| if method == "bin": | |
| return list_samples_bin() | |
| if method == "dasheng": | |
| return list_samples_dasheng() | |
| return list_samples_clap() | |
| # Default: first sample of first method | |
| DEFAULT_METHOD = "clap" | |
| DEFAULT_SAMPLE_IDS = _sample_choices(DEFAULT_METHOD) | |
| TOP1_ID = DEFAULT_SAMPLE_IDS[0] if DEFAULT_SAMPLE_IDS else None | |
| def build_nn_view(sample_id: str | None, method: str): | |
| """NN view: NN1-NN10 from baseline. Each: prompt + spec on top, BG/FG/Mix audios below.""" | |
| if not sample_id: | |
| return (None,) * (10 * 5) | |
| data = get_nn_demo_paths(sample_id, top_k=10, method=method) | |
| out = [] | |
| for i, nn in enumerate(data.get("nn_list", [])[:10]): | |
| prompt = nn.get("prompt", "") or "" | |
| out.append(f"**NN{i+1}:** {prompt}" if prompt else "") | |
| out.extend([nn.get("spec"), nn.get("bg_wav"), nn.get("fg_wav"), nn.get("m_wav")]) | |
| while len(out) < 50: | |
| out.append(None) | |
| return tuple(out[:50]) | |
| def build_results_view(sample_id: str | None, method: str): | |
| """ | |
| Results view: 3 blocks. Per block: | |
| - Row1: Gaussian | Youtube spec + their BG/FG/Mix | |
| - Row2: Ours | NN baseline spec + their BG/FG/Mix | |
| """ | |
| if not sample_id: | |
| return (None,) * (3 * (1 + 4 * 4)) | |
| data = get_results_demo_paths(sample_id, method=method) | |
| out = [] | |
| for i in range(1, 4): | |
| block = data.get(f"block{i}", {}) | |
| prompt = block.get("prompt", "") or "" | |
| out.append(f"**NN{i}:** {prompt}" if prompt else "") | |
| # Top row: Gaussian, Youtube | |
| for key in ("baseline_gaussian", "baseline_youtube"): | |
| b = block.get(key, {}) | |
| out.extend([b.get("spec"), b.get("bg_wav"), b.get("fg_wav"), b.get("m_wav")]) | |
| # Bottom row: Ours, NN baseline (Original) | |
| for key in ("ours", "baseline_original"): | |
| b = block.get(key, {}) | |
| out.extend([b.get("spec"), b.get("bg_wav"), b.get("fg_wav"), b.get("m_wav")]) | |
| return tuple(out) | |
| with gr.Blocks( | |
| title="NearestNeighbor Audio Demo", | |
| css=""" | |
| .gradio-container { max-width: 1400px; } | |
| /* Results view: force all 4 spec images (Gaussian, Youtube, Ours, NN baseline) to same size */ | |
| #results-column img { width: 700px !important; height: 280px !important; object-fit: contain; } | |
| /* Reduce audio player row height (BG/FG/Mix) */ | |
| .compact-audio .gr-form { min-height: 0 !important; } | |
| .compact-audio > div { min-height: 0 !important; max-height: 72px !important; } | |
| .compact-audio audio { max-height: 48px !important; } | |
| """, | |
| ) as app: | |
| gr.Markdown("# NearestNeighbor Audio Demo") | |
| gr.Markdown("Data from [AE-W/batch_outputs](https://huggingface.co/datasets/AE-W/batch_outputs)") | |
| view_radio = gr.Radio( | |
| choices=[ | |
| "Nearest Neighbor (Bin)", | |
| "Results (Bin)", | |
| "Nearest Neighbor (Clap)", | |
| "Results (Clap)", | |
| "Nearest Neighbor (Dasheng)", | |
| "Results (Dasheng)", | |
| ], | |
| value="Nearest Neighbor (Clap)", | |
| label="View", | |
| ) | |
| noise_dd = gr.Dropdown(choices=DEFAULT_SAMPLE_IDS, value=TOP1_ID, label="Noise (ID)") | |
| gr.Markdown(""" | |
| **Three prompt-search methods**: **Bin** | **Clap** | **Dasheng**. Each combines `batch_outputs_*` and `generated_noises_*` from the dataset. | |
| **How to read the IDs** | |
| - **Numeric IDs** (e.g. `00_000357`) come from batch_outputs (SONYC/UrbanSound8k). | |
| - **Long prompt-like IDs** (e.g. `a_bulldozer_moving_gravel_...`) come from generated_noises. | |
| **Audio labels**: **BG** = background noise | **FG** = generated foreground | **Mix** = BG + FG | |
| """) | |
| # ---- NN View: NN1-NN10, each: spec on top, BG/FG/Mix audios below ---- | |
| with gr.Column(visible=True) as nn_col: | |
| nn_section_title = gr.Markdown("### Nearest Neighbor (Clap): Baseline outputs (top 10 prompts)") | |
| nn_outputs = [] | |
| for i in range(10): | |
| with gr.Group(): | |
| nn_p_md = gr.Markdown(value="") | |
| nn_outputs.append(nn_p_md) | |
| nn_img = gr.Image(label=f"NN{i+1}", show_label=True, height=480) | |
| nn_outputs.append(nn_img) | |
| nn_bg = gr.Audio(label="BG", show_label=True, elem_classes=["compact-audio"]) | |
| nn_fg = gr.Audio(label="FG", show_label=True, elem_classes=["compact-audio"]) | |
| nn_m = gr.Audio(label="Mix", show_label=True, elem_classes=["compact-audio"]) | |
| nn_outputs.extend([nn_bg, nn_fg, nn_m]) | |
| # ---- Results View: 3 prompts, each with 2 rows (Gaussian|Youtube, Ours|NN baseline) ---- | |
| with gr.Column(visible=False, elem_id="results-column") as res_col: | |
| res_section_title = gr.Markdown("### Results (Clap): 3 baselines + Ours (top 3 prompts)") | |
| res_outputs = [] | |
| for i in range(1, 4): | |
| with gr.Group(): | |
| res_p_md = gr.Markdown(value="") | |
| res_outputs.append(res_p_md) | |
| # Row 1: Gaussian | Youtube (spec + BG/FG/Mix under each) | |
| # Fixed height & width for consistent display | |
| spec_size = {"height": 280, "width": 700} | |
| with gr.Row(): | |
| with gr.Column(): | |
| res_outputs.append(gr.Image(label="Gaussian", **spec_size)) | |
| res_outputs.append(gr.Audio(label="BG", elem_classes=["compact-audio"])) | |
| res_outputs.append(gr.Audio(label="FG", elem_classes=["compact-audio"])) | |
| res_outputs.append(gr.Audio(label="Mix", elem_classes=["compact-audio"])) | |
| with gr.Column(): | |
| res_outputs.append(gr.Image(label="Youtube", **spec_size)) | |
| res_outputs.append(gr.Audio(label="BG", elem_classes=["compact-audio"])) | |
| res_outputs.append(gr.Audio(label="FG", elem_classes=["compact-audio"])) | |
| res_outputs.append(gr.Audio(label="Mix", elem_classes=["compact-audio"])) | |
| # Row 2: Ours | NN baseline (spec + BG/FG/Mix under each) | |
| with gr.Row(): | |
| with gr.Column(): | |
| res_outputs.append(gr.Image(label="Ours", **spec_size)) | |
| res_outputs.append(gr.Audio(label="BG", elem_classes=["compact-audio"])) | |
| res_outputs.append(gr.Audio(label="FG", elem_classes=["compact-audio"])) | |
| res_outputs.append(gr.Audio(label="Mix", elem_classes=["compact-audio"])) | |
| with gr.Column(): | |
| res_outputs.append(gr.Image(label="NN baseline", **spec_size)) | |
| res_outputs.append(gr.Audio(label="BG", elem_classes=["compact-audio"])) | |
| res_outputs.append(gr.Audio(label="FG", elem_classes=["compact-audio"])) | |
| res_outputs.append(gr.Audio(label="Mix", elem_classes=["compact-audio"])) | |
| def on_change(sid, view): | |
| method = _method_from_view(view) | |
| choices = _sample_choices(method) | |
| if sid not in choices and choices: | |
| sid = choices[0] | |
| is_nn = view in ("Nearest Neighbor (Bin)", "Nearest Neighbor (Clap)", "Nearest Neighbor (Dasheng)") | |
| is_res = view in ("Results (Bin)", "Results (Clap)", "Results (Dasheng)") | |
| nn_vals = build_nn_view(sid, method) | |
| res_vals = build_results_view(sid, method) | |
| nn_title = f"### Nearest Neighbor ({method.capitalize()}): Baseline outputs (top 10 prompts)" | |
| res_title = f"### Results ({method.capitalize()}): 3 baselines + Ours (top 3 prompts)" | |
| dd_update = gr.update(choices=choices, value=sid) | |
| return ( | |
| [gr.update(value=nn_title)] + list(nn_vals) + [gr.update(value=res_title)] + list(res_vals) + | |
| [gr.update(visible=is_nn), gr.update(visible=is_res), dd_update] | |
| ) | |
| all_outputs = [nn_section_title] + nn_outputs + [res_section_title] + res_outputs + [nn_col, res_col, noise_dd] | |
| noise_dd.change(on_change, inputs=[noise_dd, view_radio], outputs=all_outputs) | |
| view_radio.change(on_change, inputs=[noise_dd, view_radio], outputs=all_outputs) | |
| app.load(lambda: on_change(TOP1_ID, "Nearest Neighbor (Clap)"), outputs=all_outputs) | |
| _hf_hub_cache = os.environ.get( | |
| "HUGGINGFACE_HUB_CACHE", | |
| os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub"), | |
| ) | |
| app.launch(allowed_paths=[_hf_hub_cache]) | |