FLUX.1 [dev] Grid

This repository provides quantized weights of the FLUX.1 Kontext [dev], converted using BitsAndBytes in NF4 format. This enables GPU inference with reduced VRAM requirements, making it accessible even on the Google Colab free tier or on GPUs with 12GB VRAM.

The FLUX.1 Kontext [dev] model consists of three main components:

  • Text Encodersβ€”CLIP and T5
  • Flux Transformer
  • VAE

In this repository, only the T5 encoder and the Flux Transformer are quantized. The CLIP encoder and VAE remain in their original precision but are included to ensure a fully functional inference pipeline.

Usage

pip install bitsandbytes==0.48.1 diffusers==0.35.1 peft==0.17.1 protobuf==5.29.5 sentencepiece==0.2.1 transformers==4.56.1

Identify GPU Compute Capability (for bfloat16 support)

import torch

print(torch.cuda.get_device_capability()[0])

Full Pipeline Mode (β‰ˆ 17.6 GB VRAM and β‰₯ 8 Compute Capability)

import torch
from diffusers import FluxKontextPipeline
from diffusers.utils import load_image

ckpt_4bit_id = "aniketppanchal/flux.1-kontext-dev-nf4-pkg"
input_image = load_image("<your_image_path_or_url_here>")
prompt = "<your_editing_prompt_here>"
height = 1024
width = 1024

pipeline = FluxKontextPipeline.from_pretrained(
    ckpt_4bit_id,
    torch_dtype=torch.bfloat16,
    device_map="cuda",
)

image = pipeline(
    image=input_image,
    prompt=prompt,
    height=height,
    width=width,
    num_inference_steps=28,
    guidance_scale=3.5,
    max_sequence_length=512,
    max_area=height * width,
).images[0]
image.save("output.png")

Split Pipeline Mode (β‰ˆ 10.2 GB VRAM)

import gc

import torch
from diffusers import FluxKontextPipeline, FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS
from diffusers.utils import load_image
from transformers import T5EncoderModel

ckpt_4bit_id = "aniketppanchal/flux.1-kontext-dev-nf4-pkg"
input_image = load_image("<your_image_path_or_url_here>")
prompt = "<your_editing_prompt_here>"
height = 1024
width = 1024

major, _ = torch.cuda.get_device_capability()

# ----------Encode Prompt Embeddings----------

text_encoder_2 = T5EncoderModel.from_pretrained(
    ckpt_4bit_id,
    subfolder="text_encoder_2",
    torch_dtype=torch.bfloat16 if major >= 8 else torch.float16,
    device_map="cuda",
)
pipeline = FluxKontextPipeline.from_pretrained(
    ckpt_4bit_id,
    text_encoder_2=text_encoder_2,
    transformer=None,
    vae=None,
    torch_dtype=torch.bfloat16,
    device_map="cuda",
)

with torch.no_grad():
    prompt_embeds, pooled_prompt_embeds, _ = pipeline.encode_prompt(
        prompt=prompt,
        max_sequence_length=512,
    )

del text_encoder_2, pipeline
gc.collect()
torch.cuda.empty_cache()

# ----------Preprocess and Encode Image to Latents----------

pipeline = FluxKontextPipeline.from_pretrained(
    ckpt_4bit_id,
    text_encoder=None,
    text_encoder_2=None,
    tokenizer=None,
    tokenizer_2=None,
    transformer=None,
    torch_dtype=torch.bfloat16,
    device_map="cuda",
)

image_h, image_w = pipeline.image_processor.get_default_height_width(input_image)
aspect_ratio = image_w / image_h

_, pref_image_w, pref_image_h = min(
    (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
)
multiple_of = pipeline.vae_scale_factor * 2
new_image_w = pref_image_w // multiple_of * multiple_of
new_image_h = pref_image_h // multiple_of * multiple_of

processed_image = pipeline.image_processor.resize(
    input_image.copy(),
    new_image_h,
    new_image_w,
)
processed_image = pipeline.image_processor.preprocess(
    processed_image,
    new_image_h,
    new_image_w,
)
processed_image = processed_image.to(device=pipeline.device, dtype=pipeline.vae.dtype)

with torch.no_grad():
    image_latents = pipeline._encode_vae_image(processed_image, generator=None)

del processed_image, pipeline
gc.collect()
torch.cuda.empty_cache()

# ----------Generate Diffusion Latents----------

transformer = FluxTransformer2DModel.from_pretrained(
    ckpt_4bit_id,
    subfolder="transformer",
    torch_dtype=torch.bfloat16 if major >= 8 else torch.float16,
    device_map="cuda",
)
pipeline = FluxKontextPipeline.from_pretrained(
    ckpt_4bit_id,
    text_encoder=None,
    text_encoder_2=None,
    tokenizer=None,
    tokenizer_2=None,
    transformer=transformer,
    vae=None,
    torch_dtype=torch.bfloat16,
    device_map="cuda",
)

packed_latents = pipeline(
    image=image_latents,
    height=height,
    width=width,
    num_inference_steps=28,
    guidance_scale=3.5,
    prompt_embeds=prompt_embeds,
    pooled_prompt_embeds=pooled_prompt_embeds,
    output_type="latent",
    max_sequence_length=512,
    max_area=height * width,
).images

del (
    prompt_embeds,
    pooled_prompt_embeds,
    image_latents,
    transformer,
    pipeline,
)
gc.collect()
torch.cuda.empty_cache()

# ----------Decode Latents to Image----------

pipeline = FluxKontextPipeline.from_pretrained(
    ckpt_4bit_id,
    text_encoder=None,
    text_encoder_2=None,
    tokenizer=None,
    tokenizer_2=None,
    transformer=None,
    torch_dtype=torch.bfloat16,
    device_map="cuda",
)

unpacked_latents = (
    pipeline._unpack_latents(
        packed_latents,
        height=height,
        width=width,
        vae_scale_factor=pipeline.vae_scale_factor,
    )
    / pipeline.vae.config.scaling_factor
    + pipeline.vae.config.shift_factor
).to(device=pipeline.device, dtype=pipeline.vae.dtype)

with torch.no_grad():
    image_tensor = pipeline.vae.decode(unpacked_latents, return_dict=False)[0]

image = pipeline.image_processor.postprocess(image_tensor)[0]
image.save("output.png")

del packed_latents, unpacked_latents, image_tensor, pipeline
gc.collect()
torch.cuda.empty_cache()

License

This repository is released under the FLUX-1 Dev Non-Commercial License. The included LICENSE.md file corresponds to the frozen state of the original repository as of 3rd November 2025. For the latest version, see the FLUX.1 [dev] License.

Downloads last month
13
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for aniketppanchal/flux.1-kontext-dev-nf4-pkg

Quantized
(18)
this model

Collection including aniketppanchal/flux.1-kontext-dev-nf4-pkg