Yuchan
commited on
Update Model.py
Browse files
Model.py
CHANGED
|
@@ -120,13 +120,22 @@ dataset = dataset.shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
|
|
| 120 |
|
| 121 |
print("โ
TF Dataset ์์ฑ ์๋ฃ!")
|
| 122 |
|
| 123 |
-
class
|
| 124 |
def __init__(self, d_model):
|
| 125 |
super().__init__()
|
| 126 |
-
|
| 127 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 128 |
def call(self, x):
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
class LoSoU(layers.Layer):
|
| 132 |
"""
|
|
@@ -145,12 +154,12 @@ class LoSoU(layers.Layer):
|
|
| 145 |
self.eps = float(eps)
|
| 146 |
|
| 147 |
# projection / gating layers in float32
|
| 148 |
-
self.Q = layers.Dense(
|
| 149 |
-
self.K = layers.Dense(
|
| 150 |
-
self.
|
|
|
|
|
|
|
| 151 |
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 152 |
-
self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 153 |
-
self.norm2 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 154 |
|
| 155 |
# ๋์ alpha ๊ณ์ฐ์ ์ํ ๋ ์ด์ด
|
| 156 |
# alpha๋ [0, 1] ๋ฒ์์ฌ์ผ ํ๋ฏ๋ก sigmoid ์ฌ์ฉ
|
|
@@ -199,22 +208,22 @@ class LoSoU(layers.Layer):
|
|
| 199 |
# x: (B, L, d_model) maybe bfloat16 or float32
|
| 200 |
# cast to float32 for all internal computations
|
| 201 |
x_f32 = tf.cast(x, tf.float32)
|
| 202 |
-
x_f32 = self.norm2(x_f32)
|
| 203 |
residual = x_f32
|
| 204 |
|
| 205 |
# Q, K, V
|
| 206 |
-
q = self.Q(x_f32)
|
| 207 |
-
k = self.K(x_f32)
|
|
|
|
| 208 |
|
| 209 |
# gating signals in (0,1)
|
| 210 |
g_q = tf.nn.sigmoid(q)
|
| 211 |
-
g_k = tf.nn.
|
| 212 |
|
| 213 |
# elementwise product -> bounded roughly [0,1]
|
| 214 |
score = g_q * g_k
|
| 215 |
|
| 216 |
# ๋์ alpha ๊ณ์ฐ: (B, L, d_model) -> (B, L, 1)
|
| 217 |
-
alpha_dynamic = self.alpha_linear(x_f32) # (B, L, 1)
|
| 218 |
# ํ์์ alpha_dynamic์ ๋ํ ํ์ฒ๋ฆฌ (์: min/max ๋ฑ) ๊ฐ๋ฅ
|
| 219 |
# ex: alpha_dynamic = tf.clip_by_value(alpha_dynamic, 0.01, 0.99)
|
| 220 |
|
|
@@ -230,9 +239,20 @@ class LoSoU(layers.Layer):
|
|
| 230 |
score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
|
| 231 |
|
| 232 |
# combine with V
|
| 233 |
-
x_comb =
|
| 234 |
-
|
| 235 |
-
out = self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
# cast back to original dtype for downstream layers
|
| 238 |
return tf.cast(out, x.dtype)
|
|
@@ -251,8 +271,10 @@ class ReLaM(tf.keras.Model):
|
|
| 251 |
def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
|
| 252 |
super().__init__()
|
| 253 |
self.token_embedding = layers.Embedding(vocab_size, 128)
|
| 254 |
-
self.pos_embedding = layers.Embedding(max_seq_len,
|
| 255 |
self.blocks = [Block(d_model, hyper_n=1) for _ in range(n_layers)]
|
|
|
|
|
|
|
| 256 |
self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
|
| 257 |
|
| 258 |
def call(self, x, training=False):
|
|
@@ -262,6 +284,7 @@ class ReLaM(tf.keras.Model):
|
|
| 262 |
x = self.token_embedding(x) + self.pos_embedding(positions)
|
| 263 |
for block in self.blocks:
|
| 264 |
x = block(x)
|
|
|
|
| 265 |
x = self.ln_f(x)
|
| 266 |
|
| 267 |
embedding_matrix = tf.cast(self.token_embedding.embeddings, x.dtype)
|
|
@@ -294,7 +317,7 @@ def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
|
|
| 294 |
model = ReLaM(
|
| 295 |
vocab_size=vocab_size,
|
| 296 |
max_seq_len=max_len,
|
| 297 |
-
d_model=
|
| 298 |
n_layers=1
|
| 299 |
)
|
| 300 |
|
|
@@ -363,4 +386,4 @@ def generate_text_topp(model, prompt, max_len=100, max_gen=98, p=0.9, temperatur
|
|
| 363 |
return ids_to_text(generated)
|
| 364 |
|
| 365 |
print("\n\n===== ์์ฑ ๊ฒฐ๊ณผ =====")
|
| 366 |
-
print(generate_text_topp(model, "
|
|
|
|
| 120 |
|
| 121 |
print("โ
TF Dataset ์์ฑ ์๋ฃ!")
|
| 122 |
|
| 123 |
+
class Lo(layers.Layer):
|
| 124 |
def __init__(self, d_model):
|
| 125 |
super().__init__()
|
| 126 |
+
# ๋ด๋ถ ๊ณ์ฐ์ float32๋ก ์ ์ง
|
| 127 |
+
self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
|
| 128 |
+
self.p = layers.Dense(96, use_bias=True, dtype='float32')
|
| 129 |
+
self._out_dtype = 'float32'
|
| 130 |
+
|
| 131 |
def call(self, x):
|
| 132 |
+
# x may be bfloat16; cast to float32 for stable intermediate computation
|
| 133 |
+
x_f32 = tf.cast(x, tf.float32)
|
| 134 |
+
x = self.proj(x_f32)
|
| 135 |
+
x = tf.nn.gelu(x)
|
| 136 |
+
x = self.p(x)
|
| 137 |
+
# cast back to model dtype for consistency
|
| 138 |
+
return tf.cast(x, self._out_dtype)
|
| 139 |
|
| 140 |
class LoSoU(layers.Layer):
|
| 141 |
"""
|
|
|
|
| 154 |
self.eps = float(eps)
|
| 155 |
|
| 156 |
# projection / gating layers in float32
|
| 157 |
+
self.Q = layers.Dense(96, dtype='float32')
|
| 158 |
+
self.K = layers.Dense(96, dtype='float32')
|
| 159 |
+
self.V = Lo(d_model) # Lo already handles casting to model dtype; we'll cast back to float32
|
| 160 |
+
self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
|
| 161 |
+
self.O = layers.Dense(d_model, dtype='float32')
|
| 162 |
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
|
|
|
|
|
|
| 163 |
|
| 164 |
# ๋์ alpha ๊ณ์ฐ์ ์ํ ๋ ์ด์ด
|
| 165 |
# alpha๋ [0, 1] ๋ฒ์์ฌ์ผ ํ๋ฏ๋ก sigmoid ์ฌ์ฉ
|
|
|
|
| 208 |
# x: (B, L, d_model) maybe bfloat16 or float32
|
| 209 |
# cast to float32 for all internal computations
|
| 210 |
x_f32 = tf.cast(x, tf.float32)
|
|
|
|
| 211 |
residual = x_f32
|
| 212 |
|
| 213 |
# Q, K, V
|
| 214 |
+
q = self.Q(x_f32) # (B, L, 96)
|
| 215 |
+
k = self.K(x_f32) # (B, L, 96)
|
| 216 |
+
V = tf.cast(self.V(x), tf.float32) # ensure V's output is float32
|
| 217 |
|
| 218 |
# gating signals in (0,1)
|
| 219 |
g_q = tf.nn.sigmoid(q)
|
| 220 |
+
g_k = tf.nn.sigmoid(k)
|
| 221 |
|
| 222 |
# elementwise product -> bounded roughly [0,1]
|
| 223 |
score = g_q * g_k
|
| 224 |
|
| 225 |
# ๋์ alpha ๊ณ์ฐ: (B, L, d_model) -> (B, L, 1)
|
| 226 |
+
alpha_dynamic = self.alpha_linear(x_f32) * 0.8 + 0.1 # (B, L, 1)
|
| 227 |
# ํ์์ alpha_dynamic์ ๋ํ ํ์ฒ๋ฆฌ (์: min/max ๋ฑ) ๊ฐ๋ฅ
|
| 228 |
# ex: alpha_dynamic = tf.clip_by_value(alpha_dynamic, 0.01, 0.99)
|
| 229 |
|
|
|
|
| 239 |
score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
|
| 240 |
|
| 241 |
# combine with V
|
| 242 |
+
x_comb = score_clipped * V # (B, L, d_model)
|
| 243 |
+
|
| 244 |
+
out = self.proj(x_comb) # (B, L, d_model)
|
| 245 |
+
|
| 246 |
+
# ensure out dim even for split
|
| 247 |
+
d = out.shape[-1] # this is an int (static shape)
|
| 248 |
+
if d is not None and d % 2 == 1:
|
| 249 |
+
out = tf.pad(out, [[0,0],[0,0],[0,1]])
|
| 250 |
+
|
| 251 |
+
a, b = tf.split(out, 2, axis=-1)
|
| 252 |
+
gated = tf.nn.silu(a) * b
|
| 253 |
+
out = self.O(gated)
|
| 254 |
+
|
| 255 |
+
out = self.norm(out + residual)
|
| 256 |
|
| 257 |
# cast back to original dtype for downstream layers
|
| 258 |
return tf.cast(out, x.dtype)
|
|
|
|
| 271 |
def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
|
| 272 |
super().__init__()
|
| 273 |
self.token_embedding = layers.Embedding(vocab_size, 128)
|
| 274 |
+
self.pos_embedding = layers.Embedding(max_seq_len, d_model)
|
| 275 |
self.blocks = [Block(d_model, hyper_n=1) for _ in range(n_layers)]
|
| 276 |
+
|
| 277 |
+
# LayerNormalization์ float32๋ก ํด์ ์ ๋ฐ๋ ๋ฌธ์ ๋ฐฉ์ง
|
| 278 |
self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
|
| 279 |
|
| 280 |
def call(self, x, training=False):
|
|
|
|
| 284 |
x = self.token_embedding(x) + self.pos_embedding(positions)
|
| 285 |
for block in self.blocks:
|
| 286 |
x = block(x)
|
| 287 |
+
|
| 288 |
x = self.ln_f(x)
|
| 289 |
|
| 290 |
embedding_matrix = tf.cast(self.token_embedding.embeddings, x.dtype)
|
|
|
|
| 317 |
model = ReLaM(
|
| 318 |
vocab_size=vocab_size,
|
| 319 |
max_seq_len=max_len,
|
| 320 |
+
d_model=256,
|
| 321 |
n_layers=1
|
| 322 |
)
|
| 323 |
|
|
|
|
| 386 |
return ids_to_text(generated)
|
| 387 |
|
| 388 |
print("\n\n===== ์์ฑ ๊ฒฐ๊ณผ =====")
|
| 389 |
+
print(generate_text_topp(model, "์๋
", p=0.9))
|