Yuchan
commited on
Update Inference.py
Browse files- Inference.py +2 -30
Inference.py
CHANGED
|
@@ -178,11 +178,7 @@ class LoU(layers.Layer):
|
|
| 178 |
super().__init__()
|
| 179 |
self.d_model = d_model
|
| 180 |
self.clip_value = float(clip_value)
|
| 181 |
-
self.
|
| 182 |
-
self.Q = layers.Dense(d_model, dtype='float32')
|
| 183 |
-
self.K = layers.Dense(d_model, dtype='float32')
|
| 184 |
-
self.V = layers.Dense(d_model, dtype='float32')
|
| 185 |
-
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 186 |
self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 187 |
|
| 188 |
self.glu = SwiGLU(d_model, 320)
|
|
@@ -193,31 +189,7 @@ class LoU(layers.Layer):
|
|
| 193 |
residual = x_f32
|
| 194 |
x_f32 = self.norm1(x)
|
| 195 |
|
| 196 |
-
|
| 197 |
-
k = self.K(x_f32)
|
| 198 |
-
V = self.V(x_f32)
|
| 199 |
-
g_q = (tf.nn.tanh(q) + 1.0) / 2.0
|
| 200 |
-
g_k = (tf.nn.tanh(k) + 1.0) / 2.0
|
| 201 |
-
score = g_q * g_k
|
| 202 |
-
|
| 203 |
-
score = tf.cumsum(score, axis=1) # (B, L, D)
|
| 204 |
-
|
| 205 |
-
# π‘ μμ λ λΆλΆ: νμ¬ ν ν°κΉμ§μ λμ ν© νκ· μΌλ‘ μ κ·ν
|
| 206 |
-
seq_len = tf.shape(score)[1]
|
| 207 |
-
# [1, 2, 3, ..., L]μ D_model μ°¨μμΌλ‘ νμ₯
|
| 208 |
-
count_for_mean = tf.cast(tf.range(seq_len) + 1, score.dtype)
|
| 209 |
-
count_for_mean = tf.reshape(count_for_mean, (1, seq_len, 1))
|
| 210 |
-
|
| 211 |
-
# λμ ν©μ νμ¬κΉμ§μ ν ν° κ°μλ‘ λλμ΄ νκ· λμ ν© κ³μ° (B, L, D)
|
| 212 |
-
score_mean = score / count_for_mean
|
| 213 |
-
|
| 214 |
-
# μ κ·ν λΆλͺ¨ μ€μ
|
| 215 |
-
denom = tf.maximum(score_mean, self.eps)
|
| 216 |
-
score_norm = score / denom
|
| 217 |
-
# -----------------------------------------------
|
| 218 |
-
|
| 219 |
-
score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
|
| 220 |
-
x_comb = score_clipped * V
|
| 221 |
|
| 222 |
out = self.norm(x_comb + residual)
|
| 223 |
out = self.cross(out, z)
|
|
|
|
| 178 |
super().__init__()
|
| 179 |
self.d_model = d_model
|
| 180 |
self.clip_value = float(clip_value)
|
| 181 |
+
self.mha = layers.MultiHeadAttention(8, 20)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 183 |
|
| 184 |
self.glu = SwiGLU(d_model, 320)
|
|
|
|
| 189 |
residual = x_f32
|
| 190 |
x_f32 = self.norm1(x)
|
| 191 |
|
| 192 |
+
x_comb = self.mha(x, x, x, use_causal_mask=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
out = self.norm(x_comb + residual)
|
| 195 |
out = self.cross(out, z)
|