OpenLab-NLP commited on
Commit
c60c507
·
verified ·
1 Parent(s): 9f2cecf

Update V2.py

Browse files
Files changed (1) hide show
  1. V2.py +11 -16
V2.py CHANGED
@@ -138,12 +138,10 @@ class HyperConv1D(layers.Layer):
138
  # Input projection
139
  self.input_proj = layers.Dense(d_model, name="input_proj")
140
 
141
- # Dynamic kernel conv (중간 차원)
142
- self.d_mid = max(64, d_model // 8)
143
- self.dynamic_dense = layers.Dense(self.d_mid, activation='silu')
144
  self.dynamic_proj = layers.Dense(d_model)
145
  self.kernel_generator = layers.Dense(k, dtype='float32')
146
- self.kernel_temp = self.add_weight("kernel_temp", shape=(), initializer=tf.constant_initializer(1.0), trainable=True)
147
 
148
  # Hypernetwork: token-wise transform before pooling
149
  self.hyper = tf.keras.Sequential([
@@ -152,13 +150,13 @@ class HyperConv1D(layers.Layer):
152
  ], name="hyper")
153
 
154
  self.attn_pool = layers.Dense(1)
155
- self.scale_dense = layers.Dense(d_model, bias_initializer=tf.keras.initializers.Constant(0.0))
156
 
157
- # Normalization + dropout
158
  self.norm = layers.LayerNormalization()
159
  self.dropout = layers.Dropout(dropout)
160
 
161
  def call(self, x, training=None):
 
162
  x_dtype = x.dtype
163
 
164
  # 1) input projection
@@ -172,9 +170,8 @@ class HyperConv1D(layers.Layer):
172
  # ------------------------------
173
  # 2) DynamicConv local mixing
174
  # ------------------------------
175
- kernels = self.kernel_generator(self.dynamic_dense(x_proj)) # (B,L,k)
176
- kernels = tf.cast(kernels, tf.float32)
177
- kernels = tf.nn.softmax(kernels / tf.maximum(self.kernel_temp, 1e-6), axis=-1)
178
 
179
  x_pad = tf.pad(x_proj, [[0,0],[pad,pad],[0,0]])
180
  x_pad_4d = tf.expand_dims(x_pad, axis=1) # (B,1,L+k-1,D)
@@ -186,7 +183,7 @@ class HyperConv1D(layers.Layer):
186
  padding='VALID'
187
  )
188
  patches = tf.reshape(patches, [B, L, self.k, D])
189
- kernels_exp = tf.cast(tf.expand_dims(kernels, axis=-1), x_proj.dtype)
190
  out_local = tf.reduce_sum(patches * kernels_exp, axis=2) # (B,L,D)
191
  out_local = self.dynamic_proj(out_local)
192
 
@@ -198,12 +195,11 @@ class HyperConv1D(layers.Layer):
198
  global_z = tf.nn.softmax(global_z, axis=1)
199
  global_z = tf.reduce_sum(h * global_z, axis=1)
200
 
201
- # residual-gate 스타일 scale: 1 + α*tanh(...)
202
- scale = 1.0 + 0.1 * tf.tanh(self.scale_dense(global_z))
203
- out_local = out_local * tf.expand_dims(scale, 1)
204
 
205
  # ------------------------------
206
- # 4) Residual + SiLU + LayerNorm + Dropout
207
  # ------------------------------
208
  out = x_proj + out_local
209
  out = tf.nn.silu(out)
@@ -211,8 +207,7 @@ class HyperConv1D(layers.Layer):
211
  out = self.dropout(out, training=training)
212
 
213
  return tf.cast(out, x_dtype)
214
-
215
-
216
  class L2NormLayer(layers.Layer):
217
  def __init__(self, axis=1, epsilon=1e-10, **kwargs):
218
  super().__init__(**kwargs)
 
138
  # Input projection
139
  self.input_proj = layers.Dense(d_model, name="input_proj")
140
 
141
+ # Dynamic kernel conv
142
+ self.dynamic_dense = layers.Dense(d_model, activation='silu')
 
143
  self.dynamic_proj = layers.Dense(d_model)
144
  self.kernel_generator = layers.Dense(k, dtype='float32')
 
145
 
146
  # Hypernetwork: token-wise transform before pooling
147
  self.hyper = tf.keras.Sequential([
 
150
  ], name="hyper")
151
 
152
  self.attn_pool = layers.Dense(1)
153
+ self.scale_dense = layers.Dense(d_model)
154
 
 
155
  self.norm = layers.LayerNormalization()
156
  self.dropout = layers.Dropout(dropout)
157
 
158
  def call(self, x, training=None):
159
+ x_in = x
160
  x_dtype = x.dtype
161
 
162
  # 1) input projection
 
170
  # ------------------------------
171
  # 2) DynamicConv local mixing
172
  # ------------------------------
173
+ kernels = self.kernel_generator(self.dynamic_dense(x_proj)) # (B, L, k)
174
+ kernels = tf.nn.softmax(kernels, axis=-1)
 
175
 
176
  x_pad = tf.pad(x_proj, [[0,0],[pad,pad],[0,0]])
177
  x_pad_4d = tf.expand_dims(x_pad, axis=1) # (B,1,L+k-1,D)
 
183
  padding='VALID'
184
  )
185
  patches = tf.reshape(patches, [B, L, self.k, D])
186
+ kernels_exp = tf.expand_dims(kernels, axis=-1)
187
  out_local = tf.reduce_sum(patches * kernels_exp, axis=2) # (B,L,D)
188
  out_local = self.dynamic_proj(out_local)
189
 
 
195
  global_z = tf.nn.softmax(global_z, axis=1)
196
  global_z = tf.reduce_sum(h * global_z, axis=1)
197
 
198
+ scale = tf.expand_dims(tf.nn.sigmoid(self.scale_dense(global_z)), 1)
199
+ out_local = out_local * scale
 
200
 
201
  # ------------------------------
202
+ # 4) Residual + SiLU + LayerNorm
203
  # ------------------------------
204
  out = x_proj + out_local
205
  out = tf.nn.silu(out)
 
207
  out = self.dropout(out, training=training)
208
 
209
  return tf.cast(out, x_dtype)
210
+
 
211
  class L2NormLayer(layers.Layer):
212
  def __init__(self, axis=1, epsilon=1e-10, **kwargs):
213
  super().__init__(**kwargs)