Spaces:
Sleeping
Sleeping
File size: 2,916 Bytes
c39b616 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | from Imports import *
from Configuration import *
VOCAB = list("abcdefghijklmnopqrstuvwxyz .,!?-'\"")
PAD_TOKEN = '<PAD>'
EOS_TOKEN = '<EOS>'
vocab_list = [PAD_TOKEN, EOS_TOKEN] + VOCAB
char2id = {c: i for i, c in enumerate(vocab_list)}
id2char = {i: c for c, i in char2id.items()}
VOCAB_SIZE = len(char2id)
PAD_ID = char2id[PAD_TOKEN]
EOS_ID = char2id[EOS_TOKEN]
keys_tensor = tf.constant(list(char2id.keys()))
values_tensor = tf.constant(list(char2id.values()), dtype=tf.int32)
char_to_id_table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor),
default_value=PAD_ID
)
keys_tensor2 = tf.constant(list(id2char.keys()), dtype=tf.int32)
values_tensor2 = tf.constant(list(id2char.values()))
id_to_char_table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(keys_tensor2, values_tensor2),
default_value='?'
)
print(VOCAB_SIZE)
# Hann window precomputed
hann_window = jnp.array(np.hanning(FFT_LENGTH))
pad_amt = (FFT_LENGTH - FRAME_STEP) // 2 # 384
hann_window = jnp.array(np.hanning(FFT_LENGTH))
mel_filterbank = jnp.array(
librosa.filters.mel(
sr=SAMPLE_RATE, n_fft=FFT_LENGTH, n_mels=NUM_MEL_BINS,
fmin=LOWER_EDGE_HERTZ, fmax=UPPER_EDGE_HERTZ
).T
)
@jax.jit
def jax_mel(waveform):
wav = jnp.squeeze(waveform, axis=-1)
wav = jnp.pad(wav, ((0, 0), (pad_amt, pad_amt)), mode='reflect')
T = wav.shape[-1]
num_frames = (T - FFT_LENGTH) // FRAME_STEP + 1
frame_starts = jnp.arange(num_frames) * FRAME_STEP
frame_idx = frame_starts[:, None] + jnp.arange(FFT_LENGTH)[None, :]
def frame_and_stft(wav_single):
frames = wav_single[frame_idx] * hann_window
stft = jnp.fft.rfft(frames, n=FFT_LENGTH)
mel = jnp.dot(jnp.abs(stft), mel_filterbank)
return jnp.log(jnp.maximum(mel, 1e-5))
return jax.vmap(frame_and_stft)(wav) # (B, MAX_MEL_LEN, 80)
# ββ text β ids (works inside tf.data) ββββββββββββββββββββββββββββββββββββββββ
def text_to_ids_tf(text):
text = tf.strings.lower(tf.strings.strip(tf.constant(text)))
chars = tf.strings.unicode_split(text, 'UTF-8')
ids = char_to_id_table.lookup(chars)
# filter out space token (id=28)
# ids = tf.boolean_mask(ids, tf.not_equal(ids, char2id[' ']))
eos = tf.constant([EOS_ID], dtype=tf.int32)
ids = tf.concat([ids, eos], axis=0)
return ids
# ββ ids β text (for verification) ββββββββββββββββββββββββββββββββββββββββββββ
def ids_to_text_tf(ids):
ids = tf.cast(tf.constant(ids), tf.int32)
chars = id_to_char_table.lookup(ids)
return tf.strings.reduce_join(chars) |