Spaces:
Build error
Build error
| import torch | |
| def spsa_func(input, params, process, i, sample_rate=24000): | |
| return process(input.cpu(), params.cpu(), i, sample_rate).type_as(input) | |
| class SPSAFunction(torch.autograd.Function): | |
| def forward( | |
| ctx, | |
| input, | |
| params, | |
| process, | |
| epsilon, | |
| thread_context, | |
| sample_rate=24000, | |
| ): | |
| """Apply processor to a batch of tensors using given parameters. | |
| Args: | |
| input (Tensor): Audio with shape: batch x 2 x samples | |
| params (Tensor): Processor parameters with shape: batch x params | |
| process (function): Function that will apply processing. | |
| epsilon (float): Perturbation strength for SPSA computation. | |
| Returns: | |
| output (Tensor): Processed audio with same shape as input. | |
| """ | |
| ctx.save_for_backward(input, params) | |
| ctx.epsilon = epsilon | |
| ctx.process = process | |
| ctx.thread_context = thread_context | |
| if thread_context.parallel: | |
| for i in range(input.shape[0]): | |
| msg = ( | |
| "forward", | |
| ( | |
| i, | |
| input[i].view(-1).detach().cpu().numpy(), | |
| params[i].view(-1).detach().cpu().numpy(), | |
| sample_rate, | |
| ), | |
| ) | |
| thread_context.procs[i][1].send(msg) | |
| z = torch.empty_like(input) | |
| for i in range(input.shape[0]): | |
| z[i] = torch.from_numpy(thread_context.procs[i][1].recv()) | |
| else: | |
| z = torch.empty_like(input) | |
| for i in range(input.shape[0]): | |
| value = ( | |
| i, | |
| input[i].view(-1).detach().cpu().numpy(), | |
| params[i].view(-1).detach().cpu().numpy(), | |
| sample_rate, | |
| ) | |
| z[i] = torch.from_numpy( | |
| thread_context.static_forward(thread_context.dsp, value) | |
| ) | |
| return z | |
| def backward(ctx, grad_output): | |
| """Estimate gradients using SPSA.""" | |
| input, params = ctx.saved_tensors | |
| epsilon = ctx.epsilon | |
| needs_input_grad = ctx.needs_input_grad[0] | |
| needs_param_grad = ctx.needs_input_grad[1] | |
| thread_context = ctx.thread_context | |
| grads_input = None | |
| grads_params = None | |
| # Receive grads | |
| if needs_input_grad: | |
| grads_input = torch.empty_like(input) | |
| if needs_param_grad: | |
| grads_params = torch.empty_like(params) | |
| if thread_context.parallel: | |
| for i in range(input.shape[0]): | |
| msg = ( | |
| "backward", | |
| ( | |
| i, | |
| input[i].view(-1).detach().cpu().numpy(), | |
| params[i].view(-1).detach().cpu().numpy(), | |
| needs_input_grad, | |
| needs_param_grad, | |
| grad_output[i].view(-1).detach().cpu().numpy(), | |
| epsilon, | |
| ), | |
| ) | |
| thread_context.procs[i][1].send(msg) | |
| # Wait for output | |
| for i in range(input.shape[0]): | |
| temp1, temp2 = thread_context.procs[i][1].recv() | |
| if temp1 is not None: | |
| grads_input[i] = torch.from_numpy(temp1) | |
| if temp2 is not None: | |
| grads_params[i] = torch.from_numpy(temp2) | |
| return grads_input, grads_params, None, None, None, None | |
| else: | |
| for i in range(input.shape[0]): | |
| value = ( | |
| i, | |
| input[i].view(-1).detach().cpu().numpy(), | |
| params[i].view(-1).detach().cpu().numpy(), | |
| needs_input_grad, | |
| needs_param_grad, | |
| grad_output[i].view(-1).detach().cpu().numpy(), | |
| epsilon, | |
| ) | |
| temp1, temp2 = thread_context.static_backward(thread_context.dsp, value) | |
| if temp1 is not None: | |
| grads_input[i] = torch.from_numpy(temp1) | |
| if temp2 is not None: | |
| grads_params[i] = torch.from_numpy(temp2) | |
| return grads_input, grads_params, None, None, None, None | |