Yuchan
commited on
Update AlphaS2S.py
Browse files- AlphaS2S.py +17 -5
AlphaS2S.py
CHANGED
|
@@ -4,12 +4,21 @@ from tensorflow.keras import layers, Model
|
|
| 4 |
class SwiGLU(layers.Layer):
|
| 5 |
def __init__(self, d_model, d_ff):
|
| 6 |
super().__init__()
|
| 7 |
-
self.proj = layers.Dense(d_ff
|
| 8 |
self.out = layers.Dense(d_model)
|
| 9 |
def call(self, x):
|
| 10 |
x_proj = self.proj(x)
|
| 11 |
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
|
| 12 |
return self.out(x_val * tf.nn.silu(x_gate))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
class EncoderBlock(layers.Layer):
|
| 15 |
def __init__(self, d_model, num_heads, dff, dropout=0.1):
|
|
@@ -35,12 +44,14 @@ class LoU(layers.Layer):
|
|
| 35 |
self.Q = layers.Dense(d_model, dtype='float32')
|
| 36 |
self.K = layers.Dense(d_model, dtype='float32')
|
| 37 |
self.V = layers.Dense(d_model, dtype='float32')
|
| 38 |
-
self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
|
| 39 |
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 40 |
self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 41 |
|
| 42 |
self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
|
| 43 |
|
|
|
|
|
|
|
|
|
|
| 44 |
def _ema_over_time(self, score, alpha_dynamic):
|
| 45 |
seq = tf.transpose(score, perm=[1, 0, 2])
|
| 46 |
alpha_seq = tf.transpose(alpha_dynamic, perm=[1, 0, 2])
|
|
@@ -60,7 +71,7 @@ class LoU(layers.Layer):
|
|
| 60 |
ema = tf.transpose(ema_seq, perm=[1, 0, 2])
|
| 61 |
return ema
|
| 62 |
|
| 63 |
-
def call(self, x):
|
| 64 |
x_f32 = tf.cast(x, tf.float32)
|
| 65 |
residual = x_f32
|
| 66 |
x_f32 = self.norm1(x)
|
|
@@ -83,8 +94,9 @@ class LoU(layers.Layer):
|
|
| 83 |
score_norm = score_ema / denom
|
| 84 |
score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
|
| 85 |
x_comb = score_clipped * V
|
| 86 |
-
out = self.
|
| 87 |
-
out = self.
|
|
|
|
| 88 |
return tf.cast(out, x.dtype)
|
| 89 |
|
| 90 |
class Transformer(tf.keras.Model):
|
|
|
|
| 4 |
class SwiGLU(layers.Layer):
|
| 5 |
def __init__(self, d_model, d_ff):
|
| 6 |
super().__init__()
|
| 7 |
+
self.proj = layers.Dense(d_ff)
|
| 8 |
self.out = layers.Dense(d_model)
|
| 9 |
def call(self, x):
|
| 10 |
x_proj = self.proj(x)
|
| 11 |
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
|
| 12 |
return self.out(x_val * tf.nn.silu(x_gate))
|
| 13 |
+
|
| 14 |
+
class CrossBlock(layers.Layer):
|
| 15 |
+
def __init__(self):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.alpha = layers.Dense(1, activation='sigmoid', dtype='float32')
|
| 18 |
+
def call(self, x, z):
|
| 19 |
+
a = self.alpha(x)
|
| 20 |
+
y = a * x + (1.0 - a) * z
|
| 21 |
+
return y
|
| 22 |
|
| 23 |
class EncoderBlock(layers.Layer):
|
| 24 |
def __init__(self, d_model, num_heads, dff, dropout=0.1):
|
|
|
|
| 44 |
self.Q = layers.Dense(d_model, dtype='float32')
|
| 45 |
self.K = layers.Dense(d_model, dtype='float32')
|
| 46 |
self.V = layers.Dense(d_model, dtype='float32')
|
|
|
|
| 47 |
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 48 |
self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 49 |
|
| 50 |
self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
|
| 51 |
|
| 52 |
+
self.cross = CrossBlock()
|
| 53 |
+
self.glu = SwiGLU(d_model, 512)
|
| 54 |
+
|
| 55 |
def _ema_over_time(self, score, alpha_dynamic):
|
| 56 |
seq = tf.transpose(score, perm=[1, 0, 2])
|
| 57 |
alpha_seq = tf.transpose(alpha_dynamic, perm=[1, 0, 2])
|
|
|
|
| 71 |
ema = tf.transpose(ema_seq, perm=[1, 0, 2])
|
| 72 |
return ema
|
| 73 |
|
| 74 |
+
def call(self, x, z):
|
| 75 |
x_f32 = tf.cast(x, tf.float32)
|
| 76 |
residual = x_f32
|
| 77 |
x_f32 = self.norm1(x)
|
|
|
|
| 94 |
score_norm = score_ema / denom
|
| 95 |
score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
|
| 96 |
x_comb = score_clipped * V
|
| 97 |
+
out = self.norm(x_comb + residual)
|
| 98 |
+
out = self.cross(out, z)
|
| 99 |
+
out = self.glu(out)
|
| 100 |
return tf.cast(out, x.dtype)
|
| 101 |
|
| 102 |
class Transformer(tf.keras.Model):
|