Rachel Ding commited on
Commit
9461a66
·
1 Parent(s): 6ed909f

Restructure: per-prompt blocks with Baseline then Ours (1 image + 3 audios each)

Browse files
Files changed (2) hide show
  1. app.py +36 -38
  2. dataset_loader.py +27 -17
app.py CHANGED
@@ -14,22 +14,19 @@ TOP1_ID = SAMPLE_IDS[0] if SAMPLE_IDS else None
14
 
15
 
16
  def build_noise_demo(sample_id: str | None):
17
- """Returns (baseline, nn1, nn2, nn3) each: (spec, bg_wav, fg_wav, m_wav); nn1/nn2/nn3 also have prompt."""
18
  if not sample_id:
19
- return (None,) * 19
20
  data = get_noise_demo_paths(sample_id)
21
  out = []
22
- for key in ("baseline", "nn1", "nn2", "nn3"):
23
- block = data.get(key, {})
24
- out.extend([
25
- block.get("spec"),
26
- block.get("bg_wav"),
27
- block.get("fg_wav"),
28
- block.get("m_wav"),
29
- ])
30
- if key.startswith("nn"):
31
- prompt = block.get("prompt", "") or ""
32
- out.append(f"**Prompt:** {prompt}" if prompt else "")
33
  return tuple(out)
34
 
35
 
@@ -54,37 +51,38 @@ with gr.Blocks(title="NearestNeighbor Audio Demo", css=".gradio-container { max-
54
  label="Noise (ID)",
55
  )
56
 
57
- def block_ui(title: str, with_prompt: bool = False):
 
58
  with gr.Group():
59
- gr.Markdown(f"### {title}")
60
- if with_prompt:
61
- prompt_md = gr.Markdown(value="", elem_id=f"{title}_prompt")
62
- img = gr.Image(label=f"{title}", show_label=True)
63
  with gr.Row():
64
- abg = gr.Audio(label="BG", show_label=True)
65
- afg = gr.Audio(label="FG", show_label=True)
66
- am = gr.Audio(label="Mix", show_label=True)
67
- if with_prompt:
68
- return prompt_md, img, abg, afg, am
69
- return img, abg, afg, am
70
-
71
- # Baseline
72
- bl_img, bl_bg, bl_fg, bl_m = block_ui("Baseline")
73
- gr.Markdown("---")
74
- # NN1
75
- nn1_prompt, nn1_img, nn1_bg, nn1_fg, nn1_m = block_ui("NN1", with_prompt=True)
 
 
76
  gr.Markdown("---")
77
- # NN2
78
- nn2_prompt, nn2_img, nn2_bg, nn2_fg, nn2_m = block_ui("NN2", with_prompt=True)
79
  gr.Markdown("---")
80
- # NN3
81
- nn3_prompt, nn3_img, nn3_bg, nn3_fg, nn3_m = block_ui("NN3", with_prompt=True)
82
 
83
  all_outputs = [
84
- bl_img, bl_bg, bl_fg, bl_m,
85
- nn1_img, nn1_bg, nn1_fg, nn1_m, nn1_prompt,
86
- nn2_img, nn2_bg, nn2_fg, nn2_m, nn2_prompt,
87
- nn3_img, nn3_bg, nn3_fg, nn3_m, nn3_prompt,
88
  ]
89
 
90
  def on_noise_select(sid):
 
14
 
15
 
16
  def build_noise_demo(sample_id: str | None):
17
+ """Returns for each of 3 blocks: prompt_md, baseline (spec, bg, fg, m), method (spec, bg, fg, m)."""
18
  if not sample_id:
19
+ return (None,) * 27
20
  data = get_noise_demo_paths(sample_id)
21
  out = []
22
+ for i in range(1, 4):
23
+ block = data.get(f"block{i}", {})
24
+ prompt = block.get("prompt", "") or ""
25
+ out.append(f"**Prompt:** {prompt}" if prompt else "")
26
+ bl = block.get("baseline", {})
27
+ out.extend([bl.get("spec"), bl.get("bg_wav"), bl.get("fg_wav"), bl.get("m_wav")])
28
+ nn = block.get("nn", {})
29
+ out.extend([nn.get("spec"), nn.get("bg_wav"), nn.get("fg_wav"), nn.get("m_wav")])
 
 
 
30
  return tuple(out)
31
 
32
 
 
51
  label="Noise (ID)",
52
  )
53
 
54
+ def prompt_block_ui(block_label: str):
55
+ """One block: prompt text, then Baseline (img + 3 audio), then Ours (img + 3 audio)."""
56
  with gr.Group():
57
+ gr.Markdown(f"### {block_label}")
58
+ prompt_md = gr.Markdown(value="")
59
+ gr.Markdown("**Baseline**")
 
60
  with gr.Row():
61
+ bl_img = gr.Image(label="Baseline", show_label=True)
62
+ bl_bg = gr.Audio(label="BG", show_label=True)
63
+ bl_fg = gr.Audio(label="FG", show_label=True)
64
+ bl_m = gr.Audio(label="Mix", show_label=True)
65
+ gr.Markdown("**Ours**")
66
+ with gr.Row():
67
+ nn_img = gr.Image(label="Ours", show_label=True)
68
+ nn_bg = gr.Audio(label="BG", show_label=True)
69
+ nn_fg = gr.Audio(label="FG", show_label=True)
70
+ nn_m = gr.Audio(label="Mix", show_label=True)
71
+ return prompt_md, bl_img, bl_bg, bl_fg, bl_m, nn_img, nn_bg, nn_fg, nn_m
72
+
73
+ # Block 1: Prompt 1 -> Baseline -> Ours
74
+ p1_md, bl1_img, bl1_bg, bl1_fg, bl1_m, nn1_img, nn1_bg, nn1_fg, nn1_m = prompt_block_ui("Prompt 1")
75
  gr.Markdown("---")
