File size: 6,801 Bytes
bcb3d72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""Feature computation for YAMNet."""

import numpy as np
import tensorflow as tf


def waveform_to_log_mel_spectrogram_patches(waveform, params):
  """Compute log mel spectrogram patches of a 1-D waveform."""
  with tf.name_scope('log_mel_features'):
    # waveform has shape [<# samples>]

    # Convert waveform into spectrogram using a Short-Time Fourier Transform.
    # Note that tf.signal.stft() uses a periodic Hann window by default.
    window_length_samples = int(
      round(params.sample_rate * params.stft_window_seconds))
    hop_length_samples = int(
      round(params.sample_rate * params.stft_hop_seconds))
    fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0)))
    num_spectrogram_bins = fft_length // 2 + 1
    if params.tflite_compatible:
      magnitude_spectrogram = _tflite_stft_magnitude(
          signal=waveform,
          frame_length=window_length_samples,
          frame_step=hop_length_samples,
          fft_length=fft_length)
    else:
      magnitude_spectrogram = tf.abs(tf.signal.stft(
          signals=waveform,
          frame_length=window_length_samples,
          frame_step=hop_length_samples,
          fft_length=fft_length))
    # magnitude_spectrogram has shape [<# STFT frames>, num_spectrogram_bins]

    # Convert spectrogram into log mel spectrogram.
    linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix(
        num_mel_bins=params.mel_bands,
        num_spectrogram_bins=num_spectrogram_bins,
        sample_rate=params.sample_rate,
        lower_edge_hertz=params.mel_min_hz,
        upper_edge_hertz=params.mel_max_hz)
    mel_spectrogram = tf.matmul(
      magnitude_spectrogram, linear_to_mel_weight_matrix)
    log_mel_spectrogram = tf.math.log(mel_spectrogram + params.log_offset)
    # log_mel_spectrogram has shape [<# STFT frames>, params.mel_bands]

    # Frame spectrogram (shape [<# STFT frames>, params.mel_bands]) into patches
    # (the input examples). Only complete frames are emitted, so if there is
    # less than params.patch_window_seconds of waveform then nothing is emitted
    # (to avoid this, zero-pad before processing).
    spectrogram_hop_length_samples = int(
      round(params.sample_rate * params.stft_hop_seconds))
    spectrogram_sample_rate = params.sample_rate / spectrogram_hop_length_samples
    patch_window_length_samples = int(
      round(spectrogram_sample_rate * params.patch_window_seconds))
    patch_hop_length_samples = int(
      round(spectrogram_sample_rate * params.patch_hop_seconds))
    features = tf.signal.frame(
        signal=log_mel_spectrogram,
        frame_length=patch_window_length_samples,
        frame_step=patch_hop_length_samples,
        axis=0)
    # features has shape [<# patches>, <# STFT frames in an patch>, params.mel_bands]

    return log_mel_spectrogram, features


def pad_waveform(waveform, params):
  """Pads waveform with silence if needed to get an integral number of patches."""
  # In order to produce one patch of log mel spectrogram input to YAMNet, we
  # need at least one patch window length of waveform plus enough extra samples
  # to complete the final STFT analysis window.
  min_waveform_seconds = (
      params.patch_window_seconds +
      params.stft_window_seconds - params.stft_hop_seconds)
  min_num_samples = tf.cast(min_waveform_seconds * params.sample_rate, tf.int32)
  num_samples = tf.shape(waveform)[0]
  num_padding_samples = tf.maximum(0, min_num_samples - num_samples)

  # In addition, there might be enough waveform for one or more additional
  # patches formed by hopping forward. If there are more samples than one patch,
  # round up to an integral number of hops.
  num_samples = tf.maximum(num_samples, min_num_samples)
  num_samples_after_first_patch = num_samples - min_num_samples
  hop_samples = tf.cast(params.patch_hop_seconds * params.sample_rate, tf.int32)
  num_hops_after_first_patch = tf.cast(tf.math.ceil(
          tf.cast(num_samples_after_first_patch, tf.float32) /
          tf.cast(hop_samples, tf.float32)), tf.int32)
  num_padding_samples += (
      hop_samples * num_hops_after_first_patch - num_samples_after_first_patch)

  padded_waveform = tf.pad(waveform, [[0, num_padding_samples]],
                           mode='CONSTANT', constant_values=0.0)
  return padded_waveform


def _tflite_stft_magnitude(signal, frame_length, frame_step, fft_length):
  """TF-Lite-compatible version of tf.abs(tf.signal.stft())."""
  def _hann_window():
    return tf.reshape(
      tf.constant(
          (0.5 - 0.5 * np.cos(2 * np.pi * np.arange(0, 1.0, 1.0 / frame_length))
          ).astype(np.float32),
          name='hann_window'), [1, frame_length])

  def _dft_matrix(dft_length):
    """Calculate the full DFT matrix in NumPy."""
    # See https://en.wikipedia.org/wiki/DFT_matrix
    omega = (0 + 1j) * 2.0 * np.pi / float(dft_length)
    # Don't include 1/sqrt(N) scaling, tf.signal.rfft doesn't apply it.
    return np.exp(omega * np.outer(np.arange(dft_length), np.arange(dft_length)))

  def _rdft(framed_signal, fft_length):
    """Implement real-input Discrete Fourier Transform by matmul."""
    # We are right-multiplying by the DFT matrix, and we are keeping only the
    # first half ("positive frequencies").  So discard the second half of rows,
    # but transpose the array for right-multiplication.  The DFT matrix is
    # symmetric, so we could have done it more directly, but this reflects our
    # intention better.
    complex_dft_matrix_kept_values = _dft_matrix(fft_length)[:(
        fft_length // 2 + 1), :].transpose()
    real_dft_matrix = tf.constant(
        np.real(complex_dft_matrix_kept_values).astype(np.float32),
        name='real_dft_matrix')
    imag_dft_matrix = tf.constant(
        np.imag(complex_dft_matrix_kept_values).astype(np.float32),
        name='imaginary_dft_matrix')
    signal_frame_length = tf.shape(framed_signal)[-1]
    half_pad = (fft_length - signal_frame_length) // 2
    padded_frames = tf.pad(
        framed_signal,
        [
            # Don't add any padding in the frame dimension.
            [0, 0],
            # Pad before and after the signal within each frame.
            [half_pad, fft_length - signal_frame_length - half_pad]
        ],
        mode='CONSTANT',
        constant_values=0.0)
    real_stft = tf.matmul(padded_frames, real_dft_matrix)
    imag_stft = tf.matmul(padded_frames, imag_dft_matrix)
    return real_stft, imag_stft

  def _complex_abs(real, imag):
    return tf.sqrt(tf.add(real * real, imag * imag))

  framed_signal = tf.signal.frame(signal, frame_length, frame_step)
  windowed_signal = framed_signal * _hann_window()
  real_stft, imag_stft = _rdft(windowed_signal, fft_length)
  stft_magnitude = _complex_abs(real_stft, imag_stft)
  return stft_magnitude