Buckets:

Mercity/FluxDistill / scripts /40_load_bfl.py
Pranav2748's picture
download
raw
3.51 kB
"""Load BFL official quantized klein-4B (NVFP4 or FP8) single-file via diffusers + NVIDIA ModelOpt,
swap into a pipeline built from models/klein-4b (so TE + VAE are IDENTICAL to teacher/ours).
Returns a ready Flux2KleinPipeline. Used by the sanity test and scripts/41 (BFL gen).
fmt = 'nvfp4' -> models/bfl-klein-4b-nvfp4/flux-2-klein-4b-nvfp4.safetensors, quant_type NVFP4
fmt = 'fp8' -> models/bfl-klein-4b-fp8/flux-2-klein-4b-fp8.safetensors, quant_type FP8
"""
import time, torch
CKPT = {
"nvfp4": ("models/bfl-klein-4b-nvfp4/flux-2-klein-4b-nvfp4.safetensors", "NVFP4"),
"fp8": ("models/bfl-klein-4b-fp8/flux-2-klein-4b-fp8.safetensors", "FP8"),
}
def load_bfl_pipeline(fmt="nvfp4", base="models/klein-4b", device="cuda"):
path, quant_type = CKPT[fmt]
# ModelOpt must patch the diffusers model classes BEFORE from_single_file so the
# pre-quantized (packed U8 + FP8 group scales + per-tensor scales) weights restore correctly.
from modelopt.torch.opt import enable_huggingface_checkpointing
enable_huggingface_checkpointing()
import modelopt.torch.quantization as mtq
from diffusers import Flux2Transformer2DModel, Flux2KleinPipeline, NVIDIAModelOptConfig
# diffusers' get_config_from_quant_type() assumes modelopt's OLD dict-based quant_cfg and
# crashes on modelopt 0.44 (list-based). Bypass it by passing modelopt's own canonical config,
# which matches BFL's stored layout exactly (NVFP4: group-16 weights + FP8[E4M3] scales).
mcfg = {"NVFP4": mtq.NVFP4_DEFAULT_CFG, "FP8": mtq.FP8_DEFAULT_CFG}[quant_type]
qcfg = NVIDIAModelOptConfig(quant_type=quant_type, modelopt_config=mcfg)
t0 = time.time()
# config= points the single-file loader at the LOCAL transformer config so it does not try to
# fetch a default (black-forest-labs/FLUX.2-dev). Same architecture as the teacher.
tf = Flux2Transformer2DModel.from_single_file(
path, config=f"{base}/transformer", quantization_config=qcfg, torch_dtype=torch.bfloat16
)
print(f"[bfl-{fmt}] transformer loaded in {time.time()-t0:.0f}s", flush=True)
pipe = Flux2KleinPipeline.from_pretrained(base, transformer=tf, torch_dtype=torch.bfloat16)
pipe = pipe.to(device)
return pipe
if __name__ == "__main__":
import sys, os
fmt = sys.argv[1] if len(sys.argv) > 1 else "nvfp4"
OUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/sanity/bfl"
os.makedirs(OUT, exist_ok=True)
pipe = load_bfl_pipeline(fmt)
# 3 sensitive probes (text / hand / counting) + 1 paired MJHQ prompt
probes = [
('a vintage bookshop storefront with a wooden sign that reads "THE OPEN PAGE"', 0),
('a close-up photograph of a human hand with five fingers spread out, palm facing camera', 1),
('a flat-lay overhead photo of a wooden table with exactly three brown eggs in a white bowl', 2),
]
import json
prompts = json.load(open("outputs/eval/prompts.json"))
probes.append((prompts[0]["prompt"], prompts[0]["idx"])) # paired with teacher idx 0
torch.cuda.reset_peak_memory_stats()
for p, seed in probes:
g = torch.Generator("cuda").manual_seed(seed)
img = pipe(prompt=p, num_inference_steps=4, guidance_scale=1.0, height=512, width=512, generator=g).images[0]
img.save(os.path.join(OUT, f"probe_{seed:05d}.png"))
print(f" saved probe seed={seed}: {p[:50]}", flush=True)
print(f"[bfl-{fmt}] peak VRAM {torch.cuda.max_memory_allocated()/1e9:.2f} GB", flush=True)

Xet Storage Details

Size:
3.51 kB
·
Xet hash:
1ab3272edcc55626698b02d3dfe1c6eb69b031d939fad6f4a96a687e5e0e353a

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.