76
+ # Block 2
77
+ p2_md, bl2_img, bl2_bg, bl2_fg, bl2_m, nn2_img, nn2_bg, nn2_fg, nn2_m = prompt_block_ui("Prompt 2")
78
  gr.Markdown("---")
79
+ # Block 3
80
+ p3_md, bl3_img, bl3_bg, bl3_fg, bl3_m, nn3_img, nn3_bg, nn3_fg, nn3_m = prompt_block_ui("Prompt 3")
81
 
82
  all_outputs = [
83
+ p1_md, bl1_img, bl1_bg, bl1_fg, bl1_m, nn1_img, nn1_bg, nn1_fg, nn1_m,
84
+ p2_md, bl2_img, bl2_bg, bl2_fg, bl2_m, nn2_img, nn2_bg, nn2_fg, nn2_m,
85
+ p3_md, bl3_img, bl3_bg, bl3_fg, bl3_m, nn3_img, nn3_bg, nn3_fg, nn3_m,
 
86
  ]
87
 
88
  def on_noise_select(sid):
dataset_loader.py CHANGED
@@ -103,10 +103,8 @@ def get_nn_demo_paths(bid: str) -> dict:
103
 
104
  def get_noise_demo_paths(bid: str) -> dict:
105
  """
106
- One block per method: baseline, nn1, nn2, nn3.
107
- Each block: one combined image (spec) + 3 audios (bg_wav, fg_wav, m_wav).
108
- nn1/nn2/nn3 also include "prompt" (retrieved text).
109
- Returns { "baseline": {spec, bg_wav, fg_wav, m_wav}, "nn1": {..., prompt}, ... }.
110
  """
111
  inner = f"{ROOT_PREFIX}{bid}/{bid}"
112
  files = _find_files(inner)
@@ -141,22 +139,34 @@ def get_noise_demo_paths(bid: str) -> dict:
141
  "m_wav": _download_file(m) if m else None,
142
  }
143
 
144
- # Baseline: use first generated_baseline_* (01 or 02 etc.)
145
- baseline_prefix = None
 
146
  for f in baseline_files:
147
- if "/baseline/generated_baseline_" in f:
148
- parts = f.replace(baseline_inner + "/", "").split("/")
149
- if parts and parts[0].startswith("generated_baseline_"):
150
- baseline_prefix = baseline_inner + "/" + parts[0]
151
- break
152
- baseline_block = collect_block(baseline_files, baseline_prefix) if baseline_prefix else {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None}
153
 
154
- # NN1, NN2, NN3: generated_01_, generated_02_, generated_03_ + prompt from retrieval JSON
155
- result = {"baseline": baseline_block}
156
  for i in range(1, 4):
 
 
 
 
 
 
 
 
 
157
  rel_prefix = f"generated_{i:02d}_"
158
  nn_files = [f for f in files if f.replace(inner + "/", "").startswith(rel_prefix)]
159
- block = collect_block(nn_files, rel_prefix)
160
- block["prompt"] = prompts[i - 1].get("prompt", "") if i <= len(prompts) else ""
161
- result[f"nn{i}"] = block
 
 
 
 
162
  return result
 
103
 
104
  def get_noise_demo_paths(bid: str) -> dict:
105
  """
106
+ One block per prompt (1, 2, 3): each has prompt text, baseline (spec + 3 wavs), and our method (spec + 3 wavs).
107
+ Returns { "block1": {prompt, baseline: {...}, nn: {...}}, "block2": ..., "block3": ... }.
 
 
108
  """
109
  inner = f"{ROOT_PREFIX}{bid}/{bid}"
110
  files = _find_files(inner)
 
139
  "m_wav": _download_file(m) if m else None,
140
  }
141
 
142
+ # Find baseline folder names generated_baseline_01_*, 02_*, 03_*
143
+ seen = set()
144
+ baseline_folders = []
145
  for f in baseline_files:
146
+ parts = f.replace(baseline_inner + "/", "").split("/")
147
+ if parts and parts[0].startswith("generated_baseline_") and parts[0] not in seen:
148
+ seen.add(parts[0])
149
+ baseline_folders.append((parts[0], baseline_inner + "/" + parts[0]))
150
+ baseline_folders.sort(key=lambda x: x[0])
 
151
 
152
+ result = {}
 
153
  for i in range(1, 4):
154
+ prompt_text = prompts[i - 1].get("prompt", "") if i <= len(prompts) else ""
155
+ # Baseline for this prompt: i-th baseline folder (01, 02, 03)
156
+ bl_prefix = f"generated_baseline_{i:02d}_"
157
+ baseline_block = {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None}
158
+ for folder_name, full_prefix in baseline_folders:
159
+ if folder_name.startswith(bl_prefix):
160
+ baseline_block = collect_block(baseline_files, full_prefix)
161
+ break
162
+ # Our method: generated_0{i}_*
163
  rel_prefix = f"generated_{i:02d}_"
164
  nn_files = [f for f in files if f.replace(inner + "/", "").startswith(rel_prefix)]
165
+ nn_block = collect_block(nn_files, rel_prefix)
166
+ nn_block["prompt"] = prompt_text
167
+ result[f"block{i}"] = {
168
+ "prompt": prompt_text,
169
+ "baseline": baseline_block,
170
+ "nn": nn_block,
171
+ }
172
  return result