diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..dc4dab5710f76ef392a8e80791479e040e335025
--- /dev/null
+++ b/README.md
@@ -0,0 +1,102 @@
+# FLUX-Kontext Optimized Implementation
+
+This package contains an optimized implementation of FLUX-Kontext using quantization and acceleration techniques.
+
+## Features
+
+- **Quantized FLUX Transformer**: Efficient INT4/FP4 quantized implementation of FLUX.1-Kontext
+- **Quantized T5 Encoder**: AWQ INT4 quantized T5 text encoder for memory efficiency
+- **LoRA Support**: Full support for LoRA fine-tuning and inference
+- **Gradio Web Interface**: Ready-to-use web interface for image editing
+
+## Installation
+
+```bash
+pip install -r requirements.txt
+python setup.py build_ext --inplace
+```
+
+## Quick Start
+
+### Using the Gradio Interface
+
+```bash
+cd app/kontext
+python run_gradio.py --precision int4
+```
+
+### Programmatic Usage
+
+```python
+import torch
+from diffusers import FluxKontextPipeline
+from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
+from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
+
+# Load quantized transformer
+transformer = NunchakuFluxTransformer2dModel.from_pretrained(
+ "mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-int4_r32-flux.1-kontext-dev.safetensors"
+)
+
+# Load quantized text encoder (optional)
+text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
+ "mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
+)
+
+# Create pipeline
+pipeline = FluxKontextPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-Kontext-dev",
+ transformer=transformer,
+ text_encoder_2=text_encoder_2,
+ torch_dtype=torch.bfloat16
+)
+pipeline = pipeline.to("cuda")
+
+# Generate image
+result = pipeline(
+ prompt="Your prompt here",
+ image=your_input_image,
+ num_inference_steps=28,
+ guidance_scale=2.5,
+).images[0]
+```
+
+## Available Models
+
+- `int4`: INT4 quantized transformer (default, most memory efficient)
+- `fp4`: FP4 quantized transformer
+- `bf16`: Full precision BFloat16 (highest quality, most memory usage)
+
+## Directory Structure
+
+```
+flux-kontext/
+├── nunchaku/ # Core quantized models and utilities
+│ ├── models/ # Transformer and text encoder models
+│ ├── lora/ # LoRA utilities
+│ ├── ops/ # Quantized operations
+│ └── csrc/ # C++ CUDA kernels
+├── app/ # Application interfaces
+│ └── kontext/ # Gradio web interface
+├── examples/ # Example scripts
+└── tests/ # Test scripts
+```
+
+## Examples
+
+See the `examples/` directory for various usage patterns:
+
+- `flux.1-kontext-dev.py`: Basic usage example
+- `flux.1-kontext-dev-teacache.py`: Using TeaCache for acceleration
+- `flux.1-kontext-FALAI_lora.py`: LoRA fine-tuning example
+
+## Requirements
+
+- Python >= 3.10
+- PyTorch >= 2.5
+- CUDA-capable GPU (recommended)
+- 8GB+ GPU memory (for INT4 quantization)
+
+## License
+
+See the main nunchaku project for license information.
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5c7f9d040f03eb1b5df4784d4331136cd362833
--- /dev/null
+++ b/app.py
@@ -0,0 +1,150 @@
+import random
+import time
+
+import torch
+from diffusers import FluxKontextPipeline
+from PIL import Image
+from utils import get_args
+from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
+from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
+
+
+import gradio as gr
+
+
+MAX_SEED = 1000000000
+
+args = get_args()
+
+if args.precision == "bf16":
+ pipeline = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16)
+ pipeline = pipeline.to("cuda")
+ pipeline.precision = "bf16"
+else:
+ assert args.precision in ["int4", "fp4"]
+ pipeline_init_kwargs = {}
+ transformer = NunchakuFluxTransformer2dModel.from_pretrained(
+ f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{args.precision}_r32-flux.1-kontext-dev.safetensors"
+ )
+ pipeline_init_kwargs["transformer"] = transformer
+ if args.use_qencoder:
+ text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
+ "mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
+ )
+ pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
+
+ pipeline = FluxKontextPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
+ )
+ pipeline = pipeline.to("cuda")
+ pipeline.precision = args.precision
+
+
+def run(image, prompt: str, num_inference_steps: int, guidance_scale: float, seed: int) -> tuple[Image, str]:
+ img = image["composite"].convert("RGB")
+
+ start_time = time.time()
+ result_image = pipeline(
+ prompt=prompt,
+ image=img,
+ height=img.height,
+ width=img.width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ generator=torch.Generator().manual_seed(seed),
+ ).images[0]
+
+ latency = time.time() - start_time
+ if latency < 1:
+ latency = latency * 1000
+ latency_str = f"{latency:.2f}ms"
+ else:
+ latency_str = f"{latency:.2f}s"
+ torch.cuda.empty_cache()
+ return result_image, latency_str
+
+
+with gr.Blocks(css_paths="assets/style.css", title="Nunchaku FLUX.1-Kontext Demo") as demo:
+ with open("assets/description.html", "r") as f:
+ DESCRIPTION = f.read()
+ # Get the GPU properties
+ if torch.cuda.device_count() > 0:
+ gpu_properties = torch.cuda.get_device_properties(0)
+ gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB
+ gpu_name = torch.cuda.get_device_name(0)
+ device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
+ else:
+ device_info = "Running on CPU 🥶 This demo does not work on CPU."
+
+ header_str = DESCRIPTION.format(precision=args.precision.upper(), device_info=device_info, count_info="")
+ header = gr.HTML(header_str)
+
+ with gr.Row(elem_id="main_row"):
+ with gr.Column(elem_id="column_input"):
+ gr.Markdown("## INPUT", elem_id="input_header")
+ with gr.Group():
+ canvas = gr.ImageEditor(
+ height=640,
+ image_mode="RGB",
+ sources=["upload", "clipboard"],
+ type="pil",
+ label="Input",
+ show_label=False,
+ show_download_button=True,
+ interactive=True,
+ transforms=[],
+ canvas_size=(1024, 1024),
+ scale=1,
+ format="png",
+ layers=False,
+ )
+ with gr.Row():
+ prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
+ run_button = gr.Button("Run", scale=1, elem_id="run_button")
+
+ with gr.Row():
+ seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
+ randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
+ with gr.Accordion("Advanced options", open=False):
+ with gr.Group():
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, value=28)
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, step=0.1, value=2.5)
+
+ with gr.Column(elem_id="column_output"):
+ gr.Markdown("## OUTPUT", elem_id="output_header")
+ with gr.Group():
+ result = gr.Image(
+ format="png",
+ height=640,
+ image_mode="RGB",
+ type="pil",
+ label="Result",
+ show_label=False,
+ show_download_button=True,
+ interactive=False,
+ elem_id="output_image",
+ )
+ latency_result = gr.Text(label="Inference Latency", show_label=True)
+
+ gr.Markdown("### Instructions")
+ gr.Markdown("**1**. Enter a text prompt")
+ gr.Markdown("**2**. Upload an image")
+ gr.Markdown("**3**. Try different seeds to generate different results")
+
+ run_inputs = [canvas, prompt, num_inference_steps, guidance_scale, seed]
+ run_outputs = [result, latency_result]
+
+ randomize_seed.click(
+ lambda: random.randint(0, MAX_SEED), inputs=[], outputs=seed, api_name=False, queue=False
+ ).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False)
+
+ gr.on(
+ triggers=[prompt.submit, run_button.click],
+ fn=run,
+ inputs=run_inputs,
+ outputs=run_outputs,
+ api_name=False,
+ )
+
+if __name__ == "__main__":
+ demo.queue().launch(debug=True, share=True, root_path=args.gradio_root_path)
diff --git a/app/kontext/README.md b/app/kontext/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f97be1005826aa11a86e1da103f150e8169f70e5
--- /dev/null
+++ b/app/kontext/README.md
@@ -0,0 +1,12 @@
+# Nunchaku INT4 FLUX.1 Kontext Demo
+
+
+
+This interactive Gradio application allows you to edit an image with natural language. Simply run:
+
+```shell
+python run_gradio.py
+```
+
+- To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`.
+- By default, we use our INT4 model. Use `-p bf16` to switch to the BF16 model.
diff --git a/app/kontext/assets/description.html b/app/kontext/assets/description.html
new file mode 100644
index 0000000000000000000000000000000000000000..e4d791567cdda3def8cc8f5f248690ead6fcf60c
--- /dev/null
+++ b/app/kontext/assets/description.html
@@ -0,0 +1,21 @@
+
+
+
+
+
{precision} FLUX.1-Kontext-dev Demo
+
+
+ {device_info}
+
+ {count_info}
+
+
diff --git a/app/kontext/assets/style.css b/app/kontext/assets/style.css
new file mode 100644
index 0000000000000000000000000000000000000000..6df54d33fe823c2dcb49ead8b33fb24fc0037bc1
--- /dev/null
+++ b/app/kontext/assets/style.css
@@ -0,0 +1,40 @@
+@import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/css/all.min.css');
+
+.gradio-container {
+ max-width: 1200px !important;
+ margin: auto; /* Centers the element horizontally */
+}
+
+h1 {
+ text-align: center
+}
+
+.wrap.svelte-p4aq0j.svelte-p4aq0j {
+ display: none;
+}
+
+#column_input, #column_output {
+ width: 500px;
+ display: flex;
+ align-items: center;
+}
+
+#input_header, #output_header {
+ display: flex;
+ justify-content: center;
+ align-items: center;
+ width: 400px;
+}
+
+#accessibility {
+ text-align: center; /* Center-aligns the text */
+ margin: auto; /* Centers the element horizontally */
+}
+
+#random_seed {
+ height: 71px;
+}
+
+#run_button {
+ height: 87px;
+}
diff --git a/app/kontext/run_gradio.py b/app/kontext/run_gradio.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5c7f9d040f03eb1b5df4784d4331136cd362833
--- /dev/null
+++ b/app/kontext/run_gradio.py
@@ -0,0 +1,150 @@
+import random
+import time
+
+import torch
+from diffusers import FluxKontextPipeline
+from PIL import Image
+from utils import get_args
+from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
+from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
+
+
+import gradio as gr
+
+
+MAX_SEED = 1000000000
+
+args = get_args()
+
+if args.precision == "bf16":
+ pipeline = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16)
+ pipeline = pipeline.to("cuda")
+ pipeline.precision = "bf16"
+else:
+ assert args.precision in ["int4", "fp4"]
+ pipeline_init_kwargs = {}
+ transformer = NunchakuFluxTransformer2dModel.from_pretrained(
+ f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{args.precision}_r32-flux.1-kontext-dev.safetensors"
+ )
+ pipeline_init_kwargs["transformer"] = transformer
+ if args.use_qencoder:
+ text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
+ "mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
+ )
+ pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
+
+ pipeline = FluxKontextPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
+ )
+ pipeline = pipeline.to("cuda")
+ pipeline.precision = args.precision
+
+
+def run(image, prompt: str, num_inference_steps: int, guidance_scale: float, seed: int) -> tuple[Image, str]:
+ img = image["composite"].convert("RGB")
+
+ start_time = time.time()
+ result_image = pipeline(
+ prompt=prompt,
+ image=img,
+ height=img.height,
+ width=img.width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ generator=torch.Generator().manual_seed(seed),
+ ).images[0]
+
+ latency = time.time() - start_time
+ if latency < 1:
+ latency = latency * 1000
+ latency_str = f"{latency:.2f}ms"
+ else:
+ latency_str = f"{latency:.2f}s"
+ torch.cuda.empty_cache()
+ return result_image, latency_str
+
+
+with gr.Blocks(css_paths="assets/style.css", title="Nunchaku FLUX.1-Kontext Demo") as demo:
+ with open("assets/description.html", "r") as f:
+ DESCRIPTION = f.read()
+ # Get the GPU properties
+ if torch.cuda.device_count() > 0:
+ gpu_properties = torch.cuda.get_device_properties(0)
+ gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB
+ gpu_name = torch.cuda.get_device_name(0)
+ device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
+ else:
+ device_info = "Running on CPU 🥶 This demo does not work on CPU."
+
+ header_str = DESCRIPTION.format(precision=args.precision.upper(), device_info=device_info, count_info="")
+ header = gr.HTML(header_str)
+
+ with gr.Row(elem_id="main_row"):
+ with gr.Column(elem_id="column_input"):
+ gr.Markdown("## INPUT", elem_id="input_header")
+ with gr.Group():
+ canvas = gr.ImageEditor(
+ height=640,
+ image_mode="RGB",
+ sources=["upload", "clipboard"],
+ type="pil",
+ label="Input",
+ show_label=False,
+ show_download_button=True,
+ interactive=True,
+ transforms=[],
+ canvas_size=(1024, 1024),
+ scale=1,
+ format="png",
+ layers=False,
+ )
+ with gr.Row():
+ prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
+ run_button = gr.Button("Run", scale=1, elem_id="run_button")
+
+ with gr.Row():
+ seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
+ randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
+ with gr.Accordion("Advanced options", open=False):
+ with gr.Group():
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, value=28)
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, step=0.1, value=2.5)
+
+ with gr.Column(elem_id="column_output"):
+ gr.Markdown("## OUTPUT", elem_id="output_header")
+ with gr.Group():
+ result = gr.Image(
+ format="png",
+ height=640,
+ image_mode="RGB",
+ type="pil",
+ label="Result",
+ show_label=False,
+ show_download_button=True,
+ interactive=False,
+ elem_id="output_image",
+ )
+ latency_result = gr.Text(label="Inference Latency", show_label=True)
+
+ gr.Markdown("### Instructions")
+ gr.Markdown("**1**. Enter a text prompt")
+ gr.Markdown("**2**. Upload an image")
+ gr.Markdown("**3**. Try different seeds to generate different results")
+
+ run_inputs = [canvas, prompt, num_inference_steps, guidance_scale, seed]
+ run_outputs = [result, latency_result]
+
+ randomize_seed.click(
+ lambda: random.randint(0, MAX_SEED), inputs=[], outputs=seed, api_name=False, queue=False
+ ).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False)
+
+ gr.on(
+ triggers=[prompt.submit, run_button.click],
+ fn=run,
+ inputs=run_inputs,
+ outputs=run_outputs,
+ api_name=False,
+ )
+
+if __name__ == "__main__":
+ demo.queue().launch(debug=True, share=True, root_path=args.gradio_root_path)
diff --git a/app/kontext/utils.py b/app/kontext/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..35956ee7688adc6fd052930dabc9cbfa95eae889
--- /dev/null
+++ b/app/kontext/utils.py
@@ -0,0 +1,14 @@
+import argparse
+
+
+def get_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-p", "--precision", type=str, default="int4", choices=["int4", "fp4", "bf16"], help="Which precisions to use"
+ )
+ parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
+ parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
+ parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
+ parser.add_argument("--gradio-root-path", type=str, default="")
+ args = parser.parse_args()
+ return args
diff --git a/examples/__init__.py b/examples/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f60b2ebcab3d03d5863d66723ffc472e1c55ccb
--- /dev/null
+++ b/examples/__init__.py
@@ -0,0 +1 @@
+# FLUX-Kontext examples
diff --git a/examples/flux.1-kontext-FALAI_lora.py b/examples/flux.1-kontext-FALAI_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cf9d30808845fc0799983cc8be83c0b3a98f744
--- /dev/null
+++ b/examples/flux.1-kontext-FALAI_lora.py
@@ -0,0 +1,30 @@
+import torch
+from diffusers import FluxKontextPipeline
+from diffusers.utils import load_image
+
+from nunchaku import NunchakuFluxTransformer2dModel
+from nunchaku.utils import get_precision
+
+transformer = NunchakuFluxTransformer2dModel.from_pretrained(
+ f"nunchaku-tech/nunchaku-flux.1-kontext-dev/svdq-{get_precision()}_r32-flux.1-kontext-dev.safetensors"
+)
+
+pipeline = FluxKontextPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-Kontext-dev", transformer=transformer, torch_dtype=torch.bfloat16
+).to("cuda")
+
+image = load_image(
+ "https://huggingface.co/datasets/nunchaku-tech/test-data/resolve/main/ComfyUI-nunchaku/inputs/monalisa.jpg"
+).convert("RGB")
+
+### LoRA Related Code ###
+transformer.update_lora_params(
+ "nunchaku-tech/nunchaku-test-models/relight-kontext-lora-single-caption_comfy.safetensors"
+ # "linoyts/relight-kontext-lora-single-caption/relight-kontext-lora-single-caption.safetensors"
+) # Path to your LoRA safetensors, can also be a remote HuggingFace path
+transformer.set_lora_strength(1) # Your LoRA strength here
+### End of LoRA Related Code ###
+
+prompt = "neon light, city"
+image = pipeline(image=image, prompt=prompt, generator=torch.Generator().manual_seed(23), guidance_scale=2.5).images[0]
+image.save("flux-kontext-dev.png")
diff --git a/examples/flux.1-kontext-dev-teacache.py b/examples/flux.1-kontext-dev-teacache.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9f40c61677bf522231b3edb498e9606a641897a
--- /dev/null
+++ b/examples/flux.1-kontext-dev-teacache.py
@@ -0,0 +1,30 @@
+import time
+
+import torch
+from diffusers import FluxKontextPipeline
+from diffusers.utils import load_image
+
+from nunchaku import NunchakuFluxTransformer2dModel
+from nunchaku.caching.teacache import TeaCache
+from nunchaku.utils import get_precision
+
+transformer = NunchakuFluxTransformer2dModel.from_pretrained(
+ f"nunchaku-tech/nunchaku-flux.1-kontext-dev/svdq-{get_precision()}_r32-flux.1-kontext-dev.safetensors"
+)
+
+pipeline = FluxKontextPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-Kontext-dev", transformer=transformer, torch_dtype=torch.bfloat16
+).to("cuda")
+
+image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
+).convert("RGB")
+
+prompt = "Make Pikachu hold a sign that says 'Nunchaku is awesome', yarn art style, detailed, vibrant colors"
+
+start_time = time.time()
+with TeaCache(model=transformer, num_steps=50, rel_l1_thresh=0.3, enabled=True, model_name="flux-kontext"):
+ image = pipeline(image=image, prompt=prompt, guidance_scale=2.5).images[0]
+end_time = time.time()
+print(f"Time taken: {(end_time - start_time)} seconds")
+image.save(f"flux-kontext-dev-{get_precision()}-tc.png")
diff --git a/examples/flux.1-kontext-dev.py b/examples/flux.1-kontext-dev.py
new file mode 100644
index 0000000000000000000000000000000000000000..2868d582f158d49c16d27ee39a8b0b40d1b3db6d
--- /dev/null
+++ b/examples/flux.1-kontext-dev.py
@@ -0,0 +1,22 @@
+import torch
+from diffusers import FluxKontextPipeline
+from diffusers.utils import load_image
+
+from nunchaku import NunchakuFluxTransformer2dModel
+from nunchaku.utils import get_precision
+
+transformer = NunchakuFluxTransformer2dModel.from_pretrained(
+ f"nunchaku-tech/nunchaku-flux.1-kontext-dev/svdq-{get_precision()}_r32-flux.1-kontext-dev.safetensors"
+)
+
+pipeline = FluxKontextPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-Kontext-dev", transformer=transformer, torch_dtype=torch.bfloat16
+).to("cuda")
+
+image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
+).convert("RGB")
+
+prompt = "Make Pikachu hold a sign that says 'Nunchaku is awesome', yarn art style, detailed, vibrant colors"
+image = pipeline(image=image, prompt=prompt, guidance_scale=2.5).images[0]
+image.save("flux-kontext-dev.png")
diff --git a/nunchaku/__init__.py b/nunchaku/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..93a19f86d836471ac684e7fbf4d7020e69d7470f
--- /dev/null
+++ b/nunchaku/__init__.py
@@ -0,0 +1,9 @@
+from .models import (
+ NunchakuFluxTransformer2dModel,
+ NunchakuT5EncoderModel,
+)
+
+__all__ = [
+ "NunchakuFluxTransformer2dModel",
+ "NunchakuT5EncoderModel",
+]
diff --git a/nunchaku/__version__.py b/nunchaku/__version__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4deaef52f24d6a5663ae2d6b3864b255c02b92b
--- /dev/null
+++ b/nunchaku/__version__.py
@@ -0,0 +1 @@
+__version__ = "1.0.0-flux-kontext"
diff --git a/nunchaku/csrc/flux.h b/nunchaku/csrc/flux.h
new file mode 100644
index 0000000000000000000000000000000000000000..f2488400b0a285f34ee1b52ff96e194f7f86f029
--- /dev/null
+++ b/nunchaku/csrc/flux.h
@@ -0,0 +1,254 @@
+#pragma once
+
+#include "interop/torch.h"
+#include "FluxModel.h"
+#include "Serialization.h"
+#include "debug.h"
+#include "Linear.h"
+#include "module.h"
+
+class QuantizedFluxModel : public ModuleWrapper { // : public torch::CustomClassHolder {
+public:
+ void init(bool use_fp4, bool offload, bool bf16, int8_t deviceId) {
+ spdlog::info("Initializing QuantizedFluxModel on device {}", deviceId);
+ if (!bf16) {
+ spdlog::info("Use FP16 model");
+ }
+ if (offload) {
+ spdlog::info("Layer offloading enabled");
+ }
+ ModuleWrapper::init(deviceId);
+
+ CUDADeviceContext ctx(this->deviceId);
+ net = std::make_unique(
+ use_fp4, offload, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
+ }
+
+ bool isBF16() {
+ checkModel();
+ return net->dtype == Tensor::BF16;
+ }
+ pybind11::function residual_callback;
+ void set_residual_callback(pybind11::function callback) {
+ pybind11::gil_scoped_acquire gil;
+ if (!callback || callback.is_none()) {
+ residual_callback = pybind11::function();
+ if (net) {
+ net->set_residual_callback(nullptr);
+ }
+ return;
+ }
+ residual_callback = std::move(callback);
+ if (net) {
+ pybind11::object cb = residual_callback;
+ net->set_residual_callback([cb](const Tensor &x) -> Tensor {
+ torch::Tensor torch_x = to_torch(x);
+ pybind11::object result = cb(torch_x);
+ torch::Tensor torch_y = result.cast();
+ Tensor y = from_torch(torch_y);
+ return y;
+ });
+ } else {
+ }
+ }
+
+ torch::Tensor forward(torch::Tensor hidden_states,
+ torch::Tensor encoder_hidden_states,
+ torch::Tensor temb,
+ torch::Tensor rotary_emb_img,
+ torch::Tensor rotary_emb_context,
+ torch::Tensor rotary_emb_single,
+ std::optional controlnet_block_samples = std::nullopt,
+ std::optional controlnet_single_block_samples = std::nullopt,
+ bool skip_first_layer = false) {
+ checkModel();
+ CUDADeviceContext ctx(deviceId);
+
+ spdlog::debug("QuantizedFluxModel forward");
+
+ hidden_states = hidden_states.contiguous();
+ encoder_hidden_states = encoder_hidden_states.contiguous();
+ temb = temb.contiguous();
+ rotary_emb_img = rotary_emb_img.contiguous();
+ rotary_emb_context = rotary_emb_context.contiguous();
+ rotary_emb_single = rotary_emb_single.contiguous();
+
+ Tensor result = net->forward(
+ from_torch(hidden_states),
+ from_torch(encoder_hidden_states),
+ from_torch(temb),
+ from_torch(rotary_emb_img),
+ from_torch(rotary_emb_context),
+ from_torch(rotary_emb_single),
+ controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
+ controlnet_single_block_samples.has_value()
+ ? from_torch(controlnet_single_block_samples.value().contiguous())
+ : Tensor{},
+ skip_first_layer);
+
+ torch::Tensor output = to_torch(result);
+ Tensor::synchronizeDevice();
+
+ return output;
+ }
+
+ std::tuple
+ forward_layer(int64_t idx,
+ torch::Tensor hidden_states,
+ torch::Tensor encoder_hidden_states,
+ torch::Tensor temb,
+ torch::Tensor rotary_emb_img,
+ torch::Tensor rotary_emb_context,
+ std::optional controlnet_block_samples = std::nullopt,
+ std::optional controlnet_single_block_samples = std::nullopt) {
+ CUDADeviceContext ctx(deviceId);
+
+ spdlog::debug("QuantizedFluxModel forward_layer {}", idx);
+
+ hidden_states = hidden_states.contiguous();
+ encoder_hidden_states = encoder_hidden_states.contiguous();
+ temb = temb.contiguous();
+ rotary_emb_img = rotary_emb_img.contiguous();
+ rotary_emb_context = rotary_emb_context.contiguous();
+
+ auto &&[hidden_states_, encoder_hidden_states_] = net->forward_layer(
+ idx,
+ from_torch(hidden_states),
+ from_torch(encoder_hidden_states),
+ from_torch(temb),
+ from_torch(rotary_emb_img),
+ from_torch(rotary_emb_context),
+ controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
+ controlnet_single_block_samples.has_value()
+ ? from_torch(controlnet_single_block_samples.value().contiguous())
+ : Tensor{});
+
+ hidden_states = to_torch(hidden_states_);
+ encoder_hidden_states = to_torch(encoder_hidden_states_);
+ Tensor::synchronizeDevice();
+
+ return {hidden_states, encoder_hidden_states};
+ }
+
+ torch::Tensor forward_single_layer(int64_t idx,
+ torch::Tensor hidden_states,
+ torch::Tensor temb,
+ torch::Tensor rotary_emb_single) {
+ CUDADeviceContext ctx(deviceId);
+
+ spdlog::debug("QuantizedFluxModel forward_single_layer {}", idx);
+
+ hidden_states = hidden_states.contiguous();
+ temb = temb.contiguous();
+ rotary_emb_single = rotary_emb_single.contiguous();
+
+ if (net->isOffloadEnabled()) {
+ net->single_transformer_blocks.at(idx)->loadLazyParams();
+ }
+
+ Tensor result = net->single_transformer_blocks.at(idx)->forward(
+ from_torch(hidden_states), from_torch(temb), from_torch(rotary_emb_single));
+
+ if (net->isOffloadEnabled()) {
+ net->single_transformer_blocks.at(idx)->releaseLazyParams();
+ }
+
+ hidden_states = to_torch(result);
+ Tensor::synchronizeDevice();
+
+ return hidden_states;
+ }
+
+ // expose the norm1 forward method of the transformer blocks
+ // this is used by TeaCache to get the norm1 output
+ std::tuple
+ norm_one_forward(int64_t idx, torch::Tensor hidden_states, torch::Tensor temb) {
+ AdaLayerNormZero::Output result =
+ net->transformer_blocks.at(idx)->norm1.forward(from_torch(hidden_states), from_torch(temb));
+ return {to_torch(result.x),
+ to_torch(result.gate_msa),
+ to_torch(result.shift_mlp),
+ to_torch(result.scale_mlp),
+ to_torch(result.gate_mlp)};
+ }
+
+ // must be called after loading lora
+ // skip specific ranks in W4A4 layers
+ void setLoraScale(int skipRanks, float scale) {
+ if (skipRanks % 16 != 0) {
+ throw std::invalid_argument("skipRanks must be multiples of 16");
+ }
+
+ CUDADeviceContext ctx(deviceId);
+
+ spdlog::info("Set lora scale to {} (skip {} ranks)", scale, skipRanks);
+
+ net->traverse([&](Module *module) {
+ if (auto *m = dynamic_cast(module)) {
+ m->lora_scale = scale;
+ } else if (auto *m = dynamic_cast(module)) {
+ for (int i = 0; i < skipRanks / 16; i++) {
+ m->lora_scales[i] = 1.0f;
+ }
+ for (int i = skipRanks / 16; i < (int)m->lora_scales.size(); i++) {
+ m->lora_scales[i] = scale;
+ }
+ }
+ });
+ }
+
+ void setAttentionImpl(std::string name) {
+ if (name.empty() || name == "default") {
+ name = "flashattn2";
+ }
+
+ spdlog::info("Set attention implementation to {}", name);
+
+ if (name == "flashattn2") {
+ net->setAttentionImpl(AttentionImpl::FlashAttention2);
+ } else if (name == "nunchaku-fp16") {
+ net->setAttentionImpl(AttentionImpl::NunchakuFP16);
+ } else {
+ throw std::invalid_argument(spdlog::fmt_lib::format("Invalid attention implementation {}", name));
+ }
+ }
+
+ std::tuple
+ forward_layer_ip_adapter(int64_t idx,
+ torch::Tensor hidden_states,
+ torch::Tensor encoder_hidden_states,
+ torch::Tensor temb,
+ torch::Tensor rotary_emb_img,
+ torch::Tensor rotary_emb_context,
+ std::optional controlnet_block_samples = std::nullopt,
+ std::optional controlnet_single_block_samples = std::nullopt) {
+ CUDADeviceContext ctx(deviceId);
+
+ spdlog::debug("QuantizedFluxModel forward_layer {}", idx);
+
+ hidden_states = hidden_states.contiguous();
+ encoder_hidden_states = encoder_hidden_states.contiguous();
+ temb = temb.contiguous();
+ rotary_emb_img = rotary_emb_img.contiguous();
+ rotary_emb_context = rotary_emb_context.contiguous();
+
+ auto &&[hidden_states_, encoder_hidden_states_, ip_query_] = net->forward_ip_adapter(
+ idx,
+ from_torch(hidden_states),
+ from_torch(encoder_hidden_states),
+ from_torch(temb),
+ from_torch(rotary_emb_img),
+ from_torch(rotary_emb_context),
+ controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
+ controlnet_single_block_samples.has_value()
+ ? from_torch(controlnet_single_block_samples.value().contiguous())
+ : Tensor{});
+
+ hidden_states = to_torch(hidden_states_);
+ encoder_hidden_states = to_torch(encoder_hidden_states_);
+ torch::Tensor ip_query = to_torch(ip_query_);
+ Tensor::synchronizeDevice();
+
+ return {hidden_states, encoder_hidden_states, ip_query};
+ }
+};
diff --git a/nunchaku/csrc/gemm.h b/nunchaku/csrc/gemm.h
new file mode 100644
index 0000000000000000000000000000000000000000..74d1fb4b5ea15200d643fa98e406ef836ccfd307
--- /dev/null
+++ b/nunchaku/csrc/gemm.h
@@ -0,0 +1,114 @@
+#pragma once
+
+#include "interop/torch.h"
+#include "Serialization.h"
+#include "Linear.h"
+#include "debug.h"
+#include "module.h"
+
+class QuantizedGEMM : public ModuleWrapper {
+public:
+ void init(int64_t in_features, int64_t out_features, bool bias, bool use_fp4, bool bf16, int8_t deviceId) {
+ spdlog::info("Initializing QuantizedGEMM");
+
+ size_t val = 0;
+ checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, 8192));
+ checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
+ spdlog::debug("Stack={}", val);
+
+ net = std::make_unique((int)in_features,
+ (int)out_features,
+ bias,
+ use_fp4,
+ bf16 ? Tensor::BF16 : Tensor::FP16,
+ Device::cuda((int)deviceId));
+ }
+
+ torch::Tensor forward(torch::Tensor x) {
+ checkModel();
+
+ std::cerr << "QuantizedGEMM forward" << std::endl;
+
+ x = x.contiguous();
+
+ Tensor result = net->forward(from_torch(x));
+
+ torch::Tensor output = to_torch(result);
+ Tensor::synchronizeDevice();
+
+ return output;
+ }
+
+ std::string dumpTensorBF16(Tensor x) {
+ std::stringstream ss;
+ for (int i = 0; i < 256; i++) {
+ ss << spdlog::fmt_lib::format("{:.3f} ", (float)(x.data_ptr<__nv_bfloat16>()[i]));
+ }
+ ss << std::endl;
+ return ss.str();
+ }
+
+ std::string dumpTensorINT4(Tensor x) {
+ using spdlog::fmt_lib::format;
+
+ const int M = x.shape[0];
+ const int K = x.shape[1] * 2;
+
+ assert(x.dtype() == Tensor::INT8);
+
+ // activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t (uint4)
+
+ constexpr int BLOCK_M = 256;
+ constexpr int WARP_K = 64;
+ constexpr int NUM_WARPS = 8;
+ constexpr int WARP_M_TILES = 2;
+ constexpr int WARP_SIZE = 32;
+
+ std::stringstream ss;
+ for (int bm = 0; bm < M / BLOCK_M; bm++) {
+ for (int bn = 0; bn < K / WARP_K; bn++) {
+ for (int warpId = 0; warpId < NUM_WARPS; warpId++) {
+ ss << format("[bm={},bn={},warp={}] ", bm, bn, warpId);
+ const int offset = ((bm * (K / WARP_K) + bn) * NUM_WARPS + warpId) * WARP_M_TILES * WARP_SIZE * 4;
+
+ for (int i = 0; i < 16; i++) {
+ assert(static_cast(offset + i) < x.numel() / 4);
+ uint32_t val = x.data_ptr()[offset + i];
+ ss << "{";
+ for (int j = 0; j < 8; j++) {
+ int i4val = (val >> (j * 4)) & 0xf;
+ if (i4val & 0x8) {
+ i4val = -((~i4val & 0x7) + 1);
+ }
+ ss << format("{} ", i4val);
+ }
+ ss << format("}} {:x} ", val);
+ }
+ ss << std::endl;
+ }
+ }
+ }
+
+ ss << std::endl;
+ return ss.str();
+ }
+
+ void quantize(torch::Tensor x, bool fuse_glu) {
+ checkModel();
+
+ spdlog::debug("QuantizedGEMM quantize");
+
+ x = x.contiguous();
+
+ auto qout = net->quantize(from_torch(x), fuse_glu);
+
+ Tensor act = qout.act.copy(Device::cpu());
+ Tensor ascales = qout.ascales.copy(Device::cpu());
+ Tensor lora_act = qout.lora_act.copy(Device::cpu());
+
+ Tensor::synchronizeDevice();
+
+ spdlog::debug("act = {}", dumpTensorINT4(act));
+ spdlog::debug("ascales = {}", dumpTensorBF16(ascales));
+ }
+};
diff --git a/nunchaku/csrc/gemm88.h b/nunchaku/csrc/gemm88.h
new file mode 100644
index 0000000000000000000000000000000000000000..aa3dd3175f0a4e89c42c4ff4643a339d49b9e504
--- /dev/null
+++ b/nunchaku/csrc/gemm88.h
@@ -0,0 +1,37 @@
+#pragma once
+
+#include "interop/torch.h"
+#include "Serialization.h"
+#include "Linear.h"
+#include "debug.h"
+#include "module.h"
+
+class QuantizedGEMM88 : public ModuleWrapper {
+public:
+ void init(int64_t in_features, int64_t out_features, bool bias, bool bf16, int8_t deviceId) {
+ spdlog::info("Initializing QuantizedGEMM88");
+
+ size_t val = 0;
+ checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, 8192));
+ checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
+ spdlog::debug("Stack={}", val);
+
+ net = std::make_unique(
+ (int)in_features, (int)out_features, bias, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
+ }
+
+ torch::Tensor forward(torch::Tensor x) {
+ checkModel();
+
+ std::cerr << "QuantizedGEMM88 forward" << std::endl;
+
+ x = x.contiguous();
+
+ Tensor result = net->forward(from_torch(x));
+
+ torch::Tensor output = to_torch(result);
+ Tensor::synchronizeDevice();
+
+ return output;
+ }
+};
diff --git a/nunchaku/csrc/module.h b/nunchaku/csrc/module.h
new file mode 100644
index 0000000000000000000000000000000000000000..812c82720203f3abf5eb5cd3cda2bf6fc9747bce
--- /dev/null
+++ b/nunchaku/csrc/module.h
@@ -0,0 +1,85 @@
+#pragma once
+
+#include "interop/torch.h"
+#include "Serialization.h"
+#include "Module.h"
+#include "debug.h"
+#include "utils.h"
+
+template
+class ModuleWrapper {
+public:
+ void init(int deviceId) {
+ this->deviceId = deviceId;
+ }
+ void reset() {
+ CUDADeviceContext ctx(this->deviceId);
+
+ debugContext.reset();
+ net.reset();
+ Tensor::synchronizeDevice();
+
+ nunchaku::utils::trim_memory();
+ Tensor::synchronizeDevice();
+ }
+
+ void load(std::string path, bool partial = false) {
+ checkModel();
+ CUDADeviceContext ctx(this->deviceId);
+
+ spdlog::info("{} weights from {}", partial ? "Loading partial" : "Loading", path);
+
+ std::shared_ptr provider = std::make_shared(path);
+ net->loadParams(*provider, partial);
+ Tensor::synchronizeDevice();
+
+ spdlog::info("Done.");
+ }
+
+ void loadDict(std::map dict, bool partial = false) {
+ checkModel();
+ CUDADeviceContext ctx(this->deviceId);
+
+ spdlog::info("{} weights from pytorch", partial ? "Loading partial" : "Loading");
+
+ std::shared_ptr provider = std::make_shared(std::move(dict));
+ net->loadParams(*provider, partial);
+ Tensor::synchronizeDevice();
+
+ spdlog::info("Done.");
+ }
+
+ void startDebug() {
+ debugContext = std::make_unique();
+ }
+ void stopDebug() {
+ debugContext.reset();
+ }
+
+ auto getDebugResults() {
+ CUDADeviceContext ctx(this->deviceId);
+
+ std::map result;
+
+ if (debugContext) {
+ for (auto &&[key, value] : debugContext->tensors) {
+ result[key] = to_torch(value);
+ }
+ }
+
+ return result;
+ }
+
+protected:
+ void checkModel() {
+ if (!net) {
+ throw std::runtime_error("Model not initialized");
+ }
+ }
+
+protected:
+ std::unique_ptr net;
+ std::unique_ptr debugContext;
+
+ int deviceId = -1;
+};
diff --git a/nunchaku/csrc/ops.h b/nunchaku/csrc/ops.h
new file mode 100644
index 0000000000000000000000000000000000000000..dbd15fb22ce89492e520c77775311e612c7e18ed
--- /dev/null
+++ b/nunchaku/csrc/ops.h
@@ -0,0 +1,173 @@
+#pragma once
+
+#include "interop/torch.h"
+#include "kernels/zgemm/zgemm.h"
+#include "kernels/awq/gemv_awq.h"
+#include "kernels/awq/gemm_awq.h"
+
+namespace nunchaku::ops {
+
+void gemm_w4a4(std::optional act, // packed act [M, K / 2]
+ std::optional wgt, // packed act [N, K / 2]
+ std::optional out, // linear [M, N]
+ std::optional qout, // packed act [M, N / 2]
+ std::optional ascales, // packed as [K / 64, M]
+ std::optional wscales, // packed ws [K / 64, N]
+ std::optional oscales, // packed as [N / 64, M]
+ std::optional poolout, // linear [M / PoolSize, N]
+ std::optional lora_act_in, // packed lora_act [M, R]
+ std::optional lora_up, // packed lora_wgt [N, R]
+ std::optional lora_down, // packed lora_wgt [N, R]
+ std::optional lora_act_out, // packed lora_act [M, R]
+ std::optional norm_q, // linear [HEAD_DIM]
+ std::optional norm_k, // linear [HEAD_DIM]
+ std::optional rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
+ std::optional bias, // packed ws [N]
+ std::optional smooth_factor, // packed ws [N], for quantization of the next layer
+ std::optional out_vk, // linear [B, num_heads, head_dim + 1, head_dim]
+ std::optional out_linearattn, // linear [B, (M), N / 3]
+ bool act_unsigned,
+ std::vector lora_scales,
+ bool fuse_silu,
+ bool fp4,
+ float alpha,
+ std::optional wcscales,
+ std::optional out_q, // packed attention [B, H, M, D]
+ std::optional out_k, // packed attention [B, H, M, D]
+ std::optional out_v, // packed attention [B, H, M, D]
+ int attn_tokens) {
+ TorchOpContext ctx;
+ spdlog::trace("running gemm_w4a4: ");
+
+ auto getTensor = [](std::optional &t) {
+ Tensor ret = t.has_value() ? from_torch(t.value()) : Tensor{};
+ if (ret.valid()) {
+ spdlog::trace(" {}", ret.shape.str());
+ } else {
+ spdlog::trace(" ");
+ }
+ return ret;
+ };
+ nunchaku::kernels::gemm_w4a4(getTensor(act),
+ getTensor(wgt),
+ getTensor(out),
+ getTensor(qout),
+ getTensor(ascales),
+ getTensor(wscales),
+ getTensor(oscales),
+ getTensor(poolout),
+ getTensor(lora_act_in),
+ getTensor(lora_up),
+ getTensor(lora_down),
+ getTensor(lora_act_out),
+ getTensor(norm_q),
+ getTensor(norm_k),
+ getTensor(rotary_emb),
+ getTensor(bias),
+ getTensor(smooth_factor),
+ getTensor(out_vk),
+ getTensor(out_linearattn),
+ act_unsigned,
+ lora_scales,
+ fuse_silu,
+ fp4,
+ alpha,
+ getTensor(wcscales),
+ getTensor(out_q),
+ getTensor(out_k),
+ getTensor(out_v),
+ attn_tokens);
+ // Tensor::synchronizeDevice();
+}
+
+void quantize_w4a4_act_fuse_lora(std::optional input,
+ std::optional output,
+ std::optional oscales,
+ std::optional lora_down,
+ std::optional lora_act_out,
+ std::optional smooth,
+ bool fuse_glu,
+ bool fp4) {
+ TorchOpContext ctx;
+
+ spdlog::trace("running quantize_w4a4_act_fuse_lora: ");
+
+ auto getTensor = [](std::optional &t) {
+ Tensor ret = t.has_value() ? from_torch(t.value()) : Tensor{};
+ if (ret.valid()) {
+ spdlog::trace(" {}", ret.shape.str());
+ } else {
+ spdlog::trace(" ");
+ }
+ return ret;
+ };
+ nunchaku::kernels::quantize_w4a4_act_fuse_lora(getTensor(input),
+ getTensor(output),
+ getTensor(oscales),
+ getTensor(lora_down),
+ getTensor(lora_act_out),
+ getTensor(smooth),
+ fuse_glu,
+ fp4);
+}
+
+void attention_fp16(torch::Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
+ torch::Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM]
+ torch::Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM]
+ torch::Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM]
+ float scale) {
+ TorchOpContext ctx;
+ nunchaku::kernels::attention_fp16(from_torch(q), from_torch(k), from_torch(v), from_torch(o), scale);
+}
+
+torch::Tensor gemv_awq(torch::Tensor _in_feats,
+ torch::Tensor _kernel,
+ torch::Tensor _scaling_factors,
+ torch::Tensor _zeros,
+ int64_t m,
+ int64_t n,
+ int64_t k,
+ int64_t group_size) {
+ TorchOpContext ctx;
+ Tensor result = ::gemv_awq(from_torch(_in_feats.contiguous()),
+ from_torch(_kernel.contiguous()),
+ from_torch(_scaling_factors.contiguous()),
+ from_torch(_zeros.contiguous()),
+ (int)m,
+ (int)n,
+ (int)k,
+ (int)group_size);
+
+ torch::Tensor output = to_torch(result);
+ // Tensor::synchronizeDevice();
+
+ return output;
+}
+
+torch::Tensor
+gemm_awq(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros) {
+ Tensor result = ::awq_gemm_forward_cuda(from_torch(_in_feats.contiguous()),
+ from_torch(_kernel.contiguous()),
+ from_torch(_scaling_factors.contiguous()),
+ from_torch(_zeros.contiguous()));
+
+ TorchOpContext ctx;
+ // TODO: allocate output in torch and use from_torch instead (to_torch needs an extra copy)
+ torch::Tensor output = to_torch(result);
+ // Tensor::synchronizeDevice();
+
+ return output;
+}
+
+void test_rmsnorm_rope(
+ torch::Tensor input, torch::Tensor output, torch::Tensor norm_q, torch::Tensor norm_k, torch::Tensor rotary_emb) {
+ nunchaku::kernels::test_rmsnorm_rope(
+ from_torch(input), from_torch(output), from_torch(norm_q), from_torch(norm_k), from_torch(rotary_emb));
+}
+
+void test_pack_qkv(torch::Tensor input, torch::Tensor out_q, torch::Tensor out_k, torch::Tensor out_v, int numTokens) {
+ nunchaku::kernels::test_pack_qkv(
+ from_torch(input), from_torch(out_q), from_torch(out_k), from_torch(out_v), numTokens);
+}
+
+}; // namespace nunchaku::ops
diff --git a/nunchaku/csrc/pybind.cpp b/nunchaku/csrc/pybind.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..74fc37f0f04aec50dc73072234d978e365de7d37
--- /dev/null
+++ b/nunchaku/csrc/pybind.cpp
@@ -0,0 +1,124 @@
+#include "gemm.h"
+#include "gemm88.h"
+#include "flux.h"
+#include "sana.h"
+#include "ops.h"
+#include "utils.h"
+#include
+#include "interop/torch.h"
+#include
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ py::class_(m, "QuantizedFluxModel")
+ .def(py::init<>())
+ .def("init",
+ &QuantizedFluxModel::init,
+ py::arg("use_fp4"),
+ py::arg("offload"),
+ py::arg("bf16"),
+ py::arg("deviceId"))
+ .def("set_residual_callback",
+ [](QuantizedFluxModel &self, pybind11::object call_back) {
+ if (call_back.is_none()) {
+ self.set_residual_callback(pybind11::function());
+ } else {
+ self.set_residual_callback(call_back);
+ }
+ })
+ .def("reset", &QuantizedFluxModel::reset)
+ .def("load", &QuantizedFluxModel::load, py::arg("path"), py::arg("partial") = false)
+ .def("loadDict", &QuantizedFluxModel::loadDict, py::arg("dict"), py::arg("partial") = false)
+ .def("forward",
+ &QuantizedFluxModel::forward,
+ py::arg("hidden_states"),
+ py::arg("encoder_hidden_states"),
+ py::arg("temb"),
+ py::arg("rotary_emb_img"),
+ py::arg("rotary_emb_context"),
+ py::arg("rotary_emb_single"),
+ py::arg("controlnet_block_samples") = py::none(),
+ py::arg("controlnet_single_block_samples") = py::none(),
+ py::arg("skip_first_layer") = false)
+ .def("forward_layer",
+ &QuantizedFluxModel::forward_layer,
+ py::arg("idx"),
+ py::arg("hidden_states"),
+ py::arg("encoder_hidden_states"),
+ py::arg("temb"),
+ py::arg("rotary_emb_img"),
+ py::arg("rotary_emb_context"),
+ py::arg("controlnet_block_samples") = py::none(),
+ py::arg("controlnet_single_block_samples") = py::none())
+ .def("forward_layer_ip_adapter",
+ &QuantizedFluxModel::forward_layer_ip_adapter,
+ py::arg("idx"),
+ py::arg("hidden_states"),
+ py::arg("encoder_hidden_states"),
+ py::arg("temb"),
+ py::arg("rotary_emb_img"),
+ py::arg("rotary_emb_context"),
+ py::arg("controlnet_block_samples") = py::none(),
+ py::arg("controlnet_single_block_samples") = py::none())
+ .def("forward_single_layer", &QuantizedFluxModel::forward_single_layer)
+ .def("norm_one_forward", &QuantizedFluxModel::norm_one_forward)
+ .def("startDebug", &QuantizedFluxModel::startDebug)
+ .def("stopDebug", &QuantizedFluxModel::stopDebug)
+ .def("getDebugResults", &QuantizedFluxModel::getDebugResults)
+ .def("setLoraScale", &QuantizedFluxModel::setLoraScale)
+ .def("setAttentionImpl", &QuantizedFluxModel::setAttentionImpl)
+ .def("isBF16", &QuantizedFluxModel::isBF16);
+ py::class_(m, "QuantizedSanaModel")
+ .def(py::init<>())
+ .def("init",
+ &QuantizedSanaModel::init,
+ py::arg("config"),
+ py::arg("pag_layers"),
+ py::arg("use_fp4"),
+ py::arg("bf16"),
+ py::arg("deviceId"))
+ .def("reset", &QuantizedSanaModel::reset)
+ .def("load", &QuantizedSanaModel::load, py::arg("path"), py::arg("partial") = false)
+ .def("loadDict", &QuantizedSanaModel::loadDict, py::arg("dict"), py::arg("partial") = false)
+ .def("forward", &QuantizedSanaModel::forward)
+ .def("forward_layer", &QuantizedSanaModel::forward_layer)
+ .def("startDebug", &QuantizedSanaModel::startDebug)
+ .def("stopDebug", &QuantizedSanaModel::stopDebug)
+ .def("getDebugResults", &QuantizedSanaModel::getDebugResults);
+ py::class_(m, "QuantizedGEMM")
+ .def(py::init<>())
+ .def("init", &QuantizedGEMM::init)
+ .def("reset", &QuantizedGEMM::reset)
+ .def("load", &QuantizedGEMM::load)
+ .def("forward", &QuantizedGEMM::forward)
+ .def("quantize", &QuantizedGEMM::quantize)
+ .def("startDebug", &QuantizedGEMM::startDebug)
+ .def("stopDebug", &QuantizedGEMM::stopDebug)
+ .def("getDebugResults", &QuantizedGEMM::getDebugResults);
+ py::class_(m, "Tensor");
+ py::class_(m, "QuantizedGEMM88")
+ .def(py::init<>())
+ .def("init", &QuantizedGEMM88::init)
+ .def("reset", &QuantizedGEMM88::reset)
+ .def("load", &QuantizedGEMM88::load)
+ .def("forward", &QuantizedGEMM88::forward)
+ .def("startDebug", &QuantizedGEMM88::startDebug)
+ .def("stopDebug", &QuantizedGEMM88::stopDebug)
+ .def("getDebugResults", &QuantizedGEMM88::getDebugResults);
+
+ m.def_submodule("ops")
+ .def("gemm_w4a4", nunchaku::ops::gemm_w4a4)
+ .def("quantize_w4a4_act_fuse_lora", nunchaku::ops::quantize_w4a4_act_fuse_lora)
+ .def("attention_fp16", nunchaku::ops::attention_fp16)
+ .def("gemm_awq", nunchaku::ops::gemm_awq)
+ .def("gemv_awq", nunchaku::ops::gemv_awq)
+
+ .def("test_rmsnorm_rope", nunchaku::ops::test_rmsnorm_rope)
+ .def("test_pack_qkv", nunchaku::ops::test_pack_qkv);
+
+ m.def_submodule("utils")
+ .def("set_log_level", [](const std::string &level) { spdlog::set_level(spdlog::level::from_str(level)); })
+ .def("set_cuda_stack_limit", nunchaku::utils::set_cuda_stack_limit)
+ .def("disable_memory_auto_release", nunchaku::utils::disable_memory_auto_release)
+ .def("trim_memory", nunchaku::utils::trim_memory)
+ .def("set_faster_i2f_mode", nunchaku::utils::set_faster_i2f_mode);
+}
diff --git a/nunchaku/csrc/sana.h b/nunchaku/csrc/sana.h
new file mode 100644
index 0000000000000000000000000000000000000000..8d6c75ffba160cc6125772a5b583b99e7c7de541
--- /dev/null
+++ b/nunchaku/csrc/sana.h
@@ -0,0 +1,102 @@
+#pragma once
+
+#include "interop/torch.h"
+#include "SanaModel.h"
+#include "Serialization.h"
+#include "debug.h"
+#include "module.h"
+
+class QuantizedSanaModel : public ModuleWrapper {
+public:
+ void init(pybind11::dict config, std::vector pag_layers, bool use_fp4, bool bf16, int8_t deviceId) {
+ spdlog::info("Initializing QuantizedSanaModel on device {}", deviceId);
+ SanaConfig cfg{
+ .num_layers = config["num_layers"].cast(),
+ .num_attention_heads = config["num_attention_heads"].cast(),
+ .attention_head_dim = config["attention_head_dim"].cast(),
+ .num_cross_attention_heads = config["num_cross_attention_heads"].cast(),
+ .expand_ratio = config["mlp_ratio"].cast(),
+ .pag_layers = pag_layers,
+ .use_fp4 = use_fp4,
+ };
+
+ ModuleWrapper::init(deviceId);
+ CUDADeviceContext ctx(this->deviceId);
+ net = std::make_unique(cfg, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
+ }
+
+ torch::Tensor forward(torch::Tensor hidden_states,
+ torch::Tensor encoder_hidden_states,
+ torch::Tensor timestep,
+ torch::Tensor cu_seqlens_img,
+ torch::Tensor cu_seqlens_txt,
+ int H,
+ int W,
+ bool pag,
+ bool cfg,
+ bool skip_first_layer = false) {
+ checkModel();
+ CUDADeviceContext ctx(deviceId);
+
+ spdlog::debug("QuantizedSanaModel forward");
+
+ hidden_states = hidden_states.contiguous();
+ encoder_hidden_states = encoder_hidden_states.contiguous();
+ timestep = timestep.contiguous();
+ cu_seqlens_img = cu_seqlens_img.contiguous();
+ cu_seqlens_txt = cu_seqlens_txt.contiguous();
+
+ Tensor result = net->forward(from_torch(hidden_states),
+ from_torch(encoder_hidden_states),
+ from_torch(timestep),
+ from_torch(cu_seqlens_img),
+ from_torch(cu_seqlens_txt),
+ H,
+ W,
+ pag,
+ cfg,
+ skip_first_layer);
+
+ torch::Tensor output = to_torch(result);
+ // Tensor::synchronizeDevice();
+
+ return output;
+ }
+
+ torch::Tensor forward_layer(int64_t idx,
+ torch::Tensor hidden_states,
+ torch::Tensor encoder_hidden_states,
+ torch::Tensor timestep,
+ torch::Tensor cu_seqlens_img,
+ torch::Tensor cu_seqlens_txt,
+ int H,
+ int W,
+ bool pag,
+ bool cfg) {
+ checkModel();
+ CUDADeviceContext ctx(deviceId);
+
+ spdlog::debug("QuantizedSanaModel forward_layer {}", idx);
+
+ hidden_states = hidden_states.contiguous();
+ encoder_hidden_states = encoder_hidden_states.contiguous();
+ timestep = timestep.contiguous();
+ cu_seqlens_img = cu_seqlens_img.contiguous();
+ cu_seqlens_txt = cu_seqlens_txt.contiguous();
+
+ Tensor result = net->transformer_blocks.at(idx)->forward(from_torch(hidden_states),
+ from_torch(encoder_hidden_states),
+ from_torch(timestep),
+ from_torch(cu_seqlens_img),
+ from_torch(cu_seqlens_txt),
+ H,
+ W,
+ pag,
+ cfg);
+
+ torch::Tensor output = to_torch(result);
+ // Tensor::synchronizeDevice();
+
+ return output;
+ }
+};
diff --git a/nunchaku/csrc/utils.h b/nunchaku/csrc/utils.h
new file mode 100644
index 0000000000000000000000000000000000000000..967f87772901aec5f66d4fc0860f8e642e51e60b
--- /dev/null
+++ b/nunchaku/csrc/utils.h
@@ -0,0 +1,39 @@
+#pragma once
+
+#include "common.h"
+#include "Tensor.h"
+#include "kernels/zgemm/zgemm.h"
+
+namespace nunchaku::utils {
+
+void set_cuda_stack_limit(int64_t newval) {
+ size_t val = 0;
+ checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, (size_t)newval));
+ checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
+ spdlog::debug("Stack={}", val);
+}
+
+void disable_memory_auto_release() {
+ int device;
+ checkCUDA(cudaGetDevice(&device));
+ cudaMemPool_t mempool;
+ checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device));
+ uint64_t threshold = UINT64_MAX;
+ checkCUDA(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &threshold));
+}
+
+void trim_memory() {
+ int device;
+ checkCUDA(cudaGetDevice(&device));
+ cudaMemPool_t mempool;
+ checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device));
+ size_t bytesToKeep = 0;
+ checkCUDA(cudaMemPoolTrimTo(mempool, bytesToKeep));
+}
+
+void set_faster_i2f_mode(std::string mode) {
+ spdlog::info("Set fasteri2f mode to {}", mode);
+ kernels::set_faster_i2f_mode(mode);
+}
+
+}; // namespace nunchaku::utils
diff --git a/nunchaku/lora/__init__.py b/nunchaku/lora/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f037000fbd143a8ad81aeb825a4667002b07d14a
--- /dev/null
+++ b/nunchaku/lora/__init__.py
@@ -0,0 +1 @@
+# LoRA utilities for FLUX models
diff --git a/nunchaku/lora/flux/__init__.py b/nunchaku/lora/flux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d49b5f570dcbe27f81f934f7fb73f5dd7c5456c
--- /dev/null
+++ b/nunchaku/lora/flux/__init__.py
@@ -0,0 +1,5 @@
+from .diffusers_converter import to_diffusers
+from .nunchaku_converter import convert_to_nunchaku_flux_lowrank_dict, to_nunchaku
+from .utils import is_nunchaku_format
+
+__all__ = ["to_diffusers", "to_nunchaku", "convert_to_nunchaku_flux_lowrank_dict", "is_nunchaku_format"]
diff --git a/nunchaku/lora/flux/compose.py b/nunchaku/lora/flux/compose.py
new file mode 100644
index 0000000000000000000000000000000000000000..a44a8aff4726c026f0dc2859c06a054453539c3c
--- /dev/null
+++ b/nunchaku/lora/flux/compose.py
@@ -0,0 +1,218 @@
+"""
+Compose multiple LoRA weights into a single LoRA for FLUX models.
+
+This script merges several LoRA safetensors files into one, applying individual strength values to each.
+
+**Example Usage:**
+
+.. code-block:: bash
+
+ python -m nunchaku.lora.flux.compose \\
+ -i lora1.safetensors lora2.safetensors \\
+ -s 0.8 1.0 \\
+ -o composed_lora.safetensors
+
+**Arguments:**
+
+- ``-i``, ``--input-paths``: Input LoRA safetensors files (one or more).
+- ``-s``, ``--strengths``: Strength value for each LoRA (must match number of inputs).
+- ``-o``, ``--output-path``: Output path for the composed LoRA safetensors file.
+
+This will merge ``lora1.safetensors`` (strength 0.8) and ``lora2.safetensors`` (strength 1.0) into ``composed_lora.safetensors``.
+
+**Main Function**
+
+:func:`compose_lora`
+"""
+
+import argparse
+import os
+
+import torch
+import torch.nn.functional as F
+from safetensors.torch import save_file
+
+from .diffusers_converter import to_diffusers
+from .utils import is_nunchaku_format, load_state_dict_in_safetensors
+
+
+def compose_lora(
+ loras: list[tuple[str | dict[str, torch.Tensor], float]], output_path: str | None = None
+) -> dict[str, torch.Tensor]:
+ """
+ Compose multiple LoRA weights into a single LoRA representation.
+
+ Parameters
+ ----------
+ loras : list of (str or dict[str, torch.Tensor], float)
+ Each tuple contains:
+ - Path to a LoRA safetensors file or a LoRA weights dictionary.
+ - Strength/scale factor for that LoRA.
+ output_path : str, optional
+ Path to save the composed LoRA weights as a safetensors file. If None, does not save.
+
+ Returns
+ -------
+ dict[str, torch.Tensor]
+ The composed LoRA weights.
+
+ Raises
+ ------
+ AssertionError
+ If LoRA weights are in Nunchaku format (must be converted to Diffusers format first)
+ or if tensor shapes are incompatible.
+
+ Notes
+ -----
+ - Converts all input LoRAs to Diffusers format.
+ - Handles QKV projection fusion for attention layers.
+ - Applies strength scaling to LoRA weights.
+ - Concatenates multiple LoRAs along appropriate dimensions.
+ - Handles normalization layers, bias vectors, and FLUX.1-tools LoRA compatibility.
+
+ Examples
+ --------
+ >>> lora_paths = [("lora1.safetensors", 0.8), ("lora2.safetensors", 0.6)]
+ >>> composed = compose_lora(lora_paths, "composed_lora.safetensors")
+ >>> lora_dicts = [({"layer.weight": torch.randn(10, 20)}, 1.0)]
+ >>> composed = compose_lora(lora_dicts)
+ """
+ if len(loras) == 1:
+ if is_nunchaku_format(loras[0][0]) and (loras[0][1] - 1) < 1e-5:
+ if isinstance(loras[0][0], str):
+ return load_state_dict_in_safetensors(loras[0][0], device="cpu")
+ else:
+ return loras[0][0]
+
+ composed = {}
+ for lora, strength in loras:
+ assert not is_nunchaku_format(lora)
+ lora = to_diffusers(lora)
+ for k, v in list(lora.items()):
+ if v.ndim == 1:
+ previous_tensor = composed.get(k, None)
+ if previous_tensor is None:
+ if "norm_q" in k or "norm_k" in k or "norm_added_q" in k or "norm_added_k" in k:
+ composed[k] = v
+ else:
+ composed[k] = v * strength
+ else:
+ assert not ("norm_q" in k or "norm_k" in k or "norm_added_q" in k or "norm_added_k" in k)
+ composed[k] = previous_tensor + v * strength
+ else:
+ assert v.ndim == 2
+ if ".to_q." in k or ".add_q_proj." in k: # qkv must all exist
+ if "lora_B" in k:
+ continue
+
+ q_a = v
+ k_a = lora[k.replace(".to_q.", ".to_k.").replace(".add_q_proj.", ".add_k_proj.")]
+ v_a = lora[k.replace(".to_q.", ".to_v.").replace(".add_q_proj.", ".add_v_proj.")]
+
+ q_b = lora[k.replace("lora_A", "lora_B")]
+ k_b = lora[
+ k.replace("lora_A", "lora_B")
+ .replace(".to_q.", ".to_k.")
+ .replace(".add_q_proj.", ".add_k_proj.")
+ ]
+ v_b = lora[
+ k.replace("lora_A", "lora_B")
+ .replace(".to_q.", ".to_v.")
+ .replace(".add_q_proj.", ".add_v_proj.")
+ ]
+
+ # Add paddings if their ranks are different
+ max_rank = max(q_a.shape[0], k_a.shape[0], v_a.shape[0])
+ q_a = F.pad(q_a, (0, 0, 0, max_rank - q_a.shape[0]))
+ k_a = F.pad(k_a, (0, 0, 0, max_rank - k_a.shape[0]))
+ v_a = F.pad(v_a, (0, 0, 0, max_rank - v_a.shape[0]))
+ q_b = F.pad(q_b, (0, max_rank - q_b.shape[1]))
+ k_b = F.pad(k_b, (0, max_rank - k_b.shape[1]))
+ v_b = F.pad(v_b, (0, max_rank - v_b.shape[1]))
+
+ if torch.isclose(q_a, k_a).all() and torch.isclose(q_a, v_a).all():
+ lora_a = q_a
+ lora_b = torch.cat((q_b, k_b, v_b), dim=0)
+ else:
+ lora_a_group = (q_a, k_a, v_a)
+ new_shape_a = [sum([_.shape[0] for _ in lora_a_group]), q_a.shape[1]]
+ lora_a = torch.zeros(new_shape_a, dtype=q_a.dtype, device=q_a.device)
+ start_dim = 0
+ for tensor in lora_a_group:
+ lora_a[start_dim : start_dim + tensor.shape[0]] = tensor
+ start_dim += tensor.shape[0]
+
+ lora_b_group = (q_b, k_b, v_b)
+ new_shape_b = [sum([_.shape[0] for _ in lora_b_group]), sum([_.shape[1] for _ in lora_b_group])]
+ lora_b = torch.zeros(new_shape_b, dtype=q_b.dtype, device=q_b.device)
+ start_dims = (0, 0)
+ for tensor in lora_b_group:
+ end_dims = (start_dims[0] + tensor.shape[0], start_dims[1] + tensor.shape[1])
+ lora_b[start_dims[0] : end_dims[0], start_dims[1] : end_dims[1]] = tensor
+ start_dims = end_dims
+
+ lora_a = lora_a * strength
+
+ new_k_a = k.replace(".to_q.", ".to_qkv.").replace(".add_q_proj.", ".add_qkv_proj.")
+ new_k_b = new_k_a.replace("lora_A", "lora_B")
+
+ for kk, vv, dim in ((new_k_a, lora_a, 0), (new_k_b, lora_b, 1)):
+ previous_lora = composed.get(kk, None)
+ composed[kk] = vv if previous_lora is None else torch.cat([previous_lora, vv], dim=dim)
+
+ elif ".to_k." in k or ".to_v." in k or ".add_k_proj." in k or ".add_v_proj." in k:
+ continue
+ else:
+ if "lora_A" in k:
+ v = v * strength
+
+ previous_lora = composed.get(k, None)
+ if previous_lora is None:
+ composed[k] = v
+ else:
+ if "lora_A" in k:
+ if previous_lora.shape[1] != v.shape[1]: # flux.1-tools LoRA compatibility
+ assert "x_embedder" in k
+ expanded_size = max(previous_lora.shape[1], v.shape[1])
+ if expanded_size > previous_lora.shape[1]:
+ expanded_previous_lora = torch.zeros(
+ (previous_lora.shape[0], expanded_size),
+ device=previous_lora.device,
+ dtype=previous_lora.dtype,
+ )
+ expanded_previous_lora[:, : previous_lora.shape[1]] = previous_lora
+ else:
+ expanded_previous_lora = previous_lora
+ if expanded_size > v.shape[1]:
+ expanded_v = torch.zeros(
+ (v.shape[0], expanded_size), device=v.device, dtype=v.dtype
+ )
+ expanded_v[:, : v.shape[1]] = v
+ else:
+ expanded_v = v
+ composed[k] = torch.cat([expanded_previous_lora, expanded_v], dim=0)
+ else:
+ composed[k] = torch.cat([previous_lora, v], dim=0)
+ else:
+ composed[k] = torch.cat([previous_lora, v], dim=1)
+
+ composed[k] = (
+ v if previous_lora is None else torch.cat([previous_lora, v], dim=0 if "lora_A" in k else 1)
+ )
+ if output_path is not None:
+ output_dir = os.path.dirname(os.path.abspath(output_path))
+ os.makedirs(output_dir, exist_ok=True)
+ save_file(composed, output_path)
+ return composed
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-i", "--input-paths", type=str, nargs="*", required=True, help="paths to the lora safetensors files"
+ )
+ parser.add_argument("-s", "--strengths", type=float, nargs="*", required=True, help="strengths for each lora")
+ parser.add_argument("-o", "--output-path", type=str, required=True, help="path to the output safetensors file")
+ args = parser.parse_args()
+ assert len(args.input_paths) == len(args.strengths)
+ compose_lora(list(zip(args.input_paths, args.strengths)), args.output_path)
diff --git a/nunchaku/lora/flux/convert.py b/nunchaku/lora/flux/convert.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff11f40accb842f67197469a8c7cc8be0f3e102e
--- /dev/null
+++ b/nunchaku/lora/flux/convert.py
@@ -0,0 +1,74 @@
+"""
+CLI tool to convert LoRA weights to Nunchaku format.
+
+**Example Usage:**
+
+.. code-block:: bash
+
+ python -m nunchaku.lora.flux.convert \\
+ --lora-path composed_lora.safetensors \\
+ --quant-path mit-han-lab/svdq-int4-flux.1-dev/transformer_blocks.safetensors \\
+ --output-root ./converted \\
+ --dtype bfloat16
+
+**Arguments:**
+
+- ``--lora-path``: Path to the LoRA weights safetensor file (required)
+- ``--quant-path``: Path to the quantized model safetensor file (default: ``mit-han-lab/svdq-int4-flux.1-dev/transformer_blocks.safetensors``)
+- ``--output-root``: Root directory for the output safetensor file (default: parent directory of the lora file)
+- ``--lora-name``: Name of the LoRA weights (optional, auto-generated if not provided)
+- ``--dtype``: Data type of the converted weights, either ``bfloat16`` or ``float16`` (default: ``bfloat16``)
+
+**Main Function**
+
+:func:`nunchaku.lora.flux.nunchaku_converter.to_nunchaku`
+"""
+
+import argparse
+import os
+
+from .nunchaku_converter import to_nunchaku
+from .utils import is_nunchaku_format
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--quant-path",
+ type=str,
+ help="Path to the quantized model safetensors file.",
+ default="mit-han-lab/svdq-int4-flux.1-dev/transformer_blocks.safetensors",
+ )
+ parser.add_argument("--lora-path", type=str, required=True, help="Path to LoRA weights safetensors file.")
+ parser.add_argument("--output-root", type=str, default="", help="Root directory for output safetensors file.")
+ parser.add_argument("--lora-name", type=str, default=None, help="Name for the output LoRA weights.")
+ parser.add_argument(
+ "--dtype",
+ type=str,
+ default="bfloat16",
+ choices=["bfloat16", "float16"],
+ help="Data type of the converted weights.",
+ )
+ args = parser.parse_args()
+
+ if is_nunchaku_format(args.lora_path):
+ print("Already in Nunchaku format, no conversion needed.")
+ exit(0)
+
+ if not args.output_root:
+ args.output_root = os.path.dirname(args.lora_path)
+ if args.lora_name is None:
+ base_name = os.path.basename(args.lora_path)
+ lora_name = base_name.rsplit(".", 1)[0]
+ precision = "fp4" if "fp4" in args.quant_path else "int4"
+ lora_name = f"svdq-{precision}-{lora_name}"
+ print(f"LoRA name not provided, using {lora_name} as the LoRA name")
+ else:
+ lora_name = args.lora_name
+ assert lora_name, "LoRA name must be provided."
+
+ to_nunchaku(
+ args.lora_path,
+ args.quant_path,
+ dtype=args.dtype,
+ output_path=os.path.join(args.output_root, f"{lora_name}.safetensors"),
+ )
diff --git a/nunchaku/lora/flux/diffusers_converter.py b/nunchaku/lora/flux/diffusers_converter.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0c30134ef718ef79509d1d2d22182963cd97651
--- /dev/null
+++ b/nunchaku/lora/flux/diffusers_converter.py
@@ -0,0 +1,220 @@
+"""
+This module implements the functions to convert FLUX LoRA weights from various formats
+to the Diffusers format, which will later be converted to Nunchaku format.
+"""
+
+import argparse
+import logging
+import os
+
+import torch
+from diffusers.loaders import FluxLoraLoaderMixin
+from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
+from safetensors.torch import save_file
+
+from ...utils import load_state_dict_in_safetensors
+
+# Get log level from environment variable (default to INFO)
+log_level = os.getenv("LOG_LEVEL", "INFO").upper()
+
+# Configure logging
+logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
+logger = logging.getLogger(__name__)
+
+
+def handle_kohya_lora(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """
+ Convert Kohya LoRA format keys to Diffusers format.
+
+ Parameters
+ ----------
+ state_dict : dict[str, torch.Tensor]
+ LoRA weights, possibly in Kohya format.
+
+ Returns
+ -------
+ dict[str, torch.Tensor]
+ LoRA weights in Diffusers format.
+ """
+ # first check if the state_dict is in the kohya format
+ # like: https://civitai.com/models/1118358?modelVersionId=1256866
+ if any([not k.startswith("lora_transformer_") for k in state_dict.keys()]):
+ return state_dict
+ else:
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ new_k = k.replace("lora_transformer_", "transformer.")
+
+ new_k = new_k.replace("norm_out_", "norm_out.")
+
+ new_k = new_k.replace("time_text_embed_", "time_text_embed.")
+ new_k = new_k.replace("guidance_embedder_", "guidance_embedder.")
+ new_k = new_k.replace("text_embedder_", "text_embedder.")
+ new_k = new_k.replace("timestep_embedder_", "timestep_embedder.")
+
+ new_k = new_k.replace("single_transformer_blocks_", "single_transformer_blocks.")
+ new_k = new_k.replace("_attn_", ".attn.")
+ new_k = new_k.replace("_norm_linear.", ".norm.linear.")
+ new_k = new_k.replace("_proj_mlp.", ".proj_mlp.")
+ new_k = new_k.replace("_proj_out.", ".proj_out.")
+
+ new_k = new_k.replace("transformer_blocks_", "transformer_blocks.")
+ new_k = new_k.replace("to_out_0.", "to_out.0.")
+ new_k = new_k.replace("_ff_context_net_0_proj.", ".ff_context.net.0.proj.")
+ new_k = new_k.replace("_ff_context_net_2.", ".ff_context.net.2.")
+ new_k = new_k.replace("_ff_net_0_proj.", ".ff.net.0.proj.")
+ new_k = new_k.replace("_ff_net_2.", ".ff.net.2.")
+ new_k = new_k.replace("_norm1_context_linear.", ".norm1_context.linear.")
+ new_k = new_k.replace("_norm1_linear.", ".norm1.linear.")
+
+ new_k = new_k.replace(".lora_down.", ".lora_A.")
+ new_k = new_k.replace(".lora_up.", ".lora_B.")
+
+ new_state_dict[new_k] = v
+ return new_state_dict
+
+
+def convert_peft_to_comfyui(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """
+ Convert PEFT format (base_model.model.*) to ComfyUI format (lora_unet_*).
+
+ Mapping rules:
+ - base_model.model.double_blocks.X.img_attn.proj → lora_unet_double_blocks_X_img_attn_proj
+ - base_model.model.single_blocks.X.linear1 → lora_unet_single_blocks_X_linear1
+ - base_model.model.final_layer.linear → lora_unet_final_layer_linear
+ - lora_A/lora_B → lora_down/lora_up
+
+ Parameters
+ ----------
+ state_dict : dict[str, torch.Tensor]
+ LoRA weights in PEFT format
+
+ Returns
+ -------
+ dict[str, torch.Tensor]
+ LoRA weights in ComfyUI format
+ """
+ converted_dict = {}
+
+ for key, value in state_dict.items():
+ new_key = key
+
+ if key.startswith("base_model.model."):
+ # Remove base_model.model. prefix
+ new_key = key.replace("base_model.model.", "")
+
+ # Convert to ComfyUI format with underscores
+ # Handle double_blocks
+ if "double_blocks" in new_key:
+ # Replace dots with underscores within the block structure
+ # e.g., double_blocks.0.img_attn.proj → double_blocks_0_img_attn_proj
+ new_key = new_key.replace("double_blocks.", "lora_unet_double_blocks_")
+ # Replace remaining dots with underscores
+ new_key = new_key.replace(".", "_")
+
+ # Handle single_blocks
+ elif "single_blocks" in new_key:
+ new_key = new_key.replace("single_blocks.", "lora_unet_single_blocks_")
+ # Special handling for modulation.lin → modulation_lin
+ new_key = new_key.replace("modulation.lin", "modulation_lin")
+ # Replace remaining dots with underscores
+ new_key = new_key.replace(".", "_")
+
+ # Handle final_layer
+ elif "final_layer" in new_key:
+ new_key = new_key.replace("final_layer.linear", "lora_unet_final_layer_linear")
+ # Replace remaining dots with underscores
+ new_key = new_key.replace(".", "_")
+
+ else:
+ # For any other keys, add lora_unet_ prefix and replace dots
+ new_key = "lora_unet_" + new_key.replace(".", "_")
+
+ # Convert lora_A/lora_B to lora_down/lora_up
+ new_key = new_key.replace("_lora_A_weight", ".lora_down.weight")
+ new_key = new_key.replace("_lora_B_weight", ".lora_up.weight")
+
+ converted_dict[new_key] = value
+
+ if key != new_key:
+ logger.debug(f"Converted: {key} → {new_key}")
+
+ return converted_dict
+
+
+def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | None = None) -> dict[str, torch.Tensor]:
+ """
+ Convert LoRA weights to Diffusers format, which will later be converted to Nunchaku format.
+
+ Parameters
+ ----------
+ input_lora : str or dict[str, torch.Tensor]
+ Path to a safetensors file or a LoRA weight dictionary.
+ output_path : str, optional
+ If given, save the converted weights to this path.
+
+ Returns
+ -------
+ dict[str, torch.Tensor]
+ LoRA weights in Diffusers format.
+ """
+ if isinstance(input_lora, str):
+ tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
+ else:
+ tensors = {k: v for k, v in input_lora.items()}
+
+ tensors = handle_kohya_lora(tensors)
+
+ # Convert FP8 tensors to BF16
+ for k, v in tensors.items():
+ if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]:
+ tensors[k] = v.to(torch.bfloat16)
+
+ # Apply Kontext-specific key conversion for both PEFT format and ComfyUI format
+ # This handles LoRAs with base_model.model.* prefix or lora_unet_* prefix (including final_layer_linear)
+ if any(k.startswith("base_model.model.") for k in tensors.keys()):
+ logger.info("Converting PEFT format to ComfyUI format")
+ return convert_peft_to_comfyui(tensors)
+
+ # Handle LoRAs that only have final_layer_linear without adaLN_modulation
+ # This is a workaround for incomplete final layer LoRAs
+ final_keys = [k for k in tensors.keys() if "final_layer" in k]
+ has_linear = any("final_layer_linear" in k for k in final_keys)
+ has_adaln = any("final_layer_adaLN_modulation" in k for k in final_keys)
+
+ if has_linear and not has_adaln:
+ for key in list(tensors.keys()):
+ if "final_layer_linear" in key:
+ adaln_key = key.replace("final_layer_linear", "final_layer_adaLN_modulation_1")
+ if adaln_key not in tensors:
+ tensors[adaln_key] = torch.zeros_like(tensors[key])
+
+ new_tensors, alphas = FluxLoraLoaderMixin.lora_state_dict(tensors, return_alphas=True)
+ new_tensors = convert_unet_state_dict_to_peft(new_tensors)
+
+ if alphas is not None and len(alphas) > 0:
+ for k, v in alphas.items():
+ key_A = k.replace(".alpha", ".lora_A.weight")
+ key_B = k.replace(".alpha", ".lora_B.weight")
+ assert key_A in new_tensors, f"Key {key_A} not found in new tensors."
+ assert key_B in new_tensors, f"Key {key_B} not found in new tensors."
+ rank = new_tensors[key_A].shape[0]
+ assert new_tensors[key_B].shape[1] == rank, f"Rank mismatch for {key_B}."
+ new_tensors[key_A] = new_tensors[key_A] * v / rank
+
+ if output_path is not None:
+ output_dir = os.path.dirname(os.path.abspath(output_path))
+ os.makedirs(output_dir, exist_ok=True)
+ save_file(new_tensors, output_path)
+
+ return new_tensors
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-i", "--input-path", type=str, required=True, help="path to the comfyui lora safetensors file")
+ parser.add_argument(
+ "-o", "--output-path", type=str, required=True, help="path to the output diffusers safetensors file"
+ )
+ args = parser.parse_args()
+ to_diffusers(args.input_path, args.output_path)
diff --git a/nunchaku/lora/flux/nunchaku_converter.py b/nunchaku/lora/flux/nunchaku_converter.py
new file mode 100644
index 0000000000000000000000000000000000000000..b10309c3dee19ba10192a0d3c237bc7a40858cb0
--- /dev/null
+++ b/nunchaku/lora/flux/nunchaku_converter.py
@@ -0,0 +1,949 @@
+"""
+Nunchaku LoRA format converter for Flux models.
+
+This module provides utilities to convert LoRA weights from Diffusers format
+to Nunchaku format for efficient quantized inference in Flux models.
+
+Key functions
+-------------
+- :func:`to_nunchaku` : Main conversion entry point
+- :func:`fuse_vectors` : Vector fusion for bias terms
+"""
+
+import logging
+import os
+
+import torch
+from safetensors.torch import save_file
+from tqdm import tqdm
+
+from ...utils import filter_state_dict, load_state_dict_in_safetensors
+from .diffusers_converter import to_diffusers
+from .packer import NunchakuWeightPacker
+from .utils import is_nunchaku_format, pad
+
+# Get log level from environment variable (default to INFO)
+log_level = os.getenv("LOG_LEVEL", "INFO").upper()
+
+# Configure logging
+logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
+logger = logging.getLogger(__name__)
+
+
+# region utilities
+
+
+def update_state_dict(
+ lhs: dict[str, torch.Tensor], rhs: dict[str, torch.Tensor], prefix: str = ""
+) -> dict[str, torch.Tensor]:
+ """
+ Update a state dictionary with values from another, optionally adding a prefix to keys.
+
+ Parameters
+ ----------
+ lhs : dict[str, torch.Tensor]
+ Target state dictionary.
+ rhs : dict[str, torch.Tensor]
+ Source state dictionary.
+ prefix : str, optional
+ Prefix to add to keys from rhs.
+
+ Returns
+ -------
+ dict[str, torch.Tensor]
+ Updated state dictionary.
+
+ Raises
+ ------
+ AssertionError
+ If any key already exists in the target dictionary.
+ """
+ for rkey, value in rhs.items():
+ lkey = f"{prefix}.{rkey}" if prefix else rkey
+ assert lkey not in lhs, f"Key {lkey} already exists in the state dict."
+ lhs[lkey] = value
+ return lhs
+
+
+# endregion
+
+
+def pack_lowrank_weight(weight: torch.Tensor, down: bool) -> torch.Tensor:
+ """
+ Pack the low-rank weight tensor for W4A4 linear layers.
+
+ Parameters
+ ----------
+ weight : torch.Tensor
+ Low-rank weight tensor.
+ down : bool
+ If True, pack as down-projection; else as up-projection.
+
+ Returns
+ -------
+ torch.Tensor
+ Packed weight tensor.
+ """
+ assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}."
+ lane_n, lane_k = 1, 2 # lane_n is always 1, lane_k is 32 bits // 16 bits = 2
+ n_pack_size, k_pack_size = 2, 2
+ num_n_lanes, num_k_lanes = 8, 4
+ frag_n = n_pack_size * num_n_lanes * lane_n
+ frag_k = k_pack_size * num_k_lanes * lane_k
+ weight = pad(weight, divisor=(frag_n, frag_k), dim=(0, 1))
+ if down:
+ r, c = weight.shape
+ r_frags, c_frags = r // frag_n, c // frag_k
+ weight = weight.view(r_frags, frag_n, c_frags, frag_k).permute(2, 0, 1, 3)
+ else:
+ c, r = weight.shape
+ c_frags, r_frags = c // frag_n, r // frag_k
+ weight = weight.view(c_frags, frag_n, r_frags, frag_k).permute(0, 2, 1, 3)
+ weight = weight.reshape(c_frags, r_frags, n_pack_size, num_n_lanes, k_pack_size, num_k_lanes, lane_k)
+ weight = weight.permute(0, 1, 3, 5, 2, 4, 6).contiguous()
+ return weight.view(c, r)
+
+
+def unpack_lowrank_weight(weight: torch.Tensor, down: bool) -> torch.Tensor:
+ """
+ Unpack the low-rank weight tensor from W4A4 linear layers.
+
+ Parameters
+ ----------
+ weight : torch.Tensor
+ Packed low-rank weight tensor.
+ down : bool
+ If True, unpack as down-projection; else as up-projection.
+
+ Returns
+ -------
+ torch.Tensor
+ Unpacked weight tensor.
+ """
+ c, r = weight.shape
+ assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}."
+ lane_n, lane_k = 1, 2 # lane_n is always 1, lane_k is 32 bits // 16 bits = 2
+ n_pack_size, k_pack_size = 2, 2
+ num_n_lanes, num_k_lanes = 8, 4
+ frag_n = n_pack_size * num_n_lanes * lane_n
+ frag_k = k_pack_size * num_k_lanes * lane_k
+ if down:
+ r_frags, c_frags = r // frag_n, c // frag_k
+ else:
+ c_frags, r_frags = c // frag_n, r // frag_k
+ weight = weight.view(c_frags, r_frags, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, lane_k)
+ weight = weight.permute(0, 1, 4, 2, 5, 3, 6).contiguous()
+ weight = weight.view(c_frags, r_frags, frag_n, frag_k)
+ if down:
+ weight = weight.permute(1, 2, 0, 3).contiguous().view(r, c)
+ else:
+ weight = weight.permute(0, 2, 1, 3).contiguous().view(c, r)
+ return weight
+
+
+def reorder_adanorm_lora_up(lora_up: torch.Tensor, splits: int) -> torch.Tensor:
+ """
+ Reorder AdaNorm LoRA up-projection tensor for correct shape.
+
+ Parameters
+ ----------
+ lora_up : torch.Tensor
+ LoRA up-projection tensor.
+ splits : int
+ Number of splits for AdaNorm.
+
+ Returns
+ -------
+ torch.Tensor
+ Reordered tensor.
+ """
+ c, r = lora_up.shape
+ assert c % splits == 0
+ return lora_up.view(splits, c // splits, r).transpose(0, 1).reshape(c, r).contiguous()
+
+
+def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
+ orig_state_dict: dict[str, torch.Tensor],
+ extra_lora_dict: dict[str, torch.Tensor],
+ converted_block_name: str,
+ candidate_block_name: str,
+ local_name_map: dict[str, str | list[str]],
+ convert_map: dict[str, str],
+ default_dtype: torch.dtype = torch.bfloat16,
+) -> dict[str, torch.Tensor]:
+ """
+ Convert LoRA weights for a transformer block from Diffusers to Nunchaku format.
+
+ Merges and converts LoRA weights from the original SVDQuant low-rank branch and an extra LoRA dict
+ for a given transformer block, producing a Nunchaku-compatible dictionary. Handles both fused and
+ unfused LoRA branches (e.g., qkv), and merges multiple LoRA branches as needed.
+
+ Parameters
+ ----------
+ orig_state_dict : dict[str, torch.Tensor]
+ Original state dict with LoRA weights, keys like ``"{block}.{local}.lora_down"`` and ``"{block}.{local}.lora_up"``.
+ extra_lora_dict : dict[str, torch.Tensor]
+ Extra LoRA weights, keys like ``"{block}.{local}.lora_A.weight"`` and ``"{block}.{local}.lora_B.weight"``.
+ converted_block_name : str
+ Block name for output (e.g., ``"transformer_blocks.0"``).
+ candidate_block_name : str
+ Block name for input lookup (e.g., ``"blocks.0"``).
+ local_name_map : dict[str, str | list[str]]
+ Maps output local names (e.g., ``"attn.qkv"``) to one or more input local names.
+ convert_map : dict[str, str]
+ Maps output local names to conversion types: ``"adanorm_single"``, ``"adanorm_zero"``, or ``"linear"``.
+ default_dtype : torch.dtype, optional
+ Output tensor dtype (default: ``torch.bfloat16``).
+
+ Returns
+ -------
+ dict[str, torch.Tensor]
+ A dictionary containing the converted LoRA weights in Nunchaku format.
+
+ Notes
+ -----
+ - If both original and extra LoRA weights are present, they are merged by concatenation.
+ - Handles both fused and unfused attention projections (e.g., qkv).
+ - Applies special packing for W4A16 linear layers (e.g., ``"adanorm_single"`` and ``"adanorm_zero"``).
+ """
+ logger.debug(f"Converting LoRA branch for block {candidate_block_name}...")
+ converted: dict[str, torch.Tensor] = {}
+ for converted_local_name, candidate_local_names in local_name_map.items():
+ if isinstance(candidate_local_names, str):
+ candidate_local_names = [candidate_local_names]
+ # region original LoRA
+ orig_lora = (
+ orig_state_dict.get(f"{converted_block_name}.{converted_local_name}.lora_down", None),
+ orig_state_dict.get(f"{converted_block_name}.{converted_local_name}.lora_up", None),
+ )
+ if orig_lora[0] is None or orig_lora[1] is None:
+ assert orig_lora[0] is None and orig_lora[1] is None
+ orig_lora = None
+ elif orig_lora[0].numel() == 0 or orig_lora[1].numel() == 0:
+ assert orig_lora[0].numel() == 0 and orig_lora[1].numel() == 0
+ orig_lora = None
+ else:
+ assert orig_lora[0] is not None and orig_lora[1] is not None
+ orig_lora = (
+ unpack_lowrank_weight(orig_lora[0], down=True),
+ unpack_lowrank_weight(orig_lora[1], down=False),
+ )
+ logger.debug(
+ f" - Found {converted_block_name} LoRA of {converted_local_name} (rank: {orig_lora[0].shape[0]})"
+ )
+ # endregion
+ # region extra LoRA
+ extra_lora_list = None
+
+ # if the qkv are already fused
+ if "qkv" in converted_local_name:
+ candidate_local_name = candidate_local_names[0]
+ assert "_q" in candidate_local_name
+ candidate_local_name = candidate_local_name.replace("_q", "_qkv")
+ lora_A = extra_lora_dict.get(f"{candidate_block_name}.{candidate_local_name}.lora_A.weight", None)
+ lora_B = extra_lora_dict.get(f"{candidate_block_name}.{candidate_local_name}.lora_B.weight", None)
+ if lora_A is None and lora_B is None:
+ extra_lora_list = None
+ else:
+ assert lora_A is not None and lora_B is not None
+ extra_lora_list = [(lora_A, lora_B)]
+
+ # not fused, fuse them manually
+ if extra_lora_list is None:
+ extra_lora_list = [
+ (
+ extra_lora_dict.get(f"{candidate_block_name}.{candidate_local_name}.lora_A.weight", None),
+ extra_lora_dict.get(f"{candidate_block_name}.{candidate_local_name}.lora_B.weight", None),
+ )
+ for candidate_local_name in candidate_local_names
+ ]
+ if any(lora[0] is not None or lora[1] is not None for lora in extra_lora_list):
+ # merge extra LoRAs into one LoRA
+ if len(extra_lora_list) > 1:
+ first_lora = None
+ for lora in extra_lora_list:
+ if lora[0] is not None:
+ assert lora[1] is not None
+ first_lora = lora
+ break
+ assert first_lora is not None
+ for lora_index in range(len(extra_lora_list)):
+ if extra_lora_list[lora_index][0] is None:
+ assert extra_lora_list[lora_index][1] is None
+ extra_lora_list[lora_index] = (first_lora[0].clone(), torch.zeros_like(first_lora[1]))
+ if all(lora[0].equal(extra_lora_list[0][0]) for lora in extra_lora_list):
+ # if all extra LoRAs have the same lora_down, use it
+ extra_lora_down = extra_lora_list[0][0]
+ extra_lora_up = torch.cat([lora[1] for lora in extra_lora_list], dim=0)
+ else:
+ extra_lora_down = torch.cat([lora[0] for lora in extra_lora_list], dim=0)
+ extra_lora_up_c = sum(lora[1].shape[0] for lora in extra_lora_list)
+ extra_lora_up_r = sum(lora[1].shape[1] for lora in extra_lora_list)
+ assert extra_lora_up_r == extra_lora_down.shape[0]
+ extra_lora_up = torch.zeros((extra_lora_up_c, extra_lora_up_r), dtype=extra_lora_down.dtype)
+ c, r = 0, 0
+ for lora in extra_lora_list:
+ c_next, r_next = c + lora[1].shape[0], r + lora[1].shape[1]
+ extra_lora_up[c:c_next, r:r_next] = lora[1]
+ c, r = c_next, r_next
+ else:
+ extra_lora_down, extra_lora_up = extra_lora_list[0]
+ extra_lora: tuple[torch.Tensor, torch.Tensor] = (extra_lora_down, extra_lora_up)
+ logger.debug(
+ f" - Found {candidate_block_name} LoRA of {candidate_local_names} (rank: {extra_lora[0].shape[0]})"
+ )
+ else:
+ extra_lora = None
+ # endregion
+ # region merge LoRA
+ if orig_lora is None:
+ if extra_lora is None:
+ lora = None
+ else:
+ logger.debug(" - Using extra LoRA")
+ lora = (extra_lora[0].to(default_dtype), extra_lora[1].to(default_dtype))
+ elif extra_lora is None:
+ logger.debug(" - Using original LoRA")
+ lora = orig_lora
+ else:
+ try:
+ lora = (
+ torch.cat([orig_lora[0], extra_lora[0].to(orig_lora[0].dtype)], dim=0), # [r, c]
+ torch.cat([orig_lora[1], extra_lora[1].to(orig_lora[1].dtype)], dim=1), # [c, r]
+ )
+ logger.debug(f" - Merging original and extra LoRA (rank: {lora[0].shape[0]})")
+ except RuntimeError as e:
+ if "Sizes of tensors must match" in str(e):
+ # Handle various dimension mismatch cases for LoRA
+ logger.debug(
+ f" - Dimension mismatch detected: orig_lora[1]={orig_lora[1].shape}, extra_lora[1]={extra_lora[1].shape}"
+ )
+
+ # Handle dimension mismatch by using only the properly sized portion of extra_lora
+ # instead of trying to concatenate mismatched dimensions
+
+ # Case 1: single_blocks linear1 [21504] -> mlp_fc1 [12288]
+ if extra_lora[1].shape[1] == 21504 and orig_lora[1].shape[1] == 12288:
+ # Use only the first 12288 dimensions from the 21504 extra LoRA
+ extra_lora_up_split = extra_lora[1][:, :12288].clone()
+ extra_lora_down = extra_lora[0].clone()
+ # logger.debug(f" - Dimension fix 21504->12288: using split extra LoRA instead of merge")
+
+ # Use the split extra LoRA instead of concatenating
+ lora = (extra_lora_down.to(orig_lora[0].dtype), extra_lora_up_split.to(orig_lora[1].dtype))
+
+ # Case 2: transformer_blocks with different MLP dimensions (27648 -> 9216)
+ elif extra_lora[1].shape[1] == 27648 and orig_lora[1].shape[1] == 9216:
+ # Use only the first 9216 dimensions from the 27648 extra LoRA
+ extra_lora_up_split = extra_lora[1][:, :9216].clone()
+ extra_lora_down = extra_lora[0].clone()
+ # logger.debug(f" - Dimension fix 27648->9216: using split extra LoRA instead of merge")
+
+ # Use the split extra LoRA instead of concatenating
+ lora = (extra_lora_down.to(orig_lora[0].dtype), extra_lora_up_split.to(orig_lora[1].dtype))
+
+ # Case 3: Other dimension ratios - try to find a reasonable split
+ elif extra_lora[1].shape[1] > orig_lora[1].shape[1]:
+ # Use only what we need from extra LoRA
+ target_dim = orig_lora[1].shape[1]
+ extra_lora_up_split = extra_lora[1][:, :target_dim].clone()
+ extra_lora_down = extra_lora[0].clone()
+ # logger.debug(
+ # f" - Dimension fix {extra_lora[1].shape[1]}->{target_dim}: using truncated extra LoRA"
+ # )
+
+ # Use the truncated extra LoRA instead of concatenating
+ lora = (extra_lora_down.to(orig_lora[0].dtype), extra_lora_up_split.to(orig_lora[1].dtype))
+
+ else:
+ # For cases where extra LoRA has fewer dimensions, use original LoRA only
+ # logger.warning(
+ # f" - Cannot split extra LoRA {extra_lora[1].shape[1]}->{orig_lora[1].shape[1]}, using original only"
+ # )
+ lora = orig_lora
+ else:
+ raise e
+ # endregion
+ if lora is not None:
+ if convert_map[converted_local_name] == "adanorm_single":
+ update_state_dict(
+ converted,
+ {
+ "lora_down": pad(lora[0], divisor=16, dim=0),
+ "lora_up": pad(reorder_adanorm_lora_up(lora[1], splits=3), divisor=16, dim=1),
+ },
+ prefix=converted_local_name,
+ )
+ elif convert_map[converted_local_name] == "adanorm_zero":
+ update_state_dict(
+ converted,
+ {
+ "lora_down": pad(lora[0], divisor=16, dim=0),
+ "lora_up": pad(reorder_adanorm_lora_up(lora[1], splits=6), divisor=16, dim=1),
+ },
+ prefix=converted_local_name,
+ )
+ elif convert_map[converted_local_name] == "linear":
+ update_state_dict(
+ converted,
+ {
+ "lora_down": pack_lowrank_weight(lora[0], down=True),
+ "lora_up": pack_lowrank_weight(lora[1], down=False),
+ },
+ prefix=converted_local_name,
+ )
+ return converted
+
+
+def preprocess_single_blocks_lora(
+ extra_lora_dict: dict[str, torch.Tensor], candidate_block_name: str
+) -> dict[str, torch.Tensor]:
+ """
+ Preprocess LoRA weights from single_blocks format to match single_transformer_blocks structure.
+
+ This function handles the architectural mismatch between old and new models:
+ - Old single_blocks: linear1 (fused 21504-dim layer) and linear2
+ - New single_transformer_blocks: mlp_fc1 (12288-dim), qkv_proj (9216-dim), and mlp_fc2
+
+ The linear1 layer in the old architecture combines two functions:
+ 1. MLP projection (first 12288 dimensions)
+ 2. QKV projection for attention (last 9216 dimensions)
+
+ These are split into separate layers in the new architecture.
+ """
+ processed_dict = extra_lora_dict.copy()
+
+ # Find all single_transformer_blocks keys that need preprocessing
+ single_blocks_keys = [k for k in extra_lora_dict.keys() if "single_transformer_blocks" in k and "linear" in k]
+
+ logger.debug(f"Preprocessing LoRA for candidate: {candidate_block_name}")
+ logger.debug(f"All keys in extra_lora_dict: {list(extra_lora_dict.keys())[:10]}...") # Show first 10 keys
+ logger.debug(f"Found single_transformer_blocks keys: {single_blocks_keys[:5]}...") # Show first 5 keys
+
+ if single_blocks_keys:
+ logger.debug(f"Found single_transformer_blocks LoRA keys, preprocessing for candidate: {candidate_block_name}")
+
+ # The candidate_block_name is already "single_transformer_blocks.0"
+ # Look for linear1 and linear2 keys with this exact name
+ linear1_lora_A_key = f"{candidate_block_name}.linear1.lora_A.weight"
+ linear1_lora_B_key = f"{candidate_block_name}.linear1.lora_B.weight"
+ linear2_lora_A_key = f"{candidate_block_name}.linear2.lora_A.weight"
+ linear2_lora_B_key = f"{candidate_block_name}.linear2.lora_B.weight"
+
+ logger.debug(f"Looking for keys: {linear1_lora_B_key}")
+ logger.debug(
+ f"Available keys matching pattern: {[k for k in extra_lora_dict.keys() if candidate_block_name in k][:5]}..."
+ )
+
+ if linear1_lora_B_key in extra_lora_dict:
+ linear1_lora_A = extra_lora_dict[linear1_lora_A_key]
+ linear1_lora_B = extra_lora_dict[linear1_lora_B_key]
+
+ # Check if this is the problematic 21504 dimension case
+ if linear1_lora_B.shape[0] == 21504:
+ logger.debug(
+ f"Splitting linear1 LoRA weights: [21504, {linear1_lora_B.shape[1]}] -> "
+ f"mlp_fc1 [12288, {linear1_lora_B.shape[1]}] + qkv_proj [9216, {linear1_lora_B.shape[1]}]"
+ )
+
+ # Split linear1.lora_B [21504, rank] into two parts:
+ # 1. First 12288 dimensions -> mlp_fc1
+ # 2. Last 9216 dimensions (12288:21504) -> qkv_proj
+ mlp_fc1_lora_B = linear1_lora_B[:12288, :].clone()
+ qkv_proj_lora_B = linear1_lora_B[12288:21504, :].clone()
+
+ # The lora_A weight is reused for both new layers
+ # since it represents the down-projection from the input
+ mlp_fc1_lora_A = linear1_lora_A.clone()
+ qkv_proj_lora_A = linear1_lora_A.clone()
+
+ # Map to new architecture:
+ # 1. proj_mlp corresponds to mlp_fc1
+ processed_dict[f"{candidate_block_name}.proj_mlp.lora_A.weight"] = mlp_fc1_lora_A
+ processed_dict[f"{candidate_block_name}.proj_mlp.lora_B.weight"] = mlp_fc1_lora_B
+
+ # 2. Map the QKV part to the attention layers
+ # Note: In the new architecture, this maps to attn.to_q, attn.to_k, attn.to_v
+ # which get fused into qkv_proj during the conversion
+ processed_dict[f"{candidate_block_name}.attn.to_q.lora_A.weight"] = qkv_proj_lora_A
+ processed_dict[f"{candidate_block_name}.attn.to_q.lora_B.weight"] = qkv_proj_lora_B[
+ :3072, :
+ ] # Q projection
+ processed_dict[f"{candidate_block_name}.attn.to_k.lora_A.weight"] = qkv_proj_lora_A
+ processed_dict[f"{candidate_block_name}.attn.to_k.lora_B.weight"] = qkv_proj_lora_B[
+ 3072:6144, :
+ ] # K projection
+ processed_dict[f"{candidate_block_name}.attn.to_v.lora_A.weight"] = qkv_proj_lora_A
+ processed_dict[f"{candidate_block_name}.attn.to_v.lora_B.weight"] = qkv_proj_lora_B[
+ 6144:9216, :
+ ] # V projection
+
+ # Handle linear2 -> mlp_fc2 mapping
+ if linear2_lora_B_key in extra_lora_dict:
+ linear2_lora_A = extra_lora_dict[linear2_lora_A_key]
+ linear2_lora_B = extra_lora_dict[linear2_lora_B_key]
+
+ # Map linear2 to proj_out.linears.1 (mlp_fc2)
+ processed_dict[f"{candidate_block_name}.proj_out.linears.1.lora_A.weight"] = linear2_lora_A
+ processed_dict[f"{candidate_block_name}.proj_out.linears.1.lora_B.weight"] = linear2_lora_B
+
+ # Remove original keys
+ processed_dict.pop(linear2_lora_A_key, None)
+ processed_dict.pop(linear2_lora_B_key, None)
+
+ # Remove original linear1 keys
+ processed_dict.pop(linear1_lora_A_key, None)
+ processed_dict.pop(linear1_lora_B_key, None)
+
+ return processed_dict
+
+
+def convert_to_nunchaku_flux_single_transformer_block_lowrank_dict(
+ orig_state_dict: dict[str, torch.Tensor],
+ extra_lora_dict: dict[str, torch.Tensor],
+ converted_block_name: str,
+ candidate_block_name: str,
+ default_dtype: torch.dtype = torch.bfloat16,
+) -> dict[str, torch.Tensor]:
+ """
+ Convert LoRA weights for a single FLUX transformer block from Diffusers to Nunchaku format.
+
+ This function merges and converts LoRA weights from the original SVDQuant low-rank branch and an
+ extra LoRA dictionary for a given transformer block, producing a Nunchaku-compatible dictionary.
+ It handles both fused and unfused LoRA branches (e.g., qkv), and merges multiple LoRA branches as needed.
+
+ Parameters
+ ----------
+ orig_state_dict : dict[str, torch.Tensor]
+ Original state dict with LoRA weights, keys like ``"{block}.{local}.lora_down"`` and ``"{block}.{local}.lora_up"``.
+ extra_lora_dict : dict[str, torch.Tensor]
+ Extra LoRA weights, keys like ``"{block}.{local}.lora_A.weight"`` and ``"{block}.{local}.lora_B.weight"``.
+ converted_block_name : str
+ Block name for output (e.g., ``"transformer_blocks.0"``).
+ candidate_block_name : str
+ Block name for input lookup (e.g., ``"blocks.0"``).
+ default_dtype : torch.dtype, optional
+ Output tensor dtype (default: ``torch.bfloat16``).
+
+ Returns
+ -------
+ dict[str, torch.Tensor]
+ A dictionary containing the converted LoRA weights in Nunchaku format.
+
+ Notes
+ -----
+ - If both original and extra LoRA weights are present, they are merged by concatenation.
+ - Handles both fused and unfused attention projections (e.g., qkv).
+ - Applies special packing for W4A16 linear layers (e.g., ``"adanorm_single"`` and ``"adanorm_zero"``).
+ """
+
+ # Preprocess single_blocks LoRA structure if needed
+ # extra_lora_dict = preprocess_single_blocks_lora(extra_lora_dict, candidate_block_name)
+
+ if f"{candidate_block_name}.proj_out.lora_A.weight" in extra_lora_dict:
+ assert f"{converted_block_name}.out_proj.qweight" in orig_state_dict
+ assert f"{converted_block_name}.mlp_fc2.qweight" in orig_state_dict
+ n1 = orig_state_dict[f"{converted_block_name}.out_proj.qweight"].shape[1] * 2
+ n2 = orig_state_dict[f"{converted_block_name}.mlp_fc2.qweight"].shape[1] * 2
+ lora_down = extra_lora_dict[f"{candidate_block_name}.proj_out.lora_A.weight"]
+ lora_up = extra_lora_dict[f"{candidate_block_name}.proj_out.lora_B.weight"]
+ assert lora_down.shape[1] == n1 + n2
+ extra_lora_dict[f"{candidate_block_name}.proj_out.linears.0.lora_A.weight"] = lora_down[:, :n1].clone()
+ extra_lora_dict[f"{candidate_block_name}.proj_out.linears.0.lora_B.weight"] = lora_up.clone()
+ extra_lora_dict[f"{candidate_block_name}.proj_out.linears.1.lora_A.weight"] = lora_down[:, n1:].clone()
+ extra_lora_dict[f"{candidate_block_name}.proj_out.linears.1.lora_B.weight"] = lora_up.clone()
+ extra_lora_dict.pop(f"{candidate_block_name}.proj_out.lora_A.weight")
+ extra_lora_dict.pop(f"{candidate_block_name}.proj_out.lora_B.weight")
+
+ for component in ["lora_A", "lora_B"]:
+ fc1_k = f"{candidate_block_name}.proj_mlp.{component}.weight"
+ fc2_k = f"{candidate_block_name}.proj_out.linears.1.{component}.weight"
+ fc1_v = extra_lora_dict[fc1_k]
+ fc2_v = extra_lora_dict[fc2_k]
+ dim = 0 if "lora_A" in fc1_k else 1
+
+ fc1_rank = fc1_v.shape[dim]
+ fc2_rank = fc2_v.shape[dim]
+ if fc1_rank != fc2_rank:
+ rank = max(fc1_rank, fc2_rank)
+ if fc1_rank < rank:
+ extra_lora_dict[fc1_k] = pad(fc1_v, divisor=rank, dim=dim)
+ if fc2_rank < rank:
+ extra_lora_dict[fc2_k] = pad(fc2_v, divisor=rank, dim=dim)
+
+ return convert_to_nunchaku_transformer_block_lowrank_dict(
+ orig_state_dict=orig_state_dict,
+ extra_lora_dict=extra_lora_dict,
+ converted_block_name=converted_block_name,
+ candidate_block_name=candidate_block_name,
+ local_name_map={
+ "norm.linear": "norm.linear",
+ "qkv_proj": ["attn.to_q", "attn.to_k", "attn.to_v"],
+ "norm_q": "attn.norm_q",
+ "norm_k": "attn.norm_k",
+ "out_proj": "proj_out.linears.0",
+ "mlp_fc1": "proj_mlp",
+ "mlp_fc2": "proj_out.linears.1",
+ },
+ convert_map={
+ "norm.linear": "adanorm_single",
+ "qkv_proj": "linear",
+ "out_proj": "linear",
+ "mlp_fc1": "linear",
+ "mlp_fc2": "linear",
+ },
+ default_dtype=default_dtype,
+ )
+
+
+def convert_to_nunchaku_flux_transformer_block_lowrank_dict(
+ orig_state_dict: dict[str, torch.Tensor],
+ extra_lora_dict: dict[str, torch.Tensor],
+ converted_block_name: str,
+ candidate_block_name: str,
+ default_dtype: torch.dtype = torch.bfloat16,
+) -> dict[str, torch.Tensor]:
+ """
+ Convert LoRA weights for a single transformer block from Diffusers to Nunchaku format.
+
+ Parameters
+ ----------
+ orig_state_dict : dict[str, torch.Tensor]
+ Original model state dict.
+ extra_lora_dict : dict[str, torch.Tensor]
+ LoRA weights state dict.
+ converted_block_name : str
+ Output block name for the converted weights.
+ candidate_block_name : str
+ Input block name for lookup.
+ default_dtype : torch.dtype, optional
+ Output tensor dtype (default: torch.bfloat16).
+
+ Returns
+ -------
+ dict[str, torch.Tensor]
+ Converted LoRA weights in Nunchaku format.
+ """
+ return convert_to_nunchaku_transformer_block_lowrank_dict(
+ orig_state_dict=orig_state_dict,
+ extra_lora_dict=extra_lora_dict,
+ converted_block_name=converted_block_name,
+ candidate_block_name=candidate_block_name,
+ local_name_map={
+ "norm1.linear": "norm1.linear",
+ "norm1_context.linear": "norm1_context.linear",
+ "qkv_proj": ["attn.to_q", "attn.to_k", "attn.to_v"],
+ "qkv_proj_context": ["attn.add_q_proj", "attn.add_k_proj", "attn.add_v_proj"],
+ "norm_q": "attn.norm_q",
+ "norm_k": "attn.norm_k",
+ "norm_added_q": "attn.norm_added_q",
+ "norm_added_k": "attn.norm_added_k",
+ "out_proj": "attn.to_out.0",
+ "out_proj_context": "attn.to_add_out",
+ "mlp_fc1": "ff.net.0.proj",
+ "mlp_fc2": "ff.net.2",
+ "mlp_context_fc1": "ff_context.net.0.proj",
+ "mlp_context_fc2": "ff_context.net.2",
+ },
+ convert_map={
+ "norm1.linear": "adanorm_zero",
+ "norm1_context.linear": "adanorm_zero",
+ "qkv_proj": "linear",
+ "qkv_proj_context": "linear",
+ "out_proj": "linear",
+ "out_proj_context": "linear",
+ "mlp_fc1": "linear",
+ "mlp_fc2": "linear",
+ "mlp_context_fc1": "linear",
+ "mlp_context_fc2": "linear",
+ },
+ default_dtype=default_dtype,
+ )
+
+
+def convert_to_nunchaku_flux_lowrank_dict(
+ base_model: dict[str, torch.Tensor] | str,
+ lora: dict[str, torch.Tensor] | str,
+ default_dtype: torch.dtype = torch.bfloat16,
+) -> dict[str, torch.Tensor]:
+ """
+ Convert a base model and LoRA weights from Diffusers format to Nunchaku format.
+
+ Parameters
+ ----------
+ base_model : dict[str, torch.Tensor] or str
+ Base model weights or path to safetensors file.
+ lora : dict[str, torch.Tensor] or str
+ LoRA weights or path to safetensors file.
+ default_dtype : torch.dtype, optional
+ Output tensor dtype (default: torch.bfloat16).
+
+ Returns
+ -------
+ dict[str, torch.Tensor]
+ LoRA weights in Nunchaku format.
+ """
+ if isinstance(base_model, str):
+ orig_state_dict = load_state_dict_in_safetensors(base_model)
+ else:
+ orig_state_dict = base_model
+
+ if isinstance(lora, str):
+ # Load the LoRA - check if it has transformer prefix
+ temp_dict = load_state_dict_in_safetensors(lora)
+ if any(k.startswith("transformer.") for k in temp_dict.keys()):
+ # Standard FLUX LoRA with transformer prefix
+ extra_lora_dict = filter_state_dict(temp_dict, filter_prefix="transformer.")
+ # Remove the transformer. prefix after filtering
+ renamed_dict = {}
+ for k, v in extra_lora_dict.items():
+ new_k = k.replace("transformer.", "") if k.startswith("transformer.") else k
+ renamed_dict[new_k] = v
+ extra_lora_dict = renamed_dict
+ else:
+ # Kontext LoRA without transformer prefix - use as is
+ extra_lora_dict = temp_dict
+ else:
+ # When called from to_nunchaku, lora is already processed by to_diffusers
+ # Keys should be in format: single_blocks.0.linear1.lora_A.weight
+ extra_lora_dict = lora
+
+ # Add transformer. prefix and rename blocks to match expectations
+ renamed_dict = {}
+ for k, v in extra_lora_dict.items():
+ new_k = k
+ # Add transformer. prefix and rename blocks
+ if k.startswith("single_blocks."):
+ new_k = "transformer.single_transformer_blocks." + k[14:]
+ elif k.startswith("double_blocks."):
+ new_k = "transformer.transformer_blocks." + k[14:]
+ elif k.startswith("proj_out."):
+ new_k = "transformer." + k
+ elif not k.startswith("transformer."):
+ new_k = "transformer." + k
+ renamed_dict[new_k] = v
+ extra_lora_dict = renamed_dict
+
+ # Now filter for transformer prefix and remove it for processing
+ extra_lora_dict = filter_state_dict(extra_lora_dict, filter_prefix="transformer.")
+
+ # Remove the transformer. prefix for internal processing
+ renamed_dict = {}
+ for k, v in extra_lora_dict.items():
+ new_k = k.replace("transformer.", "") if k.startswith("transformer.") else k
+ renamed_dict[new_k] = v
+ extra_lora_dict = renamed_dict
+
+ vector_dict, unquantized_lora_dict = {}, {}
+ for k in list(extra_lora_dict.keys()):
+ v = extra_lora_dict[k]
+ if v.ndim == 1:
+ vector_dict[k.replace(".lora_B.bias", ".bias")] = extra_lora_dict.pop(k)
+ elif "transformer_blocks" not in k and "single_transformer_blocks" not in k:
+ # Only unquantized parts (like final_layer) go here
+ unquantized_lora_dict[k] = extra_lora_dict.pop(k)
+
+ # Concatenate qkv_proj biases if present
+ for k in list(vector_dict.keys()):
+ if ".to_q." in k or ".add_q_proj." in k:
+ k_q = k
+ k_k = k.replace(".to_q.", ".to_k.").replace(".add_q_proj.", ".add_k_proj.")
+ k_v = k.replace(".to_q.", ".to_v.").replace(".add_q_proj.", ".add_v_proj.")
+ keys = [k_q, k_k, k_v]
+ values = [vector_dict.pop(key) for key in keys]
+ new_k = k_q.replace(".to_q.", ".to_qkv.").replace(".add_q_proj.", ".add_qkv_proj.")
+ vector_dict[new_k] = torch.cat(values, dim=0)
+
+ for k in extra_lora_dict.keys():
+ fc1_k = k
+ if "ff.net.0.proj" in k:
+ fc2_k = k.replace("ff.net.0.proj", "ff.net.2")
+ elif "ff_context.net.0.proj" in k:
+ fc2_k = k.replace("ff_context.net.0.proj", "ff_context.net.2")
+ else:
+ continue
+ assert fc2_k in extra_lora_dict
+ fc1_v = extra_lora_dict[fc1_k]
+ fc2_v = extra_lora_dict[fc2_k]
+ dim = 0 if "lora_A" in fc1_k else 1
+
+ fc1_rank = fc1_v.shape[dim]
+ fc2_rank = fc2_v.shape[dim]
+ if fc1_rank != fc2_rank:
+ rank = max(fc1_rank, fc2_rank)
+ if fc1_rank < rank:
+ extra_lora_dict[fc1_k] = pad(fc1_v, divisor=rank, dim=dim)
+ if fc2_rank < rank:
+ extra_lora_dict[fc2_k] = pad(fc2_v, divisor=rank, dim=dim)
+
+ block_names: set[str] = set()
+ for param_name in orig_state_dict.keys():
+ if param_name.startswith(("transformer_blocks.", "single_transformer_blocks.")):
+ block_names.add(".".join(param_name.split(".")[:2]))
+ block_names = sorted(block_names, key=lambda x: (x.split(".")[0], int(x.split(".")[-1])))
+ logger.debug(f"Converting {len(block_names)} transformer blocks...")
+ converted: dict[str, torch.Tensor] = {}
+ for block_name in tqdm(block_names, dynamic_ncols=True, desc="Converting LoRAs to nunchaku format"):
+ if block_name.startswith("transformer_blocks"):
+ convert_fn = convert_to_nunchaku_flux_transformer_block_lowrank_dict
+ else:
+ convert_fn = convert_to_nunchaku_flux_single_transformer_block_lowrank_dict
+ update_state_dict(
+ converted,
+ convert_fn(
+ orig_state_dict=orig_state_dict,
+ extra_lora_dict=extra_lora_dict,
+ converted_block_name=block_name,
+ candidate_block_name=block_name,
+ default_dtype=default_dtype,
+ ),
+ prefix=block_name,
+ )
+
+ converted.update(unquantized_lora_dict)
+ converted.update(vector_dict)
+ return converted
+
+
+def to_nunchaku(
+ input_lora: str | dict[str, torch.Tensor],
+ base_sd: str | dict[str, torch.Tensor],
+ dtype: str | torch.dtype = torch.bfloat16,
+ output_path: str | None = None,
+) -> dict[str, torch.Tensor]:
+ """
+ Convert LoRA weights to Nunchaku format.
+
+ Parameters
+ ----------
+ input_lora : str or dict[str, torch.Tensor]
+ Path or dictionary of LoRA weights in Diffusers format. Can be composed of multiple LoRA weights.
+ base_sd : str or dict[str, torch.Tensor]
+ Path or dictionary of base quantized model weights.
+ dtype : str or torch.dtype, optional
+ Output data type ("bfloat16", "float16", or torch dtype). Default is torch.bfloat16.
+ output_path : str, optional
+ If provided, saves the result to this path.
+
+ Returns
+ -------
+ dict[str, torch.Tensor]
+ LoRA weights in Nunchaku format.
+
+ Example
+ -------
+ .. code-block:: python
+
+ nunchaku_weights = to_nunchaku("lora.safetensors", "base_model.safetensors")
+ nunchaku_weights = to_nunchaku(lora_dict, base_dict)
+ """
+ if isinstance(input_lora, str):
+ tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
+ else:
+ tensors = input_lora
+ if is_nunchaku_format(tensors):
+ logger.debug("Already in nunchaku format, no conversion needed.")
+ converted = tensors
+ else:
+ extra_lora_dict = to_diffusers(tensors)
+
+ if isinstance(base_sd, str):
+ orig_state_dict = load_state_dict_in_safetensors(base_sd)
+ else:
+ orig_state_dict = base_sd
+
+ if isinstance(dtype, str):
+ if dtype == "bfloat16":
+ dtype = torch.bfloat16
+ elif dtype == "float16":
+ dtype = torch.float16
+ else:
+ raise ValueError(f"Unsupported dtype {dtype}.")
+ else:
+ assert isinstance(dtype, torch.dtype)
+
+ converted = convert_to_nunchaku_flux_lowrank_dict(
+ base_model=orig_state_dict, lora=extra_lora_dict, default_dtype=dtype
+ )
+ if output_path is not None:
+ output_dir = os.path.dirname(os.path.abspath(output_path))
+ os.makedirs(output_dir, exist_ok=True)
+ save_file(converted, output_path)
+ return converted
+
+
+#### fuse vectors ####
+
+
+def fuse_vectors(
+ vectors: dict[str, torch.Tensor], base_sd: dict[str, torch.Tensor], strength: float = 1
+) -> dict[str, torch.Tensor]:
+ """
+ Fuse vector (bias) terms from LoRA into the base model.
+
+ Parameters
+ ----------
+ vectors : dict[str, torch.Tensor]
+ LoRA vector terms.
+ base_sd : dict[str, torch.Tensor]
+ Base model state dict.
+ strength : float, optional
+ Scaling factor for LoRA vectors.
+
+ Returns
+ -------
+ dict[str, torch.Tensor]
+ State dict with fused vectors.
+ """
+ tensors: dict[str, torch.Tensor] = {}
+ packer = NunchakuWeightPacker(bits=4)
+ for k, v in base_sd.items():
+ if v.ndim != 1 or "smooth" in k or (k.startswith("single_transformer_blocks.") and ".mlp_fc2." in k):
+ continue
+ if "norm_q" in k or "norm_k" in k or "norm_added_q" in k or "norm_added_k" in k:
+ new_k = k.replace(".norm_", ".attn.norm_")
+ new_v = vectors.get(new_k, None)
+ tensors[k] = v if new_v is None else new_v
+
+ elif "norm.linear" in k or "norm1.linear" in k or "norm1_context.linear" in k:
+ diff = vectors.get(k, None)
+
+ if diff is not None:
+ if k.startswith("single_transformer_blocks."):
+ adanorm_splits = 3
+ else:
+ assert k.startswith("transformer_blocks.")
+ adanorm_splits = 6
+ diff = diff.view(adanorm_splits, -1).transpose(0, 1).reshape(-1)
+ tensors[k] = v + diff * strength
+ else:
+ tensors[k] = v
+
+ else:
+ if k.startswith("single_transformer_blocks."):
+ name_map = {".qkv_proj.": ".attn.to_qkv.", ".out_proj.": ".proj_out.", ".mlp_fc1.": ".proj_mlp."}
+ else:
+ assert k.startswith("transformer_blocks.")
+ name_map = {
+ ".qkv_proj.": ".attn.to_qkv.",
+ ".qkv_proj_context.": ".attn.add_qkv_proj.",
+ ".out_proj.": ".attn.to_out.0.",
+ ".out_proj_context.": ".attn.to_add_out.",
+ ".mlp_fc1.": ".ff.net.0.proj.",
+ ".mlp_fc2.": ".ff.net.2.",
+ ".mlp_context_fc1.": ".ff_context.net.0.proj.",
+ ".mlp_context_fc2.": ".ff_context.net.2.",
+ }
+
+ for original_pattern, new_pattern in name_map.items():
+ if original_pattern in k:
+ new_k = k.replace(original_pattern, new_pattern)
+ diff = vectors.get(new_k, None)
+ if diff is not None:
+ diff = diff * strength
+ diff = packer.pad_scale(diff, group_size=-1)
+ diff = packer.pack_scale(diff, group_size=-1)
+ tensors[k] = v + diff
+ break
+
+ return tensors
diff --git a/nunchaku/lora/flux/packer.py b/nunchaku/lora/flux/packer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ab18f6ea4099da6bd74a71d3b0788832fe4c91a
--- /dev/null
+++ b/nunchaku/lora/flux/packer.py
@@ -0,0 +1,517 @@
+"""
+Weight packing utilities for Nunchaku quantization.
+
+This module provides concise tools for packing and unpacking weight tensors,
+optimized for efficient GPU computation using Matrix Multiply and Accumulate (MMA) operations.
+"""
+
+import torch
+
+from ...utils import ceil_divide
+from .utils import pad
+
+
+class MmaWeightPackerBase:
+ """
+ Base class for Matrix Multiply and Accumulate (MMA) weight packing.
+
+ Packs weight tensors for efficient GPU computation using MMA operations.
+ Handles tile sizes, memory layout, and packing parameters.
+
+ Parameters
+ ----------
+ bits : int
+ Quantization bits. Must be 1, 4, 8, 16, or 32.
+ warp_n : int
+ Warp size in the n dimension.
+ comp_n : int, optional
+ Computation tile size in n (default: 16).
+ comp_k : int, optional
+ Computation tile size in k (default: 256 // bits).
+
+ Raises
+ ------
+ AssertionError
+ If bits or tile/pack sizes are invalid.
+
+ Attributes
+ ----------
+ comp_n : int
+ Tile size in n for MMA computation.
+ comp_k : int
+ Tile size in k for MMA computation.
+ insn_n : int
+ MMA instruction tile size in n.
+ insn_k : int
+ MMA instruction tile size in k.
+ num_lanes : int
+ Number of lanes (threads) in a warp.
+ num_k_lanes : int
+ Number of lanes in k.
+ num_n_lanes : int
+ Number of lanes in n.
+ warp_n : int
+ Warp size in n.
+ reg_k : int
+ Elements in a register in k.
+ reg_n : int
+ Elements in a register in n.
+ k_pack_size : int
+ Elements in a pack in k.
+ n_pack_size : int
+ Elements in a pack in n.
+ pack_size : int
+ Elements in a pack accessed by a lane.
+ mem_k : int
+ Tile size in k for one memory access.
+ mem_n : int
+ Tile size in n for one memory access.
+ num_k_packs : int
+ Packs in k for one memory access.
+ num_n_packs : int
+ Packs in n for one memory access.
+ """
+
+ def __init__(self, bits: int, warp_n: int, comp_n: int = None, comp_k: int = None):
+ self.bits = bits
+ assert self.bits in (1, 4, 8, 16, 32), "weight bits should be 1, 4, 8, 16, or 32."
+
+ # region compute tile size
+ self.comp_n = comp_n if comp_n is not None else 16
+ # smallest tile size in `n` dimension for MMA computation.
+ self.comp_k = comp_k if comp_k is not None else 256 // self.bits
+ # smallest tile size in `k` dimension for MMA computation.
+ # the smallest MMA computation may contain several MMA instructions
+ self.insn_n = 8 # mma instruction tile size in `n` dimension
+ # tile size in `n` dimension for MMA instruction.
+ self.insn_k = self.comp_k
+ # tile size in `k` dimension for MMA instruction.
+ assert self.insn_k * self.bits in (
+ 128,
+ 256,
+ ), f"insn_k ({self.insn_k}) * bits ({self.bits}) should be 128 or 256."
+ assert self.comp_n % self.insn_n == 0, f"comp_n ({self.comp_n}) should be divisible by insn_n ({self.insn_n})."
+ self.num_lanes = 32
+ # there are 32 lanes (or threads) in a warp.
+ self.num_k_lanes = 4
+ self.num_n_lanes = 8
+ assert (
+ warp_n >= self.comp_n and warp_n % self.comp_n == 0
+ ), f"warp_n ({warp_n}) should be divisible by comp_n({self.comp_n})."
+ self.warp_n = warp_n
+ # endregion
+ # region memory
+ self.reg_k = 32 // self.bits
+ # number of elements in a register in `k` dimension.
+ self.reg_n = 1
+ # number of elements in a register in `n` dimension (always 1).
+ self.k_pack_size = self.comp_k // (self.num_k_lanes * self.reg_k)
+ # number of elements in a pack in `k` dimension.
+ self.n_pack_size = self.comp_n // (self.num_n_lanes * self.reg_n)
+ # number of elements in a pack in `n` dimension.
+ self.pack_size = self.k_pack_size * self.n_pack_size
+ # number of elements in a pack accessed by a lane at a time.
+ assert 1 <= self.pack_size <= 4, "pack size should be less than or equal to 4."
+ assert self.k_pack_size * self.num_k_lanes * self.reg_k == self.comp_k
+ assert self.n_pack_size * self.num_n_lanes * self.reg_n == self.comp_n
+ self.mem_k = self.comp_k
+ # the tile size in `k` dimension for one tensor memory access.
+ self.mem_n = warp_n
+ # the tile size in `n` dimension for one tensor memory access.
+ self.num_k_packs = self.mem_k // (self.k_pack_size * self.num_k_lanes * self.reg_k)
+ # number of packs in `k` dimension for one tensor memory access.
+ self.num_n_packs = self.mem_n // (self.n_pack_size * self.num_n_lanes * self.reg_n)
+ # number of packs in `n` dimension for one tensor memory access.
+ # endregion
+
+ def get_view_shape(self, n: int, k: int) -> tuple[int, int, int, int, int, int, int, int, int, int]:
+ """
+ Returns the tensor view shape for MMA operations.
+
+ Parameters
+ ----------
+ n : int
+ Output channel size (must be divisible by mem_n).
+ k : int
+ Input channel size (must be divisible by mem_k).
+
+ Returns
+ -------
+ tuple of int
+ (n_tiles, num_n_packs, n_pack_size, num_n_lanes, reg_n,
+ k_tiles, num_k_packs, k_pack_size, num_k_lanes, reg_k)
+
+ Raises
+ ------
+ AssertionError
+ If n or k is not divisible by mem_n or mem_k.
+ """
+ assert n % self.mem_n == 0, "output channel size should be divisible by mem_n."
+ assert k % self.mem_k == 0, "input channel size should be divisible by mem_k."
+ return (
+ n // self.mem_n,
+ self.num_n_packs,
+ self.n_pack_size,
+ self.num_n_lanes,
+ self.reg_n,
+ k // self.mem_k,
+ self.num_k_packs,
+ self.k_pack_size,
+ self.num_k_lanes,
+ self.reg_k,
+ )
+
+
+class NunchakuWeightPacker(MmaWeightPackerBase):
+ """
+ Nunchaku-specific weight packer. Provide Nunchaku-specific packing of
+ quantized weights, scales, and low-rank weights.
+
+ Parameters
+ ----------
+ bits : int
+ Number of quantization bits. Must be 1, 4, 8, 16, or 32.
+ warp_n : int, optional
+ Warp size in the n dimension. Default is 128.
+
+ Attributes
+ ----------
+ num_k_unrolls : int
+ Number of unrolls in the k dimension (always 2 for Nunchaku).
+ """
+
+ def __init__(self, bits: int, warp_n: int = 128):
+ super().__init__(bits=bits, warp_n=warp_n)
+ self.num_k_unrolls = 2
+
+ def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
+ """
+ Pack quantized weight tensor for Nunchaku MMA.
+
+ Parameters
+ ----------
+ weight : torch.Tensor
+ Quantized weight tensor of dtype torch.int32 and shape (n, k).
+
+ Returns
+ -------
+ torch.Tensor
+ Packed weight tensor of dtype torch.int8.
+ """
+ assert weight.dtype == torch.int32, f"quantized weight should be torch.int32, but got {weight.dtype}."
+ n, k = weight.shape
+ assert n % self.mem_n == 0, f"output channel size ({n}) should be divisible by mem_n ({self.mem_n})."
+ # currently, Nunchaku did not check the boundry of unrolled `k` dimension
+ assert k % (self.mem_k * self.num_k_unrolls) == 0, (
+ f"input channel size ({k}) should be divisible by "
+ f"mem_k ({self.mem_k}) * num_k_unrolls ({self.num_k_unrolls})."
+ )
+ n_tiles, k_tiles = n // self.mem_n, k // self.mem_k
+ weight = weight.reshape(
+ n_tiles,
+ self.num_n_packs, # 8 when warp_n = 128
+ self.n_pack_size, # always 2 in nunchaku
+ self.num_n_lanes, # constant 8
+ self.reg_n, # constant 1
+ k_tiles,
+ self.num_k_packs, # 1
+ self.k_pack_size, # always 2 in nunchaku
+ self.num_k_lanes, # constant 4
+ self.reg_k, # always 8 = 32 bits / 4 bits
+ )
+ # (n_tiles, num_n_packs, n_pack_size, num_n_lanes, reg_n, k_tiles, num_k_packs, k_pack_size, num_k_lanes, reg_k)
+ # =>
+ # (n_tiles, k_tiles, num_k_packs, num_n_packs, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, reg_n, reg_k)
+ weight = weight.permute(0, 5, 6, 1, 3, 8, 2, 7, 4, 9).contiguous()
+ assert weight.shape[4:-2] == (8, 4, 2, 2)
+ if self.bits == 4:
+ weight = weight.bitwise_and_(0xF)
+ shift = torch.arange(0, 32, 4, dtype=torch.int32, device=weight.device)
+ weight = weight.bitwise_left_shift_(shift)
+ weight = weight.sum(dim=-1, dtype=torch.int32)
+ elif self.bits == 8:
+ weight = weight.bitwise_and_(0xFF)
+ shift = torch.arange(0, 32, 8, dtype=torch.int32, device=weight.device)
+ weight = weight.bitwise_left_shift_(shift)
+ weight = weight.sum(dim=-1, dtype=torch.int32)
+ else:
+ raise NotImplementedError(f"weight bits {self.bits} is not supported.")
+ return weight.view(dtype=torch.int8).view(n, -1) # assume little-endian
+
+ def pack_scale(self, scale: torch.Tensor, group_size: int) -> torch.Tensor:
+ """
+ Pack scale tensor for Nunchaku MMA.
+
+ Parameters
+ ----------
+ scale : torch.Tensor
+ Scale tensor of dtype torch.float16 or torch.bfloat16.
+ group_size : int
+ Group size for quantization.
+
+ Returns
+ -------
+ torch.Tensor
+ Packed scale tensor.
+ """
+ if self.check_if_micro_scale(group_size=group_size):
+ return self.pack_micro_scale(scale, group_size=group_size)
+ # note: refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16864-c
+ assert scale.dtype in (torch.float16, torch.bfloat16), "currently nunchaku only supports fp16 and bf16."
+ n = scale.shape[0]
+ # nunchaku load scales all in one access
+ # for `[warp_n, warp_k]` weights, we load `[warp_n, warp_k / group_size]` scales
+ # scale loading is parallelized in `n` dimension, that is,
+ # `num_s_lanes` in a warp load `num_s_packs` of `s_pack_size` elements, in total `warp_s` elements
+ # each element in `n` dimension is 16 bit as it contains 1 fp16
+ # min `s_pack_size` set to 2 element, since each lane at least holds 2 accumulator results in `n` dimension
+ # max `s_pack_size` set to 128b/16b = 8 elements
+ # for `warp_n = 8`, we have
+ # `s_pack_size = 2`, `num_s_lanes = 4`, `num_s_packs = 1`
+ # for `warp_n = 128`, we have
+ # `s_pack_size = 4`, `num_s_lanes = 32`, `num_s_packs = 1`
+ # for `warp_n = 512`, we have
+ # `s_pack_size = 8`, `num_s_lanes = 32`, `num_s_packs = 2`
+ s_pack_size = min(max(self.warp_n // self.num_lanes, 2), 8)
+ num_s_lanes = min(self.num_lanes, self.warp_n // s_pack_size)
+ num_s_packs = self.warp_n // (s_pack_size * num_s_lanes)
+ warp_s = num_s_packs * num_s_lanes * s_pack_size
+ assert warp_s == self.warp_n, "warp_n for scales should be equal to warp_n for weights."
+ # `num_n_lanes = 8 (constant)` generates 8 elements consecutive in `n` dimension
+ # however, they are held by 4 lanes, each lane holds 2 elements in `n` dimension
+ # thus, we start from first 4 lanes, assign 2 elements to each lane, until all 8 elements are assigned
+ # we then repeat the process for the same 4 lanes, until each lane holds `s_pack_size` elements
+ # finally, we move to next 4 lanes, and repeat the process until all `num_s_lanes` lanes are assigned
+ # the process is repeated for `num_s_packs` times
+ # here is an example for `warp_n = 128, s_pack_size = 4, num_s_lanes = 32, num_s_packs = 1`
+ # wscales store order:
+ # 0 1 8 9 <-- load by lane 0, broadcast to lane {0, 4, 8, ..., 28} (8x)
+ # 2 3 10 11 <-- load by lane 1, broadcast to lane {1, 5, 9, ..., 29} (8x)
+ # 4 5 12 13 <-- load by lane 2, broadcast to lane {2, 6, 10, ..., 30} (8x)
+ # 6 7 14 15 <-- load by lane 3, broadcast to lane {3, 7, 11, ..., 31} (8x)
+ # 16 17 24 25 <-- load by lane 4, broadcast to lane {0, 4, 8, ..., 28} (8x)
+ # ...
+ # 22 23 30 31 <-- load by lane 7, broadcast to lane {3, 7, 11, ..., 31} (8x)
+ # ... ...
+ # 112 113 120 121 <-- load by lane 28, broadcast to lane {0, 4, 8, ..., 28} (8x)
+ # ...
+ # 118 119 126 127 <-- load by lane 31, broadcast to lane {3, 7, 11, ..., 31} (8x)
+ scale = scale.reshape(n // warp_s, num_s_packs, num_s_lanes // 4, s_pack_size // 2, 4, 2, -1)
+ scale = scale.permute(0, 6, 1, 2, 4, 3, 5).contiguous()
+ return scale.view(-1) if group_size == -1 else scale.view(-1, n) # the shape is just used for validation
+
+ def pack_micro_scale(self, scale: torch.Tensor, group_size: int) -> torch.Tensor:
+ """
+ Pack micro scale tensor for Nunchaku MMA.
+
+ Parameters
+ ----------
+ scale : torch.Tensor
+ Scale tensor of dtype torch.float16 or torch.bfloat16.
+ group_size : int
+ Group size for quantization (must be 16).
+
+ Returns
+ -------
+ torch.Tensor
+ Packed micro scale tensor.
+ """
+ assert scale.dtype in (torch.float16, torch.bfloat16), "currently nunchaku only supports fp16 and bf16."
+ assert scale.max() <= 448, "scale should be less than 448."
+ assert scale.min() >= -448, "scale should be greater than -448."
+ assert group_size == 16, "currently only support group size 16."
+ assert self.insn_k == 64, "insn_k should be 64."
+ scale = scale.to(dtype=torch.float8_e4m3fn)
+ n = scale.shape[0]
+ assert self.warp_n >= 32, "currently only support warp_n >= 32."
+ # for `[warp_n, warp_k]` weights, we load `[warp_n, warp_k / group_size]` scales
+ # scale loading is parallelized in `n` dimension, that is,
+ # `num_s_lanes` in a warp load `num_s_packs` of `s_pack_size` elements, in total `warp_s` elements
+ # each element in `n` dimension is 32 bit as it contains 4 fp8 in `k` dimension
+ # min `s_pack_size` set to 1 element
+ # max `s_pack_size` set to 128b/32b = 4 elements
+ # for `warp_n = 128`, we have
+ # `s_pack_size = 4`, `num_s_lanes = 32`, `num_s_packs = 1`
+ # for `warp_n = 512`, we have
+ # `s_pack_size = 8`, `num_s_lanes = 32`, `num_s_packs = 2`
+ s_pack_size = min(max(self.warp_n // self.num_lanes, 1), 4)
+ num_s_lanes = 4 * 8 # 32 lanes is divided into 4 pieces, each piece has 8 lanes at a stride of 4
+ num_s_packs = ceil_divide(self.warp_n, s_pack_size * num_s_lanes)
+ warp_s = num_s_packs * num_s_lanes * s_pack_size
+ assert warp_s == self.warp_n, "warp_n for scales should be equal to warp_n for weights."
+ # note: refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-scaling-thread-id-b-selection
+ # we start from first 8 lines at a stride of 4, assign 1 element to each lane, until all 8 elements are assigned
+ # we then move to next 8 lines at a stride of 4, and repeat the process until all 32 lanes are assigned
+ # here is an example for `warp_n = 128, s_pack_size = 4, num_s_lanes = 32, num_s_packs = 1`
+ # wscales store order:
+ # 0 32 64 96 <-- load by lane 0
+ # 8 40 72 104 <-- load by lane 1
+ # 16 48 80 112 <-- load by lane 2
+ # 24 56 88 120 <-- load by lane 3
+ # 1 33 65 97 <-- load by lane 4
+ # ...
+ # 25 57 81 113 <-- load by lane 7
+ # ...
+ # 7 39 71 103 <-- load by lane 28
+ # ...
+ # 31 63 95 127 <-- load by lane 31
+ scale = scale.view(n // warp_s, num_s_packs, s_pack_size, 4, 8, -1, self.insn_k // group_size)
+ scale = scale.permute(0, 5, 1, 4, 3, 2, 6).contiguous()
+ return scale.view(-1, n) # the shape is just used for validation
+
+ def pack_lowrank_weight(self, weight: torch.Tensor, down: bool) -> torch.Tensor:
+ """
+ Pack low-rank weight tensor.
+
+ Parameters
+ ----------
+ weight : torch.Tensor
+ Low-rank weight tensor of dtype torch.float16 or torch.bfloat16.
+ down : bool
+ If True, weight is for down projection in low-rank branch.
+
+ Returns
+ -------
+ torch.Tensor
+ Packed low-rank weight tensor.
+ """
+ assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}."
+ reg_n, reg_k = 1, 2 # reg_n is always 1, reg_k is 32 bits // 16 bits = 2
+ pack_n = self.n_pack_size * self.num_n_lanes * reg_n
+ pack_k = self.k_pack_size * self.num_k_lanes * reg_k
+ weight = pad(weight, divisor=(pack_n, pack_k), dim=(0, 1))
+ if down:
+ r, c = weight.shape
+ r_packs, c_packs = r // pack_n, c // pack_k
+ weight = weight.view(r_packs, pack_n, c_packs, pack_k).permute(2, 0, 1, 3)
+ else:
+ c, r = weight.shape
+ c_packs, r_packs = c // pack_n, r // pack_k
+ weight = weight.view(c_packs, pack_n, r_packs, pack_k).permute(0, 2, 1, 3)
+ weight = weight.reshape(
+ c_packs, r_packs, self.n_pack_size, self.num_n_lanes, reg_n, self.k_pack_size, self.num_k_lanes, reg_k
+ )
+ # (c_packs, r_packs, n_pack_size, num_n_lanes, reg_n, k_pack_size, num_k_lanes, reg_k)
+ # =>
+ # (c_packs, r_packs, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, reg_n, reg_k)
+ weight = weight.permute(0, 1, 3, 6, 2, 5, 4, 7).contiguous()
+ return weight.view(c, r)
+
+ def unpack_lowrank_weight(self, weight: torch.Tensor, down: bool) -> torch.Tensor:
+ """
+ Unpack low-rank weight tensor.
+
+ Parameters
+ ----------
+ weight : torch.Tensor
+ Packed low-rank weight tensor of dtype torch.float16 or torch.bfloat16.
+ down : bool
+ If True, weight is for down projection in low-rank branch.
+
+ Returns
+ -------
+ torch.Tensor
+ Unpacked low-rank weight tensor.
+ """
+ c, r = weight.shape
+ assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}."
+ reg_n, reg_k = 1, 2 # reg_n is always 1, reg_k is 32 bits // 16 bits = 2
+ pack_n = self.n_pack_size * self.num_n_lanes * reg_n
+ pack_k = self.k_pack_size * self.num_k_lanes * reg_k
+ if down:
+ r_packs, c_packs = r // pack_n, c // pack_k
+ else:
+ c_packs, r_packs = c // pack_n, r // pack_k
+ weight = weight.view(
+ c_packs, r_packs, self.num_n_lanes, self.num_k_lanes, self.n_pack_size, self.k_pack_size, reg_n, reg_k
+ )
+ # (c_packs, r_packs, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, reg_n, reg_k)
+ # =>
+ # (c_packs, r_packs, n_pack_size, num_n_lanes, reg_n, k_pack_size, num_k_lanes, reg_k)
+ weight = weight.permute(0, 1, 4, 2, 6, 5, 3, 7).contiguous()
+ weight = weight.view(c_packs, r_packs, pack_n, pack_k)
+ if down:
+ weight = weight.permute(1, 2, 0, 3).contiguous().view(r, c)
+ else:
+ weight = weight.permute(0, 2, 1, 3).contiguous().view(c, r)
+ return weight
+
+ def check_if_micro_scale(self, group_size: int) -> bool:
+ """
+ Check if micro scale packing is required.
+
+ Parameters
+ ----------
+ group_size : int
+ Group size for quantization.
+
+ Returns
+ -------
+ bool
+ True if micro scale packing is required.
+ """
+ return self.insn_k == group_size * 4
+
+ def pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
+ """
+ Pad weight tensor to required shape.
+
+ Parameters
+ ----------
+ weight : torch.Tensor
+ Weight tensor of shape (n, k).
+
+ Returns
+ -------
+ torch.Tensor
+ Padded weight tensor.
+ """
+ assert weight.ndim == 2, "weight tensor should be 2D."
+ return pad(weight, divisor=(self.mem_n, self.mem_k * self.num_k_unrolls), dim=(0, 1))
+
+ def pad_scale(self, scale: torch.Tensor, group_size: int, fill_value: float = 0) -> torch.Tensor:
+ """
+ Pad scale tensor to required shape.
+
+ Parameters
+ ----------
+ scale : torch.Tensor
+ Scale tensor.
+ group_size : int
+ Group size for quantization.
+ fill_value : float, optional
+ Value to use for padding. Default is 0.
+
+ Returns
+ -------
+ torch.Tensor
+ Padded scale tensor.
+ """
+ if group_size > 0 and scale.numel() > scale.shape[0]:
+ scale = scale.view(scale.shape[0], 1, -1, 1)
+ if self.check_if_micro_scale(group_size=group_size):
+ scale = pad(scale, divisor=(self.warp_n, self.insn_k // group_size), dim=(0, 2), fill_value=fill_value)
+ else:
+ scale = pad(scale, divisor=(self.warp_n, self.num_k_unrolls), dim=(0, 2), fill_value=fill_value)
+ else:
+ scale = pad(scale, divisor=self.warp_n, dim=0, fill_value=fill_value)
+ return scale
+
+ def pad_lowrank_weight(self, weight: torch.Tensor, down: bool) -> torch.Tensor:
+ """
+ Pad low-rank weight tensor to required shape.
+
+ Parameters
+ ----------
+ weight : torch.Tensor
+ Low-rank weight tensor.
+ down : bool
+ If True, weight is for down projection in low-rank branch.
+
+ Returns
+ -------
+ torch.Tensor
+ Padded low-rank weight tensor.
+ """
+ assert weight.ndim == 2, "weight tensor should be 2D."
+ return pad(weight, divisor=self.warp_n, dim=1 if down else 0)
diff --git a/nunchaku/lora/flux/utils.py b/nunchaku/lora/flux/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..65b86ea815e365d11a8d91634f0f677e31d17e84
--- /dev/null
+++ b/nunchaku/lora/flux/utils.py
@@ -0,0 +1,94 @@
+"""
+Utility functions for LoRAs in Flux models.
+"""
+
+import typing as tp
+
+import torch
+
+from ...utils import ceil_divide, load_state_dict_in_safetensors
+
+
+def is_nunchaku_format(lora: str | dict[str, torch.Tensor]) -> bool:
+ """
+ Check if LoRA weights are in Nunchaku format.
+
+ Parameters
+ ----------
+ lora : str or dict[str, torch.Tensor]
+ Path to a safetensors file or a dictionary of LoRA weights.
+
+ Returns
+ -------
+ bool
+ True if the weights are in Nunchaku format, False otherwise.
+
+ Examples
+ --------
+ >>> is_nunchaku_format("path/to/lora.safetensors")
+ True
+ """
+ if isinstance(lora, str):
+ tensors = load_state_dict_in_safetensors(lora, device="cpu", return_metadata=False)
+ assert isinstance(tensors, dict), "Expected dict when return_metadata=False"
+ else:
+ tensors = lora
+
+ for k in tensors.keys():
+ if ".mlp_fc" in k or "mlp_context_fc1" in k:
+ return True
+ return False
+
+
+def pad(
+ tensor: tp.Optional[torch.Tensor],
+ divisor: int | tp.Sequence[int],
+ dim: int | tp.Sequence[int],
+ fill_value: float | int = 0,
+) -> torch.Tensor | None:
+ """
+ Pad a tensor so specified dimensions are divisible by given divisors.
+
+ Parameters
+ ----------
+ tensor : torch.Tensor or None
+ The tensor to pad. If None, returns None.
+ divisor : int or sequence of int
+ Divisor(s) for the dimension(s) to pad.
+ dim : int or sequence of int
+ Dimension(s) to pad.
+ fill_value : float or int, optional
+ Value to use for padding (default: 0).
+
+ Returns
+ -------
+ torch.Tensor or None
+ The padded tensor, or None if input tensor was None.
+
+ Examples
+ --------
+ >>> tensor = torch.randn(10, 20)
+ >>> pad(tensor, divisor=16, dim=0).shape
+ torch.Size([16, 20])
+ >>> pad(tensor, divisor=[16, 32], dim=[0, 1]).shape
+ torch.Size([16, 32])
+ """
+ if isinstance(divisor, int):
+ if divisor <= 1:
+ return tensor
+ elif all(d <= 1 for d in divisor):
+ return tensor
+ if tensor is None:
+ return None
+ shape = list(tensor.shape)
+ if isinstance(dim, int):
+ assert isinstance(divisor, int)
+ shape[dim] = ceil_divide(shape[dim], divisor) * divisor
+ else:
+ if isinstance(divisor, int):
+ divisor = [divisor] * len(dim)
+ for d, div in zip(dim, divisor, strict=True):
+ shape[d] = ceil_divide(shape[d], div) * div
+ result = torch.full(shape, fill_value, dtype=tensor.dtype, device=tensor.device)
+ result[[slice(0, extent) for extent in tensor.shape]] = tensor
+ return result
diff --git a/nunchaku/models/__init__.py b/nunchaku/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bc7e804c946ab9242bed0af708c348b43483a31
--- /dev/null
+++ b/nunchaku/models/__init__.py
@@ -0,0 +1,9 @@
+from .text_encoders.t5_encoder import NunchakuT5EncoderModel
+from .transformers import (
+ NunchakuFluxTransformer2dModel,
+)
+
+__all__ = [
+ "NunchakuFluxTransformer2dModel",
+ "NunchakuT5EncoderModel",
+]
diff --git a/nunchaku/models/attention.py b/nunchaku/models/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a9fdbce2c3cdc2ed81a4a47caa91111bbbf3369
--- /dev/null
+++ b/nunchaku/models/attention.py
@@ -0,0 +1,123 @@
+"""
+Nunchaku quantized attention-related modules.
+"""
+
+import torch
+from diffusers.models.activations import GELU
+from diffusers.models.attention import FeedForward
+from torch import nn
+
+from ..ops.fused import fused_gelu_mlp
+from .linear import SVDQW4A4Linear
+
+
+class NunchakuBaseAttention(nn.Module):
+ """
+ Base class for Nunchaku attention modules.
+
+ Provides a common interface for attention modules with processor selection.
+
+ Parameters
+ ----------
+ processor : str, optional
+ Name of the attention processor to use. Default is "flashattn2".
+ *args, **kwargs :
+ Additional arguments for subclass initialization.
+ """
+
+ def __init__(self, processor: str = "flashattn2", *args, **kwargs):
+ super(NunchakuBaseAttention, self).__init__()
+ self.processor = None
+ self.set_processor(processor)
+
+ def set_processor(self, processor: str):
+ """
+ Set the attention processor. Must be implemented by subclasses.
+
+ Parameters
+ ----------
+ processor : str
+ Name of the processor to use.
+
+ Raises
+ ------
+ NotImplementedError
+ If not implemented in subclass.
+ """
+ raise NotImplementedError("Subclass must implement this method")
+
+
+def _patch_linear(module: nn.Module, linear_cls, **kwargs) -> nn.Module:
+ """
+ Recursively replace all nn.Linear modules in a given module with a custom linear class.
+
+ Parameters
+ ----------
+ module : nn.Module
+ The module to patch.
+ linear_cls : type
+ The custom linear class to use for replacement.
+ **kwargs :
+ Additional arguments passed to ``from_linear``.
+
+ Returns
+ -------
+ nn.Module
+ The patched module with custom linear layers.
+ """
+ for name, child in module.named_children():
+ if isinstance(child, nn.Linear):
+ setattr(module, name, linear_cls.from_linear(child, **kwargs))
+ else:
+ _patch_linear(child, linear_cls, **kwargs)
+ return module
+
+
+class NunchakuFeedForward(FeedForward):
+ """
+ Quantized feed-forward (MLP) block with fused GELU support.
+
+ Replaces linear layers in a FeedForward block with :class:`~nunchaku.models.linear.SVDQW4A4Linear` for quantized inference.
+ Supports fused GELU-MLP computation for efficiency.
+
+ Parameters
+ ----------
+ ff : FeedForward
+ Source FeedForward block to quantize.
+ **kwargs :
+ Additional arguments for SVDQW4A4Linear.
+
+ Notes
+ -----
+ For int4 quantization, the activation of the second MLP layer is shifted to be unsigned.
+ """
+
+ def __init__(self, ff: FeedForward, **kwargs):
+ super(FeedForward, self).__init__()
+ self.net = _patch_linear(ff.net, SVDQW4A4Linear, **kwargs)
+ # For int4, shift the activation of mlp_fc2 to make it unsigned
+ self.net[2].act_unsigned = self.net[2].precision != "nvfp4"
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass for the quantized feed-forward block.
+ It will call :func:`~nunchaku.ops.fused.fused_gelu_mlp` if the first layer is GELU;
+ otherwise, apply modules sequentially.
+
+ Parameters
+ ----------
+ hidden_states : torch.Tensor, shape (B, D)
+ Input tensor.
+
+ Returns
+ -------
+ torch.Tensor, shape (B, D)
+ Output tensor after feed-forward transformation.
+ """
+ if isinstance(self.net[0], GELU):
+ return fused_gelu_mlp(hidden_states, self.net[0].proj, self.net[2])
+ else:
+ # Fallback to original implementation
+ for module in self.net:
+ hidden_states = module(hidden_states)
+ return hidden_states
diff --git a/nunchaku/models/embeddings.py b/nunchaku/models/embeddings.py
new file mode 100644
index 0000000000000000000000000000000000000000..12b50966200dc321c6b99679ac7cadfb07f8e594
--- /dev/null
+++ b/nunchaku/models/embeddings.py
@@ -0,0 +1,138 @@
+"""
+Embedding layers for Nunchaku.
+"""
+
+import diffusers
+import torch
+from packaging.version import Version
+from torch import nn
+
+
+def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
+ """
+ Rotary positional embedding function.
+ Copied from https://github.com/huggingface/diffusers/blob/c9ff360966327ace3faad3807dc871a4e5447501/src/diffusers/models/transformers/transformer_flux.py#L38
+
+ Parameters
+ ----------
+ pos : torch.Tensor, shape (..., n), dtype int
+ Position indices.
+ dim : int
+ Embedding dimension (must be even).
+ theta : int
+ Rotary base.
+
+ Returns
+ -------
+ out : torch.Tensor, shape (B, M, D//2, 1, 2), dtype float32
+ Rotary embedding tensor.
+
+ Notes
+ -----
+ - B: batch size
+ - M: sequence length
+ - D: embedding dimension
+ """
+ assert dim % 2 == 0, "The dimension must be even."
+
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
+ omega = 1.0 / (theta**scale)
+
+ batch_size, seq_length = pos.shape
+ out = torch.einsum("...n,d->...nd", pos, omega)
+
+ # Sin/cos representation for rotary embedding
+ cos_out = torch.cos(out)
+ sin_out = torch.sin(out)
+ stacked_out = torch.stack([sin_out, cos_out], dim=-1)
+ out = stacked_out.view(batch_size, -1, dim // 2, 1, 2)
+
+ return out.float()
+
+
+class NunchakuFluxPosEmbed(nn.Module):
+ """
+ Nunchaku multi-dimensional rotary embedding module for FLUX.
+ Adapted from https://github.com/huggingface/diffusers/blob/c9ff360966327ace3faad3807dc871a4e5447501/src/diffusers/models/transformers/transformer_flux.py#L55
+
+ Parameters
+ ----------
+ dim : int
+ Embedding dimension.
+ theta : int
+ Rotary base.
+ axes_dim : list of int
+ Dimension for each spatial axis.
+ """
+
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
+ super(NunchakuFluxPosEmbed, self).__init__()
+ self.dim = dim
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
+ """
+ Compute rotary embeddings for multi-dimensional positions.
+
+ Parameters
+ ----------
+ ids : torch.Tensor, shape (..., n_axes), dtype int
+ Position indices.
+
+ Returns
+ -------
+ out : torch.Tensor, shape (B, 1, ...), dtype float32
+ Rotary embedding tensor.
+
+ Notes
+ -----
+ - B: batch size
+ - n_axes: number of spatial axes
+ """
+ if Version(diffusers.__version__) >= Version("0.31.0"):
+ ids = ids[None, ...]
+ n_axes = ids.shape[-1]
+ emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
+ return emb.unsqueeze(1)
+
+
+def pack_rotemb(rotemb: torch.Tensor) -> torch.Tensor:
+ """
+ Pack rotary embeddings for efficient CUDA computation.
+
+ Parameters
+ ----------
+ rotemb : torch.Tensor, shape (B, M, D//2, 1, 2), dtype float32
+ Rotary embedding tensor.
+
+ Returns
+ -------
+ packed : torch.Tensor, shape (B, M, D), dtype float32
+ Packed rotary embedding tensor.
+
+ Notes
+ -----
+ - B: batch size
+ - M: sequence length (must be divisible by 16)
+ - D: embedding dimension (must be divisible by 8)
+ """
+ assert rotemb.dtype == torch.float32
+ B = rotemb.shape[0]
+ M = rotemb.shape[1]
+ D = rotemb.shape[2] * 2
+ assert rotemb.shape == (B, M, D // 2, 1, 2)
+ assert M % 16 == 0
+ assert D % 8 == 0
+ rotemb = rotemb.reshape(B, M // 16, 16, D // 8, 8)
+ rotemb = rotemb.permute(0, 1, 3, 2, 4)
+ # 16*8 pack, FP32 accumulator (C) format
+ # https://docs.nvidia.com/cuda/parallel-thread-execution/#mma-16816-c
+ ##########################################|--M--|--D--|
+ ##########################################|-3--4--5--6|
+ ########################################## : : : :
+ rotemb = rotemb.reshape(*rotemb.shape[0:3], 2, 8, 4, 2)
+ rotemb = rotemb.permute(0, 1, 2, 4, 5, 3, 6)
+ rotemb = rotemb.contiguous()
+ rotemb = rotemb.view(B, M, D)
+ return rotemb
diff --git a/nunchaku/models/linear.py b/nunchaku/models/linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e3b6cd51aeb98f10a3034e761eb14a27d8d81f9
--- /dev/null
+++ b/nunchaku/models/linear.py
@@ -0,0 +1,414 @@
+"""
+Quantized linear layers for Nunchaku.
+"""
+
+import torch
+from torch import nn
+
+from ..ops.gemm import svdq_gemm_w4a4_cuda
+from ..ops.gemv import awq_gemv_w4a16_cuda
+from ..ops.quantize import svdq_quantize_w4a4_act_fuse_lora_cuda
+
+
+class SVDQW4A4Linear(nn.Module):
+ """
+ `SVDQuant `_ W4A4 quantized linear layer.
+
+ Parameters
+ ----------
+ in_features : int
+ Input feature dimension.
+ out_features : int
+ Output feature dimension.
+ rank : int, optional
+ SVD low-rank dimension. Default is 32.
+ bias : bool, optional
+ If True, adds a learnable bias. Default is True.
+ precision : {'int4', 'nvfp4'}, optional
+ Quantization precision data type ('int4' or 'nvfp4'). Default is 'int4'.
+ act_unsigned : bool, optional
+ If True, use unsigned activation quantization (int4 only). Default is False.
+ torch_dtype : torch.dtype, optional
+ Parameter dtype. Default is torch.bfloat16.
+ device : str or torch.device or None, optional
+ Device for parameters. Default is CPU.
+
+ Attributes
+ ----------
+ in_features : int
+ out_features : int
+ rank : int
+ precision : str
+ 'int4' or 'nvfp4'.
+ group_size : int
+ 64 for int4, 16 for nvfp4.
+ qweight : nn.Parameter
+ Packed quantized weights, shape (out_features, in_features // 2), dtype int8.
+ bias : nn.Parameter or None
+ Bias tensor.
+ wscales : nn.Parameter
+ Weight scales, shape (in_features // group_size, out_features).
+ Dtype: bfloat16/float16 (int4), float8_e4m3fn (nvfp4).
+ smooth_factor : nn.Parameter
+ Smoothing factors, shape (in_features,).
+ smooth_factor_orig : nn.Parameter
+ Original smoothing factors, shape (in_features,). (Unused)
+ proj_down : nn.Parameter
+ Packed low-rank down projection, shape (in_features, rank), dtype bfloat16/float16.
+ proj_up : nn.Parameter
+ Packed low-rank up projection, shape (out_features, rank), dtype bfloat16/float16.
+ wtscale : float or None
+ Global weight scale (nvfp4 only).
+ wcscales : nn.Parameter or None
+ Channel-wise weight scale (nvfp4 only), shape (out_features,), dtype float8_e4m3fn.
+ act_unsigned : bool
+ If True, input activations are unsigned (int4 only).
+ """
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ rank: int = 32,
+ bias: bool = True,
+ precision: str = "int4",
+ act_unsigned: bool = False,
+ torch_dtype: torch.dtype = torch.bfloat16,
+ device: str | torch.device | None = None,
+ ):
+ super(SVDQW4A4Linear, self).__init__()
+ if device is None:
+ device = torch.device("cpu")
+ self.in_features = in_features
+ self.out_features = out_features
+ self.rank = rank
+
+ self.precision = precision
+ self.torch_dtype = torch_dtype
+
+ if precision == "nvfp4":
+ self.group_size = 16
+ elif precision == "int4":
+ self.group_size = 64
+ else:
+ raise ValueError(f"Invalid precision: {precision}")
+
+ self.qweight = nn.Parameter(
+ torch.empty(out_features, in_features // 2, dtype=torch.int8, device=device), requires_grad=False
+ )
+ self.bias = (
+ nn.Parameter(torch.empty(out_features, dtype=torch_dtype, device=device), requires_grad=True)
+ if bias
+ else None
+ )
+
+ self.wscales = nn.Parameter(
+ torch.empty(
+ in_features // self.group_size,
+ out_features,
+ dtype=torch_dtype if precision == "int4" else torch.float8_e4m3fn,
+ device=device,
+ ),
+ requires_grad=False,
+ )
+ self.smooth_factor = nn.Parameter(
+ torch.empty(in_features, dtype=torch_dtype, device=device), requires_grad=False
+ )
+ self.smooth_factor_orig = nn.Parameter(
+ torch.empty(in_features, dtype=torch_dtype, device=device), requires_grad=False
+ )
+
+ self.proj_down = nn.Parameter(torch.empty(in_features, rank, dtype=torch_dtype, device=device))
+ self.proj_up = nn.Parameter(torch.empty(out_features, rank, dtype=torch_dtype, device=device))
+
+ if precision == "nvfp4":
+ self.wcscales = nn.Parameter(
+ torch.ones(out_features, dtype=torch_dtype, device=device), requires_grad=False
+ )
+ self.wtscale = 1.0
+ else:
+ self.wtscale = None
+ self.wcscales = None
+
+ self.act_unsigned = act_unsigned
+
+ @classmethod
+ def from_linear(cls, linear: nn.Linear, **kwargs):
+ """
+ Create an SVDQW4A4Linear from a standard nn.Linear. The weight and bias are dummy tensors.
+
+ Parameters
+ ----------
+ linear : nn.Linear
+ Source linear layer.
+ **kwargs
+ Additional init arguments.
+
+ Returns
+ -------
+ SVDQW4A4Linear
+ """
+ in_features = kwargs.pop("in_features", linear.in_features)
+ return cls(
+ in_features=in_features,
+ out_features=linear.out_features,
+ bias=linear.bias is not None,
+ torch_dtype=linear.weight.dtype,
+ device=linear.weight.device,
+ **kwargs,
+ )
+
+ def forward(self, x: torch.Tensor, output: torch.Tensor | None = None) -> torch.Tensor:
+ """
+ Forward pass with 16-bit input. It will call :meth:`quantize` and :meth:`forward_quant`.
+
+ Parameters
+ ----------
+ x : torch.Tensor, shape (B, S, in_features), dtype float16 or bfloat16
+ Input tensor.
+ output : torch.Tensor or None, optional
+ Optional output buffer.
+
+ Returns
+ -------
+ torch.Tensor, shape (B, S, out_features)
+ Output tensor.
+
+ Notes
+ -----
+ B: batch size, S: sequence length
+ """
+ batch_size, seq_len, channels = x.shape
+ x = x.reshape(batch_size * seq_len, channels)
+ if output is None:
+ output = torch.empty(batch_size * seq_len, self.out_features, dtype=x.dtype, device=x.device)
+ quantized_x, ascales, lora_act_out = self.quantize(x)
+ output = self.forward_quant(quantized_x, ascales, lora_act_out, output)
+ output = output.reshape(batch_size, seq_len, -1)
+ return output
+
+ def quantize(self, x: torch.Tensor, pad_size: int = 256) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Quantize input to 4-bit and compute low-rank hidden states. It will call :func:`~nunchaku.ops.quantize.svdq_quantize_w4a4_act_fuse_lora_cuda`.
+
+ Parameters
+ ----------
+ x : torch.Tensor, shape (N, in_features), dtype float16 or bfloat16
+ Input tensor.
+ pad_size : int, optional
+ Batch padding size. Default is 256.
+
+ Returns
+ -------
+ quantized_x : torch.Tensor
+ Quantized input, shape (pad_size * ceil(N / pad_size), in_features // 2), dtype uint8.
+ ascales : torch.Tensor
+ Activation scales, shape (in_features // group_size,), dtype float8_e4m3fn for nvfp4 and input dtype for int4.
+ lora_act_out : torch.Tensor
+ Low-rank hidden states, shape (pad_size * ceil(N / pad_size), rank), dtype float32.
+
+ Notes
+ -----
+ N: batch size
+ """
+ quantized_x, ascales, lora_act_out = svdq_quantize_w4a4_act_fuse_lora_cuda(
+ x, lora_down=self.proj_down, smooth=self.smooth_factor, fp4=self.precision == "nvfp4", pad_size=pad_size
+ )
+ return quantized_x, ascales, lora_act_out
+
+ def forward_quant(
+ self,
+ quantized_x: torch.Tensor,
+ ascales: torch.Tensor,
+ lora_act: torch.Tensor,
+ output: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ """
+ Forward pass with pre-quantized input. It will call :func:`~nunchaku.ops.gemm.svdq_gemm_w4a4_cuda`.
+
+ Parameters
+ ----------
+ quantized_x : torch.Tensor
+ Quantized input, shape (N, in_features // 2), dtype uint8.
+ ascales : torch.Tensor
+ Activation scales, shape (in_features // group_size,), dtype float8_e4m3fn for nvfp4 and input dtype for int4.
+ lora_act : torch.Tensor
+ Low-rank hidden states, shape (N, rank), dtype float32.
+ output : torch.Tensor or None, optional
+ Optional output buffer.
+
+ Returns
+ -------
+ torch.Tensor
+ Output tensor, shape (N, out_features), dtype bfloat16/float16 for int4 and float8_e4m3fn for nvfp4.
+
+ Notes
+ -----
+ N: batch size
+ """
+ if output is None:
+ output = torch.empty(
+ quantized_x.shape[0], self.out_features, dtype=self.proj_up.dtype, device=quantized_x.device
+ )
+
+ svdq_gemm_w4a4_cuda(
+ act=quantized_x,
+ wgt=self.qweight,
+ out=output,
+ ascales=ascales,
+ wscales=self.wscales,
+ lora_act_in=lora_act,
+ lora_up=self.proj_up,
+ bias=self.bias,
+ fp4=self.precision == "nvfp4",
+ alpha=self.wtscale,
+ wcscales=self.wcscales,
+ act_unsigned=self.act_unsigned,
+ )
+ return output
+
+ def __repr__(self):
+ return (
+ f"SVDQW4A4Linear(in_features={self.in_features}, out_features={self.out_features}, "
+ f"rank={self.rank}, precision={self.precision}, act_unsigned={self.act_unsigned})"
+ )
+
+
+class AWQW4A16Linear(nn.Module):
+ """
+ `AWQ `_ W4A16 quantized linear layer.
+
+ Parameters
+ ----------
+ in_features : int
+ Input feature dimension.
+ out_features : int
+ Output feature dimension.
+ bias : bool, optional
+ If True, adds learnable bias. Default is True.
+ group_size : int, optional
+ Quantization group size. Default is 64.
+ torch_dtype : torch.dtype, optional
+ Parameter dtype. Default is torch.bfloat16.
+ device : str or torch.device or None, optional
+ Device for parameters. Default is CPU.
+
+ Attributes
+ ----------
+ in_features : int
+ out_features : int
+ group_size : int
+ qweight : nn.Parameter
+ Packed quantized weights, shape (out_features // 4, in_features // 2), dtype int32.
+ bias : nn.Parameter or None
+ Bias tensor.
+ wscales : nn.Parameter
+ Weight scales, shape (in_features // group_size, out_features), dtype float16 or bfloat16.
+ wzeros : nn.Parameter
+ Weight zero points, shape (in_features // group_size, out_features), dtype float16 or bfloat16.
+ """
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+ group_size: int = 64,
+ torch_dtype: torch.dtype = torch.bfloat16,
+ device: str | torch.device | None = None,
+ ):
+ super(AWQW4A16Linear, self).__init__()
+ if device is None:
+ device = torch.device("cpu")
+ self.in_features = in_features
+ self.out_features = out_features
+ self.group_size = group_size
+
+ self.qweight = nn.Parameter(
+ torch.empty(out_features // 4, in_features // 2, dtype=torch.int32, device=device), requires_grad=False
+ )
+ self.bias = (
+ nn.Parameter(torch.empty(out_features, dtype=torch_dtype, device=device), requires_grad=True)
+ if bias
+ else None
+ )
+ self.wscales = nn.Parameter(
+ torch.empty(in_features // self.group_size, out_features, dtype=torch_dtype, device=device),
+ requires_grad=False,
+ )
+ self.wzeros = nn.Parameter(
+ torch.empty(in_features // self.group_size, out_features, dtype=torch_dtype, device=device),
+ requires_grad=False,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass for AWQW4A16Linear.
+
+ Parameters
+ ----------
+ x : torch.Tensor, shape (N, in_features)
+ Input tensor.
+
+ Returns
+ -------
+ torch.Tensor, shape (N, out_features)
+ Output tensor.
+
+ Notes
+ -----
+ N: batch size
+ """
+ output = awq_gemv_w4a16_cuda(
+ in_feats=x,
+ kernel=self.qweight,
+ scaling_factors=self.wscales,
+ zeros=self.wzeros,
+ m=x.shape[0],
+ n=self.out_features,
+ k=self.in_features,
+ group_size=self.group_size,
+ )
+ if self.bias is not None:
+ view_shape = [1] * (output.ndim - 1) + [-1]
+ output.add_(self.bias.view(view_shape))
+ return output
+
+ @classmethod
+ def from_linear(
+ cls,
+ linear: nn.Linear,
+ group_size: int = 64,
+ torch_dtype: torch.dtype = torch.bfloat16,
+ device: str = "cpu",
+ **kwargs,
+ ):
+ """
+ Create an uninitialized AWQW4A16Linear from a standard nn.Linear.
+
+ Parameters
+ ----------
+ linear : nn.Linear
+ Source linear layer.
+ group_size : int, optional
+ Quantization group size.
+ torch_dtype : torch.dtype, optional
+ Parameter dtype.
+ device : str, optional
+ Device for parameters.
+
+ Returns
+ -------
+ AWQW4A16Linear
+ """
+ return cls(
+ in_features=linear.in_features,
+ out_features=linear.out_features,
+ bias=linear.bias is not None,
+ group_size=group_size,
+ torch_dtype=torch_dtype,
+ device=device,
+ )
+
+ def __repr__(self):
+ return f"AWQW4A16Linear(in_features={self.in_features}, out_features={self.out_features}, group_size={self.group_size})"
diff --git a/nunchaku/models/normalization.py b/nunchaku/models/normalization.py
new file mode 100644
index 0000000000000000000000000000000000000000..eede8c73bbd1291137ab931351b3ed8d31ba6b29
--- /dev/null
+++ b/nunchaku/models/normalization.py
@@ -0,0 +1,166 @@
+"""
+Quantized normalization layers for efficient inference.
+"""
+
+from typing import Optional, Tuple
+
+import torch
+from diffusers.models.normalization import AdaLayerNormZero, AdaLayerNormZeroSingle
+
+from .linear import AWQW4A16Linear
+
+
+class NunchakuAdaLayerNormZero(AdaLayerNormZero):
+ """
+ Nunchaku quantized AdaLayerNormZero for diffusion models.
+
+ Replaces the linear projection with AWQW4A16Linear for quantized inference.
+
+ Parameters
+ ----------
+ other : AdaLayerNormZero
+ Source AdaLayerNormZero instance to copy weights and structure from.
+ scale_shift : float, optional
+ Value to add to scale parameters. Default is 1.0.
+ Nunchaku may have already fused the scale_shift into the linear weights, so you may want to set it to 0.
+
+ Notes
+ -----
+ - B: batch size
+ - D: hidden dimension
+ """
+
+ def __init__(self, other: AdaLayerNormZero, scale_shift: float = 1.0):
+ super(AdaLayerNormZero, self).__init__()
+ self.scale_shift = scale_shift
+ self.emb = other.emb
+ self.silu = other.silu
+ self.linear = AWQW4A16Linear.from_linear(other.linear)
+ self.norm = other.norm
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ timestep: Optional[torch.Tensor] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ hidden_dtype: Optional[torch.dtype] = None,
+ emb: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Forward pass for quantized AdaLayerNormZero.
+
+ Parameters
+ ----------
+ x : torch.Tensor, shape (B, D), dtype float32/float16
+ Input tensor.
+ timestep : Optional[torch.Tensor], shape (B,) or (1,), optional
+ Timestep embedding input.
+ class_labels : Optional[torch.LongTensor], shape (B,) or (1,), optional
+ Class label input.
+ hidden_dtype : Optional[torch.dtype], optional
+ Dtype for embedding computation.
+ emb : Optional[torch.Tensor], shape (B, E), optional
+ Precomputed embedding. If None, computed from timestep and class_labels.
+
+ Returns
+ -------
+ norm_x_scaled : torch.Tensor, shape (B, D)
+ Normalized and scaled input.
+ gate_msa : torch.Tensor, shape (B, D)
+ Gate for MSA branch.
+ shift_mlp : torch.Tensor, shape (B, D)
+ Shift for MLP branch.
+ scale_mlp : torch.Tensor, shape (B, D)
+ Scale for MLP branch.
+ gate_mlp : torch.Tensor, shape (B, D)
+ Gate for MLP branch.
+
+ Notes
+ -----
+ - B: batch size
+ - D: hidden dimension
+ """
+ if self.emb is not None:
+ emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
+ emb = self.linear(self.silu(emb))
+
+ # The weight layout has changed; use split_mod rather than chunk to separate the embedding.
+ emb = emb.view(emb.shape[0], -1, 6).permute(2, 0, 1)
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb
+
+ norm_x = self.norm(x)
+
+ if self.scale_shift != 0:
+ scale_msa.add_(self.scale_shift)
+ scale_mlp.add_(self.scale_shift)
+
+ norm_x_scaled = norm_x * scale_msa[:, None] + shift_msa[:, None]
+ return norm_x_scaled, gate_msa, shift_mlp, scale_mlp, gate_mlp
+
+
+class NunchakuAdaLayerNormZeroSingle(AdaLayerNormZeroSingle):
+ """
+ Nunchaku quantized AdaLayerNormZeroSingle.
+
+ Uses AWQW4A16Linear for quantized embedding projection. Suitable for single-branch normalization.
+
+ Parameters
+ ----------
+ other : AdaLayerNormZeroSingle
+ Source AdaLayerNormZeroSingle instance to copy weights and structure from.
+ scale_shift : float, optional
+ Value to add to scale parameters. Default is 1.0.
+ Nunchaku may have already fused the scale_shift into the linear weights, so you may want to set it to 0.
+
+ Notes
+ -----
+ - B: batch size
+ - D: hidden dimension
+ """
+
+ def __init__(self, other: AdaLayerNormZeroSingle, scale_shift: float = 1.0):
+ super(AdaLayerNormZeroSingle, self).__init__()
+ self.scale_shift = scale_shift
+ self.silu = other.silu
+ self.linear = AWQW4A16Linear.from_linear(other.linear)
+ self.norm = other.norm
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ emb: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Forward pass for quantized AdaLayerNormZeroSingle.
+
+ Parameters
+ ----------
+ x : torch.Tensor, shape (B, D), dtype float32/float16
+ Input tensor.
+ emb : Optional[torch.Tensor], shape (B, E), optional
+ Embedding tensor.
+
+ Returns
+ -------
+ norm_x_scaled : torch.Tensor, shape (B, D)
+ Normalized and scaled input.
+ gate_msa : torch.Tensor, shape (B, D)
+ Gate for MSA branch.
+
+ Notes
+ -----
+ - B: batch size
+ - D: hidden dimension
+ """
+ emb = self.linear(self.silu(emb))
+
+ # The weight layout has changed; use split_mod rather than chunk to separate the embedding.
+ emb = emb.view(emb.shape[0], -1, 3).permute(2, 0, 1)
+ shift_msa, scale_msa, gate_msa = emb
+
+ if self.scale_shift != 0:
+ scale_msa.add_(self.scale_shift)
+
+ norm_x = self.norm(x)
+ norm_x_scaled = norm_x * scale_msa[:, None] + shift_msa[:, None]
+ return norm_x_scaled, gate_msa
diff --git a/nunchaku/models/text_encoders/__init__.py b/nunchaku/models/text_encoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b386b5e049d48e9ab5e9d4a9f71ec9c60c363004
--- /dev/null
+++ b/nunchaku/models/text_encoders/__init__.py
@@ -0,0 +1,5 @@
+from .t5_encoder import NunchakuT5EncoderModel
+
+__all__ = [
+ "NunchakuT5EncoderModel",
+]
diff --git a/nunchaku/models/text_encoders/linear.py b/nunchaku/models/text_encoders/linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..a878c13c634ec67d3e9fb57ab67ee121df238c20
--- /dev/null
+++ b/nunchaku/models/text_encoders/linear.py
@@ -0,0 +1,238 @@
+# -*- coding: utf-8 -*-
+"""
+This module provides the :class:`W4Linear` quantized linear layer, which implements
+4-bit weight-only quantization for efficient inference.
+"""
+
+import torch
+import torch.nn as nn
+
+from ..._C.ops import gemm_awq, gemv_awq
+from .tinychat_utils import ceil_num_groups, convert_to_tinychat_w4x16y16_linear_weight
+
+__all__ = ["W4Linear"]
+
+
+class W4Linear(nn.Module):
+ """
+ 4-bit quantized linear layer with group-wise quantization.
+
+ Parameters
+ ----------
+ in_features : int
+ Number of input features.
+ out_features : int
+ Number of output features.
+ bias : bool, optional
+ If True, adds a learnable bias (default: False).
+ group_size : int, optional
+ Number of input channels per quantization group (default: 128).
+ If -1, uses the full input dimension as a single group.
+ dtype : torch.dtype, optional
+ Data type for quantization scales and zeros (default: torch.float16).
+ device : str or torch.device, optional
+ Device for weights and buffers (default: "cuda").
+
+ Attributes
+ ----------
+ in_features : int
+ out_features : int
+ group_size : int
+ qweight : torch.Tensor
+ Quantized weight tensor (int16).
+ scales : torch.Tensor
+ Per-group scale tensor.
+ scaled_zeros : torch.Tensor
+ Per-group zero-point tensor (scaled).
+ bias : torch.Tensor or None
+ Optional bias tensor.
+ """
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = False,
+ group_size: int = 128,
+ dtype: torch.dtype = torch.float16,
+ device: str | torch.device = "cuda",
+ ):
+ super().__init__()
+ assert dtype in (torch.float16, torch.bfloat16), f"Unsupported dtype: {dtype}"
+
+ self.in_features = in_features
+ self.out_features = out_features
+ self.group_size = group_size if group_size != -1 else in_features
+ assert self.in_features % self.group_size == 0
+ assert out_features % (32 // self.weight_bits) == 0
+ self.ceil_num_groups = ceil_num_groups(
+ in_features=self.in_features,
+ group_size=self.group_size,
+ weight_bits=self.weight_bits,
+ )
+
+ assert out_features % (self.interleave) == 0
+ self.register_buffer(
+ "qweight",
+ torch.zeros(
+ (
+ self.out_features // self.interleave,
+ self.in_features // (16 // self.weight_bits) * self.interleave,
+ ),
+ dtype=torch.int16,
+ device=device,
+ ),
+ )
+ self.register_buffer(
+ "scales",
+ torch.zeros((self.ceil_num_groups, self.out_features), dtype=dtype, device=device),
+ )
+ self.register_buffer(
+ "scaled_zeros",
+ torch.zeros((self.ceil_num_groups, self.out_features), dtype=dtype, device=device),
+ )
+ if bias:
+ self.register_buffer("bias", torch.zeros((out_features), dtype=dtype, device=device))
+ else:
+ self.bias = None
+
+ @property
+ def weight_bits(self) -> int:
+ """
+ Number of bits per quantized weight (always 4).
+ """
+ return 4
+
+ @property
+ def interleave(self) -> int:
+ """
+ Interleave factor for quantized weights (always 4).
+ """
+ return 4
+
+ @torch.no_grad()
+ def forward(self, x):
+ """
+ Forward pass.
+
+ Parameters
+ ----------
+ x : torch.Tensor
+ Input tensor of shape (..., in_features).
+
+ Returns
+ -------
+ torch.Tensor
+ Output tensor of shape (..., out_features).
+ """
+ if x.numel() / x.shape[-1] < 8:
+ out = gemv_awq(
+ x,
+ self.qweight,
+ self.scales,
+ self.scaled_zeros,
+ x.numel() // x.shape[-1],
+ self.out_features,
+ self.in_features,
+ self.group_size,
+ )
+ else:
+ if self.group_size != 128:
+ raise NotImplementedError("Kernel currently only supports group_size=128.")
+ out = gemm_awq(x, self.qweight, self.scales, self.scaled_zeros)
+ out = out + self.bias if self.bias is not None else out
+ return out
+
+ @staticmethod
+ def from_linear(
+ linear: nn.Linear,
+ group_size: int,
+ init_only: bool = False,
+ weight: torch.Tensor | None = None,
+ scale: torch.Tensor | None = None,
+ zero: torch.Tensor | None = None,
+ zero_pre_scaled: bool = False,
+ ) -> "W4Linear":
+ """
+ Convert a standard nn.Linear to a quantized W4Linear.
+
+ Parameters
+ ----------
+ linear : nn.Linear
+ The linear layer to convert.
+ group_size : int
+ Quantization group size.
+ init_only : bool, optional
+ If True, only initializes the quantized layer (default: False).
+ weight : torch.Tensor, optional
+ Precomputed quantized weight (default: None).
+ scale : torch.Tensor, optional
+ Precomputed scale tensor (default: None).
+ zero : torch.Tensor, optional
+ Precomputed zero-point tensor (default: None).
+ zero_pre_scaled : bool, optional
+ Whether the zero-point tensor is pre-scaled (default: False).
+
+ Returns
+ -------
+ W4Linear
+ Quantized linear layer.
+ """
+ assert isinstance(linear, nn.Linear)
+ weight = linear.weight.data if weight is None else weight.data
+ dtype, device = weight.dtype, weight.device
+ oc, ic = linear.out_features, linear.in_features
+ _linear = W4Linear(
+ in_features=ic,
+ out_features=oc,
+ bias=linear.bias is not None,
+ group_size=group_size,
+ dtype=dtype,
+ device=device,
+ )
+ if init_only:
+ return _linear
+ if linear.bias is not None:
+ _linear.bias.data.copy_(linear.bias.data)
+ if scale is None:
+ assert zero is None, "scale and zero point tensors should be provided together."
+ group_size = ic if group_size <= 0 else group_size
+ assert group_size <= ic, "group size should be less than or equal to input channel size."
+ assert ic % group_size == 0, "input channel size should be divisible by group size."
+ ng, gs = ic // group_size, group_size
+ weight = weight.to(dtype=torch.float32).view(oc, 1, ng, gs)
+ vmin, vmax = weight.amin(dim=-1, keepdim=True), weight.amax(dim=-1, keepdim=True)
+ scale = (vmax - vmin).div_(15)
+ scale[scale == 0] = 1.0
+ if zero_pre_scaled:
+ zero = vmin.neg_().div_(scale).round_().clamp_(0, 15)
+ weight = weight.div_(scale).add_(zero).round_().clamp_(0, 15).sub_(zero).mul_(scale)
+ else:
+ zero = vmin.neg_().clamp_min(0)
+ weight = weight.add_(zero).div_(scale).round_().clamp_(0, 15).mul_(scale).sub_(zero)
+ weight = weight.to(dtype=dtype).view(oc, ic)
+ scale = scale.to(dtype=dtype)
+ zero = zero.to(dtype=dtype)
+ weight, scale, zero = convert_to_tinychat_w4x16y16_linear_weight(
+ weight=weight,
+ scale=scale,
+ zero=zero,
+ group_size=group_size,
+ zero_pre_scaled=zero_pre_scaled,
+ )
+ _linear.qweight.data.copy_(weight)
+ _linear.scales.data.copy_(scale)
+ _linear.scaled_zeros.data.copy_(zero)
+ return _linear
+
+ def extra_repr(self) -> str:
+ """
+ Returns a string describing the layer configuration.
+ """
+ return "in_features={}, out_features={}, bias={}, weight_bits={}, group_size={}".format(
+ self.in_features,
+ self.out_features,
+ self.bias is not None,
+ self.weight_bits,
+ self.group_size,
+ )
diff --git a/nunchaku/models/text_encoders/t5_encoder.py b/nunchaku/models/text_encoders/t5_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..6223437978d92948a5355af3b525b0d13e0e1513
--- /dev/null
+++ b/nunchaku/models/text_encoders/t5_encoder.py
@@ -0,0 +1,116 @@
+"""
+The NunchakuT5EncoderModel class enables loading T5 encoder weights from safetensors files,
+automatically replacing supported linear layers with quantized :class:`~nunchaku.models.text_encoders.linear.W4Linear`
+modules for improved performance and memory efficiency.
+"""
+
+import json
+import logging
+import os
+from pathlib import Path
+
+import torch
+from accelerate import init_empty_weights
+from torch import nn
+from transformers import T5Config, T5EncoderModel
+
+from ...utils import load_state_dict_in_safetensors
+from .linear import W4Linear
+
+# Get log level from environment variable (default to INFO)
+log_level = os.getenv("LOG_LEVEL", "INFO").upper()
+
+# Configure logging
+logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
+logger = logging.getLogger(__name__)
+
+
+class NunchakuT5EncoderModel(T5EncoderModel):
+ """
+ Nunchaku T5 Encoder Model
+
+ Extends :class:`transformers.T5EncoderModel` to support quantized weights and
+ memory-efficient inference using :class:`~nunchaku.models.text_encoders.linear.W4Linear`.
+
+ This class provides a convenient interface for loading T5 encoder weights from
+ safetensors files, automatically replacing supported linear layers with quantized
+ modules for improved speed and reduced memory usage.
+
+ Example
+ -------
+ .. code-block:: python
+
+ model = NunchakuT5EncoderModel.from_pretrained(
+ "mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
+ )
+ """
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
+ """
+ Load a :class:`NunchakuT5EncoderModel` from a safetensors file.
+
+ This method loads the model configuration and weights from a safetensors file,
+ initializes the model on the 'meta' device (no memory allocation for weights),
+ and replaces supported linear layers with quantized :class:`~nunchaku.models.text_encoders.linear.W4Linear` modules.
+
+ Parameters
+ ----------
+ pretrained_model_name_or_path : str or os.PathLike
+ Path to the safetensors file containing the model weights and metadata.
+ torch_dtype : torch.dtype, optional
+ Data type for model initialization (default: ``torch.bfloat16``).
+ Set to ``torch.float16`` for Turing GPUs.
+ device : str or torch.device, optional
+ Device to load the model onto (default: ``"cuda"``).
+ If the model is loaded on CPU, it will be automatically moved to GPU.
+
+ Returns
+ -------
+ NunchakuT5EncoderModel
+ The loaded and quantized T5 encoder model.
+
+ Example
+ -------
+ .. code-block:: python
+
+ model = NunchakuT5EncoderModel.from_pretrained(
+ "mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
+ )
+ """
+ pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
+ state_dict, metadata = load_state_dict_in_safetensors(pretrained_model_name_or_path, return_metadata=True)
+
+ # Load the config file from metadata
+ config = json.loads(metadata["config"])
+ config = T5Config(**config)
+
+ # Initialize model on 'meta' device (no memory allocation for weights)
+ with init_empty_weights():
+ t5_encoder = T5EncoderModel(config).to(kwargs.get("torch_dtype", torch.bfloat16))
+
+ t5_encoder.eval()
+
+ # Load the model weights from the safetensors file and quantize supported linear layers
+ named_modules = {}
+ for name, module in t5_encoder.named_modules():
+ assert isinstance(name, str)
+ if isinstance(module, nn.Linear):
+ if f"{name}.qweight" in state_dict:
+ logger.debug(f"Switching {name} to W4Linear")
+ qmodule = W4Linear.from_linear(module, group_size=128, init_only=True)
+ # modeling_t5.py: T5DenseGatedActDense needs dtype of weight
+ qmodule.weight = torch.empty([1], dtype=module.weight.dtype, device=module.weight.device)
+
+ parent_name, child_name = name.rsplit(".", 1)
+ setattr(named_modules[parent_name], child_name, qmodule)
+ else:
+ named_modules[name] = module
+
+ device = kwargs.get("device", "cuda")
+ if isinstance(device, str):
+ device = torch.device(device)
+ t5_encoder.to_empty(device=device)
+ t5_encoder.load_state_dict(state_dict, strict=True)
+
+ return t5_encoder
diff --git a/nunchaku/models/text_encoders/tinychat_utils.py b/nunchaku/models/text_encoders/tinychat_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..03cea464e7371cd3e966aefef3a9842e9ba2caca
--- /dev/null
+++ b/nunchaku/models/text_encoders/tinychat_utils.py
@@ -0,0 +1,188 @@
+# -*- coding: utf-8 -*-
+"""
+This module provides utility functions for quantized linear layers in the TinyChat backend.
+"""
+
+import torch
+
+__all__ = ["ceil_num_groups", "convert_to_tinychat_w4x16y16_linear_weight"]
+
+
+def ceil_divide(x: int, divisor: int) -> int:
+ """
+ Compute the ceiling of integer division.
+
+ Parameters
+ ----------
+ x : int
+ Dividend.
+ divisor : int
+ Divisor.
+
+ Returns
+ -------
+ int
+ The smallest integer greater than or equal to ``x / divisor``.
+ """
+ return (x + divisor - 1) // divisor
+
+
+def ceil_num_groups(in_features: int, group_size: int, weight_bits: int = 4) -> int:
+ """
+ Calculate the padded number of quantization groups for TinyChat quantization.
+
+ This ensures the number of groups is compatible with TinyChat's packing and kernel requirements.
+
+ Parameters
+ ----------
+ in_features : int
+ Input channel size (number of input features).
+ group_size : int
+ Quantization group size.
+ weight_bits : int, optional
+ Number of bits per quantized weight (default: 4).
+
+ Returns
+ -------
+ int
+ The padded number of quantization groups.
+
+ Raises
+ ------
+ AssertionError
+ If ``in_features`` is not divisible by ``group_size``, or if ``weight_bits`` is not 4, 2, or 1.
+ NotImplementedError
+ If ``group_size`` is not one of the supported values (>=128, 64, 32).
+ """
+ assert in_features % group_size == 0, "input channel size should be divisible by group size."
+ num_groups = in_features // group_size
+ assert weight_bits in (4, 2, 1), "weight bits should be 4, 2, or 1."
+ pack_size = 32 // weight_bits # one INT32 contains `pack_size` elements of weights
+ num_packs = ceil_divide(num_groups, pack_size)
+ if group_size >= 128:
+ num_packs_factor = 1
+ elif group_size == 64:
+ num_packs_factor = 2
+ elif group_size == 32:
+ num_packs_factor = 4
+ else:
+ raise NotImplementedError("Unsupported group size for TinyChat quantization.")
+ # make sure num_packs is a multiple of num_packs_factor
+ num_packs = ceil_divide(num_packs, num_packs_factor) * num_packs_factor
+ num_groups = num_packs * pack_size
+ return num_groups
+
+
+def pack_w4(weight: torch.Tensor) -> torch.Tensor:
+ """
+ Pack quantized 4-bit weights into TinyChat's int16 format.
+
+ This function rearranges and packs 4-bit quantized weights (stored as int32) into
+ the format expected by TinyChat CUDA kernels.
+
+ Parameters
+ ----------
+ weight : torch.Tensor
+ Quantized weight tensor of shape (out_features, in_features), dtype int32.
+ The input channel dimension must be divisible by 32.
+
+ Returns
+ -------
+ torch.Tensor
+ Packed weight tensor of dtype int16.
+
+ Raises
+ ------
+ AssertionError
+ If input tensor is not int32 or input channel size is not divisible by 32.
+ """
+ assert weight.dtype == torch.int32, f"quantized weight should be torch.int32, but got {weight.dtype}."
+ oc, ic = weight.shape
+ assert ic % 32 == 0, "input channel size should be divisible by 32."
+ # [0, 1, ..., 31] -> [0, 8, 16, 24, 1, 9, 17, 25, ..., 7, 15, 23, 31]
+ weight = weight.view(-1, 4, 8)
+ weight = weight[:, 0] | (weight[:, 1] << 4) | (weight[:, 2] << 8) | (weight[:, 3] << 12)
+ weight = weight.view(oc // 4, 4, ic // 64, 16).permute(0, 2, 1, 3).reshape(oc // 4, ic)
+ return weight.to(torch.int16)
+
+
+def convert_to_tinychat_w4x16y16_linear_weight(
+ weight: torch.Tensor,
+ scale: torch.Tensor,
+ zero: torch.Tensor,
+ group_size: int = -1,
+ zero_pre_scaled: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Convert a floating-point weight tensor to TinyChat W4-X16-Y16 quantized linear format.
+
+ This function quantizes the input weights to 4 bits per value, applies group-wise
+ scaling and zero-point, and packs the result into the format expected by TinyChat
+ quantized linear layers.
+
+ Parameters
+ ----------
+ weight : torch.Tensor
+ Floating-point weight tensor of shape (out_features, in_features).
+ Must be of dtype ``torch.float16`` or ``torch.bfloat16``.
+ scale : torch.Tensor
+ Per-group scale tensor (can be broadcastable).
+ zero : torch.Tensor
+ Per-group zero-point tensor (can be broadcastable).
+ group_size : int, optional
+ Quantization group size. If set to -1 (default), uses the full input dimension as a single group.
+ zero_pre_scaled : bool, optional
+ If True, the zero tensor is already scaled by the scale tensor (default: False).
+
+ Returns
+ -------
+ tuple of torch.Tensor
+ - packed_weight : torch.Tensor
+ Packed quantized weight tensor (int16).
+ - packed_scale : torch.Tensor
+ Packed scale tensor (shape: [num_groups, out_features], dtype matches input).
+ - packed_zero : torch.Tensor
+ Packed zero-point tensor (shape: [num_groups, out_features], dtype matches input).
+
+ Raises
+ ------
+ AssertionError
+ If input types or shapes are invalid, or quantized values are out of range.
+
+ Example
+ -------
+ .. code-block:: python
+
+ qweight, qscale, qzero = convert_to_tinychat_w4x16y16_linear_weight(
+ weight, scale, zero, group_size=128
+ )
+ """
+ dtype, device = weight.dtype, weight.device
+ assert dtype in (torch.float16, torch.bfloat16), "currently tinychat only supports fp16 and bf16."
+ assert scale is not None, "scale tensor is required for quantization."
+ assert zero is not None, "zero point tensor is required for quantization."
+ weight = weight.to(dtype=torch.float32)
+ scale = scale.to(dtype=torch.float32, device=device)
+ zero = zero.to(dtype=torch.float32, device=device)
+ if zero_pre_scaled:
+ zero = zero * scale
+ oc, ic = weight.shape
+ group_size = ic if group_size <= 0 else group_size
+ assert group_size <= ic, "group size should be less than or equal to input channel size."
+ assert ic % group_size == 0, "input channel size should be divisible by group size."
+ ng = ic // group_size
+ if scale.numel() == 1:
+ scale = scale.view(1, 1).expand(oc, ng)
+ scale = scale.reshape(oc, ng).contiguous().view(oc, ng, 1)
+ if zero.numel() == 1:
+ zero = zero.view(1, 1).expand(oc, ng)
+ zero = zero.reshape(oc, ng).contiguous().view(oc, ng, 1)
+ weight = weight.view(oc, ng, -1).add_(zero).div_(scale).round_().view(oc, ic)
+ assert weight.min() >= 0 and weight.max() <= 15, "quantized weight should be in [0, 15]."
+ _weight = pack_w4(weight.to(torch.int32))
+ _ng = ceil_num_groups(ic, group_size, weight_bits=4)
+ _scale = torch.zeros((_ng, oc), dtype=dtype, device=device)
+ _zero = torch.zeros((_ng, oc), dtype=dtype, device=device)
+ _scale[:ng] = scale.view(oc, ng).t().to(dtype=dtype)
+ _zero[:ng] = zero.view(oc, ng).t().to(dtype=dtype).neg_()
+ return _weight, _scale, _zero
diff --git a/nunchaku/models/transformers/__init__.py b/nunchaku/models/transformers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f150e7636481f00859568449c32037dd1b456d34
--- /dev/null
+++ b/nunchaku/models/transformers/__init__.py
@@ -0,0 +1,5 @@
+from .transformer_flux import NunchakuFluxTransformer2dModel
+
+__all__ = [
+ "NunchakuFluxTransformer2dModel",
+]
diff --git a/nunchaku/models/transformers/transformer_flux.py b/nunchaku/models/transformers/transformer_flux.py
new file mode 100644
index 0000000000000000000000000000000000000000..149c18d49b0dc2bef01265f22a83d5afd347e83e
--- /dev/null
+++ b/nunchaku/models/transformers/transformer_flux.py
@@ -0,0 +1,991 @@
+"""
+Implements the :class:`NunchakuFluxTransformer2dModel`, a quantized transformer for Diffusers with efficient inference and LoRA support.
+"""
+
+import json
+import logging
+import os
+from pathlib import Path
+from typing import Any, Dict, Optional, Union
+
+import diffusers
+import torch
+from diffusers import FluxTransformer2DModel
+from diffusers.configuration_utils import register_to_config
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from huggingface_hub import utils
+from packaging.version import Version
+from safetensors.torch import load_file
+from torch import nn
+
+from ..._C import QuantizedFluxModel
+from ..._C import utils as cutils
+from ...lora.flux.nunchaku_converter import fuse_vectors, to_nunchaku
+from ...lora.flux.utils import is_nunchaku_format
+from ...utils import check_hardware_compatibility, get_precision, load_state_dict_in_safetensors, pad_tensor
+from .utils import NunchakuModelLoaderMixin
+
+SVD_RANK = 32
+
+# Get log level from environment variable (default to INFO)
+log_level = os.getenv("LOG_LEVEL", "INFO").upper()
+
+# Configure logging
+logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
+logger = logging.getLogger(__name__)
+
+
+class NunchakuFluxTransformerBlocks(nn.Module):
+ """
+ Wrapper for quantized Nunchaku FLUX transformer blocks.
+
+ This class manages the forward pass, rotary embedding packing, and optional
+ residual callbacks for ID embeddings.
+
+ Parameters
+ ----------
+ m : QuantizedFluxModel
+ The quantized transformer model.
+ device : str or torch.device
+ Device to run the model on.
+ """
+
+ def __init__(self, m: QuantizedFluxModel, device: str | torch.device):
+ super(NunchakuFluxTransformerBlocks, self).__init__()
+ self.m = m
+ self.dtype = torch.bfloat16 if m.isBF16() else torch.float16
+ self.device = device
+
+ @staticmethod
+ def pack_rotemb(rotemb: torch.Tensor) -> torch.Tensor:
+ """
+ Packs rotary embeddings for efficient computation.
+
+ Parameters
+ ----------
+ rotemb : torch.Tensor
+ Rotary embedding tensor of shape (B, M, D//2, 1, 2), dtype float32.
+
+ Returns
+ -------
+ torch.Tensor
+ Packed rotary embedding tensor of shape (B, M, D).
+ """
+ assert rotemb.dtype == torch.float32
+ B = rotemb.shape[0]
+ M = rotemb.shape[1]
+ D = rotemb.shape[2] * 2
+ assert rotemb.shape == (B, M, D // 2, 1, 2)
+ assert M % 16 == 0
+ assert D % 8 == 0
+ rotemb = rotemb.reshape(B, M // 16, 16, D // 8, 8)
+ rotemb = rotemb.permute(0, 1, 3, 2, 4)
+ # 16*8 pack, FP32 accumulator (C) format
+ # https://docs.nvidia.com/cuda/parallel-thread-execution/#mma-16816-c
+ ##########################################|--M--|--D--|
+ ##########################################|-3--4--5--6|
+ ########################################## : : : :
+ rotemb = rotemb.reshape(*rotemb.shape[0:3], 2, 8, 4, 2)
+ rotemb = rotemb.permute(0, 1, 2, 4, 5, 3, 6)
+ rotemb = rotemb.contiguous()
+ rotemb = rotemb.view(B, M, D)
+ return rotemb
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ image_rotary_emb: torch.Tensor,
+ id_embeddings=None,
+ id_weight=None,
+ joint_attention_kwargs=None,
+ controlnet_block_samples=None,
+ controlnet_single_block_samples=None,
+ skip_first_layer=False,
+ ):
+ """
+ Forward pass for the quantized transformer blocks.
+ It will call the forward method of ``m`` on the C backend.
+
+ Parameters
+ ----------
+ hidden_states : torch.Tensor
+ Input hidden states for image tokens.
+ temb : torch.Tensor
+ Temporal embedding tensor.
+ encoder_hidden_states : torch.Tensor
+ Input hidden states for text tokens.
+ image_rotary_emb : torch.Tensor
+ Rotary embedding tensor for all tokens.
+ id_embeddings : torch.Tensor, optional
+ Optional ID embeddings for residual callback.
+ id_weight : float, optional
+ Weight for ID embedding residual.
+ joint_attention_kwargs : dict, optional
+ Additional kwargs for joint attention.
+ controlnet_block_samples : list[torch.Tensor], optional
+ ControlNet block samples.
+ controlnet_single_block_samples : list[torch.Tensor], optional
+ ControlNet single block samples.
+ skip_first_layer : bool, optional
+ Whether to skip the first layer.
+
+ Returns
+ -------
+ tuple[torch.Tensor, torch.Tensor]
+ (encoder_hidden_states, hidden_states) after transformer blocks.
+ """
+ # batch_size = hidden_states.shape[0]
+ txt_tokens = encoder_hidden_states.shape[1]
+ img_tokens = hidden_states.shape[1]
+
+ self.id_embeddings = id_embeddings
+ self.id_weight = id_weight
+ self.pulid_ca_idx = 0
+ if self.id_embeddings is not None:
+ self.set_pulid_residual_callback()
+
+ original_dtype = hidden_states.dtype
+ original_device = hidden_states.device
+
+ hidden_states = hidden_states.to(self.dtype).to(self.device)
+ encoder_hidden_states = encoder_hidden_states.to(self.dtype).to(self.device)
+ temb = temb.to(self.dtype).to(self.device)
+ image_rotary_emb = image_rotary_emb.to(self.device)
+
+ if controlnet_block_samples is not None:
+ if len(controlnet_block_samples) > 0:
+ controlnet_block_samples = torch.stack(controlnet_block_samples).to(self.device)
+ else:
+ controlnet_block_samples = None
+
+ if controlnet_single_block_samples is not None:
+ if len(controlnet_single_block_samples) > 0:
+ controlnet_single_block_samples = torch.stack(controlnet_single_block_samples).to(self.device)
+ else:
+ controlnet_single_block_samples = None
+
+ assert image_rotary_emb.ndim == 6
+ assert image_rotary_emb.shape[0] == 1
+ assert image_rotary_emb.shape[1] == 1
+ assert image_rotary_emb.shape[2] == 1 * (txt_tokens + img_tokens)
+ # [1, tokens, head_dim / 2, 1, 2] (sincos)
+ image_rotary_emb = image_rotary_emb.reshape([1, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
+ rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] # .to(self.dtype)
+ rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype)
+ rotary_emb_single = image_rotary_emb # .to(self.dtype)
+
+ rotary_emb_txt = self.pack_rotemb(pad_tensor(rotary_emb_txt, 256, 1))
+ rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1))
+ rotary_emb_single = self.pack_rotemb(pad_tensor(rotary_emb_single, 256, 1))
+ hidden_states = self.m.forward(
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ rotary_emb_img,
+ rotary_emb_txt,
+ rotary_emb_single,
+ controlnet_block_samples,
+ controlnet_single_block_samples,
+ skip_first_layer,
+ )
+
+ if self.id_embeddings is not None:
+ self.reset_pulid_residual_callback()
+
+ hidden_states = hidden_states.to(original_dtype).to(original_device)
+
+ encoder_hidden_states = hidden_states[:, :txt_tokens, ...]
+ hidden_states = hidden_states[:, txt_tokens:, ...]
+
+ return encoder_hidden_states, hidden_states
+
+ def forward_layer_at(
+ self,
+ idx: int,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: torch.Tensor,
+ joint_attention_kwargs=None,
+ controlnet_block_samples=None,
+ controlnet_single_block_samples=None,
+ ):
+ """
+ Forward pass for a specific transformer layer in ``m``.
+
+ Parameters
+ ----------
+ idx : int
+ Index of the transformer layer.
+ hidden_states : torch.Tensor
+ Input hidden states for image tokens.
+ encoder_hidden_states : torch.Tensor
+ Input hidden states for text tokens.
+ temb : torch.Tensor
+ Temporal embedding tensor.
+ image_rotary_emb : torch.Tensor
+ Rotary embedding tensor for all tokens.
+ joint_attention_kwargs : dict, optional
+ Additional kwargs for joint attention.
+ controlnet_block_samples : list[torch.Tensor], optional
+ ControlNet block samples.
+ controlnet_single_block_samples : list[torch.Tensor], optional
+ ControlNet single block samples.
+
+ Returns
+ -------
+ tuple[torch.Tensor, torch.Tensor]
+ (encoder_hidden_states, hidden_states) after the specified layer.
+ """
+ # batch_size = hidden_states.shape[0]
+ txt_tokens = encoder_hidden_states.shape[1]
+ img_tokens = hidden_states.shape[1]
+
+ original_dtype = hidden_states.dtype
+ original_device = hidden_states.device
+
+ hidden_states = hidden_states.to(self.dtype).to(self.device)
+ encoder_hidden_states = encoder_hidden_states.to(self.dtype).to(self.device)
+ temb = temb.to(self.dtype).to(self.device)
+ image_rotary_emb = image_rotary_emb.to(self.device)
+
+ if controlnet_block_samples is not None:
+ controlnet_block_samples = torch.stack(controlnet_block_samples).to(self.device)
+ if controlnet_single_block_samples is not None:
+ controlnet_single_block_samples = torch.stack(controlnet_single_block_samples).to(self.device)
+
+ assert image_rotary_emb.ndim == 6
+ assert image_rotary_emb.shape[0] == 1
+ assert image_rotary_emb.shape[1] == 1
+ assert image_rotary_emb.shape[2] == 1 * (txt_tokens + img_tokens)
+ # [1, tokens, head_dim / 2, 1, 2] (sincos)
+ image_rotary_emb = image_rotary_emb.reshape([1, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
+ rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] # .to(self.dtype)
+ rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype)
+
+ rotary_emb_txt = self.pack_rotemb(pad_tensor(rotary_emb_txt, 256, 1))
+ rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1))
+
+ hidden_states, encoder_hidden_states = self.m.forward_layer(
+ idx,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ rotary_emb_img,
+ rotary_emb_txt,
+ controlnet_block_samples,
+ controlnet_single_block_samples,
+ )
+
+ hidden_states = hidden_states.to(original_dtype).to(original_device)
+ encoder_hidden_states = encoder_hidden_states.to(original_dtype).to(original_device)
+
+ return encoder_hidden_states, hidden_states
+
+ def set_pulid_residual_callback(self):
+ """
+ Sets the residual callback for PulID (personalized ID) embeddings.
+ """
+ id_embeddings = self.id_embeddings
+ pulid_ca = self.pulid_ca
+ pulid_ca_idx = [self.pulid_ca_idx]
+ id_weight = self.id_weight
+
+ def callback(hidden_states):
+ ip = id_weight * pulid_ca[pulid_ca_idx[0]](id_embeddings, hidden_states)
+ pulid_ca_idx[0] += 1
+ return ip
+
+ self.callback_holder = callback
+ self.m.set_residual_callback(callback)
+
+ def reset_pulid_residual_callback(self):
+ """
+ Resets the PulID residual callback to None.
+ """
+ self.callback_holder = None
+ self.m.set_residual_callback(None)
+
+ def __del__(self):
+ """
+ Destructor to reset the quantized model.
+ """
+ self.m.reset()
+
+ def norm1(
+ self,
+ hidden_states: torch.Tensor,
+ emb: torch.Tensor,
+ idx: int = 0,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Runs the norm_one_forward for a specific layer in ``m``.
+
+ Parameters
+ ----------
+ hidden_states : torch.Tensor
+ Input hidden states.
+ emb : torch.Tensor
+ Embedding tensor.
+ idx : int, optional
+ Layer index (default: 0).
+
+ Returns
+ -------
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
+ Output tensors from norm_one_forward.
+ """
+ return self.m.norm_one_forward(idx, hidden_states, emb)
+
+
+def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
+ """
+ Rotary positional embedding function.
+
+ Parameters
+ ----------
+ pos : torch.Tensor
+ Position tensor of shape (..., n).
+ dim : int
+ Embedding dimension (must be even).
+ theta : int
+ Rotary base.
+
+ Returns
+ -------
+ torch.Tensor
+ Rotary embedding tensor.
+ """
+ assert dim % 2 == 0, "The dimension must be even."
+
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
+ omega = 1.0 / (theta**scale)
+
+ batch_size, seq_length = pos.shape
+ out = torch.einsum("...n,d->...nd", pos, omega)
+
+ USE_SINCOS = True
+ if USE_SINCOS:
+ cos_out = torch.cos(out)
+ sin_out = torch.sin(out)
+ stacked_out = torch.stack([sin_out, cos_out], dim=-1)
+ out = stacked_out.view(batch_size, -1, dim // 2, 1, 2)
+ else:
+ out = out.view(batch_size, -1, dim // 2, 1, 1)
+
+ return out.float()
+
+
+class EmbedND(nn.Module):
+ """
+ Multi-dimensional rotary embedding module.
+
+ Parameters
+ ----------
+ dim : int
+ Embedding dimension.
+ theta : int
+ Rotary base.
+ axes_dim : list[int]
+ List of axis dimensions for each spatial axis.
+ """
+
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
+ super(EmbedND, self).__init__()
+ self.dim = dim
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
+ """
+ Computes rotary embeddings for multi-dimensional positions.
+
+ Parameters
+ ----------
+ ids : torch.Tensor
+ Position indices tensor of shape (..., n_axes).
+
+ Returns
+ -------
+ torch.Tensor
+ Rotary embedding tensor.
+ """
+ if Version(diffusers.__version__) >= Version("0.31.0"):
+ ids = ids[None, ...]
+ n_axes = ids.shape[-1]
+ emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
+ return emb.unsqueeze(1)
+
+
+def load_quantized_module(
+ path_or_state_dict: str | os.PathLike[str] | dict[str, torch.Tensor],
+ device: str | torch.device = "cuda",
+ use_fp4: bool = False,
+ offload: bool = False,
+ bf16: bool = True,
+) -> QuantizedFluxModel:
+ """
+ Loads a quantized Nunchaku FLUX model from a state dict or file.
+
+ Parameters
+ ----------
+ path_or_state_dict : str, os.PathLike, or dict
+ Path to the quantized model file or a state dict.
+ device : str or torch.device, optional
+ Device to load the model on (default: "cuda").
+ use_fp4 : bool, optional
+ Whether to use FP4 quantization (default: False).
+ offload : bool, optional
+ Whether to offload weights to CPU (default: False).
+ bf16 : bool, optional
+ Whether to use bfloat16 (default: True).
+
+ Returns
+ -------
+ QuantizedFluxModel
+ Loaded quantized model.
+ """
+ device = torch.device(device)
+ assert device.type == "cuda"
+ m = QuantizedFluxModel()
+ cutils.disable_memory_auto_release()
+ m.init(use_fp4, offload, bf16, 0 if device.index is None else device.index)
+ if isinstance(path_or_state_dict, dict):
+ m.loadDict(path_or_state_dict, True)
+ else:
+ m.load(str(path_or_state_dict))
+ return m
+
+
+class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoaderMixin):
+ """
+ Nunchaku FLUX Transformer 2D Model.
+
+ This class implements a quantized transformer model compatible with the Diffusers
+ library, supporting LoRA, rotary embeddings, and efficient inference.
+
+ Parameters
+ ----------
+ patch_size : int, optional
+ Patch size for input images (default: 1).
+ in_channels : int, optional
+ Number of input channels (default: 64).
+ out_channels : int or None, optional
+ Number of output channels (default: None).
+ num_layers : int, optional
+ Number of transformer layers (default: 19).
+ num_single_layers : int, optional
+ Number of single transformer layers (default: 38).
+ attention_head_dim : int, optional
+ Dimension of each attention head (default: 128).
+ num_attention_heads : int, optional
+ Number of attention heads (default: 24).
+ joint_attention_dim : int, optional
+ Joint attention dimension (default: 4096).
+ pooled_projection_dim : int, optional
+ Pooled projection dimension (default: 768).
+ guidance_embeds : bool, optional
+ Whether to use guidance embeddings (default: False).
+ axes_dims_rope : tuple[int], optional
+ Axes dimensions for rotary embeddings (default: (16, 56, 56)).
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 1,
+ in_channels: int = 64,
+ out_channels: int | None = None,
+ num_layers: int = 19,
+ num_single_layers: int = 38,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 24,
+ joint_attention_dim: int = 4096,
+ pooled_projection_dim: int = 768,
+ guidance_embeds: bool = False,
+ axes_dims_rope: tuple[int] = (16, 56, 56),
+ ):
+ super(NunchakuFluxTransformer2dModel, self).__init__(
+ patch_size=patch_size,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ num_layers=num_layers,
+ num_single_layers=num_single_layers,
+ attention_head_dim=attention_head_dim,
+ num_attention_heads=num_attention_heads,
+ joint_attention_dim=joint_attention_dim,
+ pooled_projection_dim=pooled_projection_dim,
+ guidance_embeds=guidance_embeds,
+ axes_dims_rope=axes_dims_rope,
+ )
+ # these state_dicts are used for supporting lora
+ self._unquantized_part_sd: dict[str, torch.Tensor] = {}
+ self._unquantized_part_loras: dict[str, torch.Tensor] = {}
+ self._quantized_part_sd: dict[str, torch.Tensor] = {}
+ self._quantized_part_vectors: dict[str, torch.Tensor] = {}
+ self._original_in_channels = in_channels
+
+ # ComfyUI LoRA related
+ self.comfy_lora_meta_list = []
+ self.comfy_lora_sd_list = []
+
+ @classmethod
+ @utils.validate_hf_hub_args
+ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
+ """
+ Loads a Nunchaku FLUX transformer model from pretrained weights.
+
+ Parameters
+ ----------
+ pretrained_model_name_or_path : str or os.PathLike
+ Path to the model directory or HuggingFace repo.
+ **kwargs
+ Additional keyword arguments for device, offload, torch_dtype, precision, etc.
+
+ Returns
+ -------
+ NunchakuFluxTransformer2dModel or (NunchakuFluxTransformer2dModel, dict)
+ The loaded model, and optionally metadata if `return_metadata=True`.
+ """
+ device = kwargs.get("device", "cuda")
+ if isinstance(device, str):
+ device = torch.device(device)
+ offload = kwargs.get("offload", False)
+ torch_dtype = kwargs.get("torch_dtype", torch.bfloat16)
+ precision = get_precision(kwargs.get("precision", "auto"), device, pretrained_model_name_or_path)
+ metadata = None
+
+ if isinstance(pretrained_model_name_or_path, str):
+ pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
+ if pretrained_model_name_or_path.is_file() or pretrained_model_name_or_path.name.endswith(
+ (".safetensors", ".sft")
+ ):
+ transformer, model_state_dict, metadata = cls._build_model(pretrained_model_name_or_path, **kwargs)
+ quantized_part_sd = {}
+ unquantized_part_sd = {}
+ for k, v in model_state_dict.items():
+ if k.startswith(("transformer_blocks.", "single_transformer_blocks.")):
+ quantized_part_sd[k] = v
+ else:
+ unquantized_part_sd[k] = v
+ precision = get_precision(device=device)
+ quantization_config = json.loads(metadata["quantization_config"])
+ check_hardware_compatibility(quantization_config, device)
+ else:
+ transformer, unquantized_part_path, transformer_block_path = cls._build_model_legacy(
+ pretrained_model_name_or_path, **kwargs
+ )
+
+ # get the default LoRA branch and all the vectors
+ quantized_part_sd = load_file(transformer_block_path)
+ unquantized_part_sd = load_file(unquantized_part_path)
+ new_quantized_part_sd = {}
+ for k, v in quantized_part_sd.items():
+ if v.ndim == 1:
+ new_quantized_part_sd[k] = v
+ elif "qweight" in k:
+ # only the shape information of this tensor is needed
+ new_quantized_part_sd[k] = v.to("meta")
+
+ # if the tensor has qweight, but does not have low-rank branch, we need to add some artificial tensors
+ for t in ["lora_up", "lora_down"]:
+ new_k = k.replace(".qweight", f".{t}")
+ if new_k not in quantized_part_sd:
+ oc, ic = v.shape
+ ic = ic * 2 # v is packed into INT8, so we need to double the size
+ new_quantized_part_sd[k.replace(".qweight", f".{t}")] = torch.zeros(
+ (0, ic) if t == "lora_down" else (oc, 0), device=v.device, dtype=torch.bfloat16
+ )
+
+ elif "lora" in k:
+ new_quantized_part_sd[k] = v
+ transformer._quantized_part_sd = new_quantized_part_sd
+ m = load_quantized_module(
+ quantized_part_sd,
+ device=device,
+ use_fp4=precision == "fp4",
+ offload=offload,
+ bf16=torch_dtype == torch.bfloat16,
+ )
+ transformer.inject_quantized_module(m, device)
+ transformer.to_empty(device=device)
+
+ transformer.load_state_dict(unquantized_part_sd, strict=False)
+ transformer._unquantized_part_sd = unquantized_part_sd
+
+ if kwargs.get("return_metadata", False):
+ return transformer, metadata
+ else:
+ return transformer
+
+ def inject_quantized_module(self, m: QuantizedFluxModel, device: str | torch.device = "cuda"):
+ """
+ Injects a quantized module into the model and sets up transformer blocks.
+
+ Parameters
+ ----------
+ m : QuantizedFluxModel
+ The quantized transformer model.
+ device : str or torch.device, optional
+ Device to run the model on (default: "cuda").
+
+ Returns
+ -------
+ self : NunchakuFluxTransformer2dModel
+ The model with injected quantized module.
+ """
+ print("Injecting quantized module")
+ self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=[16, 56, 56])
+
+ ### Compatible with the original forward method
+ self.transformer_blocks = nn.ModuleList([NunchakuFluxTransformerBlocks(m, device)])
+ self.single_transformer_blocks = nn.ModuleList([])
+
+ return self
+
+ def set_attention_impl(self, impl: str):
+ """
+ Set the attention implementation for the quantized transformer block.
+
+ Parameters
+ ----------
+ impl : str
+ Attention implementation to use. Supported values:
+
+ - ``"flashattn2"`` (default): Standard FlashAttention-2.
+ - ``"nunchaku-fp16"``: Uses FP16 attention accumulation, up to 1.2× faster than FlashAttention-2 on NVIDIA 30-, 40-, and 50-series GPUs.
+ """
+ block = self.transformer_blocks[0]
+ assert isinstance(block, NunchakuFluxTransformerBlocks)
+ block.m.setAttentionImpl(impl)
+
+ ### LoRA Related Functions
+
+ def _expand_module(self, module_name: str, new_shape: tuple[int, int]):
+ """
+ Expands a linear module to a new shape for LoRA compatibility.
+ Mostly for FLUX.1-tools LoRA which changes the input channels.
+
+ Parameters
+ ----------
+ module_name : str
+ Name of the module to expand.
+ new_shape : tuple[int, int]
+ New shape (out_features, in_features) for the module.
+ """
+ module = self.get_submodule(module_name)
+ assert isinstance(module, nn.Linear)
+ weight_shape = module.weight.shape
+ logger.info("Expand the shape of module {} from {} to {}".format(module_name, tuple(weight_shape), new_shape))
+ assert new_shape[0] >= weight_shape[0] and new_shape[1] >= weight_shape[1]
+ new_module = nn.Linear(
+ new_shape[1],
+ new_shape[0],
+ bias=module.bias is not None,
+ device=module.weight.device,
+ dtype=module.weight.dtype,
+ )
+ new_module.weight.data.zero_()
+ new_module.weight.data[: weight_shape[0], : weight_shape[1]] = module.weight.data
+ self._unquantized_part_sd[f"{module_name}.weight"] = new_module.weight.data.clone()
+ if new_module.bias is not None:
+ new_module.bias.data.zero_()
+ new_module.bias.data[: weight_shape[0]] = module.bias.data
+ self._unquantized_part_sd[f"{module_name}.bias"] = new_module.bias.data.clone()
+ parent_name = ".".join(module_name.split(".")[:-1])
+ parent_module = self.get_submodule(parent_name)
+ parent_module.add_module(module_name.split(".")[-1], new_module)
+
+ if module_name == "x_embedder":
+ new_value = int(new_module.weight.data.shape[1])
+ old_value = getattr(self.config, "in_channels")
+ if new_value != old_value:
+ logger.info(f"Update in_channels from {old_value} to {new_value}")
+ setattr(self.config, "in_channels", new_value)
+
+ def _update_unquantized_part_lora_params(self, strength: float = 1):
+ """
+ Updates the unquantized part of the model with LoRA parameters.
+
+ Parameters
+ ----------
+ strength : float, optional
+ LoRA scaling strength (default: 1).
+ """
+ # check if we need to expand the linear layers
+ device = next(self.parameters()).device
+ for k, v in self._unquantized_part_loras.items():
+ if "lora_A" in k:
+ lora_a = v
+ lora_b = self._unquantized_part_loras[k.replace(".lora_A.", ".lora_B.")]
+ diff_shape = (lora_b.shape[0], lora_a.shape[1])
+ weight_shape = self._unquantized_part_sd[k.replace(".lora_A.", ".")].shape
+ if diff_shape[0] > weight_shape[0] or diff_shape[1] > weight_shape[1]:
+ module_name = ".".join(k.split(".")[:-2])
+ self._expand_module(module_name, diff_shape)
+ elif v.ndim == 1:
+ diff_shape = v.shape
+ weight_shape = self._unquantized_part_sd[k].shape
+ if diff_shape[0] > weight_shape[0]:
+ assert diff_shape[0] >= weight_shape[0]
+ module_name = ".".join(k.split(".")[:-1])
+ module = self.get_submodule(module_name)
+ weight_shape = module.weight.shape
+ diff_shape = (diff_shape[0], weight_shape[1])
+ self._expand_module(module_name, diff_shape)
+ new_state_dict = {}
+ for k in self._unquantized_part_sd.keys():
+ v = self._unquantized_part_sd[k]
+ v = v.to(device)
+ self._unquantized_part_sd[k] = v
+
+ if v.ndim == 1 and k in self._unquantized_part_loras:
+ diff = strength * self._unquantized_part_loras[k]
+ if diff.shape[0] < v.shape[0]:
+ diff = torch.cat(
+ [diff, torch.zeros(v.shape[0] - diff.shape[0], device=device, dtype=v.dtype)], dim=0
+ )
+ new_state_dict[k] = v + diff
+ elif v.ndim == 2 and k.replace(".weight", ".lora_B.weight") in self._unquantized_part_loras:
+ lora_a = self._unquantized_part_loras[k.replace(".weight", ".lora_A.weight")]
+ lora_b = self._unquantized_part_loras[k.replace(".weight", ".lora_B.weight")]
+
+ if lora_a.shape[1] < v.shape[1]:
+ lora_a = torch.cat(
+ [
+ lora_a,
+ torch.zeros(lora_a.shape[0], v.shape[1] - lora_a.shape[1], device=device, dtype=v.dtype),
+ ],
+ dim=1,
+ )
+ if lora_b.shape[0] < v.shape[0]:
+ lora_b = torch.cat(
+ [
+ lora_b,
+ torch.zeros(v.shape[0] - lora_b.shape[0], lora_b.shape[1], device=device, dtype=v.dtype),
+ ],
+ dim=0,
+ )
+
+ diff = strength * (lora_b @ lora_a)
+ new_state_dict[k] = v + diff
+ else:
+ new_state_dict[k] = v
+ self.load_state_dict(new_state_dict, strict=True)
+
+ def update_lora_params(self, path_or_state_dict: str | dict[str, torch.Tensor]):
+ """
+ Update the model with new LoRA parameters.
+
+ Parameters
+ ----------
+ path_or_state_dict : str or dict
+ Path to a LoRA weights file or a state dict. The path supports:
+
+ - Local file path, e.g., ``"/path/to/your/lora.safetensors"``
+ - HuggingFace repo with file, e.g., ``"user/repo/lora.safetensors"``
+ (automatically downloaded and cached)
+ """
+ if isinstance(path_or_state_dict, dict):
+ state_dict = {
+ k: v for k, v in path_or_state_dict.items()
+ } # copy a new one to avoid modifying the original one
+ else:
+ state_dict = load_state_dict_in_safetensors(path_or_state_dict)
+
+ if not is_nunchaku_format(state_dict):
+ state_dict = to_nunchaku(state_dict, base_sd=self._quantized_part_sd)
+
+ unquantized_part_loras = {}
+ for k, v in list(state_dict.items()):
+ device = next(self.parameters()).device
+ if "transformer_blocks" not in k:
+ unquantized_part_loras[k] = state_dict.pop(k).to(device)
+
+ if len(self._unquantized_part_loras) > 0 or len(unquantized_part_loras) > 0:
+ self._unquantized_part_loras = unquantized_part_loras
+
+ self._unquantized_part_sd = {k: v for k, v in self._unquantized_part_sd.items() if "pulid_ca" not in k}
+ self._update_unquantized_part_lora_params(1)
+
+ quantized_part_vectors = {}
+ for k, v in list(state_dict.items()):
+ if v.ndim == 1:
+ quantized_part_vectors[k] = state_dict.pop(k)
+ if len(self._quantized_part_vectors) > 0 or len(quantized_part_vectors) > 0:
+ self._quantized_part_vectors = quantized_part_vectors
+ updated_vectors = fuse_vectors(quantized_part_vectors, self._quantized_part_sd, 1)
+ state_dict.update(updated_vectors)
+
+ # Get the vectors from the quantized part
+
+ block = self.transformer_blocks[0]
+ assert isinstance(block, NunchakuFluxTransformerBlocks)
+
+ block.m.loadDict(state_dict, True)
+
+ def set_lora_strength(self, strength: float = 1):
+ """
+ Sets the LoRA scaling strength for the model.
+
+ Note: This function can only be used with a single LoRA. For multiple LoRAs,
+ please fuse the LoRA scale into the weights.
+
+ Parameters
+ ----------
+ strength : float, optional
+ LoRA scaling strength (default: 1).
+
+ Note: This function will change the strength of all the LoRAs. So only use it when you only have a single LoRA.
+ """
+ block = self.transformer_blocks[0]
+ assert isinstance(block, NunchakuFluxTransformerBlocks)
+ block.m.setLoraScale(SVD_RANK, strength)
+ if len(self._unquantized_part_loras) > 0:
+ self._update_unquantized_part_lora_params(strength)
+ if len(self._quantized_part_vectors) > 0:
+ vector_dict = fuse_vectors(self._quantized_part_vectors, self._quantized_part_sd, strength)
+ block.m.loadDict(vector_dict, True)
+
+ def reset_x_embedder(self):
+ """
+ Resets the x_embedder module if the input channel count has changed.
+ This is used for removing the effect of FLUX.1-tools LoRA which changes the input channels.
+ """
+ # if change the model in channels, we need to update the x_embedder
+ if self._original_in_channels != self.config.in_channels:
+ assert self._original_in_channels < self.config.in_channels
+ old_module = self.x_embedder
+ new_module = nn.Linear(
+ in_features=self._original_in_channels,
+ out_features=old_module.out_features,
+ bias=old_module.bias is not None,
+ device=old_module.weight.device,
+ dtype=old_module.weight.dtype,
+ )
+ new_module.weight.data.copy_(old_module.weight.data[: new_module.out_features, : new_module.in_features])
+ self._unquantized_part_sd["x_embedder.weight"] = new_module.weight.data.clone()
+ if new_module.bias is not None:
+ new_module.bias.data.zero_()
+ new_module.bias.data.copy_(old_module.bias.data[: new_module.out_features])
+ self._unquantized_part_sd["x_embedder.bias"] = new_module.bias.data.clone()
+ self.x_embedder = new_module
+ setattr(self.config, "in_channels", self._original_in_channels)
+
+ def reset_lora(self):
+ """
+ Resets all LoRA parameters to their default state.
+ """
+ unquantized_part_loras = {}
+ if len(self._unquantized_part_loras) > 0 or len(unquantized_part_loras) > 0:
+ self._unquantized_part_loras = unquantized_part_loras
+ self._update_unquantized_part_lora_params(1)
+ state_dict = {k: v for k, v in self._quantized_part_sd.items() if "lora" in k}
+ quantized_part_vectors = {}
+ if len(self._quantized_part_vectors) > 0 or len(quantized_part_vectors) > 0:
+ self._quantized_part_vectors = quantized_part_vectors
+ updated_vectors = fuse_vectors(quantized_part_vectors, self._quantized_part_sd, 1)
+ state_dict.update(updated_vectors)
+ self.transformer_blocks[0].m.loadDict(state_dict, True)
+ self.reset_x_embedder()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ pooled_projections: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_block_samples=None,
+ controlnet_single_block_samples=None,
+ return_dict: bool = True,
+ controlnet_blocks_repeat: bool = False,
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ """
+ Forward pass for the Nunchaku FLUX transformer model.
+
+ This method is compatible with the Diffusers pipeline and supports LoRA,
+ rotary embeddings, and ControlNet.
+
+ Parameters
+ ----------
+ hidden_states : torch.FloatTensor
+ Input hidden states of shape (batch_size, channel, height, width).
+ encoder_hidden_states : torch.FloatTensor, optional
+ Conditional embeddings (e.g., prompt embeddings) of shape (batch_size, sequence_len, embed_dims).
+ pooled_projections : torch.FloatTensor, optional
+ Embeddings projected from the input conditions.
+ timestep : torch.LongTensor, optional
+ Denoising step.
+ img_ids : torch.Tensor, optional
+ Image token indices.
+ txt_ids : torch.Tensor, optional
+ Text token indices.
+ guidance : torch.Tensor, optional
+ Guidance tensor for classifier-free guidance.
+ joint_attention_kwargs : dict, optional
+ Additional kwargs for joint attention.
+ controlnet_block_samples : list[torch.Tensor], optional
+ ControlNet block samples.
+ controlnet_single_block_samples : list[torch.Tensor], optional
+ ControlNet single block samples.
+ return_dict : bool, optional
+ Whether to return a Transformer2DModelOutput (default: True).
+ controlnet_blocks_repeat : bool, optional
+ Whether to repeat ControlNet blocks (default: False).
+
+ Returns
+ -------
+ torch.FloatTensor or Transformer2DModelOutput
+ Output tensor or output object containing the sample.
+ """
+ hidden_states = self.x_embedder(hidden_states)
+
+ timestep = timestep.to(hidden_states.dtype) * 1000
+ if guidance is not None:
+ guidance = guidance.to(hidden_states.dtype) * 1000
+ else:
+ guidance = None
+
+ temb = (
+ self.time_text_embed(timestep, pooled_projections)
+ if guidance is None
+ else self.time_text_embed(timestep, guidance, pooled_projections)
+ )
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ if txt_ids.ndim == 3:
+ txt_ids = txt_ids[0]
+ if img_ids.ndim == 3:
+ img_ids = img_ids[0]
+
+ ids = torch.cat((txt_ids, img_ids), dim=0)
+ image_rotary_emb = self.pos_embed(ids)
+
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
+
+ nunchaku_block = self.transformer_blocks[0]
+ encoder_hidden_states, hidden_states = nunchaku_block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ controlnet_block_samples=controlnet_block_samples,
+ controlnet_single_block_samples=controlnet_single_block_samples,
+ )
+ hidden_states = self.norm_out(hidden_states, temb)
+ output = self.proj_out(hidden_states)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/nunchaku/models/transformers/transformer_flux_v2.py b/nunchaku/models/transformers/transformer_flux_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2b2d9d052af855bdad2d88b5b2fff0375463da0
--- /dev/null
+++ b/nunchaku/models/transformers/transformer_flux_v2.py
@@ -0,0 +1,646 @@
+"""
+This module provides Nunchaku FluxTransformer2DModel and its building blocks in Python.
+"""
+
+import json
+import os
+from pathlib import Path
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.models.transformers.transformer_flux import (
+ FluxAttention,
+ FluxSingleTransformerBlock,
+ FluxTransformer2DModel,
+ FluxTransformerBlock,
+)
+from huggingface_hub import utils
+from torch.nn import GELU
+
+from ...ops.fused import fused_gelu_mlp
+from ...utils import get_precision, pad_tensor
+from ..attention import NunchakuBaseAttention, NunchakuFeedForward
+from ..attention_processors.flux import NunchakuFluxFA2Processor, NunchakuFluxFP16AttnProcessor
+from ..embeddings import NunchakuFluxPosEmbed, pack_rotemb
+from ..linear import SVDQW4A4Linear
+from ..normalization import NunchakuAdaLayerNormZero, NunchakuAdaLayerNormZeroSingle
+from ..utils import fuse_linears
+from .utils import NunchakuModelLoaderMixin
+
+
+class NunchakuFluxAttention(NunchakuBaseAttention):
+ """
+ Nunchaku-optimized FluxAttention module with quantized and fused QKV projections.
+
+ Parameters
+ ----------
+ other : FluxAttention
+ The original FluxAttention module to wrap and quantize.
+ processor : str, optional
+ The attention processor to use ("flashattn2" or "nunchaku-fp16").
+ **kwargs
+ Additional arguments for quantization.
+ """
+
+ def __init__(self, other: FluxAttention, processor: str = "flashattn2", **kwargs):
+ super(NunchakuFluxAttention, self).__init__(processor)
+ self.head_dim = other.head_dim
+ self.inner_dim = other.inner_dim
+ self.query_dim = other.query_dim
+ self.use_bias = other.use_bias
+ self.dropout = other.dropout
+ self.out_dim = other.out_dim
+ self.context_pre_only = other.context_pre_only
+ self.pre_only = other.pre_only
+ self.heads = other.heads
+ self.added_kv_proj_dim = other.added_kv_proj_dim
+ self.added_proj_bias = other.added_proj_bias
+
+ self.norm_q = other.norm_q
+ self.norm_k = other.norm_k
+
+ # Fuse the QKV projections for efficiency.
+ with torch.device("meta"):
+ to_qkv = fuse_linears([other.to_q, other.to_k, other.to_v])
+ self.to_qkv = SVDQW4A4Linear.from_linear(to_qkv, **kwargs)
+
+ if not self.pre_only:
+ self.to_out = other.to_out
+ self.to_out[0] = SVDQW4A4Linear.from_linear(self.to_out[0], **kwargs)
+
+ if self.added_kv_proj_dim is not None:
+ self.norm_added_q = other.norm_added_q
+ self.norm_added_k = other.norm_added_k
+
+ # Fuse the additional QKV projections.
+ with torch.device("meta"):
+ add_qkv_proj = fuse_linears([other.add_q_proj, other.add_k_proj, other.add_v_proj])
+ self.add_qkv_proj = SVDQW4A4Linear.from_linear(add_qkv_proj, **kwargs)
+ self.to_add_out = SVDQW4A4Linear.from_linear(other.to_add_out, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor = None,
+ **kwargs,
+ ):
+ """
+ Forward pass for NunchakuFluxAttention.
+
+ Parameters
+ ----------
+ hidden_states : torch.Tensor
+ Input tensor.
+ encoder_hidden_states : torch.Tensor, optional
+ Encoder hidden states for cross-attention.
+ attention_mask : torch.Tensor, optional
+ Attention mask.
+ image_rotary_emb : tuple or torch.Tensor, optional
+ Rotary embeddings for image/text tokens.
+ **kwargs
+ Additional arguments.
+
+ Returns
+ -------
+ Output of the attention processor.
+ """
+ return self.processor(
+ attn=self,
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ def set_processor(self, processor: str):
+ """
+ Set the attention processor.
+
+ Parameters
+ ----------
+ processor : str
+ Name of the processor ("flashattn2" or "nunchaku-fp16").
+
+ - ``"flashattn2"``: Standard FlashAttention-2. See :class:`~nunchaku.models.attention_processors.flux.NunchakuFluxFA2Processor`.
+ - ``"nunchaku-fp16"``: Uses FP16 attention accumulation, up to 1.2× faster than FlashAttention-2 on NVIDIA 30-, 40-, and 50-series GPUs. See :class:`~nunchaku.models.attention_processors.flux.NunchakuFluxFP16AttnProcessor`.
+
+ Raises
+ ------
+ ValueError
+ If the processor is not supported.
+ """
+ if processor == "flashattn2":
+ self.processor = NunchakuFluxFA2Processor()
+ elif processor == "nunchaku-fp16":
+ self.processor = NunchakuFluxFP16AttnProcessor()
+ else:
+ raise ValueError(f"Processor {processor} is not supported")
+
+
+class NunchakuFluxTransformerBlock(FluxTransformerBlock):
+ """
+ Nunchaku-optimized FluxTransformerBlock with quantized attention and feedforward layers.
+
+ Parameters
+ ----------
+ block : FluxTransformerBlock
+ The original block to wrap and quantize.
+ scale_shift : float, optional
+ Value to add to scale parameters. Default is 1.0.
+ Nunchaku may have already fused the scale_shift into the linear weights, so you may want to set it to 0.
+ **kwargs
+ Additional arguments for quantization.
+ """
+
+ def __init__(self, block: FluxTransformerBlock, scale_shift: float = 1, **kwargs):
+ super(FluxTransformerBlock, self).__init__()
+ self.scale_shift = scale_shift
+
+ # The scale_shift=1 from AdaLayerNormZero has already been fused into the linear weights,
+ # so we set scale_shift=0 here to avoid applying it again.
+ self.norm1 = NunchakuAdaLayerNormZero(block.norm1, scale_shift=scale_shift)
+ self.norm1_context = NunchakuAdaLayerNormZero(block.norm1_context, scale_shift=scale_shift)
+
+ self.attn = NunchakuFluxAttention(block.attn, **kwargs)
+ self.norm2 = block.norm2
+ self.norm2_context = block.norm2_context
+ self.ff = NunchakuFeedForward(block.ff, **kwargs)
+ self.ff_context = NunchakuFeedForward(block.ff_context, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ """
+ Forward pass for the transformer block.
+
+ Parameters
+ ----------
+ hidden_states : torch.Tensor
+ Input hidden states.
+ encoder_hidden_states : torch.Tensor
+ Encoder hidden states for cross-attention.
+ temb : torch.Tensor
+ Time or conditioning embedding.
+ image_rotary_emb : tuple of torch.Tensor, optional
+ Rotary embeddings for image/text tokens.
+ joint_attention_kwargs : dict, optional
+ Additional attention arguments (not supported).
+
+ Returns
+ -------
+ tuple
+ (encoder_hidden_states, hidden_states) after block processing.
+
+ Raises
+ ------
+ NotImplementedError
+ If joint_attention_kwargs is provided.
+ """
+ if joint_attention_kwargs is not None and len(joint_attention_kwargs) > 0:
+ raise NotImplementedError("joint_attention_kwargs is not supported")
+
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
+
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
+ encoder_hidden_states, emb=temb
+ )
+
+ joint_attention_kwargs = joint_attention_kwargs or {}
+
+ # Attention.
+ attention_outputs = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+
+ if len(attention_outputs) == 2:
+ attn_output, context_attn_output = attention_outputs
+ elif len(attention_outputs) == 3:
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
+
+ # Process attention outputs for the `hidden_states`.
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = hidden_states + attn_output
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * scale_mlp[:, None] + shift_mlp[:, None]
+
+ ff_output = self.ff(norm_hidden_states)
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = hidden_states + ff_output
+ if len(attention_outputs) == 3:
+ hidden_states = hidden_states + ip_attn_output
+
+ # Process attention outputs for the `encoder_hidden_states`.
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
+
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * c_scale_mlp[:, None] + c_shift_mlp[:, None]
+
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+ if encoder_hidden_states.dtype == torch.float16:
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
+
+ return encoder_hidden_states, hidden_states
+
+
+class NunchakuFluxSingleTransformerBlock(FluxSingleTransformerBlock):
+ """
+ Nunchaku-optimized single transformer block with quantized attention and MLP.
+
+ Parameters
+ ----------
+ block : FluxSingleTransformerBlock
+ The original block to wrap and quantize.
+ scale_shift : float, optional
+ Value to add to scale parameters. Default is 1.0.
+ Nunchaku may have already fused the scale_shift into the linear weights, so you may want to set it to 0.
+ **kwargs
+ Additional arguments for quantization.
+ """
+
+ def __init__(self, block: FluxSingleTransformerBlock, scale_shift: float = 1, **kwargs):
+ super(FluxSingleTransformerBlock, self).__init__()
+ self.mlp_hidden_dim = block.mlp_hidden_dim
+ self.norm = block.norm
+ self.norm = NunchakuAdaLayerNormZeroSingle(block.norm, scale_shift=scale_shift)
+
+ self.mlp_fc1 = SVDQW4A4Linear.from_linear(block.proj_mlp, **kwargs)
+ self.act_mlp = block.act_mlp
+ self.mlp_fc2 = SVDQW4A4Linear.from_linear(block.proj_out, in_features=self.mlp_hidden_dim, **kwargs)
+ # For int4, we shift the activation of mlp_fc2 to make it unsigned.
+ self.mlp_fc2.act_unsigned = self.mlp_fc2.precision != "nvfp4"
+
+ self.attn = NunchakuFluxAttention(block.attn, **kwargs)
+ self.attn.to_out = SVDQW4A4Linear.from_linear(block.proj_out, in_features=self.mlp_fc1.in_features, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> torch.Tensor:
+ """
+ Forward pass for the single transformer block.
+
+ Parameters
+ ----------
+ hidden_states : torch.Tensor
+ Input hidden states.
+ temb : torch.Tensor
+ Time or conditioning embedding.
+ image_rotary_emb : tuple of torch.Tensor, optional
+ Rotary embeddings for tokens.
+ joint_attention_kwargs : dict, optional
+ Additional attention arguments.
+
+ Returns
+ -------
+ torch.Tensor
+ Output hidden states after block processing.
+ """
+ residual = hidden_states
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
+
+ # Feedforward
+ if isinstance(self.act_mlp, GELU):
+ # Use fused GELU MLP for efficiency.
+ mlp_hidden_states = fused_gelu_mlp(norm_hidden_states, self.mlp_fc1, self.mlp_fc2)
+ else:
+ # Fallback to original MLP.
+ mlp_hidden_states = self.mlp_fc1(norm_hidden_states)
+ mlp_hidden_states = self.act_mlp(mlp_hidden_states)
+ mlp_hidden_states = self.mlp_fc2(mlp_hidden_states)
+
+ # Attention
+ joint_attention_kwargs = joint_attention_kwargs or {}
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, **joint_attention_kwargs
+ )
+
+ hidden_states = attn_output + mlp_hidden_states
+ gate = gate.unsqueeze(1)
+ hidden_states = gate * hidden_states
+ hidden_states = residual + hidden_states
+ if hidden_states.dtype == torch.float16:
+ hidden_states = hidden_states.clip(-65504, 65504)
+
+ return hidden_states
+
+
+class NunchakuFluxTransformer2DModelV2(FluxTransformer2DModel, NunchakuModelLoaderMixin):
+ """
+ Nunchaku-optimized FluxTransformer2DModel.
+ """
+
+ def _patch_model(self, **kwargs):
+ """
+ Patch the model with :class:`~nunchaku.models.transformers.transformer_flux_v2.NunchakuFluxTransformerBlock`
+ and :class:`~nunchaku.models.transformers.transformer_flux_v2.NunchakuFluxSingleTransformerBlock`.
+
+ Parameters
+ ----------
+ **kwargs
+ Additional arguments for quantization.
+
+ Returns
+ -------
+ self : NunchakuFluxTransformer2DModelV2
+ The patched model.
+ """
+ self.pos_embed = NunchakuFluxPosEmbed(dim=self.inner_dim, theta=10000, axes_dim=self.pos_embed.axes_dim)
+ for i, block in enumerate(self.transformer_blocks):
+ self.transformer_blocks[i] = NunchakuFluxTransformerBlock(block, scale_shift=0, **kwargs)
+ for i, block in enumerate(self.single_transformer_blocks):
+ self.single_transformer_blocks[i] = NunchakuFluxSingleTransformerBlock(block, scale_shift=0, **kwargs)
+ return self
+
+ @classmethod
+ @utils.validate_hf_hub_args
+ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
+ """
+ Load a pretrained NunchakuFluxTransformer2DModelV2 from a safetensors file.
+
+ Parameters
+ ----------
+ pretrained_model_name_or_path : str or os.PathLike
+ Path to the safetensors file. It can be a local file or a remote HuggingFace path.
+ **kwargs
+ Additional arguments (e.g., device, torch_dtype).
+
+ Returns
+ -------
+ NunchakuFluxTransformer2DModelV2
+ The loaded and quantized model.
+
+ Raises
+ ------
+ NotImplementedError
+ If offload is requested.
+ AssertionError
+ If the file is not a safetensors file.
+ """
+ device = kwargs.get("device", "cpu")
+ offload = kwargs.get("offload", False)
+
+ if offload:
+ raise NotImplementedError("Offload is not supported for FluxTransformer2DModelV2")
+
+ torch_dtype = kwargs.get("torch_dtype", torch.bfloat16)
+
+ if isinstance(pretrained_model_name_or_path, str):
+ pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
+
+ assert pretrained_model_name_or_path.is_file() or pretrained_model_name_or_path.name.endswith(
+ (".safetensors", ".sft")
+ ), "Only safetensors are supported"
+ transformer, model_state_dict, metadata = cls._build_model(pretrained_model_name_or_path, **kwargs)
+ quantization_config = json.loads(metadata.get("quantization_config", "{}"))
+ rank = quantization_config.get("rank", 32)
+ transformer = transformer.to(torch_dtype)
+
+ precision = get_precision()
+ if precision == "fp4":
+ precision = "nvfp4"
+ transformer._patch_model(precision=precision, rank=rank)
+
+ transformer = transformer.to_empty(device=device)
+ converted_state_dict = convert_flux_state_dict(model_state_dict)
+
+ state_dict = transformer.state_dict()
+
+ for k in state_dict.keys():
+ if k not in converted_state_dict:
+ assert ".wcscales" in k
+ converted_state_dict[k] = torch.ones_like(state_dict[k])
+ else:
+ assert state_dict[k].dtype == converted_state_dict[k].dtype
+
+ # Load the wtscale from the converted state dict.
+ for n, m in transformer.named_modules():
+ if isinstance(m, SVDQW4A4Linear):
+ if m.wtscale is not None:
+ m.wtscale = converted_state_dict.pop(f"{n}.wtscale", 1.0)
+
+ transformer.load_state_dict(converted_state_dict)
+
+ return transformer
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ pooled_projections: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_block_samples=None,
+ controlnet_single_block_samples=None,
+ return_dict: bool = True,
+ controlnet_blocks_repeat: bool = False,
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
+ """
+ Forward pass for the NunchakuFluxTransformer2DModelV2.
+
+ Parameters
+ ----------
+ hidden_states : torch.Tensor
+ Input hidden states of shape (batch_size, image_sequence_length, in_channels).
+ encoder_hidden_states : torch.Tensor, optional
+ Conditional embeddings (e.g., from text).
+ pooled_projections : torch.Tensor, optional
+ Projected embeddings from input conditions.
+ timestep : torch.LongTensor, optional
+ Denoising step.
+ img_ids : torch.Tensor, optional
+ Image token IDs.
+ txt_ids : torch.Tensor, optional
+ Text token IDs.
+ guidance : torch.Tensor, optional
+ Guidance tensor for classifier-free guidance.
+ joint_attention_kwargs : dict, optional
+ Additional attention arguments.
+ controlnet_block_samples : any, optional
+ Not supported.
+ controlnet_single_block_samples : any, optional
+ Not supported.
+ return_dict : bool, optional
+ Whether to return a Transformer2DModelOutput (default: True).
+ controlnet_blocks_repeat : bool, optional
+ Not supported.
+
+ Returns
+ -------
+ Transformer2DModelOutput or tuple
+ Output sample tensor or output tuple.
+
+ Raises
+ ------
+ NotImplementedError
+ If controlnet is requested.
+ """
+ hidden_states = self.x_embedder(hidden_states)
+
+ timestep = timestep.to(hidden_states.dtype) * 1000
+ if guidance is not None:
+ guidance = guidance.to(hidden_states.dtype) * 1000
+
+ temb = (
+ self.time_text_embed(timestep, pooled_projections)
+ if guidance is None
+ else self.time_text_embed(timestep, guidance, pooled_projections)
+ )
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ if txt_ids.ndim == 3:
+ txt_ids = txt_ids[0]
+ if img_ids.ndim == 3:
+ img_ids = img_ids[0]
+
+ ids = torch.cat((txt_ids, img_ids), dim=0)
+ image_rotary_emb = self.pos_embed(ids)
+
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
+
+ txt_tokens = encoder_hidden_states.shape[1]
+ img_tokens = hidden_states.shape[1]
+
+ assert image_rotary_emb.ndim == 6
+ assert image_rotary_emb.shape[0] == 1
+ assert image_rotary_emb.shape[1] == 1
+ assert image_rotary_emb.shape[2] == 1 * (txt_tokens + img_tokens)
+ # [1, tokens, head_dim / 2, 1, 2] (sincos)
+ image_rotary_emb = image_rotary_emb.reshape([1, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
+ rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] # .to(self.dtype)
+ rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype)
+ rotary_emb_single = image_rotary_emb
+
+ rotary_emb_txt = pack_rotemb(pad_tensor(rotary_emb_txt, 256, 1))
+ rotary_emb_img = pack_rotemb(pad_tensor(rotary_emb_img, 256, 1))
+ rotary_emb_single = pack_rotemb(pad_tensor(rotary_emb_single, 256, 1))
+
+ for index_block, block in enumerate(self.transformer_blocks):
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ image_rotary_emb=(rotary_emb_img, rotary_emb_txt),
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+
+ # Controlnet residual (not supported for now)
+ if controlnet_block_samples is not None:
+ raise NotImplementedError("Controlnet is not supported for FluxTransformer2DModelV2 for now")
+
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+ for index_block, block in enumerate(self.single_transformer_blocks):
+ hidden_states = block(
+ hidden_states=hidden_states,
+ temb=temb,
+ image_rotary_emb=rotary_emb_single,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+
+ # Controlnet residual (not supported for now)
+ if controlnet_single_block_samples is not None:
+ raise NotImplementedError("Controlnet is not supported for FluxTransformer2DModelV2 for now")
+
+ hidden_states = hidden_states[:, txt_tokens:]
+ hidden_states = self.norm_out(hidden_states, temb)
+ output = self.proj_out(hidden_states)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
+
+
+def convert_flux_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """
+ Convert a state dict from the :class:`~nunchaku.models.transformers.transformer_flux.NunchakuFluxTransformer2dModel`
+ format to :class:`~nunchaku.models.transformers.transformer_flux_v2.NunchakuFluxTransformer2DModelV2` format.
+
+ Parameters
+ ----------
+ state_dict : dict[str, torch.Tensor]
+ The original state dict.
+
+ Returns
+ -------
+ dict[str, torch.Tensor]
+ The converted state dict compatible with :class:`~nunchaku.models.transformers.transformer_flux_v2.NunchakuFluxTransformer2DModelV2`.
+ """
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ if "single_transformer_blocks." in k:
+ if ".qkv_proj." in k:
+ new_k = k.replace(".qkv_proj.", ".attn.to_qkv.")
+ elif ".out_proj." in k:
+ new_k = k.replace(".out_proj.", ".attn.to_out.")
+ elif ".norm_q." in k or ".norm_k." in k:
+ new_k = k.replace(".norm_k.", ".attn.norm_k.")
+ new_k = new_k.replace(".norm_q.", ".attn.norm_q.")
+ else:
+ new_k = k
+ new_k = new_k.replace(".lora_down", ".proj_down")
+ new_k = new_k.replace(".lora_up", ".proj_up")
+ if ".smooth_orig" in k:
+ new_k = new_k.replace(".smooth_orig", ".smooth_factor_orig")
+ elif ".smooth" in k:
+ new_k = new_k.replace(".smooth", ".smooth_factor")
+ new_state_dict[new_k] = v
+ elif "transformer_blocks." in k:
+ if ".mlp_context_fc1" in k:
+ new_k = k.replace(".mlp_context_fc1.", ".ff_context.net.0.proj.")
+ elif ".mlp_context_fc2" in k:
+ new_k = k.replace(".mlp_context_fc2.", ".ff_context.net.2.")
+ elif ".mlp_fc1" in k:
+ new_k = k.replace(".mlp_fc1.", ".ff.net.0.proj.")
+ elif ".mlp_fc2" in k:
+ new_k = k.replace(".mlp_fc2.", ".ff.net.2.")
+ elif ".qkv_proj_context." in k:
+ new_k = k.replace(".qkv_proj_context.", ".attn.add_qkv_proj.")
+ elif ".qkv_proj." in k:
+ new_k = k.replace(".qkv_proj.", ".attn.to_qkv.")
+ elif ".norm_q." in k or ".norm_k." in k:
+ new_k = k.replace(".norm_k.", ".attn.norm_k.")
+ new_k = new_k.replace(".norm_q.", ".attn.norm_q.")
+ elif ".norm_added_q." in k or ".norm_added_k." in k:
+ new_k = k.replace(".norm_added_k.", ".attn.norm_added_k.")
+ new_k = new_k.replace(".norm_added_q.", ".attn.norm_added_q.")
+ elif ".out_proj." in k:
+ new_k = k.replace(".out_proj.", ".attn.to_out.0.")
+ elif ".out_proj_context." in k:
+ new_k = k.replace(".out_proj_context.", ".attn.to_add_out.")
+ else:
+ new_k = k
+ new_k = new_k.replace(".lora_down", ".proj_down")
+ new_k = new_k.replace(".lora_up", ".proj_up")
+ if ".smooth_orig" in k:
+ new_k = new_k.replace(".smooth_orig", ".smooth_factor_orig")
+ elif ".smooth" in k:
+ new_k = new_k.replace(".smooth", ".smooth_factor")
+ new_state_dict[new_k] = v
+ else:
+ new_state_dict[k] = v
+
+ return new_state_dict
diff --git a/nunchaku/models/transformers/transformer_qwenimage.py b/nunchaku/models/transformers/transformer_qwenimage.py
new file mode 100644
index 0000000000000000000000000000000000000000..8476f9fed05ed03767cb261dd485d29245e5a85b
--- /dev/null
+++ b/nunchaku/models/transformers/transformer_qwenimage.py
@@ -0,0 +1,601 @@
+"""
+This module provides implementations of NunchakuQwenImageTransformer2DModel and its building blocks.
+"""
+
+import gc
+import json
+import os
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple, Union
+from warnings import warn
+
+import torch
+from diffusers.models.attention_processor import Attention
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.models.transformers.transformer_qwenimage import (
+ QwenEmbedRope,
+ QwenImageTransformer2DModel,
+ QwenImageTransformerBlock,
+)
+from huggingface_hub import utils
+
+from ...utils import get_precision
+from ..attention import NunchakuBaseAttention, NunchakuFeedForward
+from ..attention_processors.qwenimage import NunchakuQwenImageNaiveFA2Processor
+from ..linear import AWQW4A16Linear, SVDQW4A4Linear
+from ..utils import CPUOffloadManager, fuse_linears
+from .utils import NunchakuModelLoaderMixin
+
+
+class NunchakuQwenAttention(NunchakuBaseAttention):
+ """
+ Nunchaku-optimized quantized attention module for QwenImage.
+
+ Parameters
+ ----------
+ other : Attention
+ The original QwenImage Attention module to wrap and quantize.
+ processor : str, default="flashattn2"
+ The attention processor to use.
+ **kwargs
+ Additional arguments for quantization.
+ """
+
+ def __init__(self, other: Attention, processor: str = "flashattn2", **kwargs):
+ super(NunchakuQwenAttention, self).__init__(processor)
+ self.inner_dim = other.inner_dim
+ self.inner_kv_dim = other.inner_kv_dim
+ self.query_dim = other.query_dim
+ self.use_bias = other.use_bias
+ self.is_cross_attention = other.is_cross_attention
+ self.cross_attention_dim = other.cross_attention_dim
+ self.upcast_attention = other.upcast_attention
+ self.upcast_softmax = other.upcast_softmax
+ self.rescale_output_factor = other.rescale_output_factor
+ self.residual_connection = other.residual_connection
+ self.dropout = other.dropout
+ self.fused_projections = other.fused_projections
+ self.out_dim = other.out_dim
+ self.out_context_dim = other.out_context_dim
+ self.context_pre_only = other.context_pre_only
+ self.pre_only = other.pre_only
+ self.is_causal = other.is_causal
+ self.scale_qk = other.scale_qk
+ self.scale = other.scale
+ self.heads = other.heads
+ self.sliceable_head_dim = other.sliceable_head_dim
+ self.added_kv_proj_dim = other.added_kv_proj_dim
+ self.only_cross_attention = other.only_cross_attention
+ self.group_norm = other.group_norm
+ self.spatial_norm = other.spatial_norm
+
+ self.norm_cross = other.norm_cross
+
+ self.norm_q = other.norm_q
+ self.norm_k = other.norm_k
+ self.norm_added_q = other.norm_added_q
+ self.norm_added_k = other.norm_added_k
+
+ # Fuse the QKV projections for quantization
+ with torch.device("meta"):
+ to_qkv = fuse_linears([other.to_q, other.to_k, other.to_v])
+ self.to_qkv = SVDQW4A4Linear.from_linear(to_qkv, **kwargs)
+ self.to_out = other.to_out
+ self.to_out[0] = SVDQW4A4Linear.from_linear(self.to_out[0], **kwargs)
+
+ assert self.added_kv_proj_dim is not None
+ # Fuse the additional QKV projections
+ with torch.device("meta"):
+ add_qkv_proj = fuse_linears([other.add_q_proj, other.add_k_proj, other.add_v_proj])
+ self.add_qkv_proj = SVDQW4A4Linear.from_linear(add_qkv_proj, **kwargs)
+ self.to_add_out = SVDQW4A4Linear.from_linear(other.to_add_out, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ encoder_hidden_states_mask: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ **kwargs,
+ ):
+ """
+ Forward pass for NunchakuQwenAttention.
+
+ Parameters
+ ----------
+ hidden_states : torch.FloatTensor
+ Image stream input.
+ encoder_hidden_states : torch.FloatTensor, optional
+ Text stream input.
+ encoder_hidden_states_mask : torch.FloatTensor, optional
+ Mask for encoder hidden states.
+ attention_mask : torch.FloatTensor, optional
+ Attention mask.
+ image_rotary_emb : torch.Tensor, optional
+ Rotary embedding for images.
+ **kwargs
+ Additional arguments.
+
+ Returns
+ -------
+ tuple
+ Attention outputs for image and text streams.
+ """
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states,
+ encoder_hidden_states_mask,
+ attention_mask,
+ image_rotary_emb,
+ **kwargs,
+ )
+
+ def set_processor(self, processor: str):
+ """
+ Set the attention processor.
+
+ Parameters
+ ----------
+ processor : str
+ Name of the processor to use. Only "flashattn2" is supported for now. See :class:`~nunchaku.models.attention_processors.qwenimage.NunchakuQwenImageNaiveFA2Processor`.
+
+ Raises
+ ------
+ ValueError
+ If the processor is not supported.
+ """
+ if processor == "flashattn2":
+ self.processor = NunchakuQwenImageNaiveFA2Processor()
+ else:
+ raise ValueError(f"Processor {processor} is not supported")
+
+
+class NunchakuQwenImageTransformerBlock(QwenImageTransformerBlock):
+ """
+ Quantized QwenImage Transformer Block.
+
+ This block supports quantized linear layers and joint attention for image and text streams.
+
+ Parameters
+ ----------
+ other : QwenImageTransformerBlock
+ The original transformer block to wrap and quantize.
+ scale_shift : float, default=1.0
+ Value to add to scale parameters. Default is 1.0.
+ Nunchaku may have already fused the scale_shift into the linear weights, so you may want to set it to 0.
+ **kwargs
+ Additional arguments for quantization.
+ """
+
+ def __init__(self, other: QwenImageTransformerBlock, scale_shift: float = 1.0, **kwargs):
+ super(QwenImageTransformerBlock, self).__init__()
+
+ self.dim = other.dim
+ self.img_mod = other.img_mod
+ self.img_mod[1] = AWQW4A16Linear.from_linear(other.img_mod[1], **kwargs)
+ self.img_norm1 = other.img_norm1
+ self.attn = NunchakuQwenAttention(other.attn, **kwargs)
+ self.img_norm2 = other.img_norm2
+ self.img_mlp = NunchakuFeedForward(other.img_mlp, **kwargs)
+
+ # Text processing modules
+ self.txt_mod = other.txt_mod
+ self.txt_mod[1] = AWQW4A16Linear.from_linear(other.txt_mod[1], **kwargs)
+ self.txt_norm1 = other.txt_norm1
+ # Text doesn't need separate attention - it's handled by img_attn joint computation
+ self.txt_norm2 = other.txt_norm2
+ self.txt_mlp = NunchakuFeedForward(other.txt_mlp, **kwargs)
+
+ self.scale_shift = scale_shift
+
+ def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Apply modulation to input tensor.
+
+ Parameters
+ ----------
+ x : torch.Tensor
+ Input tensor.
+ mod_params : torch.Tensor
+ Modulation parameters.
+
+ Returns
+ -------
+ tuple
+ Modulated tensor and gate tensor.
+ """
+ shift, scale, gate = mod_params.chunk(3, dim=-1)
+ if self.scale_shift != 0:
+ scale.add_(self.scale_shift)
+ return x * scale.unsqueeze(1) + shift.unsqueeze(1), gate.unsqueeze(1)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_mask: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Forward pass for NunchakuQwenImageTransformerBlock.
+
+ Parameters
+ ----------
+ hidden_states : torch.Tensor
+ Image stream input.
+ encoder_hidden_states : torch.Tensor
+ Text stream input.
+ encoder_hidden_states_mask : torch.Tensor
+ Mask for encoder hidden states.
+ temb : torch.Tensor
+ Temporal embedding.
+ image_rotary_emb : tuple of torch.Tensor, optional
+ Rotary embedding for images.
+ joint_attention_kwargs : dict, optional
+ Additional arguments for joint attention.
+
+ Returns
+ -------
+ tuple
+ Updated encoder_hidden_states and hidden_states.
+ """
+ # Get modulation parameters for both streams
+ img_mod_params = self.img_mod(temb) # [B, 6*dim]
+ txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
+
+ # nunchaku's mod_params is [B, 6*dim] instead of [B, dim*6]
+ img_mod_params = (
+ img_mod_params.view(img_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(img_mod_params.shape[0], -1)
+ )
+ txt_mod_params = (
+ txt_mod_params.view(txt_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(txt_mod_params.shape[0], -1)
+ )
+
+ img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
+ txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
+
+ # Process image stream - norm1 + modulation
+ img_normed = self.img_norm1(hidden_states)
+ img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
+
+ # Process text stream - norm1 + modulation
+ txt_normed = self.txt_norm1(encoder_hidden_states)
+ txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
+
+ joint_attention_kwargs = joint_attention_kwargs or {}
+ attn_output = self.attn(
+ hidden_states=img_modulated,
+ encoder_hidden_states=txt_modulated,
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+
+ # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
+ img_attn_output, txt_attn_output = attn_output
+
+ # Apply attention gates and add residual (like in Megatron)
+ hidden_states = hidden_states + img_gate1 * img_attn_output
+ encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
+
+ # Process image stream - norm2 + MLP
+ img_normed2 = self.img_norm2(hidden_states)
+ img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
+ img_mlp_output = self.img_mlp(img_modulated2)
+ hidden_states = hidden_states + img_gate2 * img_mlp_output
+
+ # Process text stream - norm2 + MLP
+ txt_normed2 = self.txt_norm2(encoder_hidden_states)
+ txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
+ txt_mlp_output = self.txt_mlp(txt_modulated2)
+ encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
+
+ # Clip to prevent overflow for fp16
+ if encoder_hidden_states.dtype == torch.float16:
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
+ if hidden_states.dtype == torch.float16:
+ hidden_states = hidden_states.clip(-65504, 65504)
+
+ return encoder_hidden_states, hidden_states
+
+
+class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuModelLoaderMixin):
+ """
+ Quantized QwenImage Transformer2DModel.
+
+ This model supports quantized transformer blocks and optional CPU offloading for memory efficiency.
+
+ Parameters
+ ----------
+ *args
+ Positional arguments for the base model.
+ **kwargs
+ Keyword arguments for the base model and quantization.
+
+ Attributes
+ ----------
+ offload : bool
+ Whether CPU offloading is enabled.
+ offload_manager : CPUOffloadManager or None
+ Manager for offloading transformer blocks.
+ _is_initialized : bool
+ Whether the model has been patched for quantization.
+ """
+
+ def __init__(self, *args, **kwargs):
+ self.offload = kwargs.pop("offload", False)
+ self.offload_manager = None
+ self._is_initialized = False
+ super().__init__(*args, **kwargs)
+
+ def _patch_model(self, **kwargs):
+ """
+ Patch the transformer blocks for quantization.
+
+ Parameters
+ ----------
+ **kwargs
+ Additional arguments for quantization.
+
+ Returns
+ -------
+ self
+ """
+ for i, block in enumerate(self.transformer_blocks):
+ self.transformer_blocks[i] = NunchakuQwenImageTransformerBlock(block, scale_shift=0, **kwargs)
+ self._is_initialized = True
+ return self
+
+ @classmethod
+ @utils.validate_hf_hub_args
+ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
+ """
+ Load a quantized model from a pretrained checkpoint.
+
+ Parameters
+ ----------
+ pretrained_model_name_or_path : str or os.PathLike
+ Path to the pretrained model checkpoint. It can be a local file or a remote HuggingFace path.
+ **kwargs
+ Additional arguments for loading and quantization.
+
+ Returns
+ -------
+ NunchakuQwenImageTransformer2DModel
+ The loaded and quantized model.
+
+ Raises
+ ------
+ AssertionError
+ If the checkpoint is not a safetensors file.
+ """
+ device = kwargs.get("device", "cpu")
+ offload = kwargs.get("offload", False)
+
+ torch_dtype = kwargs.get("torch_dtype", torch.bfloat16)
+
+ if isinstance(pretrained_model_name_or_path, str):
+ pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
+
+ assert pretrained_model_name_or_path.is_file() or pretrained_model_name_or_path.name.endswith(
+ (".safetensors", ".sft")
+ ), "Only safetensors are supported"
+ transformer, model_state_dict, metadata = cls._build_model(pretrained_model_name_or_path, **kwargs)
+ quantization_config = json.loads(metadata.get("quantization_config", "{}"))
+ config = json.loads(metadata.get("config", "{}"))
+ rank = quantization_config.get("rank", 32)
+ transformer = transformer.to(torch_dtype)
+
+ precision = get_precision()
+ if precision == "fp4":
+ precision = "nvfp4"
+ transformer._patch_model(precision=precision, rank=rank)
+
+ transformer = transformer.to_empty(device=device)
+ # need to re-init the pos_embed as to_empty does not work on it
+ transformer.pos_embed = QwenEmbedRope(
+ theta=10000, axes_dim=list(config.get("axes_dims_rope", [16, 56, 56])), scale_rope=True
+ )
+
+ state_dict = transformer.state_dict()
+ for k in state_dict.keys():
+ if k not in model_state_dict:
+ assert ".wcscales" in k
+ model_state_dict[k] = torch.ones_like(state_dict[k])
+ else:
+ assert state_dict[k].dtype == model_state_dict[k].dtype
+
+ # load the wtscale from the state dict, as it is a float on CPU
+ for n, m in transformer.named_modules():
+ if isinstance(m, SVDQW4A4Linear):
+ if m.wtscale is not None:
+ m.wtscale = model_state_dict.pop(f"{n}.wtscale", 1.0)
+ transformer.load_state_dict(model_state_dict)
+ transformer.set_offload(offload)
+
+ return transformer
+
+ def set_offload(self, offload: bool, **kwargs):
+ """
+ Enable or disable asynchronous CPU offloading for transformer blocks.
+
+ Parameters
+ ----------
+ offload : bool
+ Whether to enable offloading.
+ **kwargs
+ Additional arguments for offload manager.
+
+ See Also
+ --------
+ :class:`~nunchaku.models.utils.CPUOffloadManager`
+ """
+ if offload == self.offload:
+ # nothing changed, just return
+ return
+ self.offload = offload
+ if offload:
+ self.offload_manager = CPUOffloadManager(
+ self.transformer_blocks,
+ use_pin_memory=kwargs.get("use_pin_memory", True),
+ on_gpu_modules=[
+ self.img_in,
+ self.txt_in,
+ self.txt_norm,
+ self.time_text_embed,
+ self.norm_out,
+ self.proj_out,
+ ],
+ num_blocks_on_gpu=kwargs.get("num_blocks_on_gpu", 1),
+ )
+ else:
+ self.offload_manager = None
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ encoder_hidden_states_mask: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_shapes: Optional[List[Tuple[int, int, int]]] = None,
+ txt_seq_lens: Optional[List[int]] = None,
+ guidance: torch.Tensor = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
+ """
+ Forward pass for the quantized QwenImage transformer model.
+
+ Parameters
+ ----------
+ hidden_states : torch.Tensor
+ Image stream input.
+ encoder_hidden_states : torch.Tensor, optional
+ Text stream input.
+ encoder_hidden_states_mask : torch.Tensor, optional
+ Mask for encoder hidden states.
+ timestep : torch.LongTensor, optional
+ Timestep for temporal embedding.
+ img_shapes : list of tuple, optional
+ Image shapes for rotary embedding.
+ txt_seq_lens : list of int, optional
+ Text sequence lengths.
+ guidance : torch.Tensor, optional
+ Guidance tensor (for classifier-free guidance).
+ attention_kwargs : dict, optional
+ Additional attention arguments.
+ return_dict : bool, default=True
+ Whether to return a dict or tuple.
+
+ Returns
+ -------
+ torch.Tensor or Transformer2DModelOutput
+ Model output.
+ """
+ device = hidden_states.device
+ if self.offload:
+ self.offload_manager.set_device(device)
+
+ hidden_states = self.img_in(hidden_states)
+
+ timestep = timestep.to(hidden_states.dtype)
+ encoder_hidden_states = self.txt_norm(encoder_hidden_states)
+ encoder_hidden_states = self.txt_in(encoder_hidden_states)
+
+ if guidance is not None:
+ guidance = guidance.to(hidden_states.dtype) * 1000
+
+ temb = (
+ self.time_text_embed(timestep, hidden_states)
+ if guidance is None
+ else self.time_text_embed(timestep, guidance, hidden_states)
+ )
+
+ image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
+
+ compute_stream = torch.cuda.current_stream()
+ if self.offload:
+ self.offload_manager.initialize(compute_stream)
+ for block_idx, block in enumerate(self.transformer_blocks):
+ with torch.cuda.stream(compute_stream):
+ if self.offload:
+ block = self.offload_manager.get_block(block_idx)
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ joint_attention_kwargs=attention_kwargs,
+ )
+ if self.offload:
+ self.offload_manager.step(compute_stream)
+
+ hidden_states = self.norm_out(hidden_states, temb)
+ output = self.proj_out(hidden_states)
+
+ torch.cuda.empty_cache()
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
+
+ def to(self, *args, **kwargs):
+ """
+ Override the default ``.to()`` method.
+
+ If offload is enabled, prevents moving the model to GPU.
+ Prevents changing dtype after quantization.
+
+ Parameters
+ ----------
+ *args
+ Positional arguments for ``.to()``.
+ **kwargs
+ Keyword arguments for ``.to()``.
+
+ Returns
+ -------
+ self
+
+ Raises
+ ------
+ ValueError
+ If attempting to change dtype after quantization.
+ """
+ device_arg_or_kwarg_present = any(isinstance(arg, torch.device) for arg in args) or "device" in kwargs
+ dtype_present_in_args = "dtype" in kwargs
+
+ # Try converting arguments to torch.device in case they are passed as strings
+ for arg in args:
+ if not isinstance(arg, str):
+ continue
+ try:
+ torch.device(arg)
+ device_arg_or_kwarg_present = True
+ except RuntimeError:
+ pass
+
+ if not dtype_present_in_args:
+ for arg in args:
+ if isinstance(arg, torch.dtype):
+ dtype_present_in_args = True
+ break
+
+ if dtype_present_in_args and self._is_initialized:
+ raise ValueError(
+ "Casting a quantized model to a new `dtype` is unsupported. To set the dtype of unquantized layers, please "
+ "use the `torch_dtype` argument when loading the model using `from_pretrained` or `from_single_file`."
+ )
+ if self.offload:
+ if device_arg_or_kwarg_present:
+ warn("Skipping moving the model to GPU as offload is enabled", UserWarning)
+ return self
+ return super(type(self), self).to(*args, **kwargs)
diff --git a/nunchaku/models/transformers/transformer_sana.py b/nunchaku/models/transformers/transformer_sana.py
new file mode 100644
index 0000000000000000000000000000000000000000..50320a125f56779fbc37ae49b31a329de57f618a
--- /dev/null
+++ b/nunchaku/models/transformers/transformer_sana.py
@@ -0,0 +1,374 @@
+"""
+Implements the :class:`NunchakuSanaTransformer2DModel`,
+a quantized Sana transformer for Diffusers with efficient inference support.
+"""
+
+import os
+from pathlib import Path
+from typing import Optional
+
+import torch
+from diffusers import SanaTransformer2DModel
+from huggingface_hub import utils
+from safetensors.torch import load_file
+from torch import nn
+from torch.nn import functional as F
+
+from ..._C import QuantizedSanaModel
+from ..._C import utils as cutils
+from ...utils import get_precision
+from .utils import NunchakuModelLoaderMixin
+
+SVD_RANK = 32
+
+
+class NunchakuSanaTransformerBlocks(nn.Module):
+ """
+ Wrapper for quantized Sana transformer blocks.
+
+ This module wraps a QuantizedSanaModel and provides forward methods compatible
+ with the expected transformer block interface.
+
+ Parameters
+ ----------
+ m : QuantizedSanaModel
+ The quantized transformer model.
+ dtype : torch.dtype
+ The data type to use for computation.
+ device : str or torch.device
+ The device to run the model on.
+ """
+
+ def __init__(self, m: QuantizedSanaModel, dtype: torch.dtype, device: str | torch.device):
+ super(NunchakuSanaTransformerBlocks, self).__init__()
+ self.m = m
+ self.dtype = dtype
+ self.device = device
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ skip_first_layer: Optional[bool] = False,
+ ):
+ """
+ Forward pass through all quantized transformer blocks.
+
+ Parameters
+ ----------
+ hidden_states : torch.Tensor
+ Input hidden states of shape (batch_size, img_tokens, ...).
+ attention_mask : torch.Tensor, optional
+ Not used.
+ encoder_hidden_states : torch.Tensor, optional
+ Encoder hidden states of shape (batch_size, txt_tokens, ...).
+ encoder_attention_mask : torch.Tensor, optional
+ Encoder attention mask of shape (batch_size, 1, txt_tokens).
+ timestep : torch.LongTensor, optional
+ Timestep tensor.
+ height : int, optional
+ Image height.
+ width : int, optional
+ Image width.
+ skip_first_layer : bool, optional
+ Whether to skip the first layer.
+
+ Returns
+ -------
+ torch.Tensor
+ Output tensor after passing through the quantized transformer blocks.
+ """
+ batch_size = hidden_states.shape[0]
+ img_tokens = hidden_states.shape[1]
+ txt_tokens = encoder_hidden_states.shape[1]
+
+ original_dtype = hidden_states.dtype
+ original_device = hidden_states.device
+
+ assert encoder_attention_mask is not None
+ assert encoder_attention_mask.shape == (batch_size, 1, txt_tokens)
+
+ mask = encoder_attention_mask.reshape(batch_size, txt_tokens)
+ nunchaku_encoder_hidden_states = encoder_hidden_states[mask > -9000]
+
+ cu_seqlens_txt = F.pad((mask > -9000).sum(dim=1).cumsum(dim=0), pad=(1, 0), value=0).to(torch.int32)
+ cu_seqlens_img = torch.arange(
+ 0, (batch_size + 1) * img_tokens, img_tokens, dtype=torch.int32, device=self.device
+ )
+
+ if height is None and width is None:
+ height = width = int(img_tokens**0.5)
+ elif height is None:
+ height = img_tokens // width
+ elif width is None:
+ width = img_tokens // height
+ assert height * width == img_tokens
+
+ return (
+ self.m.forward(
+ hidden_states.to(self.dtype).to(self.device),
+ nunchaku_encoder_hidden_states.to(self.dtype).to(self.device),
+ timestep.to(self.dtype).to(self.device),
+ cu_seqlens_img.to(self.device),
+ cu_seqlens_txt.to(self.device),
+ height,
+ width,
+ batch_size % 3 == 0, # pag is set when loading the model, FIXME: pag_scale == 0
+ True, # TODO: find a way to detect if we are doing CFG
+ skip_first_layer,
+ )
+ .to(original_dtype)
+ .to(original_device)
+ )
+
+ def forward_layer_at(
+ self,
+ idx: int,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ ):
+ """
+ Forward pass through a specific quantized transformer layer.
+
+ Parameters
+ ----------
+ idx : int
+ Index of the layer to run.
+ hidden_states : torch.Tensor
+ Input hidden states.
+ attention_mask : torch.Tensor, optional
+ Not used.
+ encoder_hidden_states : torch.Tensor, optional
+ Encoder hidden states.
+ encoder_attention_mask : torch.Tensor, optional
+ Encoder attention mask.
+ timestep : torch.LongTensor, optional
+ Timestep tensor.
+ height : int, optional
+ Image height.
+ width : int, optional
+ Image width.
+
+ Returns
+ -------
+ torch.Tensor
+ Output tensor after passing through the specified quantized transformer layer.
+ """
+ batch_size = hidden_states.shape[0]
+ img_tokens = hidden_states.shape[1]
+ txt_tokens = encoder_hidden_states.shape[1]
+
+ original_dtype = hidden_states.dtype
+ original_device = hidden_states.device
+
+ assert encoder_attention_mask is not None
+ assert encoder_attention_mask.shape == (batch_size, 1, txt_tokens)
+
+ mask = encoder_attention_mask.reshape(batch_size, txt_tokens)
+ nunchaku_encoder_hidden_states = encoder_hidden_states[mask > -9000]
+
+ cu_seqlens_txt = F.pad((mask > -9000).sum(dim=1).cumsum(dim=0), pad=(1, 0), value=0).to(torch.int32)
+ cu_seqlens_img = torch.arange(
+ 0, (batch_size + 1) * img_tokens, img_tokens, dtype=torch.int32, device=self.device
+ )
+
+ if height is None and width is None:
+ height = width = int(img_tokens**0.5)
+ elif height is None:
+ height = img_tokens // width
+ elif width is None:
+ width = img_tokens // height
+ assert height * width == img_tokens
+
+ return (
+ self.m.forward_layer(
+ idx,
+ hidden_states.to(self.dtype).to(self.device),
+ nunchaku_encoder_hidden_states.to(self.dtype).to(self.device),
+ timestep.to(self.dtype).to(self.device),
+ cu_seqlens_img.to(self.device),
+ cu_seqlens_txt.to(self.device),
+ height,
+ width,
+ batch_size % 3 == 0, # pag is set when loading the model, FIXME: pag_scale == 0
+ True, # TODO: find a way to detect if we are doing CFG
+ )
+ .to(original_dtype)
+ .to(original_device)
+ )
+
+ def __del__(self):
+ """
+ Destructor to reset the quantized model and free resources.
+ """
+ self.m.reset()
+
+
+class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoaderMixin):
+ """
+ SanaTransformer2DModel with Nunchaku quantized backend support.
+
+ This class extends the base SanaTransformer2DModel to support loading and
+ injecting quantized transformer blocks using Nunchaku's custom backend.
+ """
+
+ @classmethod
+ @utils.validate_hf_hub_args
+ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
+ """
+ Load a pretrained NunchakuSanaTransformer2DModel from a local file or HuggingFace Hub.
+
+ This method supports both quantized and unquantized checkpoints, and will
+ automatically inject quantized transformer blocks if available.
+
+ Parameters
+ ----------
+ pretrained_model_name_or_path : str or os.PathLike
+ Path to the model checkpoint or HuggingFace Hub model name.
+ **kwargs
+ Additional keyword arguments for model loading.
+
+ Returns
+ -------
+ NunchakuSanaTransformer2DModel or (NunchakuSanaTransformer2DModel, dict)
+ The loaded model, and optionally metadata if ``return_metadata=True``.
+ """
+ device = kwargs.get("device", "cuda")
+ if isinstance(device, str):
+ device = torch.device(device)
+ pag_layers = kwargs.get("pag_layers", [])
+ precision = get_precision(kwargs.get("precision", "auto"), device, pretrained_model_name_or_path)
+ metadata = None
+
+ if isinstance(pretrained_model_name_or_path, str):
+ pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
+ if pretrained_model_name_or_path.is_file() or pretrained_model_name_or_path.name.endswith(
+ (".safetensors", ".sft")
+ ):
+ transformer, model_state_dict, metadata = cls._build_model(pretrained_model_name_or_path)
+ quantized_part_sd = {}
+ unquantized_part_sd = {}
+ for k, v in model_state_dict.items():
+ if k.startswith("transformer_blocks."):
+ quantized_part_sd[k] = v
+ else:
+ unquantized_part_sd[k] = v
+ m = load_quantized_module(
+ transformer, quantized_part_sd, device=device, pag_layers=pag_layers, use_fp4=precision == "fp4"
+ )
+ transformer.inject_quantized_module(m, device)
+ transformer.to_empty(device=device)
+ transformer.load_state_dict(unquantized_part_sd, strict=False)
+ else:
+ transformer, unquantized_part_path, transformer_block_path = cls._build_model_legacy(
+ pretrained_model_name_or_path, **kwargs
+ )
+ m = load_quantized_module(
+ transformer, transformer_block_path, device=device, pag_layers=pag_layers, use_fp4=precision == "fp4"
+ )
+ transformer.inject_quantized_module(m, device)
+ transformer.to_empty(device=device)
+ unquantized_state_dict = load_file(unquantized_part_path)
+ transformer.load_state_dict(unquantized_state_dict, strict=False)
+ if kwargs.get("return_metadata", False):
+ return transformer, metadata
+ else:
+ return transformer
+
+ def inject_quantized_module(self, m: QuantizedSanaModel, device: str | torch.device = "cuda"):
+ """
+ Inject a quantized transformer module into this model.
+
+ Parameters
+ ----------
+ m : QuantizedSanaModel
+ The quantized transformer module to inject.
+ device : str or torch.device, optional
+ The device to place the module on (default: "cuda").
+
+ Returns
+ -------
+ NunchakuSanaTransformer2DModel
+ The model with the quantized module injected.
+ """
+ self.transformer_blocks = torch.nn.ModuleList([NunchakuSanaTransformerBlocks(m, self.dtype, device)])
+ return self
+
+
+def load_quantized_module(
+ net: SanaTransformer2DModel,
+ path_or_state_dict: str | os.PathLike[str] | dict[str, torch.Tensor],
+ device: str | torch.device = "cuda",
+ pag_layers: int | list[int] | None = None,
+ use_fp4: bool = False,
+) -> QuantizedSanaModel:
+ """
+ Load quantized weights into a QuantizedSanaModel.
+
+ Parameters
+ ----------
+ net : SanaTransformer2DModel
+ The base transformer model (for config and dtype).
+ path_or_state_dict : str, os.PathLike, or dict
+ Path to the quantized weights or a state dict.
+ device : str or torch.device, optional
+ Device to load the quantized model on (default: "cuda").
+ pag_layers : int, list of int, or None, optional
+ List of layers to use pag (default: None).
+ use_fp4 : bool, optional
+ Whether to use FP4 quantization (default: False).
+
+ Returns
+ -------
+ QuantizedSanaModel
+ The loaded quantized model.
+ """
+ if pag_layers is None:
+ pag_layers = []
+ elif isinstance(pag_layers, int):
+ pag_layers = [pag_layers]
+ device = torch.device(device)
+ assert device.type == "cuda"
+
+ m = QuantizedSanaModel()
+ cutils.disable_memory_auto_release()
+ m.init(net.config, pag_layers, use_fp4, net.dtype == torch.bfloat16, 0 if device.index is None else device.index)
+ if isinstance(path_or_state_dict, dict):
+ m.loadDict(path_or_state_dict, True)
+ else:
+ m.load(str(path_or_state_dict))
+ return m
+
+
+def inject_quantized_module(
+ net: SanaTransformer2DModel, m: QuantizedSanaModel, device: torch.device
+) -> SanaTransformer2DModel:
+ """
+ Inject a quantized transformer module into a SanaTransformer2DModel.
+
+ Parameters
+ ----------
+ net : SanaTransformer2DModel
+ The base transformer model.
+ m : QuantizedSanaModel
+ The quantized transformer module to inject.
+ device : torch.device
+ The device to place the module on.
+
+ Returns
+ -------
+ SanaTransformer2DModel
+ The model with the quantized module injected.
+ """
+ net.transformer_blocks = torch.nn.ModuleList([NunchakuSanaTransformerBlocks(m, net.dtype, device)])
+ return net
diff --git a/nunchaku/models/transformers/utils.py b/nunchaku/models/transformers/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..179ddd4dbb0615113f6dd5c71483a8b859705fa9
--- /dev/null
+++ b/nunchaku/models/transformers/utils.py
@@ -0,0 +1,147 @@
+"""
+Utilities for Nunchaku transformer model loading.
+"""
+
+import json
+import logging
+import os
+from pathlib import Path
+
+import torch
+from diffusers import __version__
+from huggingface_hub import constants, hf_hub_download
+from torch import nn
+
+from ...utils import load_state_dict_in_safetensors
+
+# Get log level from environment variable (default to INFO)
+log_level = os.getenv("LOG_LEVEL", "INFO").upper()
+
+# Configure logging
+logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
+logger = logging.getLogger(__name__)
+
+
+class NunchakuModelLoaderMixin:
+ """
+ Mixin for standardized model loading in Nunchaku transformer models.
+ """
+
+ @classmethod
+ def _build_model(
+ cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs
+ ) -> tuple[nn.Module, dict[str, torch.Tensor], dict[str, str]]:
+ """
+ Build a transformer model from a safetensors file.
+
+ Parameters
+ ----------
+ pretrained_model_name_or_path : str or os.PathLike
+ Path to the safetensors file.
+ **kwargs
+ Additional keyword arguments (e.g., ``torch_dtype``).
+
+ Returns
+ -------
+ tuple
+ (transformer, state_dict, metadata)
+ """
+ if isinstance(pretrained_model_name_or_path, str):
+ pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
+ state_dict, metadata = load_state_dict_in_safetensors(pretrained_model_name_or_path, return_metadata=True)
+
+ config = json.loads(metadata["config"])
+
+ with torch.device("meta"):
+ transformer = cls.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16))
+
+ return transformer, state_dict, metadata
+
+ @classmethod
+ def _build_model_legacy(
+ cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs
+ ) -> tuple[nn.Module, str, str]:
+ """
+ Build a transformer model from a legacy folder structure.
+
+ .. warning::
+ This method is deprecated and will be removed in December 2025.
+ Please use :meth:`_build_model` instead.
+
+ Parameters
+ ----------
+ pretrained_model_name_or_path : str or os.PathLike
+ Path to the folder containing model weights.
+ **kwargs
+ Additional keyword arguments for HuggingFace Hub download and config loading.
+
+ Returns
+ -------
+ tuple
+ (transformer, unquantized_part_path, transformer_block_path)
+ """
+ logger.warning(
+ "Loading models from a folder will be deprecated in December 2025. "
+ "Please download the latest safetensors model, or use one of the following tools to "
+ "merge your model into a single file: the CLI utility `python -m nunchaku.merge_safetensors` "
+ "or the ComfyUI workflow `merge_safetensors.json`."
+ )
+ subfolder = kwargs.get("subfolder", None)
+ if os.path.exists(pretrained_model_name_or_path):
+ dirname = (
+ pretrained_model_name_or_path
+ if subfolder is None
+ else os.path.join(pretrained_model_name_or_path, subfolder)
+ )
+ unquantized_part_path = os.path.join(dirname, "unquantized_layers.safetensors")
+ transformer_block_path = os.path.join(dirname, "transformer_blocks.safetensors")
+ else:
+ download_kwargs = {
+ "subfolder": subfolder,
+ "repo_type": "model",
+ "revision": kwargs.get("revision", None),
+ "cache_dir": kwargs.get("cache_dir", None),
+ "local_dir": kwargs.get("local_dir", None),
+ "user_agent": kwargs.get("user_agent", None),
+ "force_download": kwargs.get("force_download", False),
+ "proxies": kwargs.get("proxies", None),
+ "etag_timeout": kwargs.get("etag_timeout", constants.DEFAULT_ETAG_TIMEOUT),
+ "token": kwargs.get("token", None),
+ "local_files_only": kwargs.get("local_files_only", None),
+ "headers": kwargs.get("headers", None),
+ "endpoint": kwargs.get("endpoint", None),
+ "resume_download": kwargs.get("resume_download", None),
+ "force_filename": kwargs.get("force_filename", None),
+ "local_dir_use_symlinks": kwargs.get("local_dir_use_symlinks", "auto"),
+ }
+ unquantized_part_path = hf_hub_download(
+ repo_id=str(pretrained_model_name_or_path), filename="unquantized_layers.safetensors", **download_kwargs
+ )
+ transformer_block_path = hf_hub_download(
+ repo_id=str(pretrained_model_name_or_path), filename="transformer_blocks.safetensors", **download_kwargs
+ )
+
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ config, _, _ = cls.load_config(
+ pretrained_model_name_or_path,
+ subfolder=subfolder,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ return_commit_hash=True,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ user_agent={"diffusers": __version__, "file_type": "model", "framework": "pytorch"},
+ **kwargs,
+ )
+
+ with torch.device("meta"):
+ transformer = cls.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16))
+ return transformer, unquantized_part_path, transformer_block_path
diff --git a/nunchaku/models/utils.py b/nunchaku/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..090540ee37539ddaf229980d3331fec5aab591fd
--- /dev/null
+++ b/nunchaku/models/utils.py
@@ -0,0 +1,262 @@
+"""
+Utility functions and classes for efficient transformer model management in Nunchaku.
+"""
+
+import copy
+
+import torch
+from torch import nn
+
+from ..utils import copy_params_into
+
+
+def fuse_linears(linears: list[nn.Linear]) -> nn.Linear:
+ """
+ Fuse a list of nn.Linear layers into a single nn.Linear with concatenated output features.
+
+ Parameters
+ ----------
+ linears : list of nn.Linear
+ List of linear layers to fuse. All must have the same input feature dimension.
+
+ Returns
+ -------
+ fused : nn.Linear
+ A new linear layer with concatenated output features and the same input features.
+
+ Raises
+ ------
+ AssertionError
+ If the input feature dimensions do not match.
+
+ Notes
+ -----
+ The fused layer does not copy weights or biases from the input layers.
+ """
+ assert len(linears) > 0
+ if len(linears) == 1:
+ return linears[0]
+ else:
+ assert all(linear.in_features == linears[0].in_features for linear in linears)
+ out_features = sum(linear.out_features for linear in linears)
+ bias = all(linear.bias is not None for linear in linears)
+ return nn.Linear(
+ linears[0].in_features,
+ out_features,
+ bias=bias,
+ dtype=linears[0].weight.dtype,
+ device=linears[0].weight.device,
+ )
+
+
+class CPUOffloadManager:
+ """
+ Manager for per-transformer-block CPU offloading with asynchronous memory operations using a Ping-Pong buffer strategy.
+
+ This class enables memory-efficient inference or training by keeping only a subset
+ of transformer blocks on GPU, offloading the rest to CPU, and preloading blocks as needed.
+
+ Parameters
+ ----------
+ blocks : list of nn.Module
+ List of transformer blocks to manage.
+ device : str or torch.device, optional
+ Target CUDA device for GPU operations. Default is "cuda".
+ use_pin_memory : bool, optional
+ Whether to use pinned memory for faster CPU-to-GPU transfers. Default is True.
+ on_gpu_modules : list of nn.Module, optional
+ Additional modules to keep on GPU at all times. Default is [].
+ num_blocks_on_gpu : int, optional
+ Number of blocks to keep on GPU simultaneously. Must be > 0. Default is 1.
+ empty_cache_freq : int, optional
+ Frequency (in forward passes) to call torch.cuda.empty_cache(). Default is 0 (never).
+
+ Attributes
+ ----------
+ blocks : list of nn.Module
+ The managed transformer blocks.
+ buffer_blocks : list of nn.Module
+ Buffers for preloading blocks onto GPU.
+ device : torch.device
+ The current CUDA device.
+ current_block_idx : int
+ Index of the current block on GPU.
+ forward_counter : int
+ Number of forward passes completed.
+ memory_stream : torch.cuda.Stream
+ CUDA stream for memory operations.
+ compute_done : torch.cuda.Event
+ CUDA event signaling compute completion.
+ memory_done : torch.cuda.Event
+ CUDA event signaling memory completion.
+ """
+
+ def __init__(
+ self,
+ blocks: list[nn.Module],
+ device: str | torch.device = torch.device("cuda"),
+ use_pin_memory: bool = True,
+ on_gpu_modules: list[nn.Module] = [],
+ num_blocks_on_gpu: int = 1,
+ empty_cache_freq: int = 0,
+ ):
+ self.blocks = blocks
+ self.use_pin_memory = use_pin_memory
+ self.on_gpu_modules = on_gpu_modules
+ self.num_blocks_on_gpu = num_blocks_on_gpu
+ assert self.num_blocks_on_gpu > 0
+
+ # Two streams: one for compute, one for memory operations, will be initialized in set_device
+ self.memory_stream = None
+
+ self.compute_done = torch.cuda.Event(blocking=False)
+ self.memory_done = torch.cuda.Event(blocking=False)
+
+ self.buffer_blocks = [copy.deepcopy(blocks[0]), copy.deepcopy(blocks[0])]
+
+ self.device = None
+ self.set_device(device)
+
+ self.current_block_idx = 0
+ self.forward_counter = 0
+ self.empty_cache_freq = empty_cache_freq
+
+ def set_device(self, device: torch.device | str, force: bool = False):
+ """
+ Set the CUDA device for offloading and memory operations.
+ It will move buffer blocks and on-GPU modules to the specified device and offload other blocks to CPU, optionally using pinned memory.
+
+ Parameters
+ ----------
+ device : torch.device or str
+ Target CUDA device.
+ force : bool, optional
+ If True, force re-initialization even if device is unchanged. Default is False.
+
+ Raises
+ ------
+ AssertionError
+ If the device is not a CUDA device.
+ """
+ if isinstance(device, str):
+ device = torch.device(device)
+ assert device.type == "cuda"
+ if self.device == device and not force:
+ return
+ self.device = device
+ self.memory_stream = torch.cuda.Stream(device=device)
+ for block in self.buffer_blocks:
+ block.to(device)
+ for module in self.on_gpu_modules:
+ module.to(device)
+ for i, block in enumerate(self.blocks):
+ if i < self.num_blocks_on_gpu:
+ block.to(device)
+ else:
+ block.to("cpu")
+ if self.use_pin_memory:
+ for p in block.parameters(recurse=True):
+ p.data = p.data.pin_memory()
+ for b in block.buffers(recurse=True):
+ b.data = b.data.pin_memory()
+
+ def load_block(self, block_idx: int, non_blocking: bool = True):
+ """
+ Move a transformer block from CPU to GPU buffer.
+
+ Parameters
+ ----------
+ block_idx : int
+ Index of the block to load.
+ non_blocking : bool, optional
+ Whether to use non-blocking memory copy. Default is True.
+
+ Notes
+ -----
+ - No action is taken if the block is already on GPU or index is out of range.
+ """
+ # if the block is already on GPU, don't load it to the buffer
+ if block_idx < self.num_blocks_on_gpu:
+ return
+ # if there are blocks on GPU, don't load the first block to the buffer again
+ if block_idx >= len(self.blocks):
+ return
+
+ block = self.blocks[block_idx]
+ copy_params_into(block, self.buffer_blocks[block_idx % 2], non_blocking=non_blocking)
+
+ def step(self, compute_stream: torch.cuda.Stream | None = None):
+ """
+ Advance to the next transformer block, triggering asynchronous preloading.
+
+ It will preload the next block onto GPU in the background and synchronize between compute and memory streams.
+ After all the blocks are processed, it will call torch.cuda.empty_cache() periodically if ``empty_cache_freq`` > 0.
+
+ Parameters
+ ----------
+ compute_stream : torch.cuda.Stream, optional
+ CUDA stream for compute operations. If None, uses current stream.
+ """
+ if compute_stream is None:
+ compute_stream = torch.cuda.current_stream()
+ next_compute_done = torch.cuda.Event()
+ next_compute_done.record(compute_stream)
+ with torch.cuda.stream(self.memory_stream):
+ self.memory_stream.wait_event(self.compute_done)
+ self.load_block(self.current_block_idx + 1) # if the current block is the last block, load the first block
+ next_memory_done = torch.cuda.Event()
+ next_memory_done.record(self.memory_stream)
+ self.memory_done = next_memory_done
+ self.compute_done = next_compute_done
+ self.current_block_idx += 1
+ if self.current_block_idx < len(self.blocks):
+ # get ready for the next compute
+ compute_stream.wait_event(self.memory_done)
+ else:
+ # ready to finish
+ compute_stream.wait_event(self.compute_done)
+ self.current_block_idx = 0
+ self.forward_counter += 1
+ if self.empty_cache_freq > 0 and self.forward_counter % self.empty_cache_freq == 0:
+ torch.cuda.empty_cache()
+
+ def get_block(self, block_idx: int | None = None) -> nn.Module:
+ """
+ Retrieve the current or specified transformer block for computation.
+ It will return a buffer block if the requested block is offloaded.
+
+ Parameters
+ ----------
+ block_idx : int, optional
+ Index of the block to retrieve. If None, returns the current block.
+
+ Returns
+ -------
+ block : nn.Module
+ The requested transformer block (on GPU if needed).
+ """
+ if block_idx is None:
+ block_idx = self.current_block_idx
+ if block_idx < self.num_blocks_on_gpu:
+ return self.blocks[block_idx]
+ else:
+ return self.buffer_blocks[block_idx % 2]
+
+ def initialize(self, stream: torch.cuda.Stream | None = None):
+ """
+ Initialize CUDA events for compute and memory streams.
+ It will record the initial events for the compute and memory streams.
+
+ Parameters
+ ----------
+ stream : torch.cuda.Stream, optional
+ CUDA stream to record initial events. If None, uses current stream.
+
+ Notes
+ -----
+ - Should be called before the first forward pass.
+ """
+ if stream is None:
+ stream = torch.cuda.current_stream()
+ self.compute_done.record(stream)
+ self.memory_done.record(stream)
diff --git a/nunchaku/ops/__init__.py b/nunchaku/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2123e332cf157a57ea9f97bc83435704d731983e
--- /dev/null
+++ b/nunchaku/ops/__init__.py
@@ -0,0 +1 @@
+# Quantized operations for FLUX-Kontext
diff --git a/nunchaku/ops/fused.py b/nunchaku/ops/fused.py
new file mode 100644
index 0000000000000000000000000000000000000000..b58a30aae6c12d505d290f175ea8478ecf11e771
--- /dev/null
+++ b/nunchaku/ops/fused.py
@@ -0,0 +1,178 @@
+"""
+High-performance fused operators for quantized neural network inference.
+"""
+
+import torch
+from torch.nn import RMSNorm
+
+from nunchaku.models.linear import SVDQW4A4Linear
+
+from ..utils import ceil_divide
+from .gemm import svdq_gemm_w4a4_cuda
+
+
+def fused_gelu_mlp(x: torch.Tensor, fc1: SVDQW4A4Linear, fc2: SVDQW4A4Linear, pad_size: int = 256) -> torch.Tensor:
+ """
+ Fused quantized MLP with GELU activation.
+
+ Combines the first quantized linear layer, GELU activation, and the second quantized linear layer into a single CUDA kernel. Supports INT4 and NVFP4 quantization.
+
+ Parameters
+ ----------
+ x : torch.Tensor, shape (B, S, C_in), dtype float16 or bfloat16
+ Input tensor.
+ fc1 : SVDQW4A4Linear
+ First quantized linear layer (input → hidden).
+ fc2 : SVDQW4A4Linear
+ Second quantized linear layer (hidden → output).
+ pad_size : int, optional
+ Batch padding size for CUDA kernel efficiency. Default is 256.
+
+ Returns
+ -------
+ torch.Tensor, shape (B, S, C_out), dtype as input
+ Output tensor.
+
+ Notes
+ -----
+ - Notations:
+
+ - B: batch size
+ - S: sequence length
+ - C_in: input features
+ - C_out: output features
+ - For INT4 quantization, GELU activations are shifted by 0.171875 to ensure non-negativity, enabling unsigned quantization for improved quality. See: https://github.com/nunchaku-tech/nunchaku/blob/433f0b228a61a53fb700ac676fd2e290368ac94d/src/kernels/zgemm/gemm_w4a4_launch_impl.cuh#L286
+ """
+ batch_size, seq_len, channels = x.shape
+ x = x.view(batch_size * seq_len, channels)
+ quantized_x, ascales, lora_act = fc1.quantize(x)
+
+ batch_size_pad = ceil_divide(batch_size * seq_len, pad_size) * pad_size
+
+ qout_act = torch.empty(batch_size_pad, fc1.out_features // 2, dtype=torch.uint8, device=x.device)
+ if fc2.precision == "nvfp4":
+ qout_ascales = torch.empty(fc1.out_features // 16, batch_size_pad, dtype=torch.float8_e4m3fn, device=x.device)
+ else:
+ qout_ascales = torch.empty(fc1.out_features // 64, batch_size_pad, dtype=x.dtype, device=x.device)
+ qout_lora_act = torch.empty(batch_size_pad, fc2.proj_down.shape[1], dtype=torch.float32, device=x.device)
+
+ svdq_gemm_w4a4_cuda(
+ act=quantized_x,
+ wgt=fc1.qweight,
+ qout=qout_act,
+ ascales=ascales,
+ wscales=fc1.wscales,
+ oscales=qout_ascales,
+ lora_act_in=lora_act,
+ lora_up=fc1.proj_up,
+ lora_down=fc2.proj_down,
+ lora_act_out=qout_lora_act,
+ bias=fc1.bias,
+ smooth_factor=fc2.smooth_factor,
+ fp4=fc1.precision == "nvfp4",
+ alpha=fc1.wtscale,
+ wcscales=fc1.wcscales,
+ )
+ output = torch.empty(batch_size * seq_len, fc2.out_features, dtype=x.dtype, device=x.device)
+ output = fc2.forward_quant(qout_act, qout_ascales, qout_lora_act, output=output)
+ output = output.view(batch_size, seq_len, -1)
+ return output
+
+
+def fused_qkv_norm_rottary(
+ x: torch.Tensor,
+ proj: SVDQW4A4Linear,
+ norm_q: RMSNorm,
+ norm_k: RMSNorm,
+ rotary_emb: torch.Tensor,
+ output: torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
+ attn_tokens: int = 0,
+) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Fused quantized QKV projection with RMSNorm and rotary embeddings.
+
+ Performs quantized QKV projection, applies RMS normalization to Q and K, and fuses rotary embeddings in a single CUDA kernel call.
+
+ Parameters
+ ----------
+ x : torch.Tensor, shape (B, S, C_in), dtype float16 or bfloat16
+ Input tensor.
+ proj : SVDQW4A4Linear
+ Quantized QKV projection layer.
+ norm_q : RMSNorm
+ RMSNorm for query.
+ norm_k : RMSNorm
+ RMSNorm for key.
+ rotary_emb : torch.Tensor
+ Packed rotary embedding tensor (see :func:`~nunchaku.models.embeddings.pack_rotemb`).
+ output : torch.Tensor or tuple of torch.Tensor, optional
+ Output tensor(s). If None, a new tensor is allocated.
+ If tuple, should be (output_q, output_k, output_v) for fused attention packing.
+ attn_tokens : int, optional
+ Number of attention tokens. Default is 0.
+
+ Returns
+ -------
+ torch.Tensor or tuple of torch.Tensor
+ Output tensor of shape (B, S, C_out), or tuple (output_q, output_k, output_v).
+
+ Notes
+ -----
+ Notations:
+ - B: batch size
+ - S: sequence length
+ - C_in: input features
+ - C_out: output features
+ """
+ assert isinstance(norm_q, RMSNorm)
+ assert isinstance(norm_k, RMSNorm)
+
+ batch_size, seq_len, channels = x.shape
+ x = x.view(batch_size * seq_len, channels)
+ quantized_x, ascales, lora_act = proj.quantize(x)
+
+ if output is None:
+ output = torch.empty(quantized_x.shape[0], proj.out_features, dtype=x.dtype, device=x.device)
+
+ if isinstance(output, tuple):
+ assert len(output) == 3
+ output_q, output_k, output_v = output
+ svdq_gemm_w4a4_cuda(
+ act=quantized_x,
+ wgt=proj.qweight,
+ ascales=ascales,
+ wscales=proj.wscales,
+ lora_act_in=lora_act,
+ lora_up=proj.proj_up,
+ bias=proj.bias,
+ fp4=proj.precision == "nvfp4",
+ alpha=proj.wtscale,
+ wcscales=proj.wcscales,
+ norm_q=norm_q.weight,
+ norm_k=norm_k.weight,
+ rotary_emb=rotary_emb,
+ out_q=output_q,
+ out_k=output_k,
+ out_v=output_v,
+ attn_tokens=attn_tokens,
+ )
+ return output_q, output_k, output_v
+ else:
+ svdq_gemm_w4a4_cuda(
+ act=quantized_x,
+ wgt=proj.qweight,
+ out=output,
+ ascales=ascales,
+ wscales=proj.wscales,
+ lora_act_in=lora_act,
+ lora_up=proj.proj_up,
+ bias=proj.bias,
+ fp4=proj.precision == "nvfp4",
+ alpha=proj.wtscale,
+ wcscales=proj.wcscales,
+ norm_q=norm_q.weight,
+ norm_k=norm_k.weight,
+ rotary_emb=rotary_emb,
+ )
+ output = output.view(batch_size, seq_len, -1)
+ return output
diff --git a/nunchaku/ops/gemm.py b/nunchaku/ops/gemm.py
new file mode 100644
index 0000000000000000000000000000000000000000..be0b39f9fd18fd7c9a8b2b198f568aa6c61fc9bd
--- /dev/null
+++ b/nunchaku/ops/gemm.py
@@ -0,0 +1,160 @@
+"""
+Python wrappers for Nunchaku's high-performance quantized GEMM (General Matrix-Matrix Multiplication) CUDA kernels.
+"""
+
+import math
+
+import torch
+
+from .._C import ops
+
+
+def svdq_gemm_w4a4_cuda(
+ act: torch.Tensor,
+ wgt: torch.Tensor,
+ out: torch.Tensor | None = None,
+ qout: torch.Tensor | None = None,
+ ascales: torch.Tensor | None = None,
+ wscales: torch.Tensor | None = None,
+ oscales: torch.Tensor | None = None,
+ poolout: torch.Tensor | None = None,
+ lora_act_in: torch.Tensor | None = None,
+ lora_up: torch.Tensor | None = None,
+ lora_down: torch.Tensor | None = None,
+ lora_act_out: torch.Tensor | None = None,
+ norm_q: torch.Tensor | None = None,
+ norm_k: torch.Tensor | None = None,
+ rotary_emb: torch.Tensor | None = None,
+ bias: torch.Tensor | None = None,
+ smooth_factor: torch.Tensor | None = None,
+ out_vk: torch.Tensor | None = None,
+ out_linearattn: torch.Tensor | None = None,
+ act_unsigned: bool = False,
+ lora_scales: list[float] | None = None,
+ fuse_silu: bool = False,
+ fp4: bool = False,
+ alpha: float | None = 1.0,
+ wcscales: torch.Tensor | None = None,
+ out_q: torch.Tensor | None = None,
+ out_k: torch.Tensor | None = None,
+ out_v: torch.Tensor | None = None,
+ attn_tokens: int = 0,
+):
+ """
+ Quantized GEMM using SVDQuant W4A4 CUDA kernel, with support for LoRA, rotary embeddings, normalization, and fused activations.
+
+ Parameters
+ ----------
+ act : torch.Tensor, shape (M, K // 2), dtype int8
+ Packed input activations.
+ wgt : torch.Tensor, shape (N, K // 2), dtype int8
+ Packed quantized weights.
+ out : torch.Tensor or None, shape (M, N), dtype float16 or bfloat16, optional
+ Output tensor for the linear layer.
+ qout : torch.Tensor or None, shape (M, N // 2), dtype int8, optional
+ Packed quantized input for the next layer.
+ ascales : torch.Tensor or None, shape (K // G, M), dtype float16/bfloat16 (INT4) or float8_e4m3fn (NVFP4), optional
+ Activation scales.
+ wscales : torch.Tensor or None, shape (K // G, N), dtype float16/bfloat16 (INT4) or float8_e4m3fn (NVFP4), optional
+ Weight scales.
+ oscales : torch.Tensor or None, shape (N // G, M), dtype float16/bfloat16 (INT4) or float8_e4m3fn (NVFP4), optional
+ Output scales.
+ poolout : torch.Tensor or None, optional
+ Reserved for future use.
+ lora_act_in : torch.Tensor or None, shape (M, R), dtype float32, optional
+ LoRA down-projection activations.
+ lora_up : torch.Tensor or None, shape (N, R), dtype float16 or bfloat16, optional
+ Packed LoRA up-projection weights.
+ lora_down : torch.Tensor or None, shape (N, R), dtype float16 or bfloat16, optional
+ Packed LoRA down-projection weights for the next layer.
+ lora_act_out : torch.Tensor or None, shape (M, R), dtype float32, optional
+ Output for LoRA down-projection in the next layer.
+ norm_q : torch.Tensor or None, shape (HEAD_DIM,), dtype float16 or bfloat16, optional
+ Query RMS normalization.
+ norm_k : torch.Tensor or None, shape (HEAD_DIM,), dtype float16 or bfloat16, optional
+ Key RMS normalization.
+ rotary_emb : torch.Tensor or None, shape (M, HEAD_DIM // 2, 2, 2), dtype float32, optional
+ Packed rotary embeddings.
+ bias : torch.Tensor or None, shape (N,), dtype float16 or bfloat16, optional
+ Bias tensor.
+ smooth_factor : torch.Tensor or None, shape (N,), dtype float16 or bfloat16, optional
+ Smoothing factor for quantization in the next layer.
+ out_vk : torch.Tensor or None, optional
+ Used only in SANA. Leave as None.
+ out_linearattn : torch.Tensor or None, optional
+ Used only in SANA. Leave as None.
+ act_unsigned : bool, default=False
+ If True, activations are unsigned (e.g., after GeLU, shifted by 0.171875). This is only used for INT4 to enable unsigned INT4 activation quantization for better quantization quality.
+ lora_scales : list of float or None, optional
+ Per-group LoRA scaling factors (16 channels per group). Defaults to 1.0 per group.
+ fuse_silu : bool, default=False
+ If True, fuse SiLU activation.
+ fp4 : bool, default=False
+ If True, use 4-bit floating point quantization (NVFP4).
+ alpha : float or None, default=1.0
+ Per-tensor scaling factor for NVFP4.
+ wcscales : torch.Tensor or None, shape (N,), dtype float8_e4m3fn, optional
+ Per-channel scaling for NVFP4.
+ out_q : torch.Tensor or None, shape (B, H, M, D), dtype int8, optional
+ Packed quantized Q for attention (used in ``nunchaku-fp16`` attention).
+ out_k : torch.Tensor or None, shape (B, H, M, D), dtype int8, optional
+ Packed quantized K for attention (used in ``nunchaku-fp16`` attention).
+ out_v : torch.Tensor or None, shape (B, H, M, D), dtype int8, optional
+ Packed quantized V for attention (used in ``nunchaku-fp16`` attention).
+ attn_tokens : int, default=0
+ Number of attention tokens.
+
+ Returns
+ -------
+ None
+ Results are written in-place to the provided output tensors.
+
+ Notes
+ -----
+ Notations:
+
+ - M: batch size (input tokens)
+ - K: input channels (feature dimension)
+ - N: output channels
+ - G: group size (64 for INT4, 16 for NVFP4)
+ - R: LoRA rank
+ - B: batch size for attention
+ - H: number of heads
+ - D: head dimension
+ """
+ if lora_scales is None:
+ rank = lora_up.shape[1]
+ lora_scales = [1.0] * math.ceil(rank / 16)
+ if alpha is None:
+ alpha = 1.0
+ ops.gemm_w4a4(
+ act,
+ wgt,
+ out,
+ qout,
+ ascales,
+ wscales,
+ oscales,
+ poolout,
+ lora_act_in,
+ lora_up,
+ lora_down,
+ lora_act_out,
+ norm_q,
+ norm_k,
+ rotary_emb,
+ bias,
+ smooth_factor,
+ out_vk,
+ out_linearattn,
+ act_unsigned,
+ lora_scales,
+ fuse_silu,
+ fp4,
+ alpha,
+ wcscales,
+ out_q,
+ out_k,
+ out_v,
+ attn_tokens,
+ )
diff --git a/nunchaku/ops/gemv.py b/nunchaku/ops/gemv.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0658ba9264ae668ed460ade8ff4987b668161bc
--- /dev/null
+++ b/nunchaku/ops/gemv.py
@@ -0,0 +1,56 @@
+"""
+Python wrapper for Nunchaku's high-performance GEMV (General Matrix-Vector Multiplication) CUDA kernels.
+"""
+
+import torch
+
+from .._C import ops
+
+
+def awq_gemv_w4a16_cuda(
+ in_feats: torch.Tensor,
+ kernel: torch.Tensor,
+ scaling_factors: torch.Tensor,
+ zeros: torch.Tensor,
+ m: int,
+ n: int,
+ k: int,
+ group_size: int = 64,
+) -> torch.Tensor:
+ """
+ Performs quantized GEMV using the AWQ W4A16 format.
+
+ Parameters
+ ----------
+ in_feats : torch.Tensor, shape (k,) or (m, k), dtype float16 or bfloat16
+ Input feature vector or batch of vectors.
+ kernel : torch.Tensor, shape (n // 4, k // 2), dtype int32
+ Packed quantized weight matrix.
+ scaling_factors : torch.Tensor, shape (k // group_size, n), dtype float16 or bfloat16
+ Per-group scaling factors.
+ zeros : torch.Tensor, shape (k // group_size, n), dtype float16 or bfloat16
+ Per-group zero points.
+ m : int
+ Batch size (number of input vectors).
+ n : int
+ Output feature dimension.
+ k : int
+ Input feature dimension.
+ group_size : int, optional
+ Number of input channels per quantization group. Default is 64.
+
+ Returns
+ -------
+ torch.Tensor, shape (m, n), dtype float16 or bfloat16
+ Output tensor.
+
+ Notes
+ -----
+ Notations:
+
+ - m: batch size
+ - n: output features
+ - k: input features
+ - group_size: quantization group size
+ """
+ return ops.gemv_awq(in_feats, kernel, scaling_factors, zeros, m, n, k, group_size)
diff --git a/nunchaku/ops/quantize.py b/nunchaku/ops/quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdebb68f8af73af6bcecc7f72b463d067ab97a3f
--- /dev/null
+++ b/nunchaku/ops/quantize.py
@@ -0,0 +1,81 @@
+"""
+This module provides Python wrappers for Nunchaku's high-performance SVDQuant quantization CUDA kernels.
+"""
+
+import torch
+
+from .._C import ops
+from ..utils import ceil_divide
+
+
+def svdq_quantize_w4a4_act_fuse_lora_cuda(
+ input: torch.Tensor,
+ output: torch.Tensor | None = None,
+ oscales: torch.Tensor | None = None,
+ lora_down: torch.Tensor | None = None,
+ lora_act_out: torch.Tensor | None = None,
+ smooth: torch.Tensor | None = None,
+ fuse_glu: bool = False,
+ fp4: bool = False,
+ pad_size: int = 256,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Quantizes activations and computes LoRA down-projection using SVDQuant W4A4 CUDA kernel.
+
+ Parameters
+ ----------
+ input : torch.Tensor, shape (M, K), dtype bfloat16/float16
+ Input activations.
+ output : torch.Tensor or None, shape (M_pad, K // 2), dtype uint8, optional
+ Packed output tensor for quantized activations. Allocated if None.
+ oscales : torch.Tensor or None, shape (K // G, M_pad), dtype float8_e4m3fn for NVFP4 or input dtype for INT4, optional
+ Output scales tensor. Allocated if None.
+ lora_down : torch.Tensor or None, shape (K, R), dtype bfloat16/float16, optional
+ Packed LoRA down-projection weights.
+ lora_act_out : torch.Tensor or None, shape (M_pad, R), dtype float32, optional
+ Packed output tensor for LoRA activations. Allocated if None.
+ smooth : torch.Tensor or None, optional, dtype bfloat16/float16
+ Smoothing factor for quantization.
+ fuse_glu : bool, default=False
+ If True, fuse GLU activation.
+ fp4 : bool, default=False
+ If True, use NVFP4 quantization; else INT4.
+ pad_size : int, default=256
+ Pad batch size to a multiple of this value for efficient CUDA execution.
+
+ Returns
+ -------
+ output : torch.Tensor, shape (M_pad, K // 2), dtype uint8
+ Packed quantized activations.
+ oscales : torch.Tensor, shape (K // G, M_pad), dtype float8_e4m3fn for NVFP4 or input dtype for INT4
+ Output scales.
+ lora_act_out : torch.Tensor, shape (M_pad, R), dtype float32
+ Packed LoRA activation output.
+
+ Notes
+ -----
+ Notations:
+
+ - M: batch size
+ - K: input channels
+ - R: LoRA rank
+ - G: group size (64 for INT4, 16 for NVFP4)
+ - M_pad: padded batch size = ceil(M / pad_size) * pad_size
+ """
+ batch_size, channels = input.shape
+ rank = lora_down.shape[1]
+ batch_size_pad = ceil_divide(batch_size, pad_size) * pad_size
+ if output is None:
+ output = torch.empty(batch_size_pad, channels // 2, dtype=torch.uint8, device=input.device)
+ if oscales is None:
+ if fp4:
+ assert channels % 16 == 0
+ oscales = torch.empty(channels // 16, batch_size_pad, dtype=torch.float8_e4m3fn, device=input.device)
+ else:
+ assert channels % 64 == 0
+ oscales = torch.empty(channels // 64, batch_size_pad, dtype=input.dtype, device=input.device)
+ if lora_act_out is None:
+ lora_act_out = torch.empty(batch_size_pad, rank, dtype=torch.float32, device=input.device)
+
+ ops.quantize_w4a4_act_fuse_lora(input, output, oscales, lora_down, lora_act_out, smooth, fuse_glu, fp4)
+ return output, oscales, lora_act_out
diff --git a/nunchaku/utils.py b/nunchaku/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0dee226fcd6ab188e8a34837808d76cc35f241f3
--- /dev/null
+++ b/nunchaku/utils.py
@@ -0,0 +1,366 @@
+"""
+Utility functions for Nunchaku.
+"""
+
+import hashlib
+import os
+import warnings
+from pathlib import Path
+from typing import Any
+
+import safetensors
+import torch
+from huggingface_hub import hf_hub_download
+from torch import nn
+
+
+def pad_tensor(tensor: torch.Tensor | None, multiples: int, dim: int, fill: Any = 0) -> torch.Tensor | None:
+ """
+ Pad a tensor along a given dimension to the next multiple of a specified value.
+
+ Parameters
+ ----------
+ tensor : torch.Tensor or None
+ Input tensor. If None, returns None.
+ multiples : int
+ Pad to this multiple. If <= 1, no padding is applied.
+ dim : int
+ Dimension along which to pad.
+ fill : Any, optional
+ Value to use for padding (default: 0).
+
+ Returns
+ -------
+ torch.Tensor or None
+ The padded tensor, or None if input was None.
+ """
+ if multiples <= 1:
+ return tensor
+ if tensor is None:
+ return None
+ shape = list(tensor.shape)
+ if shape[dim] % multiples == 0:
+ return tensor
+ shape[dim] = ceil_divide(shape[dim], multiples) * multiples
+ result = torch.empty(shape, dtype=tensor.dtype, device=tensor.device)
+ result.fill_(fill)
+ result[[slice(0, extent) for extent in tensor.shape]] = tensor
+ return result
+
+
+def sha256sum(filepath: str | os.PathLike[str]) -> str:
+ """
+ Compute the SHA-256 checksum of a file.
+
+ Parameters
+ ----------
+ filepath : str or os.PathLike
+ Path to the file.
+
+ Returns
+ -------
+ str
+ The SHA-256 hexadecimal digest of the file.
+ """
+ sha256 = hashlib.sha256()
+ with open(filepath, "rb") as f:
+ for chunk in iter(lambda: f.read(8192), b""):
+ sha256.update(chunk)
+ return sha256.hexdigest()
+
+
+def fetch_or_download(path: str | Path, repo_type: str = "model") -> Path:
+ """
+ Fetch a file from a local path or download from HuggingFace Hub if not present.
+
+ The remote path should be in the format: ``/`` or ``//``.
+
+ Parameters
+ ----------
+ path : str or Path
+ Local file path or HuggingFace Hub path.
+ repo_type : str, optional
+ Type of HuggingFace repo (default: "model").
+
+ Returns
+ -------
+ Path
+ Path to the local file.
+
+ Raises
+ ------
+ ValueError
+ If the path is too short to extract repo_id and subfolder.
+ """
+ path = Path(path)
+
+ if path.exists():
+ return path
+
+ parts = path.parts
+ if len(parts) < 3:
+ raise ValueError(f"Path '{path}' is too short to extract repo_id and subfolder")
+
+ repo_id = "/".join(parts[:2])
+ sub_path = Path(*parts[2:])
+ filename = sub_path.name
+ subfolder = str(sub_path.parent) if sub_path.parent != Path(".") else None
+
+ path = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type)
+ return Path(path)
+
+
+def ceil_divide(x: int, divisor: int) -> int:
+ """
+ Compute the ceiling of x divided by divisor.
+
+ Parameters
+ ----------
+ x : int
+ Dividend.
+ divisor : int
+ Divisor.
+
+ Returns
+ -------
+ int
+ The smallest integer >= x / divisor.
+ """
+ return (x + divisor - 1) // divisor
+
+
+def load_state_dict_in_safetensors(
+ path: str | os.PathLike[str],
+ device: str | torch.device = "cpu",
+ filter_prefix: str = "",
+ return_metadata: bool = False,
+) -> dict[str, torch.Tensor] | tuple[dict[str, torch.Tensor], dict[str, str]]:
+ """
+ Load a state dict from a safetensors file, optionally filtering by prefix.
+
+ Parameters
+ ----------
+ path : str or os.PathLike
+ Path to the safetensors file (local or HuggingFace Hub).
+ device : str or torch.device, optional
+ Device to load tensors onto (default: "cpu").
+ filter_prefix : str, optional
+ Only load keys starting with this prefix (default: "", no filter).
+ return_metadata : bool, optional
+ Whether to return safetensors metadata (default: False).
+
+ Returns
+ -------
+ dict[str, torch.Tensor] or tuple[dict[str, torch.Tensor], dict[str, str]]
+ The loaded state dict, and optionally the metadata if ``return_metadata`` is True.
+ """
+ state_dict = {}
+ with safetensors.safe_open(fetch_or_download(path), framework="pt", device=device) as f:
+ metadata = f.metadata()
+ for k in f.keys():
+ if filter_prefix and not k.startswith(filter_prefix):
+ continue
+ state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
+ if return_metadata:
+ return state_dict, metadata
+ else:
+ return state_dict
+
+
+def filter_state_dict(state_dict: dict[str, torch.Tensor], filter_prefix: str = "") -> dict[str, torch.Tensor]:
+ """
+ Filter a state dict to only include keys starting with a given prefix.
+
+ Parameters
+ ----------
+ state_dict : dict[str, torch.Tensor]
+ The input state dict.
+ filter_prefix : str, optional
+ Prefix to filter keys by (default: "", no filter).
+
+ Returns
+ -------
+ dict[str, torch.Tensor]
+ Filtered state dict with prefix removed from keys.
+ """
+ return {k.removeprefix(filter_prefix): v for k, v in state_dict.items() if k.startswith(filter_prefix)}
+
+
+def get_precision(
+ precision: str = "auto",
+ device: str | torch.device = "cuda",
+ pretrained_model_name_or_path: str | os.PathLike[str] | None = None,
+) -> str:
+ """
+ Determine the quantization precision to use based on device and model.
+
+ Parameters
+ ----------
+ precision : str, optional
+ "auto", "int4", or "fp4" (default: "auto").
+ device : str or torch.device, optional
+ Device to check (default: "cuda").
+ pretrained_model_name_or_path : str or os.PathLike or None, optional
+ Model name or path for warning checks.
+
+ Returns
+ -------
+ str
+ The selected precision ("int4" or "fp4").
+
+ Raises
+ ------
+ AssertionError
+ If precision is not one of "auto", "int4", or "fp4".
+ """
+ assert precision in ("auto", "int4", "fp4")
+ if precision == "auto":
+ if isinstance(device, str):
+ device = torch.device(device)
+ capability = torch.cuda.get_device_capability(0 if device.index is None else device.index)
+ sm = f"{capability[0]}{capability[1]}"
+ precision = "fp4" if sm == "120" else "int4"
+ if pretrained_model_name_or_path is not None:
+ if precision == "int4":
+ if "fp4" in str(pretrained_model_name_or_path):
+ warnings.warn("The model may be quantized to fp4, but you are loading it with int4 precision.")
+ elif precision == "fp4":
+ if "int4" in str(pretrained_model_name_or_path):
+ warnings.warn("The model may be quantized to int4, but you are loading it with fp4 precision.")
+ return precision
+
+
+def is_turing(device: str | torch.device = "cuda") -> bool:
+ """
+ Check if the current GPU is a Turing GPU (compute capability 7.5).
+
+ Parameters
+ ----------
+ device : str or torch.device, optional
+ Device to check (default: "cuda").
+
+ Returns
+ -------
+ bool
+ True if the current GPU is a Turing GPU, False otherwise.
+ """
+ if isinstance(device, str):
+ device = torch.device(device)
+ device_id = 0 if device.index is None else device.index
+ capability = torch.cuda.get_device_capability(device_id)
+ sm = f"{capability[0]}{capability[1]}"
+ return sm == "75"
+
+
+def get_gpu_memory(device: str | torch.device = "cuda", unit: str = "GiB") -> int:
+ """
+ Get the total memory of the current GPU.
+
+ Parameters
+ ----------
+ device : str or torch.device, optional
+ Device to check (default: "cuda").
+ unit : str, optional
+ Unit for memory ("GiB", "MiB", or "B") (default: "GiB").
+
+ Returns
+ -------
+ int
+ GPU memory in the specified unit.
+
+ Raises
+ ------
+ AssertionError
+ If unit is not one of "GiB", "MiB", or "B".
+ """
+ if isinstance(device, str):
+ device = torch.device(device)
+ assert unit in ("GiB", "MiB", "B")
+ memory = torch.cuda.get_device_properties(device).total_memory
+ if unit == "GiB":
+ return memory // (1024**3)
+ elif unit == "MiB":
+ return memory // (1024**2)
+ else:
+ return memory
+
+
+def check_hardware_compatibility(quantization_config: dict, device: str | torch.device = "cuda"):
+ """
+ Check if the quantization config is compatible with the current GPU.
+
+ Parameters
+ ----------
+ quantization_config : dict
+ Quantization configuration dictionary.
+ device : str or torch.device, optional
+ Device to check (default: "cuda").
+
+ Raises
+ ------
+ ValueError
+ If the quantization config is not compatible with the GPU architecture.
+ """
+ if isinstance(device, str):
+ device = torch.device(device)
+ capability = torch.cuda.get_device_capability(0 if device.index is None else device.index)
+ sm = f"{capability[0]}{capability[1]}"
+ if sm == "120": # you can only use the fp4 models
+ if quantization_config["weight"]["dtype"] != "fp4_e2m1_all":
+ raise ValueError('Please use "fp4" quantization for Blackwell GPUs. ')
+ elif sm in ["75", "80", "86", "89"]:
+ if quantization_config["weight"]["dtype"] != "int4":
+ raise ValueError('Please use "int4" quantization for Turing, Ampere and Ada GPUs. ')
+ else:
+ raise ValueError(
+ f"Unsupported GPU architecture {sm} due to the lack of 4-bit tensorcores. "
+ "Please use a Turing, Ampere, Ada or Blackwell GPU for this quantization configuration."
+ )
+
+
+def get_precision_from_quantization_config(quantization_config: dict) -> str:
+ """
+ Get the precision from the quantization configuration.
+ """
+ if quantization_config["weight"]["dtype"] == "fp4_e2m1_all":
+ if quantization_config["weight"]["group_size"] == 16:
+ return "nvfp4"
+ else:
+ raise ValueError("Currently, nunchaku only supports nvfp4.")
+ elif quantization_config["weight"]["dtype"] == "int4":
+ return "int4"
+ else:
+ raise ValueError(f"Unsupported quantization dtype: {quantization_config['weight']['dtype']}")
+
+
+def copy_params_into(src: nn.Module, dst: nn.Module, non_blocking: bool = True):
+ """
+ Copy all parameters and buffers from a source module to a destination module.
+
+ Parameters
+ ----------
+ src : nn.Module
+ Source module from which parameters and buffers are copied.
+ dst : nn.Module
+ Destination module to which parameters and buffers are copied.
+ non_blocking : bool, optional
+ If True, copies are performed asynchronously with respect to the host if possible (default: True).
+
+ Notes
+ -----
+ - The function assumes that `src` and `dst` have the same structure and number of parameters and buffers.
+ - All copying is performed under `torch.no_grad()` context to avoid tracking in autograd.
+ """
+ with torch.no_grad():
+ for ps, pd in zip(src.parameters(), dst.parameters()):
+ pd.copy_(ps, non_blocking=non_blocking)
+ for bs, bd in zip(src.buffers(), dst.buffers()):
+ bd.copy_(bs, non_blocking=non_blocking)
+
+ for ms, md in zip(src.modules(), dst.modules()):
+ # wtscale is a special case which is a float on the CPU
+ if hasattr(ms, "wtscale"):
+ assert hasattr(md, "wtscale")
+ md.wtscale = ms.wtscale
+ else:
+ assert not hasattr(md, "wtscale")
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..12ae3ea59f3c09b43b9e1bab9d2296cad4b1fc5a
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,51 @@
+[tool.isort]
+profile = "black"
+known_first_party = ["nunchaku"]
+line_length = 120
+
+[tool.black]
+line-length = 120
+target-version = ['py311']
+
+[tool.ruff]
+line-length = 120
+
+[project]
+dynamic = ["version"]
+name = "flux-kontext"
+description = "Optimized FLUX-Kontext implementation using quantization and acceleration techniques"
+dependencies = [
+ "diffusers>=0.35.1",
+ "transformers>=4.53.3",
+ "accelerate>=1.9.0",
+ "sentencepiece",
+ "protobuf",
+ "huggingface_hub>=0.34",
+ "peft>=0.17",
+ "einops",
+ "torch>=2.5",
+ "gradio",
+ "pillow",
+]
+requires-python = ">=3.10"
+
+[build-system]
+requires = [
+ "setuptools",
+ "torch>=2.5",
+ "wheel",
+ "ninja",
+]
+build-backend = "setuptools.build_meta"
+
+[tool.setuptools.packages.find]
+include = ["nunchaku"]
+
+[tool.doc8]
+max-line-length = 120
+ignore-path = ["docs/_build"]
+ignore = ["D000", "D001"]
+
+[tool.rstcheck]
+ignore_directives = ["tabs"]
+ignore_messages = ["ERROR/3", "INFO/1"]
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c3d6bf5934114fff3bc69abfd454015d595e7cff
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,12 @@
+diffusers>=0.35.1
+transformers>=4.53.3
+accelerate>=1.9.0
+sentencepiece
+protobuf
+huggingface_hub>=0.34
+peft>=0.17
+einops
+torch>=2.5
+gradio
+pillow
+safetensors
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..f285876d200127e145991a59ee04be7a2377a7d4
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,108 @@
+import os
+import re
+import subprocess
+import sys
+from datetime import date
+
+import setuptools
+import torch
+from packaging import version as packaging_version
+from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
+
+
+class CustomBuildExtension(BuildExtension):
+ def build_extensions(self):
+ for ext in self.extensions:
+ if not "cxx" in ext.extra_compile_args:
+ ext.extra_compile_args["cxx"] = []
+ if not "nvcc" in ext.extra_compile_args:
+ ext.extra_compile_args["nvcc"] = []
+ if self.compiler.compiler_type == "msvc":
+ ext.extra_compile_args["cxx"] += ext.extra_compile_args["msvc"]
+ ext.extra_compile_args["nvcc"] += ext.extra_compile_args["nvcc_msvc"]
+ else:
+ ext.extra_compile_args["cxx"] += ext.extra_compile_args["gcc"]
+ super().build_extensions()
+
+
+def get_sm_targets() -> list[str]:
+ nvcc_path = os.path.join(CUDA_HOME, "bin/nvcc") if CUDA_HOME else "nvcc"
+ try:
+ nvcc_output = subprocess.check_output([nvcc_path, "--version"]).decode()
+ match = re.search(r"release (\d+\.\d+), V(\d+\.\d+\.\d+)", nvcc_output)
+ if match:
+ nvcc_version = match.group(2)
+ else:
+ raise Exception("nvcc version not found")
+ print(f"Found nvcc version: {nvcc_version}")
+ except:
+ raise Exception("nvcc not found")
+
+ support_sm120 = packaging_version.parse(nvcc_version) >= packaging_version.parse("12.8")
+
+ install_mode = os.getenv("NUNCHAKU_INSTALL_MODE", "FAST")
+ if install_mode == "FAST":
+ ret = []
+ for i in range(torch.cuda.device_count()):
+ capability = torch.cuda.get_device_capability(i)
+ sm = f"{capability[0]}{capability[1]}"
+ if sm == "120" and support_sm120:
+ sm = "120a"
+ ret.append(sm)
+ return ret
+ elif install_mode == "ALL":
+ # All supported architectures (except for experimental ones)
+ sm_targets = ["75", "80", "86", "89", "90"]
+ if support_sm120:
+ sm_targets.append("120a")
+ return sm_targets
+ else:
+ raise ValueError(f"Unknown install mode: {install_mode}")
+
+
+FLUX_SOURCES = [
+ "nunchaku/csrc/pybind.cpp",
+]
+
+ext_modules = []
+
+# Check if CUDA is available
+if torch.cuda.is_available() and CUDA_HOME is not None:
+ sm_targets = get_sm_targets()
+ arch_flags = [f"-gencode=arch=compute_{sm},code=sm_{sm}" for sm in sm_targets]
+
+ ext_modules.append(
+ CUDAExtension(
+ "nunchaku._C",
+ FLUX_SOURCES,
+ extra_compile_args={
+ "cxx": ["-O3", "-std=c++20"],
+ "nvcc": [
+ "-O3",
+ "-std=c++20",
+ "--use_fast_math",
+ "-U__CUDA_NO_HALF_OPERATORS__",
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
+ "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
+ "-U__CUDA_NO_HALF2_OPERATORS__",
+ ] + arch_flags,
+ "msvc": ["/std:c++20"],
+ "gcc": ["-std=c++20"],
+ "nvcc_msvc": [],
+ },
+ include_dirs=[
+ "third_party/cutlass/include",
+ "third_party/cutlass/tools/util/include",
+ ],
+ )
+ )
+else:
+ print("CUDA not available. Installing CPU-only version.")
+
+setuptools.setup(
+ name="flux-kontext",
+ packages=setuptools.find_packages(),
+ ext_modules=ext_modules,
+ cmdclass={"build_ext": CustomBuildExtension},
+ zip_safe=False,
+)
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1d80bfb08f3ce1206751a84212253dc7719765b
--- /dev/null
+++ b/tests/__init__.py
@@ -0,0 +1 @@
+# FLUX-Kontext test suite
diff --git a/tests/test_flux_kontext.py b/tests/test_flux_kontext.py
new file mode 100644
index 0000000000000000000000000000000000000000..557ca260c09c328cdeb59af8064dfe0a4364f6c8
--- /dev/null
+++ b/tests/test_flux_kontext.py
@@ -0,0 +1,94 @@
+import gc
+import os
+from pathlib import Path
+
+import pytest
+import torch
+from diffusers import FluxKontextPipeline
+from diffusers.utils import load_image
+
+from nunchaku import NunchakuFluxTransformer2dModel
+from nunchaku.utils import get_precision, is_turing
+
+from .utils import already_generate, compute_lpips, hash_str_to_int, offload_pipeline
+
+
+@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
+@pytest.mark.parametrize("expected_lpips", [0.25 if get_precision() == "int4" else 0.18])
+def test_flux_kontext(expected_lpips: float):
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ precision = get_precision()
+
+ ref_root = Path(os.environ.get("NUNCHAKU_TEST_CACHE_ROOT", os.path.join("test_results", "ref")))
+ results_dir_16_bit = ref_root / "bf16" / "flux.1-kontext-dev" / "kontext"
+ results_dir_4_bit = Path("test_results") / precision / "flux.1-kontext-dev" / "kontext"
+
+ os.makedirs(results_dir_16_bit, exist_ok=True)
+ os.makedirs(results_dir_4_bit, exist_ok=True)
+
+ image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
+ ).convert("RGB")
+ prompts = [
+ "Make Pikachu hold a sign that says 'Nunchaku is awesome', yarn art style, detailed, vibrant colors",
+ "Convert the image to ghibli style",
+ "help me convert it to manga style",
+ "Convert it to a realistic photo",
+ ]
+
+ # First, generate results with the 16-bit model
+ if not already_generate(results_dir_16_bit, 4):
+ pipeline = FluxKontextPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
+ )
+
+ # Possibly offload the model to CPU when GPU memory is scarce
+ pipeline = offload_pipeline(pipeline)
+
+ for prompt in prompts:
+ seed = hash_str_to_int(prompt)
+ result = pipeline(image=image, prompt=prompt, generator=torch.Generator().manual_seed(seed)).images[0]
+ result.save(os.path.join(results_dir_16_bit, f"{seed}.png"))
+
+ # Clean up the 16-bit model
+ del pipeline.transformer
+ del pipeline.text_encoder
+ del pipeline.text_encoder_2
+ del pipeline.vae
+ del pipeline
+ del result
+ gc.collect()
+ torch.cuda.synchronize()
+ torch.cuda.empty_cache()
+
+ free, total = torch.cuda.mem_get_info() # bytes
+ print(f"After 16-bit generation: Free: {free/1024**2:.0f} MB / Total: {total/1024**2:.0f} MB")
+
+ # Then, generate results with the 4-bit model
+ if not already_generate(results_dir_4_bit, 4):
+ transformer = NunchakuFluxTransformer2dModel.from_pretrained(
+ f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{precision}_r32-flux.1-kontext-dev.safetensors"
+ )
+ pipeline = FluxKontextPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
+ ).to("cuda")
+ for prompt in prompts:
+ seed = hash_str_to_int(prompt)
+ result = pipeline(image=image, prompt=prompt, generator=torch.Generator().manual_seed(seed)).images[0]
+ result.save(os.path.join(results_dir_4_bit, f"{seed}.png"))
+
+ # Clean up the 4-bit model
+ del pipeline
+ del transformer
+ gc.collect()
+ torch.cuda.synchronize()
+ torch.cuda.empty_cache()
+
+ free, total = torch.cuda.mem_get_info() # bytes
+ print(f"After 4-bit generation: Free: {free/1024**2:.0f} MB / Total: {total/1024**2:.0f} MB")
+
+ lpips = compute_lpips(results_dir_16_bit, results_dir_4_bit)
+ print(f"lpips: {lpips}")
+ assert lpips < expected_lpips * 1.15
diff --git a/tests/test_flux_kontext_lora.py b/tests/test_flux_kontext_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a81370b077236114be4589b0822f42e8d3cecde
--- /dev/null
+++ b/tests/test_flux_kontext_lora.py
@@ -0,0 +1,228 @@
+"""
+Test LoRA functionality for FLUX.1-Kontext model
+"""
+
+import gc
+import os
+from pathlib import Path
+
+import numpy as np
+import pytest
+import torch
+from diffusers import FluxKontextPipeline
+from diffusers.utils import load_image
+from PIL import Image
+
+from nunchaku import NunchakuFluxTransformer2dModel
+from nunchaku.utils import get_precision, is_turing
+
+
+def compute_pixel_difference(img1_path: str, img2_path: str) -> dict:
+ """Compute pixel-level differences between two images"""
+ img1 = np.array(Image.open(img1_path)).astype(float)
+ img2 = np.array(Image.open(img2_path)).astype(float)
+
+ diff = np.abs(img1 - img2)
+
+ return {
+ "mean_diff": np.mean(diff),
+ "max_diff": np.max(diff),
+ "pixels_changed": np.mean(diff > 0) * 100,
+ "pixels_changed_significantly": np.mean(diff > 10) * 100,
+ }
+
+
+@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
+def test_kontext_lora_application():
+ """Test that LoRA weights are properly applied to Kontext model"""
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ precision = get_precision()
+
+ # Setup directories
+ results_dir = Path("test_results") / precision / "flux.1-kontext-dev" / "lora_test"
+ os.makedirs(results_dir, exist_ok=True)
+
+ # Load test image
+ image = load_image(
+ "https://huggingface.co/datasets/nunchaku-tech/test-data/resolve/main/ComfyUI-nunchaku/inputs/monalisa.jpg"
+ ).convert("RGB")
+
+ prompt = "neon light, city atmosphere"
+ seed = 42
+ num_inference_steps = 28
+ guidance_scale = 2.5
+
+ # Load model
+ transformer = NunchakuFluxTransformer2dModel.from_pretrained(
+ f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{precision}_r32-flux.1-kontext-dev.safetensors"
+ )
+
+ pipeline = FluxKontextPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-Kontext-dev", transformer=transformer, torch_dtype=torch.bfloat16
+ ).to("cuda")
+
+ # Test 1: Generate without LoRA
+ generator = torch.Generator().manual_seed(seed)
+ result_no_lora = pipeline(
+ image=image,
+ prompt=prompt,
+ generator=generator,
+ guidance_scale=guidance_scale,
+ num_inference_steps=num_inference_steps,
+ ).images[0]
+ result_no_lora.save(results_dir / "no_lora.png")
+
+ # Test 2: Apply LoRA and generate
+ transformer.update_lora_params(
+ "nunchaku-tech/nunchaku-test-models/relight-kontext-lora-single-caption_comfy.safetensors"
+ # linoyts/relight-kontext-lora-single-caption/relight-kontext-lora-single-caption.safetensors"
+ )
+ transformer.set_lora_strength(1.0)
+
+ generator = torch.Generator().manual_seed(seed)
+ result_lora_1 = pipeline(
+ image=image,
+ prompt=prompt,
+ generator=generator,
+ guidance_scale=guidance_scale,
+ num_inference_steps=num_inference_steps,
+ ).images[0]
+ result_lora_1.save(results_dir / "lora_1.0.png")
+
+ # Test 3: Change LoRA strength
+ transformer.set_lora_strength(2.0)
+
+ generator = torch.Generator().manual_seed(seed)
+ result_lora_2 = pipeline(
+ image=image,
+ prompt=prompt,
+ generator=generator,
+ guidance_scale=guidance_scale,
+ num_inference_steps=num_inference_steps,
+ ).images[0]
+ result_lora_2.save(results_dir / "lora_2.0.png")
+
+ # Test 4: Disable LoRA
+ transformer.set_lora_strength(0.0)
+
+ generator = torch.Generator().manual_seed(seed)
+ result_lora_0 = pipeline(
+ image=image,
+ prompt=prompt,
+ generator=generator,
+ guidance_scale=guidance_scale,
+ num_inference_steps=num_inference_steps,
+ ).images[0]
+ result_lora_0.save(results_dir / "lora_0.0.png")
+
+ # Compute differences
+ diff_1 = compute_pixel_difference(results_dir / "no_lora.png", results_dir / "lora_1.0.png")
+
+ diff_2 = compute_pixel_difference(results_dir / "no_lora.png", results_dir / "lora_2.0.png")
+
+ diff_0 = compute_pixel_difference(results_dir / "no_lora.png", results_dir / "lora_0.0.png")
+
+ diff_scale = compute_pixel_difference(results_dir / "lora_1.0.png", results_dir / "lora_2.0.png")
+
+ # Assertions
+ # LoRA 1.0 should change the output
+ assert diff_1["mean_diff"] > 1.0, "LoRA 1.0 should significantly change the output"
+ assert diff_1["pixels_changed"] > 50, "LoRA 1.0 should change more than 50% of pixels"
+
+ # LoRA 2.0 should have a significant effect (but not necessarily stronger than 1.0 due to saturation)
+ assert diff_2["mean_diff"] > 1.0, "LoRA 2.0 should significantly change the output"
+
+ # Different LoRA strengths should produce different results
+ assert diff_scale["mean_diff"] > 1.0, "Different LoRA strengths should produce different results"
+
+ # Log the actual differences for debugging
+ print(f"LoRA 1.0 vs baseline difference: {diff_1['mean_diff']:.2f}")
+ print(f"LoRA 2.0 vs baseline difference: {diff_2['mean_diff']:.2f}")
+ print(f"LoRA 1.0 vs 2.0 difference: {diff_scale['mean_diff']:.2f}")
+
+ # Note: We're not asserting that LoRA 0.0 matches baseline due to known issue
+ # where LoRA weights may not be fully removed when strength=0.0
+ print(f"LoRA 0.0 vs baseline difference: {diff_0['mean_diff']:.2f}")
+ if diff_0["mean_diff"] > 1.0:
+ print("WARNING: LoRA 0.0 differs from baseline - LoRA may not be fully disabled")
+
+ # Clean up
+ del pipeline
+ del transformer
+ gc.collect()
+ torch.cuda.empty_cache()
+
+
+@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
+@pytest.mark.parametrize(
+ "lora_strength,expected_change",
+ [
+ (0.5, 1.0), # Medium strength should cause moderate change
+ (1.0, 1.5), # Full strength should cause significant change
+ (1.5, 2.0), # Over-strength should cause larger change
+ ],
+)
+def test_kontext_lora_strength_scaling(lora_strength, expected_change):
+ """Test that LoRA strength scaling works proportionally"""
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ precision = get_precision()
+
+ # Load model
+ transformer = NunchakuFluxTransformer2dModel.from_pretrained(
+ f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{precision}_r32-flux.1-kontext-dev.safetensors"
+ )
+
+ pipeline = FluxKontextPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-Kontext-dev", transformer=transformer, torch_dtype=torch.bfloat16
+ ).to("cuda")
+
+ # Load test image
+ image = load_image(
+ "https://huggingface.co/datasets/nunchaku-tech/test-data/resolve/main/ComfyUI-nunchaku/inputs/monalisa.jpg"
+ ).convert("RGB")
+
+ prompt = "dramatic lighting, cinematic"
+ seed = 123
+
+ # Generate baseline
+ generator = torch.Generator().manual_seed(seed)
+ baseline = pipeline(image=image, prompt=prompt, generator=generator, num_inference_steps=20).images[0]
+
+ transformer.update_lora_params(
+ "nunchaku-tech/nunchaku-test-models/relight-kontext-lora-single-caption_comfy.safetensors"
+ # "linoyts/relight-kontext-lora-single-caption/relight-kontext-lora-single-caption.safetensors"
+ )
+ transformer.set_lora_strength(lora_strength)
+
+ # Generate with LoRA
+ generator = torch.Generator().manual_seed(seed)
+ with_lora = pipeline(image=image, prompt=prompt, generator=generator, num_inference_steps=20).images[0]
+
+ # Compute difference
+ baseline_arr = np.array(baseline).astype(float)
+ lora_arr = np.array(with_lora).astype(float)
+ mean_diff = np.mean(np.abs(baseline_arr - lora_arr))
+
+ # Assert that change is proportional to strength
+ # Allow 50% tolerance due to non-linear effects
+ assert (
+ mean_diff > expected_change * 0.5
+ ), f"LoRA strength {lora_strength} should cause mean difference > {expected_change * 0.5}, got {mean_diff}"
+
+ print(f"LoRA strength {lora_strength}: mean difference = {mean_diff:.2f}")
+
+ # Clean up
+ del pipeline
+ del transformer
+ gc.collect()
+ torch.cuda.empty_cache()
+
+
+if __name__ == "__main__":
+ test_kontext_lora_application()
+ for strength, expected in [(0.5, 1.0), (1.0, 1.5), (1.5, 2.0)]:
+ test_kontext_lora_strength_scaling(strength, expected)
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..35956ee7688adc6fd052930dabc9cbfa95eae889
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,14 @@
+import argparse
+
+
+def get_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-p", "--precision", type=str, default="int4", choices=["int4", "fp4", "bf16"], help="Which precisions to use"
+ )
+ parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
+ parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
+ parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
+ parser.add_argument("--gradio-root-path", type=str, default="")
+ args = parser.parse_args()
+ return args