Spaces:
Build error
Build error
| from typing import Dict | |
| from data_generation.data_generation import param_descriptions | |
| import numpy as np | |
| from melody_synth.melody_generator import MelodyGenerator | |
| from melody_synth.random_midi import RandomMidi | |
| def decode_label(prediction: np.ndarray, | |
| sample_rate: int, | |
| n_samples: int, | |
| return_params=False, | |
| discard_parameters=[]): | |
| """Parses a network prediction array, synthesizes the described audio and returns it. | |
| Parameters | |
| ---------- | |
| prediction: np.ndarray | |
| The network prediction array | |
| sample_rate: int | |
| Sample rate of the audio to generate. | |
| n_samples: int | |
| Number of samples per wav file. | |
| return_params: bool | |
| Whether or not to also return the parameters alongside the signal | |
| discard_parameters: List[str] | |
| Parameter names that should be discarded (set to their default value) | |
| Returns | |
| ------- | |
| np.ndarray: | |
| The generated signal | |
| """ | |
| params: Dict[str, float] = {} | |
| index = 0 | |
| for i, param_description in enumerate(param_descriptions): | |
| # Parses the one-hot-encoding of the prediction array | |
| bits = len(param_description.values) | |
| curr_prediction = prediction[index:index + bits] | |
| hot_index = curr_prediction.argmax() | |
| params[param_description.name] = param_description.parameter_value(hot_index).value | |
| index += bits | |
| for param_str in discard_parameters: | |
| params[param_str] = 0 # todo: make this safe and change to default value and not just 0 | |
| synth = MelodyGenerator(sample_rate, | |
| n_samples, n_samples) | |
| randomMidi = RandomMidi() | |
| strategy = {"rhythm_strategy": "single_note_rhythm", | |
| "pitch_strategy": "fixed_pitch", | |
| "duration_strategy": "fixed_duration", | |
| } | |
| midi_encode, midi = randomMidi(strategy) | |
| signal = synth.get_melody(params, midi=midi).numpy() | |
| if return_params: | |
| return signal, params | |
| return signal | |