File size: 8,658 Bytes
1d29056 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 | """
Generalized optimizer: tests across diverse synthetic songs, saves best config.
Changes from v1:
- Tests each config against MULTIPLE songs (rock/funk/halftime/vocal/sfx)
- Averages scores across all test songs for robust evaluation
- Saves winning configs to HF Hub with leaderboard scores
- Diagnostic-driven parameter tuning (same as before, improved)
"""
import json, time, traceback, numpy as np
from copy import deepcopy
from dataclasses import dataclass, field
from synth_generator import generate_test_song
from config_store import PipelineConfig, save_config
@dataclass
class IterationResult:
iteration: int
params: dict
scores: dict # {pattern: score}
avg_score: float
duration_s: float
changes: list
timestamp: str = ""
@dataclass
class OptimizerState:
history: list = field(default_factory=list)
best_config: dict = field(default_factory=dict)
best_score: float = 0.0
iteration: int = 0
def run_extraction_eval(song, config: PipelineConfig):
"""Run extraction + evaluation on a single song. Returns eval dict."""
from sample_extractor import (detect_onsets, classify_and_separate,
compute_embeddings, cluster_hits, select_best,
synthesize_from_cluster, sample_quality_score)
from evaluation import evaluate_extraction
hits = detect_onsets(song.drums_only, song.sr, pre_pad=config.pre_pad,
min_dur=config.min_dur, max_dur=config.max_dur,
min_gap=config.min_gap, energy_threshold_db=config.energy_threshold_db,
mode=config.onset_mode)
if not hits:
return {'overall_score': 0, 'mean_si_sdr': -50, 'mean_sample_score': 0,
'mean_env_corr': 0, 'mean_onset_error_ms': 50, 'hit_count_accuracy': 0}
hits = classify_and_separate(hits, separate_overlaps=config.separate_overlaps,
overlap_threshold=config.overlap_threshold)
embs = compute_embeddings(hits)
clusters = cluster_hits(hits, embs)
select_best(clusters)
if config.synthesize:
for c in clusters:
if c.count >= 2:
c.synthesized = synthesize_from_cluster(c)
gt_samples = {name: s.audio for name, s in song.samples.items()}
gt_hits = [{'sample': h.sample_name, 'onset': h.onset_time, 'velocity': h.velocity}
for h in song.hits]
report = evaluate_extraction(extracted_clusters=clusters, gt_samples=gt_samples,
gt_hit_map=gt_hits, sr=song.sr, all_hits=hits)
return {
'overall_score': report.overall_score,
'mean_si_sdr': report.mean_si_sdr,
'mean_sample_score': report.mean_sample_score,
'mean_env_corr': report.mean_env_corr,
'mean_onset_error_ms': report.mean_onset_error_ms,
'hit_count_accuracy': report.hit_count_accuracy,
}
def eval_config_across_songs(config: PipelineConfig, seeds: list, patterns: list,
bpms: list) -> dict:
"""Evaluate a config across multiple test songs. Returns averaged metrics."""
all_scores = []
for seed, pattern, bpm in zip(seeds, patterns, bpms):
try:
song = generate_test_song(pattern_name=pattern, bars=4, bpm=bpm,
variation='medium', seed=seed)
result = run_extraction_eval(song, config)
all_scores.append(result)
except Exception as e:
all_scores.append({'overall_score': 0, 'mean_si_sdr': -50,
'mean_sample_score': 0, 'mean_env_corr': 0,
'mean_onset_error_ms': 50, 'hit_count_accuracy': 0})
# Average across all songs
avg = {}
for key in all_scores[0]:
vals = [s[key] for s in all_scores]
avg[key] = float(np.mean(vals))
avg['n_songs'] = len(all_scores)
avg['per_song'] = all_scores
return avg
def diagnose_and_perturb(config: PipelineConfig, metrics: dict, rng) -> tuple:
"""Diagnose issues from metrics and perturb config. Returns (new_config, changes)."""
c = PipelineConfig.from_dict(config.to_dict())
changes = []
if metrics.get('mean_onset_error_ms', 0) > 20:
c.pre_pad = max(0.001, config.pre_pad * rng.uniform(0.5, 0.9))
c.min_gap = max(0.008, config.min_gap * rng.uniform(0.6, 0.9))
changes.append(f"onset_err={metrics['mean_onset_error_ms']:.0f}ms → tightened timing")
if metrics.get('hit_count_accuracy', 1) < 0.7:
c.energy_threshold_db = max(-65, config.energy_threshold_db - rng.uniform(2, 8))
c.min_dur = max(0.008, config.min_dur * rng.uniform(0.5, 0.8))
changes.append(f"hit_acc={metrics['hit_count_accuracy']:.2f} → lowered threshold")
if metrics.get('mean_si_sdr', 0) < 5:
c.overlap_threshold += rng.uniform(-0.05, 0.05)
c.overlap_threshold = np.clip(c.overlap_threshold, 0.05, 0.4)
changes.append(f"SI-SDR={metrics['mean_si_sdr']:.1f} → adjusted overlap")
if metrics.get('mean_env_corr', 1) < 0.7:
c.max_dur = min(2.0, config.max_dur * rng.uniform(1.1, 1.3))
changes.append(f"env_corr={metrics['mean_env_corr']:.2f} → increased max_dur")
if not changes:
c.energy_threshold_db += rng.uniform(-3, 3)
c.pre_pad = max(0.001, c.pre_pad + rng.uniform(-0.002, 0.002))
c.min_dur = max(0.008, c.min_dur + rng.uniform(-0.005, 0.005))
changes.append("random exploration")
return c, changes
def run_optimization(n_iterations: int = 10, config_name: str = "optimized",
author: str = "", save_to_hub: bool = True,
seed: int = 42, log_fn=None) -> OptimizerState:
"""Run optimization loop, testing each config across diverse songs."""
rng = np.random.RandomState(seed)
state = OptimizerState()
# Test suite: diverse songs
test_patterns = ['rock', 'funk', 'halftime'] * 2 # 6 songs
test_seeds = [seed + i * 17 for i in range(6)]
test_bpms = [120, 100, 140, 130, 110, 150]
config = PipelineConfig(name=config_name, author=author)
def log(msg):
if log_fn: log_fn(msg)
print(msg)
log(f"Optimization: {n_iterations} iters, {len(test_patterns)} test songs each")
for i in range(n_iterations):
t0 = time.time()
log(f"\n{'='*50}\nITERATION {i+1}/{n_iterations}\n{'='*50}")
try:
log(f" Testing config across {len(test_patterns)} songs...")
metrics = eval_config_across_songs(config, test_seeds, test_patterns, test_bpms)
avg_score = metrics['overall_score']
log(f" Score: {avg_score:.1f}/100 (SI-SDR={metrics['mean_si_sdr']:.1f}, "
f"sample={metrics['mean_sample_score']:.1f}, "
f"env={metrics['mean_env_corr']:.2f})")
if avg_score > state.best_score:
state.best_score = avg_score
state.best_config = config.to_dict()
log(f" ★ NEW BEST: {avg_score:.1f}")
# Perturb
new_config, changes = diagnose_and_perturb(config, metrics, rng)
log(f" Changes: {'; '.join(changes)}")
state.history.append(IterationResult(
iteration=i, params=config.to_dict(),
scores={f"song_{j}": s['overall_score']
for j, s in enumerate(metrics.get('per_song', []))},
avg_score=avg_score, duration_s=time.time()-t0,
changes=changes, timestamp=time.strftime('%Y-%m-%d %H:%M:%S'),
))
config = new_config
except Exception as e:
log(f" ERROR: {e}")
config.energy_threshold_db += rng.uniform(-5, 5)
state.history.append(IterationResult(
iteration=i, params=config.to_dict(), scores={},
avg_score=0, duration_s=time.time()-t0, changes=[str(e)],
))
state.iteration = i + 1
# Save best config
if save_to_hub and state.best_config:
log(f"\nSaving best config (score={state.best_score:.1f})...")
best = PipelineConfig.from_dict(state.best_config)
best.name = config_name
best.author = author
best.overall_score = state.best_score
best.n_test_songs = len(test_patterns)
try:
save_config(best)
log(f" ✓ Saved to {best.name}")
except Exception as e:
log(f" ⚠ Could not save to Hub: {e}")
log(f"\nBest score: {state.best_score:.1f}/100")
return state
|