mansaripo commited on
Commit
12303b2
·
verified ·
1 Parent(s): f359aa5

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. config.json +1 -1
  2. fake_quartet.py +8 -14
config.json CHANGED
@@ -27,7 +27,7 @@
27
  "ratio": 4,
28
  "scale_type": "1/sqrt(d)",
29
  "tie_word_embeddings": true,
30
- "transformers_version": "5.3.0",
31
  "vocab_size": 32000,
32
  "weight_tying": true
33
  }
 
27
  "ratio": 4,
28
  "scale_type": "1/sqrt(d)",
29
  "tie_word_embeddings": true,
30
+ "transformers_version": "5.4.0",
31
  "vocab_size": 32000,
32
  "weight_tying": true
33
  }
fake_quartet.py CHANGED
@@ -19,14 +19,12 @@ def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device):
19
 
20
 
21
  def rerotate_hadamard(hadamard_matrix):
22
- signs = torch.diag(
23
- torch.randint(
24
- 0, 2, (hadamard_matrix.size(0),),
25
- device=hadamard_matrix.device,
26
- dtype=hadamard_matrix.dtype,
27
- ) * 2 - 1
28
- )
29
- return hadamard_matrix @ signs
30
 
31
 
32
 
@@ -186,15 +184,11 @@ def _eden_1x16s_fp4_kernel(
186
  tl.where(x_scaled_abs >= 0.25, 0.5,
187
  0))))))) * x_scaled_sign
188
 
189
- x_scaled = tl.reshape(x_scaled, (BLOCK_SIZE // hadamard_dim, hadamard_dim))
190
- x_fp4 = tl.reshape(x_fp4, (BLOCK_SIZE // hadamard_dim, hadamard_dim))
191
-
192
  num = tl.sum(x_scaled * x_scaled, axis=-1, keep_dims=True)
193
  denom = tl.sum(x_scaled * x_fp4, axis=-1, keep_dims=True)
194
  correction = tl.where(denom == 0.0, 1.0, num / denom)
195
 
196
- scales = tl.reshape(s_dec_b_e4m3, (BLOCK_SIZE // hadamard_dim, hadamard_dim // group_size))
197
- corrected_scales = tl.reshape(scales * correction, (BLOCK_SIZE // group_size, 1))
198
 
199
  bitscales = tl.cast(corrected_scales.to(tl.float8e4nv), tl.uint8, bitcast=True)
200
  prevscale = tl.cast((bitscales - 1), tl.float8e4nv, bitcast=True).to(tl.float32)
@@ -324,7 +318,7 @@ class FakeQuartetFn(torch.autograd.Function):
324
 
325
  class FakeQuartetLinear(torch.nn.Linear):
326
 
327
- def __init__(self, *args, hadamard_dim=32, delayed_amax=False,
328
  disable_forward_quant=False, disable_backward_quant=False,
329
  four_over_six=True, **kwargs):
330
  super().__init__(*args, **kwargs)
 
19
 
20
 
21
  def rerotate_hadamard(hadamard_matrix):
22
+ signs = torch.randint(
23
+ 0, 2, (hadamard_matrix.size(0),),
24
+ device=hadamard_matrix.device,
25
+ dtype=hadamard_matrix.dtype,
26
+ ) * 2 - 1
27
+ return hadamard_matrix * signs[None, :]
 
 
28
 
29
 
30
 
 
184
  tl.where(x_scaled_abs >= 0.25, 0.5,
185
  0))))))) * x_scaled_sign
186
 
 
 
 
187
  num = tl.sum(x_scaled * x_scaled, axis=-1, keep_dims=True)
188
  denom = tl.sum(x_scaled * x_fp4, axis=-1, keep_dims=True)
189
  correction = tl.where(denom == 0.0, 1.0, num / denom)
190
 
191
+ corrected_scales = s_dec_b_e4m3 * correction
 
192
 
193
  bitscales = tl.cast(corrected_scales.to(tl.float8e4nv), tl.uint8, bitcast=True)
194
  prevscale = tl.cast((bitscales - 1), tl.float8e4nv, bitcast=True).to(tl.float32)
 
318
 
319
  class FakeQuartetLinear(torch.nn.Linear):
320
 
321
+ def __init__(self, *args, hadamard_dim=128, delayed_amax=False,
322
  disable_forward_quant=False, disable_backward_quant=False,
323
  four_over_six=True, **kwargs):
324
  super().__init__(*args, **kwargs)