import torch import numpy as np from scipy.signal import freqz from typing import Iterable from functools import reduce from modules import fx from modules.functional import ( highpass_biquad_coef, lowpass_biquad_coef, highshelf_biquad_coef, lowshelf_biquad_coef, equalizer_biquad_coef, ) def get_log_mags_from_eq(eq: Iterable, worN=1024, sr=44100): get_ba = lambda xs: torch.cat([x.view(1) for x in xs]).view(2, 3) def f(biquad): params = biquad.params match type(biquad): case fx.HighPass: coeffs = highpass_biquad_coef(sr, params.freq, params.Q) case fx.LowPass: coeffs = lowpass_biquad_coef(sr, params.freq, params.Q) case fx.HighShelf: coeffs = highshelf_biquad_coef(sr, params.freq, params.gain, biquad.Q) case fx.LowShelf: coeffs = lowshelf_biquad_coef(sr, params.freq, params.gain, biquad.Q) case fx.Peak: coeffs = equalizer_biquad_coef(sr, params.freq, params.gain, params.Q) case _: raise ValueError(biquad) b, a = get_ba(coeffs) w, h = freqz(b.numpy(), a.numpy(), worN, fs=sr) log_h = 20 * np.log10(np.abs(h) + 1e-10) return w, log_h log_mags = list(map(f, eq)) return log_mags[0][0], [x for _, x in log_mags] jsonparse2hydra = lambda d: ( ( {"_target_": d["class_path"]} | ( {k: jsonparse2hydra(v) for k, v in d["init_args"].items()} if "init_args" in d else {} ) if "class_path" in d else {k: jsonparse2hydra(v) for k, v in d.items()} ) if isinstance(d, dict) else (list(map(jsonparse2hydra, d)) if isinstance(d, list) else d) ) remove_window_fn = lambda d: ( {k: remove_window_fn(v) for k, v in d.items() if k != "window_fn"} if isinstance(d, dict) else (list(map(remove_window_fn, d)) if isinstance(d, list) else d) ) def chain_functions(*functions): return lambda *initial_args: reduce( lambda xs, f: f(*xs) if isinstance(xs, tuple) else f(xs), functions, initial_args, )