Spaces:
Runtime error
Runtime error
| import sys | |
| import time | |
| import numpy as np | |
| from keras.activations import relu | |
| from scipy.io.wavfile import read, write | |
| from keras.models import Model, Sequential | |
| from keras.layers import Convolution2D, AtrousConvolution2D, Flatten, Dense, \ | |
| Input, Lambda, merge | |
| def wavenetBlock(n_atrous_filters, atrous_filter_size, atrous_rate, | |
| n_conv_filters, conv_filter_size): | |
| def f(input_): | |
| residual = input_ | |
| tanh_out = AtrousConvolution1D(n_atrous_filters, atrous_filter_size, | |
| atrous_rate=atrous_rate, | |
| border_mode='same', | |
| activation='tanh')(input_) | |
| sigmoid_out = AtrousConvolution1D(n_atrous_filters, atrous_filter_size, | |
| atrous_rate=atrous_rate, | |
| border_mode='same', | |
| activation='sigmoid')(input_) | |
| merged = merge([tanh_out, sigmoid_out], mode='mul') | |
| skip_out = Convolution1D(1, 1, activation='relu', border_mode='same')(merged) | |
| out = merge([skip_out, residual], mode='sum') | |
| return out, skip_out | |
| return f | |
| def get_basic_generative_model(input_size): | |
| input = Input(shape=(1, input_size, 1)) | |
| l1a, l1b = wavenetBlock(10, 5, 2, 1, 3)(input) | |
| l2a, l2b = wavenetBlock(1, 2, 4, 1, 3)(l1a) | |
| l3a, l3b = wavenetBlock(1, 2, 8, 1, 3)(l2a) | |
| l4a, l4b = wavenetBlock(1, 2, 16, 1, 3)(l3a) | |
| l5a, l5b = wavenetBlock(1, 2, 32, 1, 3)(l4a) | |
| l6 = merge([l1b, l2b, l3b, l4b, l5b], mode='sum') | |
| l7 = Lambda(relu)(l6) | |
| l8 = Convolution2D(1, 1, 1, activation='relu')(l7) | |
| l9 = Convolution2D(1, 1, 1)(l8) | |
| l10 = Flatten()(l9) | |
| l11 = Dense(1, activation='tanh')(l10) | |
| model = Model(input=input, output=l11) | |
| model.compile(loss='mse', optimizer='rmsprop', metrics=['accuracy']) | |
| model.summary() | |
| return model | |
| def get_audio(filename): | |
| sr, audio = read(filename) | |
| audio = audio.astype(float) | |
| audio = audio - audio.min() | |
| audio = audio / (audio.max() - audio.min()) | |
| audio = (audio - 0.5) * 2 | |
| return sr, audio | |
| def frame_generator(sr, audio, frame_size, frame_shift): | |
| audio_len = len(audio) | |
| while 1: | |
| for i in range(0, audio_len - frame_size - 1, frame_shift): | |
| frame = audio[i:i+frame_size] | |
| if len(frame) < frame_size: | |
| break | |
| if i + frame_size >= audio_len: | |
| break | |
| temp = audio[i + frame_size] | |
| yield frame.reshape(1, 1, frame_size, 1), \ | |
| temp.reshape(1, 1) | |
| if __name__ == '__main__': | |
| n_epochs = 20 | |
| frame_size = 2048 | |
| frame_shift = 512 | |
| sr_training, training_audio = get_audio('train.wav') | |
| training_audio = training_audio[:sr_training*240] | |
| sr_valid, valid_audio = get_audio('validate.wav') | |
| valid_audio = valid_audio[:sr_valid*30] | |
| assert sr_training == sr_valid, "Training, validation samplerate mismatch" | |
| n_training_examples = int((len(training_audio)-frame_size-1) / float( | |
| frame_shift)) | |
| n_validation_examples = int((len(valid_audio)-frame_size-1) / float( | |
| frame_shift)) | |
| model = get_basic_generative_model(frame_size) | |
| print 'Total training examples:', n_training_examples | |
| print 'Total validation examples:', n_validation_examples | |
| model.fit_generator(frame_generator(sr_training, training_audio, | |
| frame_size, frame_shift), | |
| samples_per_epoch=n_training_examples, | |
| nb_epoch=n_epochs, | |
| validation_data=frame_generator(sr_valid, valid_audio, | |
| frame_size, frame_shift | |
| ), | |
| nb_val_samples=n_validation_examples, | |
| verbose=1) | |
| print 'Saving model...' | |
| str_timestamp = str(int(time.time())) | |
| model.save('models/model_'+str_timestamp+'_'+str(n_epochs)+'.h5') | |
| print 'Generating audio...' | |
| new_audio = np.zeros((sr_training * 3)) | |
| curr_sample_idx = 0 | |
| audio_context = valid_audio[:frame_size] | |
| while curr_sample_idx < new_audio.shape[0]: | |
| predicted_val = model.predict(audio_context.reshape(1, 1, frame_size, | |
| 1)) | |
| ampl_val_16 = predicted_val * 2**15 | |
| new_audio[curr_sample_idx] = ampl_val_16 | |
| audio_context[-1] = ampl_val_16 | |
| audio_context[:-1] = audio_context[1:] | |
| pc_str = str(round(100*curr_sample_idx/float(new_audio.shape[0]), 2)) | |
| sys.stdout.write('Percent complete: ' + pc_str + '\r') | |
| sys.stdout.flush() | |
| curr_sample_idx += 1 | |
| outfilepath = 'output/reg_generated_'+str_timestamp+'.wav' | |
| print 'Writing generated audio to:', outfilepath | |
| write(outfilepath, sr_training, new_audio.astype(np.int16)) | |
| print '\nDone!' | |