Yuchan commited on
Commit
1d975a6
verified
1 Parent(s): cafd528

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +60 -19
AlphaS2S.py CHANGED
@@ -26,26 +26,67 @@ class EncoderBlock(layers.Layer):
26
  ffn_out = self.dropout2(self.ffn(out1), training=training)
27
  return self.norm2(out1 + ffn_out)
28
 
29
- class DecoderBlock(layers.Layer):
30
- def __init__(self, d_model, num_heads, dff, dropout=0.1):
31
  super().__init__()
32
- self.self_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
33
- self.cross_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
34
- self.ffn = SwiGLU(d_model, dff)
35
- self.norm1 = layers.LayerNormalization(epsilon=1e-6)
36
- self.norm2 = layers.LayerNormalization(epsilon=1e-6)
37
- self.norm3 = layers.LayerNormalization(epsilon=1e-6)
38
- self.dropout1 = layers.Dropout(dropout)
39
- self.dropout2 = layers.Dropout(dropout)
40
- self.dropout3 = layers.Dropout(dropout)
41
- def call(self, x, enc_out, training=False):
42
- attn1 = self.dropout1(self.self_mha(x, x, x, use_causal_mask=True), training=training)
43
- out1 = self.norm1(x + attn1)
44
- attn2 = self.dropout2(self.cross_mha(out1, enc_out, enc_out), training=training)
45
- out2 = self.norm2(out1 + attn2)
46
- ffn_out = self.dropout3(self.ffn(out2), training=training)
47
- return self.norm3(out2 + ffn_out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
 
 
 
 
 
 
 
 
 
 
 
49
  class Transformer(tf.keras.Model):
50
  def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, max_len=100, dropout=0.1):
51
  super().__init__()
@@ -56,7 +97,7 @@ class Transformer(tf.keras.Model):
56
  self.dec_embedding = layers.Embedding(target_vocab_size, d_model)
57
  self.dec_pos_embedding = layers.Embedding(max_len, d_model)
58
  self.enc_layers = [EncoderBlock(d_model, num_heads, dff, dropout) for _ in range(num_layers)]
59
- self.dec_layers = [DecoderBlock(d_model, num_heads, dff, dropout) for _ in range(num_layers)]
60
  self.final_layer = layers.Dense(target_vocab_size)
61
  def call(self, inputs, training=False):
62
  enc_inputs = inputs["enc_inputs"]
 
26
  ffn_out = self.dropout2(self.ffn(out1), training=training)
27
  return self.norm2(out1 + ffn_out)
28
 
29
+ class LoU(layers.Layer):
30
+ def __init__(self, d_model, clip_value=5.0, eps=1e-6):
31
  super().__init__()
32
+ self.d_model = d_model
33
+ self.clip_value = float(clip_value)
34
+ self.eps = float(eps)
35
+ self.Q = layers.Dense(d_model, dtype='float32')
36
+ self.K = layers.Dense(d_model, dtype='float32')
37
+ self.V = layers.Dense(d_model, dtype='float32')
38
+ self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
39
+ self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
40
+ self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
41
+
42
+ self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
43
+
44
+ def _ema_over_time(self, score, alpha_dynamic):
45
+ seq = tf.transpose(score, perm=[1, 0, 2])
46
+ alpha_seq = tf.transpose(alpha_dynamic, perm=[1, 0, 2])
47
+
48
+ def step(prev_ema, inputs):
49
+ x_t, alpha_t = inputs
50
+ new = alpha_t * x_t + (1.0 - alpha_t) * prev_ema
51
+ return new
52
+
53
+ init = seq[0]
54
+ first_alpha = alpha_seq[0]
55
+ remaining_seq = seq[1:]
56
+ remaining_alpha = alpha_seq[1:]
57
+ elems = (remaining_seq, remaining_alpha)
58
+ ema_seq = tf.scan(fn=step, elems=elems, initializer=init)
59
+ ema_seq = tf.concat([tf.expand_dims(init, 0), ema_seq], axis=0)
60
+ ema = tf.transpose(ema_seq, perm=[1, 0, 2])
61
+ return ema
62
+
63
+ def call(self, x):
64
+ x_f32 = tf.cast(x, tf.float32)
65
+ residual = x_f32
66
+ x_f32 = self.norm1(x)
67
+
68
+ q = self.Q(x_f32)
69
+ k = self.K(x_f32)
70
+ V = self.V(x_f32)
71
+ # 旮办〈 旖旊摐:
72
+ # g_q = tf.nn.sigmoid(q)
73
+ # g_k = tf.nn.sigmoid(k)
74
+
75
+ g_q = (tf.nn.tanh(q) + 1.0) / 2.0
76
+ g_k = (tf.nn.tanh(k) + 1.0) / 2.0
77
+ score = g_q * g_k
78
 
79
+ alpha_dynamic = self.alpha_linear(x_f32)
80
+ score_ema = self._ema_over_time(score, alpha_dynamic)
81
+ mean_last = tf.reduce_mean(score_ema, axis=-1, keepdims=True)
82
+ denom = tf.maximum(mean_last, self.eps)
83
+ score_norm = score_ema / denom
84
+ score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
85
+ x_comb = score_clipped * V
86
+ out = self.proj(x_comb)
87
+ out = self.norm(out + residual)
88
+ return tf.cast(out, x.dtype)
89
+
90
  class Transformer(tf.keras.Model):
91
  def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, max_len=100, dropout=0.1):
92
  super().__init__()
 
97
  self.dec_embedding = layers.Embedding(target_vocab_size, d_model)
98
  self.dec_pos_embedding = layers.Embedding(max_len, d_model)
99
  self.enc_layers = [EncoderBlock(d_model, num_heads, dff, dropout) for _ in range(num_layers)]
100
+ self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
101
  self.final_layer = layers.Dense(target_vocab_size)
102
  def call(self, inputs, training=False):
103
  enc_inputs = inputs["enc_inputs"]