OpenLab-NLP commited on
Commit
cebebd8
·
verified ·
1 Parent(s): bd2afe4

Update V2.py

Browse files
Files changed (1) hide show
  1. V2.py +0 -39
V2.py CHANGED
@@ -128,45 +128,6 @@ ds = ds.batch(BATCH_SIZE, drop_remainder=True)
128
  ds = ds.map(lambda v1, v2: ((v1, v2), tf.zeros([BATCH_SIZE], dtype=tf.float32)), num_parallel_calls=tf.data.AUTOTUNE)
129
  ds = ds.prefetch(tf.data.AUTOTUNE)
130
 
131
- class DynamicConvTPU(layers.Layer):
132
- def __init__(self, d_model, k=7):
133
- super().__init__()
134
- assert k % 2 == 1
135
- self.k = k
136
- self.d_model = d_model
137
-
138
- self.dense = layers.Dense(d_model, activation='silu')
139
- self.proj = layers.Dense(d_model)
140
- self.generator = layers.Dense(k, dtype='float32')
141
-
142
- def call(self, x):
143
- x_in = x
144
- x = tf.cast(x, tf.float32)
145
- B, L, D = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]
146
-
147
- # 1) token-wise kernel 생성
148
- kernels = self.generator(self.dense(x)) # (B, L, k)
149
- kernels = tf.nn.softmax(kernels, axis=-1)
150
- kernels_exp = tf.expand_dims(kernels, axis=-1) # (B, L, k, 1)
151
-
152
- # 2) 패딩 및 shifted patch 생성 (벡터화)
153
- pad = (self.k - 1) // 2
154
- x_pad = tf.pad(x, [[0,0],[pad,pad],[0,0]]) # (B, L+k-1, D)
155
-
156
- # shifted patches 한 번에 생성
157
- idx = tf.range(self.k)[None, :, None] + tf.range(L)[:, None, None] # (L, k, 1)
158
- idx = tf.broadcast_to(idx, [B, L, self.k]) + tf.zeros([B, L, self.k], dtype=tf.int32) # (B,L,k)
159
- batch_idx = tf.reshape(tf.range(B)[:, None, None], [B,1,1])
160
- batch_idx = tf.broadcast_to(batch_idx, [B,L,self.k])
161
-
162
- patches = tf.gather(x_pad, idx, axis=1, batch_dims=1) # (B, L, k, D)
163
-
164
- # 3) token-wise weighted sum
165
- out = tf.reduce_sum(patches * kernels_exp, axis=2) # (B, L, D)
166
- out = self.proj(out)
167
-
168
- return tf.cast(out, x_in.dtype)
169
-
170
 
171
  class HyperConv1D(layers.Layer):
172
  def __init__(self, d_model, k=7, mem_size=64, hyper_dim=128, dropout=0.0):
 
128
  ds = ds.map(lambda v1, v2: ((v1, v2), tf.zeros([BATCH_SIZE], dtype=tf.float32)), num_parallel_calls=tf.data.AUTOTUNE)
129
  ds = ds.prefetch(tf.data.AUTOTUNE)
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  class HyperConv1D(layers.Layer):
133
  def __init__(self, d_model, k=7, mem_size=64, hyper_dim=128, dropout=0.0):