Buckets:
| """Generate eval images for BFL's OFFICIAL FP8 klein-4B (model E), faithfully. | |
| BFL's fp8 single-file stores, per quantized Linear: weight (F8_E4M3, plain 2D), weight_scale | |
| (per-tensor F32), input_scale (per-tensor F32 = static activation scale). It is therefore W8A8. | |
| Unlike their NVFP4 file, the fp8 layout is NOT swizzled: weight.float()*weight_scale reconstructs | |
| the teacher weights to cos=0.9997 (verified). We: | |
| 1. dequantize each fp8 weight -> bf16 (their exact quantized weights), | |
| 2. reuse diffusers' OWN flux2 single-file converter on the all-bf16 dict (exact key remap + QKV split), | |
| 3. load into a Flux2Transformer2DModel built from the SAME models/klein-4b config (=> identical TE+VAE), | |
| 4. attach static per-tensor FP8(E4M3) activation fake-quant hooks (input_scale) so it is faithfully W8A8. | |
| This runs as fake-quant (bf16 compute) -> correct QUALITY of BFL's checkpoint; it is not a real fp8 | |
| kernel, so no speedup is claimed for E (their fp8 kernel needs TensorRT). Images verified on probes first. | |
| Usage: python3 -u scripts/41_gen_bfl_fp8.py OUT_DIR [START] [COUNT] [RES] | |
| """ | |
| import sys, os, json, time, torch | |
| from safetensors.torch import load_file | |
| from diffusers import Flux2Transformer2DModel, Flux2KleinPipeline | |
| from diffusers.loaders.single_file_utils import convert_flux2_transformer_checkpoint_to_diffusers | |
| CKPT = "models/bfl-klein-4b-fp8/flux-2-klein-4b-fp8.safetensors" | |
| BASE = "models/klein-4b" | |
| FP8_MAX = 448.0 | |
| # original-flux quantized-layer name -> diffusers module name(s) (qkv splits into 3, shared input_scale) | |
| def diffusers_targets(orig): | |
| # orig like "double_blocks.3.img_attn.qkv" or "single_blocks.7.linear2" | |
| parts = orig.split(".") | |
| if parts[0] == "double_blocks": | |
| n = parts[1]; sub = ".".join(parts[2:]); P = f"transformer_blocks.{n}" | |
| m = { | |
| "img_attn.qkv": [f"{P}.attn.to_q", f"{P}.attn.to_k", f"{P}.attn.to_v"], | |
| "img_attn.proj": [f"{P}.attn.to_out.0"], | |
| "txt_attn.qkv": [f"{P}.attn.add_q_proj", f"{P}.attn.add_k_proj", f"{P}.attn.add_v_proj"], | |
| "txt_attn.proj": [f"{P}.attn.to_add_out"], | |
| "img_mlp.0": [f"{P}.ff.linear_in"], "img_mlp.2": [f"{P}.ff.linear_out"], | |
| "txt_mlp.0": [f"{P}.ff_context.linear_in"], "txt_mlp.2": [f"{P}.ff_context.linear_out"], | |
| } | |
| return m[sub] | |
| else: # single_blocks | |
| n = parts[1]; sub = parts[2]; P = f"single_transformer_blocks.{n}" | |
| return {"linear1": [f"{P}.attn.to_qkv_mlp_proj"], "linear2": [f"{P}.attn.to_out"]}[sub] | |
| def load_bfl_fp8_pipeline(device="cuda"): | |
| raw = load_file(CKPT) | |
| meta = None | |
| # _quantization_metadata is in the file metadata; re-read header for the layer list | |
| import json as _j, struct | |
| with open(CKPT, "rb") as f: | |
| n = struct.unpack("<Q", f.read(8))[0]; hdr = _j.loads(f.read(n)) | |
| qlayers = list(_j.loads(hdr["__metadata__"]["_quantization_metadata"])["layers"].keys()) | |
| qset = set(qlayers) | |
| # 1) dequantize fp8 weights, stash input_scale per original layer | |
| deq = {} | |
| input_scale = {} | |
| for k, v in raw.items(): | |
| if k.endswith(".weight") and k[:-len(".weight")] in qset: | |
| L = k[:-len(".weight")] | |
| ws = raw[f"{L}.weight_scale"].float() | |
| deq[k] = (v.float() * ws).to(torch.bfloat16) # their exact dequantized weight | |
| input_scale[L] = float(raw[f"{L}.input_scale"]) | |
| elif k.endswith(".weight_scale") or k.endswith(".input_scale"): | |
| continue # consumed above | |
| else: | |
| deq[k] = v # bf16 norms/embeds/bias pass through | |
| # 2) diffusers' own converter (all tensors now bf16/2D -> QKV chunk works) | |
| conv = convert_flux2_transformer_checkpoint_to_diffusers(deq) | |
| # 3) build model from the SAME config and load | |
| model = Flux2Transformer2DModel.from_config(f"{BASE}/transformer") | |
| missing, unexpected = model.load_state_dict(conv, strict=False) | |
| missing = [m for m in missing if not m.endswith(".weight") or True] | |
| assert not unexpected, f"unexpected keys: {unexpected[:8]}" | |
| assert not missing, f"missing keys: {missing[:8]}" | |
| model = model.to(device=device, dtype=torch.bfloat16).eval().requires_grad_(False) | |
| # 4) faithful W8A8: static per-tensor FP8(E4M3) activation fake-quant on each quantized Linear | |
| name2mod = dict(model.named_modules()) | |
| def make_hook(scale): | |
| def hook(mod, args): | |
| x = args[0] | |
| xq = (x / scale).to(torch.float8_e4m3fn).to(x.dtype) * scale # saturating cast = fp8 quant | |
| return (xq,) + args[1:] | |
| return hook | |
| n_hooks = 0 | |
| for L in qlayers: | |
| s = input_scale[L] | |
| for tgt in diffusers_targets(L): | |
| name2mod[tgt].register_forward_pre_hook(make_hook(s)); n_hooks += 1 | |
| print(f"[bfl-fp8] loaded; {len(qlayers)} quantized layers, {n_hooks} activation-quant hooks", flush=True) | |
| pipe = Flux2KleinPipeline.from_pretrained(BASE, transformer=model, torch_dtype=torch.bfloat16).to(device) | |
| return pipe | |
| if __name__ == "__main__": | |
| OUT = sys.argv[1]; START = int(sys.argv[2]) if len(sys.argv) > 2 else 0 | |
| COUNT = int(sys.argv[3]) if len(sys.argv) > 3 else 10**9 | |
| RES = int(sys.argv[4]) if len(sys.argv) > 4 else 512 | |
| os.makedirs(OUT, exist_ok=True) | |
| prompts = json.load(open(os.environ.get("PROMPTS_JSON", "outputs/eval/prompts.json")))[START:START + COUNT] | |
| print(f"=== gen MODE=bfl-fp8 OUT={OUT} N={len(prompts)} RES={RES} ===", flush=True) | |
| pipe = load_bfl_fp8_pipeline() | |
| t0 = time.time(); done = 0 | |
| for d in prompts: | |
| f = os.path.join(OUT, f"{d['idx']:05d}.png") | |
| if os.path.exists(f): | |
| continue | |
| g = torch.Generator("cuda").manual_seed(d["idx"]) | |
| img = pipe(prompt=d["prompt"], num_inference_steps=4, guidance_scale=1.0, | |
| height=RES, width=RES, generator=g).images[0] | |
| img.save(f); done += 1 | |
| if done % 25 == 0: | |
| print(f" {done}/{len(prompts)} ({time.time()-t0:.0f}s)", flush=True) | |
| print(f"DONE bfl-fp8 -> {OUT} ({done} new imgs, {time.time()-t0:.0f}s)", flush=True) | |
Xet Storage Details
- Size:
- 6.16 kB
- Xet hash:
- 6cf541bd5cf2a7e4149dd762251f334097c1545abc779eabe84bd06d88565bc1
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.