Yuchan
commited on
Update AlphaS2S.py
Browse files- AlphaS2S.py +36 -35
AlphaS2S.py
CHANGED
|
@@ -182,49 +182,51 @@ class SwiGLU(layers.Layer):
|
|
| 182 |
x_proj = self.proj(x)
|
| 183 |
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
|
| 184 |
return self.out(x_val * tf.nn.silu(x_gate))
|
| 185 |
-
|
| 186 |
-
class
|
| 187 |
-
def __init__(self, d_model,
|
| 188 |
super().__init__()
|
| 189 |
self.d_model = d_model
|
| 190 |
-
self.
|
| 191 |
-
self.
|
| 192 |
-
|
| 193 |
-
# d_model * 4๋ก ํ์ฅ
|
| 194 |
-
self.channel_proj = layers.Dense(d_model * 4, use_bias=True)
|
| 195 |
-
self.dropout = layers.Dropout(dropout)
|
| 196 |
|
| 197 |
-
#
|
| 198 |
-
|
| 199 |
-
self.sgu_proj = layers.Dense(seq_len, use_bias=False)
|
| 200 |
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
-
def call(self, x
|
|
|
|
| 207 |
residual = x
|
| 208 |
-
x = self.norm(x)
|
| 209 |
-
x = self.channel_proj(x) # Shape: (B, L, 4*D)
|
| 210 |
|
| 211 |
-
|
|
|
|
| 212 |
|
| 213 |
-
#
|
| 214 |
-
|
| 215 |
-
v_norm_T = tf.transpose(v_norm, perm=[0, 2, 1]) # Shape: (B, 2*D, L)
|
| 216 |
-
v_proj = self.sgu_proj(v_norm_T) # Shape: (B, 2*D, L)
|
| 217 |
-
v_proj_T = tf.transpose(v_proj, perm=[0, 2, 1]) # Shape: (B, L, 2*D)
|
| 218 |
-
v_gate = self.sgu_final(v_proj_T) # Shape: (B, L, 2*D)
|
| 219 |
|
| 220 |
-
|
| 221 |
-
|
|
|
|
| 222 |
|
| 223 |
-
#
|
| 224 |
-
|
| 225 |
-
|
| 226 |
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
| 228 |
|
| 229 |
class CrossBlock(layers.Layer):
|
| 230 |
def __init__(self, d_model): # ๐ก d_model ์ธ์ ์ถ๊ฐ
|
|
@@ -238,7 +240,6 @@ class CrossBlock(layers.Layer):
|
|
| 238 |
y = a * x + (1.0 - a) * z
|
| 239 |
return y
|
| 240 |
|
| 241 |
-
|
| 242 |
class LoU(layers.Layer):
|
| 243 |
def __init__(self, d_model, clip_value=5.0, eps=1e-6):
|
| 244 |
super().__init__()
|
|
@@ -319,7 +320,7 @@ class AlphaS2S(tf.keras.Model):
|
|
| 319 |
self.dec_pos_embedding = layers.Embedding(max_len, d_model)
|
| 320 |
|
| 321 |
# EncoderBlock๊ณผ LoU๋ ๊ธฐ์กด ์ฝ๋์ ๋์ผํ ๊ตฌ์กฐ
|
| 322 |
-
self.enc_layers = [
|
| 323 |
self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
|
| 324 |
|
| 325 |
self.final_layer = layers.Dense(target_vocab_size, use_bias=False)
|
|
@@ -405,7 +406,7 @@ with strategy.scope():
|
|
| 405 |
masked_perplexity
|
| 406 |
]
|
| 407 |
)
|
| 408 |
-
|
| 409 |
print("โ
๋ชจ๋ธ ์ปดํ์ผ ์๋ฃ, ํ์ต ์์...")
|
| 410 |
# โ ๏ธ ํ์ต ์คํ
|
| 411 |
history = chat_model.fit(dataset, epochs=1, verbose=1)
|
|
|
|
| 182 |
x_proj = self.proj(x)
|
| 183 |
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
|
| 184 |
return self.out(x_val * tf.nn.silu(x_gate))
|
| 185 |
+
|
| 186 |
+
class DilatedConvBlock(layers.Layer):
|
| 187 |
+
def __init__(self, d_model, num_layers=5, kernel_size=3):
|
| 188 |
super().__init__()
|
| 189 |
self.d_model = d_model
|
| 190 |
+
self.kernel_size = kernel_size
|
| 191 |
+
self.conv_layers = []
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
+
# ๋๋ ์ด์
๋ ์ดํธ: 1, 2, 4, 8, 16
|
| 194 |
+
dilation_rates = [2**i for i in range(num_layers)]
|
|
|
|
| 195 |
|
| 196 |
+
for i, rate in enumerate(dilation_rates):
|
| 197 |
+
# ๐ก filters=d_model, kernel_size=3, padding='same', activation='relu' ์กฐ๊ฑด ์ ์ง
|
| 198 |
+
conv = layers.Conv1D(
|
| 199 |
+
filters=d_model,
|
| 200 |
+
kernel_size=kernel_size,
|
| 201 |
+
padding='same',
|
| 202 |
+
activation='relu',
|
| 203 |
+
dilation_rate=rate, # ๋๋ ์ด์
๋ ์ดํธ ์ ์ฉ
|
| 204 |
+
name=f"dconv_{i+1}_rate_{rate}"
|
| 205 |
+
)
|
| 206 |
+
self.conv_layers.append(conv)
|
| 207 |
|
| 208 |
+
def call(self, x):
|
| 209 |
+
# ์
๋ ฅ x๋ฅผ ์ฒซ ๋ฒ์งธ ๋ ์ด์ด์ ์ถ๋ ฅ๊ณผ ํฉ์ฐํ๊ธฐ ์ํด residual๋ก ์ ์ฅ
|
| 210 |
residual = x
|
|
|
|
|
|
|
| 211 |
|
| 212 |
+
# 5๊ฐ์ ๋๋ ์ดํฐ๋ ์ปจ๋ณผ๋ฃจ์
๋ ์ด์ด๋ฅผ ์์ฐจ์ ์ผ๋ก ์ ์ฉํ๊ณ ์ถ๋ ฅ์ ๋์
|
| 213 |
+
outputs = []
|
| 214 |
|
| 215 |
+
# ๊ฐ ๋ ์ด์ด์ ์ถ๋ ฅ์ ๋ค์ ๋ ์ด์ด์ ์
๋ ฅ์ผ๋ก ์ฌ์ฉํ์ง ์๊ณ ,
|
| 216 |
+
# ์๋ณธ ์
๋ ฅ X์ ๋ํด ๋
๋ฆฝ์ ์ผ๋ก Conv ์ฐ์ฐ์ ์ํํ ํ ์ถ๋ ฅ์ ํฉ์ฐํ๋ ๋ฐฉ์ (Residual/Parallel)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
+
for conv in self.conv_layers:
|
| 219 |
+
conv_out = conv(x)
|
| 220 |
+
outputs.append(conv_out)
|
| 221 |
|
| 222 |
+
# 5๊ฐ ๋ ์ด์ด์ ์ถ๋ ฅ์ ๋ชจ๋ ํฉ์ฐํ์ฌ ์ต์ข
๊ฒฐ๊ณผ๋ฅผ ์ป์
|
| 223 |
+
# (์ด๋ Skip Connection์ด๋ Residual Connection์ ์ผ๋ฐ์ ์ธ ํํ์
๋๋ค.)
|
| 224 |
+
final_output = tf.add_n(outputs)
|
| 225 |
|
| 226 |
+
# ํ์ํ๋ค๋ฉด ์ต์ข
๊ฒฐ๊ณผ์ ์์ฐจ ์ฐ๊ฒฐ (residual connection)์ ์ถ๊ฐํ ์๋ ์์ต๋๋ค.
|
| 227 |
+
final_output = final_output + residual
|
| 228 |
+
|
| 229 |
+
return final_output
|
| 230 |
|
| 231 |
class CrossBlock(layers.Layer):
|
| 232 |
def __init__(self, d_model): # ๐ก d_model ์ธ์ ์ถ๊ฐ
|
|
|
|
| 240 |
y = a * x + (1.0 - a) * z
|
| 241 |
return y
|
| 242 |
|
|
|
|
| 243 |
class LoU(layers.Layer):
|
| 244 |
def __init__(self, d_model, clip_value=5.0, eps=1e-6):
|
| 245 |
super().__init__()
|
|
|
|
| 320 |
self.dec_pos_embedding = layers.Embedding(max_len, d_model)
|
| 321 |
|
| 322 |
# EncoderBlock๊ณผ LoU๋ ๊ธฐ์กด ์ฝ๋์ ๋์ผํ ๊ตฌ์กฐ
|
| 323 |
+
self.enc_layers = [DilatedConvBlock(d_model) for _ in range(num_layers)]
|
| 324 |
self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
|
| 325 |
|
| 326 |
self.final_layer = layers.Dense(target_vocab_size, use_bias=False)
|
|
|
|
| 406 |
masked_perplexity
|
| 407 |
]
|
| 408 |
)
|
| 409 |
+
chat_model.summary()
|
| 410 |
print("โ
๋ชจ๋ธ ์ปดํ์ผ ์๋ฃ, ํ์ต ์์...")
|
| 411 |
# โ ๏ธ ํ์ต ์คํ
|
| 412 |
history = chat_model.fit(dataset, epochs=1, verbose=1)
|