Yuchan commited on
Commit
1e5945c
Β·
verified Β·
1 Parent(s): 7f390c3

Update Inference.py

Browse files
Files changed (1) hide show
  1. Inference.py +2 -30
Inference.py CHANGED
@@ -178,11 +178,7 @@ class LoU(layers.Layer):
178
  super().__init__()
179
  self.d_model = d_model
180
  self.clip_value = float(clip_value)
181
- self.eps = float(eps)
182
- self.Q = layers.Dense(d_model, dtype='float32')
183
- self.K = layers.Dense(d_model, dtype='float32')
184
- self.V = layers.Dense(d_model, dtype='float32')
185
- self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
186
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
187
 
188
  self.glu = SwiGLU(d_model, 320)
@@ -193,31 +189,7 @@ class LoU(layers.Layer):
193
  residual = x_f32
194
  x_f32 = self.norm1(x)
195
 
196
- q = self.Q(x_f32)
197
- k = self.K(x_f32)
198
- V = self.V(x_f32)
199
- g_q = (tf.nn.tanh(q) + 1.0) / 2.0
200
- g_k = (tf.nn.tanh(k) + 1.0) / 2.0
201
- score = g_q * g_k
202
-
203
- score = tf.cumsum(score, axis=1) # (B, L, D)
204
-
205
- # πŸ’‘ μˆ˜μ •λœ λΆ€λΆ„: ν˜„μž¬ ν† ν°κΉŒμ§€μ˜ λˆ„μ ν•© ν‰κ· μœΌλ‘œ μ •κ·œν™”
206
- seq_len = tf.shape(score)[1]
207
- # [1, 2, 3, ..., L]을 D_model μ°¨μ›μœΌλ‘œ ν™•μž₯
208
- count_for_mean = tf.cast(tf.range(seq_len) + 1, score.dtype)
209
- count_for_mean = tf.reshape(count_for_mean, (1, seq_len, 1))
210
-
211
- # λˆ„μ ν•©μ„ ν˜„μž¬κΉŒμ§€μ˜ 토큰 개수둜 λ‚˜λˆ„μ–΄ 평균 λˆ„μ ν•© 계산 (B, L, D)
212
- score_mean = score / count_for_mean
213
-
214
- # μ •κ·œν™” λΆ„λͺ¨ μ„€μ •
215
- denom = tf.maximum(score_mean, self.eps)
216
- score_norm = score / denom
217
- # -----------------------------------------------
218
-
219
- score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
220
- x_comb = score_clipped * V
221
 
222
  out = self.norm(x_comb + residual)
223
  out = self.cross(out, z)
 
178
  super().__init__()
179
  self.d_model = d_model
180
  self.clip_value = float(clip_value)
181
+ self.mha = layers.MultiHeadAttention(8, 20)
 
 
 
 
182
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
183
 
184
  self.glu = SwiGLU(d_model, 320)
 
189
  residual = x_f32
190
  x_f32 = self.norm1(x)
191
 
192
+ x_comb = self.mha(x, x, x, use_causal_mask=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
  out = self.norm(x_comb + residual)
195
  out = self.cross(out, z)