Rachel Ding commited on
Commit
6ed909f
·
1 Parent(s): b303c51

Show retrieved prompt for each NN1/NN2/NN3

Browse files
Files changed (2) hide show
  1. app.py +16 -9
  2. dataset_loader.py +12 -3
app.py CHANGED
@@ -14,9 +14,9 @@ TOP1_ID = SAMPLE_IDS[0] if SAMPLE_IDS else None
14
 
15
 
16
  def build_noise_demo(sample_id: str | None):
17
- """Returns (baseline, nn1, nn2, nn3) each: (spec, bg_wav, fg_wav, m_wav)."""
18
  if not sample_id:
19
- return (None,) * 16
20
  data = get_noise_demo_paths(sample_id)
21
  out = []
22
  for key in ("baseline", "nn1", "nn2", "nn3"):
@@ -27,6 +27,9 @@ def build_noise_demo(sample_id: str | None):
27
  block.get("fg_wav"),
28
  block.get("m_wav"),
29
  ])
 
 
 
30
  return tuple(out)
31
 
32
 
@@ -51,33 +54,37 @@ with gr.Blocks(title="NearestNeighbor Audio Demo", css=".gradio-container { max-
51
  label="Noise (ID)",
52
  )
53
 
54
- def block_ui(title: str):
55
  with gr.Group():
56
  gr.Markdown(f"### {title}")
 
 
57
  img = gr.Image(label=f"{title}", show_label=True)
58
  with gr.Row():
59
  abg = gr.Audio(label="BG", show_label=True)
60
  afg = gr.Audio(label="FG", show_label=True)
61
  am = gr.Audio(label="Mix", show_label=True)
 
 
62
  return img, abg, afg, am
63
 
64
  # Baseline
65
  bl_img, bl_bg, bl_fg, bl_m = block_ui("Baseline")
66
  gr.Markdown("---")
67
  # NN1
68
- nn1_img, nn1_bg, nn1_fg, nn1_m = block_ui("NN1")
69
  gr.Markdown("---")
70
  # NN2
71
- nn2_img, nn2_bg, nn2_fg, nn2_m = block_ui("NN2")
72
  gr.Markdown("---")
73
  # NN3
74
- nn3_img, nn3_bg, nn3_fg, nn3_m = block_ui("NN3")
75
 
76
  all_outputs = [
77
  bl_img, bl_bg, bl_fg, bl_m,
78
- nn1_img, nn1_bg, nn1_fg, nn1_m,
79
- nn2_img, nn2_bg, nn2_fg, nn2_m,
80
- nn3_img, nn3_bg, nn3_fg, nn3_m,
81
  ]
82
 
83
  def on_noise_select(sid):
 
14
 
15
 
16
  def build_noise_demo(sample_id: str | None):
17
+ """Returns (baseline, nn1, nn2, nn3) each: (spec, bg_wav, fg_wav, m_wav); nn1/nn2/nn3 also have prompt."""
18
  if not sample_id:
19
+ return (None,) * 19
20
  data = get_noise_demo_paths(sample_id)
21
  out = []
22
  for key in ("baseline", "nn1", "nn2", "nn3"):
 
27
  block.get("fg_wav"),
28
  block.get("m_wav"),
29
  ])
30
+ if key.startswith("nn"):
31
+ prompt = block.get("prompt", "") or ""
32
+ out.append(f"**Prompt:** {prompt}" if prompt else "")
33
  return tuple(out)
34
 
35
 
 
54
  label="Noise (ID)",
55
  )
56
 
57
+ def block_ui(title: str, with_prompt: bool = False):
58
  with gr.Group():
59
  gr.Markdown(f"### {title}")
60
+ if with_prompt:
61
+ prompt_md = gr.Markdown(value="", elem_id=f"{title}_prompt")
62
  img = gr.Image(label=f"{title}", show_label=True)
63
  with gr.Row():
64
  abg = gr.Audio(label="BG", show_label=True)
65
  afg = gr.Audio(label="FG", show_label=True)
66
  am = gr.Audio(label="Mix", show_label=True)
67
+ if with_prompt:
68
+ return prompt_md, img, abg, afg, am
69
  return img, abg, afg, am
70
 
71
  # Baseline
72
  bl_img, bl_bg, bl_fg, bl_m = block_ui("Baseline")
73
  gr.Markdown("---")
74
  # NN1
75
+ nn1_prompt, nn1_img, nn1_bg, nn1_fg, nn1_m = block_ui("NN1", with_prompt=True)
76
  gr.Markdown("---")
