""" Load data from Hugging Face dataset AE-W/batch_outputs. Uses huggingface_hub to list and download files on demand. """ import json import os import re from pathlib import Path from typing import Optional from huggingface_hub import HfApi, hf_hub_download, list_repo_files from huggingface_hub import HfFileSystem REPO_ID = "AE-W/batch_outputs" REPO_TYPE = "dataset" ROOT_PREFIX = "batch_outputs/" DASHENG_PREFIX = "batch_outputs_dasheng/" # Three methods: each has batch_outputs_* + generated_noises_* BIN_BATCH_PREFIX = "batch_outputs_bin/" BIN_GENERATED_PREFIX = "generated_noises_bin/" CLAP_BATCH_PREFIX = "batch_outputs/" # ROOT_PREFIX CLAP_GENERATED_PREFIX = "generated_noises_clap/" DASHENG_BATCH_PREFIX = "batch_outputs_dasheng/" DASHENG_GENERATED_PREFIX = "generated_noises_dasheng/" GENERATED_SKIP_IDS = {"__pycache__", "NearestNeighbor_space_push"} # Cache full repo file list so we only call list_repo_files once per process (major speedup) _cached_repo_files: Optional[list[str]] = None def _get_repo_files() -> list[str]: """Return full list of repo file paths, cached after first call.""" global _cached_repo_files if _cached_repo_files is None: _cached_repo_files = list_repo_files(REPO_ID, repo_type=REPO_TYPE) return _cached_repo_files def _get_sample_ids(prefix: str = ROOT_PREFIX) -> list[str]: """List sample IDs (e.g. 07_003277) under given prefix in repo.""" files = _get_repo_files() seen = set() pat = re.escape(prefix.rstrip("/")) + r"/([^/]+)/" for f in files: m = re.match(pat, f) if m: seen.add(m.group(1)) return sorted(seen) def _get_all_sample_ids() -> list[str]: """Union of sample IDs from batch_outputs and batch_outputs_dasheng.""" ids = set(_get_sample_ids(ROOT_PREFIX)) | set(_get_sample_ids(DASHENG_PREFIX)) return sorted(ids) def _download_file(path_in_repo: str, local_dir: Optional[str] = None) -> str: """Download a file from the dataset; return local path.""" return hf_hub_download( repo_id=REPO_ID, filename=path_in_repo, repo_type=REPO_TYPE, local_dir=local_dir, force_download=False, ) def _load_json_from_repo(path_in_repo: str) -> Optional[list]: """Download and load JSON file from repo.""" try: path = _download_file(path_in_repo) with open(path, "r", encoding="utf-8") as f: return json.load(f) except Exception: return None def list_samples() -> list[str]: """Return list of sample IDs (bid) from both batch_outputs and batch_outputs_dasheng.""" return _get_all_sample_ids() def list_samples_bin() -> list[str]: """Sample IDs for Bin: batch_outputs_bin ∪ generated_noises_bin (exclude __pycache__ etc).""" batch_ids = set(_get_sample_ids(BIN_BATCH_PREFIX)) gen_ids = set(_get_sample_ids(BIN_GENERATED_PREFIX)) - GENERATED_SKIP_IDS return sorted(batch_ids | gen_ids) def list_samples_clap() -> list[str]: """Sample IDs for Clap: batch_outputs ∪ generated_noises_clap.""" batch_ids = set(_get_sample_ids(CLAP_BATCH_PREFIX)) gen_ids = set(_get_sample_ids(CLAP_GENERATED_PREFIX)) - GENERATED_SKIP_IDS return sorted(batch_ids | gen_ids) def list_samples_dasheng() -> list[str]: """Sample IDs for Dasheng: batch_outputs_dasheng (excl. fold*) ∪ generated_noises_dasheng.""" batch_ids = {x for x in _get_sample_ids(DASHENG_BATCH_PREFIX) if not x.startswith("fold")} gen_ids = set(_get_sample_ids(DASHENG_GENERATED_PREFIX)) - GENERATED_SKIP_IDS return sorted(batch_ids | gen_ids) def _find_files(inner: str) -> list[str]: """List all repo files under inner path (uses cached repo file list).""" files = _get_repo_files() return [f for f in files if f.startswith(inner + "/")] def _list_files_under_via_fs(path_in_repo: str) -> list[str]: """List files under path_in_repo using HfFileSystem (avoids relying on full list_repo_files).""" try: fs = HfFileSystem() fs_prefix = f"datasets/{REPO_ID}/{path_in_repo}" strip = f"datasets/{REPO_ID}/" found = fs.glob(fs_prefix + "/**") out = [] for p in found: if p.endswith("/"): continue rel = p.replace(strip, "").lstrip("/") if rel.startswith(path_in_repo + "/"): out.append(rel) return out except Exception: return [] def _has_files_under(prefix_bid: str) -> bool: """True if repo has any file under prefix_bid/.""" files = _get_repo_files() return any(f.startswith(prefix_bid + "/") for f in files) def resolve_sample_prefix(bid: str, method: str) -> Optional[str]: """ Resolve which repo prefix contains this sample_id (batch first, then generated). method: "bin" | "clap" | "dasheng" Returns e.g. "batch_outputs_bin/" or "generated_noises_bin/". """ if method == "bin": batch, generated = BIN_BATCH_PREFIX, BIN_GENERATED_PREFIX elif method == "clap": batch, generated = CLAP_BATCH_PREFIX, CLAP_GENERATED_PREFIX elif method == "dasheng": batch, generated = DASHENG_BATCH_PREFIX, DASHENG_GENERATED_PREFIX else: return None if _has_files_under(batch + bid): return batch if _has_files_under(generated + bid): return generated return None def get_inner_path(prefix: str, bid: str) -> Optional[str]: """ Return the inner path (contains baseline/, natural_prompts.json, etc.). For batch_outputs*: prefix/bid/bid. For generated_noises*: same or prefix/bid/X/X if bid/X/X. """ inner_std = f"{prefix}{bid}/{bid}" if _has_files_under(inner_std): return inner_std # generated_noises: may have prefix/bid/X/X (e.g. cars_honking/heavy_machinery/heavy_machinery) if not prefix.startswith("generated_noises"): return inner_std files = _get_repo_files() for f in files: if not f.startswith(prefix + bid + "/"): continue if "natural_prompts.json" in f or "temp_retrieval.json" in f: parts = f.split("/") # prefix/bid/X/X/file -> inner = prefix/bid/X/X if len(parts) >= 4: return "/".join(parts[:4]) return inner_std def _collect_block(file_list: list, folder_prefix: str) -> dict: """From files under folder_prefix, get spec + bg_wav, fg_wav, m_wav.""" spec = bg = fg = m = None for f in file_list: if folder_prefix not in f: continue name = f.split("/")[-1] if name.endswith(".png"): spec = f elif name.endswith("_bg.wav"): bg = f elif name.endswith("_fg.wav"): fg = f elif name.endswith("_m.wav"): m = f return { "spec": _download_file(spec) if spec else None, "bg_wav": _download_file(bg) if bg else None, "fg_wav": _download_file(fg) if fg else None, "m_wav": _download_file(m) if m else None, } def get_nn_demo_paths(bid: str, top_k: int = 10, root_prefix: Optional[str] = None, method: Optional[str] = None) -> dict: """ For NN view: NN1-NN10 from baseline (generated_baseline_01, 02, ..., 10) in prompt order. root_prefix: legacy; if method is set (bin|clap|dasheng), resolve prefix and inner from repo. Returns {nn_list: [{spec, bg_wav, fg_wav, m_wav, prompt, similarity}, ...]}. """ if method is not None: prefix = resolve_sample_prefix(bid, method) if not prefix: return {"nn_list": []} inner = get_inner_path(prefix, bid) if not inner: return {"nn_list": []} else: prefix = root_prefix if root_prefix is not None else ROOT_PREFIX inner = f"{prefix}{bid}/{bid}" prompts = _load_json_from_repo(f"{inner}/temp_retrieval.json") if not prompts: prompts = _load_json_from_repo(f"{inner}/natural_prompts.json") if not prompts: return {"nn_list": []} files = _find_files(inner) baseline_inner = f"{inner}/baseline" baseline_files = _find_files(baseline_inner) if any(f.startswith(baseline_inner) for f in files) else [] nn_list = [] for i, p in enumerate(prompts[:top_k]): prompt = p.get("prompt", "") sim = p.get("similarity_score", p.get("retrieval_score")) bl_prefix = f"generated_baseline_{i+1:02d}_" block = {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None} for f in baseline_files: parts = f.replace(baseline_inner + "/", "").split("/") if parts and parts[0].startswith(bl_prefix): full_prefix = baseline_inner + "/" + parts[0] block = _collect_block(baseline_files, full_prefix) break block["prompt"] = prompt block["similarity"] = sim nn_list.append(block) return {"nn_list": nn_list} def get_noise_demo_paths(bid: str) -> dict: """ One block per prompt (1, 2, 3): each has prompt text, baseline (spec + 3 wavs), and our method (spec + 3 wavs). Returns { "block1": {prompt, baseline: {...}, nn: {...}}, "block2": ..., "block3": ... }. """ inner = f"{ROOT_PREFIX}{bid}/{bid}" files = _find_files(inner) baseline_inner = f"{inner}/baseline" baseline_files = _find_files(baseline_inner) if any(f.startswith(baseline_inner) for f in files) else [] prompts = _load_json_from_repo(f"{inner}/temp_retrieval.json") if not prompts: prompts = _load_json_from_repo(f"{inner}/natural_prompts.json") if not prompts: prompts = [] # Find baseline folder names generated_baseline_01_*, 02_*, 03_* seen = set() baseline_folders = [] for f in baseline_files: parts = f.replace(baseline_inner + "/", "").split("/") if parts and parts[0].startswith("generated_baseline_") and parts[0] not in seen: seen.add(parts[0]) baseline_folders.append((parts[0], baseline_inner + "/" + parts[0])) baseline_folders.sort(key=lambda x: x[0]) result = {} for i in range(1, 4): prompt_text = prompts[i - 1].get("prompt", "") if i <= len(prompts) else "" bl_prefix = f"generated_baseline_{i:02d}_" baseline_block = {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None} for folder_name, full_prefix in baseline_folders: if folder_name.startswith(bl_prefix): baseline_block = _collect_block(baseline_files, full_prefix) break rel_prefix = f"generated_{i:02d}_" nn_files = [f for f in files if f.replace(inner + "/", "").startswith(rel_prefix)] nn_block = _collect_block(nn_files, rel_prefix) nn_block["prompt"] = prompt_text result[f"block{i}"] = { "prompt": prompt_text, "baseline": baseline_block, "nn": nn_block, } return result def get_results_demo_paths(bid: str, root_prefix: Optional[str] = None, method: Optional[str] = None) -> dict: """ For Results view: 3 blocks (prompts 1-3), each with 4 columns: Baseline (original), Gaussian, Youtube-noise, Ours. root_prefix: legacy; if method is set (bin|clap|dasheng), resolve prefix and inner from repo. """ if method is not None: prefix = resolve_sample_prefix(bid, method) if not prefix: return {} inner = get_inner_path(prefix, bid) if not inner: return {} # Dasheng-style: prompt-named folders (batch_outputs_bin, batch_outputs_dasheng, all generated_noises_*) use_dasheng = prefix in (BIN_BATCH_PREFIX, DASHENG_BATCH_PREFIX, DASHENG_GENERATED_PREFIX) or prefix.startswith("generated_noises") else: prefix = root_prefix if root_prefix is not None else ROOT_PREFIX inner = f"{prefix}{bid}/{bid}" use_dasheng = root_prefix == DASHENG_PREFIX files = _find_files(inner) baseline_inner = f"{inner}/baseline" gaussian_inner = f"{inner}/gaussian_baseline" youtube_inner = f"{inner}/youtube_noise_baseline" # Use full repo file list for baseline/gaussian/youtube so we find them even if "files" is partial all_repo = _get_repo_files() baseline_files = _find_files(baseline_inner) if any(f.startswith(baseline_inner + "/") for f in all_repo) else [] gaussian_files = _find_files(gaussian_inner) if any(f.startswith(gaussian_inner + "/") for f in all_repo) else [] youtube_files = _find_files(youtube_inner) if any(f.startswith(youtube_inner + "/") for f in all_repo) else [] # Fallback for bin/generated: gaussian or youtube may live under a different inner (e.g. prefix/bid/X/X) if not gaussian_files and (prefix == BIN_BATCH_PREFIX or prefix.startswith("generated_noises")): for f in all_repo: if f.startswith(prefix + bid + "/") and "/gaussian_baseline/" in f and (f.endswith("_m.wav") or f.endswith("_bg.wav")): gaussian_inner_fb = f.rsplit("/", 1)[0] # path to gaussian_baseline dir gaussian_files = _find_files(gaussian_inner_fb) gaussian_inner = gaussian_inner_fb break if not youtube_files and (prefix == BIN_BATCH_PREFIX or prefix.startswith("generated_noises")): for f in all_repo: if f.startswith(prefix + bid + "/") and "/youtube_noise_baseline/" in f and (f.endswith("_m.wav") or f.endswith("_bg.wav")): # f = .../youtube_noise_baseline//file_m.wav -> parent = youtube_noise_baseline youtube_inner_fb = f.split("/youtube_noise_baseline/", 1)[0] + "/youtube_noise_baseline" youtube_files = _find_files(youtube_inner_fb) youtube_inner = youtube_inner_fb break # For generated_noises: list_repo_files() may not include these paths in large repos; use HfFileSystem by path if prefix.startswith("generated_noises"): if not gaussian_files: gaussian_files = _list_files_under_via_fs(gaussian_inner) if not youtube_files: youtube_files = _list_files_under_via_fs(youtube_inner) prompts = _load_json_from_repo(f"{inner}/temp_retrieval.json") if not prompts: prompts = _load_json_from_repo(f"{inner}/natural_prompts.json") if not prompts: prompts = [] def get_baseline_folders(bl_inner, bl_files): seen = set() folders = [] for f in bl_files: parts = f.replace(bl_inner + "/", "").split("/") if parts and parts[0].startswith("generated_baseline_") and parts[0] not in seen: seen.add(parts[0]) folders.append((parts[0], bl_inner + "/" + parts[0])) folders.sort(key=lambda x: x[0]) return folders def get_youtube_folders(): if use_dasheng: # Dasheng: subdirs are prompt names (underscores) seen = set() folders = [] for f in youtube_files: parts = f.replace(youtube_inner + "/", "").split("/") if parts and parts[0] not in seen: seen.add(parts[0]) folders.append((parts[0], youtube_inner + "/" + parts[0])) folders.sort(key=lambda x: x[0]) return folders seen = set() folders = [] for f in youtube_files: parts = f.replace(youtube_inner + "/", "").split("/") if parts and parts[0].startswith("generated_") and parts[0] not in seen: seen.add(parts[0]) folders.append((parts[0], youtube_inner + "/" + parts[0])) folders.sort(key=lambda x: x[0]) return folders def _match_dasheng_folder(folder_name: str, folders: list[tuple[str, str]]) -> Optional[tuple[str, str]]: """Match prompt-derived folder_name to actual folder; allow truncated names and hyphen/underscore.""" if not folder_name or not folders: return None # Normalize: prompt may have "ground-level" / "low-intensity" but dir is "ground_level" / "low_inte" normalized = folder_name.replace("-", "_") # Exact match for fn, fp in folders: if fn == folder_name or fn == normalized: return (fn, fp) # Folder may be truncated: actual fn is prefix of folder_name (e.g. fn="..._low_inte", folder_name="..._low_intensity_...") candidates = [(fn, fp) for fn, fp in folders if normalized.startswith(fn) or folder_name.startswith(fn)] if candidates: return max(candidates, key=lambda x: len(x[0])) # Or folder_name (or normalized) is prefix of fn candidates = [(fn, fp) for fn, fp in folders if fn.startswith(normalized) or fn.startswith(folder_name)] if candidates: return min(candidates, key=lambda x: len(x[0])) return None baseline_folders = get_baseline_folders(baseline_inner, baseline_files) youtube_folders = get_youtube_folders() result = {} for i in range(1, 4): prompt_text = prompts[i - 1].get("prompt", "") if i <= len(prompts) else "" bl_prefix = f"generated_baseline_{i:02d}_" rel_prefix = f"generated_{i:02d}_" bl_orig = {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None} for fn, fp in baseline_folders: if fn.startswith(bl_prefix): bl_orig = _collect_block(baseline_files, fp) break gaussian_block = _collect_block(gaussian_files, gaussian_inner) bl_youtube = {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None} if use_dasheng: folder_name = prompt_text.replace(" ", "_") if prompt_text else "" matched = _match_dasheng_folder(folder_name, youtube_folders) if matched: fn, fp = matched bl_youtube = _collect_block(youtube_files, fp) else: for fn, fp in youtube_folders: if fn.startswith(rel_prefix): bl_youtube = _collect_block(youtube_files, fp) break if use_dasheng: folder_name = prompt_text.replace(" ", "_") if prompt_text else "" # Ours: list prompt-named dirs under inner (exclude baseline, gaussian_baseline, youtube_noise_baseline) skip = {"baseline", "youtube_noise_baseline", "gaussian_baseline"} inner_dirs = set() for f in files: if not f.startswith(inner + "/"): continue rest = f.replace(inner + "/", "", 1) if "/" in rest: top = rest.split("/")[0] if top not in skip and not top.startswith("generated_baseline"): inner_dirs.add(top) inner_folders = [(d, inner + "/" + d) for d in sorted(inner_dirs)] ours_fn_fp = _match_dasheng_folder(folder_name, inner_folders) if ours_fn_fp: fn, fp = ours_fn_fp nn_files = [f for f in files if f.startswith(fp + "/")] ours_block = _collect_block(nn_files, fp) else: ours_block = {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None} else: nn_files = [f for f in files if f.replace(inner + "/", "").startswith(rel_prefix)] ours_block = _collect_block(nn_files, inner + "/" + rel_prefix) result[f"block{i}"] = { "prompt": prompt_text, "baseline_original": bl_orig, "baseline_gaussian": gaussian_block, "baseline_youtube": bl_youtube, "ours": ours_block, } return result