diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..dc4dab5710f76ef392a8e80791479e040e335025 --- /dev/null +++ b/README.md @@ -0,0 +1,102 @@ +# FLUX-Kontext Optimized Implementation + +This package contains an optimized implementation of FLUX-Kontext using quantization and acceleration techniques. + +## Features + +- **Quantized FLUX Transformer**: Efficient INT4/FP4 quantized implementation of FLUX.1-Kontext +- **Quantized T5 Encoder**: AWQ INT4 quantized T5 text encoder for memory efficiency +- **LoRA Support**: Full support for LoRA fine-tuning and inference +- **Gradio Web Interface**: Ready-to-use web interface for image editing + +## Installation + +```bash +pip install -r requirements.txt +python setup.py build_ext --inplace +``` + +## Quick Start + +### Using the Gradio Interface + +```bash +cd app/kontext +python run_gradio.py --precision int4 +``` + +### Programmatic Usage + +```python +import torch +from diffusers import FluxKontextPipeline +from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel +from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel + +# Load quantized transformer +transformer = NunchakuFluxTransformer2dModel.from_pretrained( + "mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-int4_r32-flux.1-kontext-dev.safetensors" +) + +# Load quantized text encoder (optional) +text_encoder_2 = NunchakuT5EncoderModel.from_pretrained( + "mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors" +) + +# Create pipeline +pipeline = FluxKontextPipeline.from_pretrained( + "black-forest-labs/FLUX.1-Kontext-dev", + transformer=transformer, + text_encoder_2=text_encoder_2, + torch_dtype=torch.bfloat16 +) +pipeline = pipeline.to("cuda") + +# Generate image +result = pipeline( + prompt="Your prompt here", + image=your_input_image, + num_inference_steps=28, + guidance_scale=2.5, +).images[0] +``` + +## Available Models + +- `int4`: INT4 quantized transformer (default, most memory efficient) +- `fp4`: FP4 quantized transformer +- `bf16`: Full precision BFloat16 (highest quality, most memory usage) + +## Directory Structure + +``` +flux-kontext/ +├── nunchaku/ # Core quantized models and utilities +│ ├── models/ # Transformer and text encoder models +│ ├── lora/ # LoRA utilities +│ ├── ops/ # Quantized operations +│ └── csrc/ # C++ CUDA kernels +├── app/ # Application interfaces +│ └── kontext/ # Gradio web interface +├── examples/ # Example scripts +└── tests/ # Test scripts +``` + +## Examples + +See the `examples/` directory for various usage patterns: + +- `flux.1-kontext-dev.py`: Basic usage example +- `flux.1-kontext-dev-teacache.py`: Using TeaCache for acceleration +- `flux.1-kontext-FALAI_lora.py`: LoRA fine-tuning example + +## Requirements + +- Python >= 3.10 +- PyTorch >= 2.5 +- CUDA-capable GPU (recommended) +- 8GB+ GPU memory (for INT4 quantization) + +## License + +See the main nunchaku project for license information. diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..b5c7f9d040f03eb1b5df4784d4331136cd362833 --- /dev/null +++ b/app.py @@ -0,0 +1,150 @@ +import random +import time + +import torch +from diffusers import FluxKontextPipeline +from PIL import Image +from utils import get_args +from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel +from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel + + +import gradio as gr + + +MAX_SEED = 1000000000 + +args = get_args() + +if args.precision == "bf16": + pipeline = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16) + pipeline = pipeline.to("cuda") + pipeline.precision = "bf16" +else: + assert args.precision in ["int4", "fp4"] + pipeline_init_kwargs = {} + transformer = NunchakuFluxTransformer2dModel.from_pretrained( + f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{args.precision}_r32-flux.1-kontext-dev.safetensors" + ) + pipeline_init_kwargs["transformer"] = transformer + if args.use_qencoder: + text_encoder_2 = NunchakuT5EncoderModel.from_pretrained( + "mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors" + ) + pipeline_init_kwargs["text_encoder_2"] = text_encoder_2 + + pipeline = FluxKontextPipeline.from_pretrained( + "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs + ) + pipeline = pipeline.to("cuda") + pipeline.precision = args.precision + + +def run(image, prompt: str, num_inference_steps: int, guidance_scale: float, seed: int) -> tuple[Image, str]: + img = image["composite"].convert("RGB") + + start_time = time.time() + result_image = pipeline( + prompt=prompt, + image=img, + height=img.height, + width=img.width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + generator=torch.Generator().manual_seed(seed), + ).images[0] + + latency = time.time() - start_time + if latency < 1: + latency = latency * 1000 + latency_str = f"{latency:.2f}ms" + else: + latency_str = f"{latency:.2f}s" + torch.cuda.empty_cache() + return result_image, latency_str + + +with gr.Blocks(css_paths="assets/style.css", title="Nunchaku FLUX.1-Kontext Demo") as demo: + with open("assets/description.html", "r") as f: + DESCRIPTION = f.read() + # Get the GPU properties + if torch.cuda.device_count() > 0: + gpu_properties = torch.cuda.get_device_properties(0) + gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB + gpu_name = torch.cuda.get_device_name(0) + device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory." + else: + device_info = "Running on CPU 🥶 This demo does not work on CPU." + + header_str = DESCRIPTION.format(precision=args.precision.upper(), device_info=device_info, count_info="") + header = gr.HTML(header_str) + + with gr.Row(elem_id="main_row"): + with gr.Column(elem_id="column_input"): + gr.Markdown("## INPUT", elem_id="input_header") + with gr.Group(): + canvas = gr.ImageEditor( + height=640, + image_mode="RGB", + sources=["upload", "clipboard"], + type="pil", + label="Input", + show_label=False, + show_download_button=True, + interactive=True, + transforms=[], + canvas_size=(1024, 1024), + scale=1, + format="png", + layers=False, + ) + with gr.Row(): + prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6) + run_button = gr.Button("Run", scale=1, elem_id="run_button") + + with gr.Row(): + seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4) + randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed") + with gr.Accordion("Advanced options", open=False): + with gr.Group(): + num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, value=28) + guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, step=0.1, value=2.5) + + with gr.Column(elem_id="column_output"): + gr.Markdown("## OUTPUT", elem_id="output_header") + with gr.Group(): + result = gr.Image( + format="png", + height=640, + image_mode="RGB", + type="pil", + label="Result", + show_label=False, + show_download_button=True, + interactive=False, + elem_id="output_image", + ) + latency_result = gr.Text(label="Inference Latency", show_label=True) + + gr.Markdown("### Instructions") + gr.Markdown("**1**. Enter a text prompt") + gr.Markdown("**2**. Upload an image") + gr.Markdown("**3**. Try different seeds to generate different results") + + run_inputs = [canvas, prompt, num_inference_steps, guidance_scale, seed] + run_outputs = [result, latency_result] + + randomize_seed.click( + lambda: random.randint(0, MAX_SEED), inputs=[], outputs=seed, api_name=False, queue=False + ).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False) + + gr.on( + triggers=[prompt.submit, run_button.click], + fn=run, + inputs=run_inputs, + outputs=run_outputs, + api_name=False, + ) + +if __name__ == "__main__": + demo.queue().launch(debug=True, share=True, root_path=args.gradio_root_path) diff --git a/app/kontext/README.md b/app/kontext/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f97be1005826aa11a86e1da103f150e8169f70e5 --- /dev/null +++ b/app/kontext/README.md @@ -0,0 +1,12 @@ +# Nunchaku INT4 FLUX.1 Kontext Demo + +![demo](https://huggingface.co/mit-han-lab/nunchaku-artifacts/resolve/main/nunchaku/assets/kontext.png) + +This interactive Gradio application allows you to edit an image with natural language. Simply run: + +```shell +python run_gradio.py +``` + +- To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`. +- By default, we use our INT4 model. Use `-p bf16` to switch to the BF16 model. diff --git a/app/kontext/assets/description.html b/app/kontext/assets/description.html new file mode 100644 index 0000000000000000000000000000000000000000..e4d791567cdda3def8cc8f5f248690ead6fcf60c --- /dev/null +++ b/app/kontext/assets/description.html @@ -0,0 +1,21 @@ +
+
+ +
+ + nunchaku logo + + + svdquant logo + +
+

{precision} FLUX.1-Kontext-dev Demo

