Yuchan
commited on
Update Inference.py
Browse files- Inference.py +3 -24
Inference.py
CHANGED
|
@@ -148,29 +148,10 @@ class CrossBlock(layers.Layer):
|
|
| 148 |
super().__init__()
|
| 149 |
self.clip_value = clip_value
|
| 150 |
self.eps = eps
|
|
|
|
| 151 |
# ๐ก ์์ : ์ถ๋ ฅ ์ฐจ์์ 1์์ d_model๋ก ๋ณ๊ฒฝ
|
| 152 |
def call(self, x, z):
|
| 153 |
-
|
| 154 |
-
g_q = (tf.nn.tanh(x) + 1.0) / 2.0
|
| 155 |
-
g_k = (tf.nn.tanh(z) + 1.0) / 2.0
|
| 156 |
-
score = (g_q * g_k)
|
| 157 |
-
score = tf.cumsum(score, axis=1)
|
| 158 |
-
|
| 159 |
-
seq_len = tf.shape(score)[1]
|
| 160 |
-
# [1, 2, 3, ..., L]์ D_model ์ฐจ์์ผ๋ก ํ์ฅ
|
| 161 |
-
count_for_mean = tf.cast(tf.range(seq_len) + 1, score.dtype)
|
| 162 |
-
count_for_mean = tf.reshape(count_for_mean, (1, seq_len, 1))
|
| 163 |
-
|
| 164 |
-
# ๋์ ํฉ์ ํ์ฌ๊น์ง์ ํ ํฐ ๊ฐ์๋ก ๋๋์ด ํ๊ท ๋์ ํฉ ๊ณ์ฐ (B, L, D)
|
| 165 |
-
score_mean = score / count_for_mean
|
| 166 |
-
|
| 167 |
-
# ์ ๊ทํ ๋ถ๋ชจ ์ค์
|
| 168 |
-
denom = tf.maximum(score_mean, self.eps)
|
| 169 |
-
score_norm = score / denom
|
| 170 |
-
# -----------------------------------------------
|
| 171 |
-
|
| 172 |
-
score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
|
| 173 |
-
y = score_clipped * z
|
| 174 |
return y
|
| 175 |
|
| 176 |
class LoU(layers.Layer):
|
|
@@ -182,7 +163,7 @@ class LoU(layers.Layer):
|
|
| 182 |
self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 183 |
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 184 |
|
| 185 |
-
self.glu = SwiGLU(d_model,
|
| 186 |
self.cross = CrossBlock()
|
| 187 |
|
| 188 |
def call(self, x, z):
|
|
@@ -196,8 +177,6 @@ class LoU(layers.Layer):
|
|
| 196 |
out = self.cross(out, z)
|
| 197 |
out = self.glu(out)
|
| 198 |
return tf.cast(out, x.dtype)
|
| 199 |
-
|
| 200 |
-
|
| 201 |
# =======================
|
| 202 |
# 4) AlphaS2S ๋ชจ๋ธ (๊ธฐ์กด ์ฝ๋ ์ ์ง)
|
| 203 |
# =======================
|
|
|
|
| 148 |
super().__init__()
|
| 149 |
self.clip_value = clip_value
|
| 150 |
self.eps = eps
|
| 151 |
+
self.attn = layers.MultiHeadAttention(8, 20)
|
| 152 |
# ๐ก ์์ : ์ถ๋ ฅ ์ฐจ์์ 1์์ d_model๋ก ๋ณ๊ฒฝ
|
| 153 |
def call(self, x, z):
|
| 154 |
+
y = self.attn(x, z, z)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
return y
|
| 156 |
|
| 157 |
class LoU(layers.Layer):
|
|
|
|
| 163 |
self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 164 |
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 165 |
|
| 166 |
+
self.glu = SwiGLU(d_model, 350)
|
| 167 |
self.cross = CrossBlock()
|
| 168 |
|
| 169 |
def call(self, x, z):
|
|
|
|
| 177 |
out = self.cross(out, z)
|
| 178 |
out = self.glu(out)
|
| 179 |
return tf.cast(out, x.dtype)
|
|
|
|
|
|
|
| 180 |
# =======================
|
| 181 |
# 4) AlphaS2S ๋ชจ๋ธ (๊ธฐ์กด ์ฝ๋ ์ ์ง)
|
| 182 |
# =======================
|