Update V2.py
Browse files
V2.py
CHANGED
|
@@ -138,12 +138,10 @@ class HyperConv1D(layers.Layer):
|
|
| 138 |
# Input projection
|
| 139 |
self.input_proj = layers.Dense(d_model, name="input_proj")
|
| 140 |
|
| 141 |
-
# Dynamic kernel conv
|
| 142 |
-
self.
|
| 143 |
-
self.dynamic_dense = layers.Dense(self.d_mid, activation='silu')
|
| 144 |
self.dynamic_proj = layers.Dense(d_model)
|
| 145 |
self.kernel_generator = layers.Dense(k, dtype='float32')
|
| 146 |
-
self.kernel_temp = self.add_weight("kernel_temp", shape=(), initializer=tf.constant_initializer(1.0), trainable=True)
|
| 147 |
|
| 148 |
# Hypernetwork: token-wise transform before pooling
|
| 149 |
self.hyper = tf.keras.Sequential([
|
|
@@ -152,13 +150,13 @@ class HyperConv1D(layers.Layer):
|
|
| 152 |
], name="hyper")
|
| 153 |
|
| 154 |
self.attn_pool = layers.Dense(1)
|
| 155 |
-
self.scale_dense = layers.Dense(d_model
|
| 156 |
|
| 157 |
-
# Normalization + dropout
|
| 158 |
self.norm = layers.LayerNormalization()
|
| 159 |
self.dropout = layers.Dropout(dropout)
|
| 160 |
|
| 161 |
def call(self, x, training=None):
|
|
|
|
| 162 |
x_dtype = x.dtype
|
| 163 |
|
| 164 |
# 1) input projection
|
|
@@ -172,9 +170,8 @@ class HyperConv1D(layers.Layer):
|
|
| 172 |
# ------------------------------
|
| 173 |
# 2) DynamicConv local mixing
|
| 174 |
# ------------------------------
|
| 175 |
-
kernels = self.kernel_generator(self.dynamic_dense(x_proj)) # (B,L,k)
|
| 176 |
-
kernels = tf.
|
| 177 |
-
kernels = tf.nn.softmax(kernels / tf.maximum(self.kernel_temp, 1e-6), axis=-1)
|
| 178 |
|
| 179 |
x_pad = tf.pad(x_proj, [[0,0],[pad,pad],[0,0]])
|
| 180 |
x_pad_4d = tf.expand_dims(x_pad, axis=1) # (B,1,L+k-1,D)
|
|
@@ -186,7 +183,7 @@ class HyperConv1D(layers.Layer):
|
|
| 186 |
padding='VALID'
|
| 187 |
)
|
| 188 |
patches = tf.reshape(patches, [B, L, self.k, D])
|
| 189 |
-
kernels_exp = tf.
|
| 190 |
out_local = tf.reduce_sum(patches * kernels_exp, axis=2) # (B,L,D)
|
| 191 |
out_local = self.dynamic_proj(out_local)
|
| 192 |
|
|
@@ -198,12 +195,11 @@ class HyperConv1D(layers.Layer):
|
|
| 198 |
global_z = tf.nn.softmax(global_z, axis=1)
|
| 199 |
global_z = tf.reduce_sum(h * global_z, axis=1)
|
| 200 |
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
out_local = out_local * tf.expand_dims(scale, 1)
|
| 204 |
|
| 205 |
# ------------------------------
|
| 206 |
-
# 4) Residual + SiLU + LayerNorm
|
| 207 |
# ------------------------------
|
| 208 |
out = x_proj + out_local
|
| 209 |
out = tf.nn.silu(out)
|
|
@@ -211,8 +207,7 @@ class HyperConv1D(layers.Layer):
|
|
| 211 |
out = self.dropout(out, training=training)
|
| 212 |
|
| 213 |
return tf.cast(out, x_dtype)
|
| 214 |
-
|
| 215 |
-
|
| 216 |
class L2NormLayer(layers.Layer):
|
| 217 |
def __init__(self, axis=1, epsilon=1e-10, **kwargs):
|
| 218 |
super().__init__(**kwargs)
|
|
|
|
| 138 |
# Input projection
|
| 139 |
self.input_proj = layers.Dense(d_model, name="input_proj")
|
| 140 |
|
| 141 |
+
# Dynamic kernel conv
|
| 142 |
+
self.dynamic_dense = layers.Dense(d_model, activation='silu')
|
|
|
|
| 143 |
self.dynamic_proj = layers.Dense(d_model)
|
| 144 |
self.kernel_generator = layers.Dense(k, dtype='float32')
|
|
|
|
| 145 |
|
| 146 |
# Hypernetwork: token-wise transform before pooling
|
| 147 |
self.hyper = tf.keras.Sequential([
|
|
|
|
| 150 |
], name="hyper")
|
| 151 |
|
| 152 |
self.attn_pool = layers.Dense(1)
|
| 153 |
+
self.scale_dense = layers.Dense(d_model)
|
| 154 |
|
|
|
|
| 155 |
self.norm = layers.LayerNormalization()
|
| 156 |
self.dropout = layers.Dropout(dropout)
|
| 157 |
|
| 158 |
def call(self, x, training=None):
|
| 159 |
+
x_in = x
|
| 160 |
x_dtype = x.dtype
|
| 161 |
|
| 162 |
# 1) input projection
|
|
|
|
| 170 |
# ------------------------------
|
| 171 |
# 2) DynamicConv local mixing
|
| 172 |
# ------------------------------
|
| 173 |
+
kernels = self.kernel_generator(self.dynamic_dense(x_proj)) # (B, L, k)
|
| 174 |
+
kernels = tf.nn.softmax(kernels, axis=-1)
|
|
|
|
| 175 |
|
| 176 |
x_pad = tf.pad(x_proj, [[0,0],[pad,pad],[0,0]])
|
| 177 |
x_pad_4d = tf.expand_dims(x_pad, axis=1) # (B,1,L+k-1,D)
|
|
|
|
| 183 |
padding='VALID'
|
| 184 |
)
|
| 185 |
patches = tf.reshape(patches, [B, L, self.k, D])
|
| 186 |
+
kernels_exp = tf.expand_dims(kernels, axis=-1)
|
| 187 |
out_local = tf.reduce_sum(patches * kernels_exp, axis=2) # (B,L,D)
|
| 188 |
out_local = self.dynamic_proj(out_local)
|
| 189 |
|
|
|
|
| 195 |
global_z = tf.nn.softmax(global_z, axis=1)
|
| 196 |
global_z = tf.reduce_sum(h * global_z, axis=1)
|
| 197 |
|
| 198 |
+
scale = tf.expand_dims(tf.nn.sigmoid(self.scale_dense(global_z)), 1)
|
| 199 |
+
out_local = out_local * scale
|
|
|
|
| 200 |
|
| 201 |
# ------------------------------
|
| 202 |
+
# 4) Residual + SiLU + LayerNorm
|
| 203 |
# ------------------------------
|
| 204 |
out = x_proj + out_local
|
| 205 |
out = tf.nn.silu(out)
|
|
|
|
| 207 |
out = self.dropout(out, training=training)
|
| 208 |
|
| 209 |
return tf.cast(out, x_dtype)
|
| 210 |
+
|
|
|
|
| 211 |
class L2NormLayer(layers.Layer):
|
| 212 |
def __init__(self, axis=1, epsilon=1e-10, **kwargs):
|
| 213 |
super().__init__(**kwargs)
|