File size: 6,801 Bytes
bcb3d72 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
"""Feature computation for YAMNet."""
import numpy as np
import tensorflow as tf
def waveform_to_log_mel_spectrogram_patches(waveform, params):
"""Compute log mel spectrogram patches of a 1-D waveform."""
with tf.name_scope('log_mel_features'):
# waveform has shape [<# samples>]
# Convert waveform into spectrogram using a Short-Time Fourier Transform.
# Note that tf.signal.stft() uses a periodic Hann window by default.
window_length_samples = int(
round(params.sample_rate * params.stft_window_seconds))
hop_length_samples = int(
round(params.sample_rate * params.stft_hop_seconds))
fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0)))
num_spectrogram_bins = fft_length // 2 + 1
if params.tflite_compatible:
magnitude_spectrogram = _tflite_stft_magnitude(
signal=waveform,
frame_length=window_length_samples,
frame_step=hop_length_samples,
fft_length=fft_length)
else:
magnitude_spectrogram = tf.abs(tf.signal.stft(
signals=waveform,
frame_length=window_length_samples,
frame_step=hop_length_samples,
fft_length=fft_length))
# magnitude_spectrogram has shape [<# STFT frames>, num_spectrogram_bins]
# Convert spectrogram into log mel spectrogram.
linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix(
num_mel_bins=params.mel_bands,
num_spectrogram_bins=num_spectrogram_bins,
sample_rate=params.sample_rate,
lower_edge_hertz=params.mel_min_hz,
upper_edge_hertz=params.mel_max_hz)
mel_spectrogram = tf.matmul(
magnitude_spectrogram, linear_to_mel_weight_matrix)
log_mel_spectrogram = tf.math.log(mel_spectrogram + params.log_offset)
# log_mel_spectrogram has shape [<# STFT frames>, params.mel_bands]
# Frame spectrogram (shape [<# STFT frames>, params.mel_bands]) into patches
# (the input examples). Only complete frames are emitted, so if there is
# less than params.patch_window_seconds of waveform then nothing is emitted
# (to avoid this, zero-pad before processing).
spectrogram_hop_length_samples = int(
round(params.sample_rate * params.stft_hop_seconds))
spectrogram_sample_rate = params.sample_rate / spectrogram_hop_length_samples
patch_window_length_samples = int(
round(spectrogram_sample_rate * params.patch_window_seconds))
patch_hop_length_samples = int(
round(spectrogram_sample_rate * params.patch_hop_seconds))
features = tf.signal.frame(
signal=log_mel_spectrogram,
frame_length=patch_window_length_samples,
frame_step=patch_hop_length_samples,
axis=0)
# features has shape [<# patches>, <# STFT frames in an patch>, params.mel_bands]
return log_mel_spectrogram, features
def pad_waveform(waveform, params):
"""Pads waveform with silence if needed to get an integral number of patches."""
# In order to produce one patch of log mel spectrogram input to YAMNet, we
# need at least one patch window length of waveform plus enough extra samples
# to complete the final STFT analysis window.
min_waveform_seconds = (
params.patch_window_seconds +
params.stft_window_seconds - params.stft_hop_seconds)
min_num_samples = tf.cast(min_waveform_seconds * params.sample_rate, tf.int32)
num_samples = tf.shape(waveform)[0]
num_padding_samples = tf.maximum(0, min_num_samples - num_samples)
# In addition, there might be enough waveform for one or more additional
# patches formed by hopping forward. If there are more samples than one patch,
# round up to an integral number of hops.
num_samples = tf.maximum(num_samples, min_num_samples)
num_samples_after_first_patch = num_samples - min_num_samples
hop_samples = tf.cast(params.patch_hop_seconds * params.sample_rate, tf.int32)
num_hops_after_first_patch = tf.cast(tf.math.ceil(
tf.cast(num_samples_after_first_patch, tf.float32) /
tf.cast(hop_samples, tf.float32)), tf.int32)
num_padding_samples += (
hop_samples * num_hops_after_first_patch - num_samples_after_first_patch)
padded_waveform = tf.pad(waveform, [[0, num_padding_samples]],
mode='CONSTANT', constant_values=0.0)
return padded_waveform
def _tflite_stft_magnitude(signal, frame_length, frame_step, fft_length):
"""TF-Lite-compatible version of tf.abs(tf.signal.stft())."""
def _hann_window():
return tf.reshape(
tf.constant(
(0.5 - 0.5 * np.cos(2 * np.pi * np.arange(0, 1.0, 1.0 / frame_length))
).astype(np.float32),
name='hann_window'), [1, frame_length])
def _dft_matrix(dft_length):
"""Calculate the full DFT matrix in NumPy."""
# See https://en.wikipedia.org/wiki/DFT_matrix
omega = (0 + 1j) * 2.0 * np.pi / float(dft_length)
# Don't include 1/sqrt(N) scaling, tf.signal.rfft doesn't apply it.
return np.exp(omega * np.outer(np.arange(dft_length), np.arange(dft_length)))
def _rdft(framed_signal, fft_length):
"""Implement real-input Discrete Fourier Transform by matmul."""
# We are right-multiplying by the DFT matrix, and we are keeping only the
# first half ("positive frequencies"). So discard the second half of rows,
# but transpose the array for right-multiplication. The DFT matrix is
# symmetric, so we could have done it more directly, but this reflects our
# intention better.
complex_dft_matrix_kept_values = _dft_matrix(fft_length)[:(
fft_length // 2 + 1), :].transpose()
real_dft_matrix = tf.constant(
np.real(complex_dft_matrix_kept_values).astype(np.float32),
name='real_dft_matrix')
imag_dft_matrix = tf.constant(
np.imag(complex_dft_matrix_kept_values).astype(np.float32),
name='imaginary_dft_matrix')
signal_frame_length = tf.shape(framed_signal)[-1]
half_pad = (fft_length - signal_frame_length) // 2
padded_frames = tf.pad(
framed_signal,
[
# Don't add any padding in the frame dimension.
[0, 0],
# Pad before and after the signal within each frame.
[half_pad, fft_length - signal_frame_length - half_pad]
],
mode='CONSTANT',
constant_values=0.0)
real_stft = tf.matmul(padded_frames, real_dft_matrix)
imag_stft = tf.matmul(padded_frames, imag_dft_matrix)
return real_stft, imag_stft
def _complex_abs(real, imag):
return tf.sqrt(tf.add(real * real, imag * imag))
framed_signal = tf.signal.frame(signal, frame_length, frame_step)
windowed_signal = framed_signal * _hann_window()
real_stft, imag_stft = _rdft(windowed_signal, fft_length)
stft_magnitude = _complex_abs(real_stft, imag_stft)
return stft_magnitude
|