Spaces:
Sleeping
Sleeping
| """ | |
| Load data from Hugging Face dataset AE-W/batch_outputs. | |
| Uses huggingface_hub to list and download files on demand. | |
| """ | |
| import json | |
| import os | |
| import re | |
| from pathlib import Path | |
| from typing import Optional | |
| from huggingface_hub import HfApi, hf_hub_download, list_repo_files | |
| from huggingface_hub import HfFileSystem | |
| REPO_ID = "AE-W/batch_outputs" | |
| REPO_TYPE = "dataset" | |
| ROOT_PREFIX = "batch_outputs/" | |
| DASHENG_PREFIX = "batch_outputs_dasheng/" | |
| # Three methods: each has batch_outputs_* + generated_noises_* | |
| BIN_BATCH_PREFIX = "batch_outputs_bin/" | |
| BIN_GENERATED_PREFIX = "generated_noises_bin/" | |
| CLAP_BATCH_PREFIX = "batch_outputs/" # ROOT_PREFIX | |
| CLAP_GENERATED_PREFIX = "generated_noises_clap/" | |
| DASHENG_BATCH_PREFIX = "batch_outputs_dasheng/" | |
| DASHENG_GENERATED_PREFIX = "generated_noises_dasheng/" | |
| GENERATED_SKIP_IDS = {"__pycache__", "NearestNeighbor_space_push"} | |
| # Cache full repo file list so we only call list_repo_files once per process (major speedup) | |
| _cached_repo_files: Optional[list[str]] = None | |
| def _get_repo_files() -> list[str]: | |
| """Return full list of repo file paths, cached after first call.""" | |
| global _cached_repo_files | |
| if _cached_repo_files is None: | |
| _cached_repo_files = list_repo_files(REPO_ID, repo_type=REPO_TYPE) | |
| return _cached_repo_files | |
| def _get_sample_ids(prefix: str = ROOT_PREFIX) -> list[str]: | |
| """List sample IDs (e.g. 07_003277) under given prefix in repo.""" | |
| files = _get_repo_files() | |
| seen = set() | |
| pat = re.escape(prefix.rstrip("/")) + r"/([^/]+)/" | |
| for f in files: | |
| m = re.match(pat, f) | |
| if m: | |
| seen.add(m.group(1)) | |
| return sorted(seen) | |
| def _get_all_sample_ids() -> list[str]: | |
| """Union of sample IDs from batch_outputs and batch_outputs_dasheng.""" | |
| ids = set(_get_sample_ids(ROOT_PREFIX)) | set(_get_sample_ids(DASHENG_PREFIX)) | |
| return sorted(ids) | |
| def _download_file(path_in_repo: str, local_dir: Optional[str] = None) -> str: | |
| """Download a file from the dataset; return local path.""" | |
| return hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=path_in_repo, | |
| repo_type=REPO_TYPE, | |
| local_dir=local_dir, | |
| force_download=False, | |
| ) | |
| def _load_json_from_repo(path_in_repo: str) -> Optional[list]: | |
| """Download and load JSON file from repo.""" | |
| try: | |
| path = _download_file(path_in_repo) | |
| with open(path, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| except Exception: | |
| return None | |
| def list_samples() -> list[str]: | |
| """Return list of sample IDs (bid) from both batch_outputs and batch_outputs_dasheng.""" | |
| return _get_all_sample_ids() | |
| def list_samples_bin() -> list[str]: | |
| """Sample IDs for Bin: batch_outputs_bin ∪ generated_noises_bin (exclude __pycache__ etc).""" | |
| batch_ids = set(_get_sample_ids(BIN_BATCH_PREFIX)) | |
| gen_ids = set(_get_sample_ids(BIN_GENERATED_PREFIX)) - GENERATED_SKIP_IDS | |
| return sorted(batch_ids | gen_ids) | |
| def list_samples_clap() -> list[str]: | |
| """Sample IDs for Clap: batch_outputs ∪ generated_noises_clap.""" | |
| batch_ids = set(_get_sample_ids(CLAP_BATCH_PREFIX)) | |
| gen_ids = set(_get_sample_ids(CLAP_GENERATED_PREFIX)) - GENERATED_SKIP_IDS | |
| return sorted(batch_ids | gen_ids) | |
| def list_samples_dasheng() -> list[str]: | |
| """Sample IDs for Dasheng: batch_outputs_dasheng (excl. fold*) ∪ generated_noises_dasheng.""" | |
| batch_ids = {x for x in _get_sample_ids(DASHENG_BATCH_PREFIX) if not x.startswith("fold")} | |
| gen_ids = set(_get_sample_ids(DASHENG_GENERATED_PREFIX)) - GENERATED_SKIP_IDS | |
| return sorted(batch_ids | gen_ids) | |
| def _find_files(inner: str) -> list[str]: | |
| """List all repo files under inner path (uses cached repo file list).""" | |
| files = _get_repo_files() | |
| return [f for f in files if f.startswith(inner + "/")] | |
| def _list_files_under_via_fs(path_in_repo: str) -> list[str]: | |
| """List files under path_in_repo using HfFileSystem (avoids relying on full list_repo_files).""" | |
| try: | |
| fs = HfFileSystem() | |
| fs_prefix = f"datasets/{REPO_ID}/{path_in_repo}" | |
| strip = f"datasets/{REPO_ID}/" | |
| found = fs.glob(fs_prefix + "/**") | |
| out = [] | |
| for p in found: | |
| if p.endswith("/"): | |
| continue | |
| rel = p.replace(strip, "").lstrip("/") | |
| if rel.startswith(path_in_repo + "/"): | |
| out.append(rel) | |
| return out | |
| except Exception: | |
| return [] | |
| def _has_files_under(prefix_bid: str) -> bool: | |
| """True if repo has any file under prefix_bid/.""" | |
| files = _get_repo_files() | |
| return any(f.startswith(prefix_bid + "/") for f in files) | |
| def resolve_sample_prefix(bid: str, method: str) -> Optional[str]: | |
| """ | |
| Resolve which repo prefix contains this sample_id (batch first, then generated). | |
| method: "bin" | "clap" | "dasheng" | |
| Returns e.g. "batch_outputs_bin/" or "generated_noises_bin/". | |
| """ | |
| if method == "bin": | |
| batch, generated = BIN_BATCH_PREFIX, BIN_GENERATED_PREFIX | |
| elif method == "clap": | |
| batch, generated = CLAP_BATCH_PREFIX, CLAP_GENERATED_PREFIX | |
| elif method == "dasheng": | |
| batch, generated = DASHENG_BATCH_PREFIX, DASHENG_GENERATED_PREFIX | |
| else: | |
| return None | |
| if _has_files_under(batch + bid): | |
| return batch | |
| if _has_files_under(generated + bid): | |
| return generated | |
| return None | |
| def get_inner_path(prefix: str, bid: str) -> Optional[str]: | |
| """ | |
| Return the inner path (contains baseline/, natural_prompts.json, etc.). | |
| For batch_outputs*: prefix/bid/bid. For generated_noises*: same or prefix/bid/X/X if bid/X/X. | |
| """ | |
| inner_std = f"{prefix}{bid}/{bid}" | |
| if _has_files_under(inner_std): | |
| return inner_std | |
| # generated_noises: may have prefix/bid/X/X (e.g. cars_honking/heavy_machinery/heavy_machinery) | |
| if not prefix.startswith("generated_noises"): | |
| return inner_std | |
| files = _get_repo_files() | |
| for f in files: | |
| if not f.startswith(prefix + bid + "/"): | |
| continue | |
| if "natural_prompts.json" in f or "temp_retrieval.json" in f: | |
| parts = f.split("/") | |
| # prefix/bid/X/X/file -> inner = prefix/bid/X/X | |
| if len(parts) >= 4: | |
| return "/".join(parts[:4]) | |
| return inner_std | |
| def _collect_block(file_list: list, folder_prefix: str) -> dict: | |
| """From files under folder_prefix, get spec + bg_wav, fg_wav, m_wav.""" | |
| spec = bg = fg = m = None | |
| for f in file_list: | |
| if folder_prefix not in f: | |
| continue | |
| name = f.split("/")[-1] | |
| if name.endswith(".png"): | |
| spec = f | |
| elif name.endswith("_bg.wav"): | |
| bg = f | |
| elif name.endswith("_fg.wav"): | |
| fg = f | |
| elif name.endswith("_m.wav"): | |
| m = f | |
| return { | |
| "spec": _download_file(spec) if spec else None, | |
| "bg_wav": _download_file(bg) if bg else None, | |
| "fg_wav": _download_file(fg) if fg else None, | |
| "m_wav": _download_file(m) if m else None, | |
| } | |
| def get_nn_demo_paths(bid: str, top_k: int = 10, root_prefix: Optional[str] = None, method: Optional[str] = None) -> dict: | |
| """ | |
| For NN view: NN1-NN10 from baseline (generated_baseline_01, 02, ..., 10) in prompt order. | |
| root_prefix: legacy; if method is set (bin|clap|dasheng), resolve prefix and inner from repo. | |
| Returns {nn_list: [{spec, bg_wav, fg_wav, m_wav, prompt, similarity}, ...]}. | |
| """ | |
| if method is not None: | |
| prefix = resolve_sample_prefix(bid, method) | |
| if not prefix: | |
| return {"nn_list": []} | |
| inner = get_inner_path(prefix, bid) | |
| if not inner: | |
| return {"nn_list": []} | |
| else: | |
| prefix = root_prefix if root_prefix is not None else ROOT_PREFIX | |
| inner = f"{prefix}{bid}/{bid}" | |
| prompts = _load_json_from_repo(f"{inner}/temp_retrieval.json") | |
| if not prompts: | |
| prompts = _load_json_from_repo(f"{inner}/natural_prompts.json") | |
| if not prompts: | |
| return {"nn_list": []} | |
| files = _find_files(inner) | |
| baseline_inner = f"{inner}/baseline" | |
| baseline_files = _find_files(baseline_inner) if any(f.startswith(baseline_inner) for f in files) else [] | |
| nn_list = [] | |
| for i, p in enumerate(prompts[:top_k]): | |
| prompt = p.get("prompt", "") | |
| sim = p.get("similarity_score", p.get("retrieval_score")) | |
| bl_prefix = f"generated_baseline_{i+1:02d}_" | |
| block = {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None} | |
| for f in baseline_files: | |
| parts = f.replace(baseline_inner + "/", "").split("/") | |
| if parts and parts[0].startswith(bl_prefix): | |
| full_prefix = baseline_inner + "/" + parts[0] | |
| block = _collect_block(baseline_files, full_prefix) | |
| break | |
| block["prompt"] = prompt | |
| block["similarity"] = sim | |
| nn_list.append(block) | |
| return {"nn_list": nn_list} | |
| def get_noise_demo_paths(bid: str) -> dict: | |
| """ | |
| One block per prompt (1, 2, 3): each has prompt text, baseline (spec + 3 wavs), and our method (spec + 3 wavs). | |
| Returns { "block1": {prompt, baseline: {...}, nn: {...}}, "block2": ..., "block3": ... }. | |
| """ | |
| inner = f"{ROOT_PREFIX}{bid}/{bid}" | |
| files = _find_files(inner) | |
| baseline_inner = f"{inner}/baseline" | |
| baseline_files = _find_files(baseline_inner) if any(f.startswith(baseline_inner) for f in files) else [] | |
| prompts = _load_json_from_repo(f"{inner}/temp_retrieval.json") | |
| if not prompts: | |
| prompts = _load_json_from_repo(f"{inner}/natural_prompts.json") | |
| if not prompts: | |
| prompts = [] | |
| # Find baseline folder names generated_baseline_01_*, 02_*, 03_* | |
| seen = set() | |
| baseline_folders = [] | |
| for f in baseline_files: | |
| parts = f.replace(baseline_inner + "/", "").split("/") | |
| if parts and parts[0].startswith("generated_baseline_") and parts[0] not in seen: | |
| seen.add(parts[0]) | |
| baseline_folders.append((parts[0], baseline_inner + "/" + parts[0])) | |
| baseline_folders.sort(key=lambda x: x[0]) | |
| result = {} | |
| for i in range(1, 4): | |
| prompt_text = prompts[i - 1].get("prompt", "") if i <= len(prompts) else "" | |
| bl_prefix = f"generated_baseline_{i:02d}_" | |
| baseline_block = {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None} | |
| for folder_name, full_prefix in baseline_folders: | |
| if folder_name.startswith(bl_prefix): | |
| baseline_block = _collect_block(baseline_files, full_prefix) | |
| break | |
| rel_prefix = f"generated_{i:02d}_" | |
| nn_files = [f for f in files if f.replace(inner + "/", "").startswith(rel_prefix)] | |
| nn_block = _collect_block(nn_files, rel_prefix) | |
| nn_block["prompt"] = prompt_text | |
| result[f"block{i}"] = { | |
| "prompt": prompt_text, | |
| "baseline": baseline_block, | |
| "nn": nn_block, | |
| } | |
| return result | |
| def get_results_demo_paths(bid: str, root_prefix: Optional[str] = None, method: Optional[str] = None) -> dict: | |
| """ | |
| For Results view: 3 blocks (prompts 1-3), each with 4 columns: | |
| Baseline (original), Gaussian, Youtube-noise, Ours. | |
| root_prefix: legacy; if method is set (bin|clap|dasheng), resolve prefix and inner from repo. | |
| """ | |
| if method is not None: | |
| prefix = resolve_sample_prefix(bid, method) | |
| if not prefix: | |
| return {} | |
| inner = get_inner_path(prefix, bid) | |
| if not inner: | |
| return {} | |
| # Dasheng-style: prompt-named folders (batch_outputs_bin, batch_outputs_dasheng, all generated_noises_*) | |
| use_dasheng = prefix in (BIN_BATCH_PREFIX, DASHENG_BATCH_PREFIX, DASHENG_GENERATED_PREFIX) or prefix.startswith("generated_noises") | |
| else: | |
| prefix = root_prefix if root_prefix is not None else ROOT_PREFIX | |
| inner = f"{prefix}{bid}/{bid}" | |
| use_dasheng = root_prefix == DASHENG_PREFIX | |
| files = _find_files(inner) | |
| baseline_inner = f"{inner}/baseline" | |
| gaussian_inner = f"{inner}/gaussian_baseline" | |
| youtube_inner = f"{inner}/youtube_noise_baseline" | |
| # Use full repo file list for baseline/gaussian/youtube so we find them even if "files" is partial | |
| all_repo = _get_repo_files() | |
| baseline_files = _find_files(baseline_inner) if any(f.startswith(baseline_inner + "/") for f in all_repo) else [] | |
| gaussian_files = _find_files(gaussian_inner) if any(f.startswith(gaussian_inner + "/") for f in all_repo) else [] | |
| youtube_files = _find_files(youtube_inner) if any(f.startswith(youtube_inner + "/") for f in all_repo) else [] | |
| # Fallback for bin/generated: gaussian or youtube may live under a different inner (e.g. prefix/bid/X/X) | |
| if not gaussian_files and (prefix == BIN_BATCH_PREFIX or prefix.startswith("generated_noises")): | |
| for f in all_repo: | |
| if f.startswith(prefix + bid + "/") and "/gaussian_baseline/" in f and (f.endswith("_m.wav") or f.endswith("_bg.wav")): | |
| gaussian_inner_fb = f.rsplit("/", 1)[0] # path to gaussian_baseline dir | |
| gaussian_files = _find_files(gaussian_inner_fb) | |
| gaussian_inner = gaussian_inner_fb | |
| break | |
| if not youtube_files and (prefix == BIN_BATCH_PREFIX or prefix.startswith("generated_noises")): | |
| for f in all_repo: | |
| if f.startswith(prefix + bid + "/") and "/youtube_noise_baseline/" in f and (f.endswith("_m.wav") or f.endswith("_bg.wav")): | |
| # f = .../youtube_noise_baseline/<prompt_folder>/file_m.wav -> parent = youtube_noise_baseline | |
| youtube_inner_fb = f.split("/youtube_noise_baseline/", 1)[0] + "/youtube_noise_baseline" | |
| youtube_files = _find_files(youtube_inner_fb) | |
| youtube_inner = youtube_inner_fb | |
| break | |
| # For generated_noises: list_repo_files() may not include these paths in large repos; use HfFileSystem by path | |
| if prefix.startswith("generated_noises"): | |
| if not gaussian_files: | |
| gaussian_files = _list_files_under_via_fs(gaussian_inner) | |
| if not youtube_files: | |
| youtube_files = _list_files_under_via_fs(youtube_inner) | |
| prompts = _load_json_from_repo(f"{inner}/temp_retrieval.json") | |
| if not prompts: | |
| prompts = _load_json_from_repo(f"{inner}/natural_prompts.json") | |
| if not prompts: | |
| prompts = [] | |
| def get_baseline_folders(bl_inner, bl_files): | |
| seen = set() | |
| folders = [] | |
| for f in bl_files: | |
| parts = f.replace(bl_inner + "/", "").split("/") | |
| if parts and parts[0].startswith("generated_baseline_") and parts[0] not in seen: | |
| seen.add(parts[0]) | |
| folders.append((parts[0], bl_inner + "/" + parts[0])) | |
| folders.sort(key=lambda x: x[0]) | |
| return folders | |
| def get_youtube_folders(): | |
| if use_dasheng: | |
| # Dasheng: subdirs are prompt names (underscores) | |
| seen = set() | |
| folders = [] | |
| for f in youtube_files: | |
| parts = f.replace(youtube_inner + "/", "").split("/") | |
| if parts and parts[0] not in seen: | |
| seen.add(parts[0]) | |
| folders.append((parts[0], youtube_inner + "/" + parts[0])) | |
| folders.sort(key=lambda x: x[0]) | |
| return folders | |
| seen = set() | |
| folders = [] | |
| for f in youtube_files: | |
| parts = f.replace(youtube_inner + "/", "").split("/") | |
| if parts and parts[0].startswith("generated_") and parts[0] not in seen: | |
| seen.add(parts[0]) | |
| folders.append((parts[0], youtube_inner + "/" + parts[0])) | |
| folders.sort(key=lambda x: x[0]) | |
| return folders | |
| def _match_dasheng_folder(folder_name: str, folders: list[tuple[str, str]]) -> Optional[tuple[str, str]]: | |
| """Match prompt-derived folder_name to actual folder; allow truncated names and hyphen/underscore.""" | |
| if not folder_name or not folders: | |
| return None | |
| # Normalize: prompt may have "ground-level" / "low-intensity" but dir is "ground_level" / "low_inte" | |
| normalized = folder_name.replace("-", "_") | |
| # Exact match | |
| for fn, fp in folders: | |
| if fn == folder_name or fn == normalized: | |
| return (fn, fp) | |
| # Folder may be truncated: actual fn is prefix of folder_name (e.g. fn="..._low_inte", folder_name="..._low_intensity_...") | |
| candidates = [(fn, fp) for fn, fp in folders if normalized.startswith(fn) or folder_name.startswith(fn)] | |
| if candidates: | |
| return max(candidates, key=lambda x: len(x[0])) | |
| # Or folder_name (or normalized) is prefix of fn | |
| candidates = [(fn, fp) for fn, fp in folders if fn.startswith(normalized) or fn.startswith(folder_name)] | |
| if candidates: | |
| return min(candidates, key=lambda x: len(x[0])) | |
| return None | |
| baseline_folders = get_baseline_folders(baseline_inner, baseline_files) | |
| youtube_folders = get_youtube_folders() | |
| result = {} | |
| for i in range(1, 4): | |
| prompt_text = prompts[i - 1].get("prompt", "") if i <= len(prompts) else "" | |
| bl_prefix = f"generated_baseline_{i:02d}_" | |
| rel_prefix = f"generated_{i:02d}_" | |
| bl_orig = {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None} | |
| for fn, fp in baseline_folders: | |
| if fn.startswith(bl_prefix): | |
| bl_orig = _collect_block(baseline_files, fp) | |
| break | |
| gaussian_block = _collect_block(gaussian_files, gaussian_inner) | |
| bl_youtube = {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None} | |
| if use_dasheng: | |
| folder_name = prompt_text.replace(" ", "_") if prompt_text else "" | |
| matched = _match_dasheng_folder(folder_name, youtube_folders) | |
| if matched: | |
| fn, fp = matched | |
| bl_youtube = _collect_block(youtube_files, fp) | |
| else: | |
| for fn, fp in youtube_folders: | |
| if fn.startswith(rel_prefix): | |
| bl_youtube = _collect_block(youtube_files, fp) | |
| break | |
| if use_dasheng: | |
| folder_name = prompt_text.replace(" ", "_") if prompt_text else "" | |
| # Ours: list prompt-named dirs under inner (exclude baseline, gaussian_baseline, youtube_noise_baseline) | |
| skip = {"baseline", "youtube_noise_baseline", "gaussian_baseline"} | |
| inner_dirs = set() | |
| for f in files: | |
| if not f.startswith(inner + "/"): | |
| continue | |
| rest = f.replace(inner + "/", "", 1) | |
| if "/" in rest: | |
| top = rest.split("/")[0] | |
| if top not in skip and not top.startswith("generated_baseline"): | |
| inner_dirs.add(top) | |
| inner_folders = [(d, inner + "/" + d) for d in sorted(inner_dirs)] | |
| ours_fn_fp = _match_dasheng_folder(folder_name, inner_folders) | |
| if ours_fn_fp: | |
| fn, fp = ours_fn_fp | |
| nn_files = [f for f in files if f.startswith(fp + "/")] | |
| ours_block = _collect_block(nn_files, fp) | |
| else: | |
| ours_block = {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None} | |
| else: | |
| nn_files = [f for f in files if f.replace(inner + "/", "").startswith(rel_prefix)] | |
| ours_block = _collect_block(nn_files, inner + "/" + rel_prefix) | |
| result[f"block{i}"] = { | |
| "prompt": prompt_text, | |
| "baseline_original": bl_orig, | |
| "baseline_gaussian": gaussian_block, | |
| "baseline_youtube": bl_youtube, | |
| "ours": ours_block, | |
| } | |
| return result | |