Yuchan
commited on
Update Inference.py
Browse files- Inference.py +2 -13
Inference.py
CHANGED
|
@@ -68,7 +68,6 @@ class LoSoU(layers.Layer):
|
|
| 68 |
self.K = layers.Dense(96, dtype='float32')
|
| 69 |
self.V = Lo(d_model) # Lo already handles casting to model dtype; we'll cast back to float32
|
| 70 |
self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
|
| 71 |
-
self.O = layers.Dense(d_model, dtype='float32')
|
| 72 |
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 73 |
|
| 74 |
# λμ alpha κ³μ°μ μν λ μ΄μ΄
|
|
@@ -118,7 +117,7 @@ class LoSoU(layers.Layer):
|
|
| 118 |
# x: (B, L, d_model) maybe bfloat16 or float32
|
| 119 |
# cast to float32 for all internal computations
|
| 120 |
x_f32 = tf.cast(x, tf.float32)
|
| 121 |
-
|
| 122 |
|
| 123 |
# Q, K, V
|
| 124 |
q = self.Q(x_f32) # (B, L, 96)
|
|
@@ -133,7 +132,7 @@ class LoSoU(layers.Layer):
|
|
| 133 |
score = g_q * g_k
|
| 134 |
|
| 135 |
# λμ alpha κ³μ°: (B, L, d_model) -> (B, L, 1)
|
| 136 |
-
alpha_dynamic = self.alpha_linear(x_f32) # (B, L, 1)
|
| 137 |
# νμμ alpha_dynamicμ λν νμ²λ¦¬ (μ: min/max λ±) κ°λ₯
|
| 138 |
# ex: alpha_dynamic = tf.clip_by_value(alpha_dynamic, 0.01, 0.99)
|
| 139 |
|
|
@@ -152,16 +151,6 @@ class LoSoU(layers.Layer):
|
|
| 152 |
x_comb = score_clipped * V # (B, L, d_model)
|
| 153 |
|
| 154 |
out = self.proj(x_comb) # (B, L, d_model)
|
| 155 |
-
|
| 156 |
-
# ensure out dim even for split
|
| 157 |
-
d = out.shape[-1] # this is an int (static shape)
|
| 158 |
-
if d is not None and d % 2 == 1:
|
| 159 |
-
out = tf.pad(out, [[0,0],[0,0],[0,1]])
|
| 160 |
-
|
| 161 |
-
a, b = tf.split(out, 2, axis=-1)
|
| 162 |
-
gated = tf.nn.silu(a) * b
|
| 163 |
-
out = self.O(gated)
|
| 164 |
-
|
| 165 |
out = self.norm(out)
|
| 166 |
|
| 167 |
# cast back to original dtype for downstream layers
|
|
|
|
| 68 |
self.K = layers.Dense(96, dtype='float32')
|
| 69 |
self.V = Lo(d_model) # Lo already handles casting to model dtype; we'll cast back to float32
|
| 70 |
self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
|
|
|
|
| 71 |
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 72 |
|
| 73 |
# λμ alpha κ³μ°μ μν λ μ΄μ΄
|
|
|
|
| 117 |
# x: (B, L, d_model) maybe bfloat16 or float32
|
| 118 |
# cast to float32 for all internal computations
|
| 119 |
x_f32 = tf.cast(x, tf.float32)
|
| 120 |
+
residual = x_f32
|
| 121 |
|
| 122 |
# Q, K, V
|
| 123 |
q = self.Q(x_f32) # (B, L, 96)
|
|
|
|
| 132 |
score = g_q * g_k
|
| 133 |
|
| 134 |
# λμ alpha κ³μ°: (B, L, d_model) -> (B, L, 1)
|
| 135 |
+
alpha_dynamic = self.alpha_linear(x_f32) * 0.8 + 0.1 # (B, L, 1)
|
| 136 |
# νμμ alpha_dynamicμ λν νμ²λ¦¬ (μ: min/max λ±) κ°λ₯
|
| 137 |
# ex: alpha_dynamic = tf.clip_by_value(alpha_dynamic, 0.01, 0.99)
|
| 138 |
|
|
|
|
| 151 |
x_comb = score_clipped * V # (B, L, d_model)
|
| 152 |
|
| 153 |
out = self.proj(x_comb) # (B, L, d_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
out = self.norm(out)
|
| 155 |
|
| 156 |
# cast back to original dtype for downstream layers
|