OpenLab-NLP commited on
Commit
aec1b8a
·
verified ·
1 Parent(s): 2e02c67

Update 연구중.py

Browse files
Files changed (1) hide show
  1. 연구중.py +23 -23
연구중.py CHANGED
@@ -128,50 +128,50 @@ ds = ds.batch(BATCH_SIZE, drop_remainder=True)
128
  ds = ds.map(lambda v1, v2: ((v1, v2), tf.zeros([BATCH_SIZE], dtype=tf.float32)), num_parallel_calls=tf.data.AUTOTUNE)
129
  ds = ds.prefetch(tf.data.AUTOTUNE)
130
 
131
- class MixerBlock(layers.Layer):
132
- def __init__(self, dim):
133
- super().__init__()
134
-
135
 
136
  class MixerBlock(layers.Layer):
137
  def __init__(self, seq_len, dim, token_mlp_dim, channel_mlp_dim, dropout=0.0):
138
  super().__init__()
 
139
 
140
- self.ln_token = layers.LayerNormalization(epsilon=1e-6, dtype=tf.float32)
141
- self.ln_attn = layers.LayerNormalization(epsilon=1e-6, dtype=tf.float32)
142
- self.ln_channel = layers.LayerNormalization(epsilon=1e-6, dtype=tf.float32)
143
- self.ch_fc1 = layers.Dense(self.dim * 4, activation=tf.nn.gelu)
144
- self.ch_fc2 = layers.Dense(self.dim)
145
 
 
 
 
146
 
147
- self.token_fc1 = layers.Dense(seq_len * 2, dtype=tf.float32)
148
- self.token_fc2 = layers.Dense(seq_len, dtype=tf.float32)
149
 
150
- self.attn = layers.Dense(1, dtype=tf.float32)
 
 
151
 
152
  def call(self, x, training=None):
153
- # x: (B, L, D)
154
-
155
- # ---------- Token Mixer (Pre-LN) ----------
156
  y = self.ln_token(x)
157
- y_t = tf.transpose(y, perm=[0, 2, 1]) # (B, D, L)
158
  y_t = self.token_fc1(y_t)
159
  a, b = tf.split(y_t, 2, axis=-1)
160
  y_t = self.token_fc2(a * tf.nn.gelu(b))
161
- y = tf.transpose(y_t, perm=[0, 2, 1]) # (B, L, D)
162
  x = x + y
163
 
164
- # ---------- Scalar Attention Reweight (Pre-LN) ----------
165
- y = self.ln_attn(x)
166
- weight = tf.nn.softmax(self.attn(y), axis=1) # (B, L, 1)
167
- y = y * weight
168
- x = x + y
 
169
 
170
- # ---------- Channel Mixer (Pre-LN) ----------
171
  y = self.ln_channel(x)
172
  y = self.ch_fc1(y)
173
  y = self.ch_fc2(y)
174
  x = x + y
 
175
  return x
176
 
177
 
 
128
  ds = ds.map(lambda v1, v2: ((v1, v2), tf.zeros([BATCH_SIZE], dtype=tf.float32)), num_parallel_calls=tf.data.AUTOTUNE)
129
  ds = ds.prefetch(tf.data.AUTOTUNE)
130
 
 
 
 
 
131
 
132
  class MixerBlock(layers.Layer):
133
  def __init__(self, seq_len, dim, token_mlp_dim, channel_mlp_dim, dropout=0.0):
134
  super().__init__()
135
+ self.dim = dim
136
 
137
+ self.ln_token = layers.LayerNormalization(epsilon=1e-6)
138
+ self.ln_gate = layers.LayerNormalization(epsilon=1e-6) # 이름 변경
139
+ self.ln_channel = layers.LayerNormalization(epsilon=1e-6)
 
 
140
 
141
+ # Token Mixer
142
+ self.token_fc1 = layers.Dense(seq_len * 2)
143
+ self.token_fc2 = layers.Dense(seq_len)
144
 
145
+ # Gating (Sigmoid) - Temperature 불필요
146
+ self.gate_dense = layers.Dense(1)
147
 
148
+ # Channel Mixer
149
+ self.ch_fc1 = layers.Dense(self.dim * 4, activation='gelu')
150
+ self.ch_fc2 = layers.Dense(self.dim)
151
 
152
  def call(self, x, training=None):
153
+ # 1. Token Mixer
 
 
154
  y = self.ln_token(x)
155
+ y_t = tf.transpose(y, perm=[0, 2, 1])
156
  y_t = self.token_fc1(y_t)
157
  a, b = tf.split(y_t, 2, axis=-1)
158
  y_t = self.token_fc2(a * tf.nn.gelu(b))
159
+ y = tf.transpose(y_t, perm=[0, 2, 1])
160
  x = x + y
161
 
162
+ # 2. Scalar Gating (수정됨)
163
+ # Softmax의 1/N 희석 문제를 해결하기 위해 Sigmoid 사용
164
+ y = self.ln_gate(x)
165
+ gate = tf.nn.sigmoid(self.gate_dense(y)) # (B, L, 1) Range: 0~1
166
+ y = y * gate
167
+ x = x + y
168
 
169
+ # 3. Channel Mixer
170
  y = self.ln_channel(x)
171
  y = self.ch_fc1(y)
172
  y = self.ch_fc2(y)
173
  x = x + y
174
+
175
  return x
176
 
177