77
  # NN2
78
+ nn2_prompt, nn2_img, nn2_bg, nn2_fg, nn2_m = block_ui("NN2", with_prompt=True)
79
  gr.Markdown("---")
80
  # NN3
81
+ nn3_prompt, nn3_img, nn3_bg, nn3_fg, nn3_m = block_ui("NN3", with_prompt=True)
82
 
83
  all_outputs = [
84
  bl_img, bl_bg, bl_fg, bl_m,
85
+ nn1_img, nn1_bg, nn1_fg, nn1_m, nn1_prompt,
86
+ nn2_img, nn2_bg, nn2_fg, nn2_m, nn2_prompt,
87
+ nn3_img, nn3_bg, nn3_fg, nn3_m, nn3_prompt,
88
  ]
89
 
90
  def on_noise_select(sid):
dataset_loader.py CHANGED
@@ -105,13 +105,20 @@ def get_noise_demo_paths(bid: str) -> dict:
105
  """
106
  One block per method: baseline, nn1, nn2, nn3.
107
  Each block: one combined image (spec) + 3 audios (bg_wav, fg_wav, m_wav).
108
- Returns { "baseline": {spec, bg_wav, fg_wav, m_wav}, "nn1": {...}, "nn2": {...}, "nn3": {...} }.
 
109
  """
110
  inner = f"{ROOT_PREFIX}{bid}/{bid}"
111
  files = _find_files(inner)
112
  baseline_inner = f"{inner}/baseline"
113
  baseline_files = _find_files(baseline_inner) if any(f.startswith(baseline_inner) for f in files) else []
114
 
 
 
 
 
 
 
115
  def collect_block(file_list: list, folder_prefix: str) -> dict:
116
  """From files under folder_prefix, get spec + bg_wav, fg_wav, m_wav."""
117
  spec = bg = fg = m = None
@@ -144,10 +151,12 @@ def get_noise_demo_paths(bid: str) -> dict:
144
  break
145
  baseline_block = collect_block(baseline_files, baseline_prefix) if baseline_prefix else {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None}
146
 
147
- # NN1, NN2, NN3: generated_01_, generated_02_, generated_03_
148
  result = {"baseline": baseline_block}
149
  for i in range(1, 4):
150
  rel_prefix = f"generated_{i:02d}_"
151
  nn_files = [f for f in files if f.replace(inner + "/", "").startswith(rel_prefix)]
152
- result[f"nn{i}"] = collect_block(nn_files, rel_prefix)
 
 
153
  return result
 
105
  """
106
  One block per method: baseline, nn1, nn2, nn3.
107
  Each block: one combined image (spec) + 3 audios (bg_wav, fg_wav, m_wav).
108
+ nn1/nn2/nn3 also include "prompt" (retrieved text).
109
+ Returns { "baseline": {spec, bg_wav, fg_wav, m_wav}, "nn1": {..., prompt}, ... }.
110
  """
111
  inner = f"{ROOT_PREFIX}{bid}/{bid}"
112
  files = _find_files(inner)
113
  baseline_inner = f"{inner}/baseline"
114
  baseline_files = _find_files(baseline_inner) if any(f.startswith(baseline_inner) for f in files) else []
115
 
116
+ prompts = _load_json_from_repo(f"{inner}/temp_retrieval.json")
117
+ if not prompts:
118
+ prompts = _load_json_from_repo(f"{inner}/natural_prompts.json")
119
+ if not prompts:
120
+ prompts = []
121
+
122
  def collect_block(file_list: list, folder_prefix: str) -> dict:
123
  """From files under folder_prefix, get spec + bg_wav, fg_wav, m_wav."""
124
  spec = bg = fg = m = None
 
151
  break
152
  baseline_block = collect_block(baseline_files, baseline_prefix) if baseline_prefix else {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None}
153
 
154
+ # NN1, NN2, NN3: generated_01_, generated_02_, generated_03_ + prompt from retrieval JSON
155
  result = {"baseline": baseline_block}
156
  for i in range(1, 4):
157
  rel_prefix = f"generated_{i:02d}_"
158
  nn_files = [f for f in files if f.replace(inner + "/", "").startswith(rel_prefix)]
159
+ block = collect_block(nn_files, rel_prefix)
160
+ block["prompt"] = prompts[i - 1].get("prompt", "") if i <= len(prompts) else ""
161
+ result[f"nn{i}"] = block
162
  return result