Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import time | |
| import numpy as np | |
| from keras.callbacks import Callback | |
| from scipy.io.wavfile import read, write | |
| from keras.models import Model, Sequential | |
| from keras.layers import Convolution1D, AtrousConvolution1D, Flatten, Dense, \ | |
| Input, Lambda, merge, Activation | |
| def wavenetBlock(n_atrous_filters, atrous_filter_size, atrous_rate): | |
| def f(input_): | |
| residual = input_ | |
| tanh_out = AtrousConvolution1D(n_atrous_filters, atrous_filter_size, | |
| atrous_rate=atrous_rate, | |
| border_mode='same', | |
| activation='tanh')(input_) | |
| sigmoid_out = AtrousConvolution1D(n_atrous_filters, atrous_filter_size, | |
| atrous_rate=atrous_rate, | |
| border_mode='same', | |
| activation='sigmoid')(input_) | |
| merged = merge([tanh_out, sigmoid_out], mode='mul') | |
| skip_out = Convolution1D(1, 1, activation='relu', border_mode='same')(merged) | |
| out = merge([skip_out, residual], mode='sum') | |
| return out, skip_out | |
| return f | |
| def get_basic_generative_model(input_size): | |
| input_ = Input(shape=(input_size, 1)) | |
| A, B = wavenetBlock(64, 2, 2)(input_) | |
| skip_connections = [B] | |
| for i in range(20): | |
| A, B = wavenetBlock(64, 2, 2**((i+2)%9))(A) | |
| skip_connections.append(B) | |
| net = merge(skip_connections, mode='sum') | |
| net = Activation('relu')(net) | |
| net = Convolution1D(1, 1, activation='relu')(net) | |
| net = Convolution1D(1, 1)(net) | |
| net = Flatten()(net) | |
| net = Dense(256, activation='softmax')(net) | |
| model = Model(input=input_, output=net) | |
| model.compile(loss='categorical_crossentropy', optimizer='sgd', | |
| metrics=['accuracy']) | |
| model.summary() | |
| return model | |
| def get_audio(filename): | |
| sr, audio = read(filename) | |
| audio = audio.astype(float) | |
| audio = audio - audio.min() | |
| audio = audio / (audio.max() - audio.min()) | |
| audio = (audio - 0.5) * 2 | |
| return sr, audio | |
| def frame_generator(sr, audio, frame_size, frame_shift, minibatch_size=20): | |
| audio_len = len(audio) | |
| X = [] | |
| y = [] | |
| while 1: | |
| for i in range(0, audio_len - frame_size - 1, frame_shift): | |
| frame = audio[i:i+frame_size] | |
| if len(frame) < frame_size: | |
| break | |
| if i + frame_size >= audio_len: | |
| break | |
| temp = audio[i + frame_size] | |
| target_val = int((np.sign(temp) * (np.log(1 + 256*abs(temp)) / ( | |
| np.log(1+256))) + 1)/2.0 * 255) | |
| X.append(frame.reshape(frame_size, 1)) | |
| y.append((np.eye(256)[target_val])) | |
| if len(X) == minibatch_size: | |
| yield np.array(X), np.array(y) | |
| X = [] | |
| y = [] | |
| def get_audio_from_model(model, sr, duration, seed_audio): | |
| print 'Generating audio...' | |
| new_audio = np.zeros((sr * duration)) | |
| curr_sample_idx = 0 | |
| while curr_sample_idx < new_audio.shape[0]: | |
| distribution = np.array(model.predict(seed_audio.reshape(1, | |
| frame_size, 1) | |
| ), dtype=float).reshape(256) | |
| distribution /= distribution.sum().astype(float) | |
| predicted_val = np.random.choice(range(256), p=distribution) | |
| ampl_val_8 = ((((predicted_val) / 255.0) - 0.5) * 2.0) | |
| ampl_val_16 = (np.sign(ampl_val_8) * (1/256.0) * ((1 + 256.0)**abs( | |
| ampl_val_8) - 1)) * 2**15 | |
| new_audio[curr_sample_idx] = ampl_val_16 | |
| seed_audio[-1] = ampl_val_16 | |
| seed_audio[:-1] = seed_audio[1:] | |
| pc_str = str(round(100*curr_sample_idx/float(new_audio.shape[0]), 2)) | |
| sys.stdout.write('Percent complete: ' + pc_str + '\r') | |
| sys.stdout.flush() | |
| curr_sample_idx += 1 | |
| print 'Audio generated.' | |
| return new_audio.astype(np.int16) | |
| class SaveAudioCallback(Callback): | |
| def __init__(self, ckpt_freq, sr, seed_audio): | |
| super(SaveAudioCallback, self).__init__() | |
| self.ckpt_freq = ckpt_freq | |
| self.sr = sr | |
| self.seed_audio = seed_audio | |
| def on_epoch_end(self, epoch, logs={}): | |
| if (epoch+1)%self.ckpt_freq==0: | |
| ts = str(int(time.time())) | |
| filepath = os.path.join('output/', 'ckpt_'+ts+'.wav') | |
| audio = get_audio_from_model(self.model, self.sr, 0.5, self.seed_audio) | |
| write(filepath, self.sr, audio) | |
| if __name__ == '__main__': | |
| n_epochs = 2000 | |
| frame_size = 2048 | |
| frame_shift = 128 | |
| sr_training, training_audio = get_audio('train.wav') | |
| # training_audio = training_audio[:sr_training*1200] | |
| sr_valid, valid_audio = get_audio('validate.wav') | |
| # valid_audio = valid_audio[:sr_valid*60] | |
| assert sr_training == sr_valid, "Training, validation samplerate mismatch" | |
| n_training_examples = int((len(training_audio)-frame_size-1) / float( | |
| frame_shift)) | |
| n_validation_examples = int((len(valid_audio)-frame_size-1) / float( | |
| frame_shift)) | |
| model = get_basic_generative_model(frame_size) | |
| print 'Total training examples:', n_training_examples | |
| print 'Total validation examples:', n_validation_examples | |
| audio_context = valid_audio[:frame_size] | |
| save_audio_clbk = SaveAudioCallback(100, sr_training, audio_context) | |
| validation_data_gen = frame_generator(sr_valid, valid_audio, frame_size, frame_shift) | |
| training_data_gen = frame_generator(sr_training, training_audio, frame_size, frame_shift) | |
| model.fit_generator(training_data_gen, samples_per_epoch=3000, nb_epoch=n_epochs, validation_data=validation_data_gen,nb_val_samples=500, verbose=1, callbacks=[save_audio_clbk]) | |
| print('Saving model...') | |
| str_timestamp = str(int(time.time())) | |
| model.save('models/model_'+str_timestamp+'_'+str(n_epochs)+'.h5') | |
| print('Generating audio...') | |
| new_audio = get_audio_from_model(model, sr_training, 2, audio_context) | |
| outfilepath = 'output/generated_'+str_timestamp+'.wav' | |
| print 'Writing generated audio to:', outfilepath | |
| write(outfilepath, sr_training, new_audio) | |
| print '\nDone!' |