Spaces:
Sleeping
Sleeping
| """ | |
| Config persistence: save/load optimized pipeline configs to HuggingFace Hub. | |
| Each config is a JSON file stored in a dataset repo with: | |
| - Pipeline parameters | |
| - Evaluation scores (overall, SI-SDR, sample quality, etc.) | |
| - Metadata (author, timestamp, description) | |
| The leaderboard ranks configs by their evaluation scores across test songs. | |
| """ | |
| import json, time, io, os | |
| from dataclasses import dataclass, asdict | |
| from typing import Optional | |
| CONFIGS_REPO = "rikhoffbauer2/sample-extractor-configs" | |
| CONFIGS_REPO_TYPE = "dataset" | |
| class PipelineConfig: | |
| """Complete pipeline configuration with eval scores.""" | |
| # Identity | |
| name: str = "default" | |
| description: str = "" | |
| author: str = "" | |
| timestamp: str = "" | |
| version: str = "1.0" | |
| # Pipeline params | |
| stem: str = "drums" | |
| onset_mode: str = "auto" | |
| pre_pad: float = 0.005 | |
| min_dur: float = 0.02 | |
| max_dur: float = 1.5 | |
| min_gap: float = 0.015 | |
| energy_threshold_db: float = -45.0 | |
| separate_overlaps: bool = True | |
| overlap_threshold: float = 0.15 | |
| synthesize: bool = True | |
| # Eval scores (filled after evaluation) | |
| overall_score: float = 0.0 | |
| mean_si_sdr: float = 0.0 | |
| mean_sample_score: float = 0.0 | |
| mean_env_corr: float = 0.0 | |
| mean_onset_error_ms: float = 50.0 | |
| n_test_songs: int = 0 | |
| def to_dict(self) -> dict: | |
| return asdict(self) | |
| def from_dict(cls, d: dict) -> 'PipelineConfig': | |
| valid = cls.__dataclass_fields__.keys() | |
| return cls(**{k: v for k, v in d.items() if k in valid}) | |
| def to_json(self) -> str: | |
| return json.dumps(self.to_dict(), indent=2) | |
| def save_config(config: PipelineConfig, token: str = None) -> str: | |
| """Save a config to the HF dataset repo. Returns the file path in repo.""" | |
| from huggingface_hub import HfApi | |
| api = HfApi(token=token) | |
| # Ensure repo exists | |
| try: | |
| api.create_repo(CONFIGS_REPO, repo_type=CONFIGS_REPO_TYPE, exist_ok=True) | |
| except Exception: | |
| pass | |
| config.timestamp = time.strftime('%Y-%m-%dT%H:%M:%SZ') | |
| path = f"configs/{config.name}.json" | |
| api.upload_file( | |
| path_or_fileobj=io.BytesIO(config.to_json().encode()), | |
| path_in_repo=path, | |
| repo_id=CONFIGS_REPO, | |
| repo_type=CONFIGS_REPO_TYPE, | |
| commit_message=f"Config: {config.name} (score={config.overall_score:.1f})", | |
| ) | |
| print(f" ✓ Saved config '{config.name}' → {CONFIGS_REPO}/{path}") | |
| return path | |
| def load_config(name: str, token: str = None) -> PipelineConfig: | |
| """Load a config by name from the HF dataset repo.""" | |
| from huggingface_hub import hf_hub_download | |
| path = hf_hub_download( | |
| repo_id=CONFIGS_REPO, | |
| filename=f"configs/{name}.json", | |
| repo_type=CONFIGS_REPO_TYPE, | |
| token=token, | |
| ) | |
| with open(path) as f: | |
| return PipelineConfig.from_dict(json.load(f)) | |
| def list_configs(token: str = None) -> list: | |
| """List all available configs with their scores.""" | |
| from huggingface_hub import HfApi, hf_hub_download | |
| api = HfApi(token=token) | |
| try: | |
| files = list(api.list_repo_files(CONFIGS_REPO, repo_type=CONFIGS_REPO_TYPE)) | |
| except Exception: | |
| return [] | |
| configs = [] | |
| for f in files: | |
| if f.startswith("configs/") and f.endswith(".json"): | |
| try: | |
| path = hf_hub_download(repo_id=CONFIGS_REPO, filename=f, | |
| repo_type=CONFIGS_REPO_TYPE, token=token) | |
| with open(path) as fh: | |
| cfg = PipelineConfig.from_dict(json.load(fh)) | |
| configs.append(cfg) | |
| except Exception: | |
| continue | |
| # Sort by score descending | |
| configs.sort(key=lambda c: c.overall_score, reverse=True) | |
| return configs | |
| def get_leaderboard(token: str = None) -> list: | |
| """Get leaderboard as list of dicts for display.""" | |
| configs = list_configs(token) | |
| return [ | |
| { | |
| 'Rank': i + 1, | |
| 'Name': c.name, | |
| 'Score': f"{c.overall_score:.1f}", | |
| 'SI-SDR': f"{c.mean_si_sdr:.1f}", | |
| 'Sample Q': f"{c.mean_sample_score:.1f}", | |
| 'Env Corr': f"{c.mean_env_corr:.3f}", | |
| 'Onset (ms)': f"{c.mean_onset_error_ms:.1f}", | |
| 'Tests': c.n_test_songs, | |
| 'Author': c.author, | |
| 'Date': c.timestamp[:10] if c.timestamp else '', | |
| } | |
| for i, c in enumerate(configs) | |
| ] | |