Yuchan
commited on
Update Model.py
Browse files
Model.py
CHANGED
|
@@ -124,7 +124,7 @@ class Lo(layers.Layer):
|
|
| 124 |
def __init__(self):
|
| 125 |
super().__init__()
|
| 126 |
# 내부 계산은 float32로 유지
|
| 127 |
-
self.p = layers.Dense(
|
| 128 |
self._out_dtype = 'float32'
|
| 129 |
|
| 130 |
def call(self, x):
|
|
@@ -137,26 +137,22 @@ class Lo(layers.Layer):
|
|
| 137 |
class rGLU(layers.Layer):
|
| 138 |
def __init__(self, d_model, hyper_n):
|
| 139 |
super().__init__()
|
| 140 |
-
self.Wr =
|
| 141 |
-
self.
|
| 142 |
-
self.
|
|
|
|
| 143 |
self.W = layers.Dense(d_model)
|
| 144 |
def call(self, x):
|
| 145 |
-
x = self.Wr(x)
|
| 146 |
-
x = self.WB(x)
|
| 147 |
-
a, b = tf.split(x, 2, axis=-1)
|
| 148 |
-
o = tf.nn.silu(a) * b
|
| 149 |
-
o = self.Wr1(o)
|
| 150 |
-
o = self.W(o)
|
| 151 |
return o
|
| 152 |
|
| 153 |
class Adapter(layers.Layer):
|
| 154 |
def __init__(self, d_model, hyper_n):
|
| 155 |
super().__init__()
|
| 156 |
-
self.Wr =
|
| 157 |
self.W = layers.Dense(d_model)
|
| 158 |
def call(self, x):
|
| 159 |
-
return self.W(self.Wr(x))
|
| 160 |
|
| 161 |
class LoSoU(layers.Layer):
|
| 162 |
"""
|
|
@@ -237,9 +233,9 @@ class LoSoU(layers.Layer):
|
|
| 237 |
residual = x_f32
|
| 238 |
|
| 239 |
# Q, K, V
|
| 240 |
-
q = self.Q(self.Qr(x_f32)) # (B, L, 96)
|
| 241 |
-
k = self.K(self.Kr(x_f32))
|
| 242 |
-
V =
|
| 243 |
|
| 244 |
# gating signals in (0,1)
|
| 245 |
g_q = tf.nn.sigmoid(q)
|
|
@@ -333,7 +329,7 @@ def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
|
|
| 333 |
model = ReLaM(
|
| 334 |
vocab_size=vocab_size,
|
| 335 |
max_seq_len=max_len,
|
| 336 |
-
d_model=
|
| 337 |
n_layers=1
|
| 338 |
)
|
| 339 |
|
|
|
|
| 124 |
def __init__(self):
|
| 125 |
super().__init__()
|
| 126 |
# 내부 계산은 float32로 유지
|
| 127 |
+
self.p = layers.Dense(64, use_bias=True, dtype='float32')
|
| 128 |
self._out_dtype = 'float32'
|
| 129 |
|
| 130 |
def call(self, x):
|
|
|
|
| 137 |
class rGLU(layers.Layer):
|
| 138 |
def __init__(self, d_model, hyper_n):
|
| 139 |
super().__init__()
|
| 140 |
+
self.Wr = Lo()
|
| 141 |
+
self.W2 = layers.Dense(256)
|
| 142 |
+
self.W1 = layers.Dense(256)
|
| 143 |
+
self.Wr1 = Lo()
|
| 144 |
self.W = layers.Dense(d_model)
|
| 145 |
def call(self, x):
|
| 146 |
+
x = tf.nn.silu(self.W1(Wr(x)) + x) * (self.W2(self.Wr1(x)) + x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
return o
|
| 148 |
|
| 149 |
class Adapter(layers.Layer):
|
| 150 |
def __init__(self, d_model, hyper_n):
|
| 151 |
super().__init__()
|
| 152 |
+
self.Wr = Lo()
|
| 153 |
self.W = layers.Dense(d_model)
|
| 154 |
def call(self, x):
|
| 155 |
+
return self.W(tf.nn.gelu(self.Wr(x)))
|
| 156 |
|
| 157 |
class LoSoU(layers.Layer):
|
| 158 |
"""
|
|
|
|
| 233 |
residual = x_f32
|
| 234 |
|
| 235 |
# Q, K, V
|
| 236 |
+
q = self.Q(self.Qr(x_f32)) + x_f32 # (B, L, 96)
|
| 237 |
+
k = self.K(self.Kr(x_f32)) + x_f32 # (B, L, 96)
|
| 238 |
+
V = self.V(self.Vr(x)) + x # ensure V's output is float32
|
| 239 |
|
| 240 |
# gating signals in (0,1)
|
| 241 |
g_q = tf.nn.sigmoid(q)
|
|
|
|
| 329 |
model = ReLaM(
|
| 330 |
vocab_size=vocab_size,
|
| 331 |
max_seq_len=max_len,
|
| 332 |
+
d_model=192,
|
| 333 |
n_layers=1
|
| 334 |
)
|
| 335 |
|