HyperConv-Layer / 이외.py
OpenLab-NLP's picture
Create 이외.py
ab4d9a8 verified
class MixerBlock(layers.Layer):
def __init__(self, seq_len, dim, token_mlp_dim, channel_mlp_dim, dropout=0.0):
super().__init__()
self.seq_len = seq_len
self.dim = dim
self.token_mlp_dim = token_mlp_dim
self.channel_mlp_dim = channel_mlp_dim
self.ln1 = layers.LayerNormalization(epsilon=1e-6, dtype=tf.float32)
# token-mixing MLP: operate over tokens => apply Dense on transposed axis
self.token_fc1 = layers.Dense(token_mlp_dim, activation='gelu', dtype=tf.float32)
self.token_fc2 = layers.Dense(seq_len, dtype=tf.float32)
self.ln2 = layers.LayerNormalization(epsilon=1e-6, dtype=tf.float32)
# channel-mixing MLP: operate per-token over channels
self.channel_fc1 = layers.Dense(channel_mlp_dim, activation='gelu', dtype=tf.float32)
self.channel_fc2 = layers.Dense(dim, dtype=tf.float32)
self.dropout = layers.Dropout(dropout)
def call(self, x, training=None):
# x: (B, L, D)
B = tf.shape(x)[0]
L = tf.shape(x)[1]
D = tf.shape(x)[2]
# Token-mixing
y = self.ln1(x) # (B, L, D)
y_t = tf.transpose(y, perm=[0,2,1]) # (B, D, L)
y_t = self.token_fc1(y_t) # (B, D, token_mlp_dim)
y_t = self.token_fc2(y_t) # (B, D, L)
y = tf.transpose(y_t, perm=[0,2,1]) # (B, L, D)
x = x + self.dropout(y, training=training)
# Channel-mixing
z = self.ln2(x)
z = self.channel_fc1(z)
z = self.channel_fc2(z)
x = x + self.dropout(z, training=training)
return x