Update Test.py
Browse files
Test.py
CHANGED
|
@@ -66,13 +66,18 @@ dataset = tf.data.Dataset.from_generator(
|
|
| 66 |
class EncoderBlock(tf.keras.layers.Layer):
|
| 67 |
def __init__(self, embed_dim=EMBED_DIM, ff_dim=1152, seq_len=MAX_LEN):
|
| 68 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
| 69 |
self.fc1 = layers.Dense(ff_dim)
|
| 70 |
self.fc2 = layers.Dense(embed_dim)
|
| 71 |
self.fc3 = layers.Dense(ff_dim)
|
| 72 |
self.fc4 = layers.Dense(embed_dim)
|
| 73 |
|
|
|
|
| 74 |
self.w_proj = self.add_weight(
|
| 75 |
-
|
|
|
|
| 76 |
initializer="glorot_uniform",
|
| 77 |
trainable=True
|
| 78 |
)
|
|
@@ -82,26 +87,38 @@ class EncoderBlock(tf.keras.layers.Layer):
|
|
| 82 |
self.ln = layers.LayerNormalization(epsilon=1e-5)
|
| 83 |
self.ln1 = layers.LayerNormalization(epsilon=1e-5)
|
| 84 |
self.ln2 = layers.LayerNormalization(epsilon=1e-5)
|
| 85 |
-
|
| 86 |
def call(self, x):
|
|
|
|
| 87 |
x_norm = self.ln(x)
|
| 88 |
-
x = self.fc1(x_norm)
|
| 89 |
-
g, v = tf.split(x, 2, axis=-1)
|
| 90 |
-
x = tf.nn.silu(g) * v
|
| 91 |
-
x = self.fc2(x)
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
-
v = tf.nn.softmax(self.alpha2(v), axis=1) * x
|
| 97 |
x_norm = x_norm + self.ln2(v)
|
| 98 |
|
| 99 |
-
|
| 100 |
-
g, v = tf.split(
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
| 103 |
|
| 104 |
-
return x_norm + self.ln1(x)
|
| 105 |
|
| 106 |
|
| 107 |
class L2NormLayer(layers.Layer):
|
|
|
|
| 66 |
class EncoderBlock(tf.keras.layers.Layer):
|
| 67 |
def __init__(self, embed_dim=EMBED_DIM, ff_dim=1152, seq_len=MAX_LEN):
|
| 68 |
super().__init__()
|
| 69 |
+
self.embed_dim = embed_dim
|
| 70 |
+
self.seq_len = seq_len
|
| 71 |
+
|
| 72 |
self.fc1 = layers.Dense(ff_dim)
|
| 73 |
self.fc2 = layers.Dense(embed_dim)
|
| 74 |
self.fc3 = layers.Dense(ff_dim)
|
| 75 |
self.fc4 = layers.Dense(embed_dim)
|
| 76 |
|
| 77 |
+
# (seq_len, embed_dim)๋ก ์ ์ โ (L -> D) ํฌ์ฌ์ฉ
|
| 78 |
self.w_proj = self.add_weight(
|
| 79 |
+
name="w_proj_L_to_D",
|
| 80 |
+
shape=(seq_len, embed_dim),
|
| 81 |
initializer="glorot_uniform",
|
| 82 |
trainable=True
|
| 83 |
)
|
|
|
|
| 87 |
self.ln = layers.LayerNormalization(epsilon=1e-5)
|
| 88 |
self.ln1 = layers.LayerNormalization(epsilon=1e-5)
|
| 89 |
self.ln2 = layers.LayerNormalization(epsilon=1e-5)
|
| 90 |
+
|
| 91 |
def call(self, x):
|
| 92 |
+
# x: (B, L, D)
|
| 93 |
x_norm = self.ln(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
+
h = self.fc1(x_norm) # (B, L, ff_dim)
|
| 96 |
+
g, v = tf.split(h, 2, axis=-1) # (B, L, ff_dim/2) ๊ฐ
|
| 97 |
+
h = tf.nn.silu(g) * v
|
| 98 |
+
h = self.fc2(h) # (B, L, D)
|
| 99 |
+
|
| 100 |
+
# --- matmul -> (B, L, L) ---
|
| 101 |
+
sim = tf.matmul(h, h, transpose_b=True) # (B, L, L)
|
| 102 |
+
# (์ต์
) ์ ๊ทํ/์ค์ผ์ผ๋ง ์ํ๋ฉด ์ถ๊ฐ
|
| 103 |
+
sim = tf.nn.softmax(sim, axis=-1) # (B, L, L)
|
| 104 |
+
|
| 105 |
+
# --- (B, L, L) -> (B, L, D) : tensordot axes ๋ง์ถฐ์ ํฌ์ฌ ---
|
| 106 |
+
# w_proj: (L, D), sim last axis matches w_proj first axis
|
| 107 |
+
h2 = tf.tensordot(sim, self.w_proj, axes=[[2], [0]]) # (B, L, D)
|
| 108 |
+
|
| 109 |
+
# ์ด์ shape ๋ง์ โ v์ element-wise ๊ณฑ ๊ฐ๋ฅ
|
| 110 |
+
v_gate = tf.nn.softmax(self.alpha2(v), axis=1) # (B, L, 1)
|
| 111 |
+
v = v_gate * h2 # (B, L, D)
|
| 112 |
|
|
|
|
| 113 |
x_norm = x_norm + self.ln2(v)
|
| 114 |
|
| 115 |
+
z = self.fc3(x_norm)
|
| 116 |
+
g, v = tf.split(z, 2, axis=-1)
|
| 117 |
+
z = tf.nn.silu(g) * v
|
| 118 |
+
z = self.fc4(z)
|
| 119 |
+
|
| 120 |
+
return x_norm + self.ln1(z)
|
| 121 |
|
|
|
|
| 122 |
|
| 123 |
|
| 124 |
class L2NormLayer(layers.Layer):
|