Update Whisper.py
Browse files- Whisper.py +13 -3
Whisper.py
CHANGED
|
@@ -150,7 +150,7 @@ class AudioEncoder:
|
|
| 150 |
return x
|
| 151 |
|
| 152 |
|
| 153 |
-
class TextDecoder:
|
| 154 |
def __init__(
|
| 155 |
self,
|
| 156 |
n_vocab: int,
|
|
@@ -160,8 +160,18 @@ class TextDecoder:
|
|
| 160 |
n_layer: int,
|
| 161 |
dtype = tf.float16,
|
| 162 |
):
|
| 163 |
-
self.token_embedding =
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
self.blocks = [
|
| 167 |
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
|
|
|
|
| 150 |
return x
|
| 151 |
|
| 152 |
|
| 153 |
+
class TextDecoder(tf.keras.layers.Layer):
|
| 154 |
def __init__(
|
| 155 |
self,
|
| 156 |
n_vocab: int,
|
|
|
|
| 160 |
n_layer: int,
|
| 161 |
dtype = tf.float16,
|
| 162 |
):
|
| 163 |
+
self.token_embedding = self.add_weight(
|
| 164 |
+
name='token_embedding',
|
| 165 |
+
shape=[self.n_vocab, self.n_state],
|
| 166 |
+
initializer=tf.keras.initializers.RandomNormal(stddev=0.02), # 设定标准差 stddev
|
| 167 |
+
trainable=True
|
| 168 |
+
)
|
| 169 |
+
self.positional_embedding = self.add_weight(
|
| 170 |
+
name='positional_embedding',
|
| 171 |
+
shape=[self.n_ctx, self.n_state],
|
| 172 |
+
initializer=tf.keras.initializers.Zeros(), # 初始化为全零
|
| 173 |
+
trainable=True
|
| 174 |
+
)
|
| 175 |
|
| 176 |
self.blocks = [
|
| 177 |
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
|