NearestNeighbor / app.py
AE-W's picture
Upload folder using huggingface_hub
d42fccb verified
"""
Gradio Space for batch_outputs demo.
Loads data from Hugging Face dataset AE-W/batch_outputs.
"""
import os
import gradio as gr
from dataset_loader import (
list_samples_bin,
list_samples_clap,
list_samples_dasheng,
get_nn_demo_paths,
get_results_demo_paths,
)
def _method_from_view(view: str) -> str:
"""Return 'bin' | 'clap' | 'dasheng' from view label."""
if "Bin" in view:
return "bin"
if "Dasheng" in view:
return "dasheng"
return "clap"
def _sample_choices(method: str) -> list[str]:
if method == "bin":
return list_samples_bin()
if method == "dasheng":
return list_samples_dasheng()
return list_samples_clap()
# Default: first sample of first method
DEFAULT_METHOD = "clap"
DEFAULT_SAMPLE_IDS = _sample_choices(DEFAULT_METHOD)
TOP1_ID = DEFAULT_SAMPLE_IDS[0] if DEFAULT_SAMPLE_IDS else None
def build_nn_view(sample_id: str | None, method: str):
"""NN view: NN1-NN10 from baseline. Each: prompt + spec on top, BG/FG/Mix audios below."""
if not sample_id:
return (None,) * (10 * 5)
data = get_nn_demo_paths(sample_id, top_k=10, method=method)
out = []
for i, nn in enumerate(data.get("nn_list", [])[:10]):
prompt = nn.get("prompt", "") or ""
out.append(f"**NN{i+1}:** {prompt}" if prompt else "")
out.extend([nn.get("spec"), nn.get("bg_wav"), nn.get("fg_wav"), nn.get("m_wav")])
while len(out) < 50:
out.append(None)
return tuple(out[:50])
def build_results_view(sample_id: str | None, method: str):
"""
Results view: 3 blocks. Per block:
- Row1: Gaussian | Youtube spec + their BG/FG/Mix
- Row2: Ours | NN baseline spec + their BG/FG/Mix
"""
if not sample_id:
return (None,) * (3 * (1 + 4 * 4))
data = get_results_demo_paths(sample_id, method=method)
out = []
for i in range(1, 4):
block = data.get(f"block{i}", {})
prompt = block.get("prompt", "") or ""
out.append(f"**NN{i}:** {prompt}" if prompt else "")
# Top row: Gaussian, Youtube
for key in ("baseline_gaussian", "baseline_youtube"):
b = block.get(key, {})
out.extend([b.get("spec"), b.get("bg_wav"), b.get("fg_wav"), b.get("m_wav")])
# Bottom row: Ours, NN baseline (Original)
for key in ("ours", "baseline_original"):
b = block.get(key, {})
out.extend([b.get("spec"), b.get("bg_wav"), b.get("fg_wav"), b.get("m_wav")])
return tuple(out)
with gr.Blocks(
title="NearestNeighbor Audio Demo",
css="""
.gradio-container { max-width: 1400px; }
/* Results view: force all 4 spec images (Gaussian, Youtube, Ours, NN baseline) to same size */
#results-column img { width: 700px !important; height: 280px !important; object-fit: contain; }
/* Reduce audio player row height (BG/FG/Mix) */
.compact-audio .gr-form { min-height: 0 !important; }
.compact-audio > div { min-height: 0 !important; max-height: 72px !important; }
.compact-audio audio { max-height: 48px !important; }
""",
) as app:
gr.Markdown("# NearestNeighbor Audio Demo")
gr.Markdown("Data from [AE-W/batch_outputs](https://huggingface.co/datasets/AE-W/batch_outputs)")
view_radio = gr.Radio(
choices=[
"Nearest Neighbor (Bin)",
"Results (Bin)",
"Nearest Neighbor (Clap)",
"Results (Clap)",
"Nearest Neighbor (Dasheng)",
"Results (Dasheng)",
],
value="Nearest Neighbor (Clap)",
label="View",
)
noise_dd = gr.Dropdown(choices=DEFAULT_SAMPLE_IDS, value=TOP1_ID, label="Noise (ID)")
gr.Markdown("""
**Three prompt-search methods**: **Bin** | **Clap** | **Dasheng**. Each combines `batch_outputs_*` and `generated_noises_*` from the dataset.
**How to read the IDs**
- **Numeric IDs** (e.g. `00_000357`) come from batch_outputs (SONYC/UrbanSound8k).
- **Long prompt-like IDs** (e.g. `a_bulldozer_moving_gravel_...`) come from generated_noises.
**Audio labels**: **BG** = background noise | **FG** = generated foreground | **Mix** = BG + FG
""")
# ---- NN View: NN1-NN10, each: spec on top, BG/FG/Mix audios below ----
with gr.Column(visible=True) as nn_col:
nn_section_title = gr.Markdown("### Nearest Neighbor (Clap): Baseline outputs (top 10 prompts)")
nn_outputs = []
for i in range(10):
with gr.Group():
nn_p_md = gr.Markdown(value="")
nn_outputs.append(nn_p_md)
nn_img = gr.Image(label=f"NN{i+1}", show_label=True, height=480)
nn_outputs.append(nn_img)
nn_bg = gr.Audio(label="BG", show_label=True, elem_classes=["compact-audio"])
nn_fg = gr.Audio(label="FG", show_label=True, elem_classes=["compact-audio"])
nn_m = gr.Audio(label="Mix", show_label=True, elem_classes=["compact-audio"])
nn_outputs.extend([nn_bg, nn_fg, nn_m])
# ---- Results View: 3 prompts, each with 2 rows (Gaussian|Youtube, Ours|NN baseline) ----
with gr.Column(visible=False, elem_id="results-column") as res_col:
res_section_title = gr.Markdown("### Results (Clap): 3 baselines + Ours (top 3 prompts)")
res_outputs = []
for i in range(1, 4):
with gr.Group():
res_p_md = gr.Markdown(value="")
res_outputs.append(res_p_md)
# Row 1: Gaussian | Youtube (spec + BG/FG/Mix under each)
# Fixed height & width for consistent display
spec_size = {"height": 280, "width": 700}
with gr.Row():
with gr.Column():
res_outputs.append(gr.Image(label="Gaussian", **spec_size))
res_outputs.append(gr.Audio(label="BG", elem_classes=["compact-audio"]))
res_outputs.append(gr.Audio(label="FG", elem_classes=["compact-audio"]))
res_outputs.append(gr.Audio(label="Mix", elem_classes=["compact-audio"]))
with gr.Column():
res_outputs.append(gr.Image(label="Youtube", **spec_size))
res_outputs.append(gr.Audio(label="BG", elem_classes=["compact-audio"]))
res_outputs.append(gr.Audio(label="FG", elem_classes=["compact-audio"]))
res_outputs.append(gr.Audio(label="Mix", elem_classes=["compact-audio"]))
# Row 2: Ours | NN baseline (spec + BG/FG/Mix under each)
with gr.Row():
with gr.Column():
res_outputs.append(gr.Image(label="Ours", **spec_size))
res_outputs.append(gr.Audio(label="BG", elem_classes=["compact-audio"]))
res_outputs.append(gr.Audio(label="FG", elem_classes=["compact-audio"]))
res_outputs.append(gr.Audio(label="Mix", elem_classes=["compact-audio"]))
with gr.Column():
res_outputs.append(gr.Image(label="NN baseline", **spec_size))
res_outputs.append(gr.Audio(label="BG", elem_classes=["compact-audio"]))
res_outputs.append(gr.Audio(label="FG", elem_classes=["compact-audio"]))
res_outputs.append(gr.Audio(label="Mix", elem_classes=["compact-audio"]))
def on_change(sid, view):
method = _method_from_view(view)
choices = _sample_choices(method)
if sid not in choices and choices:
sid = choices[0]
is_nn = view in ("Nearest Neighbor (Bin)", "Nearest Neighbor (Clap)", "Nearest Neighbor (Dasheng)")
is_res = view in ("Results (Bin)", "Results (Clap)", "Results (Dasheng)")
nn_vals = build_nn_view(sid, method)
res_vals = build_results_view(sid, method)
nn_title = f"### Nearest Neighbor ({method.capitalize()}): Baseline outputs (top 10 prompts)"
res_title = f"### Results ({method.capitalize()}): 3 baselines + Ours (top 3 prompts)"
dd_update = gr.update(choices=choices, value=sid)
return (
[gr.update(value=nn_title)] + list(nn_vals) + [gr.update(value=res_title)] + list(res_vals) +
[gr.update(visible=is_nn), gr.update(visible=is_res), dd_update]
)
all_outputs = [nn_section_title] + nn_outputs + [res_section_title] + res_outputs + [nn_col, res_col, noise_dd]
noise_dd.change(on_change, inputs=[noise_dd, view_radio], outputs=all_outputs)
view_radio.change(on_change, inputs=[noise_dd, view_radio], outputs=all_outputs)
app.load(lambda: on_change(TOP1_ID, "Nearest Neighbor (Clap)"), outputs=all_outputs)
_hf_hub_cache = os.environ.get(
"HUGGINGFACE_HUB_CACHE",
os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub"),
)
app.launch(allowed_paths=[_hf_hub_cache])