Update V2.py
Browse files
V2.py
CHANGED
|
@@ -128,57 +128,87 @@ ds = ds.batch(BATCH_SIZE, drop_remainder=True)
|
|
| 128 |
ds = ds.map(lambda v1, v2: ((v1, v2), tf.zeros([BATCH_SIZE], dtype=tf.float32)), num_parallel_calls=tf.data.AUTOTUNE)
|
| 129 |
ds = ds.prefetch(tf.data.AUTOTUNE)
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
def __init__(self, d_model, k=7, mem_size=64, hyper_dim=128, dropout=0.0):
|
| 134 |
super().__init__()
|
| 135 |
assert k % 2 == 1
|
| 136 |
self.k = k
|
| 137 |
self.d_model = d_model
|
| 138 |
-
self.mem_size = mem_size
|
| 139 |
|
| 140 |
# Input projection
|
| 141 |
self.input_proj = layers.Dense(d_model, name="input_proj")
|
| 142 |
|
| 143 |
-
#
|
| 144 |
-
self.
|
| 145 |
-
self.
|
|
|
|
| 146 |
|
| 147 |
-
# Hypernetwork:
|
| 148 |
self.hyper = tf.keras.Sequential([
|
| 149 |
layers.Dense(hyper_dim, activation='gelu'),
|
| 150 |
layers.Dense(d_model)
|
| 151 |
], name="hyper")
|
| 152 |
|
| 153 |
-
|
| 154 |
-
self.norm = layers.LayerNormalization()
|
| 155 |
self.attn_pool = layers.Dense(1)
|
| 156 |
self.scale_dense = layers.Dense(d_model)
|
| 157 |
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
| 159 |
x_in = x
|
| 160 |
x_dtype = x.dtype
|
| 161 |
|
| 162 |
# 1) input projection
|
| 163 |
x_proj = self.input_proj(x)
|
| 164 |
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
h = self.hyper(x_proj)
|
| 169 |
global_z = self.attn_pool(h)
|
| 170 |
global_z = tf.nn.softmax(global_z, axis=1)
|
| 171 |
-
global_z = tf.reduce_sum(
|
| 172 |
|
| 173 |
scale = tf.expand_dims(tf.nn.sigmoid(self.scale_dense(global_z)), 1)
|
| 174 |
out_local = out_local * scale
|
| 175 |
-
out_local = self.local_proj(out_local)
|
| 176 |
|
|
|
|
|
|
|
|
|
|
| 177 |
out = x_proj + out_local
|
| 178 |
out = tf.nn.silu(out)
|
| 179 |
out = self.norm(out)
|
|
|
|
|
|
|
| 180 |
return tf.cast(out, x_dtype)
|
| 181 |
|
|
|
|
| 182 |
class L2NormLayer(layers.Layer):
|
| 183 |
def __init__(self, axis=1, epsilon=1e-10, **kwargs):
|
| 184 |
super().__init__(**kwargs)
|
|
|
|
| 128 |
ds = ds.map(lambda v1, v2: ((v1, v2), tf.zeros([BATCH_SIZE], dtype=tf.float32)), num_parallel_calls=tf.data.AUTOTUNE)
|
| 129 |
ds = ds.prefetch(tf.data.AUTOTUNE)
|
| 130 |
|
| 131 |
+
class HyperDynamicConv1D(layers.Layer):
|
| 132 |
+
def __init__(self, d_model, k=7, hyper_dim=128, dropout=0.0):
|
|
|
|
| 133 |
super().__init__()
|
| 134 |
assert k % 2 == 1
|
| 135 |
self.k = k
|
| 136 |
self.d_model = d_model
|
|
|
|
| 137 |
|
| 138 |
# Input projection
|
| 139 |
self.input_proj = layers.Dense(d_model, name="input_proj")
|
| 140 |
|
| 141 |
+
# Dynamic kernel conv
|
| 142 |
+
self.dynamic_dense = layers.Dense(d_model, activation='silu')
|
| 143 |
+
self.dynamic_proj = layers.Dense(d_model)
|
| 144 |
+
self.kernel_generator = layers.Dense(k, dtype='float32')
|
| 145 |
|
| 146 |
+
# Hypernetwork: token-wise transform before pooling
|
| 147 |
self.hyper = tf.keras.Sequential([
|
| 148 |
layers.Dense(hyper_dim, activation='gelu'),
|
| 149 |
layers.Dense(d_model)
|
| 150 |
], name="hyper")
|
| 151 |
|
|
|
|
|
|
|
| 152 |
self.attn_pool = layers.Dense(1)
|
| 153 |
self.scale_dense = layers.Dense(d_model)
|
| 154 |
|
| 155 |
+
self.norm = layers.LayerNormalization()
|
| 156 |
+
self.dropout = layers.Dropout(dropout)
|
| 157 |
+
|
| 158 |
+
def call(self, x, training=None):
|
| 159 |
x_in = x
|
| 160 |
x_dtype = x.dtype
|
| 161 |
|
| 162 |
# 1) input projection
|
| 163 |
x_proj = self.input_proj(x)
|
| 164 |
|
| 165 |
+
B = tf.shape(x_proj)[0]
|
| 166 |
+
L = tf.shape(x_proj)[1]
|
| 167 |
+
D = self.d_model
|
| 168 |
+
pad = (self.k - 1) // 2
|
| 169 |
+
|
| 170 |
+
# ------------------------------
|
| 171 |
+
# 2) DynamicConv local mixing
|
| 172 |
+
# ------------------------------
|
| 173 |
+
kernels = self.kernel_generator(self.dynamic_dense(x_proj)) # (B, L, k)
|
| 174 |
+
kernels = tf.nn.softmax(kernels, axis=-1)
|
| 175 |
+
|
| 176 |
+
x_pad = tf.pad(x_proj, [[0,0],[pad,pad],[0,0]])
|
| 177 |
+
x_pad_4d = tf.expand_dims(x_pad, axis=1) # (B,1,L+k-1,D)
|
| 178 |
+
patches = tf.image.extract_patches(
|
| 179 |
+
images=x_pad_4d,
|
| 180 |
+
sizes=[1,1,self.k,1],
|
| 181 |
+
strides=[1,1,1,1],
|
| 182 |
+
rates=[1,1,1,1],
|
| 183 |
+
padding='VALID'
|
| 184 |
+
)
|
| 185 |
+
patches = tf.reshape(patches, [B, L, self.k, D])
|
| 186 |
+
kernels_exp = tf.expand_dims(kernels, axis=-1)
|
| 187 |
+
out_local = tf.reduce_sum(patches * kernels_exp, axis=2) # (B,L,D)
|
| 188 |
+
out_local = self.dynamic_proj(out_local)
|
| 189 |
+
|
| 190 |
+
# ------------------------------
|
| 191 |
+
# 3) Hyper scaling
|
| 192 |
+
# ------------------------------
|
| 193 |
h = self.hyper(x_proj)
|
| 194 |
global_z = self.attn_pool(h)
|
| 195 |
global_z = tf.nn.softmax(global_z, axis=1)
|
| 196 |
+
global_z = tf.reduce_sum(h * global_z, axis=1)
|
| 197 |
|
| 198 |
scale = tf.expand_dims(tf.nn.sigmoid(self.scale_dense(global_z)), 1)
|
| 199 |
out_local = out_local * scale
|
|
|
|
| 200 |
|
| 201 |
+
# ------------------------------
|
| 202 |
+
# 4) Residual + SiLU + LayerNorm
|
| 203 |
+
# ------------------------------
|
| 204 |
out = x_proj + out_local
|
| 205 |
out = tf.nn.silu(out)
|
| 206 |
out = self.norm(out)
|
| 207 |
+
out = self.dropout(out, training=training)
|
| 208 |
+
|
| 209 |
return tf.cast(out, x_dtype)
|
| 210 |
|
| 211 |
+
|
| 212 |
class L2NormLayer(layers.Layer):
|
| 213 |
def __init__(self, axis=1, epsilon=1e-10, **kwargs):
|
| 214 |
super().__init__(**kwargs)
|