OpenLab-NLP commited on
Commit
cc13462
·
verified ·
1 Parent(s): 52becc8

Update 연구중.py

Browse files
Files changed (1) hide show
  1. 연구중.py +27 -14
연구중.py CHANGED
@@ -128,38 +128,51 @@ ds = ds.batch(BATCH_SIZE, drop_remainder=True)
128
  ds = ds.map(lambda v1, v2: ((v1, v2), tf.zeros([BATCH_SIZE], dtype=tf.float32)), num_parallel_calls=tf.data.AUTOTUNE)
129
  ds = ds.prefetch(tf.data.AUTOTUNE)
130
 
 
 
 
 
131
 
132
  class MixerBlock(layers.Layer):
133
  def __init__(self, seq_len, dim, token_mlp_dim, channel_mlp_dim, dropout=0.0):
134
  super().__init__()
135
  self.dim = dim
136
 
137
- self.ln_token = layers.LayerNormalization(epsilon=1e-6)
138
- self.ln_channel = layers.LayerNormalization(epsilon=1e-6)
 
 
 
139
 
140
- # Token Mixer
141
- self.token_fc1 = layers.Dense(seq_len * 4)
142
- self.token_fc2 = layers.Dense(seq_len)
143
 
144
- # Channel Mixer
145
- self.ch_fc1 = layers.Dense(self.dim * 4)
146
- self.ch_fc2 = layers.Dense(self.dim)
 
147
 
148
  def call(self, x, training=None):
149
- # 1. Token Mixer
 
 
150
  y = self.ln_token(x)
151
- y_t = tf.transpose(y, perm=[0, 2, 1])
152
  y_t = self.token_fc1(y_t)
153
  a, b = tf.split(y_t, 2, axis=-1)
154
  y_t = self.token_fc2(a * tf.nn.gelu(b))
155
- y = tf.transpose(y_t, perm=[0, 2, 1])
156
  x = x + y
157
 
158
- y = self.ln_channel(x)
159
- a, b = tf.split(self.ch_fc1(y), 2, axis=-1)
160
- y = self.ch_fc2(a * tf.nn.gelu(b))
 
161
  x = x + y
162
 
 
 
 
 
 
163
  return x
164
 
165
 
 
128
  ds = ds.map(lambda v1, v2: ((v1, v2), tf.zeros([BATCH_SIZE], dtype=tf.float32)), num_parallel_calls=tf.data.AUTOTUNE)
129
  ds = ds.prefetch(tf.data.AUTOTUNE)
130
 
131
+ class MixerBlock(layers.Layer):
132
+ def __init__(self, dim):
133
+ super().__init__()
134
+
135
 
136
  class MixerBlock(layers.Layer):
137
  def __init__(self, seq_len, dim, token_mlp_dim, channel_mlp_dim, dropout=0.0):
138
  super().__init__()
139
  self.dim = dim
140
 
141
+ self.ln_token = layers.LayerNormalization(epsilon=1e-6, dtype=tf.float32)
142
+ self.ln_attn = layers.LayerNormalization(epsilon=1e-6, dtype=tf.float32)
143
+ self.ln_channel = layers.LayerNormalization(epsilon=1e-6, dtype=tf.float32)
144
+ self.ch_fc1 = layers.Dense(self.dim * 4, activation=tf.nn.gelu)
145
+ self.ch_fc2 = layers.Dense(self.dim)
146
 
 
 
 
147
 
148
+ self.token_fc1 = layers.Dense(seq_len * 2, dtype=tf.float32)
149
+ self.token_fc2 = layers.Dense(seq_len, dtype=tf.float32)
150
+
151
+ self.attn = layers.Dense(1, dtype=tf.float32)
152
 
153
  def call(self, x, training=None):
154
+ # x: (B, L, D)
155
+
156
+ # ---------- Token Mixer (Pre-LN) ----------
157
  y = self.ln_token(x)
158
+ y_t = tf.transpose(y, perm=[0, 2, 1]) # (B, D, L)
159
  y_t = self.token_fc1(y_t)
160
  a, b = tf.split(y_t, 2, axis=-1)
161
  y_t = self.token_fc2(a * tf.nn.gelu(b))
162
+ y = tf.transpose(y_t, perm=[0, 2, 1]) # (B, L, D)
163
  x = x + y
164
 
165
+ # ---------- Scalar Attention Reweight (Pre-LN) ----------
166
+ y = self.ln_attn(x)
167
+ weight = tf.nn.softmax(self.attn(y), axis=1) # (B, L, 1)
168
+ y = y * weight
169
  x = x + y
170
 
171
+ # ---------- Channel Mixer (Pre-LN) ----------
172
+ y = self.ln_channel(x)
173
+ y = self.ch_fc1(y)
174
+ y = self.ch_fc2(y)
175
+ x = x + y
176
  return x
177
 
178