""" Gradio Space for batch_outputs demo. Loads data from Hugging Face dataset AE-W/batch_outputs. """ import os import gradio as gr from dataset_loader import ( list_samples_bin, list_samples_clap, list_samples_dasheng, get_nn_demo_paths, get_results_demo_paths, ) def _method_from_view(view: str) -> str: """Return 'bin' | 'clap' | 'dasheng' from view label.""" if "Bin" in view: return "bin" if "Dasheng" in view: return "dasheng" return "clap" def _sample_choices(method: str) -> list[str]: if method == "bin": return list_samples_bin() if method == "dasheng": return list_samples_dasheng() return list_samples_clap() # Default: first sample of first method DEFAULT_METHOD = "clap" DEFAULT_SAMPLE_IDS = _sample_choices(DEFAULT_METHOD) TOP1_ID = DEFAULT_SAMPLE_IDS[0] if DEFAULT_SAMPLE_IDS else None def build_nn_view(sample_id: str | None, method: str): """NN view: NN1-NN10 from baseline. Each: prompt + spec on top, BG/FG/Mix audios below.""" if not sample_id: return (None,) * (10 * 5) data = get_nn_demo_paths(sample_id, top_k=10, method=method) out = [] for i, nn in enumerate(data.get("nn_list", [])[:10]): prompt = nn.get("prompt", "") or "" out.append(f"**NN{i+1}:** {prompt}" if prompt else "") out.extend([nn.get("spec"), nn.get("bg_wav"), nn.get("fg_wav"), nn.get("m_wav")]) while len(out) < 50: out.append(None) return tuple(out[:50]) def build_results_view(sample_id: str | None, method: str): """ Results view: 3 blocks. Per block: - Row1: Gaussian | Youtube spec + their BG/FG/Mix - Row2: Ours | NN baseline spec + their BG/FG/Mix """ if not sample_id: return (None,) * (3 * (1 + 4 * 4)) data = get_results_demo_paths(sample_id, method=method) out = [] for i in range(1, 4): block = data.get(f"block{i}", {}) prompt = block.get("prompt", "") or "" out.append(f"**NN{i}:** {prompt}" if prompt else "") # Top row: Gaussian, Youtube for key in ("baseline_gaussian", "baseline_youtube"): b = block.get(key, {}) out.extend([b.get("spec"), b.get("bg_wav"), b.get("fg_wav"), b.get("m_wav")]) # Bottom row: Ours, NN baseline (Original) for key in ("ours", "baseline_original"): b = block.get(key, {}) out.extend([b.get("spec"), b.get("bg_wav"), b.get("fg_wav"), b.get("m_wav")]) return tuple(out) with gr.Blocks( title="NearestNeighbor Audio Demo", css=""" .gradio-container { max-width: 1400px; } /* Results view: force all 4 spec images (Gaussian, Youtube, Ours, NN baseline) to same size */ #results-column img { width: 700px !important; height: 280px !important; object-fit: contain; } /* Reduce audio player row height (BG/FG/Mix) */ .compact-audio .gr-form { min-height: 0 !important; } .compact-audio > div { min-height: 0 !important; max-height: 72px !important; } .compact-audio audio { max-height: 48px !important; } """, ) as app: gr.Markdown("# NearestNeighbor Audio Demo") gr.Markdown("Data from [AE-W/batch_outputs](https://huggingface.co/datasets/AE-W/batch_outputs)") view_radio = gr.Radio( choices=[ "Nearest Neighbor (Bin)", "Results (Bin)", "Nearest Neighbor (Clap)", "Results (Clap)", "Nearest Neighbor (Dasheng)", "Results (Dasheng)", ], value="Nearest Neighbor (Clap)", label="View", ) noise_dd = gr.Dropdown(choices=DEFAULT_SAMPLE_IDS, value=TOP1_ID, label="Noise (ID)") gr.Markdown(""" **Three prompt-search methods**: **Bin** | **Clap** | **Dasheng**. Each combines `batch_outputs_*` and `generated_noises_*` from the dataset. **How to read the IDs** - **Numeric IDs** (e.g. `00_000357`) come from batch_outputs (SONYC/UrbanSound8k). - **Long prompt-like IDs** (e.g. `a_bulldozer_moving_gravel_...`) come from generated_noises. **Audio labels**: **BG** = background noise | **FG** = generated foreground | **Mix** = BG + FG """) # ---- NN View: NN1-NN10, each: spec on top, BG/FG/Mix audios below ---- with gr.Column(visible=True) as nn_col: nn_section_title = gr.Markdown("### Nearest Neighbor (Clap): Baseline outputs (top 10 prompts)") nn_outputs = [] for i in range(10): with gr.Group(): nn_p_md = gr.Markdown(value="") nn_outputs.append(nn_p_md) nn_img = gr.Image(label=f"NN{i+1}", show_label=True, height=480) nn_outputs.append(nn_img) nn_bg = gr.Audio(label="BG", show_label=True, elem_classes=["compact-audio"]) nn_fg = gr.Audio(label="FG", show_label=True, elem_classes=["compact-audio"]) nn_m = gr.Audio(label="Mix", show_label=True, elem_classes=["compact-audio"]) nn_outputs.extend([nn_bg, nn_fg, nn_m]) # ---- Results View: 3 prompts, each with 2 rows (Gaussian|Youtube, Ours|NN baseline) ---- with gr.Column(visible=False, elem_id="results-column") as res_col: res_section_title = gr.Markdown("### Results (Clap): 3 baselines + Ours (top 3 prompts)") res_outputs = [] for i in range(1, 4): with gr.Group(): res_p_md = gr.Markdown(value="") res_outputs.append(res_p_md) # Row 1: Gaussian | Youtube (spec + BG/FG/Mix under each) # Fixed height & width for consistent display spec_size = {"height": 280, "width": 700} with gr.Row(): with gr.Column(): res_outputs.append(gr.Image(label="Gaussian", **spec_size)) res_outputs.append(gr.Audio(label="BG", elem_classes=["compact-audio"])) res_outputs.append(gr.Audio(label="FG", elem_classes=["compact-audio"])) res_outputs.append(gr.Audio(label="Mix", elem_classes=["compact-audio"])) with gr.Column(): res_outputs.append(gr.Image(label="Youtube", **spec_size)) res_outputs.append(gr.Audio(label="BG", elem_classes=["compact-audio"])) res_outputs.append(gr.Audio(label="FG", elem_classes=["compact-audio"])) res_outputs.append(gr.Audio(label="Mix", elem_classes=["compact-audio"])) # Row 2: Ours | NN baseline (spec + BG/FG/Mix under each) with gr.Row(): with gr.Column(): res_outputs.append(gr.Image(label="Ours", **spec_size)) res_outputs.append(gr.Audio(label="BG", elem_classes=["compact-audio"])) res_outputs.append(gr.Audio(label="FG", elem_classes=["compact-audio"])) res_outputs.append(gr.Audio(label="Mix", elem_classes=["compact-audio"])) with gr.Column(): res_outputs.append(gr.Image(label="NN baseline", **spec_size)) res_outputs.append(gr.Audio(label="BG", elem_classes=["compact-audio"])) res_outputs.append(gr.Audio(label="FG", elem_classes=["compact-audio"])) res_outputs.append(gr.Audio(label="Mix", elem_classes=["compact-audio"])) def on_change(sid, view): method = _method_from_view(view) choices = _sample_choices(method) if sid not in choices and choices: sid = choices[0] is_nn = view in ("Nearest Neighbor (Bin)", "Nearest Neighbor (Clap)", "Nearest Neighbor (Dasheng)") is_res = view in ("Results (Bin)", "Results (Clap)", "Results (Dasheng)") nn_vals = build_nn_view(sid, method) res_vals = build_results_view(sid, method) nn_title = f"### Nearest Neighbor ({method.capitalize()}): Baseline outputs (top 10 prompts)" res_title = f"### Results ({method.capitalize()}): 3 baselines + Ours (top 3 prompts)" dd_update = gr.update(choices=choices, value=sid) return ( [gr.update(value=nn_title)] + list(nn_vals) + [gr.update(value=res_title)] + list(res_vals) + [gr.update(visible=is_nn), gr.update(visible=is_res), dd_update] ) all_outputs = [nn_section_title] + nn_outputs + [res_section_title] + res_outputs + [nn_col, res_col, noise_dd] noise_dd.change(on_change, inputs=[noise_dd, view_radio], outputs=all_outputs) view_radio.change(on_change, inputs=[noise_dd, view_radio], outputs=all_outputs) app.load(lambda: on_change(TOP1_ID, "Nearest Neighbor (Clap)"), outputs=all_outputs) _hf_hub_cache = os.environ.get( "HUGGINGFACE_HUB_CACHE", os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub"), ) app.launch(allowed_paths=[_hf_hub_cache])