Buckets:
| """Dequantize BFL's ModelOpt NVFP4 klein-4B -> a normal bf16 single-file, so diffusers | |
| from_single_file can ingest it (the standard converter handles flux->diffusers naming once the | |
| weights are unquantized; it only choked on the quant layout). | |
| NVFP4 (ModelOpt two-level): weight U8[out,in/2] (2x E2M1/byte) * weight_scale FP8[out,in/16] | |
| (per-group-16) * weight_scale_2 F32 (per-tensor). This is W-only dequant (bf16 activations); | |
| input_scale (static act quant) is dropped -> an UPPER BOUND on BFL quality (favourable to BFL). | |
| Validates against the original klein-4B bf16 single-file (same flux naming) -> rel-err should be | |
| the ~few-% 4-bit quant error. Usage: python3 scripts/38_bfl_dequant.py | |
| """ | |
| import torch | |
| from safetensors import safe_open | |
| from safetensors.torch import save_file | |
| SRC = "models/klein-4b-nvfp4/flux-2-klein-4b-nvfp4.safetensors" | |
| ORIG = "models/klein-4b/flux-2-klein-4b.safetensors" | |
| DST = "models/klein-4b-nvfp4/flux-2-klein-4b-bf16-from-nvfp4.safetensors" | |
| _E2M1 = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.]) | |
| def dequant(wu8, wscale_fp8, wscale2): | |
| out, inh = wu8.shape | |
| wu8 = wu8.to(torch.int32) | |
| def dec(n): | |
| sign = torch.where((n & 0x8) > 0, -1.0, 1.0) | |
| return sign * _E2M1[(n & 0x7).long()] | |
| lo, hi = wu8 & 0xF, (wu8 >> 4) & 0xF | |
| vals = torch.stack([dec(lo), dec(hi)], dim=-1).reshape(out, inh * 2) # elem 2j=lo, 2j+1=hi | |
| ws = wscale_fp8.float().repeat_interleave(16, dim=1) # (out, in) | |
| return (vals * ws * float(wscale2)).to(torch.bfloat16) | |
| bf = safe_open(SRC, 'pt') | |
| og = safe_open(ORIG, 'pt') | |
| bkeys = set(bf.keys()); okeys = list(og.keys()) | |
| out = {}; errs = [] | |
| for k in okeys: | |
| if k.endswith('.weight') and (k + '_scale') in bkeys: | |
| w = dequant(bf.get_tensor(k), bf.get_tensor(k + '_scale'), bf.get_tensor(k + '_scale_2')) | |
| out[k] = w | |
| ow = og.get_tensor(k).float() | |
| rel = (w.float() - ow).norm() / (ow.norm() + 1e-8) | |
| errs.append((k, float(rel))) | |
| else: | |
| out[k] = bf.get_tensor(k) | |
| print(f"dequantized {len(errs)} quantized weights; copied {len(out)-len(errs)} non-quant tensors") | |
| errs.sort(key=lambda x: -x[1]) | |
| print("rel-err vs original klein-4B bf16 (worst 5, best 1, mean):") | |
| for k, e in errs[:5]: | |
| print(f" {e:.4f} {k}") | |
| print(f" {errs[-1][1]:.4f} {errs[-1][0]} (best)") | |
| print(f" mean={sum(e for _, e in errs)/len(errs):.4f}") | |
| save_file(out, DST) | |
| print(f"saved bf16 single-file -> {DST}") | |
Xet Storage Details
- Size:
- 2.44 kB
- Xet hash:
- e2c4e7cc92971dd995fc1e792f59311f2345e43089451ea39a1a803bad9a7cff
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.