Spaces:
Running
Running
Rachel Ding commited on
Commit ·
86d12ed
1
Parent(s): bbe6901
Add view toggle: Nearest Neighbor (baseline + 10 NN) vs Results (3 baselines + Ours, top 3 prompts)
Browse files- app.py +81 -71
- dataset_loader.py +125 -35
app.py
CHANGED
|
@@ -6,103 +6,113 @@ import os
|
|
| 6 |
|
| 7 |
import gradio as gr
|
| 8 |
|
| 9 |
-
from dataset_loader import list_samples,
|
| 10 |
|
| 11 |
|
| 12 |
SAMPLE_IDS = list_samples()
|
| 13 |
TOP1_ID = SAMPLE_IDS[0] if SAMPLE_IDS else None
|
| 14 |
|
| 15 |
|
| 16 |
-
def
|
| 17 |
-
"""
|
| 18 |
if not sample_id:
|
| 19 |
-
return (None,) *
|
| 20 |
-
data =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
out = []
|
| 22 |
for i in range(1, 4):
|
| 23 |
block = data.get(f"block{i}", {})
|
| 24 |
prompt = block.get("prompt", "") or ""
|
| 25 |
out.append(f"**Prompt:** {prompt}" if prompt else "")
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
out.extend([nn.get("spec"), nn.get("bg_wav"), nn.get("fg_wav"), nn.get("m_wav")])
|
| 30 |
return tuple(out)
|
| 31 |
|
| 32 |
|
| 33 |
-
with gr.Blocks(title="NearestNeighbor Audio Demo", css=".gradio-container { max-width:
|
| 34 |
gr.Markdown("# NearestNeighbor Audio Demo")
|
| 35 |
gr.Markdown("Data from [AE-W/batch_outputs](https://huggingface.co/datasets/AE-W/batch_outputs)")
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
gr.Markdown("""
|
| 38 |
**How to read the IDs**
|
| 39 |
- **Numeric IDs** (e.g. `00_000357`) come from the **SONYC** dataset.
|
| 40 |
- **IDs starting with `fold`** come from the **UrbanSound8k** dataset.
|
| 41 |
|
| 42 |
-
**Audio labels**
|
| 43 |
-
- **BG** = background noise
|
| 44 |
-
- **FG** = generated foreground sound
|
| 45 |
-
- **Mix** = BG + FG (mixed)
|
| 46 |
""")
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
outputs=all_outputs,
|
| 98 |
-
)
|
| 99 |
-
|
| 100 |
-
def init():
|
| 101 |
-
return build_noise_demo(TOP1_ID)
|
| 102 |
-
|
| 103 |
-
app.load(init, outputs=all_outputs)
|
| 104 |
|
| 105 |
-
# Allow serving files from HF dataset cache (required on Spaces)
|
| 106 |
_hf_hub_cache = os.environ.get(
|
| 107 |
"HUGGINGFACE_HUB_CACHE",
|
| 108 |
os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub"),
|
|
|
|
| 6 |
|
| 7 |
import gradio as gr
|
| 8 |
|
| 9 |
+
from dataset_loader import list_samples, get_nn_demo_paths, get_results_demo_paths
|
| 10 |
|
| 11 |
|
| 12 |
SAMPLE_IDS = list_samples()
|
| 13 |
TOP1_ID = SAMPLE_IDS[0] if SAMPLE_IDS else None
|
| 14 |
|
| 15 |
|
| 16 |
+
def build_nn_view(sample_id: str | None):
|
| 17 |
+
"""NN view: [Baseline] [NN1] [NN2] ... [NN10]. Each NN has spec + m_wav."""
|
| 18 |
if not sample_id:
|
| 19 |
+
return (None,) * (4 + 10 * 2)
|
| 20 |
+
data = get_nn_demo_paths(sample_id, top_k=10)
|
| 21 |
+
out = []
|
| 22 |
+
bl = data.get("baseline", {})
|
| 23 |
+
out.extend([bl.get("spec"), bl.get("bg_wav"), bl.get("fg_wav"), bl.get("m_wav")])
|
| 24 |
+
for nn in data.get("nn_list", [])[:10]:
|
| 25 |
+
out.extend([nn.get("spec"), nn.get("m_wav")])
|
| 26 |
+
while len(out) < 4 + 20:
|
| 27 |
+
out.append(None)
|
| 28 |
+
return tuple(out[: 4 + 20])
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def build_results_view(sample_id: str | None):
|
| 32 |
+
"""Results view: 3 blocks, each with Baseline (orig), Gaussian, Youtube-noise, Ours."""
|
| 33 |
+
if not sample_id:
|
| 34 |
+
return (None,) * (3 * (1 + 4 * 4))
|
| 35 |
+
data = get_results_demo_paths(sample_id)
|
| 36 |
out = []
|
| 37 |
for i in range(1, 4):
|
| 38 |
block = data.get(f"block{i}", {})
|
| 39 |
prompt = block.get("prompt", "") or ""
|
| 40 |
out.append(f"**Prompt:** {prompt}" if prompt else "")
|
| 41 |
+
for key in ("baseline_original", "baseline_gaussian", "baseline_youtube", "ours"):
|
| 42 |
+
b = block.get(key, {})
|
| 43 |
+
out.extend([b.get("spec"), b.get("bg_wav"), b.get("fg_wav"), b.get("m_wav")])
|
|
|
|
| 44 |
return tuple(out)
|
| 45 |
|
| 46 |
|
| 47 |
+
with gr.Blocks(title="NearestNeighbor Audio Demo", css=".gradio-container { max-width: 1400px; }") as app:
|
| 48 |
gr.Markdown("# NearestNeighbor Audio Demo")
|
| 49 |
gr.Markdown("Data from [AE-W/batch_outputs](https://huggingface.co/datasets/AE-W/batch_outputs)")
|
| 50 |
|
| 51 |
+
view_radio = gr.Radio(
|
| 52 |
+
choices=["Nearest Neighbor", "Results"],
|
| 53 |
+
value="Nearest Neighbor",
|
| 54 |
+
label="View",
|
| 55 |
+
)
|
| 56 |
+
noise_dd = gr.Dropdown(choices=SAMPLE_IDS, value=TOP1_ID, label="Noise (ID)")
|
| 57 |
+
|
| 58 |
gr.Markdown("""
|
| 59 |
**How to read the IDs**
|
| 60 |
- **Numeric IDs** (e.g. `00_000357`) come from the **SONYC** dataset.
|
| 61 |
- **IDs starting with `fold`** come from the **UrbanSound8k** dataset.
|
| 62 |
|
| 63 |
+
**Audio labels**: **BG** = background noise | **FG** = generated foreground | **Mix** = BG + FG
|
|
|
|
|
|
|
|
|
|
| 64 |
""")
|
| 65 |
|
| 66 |
+
# ---- NN View: Baseline + 10 NN ----
|
| 67 |
+
with gr.Column(visible=True) as nn_col:
|
| 68 |
+
gr.Markdown("### Nearest Neighbor: Baseline + top 10 NN")
|
| 69 |
+
with gr.Row():
|
| 70 |
+
with gr.Column(min_width=180):
|
| 71 |
+
gr.Markdown("**Baseline**")
|
| 72 |
+
nn_bl_img = gr.Image(label="Spec", show_label=False, height=220)
|
| 73 |
+
nn_bl_bg = gr.Audio(label="BG", show_label=True)
|
| 74 |
+
nn_bl_fg = gr.Audio(label="FG", show_label=True)
|
| 75 |
+
nn_bl_m = gr.Audio(label="Mix", show_label=True)
|
| 76 |
+
gr.Markdown("**NN1–NN10**")
|
| 77 |
+
nn_items = []
|
| 78 |
+
with gr.Row():
|
| 79 |
+
for i in range(10):
|
| 80 |
+
with gr.Column(min_width=120):
|
| 81 |
+
nn_items.append(gr.Image(label=f"NN{i+1}", show_label=True, height=140))
|
| 82 |
+
nn_items.append(gr.Audio(label="Mix", show_label=True))
|
| 83 |
+
nn_outputs = [nn_bl_img, nn_bl_bg, nn_bl_fg, nn_bl_m] + nn_items
|
| 84 |
+
|
| 85 |
+
# ---- Results View: 3 prompts × 4 methods ----
|
| 86 |
+
with gr.Column(visible=False) as res_col:
|
| 87 |
+
gr.Markdown("### Results: 3 baselines + Ours (top 3 prompts)")
|
| 88 |
+
res_outputs = []
|
| 89 |
+
for i in range(1, 4):
|
| 90 |
+
with gr.Group():
|
| 91 |
+
res_p_md = gr.Markdown(value="")
|
| 92 |
+
res_outputs.append(res_p_md)
|
| 93 |
+
with gr.Row():
|
| 94 |
+
for _ in ["Original", "Gaussian", "Youtube", "Ours"]:
|
| 95 |
+
res_outputs.append(gr.Image(height=180))
|
| 96 |
+
res_outputs.append(gr.Audio(label="BG"))
|
| 97 |
+
res_outputs.append(gr.Audio(label="FG"))
|
| 98 |
+
res_outputs.append(gr.Audio(label="Mix"))
|
| 99 |
+
|
| 100 |
+
def on_change(sid, view):
|
| 101 |
+
nn_vals = build_nn_view(sid)
|
| 102 |
+
res_vals = build_results_view(sid)
|
| 103 |
+
is_nn = view == "Nearest Neighbor"
|
| 104 |
+
return (
|
| 105 |
+
list(nn_vals) + list(res_vals) +
|
| 106 |
+
[gr.update(visible=is_nn), gr.update(visible=not is_nn)]
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
all_outputs = nn_outputs + res_outputs + [nn_col, res_col]
|
| 110 |
+
|
| 111 |
+
noise_dd.change(on_change, inputs=[noise_dd, view_radio], outputs=all_outputs)
|
| 112 |
+
view_radio.change(on_change, inputs=[noise_dd, view_radio], outputs=all_outputs)
|
| 113 |
+
|
| 114 |
+
app.load(lambda: on_change(TOP1_ID, "Nearest Neighbor"), outputs=all_outputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
|
|
|
| 116 |
_hf_hub_cache = os.environ.get(
|
| 117 |
"HUGGINGFACE_HUB_CACHE",
|
| 118 |
os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub"),
|
dataset_loader.py
CHANGED
|
@@ -59,25 +59,60 @@ def _find_files(inner: str) -> list[str]:
|
|
| 59 |
return [f for f in files if f.startswith(inner + "/")]
|
| 60 |
|
| 61 |
|
| 62 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
"""
|
| 64 |
-
For
|
| 65 |
-
Returns {
|
| 66 |
"""
|
| 67 |
inner = f"{ROOT_PREFIX}{bid}/{bid}"
|
| 68 |
prompts = _load_json_from_repo(f"{inner}/temp_retrieval.json")
|
| 69 |
if not prompts:
|
| 70 |
prompts = _load_json_from_repo(f"{inner}/natural_prompts.json")
|
| 71 |
if not prompts:
|
| 72 |
-
return {"bg_wav": None, "
|
| 73 |
|
| 74 |
files = _find_files(inner)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
nn_list = []
|
| 76 |
-
for i, p in enumerate(prompts[:
|
| 77 |
prompt = p.get("prompt", "")
|
| 78 |
sim = p.get("similarity_score", p.get("retrieval_score"))
|
| 79 |
gen_prefix = f"generated_{i+1:02d}_"
|
| 80 |
-
fg_path = bg_path = spec_path = None
|
| 81 |
for f in files:
|
| 82 |
parts = f.replace(inner + "/", "").split("/")
|
| 83 |
if len(parts) >= 2 and parts[0].startswith(gen_prefix):
|
|
@@ -86,19 +121,20 @@ def get_nn_demo_paths(bid: str) -> dict:
|
|
| 86 |
fg_path = f
|
| 87 |
elif name.endswith("_bg.wav"):
|
| 88 |
bg_path = f
|
|
|
|
|
|
|
| 89 |
elif name.endswith(".png"):
|
| 90 |
spec_path = f
|
| 91 |
nn_list.append({
|
| 92 |
"fg_wav": _download_file(fg_path) if fg_path else None,
|
| 93 |
"spec": _download_file(spec_path) if spec_path else None,
|
| 94 |
"bg_wav": _download_file(bg_path) if bg_path else None,
|
|
|
|
| 95 |
"prompt": prompt,
|
| 96 |
"similarity": sim,
|
| 97 |
})
|
| 98 |
|
| 99 |
-
|
| 100 |
-
bg_spec = nn_list[0]["spec"] if nn_list else None
|
| 101 |
-
return {"bg_wav": bg_wav, "bg_spec": bg_spec, "nn_list": nn_list}
|
| 102 |
|
| 103 |
|
| 104 |
def get_noise_demo_paths(bid: str) -> dict:
|
|
@@ -117,28 +153,6 @@ def get_noise_demo_paths(bid: str) -> dict:
|
|
| 117 |
if not prompts:
|
| 118 |
prompts = []
|
| 119 |
|
| 120 |
-
def collect_block(file_list: list, folder_prefix: str) -> dict:
|
| 121 |
-
"""From files under folder_prefix, get spec + bg_wav, fg_wav, m_wav."""
|
| 122 |
-
spec = bg = fg = m = None
|
| 123 |
-
for f in file_list:
|
| 124 |
-
if folder_prefix not in f:
|
| 125 |
-
continue
|
| 126 |
-
name = f.split("/")[-1]
|
| 127 |
-
if name.endswith(".png"):
|
| 128 |
-
spec = f
|
| 129 |
-
elif name.endswith("_bg.wav"):
|
| 130 |
-
bg = f
|
| 131 |
-
elif name.endswith("_fg.wav"):
|
| 132 |
-
fg = f
|
| 133 |
-
elif name.endswith("_m.wav"):
|
| 134 |
-
m = f
|
| 135 |
-
return {
|
| 136 |
-
"spec": _download_file(spec) if spec else None,
|
| 137 |
-
"bg_wav": _download_file(bg) if bg else None,
|
| 138 |
-
"fg_wav": _download_file(fg) if fg else None,
|
| 139 |
-
"m_wav": _download_file(m) if m else None,
|
| 140 |
-
}
|
| 141 |
-
|
| 142 |
# Find baseline folder names generated_baseline_01_*, 02_*, 03_*
|
| 143 |
seen = set()
|
| 144 |
baseline_folders = []
|
|
@@ -152,17 +166,15 @@ def get_noise_demo_paths(bid: str) -> dict:
|
|
| 152 |
result = {}
|
| 153 |
for i in range(1, 4):
|
| 154 |
prompt_text = prompts[i - 1].get("prompt", "") if i <= len(prompts) else ""
|
| 155 |
-
# Baseline for this prompt: i-th baseline folder (01, 02, 03)
|
| 156 |
bl_prefix = f"generated_baseline_{i:02d}_"
|
| 157 |
baseline_block = {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None}
|
| 158 |
for folder_name, full_prefix in baseline_folders:
|
| 159 |
if folder_name.startswith(bl_prefix):
|
| 160 |
-
baseline_block =
|
| 161 |
break
|
| 162 |
-
# Our method: generated_0{i}_*
|
| 163 |
rel_prefix = f"generated_{i:02d}_"
|
| 164 |
nn_files = [f for f in files if f.replace(inner + "/", "").startswith(rel_prefix)]
|
| 165 |
-
nn_block =
|
| 166 |
nn_block["prompt"] = prompt_text
|
| 167 |
result[f"block{i}"] = {
|
| 168 |
"prompt": prompt_text,
|
|
@@ -170,3 +182,81 @@ def get_noise_demo_paths(bid: str) -> dict:
|
|
| 170 |
"nn": nn_block,
|
| 171 |
}
|
| 172 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
return [f for f in files if f.startswith(inner + "/")]
|
| 60 |
|
| 61 |
|
| 62 |
+
def _collect_block(file_list: list, folder_prefix: str) -> dict:
|
| 63 |
+
"""From files under folder_prefix, get spec + bg_wav, fg_wav, m_wav."""
|
| 64 |
+
spec = bg = fg = m = None
|
| 65 |
+
for f in file_list:
|
| 66 |
+
if folder_prefix not in f:
|
| 67 |
+
continue
|
| 68 |
+
name = f.split("/")[-1]
|
| 69 |
+
if name.endswith(".png"):
|
| 70 |
+
spec = f
|
| 71 |
+
elif name.endswith("_bg.wav"):
|
| 72 |
+
bg = f
|
| 73 |
+
elif name.endswith("_fg.wav"):
|
| 74 |
+
fg = f
|
| 75 |
+
elif name.endswith("_m.wav"):
|
| 76 |
+
m = f
|
| 77 |
+
return {
|
| 78 |
+
"spec": _download_file(spec) if spec else None,
|
| 79 |
+
"bg_wav": _download_file(bg) if bg else None,
|
| 80 |
+
"fg_wav": _download_file(fg) if fg else None,
|
| 81 |
+
"m_wav": _download_file(m) if m else None,
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def get_nn_demo_paths(bid: str, top_k: int = 10) -> dict:
|
| 86 |
"""
|
| 87 |
+
For NN view: [Baseline] [NN1] [NN2] ... [NN10].
|
| 88 |
+
Returns {baseline: {spec, bg, fg, m}, nn_list: [{fg_wav, spec, bg_wav, prompt, similarity}, ...]}.
|
| 89 |
"""
|
| 90 |
inner = f"{ROOT_PREFIX}{bid}/{bid}"
|
| 91 |
prompts = _load_json_from_repo(f"{inner}/temp_retrieval.json")
|
| 92 |
if not prompts:
|
| 93 |
prompts = _load_json_from_repo(f"{inner}/natural_prompts.json")
|
| 94 |
if not prompts:
|
| 95 |
+
return {"baseline": {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None}, "nn_list": []}
|
| 96 |
|
| 97 |
files = _find_files(inner)
|
| 98 |
+
baseline_inner = f"{inner}/baseline"
|
| 99 |
+
baseline_files = _find_files(baseline_inner) if any(f.startswith(baseline_inner) for f in files) else []
|
| 100 |
+
|
| 101 |
+
# Baseline: first baseline folder (generated_baseline_01_*)
|
| 102 |
+
baseline_block = {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None}
|
| 103 |
+
for f in baseline_files:
|
| 104 |
+
parts = f.replace(baseline_inner + "/", "").split("/")
|
| 105 |
+
if parts and parts[0].startswith("generated_baseline_01_"):
|
| 106 |
+
full_prefix = baseline_inner + "/" + parts[0]
|
| 107 |
+
baseline_block = _collect_block(baseline_files, full_prefix)
|
| 108 |
+
break
|
| 109 |
+
|
| 110 |
nn_list = []
|
| 111 |
+
for i, p in enumerate(prompts[:top_k]):
|
| 112 |
prompt = p.get("prompt", "")
|
| 113 |
sim = p.get("similarity_score", p.get("retrieval_score"))
|
| 114 |
gen_prefix = f"generated_{i+1:02d}_"
|
| 115 |
+
fg_path = bg_path = m_path = spec_path = None
|
| 116 |
for f in files:
|
| 117 |
parts = f.replace(inner + "/", "").split("/")
|
| 118 |
if len(parts) >= 2 and parts[0].startswith(gen_prefix):
|
|
|
|
| 121 |
fg_path = f
|
| 122 |
elif name.endswith("_bg.wav"):
|
| 123 |
bg_path = f
|
| 124 |
+
elif name.endswith("_m.wav"):
|
| 125 |
+
m_path = f
|
| 126 |
elif name.endswith(".png"):
|
| 127 |
spec_path = f
|
| 128 |
nn_list.append({
|
| 129 |
"fg_wav": _download_file(fg_path) if fg_path else None,
|
| 130 |
"spec": _download_file(spec_path) if spec_path else None,
|
| 131 |
"bg_wav": _download_file(bg_path) if bg_path else None,
|
| 132 |
+
"m_wav": _download_file(m_path) if m_path else None,
|
| 133 |
"prompt": prompt,
|
| 134 |
"similarity": sim,
|
| 135 |
})
|
| 136 |
|
| 137 |
+
return {"baseline": baseline_block, "nn_list": nn_list}
|
|
|
|
|
|
|
| 138 |
|
| 139 |
|
| 140 |
def get_noise_demo_paths(bid: str) -> dict:
|
|
|
|
| 153 |
if not prompts:
|
| 154 |
prompts = []
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
# Find baseline folder names generated_baseline_01_*, 02_*, 03_*
|
| 157 |
seen = set()
|
| 158 |
baseline_folders = []
|
|
|
|
| 166 |
result = {}
|
| 167 |
for i in range(1, 4):
|
| 168 |
prompt_text = prompts[i - 1].get("prompt", "") if i <= len(prompts) else ""
|
|
|
|
| 169 |
bl_prefix = f"generated_baseline_{i:02d}_"
|
| 170 |
baseline_block = {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None}
|
| 171 |
for folder_name, full_prefix in baseline_folders:
|
| 172 |
if folder_name.startswith(bl_prefix):
|
| 173 |
+
baseline_block = _collect_block(baseline_files, full_prefix)
|
| 174 |
break
|
|
|
|
| 175 |
rel_prefix = f"generated_{i:02d}_"
|
| 176 |
nn_files = [f for f in files if f.replace(inner + "/", "").startswith(rel_prefix)]
|
| 177 |
+
nn_block = _collect_block(nn_files, rel_prefix)
|
| 178 |
nn_block["prompt"] = prompt_text
|
| 179 |
result[f"block{i}"] = {
|
| 180 |
"prompt": prompt_text,
|
|
|
|
| 182 |
"nn": nn_block,
|
| 183 |
}
|
| 184 |
return result
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def get_results_demo_paths(bid: str) -> dict:
|
| 188 |
+
"""
|
| 189 |
+
For Results view: 3 blocks (prompts 1-3), each with 4 columns:
|
| 190 |
+
Baseline (original), Gaussian, Youtube-noise, Ours.
|
| 191 |
+
"""
|
| 192 |
+
inner = f"{ROOT_PREFIX}{bid}/{bid}"
|
| 193 |
+
files = _find_files(inner)
|
| 194 |
+
baseline_inner = f"{inner}/baseline"
|
| 195 |
+
gaussian_inner = f"{inner}/gaussian_baseline"
|
| 196 |
+
youtube_inner = f"{inner}/youtube_noise_baseline"
|
| 197 |
+
baseline_files = _find_files(baseline_inner) if any(f.startswith(baseline_inner) for f in files) else []
|
| 198 |
+
gaussian_files = _find_files(gaussian_inner) if any(f.startswith(gaussian_inner) for f in files) else []
|
| 199 |
+
youtube_files = _find_files(youtube_inner) if any(f.startswith(youtube_inner) for f in files) else []
|
| 200 |
+
|
| 201 |
+
prompts = _load_json_from_repo(f"{inner}/temp_retrieval.json")
|
| 202 |
+
if not prompts:
|
| 203 |
+
prompts = _load_json_from_repo(f"{inner}/natural_prompts.json")
|
| 204 |
+
if not prompts:
|
| 205 |
+
prompts = []
|
| 206 |
+
|
| 207 |
+
def get_baseline_folders(bl_inner, bl_files):
|
| 208 |
+
seen = set()
|
| 209 |
+
folders = []
|
| 210 |
+
for f in bl_files:
|
| 211 |
+
parts = f.replace(bl_inner + "/", "").split("/")
|
| 212 |
+
if parts and parts[0].startswith("generated_baseline_") and parts[0] not in seen:
|
| 213 |
+
seen.add(parts[0])
|
| 214 |
+
folders.append((parts[0], bl_inner + "/" + parts[0]))
|
| 215 |
+
folders.sort(key=lambda x: x[0])
|
| 216 |
+
return folders
|
| 217 |
+
|
| 218 |
+
def get_youtube_folders():
|
| 219 |
+
seen = set()
|
| 220 |
+
folders = []
|
| 221 |
+
for f in youtube_files:
|
| 222 |
+
parts = f.replace(youtube_inner + "/", "").split("/")
|
| 223 |
+
if parts and parts[0].startswith("generated_") and parts[0] not in seen:
|
| 224 |
+
seen.add(parts[0])
|
| 225 |
+
folders.append((parts[0], youtube_inner + "/" + parts[0]))
|
| 226 |
+
folders.sort(key=lambda x: x[0])
|
| 227 |
+
return folders
|
| 228 |
+
|
| 229 |
+
baseline_folders = get_baseline_folders(baseline_inner, baseline_files)
|
| 230 |
+
youtube_folders = get_youtube_folders()
|
| 231 |
+
|
| 232 |
+
result = {}
|
| 233 |
+
for i in range(1, 4):
|
| 234 |
+
prompt_text = prompts[i - 1].get("prompt", "") if i <= len(prompts) else ""
|
| 235 |
+
bl_prefix = f"generated_baseline_{i:02d}_"
|
| 236 |
+
rel_prefix = f"generated_{i:02d}_"
|
| 237 |
+
|
| 238 |
+
bl_orig = {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None}
|
| 239 |
+
for fn, fp in baseline_folders:
|
| 240 |
+
if fn.startswith(bl_prefix):
|
| 241 |
+
bl_orig = _collect_block(baseline_files, fp)
|
| 242 |
+
break
|
| 243 |
+
|
| 244 |
+
gaussian_block = _collect_block(gaussian_files, gaussian_inner)
|
| 245 |
+
|
| 246 |
+
bl_youtube = {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None}
|
| 247 |
+
for fn, fp in youtube_folders:
|
| 248 |
+
if fn.startswith(rel_prefix):
|
| 249 |
+
bl_youtube = _collect_block(youtube_files, fp)
|
| 250 |
+
break
|
| 251 |
+
|
| 252 |
+
nn_files = [f for f in files if f.replace(inner + "/", "").startswith(rel_prefix)]
|
| 253 |
+
ours_block = _collect_block(nn_files, rel_prefix)
|
| 254 |
+
|
| 255 |
+
result[f"block{i}"] = {
|
| 256 |
+
"prompt": prompt_text,
|
| 257 |
+
"baseline_original": bl_orig,
|
| 258 |
+
"baseline_gaussian": gaussian_block,
|
| 259 |
+
"baseline_youtube": bl_youtube,
|
| 260 |
+
"ours": ours_block,
|
| 261 |
+
}
|
| 262 |
+
return result
|