File size: 4,461 Bytes
3157241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""
Config persistence: save/load optimized pipeline configs to HuggingFace Hub.

Each config is a JSON file stored in a dataset repo with:
  - Pipeline parameters
  - Evaluation scores (overall, SI-SDR, sample quality, etc.)
  - Metadata (author, timestamp, description)

The leaderboard ranks configs by their evaluation scores across test songs.
"""

import json, time, io, os
from dataclasses import dataclass, asdict
from typing import Optional


CONFIGS_REPO = "rikhoffbauer2/sample-extractor-configs"
CONFIGS_REPO_TYPE = "dataset"


@dataclass
class PipelineConfig:
    """Complete pipeline configuration with eval scores."""
    # Identity
    name: str = "default"
    description: str = ""
    author: str = ""
    timestamp: str = ""
    version: str = "1.0"

    # Pipeline params
    stem: str = "drums"
    onset_mode: str = "auto"
    pre_pad: float = 0.005
    min_dur: float = 0.02
    max_dur: float = 1.5
    min_gap: float = 0.015
    energy_threshold_db: float = -45.0
    separate_overlaps: bool = True
    overlap_threshold: float = 0.15
    synthesize: bool = True

    # Eval scores (filled after evaluation)
    overall_score: float = 0.0
    mean_si_sdr: float = 0.0
    mean_sample_score: float = 0.0
    mean_env_corr: float = 0.0
    mean_onset_error_ms: float = 50.0
    n_test_songs: int = 0

    def to_dict(self) -> dict:
        return asdict(self)

    @classmethod
    def from_dict(cls, d: dict) -> 'PipelineConfig':
        valid = cls.__dataclass_fields__.keys()
        return cls(**{k: v for k, v in d.items() if k in valid})

    def to_json(self) -> str:
        return json.dumps(self.to_dict(), indent=2)


def save_config(config: PipelineConfig, token: str = None) -> str:
    """Save a config to the HF dataset repo. Returns the file path in repo."""
    from huggingface_hub import HfApi

    api = HfApi(token=token)

    # Ensure repo exists
    try:
        api.create_repo(CONFIGS_REPO, repo_type=CONFIGS_REPO_TYPE, exist_ok=True)
    except Exception:
        pass

    config.timestamp = time.strftime('%Y-%m-%dT%H:%M:%SZ')
    path = f"configs/{config.name}.json"

    api.upload_file(
        path_or_fileobj=io.BytesIO(config.to_json().encode()),
        path_in_repo=path,
        repo_id=CONFIGS_REPO,
        repo_type=CONFIGS_REPO_TYPE,
        commit_message=f"Config: {config.name} (score={config.overall_score:.1f})",
    )
    print(f"  ✓ Saved config '{config.name}' → {CONFIGS_REPO}/{path}")
    return path


def load_config(name: str, token: str = None) -> PipelineConfig:
    """Load a config by name from the HF dataset repo."""
    from huggingface_hub import hf_hub_download

    path = hf_hub_download(
        repo_id=CONFIGS_REPO,
        filename=f"configs/{name}.json",
        repo_type=CONFIGS_REPO_TYPE,
        token=token,
    )
    with open(path) as f:
        return PipelineConfig.from_dict(json.load(f))


def list_configs(token: str = None) -> list:
    """List all available configs with their scores."""
    from huggingface_hub import HfApi, hf_hub_download

    api = HfApi(token=token)
    try:
        files = list(api.list_repo_files(CONFIGS_REPO, repo_type=CONFIGS_REPO_TYPE))
    except Exception:
        return []

    configs = []
    for f in files:
        if f.startswith("configs/") and f.endswith(".json"):
            try:
                path = hf_hub_download(repo_id=CONFIGS_REPO, filename=f,
                                        repo_type=CONFIGS_REPO_TYPE, token=token)
                with open(path) as fh:
                    cfg = PipelineConfig.from_dict(json.load(fh))
                    configs.append(cfg)
            except Exception:
                continue

    # Sort by score descending
    configs.sort(key=lambda c: c.overall_score, reverse=True)
    return configs


def get_leaderboard(token: str = None) -> list:
    """Get leaderboard as list of dicts for display."""
    configs = list_configs(token)
    return [
        {
            'Rank': i + 1,
            'Name': c.name,
            'Score': f"{c.overall_score:.1f}",
            'SI-SDR': f"{c.mean_si_sdr:.1f}",
            'Sample Q': f"{c.mean_sample_score:.1f}",
            'Env Corr': f"{c.mean_env_corr:.3f}",
            'Onset (ms)': f"{c.mean_onset_error_ms:.1f}",
            'Tests': c.n_test_songs,
            'Author': c.author,
            'Date': c.timestamp[:10] if c.timestamp else '',
        }
        for i, c in enumerate(configs)
    ]