Update V2.py
Browse files
V2.py
CHANGED
|
@@ -128,45 +128,6 @@ ds = ds.batch(BATCH_SIZE, drop_remainder=True)
|
|
| 128 |
ds = ds.map(lambda v1, v2: ((v1, v2), tf.zeros([BATCH_SIZE], dtype=tf.float32)), num_parallel_calls=tf.data.AUTOTUNE)
|
| 129 |
ds = ds.prefetch(tf.data.AUTOTUNE)
|
| 130 |
|
| 131 |
-
class DynamicConvTPU(layers.Layer):
|
| 132 |
-
def __init__(self, d_model, k=7):
|
| 133 |
-
super().__init__()
|
| 134 |
-
assert k % 2 == 1
|
| 135 |
-
self.k = k
|
| 136 |
-
self.d_model = d_model
|
| 137 |
-
|
| 138 |
-
self.dense = layers.Dense(d_model, activation='silu')
|
| 139 |
-
self.proj = layers.Dense(d_model)
|
| 140 |
-
self.generator = layers.Dense(k, dtype='float32')
|
| 141 |
-
|
| 142 |
-
def call(self, x):
|
| 143 |
-
x_in = x
|
| 144 |
-
x = tf.cast(x, tf.float32)
|
| 145 |
-
B, L, D = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]
|
| 146 |
-
|
| 147 |
-
# 1) token-wise kernel 생성
|
| 148 |
-
kernels = self.generator(self.dense(x)) # (B, L, k)
|
| 149 |
-
kernels = tf.nn.softmax(kernels, axis=-1)
|
| 150 |
-
kernels_exp = tf.expand_dims(kernels, axis=-1) # (B, L, k, 1)
|
| 151 |
-
|
| 152 |
-
# 2) 패딩 및 shifted patch 생성 (벡터화)
|
| 153 |
-
pad = (self.k - 1) // 2
|
| 154 |
-
x_pad = tf.pad(x, [[0,0],[pad,pad],[0,0]]) # (B, L+k-1, D)
|
| 155 |
-
|
| 156 |
-
# shifted patches 한 번에 생성
|
| 157 |
-
idx = tf.range(self.k)[None, :, None] + tf.range(L)[:, None, None] # (L, k, 1)
|
| 158 |
-
idx = tf.broadcast_to(idx, [B, L, self.k]) + tf.zeros([B, L, self.k], dtype=tf.int32) # (B,L,k)
|
| 159 |
-
batch_idx = tf.reshape(tf.range(B)[:, None, None], [B,1,1])
|
| 160 |
-
batch_idx = tf.broadcast_to(batch_idx, [B,L,self.k])
|
| 161 |
-
|
| 162 |
-
patches = tf.gather(x_pad, idx, axis=1, batch_dims=1) # (B, L, k, D)
|
| 163 |
-
|
| 164 |
-
# 3) token-wise weighted sum
|
| 165 |
-
out = tf.reduce_sum(patches * kernels_exp, axis=2) # (B, L, D)
|
| 166 |
-
out = self.proj(out)
|
| 167 |
-
|
| 168 |
-
return tf.cast(out, x_in.dtype)
|
| 169 |
-
|
| 170 |
|
| 171 |
class HyperConv1D(layers.Layer):
|
| 172 |
def __init__(self, d_model, k=7, mem_size=64, hyper_dim=128, dropout=0.0):
|
|
|
|
| 128 |
ds = ds.map(lambda v1, v2: ((v1, v2), tf.zeros([BATCH_SIZE], dtype=tf.float32)), num_parallel_calls=tf.data.AUTOTUNE)
|
| 129 |
ds = ds.prefetch(tf.data.AUTOTUNE)
|
| 130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
class HyperConv1D(layers.Layer):
|
| 133 |
def __init__(self, d_model, k=7, mem_size=64, hyper_dim=128, dropout=0.0):
|