Yuchan commited on
Commit
0d505a8
·
verified ·
1 Parent(s): fff43f5

Update Mo.py

Browse files
Files changed (1) hide show
  1. Mo.py +48 -52
Mo.py CHANGED
@@ -124,57 +124,53 @@ class SwiGLU(layers.Layer):
124
  x_val, x_gate = tf.split(x_proj, 2, axis=-1)
125
  return self.out(x_val * tf.nn.silu(x_gate))
126
 
127
- class LoUScan(layers.Layer):
128
- def __init__(self, d_model, clip_value=5.0, eps=1e-6):
129
- super().__init__()
130
- self.d_model = d_model
131
- self.clip_value = float(clip_value)
132
- self.eps = float(eps)
133
-
134
- self.Q = layers.Dense(d_model, dtype='float32')
135
- self.K = layers.Dense(d_model, dtype='float32')
136
- self.V = layers.Dense(d_model, dtype='float32')
137
-
138
- self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
139
- self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
140
- self.glu = SwiGLU(d_model, 3500) # 사용자 정의 GLU
141
-
142
- def call(self, x):
143
- x_f32 = tf.cast(x, tf.float32)
144
- residual = x_f32
145
- x_f32 = self.norm1(x_f32)
146
-
147
- q = self.Q(x_f32)
148
- k = self.K(x_f32)
149
- v = self.V(x_f32)
150
-
151
- g_q = (tf.nn.tanh(q) + 1.0) / 2.0
152
- g_k = (tf.nn.tanh(k) + 1.0) / 2.0
153
- score = g_q * g_k # element-wise gating
154
-
155
- # 배치별 순차적 scan 적용 (인과적)
156
- def process_sequence(inputs):
157
- score_seq, v_seq = inputs
158
- seq_len = tf.shape(v_seq)[0]
159
- init = tf.zeros_like(v_seq[0])
160
-
161
- def step(carry, elems):
162
- s_t, v_t = elems
163
- new_sum = carry + s_t * v_t # 현재까지 누적
164
- out = new_sum / tf.maximum(tf.reduce_sum(score_seq[:tf.shape(v_seq)[0]], axis=0, keepdims=True), self.eps)
165
- return new_sum, out
166
-
167
- _, outputs = tf.scan(step, (score_seq, v_seq), initializer=init)
168
- return outputs
169
-
170
- # 배치 차원 처리
171
- outputs = tf.map_fn(lambda inp: process_sequence(inp), (score, v), dtype=tf.float32)
172
-
173
- outputs = tf.clip_by_value(outputs, -self.clip_value, self.clip_value)
174
- out = self.norm(outputs + residual)
175
- out = self.glu(out)
176
- return tf.cast(out, x.dtype)
177
-
178
 
179
  class Lo(layers.Layer):
180
  def __init__(self, d_model):
@@ -191,7 +187,7 @@ class Lo(layers.Layer):
191
  class Block(layers.Layer):
192
  def __init__(self, d_model):
193
  super().__init__()
194
- self.lou = LoUScan(d_model)
195
  self.lo = Lo(d_model)
196
 
197
  def call(self, x):
 
124
  x_val, x_gate = tf.split(x_proj, 2, axis=-1)
125
  return self.out(x_val * tf.nn.silu(x_gate))
126
 
127
+ class LoU(layers.Layer):
128
+     def __init__(self, d_model, clip_value=5.0, eps=1e-6):
129
+         super().__init__()
130
+         self.d_model = d_model
131
+         self.clip_value = float(clip_value)
132
+         self.eps = float(eps)
133
+         self.Q = layers.Dense(d_model, dtype='float32')
134
+         self.K = layers.Dense(d_model, dtype='float32')
135
+         self.V = layers.Dense(d_model, dtype='float32')
136
+         self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
137
+         self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
138
+         
139
+         self.glu = SwiGLU(d_model, 320)
140
+     def call(self, x):
141
+         x_f32 = tf.cast(x, tf.float32)
142
+         residual = x_f32
143
+         x_f32 = self.norm1(x)
144
+
145
+         q = self.Q(x_f32)
146
+         k = self.K(x_f32)
147
+         V = self.V(x_f32)
148
+         g_q = (tf.nn.tanh(q) + 1.0) / 2.0
149
+         g_k = (tf.nn.tanh(k) + 1.0) / 2.0
150
+         score = g_q * g_k
151
+
152
+         score = tf.cumsum(score, axis=1) # (B, L, D)
153
+         
154
+         # 💡 수정된 부분: 현재 토큰까지의 누적합 평균으로 정규화
155
+         seq_len = tf.shape(score)[1]
156
+         # [1, 2, 3, ..., L]을 D_model 차원으로 확장
157
+         count_for_mean = tf.cast(tf.range(seq_len) + 1, score.dtype)
158
+         count_for_mean = tf.reshape(count_for_mean, (1, seq_len, 1))
159
+         
160
+         # 누적합을 현재까지의 토큰 개수로 나누어 평균 누적합 계산 (B, L, D)
161
+         score_mean = score / count_for_mean
162
+         
163
+         # 정규화 분모 설정
164
+         denom = tf.maximum(score_mean, self.eps)
165
+         score_norm = score / denom
166
+         # -----------------------------------------------
167
+
168
+         score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
169
+         x_comb = score_clipped * V
170
+         
171
+         out = self.norm(x_comb + residual)
172
+         out = self.glu(out)
173
+         return tf.cast(out, x.dtype)
 
 
 
 
174
 
175
  class Lo(layers.Layer):
176
  def __init__(self, d_model):
 
187
  class Block(layers.Layer):
188
  def __init__(self, d_model):
189
  super().__init__()
190
+ self.lou = LoU(d_model)
191
  self.lo = Lo(d_model)
192
 
193
  def call(self, x):