Yuchan
commited on
Update AlphaS2S.py
Browse files- AlphaS2S.py +47 -41
AlphaS2S.py
CHANGED
|
@@ -182,62 +182,68 @@ class SwiGLU(layers.Layer):
|
|
| 182 |
x_proj = self.proj(x)
|
| 183 |
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
|
| 184 |
return self.out(x_val * tf.nn.silu(x_gate))
|
| 185 |
-
|
| 186 |
-
class
|
| 187 |
-
def __init__(self, d_model,
|
| 188 |
super().__init__()
|
| 189 |
self.d_model = d_model
|
| 190 |
-
self.
|
| 191 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
-
#
|
| 194 |
-
|
|
|
|
| 195 |
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
kernel_size=kernel_size,
|
| 201 |
-
padding='same',
|
| 202 |
-
activation='relu',
|
| 203 |
-
dilation_rate=rate, # ๋๋ ์ด์
๋ ์ดํธ ์ ์ฉ
|
| 204 |
-
name=f"dconv_{i+1}_rate_{rate}"
|
| 205 |
-
)
|
| 206 |
-
self.conv_layers.append(conv)
|
| 207 |
|
| 208 |
-
def call(self, x):
|
| 209 |
-
#
|
| 210 |
residual = x
|
|
|
|
|
|
|
| 211 |
|
| 212 |
-
#
|
| 213 |
-
|
| 214 |
|
| 215 |
-
#
|
| 216 |
-
|
|
|
|
| 217 |
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
|
| 222 |
-
#
|
| 223 |
-
#
|
| 224 |
-
|
|
|
|
|
|
|
| 225 |
|
| 226 |
-
#
|
| 227 |
-
|
|
|
|
|
|
|
| 228 |
|
| 229 |
-
|
|
|
|
| 230 |
|
| 231 |
class CrossBlock(layers.Layer):
|
| 232 |
-
def __init__(self
|
| 233 |
super().__init__()
|
| 234 |
-
# ๐ก ์์ : ์ถ๋ ฅ ์ฐจ์์ 1์์ d_model๋ก ๋ณ๊ฒฝ
|
| 235 |
-
self.alpha = layers.Dense(d_model, activation='sigmoid', dtype='float32')
|
| 236 |
def call(self, x, z):
|
| 237 |
# a์ shape: (Batch, Seq_len, D_model)
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
y =
|
| 241 |
return y
|
| 242 |
|
| 243 |
class LoU(layers.Layer):
|
|
@@ -254,7 +260,7 @@ class LoU(layers.Layer):
|
|
| 254 |
|
| 255 |
self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
|
| 256 |
self.glu = SwiGLU(d_model, d_model)
|
| 257 |
-
self.cross = CrossBlock(
|
| 258 |
|
| 259 |
def _ema_over_time(self, score, alpha_dynamic):
|
| 260 |
seq = tf.transpose(score, perm=[1, 0, 2])
|
|
@@ -320,7 +326,7 @@ class AlphaS2S(tf.keras.Model):
|
|
| 320 |
self.dec_pos_embedding = layers.Embedding(max_len, d_model)
|
| 321 |
|
| 322 |
# EncoderBlock๊ณผ LoU๋ ๊ธฐ์กด ์ฝ๋์ ๋์ผํ ๊ตฌ์กฐ
|
| 323 |
-
self.enc_layers = [
|
| 324 |
self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
|
| 325 |
|
| 326 |
self.final_layer = layers.Dense(target_vocab_size, use_bias=False)
|
|
|
|
| 182 |
x_proj = self.proj(x)
|
| 183 |
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
|
| 184 |
return self.out(x_val * tf.nn.silu(x_gate))
|
| 185 |
+
|
| 186 |
+
class gMLPBlock(layers.Layer):
|
| 187 |
+
def __init__(self, d_model, seq_len, dropout=0.1):
|
| 188 |
super().__init__()
|
| 189 |
self.d_model = d_model
|
| 190 |
+
self.seq_len = seq_len
|
| 191 |
+
self.norm = layers.LayerNormalization(epsilon=1e-6)
|
| 192 |
+
|
| 193 |
+
# FFN: Channel Expansion
|
| 194 |
+
# d_model * 4๋ก ํ์ฅ
|
| 195 |
+
self.channel_proj = layers.Dense(d_model * 4, use_bias=True)
|
| 196 |
+
self.dropout = layers.Dropout(dropout)
|
| 197 |
|
| 198 |
+
# Spatial Gating Unit (SGU)
|
| 199 |
+
self.sgu_norm = layers.LayerNormalization(epsilon=1e-6)
|
| 200 |
+
self.sgu_proj = layers.Dense(seq_len, use_bias=False)
|
| 201 |
|
| 202 |
+
# ์ถ๋ ฅ ์ฐจ์์ d_model * 2 (U์ ์ฐจ์)๋ก ์ค์
|
| 203 |
+
self.sgu_final = layers.Dense(d_model * 2, use_bias=True)
|
| 204 |
+
|
| 205 |
+
self.out_proj = layers.Dense(d_model, use_bias=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
+
def call(self, x, training=False):
|
| 208 |
+
# 1. Norm and Channel Expansion
|
| 209 |
residual = x
|
| 210 |
+
x_norm = self.norm(x)
|
| 211 |
+
x_proj = self.channel_proj(x_norm) # Shape: (B, L, 4*D)
|
| 212 |
|
| 213 |
+
# 2. Split (U and V streams)
|
| 214 |
+
u, v = tf.split(x_proj, 2, axis=-1) # u, v Shape: (B, L, 2*D)
|
| 215 |
|
| 216 |
+
# 3. Spatial Gating Unit (SGU)
|
| 217 |
+
v_norm = self.sgu_norm(v)
|
| 218 |
+
v_norm_T = tf.transpose(v_norm, perm=[0, 2, 1]) # (B, 2D, L)
|
| 219 |
|
| 220 |
+
# ๐ก ํ ํฐ ๋ฏน์ฑ ๋ฐ์ (์ํ์ค ์ถ์ผ๋ก Dense ์ ์ฉ)
|
| 221 |
+
v_proj = self.sgu_proj(v_norm_T) # (B, 2D, L)
|
| 222 |
+
v_proj_T = tf.transpose(v_proj, perm=[0, 2, 1]) # (B, L, 2D)
|
| 223 |
|
| 224 |
+
# 4. Activation and Gate Generation
|
| 225 |
+
# ํ์ค gMLP๋ U์ GELU๋ฅผ ์ ์ฉํ๊ณ V๋ ์ ํ ๊ฒ์ดํธ๋ก ์ฌ์ฉ
|
| 226 |
+
# ์ฌ๊ธฐ์๋ U์ GELU๋ฅผ ์ ์ฉ
|
| 227 |
+
u_act = tf.nn.gelu(u)
|
| 228 |
+
v_gate = self.sgu_final(v_proj_T) # Shape: (B, L, 2*D)
|
| 229 |
|
| 230 |
+
# 5. Gating and Contraction
|
| 231 |
+
z = u_act * v_gate # ๊ฒ์ดํ
|
| 232 |
+
z = self.dropout(z, training=training)
|
| 233 |
+
out = self.out_proj(z) # Shape: (B, L, D)
|
| 234 |
|
| 235 |
+
# 6. Residual Connection
|
| 236 |
+
return residual + out
|
| 237 |
|
| 238 |
class CrossBlock(layers.Layer):
|
| 239 |
+
def __init__(self): # ๐ก d_model ์ธ์ ์ถ๊ฐ
|
| 240 |
super().__init__()
|
| 241 |
+
# ๐ก ์์ : ์ถ๋ ฅ ์ฐจ์์ 1์์ d_model๋ก ๋ณ๊ฒฝ
|
|
|
|
| 242 |
def call(self, x, z):
|
| 243 |
# a์ shape: (Batch, Seq_len, D_model)
|
| 244 |
+
g_q = (tf.nn.tanh(x) + 1.0) / 2.0
|
| 245 |
+
g_k = (tf.nn.tanh(z) + 1.0) / 2.0
|
| 246 |
+
y = (g_q * g_k) * z
|
| 247 |
return y
|
| 248 |
|
| 249 |
class LoU(layers.Layer):
|
|
|
|
| 260 |
|
| 261 |
self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
|
| 262 |
self.glu = SwiGLU(d_model, d_model)
|
| 263 |
+
self.cross = CrossBlock()
|
| 264 |
|
| 265 |
def _ema_over_time(self, score, alpha_dynamic):
|
| 266 |
seq = tf.transpose(score, perm=[1, 0, 2])
|
|
|
|
| 326 |
self.dec_pos_embedding = layers.Embedding(max_len, d_model)
|
| 327 |
|
| 328 |
# EncoderBlock๊ณผ LoU๋ ๊ธฐ์กด ์ฝ๋์ ๋์ผํ ๊ตฌ์กฐ
|
| 329 |
+
self.enc_layers = [gMLPBlock(d_model, seq_len=max_len) for _ in range(num_layers)]
|
| 330 |
self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
|
| 331 |
|
| 332 |
self.final_layer = layers.Dense(target_vocab_size, use_bias=False)
|