Spaces:
Build error
Build error
| import torch | |
| import numpy as np | |
| import torch.multiprocessing as mp | |
| from deepafx_st.processors.dsp.peq import ParametricEQ | |
| from deepafx_st.processors.dsp.compressor import Compressor | |
| from deepafx_st.processors.spsa.spsa_func import SPSAFunction | |
| from deepafx_st.utils import rademacher | |
| def dsp_func(x, p, dsp, sample_rate=24000): | |
| (peq, comp), meta = dsp | |
| p_peq = p[:meta] | |
| p_comp = p[meta:] | |
| y = peq(x, p_peq, sample_rate) | |
| y = comp(y, p_comp, sample_rate) | |
| return y | |
| class SPSAChannel(torch.nn.Module): | |
| """ | |
| Args: | |
| sample_rate (float): Sample rate of the plugin instance | |
| parallel (bool, optional): Use parallel workers for DSP. | |
| By default, this utilizes parallelized instances of the plugin channel, | |
| where the number of workers is equal to the batch size. | |
| """ | |
| def __init__( | |
| self, | |
| sample_rate: int, | |
| parallel: bool = False, | |
| batch_size: int = 8, | |
| ): | |
| super().__init__() | |
| self.batch_size = batch_size | |
| self.parallel = parallel | |
| if self.parallel: | |
| self.apply_func = SPSAFunction.apply | |
| procs = {} | |
| for b in range(self.batch_size): | |
| peq = ParametricEQ(sample_rate) | |
| comp = Compressor(sample_rate) | |
| dsp = ((peq, comp), peq.num_control_params) | |
| parent_conn, child_conn = mp.Pipe() | |
| p = mp.Process(target=SPSAChannel.worker_pipe, args=(child_conn, dsp)) | |
| p.start() | |
| procs[b] = [p, parent_conn, child_conn] | |
| #print(b, p) | |
| # Update stuff for external public members TODO: fix | |
| self.ports = [peq.ports, comp.ports] | |
| self.num_control_params = ( | |
| comp.num_control_params + peq.num_control_params | |
| ) | |
| self.procs = procs | |
| #print(self.procs) | |
| else: | |
| self.peq = ParametricEQ(sample_rate) | |
| self.comp = Compressor(sample_rate) | |
| self.apply_func = SPSAFunction.apply | |
| self.ports = [self.peq.ports, self.comp.ports] | |
| self.num_control_params = ( | |
| self.comp.num_control_params + self.peq.num_control_params | |
| ) | |
| self.dsp = ((self.peq, self.comp), self.peq.num_control_params) | |
| # add one param for wet/dry mix | |
| # self.num_control_params += 1 | |
| def __del__(self): | |
| if hasattr(self, "procs"): | |
| for proc_idx, proc in self.procs.items(): | |
| #print(f"Closing {proc_idx}...") | |
| proc[0].terminate() | |
| def forward(self, x, p, epsilon=0.001, sample_rate=24000, **kwargs): | |
| """ | |
| Args: | |
| x (Tensor): Input signal with shape: [batch x channels x samples] | |
| p (Tensor): Audio effect control parameters with shape: [batch x parameters] | |
| epsilon (float, optional): Twiddle parameter range for SPSA gradient estimation. | |
| Returns: | |
| y (Tensor): Processed audio signal. | |
| """ | |
| if self.parallel: | |
| y = self.apply_func(x, p, None, epsilon, self, sample_rate) | |
| else: | |
| # this will process on CPU in NumPy | |
| y = self.apply_func(x, p, None, epsilon, self, sample_rate) | |
| return y.type_as(x) | |
| def static_backward(dsp, value): | |
| ( | |
| batch_index, | |
| x, | |
| params, | |
| needs_input_grad, | |
| needs_param_grad, | |
| grad_output, | |
| epsilon, | |
| ) = value | |
| grads_input = None | |
| grads_params = None | |
| ps = params.shape[-1] | |
| factors = [1.0] | |
| # estimate gradient w.r.t input | |
| if needs_input_grad: | |
| delta_k = rademacher(x.shape).numpy() | |
| J_plus = dsp_func(x + epsilon * delta_k, params, dsp) | |
| J_minus = dsp_func(x - epsilon * delta_k, params, dsp) | |
| grads_input = (J_plus - J_minus) / (2.0 * epsilon) | |
| # estimate gradient w.r.t params | |
| grads_params_runs = [] | |
| if needs_param_grad: | |
| for factor in factors: | |
| params_sublist = [] | |
| delta_k = rademacher(params.shape).numpy() | |
| # compute output in two random directions of the parameter space | |
| params_plus = np.clip(params + (factor * epsilon * delta_k), 0, 1) | |
| J_plus = dsp_func(x, params_plus, dsp) | |
| params_minus = np.clip(params - (factor * epsilon * delta_k), 0, 1) | |
| J_minus = dsp_func(x, params_minus, dsp) | |
| grad_param = J_plus - J_minus | |
| # compute gradient for each parameter as a function of epsilon and random direction | |
| for sub_p_idx in range(ps): | |
| grad_p = grad_param / (2 * epsilon * delta_k[sub_p_idx]) | |
| params_sublist.append(np.sum(grad_output * grad_p)) | |
| grads_params = np.array(params_sublist) | |
| grads_params_runs.append(grads_params) | |
| # average gradients | |
| grads_params = np.mean(grads_params_runs, axis=0) | |
| return grads_input, grads_params | |
| def static_forward(dsp, value): | |
| batch_index, x, p, sample_rate = value | |
| y = dsp_func(x, p, dsp, sample_rate) | |
| return y | |
| def worker_pipe(child_conn, dsp): | |
| while True: | |
| msg, value = child_conn.recv() | |
| if msg == "forward": | |
| child_conn.send(SPSAChannel.static_forward(dsp, value)) | |
| elif msg == "backward": | |
| child_conn.send(SPSAChannel.static_backward(dsp, value)) | |
| elif msg == "shutdown": | |
| break | |