Upload folder using huggingface_hub
Browse files- config.json +1 -1
- fake_quartet.py +8 -14
config.json
CHANGED
|
@@ -27,7 +27,7 @@
|
|
| 27 |
"ratio": 4,
|
| 28 |
"scale_type": "1/sqrt(d)",
|
| 29 |
"tie_word_embeddings": true,
|
| 30 |
-
"transformers_version": "5.
|
| 31 |
"vocab_size": 32000,
|
| 32 |
"weight_tying": true
|
| 33 |
}
|
|
|
|
| 27 |
"ratio": 4,
|
| 28 |
"scale_type": "1/sqrt(d)",
|
| 29 |
"tie_word_embeddings": true,
|
| 30 |
+
"transformers_version": "5.4.0",
|
| 31 |
"vocab_size": 32000,
|
| 32 |
"weight_tying": true
|
| 33 |
}
|
fake_quartet.py
CHANGED
|
@@ -19,14 +19,12 @@ def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device):
|
|
| 19 |
|
| 20 |
|
| 21 |
def rerotate_hadamard(hadamard_matrix):
|
| 22 |
-
signs = torch.
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
)
|
| 29 |
-
return hadamard_matrix @ signs
|
| 30 |
|
| 31 |
|
| 32 |
|
|
@@ -186,15 +184,11 @@ def _eden_1x16s_fp4_kernel(
|
|
| 186 |
tl.where(x_scaled_abs >= 0.25, 0.5,
|
| 187 |
0))))))) * x_scaled_sign
|
| 188 |
|
| 189 |
-
x_scaled = tl.reshape(x_scaled, (BLOCK_SIZE // hadamard_dim, hadamard_dim))
|
| 190 |
-
x_fp4 = tl.reshape(x_fp4, (BLOCK_SIZE // hadamard_dim, hadamard_dim))
|
| 191 |
-
|
| 192 |
num = tl.sum(x_scaled * x_scaled, axis=-1, keep_dims=True)
|
| 193 |
denom = tl.sum(x_scaled * x_fp4, axis=-1, keep_dims=True)
|
| 194 |
correction = tl.where(denom == 0.0, 1.0, num / denom)
|
| 195 |
|
| 196 |
-
|
| 197 |
-
corrected_scales = tl.reshape(scales * correction, (BLOCK_SIZE // group_size, 1))
|
| 198 |
|
| 199 |
bitscales = tl.cast(corrected_scales.to(tl.float8e4nv), tl.uint8, bitcast=True)
|
| 200 |
prevscale = tl.cast((bitscales - 1), tl.float8e4nv, bitcast=True).to(tl.float32)
|
|
@@ -324,7 +318,7 @@ class FakeQuartetFn(torch.autograd.Function):
|
|
| 324 |
|
| 325 |
class FakeQuartetLinear(torch.nn.Linear):
|
| 326 |
|
| 327 |
-
def __init__(self, *args, hadamard_dim=
|
| 328 |
disable_forward_quant=False, disable_backward_quant=False,
|
| 329 |
four_over_six=True, **kwargs):
|
| 330 |
super().__init__(*args, **kwargs)
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
def rerotate_hadamard(hadamard_matrix):
|
| 22 |
+
signs = torch.randint(
|
| 23 |
+
0, 2, (hadamard_matrix.size(0),),
|
| 24 |
+
device=hadamard_matrix.device,
|
| 25 |
+
dtype=hadamard_matrix.dtype,
|
| 26 |
+
) * 2 - 1
|
| 27 |
+
return hadamard_matrix * signs[None, :]
|
|
|
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
|
|
|
|
| 184 |
tl.where(x_scaled_abs >= 0.25, 0.5,
|
| 185 |
0))))))) * x_scaled_sign
|
| 186 |
|
|
|
|
|
|
|
|
|
|
| 187 |
num = tl.sum(x_scaled * x_scaled, axis=-1, keep_dims=True)
|
| 188 |
denom = tl.sum(x_scaled * x_fp4, axis=-1, keep_dims=True)
|
| 189 |
correction = tl.where(denom == 0.0, 1.0, num / denom)
|
| 190 |
|
| 191 |
+
corrected_scales = s_dec_b_e4m3 * correction
|
|
|
|
| 192 |
|
| 193 |
bitscales = tl.cast(corrected_scales.to(tl.float8e4nv), tl.uint8, bitcast=True)
|
| 194 |
prevscale = tl.cast((bitscales - 1), tl.float8e4nv, bitcast=True).to(tl.float32)
|
|
|
|
| 318 |
|
| 319 |
class FakeQuartetLinear(torch.nn.Linear):
|
| 320 |
|
| 321 |
+
def __init__(self, *args, hadamard_dim=128, delayed_amax=False,
|
| 322 |
disable_forward_quant=False, disable_backward_quant=False,
|
| 323 |
four_over_six=True, **kwargs):
|
| 324 |
super().__init__(*args, **kwargs)
|