OpenLab-NLP commited on
Commit
7f8fd1d
ยท
verified ยท
1 Parent(s): 67804dc

Update Test.py

Browse files
Files changed (1) hide show
  1. Test.py +31 -14
Test.py CHANGED
@@ -66,13 +66,18 @@ dataset = tf.data.Dataset.from_generator(
66
  class EncoderBlock(tf.keras.layers.Layer):
67
  def __init__(self, embed_dim=EMBED_DIM, ff_dim=1152, seq_len=MAX_LEN):
68
  super().__init__()
 
 
 
69
  self.fc1 = layers.Dense(ff_dim)
70
  self.fc2 = layers.Dense(embed_dim)
71
  self.fc3 = layers.Dense(ff_dim)
72
  self.fc4 = layers.Dense(embed_dim)
73
 
 
74
  self.w_proj = self.add_weight(
75
- shape=(embed_dim, embed_dim),
 
76
  initializer="glorot_uniform",
77
  trainable=True
78
  )
@@ -82,26 +87,38 @@ class EncoderBlock(tf.keras.layers.Layer):
82
  self.ln = layers.LayerNormalization(epsilon=1e-5)
83
  self.ln1 = layers.LayerNormalization(epsilon=1e-5)
84
  self.ln2 = layers.LayerNormalization(epsilon=1e-5)
85
-
86
  def call(self, x):
 
87
  x_norm = self.ln(x)
88
- x = self.fc1(x_norm)
89
- g, v = tf.split(x, 2, axis=-1)
90
- x = tf.nn.silu(g) * v
91
- x = self.fc2(x)
92
 
93
- x = tf.matmul(x, x, transpose_b=True) # (B,L,L)
94
- x = tf.tensordot(x, self.w_proj, axes=[-1, 0]) # (B,L,D)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- v = tf.nn.softmax(self.alpha2(v), axis=1) * x
97
  x_norm = x_norm + self.ln2(v)
98
 
99
- x = self.fc3(x_norm)
100
- g, v = tf.split(x, 2, axis=-1)
101
- x = tf.nn.silu(g) * v
102
- x = self.fc4(x)
 
 
103
 
104
- return x_norm + self.ln1(x)
105
 
106
 
107
  class L2NormLayer(layers.Layer):
 
66
  class EncoderBlock(tf.keras.layers.Layer):
67
  def __init__(self, embed_dim=EMBED_DIM, ff_dim=1152, seq_len=MAX_LEN):
68
  super().__init__()
69
+ self.embed_dim = embed_dim
70
+ self.seq_len = seq_len
71
+
72
  self.fc1 = layers.Dense(ff_dim)
73
  self.fc2 = layers.Dense(embed_dim)
74
  self.fc3 = layers.Dense(ff_dim)
75
  self.fc4 = layers.Dense(embed_dim)
76
 
77
+ # (seq_len, embed_dim)๋กœ ์ •์˜ โ€” (L -> D) ํˆฌ์‚ฌ์šฉ
78
  self.w_proj = self.add_weight(
79
+ name="w_proj_L_to_D",
80
+ shape=(seq_len, embed_dim),
81
  initializer="glorot_uniform",
82
  trainable=True
83
  )
 
87
  self.ln = layers.LayerNormalization(epsilon=1e-5)
88
  self.ln1 = layers.LayerNormalization(epsilon=1e-5)
89
  self.ln2 = layers.LayerNormalization(epsilon=1e-5)
90
+
91
  def call(self, x):
92
+ # x: (B, L, D)
93
  x_norm = self.ln(x)
 
 
 
 
94
 
95
+ h = self.fc1(x_norm) # (B, L, ff_dim)
96
+ g, v = tf.split(h, 2, axis=-1) # (B, L, ff_dim/2) ๊ฐ
97
+ h = tf.nn.silu(g) * v
98
+ h = self.fc2(h) # (B, L, D)
99
+
100
+ # --- matmul -> (B, L, L) ---
101
+ sim = tf.matmul(h, h, transpose_b=True) # (B, L, L)
102
+ # (์˜ต์…˜) ์ •๊ทœํ™”/์Šค์ผ€์ผ๋ง ์›ํ•˜๋ฉด ์ถ”๊ฐ€
103
+ sim = tf.nn.softmax(sim, axis=-1) # (B, L, L)
104
+
105
+ # --- (B, L, L) -> (B, L, D) : tensordot axes ๋งž์ถฐ์„œ ํˆฌ์‚ฌ ---
106
+ # w_proj: (L, D), sim last axis matches w_proj first axis
107
+ h2 = tf.tensordot(sim, self.w_proj, axes=[[2], [0]]) # (B, L, D)
108
+
109
+ # ์ด์ œ shape ๋งž์Œ โ€” v์™€ element-wise ๊ณฑ ๊ฐ€๋Šฅ
110
+ v_gate = tf.nn.softmax(self.alpha2(v), axis=1) # (B, L, 1)
111
+ v = v_gate * h2 # (B, L, D)
112
 
 
113
  x_norm = x_norm + self.ln2(v)
114
 
115
+ z = self.fc3(x_norm)
116
+ g, v = tf.split(z, 2, axis=-1)
117
+ z = tf.nn.silu(g) * v
118
+ z = self.fc4(z)
119
+
120
+ return x_norm + self.ln1(z)
121
 
 
122
 
123
 
124
  class L2NormLayer(layers.Layer):