Yuchan commited on
Commit
133d2fa
·
verified ·
1 Parent(s): 1d975a6

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +17 -5
AlphaS2S.py CHANGED
@@ -4,12 +4,21 @@ from tensorflow.keras import layers, Model
4
  class SwiGLU(layers.Layer):
5
  def __init__(self, d_model, d_ff):
6
  super().__init__()
7
- self.proj = layers.Dense(d_ff*2)
8
  self.out = layers.Dense(d_model)
9
  def call(self, x):
10
  x_proj = self.proj(x)
11
  x_val, x_gate = tf.split(x_proj, 2, axis=-1)
12
  return self.out(x_val * tf.nn.silu(x_gate))
 
 
 
 
 
 
 
 
 
13
 
14
  class EncoderBlock(layers.Layer):
15
  def __init__(self, d_model, num_heads, dff, dropout=0.1):
@@ -35,12 +44,14 @@ class LoU(layers.Layer):
35
  self.Q = layers.Dense(d_model, dtype='float32')
36
  self.K = layers.Dense(d_model, dtype='float32')
37
  self.V = layers.Dense(d_model, dtype='float32')
38
- self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
39
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
40
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
41
 
42
  self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
43
 
 
 
 
44
  def _ema_over_time(self, score, alpha_dynamic):
45
  seq = tf.transpose(score, perm=[1, 0, 2])
46
  alpha_seq = tf.transpose(alpha_dynamic, perm=[1, 0, 2])
@@ -60,7 +71,7 @@ class LoU(layers.Layer):
60
  ema = tf.transpose(ema_seq, perm=[1, 0, 2])
61
  return ema
62
 
63
- def call(self, x):
64
  x_f32 = tf.cast(x, tf.float32)
65
  residual = x_f32
66
  x_f32 = self.norm1(x)
@@ -83,8 +94,9 @@ class LoU(layers.Layer):
83
  score_norm = score_ema / denom
84
  score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
85
  x_comb = score_clipped * V
86
- out = self.proj(x_comb)
87
- out = self.norm(out + residual)
 
88
  return tf.cast(out, x.dtype)
89
 
90
  class Transformer(tf.keras.Model):
 
4
  class SwiGLU(layers.Layer):
5
  def __init__(self, d_model, d_ff):
6
  super().__init__()
7
+ self.proj = layers.Dense(d_ff)
8
  self.out = layers.Dense(d_model)
9
  def call(self, x):
10
  x_proj = self.proj(x)
11
  x_val, x_gate = tf.split(x_proj, 2, axis=-1)
12
  return self.out(x_val * tf.nn.silu(x_gate))
13
+
14
+ class CrossBlock(layers.Layer):
15
+ def __init__(self):
16
+ super().__init__()
17
+ self.alpha = layers.Dense(1, activation='sigmoid', dtype='float32')
18
+ def call(self, x, z):
19
+ a = self.alpha(x)
20
+ y = a * x + (1.0 - a) * z
21
+ return y
22
 
23
  class EncoderBlock(layers.Layer):
24
  def __init__(self, d_model, num_heads, dff, dropout=0.1):
 
44
  self.Q = layers.Dense(d_model, dtype='float32')
45
  self.K = layers.Dense(d_model, dtype='float32')
46
  self.V = layers.Dense(d_model, dtype='float32')
 
47
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
48
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
49
 
50
  self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
51
 
52
+ self.cross = CrossBlock()
53
+ self.glu = SwiGLU(d_model, 512)
54
+
55
  def _ema_over_time(self, score, alpha_dynamic):
56
  seq = tf.transpose(score, perm=[1, 0, 2])
57
  alpha_seq = tf.transpose(alpha_dynamic, perm=[1, 0, 2])
 
71
  ema = tf.transpose(ema_seq, perm=[1, 0, 2])
72
  return ema
73
 
74
+ def call(self, x, z):
75
  x_f32 = tf.cast(x, tf.float32)
76
  residual = x_f32
77
  x_f32 = self.norm1(x)
 
94
  score_norm = score_ema / denom
95
  score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
96
  x_comb = score_clipped * V
97
+ out = self.norm(x_comb + residual)
98
+ out = self.cross(out, z)
99
+ out = self.glu(out)
100
  return tf.cast(out, x.dtype)
101
 
102
  class Transformer(tf.keras.Model):