File size: 2,187 Bytes
40b18c2
ffdea96
40b18c2
 
ffdea96
40b18c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffdea96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import numpy as np
from scipy.signal import freqz
from typing import Iterable
from functools import reduce

from modules import fx
from modules.functional import (
    highpass_biquad_coef,
    lowpass_biquad_coef,
    highshelf_biquad_coef,
    lowshelf_biquad_coef,
    equalizer_biquad_coef,
)


def get_log_mags_from_eq(eq: Iterable, worN=1024, sr=44100):
    get_ba = lambda xs: torch.cat([x.view(1) for x in xs]).view(2, 3)

    def f(biquad):
        params = biquad.params
        match type(biquad):
            case fx.HighPass:
                coeffs = highpass_biquad_coef(sr, params.freq, params.Q)
            case fx.LowPass:
                coeffs = lowpass_biquad_coef(sr, params.freq, params.Q)
            case fx.HighShelf:
                coeffs = highshelf_biquad_coef(sr, params.freq, params.gain, biquad.Q)
            case fx.LowShelf:
                coeffs = lowshelf_biquad_coef(sr, params.freq, params.gain, biquad.Q)
            case fx.Peak:
                coeffs = equalizer_biquad_coef(sr, params.freq, params.gain, params.Q)
            case _:
                raise ValueError(biquad)

        b, a = get_ba(coeffs)
        w, h = freqz(b.numpy(), a.numpy(), worN, fs=sr)
        log_h = 20 * np.log10(np.abs(h) + 1e-10)
        return w, log_h

    log_mags = list(map(f, eq))
    return log_mags[0][0], [x for _, x in log_mags]


jsonparse2hydra = lambda d: (
    (
        {"_target_": d["class_path"]}
        | (
            {k: jsonparse2hydra(v) for k, v in d["init_args"].items()}
            if "init_args" in d
            else {}
        )
        if "class_path" in d
        else {k: jsonparse2hydra(v) for k, v in d.items()}
    )
    if isinstance(d, dict)
    else (list(map(jsonparse2hydra, d)) if isinstance(d, list) else d)
)

remove_window_fn = lambda d: (
    {k: remove_window_fn(v) for k, v in d.items() if k != "window_fn"}
    if isinstance(d, dict)
    else (list(map(remove_window_fn, d)) if isinstance(d, list) else d)
)


def chain_functions(*functions):
    return lambda *initial_args: reduce(
        lambda xs, f: f(*xs) if isinstance(xs, tuple) else f(xs),
        functions,
        initial_args,
    )