Upload Whisper.py
Browse files- Whisper.py +263 -0
Whisper.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from tensorflow.keras.layers import Dense,Conv1d,ZeroPadding1D,LayerNormalization
|
| 3 |
+
from tensorflow.keras import Model
|
| 4 |
+
import base64
|
| 5 |
+
import gzip
|
| 6 |
+
import numpy as np
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ModelDimensions:
|
| 11 |
+
n_mels: int
|
| 12 |
+
n_audio_ctx: int
|
| 13 |
+
n_audio_state: int
|
| 14 |
+
n_audio_head: int
|
| 15 |
+
n_audio_layer: int
|
| 16 |
+
n_vocab: int
|
| 17 |
+
n_text_ctx: int
|
| 18 |
+
n_text_state: int
|
| 19 |
+
n_text_head: int
|
| 20 |
+
n_text_layer: int
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def sinusoids(length, channels, max_timescale=10000):
|
| 24 |
+
"""Returns sinusoids for positional embedding"""
|
| 25 |
+
assert channels % 2 == 0
|
| 26 |
+
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
| 27 |
+
inv_timescales = tf.math.exp(-log_timescale_increment * np.arange(channels // 2))
|
| 28 |
+
scaled_time = np.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
| 29 |
+
return tf.concat([tf.math.sin(scaled_time), tf.math.cos(scaled_time)], axis=1)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class LayerNorm:
|
| 33 |
+
def __init__(self, n_state):
|
| 34 |
+
self.layer_norm = LayerNormalization
|
| 35 |
+
|
| 36 |
+
def __call__(self, x):
|
| 37 |
+
return tf.cast(self.layer_norm(tf.cast(x, 'float32')), x.dtype)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class MultiHeadAttention:
|
| 41 |
+
def __init__(self, n_state: int, n_head: int):
|
| 42 |
+
self.n_head = n_head
|
| 43 |
+
self.query = Dense(n_state)
|
| 44 |
+
self.key = Dense(n_state, use_bias=False)
|
| 45 |
+
self.value = Dense(n_state)
|
| 46 |
+
self.out = Dense(n_state)
|
| 47 |
+
|
| 48 |
+
def __call__(
|
| 49 |
+
self,
|
| 50 |
+
x,
|
| 51 |
+
xa=None,
|
| 52 |
+
mask=None,
|
| 53 |
+
kv_cache=None,
|
| 54 |
+
):
|
| 55 |
+
q = self.query(x)
|
| 56 |
+
|
| 57 |
+
if xa is None:
|
| 58 |
+
k = self.key(x)
|
| 59 |
+
v = self.value(x)
|
| 60 |
+
if kv_cache is not None:
|
| 61 |
+
k = tf.concat([kv_cache[0], k], axis=1)
|
| 62 |
+
v = tf.concat([kv_cache[1], v], axis=1)
|
| 63 |
+
elif kv_cache is None:
|
| 64 |
+
k = self.key(xa)
|
| 65 |
+
v = self.value(xa)
|
| 66 |
+
else:
|
| 67 |
+
k, v = kv_cache
|
| 68 |
+
|
| 69 |
+
wv, qk = self.qkv_attention(q, k, v, mask)
|
| 70 |
+
return self.out(wv), (k, v), qk
|
| 71 |
+
|
| 72 |
+
def qkv_attention(self, q, k, v, mask=None):
|
| 73 |
+
n_batch, n_ctx, n_state = q.shape
|
| 74 |
+
scale = (n_state // self.n_head) ** -0.25
|
| 75 |
+
q = q.reshape(*q.shape[:2], self.n_head, -1).transpose(0, 2, 1, 3) * scale
|
| 76 |
+
k = k.reshape(*k.shape[:2], self.n_head, -1).transpose(0, 2, 3, 1) * scale
|
| 77 |
+
v = v.reshape(*v.shape[:2], self.n_head, -1).transpose(0, 2, 1, 3)
|
| 78 |
+
|
| 79 |
+
qk = tf.matmul(q, k)
|
| 80 |
+
if mask is not None:
|
| 81 |
+
qk = qk + mask[:n_ctx, :n_ctx]
|
| 82 |
+
qk = tf.cast(qk, tf.float32)
|
| 83 |
+
|
| 84 |
+
w = tf.cast(tf.nn.softmax(qk, axis=-1), q.dtype)
|
| 85 |
+
out = tf.transpose(tf.matmul(w, v), (0, 2, 1, 3))
|
| 86 |
+
out = tf.reshape(out, (n_batch, n_ctx, n_state))
|
| 87 |
+
return out, qk
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class ResidualAttentionBlock:
|
| 91 |
+
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
| 92 |
+
self.attn = MultiHeadAttention(n_state, n_head)
|
| 93 |
+
self.attn_ln = LayerNorm(n_state)
|
| 94 |
+
|
| 95 |
+
self.cross_attn = (
|
| 96 |
+
MultiHeadAttention(n_state, n_head) if cross_attention else None
|
| 97 |
+
)
|
| 98 |
+
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
| 99 |
+
|
| 100 |
+
n_mlp = n_state * 4
|
| 101 |
+
self.mlp1 = Dense(n_mlp)
|
| 102 |
+
self.mlp2 = Dense(n_state)
|
| 103 |
+
self.mlp_ln = LayerNorm(n_state)
|
| 104 |
+
|
| 105 |
+
def __call__(self, x, xa=None, mask=None, kv_cache=None):
|
| 106 |
+
kv, cross_kv = kv_cache if kv_cache else (None, None)
|
| 107 |
+
y, kv, _ = self.attn(self.attn_ln(x), mask=mask, kv_cache=kv)
|
| 108 |
+
x += y
|
| 109 |
+
cross_qk = None
|
| 110 |
+
if self.cross_attn:
|
| 111 |
+
y, cross_kv, cross_qk = self.cross_attn(
|
| 112 |
+
self.cross_attn_ln(x), xa, kv_cache=cross_kv
|
| 113 |
+
)
|
| 114 |
+
x += y
|
| 115 |
+
x = x + tf.cast(self.mlp2(tf.nn.gelu(self.mlp1(self.mlp_ln(x))), x.dtype))
|
| 116 |
+
return x, (kv, cross_kv), cross_qk
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class AudioEncoder:
|
| 120 |
+
def __init__(
|
| 121 |
+
self,
|
| 122 |
+
n_mels: int,
|
| 123 |
+
n_ctx: int,
|
| 124 |
+
n_state: int,
|
| 125 |
+
n_head: int,
|
| 126 |
+
n_layer: int,
|
| 127 |
+
dtype = tf.float16,
|
| 128 |
+
):
|
| 129 |
+
self.zeropadding1d1 = ZeroPadding1D(padding=1)
|
| 130 |
+
self.conv1 = Conv1d(filters=n_state, kernel_size=3)
|
| 131 |
+
self.zeropadding1d2 = ZeroPadding1D(padding=1)
|
| 132 |
+
self.conv2 = Conv1d(filters=n_state, kernel_size=3, strides=2)
|
| 133 |
+
self._positional_embedding = tf.cast(sinusoids(n_ctx, n_state), dtype)
|
| 134 |
+
|
| 135 |
+
self.blocks = [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
| 136 |
+
self.ln_post = LayerNorm(n_state)
|
| 137 |
+
|
| 138 |
+
def __call__(self, x):
|
| 139 |
+
x = self.zeropadding1d1(x)
|
| 140 |
+
x = tf.cast(tf.nn.gelu(self.conv1(x)), x.dtype)
|
| 141 |
+
x = self.zeropadding1d2(x)
|
| 142 |
+
x = tf.cast(tf.nn.gelu(self.conv2(x)), x.dtype)
|
| 143 |
+
assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape"
|
| 144 |
+
x = x + self._positional_embedding
|
| 145 |
+
|
| 146 |
+
for block in self.blocks:
|
| 147 |
+
x, _, _ = block(x)
|
| 148 |
+
|
| 149 |
+
x = self.ln_post(x)
|
| 150 |
+
return x
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class TextDecoder:
|
| 154 |
+
def __init__(
|
| 155 |
+
self,
|
| 156 |
+
n_vocab: int,
|
| 157 |
+
n_ctx: int,
|
| 158 |
+
n_state: int,
|
| 159 |
+
n_head: int,
|
| 160 |
+
n_layer: int,
|
| 161 |
+
dtype = tf.float16,
|
| 162 |
+
):
|
| 163 |
+
self.token_embedding = tf.Variable(tf.random.normal([n_vocab, n_state]))
|
| 164 |
+
self.positional_embedding = tf.Variable(tf.zeros([n_ctx, n_state]))
|
| 165 |
+
|
| 166 |
+
self.blocks = [
|
| 167 |
+
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
|
| 168 |
+
for _ in range(n_layer)
|
| 169 |
+
]
|
| 170 |
+
self.ln = LayerNorm(n_state)
|
| 171 |
+
self._mask = tf.fill((n_ctx, n_ctx), float("-inf"))
|
| 172 |
+
self._mask = tf.linalg.band_part(self._mask, 0, -1)
|
| 173 |
+
self._mask = tf.linalg.set_diag(self._mask, tf.zeros(n_ctx))
|
| 174 |
+
self._mask = tf.cast(self._mask, dtype)
|
| 175 |
+
|
| 176 |
+
def __call__(self, x, xa, kv_cache=None):
|
| 177 |
+
"""
|
| 178 |
+
x : shape = (batch_size, <= n_ctx)
|
| 179 |
+
the text tokens
|
| 180 |
+
xa : shape = (batch_size, n_audio_ctx, n_audio_state)
|
| 181 |
+
the encoded audio features to be attended on
|
| 182 |
+
"""
|
| 183 |
+
offset = kv_cache[0][0][0].shape[1] if kv_cache else 0
|
| 184 |
+
x = (
|
| 185 |
+
tf.gather(self.token_embedding, x)
|
| 186 |
+
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
if kv_cache is None:
|
| 190 |
+
kv_cache = [None] * len(self.blocks)
|
| 191 |
+
cross_qk = [None] * len(self.blocks)
|
| 192 |
+
for e, block in enumerate(self.blocks):
|
| 193 |
+
x, kv_cache[e], cross_qk[e] = block(
|
| 194 |
+
x, xa, mask=self._mask, kv_cache=kv_cache[e]
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
x = self.ln(x)
|
| 198 |
+
return tf.matmul(x, tf.transpose(self.token_embedding)), kv_cache, cross_qk
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class Whisper(Model):
|
| 202 |
+
def __init__(self, dims: ModelDimensions, dtype = tf.float16):
|
| 203 |
+
super(Whisper, self).__init__()
|
| 204 |
+
self.dims = dims
|
| 205 |
+
self.encoder = AudioEncoder(
|
| 206 |
+
self.dims.n_mels,
|
| 207 |
+
self.dims.n_audio_ctx,
|
| 208 |
+
self.dims.n_audio_state,
|
| 209 |
+
self.dims.n_audio_head,
|
| 210 |
+
self.dims.n_audio_layer,
|
| 211 |
+
dtype,
|
| 212 |
+
)
|
| 213 |
+
self.decoder = TextDecoder(
|
| 214 |
+
self.dims.n_vocab,
|
| 215 |
+
self.dims.n_text_ctx,
|
| 216 |
+
self.dims.n_text_state,
|
| 217 |
+
self.dims.n_text_head,
|
| 218 |
+
self.dims.n_text_layer,
|
| 219 |
+
dtype,
|
| 220 |
+
)
|
| 221 |
+
# use the last half among the decoder layers for time alignment by default;
|
| 222 |
+
# to use a specific set of heads, see `set_alignment_heads()` below.
|
| 223 |
+
all_heads = np.zeros(
|
| 224 |
+
(self.dims.n_text_layer, self.dims.n_text_head), dtype=bool
|
| 225 |
+
)
|
| 226 |
+
all_heads[self.dims.n_text_layer // 2 :] = True
|
| 227 |
+
self.alignment_heads = tf.transpose(tf.cast(tf.where(all_heads != 0), dtype=tf.int32))
|
| 228 |
+
|
| 229 |
+
def set_alignment_heads(self, dump: Union[bytes, np.ndarray]):
|
| 230 |
+
if isinstance(dump, np.ndarray):
|
| 231 |
+
self.alignment_heads = tf.convert_to_tensor(dump)
|
| 232 |
+
elif isinstance(dump, bytes):
|
| 233 |
+
array = np.frombuffer(
|
| 234 |
+
gzip.decompress(base64.b85decode(dump)), dtype=bool
|
| 235 |
+
).copy()
|
| 236 |
+
mask = array.reshape(self.dims.n_text_layer, self.dims.n_text_head)
|
| 237 |
+
self.alignment_heads = tf.transpose(tf.cast(tf.where(mask != 0), dtype=tf.int32))
|
| 238 |
+
else:
|
| 239 |
+
raise ValueError(
|
| 240 |
+
f"Invalid type for `dump`: {type(dump)}. Expected a np.ndarray or base85-encoded bytes containing"
|
| 241 |
+
" alignment_head information"
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
def embed_audio(self, mel):
|
| 245 |
+
return self.encoder(mel)
|
| 246 |
+
|
| 247 |
+
def logits(self, tokens, audio_features):
|
| 248 |
+
return self.decoder(tokens, audio_features)[0]
|
| 249 |
+
|
| 250 |
+
def forward_with_cross_qk(self, mel, tokens):
|
| 251 |
+
logits, _, cross_qk = self.decoder(tokens, self.encoder(mel))
|
| 252 |
+
return logits, cross_qk
|
| 253 |
+
|
| 254 |
+
def __call__(self, mel, tokens):
|
| 255 |
+
return self.decoder(tokens, self.encoder(mel))[0]
|
| 256 |
+
|
| 257 |
+
@property
|
| 258 |
+
def is_multilingual(self):
|
| 259 |
+
return self.dims.n_vocab >= 51865
|
| 260 |
+
|
| 261 |
+
@property
|
| 262 |
+
def num_languages(self):
|
| 263 |
+
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|