Rachel Ding commited on
Commit
c8edbbe
·
1 Parent(s): 25c33f6

Results: add Results (CLAP) and Results (Dasheng) views; dataset_loader supports dasheng results

Browse files
Files changed (2) hide show
  1. app.py +20 -12
  2. dataset_loader.py +39 -9
app.py CHANGED
@@ -36,7 +36,7 @@ def build_nn_view(sample_id: str | None, use_dasheng: bool = False):
36
  return tuple(out[:50])
37
 
38
 
39
- def build_results_view(sample_id: str | None):
40
  """
41
  Results view: 3 blocks. Per block:
42
  - Row1: Gaussian | Youtube spec + their BG/FG/Mix
@@ -44,7 +44,8 @@ def build_results_view(sample_id: str | None):
44
  """
45
  if not sample_id:
46
  return (None,) * (3 * (1 + 4 * 4))
47
- data = get_results_demo_paths(sample_id)
 
48
  out = []
49
  for i in range(1, 4):
50
  block = data.get(f"block{i}", {})
@@ -66,7 +67,12 @@ with gr.Blocks(title="NearestNeighbor Audio Demo", css=".gradio-container { max-
66
  gr.Markdown("Data from [AE-W/batch_outputs](https://huggingface.co/datasets/AE-W/batch_outputs)")
67
 
68
  view_radio = gr.Radio(
69
- choices=["Nearest Neighbor (CLAP)", "Nearest Neighbor (Dasheng)", "Results"],
 
 
 
 
 
70
  value="Nearest Neighbor (CLAP)",
71
  label="View",
72
  )
@@ -97,7 +103,7 @@ with gr.Blocks(title="NearestNeighbor Audio Demo", css=".gradio-container { max-
97
 
98
  # ---- Results View: 3 prompts, each with 2 rows (Gaussian|Youtube, Ours|NN baseline) ----
99
  with gr.Column(visible=False) as res_col:
100
- gr.Markdown("### Results: 3 baselines + Ours (top 3 prompts)")
101
  res_outputs = []
102
  for i in range(1, 4):
103
  with gr.Group():
@@ -131,8 +137,8 @@ with gr.Blocks(title="NearestNeighbor Audio Demo", css=".gradio-container { max-
131
  res_outputs.append(gr.Audio(label="Mix"))
132
 
133
  def on_change(sid, view):
134
- use_dasheng = view == "Nearest Neighbor (Dasheng)"
135
- # Dasheng view: only show IDs that exist in batch_outputs_dasheng (no fold*)
136
  if use_dasheng:
137
  choices = DASHENG_SAMPLE_IDS
138
  if sid not in DASHENG_SAMPLE_IDS and DASHENG_SAMPLE_IDS:
@@ -141,17 +147,19 @@ with gr.Blocks(title="NearestNeighbor Audio Demo", css=".gradio-container { max-
141
  choices = SAMPLE_IDS
142
  if sid not in SAMPLE_IDS and SAMPLE_IDS:
143
  sid = SAMPLE_IDS[0]
144
- nn_vals = build_nn_view(sid, use_dasheng=use_dasheng)
145
- res_vals = build_results_view(sid)
146
  is_nn = view in ("Nearest Neighbor (CLAP)", "Nearest Neighbor (Dasheng)")
147
- nn_title = "### Nearest Neighbor (Dasheng): Baseline outputs (top 10 prompts)" if use_dasheng else "### Nearest Neighbor (CLAP): Baseline outputs (top 10 prompts)"
 
 
 
 
148
  dd_update = gr.update(choices=choices, value=sid)
149
  return (
150
- [gr.update(value=nn_title)] + list(nn_vals) + list(res_vals) +
151
- [gr.update(visible=is_nn), gr.update(visible=not is_nn), dd_update]
152
  )
153
 
154
- all_outputs = [nn_section_title] + nn_outputs + res_outputs + [nn_col, res_col, noise_dd]
155
 
156
  noise_dd.change(on_change, inputs=[noise_dd, view_radio], outputs=all_outputs)
157
  view_radio.change(on_change, inputs=[noise_dd, view_radio], outputs=all_outputs)
 
36
  return tuple(out[:50])
37
 
38
 
39
+ def build_results_view(sample_id: str | None, use_dasheng: bool = False):
40
  """
41
  Results view: 3 blocks. Per block:
42
  - Row1: Gaussian | Youtube spec + their BG/FG/Mix
 
44
  """
45
  if not sample_id:
46
  return (None,) * (3 * (1 + 4 * 4))
47
+ root_prefix = DASHENG_PREFIX if use_dasheng else None
48
+ data = get_results_demo_paths(sample_id, root_prefix=root_prefix)
49
  out = []
50
  for i in range(1, 4):
51
  block = data.get(f"block{i}", {})
 
67
  gr.Markdown("Data from [AE-W/batch_outputs](https://huggingface.co/datasets/AE-W/batch_outputs)")
68
 
69
  view_radio = gr.Radio(
70
+ choices=[
71
+ "Nearest Neighbor (CLAP)",
72
+ "Nearest Neighbor (Dasheng)",
73
+ "Results (CLAP)",
74
+ "Results (Dasheng)",
75
+ ],
76
  value="Nearest Neighbor (CLAP)",
77
  label="View",
78
  )
 
103
 
104
  # ---- Results View: 3 prompts, each with 2 rows (Gaussian|Youtube, Ours|NN baseline) ----
105
  with gr.Column(visible=False) as res_col:
106
+ res_section_title = gr.Markdown("### Results (CLAP): 3 baselines + Ours (top 3 prompts)")
107
  res_outputs = []
108
  for i in range(1, 4):
109
  with gr.Group():
 
137
  res_outputs.append(gr.Audio(label="Mix"))
138
 
139
  def on_change(sid, view):
140
+ use_dasheng = view in ("Nearest Neighbor (Dasheng)", "Results (Dasheng)")
141
+ # Dasheng views: only show IDs that exist in batch_outputs_dasheng (no fold*)
142
  if use_dasheng:
143
  choices = DASHENG_SAMPLE_IDS
144
  if sid not in DASHENG_SAMPLE_IDS and DASHENG_SAMPLE_IDS:
 
147
  choices = SAMPLE_IDS
148
  if sid not in SAMPLE_IDS and SAMPLE_IDS:
149
  sid = SAMPLE_IDS[0]
 
 
150
  is_nn = view in ("Nearest Neighbor (CLAP)", "Nearest Neighbor (Dasheng)")
151
+ is_res = view in ("Results (CLAP)", "Results (Dasheng)")
152
+ nn_vals = build_nn_view(sid, use_dasheng=(view == "Nearest Neighbor (Dasheng)"))
153
+ res_vals = build_results_view(sid, use_dasheng=(view == "Results (Dasheng)"))
154
+ nn_title = "### Nearest Neighbor (Dasheng): Baseline outputs (top 10 prompts)" if view == "Nearest Neighbor (Dasheng)" else "### Nearest Neighbor (CLAP): Baseline outputs (top 10 prompts)"
155
+ res_title = "### Results (Dasheng): 3 baselines + Ours (top 3 prompts)" if view == "Results (Dasheng)" else "### Results (CLAP): 3 baselines + Ours (top 3 prompts)"
156
  dd_update = gr.update(choices=choices, value=sid)
157
  return (
158
+ [gr.update(value=nn_title)] + list(nn_vals) + [gr.update(value=res_title)] + list(res_vals) +
159
+ [gr.update(visible=is_nn), gr.update(visible=is_res), dd_update]
160
  )
161
 
162
+ all_outputs = [nn_section_title] + nn_outputs + [res_section_title] + res_outputs + [nn_col, res_col, noise_dd]
163
 
164
  noise_dd.change(on_change, inputs=[noise_dd, view_radio], outputs=all_outputs)
165
  view_radio.change(on_change, inputs=[noise_dd, view_radio], outputs=all_outputs)
dataset_loader.py CHANGED
@@ -191,12 +191,14 @@ def get_noise_demo_paths(bid: str) -> dict:
191
  return result
192
 
193
 
194
- def get_results_demo_paths(bid: str) -> dict:
195
  """
196
  For Results view: 3 blocks (prompts 1-3), each with 4 columns:
197
  Baseline (original), Gaussian, Youtube-noise, Ours.
 
198
  """
199
- inner = f"{ROOT_PREFIX}{bid}/{bid}"
 
200
  files = _find_files(inner)
201
  baseline_inner = f"{inner}/baseline"
202
  gaussian_inner = f"{inner}/gaussian_baseline"
@@ -211,6 +213,8 @@ def get_results_demo_paths(bid: str) -> dict:
211
  if not prompts:
212
  prompts = []
213
 
 
 
214
  def get_baseline_folders(bl_inner, bl_files):
215
  seen = set()
216
  folders = []
@@ -223,6 +227,17 @@ def get_results_demo_paths(bid: str) -> dict:
223
  return folders
224
 
225
  def get_youtube_folders():
 
 
 
 
 
 
 
 
 
 
 
226
  seen = set()
227
  folders = []
228
  for f in youtube_files:
@@ -251,13 +266,28 @@ def get_results_demo_paths(bid: str) -> dict:
251
  gaussian_block = _collect_block(gaussian_files, gaussian_inner)
252
 
253
  bl_youtube = {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None}
254
- for fn, fp in youtube_folders:
255
- if fn.startswith(rel_prefix):
256
- bl_youtube = _collect_block(youtube_files, fp)
257
- break
258
-
259
- nn_files = [f for f in files if f.replace(inner + "/", "").startswith(rel_prefix)]
260
- ours_block = _collect_block(nn_files, rel_prefix)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
  result[f"block{i}"] = {
263
  "prompt": prompt_text,
 
191
  return result
192
 
193
 
194
+ def get_results_demo_paths(bid: str, root_prefix: Optional[str] = None) -> dict:
195
  """
196
  For Results view: 3 blocks (prompts 1-3), each with 4 columns:
197
  Baseline (original), Gaussian, Youtube-noise, Ours.
198
+ root_prefix: None = batch_outputs (CLAP), DASHENG_PREFIX = batch_outputs_dasheng.
199
  """
200
+ prefix = root_prefix if root_prefix is not None else ROOT_PREFIX
201
+ inner = f"{prefix}{bid}/{bid}"
202
  files = _find_files(inner)
203
  baseline_inner = f"{inner}/baseline"
204
  gaussian_inner = f"{inner}/gaussian_baseline"
 
213
  if not prompts:
214
  prompts = []
215
 
216
+ use_dasheng = root_prefix == DASHENG_PREFIX
217
+
218
  def get_baseline_folders(bl_inner, bl_files):
219
  seen = set()
220
  folders = []
 
227
  return folders
228
 
229
  def get_youtube_folders():
230
+ if use_dasheng:
231
+ # Dasheng: subdirs are prompt names (underscores)
232
+ seen = set()
233
+ folders = []
234
+ for f in youtube_files:
235
+ parts = f.replace(youtube_inner + "/", "").split("/")
236
+ if parts and parts[0] not in seen:
237
+ seen.add(parts[0])
238
+ folders.append((parts[0], youtube_inner + "/" + parts[0]))
239
+ folders.sort(key=lambda x: x[0])
240
+ return folders
241
  seen = set()
242
  folders = []
243
  for f in youtube_files:
 
266
  gaussian_block = _collect_block(gaussian_files, gaussian_inner)
267
 
268
  bl_youtube = {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None}
269
+ if use_dasheng:
270
+ # Dasheng: match by prompt -> folder name (spaces to underscores)
271
+ folder_name = prompt_text.replace(" ", "_") if prompt_text else ""
272
+ for fn, fp in youtube_folders:
273
+ if fn == folder_name:
274
+ bl_youtube = _collect_block(youtube_files, fp)
275
+ break
276
+ else:
277
+ for fn, fp in youtube_folders:
278
+ if fn.startswith(rel_prefix):
279
+ bl_youtube = _collect_block(youtube_files, fp)
280
+ break
281
+
282
+ if use_dasheng:
283
+ # Dasheng: "ours" = prompt-named folder under inner
284
+ folder_name = prompt_text.replace(" ", "_") if prompt_text else ""
285
+ ours_prefix = f"{inner}/{folder_name}"
286
+ nn_files = [f for f in files if f.startswith(ours_prefix + "/")]
287
+ ours_block = _collect_block(nn_files, ours_prefix)
288
+ else:
289
+ nn_files = [f for f in files if f.replace(inner + "/", "").startswith(rel_prefix)]
290
+ ours_block = _collect_block(nn_files, inner + "/" + rel_prefix)
291
 
292
  result[f"block{i}"] = {
293
  "prompt": prompt_text,