Buckets:

Mercity/FluxDistill / scripts /41_gen_bfl_fp8.py
Pranav2748's picture
download
raw
6.16 kB
"""Generate eval images for BFL's OFFICIAL FP8 klein-4B (model E), faithfully.
BFL's fp8 single-file stores, per quantized Linear: weight (F8_E4M3, plain 2D), weight_scale
(per-tensor F32), input_scale (per-tensor F32 = static activation scale). It is therefore W8A8.
Unlike their NVFP4 file, the fp8 layout is NOT swizzled: weight.float()*weight_scale reconstructs
the teacher weights to cos=0.9997 (verified). We:
1. dequantize each fp8 weight -> bf16 (their exact quantized weights),
2. reuse diffusers' OWN flux2 single-file converter on the all-bf16 dict (exact key remap + QKV split),
3. load into a Flux2Transformer2DModel built from the SAME models/klein-4b config (=> identical TE+VAE),
4. attach static per-tensor FP8(E4M3) activation fake-quant hooks (input_scale) so it is faithfully W8A8.
This runs as fake-quant (bf16 compute) -> correct QUALITY of BFL's checkpoint; it is not a real fp8
kernel, so no speedup is claimed for E (their fp8 kernel needs TensorRT). Images verified on probes first.
Usage: python3 -u scripts/41_gen_bfl_fp8.py OUT_DIR [START] [COUNT] [RES]
"""
import sys, os, json, time, torch
from safetensors.torch import load_file
from diffusers import Flux2Transformer2DModel, Flux2KleinPipeline
from diffusers.loaders.single_file_utils import convert_flux2_transformer_checkpoint_to_diffusers
CKPT = "models/bfl-klein-4b-fp8/flux-2-klein-4b-fp8.safetensors"
BASE = "models/klein-4b"
FP8_MAX = 448.0
# original-flux quantized-layer name -> diffusers module name(s) (qkv splits into 3, shared input_scale)
def diffusers_targets(orig):
# orig like "double_blocks.3.img_attn.qkv" or "single_blocks.7.linear2"
parts = orig.split(".")
if parts[0] == "double_blocks":
n = parts[1]; sub = ".".join(parts[2:]); P = f"transformer_blocks.{n}"
m = {
"img_attn.qkv": [f"{P}.attn.to_q", f"{P}.attn.to_k", f"{P}.attn.to_v"],
"img_attn.proj": [f"{P}.attn.to_out.0"],
"txt_attn.qkv": [f"{P}.attn.add_q_proj", f"{P}.attn.add_k_proj", f"{P}.attn.add_v_proj"],
"txt_attn.proj": [f"{P}.attn.to_add_out"],
"img_mlp.0": [f"{P}.ff.linear_in"], "img_mlp.2": [f"{P}.ff.linear_out"],
"txt_mlp.0": [f"{P}.ff_context.linear_in"], "txt_mlp.2": [f"{P}.ff_context.linear_out"],
}
return m[sub]
else: # single_blocks
n = parts[1]; sub = parts[2]; P = f"single_transformer_blocks.{n}"
return {"linear1": [f"{P}.attn.to_qkv_mlp_proj"], "linear2": [f"{P}.attn.to_out"]}[sub]
def load_bfl_fp8_pipeline(device="cuda"):
raw = load_file(CKPT)
meta = None
# _quantization_metadata is in the file metadata; re-read header for the layer list
import json as _j, struct
with open(CKPT, "rb") as f:
n = struct.unpack("<Q", f.read(8))[0]; hdr = _j.loads(f.read(n))
qlayers = list(_j.loads(hdr["__metadata__"]["_quantization_metadata"])["layers"].keys())
qset = set(qlayers)
# 1) dequantize fp8 weights, stash input_scale per original layer
deq = {}
input_scale = {}
for k, v in raw.items():
if k.endswith(".weight") and k[:-len(".weight")] in qset:
L = k[:-len(".weight")]
ws = raw[f"{L}.weight_scale"].float()
deq[k] = (v.float() * ws).to(torch.bfloat16) # their exact dequantized weight
input_scale[L] = float(raw[f"{L}.input_scale"])
elif k.endswith(".weight_scale") or k.endswith(".input_scale"):
continue # consumed above
else:
deq[k] = v # bf16 norms/embeds/bias pass through
# 2) diffusers' own converter (all tensors now bf16/2D -> QKV chunk works)
conv = convert_flux2_transformer_checkpoint_to_diffusers(deq)
# 3) build model from the SAME config and load
model = Flux2Transformer2DModel.from_config(f"{BASE}/transformer")
missing, unexpected = model.load_state_dict(conv, strict=False)
missing = [m for m in missing if not m.endswith(".weight") or True]
assert not unexpected, f"unexpected keys: {unexpected[:8]}"
assert not missing, f"missing keys: {missing[:8]}"
model = model.to(device=device, dtype=torch.bfloat16).eval().requires_grad_(False)
# 4) faithful W8A8: static per-tensor FP8(E4M3) activation fake-quant on each quantized Linear
name2mod = dict(model.named_modules())
def make_hook(scale):
def hook(mod, args):
x = args[0]
xq = (x / scale).to(torch.float8_e4m3fn).to(x.dtype) * scale # saturating cast = fp8 quant
return (xq,) + args[1:]
return hook
n_hooks = 0
for L in qlayers:
s = input_scale[L]
for tgt in diffusers_targets(L):
name2mod[tgt].register_forward_pre_hook(make_hook(s)); n_hooks += 1
print(f"[bfl-fp8] loaded; {len(qlayers)} quantized layers, {n_hooks} activation-quant hooks", flush=True)
pipe = Flux2KleinPipeline.from_pretrained(BASE, transformer=model, torch_dtype=torch.bfloat16).to(device)
return pipe
if __name__ == "__main__":
OUT = sys.argv[1]; START = int(sys.argv[2]) if len(sys.argv) > 2 else 0
COUNT = int(sys.argv[3]) if len(sys.argv) > 3 else 10**9
RES = int(sys.argv[4]) if len(sys.argv) > 4 else 512
os.makedirs(OUT, exist_ok=True)
prompts = json.load(open(os.environ.get("PROMPTS_JSON", "outputs/eval/prompts.json")))[START:START + COUNT]
print(f"=== gen MODE=bfl-fp8 OUT={OUT} N={len(prompts)} RES={RES} ===", flush=True)
pipe = load_bfl_fp8_pipeline()
t0 = time.time(); done = 0
for d in prompts:
f = os.path.join(OUT, f"{d['idx']:05d}.png")
if os.path.exists(f):
continue
g = torch.Generator("cuda").manual_seed(d["idx"])
img = pipe(prompt=d["prompt"], num_inference_steps=4, guidance_scale=1.0,
height=RES, width=RES, generator=g).images[0]
img.save(f); done += 1
if done % 25 == 0:
print(f" {done}/{len(prompts)} ({time.time()-t0:.0f}s)", flush=True)
print(f"DONE bfl-fp8 -> {OUT} ({done} new imgs, {time.time()-t0:.0f}s)", flush=True)

Xet Storage Details

Size:
6.16 kB
·
Xet hash:
6cf541bd5cf2a7e4149dd762251f334097c1545abc779eabe84bd06d88565bc1

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.