Buckets:
| """Load BFL official quantized klein-4B (NVFP4 or FP8) single-file via diffusers + NVIDIA ModelOpt, | |
| swap into a pipeline built from models/klein-4b (so TE + VAE are IDENTICAL to teacher/ours). | |
| Returns a ready Flux2KleinPipeline. Used by the sanity test and scripts/41 (BFL gen). | |
| fmt = 'nvfp4' -> models/bfl-klein-4b-nvfp4/flux-2-klein-4b-nvfp4.safetensors, quant_type NVFP4 | |
| fmt = 'fp8' -> models/bfl-klein-4b-fp8/flux-2-klein-4b-fp8.safetensors, quant_type FP8 | |
| """ | |
| import time, torch | |
| CKPT = { | |
| "nvfp4": ("models/bfl-klein-4b-nvfp4/flux-2-klein-4b-nvfp4.safetensors", "NVFP4"), | |
| "fp8": ("models/bfl-klein-4b-fp8/flux-2-klein-4b-fp8.safetensors", "FP8"), | |
| } | |
| def load_bfl_pipeline(fmt="nvfp4", base="models/klein-4b", device="cuda"): | |
| path, quant_type = CKPT[fmt] | |
| # ModelOpt must patch the diffusers model classes BEFORE from_single_file so the | |
| # pre-quantized (packed U8 + FP8 group scales + per-tensor scales) weights restore correctly. | |
| from modelopt.torch.opt import enable_huggingface_checkpointing | |
| enable_huggingface_checkpointing() | |
| import modelopt.torch.quantization as mtq | |
| from diffusers import Flux2Transformer2DModel, Flux2KleinPipeline, NVIDIAModelOptConfig | |
| # diffusers' get_config_from_quant_type() assumes modelopt's OLD dict-based quant_cfg and | |
| # crashes on modelopt 0.44 (list-based). Bypass it by passing modelopt's own canonical config, | |
| # which matches BFL's stored layout exactly (NVFP4: group-16 weights + FP8[E4M3] scales). | |
| mcfg = {"NVFP4": mtq.NVFP4_DEFAULT_CFG, "FP8": mtq.FP8_DEFAULT_CFG}[quant_type] | |
| qcfg = NVIDIAModelOptConfig(quant_type=quant_type, modelopt_config=mcfg) | |
| t0 = time.time() | |
| # config= points the single-file loader at the LOCAL transformer config so it does not try to | |
| # fetch a default (black-forest-labs/FLUX.2-dev). Same architecture as the teacher. | |
| tf = Flux2Transformer2DModel.from_single_file( | |
| path, config=f"{base}/transformer", quantization_config=qcfg, torch_dtype=torch.bfloat16 | |
| ) | |
| print(f"[bfl-{fmt}] transformer loaded in {time.time()-t0:.0f}s", flush=True) | |
| pipe = Flux2KleinPipeline.from_pretrained(base, transformer=tf, torch_dtype=torch.bfloat16) | |
| pipe = pipe.to(device) | |
| return pipe | |
| if __name__ == "__main__": | |
| import sys, os | |
| fmt = sys.argv[1] if len(sys.argv) > 1 else "nvfp4" | |
| OUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/sanity/bfl" | |
| os.makedirs(OUT, exist_ok=True) | |
| pipe = load_bfl_pipeline(fmt) | |
| # 3 sensitive probes (text / hand / counting) + 1 paired MJHQ prompt | |
| probes = [ | |
| ('a vintage bookshop storefront with a wooden sign that reads "THE OPEN PAGE"', 0), | |
| ('a close-up photograph of a human hand with five fingers spread out, palm facing camera', 1), | |
| ('a flat-lay overhead photo of a wooden table with exactly three brown eggs in a white bowl', 2), | |
| ] | |
| import json | |
| prompts = json.load(open("outputs/eval/prompts.json")) | |
| probes.append((prompts[0]["prompt"], prompts[0]["idx"])) # paired with teacher idx 0 | |
| torch.cuda.reset_peak_memory_stats() | |
| for p, seed in probes: | |
| g = torch.Generator("cuda").manual_seed(seed) | |
| img = pipe(prompt=p, num_inference_steps=4, guidance_scale=1.0, height=512, width=512, generator=g).images[0] | |
| img.save(os.path.join(OUT, f"probe_{seed:05d}.png")) | |
| print(f" saved probe seed={seed}: {p[:50]}", flush=True) | |
| print(f"[bfl-{fmt}] peak VRAM {torch.cuda.max_memory_allocated()/1e9:.2f} GB", flush=True) | |
Xet Storage Details
- Size:
- 3.51 kB
- Xet hash:
- 1ab3272edcc55626698b02d3dfe1c6eb69b031d939fad6f4a96a687e5e0e353a
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.