NearestNeighbor / dataset_loader.py
AE-W's picture
Upload folder using huggingface_hub
9d63714 verified
"""
Load data from Hugging Face dataset AE-W/batch_outputs.
Uses huggingface_hub to list and download files on demand.
"""
import json
import os
import re
from pathlib import Path
from typing import Optional
from huggingface_hub import HfApi, hf_hub_download, list_repo_files
from huggingface_hub import HfFileSystem
REPO_ID = "AE-W/batch_outputs"
REPO_TYPE = "dataset"
ROOT_PREFIX = "batch_outputs/"
DASHENG_PREFIX = "batch_outputs_dasheng/"
# Three methods: each has batch_outputs_* + generated_noises_*
BIN_BATCH_PREFIX = "batch_outputs_bin/"
BIN_GENERATED_PREFIX = "generated_noises_bin/"
CLAP_BATCH_PREFIX = "batch_outputs/" # ROOT_PREFIX
CLAP_GENERATED_PREFIX = "generated_noises_clap/"
DASHENG_BATCH_PREFIX = "batch_outputs_dasheng/"
DASHENG_GENERATED_PREFIX = "generated_noises_dasheng/"
GENERATED_SKIP_IDS = {"__pycache__", "NearestNeighbor_space_push"}
# Cache full repo file list so we only call list_repo_files once per process (major speedup)
_cached_repo_files: Optional[list[str]] = None
def _get_repo_files() -> list[str]:
"""Return full list of repo file paths, cached after first call."""
global _cached_repo_files
if _cached_repo_files is None:
_cached_repo_files = list_repo_files(REPO_ID, repo_type=REPO_TYPE)
return _cached_repo_files
def _get_sample_ids(prefix: str = ROOT_PREFIX) -> list[str]:
"""List sample IDs (e.g. 07_003277) under given prefix in repo."""
files = _get_repo_files()
seen = set()
pat = re.escape(prefix.rstrip("/")) + r"/([^/]+)/"
for f in files:
m = re.match(pat, f)
if m:
seen.add(m.group(1))
return sorted(seen)
def _get_all_sample_ids() -> list[str]:
"""Union of sample IDs from batch_outputs and batch_outputs_dasheng."""
ids = set(_get_sample_ids(ROOT_PREFIX)) | set(_get_sample_ids(DASHENG_PREFIX))
return sorted(ids)
def _download_file(path_in_repo: str, local_dir: Optional[str] = None) -> str:
"""Download a file from the dataset; return local path."""
return hf_hub_download(
repo_id=REPO_ID,
filename=path_in_repo,
repo_type=REPO_TYPE,
local_dir=local_dir,
force_download=False,
)
def _load_json_from_repo(path_in_repo: str) -> Optional[list]:
"""Download and load JSON file from repo."""
try:
path = _download_file(path_in_repo)
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
return None
def list_samples() -> list[str]:
"""Return list of sample IDs (bid) from both batch_outputs and batch_outputs_dasheng."""
return _get_all_sample_ids()
def list_samples_bin() -> list[str]:
"""Sample IDs for Bin: batch_outputs_bin ∪ generated_noises_bin (exclude __pycache__ etc)."""
batch_ids = set(_get_sample_ids(BIN_BATCH_PREFIX))
gen_ids = set(_get_sample_ids(BIN_GENERATED_PREFIX)) - GENERATED_SKIP_IDS
return sorted(batch_ids | gen_ids)
def list_samples_clap() -> list[str]:
"""Sample IDs for Clap: batch_outputs ∪ generated_noises_clap."""
batch_ids = set(_get_sample_ids(CLAP_BATCH_PREFIX))
gen_ids = set(_get_sample_ids(CLAP_GENERATED_PREFIX)) - GENERATED_SKIP_IDS
return sorted(batch_ids | gen_ids)
def list_samples_dasheng() -> list[str]:
"""Sample IDs for Dasheng: batch_outputs_dasheng (excl. fold*) ∪ generated_noises_dasheng."""
batch_ids = {x for x in _get_sample_ids(DASHENG_BATCH_PREFIX) if not x.startswith("fold")}
gen_ids = set(_get_sample_ids(DASHENG_GENERATED_PREFIX)) - GENERATED_SKIP_IDS
return sorted(batch_ids | gen_ids)
def _find_files(inner: str) -> list[str]:
"""List all repo files under inner path (uses cached repo file list)."""
files = _get_repo_files()
return [f for f in files if f.startswith(inner + "/")]
def _list_files_under_via_fs(path_in_repo: str) -> list[str]:
"""List files under path_in_repo using HfFileSystem (avoids relying on full list_repo_files)."""
try:
fs = HfFileSystem()
fs_prefix = f"datasets/{REPO_ID}/{path_in_repo}"
strip = f"datasets/{REPO_ID}/"
found = fs.glob(fs_prefix + "/**")
out = []
for p in found:
if p.endswith("/"):
continue
rel = p.replace(strip, "").lstrip("/")
if rel.startswith(path_in_repo + "/"):
out.append(rel)
return out
except Exception:
return []
def _has_files_under(prefix_bid: str) -> bool:
"""True if repo has any file under prefix_bid/."""
files = _get_repo_files()
return any(f.startswith(prefix_bid + "/") for f in files)
def resolve_sample_prefix(bid: str, method: str) -> Optional[str]:
"""
Resolve which repo prefix contains this sample_id (batch first, then generated).
method: "bin" | "clap" | "dasheng"
Returns e.g. "batch_outputs_bin/" or "generated_noises_bin/".
"""
if method == "bin":
batch, generated = BIN_BATCH_PREFIX, BIN_GENERATED_PREFIX
elif method == "clap":
batch, generated = CLAP_BATCH_PREFIX, CLAP_GENERATED_PREFIX
elif method == "dasheng":
batch, generated = DASHENG_BATCH_PREFIX, DASHENG_GENERATED_PREFIX
else:
return None
if _has_files_under(batch + bid):
return batch
if _has_files_under(generated + bid):
return generated
return None
def get_inner_path(prefix: str, bid: str) -> Optional[str]:
"""
Return the inner path (contains baseline/, natural_prompts.json, etc.).
For batch_outputs*: prefix/bid/bid. For generated_noises*: same or prefix/bid/X/X if bid/X/X.
"""
inner_std = f"{prefix}{bid}/{bid}"
if _has_files_under(inner_std):
return inner_std
# generated_noises: may have prefix/bid/X/X (e.g. cars_honking/heavy_machinery/heavy_machinery)
if not prefix.startswith("generated_noises"):
return inner_std
files = _get_repo_files()
for f in files:
if not f.startswith(prefix + bid + "/"):
continue
if "natural_prompts.json" in f or "temp_retrieval.json" in f:
parts = f.split("/")
# prefix/bid/X/X/file -> inner = prefix/bid/X/X
if len(parts) >= 4:
return "/".join(parts[:4])
return inner_std
def _collect_block(file_list: list, folder_prefix: str) -> dict:
"""From files under folder_prefix, get spec + bg_wav, fg_wav, m_wav."""
spec = bg = fg = m = None
for f in file_list:
if folder_prefix not in f:
continue
name = f.split("/")[-1]
if name.endswith(".png"):
spec = f
elif name.endswith("_bg.wav"):
bg = f
elif name.endswith("_fg.wav"):
fg = f
elif name.endswith("_m.wav"):
m = f
return {
"spec": _download_file(spec) if spec else None,
"bg_wav": _download_file(bg) if bg else None,
"fg_wav": _download_file(fg) if fg else None,
"m_wav": _download_file(m) if m else None,
}
def get_nn_demo_paths(bid: str, top_k: int = 10, root_prefix: Optional[str] = None, method: Optional[str] = None) -> dict:
"""
For NN view: NN1-NN10 from baseline (generated_baseline_01, 02, ..., 10) in prompt order.
root_prefix: legacy; if method is set (bin|clap|dasheng), resolve prefix and inner from repo.
Returns {nn_list: [{spec, bg_wav, fg_wav, m_wav, prompt, similarity}, ...]}.
"""
if method is not None:
prefix = resolve_sample_prefix(bid, method)
if not prefix:
return {"nn_list": []}
inner = get_inner_path(prefix, bid)
if not inner:
return {"nn_list": []}
else:
prefix = root_prefix if root_prefix is not None else ROOT_PREFIX
inner = f"{prefix}{bid}/{bid}"
prompts = _load_json_from_repo(f"{inner}/temp_retrieval.json")
if not prompts:
prompts = _load_json_from_repo(f"{inner}/natural_prompts.json")
if not prompts:
return {"nn_list": []}
files = _find_files(inner)
baseline_inner = f"{inner}/baseline"
baseline_files = _find_files(baseline_inner) if any(f.startswith(baseline_inner) for f in files) else []
nn_list = []
for i, p in enumerate(prompts[:top_k]):
prompt = p.get("prompt", "")
sim = p.get("similarity_score", p.get("retrieval_score"))
bl_prefix = f"generated_baseline_{i+1:02d}_"
block = {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None}
for f in baseline_files:
parts = f.replace(baseline_inner + "/", "").split("/")
if parts and parts[0].startswith(bl_prefix):
full_prefix = baseline_inner + "/" + parts[0]
block = _collect_block(baseline_files, full_prefix)
break
block["prompt"] = prompt
block["similarity"] = sim
nn_list.append(block)
return {"nn_list": nn_list}
def get_noise_demo_paths(bid: str) -> dict:
"""
One block per prompt (1, 2, 3): each has prompt text, baseline (spec + 3 wavs), and our method (spec + 3 wavs).
Returns { "block1": {prompt, baseline: {...}, nn: {...}}, "block2": ..., "block3": ... }.
"""
inner = f"{ROOT_PREFIX}{bid}/{bid}"
files = _find_files(inner)
baseline_inner = f"{inner}/baseline"
baseline_files = _find_files(baseline_inner) if any(f.startswith(baseline_inner) for f in files) else []
prompts = _load_json_from_repo(f"{inner}/temp_retrieval.json")
if not prompts:
prompts = _load_json_from_repo(f"{inner}/natural_prompts.json")
if not prompts:
prompts = []
# Find baseline folder names generated_baseline_01_*, 02_*, 03_*
seen = set()
baseline_folders = []
for f in baseline_files:
parts = f.replace(baseline_inner + "/", "").split("/")
if parts and parts[0].startswith("generated_baseline_") and parts[0] not in seen:
seen.add(parts[0])
baseline_folders.append((parts[0], baseline_inner + "/" + parts[0]))
baseline_folders.sort(key=lambda x: x[0])
result = {}
for i in range(1, 4):
prompt_text = prompts[i - 1].get("prompt", "") if i <= len(prompts) else ""
bl_prefix = f"generated_baseline_{i:02d}_"
baseline_block = {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None}
for folder_name, full_prefix in baseline_folders:
if folder_name.startswith(bl_prefix):
baseline_block = _collect_block(baseline_files, full_prefix)
break
rel_prefix = f"generated_{i:02d}_"
nn_files = [f for f in files if f.replace(inner + "/", "").startswith(rel_prefix)]
nn_block = _collect_block(nn_files, rel_prefix)
nn_block["prompt"] = prompt_text
result[f"block{i}"] = {
"prompt": prompt_text,
"baseline": baseline_block,
"nn": nn_block,
}
return result
def get_results_demo_paths(bid: str, root_prefix: Optional[str] = None, method: Optional[str] = None) -> dict:
"""
For Results view: 3 blocks (prompts 1-3), each with 4 columns:
Baseline (original), Gaussian, Youtube-noise, Ours.
root_prefix: legacy; if method is set (bin|clap|dasheng), resolve prefix and inner from repo.
"""
if method is not None:
prefix = resolve_sample_prefix(bid, method)
if not prefix:
return {}
inner = get_inner_path(prefix, bid)
if not inner:
return {}
# Dasheng-style: prompt-named folders (batch_outputs_bin, batch_outputs_dasheng, all generated_noises_*)
use_dasheng = prefix in (BIN_BATCH_PREFIX, DASHENG_BATCH_PREFIX, DASHENG_GENERATED_PREFIX) or prefix.startswith("generated_noises")
else:
prefix = root_prefix if root_prefix is not None else ROOT_PREFIX
inner = f"{prefix}{bid}/{bid}"
use_dasheng = root_prefix == DASHENG_PREFIX
files = _find_files(inner)
baseline_inner = f"{inner}/baseline"
gaussian_inner = f"{inner}/gaussian_baseline"
youtube_inner = f"{inner}/youtube_noise_baseline"
# Use full repo file list for baseline/gaussian/youtube so we find them even if "files" is partial
all_repo = _get_repo_files()
baseline_files = _find_files(baseline_inner) if any(f.startswith(baseline_inner + "/") for f in all_repo) else []
gaussian_files = _find_files(gaussian_inner) if any(f.startswith(gaussian_inner + "/") for f in all_repo) else []
youtube_files = _find_files(youtube_inner) if any(f.startswith(youtube_inner + "/") for f in all_repo) else []
# Fallback for bin/generated: gaussian or youtube may live under a different inner (e.g. prefix/bid/X/X)
if not gaussian_files and (prefix == BIN_BATCH_PREFIX or prefix.startswith("generated_noises")):
for f in all_repo:
if f.startswith(prefix + bid + "/") and "/gaussian_baseline/" in f and (f.endswith("_m.wav") or f.endswith("_bg.wav")):
gaussian_inner_fb = f.rsplit("/", 1)[0] # path to gaussian_baseline dir
gaussian_files = _find_files(gaussian_inner_fb)
gaussian_inner = gaussian_inner_fb
break
if not youtube_files and (prefix == BIN_BATCH_PREFIX or prefix.startswith("generated_noises")):
for f in all_repo:
if f.startswith(prefix + bid + "/") and "/youtube_noise_baseline/" in f and (f.endswith("_m.wav") or f.endswith("_bg.wav")):
# f = .../youtube_noise_baseline/<prompt_folder>/file_m.wav -> parent = youtube_noise_baseline
youtube_inner_fb = f.split("/youtube_noise_baseline/", 1)[0] + "/youtube_noise_baseline"
youtube_files = _find_files(youtube_inner_fb)
youtube_inner = youtube_inner_fb
break
# For generated_noises: list_repo_files() may not include these paths in large repos; use HfFileSystem by path
if prefix.startswith("generated_noises"):
if not gaussian_files:
gaussian_files = _list_files_under_via_fs(gaussian_inner)
if not youtube_files:
youtube_files = _list_files_under_via_fs(youtube_inner)
prompts = _load_json_from_repo(f"{inner}/temp_retrieval.json")
if not prompts:
prompts = _load_json_from_repo(f"{inner}/natural_prompts.json")
if not prompts:
prompts = []
def get_baseline_folders(bl_inner, bl_files):
seen = set()
folders = []
for f in bl_files:
parts = f.replace(bl_inner + "/", "").split("/")
if parts and parts[0].startswith("generated_baseline_") and parts[0] not in seen:
seen.add(parts[0])
folders.append((parts[0], bl_inner + "/" + parts[0]))
folders.sort(key=lambda x: x[0])
return folders
def get_youtube_folders():
if use_dasheng:
# Dasheng: subdirs are prompt names (underscores)
seen = set()
folders = []
for f in youtube_files:
parts = f.replace(youtube_inner + "/", "").split("/")
if parts and parts[0] not in seen:
seen.add(parts[0])
folders.append((parts[0], youtube_inner + "/" + parts[0]))
folders.sort(key=lambda x: x[0])
return folders
seen = set()
folders = []
for f in youtube_files:
parts = f.replace(youtube_inner + "/", "").split("/")
if parts and parts[0].startswith("generated_") and parts[0] not in seen:
seen.add(parts[0])
folders.append((parts[0], youtube_inner + "/" + parts[0]))
folders.sort(key=lambda x: x[0])
return folders
def _match_dasheng_folder(folder_name: str, folders: list[tuple[str, str]]) -> Optional[tuple[str, str]]:
"""Match prompt-derived folder_name to actual folder; allow truncated names and hyphen/underscore."""
if not folder_name or not folders:
return None
# Normalize: prompt may have "ground-level" / "low-intensity" but dir is "ground_level" / "low_inte"
normalized = folder_name.replace("-", "_")
# Exact match
for fn, fp in folders:
if fn == folder_name or fn == normalized:
return (fn, fp)
# Folder may be truncated: actual fn is prefix of folder_name (e.g. fn="..._low_inte", folder_name="..._low_intensity_...")
candidates = [(fn, fp) for fn, fp in folders if normalized.startswith(fn) or folder_name.startswith(fn)]
if candidates:
return max(candidates, key=lambda x: len(x[0]))
# Or folder_name (or normalized) is prefix of fn
candidates = [(fn, fp) for fn, fp in folders if fn.startswith(normalized) or fn.startswith(folder_name)]
if candidates:
return min(candidates, key=lambda x: len(x[0]))
return None
baseline_folders = get_baseline_folders(baseline_inner, baseline_files)
youtube_folders = get_youtube_folders()
result = {}
for i in range(1, 4):
prompt_text = prompts[i - 1].get("prompt", "") if i <= len(prompts) else ""
bl_prefix = f"generated_baseline_{i:02d}_"
rel_prefix = f"generated_{i:02d}_"
bl_orig = {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None}
for fn, fp in baseline_folders:
if fn.startswith(bl_prefix):
bl_orig = _collect_block(baseline_files, fp)
break
gaussian_block = _collect_block(gaussian_files, gaussian_inner)
bl_youtube = {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None}
if use_dasheng:
folder_name = prompt_text.replace(" ", "_") if prompt_text else ""
matched = _match_dasheng_folder(folder_name, youtube_folders)
if matched:
fn, fp = matched
bl_youtube = _collect_block(youtube_files, fp)
else:
for fn, fp in youtube_folders:
if fn.startswith(rel_prefix):
bl_youtube = _collect_block(youtube_files, fp)
break
if use_dasheng:
folder_name = prompt_text.replace(" ", "_") if prompt_text else ""
# Ours: list prompt-named dirs under inner (exclude baseline, gaussian_baseline, youtube_noise_baseline)
skip = {"baseline", "youtube_noise_baseline", "gaussian_baseline"}
inner_dirs = set()
for f in files:
if not f.startswith(inner + "/"):
continue
rest = f.replace(inner + "/", "", 1)
if "/" in rest:
top = rest.split("/")[0]
if top not in skip and not top.startswith("generated_baseline"):
inner_dirs.add(top)
inner_folders = [(d, inner + "/" + d) for d in sorted(inner_dirs)]
ours_fn_fp = _match_dasheng_folder(folder_name, inner_folders)
if ours_fn_fp:
fn, fp = ours_fn_fp
nn_files = [f for f in files if f.startswith(fp + "/")]
ours_block = _collect_block(nn_files, fp)
else:
ours_block = {"spec": None, "bg_wav": None, "fg_wav": None, "m_wav": None}
else:
nn_files = [f for f in files if f.replace(inner + "/", "").startswith(rel_prefix)]
ours_block = _collect_block(nn_files, inner + "/" + rel_prefix)
result[f"block{i}"] = {
"prompt": prompt_text,
"baseline_original": bl_orig,
"baseline_gaussian": gaussian_block,
"baseline_youtube": bl_youtube,
"ours": ours_block,
}
return result