Spaces:
Sleeping
Sleeping
Rachel Ding commited on
Commit ·
6ed909f
1
Parent(s): b303c51
Show retrieved prompt for each NN1/NN2/NN3
Browse files- app.py +16 -9
- dataset_loader.py +12 -3
app.py
CHANGED
|
@@ -14,9 +14,9 @@ TOP1_ID = SAMPLE_IDS[0] if SAMPLE_IDS else None
|
|
| 14 |
|
| 15 |
|
| 16 |
def build_noise_demo(sample_id: str | None):
|
| 17 |
-
"""Returns (baseline, nn1, nn2, nn3) each: (spec, bg_wav, fg_wav, m_wav)."""
|
| 18 |
if not sample_id:
|
| 19 |
-
return (None,) *
|
| 20 |
data = get_noise_demo_paths(sample_id)
|
| 21 |
out = []
|
| 22 |
for key in ("baseline", "nn1", "nn2", "nn3"):
|
|
@@ -27,6 +27,9 @@ def build_noise_demo(sample_id: str | None):
|
|
| 27 |
block.get("fg_wav"),
|
| 28 |
block.get("m_wav"),
|
| 29 |
])
|
|
|
|
|
|
|
|
|
|
| 30 |
return tuple(out)
|
| 31 |
|
| 32 |
|
|
@@ -51,33 +54,37 @@ with gr.Blocks(title="NearestNeighbor Audio Demo", css=".gradio-container { max-
|
|
| 51 |
label="Noise (ID)",
|
| 52 |
)
|
| 53 |
|
| 54 |
-
def block_ui(title: str):
|
| 55 |
with gr.Group():
|
| 56 |
gr.Markdown(f"### {title}")
|
|
|
|
|
|
|
| 57 |
img = gr.Image(label=f"{title}", show_label=True)
|
| 58 |
with gr.Row():
|
| 59 |
abg = gr.Audio(label="BG", show_label=True)
|
| 60 |
afg = gr.Audio(label="FG", show_label=True)
|
| 61 |
am = gr.Audio(label="Mix", show_label=True)
|
|
|
|
|
|
|
| 62 |
return img, abg, afg, am
|
| 63 |
|
| 64 |
# Baseline
|
| 65 |
bl_img, bl_bg, bl_fg, bl_m = block_ui("Baseline")
|
| 66 |
gr.Markdown("---")
|
| 67 |
# NN1
|
| 68 |
-
nn1_img, nn1_bg, nn1_fg, nn1_m = block_ui("NN1")
|
| 69 |
gr.Markdown("---")
|
| 70 |
# NN2
|
| 71 |
-
nn2_img, nn2_bg, nn2_fg, nn2_m = block_ui("NN2")
|
| 72 |
gr.Markdown("---")
|
| 73 |
# NN3
|
| 74 |
-
nn3_img, nn3_bg, nn3_fg, nn3_m = block_ui("NN3")
|
| 75 |
|
| 76 |
all_outputs = [
|
| 77 |
bl_img, bl_bg, bl_fg, bl_m,
|
| 78 |
-
nn1_img, nn1_bg, nn1_fg, nn1_m,
|
| 79 |
-
nn2_img, nn2_bg, nn2_fg, nn2_m,
|
| 80 |
-
nn3_img, nn3_bg, nn3_fg, nn3_m,
|
| 81 |
]
|
| 82 |
|
| 83 |
def on_noise_select(sid):
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
def build_noise_demo(sample_id: str | None):
|
| 17 |
+
"""Returns (baseline, nn1, nn2, nn3) each: (spec, bg_wav, fg_wav, m_wav); nn1/nn2/nn3 also have prompt."""
|
| 18 |
if not sample_id:
|
| 19 |
+
return (None,) * 19
|
| 20 |
data = get_noise_demo_paths(sample_id)
|
| 21 |
out = []
|
| 22 |
for key in ("baseline", "nn1", "nn2", "nn3"):
|
|
|
|
| 27 |
block.get("fg_wav"),
|
| 28 |
block.get("m_wav"),
|
| 29 |
])
|
| 30 |
+
if key.startswith("nn"):
|
| 31 |
+
prompt = block.get("prompt", "") or ""
|
| 32 |
+
out.append(f"**Prompt:** {prompt}" if prompt else "")
|
| 33 |
return tuple(out)
|
| 34 |
|
| 35 |
|
|
|
|
| 54 |
label="Noise (ID)",
|
| 55 |
)
|
| 56 |
|
| 57 |
+
def block_ui(title: str, with_prompt: bool = False):
|
| 58 |
with gr.Group():
|
| 59 |
gr.Markdown(f"### {title}")
|
| 60 |
+
if with_prompt:
|
| 61 |
+
prompt_md = gr.Markdown(value="", elem_id=f"{title}_prompt")
|
| 62 |
img = gr.Image(label=f"{title}", show_label=True)
|
| 63 |
with gr.Row():
|
| 64 |
abg = gr.Audio(label="BG", show_label=True)
|
| 65 |
afg = gr.Audio(label="FG", show_label=True)
|
| 66 |
am = gr.Audio(label="Mix", show_label=True)
|
| 67 |
+
if with_prompt:
|
| 68 |
+
return prompt_md, img, abg, afg, am
|
| 69 |
return img, abg, afg, am
|
| 70 |
|
| 71 |
# Baseline
|
| 72 |
bl_img, bl_bg, bl_fg, bl_m = block_ui("Baseline")
|
| 73 |
gr.Markdown("---")
|
| 74 |
# NN1
|
| 75 |
+
nn1_prompt, nn1_img, nn1_bg, nn1_fg, nn1_m = block_ui("NN1", with_prompt=True)
|
| 76 |
gr.Markdown("---")
|
| 77 |
# NN2
|
| 78 |
+
nn2_prompt, nn2_img, nn2_bg, nn2_fg, nn2_m = block_ui("NN2", with_prompt=True)
|
| 79 |
gr.Markdown("---")
|
| 80 |
# NN3
|
| 81 |
+
nn3_prompt, nn3_img, nn3_bg, nn3_fg, nn3_m = block_ui("NN3", with_prompt=True)
|
| 82 |
|
| 83 |
all_outputs = [
|
| 84 |
bl_img, bl_bg, bl_fg, bl_m,
|
| 85 |
+
nn1_img, nn1_bg, nn1_fg, nn1_m, nn1_prompt,
|
| 86 |
+
nn2_img, nn2_bg, nn2_fg, nn2_m, nn2_prompt,
|
| 87 |
+
nn3_img, nn3_bg, nn3_fg, nn3_m, nn3_prompt,
|
| 88 |
]
|
| 89 |
|
| 90 |
def on_noise_select(sid):
|
dataset_loader.py
CHANGED
|
@@ -105,13 +105,20 @@ def get_noise_demo_paths(bid: str) -> dict:
|
|
| 105 |
"""
|
| 106 |
One block per method: baseline, nn1, nn2, nn3.
|
| 107 |
Each block: one combined image (spec) + 3 audios (bg_wav, fg_wav, m_wav).
|
| 108 |
-
|
|
|
|
| 109 |
"""
|
| 110 |
inner = f"{ROOT_PREFIX}{bid}/{bid}"
|
| 111 |
files = _find_files(inner)
|
| 112 |
baseline_inner = f"{inner}/baseline"
|
| 113 |
baseline_files = _find_files(baseline_inner) if any(f.startswith(baseline_inner) for f in files) else []
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
def collect_block(file_list: list, folder_prefix: str) -> dict:
|
| 116 |
"""From files under folder_prefix, get spec + bg_wav, fg_wav, m_wav."""
|
| 117 |
spec = bg = fg = m = None
|
|
@@ -144,10 +151,12 @@ def get_noise_demo_paths(bid: str) -> dict:
|
|
| 144 |
break
|
| 145 |
baseline_block = collect_block(baseline_files, baseline_prefix) if baseline_prefix else {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None}
|
| 146 |
|
| 147 |
-
# NN1, NN2, NN3: generated_01_, generated_02_, generated_03_
|
| 148 |
result = {"baseline": baseline_block}
|
| 149 |
for i in range(1, 4):
|
| 150 |
rel_prefix = f"generated_{i:02d}_"
|
| 151 |
nn_files = [f for f in files if f.replace(inner + "/", "").startswith(rel_prefix)]
|
| 152 |
-
|
|
|
|
|
|
|
| 153 |
return result
|
|
|
|
| 105 |
"""
|
| 106 |
One block per method: baseline, nn1, nn2, nn3.
|
| 107 |
Each block: one combined image (spec) + 3 audios (bg_wav, fg_wav, m_wav).
|
| 108 |
+
nn1/nn2/nn3 also include "prompt" (retrieved text).
|
| 109 |
+
Returns { "baseline": {spec, bg_wav, fg_wav, m_wav}, "nn1": {..., prompt}, ... }.
|
| 110 |
"""
|
| 111 |
inner = f"{ROOT_PREFIX}{bid}/{bid}"
|
| 112 |
files = _find_files(inner)
|
| 113 |
baseline_inner = f"{inner}/baseline"
|
| 114 |
baseline_files = _find_files(baseline_inner) if any(f.startswith(baseline_inner) for f in files) else []
|
| 115 |
|
| 116 |
+
prompts = _load_json_from_repo(f"{inner}/temp_retrieval.json")
|
| 117 |
+
if not prompts:
|
| 118 |
+
prompts = _load_json_from_repo(f"{inner}/natural_prompts.json")
|
| 119 |
+
if not prompts:
|
| 120 |
+
prompts = []
|
| 121 |
+
|
| 122 |
def collect_block(file_list: list, folder_prefix: str) -> dict:
|
| 123 |
"""From files under folder_prefix, get spec + bg_wav, fg_wav, m_wav."""
|
| 124 |
spec = bg = fg = m = None
|
|
|
|
| 151 |
break
|
| 152 |
baseline_block = collect_block(baseline_files, baseline_prefix) if baseline_prefix else {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None}
|
| 153 |
|
| 154 |
+
# NN1, NN2, NN3: generated_01_, generated_02_, generated_03_ + prompt from retrieval JSON
|
| 155 |
result = {"baseline": baseline_block}
|
| 156 |
for i in range(1, 4):
|
| 157 |
rel_prefix = f"generated_{i:02d}_"
|
| 158 |
nn_files = [f for f in files if f.replace(inner + "/", "").startswith(rel_prefix)]
|
| 159 |
+
block = collect_block(nn_files, rel_prefix)
|
| 160 |
+
block["prompt"] = prompts[i - 1].get("prompt", "") if i <= len(prompts) else ""
|
| 161 |
+
result[f"nn{i}"] = block
|
| 162 |
return result
|