Yuchan commited on
Commit
55eb46d
ยท
verified ยท
1 Parent(s): 442acd1

Update Mo.py

Browse files
Files changed (1) hide show
  1. Mo.py +23 -24
Mo.py CHANGED
@@ -124,51 +124,50 @@ class SwiGLU(layers.Layer):
124
  x_val, x_gate = tf.split(x_proj, 2, axis=-1)
125
  return self.out(x_val * tf.nn.silu(x_gate))
126
 
127
- class LoU(layers.Layer):
128
  def __init__(self, d_model, clip_value=5.0, eps=1e-6):
129
  super().__init__()
130
  self.d_model = d_model
131
  self.clip_value = float(clip_value)
132
  self.eps = float(eps)
133
 
134
- # Q/K/V ๋ณ€ํ™˜
135
  self.Q = layers.Dense(d_model, dtype='float32')
136
  self.K = layers.Dense(d_model, dtype='float32')
137
  self.V = layers.Dense(d_model, dtype='float32')
138
 
139
- # ์ •๊ทœํ™”
140
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
141
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
142
-
143
- # ๋น„์„ ํ˜• ํ‘œํ˜„๋ ฅ
144
- self.glu = SwiGLU(d_model, 320)
145
-
146
- # ํ•™์Šต ๊ฐ€๋Šฅํ•œ ๊ณผ๊ฑฐ ํ† ํฐ ๊ฐ€์ค‘์น˜
147
- self.alpha = self.add_weight(shape=(d_model,), initializer='ones', trainable=True)
148
 
149
  def call(self, x):
150
  x_f32 = tf.cast(x, tf.float32)
151
  residual = x_f32
152
- x_f32 = self.norm1(x)
153
 
154
  q = self.Q(x_f32)
155
  k = self.K(x_f32)
156
- V = self.V(x_f32)
157
 
158
  g_q = (tf.nn.tanh(q) + 1.0) / 2.0
159
  g_k = (tf.nn.tanh(k) + 1.0) / 2.0
160
-
161
- # ๊ณผ๊ฑฐ ํ† ํฐ ๊ฐ€์ค‘์น˜ ๋ฐ˜์˜ ์ ์ˆ˜
162
- score = g_q * g_k * self.alpha # element-wise scaling
163
- # ๋ˆ„์ ํ•ฉ ๋Œ€์‹  ๊ฐ€์ค‘ ํ‰๊ท 
164
- # score_t = sum_{i=0}^{t} alpha_i * V_i / sum_{i=0}^{t} alpha_i
165
- score_cum = tf.math.cumsum(score * V, axis=1)
166
- alpha_cum = tf.math.cumsum(score, axis=1)
167
- score_weighted = score_cum / tf.maximum(alpha_cum, self.eps)
168
-
169
- # ์ •๊ทœํ™” + ํด๋ฆฌํ•‘
170
- score_norm = tf.clip_by_value(score_weighted, -self.clip_value, self.clip_value)
171
- out = self.norm(score_norm + residual)
 
 
 
 
 
 
172
  out = self.glu(out)
173
  return tf.cast(out, x.dtype)
174
 
@@ -187,7 +186,7 @@ class Lo(layers.Layer):
187
  class Block(layers.Layer):
188
  def __init__(self, d_model):
189
  super().__init__()
190
- self.lou = LoU(d_model)
191
  self.lo = Lo(d_model)
192
 
193
  def call(self, x):
 
124
  x_val, x_gate = tf.split(x_proj, 2, axis=-1)
125
  return self.out(x_val * tf.nn.silu(x_gate))
126
 
127
+ class LoUScan(layers.Layer):
128
  def __init__(self, d_model, clip_value=5.0, eps=1e-6):
129
  super().__init__()
130
  self.d_model = d_model
131
  self.clip_value = float(clip_value)
132
  self.eps = float(eps)
133
 
 
134
  self.Q = layers.Dense(d_model, dtype='float32')
135
  self.K = layers.Dense(d_model, dtype='float32')
136
  self.V = layers.Dense(d_model, dtype='float32')
137
 
 
138
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
139
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
140
+ self.glu = SwiGLU(d_model, 3500)
 
 
 
 
 
141
 
142
  def call(self, x):
143
  x_f32 = tf.cast(x, tf.float32)
144
  residual = x_f32
145
+ x_f32 = self.norm1(x_f32)
146
 
147
  q = self.Q(x_f32)
148
  k = self.K(x_f32)
149
+ v = self.V(x_f32)
150
 
151
  g_q = (tf.nn.tanh(q) + 1.0) / 2.0
152
  g_k = (tf.nn.tanh(k) + 1.0) / 2.0
153
+ score = g_q * g_k # gating
154
+
155
+ # tf.scan์œผ๋กœ ์ˆœ์ฐจ ๋ˆ„์ ํ•ฉ (์ธ๊ณผ์ )
156
+ def step(carry, inputs):
157
+ prev_sum = carry
158
+ s, v_t = inputs
159
+ new_sum = prev_sum + s * v_t
160
+ # ์ •๊ทœํ™”
161
+ out = new_sum / tf.maximum(tf.reduce_sum(score[:tf.shape(prev_sum)[0]], axis=0, keepdims=True), self.eps)
162
+ return new_sum, out
163
+
164
+ # ์ดˆ๊ธฐ๊ฐ’
165
+ init = tf.zeros_like(v[0])
166
+ _, outputs = tf.scan(step, (score, v), initializer=init, axis=0)
167
+
168
+ # ์•ˆ์ •ํ™”
169
+ outputs = tf.clip_by_value(outputs, -self.clip_value, self.clip_value)
170
+ out = self.norm(outputs + residual)
171
  out = self.glu(out)
172
  return tf.cast(out, x.dtype)
173
 
 
186
  class Block(layers.Layer):
187
  def __init__(self, d_model):
188
  super().__init__()
189
+ self.lou = LoUScan(d_model)
190
  self.lo = Lo(d_model)
191
 
192
  def call(self, x):