drum-sample-extractor / config_store.py
rikhoffbauer2's picture
v2: Update config_store.py
3157241 verified
"""
Config persistence: save/load optimized pipeline configs to HuggingFace Hub.
Each config is a JSON file stored in a dataset repo with:
- Pipeline parameters
- Evaluation scores (overall, SI-SDR, sample quality, etc.)
- Metadata (author, timestamp, description)
The leaderboard ranks configs by their evaluation scores across test songs.
"""
import json, time, io, os
from dataclasses import dataclass, asdict
from typing import Optional
CONFIGS_REPO = "rikhoffbauer2/sample-extractor-configs"
CONFIGS_REPO_TYPE = "dataset"
@dataclass
class PipelineConfig:
"""Complete pipeline configuration with eval scores."""
# Identity
name: str = "default"
description: str = ""
author: str = ""
timestamp: str = ""
version: str = "1.0"
# Pipeline params
stem: str = "drums"
onset_mode: str = "auto"
pre_pad: float = 0.005
min_dur: float = 0.02
max_dur: float = 1.5
min_gap: float = 0.015
energy_threshold_db: float = -45.0
separate_overlaps: bool = True
overlap_threshold: float = 0.15
synthesize: bool = True
# Eval scores (filled after evaluation)
overall_score: float = 0.0
mean_si_sdr: float = 0.0
mean_sample_score: float = 0.0
mean_env_corr: float = 0.0
mean_onset_error_ms: float = 50.0
n_test_songs: int = 0
def to_dict(self) -> dict:
return asdict(self)
@classmethod
def from_dict(cls, d: dict) -> 'PipelineConfig':
valid = cls.__dataclass_fields__.keys()
return cls(**{k: v for k, v in d.items() if k in valid})
def to_json(self) -> str:
return json.dumps(self.to_dict(), indent=2)
def save_config(config: PipelineConfig, token: str = None) -> str:
"""Save a config to the HF dataset repo. Returns the file path in repo."""
from huggingface_hub import HfApi
api = HfApi(token=token)
# Ensure repo exists
try:
api.create_repo(CONFIGS_REPO, repo_type=CONFIGS_REPO_TYPE, exist_ok=True)
except Exception:
pass
config.timestamp = time.strftime('%Y-%m-%dT%H:%M:%SZ')
path = f"configs/{config.name}.json"
api.upload_file(
path_or_fileobj=io.BytesIO(config.to_json().encode()),
path_in_repo=path,
repo_id=CONFIGS_REPO,
repo_type=CONFIGS_REPO_TYPE,
commit_message=f"Config: {config.name} (score={config.overall_score:.1f})",
)
print(f" ✓ Saved config '{config.name}' → {CONFIGS_REPO}/{path}")
return path
def load_config(name: str, token: str = None) -> PipelineConfig:
"""Load a config by name from the HF dataset repo."""
from huggingface_hub import hf_hub_download
path = hf_hub_download(
repo_id=CONFIGS_REPO,
filename=f"configs/{name}.json",
repo_type=CONFIGS_REPO_TYPE,
token=token,
)
with open(path) as f:
return PipelineConfig.from_dict(json.load(f))
def list_configs(token: str = None) -> list:
"""List all available configs with their scores."""
from huggingface_hub import HfApi, hf_hub_download
api = HfApi(token=token)
try:
files = list(api.list_repo_files(CONFIGS_REPO, repo_type=CONFIGS_REPO_TYPE))
except Exception:
return []
configs = []
for f in files:
if f.startswith("configs/") and f.endswith(".json"):
try:
path = hf_hub_download(repo_id=CONFIGS_REPO, filename=f,
repo_type=CONFIGS_REPO_TYPE, token=token)
with open(path) as fh:
cfg = PipelineConfig.from_dict(json.load(fh))
configs.append(cfg)
except Exception:
continue
# Sort by score descending
configs.sort(key=lambda c: c.overall_score, reverse=True)
return configs
def get_leaderboard(token: str = None) -> list:
"""Get leaderboard as list of dicts for display."""
configs = list_configs(token)
return [
{
'Rank': i + 1,
'Name': c.name,
'Score': f"{c.overall_score:.1f}",
'SI-SDR': f"{c.mean_si_sdr:.1f}",
'Sample Q': f"{c.mean_sample_score:.1f}",
'Env Corr': f"{c.mean_env_corr:.3f}",
'Onset (ms)': f"{c.mean_onset_error_ms:.1f}",
'Tests': c.n_test_songs,
'Author': c.author,
'Date': c.timestamp[:10] if c.timestamp else '',
}
for i, c in enumerate(configs)
]