rikhoffbauer2 commited on
Commit
1d29056
·
verified ·
1 Parent(s): 3157241

v2: Update optimizer_v2.py

Browse files
Files changed (1) hide show
  1. optimizer_v2.py +214 -0
optimizer_v2.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generalized optimizer: tests across diverse synthetic songs, saves best config.
3
+
4
+ Changes from v1:
5
+ - Tests each config against MULTIPLE songs (rock/funk/halftime/vocal/sfx)
6
+ - Averages scores across all test songs for robust evaluation
7
+ - Saves winning configs to HF Hub with leaderboard scores
8
+ - Diagnostic-driven parameter tuning (same as before, improved)
9
+ """
10
+
11
+ import json, time, traceback, numpy as np
12
+ from copy import deepcopy
13
+ from dataclasses import dataclass, field
14
+
15
+ from synth_generator import generate_test_song
16
+ from config_store import PipelineConfig, save_config
17
+
18
+
19
+ @dataclass
20
+ class IterationResult:
21
+ iteration: int
22
+ params: dict
23
+ scores: dict # {pattern: score}
24
+ avg_score: float
25
+ duration_s: float
26
+ changes: list
27
+ timestamp: str = ""
28
+
29
+
30
+ @dataclass
31
+ class OptimizerState:
32
+ history: list = field(default_factory=list)
33
+ best_config: dict = field(default_factory=dict)
34
+ best_score: float = 0.0
35
+ iteration: int = 0
36
+
37
+
38
+ def run_extraction_eval(song, config: PipelineConfig):
39
+ """Run extraction + evaluation on a single song. Returns eval dict."""
40
+ from sample_extractor import (detect_onsets, classify_and_separate,
41
+ compute_embeddings, cluster_hits, select_best,
42
+ synthesize_from_cluster, sample_quality_score)
43
+ from evaluation import evaluate_extraction
44
+
45
+ hits = detect_onsets(song.drums_only, song.sr, pre_pad=config.pre_pad,
46
+ min_dur=config.min_dur, max_dur=config.max_dur,
47
+ min_gap=config.min_gap, energy_threshold_db=config.energy_threshold_db,
48
+ mode=config.onset_mode)
49
+ if not hits:
50
+ return {'overall_score': 0, 'mean_si_sdr': -50, 'mean_sample_score': 0,
51
+ 'mean_env_corr': 0, 'mean_onset_error_ms': 50, 'hit_count_accuracy': 0}
52
+
53
+ hits = classify_and_separate(hits, separate_overlaps=config.separate_overlaps,
54
+ overlap_threshold=config.overlap_threshold)
55
+ embs = compute_embeddings(hits)
56
+ clusters = cluster_hits(hits, embs)
57
+ select_best(clusters)
58
+
59
+ if config.synthesize:
60
+ for c in clusters:
61
+ if c.count >= 2:
62
+ c.synthesized = synthesize_from_cluster(c)
63
+
64
+ gt_samples = {name: s.audio for name, s in song.samples.items()}
65
+ gt_hits = [{'sample': h.sample_name, 'onset': h.onset_time, 'velocity': h.velocity}
66
+ for h in song.hits]
67
+
68
+ report = evaluate_extraction(extracted_clusters=clusters, gt_samples=gt_samples,
69
+ gt_hit_map=gt_hits, sr=song.sr, all_hits=hits)
70
+ return {
71
+ 'overall_score': report.overall_score,
72
+ 'mean_si_sdr': report.mean_si_sdr,
73
+ 'mean_sample_score': report.mean_sample_score,
74
+ 'mean_env_corr': report.mean_env_corr,
75
+ 'mean_onset_error_ms': report.mean_onset_error_ms,
76
+ 'hit_count_accuracy': report.hit_count_accuracy,
77
+ }
78
+
79
+
80
+ def eval_config_across_songs(config: PipelineConfig, seeds: list, patterns: list,
81
+ bpms: list) -> dict:
82
+ """Evaluate a config across multiple test songs. Returns averaged metrics."""
83
+ all_scores = []
84
+ for seed, pattern, bpm in zip(seeds, patterns, bpms):
85
+ try:
86
+ song = generate_test_song(pattern_name=pattern, bars=4, bpm=bpm,
87
+ variation='medium', seed=seed)
88
+ result = run_extraction_eval(song, config)
89
+ all_scores.append(result)
90
+ except Exception as e:
91
+ all_scores.append({'overall_score': 0, 'mean_si_sdr': -50,
92
+ 'mean_sample_score': 0, 'mean_env_corr': 0,
93
+ 'mean_onset_error_ms': 50, 'hit_count_accuracy': 0})
94
+
95
+ # Average across all songs
96
+ avg = {}
97
+ for key in all_scores[0]:
98
+ vals = [s[key] for s in all_scores]
99
+ avg[key] = float(np.mean(vals))
100
+ avg['n_songs'] = len(all_scores)
101
+ avg['per_song'] = all_scores
102
+ return avg
103
+
104
+
105
+ def diagnose_and_perturb(config: PipelineConfig, metrics: dict, rng) -> tuple:
106
+ """Diagnose issues from metrics and perturb config. Returns (new_config, changes)."""
107
+ c = PipelineConfig.from_dict(config.to_dict())
108
+ changes = []
109
+
110
+ if metrics.get('mean_onset_error_ms', 0) > 20:
111
+ c.pre_pad = max(0.001, config.pre_pad * rng.uniform(0.5, 0.9))
112
+ c.min_gap = max(0.008, config.min_gap * rng.uniform(0.6, 0.9))
113
+ changes.append(f"onset_err={metrics['mean_onset_error_ms']:.0f}ms → tightened timing")
114
+
115
+ if metrics.get('hit_count_accuracy', 1) < 0.7:
116
+ c.energy_threshold_db = max(-65, config.energy_threshold_db - rng.uniform(2, 8))
117
+ c.min_dur = max(0.008, config.min_dur * rng.uniform(0.5, 0.8))
118
+ changes.append(f"hit_acc={metrics['hit_count_accuracy']:.2f} → lowered threshold")
119
+
120
+ if metrics.get('mean_si_sdr', 0) < 5:
121
+ c.overlap_threshold += rng.uniform(-0.05, 0.05)
122
+ c.overlap_threshold = np.clip(c.overlap_threshold, 0.05, 0.4)
123
+ changes.append(f"SI-SDR={metrics['mean_si_sdr']:.1f} → adjusted overlap")
124
+
125
+ if metrics.get('mean_env_corr', 1) < 0.7:
126
+ c.max_dur = min(2.0, config.max_dur * rng.uniform(1.1, 1.3))
127
+ changes.append(f"env_corr={metrics['mean_env_corr']:.2f} → increased max_dur")
128
+
129
+ if not changes:
130
+ c.energy_threshold_db += rng.uniform(-3, 3)
131
+ c.pre_pad = max(0.001, c.pre_pad + rng.uniform(-0.002, 0.002))
132
+ c.min_dur = max(0.008, c.min_dur + rng.uniform(-0.005, 0.005))
133
+ changes.append("random exploration")
134
+
135
+ return c, changes
136
+
137
+
138
+ def run_optimization(n_iterations: int = 10, config_name: str = "optimized",
139
+ author: str = "", save_to_hub: bool = True,
140
+ seed: int = 42, log_fn=None) -> OptimizerState:
141
+ """Run optimization loop, testing each config across diverse songs."""
142
+ rng = np.random.RandomState(seed)
143
+ state = OptimizerState()
144
+
145
+ # Test suite: diverse songs
146
+ test_patterns = ['rock', 'funk', 'halftime'] * 2 # 6 songs
147
+ test_seeds = [seed + i * 17 for i in range(6)]
148
+ test_bpms = [120, 100, 140, 130, 110, 150]
149
+
150
+ config = PipelineConfig(name=config_name, author=author)
151
+
152
+ def log(msg):
153
+ if log_fn: log_fn(msg)
154
+ print(msg)
155
+
156
+ log(f"Optimization: {n_iterations} iters, {len(test_patterns)} test songs each")
157
+
158
+ for i in range(n_iterations):
159
+ t0 = time.time()
160
+ log(f"\n{'='*50}\nITERATION {i+1}/{n_iterations}\n{'='*50}")
161
+
162
+ try:
163
+ log(f" Testing config across {len(test_patterns)} songs...")
164
+ metrics = eval_config_across_songs(config, test_seeds, test_patterns, test_bpms)
165
+ avg_score = metrics['overall_score']
166
+
167
+ log(f" Score: {avg_score:.1f}/100 (SI-SDR={metrics['mean_si_sdr']:.1f}, "
168
+ f"sample={metrics['mean_sample_score']:.1f}, "
169
+ f"env={metrics['mean_env_corr']:.2f})")
170
+
171
+ if avg_score > state.best_score:
172
+ state.best_score = avg_score
173
+ state.best_config = config.to_dict()
174
+ log(f" ★ NEW BEST: {avg_score:.1f}")
175
+
176
+ # Perturb
177
+ new_config, changes = diagnose_and_perturb(config, metrics, rng)
178
+ log(f" Changes: {'; '.join(changes)}")
179
+
180
+ state.history.append(IterationResult(
181
+ iteration=i, params=config.to_dict(),
182
+ scores={f"song_{j}": s['overall_score']
183
+ for j, s in enumerate(metrics.get('per_song', []))},
184
+ avg_score=avg_score, duration_s=time.time()-t0,
185
+ changes=changes, timestamp=time.strftime('%Y-%m-%d %H:%M:%S'),
186
+ ))
187
+ config = new_config
188
+
189
+ except Exception as e:
190
+ log(f" ERROR: {e}")
191
+ config.energy_threshold_db += rng.uniform(-5, 5)
192
+ state.history.append(IterationResult(
193
+ iteration=i, params=config.to_dict(), scores={},
194
+ avg_score=0, duration_s=time.time()-t0, changes=[str(e)],
195
+ ))
196
+
197
+ state.iteration = i + 1
198
+
199
+ # Save best config
200
+ if save_to_hub and state.best_config:
201
+ log(f"\nSaving best config (score={state.best_score:.1f})...")
202
+ best = PipelineConfig.from_dict(state.best_config)
203
+ best.name = config_name
204
+ best.author = author
205
+ best.overall_score = state.best_score
206
+ best.n_test_songs = len(test_patterns)
207
+ try:
208
+ save_config(best)
209
+ log(f" ✓ Saved to {best.name}")
210
+ except Exception as e:
211
+ log(f" ⚠ Could not save to Hub: {e}")
212
+
213
+ log(f"\nBest score: {state.best_score:.1f}/100")
214
+ return state