+ +
+ {device_info} +
+ {count_info} +
+
diff --git a/app/kontext/assets/style.css b/app/kontext/assets/style.css new file mode 100644 index 0000000000000000000000000000000000000000..6df54d33fe823c2dcb49ead8b33fb24fc0037bc1 --- /dev/null +++ b/app/kontext/assets/style.css @@ -0,0 +1,40 @@ +@import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/css/all.min.css'); + +.gradio-container { + max-width: 1200px !important; + margin: auto; /* Centers the element horizontally */ +} + +h1 { + text-align: center +} + +.wrap.svelte-p4aq0j.svelte-p4aq0j { + display: none; +} + +#column_input, #column_output { + width: 500px; + display: flex; + align-items: center; +} + +#input_header, #output_header { + display: flex; + justify-content: center; + align-items: center; + width: 400px; +} + +#accessibility { + text-align: center; /* Center-aligns the text */ + margin: auto; /* Centers the element horizontally */ +} + +#random_seed { + height: 71px; +} + +#run_button { + height: 87px; +} diff --git a/app/kontext/run_gradio.py b/app/kontext/run_gradio.py new file mode 100644 index 0000000000000000000000000000000000000000..b5c7f9d040f03eb1b5df4784d4331136cd362833 --- /dev/null +++ b/app/kontext/run_gradio.py @@ -0,0 +1,150 @@ +import random +import time + +import torch +from diffusers import FluxKontextPipeline +from PIL import Image +from utils import get_args +from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel +from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel + + +import gradio as gr + + +MAX_SEED = 1000000000 + +args = get_args() + +if args.precision == "bf16": + pipeline = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16) + pipeline = pipeline.to("cuda") + pipeline.precision = "bf16" +else: + assert args.precision in ["int4", "fp4"] + pipeline_init_kwargs = {} + transformer = NunchakuFluxTransformer2dModel.from_pretrained( + f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{args.precision}_r32-flux.1-kontext-dev.safetensors" + ) + pipeline_init_kwargs["transformer"] = transformer + if args.use_qencoder: + text_encoder_2 = NunchakuT5EncoderModel.from_pretrained( + "mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors" + ) + pipeline_init_kwargs["text_encoder_2"] = text_encoder_2 + + pipeline = FluxKontextPipeline.from_pretrained( + "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs + ) + pipeline = pipeline.to("cuda") + pipeline.precision = args.precision + + +def run(image, prompt: str, num_inference_steps: int, guidance_scale: float, seed: int) -> tuple[Image, str]: + img = image["composite"].convert("RGB") + + start_time = time.time() + result_image = pipeline( + prompt=prompt, + image=img, + height=img.height, + width=img.width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + generator=torch.Generator().manual_seed(seed), + ).images[0] + + latency = time.time() - start_time + if latency < 1: + latency = latency * 1000 + latency_str = f"{latency:.2f}ms" + else: + latency_str = f"{latency:.2f}s" + torch.cuda.empty_cache() + return result_image, latency_str + + +with gr.Blocks(css_paths="assets/style.css", title="Nunchaku FLUX.1-Kontext Demo") as demo: + with open("assets/description.html", "r") as f: + DESCRIPTION = f.read() + # Get the GPU properties + if torch.cuda.device_count() > 0: + gpu_properties = torch.cuda.get_device_properties(0) + gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB + gpu_name = torch.cuda.get_device_name(0) + device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory." + else: + device_info = "Running on CPU 🥶 This demo does not work on CPU." + + header_str = DESCRIPTION.format(precision=args.precision.upper(), device_info=device_info, count_info="") + header = gr.HTML(header_str) + + with gr.Row(elem_id="main_row"): + with gr.Column(elem_id="column_input"): + gr.Markdown("## INPUT", elem_id="input_header") + with gr.Group(): + canvas = gr.ImageEditor( + height=640, + image_mode="RGB", + sources=["upload", "clipboard"], + type="pil", + label="Input", + show_label=False, + show_download_button=True, + interactive=True, + transforms=[], + canvas_size=(1024, 1024), + scale=1, + format="png", + layers=False, + ) + with gr.Row(): + prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6) + run_button = gr.Button("Run", scale=1, elem_id="run_button") + + with gr.Row(): + seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4) + randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed") + with gr.Accordion("Advanced options", open=False): + with gr.Group(): + num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, value=28) + guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, step=0.1, value=2.5) + + with gr.Column(elem_id="column_output"): + gr.Markdown("## OUTPUT", elem_id="output_header") + with gr.Group(): + result = gr.Image( + format="png", + height=640, + image_mode="RGB", + type="pil", + label="Result", + show_label=False, + show_download_button=True, + interactive=False, + elem_id="output_image", + ) + latency_result = gr.Text(label="Inference Latency", show_label=True) + + gr.Markdown("### Instructions") + gr.Markdown("**1**. Enter a text prompt") + gr.Markdown("**2**. Upload an image") + gr.Markdown("**3**. Try different seeds to generate different results") + + run_inputs = [canvas, prompt, num_inference_steps, guidance_scale, seed] + run_outputs = [result, latency_result] + + randomize_seed.click( + lambda: random.randint(0, MAX_SEED), inputs=[], outputs=seed, api_name=False, queue=False + ).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False) + + gr.on( + triggers=[prompt.submit, run_button.click], + fn=run, + inputs=run_inputs, + outputs=run_outputs, + api_name=False, + ) + +if __name__ == "__main__": + demo.queue().launch(debug=True, share=True, root_path=args.gradio_root_path) diff --git a/app/kontext/utils.py b/app/kontext/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..35956ee7688adc6fd052930dabc9cbfa95eae889 --- /dev/null +++ b/app/kontext/utils.py @@ -0,0 +1,14 @@ +import argparse + + +def get_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "-p", "--precision", type=str, default="int4", choices=["int4", "fp4", "bf16"], help="Which precisions to use" + ) + parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder") + parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker") + parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses") + parser.add_argument("--gradio-root-path", type=str, default="") + args = parser.parse_args() + return args diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9f60b2ebcab3d03d5863d66723ffc472e1c55ccb --- /dev/null +++ b/examples/__init__.py @@ -0,0 +1 @@ +# FLUX-Kontext examples diff --git a/examples/flux.1-kontext-FALAI_lora.py b/examples/flux.1-kontext-FALAI_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..8cf9d30808845fc0799983cc8be83c0b3a98f744 --- /dev/null +++ b/examples/flux.1-kontext-FALAI_lora.py @@ -0,0 +1,30 @@ +import torch +from diffusers import FluxKontextPipeline +from diffusers.utils import load_image + +from nunchaku import NunchakuFluxTransformer2dModel +from nunchaku.utils import get_precision + +transformer = NunchakuFluxTransformer2dModel.from_pretrained( + f"nunchaku-tech/nunchaku-flux.1-kontext-dev/svdq-{get_precision()}_r32-flux.1-kontext-dev.safetensors" +) + +pipeline = FluxKontextPipeline.from_pretrained( + "black-forest-labs/FLUX.1-Kontext-dev", transformer=transformer, torch_dtype=torch.bfloat16 +).to("cuda") + +image = load_image( + "https://huggingface.co/datasets/nunchaku-tech/test-data/resolve/main/ComfyUI-nunchaku/inputs/monalisa.jpg" +).convert("RGB") + +### LoRA Related Code ### +transformer.update_lora_params( + "nunchaku-tech/nunchaku-test-models/relight-kontext-lora-single-caption_comfy.safetensors" + # "linoyts/relight-kontext-lora-single-caption/relight-kontext-lora-single-caption.safetensors" +) # Path to your LoRA safetensors, can also be a remote HuggingFace path +transformer.set_lora_strength(1) # Your LoRA strength here +### End of LoRA Related Code ### + +prompt = "neon light, city" +image = pipeline(image=image, prompt=prompt, generator=torch.Generator().manual_seed(23), guidance_scale=2.5).images[0] +image.save("flux-kontext-dev.png") diff --git a/examples/flux.1-kontext-dev-teacache.py b/examples/flux.1-kontext-dev-teacache.py new file mode 100644 index 0000000000000000000000000000000000000000..f9f40c61677bf522231b3edb498e9606a641897a --- /dev/null +++ b/examples/flux.1-kontext-dev-teacache.py @@ -0,0 +1,30 @@ +import time + +import torch +from diffusers import FluxKontextPipeline +from diffusers.utils import load_image + +from nunchaku import NunchakuFluxTransformer2dModel +from nunchaku.caching.teacache import TeaCache +from nunchaku.utils import get_precision + +transformer = NunchakuFluxTransformer2dModel.from_pretrained( + f"nunchaku-tech/nunchaku-flux.1-kontext-dev/svdq-{get_precision()}_r32-flux.1-kontext-dev.safetensors" +) + +pipeline = FluxKontextPipeline.from_pretrained( + "black-forest-labs/FLUX.1-Kontext-dev", transformer=transformer, torch_dtype=torch.bfloat16 +).to("cuda") + +image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png" +).convert("RGB") + +prompt = "Make Pikachu hold a sign that says 'Nunchaku is awesome', yarn art style, detailed, vibrant colors" + +start_time = time.time() +with TeaCache(model=transformer, num_steps=50, rel_l1_thresh=0.3, enabled=True, model_name="flux-kontext"): + image = pipeline(image=image, prompt=prompt, guidance_scale=2.5).images[0] +end_time = time.time() +print(f"Time taken: {(end_time - start_time)} seconds") +image.save(f"flux-kontext-dev-{get_precision()}-tc.png") diff --git a/examples/flux.1-kontext-dev.py b/examples/flux.1-kontext-dev.py new file mode 100644 index 0000000000000000000000000000000000000000..2868d582f158d49c16d27ee39a8b0b40d1b3db6d --- /dev/null +++ b/examples/flux.1-kontext-dev.py @@ -0,0 +1,22 @@ +import torch +from diffusers import FluxKontextPipeline +from diffusers.utils import load_image + +from nunchaku import NunchakuFluxTransformer2dModel +from nunchaku.utils import get_precision + +transformer = NunchakuFluxTransformer2dModel.from_pretrained( + f"nunchaku-tech/nunchaku-flux.1-kontext-dev/svdq-{get_precision()}_r32-flux.1-kontext-dev.safetensors" +) + +pipeline = FluxKontextPipeline.from_pretrained( + "black-forest-labs/FLUX.1-Kontext-dev", transformer=transformer, torch_dtype=torch.bfloat16 +).to("cuda") + +image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png" +).convert("RGB") + +prompt = "Make Pikachu hold a sign that says 'Nunchaku is awesome', yarn art style, detailed, vibrant colors" +image = pipeline(image=image, prompt=prompt, guidance_scale=2.5).images[0] +image.save("flux-kontext-dev.png") diff --git a/nunchaku/__init__.py b/nunchaku/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..93a19f86d836471ac684e7fbf4d7020e69d7470f --- /dev/null +++ b/nunchaku/__init__.py @@ -0,0 +1,9 @@ +from .models import ( + NunchakuFluxTransformer2dModel, + NunchakuT5EncoderModel, +) + +__all__ = [ + "NunchakuFluxTransformer2dModel", + "NunchakuT5EncoderModel", +] diff --git a/nunchaku/__version__.py b/nunchaku/__version__.py new file mode 100644 index 0000000000000000000000000000000000000000..b4deaef52f24d6a5663ae2d6b3864b255c02b92b --- /dev/null +++ b/nunchaku/__version__.py @@ -0,0 +1 @@ +__version__ = "1.0.0-flux-kontext" diff --git a/nunchaku/csrc/flux.h b/nunchaku/csrc/flux.h new file mode 100644 index 0000000000000000000000000000000000000000..f2488400b0a285f34ee1b52ff96e194f7f86f029 --- /dev/null +++ b/nunchaku/csrc/flux.h @@ -0,0 +1,254 @@ +#pragma once + +#include "interop/torch.h" +#include "FluxModel.h" +#include "Serialization.h" +#include "debug.h" +#include "Linear.h" +#include "module.h" + +class QuantizedFluxModel : public ModuleWrapper { // : public torch::CustomClassHolder { +public: + void init(bool use_fp4, bool offload, bool bf16, int8_t deviceId) { + spdlog::info("Initializing QuantizedFluxModel on device {}", deviceId); + if (!bf16) { + spdlog::info("Use FP16 model"); + } + if (offload) { + spdlog::info("Layer offloading enabled"); + } + ModuleWrapper::init(deviceId); + + CUDADeviceContext ctx(this->deviceId); + net = std::make_unique( + use_fp4, offload, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId)); + } + + bool isBF16() { + checkModel(); + return net->dtype == Tensor::BF16; + } + pybind11::function residual_callback; + void set_residual_callback(pybind11::function callback) { + pybind11::gil_scoped_acquire gil; + if (!callback || callback.is_none()) { + residual_callback = pybind11::function(); + if (net) { + net->set_residual_callback(nullptr); + } + return; + } + residual_callback = std::move(callback); + if (net) { + pybind11::object cb = residual_callback; + net->set_residual_callback([cb](const Tensor &x) -> Tensor { + torch::Tensor torch_x = to_torch(x); + pybind11::object result = cb(torch_x); + torch::Tensor torch_y = result.cast(); + Tensor y = from_torch(torch_y); + return y; + }); + } else { + } + } + + torch::Tensor forward(torch::Tensor hidden_states, + torch::Tensor encoder_hidden_states, + torch::Tensor temb, + torch::Tensor rotary_emb_img, + torch::Tensor rotary_emb_context, + torch::Tensor rotary_emb_single, + std::optional controlnet_block_samples = std::nullopt, + std::optional controlnet_single_block_samples = std::nullopt, + bool skip_first_layer = false) { + checkModel(); + CUDADeviceContext ctx(deviceId); + + spdlog::debug("QuantizedFluxModel forward"); + + hidden_states = hidden_states.contiguous(); + encoder_hidden_states = encoder_hidden_states.contiguous(); + temb = temb.contiguous(); + rotary_emb_img = rotary_emb_img.contiguous(); + rotary_emb_context = rotary_emb_context.contiguous(); + rotary_emb_single = rotary_emb_single.contiguous(); + + Tensor result = net->forward( + from_torch(hidden_states), + from_torch(encoder_hidden_states), + from_torch(temb), + from_torch(rotary_emb_img), + from_torch(rotary_emb_context), + from_torch(rotary_emb_single), + controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{}, + controlnet_single_block_samples.has_value() + ? from_torch(controlnet_single_block_samples.value().contiguous()) + : Tensor{}, + skip_first_layer); + + torch::Tensor output = to_torch(result); + Tensor::synchronizeDevice(); + + return output; + } + + std::tuple + forward_layer(int64_t idx, + torch::Tensor hidden_states, + torch::Tensor encoder_hidden_states, + torch::Tensor temb, + torch::Tensor rotary_emb_img, + torch::Tensor rotary_emb_context, + std::optional controlnet_block_samples = std::nullopt, + std::optional controlnet_single_block_samples = std::nullopt) { + CUDADeviceContext ctx(deviceId); + + spdlog::debug("QuantizedFluxModel forward_layer {}", idx); + + hidden_states = hidden_states.contiguous(); + encoder_hidden_states = encoder_hidden_states.contiguous(); + temb = temb.contiguous(); + rotary_emb_img = rotary_emb_img.contiguous(); + rotary_emb_context = rotary_emb_context.contiguous(); + + auto &&[hidden_states_, encoder_hidden_states_] = net->forward_layer( + idx, + from_torch(hidden_states), + from_torch(encoder_hidden_states), + from_torch(temb), + from_torch(rotary_emb_img), + from_torch(rotary_emb_context), + controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{}, + controlnet_single_block_samples.has_value() + ? from_torch(controlnet_single_block_samples.value().contiguous()) + : Tensor{}); + + hidden_states = to_torch(hidden_states_); + encoder_hidden_states = to_torch(encoder_hidden_states_); + Tensor::synchronizeDevice(); + + return {hidden_states, encoder_hidden_states}; + } + + torch::Tensor forward_single_layer(int64_t idx, + torch::Tensor hidden_states, + torch::Tensor temb, + torch::Tensor rotary_emb_single) { + CUDADeviceContext ctx(deviceId); + + spdlog::debug("QuantizedFluxModel forward_single_layer {}", idx); + + hidden_states = hidden_states.contiguous(); + temb = temb.contiguous(); + rotary_emb_single = rotary_emb_single.contiguous(); + + if (net->isOffloadEnabled()) { + net->single_transformer_blocks.at(idx)->loadLazyParams(); + } + + Tensor result = net->single_transformer_blocks.at(idx)->forward( + from_torch(hidden_states), from_torch(temb), from_torch(rotary_emb_single)); + + if (net->isOffloadEnabled()) { + net->single_transformer_blocks.at(idx)->releaseLazyParams(); + } + + hidden_states = to_torch(result); + Tensor::synchronizeDevice(); + + return hidden_states; + } + + // expose the norm1 forward method of the transformer blocks + // this is used by TeaCache to get the norm1 output + std::tuple + norm_one_forward(int64_t idx, torch::Tensor hidden_states, torch::Tensor temb) { + AdaLayerNormZero::Output result = + net->transformer_blocks.at(idx)->norm1.forward(from_torch(hidden_states), from_torch(temb)); + return {to_torch(result.x), + to_torch(result.gate_msa), + to_torch(result.shift_mlp), + to_torch(result.scale_mlp), + to_torch(result.gate_mlp)}; + } + + // must be called after loading lora + // skip specific ranks in W4A4 layers + void setLoraScale(int skipRanks, float scale) { + if (skipRanks % 16 != 0) { + throw std::invalid_argument("skipRanks must be multiples of 16"); + } + + CUDADeviceContext ctx(deviceId); + + spdlog::info("Set lora scale to {} (skip {} ranks)", scale, skipRanks); + + net->traverse([&](Module *module) { + if (auto *m = dynamic_cast(module)) { + m->lora_scale = scale; + } else if (auto *m = dynamic_cast(module)) { + for (int i = 0; i < skipRanks / 16; i++) { + m->lora_scales[i] = 1.0f; + } + for (int i = skipRanks / 16; i < (int)m->lora_scales.size(); i++) { + m->lora_scales[i] = scale; + } + } + }); + } + + void setAttentionImpl(std::string name) { + if (name.empty() || name == "default") { + name = "flashattn2"; + } + + spdlog::info("Set attention implementation to {}", name); + + if (name == "flashattn2") { + net->setAttentionImpl(AttentionImpl::FlashAttention2); + } else if (name == "nunchaku-fp16") { + net->setAttentionImpl(AttentionImpl::NunchakuFP16); + } else { + throw std::invalid_argument(spdlog::fmt_lib::format("Invalid attention implementation {}", name)); + } + } + + std::tuple + forward_layer_ip_adapter(int64_t idx, + torch::Tensor hidden_states, + torch::Tensor encoder_hidden_states, + torch::Tensor temb, + torch::Tensor rotary_emb_img, + torch::Tensor rotary_emb_context, + std::optional controlnet_block_samples = std::nullopt, + std::optional controlnet_single_block_samples = std::nullopt) { + CUDADeviceContext ctx(deviceId); + + spdlog::debug("QuantizedFluxModel forward_layer {}", idx); + + hidden_states = hidden_states.contiguous(); + encoder_hidden_states = encoder_hidden_states.contiguous(); + temb = temb.contiguous(); + rotary_emb_img = rotary_emb_img.contiguous(); + rotary_emb_context = rotary_emb_context.contiguous(); + + auto &&[hidden_states_, encoder_hidden_states_, ip_query_] = net->forward_ip_adapter( + idx, + from_torch(hidden_states), + from_torch(encoder_hidden_states), + from_torch(temb), + from_torch(rotary_emb_img), + from_torch(rotary_emb_context), + controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{}, + controlnet_single_block_samples.has_value() + ? from_torch(controlnet_single_block_samples.value().contiguous()) + : Tensor{}); + + hidden_states = to_torch(hidden_states_); + encoder_hidden_states = to_torch(encoder_hidden_states_); + torch::Tensor ip_query = to_torch(ip_query_); + Tensor::synchronizeDevice(); + + return {hidden_states, encoder_hidden_states, ip_query}; + } +}; diff --git a/nunchaku/csrc/gemm.h b/nunchaku/csrc/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..74d1fb4b5ea15200d643fa98e406ef836ccfd307 --- /dev/null +++ b/nunchaku/csrc/gemm.h @@ -0,0 +1,114 @@ +#pragma once + +#include "interop/torch.h" +#include "Serialization.h" +#include "Linear.h" +#include "debug.h" +#include "module.h" + +class QuantizedGEMM : public ModuleWrapper { +public: + void init(int64_t in_features, int64_t out_features, bool bias, bool use_fp4, bool bf16, int8_t deviceId) { + spdlog::info("Initializing QuantizedGEMM"); + + size_t val = 0; + checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, 8192)); + checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize)); + spdlog::debug("Stack={}", val); + + net = std::make_unique((int)in_features, + (int)out_features, + bias, + use_fp4, + bf16 ? Tensor::BF16 : Tensor::FP16, + Device::cuda((int)deviceId)); + } + + torch::Tensor forward(torch::Tensor x) { + checkModel(); + + std::cerr << "QuantizedGEMM forward" << std::endl; + + x = x.contiguous(); + + Tensor result = net->forward(from_torch(x)); + + torch::Tensor output = to_torch(result); + Tensor::synchronizeDevice(); + + return output; + } + + std::string dumpTensorBF16(Tensor x) { + std::stringstream ss; + for (int i = 0; i < 256; i++) { + ss << spdlog::fmt_lib::format("{:.3f} ", (float)(x.data_ptr<__nv_bfloat16>()[i])); + } + ss << std::endl; + return ss.str(); + } + + std::string dumpTensorINT4(Tensor x) { + using spdlog::fmt_lib::format; + + const int M = x.shape[0]; + const int K = x.shape[1] * 2; + + assert(x.dtype() == Tensor::INT8); + + // activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t (uint4) + + constexpr int BLOCK_M = 256; + constexpr int WARP_K = 64; + constexpr int NUM_WARPS = 8; + constexpr int WARP_M_TILES = 2; + constexpr int WARP_SIZE = 32; + + std::stringstream ss; + for (int bm = 0; bm < M / BLOCK_M; bm++) { + for (int bn = 0; bn < K / WARP_K; bn++) { + for (int warpId = 0; warpId < NUM_WARPS; warpId++) { + ss << format("[bm={},bn={},warp={}] ", bm, bn, warpId); + const int offset = ((bm * (K / WARP_K) + bn) * NUM_WARPS + warpId) * WARP_M_TILES * WARP_SIZE * 4; + + for (int i = 0; i < 16; i++) { + assert(static_cast(offset + i) < x.numel() / 4); + uint32_t val = x.data_ptr()[offset + i]; + ss << "{"; + for (int j = 0; j < 8; j++) { + int i4val = (val >> (j * 4)) & 0xf; + if (i4val & 0x8) { + i4val = -((~i4val & 0x7) + 1); + } + ss << format("{} ", i4val); + } + ss << format("}} {:x} ", val); + } + ss << std::endl; + } + } + } + + ss << std::endl; + return ss.str(); + } + + void quantize(torch::Tensor x, bool fuse_glu) { + checkModel(); + + spdlog::debug("QuantizedGEMM quantize"); + + x = x.contiguous(); + + auto qout = net->quantize(from_torch(x), fuse_glu); + + Tensor act = qout.act.copy(Device::cpu()); + Tensor ascales = qout.ascales.copy(Device::cpu()); + Tensor lora_act = qout.lora_act.copy(Device::cpu()); + + Tensor::synchronizeDevice(); + + spdlog::debug("act = {}", dumpTensorINT4(act)); + spdlog::debug("ascales = {}", dumpTensorBF16(ascales)); + } +}; diff --git a/nunchaku/csrc/gemm88.h b/nunchaku/csrc/gemm88.h new file mode 100644 index 0000000000000000000000000000000000000000..aa3dd3175f0a4e89c42c4ff4643a339d49b9e504 --- /dev/null +++ b/nunchaku/csrc/gemm88.h @@ -0,0 +1,37 @@ +#pragma once + +#include "interop/torch.h" +#include "Serialization.h" +#include "Linear.h" +#include "debug.h" +#include "module.h" + +class QuantizedGEMM88 : public ModuleWrapper { +public: + void init(int64_t in_features, int64_t out_features, bool bias, bool bf16, int8_t deviceId) { + spdlog::info("Initializing QuantizedGEMM88"); + + size_t val = 0; + checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, 8192)); + checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize)); + spdlog::debug("Stack={}", val); + + net = std::make_unique( + (int)in_features, (int)out_features, bias, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId)); + } + + torch::Tensor forward(torch::Tensor x) { + checkModel(); + + std::cerr << "QuantizedGEMM88 forward" << std::endl; + + x = x.contiguous(); + + Tensor result = net->forward(from_torch(x)); + + torch::Tensor output = to_torch(result); + Tensor::synchronizeDevice(); + + return output; + } +}; diff --git a/nunchaku/csrc/module.h b/nunchaku/csrc/module.h new file mode 100644 index 0000000000000000000000000000000000000000..812c82720203f3abf5eb5cd3cda2bf6fc9747bce --- /dev/null +++ b/nunchaku/csrc/module.h @@ -0,0 +1,85 @@ +#pragma once + +#include "interop/torch.h" +#include "Serialization.h" +#include "Module.h" +#include "debug.h" +#include "utils.h" + +template +class ModuleWrapper { +public: + void init(int deviceId) { + this->deviceId = deviceId; + } + void reset() { + CUDADeviceContext ctx(this->deviceId); + + debugContext.reset(); + net.reset(); + Tensor::synchronizeDevice(); + + nunchaku::utils::trim_memory(); + Tensor::synchronizeDevice(); + } + + void load(std::string path, bool partial = false) { + checkModel(); + CUDADeviceContext ctx(this->deviceId); + + spdlog::info("{} weights from {}", partial ? "Loading partial" : "Loading", path); + + std::shared_ptr provider = std::make_shared(path); + net->loadParams(*provider, partial); + Tensor::synchronizeDevice(); + + spdlog::info("Done."); + } + + void loadDict(std::map dict, bool partial = false) { + checkModel(); + CUDADeviceContext ctx(this->deviceId); + + spdlog::info("{} weights from pytorch", partial ? "Loading partial" : "Loading"); + + std::shared_ptr provider = std::make_shared(std::move(dict)); + net->loadParams(*provider, partial); + Tensor::synchronizeDevice(); + + spdlog::info("Done."); + } + + void startDebug() { + debugContext = std::make_unique(); + } + void stopDebug() { + debugContext.reset(); + } + + auto getDebugResults() { + CUDADeviceContext ctx(this->deviceId); + + std::map result; + + if (debugContext) { + for (auto &&[key, value] : debugContext->tensors) { + result[key] = to_torch(value); + } + } + + return result; + } + +protected: + void checkModel() { + if (!net) { + throw std::runtime_error("Model not initialized"); + } + } + +protected: + std::unique_ptr net; + std::unique_ptr debugContext; + + int deviceId = -1; +}; diff --git a/nunchaku/csrc/ops.h b/nunchaku/csrc/ops.h new file mode 100644 index 0000000000000000000000000000000000000000..dbd15fb22ce89492e520c77775311e612c7e18ed --- /dev/null +++ b/nunchaku/csrc/ops.h @@ -0,0 +1,173 @@ +#pragma once + +#include "interop/torch.h" +#include "kernels/zgemm/zgemm.h" +#include "kernels/awq/gemv_awq.h" +#include "kernels/awq/gemm_awq.h" + +namespace nunchaku::ops { + +void gemm_w4a4(std::optional act, // packed act [M, K / 2] + std::optional wgt, // packed act [N, K / 2] + std::optional out, // linear [M, N] + std::optional qout, // packed act [M, N / 2] + std::optional ascales, // packed as [K / 64, M] + std::optional wscales, // packed ws [K / 64, N] + std::optional oscales, // packed as [N / 64, M] + std::optional poolout, // linear [M / PoolSize, N] + std::optional lora_act_in, // packed lora_act [M, R] + std::optional lora_up, // packed lora_wgt [N, R] + std::optional lora_down, // packed lora_wgt [N, R] + std::optional lora_act_out, // packed lora_act [M, R] + std::optional norm_q, // linear [HEAD_DIM] + std::optional norm_k, // linear [HEAD_DIM] + std::optional rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2] + std::optional bias, // packed ws [N] + std::optional smooth_factor, // packed ws [N], for quantization of the next layer + std::optional out_vk, // linear [B, num_heads, head_dim + 1, head_dim] + std::optional out_linearattn, // linear [B, (M), N / 3] + bool act_unsigned, + std::vector lora_scales, + bool fuse_silu, + bool fp4, + float alpha, + std::optional wcscales, + std::optional out_q, // packed attention [B, H, M, D] + std::optional out_k, // packed attention [B, H, M, D] + std::optional out_v, // packed attention [B, H, M, D] + int attn_tokens) { + TorchOpContext ctx; + spdlog::trace("running gemm_w4a4: "); + + auto getTensor = [](std::optional &t) { + Tensor ret = t.has_value() ? from_torch(t.value()) : Tensor{}; + if (ret.valid()) { + spdlog::trace(" {}", ret.shape.str()); + } else { + spdlog::trace(" "); + } + return ret; + }; + nunchaku::kernels::gemm_w4a4(getTensor(act), + getTensor(wgt), + getTensor(out), + getTensor(qout), + getTensor(ascales), + getTensor(wscales), + getTensor(oscales), + getTensor(poolout), + getTensor(lora_act_in), + getTensor(lora_up), + getTensor(lora_down), + getTensor(lora_act_out), + getTensor(norm_q), + getTensor(norm_k), + getTensor(rotary_emb), + getTensor(bias), + getTensor(smooth_factor), + getTensor(out_vk), + getTensor(out_linearattn), + act_unsigned, + lora_scales, + fuse_silu, + fp4, + alpha, + getTensor(wcscales), + getTensor(out_q), + getTensor(out_k), + getTensor(out_v), + attn_tokens); + // Tensor::synchronizeDevice(); +} + +void quantize_w4a4_act_fuse_lora(std::optional input, + std::optional output, + std::optional oscales, + std::optional lora_down, + std::optional lora_act_out, + std::optional smooth, + bool fuse_glu, + bool fp4) { + TorchOpContext ctx; + + spdlog::trace("running quantize_w4a4_act_fuse_lora: "); + + auto getTensor = [](std::optional &t) { + Tensor ret = t.has_value() ? from_torch(t.value()) : Tensor{}; + if (ret.valid()) { + spdlog::trace(" {}", ret.shape.str()); + } else { + spdlog::trace(" "); + } + return ret; + }; + nunchaku::kernels::quantize_w4a4_act_fuse_lora(getTensor(input), + getTensor(output), + getTensor(oscales), + getTensor(lora_down), + getTensor(lora_act_out), + getTensor(smooth), + fuse_glu, + fp4); +} + +void attention_fp16(torch::Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM] + torch::Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM] + torch::Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM] + torch::Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM] + float scale) { + TorchOpContext ctx; + nunchaku::kernels::attention_fp16(from_torch(q), from_torch(k), from_torch(v), from_torch(o), scale); +} + +torch::Tensor gemv_awq(torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int64_t m, + int64_t n, + int64_t k, + int64_t group_size) { + TorchOpContext ctx; + Tensor result = ::gemv_awq(from_torch(_in_feats.contiguous()), + from_torch(_kernel.contiguous()), + from_torch(_scaling_factors.contiguous()), + from_torch(_zeros.contiguous()), + (int)m, + (int)n, + (int)k, + (int)group_size); + + torch::Tensor output = to_torch(result); + // Tensor::synchronizeDevice(); + + return output; +} + +torch::Tensor +gemm_awq(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros) { + Tensor result = ::awq_gemm_forward_cuda(from_torch(_in_feats.contiguous()), + from_torch(_kernel.contiguous()), + from_torch(_scaling_factors.contiguous()), + from_torch(_zeros.contiguous())); + + TorchOpContext ctx; + // TODO: allocate output in torch and use from_torch instead (to_torch needs an extra copy) + torch::Tensor output = to_torch(result); + // Tensor::synchronizeDevice(); + + return output; +} + +void test_rmsnorm_rope( + torch::Tensor input, torch::Tensor output, torch::Tensor norm_q, torch::Tensor norm_k, torch::Tensor rotary_emb) { + nunchaku::kernels::test_rmsnorm_rope( + from_torch(input), from_torch(output), from_torch(norm_q), from_torch(norm_k), from_torch(rotary_emb)); +} + +void test_pack_qkv(torch::Tensor input, torch::Tensor out_q, torch::Tensor out_k, torch::Tensor out_v, int numTokens) { + nunchaku::kernels::test_pack_qkv( + from_torch(input), from_torch(out_q), from_torch(out_k), from_torch(out_v), numTokens); +} + +}; // namespace nunchaku::ops diff --git a/nunchaku/csrc/pybind.cpp b/nunchaku/csrc/pybind.cpp new file mode 100644 index 0000000000000000000000000000000000000000..74fc37f0f04aec50dc73072234d978e365de7d37 --- /dev/null +++ b/nunchaku/csrc/pybind.cpp @@ -0,0 +1,124 @@ +#include "gemm.h" +#include "gemm88.h" +#include "flux.h" +#include "sana.h" +#include "ops.h" +#include "utils.h" +#include +#include "interop/torch.h" +#include + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + py::class_(m, "QuantizedFluxModel") + .def(py::init<>()) + .def("init", + &QuantizedFluxModel::init, + py::arg("use_fp4"), + py::arg("offload"), + py::arg("bf16"), + py::arg("deviceId")) + .def("set_residual_callback", + [](QuantizedFluxModel &self, pybind11::object call_back) { + if (call_back.is_none()) { + self.set_residual_callback(pybind11::function()); + } else { + self.set_residual_callback(call_back); + } + }) + .def("reset", &QuantizedFluxModel::reset) + .def("load", &QuantizedFluxModel::load, py::arg("path"), py::arg("partial") = false) + .def("loadDict", &QuantizedFluxModel::loadDict, py::arg("dict"), py::arg("partial") = false) + .def("forward", + &QuantizedFluxModel::forward, + py::arg("hidden_states"), + py::arg("encoder_hidden_states"), + py::arg("temb"), + py::arg("rotary_emb_img"), + py::arg("rotary_emb_context"), + py::arg("rotary_emb_single"), + py::arg("controlnet_block_samples") = py::none(), + py::arg("controlnet_single_block_samples") = py::none(), + py::arg("skip_first_layer") = false) + .def("forward_layer", + &QuantizedFluxModel::forward_layer, + py::arg("idx"), + py::arg("hidden_states"), + py::arg("encoder_hidden_states"), + py::arg("temb"), + py::arg("rotary_emb_img"), + py::arg("rotary_emb_context"), + py::arg("controlnet_block_samples") = py::none(), + py::arg("controlnet_single_block_samples") = py::none()) + .def("forward_layer_ip_adapter", + &QuantizedFluxModel::forward_layer_ip_adapter, + py::arg("idx"), + py::arg("hidden_states"), + py::arg("encoder_hidden_states"), + py::arg("temb"), + py::arg("rotary_emb_img"), + py::arg("rotary_emb_context"), + py::arg("controlnet_block_samples") = py::none(), + py::arg("controlnet_single_block_samples") = py::none()) + .def("forward_single_layer", &QuantizedFluxModel::forward_single_layer) + .def("norm_one_forward", &QuantizedFluxModel::norm_one_forward) + .def("startDebug", &QuantizedFluxModel::startDebug) + .def("stopDebug", &QuantizedFluxModel::stopDebug) + .def("getDebugResults", &QuantizedFluxModel::getDebugResults) + .def("setLoraScale", &QuantizedFluxModel::setLoraScale) + .def("setAttentionImpl", &QuantizedFluxModel::setAttentionImpl) + .def("isBF16", &QuantizedFluxModel::isBF16); + py::class_(m, "QuantizedSanaModel") + .def(py::init<>()) + .def("init", + &QuantizedSanaModel::init, + py::arg("config"), + py::arg("pag_layers"), + py::arg("use_fp4"), + py::arg("bf16"), + py::arg("deviceId")) + .def("reset", &QuantizedSanaModel::reset) + .def("load", &QuantizedSanaModel::load, py::arg("path"), py::arg("partial") = false) + .def("loadDict", &QuantizedSanaModel::loadDict, py::arg("dict"), py::arg("partial") = false) + .def("forward", &QuantizedSanaModel::forward) + .def("forward_layer", &QuantizedSanaModel::forward_layer) + .def("startDebug", &QuantizedSanaModel::startDebug) + .def("stopDebug", &QuantizedSanaModel::stopDebug) + .def("getDebugResults", &QuantizedSanaModel::getDebugResults); + py::class_(m, "QuantizedGEMM") + .def(py::init<>()) + .def("init", &QuantizedGEMM::init) + .def("reset", &QuantizedGEMM::reset) + .def("load", &QuantizedGEMM::load) + .def("forward", &QuantizedGEMM::forward) + .def("quantize", &QuantizedGEMM::quantize) + .def("startDebug", &QuantizedGEMM::startDebug) + .def("stopDebug", &QuantizedGEMM::stopDebug) + .def("getDebugResults", &QuantizedGEMM::getDebugResults); + py::class_(m, "Tensor"); + py::class_(m, "QuantizedGEMM88") + .def(py::init<>()) + .def("init", &QuantizedGEMM88::init) + .def("reset", &QuantizedGEMM88::reset) + .def("load", &QuantizedGEMM88::load) + .def("forward", &QuantizedGEMM88::forward) + .def("startDebug", &QuantizedGEMM88::startDebug) + .def("stopDebug", &QuantizedGEMM88::stopDebug) + .def("getDebugResults", &QuantizedGEMM88::getDebugResults); + + m.def_submodule("ops") + .def("gemm_w4a4", nunchaku::ops::gemm_w4a4) + .def("quantize_w4a4_act_fuse_lora", nunchaku::ops::quantize_w4a4_act_fuse_lora) + .def("attention_fp16", nunchaku::ops::attention_fp16) + .def("gemm_awq", nunchaku::ops::gemm_awq) + .def("gemv_awq", nunchaku::ops::gemv_awq) + + .def("test_rmsnorm_rope", nunchaku::ops::test_rmsnorm_rope) + .def("test_pack_qkv", nunchaku::ops::test_pack_qkv); + + m.def_submodule("utils") + .def("set_log_level", [](const std::string &level) { spdlog::set_level(spdlog::level::from_str(level)); }) + .def("set_cuda_stack_limit", nunchaku::utils::set_cuda_stack_limit) + .def("disable_memory_auto_release", nunchaku::utils::disable_memory_auto_release) + .def("trim_memory", nunchaku::utils::trim_memory) + .def("set_faster_i2f_mode", nunchaku::utils::set_faster_i2f_mode); +} diff --git a/nunchaku/csrc/sana.h b/nunchaku/csrc/sana.h new file mode 100644 index 0000000000000000000000000000000000000000..8d6c75ffba160cc6125772a5b583b99e7c7de541 --- /dev/null +++ b/nunchaku/csrc/sana.h @@ -0,0 +1,102 @@ +#pragma once + +#include "interop/torch.h" +#include "SanaModel.h" +#include "Serialization.h" +#include "debug.h" +#include "module.h" + +class QuantizedSanaModel : public ModuleWrapper { +public: + void init(pybind11::dict config, std::vector pag_layers, bool use_fp4, bool bf16, int8_t deviceId) { + spdlog::info("Initializing QuantizedSanaModel on device {}", deviceId); + SanaConfig cfg{ + .num_layers = config["num_layers"].cast(), + .num_attention_heads = config["num_attention_heads"].cast(), + .attention_head_dim = config["attention_head_dim"].cast(), + .num_cross_attention_heads = config["num_cross_attention_heads"].cast(), + .expand_ratio = config["mlp_ratio"].cast(), + .pag_layers = pag_layers, + .use_fp4 = use_fp4, + }; + + ModuleWrapper::init(deviceId); + CUDADeviceContext ctx(this->deviceId); + net = std::make_unique(cfg, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId)); + } + + torch::Tensor forward(torch::Tensor hidden_states, + torch::Tensor encoder_hidden_states, + torch::Tensor timestep, + torch::Tensor cu_seqlens_img, + torch::Tensor cu_seqlens_txt, + int H, + int W, + bool pag, + bool cfg, + bool skip_first_layer = false) { + checkModel(); + CUDADeviceContext ctx(deviceId); + + spdlog::debug("QuantizedSanaModel forward"); + + hidden_states = hidden_states.contiguous(); + encoder_hidden_states = encoder_hidden_states.contiguous(); + timestep = timestep.contiguous(); + cu_seqlens_img = cu_seqlens_img.contiguous(); + cu_seqlens_txt = cu_seqlens_txt.contiguous(); + + Tensor result = net->forward(from_torch(hidden_states), + from_torch(encoder_hidden_states), + from_torch(timestep), + from_torch(cu_seqlens_img), + from_torch(cu_seqlens_txt), + H, + W, + pag, + cfg, + skip_first_layer); + + torch::Tensor output = to_torch(result); + // Tensor::synchronizeDevice(); + + return output; + } + + torch::Tensor forward_layer(int64_t idx, + torch::Tensor hidden_states, + torch::Tensor encoder_hidden_states, + torch::Tensor timestep, + torch::Tensor cu_seqlens_img, + torch::Tensor cu_seqlens_txt, + int H, + int W, + bool pag, + bool cfg) { + checkModel(); + CUDADeviceContext ctx(deviceId); + + spdlog::debug("QuantizedSanaModel forward_layer {}", idx); + + hidden_states = hidden_states.contiguous(); + encoder_hidden_states = encoder_hidden_states.contiguous(); + timestep = timestep.contiguous(); + cu_seqlens_img = cu_seqlens_img.contiguous(); + cu_seqlens_txt = cu_seqlens_txt.contiguous(); + + Tensor result = net->transformer_blocks.at(idx)->forward(from_torch(hidden_states), + from_torch(encoder_hidden_states), + from_torch(timestep), + from_torch(cu_seqlens_img), + from_torch(cu_seqlens_txt), + H, + W, + pag, + cfg); + + torch::Tensor output = to_torch(result); + // Tensor::synchronizeDevice(); + + return output; + } +}; diff --git a/nunchaku/csrc/utils.h b/nunchaku/csrc/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..967f87772901aec5f66d4fc0860f8e642e51e60b --- /dev/null +++ b/nunchaku/csrc/utils.h @@ -0,0 +1,39 @@ +#pragma once + +#include "common.h" +#include "Tensor.h" +#include "kernels/zgemm/zgemm.h" + +namespace nunchaku::utils { + +void set_cuda_stack_limit(int64_t newval) { + size_t val = 0; + checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, (size_t)newval)); + checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize)); + spdlog::debug("Stack={}", val); +} + +void disable_memory_auto_release() { + int device; + checkCUDA(cudaGetDevice(&device)); + cudaMemPool_t mempool; + checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device)); + uint64_t threshold = UINT64_MAX; + checkCUDA(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &threshold)); +} + +void trim_memory() { + int device; + checkCUDA(cudaGetDevice(&device)); + cudaMemPool_t mempool; + checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device)); + size_t bytesToKeep = 0; + checkCUDA(cudaMemPoolTrimTo(mempool, bytesToKeep)); +} + +void set_faster_i2f_mode(std::string mode) { + spdlog::info("Set fasteri2f mode to {}", mode); + kernels::set_faster_i2f_mode(mode); +} + +}; // namespace nunchaku::utils diff --git a/nunchaku/lora/__init__.py b/nunchaku/lora/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f037000fbd143a8ad81aeb825a4667002b07d14a --- /dev/null +++ b/nunchaku/lora/__init__.py @@ -0,0 +1 @@ +# LoRA utilities for FLUX models diff --git a/nunchaku/lora/flux/__init__.py b/nunchaku/lora/flux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8d49b5f570dcbe27f81f934f7fb73f5dd7c5456c --- /dev/null +++ b/nunchaku/lora/flux/__init__.py @@ -0,0 +1,5 @@ +from .diffusers_converter import to_diffusers +from .nunchaku_converter import convert_to_nunchaku_flux_lowrank_dict, to_nunchaku +from .utils import is_nunchaku_format + +__all__ = ["to_diffusers", "to_nunchaku", "convert_to_nunchaku_flux_lowrank_dict", "is_nunchaku_format"] diff --git a/nunchaku/lora/flux/compose.py b/nunchaku/lora/flux/compose.py new file mode 100644 index 0000000000000000000000000000000000000000..a44a8aff4726c026f0dc2859c06a054453539c3c --- /dev/null +++ b/nunchaku/lora/flux/compose.py @@ -0,0 +1,218 @@ +""" +Compose multiple LoRA weights into a single LoRA for FLUX models. + +This script merges several LoRA safetensors files into one, applying individual strength values to each. + +**Example Usage:** + +.. code-block:: bash + + python -m nunchaku.lora.flux.compose \\ + -i lora1.safetensors lora2.safetensors \\ + -s 0.8 1.0 \\ + -o composed_lora.safetensors + +**Arguments:** + +- ``-i``, ``--input-paths``: Input LoRA safetensors files (one or more). +- ``-s``, ``--strengths``: Strength value for each LoRA (must match number of inputs). +- ``-o``, ``--output-path``: Output path for the composed LoRA safetensors file. + +This will merge ``lora1.safetensors`` (strength 0.8) and ``lora2.safetensors`` (strength 1.0) into ``composed_lora.safetensors``. + +**Main Function** + +:func:`compose_lora` +""" + +import argparse +import os + +import torch +import torch.nn.functional as F +from safetensors.torch import save_file + +from .diffusers_converter import to_diffusers +from .utils import is_nunchaku_format, load_state_dict_in_safetensors + + +def compose_lora( + loras: list[tuple[str | dict[str, torch.Tensor], float]], output_path: str | None = None +) -> dict[str, torch.Tensor]: + """ + Compose multiple LoRA weights into a single LoRA representation. + + Parameters + ---------- + loras : list of (str or dict[str, torch.Tensor], float) + Each tuple contains: + - Path to a LoRA safetensors file or a LoRA weights dictionary. + - Strength/scale factor for that LoRA. + output_path : str, optional + Path to save the composed LoRA weights as a safetensors file. If None, does not save. + + Returns + ------- + dict[str, torch.Tensor] + The composed LoRA weights. + + Raises + ------ + AssertionError + If LoRA weights are in Nunchaku format (must be converted to Diffusers format first) + or if tensor shapes are incompatible. + + Notes + ----- + - Converts all input LoRAs to Diffusers format. + - Handles QKV projection fusion for attention layers. + - Applies strength scaling to LoRA weights. + - Concatenates multiple LoRAs along appropriate dimensions. + - Handles normalization layers, bias vectors, and FLUX.1-tools LoRA compatibility. + + Examples + -------- + >>> lora_paths = [("lora1.safetensors", 0.8), ("lora2.safetensors", 0.6)] + >>> composed = compose_lora(lora_paths, "composed_lora.safetensors") + >>> lora_dicts = [({"layer.weight": torch.randn(10, 20)}, 1.0)] + >>> composed = compose_lora(lora_dicts) + """ + if len(loras) == 1: + if is_nunchaku_format(loras[0][0]) and (loras[0][1] - 1) < 1e-5: + if isinstance(loras[0][0], str): + return load_state_dict_in_safetensors(loras[0][0], device="cpu") + else: + return loras[0][0] + + composed = {} + for lora, strength in loras: + assert not is_nunchaku_format(lora) + lora = to_diffusers(lora) + for k, v in list(lora.items()): + if v.ndim == 1: + previous_tensor = composed.get(k, None) + if previous_tensor is None: + if "norm_q" in k or "norm_k" in k or "norm_added_q" in k or "norm_added_k" in k: + composed[k] = v + else: + composed[k] = v * strength + else: + assert not ("norm_q" in k or "norm_k" in k or "norm_added_q" in k or "norm_added_k" in k) + composed[k] = previous_tensor + v * strength + else: + assert v.ndim == 2 + if ".to_q." in k or ".add_q_proj." in k: # qkv must all exist + if "lora_B" in k: + continue + + q_a = v + k_a = lora[k.replace(".to_q.", ".to_k.").replace(".add_q_proj.", ".add_k_proj.")] + v_a = lora[k.replace(".to_q.", ".to_v.").replace(".add_q_proj.", ".add_v_proj.")] + + q_b = lora[k.replace("lora_A", "lora_B")] + k_b = lora[ + k.replace("lora_A", "lora_B") + .replace(".to_q.", ".to_k.") + .replace(".add_q_proj.", ".add_k_proj.") + ] + v_b = lora[ + k.replace("lora_A", "lora_B") + .replace(".to_q.", ".to_v.") + .replace(".add_q_proj.", ".add_v_proj.") + ] + + # Add paddings if their ranks are different + max_rank = max(q_a.shape[0], k_a.shape[0], v_a.shape[0]) + q_a = F.pad(q_a, (0, 0, 0, max_rank - q_a.shape[0])) + k_a = F.pad(k_a, (0, 0, 0, max_rank - k_a.shape[0])) + v_a = F.pad(v_a, (0, 0, 0, max_rank - v_a.shape[0])) + q_b = F.pad(q_b, (0, max_rank - q_b.shape[1])) + k_b = F.pad(k_b, (0, max_rank - k_b.shape[1])) + v_b = F.pad(v_b, (0, max_rank - v_b.shape[1])) + + if torch.isclose(q_a, k_a).all() and torch.isclose(q_a, v_a).all(): + lora_a = q_a + lora_b = torch.cat((q_b, k_b, v_b), dim=0) + else: + lora_a_group = (q_a, k_a, v_a) + new_shape_a = [sum([_.shape[0] for _ in lora_a_group]), q_a.shape[1]] + lora_a = torch.zeros(new_shape_a, dtype=q_a.dtype, device=q_a.device) + start_dim = 0 + for tensor in lora_a_group: + lora_a[start_dim : start_dim + tensor.shape[0]] = tensor + start_dim += tensor.shape[0] + + lora_b_group = (q_b, k_b, v_b) + new_shape_b = [sum([_.shape[0] for _ in lora_b_group]), sum([_.shape[1] for _ in lora_b_group])] + lora_b = torch.zeros(new_shape_b, dtype=q_b.dtype, device=q_b.device) + start_dims = (0, 0) + for tensor in lora_b_group: + end_dims = (start_dims[0] + tensor.shape[0], start_dims[1] + tensor.shape[1]) + lora_b[start_dims[0] : end_dims[0], start_dims[1] : end_dims[1]] = tensor + start_dims = end_dims + + lora_a = lora_a * strength + + new_k_a = k.replace(".to_q.", ".to_qkv.").replace(".add_q_proj.", ".add_qkv_proj.") + new_k_b = new_k_a.replace("lora_A", "lora_B") + + for kk, vv, dim in ((new_k_a, lora_a, 0), (new_k_b, lora_b, 1)): + previous_lora = composed.get(kk, None) + composed[kk] = vv if previous_lora is None else torch.cat([previous_lora, vv], dim=dim) + + elif ".to_k." in k or ".to_v." in k or ".add_k_proj." in k or ".add_v_proj." in k: + continue + else: + if "lora_A" in k: + v = v * strength + + previous_lora = composed.get(k, None) + if previous_lora is None: + composed[k] = v + else: + if "lora_A" in k: + if previous_lora.shape[1] != v.shape[1]: # flux.1-tools LoRA compatibility + assert "x_embedder" in k + expanded_size = max(previous_lora.shape[1], v.shape[1]) + if expanded_size > previous_lora.shape[1]: + expanded_previous_lora = torch.zeros( + (previous_lora.shape[0], expanded_size), + device=previous_lora.device, + dtype=previous_lora.dtype, + ) + expanded_previous_lora[:, : previous_lora.shape[1]] = previous_lora + else: + expanded_previous_lora = previous_lora + if expanded_size > v.shape[1]: + expanded_v = torch.zeros( + (v.shape[0], expanded_size), device=v.device, dtype=v.dtype + ) + expanded_v[:, : v.shape[1]] = v + else: + expanded_v = v + composed[k] = torch.cat([expanded_previous_lora, expanded_v], dim=0) + else: + composed[k] = torch.cat([previous_lora, v], dim=0) + else: + composed[k] = torch.cat([previous_lora, v], dim=1) + + composed[k] = ( + v if previous_lora is None else torch.cat([previous_lora, v], dim=0 if "lora_A" in k else 1) + ) + if output_path is not None: + output_dir = os.path.dirname(os.path.abspath(output_path)) + os.makedirs(output_dir, exist_ok=True) + save_file(composed, output_path) + return composed + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-i", "--input-paths", type=str, nargs="*", required=True, help="paths to the lora safetensors files" + ) + parser.add_argument("-s", "--strengths", type=float, nargs="*", required=True, help="strengths for each lora") + parser.add_argument("-o", "--output-path", type=str, required=True, help="path to the output safetensors file") + args = parser.parse_args() + assert len(args.input_paths) == len(args.strengths) + compose_lora(list(zip(args.input_paths, args.strengths)), args.output_path) diff --git a/nunchaku/lora/flux/convert.py b/nunchaku/lora/flux/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..ff11f40accb842f67197469a8c7cc8be0f3e102e --- /dev/null +++ b/nunchaku/lora/flux/convert.py @@ -0,0 +1,74 @@ +""" +CLI tool to convert LoRA weights to Nunchaku format. + +**Example Usage:** + +.. code-block:: bash + + python -m nunchaku.lora.flux.convert \\ + --lora-path composed_lora.safetensors \\ + --quant-path mit-han-lab/svdq-int4-flux.1-dev/transformer_blocks.safetensors \\ + --output-root ./converted \\ + --dtype bfloat16 + +**Arguments:** + +- ``--lora-path``: Path to the LoRA weights safetensor file (required) +- ``--quant-path``: Path to the quantized model safetensor file (default: ``mit-han-lab/svdq-int4-flux.1-dev/transformer_blocks.safetensors``) +- ``--output-root``: Root directory for the output safetensor file (default: parent directory of the lora file) +- ``--lora-name``: Name of the LoRA weights (optional, auto-generated if not provided) +- ``--dtype``: Data type of the converted weights, either ``bfloat16`` or ``float16`` (default: ``bfloat16``) + +**Main Function** + +:func:`nunchaku.lora.flux.nunchaku_converter.to_nunchaku` +""" + +import argparse +import os + +from .nunchaku_converter import to_nunchaku +from .utils import is_nunchaku_format + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--quant-path", + type=str, + help="Path to the quantized model safetensors file.", + default="mit-han-lab/svdq-int4-flux.1-dev/transformer_blocks.safetensors", + ) + parser.add_argument("--lora-path", type=str, required=True, help="Path to LoRA weights safetensors file.") + parser.add_argument("--output-root", type=str, default="", help="Root directory for output safetensors file.") + parser.add_argument("--lora-name", type=str, default=None, help="Name for the output LoRA weights.") + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + choices=["bfloat16", "float16"], + help="Data type of the converted weights.", + ) + args = parser.parse_args() + + if is_nunchaku_format(args.lora_path): + print("Already in Nunchaku format, no conversion needed.") + exit(0) + + if not args.output_root: + args.output_root = os.path.dirname(args.lora_path) + if args.lora_name is None: + base_name = os.path.basename(args.lora_path) + lora_name = base_name.rsplit(".", 1)[0] + precision = "fp4" if "fp4" in args.quant_path else "int4" + lora_name = f"svdq-{precision}-{lora_name}" + print(f"LoRA name not provided, using {lora_name} as the LoRA name") + else: + lora_name = args.lora_name + assert lora_name, "LoRA name must be provided." + + to_nunchaku( + args.lora_path, + args.quant_path, + dtype=args.dtype, + output_path=os.path.join(args.output_root, f"{lora_name}.safetensors"), + ) diff --git a/nunchaku/lora/flux/diffusers_converter.py b/nunchaku/lora/flux/diffusers_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..e0c30134ef718ef79509d1d2d22182963cd97651 --- /dev/null +++ b/nunchaku/lora/flux/diffusers_converter.py @@ -0,0 +1,220 @@ +""" +This module implements the functions to convert FLUX LoRA weights from various formats +to the Diffusers format, which will later be converted to Nunchaku format. +""" + +import argparse +import logging +import os + +import torch +from diffusers.loaders import FluxLoraLoaderMixin +from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft +from safetensors.torch import save_file + +from ...utils import load_state_dict_in_safetensors + +# Get log level from environment variable (default to INFO) +log_level = os.getenv("LOG_LEVEL", "INFO").upper() + +# Configure logging +logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +def handle_kohya_lora(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Convert Kohya LoRA format keys to Diffusers format. + + Parameters + ---------- + state_dict : dict[str, torch.Tensor] + LoRA weights, possibly in Kohya format. + + Returns + ------- + dict[str, torch.Tensor] + LoRA weights in Diffusers format. + """ + # first check if the state_dict is in the kohya format + # like: https://civitai.com/models/1118358?modelVersionId=1256866 + if any([not k.startswith("lora_transformer_") for k in state_dict.keys()]): + return state_dict + else: + new_state_dict = {} + for k, v in state_dict.items(): + new_k = k.replace("lora_transformer_", "transformer.") + + new_k = new_k.replace("norm_out_", "norm_out.") + + new_k = new_k.replace("time_text_embed_", "time_text_embed.") + new_k = new_k.replace("guidance_embedder_", "guidance_embedder.") + new_k = new_k.replace("text_embedder_", "text_embedder.") + new_k = new_k.replace("timestep_embedder_", "timestep_embedder.") + + new_k = new_k.replace("single_transformer_blocks_", "single_transformer_blocks.") + new_k = new_k.replace("_attn_", ".attn.") + new_k = new_k.replace("_norm_linear.", ".norm.linear.") + new_k = new_k.replace("_proj_mlp.", ".proj_mlp.") + new_k = new_k.replace("_proj_out.", ".proj_out.") + + new_k = new_k.replace("transformer_blocks_", "transformer_blocks.") + new_k = new_k.replace("to_out_0.", "to_out.0.") + new_k = new_k.replace("_ff_context_net_0_proj.", ".ff_context.net.0.proj.") + new_k = new_k.replace("_ff_context_net_2.", ".ff_context.net.2.") + new_k = new_k.replace("_ff_net_0_proj.", ".ff.net.0.proj.") + new_k = new_k.replace("_ff_net_2.", ".ff.net.2.") + new_k = new_k.replace("_norm1_context_linear.", ".norm1_context.linear.") + new_k = new_k.replace("_norm1_linear.", ".norm1.linear.") + + new_k = new_k.replace(".lora_down.", ".lora_A.") + new_k = new_k.replace(".lora_up.", ".lora_B.") + + new_state_dict[new_k] = v + return new_state_dict + + +def convert_peft_to_comfyui(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Convert PEFT format (base_model.model.*) to ComfyUI format (lora_unet_*). + + Mapping rules: + - base_model.model.double_blocks.X.img_attn.proj → lora_unet_double_blocks_X_img_attn_proj + - base_model.model.single_blocks.X.linear1 → lora_unet_single_blocks_X_linear1 + - base_model.model.final_layer.linear → lora_unet_final_layer_linear + - lora_A/lora_B → lora_down/lora_up + + Parameters + ---------- + state_dict : dict[str, torch.Tensor] + LoRA weights in PEFT format + + Returns + ------- + dict[str, torch.Tensor] + LoRA weights in ComfyUI format + """ + converted_dict = {} + + for key, value in state_dict.items(): + new_key = key + + if key.startswith("base_model.model."): + # Remove base_model.model. prefix + new_key = key.replace("base_model.model.", "") + + # Convert to ComfyUI format with underscores + # Handle double_blocks + if "double_blocks" in new_key: + # Replace dots with underscores within the block structure + # e.g., double_blocks.0.img_attn.proj → double_blocks_0_img_attn_proj + new_key = new_key.replace("double_blocks.", "lora_unet_double_blocks_") + # Replace remaining dots with underscores + new_key = new_key.replace(".", "_") + + # Handle single_blocks + elif "single_blocks" in new_key: + new_key = new_key.replace("single_blocks.", "lora_unet_single_blocks_") + # Special handling for modulation.lin → modulation_lin + new_key = new_key.replace("modulation.lin", "modulation_lin") + # Replace remaining dots with underscores + new_key = new_key.replace(".", "_") + + # Handle final_layer + elif "final_layer" in new_key: + new_key = new_key.replace("final_layer.linear", "lora_unet_final_layer_linear") + # Replace remaining dots with underscores + new_key = new_key.replace(".", "_") + + else: + # For any other keys, add lora_unet_ prefix and replace dots + new_key = "lora_unet_" + new_key.replace(".", "_") + + # Convert lora_A/lora_B to lora_down/lora_up + new_key = new_key.replace("_lora_A_weight", ".lora_down.weight") + new_key = new_key.replace("_lora_B_weight", ".lora_up.weight") + + converted_dict[new_key] = value + + if key != new_key: + logger.debug(f"Converted: {key} → {new_key}") + + return converted_dict + + +def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | None = None) -> dict[str, torch.Tensor]: + """ + Convert LoRA weights to Diffusers format, which will later be converted to Nunchaku format. + + Parameters + ---------- + input_lora : str or dict[str, torch.Tensor] + Path to a safetensors file or a LoRA weight dictionary. + output_path : str, optional + If given, save the converted weights to this path. + + Returns + ------- + dict[str, torch.Tensor] + LoRA weights in Diffusers format. + """ + if isinstance(input_lora, str): + tensors = load_state_dict_in_safetensors(input_lora, device="cpu") + else: + tensors = {k: v for k, v in input_lora.items()} + + tensors = handle_kohya_lora(tensors) + + # Convert FP8 tensors to BF16 + for k, v in tensors.items(): + if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]: + tensors[k] = v.to(torch.bfloat16) + + # Apply Kontext-specific key conversion for both PEFT format and ComfyUI format + # This handles LoRAs with base_model.model.* prefix or lora_unet_* prefix (including final_layer_linear) + if any(k.startswith("base_model.model.") for k in tensors.keys()): + logger.info("Converting PEFT format to ComfyUI format") + return convert_peft_to_comfyui(tensors) + + # Handle LoRAs that only have final_layer_linear without adaLN_modulation + # This is a workaround for incomplete final layer LoRAs + final_keys = [k for k in tensors.keys() if "final_layer" in k] + has_linear = any("final_layer_linear" in k for k in final_keys) + has_adaln = any("final_layer_adaLN_modulation" in k for k in final_keys) + + if has_linear and not has_adaln: + for key in list(tensors.keys()): + if "final_layer_linear" in key: + adaln_key = key.replace("final_layer_linear", "final_layer_adaLN_modulation_1") + if adaln_key not in tensors: + tensors[adaln_key] = torch.zeros_like(tensors[key]) + + new_tensors, alphas = FluxLoraLoaderMixin.lora_state_dict(tensors, return_alphas=True) + new_tensors = convert_unet_state_dict_to_peft(new_tensors) + + if alphas is not None and len(alphas) > 0: + for k, v in alphas.items(): + key_A = k.replace(".alpha", ".lora_A.weight") + key_B = k.replace(".alpha", ".lora_B.weight") + assert key_A in new_tensors, f"Key {key_A} not found in new tensors." + assert key_B in new_tensors, f"Key {key_B} not found in new tensors." + rank = new_tensors[key_A].shape[0] + assert new_tensors[key_B].shape[1] == rank, f"Rank mismatch for {key_B}." + new_tensors[key_A] = new_tensors[key_A] * v / rank + + if output_path is not None: + output_dir = os.path.dirname(os.path.abspath(output_path)) + os.makedirs(output_dir, exist_ok=True) + save_file(new_tensors, output_path) + + return new_tensors + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--input-path", type=str, required=True, help="path to the comfyui lora safetensors file") + parser.add_argument( + "-o", "--output-path", type=str, required=True, help="path to the output diffusers safetensors file" + ) + args = parser.parse_args() + to_diffusers(args.input_path, args.output_path) diff --git a/nunchaku/lora/flux/nunchaku_converter.py b/nunchaku/lora/flux/nunchaku_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..b10309c3dee19ba10192a0d3c237bc7a40858cb0 --- /dev/null +++ b/nunchaku/lora/flux/nunchaku_converter.py @@ -0,0 +1,949 @@ +""" +Nunchaku LoRA format converter for Flux models. + +This module provides utilities to convert LoRA weights from Diffusers format +to Nunchaku format for efficient quantized inference in Flux models. + +Key functions +------------- +- :func:`to_nunchaku` : Main conversion entry point +- :func:`fuse_vectors` : Vector fusion for bias terms +""" + +import logging +import os + +import torch +from safetensors.torch import save_file +from tqdm import tqdm + +from ...utils import filter_state_dict, load_state_dict_in_safetensors +from .diffusers_converter import to_diffusers +from .packer import NunchakuWeightPacker +from .utils import is_nunchaku_format, pad + +# Get log level from environment variable (default to INFO) +log_level = os.getenv("LOG_LEVEL", "INFO").upper() + +# Configure logging +logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +# region utilities + + +def update_state_dict( + lhs: dict[str, torch.Tensor], rhs: dict[str, torch.Tensor], prefix: str = "" +) -> dict[str, torch.Tensor]: + """ + Update a state dictionary with values from another, optionally adding a prefix to keys. + + Parameters + ---------- + lhs : dict[str, torch.Tensor] + Target state dictionary. + rhs : dict[str, torch.Tensor] + Source state dictionary. + prefix : str, optional + Prefix to add to keys from rhs. + + Returns + ------- + dict[str, torch.Tensor] + Updated state dictionary. + + Raises + ------ + AssertionError + If any key already exists in the target dictionary. + """ + for rkey, value in rhs.items(): + lkey = f"{prefix}.{rkey}" if prefix else rkey + assert lkey not in lhs, f"Key {lkey} already exists in the state dict." + lhs[lkey] = value + return lhs + + +# endregion + + +def pack_lowrank_weight(weight: torch.Tensor, down: bool) -> torch.Tensor: + """ + Pack the low-rank weight tensor for W4A4 linear layers. + + Parameters + ---------- + weight : torch.Tensor + Low-rank weight tensor. + down : bool + If True, pack as down-projection; else as up-projection. + + Returns + ------- + torch.Tensor + Packed weight tensor. + """ + assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}." + lane_n, lane_k = 1, 2 # lane_n is always 1, lane_k is 32 bits // 16 bits = 2 + n_pack_size, k_pack_size = 2, 2 + num_n_lanes, num_k_lanes = 8, 4 + frag_n = n_pack_size * num_n_lanes * lane_n + frag_k = k_pack_size * num_k_lanes * lane_k + weight = pad(weight, divisor=(frag_n, frag_k), dim=(0, 1)) + if down: + r, c = weight.shape + r_frags, c_frags = r // frag_n, c // frag_k + weight = weight.view(r_frags, frag_n, c_frags, frag_k).permute(2, 0, 1, 3) + else: + c, r = weight.shape + c_frags, r_frags = c // frag_n, r // frag_k + weight = weight.view(c_frags, frag_n, r_frags, frag_k).permute(0, 2, 1, 3) + weight = weight.reshape(c_frags, r_frags, n_pack_size, num_n_lanes, k_pack_size, num_k_lanes, lane_k) + weight = weight.permute(0, 1, 3, 5, 2, 4, 6).contiguous() + return weight.view(c, r) + + +def unpack_lowrank_weight(weight: torch.Tensor, down: bool) -> torch.Tensor: + """ + Unpack the low-rank weight tensor from W4A4 linear layers. + + Parameters + ---------- + weight : torch.Tensor + Packed low-rank weight tensor. + down : bool + If True, unpack as down-projection; else as up-projection. + + Returns + ------- + torch.Tensor + Unpacked weight tensor. + """ + c, r = weight.shape + assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}." + lane_n, lane_k = 1, 2 # lane_n is always 1, lane_k is 32 bits // 16 bits = 2 + n_pack_size, k_pack_size = 2, 2 + num_n_lanes, num_k_lanes = 8, 4 + frag_n = n_pack_size * num_n_lanes * lane_n + frag_k = k_pack_size * num_k_lanes * lane_k + if down: + r_frags, c_frags = r // frag_n, c // frag_k + else: + c_frags, r_frags = c // frag_n, r // frag_k + weight = weight.view(c_frags, r_frags, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, lane_k) + weight = weight.permute(0, 1, 4, 2, 5, 3, 6).contiguous() + weight = weight.view(c_frags, r_frags, frag_n, frag_k) + if down: + weight = weight.permute(1, 2, 0, 3).contiguous().view(r, c) + else: + weight = weight.permute(0, 2, 1, 3).contiguous().view(c, r) + return weight + + +def reorder_adanorm_lora_up(lora_up: torch.Tensor, splits: int) -> torch.Tensor: + """ + Reorder AdaNorm LoRA up-projection tensor for correct shape. + + Parameters + ---------- + lora_up : torch.Tensor + LoRA up-projection tensor. + splits : int + Number of splits for AdaNorm. + + Returns + ------- + torch.Tensor + Reordered tensor. + """ + c, r = lora_up.shape + assert c % splits == 0 + return lora_up.view(splits, c // splits, r).transpose(0, 1).reshape(c, r).contiguous() + + +def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901 + orig_state_dict: dict[str, torch.Tensor], + extra_lora_dict: dict[str, torch.Tensor], + converted_block_name: str, + candidate_block_name: str, + local_name_map: dict[str, str | list[str]], + convert_map: dict[str, str], + default_dtype: torch.dtype = torch.bfloat16, +) -> dict[str, torch.Tensor]: + """ + Convert LoRA weights for a transformer block from Diffusers to Nunchaku format. + + Merges and converts LoRA weights from the original SVDQuant low-rank branch and an extra LoRA dict + for a given transformer block, producing a Nunchaku-compatible dictionary. Handles both fused and + unfused LoRA branches (e.g., qkv), and merges multiple LoRA branches as needed. + + Parameters + ---------- + orig_state_dict : dict[str, torch.Tensor] + Original state dict with LoRA weights, keys like ``"{block}.{local}.lora_down"`` and ``"{block}.{local}.lora_up"``. + extra_lora_dict : dict[str, torch.Tensor] + Extra LoRA weights, keys like ``"{block}.{local}.lora_A.weight"`` and ``"{block}.{local}.lora_B.weight"``. + converted_block_name : str + Block name for output (e.g., ``"transformer_blocks.0"``). + candidate_block_name : str + Block name for input lookup (e.g., ``"blocks.0"``). + local_name_map : dict[str, str | list[str]] + Maps output local names (e.g., ``"attn.qkv"``) to one or more input local names. + convert_map : dict[str, str] + Maps output local names to conversion types: ``"adanorm_single"``, ``"adanorm_zero"``, or ``"linear"``. + default_dtype : torch.dtype, optional + Output tensor dtype (default: ``torch.bfloat16``). + + Returns + ------- + dict[str, torch.Tensor] + A dictionary containing the converted LoRA weights in Nunchaku format. + + Notes + ----- + - If both original and extra LoRA weights are present, they are merged by concatenation. + - Handles both fused and unfused attention projections (e.g., qkv). + - Applies special packing for W4A16 linear layers (e.g., ``"adanorm_single"`` and ``"adanorm_zero"``). + """ + logger.debug(f"Converting LoRA branch for block {candidate_block_name}...") + converted: dict[str, torch.Tensor] = {} + for converted_local_name, candidate_local_names in local_name_map.items(): + if isinstance(candidate_local_names, str): + candidate_local_names = [candidate_local_names] + # region original LoRA + orig_lora = ( + orig_state_dict.get(f"{converted_block_name}.{converted_local_name}.lora_down", None), + orig_state_dict.get(f"{converted_block_name}.{converted_local_name}.lora_up", None), + ) + if orig_lora[0] is None or orig_lora[1] is None: + assert orig_lora[0] is None and orig_lora[1] is None + orig_lora = None + elif orig_lora[0].numel() == 0 or orig_lora[1].numel() == 0: + assert orig_lora[0].numel() == 0 and orig_lora[1].numel() == 0 + orig_lora = None + else: + assert orig_lora[0] is not None and orig_lora[1] is not None + orig_lora = ( + unpack_lowrank_weight(orig_lora[0], down=True), + unpack_lowrank_weight(orig_lora[1], down=False), + ) + logger.debug( + f" - Found {converted_block_name} LoRA of {converted_local_name} (rank: {orig_lora[0].shape[0]})" + ) + # endregion + # region extra LoRA + extra_lora_list = None + + # if the qkv are already fused + if "qkv" in converted_local_name: + candidate_local_name = candidate_local_names[0] + assert "_q" in candidate_local_name + candidate_local_name = candidate_local_name.replace("_q", "_qkv") + lora_A = extra_lora_dict.get(f"{candidate_block_name}.{candidate_local_name}.lora_A.weight", None) + lora_B = extra_lora_dict.get(f"{candidate_block_name}.{candidate_local_name}.lora_B.weight", None) + if lora_A is None and lora_B is None: + extra_lora_list = None + else: + assert lora_A is not None and lora_B is not None + extra_lora_list = [(lora_A, lora_B)] + + # not fused, fuse them manually + if extra_lora_list is None: + extra_lora_list = [ + ( + extra_lora_dict.get(f"{candidate_block_name}.{candidate_local_name}.lora_A.weight", None), + extra_lora_dict.get(f"{candidate_block_name}.{candidate_local_name}.lora_B.weight", None), + ) + for candidate_local_name in candidate_local_names + ] + if any(lora[0] is not None or lora[1] is not None for lora in extra_lora_list): + # merge extra LoRAs into one LoRA + if len(extra_lora_list) > 1: + first_lora = None + for lora in extra_lora_list: + if lora[0] is not None: + assert lora[1] is not None + first_lora = lora + break + assert first_lora is not None + for lora_index in range(len(extra_lora_list)): + if extra_lora_list[lora_index][0] is None: + assert extra_lora_list[lora_index][1] is None + extra_lora_list[lora_index] = (first_lora[0].clone(), torch.zeros_like(first_lora[1])) + if all(lora[0].equal(extra_lora_list[0][0]) for lora in extra_lora_list): + # if all extra LoRAs have the same lora_down, use it + extra_lora_down = extra_lora_list[0][0] + extra_lora_up = torch.cat([lora[1] for lora in extra_lora_list], dim=0) + else: + extra_lora_down = torch.cat([lora[0] for lora in extra_lora_list], dim=0) + extra_lora_up_c = sum(lora[1].shape[0] for lora in extra_lora_list) + extra_lora_up_r = sum(lora[1].shape[1] for lora in extra_lora_list) + assert extra_lora_up_r == extra_lora_down.shape[0] + extra_lora_up = torch.zeros((extra_lora_up_c, extra_lora_up_r), dtype=extra_lora_down.dtype) + c, r = 0, 0 + for lora in extra_lora_list: + c_next, r_next = c + lora[1].shape[0], r + lora[1].shape[1] + extra_lora_up[c:c_next, r:r_next] = lora[1] + c, r = c_next, r_next + else: + extra_lora_down, extra_lora_up = extra_lora_list[0] + extra_lora: tuple[torch.Tensor, torch.Tensor] = (extra_lora_down, extra_lora_up) + logger.debug( + f" - Found {candidate_block_name} LoRA of {candidate_local_names} (rank: {extra_lora[0].shape[0]})" + ) + else: + extra_lora = None + # endregion + # region merge LoRA + if orig_lora is None: + if extra_lora is None: + lora = None + else: + logger.debug(" - Using extra LoRA") + lora = (extra_lora[0].to(default_dtype), extra_lora[1].to(default_dtype)) + elif extra_lora is None: + logger.debug(" - Using original LoRA") + lora = orig_lora + else: + try: + lora = ( + torch.cat([orig_lora[0], extra_lora[0].to(orig_lora[0].dtype)], dim=0), # [r, c] + torch.cat([orig_lora[1], extra_lora[1].to(orig_lora[1].dtype)], dim=1), # [c, r] + ) + logger.debug(f" - Merging original and extra LoRA (rank: {lora[0].shape[0]})") + except RuntimeError as e: + if "Sizes of tensors must match" in str(e): + # Handle various dimension mismatch cases for LoRA + logger.debug( + f" - Dimension mismatch detected: orig_lora[1]={orig_lora[1].shape}, extra_lora[1]={extra_lora[1].shape}" + ) + + # Handle dimension mismatch by using only the properly sized portion of extra_lora + # instead of trying to concatenate mismatched dimensions + + # Case 1: single_blocks linear1 [21504] -> mlp_fc1 [12288] + if extra_lora[1].shape[1] == 21504 and orig_lora[1].shape[1] == 12288: + # Use only the first 12288 dimensions from the 21504 extra LoRA + extra_lora_up_split = extra_lora[1][:, :12288].clone() + extra_lora_down = extra_lora[0].clone() + # logger.debug(f" - Dimension fix 21504->12288: using split extra LoRA instead of merge") + + # Use the split extra LoRA instead of concatenating + lora = (extra_lora_down.to(orig_lora[0].dtype), extra_lora_up_split.to(orig_lora[1].dtype)) + + # Case 2: transformer_blocks with different MLP dimensions (27648 -> 9216) + elif extra_lora[1].shape[1] == 27648 and orig_lora[1].shape[1] == 9216: + # Use only the first 9216 dimensions from the 27648 extra LoRA + extra_lora_up_split = extra_lora[1][:, :9216].clone() + extra_lora_down = extra_lora[0].clone() + # logger.debug(f" - Dimension fix 27648->9216: using split extra LoRA instead of merge") + + # Use the split extra LoRA instead of concatenating + lora = (extra_lora_down.to(orig_lora[0].dtype), extra_lora_up_split.to(orig_lora[1].dtype)) + + # Case 3: Other dimension ratios - try to find a reasonable split + elif extra_lora[1].shape[1] > orig_lora[1].shape[1]: + # Use only what we need from extra LoRA + target_dim = orig_lora[1].shape[1] + extra_lora_up_split = extra_lora[1][:, :target_dim].clone() + extra_lora_down = extra_lora[0].clone() + # logger.debug( + # f" - Dimension fix {extra_lora[1].shape[1]}->{target_dim}: using truncated extra LoRA" + # ) + + # Use the truncated extra LoRA instead of concatenating + lora = (extra_lora_down.to(orig_lora[0].dtype), extra_lora_up_split.to(orig_lora[1].dtype)) + + else: + # For cases where extra LoRA has fewer dimensions, use original LoRA only + # logger.warning( + # f" - Cannot split extra LoRA {extra_lora[1].shape[1]}->{orig_lora[1].shape[1]}, using original only" + # ) + lora = orig_lora + else: + raise e + # endregion + if lora is not None: + if convert_map[converted_local_name] == "adanorm_single": + update_state_dict( + converted, + { + "lora_down": pad(lora[0], divisor=16, dim=0), + "lora_up": pad(reorder_adanorm_lora_up(lora[1], splits=3), divisor=16, dim=1), + }, + prefix=converted_local_name, + ) + elif convert_map[converted_local_name] == "adanorm_zero": + update_state_dict( + converted, + { + "lora_down": pad(lora[0], divisor=16, dim=0), + "lora_up": pad(reorder_adanorm_lora_up(lora[1], splits=6), divisor=16, dim=1), + }, + prefix=converted_local_name, + ) + elif convert_map[converted_local_name] == "linear": + update_state_dict( + converted, + { + "lora_down": pack_lowrank_weight(lora[0], down=True), + "lora_up": pack_lowrank_weight(lora[1], down=False), + }, + prefix=converted_local_name, + ) + return converted + + +def preprocess_single_blocks_lora( + extra_lora_dict: dict[str, torch.Tensor], candidate_block_name: str +) -> dict[str, torch.Tensor]: + """ + Preprocess LoRA weights from single_blocks format to match single_transformer_blocks structure. + + This function handles the architectural mismatch between old and new models: + - Old single_blocks: linear1 (fused 21504-dim layer) and linear2 + - New single_transformer_blocks: mlp_fc1 (12288-dim), qkv_proj (9216-dim), and mlp_fc2 + + The linear1 layer in the old architecture combines two functions: + 1. MLP projection (first 12288 dimensions) + 2. QKV projection for attention (last 9216 dimensions) + + These are split into separate layers in the new architecture. + """ + processed_dict = extra_lora_dict.copy() + + # Find all single_transformer_blocks keys that need preprocessing + single_blocks_keys = [k for k in extra_lora_dict.keys() if "single_transformer_blocks" in k and "linear" in k] + + logger.debug(f"Preprocessing LoRA for candidate: {candidate_block_name}") + logger.debug(f"All keys in extra_lora_dict: {list(extra_lora_dict.keys())[:10]}...") # Show first 10 keys + logger.debug(f"Found single_transformer_blocks keys: {single_blocks_keys[:5]}...") # Show first 5 keys + + if single_blocks_keys: + logger.debug(f"Found single_transformer_blocks LoRA keys, preprocessing for candidate: {candidate_block_name}") + + # The candidate_block_name is already "single_transformer_blocks.0" + # Look for linear1 and linear2 keys with this exact name + linear1_lora_A_key = f"{candidate_block_name}.linear1.lora_A.weight" + linear1_lora_B_key = f"{candidate_block_name}.linear1.lora_B.weight" + linear2_lora_A_key = f"{candidate_block_name}.linear2.lora_A.weight" + linear2_lora_B_key = f"{candidate_block_name}.linear2.lora_B.weight" + + logger.debug(f"Looking for keys: {linear1_lora_B_key}") + logger.debug( + f"Available keys matching pattern: {[k for k in extra_lora_dict.keys() if candidate_block_name in k][:5]}..." + ) + + if linear1_lora_B_key in extra_lora_dict: + linear1_lora_A = extra_lora_dict[linear1_lora_A_key] + linear1_lora_B = extra_lora_dict[linear1_lora_B_key] + + # Check if this is the problematic 21504 dimension case + if linear1_lora_B.shape[0] == 21504: + logger.debug( + f"Splitting linear1 LoRA weights: [21504, {linear1_lora_B.shape[1]}] -> " + f"mlp_fc1 [12288, {linear1_lora_B.shape[1]}] + qkv_proj [9216, {linear1_lora_B.shape[1]}]" + ) + + # Split linear1.lora_B [21504, rank] into two parts: + # 1. First 12288 dimensions -> mlp_fc1 + # 2. Last 9216 dimensions (12288:21504) -> qkv_proj + mlp_fc1_lora_B = linear1_lora_B[:12288, :].clone() + qkv_proj_lora_B = linear1_lora_B[12288:21504, :].clone() + + # The lora_A weight is reused for both new layers + # since it represents the down-projection from the input + mlp_fc1_lora_A = linear1_lora_A.clone() + qkv_proj_lora_A = linear1_lora_A.clone() + + # Map to new architecture: + # 1. proj_mlp corresponds to mlp_fc1 + processed_dict[f"{candidate_block_name}.proj_mlp.lora_A.weight"] = mlp_fc1_lora_A + processed_dict[f"{candidate_block_name}.proj_mlp.lora_B.weight"] = mlp_fc1_lora_B + + # 2. Map the QKV part to the attention layers + # Note: In the new architecture, this maps to attn.to_q, attn.to_k, attn.to_v + # which get fused into qkv_proj during the conversion + processed_dict[f"{candidate_block_name}.attn.to_q.lora_A.weight"] = qkv_proj_lora_A + processed_dict[f"{candidate_block_name}.attn.to_q.lora_B.weight"] = qkv_proj_lora_B[ + :3072, : + ] # Q projection + processed_dict[f"{candidate_block_name}.attn.to_k.lora_A.weight"] = qkv_proj_lora_A + processed_dict[f"{candidate_block_name}.attn.to_k.lora_B.weight"] = qkv_proj_lora_B[ + 3072:6144, : + ] # K projection + processed_dict[f"{candidate_block_name}.attn.to_v.lora_A.weight"] = qkv_proj_lora_A + processed_dict[f"{candidate_block_name}.attn.to_v.lora_B.weight"] = qkv_proj_lora_B[ + 6144:9216, : + ] # V projection + + # Handle linear2 -> mlp_fc2 mapping + if linear2_lora_B_key in extra_lora_dict: + linear2_lora_A = extra_lora_dict[linear2_lora_A_key] + linear2_lora_B = extra_lora_dict[linear2_lora_B_key] + + # Map linear2 to proj_out.linears.1 (mlp_fc2) + processed_dict[f"{candidate_block_name}.proj_out.linears.1.lora_A.weight"] = linear2_lora_A + processed_dict[f"{candidate_block_name}.proj_out.linears.1.lora_B.weight"] = linear2_lora_B + + # Remove original keys + processed_dict.pop(linear2_lora_A_key, None) + processed_dict.pop(linear2_lora_B_key, None) + + # Remove original linear1 keys + processed_dict.pop(linear1_lora_A_key, None) + processed_dict.pop(linear1_lora_B_key, None) + + return processed_dict + + +def convert_to_nunchaku_flux_single_transformer_block_lowrank_dict( + orig_state_dict: dict[str, torch.Tensor], + extra_lora_dict: dict[str, torch.Tensor], + converted_block_name: str, + candidate_block_name: str, + default_dtype: torch.dtype = torch.bfloat16, +) -> dict[str, torch.Tensor]: + """ + Convert LoRA weights for a single FLUX transformer block from Diffusers to Nunchaku format. + + This function merges and converts LoRA weights from the original SVDQuant low-rank branch and an + extra LoRA dictionary for a given transformer block, producing a Nunchaku-compatible dictionary. + It handles both fused and unfused LoRA branches (e.g., qkv), and merges multiple LoRA branches as needed. + + Parameters + ---------- + orig_state_dict : dict[str, torch.Tensor] + Original state dict with LoRA weights, keys like ``"{block}.{local}.lora_down"`` and ``"{block}.{local}.lora_up"``. + extra_lora_dict : dict[str, torch.Tensor] + Extra LoRA weights, keys like ``"{block}.{local}.lora_A.weight"`` and ``"{block}.{local}.lora_B.weight"``. + converted_block_name : str + Block name for output (e.g., ``"transformer_blocks.0"``). + candidate_block_name : str + Block name for input lookup (e.g., ``"blocks.0"``). + default_dtype : torch.dtype, optional + Output tensor dtype (default: ``torch.bfloat16``). + + Returns + ------- + dict[str, torch.Tensor] + A dictionary containing the converted LoRA weights in Nunchaku format. + + Notes + ----- + - If both original and extra LoRA weights are present, they are merged by concatenation. + - Handles both fused and unfused attention projections (e.g., qkv). + - Applies special packing for W4A16 linear layers (e.g., ``"adanorm_single"`` and ``"adanorm_zero"``). + """ + + # Preprocess single_blocks LoRA structure if needed + # extra_lora_dict = preprocess_single_blocks_lora(extra_lora_dict, candidate_block_name) + + if f"{candidate_block_name}.proj_out.lora_A.weight" in extra_lora_dict: + assert f"{converted_block_name}.out_proj.qweight" in orig_state_dict + assert f"{converted_block_name}.mlp_fc2.qweight" in orig_state_dict + n1 = orig_state_dict[f"{converted_block_name}.out_proj.qweight"].shape[1] * 2 + n2 = orig_state_dict[f"{converted_block_name}.mlp_fc2.qweight"].shape[1] * 2 + lora_down = extra_lora_dict[f"{candidate_block_name}.proj_out.lora_A.weight"] + lora_up = extra_lora_dict[f"{candidate_block_name}.proj_out.lora_B.weight"] + assert lora_down.shape[1] == n1 + n2 + extra_lora_dict[f"{candidate_block_name}.proj_out.linears.0.lora_A.weight"] = lora_down[:, :n1].clone() + extra_lora_dict[f"{candidate_block_name}.proj_out.linears.0.lora_B.weight"] = lora_up.clone() + extra_lora_dict[f"{candidate_block_name}.proj_out.linears.1.lora_A.weight"] = lora_down[:, n1:].clone() + extra_lora_dict[f"{candidate_block_name}.proj_out.linears.1.lora_B.weight"] = lora_up.clone() + extra_lora_dict.pop(f"{candidate_block_name}.proj_out.lora_A.weight") + extra_lora_dict.pop(f"{candidate_block_name}.proj_out.lora_B.weight") + + for component in ["lora_A", "lora_B"]: + fc1_k = f"{candidate_block_name}.proj_mlp.{component}.weight" + fc2_k = f"{candidate_block_name}.proj_out.linears.1.{component}.weight" + fc1_v = extra_lora_dict[fc1_k] + fc2_v = extra_lora_dict[fc2_k] + dim = 0 if "lora_A" in fc1_k else 1 + + fc1_rank = fc1_v.shape[dim] + fc2_rank = fc2_v.shape[dim] + if fc1_rank != fc2_rank: + rank = max(fc1_rank, fc2_rank) + if fc1_rank < rank: + extra_lora_dict[fc1_k] = pad(fc1_v, divisor=rank, dim=dim) + if fc2_rank < rank: + extra_lora_dict[fc2_k] = pad(fc2_v, divisor=rank, dim=dim) + + return convert_to_nunchaku_transformer_block_lowrank_dict( + orig_state_dict=orig_state_dict, + extra_lora_dict=extra_lora_dict, + converted_block_name=converted_block_name, + candidate_block_name=candidate_block_name, + local_name_map={ + "norm.linear": "norm.linear", + "qkv_proj": ["attn.to_q", "attn.to_k", "attn.to_v"], + "norm_q": "attn.norm_q", + "norm_k": "attn.norm_k", + "out_proj": "proj_out.linears.0", + "mlp_fc1": "proj_mlp", + "mlp_fc2": "proj_out.linears.1", + }, + convert_map={ + "norm.linear": "adanorm_single", + "qkv_proj": "linear", + "out_proj": "linear", + "mlp_fc1": "linear", + "mlp_fc2": "linear", + }, + default_dtype=default_dtype, + ) + + +def convert_to_nunchaku_flux_transformer_block_lowrank_dict( + orig_state_dict: dict[str, torch.Tensor], + extra_lora_dict: dict[str, torch.Tensor], + converted_block_name: str, + candidate_block_name: str, + default_dtype: torch.dtype = torch.bfloat16, +) -> dict[str, torch.Tensor]: + """ + Convert LoRA weights for a single transformer block from Diffusers to Nunchaku format. + + Parameters + ---------- + orig_state_dict : dict[str, torch.Tensor] + Original model state dict. + extra_lora_dict : dict[str, torch.Tensor] + LoRA weights state dict. + converted_block_name : str + Output block name for the converted weights. + candidate_block_name : str + Input block name for lookup. + default_dtype : torch.dtype, optional + Output tensor dtype (default: torch.bfloat16). + + Returns + ------- + dict[str, torch.Tensor] + Converted LoRA weights in Nunchaku format. + """ + return convert_to_nunchaku_transformer_block_lowrank_dict( + orig_state_dict=orig_state_dict, + extra_lora_dict=extra_lora_dict, + converted_block_name=converted_block_name, + candidate_block_name=candidate_block_name, + local_name_map={ + "norm1.linear": "norm1.linear", + "norm1_context.linear": "norm1_context.linear", + "qkv_proj": ["attn.to_q", "attn.to_k", "attn.to_v"], + "qkv_proj_context": ["attn.add_q_proj", "attn.add_k_proj", "attn.add_v_proj"], + "norm_q": "attn.norm_q", + "norm_k": "attn.norm_k", + "norm_added_q": "attn.norm_added_q", + "norm_added_k": "attn.norm_added_k", + "out_proj": "attn.to_out.0", + "out_proj_context": "attn.to_add_out", + "mlp_fc1": "ff.net.0.proj", + "mlp_fc2": "ff.net.2", + "mlp_context_fc1": "ff_context.net.0.proj", + "mlp_context_fc2": "ff_context.net.2", + }, + convert_map={ + "norm1.linear": "adanorm_zero", + "norm1_context.linear": "adanorm_zero", + "qkv_proj": "linear", + "qkv_proj_context": "linear", + "out_proj": "linear", + "out_proj_context": "linear", + "mlp_fc1": "linear", + "mlp_fc2": "linear", + "mlp_context_fc1": "linear", + "mlp_context_fc2": "linear", + }, + default_dtype=default_dtype, + ) + + +def convert_to_nunchaku_flux_lowrank_dict( + base_model: dict[str, torch.Tensor] | str, + lora: dict[str, torch.Tensor] | str, + default_dtype: torch.dtype = torch.bfloat16, +) -> dict[str, torch.Tensor]: + """ + Convert a base model and LoRA weights from Diffusers format to Nunchaku format. + + Parameters + ---------- + base_model : dict[str, torch.Tensor] or str + Base model weights or path to safetensors file. + lora : dict[str, torch.Tensor] or str + LoRA weights or path to safetensors file. + default_dtype : torch.dtype, optional + Output tensor dtype (default: torch.bfloat16). + + Returns + ------- + dict[str, torch.Tensor] + LoRA weights in Nunchaku format. + """ + if isinstance(base_model, str): + orig_state_dict = load_state_dict_in_safetensors(base_model) + else: + orig_state_dict = base_model + + if isinstance(lora, str): + # Load the LoRA - check if it has transformer prefix + temp_dict = load_state_dict_in_safetensors(lora) + if any(k.startswith("transformer.") for k in temp_dict.keys()): + # Standard FLUX LoRA with transformer prefix + extra_lora_dict = filter_state_dict(temp_dict, filter_prefix="transformer.") + # Remove the transformer. prefix after filtering + renamed_dict = {} + for k, v in extra_lora_dict.items(): + new_k = k.replace("transformer.", "") if k.startswith("transformer.") else k + renamed_dict[new_k] = v + extra_lora_dict = renamed_dict + else: + # Kontext LoRA without transformer prefix - use as is + extra_lora_dict = temp_dict + else: + # When called from to_nunchaku, lora is already processed by to_diffusers + # Keys should be in format: single_blocks.0.linear1.lora_A.weight + extra_lora_dict = lora + + # Add transformer. prefix and rename blocks to match expectations + renamed_dict = {} + for k, v in extra_lora_dict.items(): + new_k = k + # Add transformer. prefix and rename blocks + if k.startswith("single_blocks."): + new_k = "transformer.single_transformer_blocks." + k[14:] + elif k.startswith("double_blocks."): + new_k = "transformer.transformer_blocks." + k[14:] + elif k.startswith("proj_out."): + new_k = "transformer." + k + elif not k.startswith("transformer."): + new_k = "transformer." + k + renamed_dict[new_k] = v + extra_lora_dict = renamed_dict + + # Now filter for transformer prefix and remove it for processing + extra_lora_dict = filter_state_dict(extra_lora_dict, filter_prefix="transformer.") + + # Remove the transformer. prefix for internal processing + renamed_dict = {} + for k, v in extra_lora_dict.items(): + new_k = k.replace("transformer.", "") if k.startswith("transformer.") else k + renamed_dict[new_k] = v + extra_lora_dict = renamed_dict + + vector_dict, unquantized_lora_dict = {}, {} + for k in list(extra_lora_dict.keys()): + v = extra_lora_dict[k] + if v.ndim == 1: + vector_dict[k.replace(".lora_B.bias", ".bias")] = extra_lora_dict.pop(k) + elif "transformer_blocks" not in k and "single_transformer_blocks" not in k: + # Only unquantized parts (like final_layer) go here + unquantized_lora_dict[k] = extra_lora_dict.pop(k) + + # Concatenate qkv_proj biases if present + for k in list(vector_dict.keys()): + if ".to_q." in k or ".add_q_proj." in k: + k_q = k + k_k = k.replace(".to_q.", ".to_k.").replace(".add_q_proj.", ".add_k_proj.") + k_v = k.replace(".to_q.", ".to_v.").replace(".add_q_proj.", ".add_v_proj.") + keys = [k_q, k_k, k_v] + values = [vector_dict.pop(key) for key in keys] + new_k = k_q.replace(".to_q.", ".to_qkv.").replace(".add_q_proj.", ".add_qkv_proj.") + vector_dict[new_k] = torch.cat(values, dim=0) + + for k in extra_lora_dict.keys(): + fc1_k = k + if "ff.net.0.proj" in k: + fc2_k = k.replace("ff.net.0.proj", "ff.net.2") + elif "ff_context.net.0.proj" in k: + fc2_k = k.replace("ff_context.net.0.proj", "ff_context.net.2") + else: + continue + assert fc2_k in extra_lora_dict + fc1_v = extra_lora_dict[fc1_k] + fc2_v = extra_lora_dict[fc2_k] + dim = 0 if "lora_A" in fc1_k else 1 + + fc1_rank = fc1_v.shape[dim] + fc2_rank = fc2_v.shape[dim] + if fc1_rank != fc2_rank: + rank = max(fc1_rank, fc2_rank) + if fc1_rank < rank: + extra_lora_dict[fc1_k] = pad(fc1_v, divisor=rank, dim=dim) + if fc2_rank < rank: + extra_lora_dict[fc2_k] = pad(fc2_v, divisor=rank, dim=dim) + + block_names: set[str] = set() + for param_name in orig_state_dict.keys(): + if param_name.startswith(("transformer_blocks.", "single_transformer_blocks.")): + block_names.add(".".join(param_name.split(".")[:2])) + block_names = sorted(block_names, key=lambda x: (x.split(".")[0], int(x.split(".")[-1]))) + logger.debug(f"Converting {len(block_names)} transformer blocks...") + converted: dict[str, torch.Tensor] = {} + for block_name in tqdm(block_names, dynamic_ncols=True, desc="Converting LoRAs to nunchaku format"): + if block_name.startswith("transformer_blocks"): + convert_fn = convert_to_nunchaku_flux_transformer_block_lowrank_dict + else: + convert_fn = convert_to_nunchaku_flux_single_transformer_block_lowrank_dict + update_state_dict( + converted, + convert_fn( + orig_state_dict=orig_state_dict, + extra_lora_dict=extra_lora_dict, + converted_block_name=block_name, + candidate_block_name=block_name, + default_dtype=default_dtype, + ), + prefix=block_name, + ) + + converted.update(unquantized_lora_dict) + converted.update(vector_dict) + return converted + + +def to_nunchaku( + input_lora: str | dict[str, torch.Tensor], + base_sd: str | dict[str, torch.Tensor], + dtype: str | torch.dtype = torch.bfloat16, + output_path: str | None = None, +) -> dict[str, torch.Tensor]: + """ + Convert LoRA weights to Nunchaku format. + + Parameters + ---------- + input_lora : str or dict[str, torch.Tensor] + Path or dictionary of LoRA weights in Diffusers format. Can be composed of multiple LoRA weights. + base_sd : str or dict[str, torch.Tensor] + Path or dictionary of base quantized model weights. + dtype : str or torch.dtype, optional + Output data type ("bfloat16", "float16", or torch dtype). Default is torch.bfloat16. + output_path : str, optional + If provided, saves the result to this path. + + Returns + ------- + dict[str, torch.Tensor] + LoRA weights in Nunchaku format. + + Example + ------- + .. code-block:: python + + nunchaku_weights = to_nunchaku("lora.safetensors", "base_model.safetensors") + nunchaku_weights = to_nunchaku(lora_dict, base_dict) + """ + if isinstance(input_lora, str): + tensors = load_state_dict_in_safetensors(input_lora, device="cpu") + else: + tensors = input_lora + if is_nunchaku_format(tensors): + logger.debug("Already in nunchaku format, no conversion needed.") + converted = tensors + else: + extra_lora_dict = to_diffusers(tensors) + + if isinstance(base_sd, str): + orig_state_dict = load_state_dict_in_safetensors(base_sd) + else: + orig_state_dict = base_sd + + if isinstance(dtype, str): + if dtype == "bfloat16": + dtype = torch.bfloat16 + elif dtype == "float16": + dtype = torch.float16 + else: + raise ValueError(f"Unsupported dtype {dtype}.") + else: + assert isinstance(dtype, torch.dtype) + + converted = convert_to_nunchaku_flux_lowrank_dict( + base_model=orig_state_dict, lora=extra_lora_dict, default_dtype=dtype + ) + if output_path is not None: + output_dir = os.path.dirname(os.path.abspath(output_path)) + os.makedirs(output_dir, exist_ok=True) + save_file(converted, output_path) + return converted + + +#### fuse vectors #### + + +def fuse_vectors( + vectors: dict[str, torch.Tensor], base_sd: dict[str, torch.Tensor], strength: float = 1 +) -> dict[str, torch.Tensor]: + """ + Fuse vector (bias) terms from LoRA into the base model. + + Parameters + ---------- + vectors : dict[str, torch.Tensor] + LoRA vector terms. + base_sd : dict[str, torch.Tensor] + Base model state dict. + strength : float, optional + Scaling factor for LoRA vectors. + + Returns + ------- + dict[str, torch.Tensor] + State dict with fused vectors. + """ + tensors: dict[str, torch.Tensor] = {} + packer = NunchakuWeightPacker(bits=4) + for k, v in base_sd.items(): + if v.ndim != 1 or "smooth" in k or (k.startswith("single_transformer_blocks.") and ".mlp_fc2." in k): + continue + if "norm_q" in k or "norm_k" in k or "norm_added_q" in k or "norm_added_k" in k: + new_k = k.replace(".norm_", ".attn.norm_") + new_v = vectors.get(new_k, None) + tensors[k] = v if new_v is None else new_v + + elif "norm.linear" in k or "norm1.linear" in k or "norm1_context.linear" in k: + diff = vectors.get(k, None) + + if diff is not None: + if k.startswith("single_transformer_blocks."): + adanorm_splits = 3 + else: + assert k.startswith("transformer_blocks.") + adanorm_splits = 6 + diff = diff.view(adanorm_splits, -1).transpose(0, 1).reshape(-1) + tensors[k] = v + diff * strength + else: + tensors[k] = v + + else: + if k.startswith("single_transformer_blocks."): + name_map = {".qkv_proj.": ".attn.to_qkv.", ".out_proj.": ".proj_out.", ".mlp_fc1.": ".proj_mlp."} + else: + assert k.startswith("transformer_blocks.") + name_map = { + ".qkv_proj.": ".attn.to_qkv.", + ".qkv_proj_context.": ".attn.add_qkv_proj.", + ".out_proj.": ".attn.to_out.0.", + ".out_proj_context.": ".attn.to_add_out.", + ".mlp_fc1.": ".ff.net.0.proj.", + ".mlp_fc2.": ".ff.net.2.", + ".mlp_context_fc1.": ".ff_context.net.0.proj.", + ".mlp_context_fc2.": ".ff_context.net.2.", + } + + for original_pattern, new_pattern in name_map.items(): + if original_pattern in k: + new_k = k.replace(original_pattern, new_pattern) + diff = vectors.get(new_k, None) + if diff is not None: + diff = diff * strength + diff = packer.pad_scale(diff, group_size=-1) + diff = packer.pack_scale(diff, group_size=-1) + tensors[k] = v + diff + break + + return tensors diff --git a/nunchaku/lora/flux/packer.py b/nunchaku/lora/flux/packer.py new file mode 100644 index 0000000000000000000000000000000000000000..7ab18f6ea4099da6bd74a71d3b0788832fe4c91a --- /dev/null +++ b/nunchaku/lora/flux/packer.py @@ -0,0 +1,517 @@ +""" +Weight packing utilities for Nunchaku quantization. + +This module provides concise tools for packing and unpacking weight tensors, +optimized for efficient GPU computation using Matrix Multiply and Accumulate (MMA) operations. +""" + +import torch + +from ...utils import ceil_divide +from .utils import pad + + +class MmaWeightPackerBase: + """ + Base class for Matrix Multiply and Accumulate (MMA) weight packing. + + Packs weight tensors for efficient GPU computation using MMA operations. + Handles tile sizes, memory layout, and packing parameters. + + Parameters + ---------- + bits : int + Quantization bits. Must be 1, 4, 8, 16, or 32. + warp_n : int + Warp size in the n dimension. + comp_n : int, optional + Computation tile size in n (default: 16). + comp_k : int, optional + Computation tile size in k (default: 256 // bits). + + Raises + ------ + AssertionError + If bits or tile/pack sizes are invalid. + + Attributes + ---------- + comp_n : int + Tile size in n for MMA computation. + comp_k : int + Tile size in k for MMA computation. + insn_n : int + MMA instruction tile size in n. + insn_k : int + MMA instruction tile size in k. + num_lanes : int + Number of lanes (threads) in a warp. + num_k_lanes : int + Number of lanes in k. + num_n_lanes : int + Number of lanes in n. + warp_n : int + Warp size in n. + reg_k : int + Elements in a register in k. + reg_n : int + Elements in a register in n. + k_pack_size : int + Elements in a pack in k. + n_pack_size : int + Elements in a pack in n. + pack_size : int + Elements in a pack accessed by a lane. + mem_k : int + Tile size in k for one memory access. + mem_n : int + Tile size in n for one memory access. + num_k_packs : int + Packs in k for one memory access. + num_n_packs : int + Packs in n for one memory access. + """ + + def __init__(self, bits: int, warp_n: int, comp_n: int = None, comp_k: int = None): + self.bits = bits + assert self.bits in (1, 4, 8, 16, 32), "weight bits should be 1, 4, 8, 16, or 32." + + # region compute tile size + self.comp_n = comp_n if comp_n is not None else 16 + # smallest tile size in `n` dimension for MMA computation. + self.comp_k = comp_k if comp_k is not None else 256 // self.bits + # smallest tile size in `k` dimension for MMA computation. + # the smallest MMA computation may contain several MMA instructions + self.insn_n = 8 # mma instruction tile size in `n` dimension + # tile size in `n` dimension for MMA instruction. + self.insn_k = self.comp_k + # tile size in `k` dimension for MMA instruction. + assert self.insn_k * self.bits in ( + 128, + 256, + ), f"insn_k ({self.insn_k}) * bits ({self.bits}) should be 128 or 256." + assert self.comp_n % self.insn_n == 0, f"comp_n ({self.comp_n}) should be divisible by insn_n ({self.insn_n})." + self.num_lanes = 32 + # there are 32 lanes (or threads) in a warp. + self.num_k_lanes = 4 + self.num_n_lanes = 8 + assert ( + warp_n >= self.comp_n and warp_n % self.comp_n == 0 + ), f"warp_n ({warp_n}) should be divisible by comp_n({self.comp_n})." + self.warp_n = warp_n + # endregion + # region memory + self.reg_k = 32 // self.bits + # number of elements in a register in `k` dimension. + self.reg_n = 1 + # number of elements in a register in `n` dimension (always 1). + self.k_pack_size = self.comp_k // (self.num_k_lanes * self.reg_k) + # number of elements in a pack in `k` dimension. + self.n_pack_size = self.comp_n // (self.num_n_lanes * self.reg_n) + # number of elements in a pack in `n` dimension. + self.pack_size = self.k_pack_size * self.n_pack_size + # number of elements in a pack accessed by a lane at a time. + assert 1 <= self.pack_size <= 4, "pack size should be less than or equal to 4." + assert self.k_pack_size * self.num_k_lanes * self.reg_k == self.comp_k + assert self.n_pack_size * self.num_n_lanes * self.reg_n == self.comp_n + self.mem_k = self.comp_k + # the tile size in `k` dimension for one tensor memory access. + self.mem_n = warp_n + # the tile size in `n` dimension for one tensor memory access. + self.num_k_packs = self.mem_k // (self.k_pack_size * self.num_k_lanes * self.reg_k) + # number of packs in `k` dimension for one tensor memory access. + self.num_n_packs = self.mem_n // (self.n_pack_size * self.num_n_lanes * self.reg_n) + # number of packs in `n` dimension for one tensor memory access. + # endregion + + def get_view_shape(self, n: int, k: int) -> tuple[int, int, int, int, int, int, int, int, int, int]: + """ + Returns the tensor view shape for MMA operations. + + Parameters + ---------- + n : int + Output channel size (must be divisible by mem_n). + k : int + Input channel size (must be divisible by mem_k). + + Returns + ------- + tuple of int + (n_tiles, num_n_packs, n_pack_size, num_n_lanes, reg_n, + k_tiles, num_k_packs, k_pack_size, num_k_lanes, reg_k) + + Raises + ------ + AssertionError + If n or k is not divisible by mem_n or mem_k. + """ + assert n % self.mem_n == 0, "output channel size should be divisible by mem_n." + assert k % self.mem_k == 0, "input channel size should be divisible by mem_k." + return ( + n // self.mem_n, + self.num_n_packs, + self.n_pack_size, + self.num_n_lanes, + self.reg_n, + k // self.mem_k, + self.num_k_packs, + self.k_pack_size, + self.num_k_lanes, + self.reg_k, + ) + + +class NunchakuWeightPacker(MmaWeightPackerBase): + """ + Nunchaku-specific weight packer. Provide Nunchaku-specific packing of + quantized weights, scales, and low-rank weights. + + Parameters + ---------- + bits : int + Number of quantization bits. Must be 1, 4, 8, 16, or 32. + warp_n : int, optional + Warp size in the n dimension. Default is 128. + + Attributes + ---------- + num_k_unrolls : int + Number of unrolls in the k dimension (always 2 for Nunchaku). + """ + + def __init__(self, bits: int, warp_n: int = 128): + super().__init__(bits=bits, warp_n=warp_n) + self.num_k_unrolls = 2 + + def pack_weight(self, weight: torch.Tensor) -> torch.Tensor: + """ + Pack quantized weight tensor for Nunchaku MMA. + + Parameters + ---------- + weight : torch.Tensor + Quantized weight tensor of dtype torch.int32 and shape (n, k). + + Returns + ------- + torch.Tensor + Packed weight tensor of dtype torch.int8. + """ + assert weight.dtype == torch.int32, f"quantized weight should be torch.int32, but got {weight.dtype}." + n, k = weight.shape + assert n % self.mem_n == 0, f"output channel size ({n}) should be divisible by mem_n ({self.mem_n})." + # currently, Nunchaku did not check the boundry of unrolled `k` dimension + assert k % (self.mem_k * self.num_k_unrolls) == 0, ( + f"input channel size ({k}) should be divisible by " + f"mem_k ({self.mem_k}) * num_k_unrolls ({self.num_k_unrolls})." + ) + n_tiles, k_tiles = n // self.mem_n, k // self.mem_k + weight = weight.reshape( + n_tiles, + self.num_n_packs, # 8 when warp_n = 128 + self.n_pack_size, # always 2 in nunchaku + self.num_n_lanes, # constant 8 + self.reg_n, # constant 1 + k_tiles, + self.num_k_packs, # 1 + self.k_pack_size, # always 2 in nunchaku + self.num_k_lanes, # constant 4 + self.reg_k, # always 8 = 32 bits / 4 bits + ) + # (n_tiles, num_n_packs, n_pack_size, num_n_lanes, reg_n, k_tiles, num_k_packs, k_pack_size, num_k_lanes, reg_k) + # => + # (n_tiles, k_tiles, num_k_packs, num_n_packs, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, reg_n, reg_k) + weight = weight.permute(0, 5, 6, 1, 3, 8, 2, 7, 4, 9).contiguous() + assert weight.shape[4:-2] == (8, 4, 2, 2) + if self.bits == 4: + weight = weight.bitwise_and_(0xF) + shift = torch.arange(0, 32, 4, dtype=torch.int32, device=weight.device) + weight = weight.bitwise_left_shift_(shift) + weight = weight.sum(dim=-1, dtype=torch.int32) + elif self.bits == 8: + weight = weight.bitwise_and_(0xFF) + shift = torch.arange(0, 32, 8, dtype=torch.int32, device=weight.device) + weight = weight.bitwise_left_shift_(shift) + weight = weight.sum(dim=-1, dtype=torch.int32) + else: + raise NotImplementedError(f"weight bits {self.bits} is not supported.") + return weight.view(dtype=torch.int8).view(n, -1) # assume little-endian + + def pack_scale(self, scale: torch.Tensor, group_size: int) -> torch.Tensor: + """ + Pack scale tensor for Nunchaku MMA. + + Parameters + ---------- + scale : torch.Tensor + Scale tensor of dtype torch.float16 or torch.bfloat16. + group_size : int + Group size for quantization. + + Returns + ------- + torch.Tensor + Packed scale tensor. + """ + if self.check_if_micro_scale(group_size=group_size): + return self.pack_micro_scale(scale, group_size=group_size) + # note: refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16864-c + assert scale.dtype in (torch.float16, torch.bfloat16), "currently nunchaku only supports fp16 and bf16." + n = scale.shape[0] + # nunchaku load scales all in one access + # for `[warp_n, warp_k]` weights, we load `[warp_n, warp_k / group_size]` scales + # scale loading is parallelized in `n` dimension, that is, + # `num_s_lanes` in a warp load `num_s_packs` of `s_pack_size` elements, in total `warp_s` elements + # each element in `n` dimension is 16 bit as it contains 1 fp16 + # min `s_pack_size` set to 2 element, since each lane at least holds 2 accumulator results in `n` dimension + # max `s_pack_size` set to 128b/16b = 8 elements + # for `warp_n = 8`, we have + # `s_pack_size = 2`, `num_s_lanes = 4`, `num_s_packs = 1` + # for `warp_n = 128`, we have + # `s_pack_size = 4`, `num_s_lanes = 32`, `num_s_packs = 1` + # for `warp_n = 512`, we have + # `s_pack_size = 8`, `num_s_lanes = 32`, `num_s_packs = 2` + s_pack_size = min(max(self.warp_n // self.num_lanes, 2), 8) + num_s_lanes = min(self.num_lanes, self.warp_n // s_pack_size) + num_s_packs = self.warp_n // (s_pack_size * num_s_lanes) + warp_s = num_s_packs * num_s_lanes * s_pack_size + assert warp_s == self.warp_n, "warp_n for scales should be equal to warp_n for weights." + # `num_n_lanes = 8 (constant)` generates 8 elements consecutive in `n` dimension + # however, they are held by 4 lanes, each lane holds 2 elements in `n` dimension + # thus, we start from first 4 lanes, assign 2 elements to each lane, until all 8 elements are assigned + # we then repeat the process for the same 4 lanes, until each lane holds `s_pack_size` elements + # finally, we move to next 4 lanes, and repeat the process until all `num_s_lanes` lanes are assigned + # the process is repeated for `num_s_packs` times + # here is an example for `warp_n = 128, s_pack_size = 4, num_s_lanes = 32, num_s_packs = 1` + # wscales store order: + # 0 1 8 9 <-- load by lane 0, broadcast to lane {0, 4, 8, ..., 28} (8x) + # 2 3 10 11 <-- load by lane 1, broadcast to lane {1, 5, 9, ..., 29} (8x) + # 4 5 12 13 <-- load by lane 2, broadcast to lane {2, 6, 10, ..., 30} (8x) + # 6 7 14 15 <-- load by lane 3, broadcast to lane {3, 7, 11, ..., 31} (8x) + # 16 17 24 25 <-- load by lane 4, broadcast to lane {0, 4, 8, ..., 28} (8x) + # ... + # 22 23 30 31 <-- load by lane 7, broadcast to lane {3, 7, 11, ..., 31} (8x) + # ... ... + # 112 113 120 121 <-- load by lane 28, broadcast to lane {0, 4, 8, ..., 28} (8x) + # ... + # 118 119 126 127 <-- load by lane 31, broadcast to lane {3, 7, 11, ..., 31} (8x) + scale = scale.reshape(n // warp_s, num_s_packs, num_s_lanes // 4, s_pack_size // 2, 4, 2, -1) + scale = scale.permute(0, 6, 1, 2, 4, 3, 5).contiguous() + return scale.view(-1) if group_size == -1 else scale.view(-1, n) # the shape is just used for validation + + def pack_micro_scale(self, scale: torch.Tensor, group_size: int) -> torch.Tensor: + """ + Pack micro scale tensor for Nunchaku MMA. + + Parameters + ---------- + scale : torch.Tensor + Scale tensor of dtype torch.float16 or torch.bfloat16. + group_size : int + Group size for quantization (must be 16). + + Returns + ------- + torch.Tensor + Packed micro scale tensor. + """ + assert scale.dtype in (torch.float16, torch.bfloat16), "currently nunchaku only supports fp16 and bf16." + assert scale.max() <= 448, "scale should be less than 448." + assert scale.min() >= -448, "scale should be greater than -448." + assert group_size == 16, "currently only support group size 16." + assert self.insn_k == 64, "insn_k should be 64." + scale = scale.to(dtype=torch.float8_e4m3fn) + n = scale.shape[0] + assert self.warp_n >= 32, "currently only support warp_n >= 32." + # for `[warp_n, warp_k]` weights, we load `[warp_n, warp_k / group_size]` scales + # scale loading is parallelized in `n` dimension, that is, + # `num_s_lanes` in a warp load `num_s_packs` of `s_pack_size` elements, in total `warp_s` elements + # each element in `n` dimension is 32 bit as it contains 4 fp8 in `k` dimension + # min `s_pack_size` set to 1 element + # max `s_pack_size` set to 128b/32b = 4 elements + # for `warp_n = 128`, we have + # `s_pack_size = 4`, `num_s_lanes = 32`, `num_s_packs = 1` + # for `warp_n = 512`, we have + # `s_pack_size = 8`, `num_s_lanes = 32`, `num_s_packs = 2` + s_pack_size = min(max(self.warp_n // self.num_lanes, 1), 4) + num_s_lanes = 4 * 8 # 32 lanes is divided into 4 pieces, each piece has 8 lanes at a stride of 4 + num_s_packs = ceil_divide(self.warp_n, s_pack_size * num_s_lanes) + warp_s = num_s_packs * num_s_lanes * s_pack_size + assert warp_s == self.warp_n, "warp_n for scales should be equal to warp_n for weights." + # note: refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-scaling-thread-id-b-selection + # we start from first 8 lines at a stride of 4, assign 1 element to each lane, until all 8 elements are assigned + # we then move to next 8 lines at a stride of 4, and repeat the process until all 32 lanes are assigned + # here is an example for `warp_n = 128, s_pack_size = 4, num_s_lanes = 32, num_s_packs = 1` + # wscales store order: + # 0 32 64 96 <-- load by lane 0 + # 8 40 72 104 <-- load by lane 1 + # 16 48 80 112 <-- load by lane 2 + # 24 56 88 120 <-- load by lane 3 + # 1 33 65 97 <-- load by lane 4 + # ... + # 25 57 81 113 <-- load by lane 7 + # ... + # 7 39 71 103 <-- load by lane 28 + # ... + # 31 63 95 127 <-- load by lane 31 + scale = scale.view(n // warp_s, num_s_packs, s_pack_size, 4, 8, -1, self.insn_k // group_size) + scale = scale.permute(0, 5, 1, 4, 3, 2, 6).contiguous() + return scale.view(-1, n) # the shape is just used for validation + + def pack_lowrank_weight(self, weight: torch.Tensor, down: bool) -> torch.Tensor: + """ + Pack low-rank weight tensor. + + Parameters + ---------- + weight : torch.Tensor + Low-rank weight tensor of dtype torch.float16 or torch.bfloat16. + down : bool + If True, weight is for down projection in low-rank branch. + + Returns + ------- + torch.Tensor + Packed low-rank weight tensor. + """ + assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}." + reg_n, reg_k = 1, 2 # reg_n is always 1, reg_k is 32 bits // 16 bits = 2 + pack_n = self.n_pack_size * self.num_n_lanes * reg_n + pack_k = self.k_pack_size * self.num_k_lanes * reg_k + weight = pad(weight, divisor=(pack_n, pack_k), dim=(0, 1)) + if down: + r, c = weight.shape + r_packs, c_packs = r // pack_n, c // pack_k + weight = weight.view(r_packs, pack_n, c_packs, pack_k).permute(2, 0, 1, 3) + else: + c, r = weight.shape + c_packs, r_packs = c // pack_n, r // pack_k + weight = weight.view(c_packs, pack_n, r_packs, pack_k).permute(0, 2, 1, 3) + weight = weight.reshape( + c_packs, r_packs, self.n_pack_size, self.num_n_lanes, reg_n, self.k_pack_size, self.num_k_lanes, reg_k + ) + # (c_packs, r_packs, n_pack_size, num_n_lanes, reg_n, k_pack_size, num_k_lanes, reg_k) + # => + # (c_packs, r_packs, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, reg_n, reg_k) + weight = weight.permute(0, 1, 3, 6, 2, 5, 4, 7).contiguous() + return weight.view(c, r) + + def unpack_lowrank_weight(self, weight: torch.Tensor, down: bool) -> torch.Tensor: + """ + Unpack low-rank weight tensor. + + Parameters + ---------- + weight : torch.Tensor + Packed low-rank weight tensor of dtype torch.float16 or torch.bfloat16. + down : bool + If True, weight is for down projection in low-rank branch. + + Returns + ------- + torch.Tensor + Unpacked low-rank weight tensor. + """ + c, r = weight.shape + assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}." + reg_n, reg_k = 1, 2 # reg_n is always 1, reg_k is 32 bits // 16 bits = 2 + pack_n = self.n_pack_size * self.num_n_lanes * reg_n + pack_k = self.k_pack_size * self.num_k_lanes * reg_k + if down: + r_packs, c_packs = r // pack_n, c // pack_k + else: + c_packs, r_packs = c // pack_n, r // pack_k + weight = weight.view( + c_packs, r_packs, self.num_n_lanes, self.num_k_lanes, self.n_pack_size, self.k_pack_size, reg_n, reg_k + ) + # (c_packs, r_packs, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, reg_n, reg_k) + # => + # (c_packs, r_packs, n_pack_size, num_n_lanes, reg_n, k_pack_size, num_k_lanes, reg_k) + weight = weight.permute(0, 1, 4, 2, 6, 5, 3, 7).contiguous() + weight = weight.view(c_packs, r_packs, pack_n, pack_k) + if down: + weight = weight.permute(1, 2, 0, 3).contiguous().view(r, c) + else: + weight = weight.permute(0, 2, 1, 3).contiguous().view(c, r) + return weight + + def check_if_micro_scale(self, group_size: int) -> bool: + """ + Check if micro scale packing is required. + + Parameters + ---------- + group_size : int + Group size for quantization. + + Returns + ------- + bool + True if micro scale packing is required. + """ + return self.insn_k == group_size * 4 + + def pad_weight(self, weight: torch.Tensor) -> torch.Tensor: + """ + Pad weight tensor to required shape. + + Parameters + ---------- + weight : torch.Tensor + Weight tensor of shape (n, k). + + Returns + ------- + torch.Tensor + Padded weight tensor. + """ + assert weight.ndim == 2, "weight tensor should be 2D." + return pad(weight, divisor=(self.mem_n, self.mem_k * self.num_k_unrolls), dim=(0, 1)) + + def pad_scale(self, scale: torch.Tensor, group_size: int, fill_value: float = 0) -> torch.Tensor: + """ + Pad scale tensor to required shape. + + Parameters + ---------- + scale : torch.Tensor + Scale tensor. + group_size : int + Group size for quantization. + fill_value : float, optional + Value to use for padding. Default is 0. + + Returns + ------- + torch.Tensor + Padded scale tensor. + """ + if group_size > 0 and scale.numel() > scale.shape[0]: + scale = scale.view(scale.shape[0], 1, -1, 1) + if self.check_if_micro_scale(group_size=group_size): + scale = pad(scale, divisor=(self.warp_n, self.insn_k // group_size), dim=(0, 2), fill_value=fill_value) + else: + scale = pad(scale, divisor=(self.warp_n, self.num_k_unrolls), dim=(0, 2), fill_value=fill_value) + else: + scale = pad(scale, divisor=self.warp_n, dim=0, fill_value=fill_value) + return scale + + def pad_lowrank_weight(self, weight: torch.Tensor, down: bool) -> torch.Tensor: + """ + Pad low-rank weight tensor to required shape. + + Parameters + ---------- + weight : torch.Tensor + Low-rank weight tensor. + down : bool + If True, weight is for down projection in low-rank branch. + + Returns + ------- + torch.Tensor + Padded low-rank weight tensor. + """ + assert weight.ndim == 2, "weight tensor should be 2D." + return pad(weight, divisor=self.warp_n, dim=1 if down else 0) diff --git a/nunchaku/lora/flux/utils.py b/nunchaku/lora/flux/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..65b86ea815e365d11a8d91634f0f677e31d17e84 --- /dev/null +++ b/nunchaku/lora/flux/utils.py @@ -0,0 +1,94 @@ +""" +Utility functions for LoRAs in Flux models. +""" + +import typing as tp + +import torch + +from ...utils import ceil_divide, load_state_dict_in_safetensors + + +def is_nunchaku_format(lora: str | dict[str, torch.Tensor]) -> bool: + """ + Check if LoRA weights are in Nunchaku format. + + Parameters + ---------- + lora : str or dict[str, torch.Tensor] + Path to a safetensors file or a dictionary of LoRA weights. + + Returns + ------- + bool + True if the weights are in Nunchaku format, False otherwise. + + Examples + -------- + >>> is_nunchaku_format("path/to/lora.safetensors") + True + """ + if isinstance(lora, str): + tensors = load_state_dict_in_safetensors(lora, device="cpu", return_metadata=False) + assert isinstance(tensors, dict), "Expected dict when return_metadata=False" + else: + tensors = lora + + for k in tensors.keys(): + if ".mlp_fc" in k or "mlp_context_fc1" in k: + return True + return False + + +def pad( + tensor: tp.Optional[torch.Tensor], + divisor: int | tp.Sequence[int], + dim: int | tp.Sequence[int], + fill_value: float | int = 0, +) -> torch.Tensor | None: + """ + Pad a tensor so specified dimensions are divisible by given divisors. + + Parameters + ---------- + tensor : torch.Tensor or None + The tensor to pad. If None, returns None. + divisor : int or sequence of int + Divisor(s) for the dimension(s) to pad. + dim : int or sequence of int + Dimension(s) to pad. + fill_value : float or int, optional + Value to use for padding (default: 0). + + Returns + ------- + torch.Tensor or None + The padded tensor, or None if input tensor was None. + + Examples + -------- + >>> tensor = torch.randn(10, 20) + >>> pad(tensor, divisor=16, dim=0).shape + torch.Size([16, 20]) + >>> pad(tensor, divisor=[16, 32], dim=[0, 1]).shape + torch.Size([16, 32]) + """ + if isinstance(divisor, int): + if divisor <= 1: + return tensor + elif all(d <= 1 for d in divisor): + return tensor + if tensor is None: + return None + shape = list(tensor.shape) + if isinstance(dim, int): + assert isinstance(divisor, int) + shape[dim] = ceil_divide(shape[dim], divisor) * divisor + else: + if isinstance(divisor, int): + divisor = [divisor] * len(dim) + for d, div in zip(dim, divisor, strict=True): + shape[d] = ceil_divide(shape[d], div) * div + result = torch.full(shape, fill_value, dtype=tensor.dtype, device=tensor.device) + result[[slice(0, extent) for extent in tensor.shape]] = tensor + return result diff --git a/nunchaku/models/__init__.py b/nunchaku/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bc7e804c946ab9242bed0af708c348b43483a31 --- /dev/null +++ b/nunchaku/models/__init__.py @@ -0,0 +1,9 @@ +from .text_encoders.t5_encoder import NunchakuT5EncoderModel +from .transformers import ( + NunchakuFluxTransformer2dModel, +) + +__all__ = [ + "NunchakuFluxTransformer2dModel", + "NunchakuT5EncoderModel", +] diff --git a/nunchaku/models/attention.py b/nunchaku/models/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9fdbce2c3cdc2ed81a4a47caa91111bbbf3369 --- /dev/null +++ b/nunchaku/models/attention.py @@ -0,0 +1,123 @@ +""" +Nunchaku quantized attention-related modules. +""" + +import torch +from diffusers.models.activations import GELU +from diffusers.models.attention import FeedForward +from torch import nn + +from ..ops.fused import fused_gelu_mlp +from .linear import SVDQW4A4Linear + + +class NunchakuBaseAttention(nn.Module): + """ + Base class for Nunchaku attention modules. + + Provides a common interface for attention modules with processor selection. + + Parameters + ---------- + processor : str, optional + Name of the attention processor to use. Default is "flashattn2". + *args, **kwargs : + Additional arguments for subclass initialization. + """ + + def __init__(self, processor: str = "flashattn2", *args, **kwargs): + super(NunchakuBaseAttention, self).__init__() + self.processor = None + self.set_processor(processor) + + def set_processor(self, processor: str): + """ + Set the attention processor. Must be implemented by subclasses. + + Parameters + ---------- + processor : str + Name of the processor to use. + + Raises + ------ + NotImplementedError + If not implemented in subclass. + """ + raise NotImplementedError("Subclass must implement this method") + + +def _patch_linear(module: nn.Module, linear_cls, **kwargs) -> nn.Module: + """ + Recursively replace all nn.Linear modules in a given module with a custom linear class. + + Parameters + ---------- + module : nn.Module + The module to patch. + linear_cls : type + The custom linear class to use for replacement. + **kwargs : + Additional arguments passed to ``from_linear``. + + Returns + ------- + nn.Module + The patched module with custom linear layers. + """ + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + setattr(module, name, linear_cls.from_linear(child, **kwargs)) + else: + _patch_linear(child, linear_cls, **kwargs) + return module + + +class NunchakuFeedForward(FeedForward): + """ + Quantized feed-forward (MLP) block with fused GELU support. + + Replaces linear layers in a FeedForward block with :class:`~nunchaku.models.linear.SVDQW4A4Linear` for quantized inference. + Supports fused GELU-MLP computation for efficiency. + + Parameters + ---------- + ff : FeedForward + Source FeedForward block to quantize. + **kwargs : + Additional arguments for SVDQW4A4Linear. + + Notes + ----- + For int4 quantization, the activation of the second MLP layer is shifted to be unsigned. + """ + + def __init__(self, ff: FeedForward, **kwargs): + super(FeedForward, self).__init__() + self.net = _patch_linear(ff.net, SVDQW4A4Linear, **kwargs) + # For int4, shift the activation of mlp_fc2 to make it unsigned + self.net[2].act_unsigned = self.net[2].precision != "nvfp4" + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the quantized feed-forward block. + It will call :func:`~nunchaku.ops.fused.fused_gelu_mlp` if the first layer is GELU; + otherwise, apply modules sequentially. + + Parameters + ---------- + hidden_states : torch.Tensor, shape (B, D) + Input tensor. + + Returns + ------- + torch.Tensor, shape (B, D) + Output tensor after feed-forward transformation. + """ + if isinstance(self.net[0], GELU): + return fused_gelu_mlp(hidden_states, self.net[0].proj, self.net[2]) + else: + # Fallback to original implementation + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states diff --git a/nunchaku/models/embeddings.py b/nunchaku/models/embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..12b50966200dc321c6b99679ac7cadfb07f8e594 --- /dev/null +++ b/nunchaku/models/embeddings.py @@ -0,0 +1,138 @@ +""" +Embedding layers for Nunchaku. +""" + +import diffusers +import torch +from packaging.version import Version +from torch import nn + + +def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: + """ + Rotary positional embedding function. + Copied from https://github.com/huggingface/diffusers/blob/c9ff360966327ace3faad3807dc871a4e5447501/src/diffusers/models/transformers/transformer_flux.py#L38 + + Parameters + ---------- + pos : torch.Tensor, shape (..., n), dtype int + Position indices. + dim : int + Embedding dimension (must be even). + theta : int + Rotary base. + + Returns + ------- + out : torch.Tensor, shape (B, M, D//2, 1, 2), dtype float32 + Rotary embedding tensor. + + Notes + ----- + - B: batch size + - M: sequence length + - D: embedding dimension + """ + assert dim % 2 == 0, "The dimension must be even." + + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + + batch_size, seq_length = pos.shape + out = torch.einsum("...n,d->...nd", pos, omega) + + # Sin/cos representation for rotary embedding + cos_out = torch.cos(out) + sin_out = torch.sin(out) + stacked_out = torch.stack([sin_out, cos_out], dim=-1) + out = stacked_out.view(batch_size, -1, dim // 2, 1, 2) + + return out.float() + + +class NunchakuFluxPosEmbed(nn.Module): + """ + Nunchaku multi-dimensional rotary embedding module for FLUX. + Adapted from https://github.com/huggingface/diffusers/blob/c9ff360966327ace3faad3807dc871a4e5447501/src/diffusers/models/transformers/transformer_flux.py#L55 + + Parameters + ---------- + dim : int + Embedding dimension. + theta : int + Rotary base. + axes_dim : list of int + Dimension for each spatial axis. + """ + + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super(NunchakuFluxPosEmbed, self).__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + """ + Compute rotary embeddings for multi-dimensional positions. + + Parameters + ---------- + ids : torch.Tensor, shape (..., n_axes), dtype int + Position indices. + + Returns + ------- + out : torch.Tensor, shape (B, 1, ...), dtype float32 + Rotary embedding tensor. + + Notes + ----- + - B: batch size + - n_axes: number of spatial axes + """ + if Version(diffusers.__version__) >= Version("0.31.0"): + ids = ids[None, ...] + n_axes = ids.shape[-1] + emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3) + return emb.unsqueeze(1) + + +def pack_rotemb(rotemb: torch.Tensor) -> torch.Tensor: + """ + Pack rotary embeddings for efficient CUDA computation. + + Parameters + ---------- + rotemb : torch.Tensor, shape (B, M, D//2, 1, 2), dtype float32 + Rotary embedding tensor. + + Returns + ------- + packed : torch.Tensor, shape (B, M, D), dtype float32 + Packed rotary embedding tensor. + + Notes + ----- + - B: batch size + - M: sequence length (must be divisible by 16) + - D: embedding dimension (must be divisible by 8) + """ + assert rotemb.dtype == torch.float32 + B = rotemb.shape[0] + M = rotemb.shape[1] + D = rotemb.shape[2] * 2 + assert rotemb.shape == (B, M, D // 2, 1, 2) + assert M % 16 == 0 + assert D % 8 == 0 + rotemb = rotemb.reshape(B, M // 16, 16, D // 8, 8) + rotemb = rotemb.permute(0, 1, 3, 2, 4) + # 16*8 pack, FP32 accumulator (C) format + # https://docs.nvidia.com/cuda/parallel-thread-execution/#mma-16816-c + ##########################################|--M--|--D--| + ##########################################|-3--4--5--6| + ########################################## : : : : + rotemb = rotemb.reshape(*rotemb.shape[0:3], 2, 8, 4, 2) + rotemb = rotemb.permute(0, 1, 2, 4, 5, 3, 6) + rotemb = rotemb.contiguous() + rotemb = rotemb.view(B, M, D) + return rotemb diff --git a/nunchaku/models/linear.py b/nunchaku/models/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..7e3b6cd51aeb98f10a3034e761eb14a27d8d81f9 --- /dev/null +++ b/nunchaku/models/linear.py @@ -0,0 +1,414 @@ +""" +Quantized linear layers for Nunchaku. +""" + +import torch +from torch import nn + +from ..ops.gemm import svdq_gemm_w4a4_cuda +from ..ops.gemv import awq_gemv_w4a16_cuda +from ..ops.quantize import svdq_quantize_w4a4_act_fuse_lora_cuda + + +class SVDQW4A4Linear(nn.Module): + """ + `SVDQuant `_ W4A4 quantized linear layer. + + Parameters + ---------- + in_features : int + Input feature dimension. + out_features : int + Output feature dimension. + rank : int, optional + SVD low-rank dimension. Default is 32. + bias : bool, optional + If True, adds a learnable bias. Default is True. + precision : {'int4', 'nvfp4'}, optional + Quantization precision data type ('int4' or 'nvfp4'). Default is 'int4'. + act_unsigned : bool, optional + If True, use unsigned activation quantization (int4 only). Default is False. + torch_dtype : torch.dtype, optional + Parameter dtype. Default is torch.bfloat16. + device : str or torch.device or None, optional + Device for parameters. Default is CPU. + + Attributes + ---------- + in_features : int + out_features : int + rank : int + precision : str + 'int4' or 'nvfp4'. + group_size : int + 64 for int4, 16 for nvfp4. + qweight : nn.Parameter + Packed quantized weights, shape (out_features, in_features // 2), dtype int8. + bias : nn.Parameter or None + Bias tensor. + wscales : nn.Parameter + Weight scales, shape (in_features // group_size, out_features). + Dtype: bfloat16/float16 (int4), float8_e4m3fn (nvfp4). + smooth_factor : nn.Parameter + Smoothing factors, shape (in_features,). + smooth_factor_orig : nn.Parameter + Original smoothing factors, shape (in_features,). (Unused) + proj_down : nn.Parameter + Packed low-rank down projection, shape (in_features, rank), dtype bfloat16/float16. + proj_up : nn.Parameter + Packed low-rank up projection, shape (out_features, rank), dtype bfloat16/float16. + wtscale : float or None + Global weight scale (nvfp4 only). + wcscales : nn.Parameter or None + Channel-wise weight scale (nvfp4 only), shape (out_features,), dtype float8_e4m3fn. + act_unsigned : bool + If True, input activations are unsigned (int4 only). + """ + + def __init__( + self, + in_features: int, + out_features: int, + rank: int = 32, + bias: bool = True, + precision: str = "int4", + act_unsigned: bool = False, + torch_dtype: torch.dtype = torch.bfloat16, + device: str | torch.device | None = None, + ): + super(SVDQW4A4Linear, self).__init__() + if device is None: + device = torch.device("cpu") + self.in_features = in_features + self.out_features = out_features + self.rank = rank + + self.precision = precision + self.torch_dtype = torch_dtype + + if precision == "nvfp4": + self.group_size = 16 + elif precision == "int4": + self.group_size = 64 + else: + raise ValueError(f"Invalid precision: {precision}") + + self.qweight = nn.Parameter( + torch.empty(out_features, in_features // 2, dtype=torch.int8, device=device), requires_grad=False + ) + self.bias = ( + nn.Parameter(torch.empty(out_features, dtype=torch_dtype, device=device), requires_grad=True) + if bias + else None + ) + + self.wscales = nn.Parameter( + torch.empty( + in_features // self.group_size, + out_features, + dtype=torch_dtype if precision == "int4" else torch.float8_e4m3fn, + device=device, + ), + requires_grad=False, + ) + self.smooth_factor = nn.Parameter( + torch.empty(in_features, dtype=torch_dtype, device=device), requires_grad=False + ) + self.smooth_factor_orig = nn.Parameter( + torch.empty(in_features, dtype=torch_dtype, device=device), requires_grad=False + ) + + self.proj_down = nn.Parameter(torch.empty(in_features, rank, dtype=torch_dtype, device=device)) + self.proj_up = nn.Parameter(torch.empty(out_features, rank, dtype=torch_dtype, device=device)) + + if precision == "nvfp4": + self.wcscales = nn.Parameter( + torch.ones(out_features, dtype=torch_dtype, device=device), requires_grad=False + ) + self.wtscale = 1.0 + else: + self.wtscale = None + self.wcscales = None + + self.act_unsigned = act_unsigned + + @classmethod + def from_linear(cls, linear: nn.Linear, **kwargs): + """ + Create an SVDQW4A4Linear from a standard nn.Linear. The weight and bias are dummy tensors. + + Parameters + ---------- + linear : nn.Linear + Source linear layer. + **kwargs + Additional init arguments. + + Returns + ------- + SVDQW4A4Linear + """ + in_features = kwargs.pop("in_features", linear.in_features) + return cls( + in_features=in_features, + out_features=linear.out_features, + bias=linear.bias is not None, + torch_dtype=linear.weight.dtype, + device=linear.weight.device, + **kwargs, + ) + + def forward(self, x: torch.Tensor, output: torch.Tensor | None = None) -> torch.Tensor: + """ + Forward pass with 16-bit input. It will call :meth:`quantize` and :meth:`forward_quant`. + + Parameters + ---------- + x : torch.Tensor, shape (B, S, in_features), dtype float16 or bfloat16 + Input tensor. + output : torch.Tensor or None, optional + Optional output buffer. + + Returns + ------- + torch.Tensor, shape (B, S, out_features) + Output tensor. + + Notes + ----- + B: batch size, S: sequence length + """ + batch_size, seq_len, channels = x.shape + x = x.reshape(batch_size * seq_len, channels) + if output is None: + output = torch.empty(batch_size * seq_len, self.out_features, dtype=x.dtype, device=x.device) + quantized_x, ascales, lora_act_out = self.quantize(x) + output = self.forward_quant(quantized_x, ascales, lora_act_out, output) + output = output.reshape(batch_size, seq_len, -1) + return output + + def quantize(self, x: torch.Tensor, pad_size: int = 256) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Quantize input to 4-bit and compute low-rank hidden states. It will call :func:`~nunchaku.ops.quantize.svdq_quantize_w4a4_act_fuse_lora_cuda`. + + Parameters + ---------- + x : torch.Tensor, shape (N, in_features), dtype float16 or bfloat16 + Input tensor. + pad_size : int, optional + Batch padding size. Default is 256. + + Returns + ------- + quantized_x : torch.Tensor + Quantized input, shape (pad_size * ceil(N / pad_size), in_features // 2), dtype uint8. + ascales : torch.Tensor + Activation scales, shape (in_features // group_size,), dtype float8_e4m3fn for nvfp4 and input dtype for int4. + lora_act_out : torch.Tensor + Low-rank hidden states, shape (pad_size * ceil(N / pad_size), rank), dtype float32. + + Notes + ----- + N: batch size + """ + quantized_x, ascales, lora_act_out = svdq_quantize_w4a4_act_fuse_lora_cuda( + x, lora_down=self.proj_down, smooth=self.smooth_factor, fp4=self.precision == "nvfp4", pad_size=pad_size + ) + return quantized_x, ascales, lora_act_out + + def forward_quant( + self, + quantized_x: torch.Tensor, + ascales: torch.Tensor, + lora_act: torch.Tensor, + output: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Forward pass with pre-quantized input. It will call :func:`~nunchaku.ops.gemm.svdq_gemm_w4a4_cuda`. + + Parameters + ---------- + quantized_x : torch.Tensor + Quantized input, shape (N, in_features // 2), dtype uint8. + ascales : torch.Tensor + Activation scales, shape (in_features // group_size,), dtype float8_e4m3fn for nvfp4 and input dtype for int4. + lora_act : torch.Tensor + Low-rank hidden states, shape (N, rank), dtype float32. + output : torch.Tensor or None, optional + Optional output buffer. + + Returns + ------- + torch.Tensor + Output tensor, shape (N, out_features), dtype bfloat16/float16 for int4 and float8_e4m3fn for nvfp4. + + Notes + ----- + N: batch size + """ + if output is None: + output = torch.empty( + quantized_x.shape[0], self.out_features, dtype=self.proj_up.dtype, device=quantized_x.device + ) + + svdq_gemm_w4a4_cuda( + act=quantized_x, + wgt=self.qweight, + out=output, + ascales=ascales, + wscales=self.wscales, + lora_act_in=lora_act, + lora_up=self.proj_up, + bias=self.bias, + fp4=self.precision == "nvfp4", + alpha=self.wtscale, + wcscales=self.wcscales, + act_unsigned=self.act_unsigned, + ) + return output + + def __repr__(self): + return ( + f"SVDQW4A4Linear(in_features={self.in_features}, out_features={self.out_features}, " + f"rank={self.rank}, precision={self.precision}, act_unsigned={self.act_unsigned})" + ) + + +class AWQW4A16Linear(nn.Module): + """ + `AWQ `_ W4A16 quantized linear layer. + + Parameters + ---------- + in_features : int + Input feature dimension. + out_features : int + Output feature dimension. + bias : bool, optional + If True, adds learnable bias. Default is True. + group_size : int, optional + Quantization group size. Default is 64. + torch_dtype : torch.dtype, optional + Parameter dtype. Default is torch.bfloat16. + device : str or torch.device or None, optional + Device for parameters. Default is CPU. + + Attributes + ---------- + in_features : int + out_features : int + group_size : int + qweight : nn.Parameter + Packed quantized weights, shape (out_features // 4, in_features // 2), dtype int32. + bias : nn.Parameter or None + Bias tensor. + wscales : nn.Parameter + Weight scales, shape (in_features // group_size, out_features), dtype float16 or bfloat16. + wzeros : nn.Parameter + Weight zero points, shape (in_features // group_size, out_features), dtype float16 or bfloat16. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + group_size: int = 64, + torch_dtype: torch.dtype = torch.bfloat16, + device: str | torch.device | None = None, + ): + super(AWQW4A16Linear, self).__init__() + if device is None: + device = torch.device("cpu") + self.in_features = in_features + self.out_features = out_features + self.group_size = group_size + + self.qweight = nn.Parameter( + torch.empty(out_features // 4, in_features // 2, dtype=torch.int32, device=device), requires_grad=False + ) + self.bias = ( + nn.Parameter(torch.empty(out_features, dtype=torch_dtype, device=device), requires_grad=True) + if bias + else None + ) + self.wscales = nn.Parameter( + torch.empty(in_features // self.group_size, out_features, dtype=torch_dtype, device=device), + requires_grad=False, + ) + self.wzeros = nn.Parameter( + torch.empty(in_features // self.group_size, out_features, dtype=torch_dtype, device=device), + requires_grad=False, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for AWQW4A16Linear. + + Parameters + ---------- + x : torch.Tensor, shape (N, in_features) + Input tensor. + + Returns + ------- + torch.Tensor, shape (N, out_features) + Output tensor. + + Notes + ----- + N: batch size + """ + output = awq_gemv_w4a16_cuda( + in_feats=x, + kernel=self.qweight, + scaling_factors=self.wscales, + zeros=self.wzeros, + m=x.shape[0], + n=self.out_features, + k=self.in_features, + group_size=self.group_size, + ) + if self.bias is not None: + view_shape = [1] * (output.ndim - 1) + [-1] + output.add_(self.bias.view(view_shape)) + return output + + @classmethod + def from_linear( + cls, + linear: nn.Linear, + group_size: int = 64, + torch_dtype: torch.dtype = torch.bfloat16, + device: str = "cpu", + **kwargs, + ): + """ + Create an uninitialized AWQW4A16Linear from a standard nn.Linear. + + Parameters + ---------- + linear : nn.Linear + Source linear layer. + group_size : int, optional + Quantization group size. + torch_dtype : torch.dtype, optional + Parameter dtype. + device : str, optional + Device for parameters. + + Returns + ------- + AWQW4A16Linear + """ + return cls( + in_features=linear.in_features, + out_features=linear.out_features, + bias=linear.bias is not None, + group_size=group_size, + torch_dtype=torch_dtype, + device=device, + ) + + def __repr__(self): + return f"AWQW4A16Linear(in_features={self.in_features}, out_features={self.out_features}, group_size={self.group_size})" diff --git a/nunchaku/models/normalization.py b/nunchaku/models/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..eede8c73bbd1291137ab931351b3ed8d31ba6b29 --- /dev/null +++ b/nunchaku/models/normalization.py @@ -0,0 +1,166 @@ +""" +Quantized normalization layers for efficient inference. +""" + +from typing import Optional, Tuple + +import torch +from diffusers.models.normalization import AdaLayerNormZero, AdaLayerNormZeroSingle + +from .linear import AWQW4A16Linear + + +class NunchakuAdaLayerNormZero(AdaLayerNormZero): + """ + Nunchaku quantized AdaLayerNormZero for diffusion models. + + Replaces the linear projection with AWQW4A16Linear for quantized inference. + + Parameters + ---------- + other : AdaLayerNormZero + Source AdaLayerNormZero instance to copy weights and structure from. + scale_shift : float, optional + Value to add to scale parameters. Default is 1.0. + Nunchaku may have already fused the scale_shift into the linear weights, so you may want to set it to 0. + + Notes + ----- + - B: batch size + - D: hidden dimension + """ + + def __init__(self, other: AdaLayerNormZero, scale_shift: float = 1.0): + super(AdaLayerNormZero, self).__init__() + self.scale_shift = scale_shift + self.emb = other.emb + self.silu = other.silu + self.linear = AWQW4A16Linear.from_linear(other.linear) + self.norm = other.norm + + def forward( + self, + x: torch.Tensor, + timestep: Optional[torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + hidden_dtype: Optional[torch.dtype] = None, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Forward pass for quantized AdaLayerNormZero. + + Parameters + ---------- + x : torch.Tensor, shape (B, D), dtype float32/float16 + Input tensor. + timestep : Optional[torch.Tensor], shape (B,) or (1,), optional + Timestep embedding input. + class_labels : Optional[torch.LongTensor], shape (B,) or (1,), optional + Class label input. + hidden_dtype : Optional[torch.dtype], optional + Dtype for embedding computation. + emb : Optional[torch.Tensor], shape (B, E), optional + Precomputed embedding. If None, computed from timestep and class_labels. + + Returns + ------- + norm_x_scaled : torch.Tensor, shape (B, D) + Normalized and scaled input. + gate_msa : torch.Tensor, shape (B, D) + Gate for MSA branch. + shift_mlp : torch.Tensor, shape (B, D) + Shift for MLP branch. + scale_mlp : torch.Tensor, shape (B, D) + Scale for MLP branch. + gate_mlp : torch.Tensor, shape (B, D) + Gate for MLP branch. + + Notes + ----- + - B: batch size + - D: hidden dimension + """ + if self.emb is not None: + emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype) + emb = self.linear(self.silu(emb)) + + # The weight layout has changed; use split_mod rather than chunk to separate the embedding. + emb = emb.view(emb.shape[0], -1, 6).permute(2, 0, 1) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb + + norm_x = self.norm(x) + + if self.scale_shift != 0: + scale_msa.add_(self.scale_shift) + scale_mlp.add_(self.scale_shift) + + norm_x_scaled = norm_x * scale_msa[:, None] + shift_msa[:, None] + return norm_x_scaled, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +class NunchakuAdaLayerNormZeroSingle(AdaLayerNormZeroSingle): + """ + Nunchaku quantized AdaLayerNormZeroSingle. + + Uses AWQW4A16Linear for quantized embedding projection. Suitable for single-branch normalization. + + Parameters + ---------- + other : AdaLayerNormZeroSingle + Source AdaLayerNormZeroSingle instance to copy weights and structure from. + scale_shift : float, optional + Value to add to scale parameters. Default is 1.0. + Nunchaku may have already fused the scale_shift into the linear weights, so you may want to set it to 0. + + Notes + ----- + - B: batch size + - D: hidden dimension + """ + + def __init__(self, other: AdaLayerNormZeroSingle, scale_shift: float = 1.0): + super(AdaLayerNormZeroSingle, self).__init__() + self.scale_shift = scale_shift + self.silu = other.silu + self.linear = AWQW4A16Linear.from_linear(other.linear) + self.norm = other.norm + + def forward( + self, + x: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for quantized AdaLayerNormZeroSingle. + + Parameters + ---------- + x : torch.Tensor, shape (B, D), dtype float32/float16 + Input tensor. + emb : Optional[torch.Tensor], shape (B, E), optional + Embedding tensor. + + Returns + ------- + norm_x_scaled : torch.Tensor, shape (B, D) + Normalized and scaled input. + gate_msa : torch.Tensor, shape (B, D) + Gate for MSA branch. + + Notes + ----- + - B: batch size + - D: hidden dimension + """ + emb = self.linear(self.silu(emb)) + + # The weight layout has changed; use split_mod rather than chunk to separate the embedding. + emb = emb.view(emb.shape[0], -1, 3).permute(2, 0, 1) + shift_msa, scale_msa, gate_msa = emb + + if self.scale_shift != 0: + scale_msa.add_(self.scale_shift) + + norm_x = self.norm(x) + norm_x_scaled = norm_x * scale_msa[:, None] + shift_msa[:, None] + return norm_x_scaled, gate_msa diff --git a/nunchaku/models/text_encoders/__init__.py b/nunchaku/models/text_encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b386b5e049d48e9ab5e9d4a9f71ec9c60c363004 --- /dev/null +++ b/nunchaku/models/text_encoders/__init__.py @@ -0,0 +1,5 @@ +from .t5_encoder import NunchakuT5EncoderModel + +__all__ = [ + "NunchakuT5EncoderModel", +] diff --git a/nunchaku/models/text_encoders/linear.py b/nunchaku/models/text_encoders/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..a878c13c634ec67d3e9fb57ab67ee121df238c20 --- /dev/null +++ b/nunchaku/models/text_encoders/linear.py @@ -0,0 +1,238 @@ +# -*- coding: utf-8 -*- +""" +This module provides the :class:`W4Linear` quantized linear layer, which implements +4-bit weight-only quantization for efficient inference. +""" + +import torch +import torch.nn as nn + +from ..._C.ops import gemm_awq, gemv_awq +from .tinychat_utils import ceil_num_groups, convert_to_tinychat_w4x16y16_linear_weight + +__all__ = ["W4Linear"] + + +class W4Linear(nn.Module): + """ + 4-bit quantized linear layer with group-wise quantization. + + Parameters + ---------- + in_features : int + Number of input features. + out_features : int + Number of output features. + bias : bool, optional + If True, adds a learnable bias (default: False). + group_size : int, optional + Number of input channels per quantization group (default: 128). + If -1, uses the full input dimension as a single group. + dtype : torch.dtype, optional + Data type for quantization scales and zeros (default: torch.float16). + device : str or torch.device, optional + Device for weights and buffers (default: "cuda"). + + Attributes + ---------- + in_features : int + out_features : int + group_size : int + qweight : torch.Tensor + Quantized weight tensor (int16). + scales : torch.Tensor + Per-group scale tensor. + scaled_zeros : torch.Tensor + Per-group zero-point tensor (scaled). + bias : torch.Tensor or None + Optional bias tensor. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + group_size: int = 128, + dtype: torch.dtype = torch.float16, + device: str | torch.device = "cuda", + ): + super().__init__() + assert dtype in (torch.float16, torch.bfloat16), f"Unsupported dtype: {dtype}" + + self.in_features = in_features + self.out_features = out_features + self.group_size = group_size if group_size != -1 else in_features + assert self.in_features % self.group_size == 0 + assert out_features % (32 // self.weight_bits) == 0 + self.ceil_num_groups = ceil_num_groups( + in_features=self.in_features, + group_size=self.group_size, + weight_bits=self.weight_bits, + ) + + assert out_features % (self.interleave) == 0 + self.register_buffer( + "qweight", + torch.zeros( + ( + self.out_features // self.interleave, + self.in_features // (16 // self.weight_bits) * self.interleave, + ), + dtype=torch.int16, + device=device, + ), + ) + self.register_buffer( + "scales", + torch.zeros((self.ceil_num_groups, self.out_features), dtype=dtype, device=device), + ) + self.register_buffer( + "scaled_zeros", + torch.zeros((self.ceil_num_groups, self.out_features), dtype=dtype, device=device), + ) + if bias: + self.register_buffer("bias", torch.zeros((out_features), dtype=dtype, device=device)) + else: + self.bias = None + + @property + def weight_bits(self) -> int: + """ + Number of bits per quantized weight (always 4). + """ + return 4 + + @property + def interleave(self) -> int: + """ + Interleave factor for quantized weights (always 4). + """ + return 4 + + @torch.no_grad() + def forward(self, x): + """ + Forward pass. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (..., in_features). + + Returns + ------- + torch.Tensor + Output tensor of shape (..., out_features). + """ + if x.numel() / x.shape[-1] < 8: + out = gemv_awq( + x, + self.qweight, + self.scales, + self.scaled_zeros, + x.numel() // x.shape[-1], + self.out_features, + self.in_features, + self.group_size, + ) + else: + if self.group_size != 128: + raise NotImplementedError("Kernel currently only supports group_size=128.") + out = gemm_awq(x, self.qweight, self.scales, self.scaled_zeros) + out = out + self.bias if self.bias is not None else out + return out + + @staticmethod + def from_linear( + linear: nn.Linear, + group_size: int, + init_only: bool = False, + weight: torch.Tensor | None = None, + scale: torch.Tensor | None = None, + zero: torch.Tensor | None = None, + zero_pre_scaled: bool = False, + ) -> "W4Linear": + """ + Convert a standard nn.Linear to a quantized W4Linear. + + Parameters + ---------- + linear : nn.Linear + The linear layer to convert. + group_size : int + Quantization group size. + init_only : bool, optional + If True, only initializes the quantized layer (default: False). + weight : torch.Tensor, optional + Precomputed quantized weight (default: None). + scale : torch.Tensor, optional + Precomputed scale tensor (default: None). + zero : torch.Tensor, optional + Precomputed zero-point tensor (default: None). + zero_pre_scaled : bool, optional + Whether the zero-point tensor is pre-scaled (default: False). + + Returns + ------- + W4Linear + Quantized linear layer. + """ + assert isinstance(linear, nn.Linear) + weight = linear.weight.data if weight is None else weight.data + dtype, device = weight.dtype, weight.device + oc, ic = linear.out_features, linear.in_features + _linear = W4Linear( + in_features=ic, + out_features=oc, + bias=linear.bias is not None, + group_size=group_size, + dtype=dtype, + device=device, + ) + if init_only: + return _linear + if linear.bias is not None: + _linear.bias.data.copy_(linear.bias.data) + if scale is None: + assert zero is None, "scale and zero point tensors should be provided together." + group_size = ic if group_size <= 0 else group_size + assert group_size <= ic, "group size should be less than or equal to input channel size." + assert ic % group_size == 0, "input channel size should be divisible by group size." + ng, gs = ic // group_size, group_size + weight = weight.to(dtype=torch.float32).view(oc, 1, ng, gs) + vmin, vmax = weight.amin(dim=-1, keepdim=True), weight.amax(dim=-1, keepdim=True) + scale = (vmax - vmin).div_(15) + scale[scale == 0] = 1.0 + if zero_pre_scaled: + zero = vmin.neg_().div_(scale).round_().clamp_(0, 15) + weight = weight.div_(scale).add_(zero).round_().clamp_(0, 15).sub_(zero).mul_(scale) + else: + zero = vmin.neg_().clamp_min(0) + weight = weight.add_(zero).div_(scale).round_().clamp_(0, 15).mul_(scale).sub_(zero) + weight = weight.to(dtype=dtype).view(oc, ic) + scale = scale.to(dtype=dtype) + zero = zero.to(dtype=dtype) + weight, scale, zero = convert_to_tinychat_w4x16y16_linear_weight( + weight=weight, + scale=scale, + zero=zero, + group_size=group_size, + zero_pre_scaled=zero_pre_scaled, + ) + _linear.qweight.data.copy_(weight) + _linear.scales.data.copy_(scale) + _linear.scaled_zeros.data.copy_(zero) + return _linear + + def extra_repr(self) -> str: + """ + Returns a string describing the layer configuration. + """ + return "in_features={}, out_features={}, bias={}, weight_bits={}, group_size={}".format( + self.in_features, + self.out_features, + self.bias is not None, + self.weight_bits, + self.group_size, + ) diff --git a/nunchaku/models/text_encoders/t5_encoder.py b/nunchaku/models/text_encoders/t5_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6223437978d92948a5355af3b525b0d13e0e1513 --- /dev/null +++ b/nunchaku/models/text_encoders/t5_encoder.py @@ -0,0 +1,116 @@ +""" +The NunchakuT5EncoderModel class enables loading T5 encoder weights from safetensors files, +automatically replacing supported linear layers with quantized :class:`~nunchaku.models.text_encoders.linear.W4Linear` +modules for improved performance and memory efficiency. +""" + +import json +import logging +import os +from pathlib import Path + +import torch +from accelerate import init_empty_weights +from torch import nn +from transformers import T5Config, T5EncoderModel + +from ...utils import load_state_dict_in_safetensors +from .linear import W4Linear + +# Get log level from environment variable (default to INFO) +log_level = os.getenv("LOG_LEVEL", "INFO").upper() + +# Configure logging +logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +class NunchakuT5EncoderModel(T5EncoderModel): + """ + Nunchaku T5 Encoder Model + + Extends :class:`transformers.T5EncoderModel` to support quantized weights and + memory-efficient inference using :class:`~nunchaku.models.text_encoders.linear.W4Linear`. + + This class provides a convenient interface for loading T5 encoder weights from + safetensors files, automatically replacing supported linear layers with quantized + modules for improved speed and reduced memory usage. + + Example + ------- + .. code-block:: python + + model = NunchakuT5EncoderModel.from_pretrained( + "mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors" + ) + """ + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs): + """ + Load a :class:`NunchakuT5EncoderModel` from a safetensors file. + + This method loads the model configuration and weights from a safetensors file, + initializes the model on the 'meta' device (no memory allocation for weights), + and replaces supported linear layers with quantized :class:`~nunchaku.models.text_encoders.linear.W4Linear` modules. + + Parameters + ---------- + pretrained_model_name_or_path : str or os.PathLike + Path to the safetensors file containing the model weights and metadata. + torch_dtype : torch.dtype, optional + Data type for model initialization (default: ``torch.bfloat16``). + Set to ``torch.float16`` for Turing GPUs. + device : str or torch.device, optional + Device to load the model onto (default: ``"cuda"``). + If the model is loaded on CPU, it will be automatically moved to GPU. + + Returns + ------- + NunchakuT5EncoderModel + The loaded and quantized T5 encoder model. + + Example + ------- + .. code-block:: python + + model = NunchakuT5EncoderModel.from_pretrained( + "mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors" + ) + """ + pretrained_model_name_or_path = Path(pretrained_model_name_or_path) + state_dict, metadata = load_state_dict_in_safetensors(pretrained_model_name_or_path, return_metadata=True) + + # Load the config file from metadata + config = json.loads(metadata["config"]) + config = T5Config(**config) + + # Initialize model on 'meta' device (no memory allocation for weights) + with init_empty_weights(): + t5_encoder = T5EncoderModel(config).to(kwargs.get("torch_dtype", torch.bfloat16)) + + t5_encoder.eval() + + # Load the model weights from the safetensors file and quantize supported linear layers + named_modules = {} + for name, module in t5_encoder.named_modules(): + assert isinstance(name, str) + if isinstance(module, nn.Linear): + if f"{name}.qweight" in state_dict: + logger.debug(f"Switching {name} to W4Linear") + qmodule = W4Linear.from_linear(module, group_size=128, init_only=True) + # modeling_t5.py: T5DenseGatedActDense needs dtype of weight + qmodule.weight = torch.empty([1], dtype=module.weight.dtype, device=module.weight.device) + + parent_name, child_name = name.rsplit(".", 1) + setattr(named_modules[parent_name], child_name, qmodule) + else: + named_modules[name] = module + + device = kwargs.get("device", "cuda") + if isinstance(device, str): + device = torch.device(device) + t5_encoder.to_empty(device=device) + t5_encoder.load_state_dict(state_dict, strict=True) + + return t5_encoder diff --git a/nunchaku/models/text_encoders/tinychat_utils.py b/nunchaku/models/text_encoders/tinychat_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..03cea464e7371cd3e966aefef3a9842e9ba2caca --- /dev/null +++ b/nunchaku/models/text_encoders/tinychat_utils.py @@ -0,0 +1,188 @@ +# -*- coding: utf-8 -*- +""" +This module provides utility functions for quantized linear layers in the TinyChat backend. +""" + +import torch + +__all__ = ["ceil_num_groups", "convert_to_tinychat_w4x16y16_linear_weight"] + + +def ceil_divide(x: int, divisor: int) -> int: + """ + Compute the ceiling of integer division. + + Parameters + ---------- + x : int + Dividend. + divisor : int + Divisor. + + Returns + ------- + int + The smallest integer greater than or equal to ``x / divisor``. + """ + return (x + divisor - 1) // divisor + + +def ceil_num_groups(in_features: int, group_size: int, weight_bits: int = 4) -> int: + """ + Calculate the padded number of quantization groups for TinyChat quantization. + + This ensures the number of groups is compatible with TinyChat's packing and kernel requirements. + + Parameters + ---------- + in_features : int + Input channel size (number of input features). + group_size : int + Quantization group size. + weight_bits : int, optional + Number of bits per quantized weight (default: 4). + + Returns + ------- + int + The padded number of quantization groups. + + Raises + ------ + AssertionError + If ``in_features`` is not divisible by ``group_size``, or if ``weight_bits`` is not 4, 2, or 1. + NotImplementedError + If ``group_size`` is not one of the supported values (>=128, 64, 32). + """ + assert in_features % group_size == 0, "input channel size should be divisible by group size." + num_groups = in_features // group_size + assert weight_bits in (4, 2, 1), "weight bits should be 4, 2, or 1." + pack_size = 32 // weight_bits # one INT32 contains `pack_size` elements of weights + num_packs = ceil_divide(num_groups, pack_size) + if group_size >= 128: + num_packs_factor = 1 + elif group_size == 64: + num_packs_factor = 2 + elif group_size == 32: + num_packs_factor = 4 + else: + raise NotImplementedError("Unsupported group size for TinyChat quantization.") + # make sure num_packs is a multiple of num_packs_factor + num_packs = ceil_divide(num_packs, num_packs_factor) * num_packs_factor + num_groups = num_packs * pack_size + return num_groups + + +def pack_w4(weight: torch.Tensor) -> torch.Tensor: + """ + Pack quantized 4-bit weights into TinyChat's int16 format. + + This function rearranges and packs 4-bit quantized weights (stored as int32) into + the format expected by TinyChat CUDA kernels. + + Parameters + ---------- + weight : torch.Tensor + Quantized weight tensor of shape (out_features, in_features), dtype int32. + The input channel dimension must be divisible by 32. + + Returns + ------- + torch.Tensor + Packed weight tensor of dtype int16. + + Raises + ------ + AssertionError + If input tensor is not int32 or input channel size is not divisible by 32. + """ + assert weight.dtype == torch.int32, f"quantized weight should be torch.int32, but got {weight.dtype}." + oc, ic = weight.shape + assert ic % 32 == 0, "input channel size should be divisible by 32." + # [0, 1, ..., 31] -> [0, 8, 16, 24, 1, 9, 17, 25, ..., 7, 15, 23, 31] + weight = weight.view(-1, 4, 8) + weight = weight[:, 0] | (weight[:, 1] << 4) | (weight[:, 2] << 8) | (weight[:, 3] << 12) + weight = weight.view(oc // 4, 4, ic // 64, 16).permute(0, 2, 1, 3).reshape(oc // 4, ic) + return weight.to(torch.int16) + + +def convert_to_tinychat_w4x16y16_linear_weight( + weight: torch.Tensor, + scale: torch.Tensor, + zero: torch.Tensor, + group_size: int = -1, + zero_pre_scaled: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Convert a floating-point weight tensor to TinyChat W4-X16-Y16 quantized linear format. + + This function quantizes the input weights to 4 bits per value, applies group-wise + scaling and zero-point, and packs the result into the format expected by TinyChat + quantized linear layers. + + Parameters + ---------- + weight : torch.Tensor + Floating-point weight tensor of shape (out_features, in_features). + Must be of dtype ``torch.float16`` or ``torch.bfloat16``. + scale : torch.Tensor + Per-group scale tensor (can be broadcastable). + zero : torch.Tensor + Per-group zero-point tensor (can be broadcastable). + group_size : int, optional + Quantization group size. If set to -1 (default), uses the full input dimension as a single group. + zero_pre_scaled : bool, optional + If True, the zero tensor is already scaled by the scale tensor (default: False). + + Returns + ------- + tuple of torch.Tensor + - packed_weight : torch.Tensor + Packed quantized weight tensor (int16). + - packed_scale : torch.Tensor + Packed scale tensor (shape: [num_groups, out_features], dtype matches input). + - packed_zero : torch.Tensor + Packed zero-point tensor (shape: [num_groups, out_features], dtype matches input). + + Raises + ------ + AssertionError + If input types or shapes are invalid, or quantized values are out of range. + + Example + ------- + .. code-block:: python + + qweight, qscale, qzero = convert_to_tinychat_w4x16y16_linear_weight( + weight, scale, zero, group_size=128 + ) + """ + dtype, device = weight.dtype, weight.device + assert dtype in (torch.float16, torch.bfloat16), "currently tinychat only supports fp16 and bf16." + assert scale is not None, "scale tensor is required for quantization." + assert zero is not None, "zero point tensor is required for quantization." + weight = weight.to(dtype=torch.float32) + scale = scale.to(dtype=torch.float32, device=device) + zero = zero.to(dtype=torch.float32, device=device) + if zero_pre_scaled: + zero = zero * scale + oc, ic = weight.shape + group_size = ic if group_size <= 0 else group_size + assert group_size <= ic, "group size should be less than or equal to input channel size." + assert ic % group_size == 0, "input channel size should be divisible by group size." + ng = ic // group_size + if scale.numel() == 1: + scale = scale.view(1, 1).expand(oc, ng) + scale = scale.reshape(oc, ng).contiguous().view(oc, ng, 1) + if zero.numel() == 1: + zero = zero.view(1, 1).expand(oc, ng) + zero = zero.reshape(oc, ng).contiguous().view(oc, ng, 1) + weight = weight.view(oc, ng, -1).add_(zero).div_(scale).round_().view(oc, ic) + assert weight.min() >= 0 and weight.max() <= 15, "quantized weight should be in [0, 15]." + _weight = pack_w4(weight.to(torch.int32)) + _ng = ceil_num_groups(ic, group_size, weight_bits=4) + _scale = torch.zeros((_ng, oc), dtype=dtype, device=device) + _zero = torch.zeros((_ng, oc), dtype=dtype, device=device) + _scale[:ng] = scale.view(oc, ng).t().to(dtype=dtype) + _zero[:ng] = zero.view(oc, ng).t().to(dtype=dtype).neg_() + return _weight, _scale, _zero diff --git a/nunchaku/models/transformers/__init__.py b/nunchaku/models/transformers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f150e7636481f00859568449c32037dd1b456d34 --- /dev/null +++ b/nunchaku/models/transformers/__init__.py @@ -0,0 +1,5 @@ +from .transformer_flux import NunchakuFluxTransformer2dModel + +__all__ = [ + "NunchakuFluxTransformer2dModel", +] diff --git a/nunchaku/models/transformers/transformer_flux.py b/nunchaku/models/transformers/transformer_flux.py new file mode 100644 index 0000000000000000000000000000000000000000..149c18d49b0dc2bef01265f22a83d5afd347e83e --- /dev/null +++ b/nunchaku/models/transformers/transformer_flux.py @@ -0,0 +1,991 @@ +""" +Implements the :class:`NunchakuFluxTransformer2dModel`, a quantized transformer for Diffusers with efficient inference and LoRA support. +""" + +import json +import logging +import os +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import diffusers +import torch +from diffusers import FluxTransformer2DModel +from diffusers.configuration_utils import register_to_config +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from huggingface_hub import utils +from packaging.version import Version +from safetensors.torch import load_file +from torch import nn + +from ..._C import QuantizedFluxModel +from ..._C import utils as cutils +from ...lora.flux.nunchaku_converter import fuse_vectors, to_nunchaku +from ...lora.flux.utils import is_nunchaku_format +from ...utils import check_hardware_compatibility, get_precision, load_state_dict_in_safetensors, pad_tensor +from .utils import NunchakuModelLoaderMixin + +SVD_RANK = 32 + +# Get log level from environment variable (default to INFO) +log_level = os.getenv("LOG_LEVEL", "INFO").upper() + +# Configure logging +logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +class NunchakuFluxTransformerBlocks(nn.Module): + """ + Wrapper for quantized Nunchaku FLUX transformer blocks. + + This class manages the forward pass, rotary embedding packing, and optional + residual callbacks for ID embeddings. + + Parameters + ---------- + m : QuantizedFluxModel + The quantized transformer model. + device : str or torch.device + Device to run the model on. + """ + + def __init__(self, m: QuantizedFluxModel, device: str | torch.device): + super(NunchakuFluxTransformerBlocks, self).__init__() + self.m = m + self.dtype = torch.bfloat16 if m.isBF16() else torch.float16 + self.device = device + + @staticmethod + def pack_rotemb(rotemb: torch.Tensor) -> torch.Tensor: + """ + Packs rotary embeddings for efficient computation. + + Parameters + ---------- + rotemb : torch.Tensor + Rotary embedding tensor of shape (B, M, D//2, 1, 2), dtype float32. + + Returns + ------- + torch.Tensor + Packed rotary embedding tensor of shape (B, M, D). + """ + assert rotemb.dtype == torch.float32 + B = rotemb.shape[0] + M = rotemb.shape[1] + D = rotemb.shape[2] * 2 + assert rotemb.shape == (B, M, D // 2, 1, 2) + assert M % 16 == 0 + assert D % 8 == 0 + rotemb = rotemb.reshape(B, M // 16, 16, D // 8, 8) + rotemb = rotemb.permute(0, 1, 3, 2, 4) + # 16*8 pack, FP32 accumulator (C) format + # https://docs.nvidia.com/cuda/parallel-thread-execution/#mma-16816-c + ##########################################|--M--|--D--| + ##########################################|-3--4--5--6| + ########################################## : : : : + rotemb = rotemb.reshape(*rotemb.shape[0:3], 2, 8, 4, 2) + rotemb = rotemb.permute(0, 1, 2, 4, 5, 3, 6) + rotemb = rotemb.contiguous() + rotemb = rotemb.view(B, M, D) + return rotemb + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + encoder_hidden_states: torch.Tensor, + image_rotary_emb: torch.Tensor, + id_embeddings=None, + id_weight=None, + joint_attention_kwargs=None, + controlnet_block_samples=None, + controlnet_single_block_samples=None, + skip_first_layer=False, + ): + """ + Forward pass for the quantized transformer blocks. + It will call the forward method of ``m`` on the C backend. + + Parameters + ---------- + hidden_states : torch.Tensor + Input hidden states for image tokens. + temb : torch.Tensor + Temporal embedding tensor. + encoder_hidden_states : torch.Tensor + Input hidden states for text tokens. + image_rotary_emb : torch.Tensor + Rotary embedding tensor for all tokens. + id_embeddings : torch.Tensor, optional + Optional ID embeddings for residual callback. + id_weight : float, optional + Weight for ID embedding residual. + joint_attention_kwargs : dict, optional + Additional kwargs for joint attention. + controlnet_block_samples : list[torch.Tensor], optional + ControlNet block samples. + controlnet_single_block_samples : list[torch.Tensor], optional + ControlNet single block samples. + skip_first_layer : bool, optional + Whether to skip the first layer. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + (encoder_hidden_states, hidden_states) after transformer blocks. + """ + # batch_size = hidden_states.shape[0] + txt_tokens = encoder_hidden_states.shape[1] + img_tokens = hidden_states.shape[1] + + self.id_embeddings = id_embeddings + self.id_weight = id_weight + self.pulid_ca_idx = 0 + if self.id_embeddings is not None: + self.set_pulid_residual_callback() + + original_dtype = hidden_states.dtype + original_device = hidden_states.device + + hidden_states = hidden_states.to(self.dtype).to(self.device) + encoder_hidden_states = encoder_hidden_states.to(self.dtype).to(self.device) + temb = temb.to(self.dtype).to(self.device) + image_rotary_emb = image_rotary_emb.to(self.device) + + if controlnet_block_samples is not None: + if len(controlnet_block_samples) > 0: + controlnet_block_samples = torch.stack(controlnet_block_samples).to(self.device) + else: + controlnet_block_samples = None + + if controlnet_single_block_samples is not None: + if len(controlnet_single_block_samples) > 0: + controlnet_single_block_samples = torch.stack(controlnet_single_block_samples).to(self.device) + else: + controlnet_single_block_samples = None + + assert image_rotary_emb.ndim == 6 + assert image_rotary_emb.shape[0] == 1 + assert image_rotary_emb.shape[1] == 1 + assert image_rotary_emb.shape[2] == 1 * (txt_tokens + img_tokens) + # [1, tokens, head_dim / 2, 1, 2] (sincos) + image_rotary_emb = image_rotary_emb.reshape([1, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]]) + rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] # .to(self.dtype) + rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype) + rotary_emb_single = image_rotary_emb # .to(self.dtype) + + rotary_emb_txt = self.pack_rotemb(pad_tensor(rotary_emb_txt, 256, 1)) + rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1)) + rotary_emb_single = self.pack_rotemb(pad_tensor(rotary_emb_single, 256, 1)) + hidden_states = self.m.forward( + hidden_states, + encoder_hidden_states, + temb, + rotary_emb_img, + rotary_emb_txt, + rotary_emb_single, + controlnet_block_samples, + controlnet_single_block_samples, + skip_first_layer, + ) + + if self.id_embeddings is not None: + self.reset_pulid_residual_callback() + + hidden_states = hidden_states.to(original_dtype).to(original_device) + + encoder_hidden_states = hidden_states[:, :txt_tokens, ...] + hidden_states = hidden_states[:, txt_tokens:, ...] + + return encoder_hidden_states, hidden_states + + def forward_layer_at( + self, + idx: int, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: torch.Tensor, + joint_attention_kwargs=None, + controlnet_block_samples=None, + controlnet_single_block_samples=None, + ): + """ + Forward pass for a specific transformer layer in ``m``. + + Parameters + ---------- + idx : int + Index of the transformer layer. + hidden_states : torch.Tensor + Input hidden states for image tokens. + encoder_hidden_states : torch.Tensor + Input hidden states for text tokens. + temb : torch.Tensor + Temporal embedding tensor. + image_rotary_emb : torch.Tensor + Rotary embedding tensor for all tokens. + joint_attention_kwargs : dict, optional + Additional kwargs for joint attention. + controlnet_block_samples : list[torch.Tensor], optional + ControlNet block samples. + controlnet_single_block_samples : list[torch.Tensor], optional + ControlNet single block samples. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + (encoder_hidden_states, hidden_states) after the specified layer. + """ + # batch_size = hidden_states.shape[0] + txt_tokens = encoder_hidden_states.shape[1] + img_tokens = hidden_states.shape[1] + + original_dtype = hidden_states.dtype + original_device = hidden_states.device + + hidden_states = hidden_states.to(self.dtype).to(self.device) + encoder_hidden_states = encoder_hidden_states.to(self.dtype).to(self.device) + temb = temb.to(self.dtype).to(self.device) + image_rotary_emb = image_rotary_emb.to(self.device) + + if controlnet_block_samples is not None: + controlnet_block_samples = torch.stack(controlnet_block_samples).to(self.device) + if controlnet_single_block_samples is not None: + controlnet_single_block_samples = torch.stack(controlnet_single_block_samples).to(self.device) + + assert image_rotary_emb.ndim == 6 + assert image_rotary_emb.shape[0] == 1 + assert image_rotary_emb.shape[1] == 1 + assert image_rotary_emb.shape[2] == 1 * (txt_tokens + img_tokens) + # [1, tokens, head_dim / 2, 1, 2] (sincos) + image_rotary_emb = image_rotary_emb.reshape([1, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]]) + rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] # .to(self.dtype) + rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype) + + rotary_emb_txt = self.pack_rotemb(pad_tensor(rotary_emb_txt, 256, 1)) + rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1)) + + hidden_states, encoder_hidden_states = self.m.forward_layer( + idx, + hidden_states, + encoder_hidden_states, + temb, + rotary_emb_img, + rotary_emb_txt, + controlnet_block_samples, + controlnet_single_block_samples, + ) + + hidden_states = hidden_states.to(original_dtype).to(original_device) + encoder_hidden_states = encoder_hidden_states.to(original_dtype).to(original_device) + + return encoder_hidden_states, hidden_states + + def set_pulid_residual_callback(self): + """ + Sets the residual callback for PulID (personalized ID) embeddings. + """ + id_embeddings = self.id_embeddings + pulid_ca = self.pulid_ca + pulid_ca_idx = [self.pulid_ca_idx] + id_weight = self.id_weight + + def callback(hidden_states): + ip = id_weight * pulid_ca[pulid_ca_idx[0]](id_embeddings, hidden_states) + pulid_ca_idx[0] += 1 + return ip + + self.callback_holder = callback + self.m.set_residual_callback(callback) + + def reset_pulid_residual_callback(self): + """ + Resets the PulID residual callback to None. + """ + self.callback_holder = None + self.m.set_residual_callback(None) + + def __del__(self): + """ + Destructor to reset the quantized model. + """ + self.m.reset() + + def norm1( + self, + hidden_states: torch.Tensor, + emb: torch.Tensor, + idx: int = 0, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Runs the norm_one_forward for a specific layer in ``m``. + + Parameters + ---------- + hidden_states : torch.Tensor + Input hidden states. + emb : torch.Tensor + Embedding tensor. + idx : int, optional + Layer index (default: 0). + + Returns + ------- + tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + Output tensors from norm_one_forward. + """ + return self.m.norm_one_forward(idx, hidden_states, emb) + + +def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: + """ + Rotary positional embedding function. + + Parameters + ---------- + pos : torch.Tensor + Position tensor of shape (..., n). + dim : int + Embedding dimension (must be even). + theta : int + Rotary base. + + Returns + ------- + torch.Tensor + Rotary embedding tensor. + """ + assert dim % 2 == 0, "The dimension must be even." + + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + + batch_size, seq_length = pos.shape + out = torch.einsum("...n,d->...nd", pos, omega) + + USE_SINCOS = True + if USE_SINCOS: + cos_out = torch.cos(out) + sin_out = torch.sin(out) + stacked_out = torch.stack([sin_out, cos_out], dim=-1) + out = stacked_out.view(batch_size, -1, dim // 2, 1, 2) + else: + out = out.view(batch_size, -1, dim // 2, 1, 1) + + return out.float() + + +class EmbedND(nn.Module): + """ + Multi-dimensional rotary embedding module. + + Parameters + ---------- + dim : int + Embedding dimension. + theta : int + Rotary base. + axes_dim : list[int] + List of axis dimensions for each spatial axis. + """ + + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super(EmbedND, self).__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + """ + Computes rotary embeddings for multi-dimensional positions. + + Parameters + ---------- + ids : torch.Tensor + Position indices tensor of shape (..., n_axes). + + Returns + ------- + torch.Tensor + Rotary embedding tensor. + """ + if Version(diffusers.__version__) >= Version("0.31.0"): + ids = ids[None, ...] + n_axes = ids.shape[-1] + emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3) + return emb.unsqueeze(1) + + +def load_quantized_module( + path_or_state_dict: str | os.PathLike[str] | dict[str, torch.Tensor], + device: str | torch.device = "cuda", + use_fp4: bool = False, + offload: bool = False, + bf16: bool = True, +) -> QuantizedFluxModel: + """ + Loads a quantized Nunchaku FLUX model from a state dict or file. + + Parameters + ---------- + path_or_state_dict : str, os.PathLike, or dict + Path to the quantized model file or a state dict. + device : str or torch.device, optional + Device to load the model on (default: "cuda"). + use_fp4 : bool, optional + Whether to use FP4 quantization (default: False). + offload : bool, optional + Whether to offload weights to CPU (default: False). + bf16 : bool, optional + Whether to use bfloat16 (default: True). + + Returns + ------- + QuantizedFluxModel + Loaded quantized model. + """ + device = torch.device(device) + assert device.type == "cuda" + m = QuantizedFluxModel() + cutils.disable_memory_auto_release() + m.init(use_fp4, offload, bf16, 0 if device.index is None else device.index) + if isinstance(path_or_state_dict, dict): + m.loadDict(path_or_state_dict, True) + else: + m.load(str(path_or_state_dict)) + return m + + +class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoaderMixin): + """ + Nunchaku FLUX Transformer 2D Model. + + This class implements a quantized transformer model compatible with the Diffusers + library, supporting LoRA, rotary embeddings, and efficient inference. + + Parameters + ---------- + patch_size : int, optional + Patch size for input images (default: 1). + in_channels : int, optional + Number of input channels (default: 64). + out_channels : int or None, optional + Number of output channels (default: None). + num_layers : int, optional + Number of transformer layers (default: 19). + num_single_layers : int, optional + Number of single transformer layers (default: 38). + attention_head_dim : int, optional + Dimension of each attention head (default: 128). + num_attention_heads : int, optional + Number of attention heads (default: 24). + joint_attention_dim : int, optional + Joint attention dimension (default: 4096). + pooled_projection_dim : int, optional + Pooled projection dimension (default: 768). + guidance_embeds : bool, optional + Whether to use guidance embeddings (default: False). + axes_dims_rope : tuple[int], optional + Axes dimensions for rotary embeddings (default: (16, 56, 56)). + """ + + @register_to_config + def __init__( + self, + patch_size: int = 1, + in_channels: int = 64, + out_channels: int | None = None, + num_layers: int = 19, + num_single_layers: int = 38, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 4096, + pooled_projection_dim: int = 768, + guidance_embeds: bool = False, + axes_dims_rope: tuple[int] = (16, 56, 56), + ): + super(NunchakuFluxTransformer2dModel, self).__init__( + patch_size=patch_size, + in_channels=in_channels, + out_channels=out_channels, + num_layers=num_layers, + num_single_layers=num_single_layers, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + joint_attention_dim=joint_attention_dim, + pooled_projection_dim=pooled_projection_dim, + guidance_embeds=guidance_embeds, + axes_dims_rope=axes_dims_rope, + ) + # these state_dicts are used for supporting lora + self._unquantized_part_sd: dict[str, torch.Tensor] = {} + self._unquantized_part_loras: dict[str, torch.Tensor] = {} + self._quantized_part_sd: dict[str, torch.Tensor] = {} + self._quantized_part_vectors: dict[str, torch.Tensor] = {} + self._original_in_channels = in_channels + + # ComfyUI LoRA related + self.comfy_lora_meta_list = [] + self.comfy_lora_sd_list = [] + + @classmethod + @utils.validate_hf_hub_args + def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs): + """ + Loads a Nunchaku FLUX transformer model from pretrained weights. + + Parameters + ---------- + pretrained_model_name_or_path : str or os.PathLike + Path to the model directory or HuggingFace repo. + **kwargs + Additional keyword arguments for device, offload, torch_dtype, precision, etc. + + Returns + ------- + NunchakuFluxTransformer2dModel or (NunchakuFluxTransformer2dModel, dict) + The loaded model, and optionally metadata if `return_metadata=True`. + """ + device = kwargs.get("device", "cuda") + if isinstance(device, str): + device = torch.device(device) + offload = kwargs.get("offload", False) + torch_dtype = kwargs.get("torch_dtype", torch.bfloat16) + precision = get_precision(kwargs.get("precision", "auto"), device, pretrained_model_name_or_path) + metadata = None + + if isinstance(pretrained_model_name_or_path, str): + pretrained_model_name_or_path = Path(pretrained_model_name_or_path) + if pretrained_model_name_or_path.is_file() or pretrained_model_name_or_path.name.endswith( + (".safetensors", ".sft") + ): + transformer, model_state_dict, metadata = cls._build_model(pretrained_model_name_or_path, **kwargs) + quantized_part_sd = {} + unquantized_part_sd = {} + for k, v in model_state_dict.items(): + if k.startswith(("transformer_blocks.", "single_transformer_blocks.")): + quantized_part_sd[k] = v + else: + unquantized_part_sd[k] = v + precision = get_precision(device=device) + quantization_config = json.loads(metadata["quantization_config"]) + check_hardware_compatibility(quantization_config, device) + else: + transformer, unquantized_part_path, transformer_block_path = cls._build_model_legacy( + pretrained_model_name_or_path, **kwargs + ) + + # get the default LoRA branch and all the vectors + quantized_part_sd = load_file(transformer_block_path) + unquantized_part_sd = load_file(unquantized_part_path) + new_quantized_part_sd = {} + for k, v in quantized_part_sd.items(): + if v.ndim == 1: + new_quantized_part_sd[k] = v + elif "qweight" in k: + # only the shape information of this tensor is needed + new_quantized_part_sd[k] = v.to("meta") + + # if the tensor has qweight, but does not have low-rank branch, we need to add some artificial tensors + for t in ["lora_up", "lora_down"]: + new_k = k.replace(".qweight", f".{t}") + if new_k not in quantized_part_sd: + oc, ic = v.shape + ic = ic * 2 # v is packed into INT8, so we need to double the size + new_quantized_part_sd[k.replace(".qweight", f".{t}")] = torch.zeros( + (0, ic) if t == "lora_down" else (oc, 0), device=v.device, dtype=torch.bfloat16 + ) + + elif "lora" in k: + new_quantized_part_sd[k] = v + transformer._quantized_part_sd = new_quantized_part_sd + m = load_quantized_module( + quantized_part_sd, + device=device, + use_fp4=precision == "fp4", + offload=offload, + bf16=torch_dtype == torch.bfloat16, + ) + transformer.inject_quantized_module(m, device) + transformer.to_empty(device=device) + + transformer.load_state_dict(unquantized_part_sd, strict=False) + transformer._unquantized_part_sd = unquantized_part_sd + + if kwargs.get("return_metadata", False): + return transformer, metadata + else: + return transformer + + def inject_quantized_module(self, m: QuantizedFluxModel, device: str | torch.device = "cuda"): + """ + Injects a quantized module into the model and sets up transformer blocks. + + Parameters + ---------- + m : QuantizedFluxModel + The quantized transformer model. + device : str or torch.device, optional + Device to run the model on (default: "cuda"). + + Returns + ------- + self : NunchakuFluxTransformer2dModel + The model with injected quantized module. + """ + print("Injecting quantized module") + self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=[16, 56, 56]) + + ### Compatible with the original forward method + self.transformer_blocks = nn.ModuleList([NunchakuFluxTransformerBlocks(m, device)]) + self.single_transformer_blocks = nn.ModuleList([]) + + return self + + def set_attention_impl(self, impl: str): + """ + Set the attention implementation for the quantized transformer block. + + Parameters + ---------- + impl : str + Attention implementation to use. Supported values: + + - ``"flashattn2"`` (default): Standard FlashAttention-2. + - ``"nunchaku-fp16"``: Uses FP16 attention accumulation, up to 1.2× faster than FlashAttention-2 on NVIDIA 30-, 40-, and 50-series GPUs. + """ + block = self.transformer_blocks[0] + assert isinstance(block, NunchakuFluxTransformerBlocks) + block.m.setAttentionImpl(impl) + + ### LoRA Related Functions + + def _expand_module(self, module_name: str, new_shape: tuple[int, int]): + """ + Expands a linear module to a new shape for LoRA compatibility. + Mostly for FLUX.1-tools LoRA which changes the input channels. + + Parameters + ---------- + module_name : str + Name of the module to expand. + new_shape : tuple[int, int] + New shape (out_features, in_features) for the module. + """ + module = self.get_submodule(module_name) + assert isinstance(module, nn.Linear) + weight_shape = module.weight.shape + logger.info("Expand the shape of module {} from {} to {}".format(module_name, tuple(weight_shape), new_shape)) + assert new_shape[0] >= weight_shape[0] and new_shape[1] >= weight_shape[1] + new_module = nn.Linear( + new_shape[1], + new_shape[0], + bias=module.bias is not None, + device=module.weight.device, + dtype=module.weight.dtype, + ) + new_module.weight.data.zero_() + new_module.weight.data[: weight_shape[0], : weight_shape[1]] = module.weight.data + self._unquantized_part_sd[f"{module_name}.weight"] = new_module.weight.data.clone() + if new_module.bias is not None: + new_module.bias.data.zero_() + new_module.bias.data[: weight_shape[0]] = module.bias.data + self._unquantized_part_sd[f"{module_name}.bias"] = new_module.bias.data.clone() + parent_name = ".".join(module_name.split(".")[:-1]) + parent_module = self.get_submodule(parent_name) + parent_module.add_module(module_name.split(".")[-1], new_module) + + if module_name == "x_embedder": + new_value = int(new_module.weight.data.shape[1]) + old_value = getattr(self.config, "in_channels") + if new_value != old_value: + logger.info(f"Update in_channels from {old_value} to {new_value}") + setattr(self.config, "in_channels", new_value) + + def _update_unquantized_part_lora_params(self, strength: float = 1): + """ + Updates the unquantized part of the model with LoRA parameters. + + Parameters + ---------- + strength : float, optional + LoRA scaling strength (default: 1). + """ + # check if we need to expand the linear layers + device = next(self.parameters()).device + for k, v in self._unquantized_part_loras.items(): + if "lora_A" in k: + lora_a = v + lora_b = self._unquantized_part_loras[k.replace(".lora_A.", ".lora_B.")] + diff_shape = (lora_b.shape[0], lora_a.shape[1]) + weight_shape = self._unquantized_part_sd[k.replace(".lora_A.", ".")].shape + if diff_shape[0] > weight_shape[0] or diff_shape[1] > weight_shape[1]: + module_name = ".".join(k.split(".")[:-2]) + self._expand_module(module_name, diff_shape) + elif v.ndim == 1: + diff_shape = v.shape + weight_shape = self._unquantized_part_sd[k].shape + if diff_shape[0] > weight_shape[0]: + assert diff_shape[0] >= weight_shape[0] + module_name = ".".join(k.split(".")[:-1]) + module = self.get_submodule(module_name) + weight_shape = module.weight.shape + diff_shape = (diff_shape[0], weight_shape[1]) + self._expand_module(module_name, diff_shape) + new_state_dict = {} + for k in self._unquantized_part_sd.keys(): + v = self._unquantized_part_sd[k] + v = v.to(device) + self._unquantized_part_sd[k] = v + + if v.ndim == 1 and k in self._unquantized_part_loras: + diff = strength * self._unquantized_part_loras[k] + if diff.shape[0] < v.shape[0]: + diff = torch.cat( + [diff, torch.zeros(v.shape[0] - diff.shape[0], device=device, dtype=v.dtype)], dim=0 + ) + new_state_dict[k] = v + diff + elif v.ndim == 2 and k.replace(".weight", ".lora_B.weight") in self._unquantized_part_loras: + lora_a = self._unquantized_part_loras[k.replace(".weight", ".lora_A.weight")] + lora_b = self._unquantized_part_loras[k.replace(".weight", ".lora_B.weight")] + + if lora_a.shape[1] < v.shape[1]: + lora_a = torch.cat( + [ + lora_a, + torch.zeros(lora_a.shape[0], v.shape[1] - lora_a.shape[1], device=device, dtype=v.dtype), + ], + dim=1, + ) + if lora_b.shape[0] < v.shape[0]: + lora_b = torch.cat( + [ + lora_b, + torch.zeros(v.shape[0] - lora_b.shape[0], lora_b.shape[1], device=device, dtype=v.dtype), + ], + dim=0, + ) + + diff = strength * (lora_b @ lora_a) + new_state_dict[k] = v + diff + else: + new_state_dict[k] = v + self.load_state_dict(new_state_dict, strict=True) + + def update_lora_params(self, path_or_state_dict: str | dict[str, torch.Tensor]): + """ + Update the model with new LoRA parameters. + + Parameters + ---------- + path_or_state_dict : str or dict + Path to a LoRA weights file or a state dict. The path supports: + + - Local file path, e.g., ``"/path/to/your/lora.safetensors"`` + - HuggingFace repo with file, e.g., ``"user/repo/lora.safetensors"`` + (automatically downloaded and cached) + """ + if isinstance(path_or_state_dict, dict): + state_dict = { + k: v for k, v in path_or_state_dict.items() + } # copy a new one to avoid modifying the original one + else: + state_dict = load_state_dict_in_safetensors(path_or_state_dict) + + if not is_nunchaku_format(state_dict): + state_dict = to_nunchaku(state_dict, base_sd=self._quantized_part_sd) + + unquantized_part_loras = {} + for k, v in list(state_dict.items()): + device = next(self.parameters()).device + if "transformer_blocks" not in k: + unquantized_part_loras[k] = state_dict.pop(k).to(device) + + if len(self._unquantized_part_loras) > 0 or len(unquantized_part_loras) > 0: + self._unquantized_part_loras = unquantized_part_loras + + self._unquantized_part_sd = {k: v for k, v in self._unquantized_part_sd.items() if "pulid_ca" not in k} + self._update_unquantized_part_lora_params(1) + + quantized_part_vectors = {} + for k, v in list(state_dict.items()): + if v.ndim == 1: + quantized_part_vectors[k] = state_dict.pop(k) + if len(self._quantized_part_vectors) > 0 or len(quantized_part_vectors) > 0: + self._quantized_part_vectors = quantized_part_vectors + updated_vectors = fuse_vectors(quantized_part_vectors, self._quantized_part_sd, 1) + state_dict.update(updated_vectors) + + # Get the vectors from the quantized part + + block = self.transformer_blocks[0] + assert isinstance(block, NunchakuFluxTransformerBlocks) + + block.m.loadDict(state_dict, True) + + def set_lora_strength(self, strength: float = 1): + """ + Sets the LoRA scaling strength for the model. + + Note: This function can only be used with a single LoRA. For multiple LoRAs, + please fuse the LoRA scale into the weights. + + Parameters + ---------- + strength : float, optional + LoRA scaling strength (default: 1). + + Note: This function will change the strength of all the LoRAs. So only use it when you only have a single LoRA. + """ + block = self.transformer_blocks[0] + assert isinstance(block, NunchakuFluxTransformerBlocks) + block.m.setLoraScale(SVD_RANK, strength) + if len(self._unquantized_part_loras) > 0: + self._update_unquantized_part_lora_params(strength) + if len(self._quantized_part_vectors) > 0: + vector_dict = fuse_vectors(self._quantized_part_vectors, self._quantized_part_sd, strength) + block.m.loadDict(vector_dict, True) + + def reset_x_embedder(self): + """ + Resets the x_embedder module if the input channel count has changed. + This is used for removing the effect of FLUX.1-tools LoRA which changes the input channels. + """ + # if change the model in channels, we need to update the x_embedder + if self._original_in_channels != self.config.in_channels: + assert self._original_in_channels < self.config.in_channels + old_module = self.x_embedder + new_module = nn.Linear( + in_features=self._original_in_channels, + out_features=old_module.out_features, + bias=old_module.bias is not None, + device=old_module.weight.device, + dtype=old_module.weight.dtype, + ) + new_module.weight.data.copy_(old_module.weight.data[: new_module.out_features, : new_module.in_features]) + self._unquantized_part_sd["x_embedder.weight"] = new_module.weight.data.clone() + if new_module.bias is not None: + new_module.bias.data.zero_() + new_module.bias.data.copy_(old_module.bias.data[: new_module.out_features]) + self._unquantized_part_sd["x_embedder.bias"] = new_module.bias.data.clone() + self.x_embedder = new_module + setattr(self.config, "in_channels", self._original_in_channels) + + def reset_lora(self): + """ + Resets all LoRA parameters to their default state. + """ + unquantized_part_loras = {} + if len(self._unquantized_part_loras) > 0 or len(unquantized_part_loras) > 0: + self._unquantized_part_loras = unquantized_part_loras + self._update_unquantized_part_lora_params(1) + state_dict = {k: v for k, v in self._quantized_part_sd.items() if "lora" in k} + quantized_part_vectors = {} + if len(self._quantized_part_vectors) > 0 or len(quantized_part_vectors) > 0: + self._quantized_part_vectors = quantized_part_vectors + updated_vectors = fuse_vectors(quantized_part_vectors, self._quantized_part_sd, 1) + state_dict.update(updated_vectors) + self.transformer_blocks[0].m.loadDict(state_dict, True) + self.reset_x_embedder() + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples=None, + controlnet_single_block_samples=None, + return_dict: bool = True, + controlnet_blocks_repeat: bool = False, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + Forward pass for the Nunchaku FLUX transformer model. + + This method is compatible with the Diffusers pipeline and supports LoRA, + rotary embeddings, and ControlNet. + + Parameters + ---------- + hidden_states : torch.FloatTensor + Input hidden states of shape (batch_size, channel, height, width). + encoder_hidden_states : torch.FloatTensor, optional + Conditional embeddings (e.g., prompt embeddings) of shape (batch_size, sequence_len, embed_dims). + pooled_projections : torch.FloatTensor, optional + Embeddings projected from the input conditions. + timestep : torch.LongTensor, optional + Denoising step. + img_ids : torch.Tensor, optional + Image token indices. + txt_ids : torch.Tensor, optional + Text token indices. + guidance : torch.Tensor, optional + Guidance tensor for classifier-free guidance. + joint_attention_kwargs : dict, optional + Additional kwargs for joint attention. + controlnet_block_samples : list[torch.Tensor], optional + ControlNet block samples. + controlnet_single_block_samples : list[torch.Tensor], optional + ControlNet single block samples. + return_dict : bool, optional + Whether to return a Transformer2DModelOutput (default: True). + controlnet_blocks_repeat : bool, optional + Whether to repeat ControlNet blocks (default: False). + + Returns + ------- + torch.FloatTensor or Transformer2DModelOutput + Output tensor or output object containing the sample. + """ + hidden_states = self.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + else: + guidance = None + + temb = ( + self.time_text_embed(timestep, pooled_projections) + if guidance is None + else self.time_text_embed(timestep, guidance, pooled_projections) + ) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) + + if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: + ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") + ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) + joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) + + nunchaku_block = self.transformer_blocks[0] + encoder_hidden_states, hidden_states = nunchaku_block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + ) + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/nunchaku/models/transformers/transformer_flux_v2.py b/nunchaku/models/transformers/transformer_flux_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..f2b2d9d052af855bdad2d88b5b2fff0375463da0 --- /dev/null +++ b/nunchaku/models/transformers/transformer_flux_v2.py @@ -0,0 +1,646 @@ +""" +This module provides Nunchaku FluxTransformer2DModel and its building blocks in Python. +""" + +import json +import os +from pathlib import Path +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.transformers.transformer_flux import ( + FluxAttention, + FluxSingleTransformerBlock, + FluxTransformer2DModel, + FluxTransformerBlock, +) +from huggingface_hub import utils +from torch.nn import GELU + +from ...ops.fused import fused_gelu_mlp +from ...utils import get_precision, pad_tensor +from ..attention import NunchakuBaseAttention, NunchakuFeedForward +from ..attention_processors.flux import NunchakuFluxFA2Processor, NunchakuFluxFP16AttnProcessor +from ..embeddings import NunchakuFluxPosEmbed, pack_rotemb +from ..linear import SVDQW4A4Linear +from ..normalization import NunchakuAdaLayerNormZero, NunchakuAdaLayerNormZeroSingle +from ..utils import fuse_linears +from .utils import NunchakuModelLoaderMixin + + +class NunchakuFluxAttention(NunchakuBaseAttention): + """ + Nunchaku-optimized FluxAttention module with quantized and fused QKV projections. + + Parameters + ---------- + other : FluxAttention + The original FluxAttention module to wrap and quantize. + processor : str, optional + The attention processor to use ("flashattn2" or "nunchaku-fp16"). + **kwargs + Additional arguments for quantization. + """ + + def __init__(self, other: FluxAttention, processor: str = "flashattn2", **kwargs): + super(NunchakuFluxAttention, self).__init__(processor) + self.head_dim = other.head_dim + self.inner_dim = other.inner_dim + self.query_dim = other.query_dim + self.use_bias = other.use_bias + self.dropout = other.dropout + self.out_dim = other.out_dim + self.context_pre_only = other.context_pre_only + self.pre_only = other.pre_only + self.heads = other.heads + self.added_kv_proj_dim = other.added_kv_proj_dim + self.added_proj_bias = other.added_proj_bias + + self.norm_q = other.norm_q + self.norm_k = other.norm_k + + # Fuse the QKV projections for efficiency. + with torch.device("meta"): + to_qkv = fuse_linears([other.to_q, other.to_k, other.to_v]) + self.to_qkv = SVDQW4A4Linear.from_linear(to_qkv, **kwargs) + + if not self.pre_only: + self.to_out = other.to_out + self.to_out[0] = SVDQW4A4Linear.from_linear(self.to_out[0], **kwargs) + + if self.added_kv_proj_dim is not None: + self.norm_added_q = other.norm_added_q + self.norm_added_k = other.norm_added_k + + # Fuse the additional QKV projections. + with torch.device("meta"): + add_qkv_proj = fuse_linears([other.add_q_proj, other.add_k_proj, other.add_v_proj]) + self.add_qkv_proj = SVDQW4A4Linear.from_linear(add_qkv_proj, **kwargs) + self.to_add_out = SVDQW4A4Linear.from_linear(other.to_add_out, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor = None, + **kwargs, + ): + """ + Forward pass for NunchakuFluxAttention. + + Parameters + ---------- + hidden_states : torch.Tensor + Input tensor. + encoder_hidden_states : torch.Tensor, optional + Encoder hidden states for cross-attention. + attention_mask : torch.Tensor, optional + Attention mask. + image_rotary_emb : tuple or torch.Tensor, optional + Rotary embeddings for image/text tokens. + **kwargs + Additional arguments. + + Returns + ------- + Output of the attention processor. + """ + return self.processor( + attn=self, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + + def set_processor(self, processor: str): + """ + Set the attention processor. + + Parameters + ---------- + processor : str + Name of the processor ("flashattn2" or "nunchaku-fp16"). + + - ``"flashattn2"``: Standard FlashAttention-2. See :class:`~nunchaku.models.attention_processors.flux.NunchakuFluxFA2Processor`. + - ``"nunchaku-fp16"``: Uses FP16 attention accumulation, up to 1.2× faster than FlashAttention-2 on NVIDIA 30-, 40-, and 50-series GPUs. See :class:`~nunchaku.models.attention_processors.flux.NunchakuFluxFP16AttnProcessor`. + + Raises + ------ + ValueError + If the processor is not supported. + """ + if processor == "flashattn2": + self.processor = NunchakuFluxFA2Processor() + elif processor == "nunchaku-fp16": + self.processor = NunchakuFluxFP16AttnProcessor() + else: + raise ValueError(f"Processor {processor} is not supported") + + +class NunchakuFluxTransformerBlock(FluxTransformerBlock): + """ + Nunchaku-optimized FluxTransformerBlock with quantized attention and feedforward layers. + + Parameters + ---------- + block : FluxTransformerBlock + The original block to wrap and quantize. + scale_shift : float, optional + Value to add to scale parameters. Default is 1.0. + Nunchaku may have already fused the scale_shift into the linear weights, so you may want to set it to 0. + **kwargs + Additional arguments for quantization. + """ + + def __init__(self, block: FluxTransformerBlock, scale_shift: float = 1, **kwargs): + super(FluxTransformerBlock, self).__init__() + self.scale_shift = scale_shift + + # The scale_shift=1 from AdaLayerNormZero has already been fused into the linear weights, + # so we set scale_shift=0 here to avoid applying it again. + self.norm1 = NunchakuAdaLayerNormZero(block.norm1, scale_shift=scale_shift) + self.norm1_context = NunchakuAdaLayerNormZero(block.norm1_context, scale_shift=scale_shift) + + self.attn = NunchakuFluxAttention(block.attn, **kwargs) + self.norm2 = block.norm2 + self.norm2_context = block.norm2_context + self.ff = NunchakuFeedForward(block.ff, **kwargs) + self.ff_context = NunchakuFeedForward(block.ff_context, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Forward pass for the transformer block. + + Parameters + ---------- + hidden_states : torch.Tensor + Input hidden states. + encoder_hidden_states : torch.Tensor + Encoder hidden states for cross-attention. + temb : torch.Tensor + Time or conditioning embedding. + image_rotary_emb : tuple of torch.Tensor, optional + Rotary embeddings for image/text tokens. + joint_attention_kwargs : dict, optional + Additional attention arguments (not supported). + + Returns + ------- + tuple + (encoder_hidden_states, hidden_states) after block processing. + + Raises + ------ + NotImplementedError + If joint_attention_kwargs is provided. + """ + if joint_attention_kwargs is not None and len(joint_attention_kwargs) > 0: + raise NotImplementedError("joint_attention_kwargs is not supported") + + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + joint_attention_kwargs = joint_attention_kwargs or {} + + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * scale_mlp[:, None] + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * c_scale_mlp[:, None] + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class NunchakuFluxSingleTransformerBlock(FluxSingleTransformerBlock): + """ + Nunchaku-optimized single transformer block with quantized attention and MLP. + + Parameters + ---------- + block : FluxSingleTransformerBlock + The original block to wrap and quantize. + scale_shift : float, optional + Value to add to scale parameters. Default is 1.0. + Nunchaku may have already fused the scale_shift into the linear weights, so you may want to set it to 0. + **kwargs + Additional arguments for quantization. + """ + + def __init__(self, block: FluxSingleTransformerBlock, scale_shift: float = 1, **kwargs): + super(FluxSingleTransformerBlock, self).__init__() + self.mlp_hidden_dim = block.mlp_hidden_dim + self.norm = block.norm + self.norm = NunchakuAdaLayerNormZeroSingle(block.norm, scale_shift=scale_shift) + + self.mlp_fc1 = SVDQW4A4Linear.from_linear(block.proj_mlp, **kwargs) + self.act_mlp = block.act_mlp + self.mlp_fc2 = SVDQW4A4Linear.from_linear(block.proj_out, in_features=self.mlp_hidden_dim, **kwargs) + # For int4, we shift the activation of mlp_fc2 to make it unsigned. + self.mlp_fc2.act_unsigned = self.mlp_fc2.precision != "nvfp4" + + self.attn = NunchakuFluxAttention(block.attn, **kwargs) + self.attn.to_out = SVDQW4A4Linear.from_linear(block.proj_out, in_features=self.mlp_fc1.in_features, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: + """ + Forward pass for the single transformer block. + + Parameters + ---------- + hidden_states : torch.Tensor + Input hidden states. + temb : torch.Tensor + Time or conditioning embedding. + image_rotary_emb : tuple of torch.Tensor, optional + Rotary embeddings for tokens. + joint_attention_kwargs : dict, optional + Additional attention arguments. + + Returns + ------- + torch.Tensor + Output hidden states after block processing. + """ + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + + # Feedforward + if isinstance(self.act_mlp, GELU): + # Use fused GELU MLP for efficiency. + mlp_hidden_states = fused_gelu_mlp(norm_hidden_states, self.mlp_fc1, self.mlp_fc2) + else: + # Fallback to original MLP. + mlp_hidden_states = self.mlp_fc1(norm_hidden_states) + mlp_hidden_states = self.act_mlp(mlp_hidden_states) + mlp_hidden_states = self.mlp_fc2(mlp_hidden_states) + + # Attention + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, **joint_attention_kwargs + ) + + hidden_states = attn_output + mlp_hidden_states + gate = gate.unsqueeze(1) + hidden_states = gate * hidden_states + hidden_states = residual + hidden_states + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + return hidden_states + + +class NunchakuFluxTransformer2DModelV2(FluxTransformer2DModel, NunchakuModelLoaderMixin): + """ + Nunchaku-optimized FluxTransformer2DModel. + """ + + def _patch_model(self, **kwargs): + """ + Patch the model with :class:`~nunchaku.models.transformers.transformer_flux_v2.NunchakuFluxTransformerBlock` + and :class:`~nunchaku.models.transformers.transformer_flux_v2.NunchakuFluxSingleTransformerBlock`. + + Parameters + ---------- + **kwargs + Additional arguments for quantization. + + Returns + ------- + self : NunchakuFluxTransformer2DModelV2 + The patched model. + """ + self.pos_embed = NunchakuFluxPosEmbed(dim=self.inner_dim, theta=10000, axes_dim=self.pos_embed.axes_dim) + for i, block in enumerate(self.transformer_blocks): + self.transformer_blocks[i] = NunchakuFluxTransformerBlock(block, scale_shift=0, **kwargs) + for i, block in enumerate(self.single_transformer_blocks): + self.single_transformer_blocks[i] = NunchakuFluxSingleTransformerBlock(block, scale_shift=0, **kwargs) + return self + + @classmethod + @utils.validate_hf_hub_args + def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs): + """ + Load a pretrained NunchakuFluxTransformer2DModelV2 from a safetensors file. + + Parameters + ---------- + pretrained_model_name_or_path : str or os.PathLike + Path to the safetensors file. It can be a local file or a remote HuggingFace path. + **kwargs + Additional arguments (e.g., device, torch_dtype). + + Returns + ------- + NunchakuFluxTransformer2DModelV2 + The loaded and quantized model. + + Raises + ------ + NotImplementedError + If offload is requested. + AssertionError + If the file is not a safetensors file. + """ + device = kwargs.get("device", "cpu") + offload = kwargs.get("offload", False) + + if offload: + raise NotImplementedError("Offload is not supported for FluxTransformer2DModelV2") + + torch_dtype = kwargs.get("torch_dtype", torch.bfloat16) + + if isinstance(pretrained_model_name_or_path, str): + pretrained_model_name_or_path = Path(pretrained_model_name_or_path) + + assert pretrained_model_name_or_path.is_file() or pretrained_model_name_or_path.name.endswith( + (".safetensors", ".sft") + ), "Only safetensors are supported" + transformer, model_state_dict, metadata = cls._build_model(pretrained_model_name_or_path, **kwargs) + quantization_config = json.loads(metadata.get("quantization_config", "{}")) + rank = quantization_config.get("rank", 32) + transformer = transformer.to(torch_dtype) + + precision = get_precision() + if precision == "fp4": + precision = "nvfp4" + transformer._patch_model(precision=precision, rank=rank) + + transformer = transformer.to_empty(device=device) + converted_state_dict = convert_flux_state_dict(model_state_dict) + + state_dict = transformer.state_dict() + + for k in state_dict.keys(): + if k not in converted_state_dict: + assert ".wcscales" in k + converted_state_dict[k] = torch.ones_like(state_dict[k]) + else: + assert state_dict[k].dtype == converted_state_dict[k].dtype + + # Load the wtscale from the converted state dict. + for n, m in transformer.named_modules(): + if isinstance(m, SVDQW4A4Linear): + if m.wtscale is not None: + m.wtscale = converted_state_dict.pop(f"{n}.wtscale", 1.0) + + transformer.load_state_dict(converted_state_dict) + + return transformer + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples=None, + controlnet_single_block_samples=None, + return_dict: bool = True, + controlnet_blocks_repeat: bool = False, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + Forward pass for the NunchakuFluxTransformer2DModelV2. + + Parameters + ---------- + hidden_states : torch.Tensor + Input hidden states of shape (batch_size, image_sequence_length, in_channels). + encoder_hidden_states : torch.Tensor, optional + Conditional embeddings (e.g., from text). + pooled_projections : torch.Tensor, optional + Projected embeddings from input conditions. + timestep : torch.LongTensor, optional + Denoising step. + img_ids : torch.Tensor, optional + Image token IDs. + txt_ids : torch.Tensor, optional + Text token IDs. + guidance : torch.Tensor, optional + Guidance tensor for classifier-free guidance. + joint_attention_kwargs : dict, optional + Additional attention arguments. + controlnet_block_samples : any, optional + Not supported. + controlnet_single_block_samples : any, optional + Not supported. + return_dict : bool, optional + Whether to return a Transformer2DModelOutput (default: True). + controlnet_blocks_repeat : bool, optional + Not supported. + + Returns + ------- + Transformer2DModelOutput or tuple + Output sample tensor or output tuple. + + Raises + ------ + NotImplementedError + If controlnet is requested. + """ + hidden_states = self.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = ( + self.time_text_embed(timestep, pooled_projections) + if guidance is None + else self.time_text_embed(timestep, guidance, pooled_projections) + ) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) + + if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: + ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") + ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) + joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) + + txt_tokens = encoder_hidden_states.shape[1] + img_tokens = hidden_states.shape[1] + + assert image_rotary_emb.ndim == 6 + assert image_rotary_emb.shape[0] == 1 + assert image_rotary_emb.shape[1] == 1 + assert image_rotary_emb.shape[2] == 1 * (txt_tokens + img_tokens) + # [1, tokens, head_dim / 2, 1, 2] (sincos) + image_rotary_emb = image_rotary_emb.reshape([1, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]]) + rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] # .to(self.dtype) + rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype) + rotary_emb_single = image_rotary_emb + + rotary_emb_txt = pack_rotemb(pad_tensor(rotary_emb_txt, 256, 1)) + rotary_emb_img = pack_rotemb(pad_tensor(rotary_emb_img, 256, 1)) + rotary_emb_single = pack_rotemb(pad_tensor(rotary_emb_single, 256, 1)) + + for index_block, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=(rotary_emb_img, rotary_emb_txt), + joint_attention_kwargs=joint_attention_kwargs, + ) + + # Controlnet residual (not supported for now) + if controlnet_block_samples is not None: + raise NotImplementedError("Controlnet is not supported for FluxTransformer2DModelV2 for now") + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + for index_block, block in enumerate(self.single_transformer_blocks): + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=rotary_emb_single, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # Controlnet residual (not supported for now) + if controlnet_single_block_samples is not None: + raise NotImplementedError("Controlnet is not supported for FluxTransformer2DModelV2 for now") + + hidden_states = hidden_states[:, txt_tokens:] + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + +def convert_flux_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Convert a state dict from the :class:`~nunchaku.models.transformers.transformer_flux.NunchakuFluxTransformer2dModel` + format to :class:`~nunchaku.models.transformers.transformer_flux_v2.NunchakuFluxTransformer2DModelV2` format. + + Parameters + ---------- + state_dict : dict[str, torch.Tensor] + The original state dict. + + Returns + ------- + dict[str, torch.Tensor] + The converted state dict compatible with :class:`~nunchaku.models.transformers.transformer_flux_v2.NunchakuFluxTransformer2DModelV2`. + """ + new_state_dict = {} + for k, v in state_dict.items(): + if "single_transformer_blocks." in k: + if ".qkv_proj." in k: + new_k = k.replace(".qkv_proj.", ".attn.to_qkv.") + elif ".out_proj." in k: + new_k = k.replace(".out_proj.", ".attn.to_out.") + elif ".norm_q." in k or ".norm_k." in k: + new_k = k.replace(".norm_k.", ".attn.norm_k.") + new_k = new_k.replace(".norm_q.", ".attn.norm_q.") + else: + new_k = k + new_k = new_k.replace(".lora_down", ".proj_down") + new_k = new_k.replace(".lora_up", ".proj_up") + if ".smooth_orig" in k: + new_k = new_k.replace(".smooth_orig", ".smooth_factor_orig") + elif ".smooth" in k: + new_k = new_k.replace(".smooth", ".smooth_factor") + new_state_dict[new_k] = v + elif "transformer_blocks." in k: + if ".mlp_context_fc1" in k: + new_k = k.replace(".mlp_context_fc1.", ".ff_context.net.0.proj.") + elif ".mlp_context_fc2" in k: + new_k = k.replace(".mlp_context_fc2.", ".ff_context.net.2.") + elif ".mlp_fc1" in k: + new_k = k.replace(".mlp_fc1.", ".ff.net.0.proj.") + elif ".mlp_fc2" in k: + new_k = k.replace(".mlp_fc2.", ".ff.net.2.") + elif ".qkv_proj_context." in k: + new_k = k.replace(".qkv_proj_context.", ".attn.add_qkv_proj.") + elif ".qkv_proj." in k: + new_k = k.replace(".qkv_proj.", ".attn.to_qkv.") + elif ".norm_q." in k or ".norm_k." in k: + new_k = k.replace(".norm_k.", ".attn.norm_k.") + new_k = new_k.replace(".norm_q.", ".attn.norm_q.") + elif ".norm_added_q." in k or ".norm_added_k." in k: + new_k = k.replace(".norm_added_k.", ".attn.norm_added_k.") + new_k = new_k.replace(".norm_added_q.", ".attn.norm_added_q.") + elif ".out_proj." in k: + new_k = k.replace(".out_proj.", ".attn.to_out.0.") + elif ".out_proj_context." in k: + new_k = k.replace(".out_proj_context.", ".attn.to_add_out.") + else: + new_k = k + new_k = new_k.replace(".lora_down", ".proj_down") + new_k = new_k.replace(".lora_up", ".proj_up") + if ".smooth_orig" in k: + new_k = new_k.replace(".smooth_orig", ".smooth_factor_orig") + elif ".smooth" in k: + new_k = new_k.replace(".smooth", ".smooth_factor") + new_state_dict[new_k] = v + else: + new_state_dict[k] = v + + return new_state_dict diff --git a/nunchaku/models/transformers/transformer_qwenimage.py b/nunchaku/models/transformers/transformer_qwenimage.py new file mode 100644 index 0000000000000000000000000000000000000000..8476f9fed05ed03767cb261dd485d29245e5a85b --- /dev/null +++ b/nunchaku/models/transformers/transformer_qwenimage.py @@ -0,0 +1,601 @@ +""" +This module provides implementations of NunchakuQwenImageTransformer2DModel and its building blocks. +""" + +import gc +import json +import os +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union +from warnings import warn + +import torch +from diffusers.models.attention_processor import Attention +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.transformers.transformer_qwenimage import ( + QwenEmbedRope, + QwenImageTransformer2DModel, + QwenImageTransformerBlock, +) +from huggingface_hub import utils + +from ...utils import get_precision +from ..attention import NunchakuBaseAttention, NunchakuFeedForward +from ..attention_processors.qwenimage import NunchakuQwenImageNaiveFA2Processor +from ..linear import AWQW4A16Linear, SVDQW4A4Linear +from ..utils import CPUOffloadManager, fuse_linears +from .utils import NunchakuModelLoaderMixin + + +class NunchakuQwenAttention(NunchakuBaseAttention): + """ + Nunchaku-optimized quantized attention module for QwenImage. + + Parameters + ---------- + other : Attention + The original QwenImage Attention module to wrap and quantize. + processor : str, default="flashattn2" + The attention processor to use. + **kwargs + Additional arguments for quantization. + """ + + def __init__(self, other: Attention, processor: str = "flashattn2", **kwargs): + super(NunchakuQwenAttention, self).__init__(processor) + self.inner_dim = other.inner_dim + self.inner_kv_dim = other.inner_kv_dim + self.query_dim = other.query_dim + self.use_bias = other.use_bias + self.is_cross_attention = other.is_cross_attention + self.cross_attention_dim = other.cross_attention_dim + self.upcast_attention = other.upcast_attention + self.upcast_softmax = other.upcast_softmax + self.rescale_output_factor = other.rescale_output_factor + self.residual_connection = other.residual_connection + self.dropout = other.dropout + self.fused_projections = other.fused_projections + self.out_dim = other.out_dim + self.out_context_dim = other.out_context_dim + self.context_pre_only = other.context_pre_only + self.pre_only = other.pre_only + self.is_causal = other.is_causal + self.scale_qk = other.scale_qk + self.scale = other.scale + self.heads = other.heads + self.sliceable_head_dim = other.sliceable_head_dim + self.added_kv_proj_dim = other.added_kv_proj_dim + self.only_cross_attention = other.only_cross_attention + self.group_norm = other.group_norm + self.spatial_norm = other.spatial_norm + + self.norm_cross = other.norm_cross + + self.norm_q = other.norm_q + self.norm_k = other.norm_k + self.norm_added_q = other.norm_added_q + self.norm_added_k = other.norm_added_k + + # Fuse the QKV projections for quantization + with torch.device("meta"): + to_qkv = fuse_linears([other.to_q, other.to_k, other.to_v]) + self.to_qkv = SVDQW4A4Linear.from_linear(to_qkv, **kwargs) + self.to_out = other.to_out + self.to_out[0] = SVDQW4A4Linear.from_linear(self.to_out[0], **kwargs) + + assert self.added_kv_proj_dim is not None + # Fuse the additional QKV projections + with torch.device("meta"): + add_qkv_proj = fuse_linears([other.add_q_proj, other.add_k_proj, other.add_v_proj]) + self.add_qkv_proj = SVDQW4A4Linear.from_linear(add_qkv_proj, **kwargs) + self.to_add_out = SVDQW4A4Linear.from_linear(other.to_add_out, **kwargs) + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + encoder_hidden_states_mask: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ): + """ + Forward pass for NunchakuQwenAttention. + + Parameters + ---------- + hidden_states : torch.FloatTensor + Image stream input. + encoder_hidden_states : torch.FloatTensor, optional + Text stream input. + encoder_hidden_states_mask : torch.FloatTensor, optional + Mask for encoder hidden states. + attention_mask : torch.FloatTensor, optional + Attention mask. + image_rotary_emb : torch.Tensor, optional + Rotary embedding for images. + **kwargs + Additional arguments. + + Returns + ------- + tuple + Attention outputs for image and text streams. + """ + return self.processor( + self, + hidden_states, + encoder_hidden_states, + encoder_hidden_states_mask, + attention_mask, + image_rotary_emb, + **kwargs, + ) + + def set_processor(self, processor: str): + """ + Set the attention processor. + + Parameters + ---------- + processor : str + Name of the processor to use. Only "flashattn2" is supported for now. See :class:`~nunchaku.models.attention_processors.qwenimage.NunchakuQwenImageNaiveFA2Processor`. + + Raises + ------ + ValueError + If the processor is not supported. + """ + if processor == "flashattn2": + self.processor = NunchakuQwenImageNaiveFA2Processor() + else: + raise ValueError(f"Processor {processor} is not supported") + + +class NunchakuQwenImageTransformerBlock(QwenImageTransformerBlock): + """ + Quantized QwenImage Transformer Block. + + This block supports quantized linear layers and joint attention for image and text streams. + + Parameters + ---------- + other : QwenImageTransformerBlock + The original transformer block to wrap and quantize. + scale_shift : float, default=1.0 + Value to add to scale parameters. Default is 1.0. + Nunchaku may have already fused the scale_shift into the linear weights, so you may want to set it to 0. + **kwargs + Additional arguments for quantization. + """ + + def __init__(self, other: QwenImageTransformerBlock, scale_shift: float = 1.0, **kwargs): + super(QwenImageTransformerBlock, self).__init__() + + self.dim = other.dim + self.img_mod = other.img_mod + self.img_mod[1] = AWQW4A16Linear.from_linear(other.img_mod[1], **kwargs) + self.img_norm1 = other.img_norm1 + self.attn = NunchakuQwenAttention(other.attn, **kwargs) + self.img_norm2 = other.img_norm2 + self.img_mlp = NunchakuFeedForward(other.img_mlp, **kwargs) + + # Text processing modules + self.txt_mod = other.txt_mod + self.txt_mod[1] = AWQW4A16Linear.from_linear(other.txt_mod[1], **kwargs) + self.txt_norm1 = other.txt_norm1 + # Text doesn't need separate attention - it's handled by img_attn joint computation + self.txt_norm2 = other.txt_norm2 + self.txt_mlp = NunchakuFeedForward(other.txt_mlp, **kwargs) + + self.scale_shift = scale_shift + + def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply modulation to input tensor. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + mod_params : torch.Tensor + Modulation parameters. + + Returns + ------- + tuple + Modulated tensor and gate tensor. + """ + shift, scale, gate = mod_params.chunk(3, dim=-1) + if self.scale_shift != 0: + scale.add_(self.scale_shift) + return x * scale.unsqueeze(1) + shift.unsqueeze(1), gate.unsqueeze(1) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_mask: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for NunchakuQwenImageTransformerBlock. + + Parameters + ---------- + hidden_states : torch.Tensor + Image stream input. + encoder_hidden_states : torch.Tensor + Text stream input. + encoder_hidden_states_mask : torch.Tensor + Mask for encoder hidden states. + temb : torch.Tensor + Temporal embedding. + image_rotary_emb : tuple of torch.Tensor, optional + Rotary embedding for images. + joint_attention_kwargs : dict, optional + Additional arguments for joint attention. + + Returns + ------- + tuple + Updated encoder_hidden_states and hidden_states. + """ + # Get modulation parameters for both streams + img_mod_params = self.img_mod(temb) # [B, 6*dim] + txt_mod_params = self.txt_mod(temb) # [B, 6*dim] + + # nunchaku's mod_params is [B, 6*dim] instead of [B, dim*6] + img_mod_params = ( + img_mod_params.view(img_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(img_mod_params.shape[0], -1) + ) + txt_mod_params = ( + txt_mod_params.view(txt_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(txt_mod_params.shape[0], -1) + ) + + img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] + txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] + + # Process image stream - norm1 + modulation + img_normed = self.img_norm1(hidden_states) + img_modulated, img_gate1 = self._modulate(img_normed, img_mod1) + + # Process text stream - norm1 + modulation + txt_normed = self.txt_norm1(encoder_hidden_states) + txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1) + + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=img_modulated, + encoder_hidden_states=txt_modulated, + encoder_hidden_states_mask=encoder_hidden_states_mask, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided + img_attn_output, txt_attn_output = attn_output + + # Apply attention gates and add residual (like in Megatron) + hidden_states = hidden_states + img_gate1 * img_attn_output + encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output + + # Process image stream - norm2 + MLP + img_normed2 = self.img_norm2(hidden_states) + img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2) + img_mlp_output = self.img_mlp(img_modulated2) + hidden_states = hidden_states + img_gate2 * img_mlp_output + + # Process text stream - norm2 + MLP + txt_normed2 = self.txt_norm2(encoder_hidden_states) + txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2) + txt_mlp_output = self.txt_mlp(txt_modulated2) + encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output + + # Clip to prevent overflow for fp16 + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuModelLoaderMixin): + """ + Quantized QwenImage Transformer2DModel. + + This model supports quantized transformer blocks and optional CPU offloading for memory efficiency. + + Parameters + ---------- + *args + Positional arguments for the base model. + **kwargs + Keyword arguments for the base model and quantization. + + Attributes + ---------- + offload : bool + Whether CPU offloading is enabled. + offload_manager : CPUOffloadManager or None + Manager for offloading transformer blocks. + _is_initialized : bool + Whether the model has been patched for quantization. + """ + + def __init__(self, *args, **kwargs): + self.offload = kwargs.pop("offload", False) + self.offload_manager = None + self._is_initialized = False + super().__init__(*args, **kwargs) + + def _patch_model(self, **kwargs): + """ + Patch the transformer blocks for quantization. + + Parameters + ---------- + **kwargs + Additional arguments for quantization. + + Returns + ------- + self + """ + for i, block in enumerate(self.transformer_blocks): + self.transformer_blocks[i] = NunchakuQwenImageTransformerBlock(block, scale_shift=0, **kwargs) + self._is_initialized = True + return self + + @classmethod + @utils.validate_hf_hub_args + def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs): + """ + Load a quantized model from a pretrained checkpoint. + + Parameters + ---------- + pretrained_model_name_or_path : str or os.PathLike + Path to the pretrained model checkpoint. It can be a local file or a remote HuggingFace path. + **kwargs + Additional arguments for loading and quantization. + + Returns + ------- + NunchakuQwenImageTransformer2DModel + The loaded and quantized model. + + Raises + ------ + AssertionError + If the checkpoint is not a safetensors file. + """ + device = kwargs.get("device", "cpu") + offload = kwargs.get("offload", False) + + torch_dtype = kwargs.get("torch_dtype", torch.bfloat16) + + if isinstance(pretrained_model_name_or_path, str): + pretrained_model_name_or_path = Path(pretrained_model_name_or_path) + + assert pretrained_model_name_or_path.is_file() or pretrained_model_name_or_path.name.endswith( + (".safetensors", ".sft") + ), "Only safetensors are supported" + transformer, model_state_dict, metadata = cls._build_model(pretrained_model_name_or_path, **kwargs) + quantization_config = json.loads(metadata.get("quantization_config", "{}")) + config = json.loads(metadata.get("config", "{}")) + rank = quantization_config.get("rank", 32) + transformer = transformer.to(torch_dtype) + + precision = get_precision() + if precision == "fp4": + precision = "nvfp4" + transformer._patch_model(precision=precision, rank=rank) + + transformer = transformer.to_empty(device=device) + # need to re-init the pos_embed as to_empty does not work on it + transformer.pos_embed = QwenEmbedRope( + theta=10000, axes_dim=list(config.get("axes_dims_rope", [16, 56, 56])), scale_rope=True + ) + + state_dict = transformer.state_dict() + for k in state_dict.keys(): + if k not in model_state_dict: + assert ".wcscales" in k + model_state_dict[k] = torch.ones_like(state_dict[k]) + else: + assert state_dict[k].dtype == model_state_dict[k].dtype + + # load the wtscale from the state dict, as it is a float on CPU + for n, m in transformer.named_modules(): + if isinstance(m, SVDQW4A4Linear): + if m.wtscale is not None: + m.wtscale = model_state_dict.pop(f"{n}.wtscale", 1.0) + transformer.load_state_dict(model_state_dict) + transformer.set_offload(offload) + + return transformer + + def set_offload(self, offload: bool, **kwargs): + """ + Enable or disable asynchronous CPU offloading for transformer blocks. + + Parameters + ---------- + offload : bool + Whether to enable offloading. + **kwargs + Additional arguments for offload manager. + + See Also + -------- + :class:`~nunchaku.models.utils.CPUOffloadManager` + """ + if offload == self.offload: + # nothing changed, just return + return + self.offload = offload + if offload: + self.offload_manager = CPUOffloadManager( + self.transformer_blocks, + use_pin_memory=kwargs.get("use_pin_memory", True), + on_gpu_modules=[ + self.img_in, + self.txt_in, + self.txt_norm, + self.time_text_embed, + self.norm_out, + self.proj_out, + ], + num_blocks_on_gpu=kwargs.get("num_blocks_on_gpu", 1), + ) + else: + self.offload_manager = None + gc.collect() + torch.cuda.empty_cache() + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + encoder_hidden_states_mask: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_shapes: Optional[List[Tuple[int, int, int]]] = None, + txt_seq_lens: Optional[List[int]] = None, + guidance: torch.Tensor = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + Forward pass for the quantized QwenImage transformer model. + + Parameters + ---------- + hidden_states : torch.Tensor + Image stream input. + encoder_hidden_states : torch.Tensor, optional + Text stream input. + encoder_hidden_states_mask : torch.Tensor, optional + Mask for encoder hidden states. + timestep : torch.LongTensor, optional + Timestep for temporal embedding. + img_shapes : list of tuple, optional + Image shapes for rotary embedding. + txt_seq_lens : list of int, optional + Text sequence lengths. + guidance : torch.Tensor, optional + Guidance tensor (for classifier-free guidance). + attention_kwargs : dict, optional + Additional attention arguments. + return_dict : bool, default=True + Whether to return a dict or tuple. + + Returns + ------- + torch.Tensor or Transformer2DModelOutput + Model output. + """ + device = hidden_states.device + if self.offload: + self.offload_manager.set_device(device) + + hidden_states = self.img_in(hidden_states) + + timestep = timestep.to(hidden_states.dtype) + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = ( + self.time_text_embed(timestep, hidden_states) + if guidance is None + else self.time_text_embed(timestep, guidance, hidden_states) + ) + + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + + compute_stream = torch.cuda.current_stream() + if self.offload: + self.offload_manager.initialize(compute_stream) + for block_idx, block in enumerate(self.transformer_blocks): + with torch.cuda.stream(compute_stream): + if self.offload: + block = self.offload_manager.get_block(block_idx) + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=attention_kwargs, + ) + if self.offload: + self.offload_manager.step(compute_stream) + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + torch.cuda.empty_cache() + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + def to(self, *args, **kwargs): + """ + Override the default ``.to()`` method. + + If offload is enabled, prevents moving the model to GPU. + Prevents changing dtype after quantization. + + Parameters + ---------- + *args + Positional arguments for ``.to()``. + **kwargs + Keyword arguments for ``.to()``. + + Returns + ------- + self + + Raises + ------ + ValueError + If attempting to change dtype after quantization. + """ + device_arg_or_kwarg_present = any(isinstance(arg, torch.device) for arg in args) or "device" in kwargs + dtype_present_in_args = "dtype" in kwargs + + # Try converting arguments to torch.device in case they are passed as strings + for arg in args: + if not isinstance(arg, str): + continue + try: + torch.device(arg) + device_arg_or_kwarg_present = True + except RuntimeError: + pass + + if not dtype_present_in_args: + for arg in args: + if isinstance(arg, torch.dtype): + dtype_present_in_args = True + break + + if dtype_present_in_args and self._is_initialized: + raise ValueError( + "Casting a quantized model to a new `dtype` is unsupported. To set the dtype of unquantized layers, please " + "use the `torch_dtype` argument when loading the model using `from_pretrained` or `from_single_file`." + ) + if self.offload: + if device_arg_or_kwarg_present: + warn("Skipping moving the model to GPU as offload is enabled", UserWarning) + return self + return super(type(self), self).to(*args, **kwargs) diff --git a/nunchaku/models/transformers/transformer_sana.py b/nunchaku/models/transformers/transformer_sana.py new file mode 100644 index 0000000000000000000000000000000000000000..50320a125f56779fbc37ae49b31a329de57f618a --- /dev/null +++ b/nunchaku/models/transformers/transformer_sana.py @@ -0,0 +1,374 @@ +""" +Implements the :class:`NunchakuSanaTransformer2DModel`, +a quantized Sana transformer for Diffusers with efficient inference support. +""" + +import os +from pathlib import Path +from typing import Optional + +import torch +from diffusers import SanaTransformer2DModel +from huggingface_hub import utils +from safetensors.torch import load_file +from torch import nn +from torch.nn import functional as F + +from ..._C import QuantizedSanaModel +from ..._C import utils as cutils +from ...utils import get_precision +from .utils import NunchakuModelLoaderMixin + +SVD_RANK = 32 + + +class NunchakuSanaTransformerBlocks(nn.Module): + """ + Wrapper for quantized Sana transformer blocks. + + This module wraps a QuantizedSanaModel and provides forward methods compatible + with the expected transformer block interface. + + Parameters + ---------- + m : QuantizedSanaModel + The quantized transformer model. + dtype : torch.dtype + The data type to use for computation. + device : str or torch.device + The device to run the model on. + """ + + def __init__(self, m: QuantizedSanaModel, dtype: torch.dtype, device: str | torch.device): + super(NunchakuSanaTransformerBlocks, self).__init__() + self.m = m + self.dtype = dtype + self.device = device + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + height: Optional[int] = None, + width: Optional[int] = None, + skip_first_layer: Optional[bool] = False, + ): + """ + Forward pass through all quantized transformer blocks. + + Parameters + ---------- + hidden_states : torch.Tensor + Input hidden states of shape (batch_size, img_tokens, ...). + attention_mask : torch.Tensor, optional + Not used. + encoder_hidden_states : torch.Tensor, optional + Encoder hidden states of shape (batch_size, txt_tokens, ...). + encoder_attention_mask : torch.Tensor, optional + Encoder attention mask of shape (batch_size, 1, txt_tokens). + timestep : torch.LongTensor, optional + Timestep tensor. + height : int, optional + Image height. + width : int, optional + Image width. + skip_first_layer : bool, optional + Whether to skip the first layer. + + Returns + ------- + torch.Tensor + Output tensor after passing through the quantized transformer blocks. + """ + batch_size = hidden_states.shape[0] + img_tokens = hidden_states.shape[1] + txt_tokens = encoder_hidden_states.shape[1] + + original_dtype = hidden_states.dtype + original_device = hidden_states.device + + assert encoder_attention_mask is not None + assert encoder_attention_mask.shape == (batch_size, 1, txt_tokens) + + mask = encoder_attention_mask.reshape(batch_size, txt_tokens) + nunchaku_encoder_hidden_states = encoder_hidden_states[mask > -9000] + + cu_seqlens_txt = F.pad((mask > -9000).sum(dim=1).cumsum(dim=0), pad=(1, 0), value=0).to(torch.int32) + cu_seqlens_img = torch.arange( + 0, (batch_size + 1) * img_tokens, img_tokens, dtype=torch.int32, device=self.device + ) + + if height is None and width is None: + height = width = int(img_tokens**0.5) + elif height is None: + height = img_tokens // width + elif width is None: + width = img_tokens // height + assert height * width == img_tokens + + return ( + self.m.forward( + hidden_states.to(self.dtype).to(self.device), + nunchaku_encoder_hidden_states.to(self.dtype).to(self.device), + timestep.to(self.dtype).to(self.device), + cu_seqlens_img.to(self.device), + cu_seqlens_txt.to(self.device), + height, + width, + batch_size % 3 == 0, # pag is set when loading the model, FIXME: pag_scale == 0 + True, # TODO: find a way to detect if we are doing CFG + skip_first_layer, + ) + .to(original_dtype) + .to(original_device) + ) + + def forward_layer_at( + self, + idx: int, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + height: Optional[int] = None, + width: Optional[int] = None, + ): + """ + Forward pass through a specific quantized transformer layer. + + Parameters + ---------- + idx : int + Index of the layer to run. + hidden_states : torch.Tensor + Input hidden states. + attention_mask : torch.Tensor, optional + Not used. + encoder_hidden_states : torch.Tensor, optional + Encoder hidden states. + encoder_attention_mask : torch.Tensor, optional + Encoder attention mask. + timestep : torch.LongTensor, optional + Timestep tensor. + height : int, optional + Image height. + width : int, optional + Image width. + + Returns + ------- + torch.Tensor + Output tensor after passing through the specified quantized transformer layer. + """ + batch_size = hidden_states.shape[0] + img_tokens = hidden_states.shape[1] + txt_tokens = encoder_hidden_states.shape[1] + + original_dtype = hidden_states.dtype + original_device = hidden_states.device + + assert encoder_attention_mask is not None + assert encoder_attention_mask.shape == (batch_size, 1, txt_tokens) + + mask = encoder_attention_mask.reshape(batch_size, txt_tokens) + nunchaku_encoder_hidden_states = encoder_hidden_states[mask > -9000] + + cu_seqlens_txt = F.pad((mask > -9000).sum(dim=1).cumsum(dim=0), pad=(1, 0), value=0).to(torch.int32) + cu_seqlens_img = torch.arange( + 0, (batch_size + 1) * img_tokens, img_tokens, dtype=torch.int32, device=self.device + ) + + if height is None and width is None: + height = width = int(img_tokens**0.5) + elif height is None: + height = img_tokens // width + elif width is None: + width = img_tokens // height + assert height * width == img_tokens + + return ( + self.m.forward_layer( + idx, + hidden_states.to(self.dtype).to(self.device), + nunchaku_encoder_hidden_states.to(self.dtype).to(self.device), + timestep.to(self.dtype).to(self.device), + cu_seqlens_img.to(self.device), + cu_seqlens_txt.to(self.device), + height, + width, + batch_size % 3 == 0, # pag is set when loading the model, FIXME: pag_scale == 0 + True, # TODO: find a way to detect if we are doing CFG + ) + .to(original_dtype) + .to(original_device) + ) + + def __del__(self): + """ + Destructor to reset the quantized model and free resources. + """ + self.m.reset() + + +class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoaderMixin): + """ + SanaTransformer2DModel with Nunchaku quantized backend support. + + This class extends the base SanaTransformer2DModel to support loading and + injecting quantized transformer blocks using Nunchaku's custom backend. + """ + + @classmethod + @utils.validate_hf_hub_args + def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs): + """ + Load a pretrained NunchakuSanaTransformer2DModel from a local file or HuggingFace Hub. + + This method supports both quantized and unquantized checkpoints, and will + automatically inject quantized transformer blocks if available. + + Parameters + ---------- + pretrained_model_name_or_path : str or os.PathLike + Path to the model checkpoint or HuggingFace Hub model name. + **kwargs + Additional keyword arguments for model loading. + + Returns + ------- + NunchakuSanaTransformer2DModel or (NunchakuSanaTransformer2DModel, dict) + The loaded model, and optionally metadata if ``return_metadata=True``. + """ + device = kwargs.get("device", "cuda") + if isinstance(device, str): + device = torch.device(device) + pag_layers = kwargs.get("pag_layers", []) + precision = get_precision(kwargs.get("precision", "auto"), device, pretrained_model_name_or_path) + metadata = None + + if isinstance(pretrained_model_name_or_path, str): + pretrained_model_name_or_path = Path(pretrained_model_name_or_path) + if pretrained_model_name_or_path.is_file() or pretrained_model_name_or_path.name.endswith( + (".safetensors", ".sft") + ): + transformer, model_state_dict, metadata = cls._build_model(pretrained_model_name_or_path) + quantized_part_sd = {} + unquantized_part_sd = {} + for k, v in model_state_dict.items(): + if k.startswith("transformer_blocks."): + quantized_part_sd[k] = v + else: + unquantized_part_sd[k] = v + m = load_quantized_module( + transformer, quantized_part_sd, device=device, pag_layers=pag_layers, use_fp4=precision == "fp4" + ) + transformer.inject_quantized_module(m, device) + transformer.to_empty(device=device) + transformer.load_state_dict(unquantized_part_sd, strict=False) + else: + transformer, unquantized_part_path, transformer_block_path = cls._build_model_legacy( + pretrained_model_name_or_path, **kwargs + ) + m = load_quantized_module( + transformer, transformer_block_path, device=device, pag_layers=pag_layers, use_fp4=precision == "fp4" + ) + transformer.inject_quantized_module(m, device) + transformer.to_empty(device=device) + unquantized_state_dict = load_file(unquantized_part_path) + transformer.load_state_dict(unquantized_state_dict, strict=False) + if kwargs.get("return_metadata", False): + return transformer, metadata + else: + return transformer + + def inject_quantized_module(self, m: QuantizedSanaModel, device: str | torch.device = "cuda"): + """ + Inject a quantized transformer module into this model. + + Parameters + ---------- + m : QuantizedSanaModel + The quantized transformer module to inject. + device : str or torch.device, optional + The device to place the module on (default: "cuda"). + + Returns + ------- + NunchakuSanaTransformer2DModel + The model with the quantized module injected. + """ + self.transformer_blocks = torch.nn.ModuleList([NunchakuSanaTransformerBlocks(m, self.dtype, device)]) + return self + + +def load_quantized_module( + net: SanaTransformer2DModel, + path_or_state_dict: str | os.PathLike[str] | dict[str, torch.Tensor], + device: str | torch.device = "cuda", + pag_layers: int | list[int] | None = None, + use_fp4: bool = False, +) -> QuantizedSanaModel: + """ + Load quantized weights into a QuantizedSanaModel. + + Parameters + ---------- + net : SanaTransformer2DModel + The base transformer model (for config and dtype). + path_or_state_dict : str, os.PathLike, or dict + Path to the quantized weights or a state dict. + device : str or torch.device, optional + Device to load the quantized model on (default: "cuda"). + pag_layers : int, list of int, or None, optional + List of layers to use pag (default: None). + use_fp4 : bool, optional + Whether to use FP4 quantization (default: False). + + Returns + ------- + QuantizedSanaModel + The loaded quantized model. + """ + if pag_layers is None: + pag_layers = [] + elif isinstance(pag_layers, int): + pag_layers = [pag_layers] + device = torch.device(device) + assert device.type == "cuda" + + m = QuantizedSanaModel() + cutils.disable_memory_auto_release() + m.init(net.config, pag_layers, use_fp4, net.dtype == torch.bfloat16, 0 if device.index is None else device.index) + if isinstance(path_or_state_dict, dict): + m.loadDict(path_or_state_dict, True) + else: + m.load(str(path_or_state_dict)) + return m + + +def inject_quantized_module( + net: SanaTransformer2DModel, m: QuantizedSanaModel, device: torch.device +) -> SanaTransformer2DModel: + """ + Inject a quantized transformer module into a SanaTransformer2DModel. + + Parameters + ---------- + net : SanaTransformer2DModel + The base transformer model. + m : QuantizedSanaModel + The quantized transformer module to inject. + device : torch.device + The device to place the module on. + + Returns + ------- + SanaTransformer2DModel + The model with the quantized module injected. + """ + net.transformer_blocks = torch.nn.ModuleList([NunchakuSanaTransformerBlocks(m, net.dtype, device)]) + return net diff --git a/nunchaku/models/transformers/utils.py b/nunchaku/models/transformers/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..179ddd4dbb0615113f6dd5c71483a8b859705fa9 --- /dev/null +++ b/nunchaku/models/transformers/utils.py @@ -0,0 +1,147 @@ +""" +Utilities for Nunchaku transformer model loading. +""" + +import json +import logging +import os +from pathlib import Path + +import torch +from diffusers import __version__ +from huggingface_hub import constants, hf_hub_download +from torch import nn + +from ...utils import load_state_dict_in_safetensors + +# Get log level from environment variable (default to INFO) +log_level = os.getenv("LOG_LEVEL", "INFO").upper() + +# Configure logging +logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +class NunchakuModelLoaderMixin: + """ + Mixin for standardized model loading in Nunchaku transformer models. + """ + + @classmethod + def _build_model( + cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs + ) -> tuple[nn.Module, dict[str, torch.Tensor], dict[str, str]]: + """ + Build a transformer model from a safetensors file. + + Parameters + ---------- + pretrained_model_name_or_path : str or os.PathLike + Path to the safetensors file. + **kwargs + Additional keyword arguments (e.g., ``torch_dtype``). + + Returns + ------- + tuple + (transformer, state_dict, metadata) + """ + if isinstance(pretrained_model_name_or_path, str): + pretrained_model_name_or_path = Path(pretrained_model_name_or_path) + state_dict, metadata = load_state_dict_in_safetensors(pretrained_model_name_or_path, return_metadata=True) + + config = json.loads(metadata["config"]) + + with torch.device("meta"): + transformer = cls.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16)) + + return transformer, state_dict, metadata + + @classmethod + def _build_model_legacy( + cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs + ) -> tuple[nn.Module, str, str]: + """ + Build a transformer model from a legacy folder structure. + + .. warning:: + This method is deprecated and will be removed in December 2025. + Please use :meth:`_build_model` instead. + + Parameters + ---------- + pretrained_model_name_or_path : str or os.PathLike + Path to the folder containing model weights. + **kwargs + Additional keyword arguments for HuggingFace Hub download and config loading. + + Returns + ------- + tuple + (transformer, unquantized_part_path, transformer_block_path) + """ + logger.warning( + "Loading models from a folder will be deprecated in December 2025. " + "Please download the latest safetensors model, or use one of the following tools to " + "merge your model into a single file: the CLI utility `python -m nunchaku.merge_safetensors` " + "or the ComfyUI workflow `merge_safetensors.json`." + ) + subfolder = kwargs.get("subfolder", None) + if os.path.exists(pretrained_model_name_or_path): + dirname = ( + pretrained_model_name_or_path + if subfolder is None + else os.path.join(pretrained_model_name_or_path, subfolder) + ) + unquantized_part_path = os.path.join(dirname, "unquantized_layers.safetensors") + transformer_block_path = os.path.join(dirname, "transformer_blocks.safetensors") + else: + download_kwargs = { + "subfolder": subfolder, + "repo_type": "model", + "revision": kwargs.get("revision", None), + "cache_dir": kwargs.get("cache_dir", None), + "local_dir": kwargs.get("local_dir", None), + "user_agent": kwargs.get("user_agent", None), + "force_download": kwargs.get("force_download", False), + "proxies": kwargs.get("proxies", None), + "etag_timeout": kwargs.get("etag_timeout", constants.DEFAULT_ETAG_TIMEOUT), + "token": kwargs.get("token", None), + "local_files_only": kwargs.get("local_files_only", None), + "headers": kwargs.get("headers", None), + "endpoint": kwargs.get("endpoint", None), + "resume_download": kwargs.get("resume_download", None), + "force_filename": kwargs.get("force_filename", None), + "local_dir_use_symlinks": kwargs.get("local_dir_use_symlinks", "auto"), + } + unquantized_part_path = hf_hub_download( + repo_id=str(pretrained_model_name_or_path), filename="unquantized_layers.safetensors", **download_kwargs + ) + transformer_block_path = hf_hub_download( + repo_id=str(pretrained_model_name_or_path), filename="transformer_blocks.safetensors", **download_kwargs + ) + + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + config, _, _ = cls.load_config( + pretrained_model_name_or_path, + subfolder=subfolder, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + user_agent={"diffusers": __version__, "file_type": "model", "framework": "pytorch"}, + **kwargs, + ) + + with torch.device("meta"): + transformer = cls.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16)) + return transformer, unquantized_part_path, transformer_block_path diff --git a/nunchaku/models/utils.py b/nunchaku/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..090540ee37539ddaf229980d3331fec5aab591fd --- /dev/null +++ b/nunchaku/models/utils.py @@ -0,0 +1,262 @@ +""" +Utility functions and classes for efficient transformer model management in Nunchaku. +""" + +import copy + +import torch +from torch import nn + +from ..utils import copy_params_into + + +def fuse_linears(linears: list[nn.Linear]) -> nn.Linear: + """ + Fuse a list of nn.Linear layers into a single nn.Linear with concatenated output features. + + Parameters + ---------- + linears : list of nn.Linear + List of linear layers to fuse. All must have the same input feature dimension. + + Returns + ------- + fused : nn.Linear + A new linear layer with concatenated output features and the same input features. + + Raises + ------ + AssertionError + If the input feature dimensions do not match. + + Notes + ----- + The fused layer does not copy weights or biases from the input layers. + """ + assert len(linears) > 0 + if len(linears) == 1: + return linears[0] + else: + assert all(linear.in_features == linears[0].in_features for linear in linears) + out_features = sum(linear.out_features for linear in linears) + bias = all(linear.bias is not None for linear in linears) + return nn.Linear( + linears[0].in_features, + out_features, + bias=bias, + dtype=linears[0].weight.dtype, + device=linears[0].weight.device, + ) + + +class CPUOffloadManager: + """ + Manager for per-transformer-block CPU offloading with asynchronous memory operations using a Ping-Pong buffer strategy. + + This class enables memory-efficient inference or training by keeping only a subset + of transformer blocks on GPU, offloading the rest to CPU, and preloading blocks as needed. + + Parameters + ---------- + blocks : list of nn.Module + List of transformer blocks to manage. + device : str or torch.device, optional + Target CUDA device for GPU operations. Default is "cuda". + use_pin_memory : bool, optional + Whether to use pinned memory for faster CPU-to-GPU transfers. Default is True. + on_gpu_modules : list of nn.Module, optional + Additional modules to keep on GPU at all times. Default is []. + num_blocks_on_gpu : int, optional + Number of blocks to keep on GPU simultaneously. Must be > 0. Default is 1. + empty_cache_freq : int, optional + Frequency (in forward passes) to call torch.cuda.empty_cache(). Default is 0 (never). + + Attributes + ---------- + blocks : list of nn.Module + The managed transformer blocks. + buffer_blocks : list of nn.Module + Buffers for preloading blocks onto GPU. + device : torch.device + The current CUDA device. + current_block_idx : int + Index of the current block on GPU. + forward_counter : int + Number of forward passes completed. + memory_stream : torch.cuda.Stream + CUDA stream for memory operations. + compute_done : torch.cuda.Event + CUDA event signaling compute completion. + memory_done : torch.cuda.Event + CUDA event signaling memory completion. + """ + + def __init__( + self, + blocks: list[nn.Module], + device: str | torch.device = torch.device("cuda"), + use_pin_memory: bool = True, + on_gpu_modules: list[nn.Module] = [], + num_blocks_on_gpu: int = 1, + empty_cache_freq: int = 0, + ): + self.blocks = blocks + self.use_pin_memory = use_pin_memory + self.on_gpu_modules = on_gpu_modules + self.num_blocks_on_gpu = num_blocks_on_gpu + assert self.num_blocks_on_gpu > 0 + + # Two streams: one for compute, one for memory operations, will be initialized in set_device + self.memory_stream = None + + self.compute_done = torch.cuda.Event(blocking=False) + self.memory_done = torch.cuda.Event(blocking=False) + + self.buffer_blocks = [copy.deepcopy(blocks[0]), copy.deepcopy(blocks[0])] + + self.device = None + self.set_device(device) + + self.current_block_idx = 0 + self.forward_counter = 0 + self.empty_cache_freq = empty_cache_freq + + def set_device(self, device: torch.device | str, force: bool = False): + """ + Set the CUDA device for offloading and memory operations. + It will move buffer blocks and on-GPU modules to the specified device and offload other blocks to CPU, optionally using pinned memory. + + Parameters + ---------- + device : torch.device or str + Target CUDA device. + force : bool, optional + If True, force re-initialization even if device is unchanged. Default is False. + + Raises + ------ + AssertionError + If the device is not a CUDA device. + """ + if isinstance(device, str): + device = torch.device(device) + assert device.type == "cuda" + if self.device == device and not force: + return + self.device = device + self.memory_stream = torch.cuda.Stream(device=device) + for block in self.buffer_blocks: + block.to(device) + for module in self.on_gpu_modules: + module.to(device) + for i, block in enumerate(self.blocks): + if i < self.num_blocks_on_gpu: + block.to(device) + else: + block.to("cpu") + if self.use_pin_memory: + for p in block.parameters(recurse=True): + p.data = p.data.pin_memory() + for b in block.buffers(recurse=True): + b.data = b.data.pin_memory() + + def load_block(self, block_idx: int, non_blocking: bool = True): + """ + Move a transformer block from CPU to GPU buffer. + + Parameters + ---------- + block_idx : int + Index of the block to load. + non_blocking : bool, optional + Whether to use non-blocking memory copy. Default is True. + + Notes + ----- + - No action is taken if the block is already on GPU or index is out of range. + """ + # if the block is already on GPU, don't load it to the buffer + if block_idx < self.num_blocks_on_gpu: + return + # if there are blocks on GPU, don't load the first block to the buffer again + if block_idx >= len(self.blocks): + return + + block = self.blocks[block_idx] + copy_params_into(block, self.buffer_blocks[block_idx % 2], non_blocking=non_blocking) + + def step(self, compute_stream: torch.cuda.Stream | None = None): + """ + Advance to the next transformer block, triggering asynchronous preloading. + + It will preload the next block onto GPU in the background and synchronize between compute and memory streams. + After all the blocks are processed, it will call torch.cuda.empty_cache() periodically if ``empty_cache_freq`` > 0. + + Parameters + ---------- + compute_stream : torch.cuda.Stream, optional + CUDA stream for compute operations. If None, uses current stream. + """ + if compute_stream is None: + compute_stream = torch.cuda.current_stream() + next_compute_done = torch.cuda.Event() + next_compute_done.record(compute_stream) + with torch.cuda.stream(self.memory_stream): + self.memory_stream.wait_event(self.compute_done) + self.load_block(self.current_block_idx + 1) # if the current block is the last block, load the first block + next_memory_done = torch.cuda.Event() + next_memory_done.record(self.memory_stream) + self.memory_done = next_memory_done + self.compute_done = next_compute_done + self.current_block_idx += 1 + if self.current_block_idx < len(self.blocks): + # get ready for the next compute + compute_stream.wait_event(self.memory_done) + else: + # ready to finish + compute_stream.wait_event(self.compute_done) + self.current_block_idx = 0 + self.forward_counter += 1 + if self.empty_cache_freq > 0 and self.forward_counter % self.empty_cache_freq == 0: + torch.cuda.empty_cache() + + def get_block(self, block_idx: int | None = None) -> nn.Module: + """ + Retrieve the current or specified transformer block for computation. + It will return a buffer block if the requested block is offloaded. + + Parameters + ---------- + block_idx : int, optional + Index of the block to retrieve. If None, returns the current block. + + Returns + ------- + block : nn.Module + The requested transformer block (on GPU if needed). + """ + if block_idx is None: + block_idx = self.current_block_idx + if block_idx < self.num_blocks_on_gpu: + return self.blocks[block_idx] + else: + return self.buffer_blocks[block_idx % 2] + + def initialize(self, stream: torch.cuda.Stream | None = None): + """ + Initialize CUDA events for compute and memory streams. + It will record the initial events for the compute and memory streams. + + Parameters + ---------- + stream : torch.cuda.Stream, optional + CUDA stream to record initial events. If None, uses current stream. + + Notes + ----- + - Should be called before the first forward pass. + """ + if stream is None: + stream = torch.cuda.current_stream() + self.compute_done.record(stream) + self.memory_done.record(stream) diff --git a/nunchaku/ops/__init__.py b/nunchaku/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2123e332cf157a57ea9f97bc83435704d731983e --- /dev/null +++ b/nunchaku/ops/__init__.py @@ -0,0 +1 @@ +# Quantized operations for FLUX-Kontext diff --git a/nunchaku/ops/fused.py b/nunchaku/ops/fused.py new file mode 100644 index 0000000000000000000000000000000000000000..b58a30aae6c12d505d290f175ea8478ecf11e771 --- /dev/null +++ b/nunchaku/ops/fused.py @@ -0,0 +1,178 @@ +""" +High-performance fused operators for quantized neural network inference. +""" + +import torch +from torch.nn import RMSNorm + +from nunchaku.models.linear import SVDQW4A4Linear + +from ..utils import ceil_divide +from .gemm import svdq_gemm_w4a4_cuda + + +def fused_gelu_mlp(x: torch.Tensor, fc1: SVDQW4A4Linear, fc2: SVDQW4A4Linear, pad_size: int = 256) -> torch.Tensor: + """ + Fused quantized MLP with GELU activation. + + Combines the first quantized linear layer, GELU activation, and the second quantized linear layer into a single CUDA kernel. Supports INT4 and NVFP4 quantization. + + Parameters + ---------- + x : torch.Tensor, shape (B, S, C_in), dtype float16 or bfloat16 + Input tensor. + fc1 : SVDQW4A4Linear + First quantized linear layer (input → hidden). + fc2 : SVDQW4A4Linear + Second quantized linear layer (hidden → output). + pad_size : int, optional + Batch padding size for CUDA kernel efficiency. Default is 256. + + Returns + ------- + torch.Tensor, shape (B, S, C_out), dtype as input + Output tensor. + + Notes + ----- + - Notations: + + - B: batch size + - S: sequence length + - C_in: input features + - C_out: output features + - For INT4 quantization, GELU activations are shifted by 0.171875 to ensure non-negativity, enabling unsigned quantization for improved quality. See: https://github.com/nunchaku-tech/nunchaku/blob/433f0b228a61a53fb700ac676fd2e290368ac94d/src/kernels/zgemm/gemm_w4a4_launch_impl.cuh#L286 + """ + batch_size, seq_len, channels = x.shape + x = x.view(batch_size * seq_len, channels) + quantized_x, ascales, lora_act = fc1.quantize(x) + + batch_size_pad = ceil_divide(batch_size * seq_len, pad_size) * pad_size + + qout_act = torch.empty(batch_size_pad, fc1.out_features // 2, dtype=torch.uint8, device=x.device) + if fc2.precision == "nvfp4": + qout_ascales = torch.empty(fc1.out_features // 16, batch_size_pad, dtype=torch.float8_e4m3fn, device=x.device) + else: + qout_ascales = torch.empty(fc1.out_features // 64, batch_size_pad, dtype=x.dtype, device=x.device) + qout_lora_act = torch.empty(batch_size_pad, fc2.proj_down.shape[1], dtype=torch.float32, device=x.device) + + svdq_gemm_w4a4_cuda( + act=quantized_x, + wgt=fc1.qweight, + qout=qout_act, + ascales=ascales, + wscales=fc1.wscales, + oscales=qout_ascales, + lora_act_in=lora_act, + lora_up=fc1.proj_up, + lora_down=fc2.proj_down, + lora_act_out=qout_lora_act, + bias=fc1.bias, + smooth_factor=fc2.smooth_factor, + fp4=fc1.precision == "nvfp4", + alpha=fc1.wtscale, + wcscales=fc1.wcscales, + ) + output = torch.empty(batch_size * seq_len, fc2.out_features, dtype=x.dtype, device=x.device) + output = fc2.forward_quant(qout_act, qout_ascales, qout_lora_act, output=output) + output = output.view(batch_size, seq_len, -1) + return output + + +def fused_qkv_norm_rottary( + x: torch.Tensor, + proj: SVDQW4A4Linear, + norm_q: RMSNorm, + norm_k: RMSNorm, + rotary_emb: torch.Tensor, + output: torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + attn_tokens: int = 0, +) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Fused quantized QKV projection with RMSNorm and rotary embeddings. + + Performs quantized QKV projection, applies RMS normalization to Q and K, and fuses rotary embeddings in a single CUDA kernel call. + + Parameters + ---------- + x : torch.Tensor, shape (B, S, C_in), dtype float16 or bfloat16 + Input tensor. + proj : SVDQW4A4Linear + Quantized QKV projection layer. + norm_q : RMSNorm + RMSNorm for query. + norm_k : RMSNorm + RMSNorm for key. + rotary_emb : torch.Tensor + Packed rotary embedding tensor (see :func:`~nunchaku.models.embeddings.pack_rotemb`). + output : torch.Tensor or tuple of torch.Tensor, optional + Output tensor(s). If None, a new tensor is allocated. + If tuple, should be (output_q, output_k, output_v) for fused attention packing. + attn_tokens : int, optional + Number of attention tokens. Default is 0. + + Returns + ------- + torch.Tensor or tuple of torch.Tensor + Output tensor of shape (B, S, C_out), or tuple (output_q, output_k, output_v). + + Notes + ----- + Notations: + - B: batch size + - S: sequence length + - C_in: input features + - C_out: output features + """ + assert isinstance(norm_q, RMSNorm) + assert isinstance(norm_k, RMSNorm) + + batch_size, seq_len, channels = x.shape + x = x.view(batch_size * seq_len, channels) + quantized_x, ascales, lora_act = proj.quantize(x) + + if output is None: + output = torch.empty(quantized_x.shape[0], proj.out_features, dtype=x.dtype, device=x.device) + + if isinstance(output, tuple): + assert len(output) == 3 + output_q, output_k, output_v = output + svdq_gemm_w4a4_cuda( + act=quantized_x, + wgt=proj.qweight, + ascales=ascales, + wscales=proj.wscales, + lora_act_in=lora_act, + lora_up=proj.proj_up, + bias=proj.bias, + fp4=proj.precision == "nvfp4", + alpha=proj.wtscale, + wcscales=proj.wcscales, + norm_q=norm_q.weight, + norm_k=norm_k.weight, + rotary_emb=rotary_emb, + out_q=output_q, + out_k=output_k, + out_v=output_v, + attn_tokens=attn_tokens, + ) + return output_q, output_k, output_v + else: + svdq_gemm_w4a4_cuda( + act=quantized_x, + wgt=proj.qweight, + out=output, + ascales=ascales, + wscales=proj.wscales, + lora_act_in=lora_act, + lora_up=proj.proj_up, + bias=proj.bias, + fp4=proj.precision == "nvfp4", + alpha=proj.wtscale, + wcscales=proj.wcscales, + norm_q=norm_q.weight, + norm_k=norm_k.weight, + rotary_emb=rotary_emb, + ) + output = output.view(batch_size, seq_len, -1) + return output diff --git a/nunchaku/ops/gemm.py b/nunchaku/ops/gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..be0b39f9fd18fd7c9a8b2b198f568aa6c61fc9bd --- /dev/null +++ b/nunchaku/ops/gemm.py @@ -0,0 +1,160 @@ +""" +Python wrappers for Nunchaku's high-performance quantized GEMM (General Matrix-Matrix Multiplication) CUDA kernels. +""" + +import math + +import torch + +from .._C import ops + + +def svdq_gemm_w4a4_cuda( + act: torch.Tensor, + wgt: torch.Tensor, + out: torch.Tensor | None = None, + qout: torch.Tensor | None = None, + ascales: torch.Tensor | None = None, + wscales: torch.Tensor | None = None, + oscales: torch.Tensor | None = None, + poolout: torch.Tensor | None = None, + lora_act_in: torch.Tensor | None = None, + lora_up: torch.Tensor | None = None, + lora_down: torch.Tensor | None = None, + lora_act_out: torch.Tensor | None = None, + norm_q: torch.Tensor | None = None, + norm_k: torch.Tensor | None = None, + rotary_emb: torch.Tensor | None = None, + bias: torch.Tensor | None = None, + smooth_factor: torch.Tensor | None = None, + out_vk: torch.Tensor | None = None, + out_linearattn: torch.Tensor | None = None, + act_unsigned: bool = False, + lora_scales: list[float] | None = None, + fuse_silu: bool = False, + fp4: bool = False, + alpha: float | None = 1.0, + wcscales: torch.Tensor | None = None, + out_q: torch.Tensor | None = None, + out_k: torch.Tensor | None = None, + out_v: torch.Tensor | None = None, + attn_tokens: int = 0, +): + """ + Quantized GEMM using SVDQuant W4A4 CUDA kernel, with support for LoRA, rotary embeddings, normalization, and fused activations. + + Parameters + ---------- + act : torch.Tensor, shape (M, K // 2), dtype int8 + Packed input activations. + wgt : torch.Tensor, shape (N, K // 2), dtype int8 + Packed quantized weights. + out : torch.Tensor or None, shape (M, N), dtype float16 or bfloat16, optional + Output tensor for the linear layer. + qout : torch.Tensor or None, shape (M, N // 2), dtype int8, optional + Packed quantized input for the next layer. + ascales : torch.Tensor or None, shape (K // G, M), dtype float16/bfloat16 (INT4) or float8_e4m3fn (NVFP4), optional + Activation scales. + wscales : torch.Tensor or None, shape (K // G, N), dtype float16/bfloat16 (INT4) or float8_e4m3fn (NVFP4), optional + Weight scales. + oscales : torch.Tensor or None, shape (N // G, M), dtype float16/bfloat16 (INT4) or float8_e4m3fn (NVFP4), optional + Output scales. + poolout : torch.Tensor or None, optional + Reserved for future use. + lora_act_in : torch.Tensor or None, shape (M, R), dtype float32, optional + LoRA down-projection activations. + lora_up : torch.Tensor or None, shape (N, R), dtype float16 or bfloat16, optional + Packed LoRA up-projection weights. + lora_down : torch.Tensor or None, shape (N, R), dtype float16 or bfloat16, optional + Packed LoRA down-projection weights for the next layer. + lora_act_out : torch.Tensor or None, shape (M, R), dtype float32, optional + Output for LoRA down-projection in the next layer. + norm_q : torch.Tensor or None, shape (HEAD_DIM,), dtype float16 or bfloat16, optional + Query RMS normalization. + norm_k : torch.Tensor or None, shape (HEAD_DIM,), dtype float16 or bfloat16, optional + Key RMS normalization. + rotary_emb : torch.Tensor or None, shape (M, HEAD_DIM // 2, 2, 2), dtype float32, optional + Packed rotary embeddings. + bias : torch.Tensor or None, shape (N,), dtype float16 or bfloat16, optional + Bias tensor. + smooth_factor : torch.Tensor or None, shape (N,), dtype float16 or bfloat16, optional + Smoothing factor for quantization in the next layer. + out_vk : torch.Tensor or None, optional + Used only in SANA. Leave as None. + out_linearattn : torch.Tensor or None, optional + Used only in SANA. Leave as None. + act_unsigned : bool, default=False + If True, activations are unsigned (e.g., after GeLU, shifted by 0.171875). This is only used for INT4 to enable unsigned INT4 activation quantization for better quantization quality. + lora_scales : list of float or None, optional + Per-group LoRA scaling factors (16 channels per group). Defaults to 1.0 per group. + fuse_silu : bool, default=False + If True, fuse SiLU activation. + fp4 : bool, default=False + If True, use 4-bit floating point quantization (NVFP4). + alpha : float or None, default=1.0 + Per-tensor scaling factor for NVFP4. + wcscales : torch.Tensor or None, shape (N,), dtype float8_e4m3fn, optional + Per-channel scaling for NVFP4. + out_q : torch.Tensor or None, shape (B, H, M, D), dtype int8, optional + Packed quantized Q for attention (used in ``nunchaku-fp16`` attention). + out_k : torch.Tensor or None, shape (B, H, M, D), dtype int8, optional + Packed quantized K for attention (used in ``nunchaku-fp16`` attention). + out_v : torch.Tensor or None, shape (B, H, M, D), dtype int8, optional + Packed quantized V for attention (used in ``nunchaku-fp16`` attention). + attn_tokens : int, default=0 + Number of attention tokens. + + Returns + ------- + None + Results are written in-place to the provided output tensors. + + Notes + ----- + Notations: + + - M: batch size (input tokens) + - K: input channels (feature dimension) + - N: output channels + - G: group size (64 for INT4, 16 for NVFP4) + - R: LoRA rank + - B: batch size for attention + - H: number of heads + - D: head dimension + """ + if lora_scales is None: + rank = lora_up.shape[1] + lora_scales = [1.0] * math.ceil(rank / 16) + if alpha is None: + alpha = 1.0 + ops.gemm_w4a4( + act, + wgt, + out, + qout, + ascales, + wscales, + oscales, + poolout, + lora_act_in, + lora_up, + lora_down, + lora_act_out, + norm_q, + norm_k, + rotary_emb, + bias, + smooth_factor, + out_vk, + out_linearattn, + act_unsigned, + lora_scales, + fuse_silu, + fp4, + alpha, + wcscales, + out_q, + out_k, + out_v, + attn_tokens, + ) diff --git a/nunchaku/ops/gemv.py b/nunchaku/ops/gemv.py new file mode 100644 index 0000000000000000000000000000000000000000..b0658ba9264ae668ed460ade8ff4987b668161bc --- /dev/null +++ b/nunchaku/ops/gemv.py @@ -0,0 +1,56 @@ +""" +Python wrapper for Nunchaku's high-performance GEMV (General Matrix-Vector Multiplication) CUDA kernels. +""" + +import torch + +from .._C import ops + + +def awq_gemv_w4a16_cuda( + in_feats: torch.Tensor, + kernel: torch.Tensor, + scaling_factors: torch.Tensor, + zeros: torch.Tensor, + m: int, + n: int, + k: int, + group_size: int = 64, +) -> torch.Tensor: + """ + Performs quantized GEMV using the AWQ W4A16 format. + + Parameters + ---------- + in_feats : torch.Tensor, shape (k,) or (m, k), dtype float16 or bfloat16 + Input feature vector or batch of vectors. + kernel : torch.Tensor, shape (n // 4, k // 2), dtype int32 + Packed quantized weight matrix. + scaling_factors : torch.Tensor, shape (k // group_size, n), dtype float16 or bfloat16 + Per-group scaling factors. + zeros : torch.Tensor, shape (k // group_size, n), dtype float16 or bfloat16 + Per-group zero points. + m : int + Batch size (number of input vectors). + n : int + Output feature dimension. + k : int + Input feature dimension. + group_size : int, optional + Number of input channels per quantization group. Default is 64. + + Returns + ------- + torch.Tensor, shape (m, n), dtype float16 or bfloat16 + Output tensor. + + Notes + ----- + Notations: + + - m: batch size + - n: output features + - k: input features + - group_size: quantization group size + """ + return ops.gemv_awq(in_feats, kernel, scaling_factors, zeros, m, n, k, group_size) diff --git a/nunchaku/ops/quantize.py b/nunchaku/ops/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..fdebb68f8af73af6bcecc7f72b463d067ab97a3f --- /dev/null +++ b/nunchaku/ops/quantize.py @@ -0,0 +1,81 @@ +""" +This module provides Python wrappers for Nunchaku's high-performance SVDQuant quantization CUDA kernels. +""" + +import torch + +from .._C import ops +from ..utils import ceil_divide + + +def svdq_quantize_w4a4_act_fuse_lora_cuda( + input: torch.Tensor, + output: torch.Tensor | None = None, + oscales: torch.Tensor | None = None, + lora_down: torch.Tensor | None = None, + lora_act_out: torch.Tensor | None = None, + smooth: torch.Tensor | None = None, + fuse_glu: bool = False, + fp4: bool = False, + pad_size: int = 256, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Quantizes activations and computes LoRA down-projection using SVDQuant W4A4 CUDA kernel. + + Parameters + ---------- + input : torch.Tensor, shape (M, K), dtype bfloat16/float16 + Input activations. + output : torch.Tensor or None, shape (M_pad, K // 2), dtype uint8, optional + Packed output tensor for quantized activations. Allocated if None. + oscales : torch.Tensor or None, shape (K // G, M_pad), dtype float8_e4m3fn for NVFP4 or input dtype for INT4, optional + Output scales tensor. Allocated if None. + lora_down : torch.Tensor or None, shape (K, R), dtype bfloat16/float16, optional + Packed LoRA down-projection weights. + lora_act_out : torch.Tensor or None, shape (M_pad, R), dtype float32, optional + Packed output tensor for LoRA activations. Allocated if None. + smooth : torch.Tensor or None, optional, dtype bfloat16/float16 + Smoothing factor for quantization. + fuse_glu : bool, default=False + If True, fuse GLU activation. + fp4 : bool, default=False + If True, use NVFP4 quantization; else INT4. + pad_size : int, default=256 + Pad batch size to a multiple of this value for efficient CUDA execution. + + Returns + ------- + output : torch.Tensor, shape (M_pad, K // 2), dtype uint8 + Packed quantized activations. + oscales : torch.Tensor, shape (K // G, M_pad), dtype float8_e4m3fn for NVFP4 or input dtype for INT4 + Output scales. + lora_act_out : torch.Tensor, shape (M_pad, R), dtype float32 + Packed LoRA activation output. + + Notes + ----- + Notations: + + - M: batch size + - K: input channels + - R: LoRA rank + - G: group size (64 for INT4, 16 for NVFP4) + - M_pad: padded batch size = ceil(M / pad_size) * pad_size + """ + batch_size, channels = input.shape + rank = lora_down.shape[1] + batch_size_pad = ceil_divide(batch_size, pad_size) * pad_size + if output is None: + output = torch.empty(batch_size_pad, channels // 2, dtype=torch.uint8, device=input.device) + if oscales is None: + if fp4: + assert channels % 16 == 0 + oscales = torch.empty(channels // 16, batch_size_pad, dtype=torch.float8_e4m3fn, device=input.device) + else: + assert channels % 64 == 0 + oscales = torch.empty(channels // 64, batch_size_pad, dtype=input.dtype, device=input.device) + if lora_act_out is None: + lora_act_out = torch.empty(batch_size_pad, rank, dtype=torch.float32, device=input.device) + + ops.quantize_w4a4_act_fuse_lora(input, output, oscales, lora_down, lora_act_out, smooth, fuse_glu, fp4) + return output, oscales, lora_act_out diff --git a/nunchaku/utils.py b/nunchaku/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0dee226fcd6ab188e8a34837808d76cc35f241f3 --- /dev/null +++ b/nunchaku/utils.py @@ -0,0 +1,366 @@ +""" +Utility functions for Nunchaku. +""" + +import hashlib +import os +import warnings +from pathlib import Path +from typing import Any + +import safetensors +import torch +from huggingface_hub import hf_hub_download +from torch import nn + + +def pad_tensor(tensor: torch.Tensor | None, multiples: int, dim: int, fill: Any = 0) -> torch.Tensor | None: + """ + Pad a tensor along a given dimension to the next multiple of a specified value. + + Parameters + ---------- + tensor : torch.Tensor or None + Input tensor. If None, returns None. + multiples : int + Pad to this multiple. If <= 1, no padding is applied. + dim : int + Dimension along which to pad. + fill : Any, optional + Value to use for padding (default: 0). + + Returns + ------- + torch.Tensor or None + The padded tensor, or None if input was None. + """ + if multiples <= 1: + return tensor + if tensor is None: + return None + shape = list(tensor.shape) + if shape[dim] % multiples == 0: + return tensor + shape[dim] = ceil_divide(shape[dim], multiples) * multiples + result = torch.empty(shape, dtype=tensor.dtype, device=tensor.device) + result.fill_(fill) + result[[slice(0, extent) for extent in tensor.shape]] = tensor + return result + + +def sha256sum(filepath: str | os.PathLike[str]) -> str: + """ + Compute the SHA-256 checksum of a file. + + Parameters + ---------- + filepath : str or os.PathLike + Path to the file. + + Returns + ------- + str + The SHA-256 hexadecimal digest of the file. + """ + sha256 = hashlib.sha256() + with open(filepath, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + sha256.update(chunk) + return sha256.hexdigest() + + +def fetch_or_download(path: str | Path, repo_type: str = "model") -> Path: + """ + Fetch a file from a local path or download from HuggingFace Hub if not present. + + The remote path should be in the format: ``/`` or ``//``. + + Parameters + ---------- + path : str or Path + Local file path or HuggingFace Hub path. + repo_type : str, optional + Type of HuggingFace repo (default: "model"). + + Returns + ------- + Path + Path to the local file. + + Raises + ------ + ValueError + If the path is too short to extract repo_id and subfolder. + """ + path = Path(path) + + if path.exists(): + return path + + parts = path.parts + if len(parts) < 3: + raise ValueError(f"Path '{path}' is too short to extract repo_id and subfolder") + + repo_id = "/".join(parts[:2]) + sub_path = Path(*parts[2:]) + filename = sub_path.name + subfolder = str(sub_path.parent) if sub_path.parent != Path(".") else None + + path = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type) + return Path(path) + + +def ceil_divide(x: int, divisor: int) -> int: + """ + Compute the ceiling of x divided by divisor. + + Parameters + ---------- + x : int + Dividend. + divisor : int + Divisor. + + Returns + ------- + int + The smallest integer >= x / divisor. + """ + return (x + divisor - 1) // divisor + + +def load_state_dict_in_safetensors( + path: str | os.PathLike[str], + device: str | torch.device = "cpu", + filter_prefix: str = "", + return_metadata: bool = False, +) -> dict[str, torch.Tensor] | tuple[dict[str, torch.Tensor], dict[str, str]]: + """ + Load a state dict from a safetensors file, optionally filtering by prefix. + + Parameters + ---------- + path : str or os.PathLike + Path to the safetensors file (local or HuggingFace Hub). + device : str or torch.device, optional + Device to load tensors onto (default: "cpu"). + filter_prefix : str, optional + Only load keys starting with this prefix (default: "", no filter). + return_metadata : bool, optional + Whether to return safetensors metadata (default: False). + + Returns + ------- + dict[str, torch.Tensor] or tuple[dict[str, torch.Tensor], dict[str, str]] + The loaded state dict, and optionally the metadata if ``return_metadata`` is True. + """ + state_dict = {} + with safetensors.safe_open(fetch_or_download(path), framework="pt", device=device) as f: + metadata = f.metadata() + for k in f.keys(): + if filter_prefix and not k.startswith(filter_prefix): + continue + state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k) + if return_metadata: + return state_dict, metadata + else: + return state_dict + + +def filter_state_dict(state_dict: dict[str, torch.Tensor], filter_prefix: str = "") -> dict[str, torch.Tensor]: + """ + Filter a state dict to only include keys starting with a given prefix. + + Parameters + ---------- + state_dict : dict[str, torch.Tensor] + The input state dict. + filter_prefix : str, optional + Prefix to filter keys by (default: "", no filter). + + Returns + ------- + dict[str, torch.Tensor] + Filtered state dict with prefix removed from keys. + """ + return {k.removeprefix(filter_prefix): v for k, v in state_dict.items() if k.startswith(filter_prefix)} + + +def get_precision( + precision: str = "auto", + device: str | torch.device = "cuda", + pretrained_model_name_or_path: str | os.PathLike[str] | None = None, +) -> str: + """ + Determine the quantization precision to use based on device and model. + + Parameters + ---------- + precision : str, optional + "auto", "int4", or "fp4" (default: "auto"). + device : str or torch.device, optional + Device to check (default: "cuda"). + pretrained_model_name_or_path : str or os.PathLike or None, optional + Model name or path for warning checks. + + Returns + ------- + str + The selected precision ("int4" or "fp4"). + + Raises + ------ + AssertionError + If precision is not one of "auto", "int4", or "fp4". + """ + assert precision in ("auto", "int4", "fp4") + if precision == "auto": + if isinstance(device, str): + device = torch.device(device) + capability = torch.cuda.get_device_capability(0 if device.index is None else device.index) + sm = f"{capability[0]}{capability[1]}" + precision = "fp4" if sm == "120" else "int4" + if pretrained_model_name_or_path is not None: + if precision == "int4": + if "fp4" in str(pretrained_model_name_or_path): + warnings.warn("The model may be quantized to fp4, but you are loading it with int4 precision.") + elif precision == "fp4": + if "int4" in str(pretrained_model_name_or_path): + warnings.warn("The model may be quantized to int4, but you are loading it with fp4 precision.") + return precision + + +def is_turing(device: str | torch.device = "cuda") -> bool: + """ + Check if the current GPU is a Turing GPU (compute capability 7.5). + + Parameters + ---------- + device : str or torch.device, optional + Device to check (default: "cuda"). + + Returns + ------- + bool + True if the current GPU is a Turing GPU, False otherwise. + """ + if isinstance(device, str): + device = torch.device(device) + device_id = 0 if device.index is None else device.index + capability = torch.cuda.get_device_capability(device_id) + sm = f"{capability[0]}{capability[1]}" + return sm == "75" + + +def get_gpu_memory(device: str | torch.device = "cuda", unit: str = "GiB") -> int: + """ + Get the total memory of the current GPU. + + Parameters + ---------- + device : str or torch.device, optional + Device to check (default: "cuda"). + unit : str, optional + Unit for memory ("GiB", "MiB", or "B") (default: "GiB"). + + Returns + ------- + int + GPU memory in the specified unit. + + Raises + ------ + AssertionError + If unit is not one of "GiB", "MiB", or "B". + """ + if isinstance(device, str): + device = torch.device(device) + assert unit in ("GiB", "MiB", "B") + memory = torch.cuda.get_device_properties(device).total_memory + if unit == "GiB": + return memory // (1024**3) + elif unit == "MiB": + return memory // (1024**2) + else: + return memory + + +def check_hardware_compatibility(quantization_config: dict, device: str | torch.device = "cuda"): + """ + Check if the quantization config is compatible with the current GPU. + + Parameters + ---------- + quantization_config : dict + Quantization configuration dictionary. + device : str or torch.device, optional + Device to check (default: "cuda"). + + Raises + ------ + ValueError + If the quantization config is not compatible with the GPU architecture. + """ + if isinstance(device, str): + device = torch.device(device) + capability = torch.cuda.get_device_capability(0 if device.index is None else device.index) + sm = f"{capability[0]}{capability[1]}" + if sm == "120": # you can only use the fp4 models + if quantization_config["weight"]["dtype"] != "fp4_e2m1_all": + raise ValueError('Please use "fp4" quantization for Blackwell GPUs. ') + elif sm in ["75", "80", "86", "89"]: + if quantization_config["weight"]["dtype"] != "int4": + raise ValueError('Please use "int4" quantization for Turing, Ampere and Ada GPUs. ') + else: + raise ValueError( + f"Unsupported GPU architecture {sm} due to the lack of 4-bit tensorcores. " + "Please use a Turing, Ampere, Ada or Blackwell GPU for this quantization configuration." + ) + + +def get_precision_from_quantization_config(quantization_config: dict) -> str: + """ + Get the precision from the quantization configuration. + """ + if quantization_config["weight"]["dtype"] == "fp4_e2m1_all": + if quantization_config["weight"]["group_size"] == 16: + return "nvfp4" + else: + raise ValueError("Currently, nunchaku only supports nvfp4.") + elif quantization_config["weight"]["dtype"] == "int4": + return "int4" + else: + raise ValueError(f"Unsupported quantization dtype: {quantization_config['weight']['dtype']}") + + +def copy_params_into(src: nn.Module, dst: nn.Module, non_blocking: bool = True): + """ + Copy all parameters and buffers from a source module to a destination module. + + Parameters + ---------- + src : nn.Module + Source module from which parameters and buffers are copied. + dst : nn.Module + Destination module to which parameters and buffers are copied. + non_blocking : bool, optional + If True, copies are performed asynchronously with respect to the host if possible (default: True). + + Notes + ----- + - The function assumes that `src` and `dst` have the same structure and number of parameters and buffers. + - All copying is performed under `torch.no_grad()` context to avoid tracking in autograd. + """ + with torch.no_grad(): + for ps, pd in zip(src.parameters(), dst.parameters()): + pd.copy_(ps, non_blocking=non_blocking) + for bs, bd in zip(src.buffers(), dst.buffers()): + bd.copy_(bs, non_blocking=non_blocking) + + for ms, md in zip(src.modules(), dst.modules()): + # wtscale is a special case which is a float on the CPU + if hasattr(ms, "wtscale"): + assert hasattr(md, "wtscale") + md.wtscale = ms.wtscale + else: + assert not hasattr(md, "wtscale") diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..12ae3ea59f3c09b43b9e1bab9d2296cad4b1fc5a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,51 @@ +[tool.isort] +profile = "black" +known_first_party = ["nunchaku"] +line_length = 120 + +[tool.black] +line-length = 120 +target-version = ['py311'] + +[tool.ruff] +line-length = 120 + +[project] +dynamic = ["version"] +name = "flux-kontext" +description = "Optimized FLUX-Kontext implementation using quantization and acceleration techniques" +dependencies = [ + "diffusers>=0.35.1", + "transformers>=4.53.3", + "accelerate>=1.9.0", + "sentencepiece", + "protobuf", + "huggingface_hub>=0.34", + "peft>=0.17", + "einops", + "torch>=2.5", + "gradio", + "pillow", +] +requires-python = ">=3.10" + +[build-system] +requires = [ + "setuptools", + "torch>=2.5", + "wheel", + "ninja", +] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +include = ["nunchaku"] + +[tool.doc8] +max-line-length = 120 +ignore-path = ["docs/_build"] +ignore = ["D000", "D001"] + +[tool.rstcheck] +ignore_directives = ["tabs"] +ignore_messages = ["ERROR/3", "INFO/1"] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..c3d6bf5934114fff3bc69abfd454015d595e7cff --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +diffusers>=0.35.1 +transformers>=4.53.3 +accelerate>=1.9.0 +sentencepiece +protobuf +huggingface_hub>=0.34 +peft>=0.17 +einops +torch>=2.5 +gradio +pillow +safetensors diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..f285876d200127e145991a59ee04be7a2377a7d4 --- /dev/null +++ b/setup.py @@ -0,0 +1,108 @@ +import os +import re +import subprocess +import sys +from datetime import date + +import setuptools +import torch +from packaging import version as packaging_version +from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension + + +class CustomBuildExtension(BuildExtension): + def build_extensions(self): + for ext in self.extensions: + if not "cxx" in ext.extra_compile_args: + ext.extra_compile_args["cxx"] = [] + if not "nvcc" in ext.extra_compile_args: + ext.extra_compile_args["nvcc"] = [] + if self.compiler.compiler_type == "msvc": + ext.extra_compile_args["cxx"] += ext.extra_compile_args["msvc"] + ext.extra_compile_args["nvcc"] += ext.extra_compile_args["nvcc_msvc"] + else: + ext.extra_compile_args["cxx"] += ext.extra_compile_args["gcc"] + super().build_extensions() + + +def get_sm_targets() -> list[str]: + nvcc_path = os.path.join(CUDA_HOME, "bin/nvcc") if CUDA_HOME else "nvcc" + try: + nvcc_output = subprocess.check_output([nvcc_path, "--version"]).decode() + match = re.search(r"release (\d+\.\d+), V(\d+\.\d+\.\d+)", nvcc_output) + if match: + nvcc_version = match.group(2) + else: + raise Exception("nvcc version not found") + print(f"Found nvcc version: {nvcc_version}") + except: + raise Exception("nvcc not found") + + support_sm120 = packaging_version.parse(nvcc_version) >= packaging_version.parse("12.8") + + install_mode = os.getenv("NUNCHAKU_INSTALL_MODE", "FAST") + if install_mode == "FAST": + ret = [] + for i in range(torch.cuda.device_count()): + capability = torch.cuda.get_device_capability(i) + sm = f"{capability[0]}{capability[1]}" + if sm == "120" and support_sm120: + sm = "120a" + ret.append(sm) + return ret + elif install_mode == "ALL": + # All supported architectures (except for experimental ones) + sm_targets = ["75", "80", "86", "89", "90"] + if support_sm120: + sm_targets.append("120a") + return sm_targets + else: + raise ValueError(f"Unknown install mode: {install_mode}") + + +FLUX_SOURCES = [ + "nunchaku/csrc/pybind.cpp", +] + +ext_modules = [] + +# Check if CUDA is available +if torch.cuda.is_available() and CUDA_HOME is not None: + sm_targets = get_sm_targets() + arch_flags = [f"-gencode=arch=compute_{sm},code=sm_{sm}" for sm in sm_targets] + + ext_modules.append( + CUDAExtension( + "nunchaku._C", + FLUX_SOURCES, + extra_compile_args={ + "cxx": ["-O3", "-std=c++20"], + "nvcc": [ + "-O3", + "-std=c++20", + "--use_fast_math", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + ] + arch_flags, + "msvc": ["/std:c++20"], + "gcc": ["-std=c++20"], + "nvcc_msvc": [], + }, + include_dirs=[ + "third_party/cutlass/include", + "third_party/cutlass/tools/util/include", + ], + ) + ) +else: + print("CUDA not available. Installing CPU-only version.") + +setuptools.setup( + name="flux-kontext", + packages=setuptools.find_packages(), + ext_modules=ext_modules, + cmdclass={"build_ext": CustomBuildExtension}, + zip_safe=False, +) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a1d80bfb08f3ce1206751a84212253dc7719765b --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# FLUX-Kontext test suite diff --git a/tests/test_flux_kontext.py b/tests/test_flux_kontext.py new file mode 100644 index 0000000000000000000000000000000000000000..557ca260c09c328cdeb59af8064dfe0a4364f6c8 --- /dev/null +++ b/tests/test_flux_kontext.py @@ -0,0 +1,94 @@ +import gc +import os +from pathlib import Path + +import pytest +import torch +from diffusers import FluxKontextPipeline +from diffusers.utils import load_image + +from nunchaku import NunchakuFluxTransformer2dModel +from nunchaku.utils import get_precision, is_turing + +from .utils import already_generate, compute_lpips, hash_str_to_int, offload_pipeline + + +@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs") +@pytest.mark.parametrize("expected_lpips", [0.25 if get_precision() == "int4" else 0.18]) +def test_flux_kontext(expected_lpips: float): + gc.collect() + torch.cuda.empty_cache() + + precision = get_precision() + + ref_root = Path(os.environ.get("NUNCHAKU_TEST_CACHE_ROOT", os.path.join("test_results", "ref"))) + results_dir_16_bit = ref_root / "bf16" / "flux.1-kontext-dev" / "kontext" + results_dir_4_bit = Path("test_results") / precision / "flux.1-kontext-dev" / "kontext" + + os.makedirs(results_dir_16_bit, exist_ok=True) + os.makedirs(results_dir_4_bit, exist_ok=True) + + image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png" + ).convert("RGB") + prompts = [ + "Make Pikachu hold a sign that says 'Nunchaku is awesome', yarn art style, detailed, vibrant colors", + "Convert the image to ghibli style", + "help me convert it to manga style", + "Convert it to a realistic photo", + ] + + # First, generate results with the 16-bit model + if not already_generate(results_dir_16_bit, 4): + pipeline = FluxKontextPipeline.from_pretrained( + "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16 + ) + + # Possibly offload the model to CPU when GPU memory is scarce + pipeline = offload_pipeline(pipeline) + + for prompt in prompts: + seed = hash_str_to_int(prompt) + result = pipeline(image=image, prompt=prompt, generator=torch.Generator().manual_seed(seed)).images[0] + result.save(os.path.join(results_dir_16_bit, f"{seed}.png")) + + # Clean up the 16-bit model + del pipeline.transformer + del pipeline.text_encoder + del pipeline.text_encoder_2 + del pipeline.vae + del pipeline + del result + gc.collect() + torch.cuda.synchronize() + torch.cuda.empty_cache() + + free, total = torch.cuda.mem_get_info() # bytes + print(f"After 16-bit generation: Free: {free/1024**2:.0f} MB / Total: {total/1024**2:.0f} MB") + + # Then, generate results with the 4-bit model + if not already_generate(results_dir_4_bit, 4): + transformer = NunchakuFluxTransformer2dModel.from_pretrained( + f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{precision}_r32-flux.1-kontext-dev.safetensors" + ) + pipeline = FluxKontextPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16 + ).to("cuda") + for prompt in prompts: + seed = hash_str_to_int(prompt) + result = pipeline(image=image, prompt=prompt, generator=torch.Generator().manual_seed(seed)).images[0] + result.save(os.path.join(results_dir_4_bit, f"{seed}.png")) + + # Clean up the 4-bit model + del pipeline + del transformer + gc.collect() + torch.cuda.synchronize() + torch.cuda.empty_cache() + + free, total = torch.cuda.mem_get_info() # bytes + print(f"After 4-bit generation: Free: {free/1024**2:.0f} MB / Total: {total/1024**2:.0f} MB") + + lpips = compute_lpips(results_dir_16_bit, results_dir_4_bit) + print(f"lpips: {lpips}") + assert lpips < expected_lpips * 1.15 diff --git a/tests/test_flux_kontext_lora.py b/tests/test_flux_kontext_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..6a81370b077236114be4589b0822f42e8d3cecde --- /dev/null +++ b/tests/test_flux_kontext_lora.py @@ -0,0 +1,228 @@ +""" +Test LoRA functionality for FLUX.1-Kontext model +""" + +import gc +import os +from pathlib import Path + +import numpy as np +import pytest +import torch +from diffusers import FluxKontextPipeline +from diffusers.utils import load_image +from PIL import Image + +from nunchaku import NunchakuFluxTransformer2dModel +from nunchaku.utils import get_precision, is_turing + + +def compute_pixel_difference(img1_path: str, img2_path: str) -> dict: + """Compute pixel-level differences between two images""" + img1 = np.array(Image.open(img1_path)).astype(float) + img2 = np.array(Image.open(img2_path)).astype(float) + + diff = np.abs(img1 - img2) + + return { + "mean_diff": np.mean(diff), + "max_diff": np.max(diff), + "pixels_changed": np.mean(diff > 0) * 100, + "pixels_changed_significantly": np.mean(diff > 10) * 100, + } + + +@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs") +def test_kontext_lora_application(): + """Test that LoRA weights are properly applied to Kontext model""" + gc.collect() + torch.cuda.empty_cache() + + precision = get_precision() + + # Setup directories + results_dir = Path("test_results") / precision / "flux.1-kontext-dev" / "lora_test" + os.makedirs(results_dir, exist_ok=True) + + # Load test image + image = load_image( + "https://huggingface.co/datasets/nunchaku-tech/test-data/resolve/main/ComfyUI-nunchaku/inputs/monalisa.jpg" + ).convert("RGB") + + prompt = "neon light, city atmosphere" + seed = 42 + num_inference_steps = 28 + guidance_scale = 2.5 + + # Load model + transformer = NunchakuFluxTransformer2dModel.from_pretrained( + f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{precision}_r32-flux.1-kontext-dev.safetensors" + ) + + pipeline = FluxKontextPipeline.from_pretrained( + "black-forest-labs/FLUX.1-Kontext-dev", transformer=transformer, torch_dtype=torch.bfloat16 + ).to("cuda") + + # Test 1: Generate without LoRA + generator = torch.Generator().manual_seed(seed) + result_no_lora = pipeline( + image=image, + prompt=prompt, + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + ).images[0] + result_no_lora.save(results_dir / "no_lora.png") + + # Test 2: Apply LoRA and generate + transformer.update_lora_params( + "nunchaku-tech/nunchaku-test-models/relight-kontext-lora-single-caption_comfy.safetensors" + # linoyts/relight-kontext-lora-single-caption/relight-kontext-lora-single-caption.safetensors" + ) + transformer.set_lora_strength(1.0) + + generator = torch.Generator().manual_seed(seed) + result_lora_1 = pipeline( + image=image, + prompt=prompt, + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + ).images[0] + result_lora_1.save(results_dir / "lora_1.0.png") + + # Test 3: Change LoRA strength + transformer.set_lora_strength(2.0) + + generator = torch.Generator().manual_seed(seed) + result_lora_2 = pipeline( + image=image, + prompt=prompt, + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + ).images[0] + result_lora_2.save(results_dir / "lora_2.0.png") + + # Test 4: Disable LoRA + transformer.set_lora_strength(0.0) + + generator = torch.Generator().manual_seed(seed) + result_lora_0 = pipeline( + image=image, + prompt=prompt, + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + ).images[0] + result_lora_0.save(results_dir / "lora_0.0.png") + + # Compute differences + diff_1 = compute_pixel_difference(results_dir / "no_lora.png", results_dir / "lora_1.0.png") + + diff_2 = compute_pixel_difference(results_dir / "no_lora.png", results_dir / "lora_2.0.png") + + diff_0 = compute_pixel_difference(results_dir / "no_lora.png", results_dir / "lora_0.0.png") + + diff_scale = compute_pixel_difference(results_dir / "lora_1.0.png", results_dir / "lora_2.0.png") + + # Assertions + # LoRA 1.0 should change the output + assert diff_1["mean_diff"] > 1.0, "LoRA 1.0 should significantly change the output" + assert diff_1["pixels_changed"] > 50, "LoRA 1.0 should change more than 50% of pixels" + + # LoRA 2.0 should have a significant effect (but not necessarily stronger than 1.0 due to saturation) + assert diff_2["mean_diff"] > 1.0, "LoRA 2.0 should significantly change the output" + + # Different LoRA strengths should produce different results + assert diff_scale["mean_diff"] > 1.0, "Different LoRA strengths should produce different results" + + # Log the actual differences for debugging + print(f"LoRA 1.0 vs baseline difference: {diff_1['mean_diff']:.2f}") + print(f"LoRA 2.0 vs baseline difference: {diff_2['mean_diff']:.2f}") + print(f"LoRA 1.0 vs 2.0 difference: {diff_scale['mean_diff']:.2f}") + + # Note: We're not asserting that LoRA 0.0 matches baseline due to known issue + # where LoRA weights may not be fully removed when strength=0.0 + print(f"LoRA 0.0 vs baseline difference: {diff_0['mean_diff']:.2f}") + if diff_0["mean_diff"] > 1.0: + print("WARNING: LoRA 0.0 differs from baseline - LoRA may not be fully disabled") + + # Clean up + del pipeline + del transformer + gc.collect() + torch.cuda.empty_cache() + + +@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs") +@pytest.mark.parametrize( + "lora_strength,expected_change", + [ + (0.5, 1.0), # Medium strength should cause moderate change + (1.0, 1.5), # Full strength should cause significant change + (1.5, 2.0), # Over-strength should cause larger change + ], +) +def test_kontext_lora_strength_scaling(lora_strength, expected_change): + """Test that LoRA strength scaling works proportionally""" + gc.collect() + torch.cuda.empty_cache() + + precision = get_precision() + + # Load model + transformer = NunchakuFluxTransformer2dModel.from_pretrained( + f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{precision}_r32-flux.1-kontext-dev.safetensors" + ) + + pipeline = FluxKontextPipeline.from_pretrained( + "black-forest-labs/FLUX.1-Kontext-dev", transformer=transformer, torch_dtype=torch.bfloat16 + ).to("cuda") + + # Load test image + image = load_image( + "https://huggingface.co/datasets/nunchaku-tech/test-data/resolve/main/ComfyUI-nunchaku/inputs/monalisa.jpg" + ).convert("RGB") + + prompt = "dramatic lighting, cinematic" + seed = 123 + + # Generate baseline + generator = torch.Generator().manual_seed(seed) + baseline = pipeline(image=image, prompt=prompt, generator=generator, num_inference_steps=20).images[0] + + transformer.update_lora_params( + "nunchaku-tech/nunchaku-test-models/relight-kontext-lora-single-caption_comfy.safetensors" + # "linoyts/relight-kontext-lora-single-caption/relight-kontext-lora-single-caption.safetensors" + ) + transformer.set_lora_strength(lora_strength) + + # Generate with LoRA + generator = torch.Generator().manual_seed(seed) + with_lora = pipeline(image=image, prompt=prompt, generator=generator, num_inference_steps=20).images[0] + + # Compute difference + baseline_arr = np.array(baseline).astype(float) + lora_arr = np.array(with_lora).astype(float) + mean_diff = np.mean(np.abs(baseline_arr - lora_arr)) + + # Assert that change is proportional to strength + # Allow 50% tolerance due to non-linear effects + assert ( + mean_diff > expected_change * 0.5 + ), f"LoRA strength {lora_strength} should cause mean difference > {expected_change * 0.5}, got {mean_diff}" + + print(f"LoRA strength {lora_strength}: mean difference = {mean_diff:.2f}") + + # Clean up + del pipeline + del transformer + gc.collect() + torch.cuda.empty_cache() + + +if __name__ == "__main__": + test_kontext_lora_application() + for strength, expected in [(0.5, 1.0), (1.0, 1.5), (1.5, 2.0)]: + test_kontext_lora_strength_scaling(strength, expected) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..35956ee7688adc6fd052930dabc9cbfa95eae889 --- /dev/null +++ b/utils.py @@ -0,0 +1,14 @@ +import argparse + + +def get_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "-p", "--precision", type=str, default="int4", choices=["int4", "fp4", "bf16"], help="Which precisions to use" + ) + parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder") + parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker") + parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses") + parser.add_argument("--gradio-root-path", type=str, default="") + args = parser.parse_args() + return args