Spaces:
Sleeping
Sleeping
| from Imports import * | |
| from Configuration import * | |
| VOCAB = list("abcdefghijklmnopqrstuvwxyz .,!?-'\"") | |
| PAD_TOKEN = '<PAD>' | |
| EOS_TOKEN = '<EOS>' | |
| vocab_list = [PAD_TOKEN, EOS_TOKEN] + VOCAB | |
| char2id = {c: i for i, c in enumerate(vocab_list)} | |
| id2char = {i: c for c, i in char2id.items()} | |
| VOCAB_SIZE = len(char2id) | |
| PAD_ID = char2id[PAD_TOKEN] | |
| EOS_ID = char2id[EOS_TOKEN] | |
| keys_tensor = tf.constant(list(char2id.keys())) | |
| values_tensor = tf.constant(list(char2id.values()), dtype=tf.int32) | |
| char_to_id_table = tf.lookup.StaticHashTable( | |
| tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor), | |
| default_value=PAD_ID | |
| ) | |
| keys_tensor2 = tf.constant(list(id2char.keys()), dtype=tf.int32) | |
| values_tensor2 = tf.constant(list(id2char.values())) | |
| id_to_char_table = tf.lookup.StaticHashTable( | |
| tf.lookup.KeyValueTensorInitializer(keys_tensor2, values_tensor2), | |
| default_value='?' | |
| ) | |
| print(VOCAB_SIZE) | |
| # Hann window precomputed | |
| hann_window = jnp.array(np.hanning(FFT_LENGTH)) | |
| pad_amt = (FFT_LENGTH - FRAME_STEP) // 2 # 384 | |
| hann_window = jnp.array(np.hanning(FFT_LENGTH)) | |
| mel_filterbank = jnp.array( | |
| librosa.filters.mel( | |
| sr=SAMPLE_RATE, n_fft=FFT_LENGTH, n_mels=NUM_MEL_BINS, | |
| fmin=LOWER_EDGE_HERTZ, fmax=UPPER_EDGE_HERTZ | |
| ).T | |
| ) | |
| def jax_mel(waveform): | |
| wav = jnp.squeeze(waveform, axis=-1) | |
| wav = jnp.pad(wav, ((0, 0), (pad_amt, pad_amt)), mode='reflect') | |
| T = wav.shape[-1] | |
| num_frames = (T - FFT_LENGTH) // FRAME_STEP + 1 | |
| frame_starts = jnp.arange(num_frames) * FRAME_STEP | |
| frame_idx = frame_starts[:, None] + jnp.arange(FFT_LENGTH)[None, :] | |
| def frame_and_stft(wav_single): | |
| frames = wav_single[frame_idx] * hann_window | |
| stft = jnp.fft.rfft(frames, n=FFT_LENGTH) | |
| mel = jnp.dot(jnp.abs(stft), mel_filterbank) | |
| return jnp.log(jnp.maximum(mel, 1e-5)) | |
| return jax.vmap(frame_and_stft)(wav) # (B, MAX_MEL_LEN, 80) | |
| # ββ text β ids (works inside tf.data) ββββββββββββββββββββββββββββββββββββββββ | |
| def text_to_ids_tf(text): | |
| text = tf.strings.lower(tf.strings.strip(tf.constant(text))) | |
| chars = tf.strings.unicode_split(text, 'UTF-8') | |
| ids = char_to_id_table.lookup(chars) | |
| # filter out space token (id=28) | |
| # ids = tf.boolean_mask(ids, tf.not_equal(ids, char2id[' '])) | |
| eos = tf.constant([EOS_ID], dtype=tf.int32) | |
| ids = tf.concat([ids, eos], axis=0) | |
| return ids | |
| # ββ ids β text (for verification) ββββββββββββββββββββββββββββββββββββββββββββ | |
| def ids_to_text_tf(ids): | |
| ids = tf.cast(tf.constant(ids), tf.int32) | |
| chars = id_to_char_table.lookup(ids) | |
| return tf.strings.reduce_join(chars) |