File size: 2,916 Bytes
c39b616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from Imports import *
from Configuration import *

VOCAB        = list("abcdefghijklmnopqrstuvwxyz .,!?-'\"")
PAD_TOKEN    = '<PAD>'
EOS_TOKEN    = '<EOS>'
vocab_list   = [PAD_TOKEN, EOS_TOKEN] + VOCAB
char2id      = {c: i for i, c in enumerate(vocab_list)}
id2char      = {i: c for c, i in char2id.items()}
VOCAB_SIZE   = len(char2id)
PAD_ID       = char2id[PAD_TOKEN]
EOS_ID       = char2id[EOS_TOKEN]

keys_tensor   = tf.constant(list(char2id.keys()))
values_tensor = tf.constant(list(char2id.values()), dtype=tf.int32)
char_to_id_table = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor),
    default_value=PAD_ID
)
keys_tensor2   = tf.constant(list(id2char.keys()),   dtype=tf.int32)
values_tensor2 = tf.constant(list(id2char.values()))
id_to_char_table = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(keys_tensor2, values_tensor2),
    default_value='?'
)
print(VOCAB_SIZE)

# Hann window precomputed
hann_window = jnp.array(np.hanning(FFT_LENGTH))
pad_amt     = (FFT_LENGTH - FRAME_STEP) // 2    # 384
hann_window = jnp.array(np.hanning(FFT_LENGTH))
mel_filterbank = jnp.array(
    librosa.filters.mel(
        sr=SAMPLE_RATE, n_fft=FFT_LENGTH, n_mels=NUM_MEL_BINS,
        fmin=LOWER_EDGE_HERTZ, fmax=UPPER_EDGE_HERTZ
    ).T
)

@jax.jit
def jax_mel(waveform):
    wav = jnp.squeeze(waveform, axis=-1)
    wav = jnp.pad(wav, ((0, 0), (pad_amt, pad_amt)), mode='reflect')
    T   = wav.shape[-1]
    num_frames  = (T - FFT_LENGTH) // FRAME_STEP + 1
    frame_starts = jnp.arange(num_frames) * FRAME_STEP
    frame_idx    = frame_starts[:, None] + jnp.arange(FFT_LENGTH)[None, :]
    def frame_and_stft(wav_single):
        frames = wav_single[frame_idx] * hann_window
        stft   = jnp.fft.rfft(frames, n=FFT_LENGTH)
        mel    = jnp.dot(jnp.abs(stft), mel_filterbank)
        return jnp.log(jnp.maximum(mel, 1e-5))
    return jax.vmap(frame_and_stft)(wav)   # (B, MAX_MEL_LEN, 80)


# ── text β†’ ids (works inside tf.data) ────────────────────────────────────────
def text_to_ids_tf(text):
    text = tf.strings.lower(tf.strings.strip(tf.constant(text)))
    chars = tf.strings.unicode_split(text, 'UTF-8')
    ids = char_to_id_table.lookup(chars)
    # filter out space token (id=28)
    # ids = tf.boolean_mask(ids, tf.not_equal(ids, char2id[' ']))
    eos = tf.constant([EOS_ID], dtype=tf.int32)
    ids = tf.concat([ids, eos], axis=0)
    return ids

# ── ids β†’ text (for verification) ────────────────────────────────────────────
def ids_to_text_tf(ids):
    ids = tf.cast(tf.constant(ids), tf.int32)
    chars = id_to_char_table.lookup(ids)
    return tf.strings.reduce_join(chars)