from Imports import * from Configuration import * VOCAB = list("abcdefghijklmnopqrstuvwxyz .,!?-'\"") PAD_TOKEN = '' EOS_TOKEN = '' vocab_list = [PAD_TOKEN, EOS_TOKEN] + VOCAB char2id = {c: i for i, c in enumerate(vocab_list)} id2char = {i: c for c, i in char2id.items()} VOCAB_SIZE = len(char2id) PAD_ID = char2id[PAD_TOKEN] EOS_ID = char2id[EOS_TOKEN] keys_tensor = tf.constant(list(char2id.keys())) values_tensor = tf.constant(list(char2id.values()), dtype=tf.int32) char_to_id_table = tf.lookup.StaticHashTable( tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor), default_value=PAD_ID ) keys_tensor2 = tf.constant(list(id2char.keys()), dtype=tf.int32) values_tensor2 = tf.constant(list(id2char.values())) id_to_char_table = tf.lookup.StaticHashTable( tf.lookup.KeyValueTensorInitializer(keys_tensor2, values_tensor2), default_value='?' ) print(VOCAB_SIZE) # Hann window precomputed hann_window = jnp.array(np.hanning(FFT_LENGTH)) pad_amt = (FFT_LENGTH - FRAME_STEP) // 2 # 384 hann_window = jnp.array(np.hanning(FFT_LENGTH)) mel_filterbank = jnp.array( librosa.filters.mel( sr=SAMPLE_RATE, n_fft=FFT_LENGTH, n_mels=NUM_MEL_BINS, fmin=LOWER_EDGE_HERTZ, fmax=UPPER_EDGE_HERTZ ).T ) @jax.jit def jax_mel(waveform): wav = jnp.squeeze(waveform, axis=-1) wav = jnp.pad(wav, ((0, 0), (pad_amt, pad_amt)), mode='reflect') T = wav.shape[-1] num_frames = (T - FFT_LENGTH) // FRAME_STEP + 1 frame_starts = jnp.arange(num_frames) * FRAME_STEP frame_idx = frame_starts[:, None] + jnp.arange(FFT_LENGTH)[None, :] def frame_and_stft(wav_single): frames = wav_single[frame_idx] * hann_window stft = jnp.fft.rfft(frames, n=FFT_LENGTH) mel = jnp.dot(jnp.abs(stft), mel_filterbank) return jnp.log(jnp.maximum(mel, 1e-5)) return jax.vmap(frame_and_stft)(wav) # (B, MAX_MEL_LEN, 80) # ── text → ids (works inside tf.data) ──────────────────────────────────────── def text_to_ids_tf(text): text = tf.strings.lower(tf.strings.strip(tf.constant(text))) chars = tf.strings.unicode_split(text, 'UTF-8') ids = char_to_id_table.lookup(chars) # filter out space token (id=28) # ids = tf.boolean_mask(ids, tf.not_equal(ids, char2id[' '])) eos = tf.constant([EOS_ID], dtype=tf.int32) ids = tf.concat([ids, eos], axis=0) return ids # ── ids → text (for verification) ──────────────────────────────────────────── def ids_to_text_tf(ids): ids = tf.cast(tf.constant(ids), tf.int32) chars = id_to_char_table.lookup(ids) return tf.strings.reduce_join(chars)