FLUX.2 Klein 9B — TorchAO NVFP4 Selective Quantization

This is an unofficial, community-made TorchAO NVFP4 quantized version of black-forest-labs/FLUX.2-klein-9B for Diffusers.

It was created by loading the original BF16 Diffusers pipeline, applying selective TorchAO NVFP4 quantization to the transformer and text encoder, testing generation, and saving the resulting pipeline as .bin weights for reload compatibility.

This is not an official Black Forest Labs release and is not endorsed, approved, or validated by Black Forest Labs.

License and attribution

This model is a derivative of FLUX.2 Klein 9B and is distributed under the FLUX Non-Commercial License.

This FLUX Model is licensed by Black Forest Labs Inc. under the FLUX Non-Commercial License. Copyright Black Forest Labs Inc. IN NO EVENT SHALL BLACK FOREST LABS INC. BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH USE OF THIS MODEL.

Additional notice: this repository contains a modified/quantized derivative of the original FLUX.2 Klein 9B model. The modification consists of selective TorchAO NVFP4 quantization of model components for inference.

For commercial use, consult Black Forest Labs licensing.

What was quantized?

The model was quantized using TorchAO NVFP4:

NVFP4DynamicActivationNVFP4WeightConfig(
    use_dynamic_per_tensor_scale=True,
    use_triton_kernel=True,
)

Transformer

Selective NVFP4 quantization was applied to the largest and fastest-safe transformer linear layers.

Final transformer coverage:

Component Linear params NVFP4 params Coverage
transformer 9.079B ~8.05B ~88.70%

The production filter corresponds to the best benchmarked configuration, named aggressive_a during testing:

  • transformer_blocks.*.attn.to_q
  • transformer_blocks.*.attn.to_k
  • transformer_blocks.*.attn.to_v
  • transformer_blocks.*.ff.linear_in
  • transformer_blocks.*.ff.linear_out
  • transformer_blocks.*.ff_context.linear_in
  • transformer_blocks.*.ff_context.linear_out
  • single_transformer_blocks.*.attn.to_qkv_mlp_proj
  • single_transformer_blocks.*.attn.to_out

Some small, modulation, embedding, and output layers remain BF16.

Text encoder

Selective NVFP4 quantization was applied to Qwen3 text encoder attention and MLP linear layers.

Final text encoder coverage:

Component Linear params NVFP4 params Coverage
text_encoder 7.568B 6.946B 91.78%

The lm_head remains BF16.

VAE

The VAE remains BF16.

Benchmark

Benchmarked on an NVIDIA RTX 6000 Pro / Blackwell-class GPU with CUDA + TorchAO NVFP4 support.

Prompt:

A cat holding a sign that says hello world

Settings:

height=1024
width=1024
guidance_scale=4.0
num_inference_steps=4
dtype=torch.bfloat16

Results:

Model Avg latency Peak VRAM Speedup
BF16 original 1.673 s 34.76 GB 1.00x
NVFP4 1.087 s 14.72 GB 1.54x
NVFP4 + compile 0.817 s 14.72 GB 2.05x

compile_repeated_blocks(fullgraph=True) was used for the compiled benchmark after warmup.

Installation

A recent PyTorch/TorchAO stack with Blackwell NVFP4 support is recommended.

Example:

pip install -U diffusers transformers accelerate safetensors sentencepiece protobuf huggingface_hub
pip install --pre -U torch torchvision torchao mslk --index-url https://download.pytorch.org/whl/nightly/cu130

Verify your GPU stack:

import torch
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.get_device_name(0))
print(torch.cuda.get_arch_list())

Usage

Important: this repo uses .bin serialization for TorchAO tensor subclasses, so load with use_safetensors=False.

import torch
from diffusers import Flux2KleinPipeline

pipe = Flux2KleinPipeline.from_pretrained(
    "joseplcam/FLUX.2-klein-9B-nvfp4",
    torch_dtype=torch.bfloat16,
    use_safetensors=False,
).to("cuda")

prompt = "A cat holding a sign that says hello world"

image = pipe(
    prompt=prompt,
    height=1024,
    width=1024,
    guidance_scale=4.0,
    num_inference_steps=4,
    generator=torch.Generator("cuda").manual_seed(0),
).images[0]

image.save("flux2-klein-nvfp4.png")

Optional compile

For best steady-state speed:

pipe.transformer.compile_repeated_blocks(fullgraph=True)

Run one or two warmup generations before measuring latency, because the first compiled run includes compile overhead.

Notes

  • This model is intended for inference.
  • This is a PyTorch/Diffusers/TorchAO quantized repo, not a GGUF model.
  • The checkpoint is saved as .bin because current TorchAO NVFP4 tensor subclasses are not reliably compatible with safetensors serialization in this workflow.
  • Use use_safetensors=False when loading.
  • This model is for non-commercial use under the FLUX Non-Commercial License.
Downloads last month
165
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for joseplcam/FLUX.2-klein-9B-nvfp4

Quantized
(25)
this model