Spaces:
Running
Running
Rachel Ding commited on
Commit ·
349e3bf
1
Parent(s): b21e8c5
Dasheng view: dropdown only shows batch_outputs_dasheng sample IDs (no fold*)
Browse files- app.py +14 -2
- dataset_loader.py +5 -0
app.py
CHANGED
|
@@ -9,12 +9,14 @@ import gradio as gr
|
|
| 9 |
from dataset_loader import (
|
| 10 |
DASHENG_PREFIX,
|
| 11 |
list_samples,
|
|
|
|
| 12 |
get_nn_demo_paths,
|
| 13 |
get_results_demo_paths,
|
| 14 |
)
|
| 15 |
|
| 16 |
|
| 17 |
SAMPLE_IDS = list_samples()
|
|
|
|
| 18 |
TOP1_ID = SAMPLE_IDS[0] if SAMPLE_IDS else None
|
| 19 |
|
| 20 |
|
|
@@ -130,16 +132,26 @@ with gr.Blocks(title="NearestNeighbor Audio Demo", css=".gradio-container { max-
|
|
| 130 |
|
| 131 |
def on_change(sid, view):
|
| 132 |
use_dasheng = view == "Nearest Neighbor (Dasheng)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
nn_vals = build_nn_view(sid, use_dasheng=use_dasheng)
|
| 134 |
res_vals = build_results_view(sid)
|
| 135 |
is_nn = view in ("Nearest Neighbor (CLAP)", "Nearest Neighbor (Dasheng)")
|
| 136 |
nn_title = "### Nearest Neighbor (Dasheng): Baseline outputs (top 10 prompts)" if use_dasheng else "### Nearest Neighbor (CLAP): Baseline outputs (top 10 prompts)"
|
|
|
|
| 137 |
return (
|
| 138 |
[gr.update(value=nn_title)] + list(nn_vals) + list(res_vals) +
|
| 139 |
-
[gr.update(visible=is_nn), gr.update(visible=not is_nn)]
|
| 140 |
)
|
| 141 |
|
| 142 |
-
all_outputs = [nn_section_title] + nn_outputs + res_outputs + [nn_col, res_col]
|
| 143 |
|
| 144 |
noise_dd.change(on_change, inputs=[noise_dd, view_radio], outputs=all_outputs)
|
| 145 |
view_radio.change(on_change, inputs=[noise_dd, view_radio], outputs=all_outputs)
|
|
|
|
| 9 |
from dataset_loader import (
|
| 10 |
DASHENG_PREFIX,
|
| 11 |
list_samples,
|
| 12 |
+
list_samples_dasheng,
|
| 13 |
get_nn_demo_paths,
|
| 14 |
get_results_demo_paths,
|
| 15 |
)
|
| 16 |
|
| 17 |
|
| 18 |
SAMPLE_IDS = list_samples()
|
| 19 |
+
DASHENG_SAMPLE_IDS = list_samples_dasheng() # Only IDs in batch_outputs_dasheng (no fold*)
|
| 20 |
TOP1_ID = SAMPLE_IDS[0] if SAMPLE_IDS else None
|
| 21 |
|
| 22 |
|
|
|
|
| 132 |
|
| 133 |
def on_change(sid, view):
|
| 134 |
use_dasheng = view == "Nearest Neighbor (Dasheng)"
|
| 135 |
+
# Dasheng view: only show IDs that exist in batch_outputs_dasheng (no fold*)
|
| 136 |
+
if use_dasheng:
|
| 137 |
+
choices = DASHENG_SAMPLE_IDS
|
| 138 |
+
if sid not in DASHENG_SAMPLE_IDS and DASHENG_SAMPLE_IDS:
|
| 139 |
+
sid = DASHENG_SAMPLE_IDS[0]
|
| 140 |
+
else:
|
| 141 |
+
choices = SAMPLE_IDS
|
| 142 |
+
if sid not in SAMPLE_IDS and SAMPLE_IDS:
|
| 143 |
+
sid = SAMPLE_IDS[0]
|
| 144 |
nn_vals = build_nn_view(sid, use_dasheng=use_dasheng)
|
| 145 |
res_vals = build_results_view(sid)
|
| 146 |
is_nn = view in ("Nearest Neighbor (CLAP)", "Nearest Neighbor (Dasheng)")
|
| 147 |
nn_title = "### Nearest Neighbor (Dasheng): Baseline outputs (top 10 prompts)" if use_dasheng else "### Nearest Neighbor (CLAP): Baseline outputs (top 10 prompts)"
|
| 148 |
+
dd_update = gr.update(choices=choices, value=sid)
|
| 149 |
return (
|
| 150 |
[gr.update(value=nn_title)] + list(nn_vals) + list(res_vals) +
|
| 151 |
+
[gr.update(visible=is_nn), gr.update(visible=not is_nn), dd_update]
|
| 152 |
)
|
| 153 |
|
| 154 |
+
all_outputs = [nn_section_title] + nn_outputs + res_outputs + [nn_col, res_col, noise_dd]
|
| 155 |
|
| 156 |
noise_dd.change(on_change, inputs=[noise_dd, view_radio], outputs=all_outputs)
|
| 157 |
view_radio.change(on_change, inputs=[noise_dd, view_radio], outputs=all_outputs)
|
dataset_loader.py
CHANGED
|
@@ -61,6 +61,11 @@ def list_samples() -> list[str]:
|
|
| 61 |
return _get_all_sample_ids()
|
| 62 |
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
def _find_files(inner: str) -> list[str]:
|
| 65 |
"""List all repo files under inner path."""
|
| 66 |
files = list_repo_files(REPO_ID, repo_type=REPO_TYPE)
|
|
|
|
| 61 |
return _get_all_sample_ids()
|
| 62 |
|
| 63 |
|
| 64 |
+
def list_samples_dasheng() -> list[str]:
|
| 65 |
+
"""Return only sample IDs that exist in batch_outputs_dasheng (no fold* from UrbanSound8k)."""
|
| 66 |
+
return _get_sample_ids(DASHENG_PREFIX)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
def _find_files(inner: str) -> list[str]:
|
| 70 |
"""List all repo files under inner path."""
|
| 71 |
files = list_repo_files(REPO_ID, repo_type=REPO_TYPE)
|