Spaces:
Runtime error
Runtime error
Commit
·
04eaca9
0
Parent(s):
Add Git LFS support and remove binary files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +102 -0
- app.py +150 -0
- app/kontext/README.md +12 -0
- app/kontext/assets/description.html +21 -0
- app/kontext/assets/style.css +40 -0
- app/kontext/run_gradio.py +150 -0
- app/kontext/utils.py +14 -0
- examples/__init__.py +1 -0
- examples/flux.1-kontext-FALAI_lora.py +30 -0
- examples/flux.1-kontext-dev-teacache.py +30 -0
- examples/flux.1-kontext-dev.py +22 -0
- nunchaku/__init__.py +9 -0
- nunchaku/__version__.py +1 -0
- nunchaku/csrc/flux.h +254 -0
- nunchaku/csrc/gemm.h +114 -0
- nunchaku/csrc/gemm88.h +37 -0
- nunchaku/csrc/module.h +85 -0
- nunchaku/csrc/ops.h +173 -0
- nunchaku/csrc/pybind.cpp +124 -0
- nunchaku/csrc/sana.h +102 -0
- nunchaku/csrc/utils.h +39 -0
- nunchaku/lora/__init__.py +1 -0
- nunchaku/lora/flux/__init__.py +5 -0
- nunchaku/lora/flux/compose.py +218 -0
- nunchaku/lora/flux/convert.py +74 -0
- nunchaku/lora/flux/diffusers_converter.py +220 -0
- nunchaku/lora/flux/nunchaku_converter.py +949 -0
- nunchaku/lora/flux/packer.py +517 -0
- nunchaku/lora/flux/utils.py +94 -0
- nunchaku/models/__init__.py +9 -0
- nunchaku/models/attention.py +123 -0
- nunchaku/models/embeddings.py +138 -0
- nunchaku/models/linear.py +414 -0
- nunchaku/models/normalization.py +166 -0
- nunchaku/models/text_encoders/__init__.py +5 -0
- nunchaku/models/text_encoders/linear.py +238 -0
- nunchaku/models/text_encoders/t5_encoder.py +116 -0
- nunchaku/models/text_encoders/tinychat_utils.py +188 -0
- nunchaku/models/transformers/__init__.py +5 -0
- nunchaku/models/transformers/transformer_flux.py +991 -0
- nunchaku/models/transformers/transformer_flux_v2.py +646 -0
- nunchaku/models/transformers/transformer_qwenimage.py +601 -0
- nunchaku/models/transformers/transformer_sana.py +374 -0
- nunchaku/models/transformers/utils.py +147 -0
- nunchaku/models/utils.py +262 -0
- nunchaku/ops/__init__.py +1 -0
- nunchaku/ops/fused.py +178 -0
- nunchaku/ops/gemm.py +160 -0
- nunchaku/ops/gemv.py +56 -0
- nunchaku/ops/quantize.py +81 -0
README.md
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FLUX-Kontext Optimized Implementation
|
| 2 |
+
|
| 3 |
+
This package contains an optimized implementation of FLUX-Kontext using quantization and acceleration techniques.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
- **Quantized FLUX Transformer**: Efficient INT4/FP4 quantized implementation of FLUX.1-Kontext
|
| 8 |
+
- **Quantized T5 Encoder**: AWQ INT4 quantized T5 text encoder for memory efficiency
|
| 9 |
+
- **LoRA Support**: Full support for LoRA fine-tuning and inference
|
| 10 |
+
- **Gradio Web Interface**: Ready-to-use web interface for image editing
|
| 11 |
+
|
| 12 |
+
## Installation
|
| 13 |
+
|
| 14 |
+
```bash
|
| 15 |
+
pip install -r requirements.txt
|
| 16 |
+
python setup.py build_ext --inplace
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
## Quick Start
|
| 20 |
+
|
| 21 |
+
### Using the Gradio Interface
|
| 22 |
+
|
| 23 |
+
```bash
|
| 24 |
+
cd app/kontext
|
| 25 |
+
python run_gradio.py --precision int4
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
### Programmatic Usage
|
| 29 |
+
|
| 30 |
+
```python
|
| 31 |
+
import torch
|
| 32 |
+
from diffusers import FluxKontextPipeline
|
| 33 |
+
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
|
| 34 |
+
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
|
| 35 |
+
|
| 36 |
+
# Load quantized transformer
|
| 37 |
+
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
|
| 38 |
+
"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-int4_r32-flux.1-kontext-dev.safetensors"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Load quantized text encoder (optional)
|
| 42 |
+
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
|
| 43 |
+
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# Create pipeline
|
| 47 |
+
pipeline = FluxKontextPipeline.from_pretrained(
|
| 48 |
+
"black-forest-labs/FLUX.1-Kontext-dev",
|
| 49 |
+
transformer=transformer,
|
| 50 |
+
text_encoder_2=text_encoder_2,
|
| 51 |
+
torch_dtype=torch.bfloat16
|
| 52 |
+
)
|
| 53 |
+
pipeline = pipeline.to("cuda")
|
| 54 |
+
|
| 55 |
+
# Generate image
|
| 56 |
+
result = pipeline(
|
| 57 |
+
prompt="Your prompt here",
|
| 58 |
+
image=your_input_image,
|
| 59 |
+
num_inference_steps=28,
|
| 60 |
+
guidance_scale=2.5,
|
| 61 |
+
).images[0]
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
## Available Models
|
| 65 |
+
|
| 66 |
+
- `int4`: INT4 quantized transformer (default, most memory efficient)
|
| 67 |
+
- `fp4`: FP4 quantized transformer
|
| 68 |
+
- `bf16`: Full precision BFloat16 (highest quality, most memory usage)
|
| 69 |
+
|
| 70 |
+
## Directory Structure
|
| 71 |
+
|
| 72 |
+
```
|
| 73 |
+
flux-kontext/
|
| 74 |
+
├── nunchaku/ # Core quantized models and utilities
|
| 75 |
+
│ ├── models/ # Transformer and text encoder models
|
| 76 |
+
│ ├── lora/ # LoRA utilities
|
| 77 |
+
│ ├── ops/ # Quantized operations
|
| 78 |
+
│ └── csrc/ # C++ CUDA kernels
|
| 79 |
+
├── app/ # Application interfaces
|
| 80 |
+
│ └── kontext/ # Gradio web interface
|
| 81 |
+
├── examples/ # Example scripts
|
| 82 |
+
└── tests/ # Test scripts
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
## Examples
|
| 86 |
+
|
| 87 |
+
See the `examples/` directory for various usage patterns:
|
| 88 |
+
|
| 89 |
+
- `flux.1-kontext-dev.py`: Basic usage example
|
| 90 |
+
- `flux.1-kontext-dev-teacache.py`: Using TeaCache for acceleration
|
| 91 |
+
- `flux.1-kontext-FALAI_lora.py`: LoRA fine-tuning example
|
| 92 |
+
|
| 93 |
+
## Requirements
|
| 94 |
+
|
| 95 |
+
- Python >= 3.10
|
| 96 |
+
- PyTorch >= 2.5
|
| 97 |
+
- CUDA-capable GPU (recommended)
|
| 98 |
+
- 8GB+ GPU memory (for INT4 quantization)
|
| 99 |
+
|
| 100 |
+
## License
|
| 101 |
+
|
| 102 |
+
See the main nunchaku project for license information.
|
app.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from diffusers import FluxKontextPipeline
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from utils import get_args
|
| 8 |
+
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
|
| 9 |
+
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
import gradio as gr
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
MAX_SEED = 1000000000
|
| 16 |
+
|
| 17 |
+
args = get_args()
|
| 18 |
+
|
| 19 |
+
if args.precision == "bf16":
|
| 20 |
+
pipeline = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16)
|
| 21 |
+
pipeline = pipeline.to("cuda")
|
| 22 |
+
pipeline.precision = "bf16"
|
| 23 |
+
else:
|
| 24 |
+
assert args.precision in ["int4", "fp4"]
|
| 25 |
+
pipeline_init_kwargs = {}
|
| 26 |
+
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
|
| 27 |
+
f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{args.precision}_r32-flux.1-kontext-dev.safetensors"
|
| 28 |
+
)
|
| 29 |
+
pipeline_init_kwargs["transformer"] = transformer
|
| 30 |
+
if args.use_qencoder:
|
| 31 |
+
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
|
| 32 |
+
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
|
| 33 |
+
)
|
| 34 |
+
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
|
| 35 |
+
|
| 36 |
+
pipeline = FluxKontextPipeline.from_pretrained(
|
| 37 |
+
"black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
|
| 38 |
+
)
|
| 39 |
+
pipeline = pipeline.to("cuda")
|
| 40 |
+
pipeline.precision = args.precision
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def run(image, prompt: str, num_inference_steps: int, guidance_scale: float, seed: int) -> tuple[Image, str]:
|
| 44 |
+
img = image["composite"].convert("RGB")
|
| 45 |
+
|
| 46 |
+
start_time = time.time()
|
| 47 |
+
result_image = pipeline(
|
| 48 |
+
prompt=prompt,
|
| 49 |
+
image=img,
|
| 50 |
+
height=img.height,
|
| 51 |
+
width=img.width,
|
| 52 |
+
num_inference_steps=num_inference_steps,
|
| 53 |
+
guidance_scale=guidance_scale,
|
| 54 |
+
generator=torch.Generator().manual_seed(seed),
|
| 55 |
+
).images[0]
|
| 56 |
+
|
| 57 |
+
latency = time.time() - start_time
|
| 58 |
+
if latency < 1:
|
| 59 |
+
latency = latency * 1000
|
| 60 |
+
latency_str = f"{latency:.2f}ms"
|
| 61 |
+
else:
|
| 62 |
+
latency_str = f"{latency:.2f}s"
|
| 63 |
+
torch.cuda.empty_cache()
|
| 64 |
+
return result_image, latency_str
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
with gr.Blocks(css_paths="assets/style.css", title="Nunchaku FLUX.1-Kontext Demo") as demo:
|
| 68 |
+
with open("assets/description.html", "r") as f:
|
| 69 |
+
DESCRIPTION = f.read()
|
| 70 |
+
# Get the GPU properties
|
| 71 |
+
if torch.cuda.device_count() > 0:
|
| 72 |
+
gpu_properties = torch.cuda.get_device_properties(0)
|
| 73 |
+
gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB
|
| 74 |
+
gpu_name = torch.cuda.get_device_name(0)
|
| 75 |
+
device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
|
| 76 |
+
else:
|
| 77 |
+
device_info = "Running on CPU 🥶 This demo does not work on CPU."
|
| 78 |
+
|
| 79 |
+
header_str = DESCRIPTION.format(precision=args.precision.upper(), device_info=device_info, count_info="")
|
| 80 |
+
header = gr.HTML(header_str)
|
| 81 |
+
|
| 82 |
+
with gr.Row(elem_id="main_row"):
|
| 83 |
+
with gr.Column(elem_id="column_input"):
|
| 84 |
+
gr.Markdown("## INPUT", elem_id="input_header")
|
| 85 |
+
with gr.Group():
|
| 86 |
+
canvas = gr.ImageEditor(
|
| 87 |
+
height=640,
|
| 88 |
+
image_mode="RGB",
|
| 89 |
+
sources=["upload", "clipboard"],
|
| 90 |
+
type="pil",
|
| 91 |
+
label="Input",
|
| 92 |
+
show_label=False,
|
| 93 |
+
show_download_button=True,
|
| 94 |
+
interactive=True,
|
| 95 |
+
transforms=[],
|
| 96 |
+
canvas_size=(1024, 1024),
|
| 97 |
+
scale=1,
|
| 98 |
+
format="png",
|
| 99 |
+
layers=False,
|
| 100 |
+
)
|
| 101 |
+
with gr.Row():
|
| 102 |
+
prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
|
| 103 |
+
run_button = gr.Button("Run", scale=1, elem_id="run_button")
|
| 104 |
+
|
| 105 |
+
with gr.Row():
|
| 106 |
+
seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
|
| 107 |
+
randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
|
| 108 |
+
with gr.Accordion("Advanced options", open=False):
|
| 109 |
+
with gr.Group():
|
| 110 |
+
num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, value=28)
|
| 111 |
+
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, step=0.1, value=2.5)
|
| 112 |
+
|
| 113 |
+
with gr.Column(elem_id="column_output"):
|
| 114 |
+
gr.Markdown("## OUTPUT", elem_id="output_header")
|
| 115 |
+
with gr.Group():
|
| 116 |
+
result = gr.Image(
|
| 117 |
+
format="png",
|
| 118 |
+
height=640,
|
| 119 |
+
image_mode="RGB",
|
| 120 |
+
type="pil",
|
| 121 |
+
label="Result",
|
| 122 |
+
show_label=False,
|
| 123 |
+
show_download_button=True,
|
| 124 |
+
interactive=False,
|
| 125 |
+
elem_id="output_image",
|
| 126 |
+
)
|
| 127 |
+
latency_result = gr.Text(label="Inference Latency", show_label=True)
|
| 128 |
+
|
| 129 |
+
gr.Markdown("### Instructions")
|
| 130 |
+
gr.Markdown("**1**. Enter a text prompt")
|
| 131 |
+
gr.Markdown("**2**. Upload an image")
|
| 132 |
+
gr.Markdown("**3**. Try different seeds to generate different results")
|
| 133 |
+
|
| 134 |
+
run_inputs = [canvas, prompt, num_inference_steps, guidance_scale, seed]
|
| 135 |
+
run_outputs = [result, latency_result]
|
| 136 |
+
|
| 137 |
+
randomize_seed.click(
|
| 138 |
+
lambda: random.randint(0, MAX_SEED), inputs=[], outputs=seed, api_name=False, queue=False
|
| 139 |
+
).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False)
|
| 140 |
+
|
| 141 |
+
gr.on(
|
| 142 |
+
triggers=[prompt.submit, run_button.click],
|
| 143 |
+
fn=run,
|
| 144 |
+
inputs=run_inputs,
|
| 145 |
+
outputs=run_outputs,
|
| 146 |
+
api_name=False,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
if __name__ == "__main__":
|
| 150 |
+
demo.queue().launch(debug=True, share=True, root_path=args.gradio_root_path)
|
app/kontext/README.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Nunchaku INT4 FLUX.1 Kontext Demo
|
| 2 |
+
|
| 3 |
+

|
| 4 |
+
|
| 5 |
+
This interactive Gradio application allows you to edit an image with natural language. Simply run:
|
| 6 |
+
|
| 7 |
+
```shell
|
| 8 |
+
python run_gradio.py
|
| 9 |
+
```
|
| 10 |
+
|
| 11 |
+
- To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`.
|
| 12 |
+
- By default, we use our INT4 model. Use `-p bf16` to switch to the BF16 model.
|
app/kontext/assets/description.html
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
| 2 |
+
<div>
|
| 3 |
+
<!-- Logo Row -->
|
| 4 |
+
<div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
|
| 5 |
+
<a href="https://github.com/mit-han-lab/nunchaku" target="_blank">
|
| 6 |
+
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
|
| 7 |
+
alt="nunchaku logo" style="height: 150px; width: auto;" />
|
| 8 |
+
</a>
|
| 9 |
+
<a href="https://hanlab.mit.edu/projects/svdquant" target="_blank">
|
| 10 |
+
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
|
| 11 |
+
alt="svdquant logo" style="height: 40px; width: auto;" />
|
| 12 |
+
</a>
|
| 13 |
+
</div>
|
| 14 |
+
<h1 style="margin-top: 0;">{precision} FLUX.1-Kontext-dev Demo</h1>
|
| 15 |
+
|
| 16 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
| 17 |
+
{device_info}
|
| 18 |
+
</div>
|
| 19 |
+
{count_info}
|
| 20 |
+
</div>
|
| 21 |
+
</div>
|
app/kontext/assets/style.css
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/css/all.min.css');
|
| 2 |
+
|
| 3 |
+
.gradio-container {
|
| 4 |
+
max-width: 1200px !important;
|
| 5 |
+
margin: auto; /* Centers the element horizontally */
|
| 6 |
+
}
|
| 7 |
+
|
| 8 |
+
h1 {
|
| 9 |
+
text-align: center
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
.wrap.svelte-p4aq0j.svelte-p4aq0j {
|
| 13 |
+
display: none;
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
#column_input, #column_output {
|
| 17 |
+
width: 500px;
|
| 18 |
+
display: flex;
|
| 19 |
+
align-items: center;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
#input_header, #output_header {
|
| 23 |
+
display: flex;
|
| 24 |
+
justify-content: center;
|
| 25 |
+
align-items: center;
|
| 26 |
+
width: 400px;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
#accessibility {
|
| 30 |
+
text-align: center; /* Center-aligns the text */
|
| 31 |
+
margin: auto; /* Centers the element horizontally */
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
#random_seed {
|
| 35 |
+
height: 71px;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
#run_button {
|
| 39 |
+
height: 87px;
|
| 40 |
+
}
|
app/kontext/run_gradio.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from diffusers import FluxKontextPipeline
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from utils import get_args
|
| 8 |
+
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
|
| 9 |
+
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
import gradio as gr
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
MAX_SEED = 1000000000
|
| 16 |
+
|
| 17 |
+
args = get_args()
|
| 18 |
+
|
| 19 |
+
if args.precision == "bf16":
|
| 20 |
+
pipeline = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16)
|
| 21 |
+
pipeline = pipeline.to("cuda")
|
| 22 |
+
pipeline.precision = "bf16"
|
| 23 |
+
else:
|
| 24 |
+
assert args.precision in ["int4", "fp4"]
|
| 25 |
+
pipeline_init_kwargs = {}
|
| 26 |
+
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
|
| 27 |
+
f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{args.precision}_r32-flux.1-kontext-dev.safetensors"
|
| 28 |
+
)
|
| 29 |
+
pipeline_init_kwargs["transformer"] = transformer
|
| 30 |
+
if args.use_qencoder:
|
| 31 |
+
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
|
| 32 |
+
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
|
| 33 |
+
)
|
| 34 |
+
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
|
| 35 |
+
|
| 36 |
+
pipeline = FluxKontextPipeline.from_pretrained(
|
| 37 |
+
"black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
|
| 38 |
+
)
|
| 39 |
+
pipeline = pipeline.to("cuda")
|
| 40 |
+
pipeline.precision = args.precision
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def run(image, prompt: str, num_inference_steps: int, guidance_scale: float, seed: int) -> tuple[Image, str]:
|
| 44 |
+
img = image["composite"].convert("RGB")
|
| 45 |
+
|
| 46 |
+
start_time = time.time()
|
| 47 |
+
result_image = pipeline(
|
| 48 |
+
prompt=prompt,
|
| 49 |
+
image=img,
|
| 50 |
+
height=img.height,
|
| 51 |
+
width=img.width,
|
| 52 |
+
num_inference_steps=num_inference_steps,
|
| 53 |
+
guidance_scale=guidance_scale,
|
| 54 |
+
generator=torch.Generator().manual_seed(seed),
|
| 55 |
+
).images[0]
|
| 56 |
+
|
| 57 |
+
latency = time.time() - start_time
|
| 58 |
+
if latency < 1:
|
| 59 |
+
latency = latency * 1000
|
| 60 |
+
latency_str = f"{latency:.2f}ms"
|
| 61 |
+
else:
|
| 62 |
+
latency_str = f"{latency:.2f}s"
|
| 63 |
+
torch.cuda.empty_cache()
|
| 64 |
+
return result_image, latency_str
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
with gr.Blocks(css_paths="assets/style.css", title="Nunchaku FLUX.1-Kontext Demo") as demo:
|
| 68 |
+
with open("assets/description.html", "r") as f:
|
| 69 |
+
DESCRIPTION = f.read()
|
| 70 |
+
# Get the GPU properties
|
| 71 |
+
if torch.cuda.device_count() > 0:
|
| 72 |
+
gpu_properties = torch.cuda.get_device_properties(0)
|
| 73 |
+
gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB
|
| 74 |
+
gpu_name = torch.cuda.get_device_name(0)
|
| 75 |
+
device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
|
| 76 |
+
else:
|
| 77 |
+
device_info = "Running on CPU 🥶 This demo does not work on CPU."
|
| 78 |
+
|
| 79 |
+
header_str = DESCRIPTION.format(precision=args.precision.upper(), device_info=device_info, count_info="")
|
| 80 |
+
header = gr.HTML(header_str)
|
| 81 |
+
|
| 82 |
+
with gr.Row(elem_id="main_row"):
|
| 83 |
+
with gr.Column(elem_id="column_input"):
|
| 84 |
+
gr.Markdown("## INPUT", elem_id="input_header")
|
| 85 |
+
with gr.Group():
|
| 86 |
+
canvas = gr.ImageEditor(
|
| 87 |
+
height=640,
|
| 88 |
+
image_mode="RGB",
|
| 89 |
+
sources=["upload", "clipboard"],
|
| 90 |
+
type="pil",
|
| 91 |
+
label="Input",
|
| 92 |
+
show_label=False,
|
| 93 |
+
show_download_button=True,
|
| 94 |
+
interactive=True,
|
| 95 |
+
transforms=[],
|
| 96 |
+
canvas_size=(1024, 1024),
|
| 97 |
+
scale=1,
|
| 98 |
+
format="png",
|
| 99 |
+
layers=False,
|
| 100 |
+
)
|
| 101 |
+
with gr.Row():
|
| 102 |
+
prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
|
| 103 |
+
run_button = gr.Button("Run", scale=1, elem_id="run_button")
|
| 104 |
+
|
| 105 |
+
with gr.Row():
|
| 106 |
+
seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
|
| 107 |
+
randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
|
| 108 |
+
with gr.Accordion("Advanced options", open=False):
|
| 109 |
+
with gr.Group():
|
| 110 |
+
num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, value=28)
|
| 111 |
+
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, step=0.1, value=2.5)
|
| 112 |
+
|
| 113 |
+
with gr.Column(elem_id="column_output"):
|
| 114 |
+
gr.Markdown("## OUTPUT", elem_id="output_header")
|
| 115 |
+
with gr.Group():
|
| 116 |
+
result = gr.Image(
|
| 117 |
+
format="png",
|
| 118 |
+
height=640,
|
| 119 |
+
image_mode="RGB",
|
| 120 |
+
type="pil",
|
| 121 |
+
label="Result",
|
| 122 |
+
show_label=False,
|
| 123 |
+
show_download_button=True,
|
| 124 |
+
interactive=False,
|
| 125 |
+
elem_id="output_image",
|
| 126 |
+
)
|
| 127 |
+
latency_result = gr.Text(label="Inference Latency", show_label=True)
|
| 128 |
+
|
| 129 |
+
gr.Markdown("### Instructions")
|
| 130 |
+
gr.Markdown("**1**. Enter a text prompt")
|
| 131 |
+
gr.Markdown("**2**. Upload an image")
|
| 132 |
+
gr.Markdown("**3**. Try different seeds to generate different results")
|
| 133 |
+
|
| 134 |
+
run_inputs = [canvas, prompt, num_inference_steps, guidance_scale, seed]
|
| 135 |
+
run_outputs = [result, latency_result]
|
| 136 |
+
|
| 137 |
+
randomize_seed.click(
|
| 138 |
+
lambda: random.randint(0, MAX_SEED), inputs=[], outputs=seed, api_name=False, queue=False
|
| 139 |
+
).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False)
|
| 140 |
+
|
| 141 |
+
gr.on(
|
| 142 |
+
triggers=[prompt.submit, run_button.click],
|
| 143 |
+
fn=run,
|
| 144 |
+
inputs=run_inputs,
|
| 145 |
+
outputs=run_outputs,
|
| 146 |
+
api_name=False,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
if __name__ == "__main__":
|
| 150 |
+
demo.queue().launch(debug=True, share=True, root_path=args.gradio_root_path)
|
app/kontext/utils.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_args() -> argparse.Namespace:
|
| 5 |
+
parser = argparse.ArgumentParser()
|
| 6 |
+
parser.add_argument(
|
| 7 |
+
"-p", "--precision", type=str, default="int4", choices=["int4", "fp4", "bf16"], help="Which precisions to use"
|
| 8 |
+
)
|
| 9 |
+
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
|
| 10 |
+
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
|
| 11 |
+
parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
|
| 12 |
+
parser.add_argument("--gradio-root-path", type=str, default="")
|
| 13 |
+
args = parser.parse_args()
|
| 14 |
+
return args
|
examples/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# FLUX-Kontext examples
|
examples/flux.1-kontext-FALAI_lora.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from diffusers import FluxKontextPipeline
|
| 3 |
+
from diffusers.utils import load_image
|
| 4 |
+
|
| 5 |
+
from nunchaku import NunchakuFluxTransformer2dModel
|
| 6 |
+
from nunchaku.utils import get_precision
|
| 7 |
+
|
| 8 |
+
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
|
| 9 |
+
f"nunchaku-tech/nunchaku-flux.1-kontext-dev/svdq-{get_precision()}_r32-flux.1-kontext-dev.safetensors"
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
pipeline = FluxKontextPipeline.from_pretrained(
|
| 13 |
+
"black-forest-labs/FLUX.1-Kontext-dev", transformer=transformer, torch_dtype=torch.bfloat16
|
| 14 |
+
).to("cuda")
|
| 15 |
+
|
| 16 |
+
image = load_image(
|
| 17 |
+
"https://huggingface.co/datasets/nunchaku-tech/test-data/resolve/main/ComfyUI-nunchaku/inputs/monalisa.jpg"
|
| 18 |
+
).convert("RGB")
|
| 19 |
+
|
| 20 |
+
### LoRA Related Code ###
|
| 21 |
+
transformer.update_lora_params(
|
| 22 |
+
"nunchaku-tech/nunchaku-test-models/relight-kontext-lora-single-caption_comfy.safetensors"
|
| 23 |
+
# "linoyts/relight-kontext-lora-single-caption/relight-kontext-lora-single-caption.safetensors"
|
| 24 |
+
) # Path to your LoRA safetensors, can also be a remote HuggingFace path
|
| 25 |
+
transformer.set_lora_strength(1) # Your LoRA strength here
|
| 26 |
+
### End of LoRA Related Code ###
|
| 27 |
+
|
| 28 |
+
prompt = "neon light, city"
|
| 29 |
+
image = pipeline(image=image, prompt=prompt, generator=torch.Generator().manual_seed(23), guidance_scale=2.5).images[0]
|
| 30 |
+
image.save("flux-kontext-dev.png")
|
examples/flux.1-kontext-dev-teacache.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from diffusers import FluxKontextPipeline
|
| 5 |
+
from diffusers.utils import load_image
|
| 6 |
+
|
| 7 |
+
from nunchaku import NunchakuFluxTransformer2dModel
|
| 8 |
+
from nunchaku.caching.teacache import TeaCache
|
| 9 |
+
from nunchaku.utils import get_precision
|
| 10 |
+
|
| 11 |
+
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
|
| 12 |
+
f"nunchaku-tech/nunchaku-flux.1-kontext-dev/svdq-{get_precision()}_r32-flux.1-kontext-dev.safetensors"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
pipeline = FluxKontextPipeline.from_pretrained(
|
| 16 |
+
"black-forest-labs/FLUX.1-Kontext-dev", transformer=transformer, torch_dtype=torch.bfloat16
|
| 17 |
+
).to("cuda")
|
| 18 |
+
|
| 19 |
+
image = load_image(
|
| 20 |
+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
|
| 21 |
+
).convert("RGB")
|
| 22 |
+
|
| 23 |
+
prompt = "Make Pikachu hold a sign that says 'Nunchaku is awesome', yarn art style, detailed, vibrant colors"
|
| 24 |
+
|
| 25 |
+
start_time = time.time()
|
| 26 |
+
with TeaCache(model=transformer, num_steps=50, rel_l1_thresh=0.3, enabled=True, model_name="flux-kontext"):
|
| 27 |
+
image = pipeline(image=image, prompt=prompt, guidance_scale=2.5).images[0]
|
| 28 |
+
end_time = time.time()
|
| 29 |
+
print(f"Time taken: {(end_time - start_time)} seconds")
|
| 30 |
+
image.save(f"flux-kontext-dev-{get_precision()}-tc.png")
|
examples/flux.1-kontext-dev.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from diffusers import FluxKontextPipeline
|
| 3 |
+
from diffusers.utils import load_image
|
| 4 |
+
|
| 5 |
+
from nunchaku import NunchakuFluxTransformer2dModel
|
| 6 |
+
from nunchaku.utils import get_precision
|
| 7 |
+
|
| 8 |
+
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
|
| 9 |
+
f"nunchaku-tech/nunchaku-flux.1-kontext-dev/svdq-{get_precision()}_r32-flux.1-kontext-dev.safetensors"
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
pipeline = FluxKontextPipeline.from_pretrained(
|
| 13 |
+
"black-forest-labs/FLUX.1-Kontext-dev", transformer=transformer, torch_dtype=torch.bfloat16
|
| 14 |
+
).to("cuda")
|
| 15 |
+
|
| 16 |
+
image = load_image(
|
| 17 |
+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
|
| 18 |
+
).convert("RGB")
|
| 19 |
+
|
| 20 |
+
prompt = "Make Pikachu hold a sign that says 'Nunchaku is awesome', yarn art style, detailed, vibrant colors"
|
| 21 |
+
image = pipeline(image=image, prompt=prompt, guidance_scale=2.5).images[0]
|
| 22 |
+
image.save("flux-kontext-dev.png")
|
nunchaku/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .models import (
|
| 2 |
+
NunchakuFluxTransformer2dModel,
|
| 3 |
+
NunchakuT5EncoderModel,
|
| 4 |
+
)
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"NunchakuFluxTransformer2dModel",
|
| 8 |
+
"NunchakuT5EncoderModel",
|
| 9 |
+
]
|
nunchaku/__version__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__version__ = "1.0.0-flux-kontext"
|
nunchaku/csrc/flux.h
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "interop/torch.h"
|
| 4 |
+
#include "FluxModel.h"
|
| 5 |
+
#include "Serialization.h"
|
| 6 |
+
#include "debug.h"
|
| 7 |
+
#include "Linear.h"
|
| 8 |
+
#include "module.h"
|
| 9 |
+
|
| 10 |
+
class QuantizedFluxModel : public ModuleWrapper<FluxModel> { // : public torch::CustomClassHolder {
|
| 11 |
+
public:
|
| 12 |
+
void init(bool use_fp4, bool offload, bool bf16, int8_t deviceId) {
|
| 13 |
+
spdlog::info("Initializing QuantizedFluxModel on device {}", deviceId);
|
| 14 |
+
if (!bf16) {
|
| 15 |
+
spdlog::info("Use FP16 model");
|
| 16 |
+
}
|
| 17 |
+
if (offload) {
|
| 18 |
+
spdlog::info("Layer offloading enabled");
|
| 19 |
+
}
|
| 20 |
+
ModuleWrapper::init(deviceId);
|
| 21 |
+
|
| 22 |
+
CUDADeviceContext ctx(this->deviceId);
|
| 23 |
+
net = std::make_unique<FluxModel>(
|
| 24 |
+
use_fp4, offload, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
bool isBF16() {
|
| 28 |
+
checkModel();
|
| 29 |
+
return net->dtype == Tensor::BF16;
|
| 30 |
+
}
|
| 31 |
+
pybind11::function residual_callback;
|
| 32 |
+
void set_residual_callback(pybind11::function callback) {
|
| 33 |
+
pybind11::gil_scoped_acquire gil;
|
| 34 |
+
if (!callback || callback.is_none()) {
|
| 35 |
+
residual_callback = pybind11::function();
|
| 36 |
+
if (net) {
|
| 37 |
+
net->set_residual_callback(nullptr);
|
| 38 |
+
}
|
| 39 |
+
return;
|
| 40 |
+
}
|
| 41 |
+
residual_callback = std::move(callback);
|
| 42 |
+
if (net) {
|
| 43 |
+
pybind11::object cb = residual_callback;
|
| 44 |
+
net->set_residual_callback([cb](const Tensor &x) -> Tensor {
|
| 45 |
+
torch::Tensor torch_x = to_torch(x);
|
| 46 |
+
pybind11::object result = cb(torch_x);
|
| 47 |
+
torch::Tensor torch_y = result.cast<torch::Tensor>();
|
| 48 |
+
Tensor y = from_torch(torch_y);
|
| 49 |
+
return y;
|
| 50 |
+
});
|
| 51 |
+
} else {
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
torch::Tensor forward(torch::Tensor hidden_states,
|
| 56 |
+
torch::Tensor encoder_hidden_states,
|
| 57 |
+
torch::Tensor temb,
|
| 58 |
+
torch::Tensor rotary_emb_img,
|
| 59 |
+
torch::Tensor rotary_emb_context,
|
| 60 |
+
torch::Tensor rotary_emb_single,
|
| 61 |
+
std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
|
| 62 |
+
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt,
|
| 63 |
+
bool skip_first_layer = false) {
|
| 64 |
+
checkModel();
|
| 65 |
+
CUDADeviceContext ctx(deviceId);
|
| 66 |
+
|
| 67 |
+
spdlog::debug("QuantizedFluxModel forward");
|
| 68 |
+
|
| 69 |
+
hidden_states = hidden_states.contiguous();
|
| 70 |
+
encoder_hidden_states = encoder_hidden_states.contiguous();
|
| 71 |
+
temb = temb.contiguous();
|
| 72 |
+
rotary_emb_img = rotary_emb_img.contiguous();
|
| 73 |
+
rotary_emb_context = rotary_emb_context.contiguous();
|
| 74 |
+
rotary_emb_single = rotary_emb_single.contiguous();
|
| 75 |
+
|
| 76 |
+
Tensor result = net->forward(
|
| 77 |
+
from_torch(hidden_states),
|
| 78 |
+
from_torch(encoder_hidden_states),
|
| 79 |
+
from_torch(temb),
|
| 80 |
+
from_torch(rotary_emb_img),
|
| 81 |
+
from_torch(rotary_emb_context),
|
| 82 |
+
from_torch(rotary_emb_single),
|
| 83 |
+
controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
|
| 84 |
+
controlnet_single_block_samples.has_value()
|
| 85 |
+
? from_torch(controlnet_single_block_samples.value().contiguous())
|
| 86 |
+
: Tensor{},
|
| 87 |
+
skip_first_layer);
|
| 88 |
+
|
| 89 |
+
torch::Tensor output = to_torch(result);
|
| 90 |
+
Tensor::synchronizeDevice();
|
| 91 |
+
|
| 92 |
+
return output;
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
std::tuple<torch::Tensor, torch::Tensor>
|
| 96 |
+
forward_layer(int64_t idx,
|
| 97 |
+
torch::Tensor hidden_states,
|
| 98 |
+
torch::Tensor encoder_hidden_states,
|
| 99 |
+
torch::Tensor temb,
|
| 100 |
+
torch::Tensor rotary_emb_img,
|
| 101 |
+
torch::Tensor rotary_emb_context,
|
| 102 |
+
std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
|
| 103 |
+
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt) {
|
| 104 |
+
CUDADeviceContext ctx(deviceId);
|
| 105 |
+
|
| 106 |
+
spdlog::debug("QuantizedFluxModel forward_layer {}", idx);
|
| 107 |
+
|
| 108 |
+
hidden_states = hidden_states.contiguous();
|
| 109 |
+
encoder_hidden_states = encoder_hidden_states.contiguous();
|
| 110 |
+
temb = temb.contiguous();
|
| 111 |
+
rotary_emb_img = rotary_emb_img.contiguous();
|
| 112 |
+
rotary_emb_context = rotary_emb_context.contiguous();
|
| 113 |
+
|
| 114 |
+
auto &&[hidden_states_, encoder_hidden_states_] = net->forward_layer(
|
| 115 |
+
idx,
|
| 116 |
+
from_torch(hidden_states),
|
| 117 |
+
from_torch(encoder_hidden_states),
|
| 118 |
+
from_torch(temb),
|
| 119 |
+
from_torch(rotary_emb_img),
|
| 120 |
+
from_torch(rotary_emb_context),
|
| 121 |
+
controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
|
| 122 |
+
controlnet_single_block_samples.has_value()
|
| 123 |
+
? from_torch(controlnet_single_block_samples.value().contiguous())
|
| 124 |
+
: Tensor{});
|
| 125 |
+
|
| 126 |
+
hidden_states = to_torch(hidden_states_);
|
| 127 |
+
encoder_hidden_states = to_torch(encoder_hidden_states_);
|
| 128 |
+
Tensor::synchronizeDevice();
|
| 129 |
+
|
| 130 |
+
return {hidden_states, encoder_hidden_states};
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
torch::Tensor forward_single_layer(int64_t idx,
|
| 134 |
+
torch::Tensor hidden_states,
|
| 135 |
+
torch::Tensor temb,
|
| 136 |
+
torch::Tensor rotary_emb_single) {
|
| 137 |
+
CUDADeviceContext ctx(deviceId);
|
| 138 |
+
|
| 139 |
+
spdlog::debug("QuantizedFluxModel forward_single_layer {}", idx);
|
| 140 |
+
|
| 141 |
+
hidden_states = hidden_states.contiguous();
|
| 142 |
+
temb = temb.contiguous();
|
| 143 |
+
rotary_emb_single = rotary_emb_single.contiguous();
|
| 144 |
+
|
| 145 |
+
if (net->isOffloadEnabled()) {
|
| 146 |
+
net->single_transformer_blocks.at(idx)->loadLazyParams();
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
Tensor result = net->single_transformer_blocks.at(idx)->forward(
|
| 150 |
+
from_torch(hidden_states), from_torch(temb), from_torch(rotary_emb_single));
|
| 151 |
+
|
| 152 |
+
if (net->isOffloadEnabled()) {
|
| 153 |
+
net->single_transformer_blocks.at(idx)->releaseLazyParams();
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
hidden_states = to_torch(result);
|
| 157 |
+
Tensor::synchronizeDevice();
|
| 158 |
+
|
| 159 |
+
return hidden_states;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
// expose the norm1 forward method of the transformer blocks
|
| 163 |
+
// this is used by TeaCache to get the norm1 output
|
| 164 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
| 165 |
+
norm_one_forward(int64_t idx, torch::Tensor hidden_states, torch::Tensor temb) {
|
| 166 |
+
AdaLayerNormZero::Output result =
|
| 167 |
+
net->transformer_blocks.at(idx)->norm1.forward(from_torch(hidden_states), from_torch(temb));
|
| 168 |
+
return {to_torch(result.x),
|
| 169 |
+
to_torch(result.gate_msa),
|
| 170 |
+
to_torch(result.shift_mlp),
|
| 171 |
+
to_torch(result.scale_mlp),
|
| 172 |
+
to_torch(result.gate_mlp)};
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
// must be called after loading lora
|
| 176 |
+
// skip specific ranks in W4A4 layers
|
| 177 |
+
void setLoraScale(int skipRanks, float scale) {
|
| 178 |
+
if (skipRanks % 16 != 0) {
|
| 179 |
+
throw std::invalid_argument("skipRanks must be multiples of 16");
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
CUDADeviceContext ctx(deviceId);
|
| 183 |
+
|
| 184 |
+
spdlog::info("Set lora scale to {} (skip {} ranks)", scale, skipRanks);
|
| 185 |
+
|
| 186 |
+
net->traverse([&](Module *module) {
|
| 187 |
+
if (auto *m = dynamic_cast<GEMV_AWQ *>(module)) {
|
| 188 |
+
m->lora_scale = scale;
|
| 189 |
+
} else if (auto *m = dynamic_cast<GEMM_W4A4 *>(module)) {
|
| 190 |
+
for (int i = 0; i < skipRanks / 16; i++) {
|
| 191 |
+
m->lora_scales[i] = 1.0f;
|
| 192 |
+
}
|
| 193 |
+
for (int i = skipRanks / 16; i < (int)m->lora_scales.size(); i++) {
|
| 194 |
+
m->lora_scales[i] = scale;
|
| 195 |
+
}
|
| 196 |
+
}
|
| 197 |
+
});
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
void setAttentionImpl(std::string name) {
|
| 201 |
+
if (name.empty() || name == "default") {
|
| 202 |
+
name = "flashattn2";
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
spdlog::info("Set attention implementation to {}", name);
|
| 206 |
+
|
| 207 |
+
if (name == "flashattn2") {
|
| 208 |
+
net->setAttentionImpl(AttentionImpl::FlashAttention2);
|
| 209 |
+
} else if (name == "nunchaku-fp16") {
|
| 210 |
+
net->setAttentionImpl(AttentionImpl::NunchakuFP16);
|
| 211 |
+
} else {
|
| 212 |
+
throw std::invalid_argument(spdlog::fmt_lib::format("Invalid attention implementation {}", name));
|
| 213 |
+
}
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
|
| 217 |
+
forward_layer_ip_adapter(int64_t idx,
|
| 218 |
+
torch::Tensor hidden_states,
|
| 219 |
+
torch::Tensor encoder_hidden_states,
|
| 220 |
+
torch::Tensor temb,
|
| 221 |
+
torch::Tensor rotary_emb_img,
|
| 222 |
+
torch::Tensor rotary_emb_context,
|
| 223 |
+
std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
|
| 224 |
+
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt) {
|
| 225 |
+
CUDADeviceContext ctx(deviceId);
|
| 226 |
+
|
| 227 |
+
spdlog::debug("QuantizedFluxModel forward_layer {}", idx);
|
| 228 |
+
|
| 229 |
+
hidden_states = hidden_states.contiguous();
|
| 230 |
+
encoder_hidden_states = encoder_hidden_states.contiguous();
|
| 231 |
+
temb = temb.contiguous();
|
| 232 |
+
rotary_emb_img = rotary_emb_img.contiguous();
|
| 233 |
+
rotary_emb_context = rotary_emb_context.contiguous();
|
| 234 |
+
|
| 235 |
+
auto &&[hidden_states_, encoder_hidden_states_, ip_query_] = net->forward_ip_adapter(
|
| 236 |
+
idx,
|
| 237 |
+
from_torch(hidden_states),
|
| 238 |
+
from_torch(encoder_hidden_states),
|
| 239 |
+
from_torch(temb),
|
| 240 |
+
from_torch(rotary_emb_img),
|
| 241 |
+
from_torch(rotary_emb_context),
|
| 242 |
+
controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
|
| 243 |
+
controlnet_single_block_samples.has_value()
|
| 244 |
+
? from_torch(controlnet_single_block_samples.value().contiguous())
|
| 245 |
+
: Tensor{});
|
| 246 |
+
|
| 247 |
+
hidden_states = to_torch(hidden_states_);
|
| 248 |
+
encoder_hidden_states = to_torch(encoder_hidden_states_);
|
| 249 |
+
torch::Tensor ip_query = to_torch(ip_query_);
|
| 250 |
+
Tensor::synchronizeDevice();
|
| 251 |
+
|
| 252 |
+
return {hidden_states, encoder_hidden_states, ip_query};
|
| 253 |
+
}
|
| 254 |
+
};
|
nunchaku/csrc/gemm.h
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "interop/torch.h"
|
| 4 |
+
#include "Serialization.h"
|
| 5 |
+
#include "Linear.h"
|
| 6 |
+
#include "debug.h"
|
| 7 |
+
#include "module.h"
|
| 8 |
+
|
| 9 |
+
class QuantizedGEMM : public ModuleWrapper<GEMM_W4A4> {
|
| 10 |
+
public:
|
| 11 |
+
void init(int64_t in_features, int64_t out_features, bool bias, bool use_fp4, bool bf16, int8_t deviceId) {
|
| 12 |
+
spdlog::info("Initializing QuantizedGEMM");
|
| 13 |
+
|
| 14 |
+
size_t val = 0;
|
| 15 |
+
checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, 8192));
|
| 16 |
+
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
|
| 17 |
+
spdlog::debug("Stack={}", val);
|
| 18 |
+
|
| 19 |
+
net = std::make_unique<GEMM_W4A4>((int)in_features,
|
| 20 |
+
(int)out_features,
|
| 21 |
+
bias,
|
| 22 |
+
use_fp4,
|
| 23 |
+
bf16 ? Tensor::BF16 : Tensor::FP16,
|
| 24 |
+
Device::cuda((int)deviceId));
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
torch::Tensor forward(torch::Tensor x) {
|
| 28 |
+
checkModel();
|
| 29 |
+
|
| 30 |
+
std::cerr << "QuantizedGEMM forward" << std::endl;
|
| 31 |
+
|
| 32 |
+
x = x.contiguous();
|
| 33 |
+
|
| 34 |
+
Tensor result = net->forward(from_torch(x));
|
| 35 |
+
|
| 36 |
+
torch::Tensor output = to_torch(result);
|
| 37 |
+
Tensor::synchronizeDevice();
|
| 38 |
+
|
| 39 |
+
return output;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
std::string dumpTensorBF16(Tensor x) {
|
| 43 |
+
std::stringstream ss;
|
| 44 |
+
for (int i = 0; i < 256; i++) {
|
| 45 |
+
ss << spdlog::fmt_lib::format("{:.3f} ", (float)(x.data_ptr<__nv_bfloat16>()[i]));
|
| 46 |
+
}
|
| 47 |
+
ss << std::endl;
|
| 48 |
+
return ss.str();
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
std::string dumpTensorINT4(Tensor x) {
|
| 52 |
+
using spdlog::fmt_lib::format;
|
| 53 |
+
|
| 54 |
+
const int M = x.shape[0];
|
| 55 |
+
const int K = x.shape[1] * 2;
|
| 56 |
+
|
| 57 |
+
assert(x.dtype() == Tensor::INT8);
|
| 58 |
+
|
| 59 |
+
// activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t (uint4)
|
| 60 |
+
|
| 61 |
+
constexpr int BLOCK_M = 256;
|
| 62 |
+
constexpr int WARP_K = 64;
|
| 63 |
+
constexpr int NUM_WARPS = 8;
|
| 64 |
+
constexpr int WARP_M_TILES = 2;
|
| 65 |
+
constexpr int WARP_SIZE = 32;
|
| 66 |
+
|
| 67 |
+
std::stringstream ss;
|
| 68 |
+
for (int bm = 0; bm < M / BLOCK_M; bm++) {
|
| 69 |
+
for (int bn = 0; bn < K / WARP_K; bn++) {
|
| 70 |
+
for (int warpId = 0; warpId < NUM_WARPS; warpId++) {
|
| 71 |
+
ss << format("[bm={},bn={},warp={}] ", bm, bn, warpId);
|
| 72 |
+
const int offset = ((bm * (K / WARP_K) + bn) * NUM_WARPS + warpId) * WARP_M_TILES * WARP_SIZE * 4;
|
| 73 |
+
|
| 74 |
+
for (int i = 0; i < 16; i++) {
|
| 75 |
+
assert(static_cast<size_t>(offset + i) < x.numel() / 4);
|
| 76 |
+
uint32_t val = x.data_ptr<uint32_t>()[offset + i];
|
| 77 |
+
ss << "{";
|
| 78 |
+
for (int j = 0; j < 8; j++) {
|
| 79 |
+
int i4val = (val >> (j * 4)) & 0xf;
|
| 80 |
+
if (i4val & 0x8) {
|
| 81 |
+
i4val = -((~i4val & 0x7) + 1);
|
| 82 |
+
}
|
| 83 |
+
ss << format("{} ", i4val);
|
| 84 |
+
}
|
| 85 |
+
ss << format("}} {:x} ", val);
|
| 86 |
+
}
|
| 87 |
+
ss << std::endl;
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
ss << std::endl;
|
| 93 |
+
return ss.str();
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
void quantize(torch::Tensor x, bool fuse_glu) {
|
| 97 |
+
checkModel();
|
| 98 |
+
|
| 99 |
+
spdlog::debug("QuantizedGEMM quantize");
|
| 100 |
+
|
| 101 |
+
x = x.contiguous();
|
| 102 |
+
|
| 103 |
+
auto qout = net->quantize(from_torch(x), fuse_glu);
|
| 104 |
+
|
| 105 |
+
Tensor act = qout.act.copy(Device::cpu());
|
| 106 |
+
Tensor ascales = qout.ascales.copy(Device::cpu());
|
| 107 |
+
Tensor lora_act = qout.lora_act.copy(Device::cpu());
|
| 108 |
+
|
| 109 |
+
Tensor::synchronizeDevice();
|
| 110 |
+
|
| 111 |
+
spdlog::debug("act = {}", dumpTensorINT4(act));
|
| 112 |
+
spdlog::debug("ascales = {}", dumpTensorBF16(ascales));
|
| 113 |
+
}
|
| 114 |
+
};
|
nunchaku/csrc/gemm88.h
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "interop/torch.h"
|
| 4 |
+
#include "Serialization.h"
|
| 5 |
+
#include "Linear.h"
|
| 6 |
+
#include "debug.h"
|
| 7 |
+
#include "module.h"
|
| 8 |
+
|
| 9 |
+
class QuantizedGEMM88 : public ModuleWrapper<GEMM_W8A8> {
|
| 10 |
+
public:
|
| 11 |
+
void init(int64_t in_features, int64_t out_features, bool bias, bool bf16, int8_t deviceId) {
|
| 12 |
+
spdlog::info("Initializing QuantizedGEMM88");
|
| 13 |
+
|
| 14 |
+
size_t val = 0;
|
| 15 |
+
checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, 8192));
|
| 16 |
+
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
|
| 17 |
+
spdlog::debug("Stack={}", val);
|
| 18 |
+
|
| 19 |
+
net = std::make_unique<GEMM_W8A8>(
|
| 20 |
+
(int)in_features, (int)out_features, bias, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
torch::Tensor forward(torch::Tensor x) {
|
| 24 |
+
checkModel();
|
| 25 |
+
|
| 26 |
+
std::cerr << "QuantizedGEMM88 forward" << std::endl;
|
| 27 |
+
|
| 28 |
+
x = x.contiguous();
|
| 29 |
+
|
| 30 |
+
Tensor result = net->forward(from_torch(x));
|
| 31 |
+
|
| 32 |
+
torch::Tensor output = to_torch(result);
|
| 33 |
+
Tensor::synchronizeDevice();
|
| 34 |
+
|
| 35 |
+
return output;
|
| 36 |
+
}
|
| 37 |
+
};
|
nunchaku/csrc/module.h
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "interop/torch.h"
|
| 4 |
+
#include "Serialization.h"
|
| 5 |
+
#include "Module.h"
|
| 6 |
+
#include "debug.h"
|
| 7 |
+
#include "utils.h"
|
| 8 |
+
|
| 9 |
+
template<typename M>
|
| 10 |
+
class ModuleWrapper {
|
| 11 |
+
public:
|
| 12 |
+
void init(int deviceId) {
|
| 13 |
+
this->deviceId = deviceId;
|
| 14 |
+
}
|
| 15 |
+
void reset() {
|
| 16 |
+
CUDADeviceContext ctx(this->deviceId);
|
| 17 |
+
|
| 18 |
+
debugContext.reset();
|
| 19 |
+
net.reset();
|
| 20 |
+
Tensor::synchronizeDevice();
|
| 21 |
+
|
| 22 |
+
nunchaku::utils::trim_memory();
|
| 23 |
+
Tensor::synchronizeDevice();
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
void load(std::string path, bool partial = false) {
|
| 27 |
+
checkModel();
|
| 28 |
+
CUDADeviceContext ctx(this->deviceId);
|
| 29 |
+
|
| 30 |
+
spdlog::info("{} weights from {}", partial ? "Loading partial" : "Loading", path);
|
| 31 |
+
|
| 32 |
+
std::shared_ptr<SafeTensors> provider = std::make_shared<SafeTensors>(path);
|
| 33 |
+
net->loadParams(*provider, partial);
|
| 34 |
+
Tensor::synchronizeDevice();
|
| 35 |
+
|
| 36 |
+
spdlog::info("Done.");
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
void loadDict(std::map<std::string, torch::Tensor> dict, bool partial = false) {
|
| 40 |
+
checkModel();
|
| 41 |
+
CUDADeviceContext ctx(this->deviceId);
|
| 42 |
+
|
| 43 |
+
spdlog::info("{} weights from pytorch", partial ? "Loading partial" : "Loading");
|
| 44 |
+
|
| 45 |
+
std::shared_ptr<TensorsProviderTorch> provider = std::make_shared<TensorsProviderTorch>(std::move(dict));
|
| 46 |
+
net->loadParams(*provider, partial);
|
| 47 |
+
Tensor::synchronizeDevice();
|
| 48 |
+
|
| 49 |
+
spdlog::info("Done.");
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
void startDebug() {
|
| 53 |
+
debugContext = std::make_unique<DebugContext>();
|
| 54 |
+
}
|
| 55 |
+
void stopDebug() {
|
| 56 |
+
debugContext.reset();
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
auto getDebugResults() {
|
| 60 |
+
CUDADeviceContext ctx(this->deviceId);
|
| 61 |
+
|
| 62 |
+
std::map<std::string, torch::Tensor> result;
|
| 63 |
+
|
| 64 |
+
if (debugContext) {
|
| 65 |
+
for (auto &&[key, value] : debugContext->tensors) {
|
| 66 |
+
result[key] = to_torch(value);
|
| 67 |
+
}
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
return result;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
protected:
|
| 74 |
+
void checkModel() {
|
| 75 |
+
if (!net) {
|
| 76 |
+
throw std::runtime_error("Model not initialized");
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
protected:
|
| 81 |
+
std::unique_ptr<M> net;
|
| 82 |
+
std::unique_ptr<DebugContext> debugContext;
|
| 83 |
+
|
| 84 |
+
int deviceId = -1;
|
| 85 |
+
};
|
nunchaku/csrc/ops.h
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "interop/torch.h"
|
| 4 |
+
#include "kernels/zgemm/zgemm.h"
|
| 5 |
+
#include "kernels/awq/gemv_awq.h"
|
| 6 |
+
#include "kernels/awq/gemm_awq.h"
|
| 7 |
+
|
| 8 |
+
namespace nunchaku::ops {
|
| 9 |
+
|
| 10 |
+
void gemm_w4a4(std::optional<torch::Tensor> act, // packed act [M, K / 2]
|
| 11 |
+
std::optional<torch::Tensor> wgt, // packed act [N, K / 2]
|
| 12 |
+
std::optional<torch::Tensor> out, // linear [M, N]
|
| 13 |
+
std::optional<torch::Tensor> qout, // packed act [M, N / 2]
|
| 14 |
+
std::optional<torch::Tensor> ascales, // packed as [K / 64, M]
|
| 15 |
+
std::optional<torch::Tensor> wscales, // packed ws [K / 64, N]
|
| 16 |
+
std::optional<torch::Tensor> oscales, // packed as [N / 64, M]
|
| 17 |
+
std::optional<torch::Tensor> poolout, // linear [M / PoolSize, N]
|
| 18 |
+
std::optional<torch::Tensor> lora_act_in, // packed lora_act [M, R]
|
| 19 |
+
std::optional<torch::Tensor> lora_up, // packed lora_wgt [N, R]
|
| 20 |
+
std::optional<torch::Tensor> lora_down, // packed lora_wgt [N, R]
|
| 21 |
+
std::optional<torch::Tensor> lora_act_out, // packed lora_act [M, R]
|
| 22 |
+
std::optional<torch::Tensor> norm_q, // linear [HEAD_DIM]
|
| 23 |
+
std::optional<torch::Tensor> norm_k, // linear [HEAD_DIM]
|
| 24 |
+
std::optional<torch::Tensor> rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
|
| 25 |
+
std::optional<torch::Tensor> bias, // packed ws [N]
|
| 26 |
+
std::optional<torch::Tensor> smooth_factor, // packed ws [N], for quantization of the next layer
|
| 27 |
+
std::optional<torch::Tensor> out_vk, // linear [B, num_heads, head_dim + 1, head_dim]
|
| 28 |
+
std::optional<torch::Tensor> out_linearattn, // linear [B, (M), N / 3]
|
| 29 |
+
bool act_unsigned,
|
| 30 |
+
std::vector<float> lora_scales,
|
| 31 |
+
bool fuse_silu,
|
| 32 |
+
bool fp4,
|
| 33 |
+
float alpha,
|
| 34 |
+
std::optional<torch::Tensor> wcscales,
|
| 35 |
+
std::optional<torch::Tensor> out_q, // packed attention [B, H, M, D]
|
| 36 |
+
std::optional<torch::Tensor> out_k, // packed attention [B, H, M, D]
|
| 37 |
+
std::optional<torch::Tensor> out_v, // packed attention [B, H, M, D]
|
| 38 |
+
int attn_tokens) {
|
| 39 |
+
TorchOpContext ctx;
|
| 40 |
+
spdlog::trace("running gemm_w4a4: ");
|
| 41 |
+
|
| 42 |
+
auto getTensor = [](std::optional<torch::Tensor> &t) {
|
| 43 |
+
Tensor ret = t.has_value() ? from_torch(t.value()) : Tensor{};
|
| 44 |
+
if (ret.valid()) {
|
| 45 |
+
spdlog::trace(" {}", ret.shape.str());
|
| 46 |
+
} else {
|
| 47 |
+
spdlog::trace(" <invalid>");
|
| 48 |
+
}
|
| 49 |
+
return ret;
|
| 50 |
+
};
|
| 51 |
+
nunchaku::kernels::gemm_w4a4(getTensor(act),
|
| 52 |
+
getTensor(wgt),
|
| 53 |
+
getTensor(out),
|
| 54 |
+
getTensor(qout),
|
| 55 |
+
getTensor(ascales),
|
| 56 |
+
getTensor(wscales),
|
| 57 |
+
getTensor(oscales),
|
| 58 |
+
getTensor(poolout),
|
| 59 |
+
getTensor(lora_act_in),
|
| 60 |
+
getTensor(lora_up),
|
| 61 |
+
getTensor(lora_down),
|
| 62 |
+
getTensor(lora_act_out),
|
| 63 |
+
getTensor(norm_q),
|
| 64 |
+
getTensor(norm_k),
|
| 65 |
+
getTensor(rotary_emb),
|
| 66 |
+
getTensor(bias),
|
| 67 |
+
getTensor(smooth_factor),
|
| 68 |
+
getTensor(out_vk),
|
| 69 |
+
getTensor(out_linearattn),
|
| 70 |
+
act_unsigned,
|
| 71 |
+
lora_scales,
|
| 72 |
+
fuse_silu,
|
| 73 |
+
fp4,
|
| 74 |
+
alpha,
|
| 75 |
+
getTensor(wcscales),
|
| 76 |
+
getTensor(out_q),
|
| 77 |
+
getTensor(out_k),
|
| 78 |
+
getTensor(out_v),
|
| 79 |
+
attn_tokens);
|
| 80 |
+
// Tensor::synchronizeDevice();
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
void quantize_w4a4_act_fuse_lora(std::optional<torch::Tensor> input,
|
| 84 |
+
std::optional<torch::Tensor> output,
|
| 85 |
+
std::optional<torch::Tensor> oscales,
|
| 86 |
+
std::optional<torch::Tensor> lora_down,
|
| 87 |
+
std::optional<torch::Tensor> lora_act_out,
|
| 88 |
+
std::optional<torch::Tensor> smooth,
|
| 89 |
+
bool fuse_glu,
|
| 90 |
+
bool fp4) {
|
| 91 |
+
TorchOpContext ctx;
|
| 92 |
+
|
| 93 |
+
spdlog::trace("running quantize_w4a4_act_fuse_lora: ");
|
| 94 |
+
|
| 95 |
+
auto getTensor = [](std::optional<torch::Tensor> &t) {
|
| 96 |
+
Tensor ret = t.has_value() ? from_torch(t.value()) : Tensor{};
|
| 97 |
+
if (ret.valid()) {
|
| 98 |
+
spdlog::trace(" {}", ret.shape.str());
|
| 99 |
+
} else {
|
| 100 |
+
spdlog::trace(" <invalid>");
|
| 101 |
+
}
|
| 102 |
+
return ret;
|
| 103 |
+
};
|
| 104 |
+
nunchaku::kernels::quantize_w4a4_act_fuse_lora(getTensor(input),
|
| 105 |
+
getTensor(output),
|
| 106 |
+
getTensor(oscales),
|
| 107 |
+
getTensor(lora_down),
|
| 108 |
+
getTensor(lora_act_out),
|
| 109 |
+
getTensor(smooth),
|
| 110 |
+
fuse_glu,
|
| 111 |
+
fp4);
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
void attention_fp16(torch::Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
|
| 115 |
+
torch::Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM]
|
| 116 |
+
torch::Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM]
|
| 117 |
+
torch::Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM]
|
| 118 |
+
float scale) {
|
| 119 |
+
TorchOpContext ctx;
|
| 120 |
+
nunchaku::kernels::attention_fp16(from_torch(q), from_torch(k), from_torch(v), from_torch(o), scale);
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
torch::Tensor gemv_awq(torch::Tensor _in_feats,
|
| 124 |
+
torch::Tensor _kernel,
|
| 125 |
+
torch::Tensor _scaling_factors,
|
| 126 |
+
torch::Tensor _zeros,
|
| 127 |
+
int64_t m,
|
| 128 |
+
int64_t n,
|
| 129 |
+
int64_t k,
|
| 130 |
+
int64_t group_size) {
|
| 131 |
+
TorchOpContext ctx;
|
| 132 |
+
Tensor result = ::gemv_awq(from_torch(_in_feats.contiguous()),
|
| 133 |
+
from_torch(_kernel.contiguous()),
|
| 134 |
+
from_torch(_scaling_factors.contiguous()),
|
| 135 |
+
from_torch(_zeros.contiguous()),
|
| 136 |
+
(int)m,
|
| 137 |
+
(int)n,
|
| 138 |
+
(int)k,
|
| 139 |
+
(int)group_size);
|
| 140 |
+
|
| 141 |
+
torch::Tensor output = to_torch(result);
|
| 142 |
+
// Tensor::synchronizeDevice();
|
| 143 |
+
|
| 144 |
+
return output;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
torch::Tensor
|
| 148 |
+
gemm_awq(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros) {
|
| 149 |
+
Tensor result = ::awq_gemm_forward_cuda(from_torch(_in_feats.contiguous()),
|
| 150 |
+
from_torch(_kernel.contiguous()),
|
| 151 |
+
from_torch(_scaling_factors.contiguous()),
|
| 152 |
+
from_torch(_zeros.contiguous()));
|
| 153 |
+
|
| 154 |
+
TorchOpContext ctx;
|
| 155 |
+
// TODO: allocate output in torch and use from_torch instead (to_torch needs an extra copy)
|
| 156 |
+
torch::Tensor output = to_torch(result);
|
| 157 |
+
// Tensor::synchronizeDevice();
|
| 158 |
+
|
| 159 |
+
return output;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
void test_rmsnorm_rope(
|
| 163 |
+
torch::Tensor input, torch::Tensor output, torch::Tensor norm_q, torch::Tensor norm_k, torch::Tensor rotary_emb) {
|
| 164 |
+
nunchaku::kernels::test_rmsnorm_rope(
|
| 165 |
+
from_torch(input), from_torch(output), from_torch(norm_q), from_torch(norm_k), from_torch(rotary_emb));
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
void test_pack_qkv(torch::Tensor input, torch::Tensor out_q, torch::Tensor out_k, torch::Tensor out_v, int numTokens) {
|
| 169 |
+
nunchaku::kernels::test_pack_qkv(
|
| 170 |
+
from_torch(input), from_torch(out_q), from_torch(out_k), from_torch(out_v), numTokens);
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
}; // namespace nunchaku::ops
|
nunchaku/csrc/pybind.cpp
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "gemm.h"
|
| 2 |
+
#include "gemm88.h"
|
| 3 |
+
#include "flux.h"
|
| 4 |
+
#include "sana.h"
|
| 5 |
+
#include "ops.h"
|
| 6 |
+
#include "utils.h"
|
| 7 |
+
#include <torch/extension.h>
|
| 8 |
+
#include "interop/torch.h"
|
| 9 |
+
#include <pybind11/pybind11.h>
|
| 10 |
+
|
| 11 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 12 |
+
py::class_<QuantizedFluxModel>(m, "QuantizedFluxModel")
|
| 13 |
+
.def(py::init<>())
|
| 14 |
+
.def("init",
|
| 15 |
+
&QuantizedFluxModel::init,
|
| 16 |
+
py::arg("use_fp4"),
|
| 17 |
+
py::arg("offload"),
|
| 18 |
+
py::arg("bf16"),
|
| 19 |
+
py::arg("deviceId"))
|
| 20 |
+
.def("set_residual_callback",
|
| 21 |
+
[](QuantizedFluxModel &self, pybind11::object call_back) {
|
| 22 |
+
if (call_back.is_none()) {
|
| 23 |
+
self.set_residual_callback(pybind11::function());
|
| 24 |
+
} else {
|
| 25 |
+
self.set_residual_callback(call_back);
|
| 26 |
+
}
|
| 27 |
+
})
|
| 28 |
+
.def("reset", &QuantizedFluxModel::reset)
|
| 29 |
+
.def("load", &QuantizedFluxModel::load, py::arg("path"), py::arg("partial") = false)
|
| 30 |
+
.def("loadDict", &QuantizedFluxModel::loadDict, py::arg("dict"), py::arg("partial") = false)
|
| 31 |
+
.def("forward",
|
| 32 |
+
&QuantizedFluxModel::forward,
|
| 33 |
+
py::arg("hidden_states"),
|
| 34 |
+
py::arg("encoder_hidden_states"),
|
| 35 |
+
py::arg("temb"),
|
| 36 |
+
py::arg("rotary_emb_img"),
|
| 37 |
+
py::arg("rotary_emb_context"),
|
| 38 |
+
py::arg("rotary_emb_single"),
|
| 39 |
+
py::arg("controlnet_block_samples") = py::none(),
|
| 40 |
+
py::arg("controlnet_single_block_samples") = py::none(),
|
| 41 |
+
py::arg("skip_first_layer") = false)
|
| 42 |
+
.def("forward_layer",
|
| 43 |
+
&QuantizedFluxModel::forward_layer,
|
| 44 |
+
py::arg("idx"),
|
| 45 |
+
py::arg("hidden_states"),
|
| 46 |
+
py::arg("encoder_hidden_states"),
|
| 47 |
+
py::arg("temb"),
|
| 48 |
+
py::arg("rotary_emb_img"),
|
| 49 |
+
py::arg("rotary_emb_context"),
|
| 50 |
+
py::arg("controlnet_block_samples") = py::none(),
|
| 51 |
+
py::arg("controlnet_single_block_samples") = py::none())
|
| 52 |
+
.def("forward_layer_ip_adapter",
|
| 53 |
+
&QuantizedFluxModel::forward_layer_ip_adapter,
|
| 54 |
+
py::arg("idx"),
|
| 55 |
+
py::arg("hidden_states"),
|
| 56 |
+
py::arg("encoder_hidden_states"),
|
| 57 |
+
py::arg("temb"),
|
| 58 |
+
py::arg("rotary_emb_img"),
|
| 59 |
+
py::arg("rotary_emb_context"),
|
| 60 |
+
py::arg("controlnet_block_samples") = py::none(),
|
| 61 |
+
py::arg("controlnet_single_block_samples") = py::none())
|
| 62 |
+
.def("forward_single_layer", &QuantizedFluxModel::forward_single_layer)
|
| 63 |
+
.def("norm_one_forward", &QuantizedFluxModel::norm_one_forward)
|
| 64 |
+
.def("startDebug", &QuantizedFluxModel::startDebug)
|
| 65 |
+
.def("stopDebug", &QuantizedFluxModel::stopDebug)
|
| 66 |
+
.def("getDebugResults", &QuantizedFluxModel::getDebugResults)
|
| 67 |
+
.def("setLoraScale", &QuantizedFluxModel::setLoraScale)
|
| 68 |
+
.def("setAttentionImpl", &QuantizedFluxModel::setAttentionImpl)
|
| 69 |
+
.def("isBF16", &QuantizedFluxModel::isBF16);
|
| 70 |
+
py::class_<QuantizedSanaModel>(m, "QuantizedSanaModel")
|
| 71 |
+
.def(py::init<>())
|
| 72 |
+
.def("init",
|
| 73 |
+
&QuantizedSanaModel::init,
|
| 74 |
+
py::arg("config"),
|
| 75 |
+
py::arg("pag_layers"),
|
| 76 |
+
py::arg("use_fp4"),
|
| 77 |
+
py::arg("bf16"),
|
| 78 |
+
py::arg("deviceId"))
|
| 79 |
+
.def("reset", &QuantizedSanaModel::reset)
|
| 80 |
+
.def("load", &QuantizedSanaModel::load, py::arg("path"), py::arg("partial") = false)
|
| 81 |
+
.def("loadDict", &QuantizedSanaModel::loadDict, py::arg("dict"), py::arg("partial") = false)
|
| 82 |
+
.def("forward", &QuantizedSanaModel::forward)
|
| 83 |
+
.def("forward_layer", &QuantizedSanaModel::forward_layer)
|
| 84 |
+
.def("startDebug", &QuantizedSanaModel::startDebug)
|
| 85 |
+
.def("stopDebug", &QuantizedSanaModel::stopDebug)
|
| 86 |
+
.def("getDebugResults", &QuantizedSanaModel::getDebugResults);
|
| 87 |
+
py::class_<QuantizedGEMM>(m, "QuantizedGEMM")
|
| 88 |
+
.def(py::init<>())
|
| 89 |
+
.def("init", &QuantizedGEMM::init)
|
| 90 |
+
.def("reset", &QuantizedGEMM::reset)
|
| 91 |
+
.def("load", &QuantizedGEMM::load)
|
| 92 |
+
.def("forward", &QuantizedGEMM::forward)
|
| 93 |
+
.def("quantize", &QuantizedGEMM::quantize)
|
| 94 |
+
.def("startDebug", &QuantizedGEMM::startDebug)
|
| 95 |
+
.def("stopDebug", &QuantizedGEMM::stopDebug)
|
| 96 |
+
.def("getDebugResults", &QuantizedGEMM::getDebugResults);
|
| 97 |
+
py::class_<Tensor>(m, "Tensor");
|
| 98 |
+
py::class_<QuantizedGEMM88>(m, "QuantizedGEMM88")
|
| 99 |
+
.def(py::init<>())
|
| 100 |
+
.def("init", &QuantizedGEMM88::init)
|
| 101 |
+
.def("reset", &QuantizedGEMM88::reset)
|
| 102 |
+
.def("load", &QuantizedGEMM88::load)
|
| 103 |
+
.def("forward", &QuantizedGEMM88::forward)
|
| 104 |
+
.def("startDebug", &QuantizedGEMM88::startDebug)
|
| 105 |
+
.def("stopDebug", &QuantizedGEMM88::stopDebug)
|
| 106 |
+
.def("getDebugResults", &QuantizedGEMM88::getDebugResults);
|
| 107 |
+
|
| 108 |
+
m.def_submodule("ops")
|
| 109 |
+
.def("gemm_w4a4", nunchaku::ops::gemm_w4a4)
|
| 110 |
+
.def("quantize_w4a4_act_fuse_lora", nunchaku::ops::quantize_w4a4_act_fuse_lora)
|
| 111 |
+
.def("attention_fp16", nunchaku::ops::attention_fp16)
|
| 112 |
+
.def("gemm_awq", nunchaku::ops::gemm_awq)
|
| 113 |
+
.def("gemv_awq", nunchaku::ops::gemv_awq)
|
| 114 |
+
|
| 115 |
+
.def("test_rmsnorm_rope", nunchaku::ops::test_rmsnorm_rope)
|
| 116 |
+
.def("test_pack_qkv", nunchaku::ops::test_pack_qkv);
|
| 117 |
+
|
| 118 |
+
m.def_submodule("utils")
|
| 119 |
+
.def("set_log_level", [](const std::string &level) { spdlog::set_level(spdlog::level::from_str(level)); })
|
| 120 |
+
.def("set_cuda_stack_limit", nunchaku::utils::set_cuda_stack_limit)
|
| 121 |
+
.def("disable_memory_auto_release", nunchaku::utils::disable_memory_auto_release)
|
| 122 |
+
.def("trim_memory", nunchaku::utils::trim_memory)
|
| 123 |
+
.def("set_faster_i2f_mode", nunchaku::utils::set_faster_i2f_mode);
|
| 124 |
+
}
|
nunchaku/csrc/sana.h
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "interop/torch.h"
|
| 4 |
+
#include "SanaModel.h"
|
| 5 |
+
#include "Serialization.h"
|
| 6 |
+
#include "debug.h"
|
| 7 |
+
#include "module.h"
|
| 8 |
+
|
| 9 |
+
class QuantizedSanaModel : public ModuleWrapper<SanaModel> {
|
| 10 |
+
public:
|
| 11 |
+
void init(pybind11::dict config, std::vector<int> pag_layers, bool use_fp4, bool bf16, int8_t deviceId) {
|
| 12 |
+
spdlog::info("Initializing QuantizedSanaModel on device {}", deviceId);
|
| 13 |
+
SanaConfig cfg{
|
| 14 |
+
.num_layers = config["num_layers"].cast<int>(),
|
| 15 |
+
.num_attention_heads = config["num_attention_heads"].cast<int>(),
|
| 16 |
+
.attention_head_dim = config["attention_head_dim"].cast<int>(),
|
| 17 |
+
.num_cross_attention_heads = config["num_cross_attention_heads"].cast<int>(),
|
| 18 |
+
.expand_ratio = config["mlp_ratio"].cast<double>(),
|
| 19 |
+
.pag_layers = pag_layers,
|
| 20 |
+
.use_fp4 = use_fp4,
|
| 21 |
+
};
|
| 22 |
+
|
| 23 |
+
ModuleWrapper::init(deviceId);
|
| 24 |
+
CUDADeviceContext ctx(this->deviceId);
|
| 25 |
+
net = std::make_unique<SanaModel>(cfg, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
torch::Tensor forward(torch::Tensor hidden_states,
|
| 29 |
+
torch::Tensor encoder_hidden_states,
|
| 30 |
+
torch::Tensor timestep,
|
| 31 |
+
torch::Tensor cu_seqlens_img,
|
| 32 |
+
torch::Tensor cu_seqlens_txt,
|
| 33 |
+
int H,
|
| 34 |
+
int W,
|
| 35 |
+
bool pag,
|
| 36 |
+
bool cfg,
|
| 37 |
+
bool skip_first_layer = false) {
|
| 38 |
+
checkModel();
|
| 39 |
+
CUDADeviceContext ctx(deviceId);
|
| 40 |
+
|
| 41 |
+
spdlog::debug("QuantizedSanaModel forward");
|
| 42 |
+
|
| 43 |
+
hidden_states = hidden_states.contiguous();
|
| 44 |
+
encoder_hidden_states = encoder_hidden_states.contiguous();
|
| 45 |
+
timestep = timestep.contiguous();
|
| 46 |
+
cu_seqlens_img = cu_seqlens_img.contiguous();
|
| 47 |
+
cu_seqlens_txt = cu_seqlens_txt.contiguous();
|
| 48 |
+
|
| 49 |
+
Tensor result = net->forward(from_torch(hidden_states),
|
| 50 |
+
from_torch(encoder_hidden_states),
|
| 51 |
+
from_torch(timestep),
|
| 52 |
+
from_torch(cu_seqlens_img),
|
| 53 |
+
from_torch(cu_seqlens_txt),
|
| 54 |
+
H,
|
| 55 |
+
W,
|
| 56 |
+
pag,
|
| 57 |
+
cfg,
|
| 58 |
+
skip_first_layer);
|
| 59 |
+
|
| 60 |
+
torch::Tensor output = to_torch(result);
|
| 61 |
+
// Tensor::synchronizeDevice();
|
| 62 |
+
|
| 63 |
+
return output;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
torch::Tensor forward_layer(int64_t idx,
|
| 67 |
+
torch::Tensor hidden_states,
|
| 68 |
+
torch::Tensor encoder_hidden_states,
|
| 69 |
+
torch::Tensor timestep,
|
| 70 |
+
torch::Tensor cu_seqlens_img,
|
| 71 |
+
torch::Tensor cu_seqlens_txt,
|
| 72 |
+
int H,
|
| 73 |
+
int W,
|
| 74 |
+
bool pag,
|
| 75 |
+
bool cfg) {
|
| 76 |
+
checkModel();
|
| 77 |
+
CUDADeviceContext ctx(deviceId);
|
| 78 |
+
|
| 79 |
+
spdlog::debug("QuantizedSanaModel forward_layer {}", idx);
|
| 80 |
+
|
| 81 |
+
hidden_states = hidden_states.contiguous();
|
| 82 |
+
encoder_hidden_states = encoder_hidden_states.contiguous();
|
| 83 |
+
timestep = timestep.contiguous();
|
| 84 |
+
cu_seqlens_img = cu_seqlens_img.contiguous();
|
| 85 |
+
cu_seqlens_txt = cu_seqlens_txt.contiguous();
|
| 86 |
+
|
| 87 |
+
Tensor result = net->transformer_blocks.at(idx)->forward(from_torch(hidden_states),
|
| 88 |
+
from_torch(encoder_hidden_states),
|
| 89 |
+
from_torch(timestep),
|
| 90 |
+
from_torch(cu_seqlens_img),
|
| 91 |
+
from_torch(cu_seqlens_txt),
|
| 92 |
+
H,
|
| 93 |
+
W,
|
| 94 |
+
pag,
|
| 95 |
+
cfg);
|
| 96 |
+
|
| 97 |
+
torch::Tensor output = to_torch(result);
|
| 98 |
+
// Tensor::synchronizeDevice();
|
| 99 |
+
|
| 100 |
+
return output;
|
| 101 |
+
}
|
| 102 |
+
};
|
nunchaku/csrc/utils.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "common.h"
|
| 4 |
+
#include "Tensor.h"
|
| 5 |
+
#include "kernels/zgemm/zgemm.h"
|
| 6 |
+
|
| 7 |
+
namespace nunchaku::utils {
|
| 8 |
+
|
| 9 |
+
void set_cuda_stack_limit(int64_t newval) {
|
| 10 |
+
size_t val = 0;
|
| 11 |
+
checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, (size_t)newval));
|
| 12 |
+
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
|
| 13 |
+
spdlog::debug("Stack={}", val);
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
void disable_memory_auto_release() {
|
| 17 |
+
int device;
|
| 18 |
+
checkCUDA(cudaGetDevice(&device));
|
| 19 |
+
cudaMemPool_t mempool;
|
| 20 |
+
checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device));
|
| 21 |
+
uint64_t threshold = UINT64_MAX;
|
| 22 |
+
checkCUDA(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &threshold));
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
void trim_memory() {
|
| 26 |
+
int device;
|
| 27 |
+
checkCUDA(cudaGetDevice(&device));
|
| 28 |
+
cudaMemPool_t mempool;
|
| 29 |
+
checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device));
|
| 30 |
+
size_t bytesToKeep = 0;
|
| 31 |
+
checkCUDA(cudaMemPoolTrimTo(mempool, bytesToKeep));
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
void set_faster_i2f_mode(std::string mode) {
|
| 35 |
+
spdlog::info("Set fasteri2f mode to {}", mode);
|
| 36 |
+
kernels::set_faster_i2f_mode(mode);
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
}; // namespace nunchaku::utils
|
nunchaku/lora/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# LoRA utilities for FLUX models
|
nunchaku/lora/flux/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .diffusers_converter import to_diffusers
|
| 2 |
+
from .nunchaku_converter import convert_to_nunchaku_flux_lowrank_dict, to_nunchaku
|
| 3 |
+
from .utils import is_nunchaku_format
|
| 4 |
+
|
| 5 |
+
__all__ = ["to_diffusers", "to_nunchaku", "convert_to_nunchaku_flux_lowrank_dict", "is_nunchaku_format"]
|
nunchaku/lora/flux/compose.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Compose multiple LoRA weights into a single LoRA for FLUX models.
|
| 3 |
+
|
| 4 |
+
This script merges several LoRA safetensors files into one, applying individual strength values to each.
|
| 5 |
+
|
| 6 |
+
**Example Usage:**
|
| 7 |
+
|
| 8 |
+
.. code-block:: bash
|
| 9 |
+
|
| 10 |
+
python -m nunchaku.lora.flux.compose \\
|
| 11 |
+
-i lora1.safetensors lora2.safetensors \\
|
| 12 |
+
-s 0.8 1.0 \\
|
| 13 |
+
-o composed_lora.safetensors
|
| 14 |
+
|
| 15 |
+
**Arguments:**
|
| 16 |
+
|
| 17 |
+
- ``-i``, ``--input-paths``: Input LoRA safetensors files (one or more).
|
| 18 |
+
- ``-s``, ``--strengths``: Strength value for each LoRA (must match number of inputs).
|
| 19 |
+
- ``-o``, ``--output-path``: Output path for the composed LoRA safetensors file.
|
| 20 |
+
|
| 21 |
+
This will merge ``lora1.safetensors`` (strength 0.8) and ``lora2.safetensors`` (strength 1.0) into ``composed_lora.safetensors``.
|
| 22 |
+
|
| 23 |
+
**Main Function**
|
| 24 |
+
|
| 25 |
+
:func:`compose_lora`
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
import argparse
|
| 29 |
+
import os
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
import torch.nn.functional as F
|
| 33 |
+
from safetensors.torch import save_file
|
| 34 |
+
|
| 35 |
+
from .diffusers_converter import to_diffusers
|
| 36 |
+
from .utils import is_nunchaku_format, load_state_dict_in_safetensors
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def compose_lora(
|
| 40 |
+
loras: list[tuple[str | dict[str, torch.Tensor], float]], output_path: str | None = None
|
| 41 |
+
) -> dict[str, torch.Tensor]:
|
| 42 |
+
"""
|
| 43 |
+
Compose multiple LoRA weights into a single LoRA representation.
|
| 44 |
+
|
| 45 |
+
Parameters
|
| 46 |
+
----------
|
| 47 |
+
loras : list of (str or dict[str, torch.Tensor], float)
|
| 48 |
+
Each tuple contains:
|
| 49 |
+
- Path to a LoRA safetensors file or a LoRA weights dictionary.
|
| 50 |
+
- Strength/scale factor for that LoRA.
|
| 51 |
+
output_path : str, optional
|
| 52 |
+
Path to save the composed LoRA weights as a safetensors file. If None, does not save.
|
| 53 |
+
|
| 54 |
+
Returns
|
| 55 |
+
-------
|
| 56 |
+
dict[str, torch.Tensor]
|
| 57 |
+
The composed LoRA weights.
|
| 58 |
+
|
| 59 |
+
Raises
|
| 60 |
+
------
|
| 61 |
+
AssertionError
|
| 62 |
+
If LoRA weights are in Nunchaku format (must be converted to Diffusers format first)
|
| 63 |
+
or if tensor shapes are incompatible.
|
| 64 |
+
|
| 65 |
+
Notes
|
| 66 |
+
-----
|
| 67 |
+
- Converts all input LoRAs to Diffusers format.
|
| 68 |
+
- Handles QKV projection fusion for attention layers.
|
| 69 |
+
- Applies strength scaling to LoRA weights.
|
| 70 |
+
- Concatenates multiple LoRAs along appropriate dimensions.
|
| 71 |
+
- Handles normalization layers, bias vectors, and FLUX.1-tools LoRA compatibility.
|
| 72 |
+
|
| 73 |
+
Examples
|
| 74 |
+
--------
|
| 75 |
+
>>> lora_paths = [("lora1.safetensors", 0.8), ("lora2.safetensors", 0.6)]
|
| 76 |
+
>>> composed = compose_lora(lora_paths, "composed_lora.safetensors")
|
| 77 |
+
>>> lora_dicts = [({"layer.weight": torch.randn(10, 20)}, 1.0)]
|
| 78 |
+
>>> composed = compose_lora(lora_dicts)
|
| 79 |
+
"""
|
| 80 |
+
if len(loras) == 1:
|
| 81 |
+
if is_nunchaku_format(loras[0][0]) and (loras[0][1] - 1) < 1e-5:
|
| 82 |
+
if isinstance(loras[0][0], str):
|
| 83 |
+
return load_state_dict_in_safetensors(loras[0][0], device="cpu")
|
| 84 |
+
else:
|
| 85 |
+
return loras[0][0]
|
| 86 |
+
|
| 87 |
+
composed = {}
|
| 88 |
+
for lora, strength in loras:
|
| 89 |
+
assert not is_nunchaku_format(lora)
|
| 90 |
+
lora = to_diffusers(lora)
|
| 91 |
+
for k, v in list(lora.items()):
|
| 92 |
+
if v.ndim == 1:
|
| 93 |
+
previous_tensor = composed.get(k, None)
|
| 94 |
+
if previous_tensor is None:
|
| 95 |
+
if "norm_q" in k or "norm_k" in k or "norm_added_q" in k or "norm_added_k" in k:
|
| 96 |
+
composed[k] = v
|
| 97 |
+
else:
|
| 98 |
+
composed[k] = v * strength
|
| 99 |
+
else:
|
| 100 |
+
assert not ("norm_q" in k or "norm_k" in k or "norm_added_q" in k or "norm_added_k" in k)
|
| 101 |
+
composed[k] = previous_tensor + v * strength
|
| 102 |
+
else:
|
| 103 |
+
assert v.ndim == 2
|
| 104 |
+
if ".to_q." in k or ".add_q_proj." in k: # qkv must all exist
|
| 105 |
+
if "lora_B" in k:
|
| 106 |
+
continue
|
| 107 |
+
|
| 108 |
+
q_a = v
|
| 109 |
+
k_a = lora[k.replace(".to_q.", ".to_k.").replace(".add_q_proj.", ".add_k_proj.")]
|
| 110 |
+
v_a = lora[k.replace(".to_q.", ".to_v.").replace(".add_q_proj.", ".add_v_proj.")]
|
| 111 |
+
|
| 112 |
+
q_b = lora[k.replace("lora_A", "lora_B")]
|
| 113 |
+
k_b = lora[
|
| 114 |
+
k.replace("lora_A", "lora_B")
|
| 115 |
+
.replace(".to_q.", ".to_k.")
|
| 116 |
+
.replace(".add_q_proj.", ".add_k_proj.")
|
| 117 |
+
]
|
| 118 |
+
v_b = lora[
|
| 119 |
+
k.replace("lora_A", "lora_B")
|
| 120 |
+
.replace(".to_q.", ".to_v.")
|
| 121 |
+
.replace(".add_q_proj.", ".add_v_proj.")
|
| 122 |
+
]
|
| 123 |
+
|
| 124 |
+
# Add paddings if their ranks are different
|
| 125 |
+
max_rank = max(q_a.shape[0], k_a.shape[0], v_a.shape[0])
|
| 126 |
+
q_a = F.pad(q_a, (0, 0, 0, max_rank - q_a.shape[0]))
|
| 127 |
+
k_a = F.pad(k_a, (0, 0, 0, max_rank - k_a.shape[0]))
|
| 128 |
+
v_a = F.pad(v_a, (0, 0, 0, max_rank - v_a.shape[0]))
|
| 129 |
+
q_b = F.pad(q_b, (0, max_rank - q_b.shape[1]))
|
| 130 |
+
k_b = F.pad(k_b, (0, max_rank - k_b.shape[1]))
|
| 131 |
+
v_b = F.pad(v_b, (0, max_rank - v_b.shape[1]))
|
| 132 |
+
|
| 133 |
+
if torch.isclose(q_a, k_a).all() and torch.isclose(q_a, v_a).all():
|
| 134 |
+
lora_a = q_a
|
| 135 |
+
lora_b = torch.cat((q_b, k_b, v_b), dim=0)
|
| 136 |
+
else:
|
| 137 |
+
lora_a_group = (q_a, k_a, v_a)
|
| 138 |
+
new_shape_a = [sum([_.shape[0] for _ in lora_a_group]), q_a.shape[1]]
|
| 139 |
+
lora_a = torch.zeros(new_shape_a, dtype=q_a.dtype, device=q_a.device)
|
| 140 |
+
start_dim = 0
|
| 141 |
+
for tensor in lora_a_group:
|
| 142 |
+
lora_a[start_dim : start_dim + tensor.shape[0]] = tensor
|
| 143 |
+
start_dim += tensor.shape[0]
|
| 144 |
+
|
| 145 |
+
lora_b_group = (q_b, k_b, v_b)
|
| 146 |
+
new_shape_b = [sum([_.shape[0] for _ in lora_b_group]), sum([_.shape[1] for _ in lora_b_group])]
|
| 147 |
+
lora_b = torch.zeros(new_shape_b, dtype=q_b.dtype, device=q_b.device)
|
| 148 |
+
start_dims = (0, 0)
|
| 149 |
+
for tensor in lora_b_group:
|
| 150 |
+
end_dims = (start_dims[0] + tensor.shape[0], start_dims[1] + tensor.shape[1])
|
| 151 |
+
lora_b[start_dims[0] : end_dims[0], start_dims[1] : end_dims[1]] = tensor
|
| 152 |
+
start_dims = end_dims
|
| 153 |
+
|
| 154 |
+
lora_a = lora_a * strength
|
| 155 |
+
|
| 156 |
+
new_k_a = k.replace(".to_q.", ".to_qkv.").replace(".add_q_proj.", ".add_qkv_proj.")
|
| 157 |
+
new_k_b = new_k_a.replace("lora_A", "lora_B")
|
| 158 |
+
|
| 159 |
+
for kk, vv, dim in ((new_k_a, lora_a, 0), (new_k_b, lora_b, 1)):
|
| 160 |
+
previous_lora = composed.get(kk, None)
|
| 161 |
+
composed[kk] = vv if previous_lora is None else torch.cat([previous_lora, vv], dim=dim)
|
| 162 |
+
|
| 163 |
+
elif ".to_k." in k or ".to_v." in k or ".add_k_proj." in k or ".add_v_proj." in k:
|
| 164 |
+
continue
|
| 165 |
+
else:
|
| 166 |
+
if "lora_A" in k:
|
| 167 |
+
v = v * strength
|
| 168 |
+
|
| 169 |
+
previous_lora = composed.get(k, None)
|
| 170 |
+
if previous_lora is None:
|
| 171 |
+
composed[k] = v
|
| 172 |
+
else:
|
| 173 |
+
if "lora_A" in k:
|
| 174 |
+
if previous_lora.shape[1] != v.shape[1]: # flux.1-tools LoRA compatibility
|
| 175 |
+
assert "x_embedder" in k
|
| 176 |
+
expanded_size = max(previous_lora.shape[1], v.shape[1])
|
| 177 |
+
if expanded_size > previous_lora.shape[1]:
|
| 178 |
+
expanded_previous_lora = torch.zeros(
|
| 179 |
+
(previous_lora.shape[0], expanded_size),
|
| 180 |
+
device=previous_lora.device,
|
| 181 |
+
dtype=previous_lora.dtype,
|
| 182 |
+
)
|
| 183 |
+
expanded_previous_lora[:, : previous_lora.shape[1]] = previous_lora
|
| 184 |
+
else:
|
| 185 |
+
expanded_previous_lora = previous_lora
|
| 186 |
+
if expanded_size > v.shape[1]:
|
| 187 |
+
expanded_v = torch.zeros(
|
| 188 |
+
(v.shape[0], expanded_size), device=v.device, dtype=v.dtype
|
| 189 |
+
)
|
| 190 |
+
expanded_v[:, : v.shape[1]] = v
|
| 191 |
+
else:
|
| 192 |
+
expanded_v = v
|
| 193 |
+
composed[k] = torch.cat([expanded_previous_lora, expanded_v], dim=0)
|
| 194 |
+
else:
|
| 195 |
+
composed[k] = torch.cat([previous_lora, v], dim=0)
|
| 196 |
+
else:
|
| 197 |
+
composed[k] = torch.cat([previous_lora, v], dim=1)
|
| 198 |
+
|
| 199 |
+
composed[k] = (
|
| 200 |
+
v if previous_lora is None else torch.cat([previous_lora, v], dim=0 if "lora_A" in k else 1)
|
| 201 |
+
)
|
| 202 |
+
if output_path is not None:
|
| 203 |
+
output_dir = os.path.dirname(os.path.abspath(output_path))
|
| 204 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 205 |
+
save_file(composed, output_path)
|
| 206 |
+
return composed
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
if __name__ == "__main__":
|
| 210 |
+
parser = argparse.ArgumentParser()
|
| 211 |
+
parser.add_argument(
|
| 212 |
+
"-i", "--input-paths", type=str, nargs="*", required=True, help="paths to the lora safetensors files"
|
| 213 |
+
)
|
| 214 |
+
parser.add_argument("-s", "--strengths", type=float, nargs="*", required=True, help="strengths for each lora")
|
| 215 |
+
parser.add_argument("-o", "--output-path", type=str, required=True, help="path to the output safetensors file")
|
| 216 |
+
args = parser.parse_args()
|
| 217 |
+
assert len(args.input_paths) == len(args.strengths)
|
| 218 |
+
compose_lora(list(zip(args.input_paths, args.strengths)), args.output_path)
|
nunchaku/lora/flux/convert.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CLI tool to convert LoRA weights to Nunchaku format.
|
| 3 |
+
|
| 4 |
+
**Example Usage:**
|
| 5 |
+
|
| 6 |
+
.. code-block:: bash
|
| 7 |
+
|
| 8 |
+
python -m nunchaku.lora.flux.convert \\
|
| 9 |
+
--lora-path composed_lora.safetensors \\
|
| 10 |
+
--quant-path mit-han-lab/svdq-int4-flux.1-dev/transformer_blocks.safetensors \\
|
| 11 |
+
--output-root ./converted \\
|
| 12 |
+
--dtype bfloat16
|
| 13 |
+
|
| 14 |
+
**Arguments:**
|
| 15 |
+
|
| 16 |
+
- ``--lora-path``: Path to the LoRA weights safetensor file (required)
|
| 17 |
+
- ``--quant-path``: Path to the quantized model safetensor file (default: ``mit-han-lab/svdq-int4-flux.1-dev/transformer_blocks.safetensors``)
|
| 18 |
+
- ``--output-root``: Root directory for the output safetensor file (default: parent directory of the lora file)
|
| 19 |
+
- ``--lora-name``: Name of the LoRA weights (optional, auto-generated if not provided)
|
| 20 |
+
- ``--dtype``: Data type of the converted weights, either ``bfloat16`` or ``float16`` (default: ``bfloat16``)
|
| 21 |
+
|
| 22 |
+
**Main Function**
|
| 23 |
+
|
| 24 |
+
:func:`nunchaku.lora.flux.nunchaku_converter.to_nunchaku`
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
import argparse
|
| 28 |
+
import os
|
| 29 |
+
|
| 30 |
+
from .nunchaku_converter import to_nunchaku
|
| 31 |
+
from .utils import is_nunchaku_format
|
| 32 |
+
|
| 33 |
+
if __name__ == "__main__":
|
| 34 |
+
parser = argparse.ArgumentParser()
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--quant-path",
|
| 37 |
+
type=str,
|
| 38 |
+
help="Path to the quantized model safetensors file.",
|
| 39 |
+
default="mit-han-lab/svdq-int4-flux.1-dev/transformer_blocks.safetensors",
|
| 40 |
+
)
|
| 41 |
+
parser.add_argument("--lora-path", type=str, required=True, help="Path to LoRA weights safetensors file.")
|
| 42 |
+
parser.add_argument("--output-root", type=str, default="", help="Root directory for output safetensors file.")
|
| 43 |
+
parser.add_argument("--lora-name", type=str, default=None, help="Name for the output LoRA weights.")
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--dtype",
|
| 46 |
+
type=str,
|
| 47 |
+
default="bfloat16",
|
| 48 |
+
choices=["bfloat16", "float16"],
|
| 49 |
+
help="Data type of the converted weights.",
|
| 50 |
+
)
|
| 51 |
+
args = parser.parse_args()
|
| 52 |
+
|
| 53 |
+
if is_nunchaku_format(args.lora_path):
|
| 54 |
+
print("Already in Nunchaku format, no conversion needed.")
|
| 55 |
+
exit(0)
|
| 56 |
+
|
| 57 |
+
if not args.output_root:
|
| 58 |
+
args.output_root = os.path.dirname(args.lora_path)
|
| 59 |
+
if args.lora_name is None:
|
| 60 |
+
base_name = os.path.basename(args.lora_path)
|
| 61 |
+
lora_name = base_name.rsplit(".", 1)[0]
|
| 62 |
+
precision = "fp4" if "fp4" in args.quant_path else "int4"
|
| 63 |
+
lora_name = f"svdq-{precision}-{lora_name}"
|
| 64 |
+
print(f"LoRA name not provided, using {lora_name} as the LoRA name")
|
| 65 |
+
else:
|
| 66 |
+
lora_name = args.lora_name
|
| 67 |
+
assert lora_name, "LoRA name must be provided."
|
| 68 |
+
|
| 69 |
+
to_nunchaku(
|
| 70 |
+
args.lora_path,
|
| 71 |
+
args.quant_path,
|
| 72 |
+
dtype=args.dtype,
|
| 73 |
+
output_path=os.path.join(args.output_root, f"{lora_name}.safetensors"),
|
| 74 |
+
)
|
nunchaku/lora/flux/diffusers_converter.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This module implements the functions to convert FLUX LoRA weights from various formats
|
| 3 |
+
to the Diffusers format, which will later be converted to Nunchaku format.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from diffusers.loaders import FluxLoraLoaderMixin
|
| 12 |
+
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
|
| 13 |
+
from safetensors.torch import save_file
|
| 14 |
+
|
| 15 |
+
from ...utils import load_state_dict_in_safetensors
|
| 16 |
+
|
| 17 |
+
# Get log level from environment variable (default to INFO)
|
| 18 |
+
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
|
| 19 |
+
|
| 20 |
+
# Configure logging
|
| 21 |
+
logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def handle_kohya_lora(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
| 26 |
+
"""
|
| 27 |
+
Convert Kohya LoRA format keys to Diffusers format.
|
| 28 |
+
|
| 29 |
+
Parameters
|
| 30 |
+
----------
|
| 31 |
+
state_dict : dict[str, torch.Tensor]
|
| 32 |
+
LoRA weights, possibly in Kohya format.
|
| 33 |
+
|
| 34 |
+
Returns
|
| 35 |
+
-------
|
| 36 |
+
dict[str, torch.Tensor]
|
| 37 |
+
LoRA weights in Diffusers format.
|
| 38 |
+
"""
|
| 39 |
+
# first check if the state_dict is in the kohya format
|
| 40 |
+
# like: https://civitai.com/models/1118358?modelVersionId=1256866
|
| 41 |
+
if any([not k.startswith("lora_transformer_") for k in state_dict.keys()]):
|
| 42 |
+
return state_dict
|
| 43 |
+
else:
|
| 44 |
+
new_state_dict = {}
|
| 45 |
+
for k, v in state_dict.items():
|
| 46 |
+
new_k = k.replace("lora_transformer_", "transformer.")
|
| 47 |
+
|
| 48 |
+
new_k = new_k.replace("norm_out_", "norm_out.")
|
| 49 |
+
|
| 50 |
+
new_k = new_k.replace("time_text_embed_", "time_text_embed.")
|
| 51 |
+
new_k = new_k.replace("guidance_embedder_", "guidance_embedder.")
|
| 52 |
+
new_k = new_k.replace("text_embedder_", "text_embedder.")
|
| 53 |
+
new_k = new_k.replace("timestep_embedder_", "timestep_embedder.")
|
| 54 |
+
|
| 55 |
+
new_k = new_k.replace("single_transformer_blocks_", "single_transformer_blocks.")
|
| 56 |
+
new_k = new_k.replace("_attn_", ".attn.")
|
| 57 |
+
new_k = new_k.replace("_norm_linear.", ".norm.linear.")
|
| 58 |
+
new_k = new_k.replace("_proj_mlp.", ".proj_mlp.")
|
| 59 |
+
new_k = new_k.replace("_proj_out.", ".proj_out.")
|
| 60 |
+
|
| 61 |
+
new_k = new_k.replace("transformer_blocks_", "transformer_blocks.")
|
| 62 |
+
new_k = new_k.replace("to_out_0.", "to_out.0.")
|
| 63 |
+
new_k = new_k.replace("_ff_context_net_0_proj.", ".ff_context.net.0.proj.")
|
| 64 |
+
new_k = new_k.replace("_ff_context_net_2.", ".ff_context.net.2.")
|
| 65 |
+
new_k = new_k.replace("_ff_net_0_proj.", ".ff.net.0.proj.")
|
| 66 |
+
new_k = new_k.replace("_ff_net_2.", ".ff.net.2.")
|
| 67 |
+
new_k = new_k.replace("_norm1_context_linear.", ".norm1_context.linear.")
|
| 68 |
+
new_k = new_k.replace("_norm1_linear.", ".norm1.linear.")
|
| 69 |
+
|
| 70 |
+
new_k = new_k.replace(".lora_down.", ".lora_A.")
|
| 71 |
+
new_k = new_k.replace(".lora_up.", ".lora_B.")
|
| 72 |
+
|
| 73 |
+
new_state_dict[new_k] = v
|
| 74 |
+
return new_state_dict
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def convert_peft_to_comfyui(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
| 78 |
+
"""
|
| 79 |
+
Convert PEFT format (base_model.model.*) to ComfyUI format (lora_unet_*).
|
| 80 |
+
|
| 81 |
+
Mapping rules:
|
| 82 |
+
- base_model.model.double_blocks.X.img_attn.proj → lora_unet_double_blocks_X_img_attn_proj
|
| 83 |
+
- base_model.model.single_blocks.X.linear1 → lora_unet_single_blocks_X_linear1
|
| 84 |
+
- base_model.model.final_layer.linear → lora_unet_final_layer_linear
|
| 85 |
+
- lora_A/lora_B → lora_down/lora_up
|
| 86 |
+
|
| 87 |
+
Parameters
|
| 88 |
+
----------
|
| 89 |
+
state_dict : dict[str, torch.Tensor]
|
| 90 |
+
LoRA weights in PEFT format
|
| 91 |
+
|
| 92 |
+
Returns
|
| 93 |
+
-------
|
| 94 |
+
dict[str, torch.Tensor]
|
| 95 |
+
LoRA weights in ComfyUI format
|
| 96 |
+
"""
|
| 97 |
+
converted_dict = {}
|
| 98 |
+
|
| 99 |
+
for key, value in state_dict.items():
|
| 100 |
+
new_key = key
|
| 101 |
+
|
| 102 |
+
if key.startswith("base_model.model."):
|
| 103 |
+
# Remove base_model.model. prefix
|
| 104 |
+
new_key = key.replace("base_model.model.", "")
|
| 105 |
+
|
| 106 |
+
# Convert to ComfyUI format with underscores
|
| 107 |
+
# Handle double_blocks
|
| 108 |
+
if "double_blocks" in new_key:
|
| 109 |
+
# Replace dots with underscores within the block structure
|
| 110 |
+
# e.g., double_blocks.0.img_attn.proj → double_blocks_0_img_attn_proj
|
| 111 |
+
new_key = new_key.replace("double_blocks.", "lora_unet_double_blocks_")
|
| 112 |
+
# Replace remaining dots with underscores
|
| 113 |
+
new_key = new_key.replace(".", "_")
|
| 114 |
+
|
| 115 |
+
# Handle single_blocks
|
| 116 |
+
elif "single_blocks" in new_key:
|
| 117 |
+
new_key = new_key.replace("single_blocks.", "lora_unet_single_blocks_")
|
| 118 |
+
# Special handling for modulation.lin → modulation_lin
|
| 119 |
+
new_key = new_key.replace("modulation.lin", "modulation_lin")
|
| 120 |
+
# Replace remaining dots with underscores
|
| 121 |
+
new_key = new_key.replace(".", "_")
|
| 122 |
+
|
| 123 |
+
# Handle final_layer
|
| 124 |
+
elif "final_layer" in new_key:
|
| 125 |
+
new_key = new_key.replace("final_layer.linear", "lora_unet_final_layer_linear")
|
| 126 |
+
# Replace remaining dots with underscores
|
| 127 |
+
new_key = new_key.replace(".", "_")
|
| 128 |
+
|
| 129 |
+
else:
|
| 130 |
+
# For any other keys, add lora_unet_ prefix and replace dots
|
| 131 |
+
new_key = "lora_unet_" + new_key.replace(".", "_")
|
| 132 |
+
|
| 133 |
+
# Convert lora_A/lora_B to lora_down/lora_up
|
| 134 |
+
new_key = new_key.replace("_lora_A_weight", ".lora_down.weight")
|
| 135 |
+
new_key = new_key.replace("_lora_B_weight", ".lora_up.weight")
|
| 136 |
+
|
| 137 |
+
converted_dict[new_key] = value
|
| 138 |
+
|
| 139 |
+
if key != new_key:
|
| 140 |
+
logger.debug(f"Converted: {key} → {new_key}")
|
| 141 |
+
|
| 142 |
+
return converted_dict
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | None = None) -> dict[str, torch.Tensor]:
|
| 146 |
+
"""
|
| 147 |
+
Convert LoRA weights to Diffusers format, which will later be converted to Nunchaku format.
|
| 148 |
+
|
| 149 |
+
Parameters
|
| 150 |
+
----------
|
| 151 |
+
input_lora : str or dict[str, torch.Tensor]
|
| 152 |
+
Path to a safetensors file or a LoRA weight dictionary.
|
| 153 |
+
output_path : str, optional
|
| 154 |
+
If given, save the converted weights to this path.
|
| 155 |
+
|
| 156 |
+
Returns
|
| 157 |
+
-------
|
| 158 |
+
dict[str, torch.Tensor]
|
| 159 |
+
LoRA weights in Diffusers format.
|
| 160 |
+
"""
|
| 161 |
+
if isinstance(input_lora, str):
|
| 162 |
+
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
|
| 163 |
+
else:
|
| 164 |
+
tensors = {k: v for k, v in input_lora.items()}
|
| 165 |
+
|
| 166 |
+
tensors = handle_kohya_lora(tensors)
|
| 167 |
+
|
| 168 |
+
# Convert FP8 tensors to BF16
|
| 169 |
+
for k, v in tensors.items():
|
| 170 |
+
if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]:
|
| 171 |
+
tensors[k] = v.to(torch.bfloat16)
|
| 172 |
+
|
| 173 |
+
# Apply Kontext-specific key conversion for both PEFT format and ComfyUI format
|
| 174 |
+
# This handles LoRAs with base_model.model.* prefix or lora_unet_* prefix (including final_layer_linear)
|
| 175 |
+
if any(k.startswith("base_model.model.") for k in tensors.keys()):
|
| 176 |
+
logger.info("Converting PEFT format to ComfyUI format")
|
| 177 |
+
return convert_peft_to_comfyui(tensors)
|
| 178 |
+
|
| 179 |
+
# Handle LoRAs that only have final_layer_linear without adaLN_modulation
|
| 180 |
+
# This is a workaround for incomplete final layer LoRAs
|
| 181 |
+
final_keys = [k for k in tensors.keys() if "final_layer" in k]
|
| 182 |
+
has_linear = any("final_layer_linear" in k for k in final_keys)
|
| 183 |
+
has_adaln = any("final_layer_adaLN_modulation" in k for k in final_keys)
|
| 184 |
+
|
| 185 |
+
if has_linear and not has_adaln:
|
| 186 |
+
for key in list(tensors.keys()):
|
| 187 |
+
if "final_layer_linear" in key:
|
| 188 |
+
adaln_key = key.replace("final_layer_linear", "final_layer_adaLN_modulation_1")
|
| 189 |
+
if adaln_key not in tensors:
|
| 190 |
+
tensors[adaln_key] = torch.zeros_like(tensors[key])
|
| 191 |
+
|
| 192 |
+
new_tensors, alphas = FluxLoraLoaderMixin.lora_state_dict(tensors, return_alphas=True)
|
| 193 |
+
new_tensors = convert_unet_state_dict_to_peft(new_tensors)
|
| 194 |
+
|
| 195 |
+
if alphas is not None and len(alphas) > 0:
|
| 196 |
+
for k, v in alphas.items():
|
| 197 |
+
key_A = k.replace(".alpha", ".lora_A.weight")
|
| 198 |
+
key_B = k.replace(".alpha", ".lora_B.weight")
|
| 199 |
+
assert key_A in new_tensors, f"Key {key_A} not found in new tensors."
|
| 200 |
+
assert key_B in new_tensors, f"Key {key_B} not found in new tensors."
|
| 201 |
+
rank = new_tensors[key_A].shape[0]
|
| 202 |
+
assert new_tensors[key_B].shape[1] == rank, f"Rank mismatch for {key_B}."
|
| 203 |
+
new_tensors[key_A] = new_tensors[key_A] * v / rank
|
| 204 |
+
|
| 205 |
+
if output_path is not None:
|
| 206 |
+
output_dir = os.path.dirname(os.path.abspath(output_path))
|
| 207 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 208 |
+
save_file(new_tensors, output_path)
|
| 209 |
+
|
| 210 |
+
return new_tensors
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
if __name__ == "__main__":
|
| 214 |
+
parser = argparse.ArgumentParser()
|
| 215 |
+
parser.add_argument("-i", "--input-path", type=str, required=True, help="path to the comfyui lora safetensors file")
|
| 216 |
+
parser.add_argument(
|
| 217 |
+
"-o", "--output-path", type=str, required=True, help="path to the output diffusers safetensors file"
|
| 218 |
+
)
|
| 219 |
+
args = parser.parse_args()
|
| 220 |
+
to_diffusers(args.input_path, args.output_path)
|
nunchaku/lora/flux/nunchaku_converter.py
ADDED
|
@@ -0,0 +1,949 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Nunchaku LoRA format converter for Flux models.
|
| 3 |
+
|
| 4 |
+
This module provides utilities to convert LoRA weights from Diffusers format
|
| 5 |
+
to Nunchaku format for efficient quantized inference in Flux models.
|
| 6 |
+
|
| 7 |
+
Key functions
|
| 8 |
+
-------------
|
| 9 |
+
- :func:`to_nunchaku` : Main conversion entry point
|
| 10 |
+
- :func:`fuse_vectors` : Vector fusion for bias terms
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import logging
|
| 14 |
+
import os
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from safetensors.torch import save_file
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
|
| 20 |
+
from ...utils import filter_state_dict, load_state_dict_in_safetensors
|
| 21 |
+
from .diffusers_converter import to_diffusers
|
| 22 |
+
from .packer import NunchakuWeightPacker
|
| 23 |
+
from .utils import is_nunchaku_format, pad
|
| 24 |
+
|
| 25 |
+
# Get log level from environment variable (default to INFO)
|
| 26 |
+
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
|
| 27 |
+
|
| 28 |
+
# Configure logging
|
| 29 |
+
logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# region utilities
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def update_state_dict(
|
| 37 |
+
lhs: dict[str, torch.Tensor], rhs: dict[str, torch.Tensor], prefix: str = ""
|
| 38 |
+
) -> dict[str, torch.Tensor]:
|
| 39 |
+
"""
|
| 40 |
+
Update a state dictionary with values from another, optionally adding a prefix to keys.
|
| 41 |
+
|
| 42 |
+
Parameters
|
| 43 |
+
----------
|
| 44 |
+
lhs : dict[str, torch.Tensor]
|
| 45 |
+
Target state dictionary.
|
| 46 |
+
rhs : dict[str, torch.Tensor]
|
| 47 |
+
Source state dictionary.
|
| 48 |
+
prefix : str, optional
|
| 49 |
+
Prefix to add to keys from rhs.
|
| 50 |
+
|
| 51 |
+
Returns
|
| 52 |
+
-------
|
| 53 |
+
dict[str, torch.Tensor]
|
| 54 |
+
Updated state dictionary.
|
| 55 |
+
|
| 56 |
+
Raises
|
| 57 |
+
------
|
| 58 |
+
AssertionError
|
| 59 |
+
If any key already exists in the target dictionary.
|
| 60 |
+
"""
|
| 61 |
+
for rkey, value in rhs.items():
|
| 62 |
+
lkey = f"{prefix}.{rkey}" if prefix else rkey
|
| 63 |
+
assert lkey not in lhs, f"Key {lkey} already exists in the state dict."
|
| 64 |
+
lhs[lkey] = value
|
| 65 |
+
return lhs
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# endregion
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def pack_lowrank_weight(weight: torch.Tensor, down: bool) -> torch.Tensor:
|
| 72 |
+
"""
|
| 73 |
+
Pack the low-rank weight tensor for W4A4 linear layers.
|
| 74 |
+
|
| 75 |
+
Parameters
|
| 76 |
+
----------
|
| 77 |
+
weight : torch.Tensor
|
| 78 |
+
Low-rank weight tensor.
|
| 79 |
+
down : bool
|
| 80 |
+
If True, pack as down-projection; else as up-projection.
|
| 81 |
+
|
| 82 |
+
Returns
|
| 83 |
+
-------
|
| 84 |
+
torch.Tensor
|
| 85 |
+
Packed weight tensor.
|
| 86 |
+
"""
|
| 87 |
+
assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}."
|
| 88 |
+
lane_n, lane_k = 1, 2 # lane_n is always 1, lane_k is 32 bits // 16 bits = 2
|
| 89 |
+
n_pack_size, k_pack_size = 2, 2
|
| 90 |
+
num_n_lanes, num_k_lanes = 8, 4
|
| 91 |
+
frag_n = n_pack_size * num_n_lanes * lane_n
|
| 92 |
+
frag_k = k_pack_size * num_k_lanes * lane_k
|
| 93 |
+
weight = pad(weight, divisor=(frag_n, frag_k), dim=(0, 1))
|
| 94 |
+
if down:
|
| 95 |
+
r, c = weight.shape
|
| 96 |
+
r_frags, c_frags = r // frag_n, c // frag_k
|
| 97 |
+
weight = weight.view(r_frags, frag_n, c_frags, frag_k).permute(2, 0, 1, 3)
|
| 98 |
+
else:
|
| 99 |
+
c, r = weight.shape
|
| 100 |
+
c_frags, r_frags = c // frag_n, r // frag_k
|
| 101 |
+
weight = weight.view(c_frags, frag_n, r_frags, frag_k).permute(0, 2, 1, 3)
|
| 102 |
+
weight = weight.reshape(c_frags, r_frags, n_pack_size, num_n_lanes, k_pack_size, num_k_lanes, lane_k)
|
| 103 |
+
weight = weight.permute(0, 1, 3, 5, 2, 4, 6).contiguous()
|
| 104 |
+
return weight.view(c, r)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def unpack_lowrank_weight(weight: torch.Tensor, down: bool) -> torch.Tensor:
|
| 108 |
+
"""
|
| 109 |
+
Unpack the low-rank weight tensor from W4A4 linear layers.
|
| 110 |
+
|
| 111 |
+
Parameters
|
| 112 |
+
----------
|
| 113 |
+
weight : torch.Tensor
|
| 114 |
+
Packed low-rank weight tensor.
|
| 115 |
+
down : bool
|
| 116 |
+
If True, unpack as down-projection; else as up-projection.
|
| 117 |
+
|
| 118 |
+
Returns
|
| 119 |
+
-------
|
| 120 |
+
torch.Tensor
|
| 121 |
+
Unpacked weight tensor.
|
| 122 |
+
"""
|
| 123 |
+
c, r = weight.shape
|
| 124 |
+
assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}."
|
| 125 |
+
lane_n, lane_k = 1, 2 # lane_n is always 1, lane_k is 32 bits // 16 bits = 2
|
| 126 |
+
n_pack_size, k_pack_size = 2, 2
|
| 127 |
+
num_n_lanes, num_k_lanes = 8, 4
|
| 128 |
+
frag_n = n_pack_size * num_n_lanes * lane_n
|
| 129 |
+
frag_k = k_pack_size * num_k_lanes * lane_k
|
| 130 |
+
if down:
|
| 131 |
+
r_frags, c_frags = r // frag_n, c // frag_k
|
| 132 |
+
else:
|
| 133 |
+
c_frags, r_frags = c // frag_n, r // frag_k
|
| 134 |
+
weight = weight.view(c_frags, r_frags, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, lane_k)
|
| 135 |
+
weight = weight.permute(0, 1, 4, 2, 5, 3, 6).contiguous()
|
| 136 |
+
weight = weight.view(c_frags, r_frags, frag_n, frag_k)
|
| 137 |
+
if down:
|
| 138 |
+
weight = weight.permute(1, 2, 0, 3).contiguous().view(r, c)
|
| 139 |
+
else:
|
| 140 |
+
weight = weight.permute(0, 2, 1, 3).contiguous().view(c, r)
|
| 141 |
+
return weight
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def reorder_adanorm_lora_up(lora_up: torch.Tensor, splits: int) -> torch.Tensor:
|
| 145 |
+
"""
|
| 146 |
+
Reorder AdaNorm LoRA up-projection tensor for correct shape.
|
| 147 |
+
|
| 148 |
+
Parameters
|
| 149 |
+
----------
|
| 150 |
+
lora_up : torch.Tensor
|
| 151 |
+
LoRA up-projection tensor.
|
| 152 |
+
splits : int
|
| 153 |
+
Number of splits for AdaNorm.
|
| 154 |
+
|
| 155 |
+
Returns
|
| 156 |
+
-------
|
| 157 |
+
torch.Tensor
|
| 158 |
+
Reordered tensor.
|
| 159 |
+
"""
|
| 160 |
+
c, r = lora_up.shape
|
| 161 |
+
assert c % splits == 0
|
| 162 |
+
return lora_up.view(splits, c // splits, r).transpose(0, 1).reshape(c, r).contiguous()
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
|
| 166 |
+
orig_state_dict: dict[str, torch.Tensor],
|
| 167 |
+
extra_lora_dict: dict[str, torch.Tensor],
|
| 168 |
+
converted_block_name: str,
|
| 169 |
+
candidate_block_name: str,
|
| 170 |
+
local_name_map: dict[str, str | list[str]],
|
| 171 |
+
convert_map: dict[str, str],
|
| 172 |
+
default_dtype: torch.dtype = torch.bfloat16,
|
| 173 |
+
) -> dict[str, torch.Tensor]:
|
| 174 |
+
"""
|
| 175 |
+
Convert LoRA weights for a transformer block from Diffusers to Nunchaku format.
|
| 176 |
+
|
| 177 |
+
Merges and converts LoRA weights from the original SVDQuant low-rank branch and an extra LoRA dict
|
| 178 |
+
for a given transformer block, producing a Nunchaku-compatible dictionary. Handles both fused and
|
| 179 |
+
unfused LoRA branches (e.g., qkv), and merges multiple LoRA branches as needed.
|
| 180 |
+
|
| 181 |
+
Parameters
|
| 182 |
+
----------
|
| 183 |
+
orig_state_dict : dict[str, torch.Tensor]
|
| 184 |
+
Original state dict with LoRA weights, keys like ``"{block}.{local}.lora_down"`` and ``"{block}.{local}.lora_up"``.
|
| 185 |
+
extra_lora_dict : dict[str, torch.Tensor]
|
| 186 |
+
Extra LoRA weights, keys like ``"{block}.{local}.lora_A.weight"`` and ``"{block}.{local}.lora_B.weight"``.
|
| 187 |
+
converted_block_name : str
|
| 188 |
+
Block name for output (e.g., ``"transformer_blocks.0"``).
|
| 189 |
+
candidate_block_name : str
|
| 190 |
+
Block name for input lookup (e.g., ``"blocks.0"``).
|
| 191 |
+
local_name_map : dict[str, str | list[str]]
|
| 192 |
+
Maps output local names (e.g., ``"attn.qkv"``) to one or more input local names.
|
| 193 |
+
convert_map : dict[str, str]
|
| 194 |
+
Maps output local names to conversion types: ``"adanorm_single"``, ``"adanorm_zero"``, or ``"linear"``.
|
| 195 |
+
default_dtype : torch.dtype, optional
|
| 196 |
+
Output tensor dtype (default: ``torch.bfloat16``).
|
| 197 |
+
|
| 198 |
+
Returns
|
| 199 |
+
-------
|
| 200 |
+
dict[str, torch.Tensor]
|
| 201 |
+
A dictionary containing the converted LoRA weights in Nunchaku format.
|
| 202 |
+
|
| 203 |
+
Notes
|
| 204 |
+
-----
|
| 205 |
+
- If both original and extra LoRA weights are present, they are merged by concatenation.
|
| 206 |
+
- Handles both fused and unfused attention projections (e.g., qkv).
|
| 207 |
+
- Applies special packing for W4A16 linear layers (e.g., ``"adanorm_single"`` and ``"adanorm_zero"``).
|
| 208 |
+
"""
|
| 209 |
+
logger.debug(f"Converting LoRA branch for block {candidate_block_name}...")
|
| 210 |
+
converted: dict[str, torch.Tensor] = {}
|
| 211 |
+
for converted_local_name, candidate_local_names in local_name_map.items():
|
| 212 |
+
if isinstance(candidate_local_names, str):
|
| 213 |
+
candidate_local_names = [candidate_local_names]
|
| 214 |
+
# region original LoRA
|
| 215 |
+
orig_lora = (
|
| 216 |
+
orig_state_dict.get(f"{converted_block_name}.{converted_local_name}.lora_down", None),
|
| 217 |
+
orig_state_dict.get(f"{converted_block_name}.{converted_local_name}.lora_up", None),
|
| 218 |
+
)
|
| 219 |
+
if orig_lora[0] is None or orig_lora[1] is None:
|
| 220 |
+
assert orig_lora[0] is None and orig_lora[1] is None
|
| 221 |
+
orig_lora = None
|
| 222 |
+
elif orig_lora[0].numel() == 0 or orig_lora[1].numel() == 0:
|
| 223 |
+
assert orig_lora[0].numel() == 0 and orig_lora[1].numel() == 0
|
| 224 |
+
orig_lora = None
|
| 225 |
+
else:
|
| 226 |
+
assert orig_lora[0] is not None and orig_lora[1] is not None
|
| 227 |
+
orig_lora = (
|
| 228 |
+
unpack_lowrank_weight(orig_lora[0], down=True),
|
| 229 |
+
unpack_lowrank_weight(orig_lora[1], down=False),
|
| 230 |
+
)
|
| 231 |
+
logger.debug(
|
| 232 |
+
f" - Found {converted_block_name} LoRA of {converted_local_name} (rank: {orig_lora[0].shape[0]})"
|
| 233 |
+
)
|
| 234 |
+
# endregion
|
| 235 |
+
# region extra LoRA
|
| 236 |
+
extra_lora_list = None
|
| 237 |
+
|
| 238 |
+
# if the qkv are already fused
|
| 239 |
+
if "qkv" in converted_local_name:
|
| 240 |
+
candidate_local_name = candidate_local_names[0]
|
| 241 |
+
assert "_q" in candidate_local_name
|
| 242 |
+
candidate_local_name = candidate_local_name.replace("_q", "_qkv")
|
| 243 |
+
lora_A = extra_lora_dict.get(f"{candidate_block_name}.{candidate_local_name}.lora_A.weight", None)
|
| 244 |
+
lora_B = extra_lora_dict.get(f"{candidate_block_name}.{candidate_local_name}.lora_B.weight", None)
|
| 245 |
+
if lora_A is None and lora_B is None:
|
| 246 |
+
extra_lora_list = None
|
| 247 |
+
else:
|
| 248 |
+
assert lora_A is not None and lora_B is not None
|
| 249 |
+
extra_lora_list = [(lora_A, lora_B)]
|
| 250 |
+
|
| 251 |
+
# not fused, fuse them manually
|
| 252 |
+
if extra_lora_list is None:
|
| 253 |
+
extra_lora_list = [
|
| 254 |
+
(
|
| 255 |
+
extra_lora_dict.get(f"{candidate_block_name}.{candidate_local_name}.lora_A.weight", None),
|
| 256 |
+
extra_lora_dict.get(f"{candidate_block_name}.{candidate_local_name}.lora_B.weight", None),
|
| 257 |
+
)
|
| 258 |
+
for candidate_local_name in candidate_local_names
|
| 259 |
+
]
|
| 260 |
+
if any(lora[0] is not None or lora[1] is not None for lora in extra_lora_list):
|
| 261 |
+
# merge extra LoRAs into one LoRA
|
| 262 |
+
if len(extra_lora_list) > 1:
|
| 263 |
+
first_lora = None
|
| 264 |
+
for lora in extra_lora_list:
|
| 265 |
+
if lora[0] is not None:
|
| 266 |
+
assert lora[1] is not None
|
| 267 |
+
first_lora = lora
|
| 268 |
+
break
|
| 269 |
+
assert first_lora is not None
|
| 270 |
+
for lora_index in range(len(extra_lora_list)):
|
| 271 |
+
if extra_lora_list[lora_index][0] is None:
|
| 272 |
+
assert extra_lora_list[lora_index][1] is None
|
| 273 |
+
extra_lora_list[lora_index] = (first_lora[0].clone(), torch.zeros_like(first_lora[1]))
|
| 274 |
+
if all(lora[0].equal(extra_lora_list[0][0]) for lora in extra_lora_list):
|
| 275 |
+
# if all extra LoRAs have the same lora_down, use it
|
| 276 |
+
extra_lora_down = extra_lora_list[0][0]
|
| 277 |
+
extra_lora_up = torch.cat([lora[1] for lora in extra_lora_list], dim=0)
|
| 278 |
+
else:
|
| 279 |
+
extra_lora_down = torch.cat([lora[0] for lora in extra_lora_list], dim=0)
|
| 280 |
+
extra_lora_up_c = sum(lora[1].shape[0] for lora in extra_lora_list)
|
| 281 |
+
extra_lora_up_r = sum(lora[1].shape[1] for lora in extra_lora_list)
|
| 282 |
+
assert extra_lora_up_r == extra_lora_down.shape[0]
|
| 283 |
+
extra_lora_up = torch.zeros((extra_lora_up_c, extra_lora_up_r), dtype=extra_lora_down.dtype)
|
| 284 |
+
c, r = 0, 0
|
| 285 |
+
for lora in extra_lora_list:
|
| 286 |
+
c_next, r_next = c + lora[1].shape[0], r + lora[1].shape[1]
|
| 287 |
+
extra_lora_up[c:c_next, r:r_next] = lora[1]
|
| 288 |
+
c, r = c_next, r_next
|
| 289 |
+
else:
|
| 290 |
+
extra_lora_down, extra_lora_up = extra_lora_list[0]
|
| 291 |
+
extra_lora: tuple[torch.Tensor, torch.Tensor] = (extra_lora_down, extra_lora_up)
|
| 292 |
+
logger.debug(
|
| 293 |
+
f" - Found {candidate_block_name} LoRA of {candidate_local_names} (rank: {extra_lora[0].shape[0]})"
|
| 294 |
+
)
|
| 295 |
+
else:
|
| 296 |
+
extra_lora = None
|
| 297 |
+
# endregion
|
| 298 |
+
# region merge LoRA
|
| 299 |
+
if orig_lora is None:
|
| 300 |
+
if extra_lora is None:
|
| 301 |
+
lora = None
|
| 302 |
+
else:
|
| 303 |
+
logger.debug(" - Using extra LoRA")
|
| 304 |
+
lora = (extra_lora[0].to(default_dtype), extra_lora[1].to(default_dtype))
|
| 305 |
+
elif extra_lora is None:
|
| 306 |
+
logger.debug(" - Using original LoRA")
|
| 307 |
+
lora = orig_lora
|
| 308 |
+
else:
|
| 309 |
+
try:
|
| 310 |
+
lora = (
|
| 311 |
+
torch.cat([orig_lora[0], extra_lora[0].to(orig_lora[0].dtype)], dim=0), # [r, c]
|
| 312 |
+
torch.cat([orig_lora[1], extra_lora[1].to(orig_lora[1].dtype)], dim=1), # [c, r]
|
| 313 |
+
)
|
| 314 |
+
logger.debug(f" - Merging original and extra LoRA (rank: {lora[0].shape[0]})")
|
| 315 |
+
except RuntimeError as e:
|
| 316 |
+
if "Sizes of tensors must match" in str(e):
|
| 317 |
+
# Handle various dimension mismatch cases for LoRA
|
| 318 |
+
logger.debug(
|
| 319 |
+
f" - Dimension mismatch detected: orig_lora[1]={orig_lora[1].shape}, extra_lora[1]={extra_lora[1].shape}"
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# Handle dimension mismatch by using only the properly sized portion of extra_lora
|
| 323 |
+
# instead of trying to concatenate mismatched dimensions
|
| 324 |
+
|
| 325 |
+
# Case 1: single_blocks linear1 [21504] -> mlp_fc1 [12288]
|
| 326 |
+
if extra_lora[1].shape[1] == 21504 and orig_lora[1].shape[1] == 12288:
|
| 327 |
+
# Use only the first 12288 dimensions from the 21504 extra LoRA
|
| 328 |
+
extra_lora_up_split = extra_lora[1][:, :12288].clone()
|
| 329 |
+
extra_lora_down = extra_lora[0].clone()
|
| 330 |
+
# logger.debug(f" - Dimension fix 21504->12288: using split extra LoRA instead of merge")
|
| 331 |
+
|
| 332 |
+
# Use the split extra LoRA instead of concatenating
|
| 333 |
+
lora = (extra_lora_down.to(orig_lora[0].dtype), extra_lora_up_split.to(orig_lora[1].dtype))
|
| 334 |
+
|
| 335 |
+
# Case 2: transformer_blocks with different MLP dimensions (27648 -> 9216)
|
| 336 |
+
elif extra_lora[1].shape[1] == 27648 and orig_lora[1].shape[1] == 9216:
|
| 337 |
+
# Use only the first 9216 dimensions from the 27648 extra LoRA
|
| 338 |
+
extra_lora_up_split = extra_lora[1][:, :9216].clone()
|
| 339 |
+
extra_lora_down = extra_lora[0].clone()
|
| 340 |
+
# logger.debug(f" - Dimension fix 27648->9216: using split extra LoRA instead of merge")
|
| 341 |
+
|
| 342 |
+
# Use the split extra LoRA instead of concatenating
|
| 343 |
+
lora = (extra_lora_down.to(orig_lora[0].dtype), extra_lora_up_split.to(orig_lora[1].dtype))
|
| 344 |
+
|
| 345 |
+
# Case 3: Other dimension ratios - try to find a reasonable split
|
| 346 |
+
elif extra_lora[1].shape[1] > orig_lora[1].shape[1]:
|
| 347 |
+
# Use only what we need from extra LoRA
|
| 348 |
+
target_dim = orig_lora[1].shape[1]
|
| 349 |
+
extra_lora_up_split = extra_lora[1][:, :target_dim].clone()
|
| 350 |
+
extra_lora_down = extra_lora[0].clone()
|
| 351 |
+
# logger.debug(
|
| 352 |
+
# f" - Dimension fix {extra_lora[1].shape[1]}->{target_dim}: using truncated extra LoRA"
|
| 353 |
+
# )
|
| 354 |
+
|
| 355 |
+
# Use the truncated extra LoRA instead of concatenating
|
| 356 |
+
lora = (extra_lora_down.to(orig_lora[0].dtype), extra_lora_up_split.to(orig_lora[1].dtype))
|
| 357 |
+
|
| 358 |
+
else:
|
| 359 |
+
# For cases where extra LoRA has fewer dimensions, use original LoRA only
|
| 360 |
+
# logger.warning(
|
| 361 |
+
# f" - Cannot split extra LoRA {extra_lora[1].shape[1]}->{orig_lora[1].shape[1]}, using original only"
|
| 362 |
+
# )
|
| 363 |
+
lora = orig_lora
|
| 364 |
+
else:
|
| 365 |
+
raise e
|
| 366 |
+
# endregion
|
| 367 |
+
if lora is not None:
|
| 368 |
+
if convert_map[converted_local_name] == "adanorm_single":
|
| 369 |
+
update_state_dict(
|
| 370 |
+
converted,
|
| 371 |
+
{
|
| 372 |
+
"lora_down": pad(lora[0], divisor=16, dim=0),
|
| 373 |
+
"lora_up": pad(reorder_adanorm_lora_up(lora[1], splits=3), divisor=16, dim=1),
|
| 374 |
+
},
|
| 375 |
+
prefix=converted_local_name,
|
| 376 |
+
)
|
| 377 |
+
elif convert_map[converted_local_name] == "adanorm_zero":
|
| 378 |
+
update_state_dict(
|
| 379 |
+
converted,
|
| 380 |
+
{
|
| 381 |
+
"lora_down": pad(lora[0], divisor=16, dim=0),
|
| 382 |
+
"lora_up": pad(reorder_adanorm_lora_up(lora[1], splits=6), divisor=16, dim=1),
|
| 383 |
+
},
|
| 384 |
+
prefix=converted_local_name,
|
| 385 |
+
)
|
| 386 |
+
elif convert_map[converted_local_name] == "linear":
|
| 387 |
+
update_state_dict(
|
| 388 |
+
converted,
|
| 389 |
+
{
|
| 390 |
+
"lora_down": pack_lowrank_weight(lora[0], down=True),
|
| 391 |
+
"lora_up": pack_lowrank_weight(lora[1], down=False),
|
| 392 |
+
},
|
| 393 |
+
prefix=converted_local_name,
|
| 394 |
+
)
|
| 395 |
+
return converted
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def preprocess_single_blocks_lora(
|
| 399 |
+
extra_lora_dict: dict[str, torch.Tensor], candidate_block_name: str
|
| 400 |
+
) -> dict[str, torch.Tensor]:
|
| 401 |
+
"""
|
| 402 |
+
Preprocess LoRA weights from single_blocks format to match single_transformer_blocks structure.
|
| 403 |
+
|
| 404 |
+
This function handles the architectural mismatch between old and new models:
|
| 405 |
+
- Old single_blocks: linear1 (fused 21504-dim layer) and linear2
|
| 406 |
+
- New single_transformer_blocks: mlp_fc1 (12288-dim), qkv_proj (9216-dim), and mlp_fc2
|
| 407 |
+
|
| 408 |
+
The linear1 layer in the old architecture combines two functions:
|
| 409 |
+
1. MLP projection (first 12288 dimensions)
|
| 410 |
+
2. QKV projection for attention (last 9216 dimensions)
|
| 411 |
+
|
| 412 |
+
These are split into separate layers in the new architecture.
|
| 413 |
+
"""
|
| 414 |
+
processed_dict = extra_lora_dict.copy()
|
| 415 |
+
|
| 416 |
+
# Find all single_transformer_blocks keys that need preprocessing
|
| 417 |
+
single_blocks_keys = [k for k in extra_lora_dict.keys() if "single_transformer_blocks" in k and "linear" in k]
|
| 418 |
+
|
| 419 |
+
logger.debug(f"Preprocessing LoRA for candidate: {candidate_block_name}")
|
| 420 |
+
logger.debug(f"All keys in extra_lora_dict: {list(extra_lora_dict.keys())[:10]}...") # Show first 10 keys
|
| 421 |
+
logger.debug(f"Found single_transformer_blocks keys: {single_blocks_keys[:5]}...") # Show first 5 keys
|
| 422 |
+
|
| 423 |
+
if single_blocks_keys:
|
| 424 |
+
logger.debug(f"Found single_transformer_blocks LoRA keys, preprocessing for candidate: {candidate_block_name}")
|
| 425 |
+
|
| 426 |
+
# The candidate_block_name is already "single_transformer_blocks.0"
|
| 427 |
+
# Look for linear1 and linear2 keys with this exact name
|
| 428 |
+
linear1_lora_A_key = f"{candidate_block_name}.linear1.lora_A.weight"
|
| 429 |
+
linear1_lora_B_key = f"{candidate_block_name}.linear1.lora_B.weight"
|
| 430 |
+
linear2_lora_A_key = f"{candidate_block_name}.linear2.lora_A.weight"
|
| 431 |
+
linear2_lora_B_key = f"{candidate_block_name}.linear2.lora_B.weight"
|
| 432 |
+
|
| 433 |
+
logger.debug(f"Looking for keys: {linear1_lora_B_key}")
|
| 434 |
+
logger.debug(
|
| 435 |
+
f"Available keys matching pattern: {[k for k in extra_lora_dict.keys() if candidate_block_name in k][:5]}..."
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
if linear1_lora_B_key in extra_lora_dict:
|
| 439 |
+
linear1_lora_A = extra_lora_dict[linear1_lora_A_key]
|
| 440 |
+
linear1_lora_B = extra_lora_dict[linear1_lora_B_key]
|
| 441 |
+
|
| 442 |
+
# Check if this is the problematic 21504 dimension case
|
| 443 |
+
if linear1_lora_B.shape[0] == 21504:
|
| 444 |
+
logger.debug(
|
| 445 |
+
f"Splitting linear1 LoRA weights: [21504, {linear1_lora_B.shape[1]}] -> "
|
| 446 |
+
f"mlp_fc1 [12288, {linear1_lora_B.shape[1]}] + qkv_proj [9216, {linear1_lora_B.shape[1]}]"
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
# Split linear1.lora_B [21504, rank] into two parts:
|
| 450 |
+
# 1. First 12288 dimensions -> mlp_fc1
|
| 451 |
+
# 2. Last 9216 dimensions (12288:21504) -> qkv_proj
|
| 452 |
+
mlp_fc1_lora_B = linear1_lora_B[:12288, :].clone()
|
| 453 |
+
qkv_proj_lora_B = linear1_lora_B[12288:21504, :].clone()
|
| 454 |
+
|
| 455 |
+
# The lora_A weight is reused for both new layers
|
| 456 |
+
# since it represents the down-projection from the input
|
| 457 |
+
mlp_fc1_lora_A = linear1_lora_A.clone()
|
| 458 |
+
qkv_proj_lora_A = linear1_lora_A.clone()
|
| 459 |
+
|
| 460 |
+
# Map to new architecture:
|
| 461 |
+
# 1. proj_mlp corresponds to mlp_fc1
|
| 462 |
+
processed_dict[f"{candidate_block_name}.proj_mlp.lora_A.weight"] = mlp_fc1_lora_A
|
| 463 |
+
processed_dict[f"{candidate_block_name}.proj_mlp.lora_B.weight"] = mlp_fc1_lora_B
|
| 464 |
+
|
| 465 |
+
# 2. Map the QKV part to the attention layers
|
| 466 |
+
# Note: In the new architecture, this maps to attn.to_q, attn.to_k, attn.to_v
|
| 467 |
+
# which get fused into qkv_proj during the conversion
|
| 468 |
+
processed_dict[f"{candidate_block_name}.attn.to_q.lora_A.weight"] = qkv_proj_lora_A
|
| 469 |
+
processed_dict[f"{candidate_block_name}.attn.to_q.lora_B.weight"] = qkv_proj_lora_B[
|
| 470 |
+
:3072, :
|
| 471 |
+
] # Q projection
|
| 472 |
+
processed_dict[f"{candidate_block_name}.attn.to_k.lora_A.weight"] = qkv_proj_lora_A
|
| 473 |
+
processed_dict[f"{candidate_block_name}.attn.to_k.lora_B.weight"] = qkv_proj_lora_B[
|
| 474 |
+
3072:6144, :
|
| 475 |
+
] # K projection
|
| 476 |
+
processed_dict[f"{candidate_block_name}.attn.to_v.lora_A.weight"] = qkv_proj_lora_A
|
| 477 |
+
processed_dict[f"{candidate_block_name}.attn.to_v.lora_B.weight"] = qkv_proj_lora_B[
|
| 478 |
+
6144:9216, :
|
| 479 |
+
] # V projection
|
| 480 |
+
|
| 481 |
+
# Handle linear2 -> mlp_fc2 mapping
|
| 482 |
+
if linear2_lora_B_key in extra_lora_dict:
|
| 483 |
+
linear2_lora_A = extra_lora_dict[linear2_lora_A_key]
|
| 484 |
+
linear2_lora_B = extra_lora_dict[linear2_lora_B_key]
|
| 485 |
+
|
| 486 |
+
# Map linear2 to proj_out.linears.1 (mlp_fc2)
|
| 487 |
+
processed_dict[f"{candidate_block_name}.proj_out.linears.1.lora_A.weight"] = linear2_lora_A
|
| 488 |
+
processed_dict[f"{candidate_block_name}.proj_out.linears.1.lora_B.weight"] = linear2_lora_B
|
| 489 |
+
|
| 490 |
+
# Remove original keys
|
| 491 |
+
processed_dict.pop(linear2_lora_A_key, None)
|
| 492 |
+
processed_dict.pop(linear2_lora_B_key, None)
|
| 493 |
+
|
| 494 |
+
# Remove original linear1 keys
|
| 495 |
+
processed_dict.pop(linear1_lora_A_key, None)
|
| 496 |
+
processed_dict.pop(linear1_lora_B_key, None)
|
| 497 |
+
|
| 498 |
+
return processed_dict
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def convert_to_nunchaku_flux_single_transformer_block_lowrank_dict(
|
| 502 |
+
orig_state_dict: dict[str, torch.Tensor],
|
| 503 |
+
extra_lora_dict: dict[str, torch.Tensor],
|
| 504 |
+
converted_block_name: str,
|
| 505 |
+
candidate_block_name: str,
|
| 506 |
+
default_dtype: torch.dtype = torch.bfloat16,
|
| 507 |
+
) -> dict[str, torch.Tensor]:
|
| 508 |
+
"""
|
| 509 |
+
Convert LoRA weights for a single FLUX transformer block from Diffusers to Nunchaku format.
|
| 510 |
+
|
| 511 |
+
This function merges and converts LoRA weights from the original SVDQuant low-rank branch and an
|
| 512 |
+
extra LoRA dictionary for a given transformer block, producing a Nunchaku-compatible dictionary.
|
| 513 |
+
It handles both fused and unfused LoRA branches (e.g., qkv), and merges multiple LoRA branches as needed.
|
| 514 |
+
|
| 515 |
+
Parameters
|
| 516 |
+
----------
|
| 517 |
+
orig_state_dict : dict[str, torch.Tensor]
|
| 518 |
+
Original state dict with LoRA weights, keys like ``"{block}.{local}.lora_down"`` and ``"{block}.{local}.lora_up"``.
|
| 519 |
+
extra_lora_dict : dict[str, torch.Tensor]
|
| 520 |
+
Extra LoRA weights, keys like ``"{block}.{local}.lora_A.weight"`` and ``"{block}.{local}.lora_B.weight"``.
|
| 521 |
+
converted_block_name : str
|
| 522 |
+
Block name for output (e.g., ``"transformer_blocks.0"``).
|
| 523 |
+
candidate_block_name : str
|
| 524 |
+
Block name for input lookup (e.g., ``"blocks.0"``).
|
| 525 |
+
default_dtype : torch.dtype, optional
|
| 526 |
+
Output tensor dtype (default: ``torch.bfloat16``).
|
| 527 |
+
|
| 528 |
+
Returns
|
| 529 |
+
-------
|
| 530 |
+
dict[str, torch.Tensor]
|
| 531 |
+
A dictionary containing the converted LoRA weights in Nunchaku format.
|
| 532 |
+
|
| 533 |
+
Notes
|
| 534 |
+
-----
|
| 535 |
+
- If both original and extra LoRA weights are present, they are merged by concatenation.
|
| 536 |
+
- Handles both fused and unfused attention projections (e.g., qkv).
|
| 537 |
+
- Applies special packing for W4A16 linear layers (e.g., ``"adanorm_single"`` and ``"adanorm_zero"``).
|
| 538 |
+
"""
|
| 539 |
+
|
| 540 |
+
# Preprocess single_blocks LoRA structure if needed
|
| 541 |
+
# extra_lora_dict = preprocess_single_blocks_lora(extra_lora_dict, candidate_block_name)
|
| 542 |
+
|
| 543 |
+
if f"{candidate_block_name}.proj_out.lora_A.weight" in extra_lora_dict:
|
| 544 |
+
assert f"{converted_block_name}.out_proj.qweight" in orig_state_dict
|
| 545 |
+
assert f"{converted_block_name}.mlp_fc2.qweight" in orig_state_dict
|
| 546 |
+
n1 = orig_state_dict[f"{converted_block_name}.out_proj.qweight"].shape[1] * 2
|
| 547 |
+
n2 = orig_state_dict[f"{converted_block_name}.mlp_fc2.qweight"].shape[1] * 2
|
| 548 |
+
lora_down = extra_lora_dict[f"{candidate_block_name}.proj_out.lora_A.weight"]
|
| 549 |
+
lora_up = extra_lora_dict[f"{candidate_block_name}.proj_out.lora_B.weight"]
|
| 550 |
+
assert lora_down.shape[1] == n1 + n2
|
| 551 |
+
extra_lora_dict[f"{candidate_block_name}.proj_out.linears.0.lora_A.weight"] = lora_down[:, :n1].clone()
|
| 552 |
+
extra_lora_dict[f"{candidate_block_name}.proj_out.linears.0.lora_B.weight"] = lora_up.clone()
|
| 553 |
+
extra_lora_dict[f"{candidate_block_name}.proj_out.linears.1.lora_A.weight"] = lora_down[:, n1:].clone()
|
| 554 |
+
extra_lora_dict[f"{candidate_block_name}.proj_out.linears.1.lora_B.weight"] = lora_up.clone()
|
| 555 |
+
extra_lora_dict.pop(f"{candidate_block_name}.proj_out.lora_A.weight")
|
| 556 |
+
extra_lora_dict.pop(f"{candidate_block_name}.proj_out.lora_B.weight")
|
| 557 |
+
|
| 558 |
+
for component in ["lora_A", "lora_B"]:
|
| 559 |
+
fc1_k = f"{candidate_block_name}.proj_mlp.{component}.weight"
|
| 560 |
+
fc2_k = f"{candidate_block_name}.proj_out.linears.1.{component}.weight"
|
| 561 |
+
fc1_v = extra_lora_dict[fc1_k]
|
| 562 |
+
fc2_v = extra_lora_dict[fc2_k]
|
| 563 |
+
dim = 0 if "lora_A" in fc1_k else 1
|
| 564 |
+
|
| 565 |
+
fc1_rank = fc1_v.shape[dim]
|
| 566 |
+
fc2_rank = fc2_v.shape[dim]
|
| 567 |
+
if fc1_rank != fc2_rank:
|
| 568 |
+
rank = max(fc1_rank, fc2_rank)
|
| 569 |
+
if fc1_rank < rank:
|
| 570 |
+
extra_lora_dict[fc1_k] = pad(fc1_v, divisor=rank, dim=dim)
|
| 571 |
+
if fc2_rank < rank:
|
| 572 |
+
extra_lora_dict[fc2_k] = pad(fc2_v, divisor=rank, dim=dim)
|
| 573 |
+
|
| 574 |
+
return convert_to_nunchaku_transformer_block_lowrank_dict(
|
| 575 |
+
orig_state_dict=orig_state_dict,
|
| 576 |
+
extra_lora_dict=extra_lora_dict,
|
| 577 |
+
converted_block_name=converted_block_name,
|
| 578 |
+
candidate_block_name=candidate_block_name,
|
| 579 |
+
local_name_map={
|
| 580 |
+
"norm.linear": "norm.linear",
|
| 581 |
+
"qkv_proj": ["attn.to_q", "attn.to_k", "attn.to_v"],
|
| 582 |
+
"norm_q": "attn.norm_q",
|
| 583 |
+
"norm_k": "attn.norm_k",
|
| 584 |
+
"out_proj": "proj_out.linears.0",
|
| 585 |
+
"mlp_fc1": "proj_mlp",
|
| 586 |
+
"mlp_fc2": "proj_out.linears.1",
|
| 587 |
+
},
|
| 588 |
+
convert_map={
|
| 589 |
+
"norm.linear": "adanorm_single",
|
| 590 |
+
"qkv_proj": "linear",
|
| 591 |
+
"out_proj": "linear",
|
| 592 |
+
"mlp_fc1": "linear",
|
| 593 |
+
"mlp_fc2": "linear",
|
| 594 |
+
},
|
| 595 |
+
default_dtype=default_dtype,
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
def convert_to_nunchaku_flux_transformer_block_lowrank_dict(
|
| 600 |
+
orig_state_dict: dict[str, torch.Tensor],
|
| 601 |
+
extra_lora_dict: dict[str, torch.Tensor],
|
| 602 |
+
converted_block_name: str,
|
| 603 |
+
candidate_block_name: str,
|
| 604 |
+
default_dtype: torch.dtype = torch.bfloat16,
|
| 605 |
+
) -> dict[str, torch.Tensor]:
|
| 606 |
+
"""
|
| 607 |
+
Convert LoRA weights for a single transformer block from Diffusers to Nunchaku format.
|
| 608 |
+
|
| 609 |
+
Parameters
|
| 610 |
+
----------
|
| 611 |
+
orig_state_dict : dict[str, torch.Tensor]
|
| 612 |
+
Original model state dict.
|
| 613 |
+
extra_lora_dict : dict[str, torch.Tensor]
|
| 614 |
+
LoRA weights state dict.
|
| 615 |
+
converted_block_name : str
|
| 616 |
+
Output block name for the converted weights.
|
| 617 |
+
candidate_block_name : str
|
| 618 |
+
Input block name for lookup.
|
| 619 |
+
default_dtype : torch.dtype, optional
|
| 620 |
+
Output tensor dtype (default: torch.bfloat16).
|
| 621 |
+
|
| 622 |
+
Returns
|
| 623 |
+
-------
|
| 624 |
+
dict[str, torch.Tensor]
|
| 625 |
+
Converted LoRA weights in Nunchaku format.
|
| 626 |
+
"""
|
| 627 |
+
return convert_to_nunchaku_transformer_block_lowrank_dict(
|
| 628 |
+
orig_state_dict=orig_state_dict,
|
| 629 |
+
extra_lora_dict=extra_lora_dict,
|
| 630 |
+
converted_block_name=converted_block_name,
|
| 631 |
+
candidate_block_name=candidate_block_name,
|
| 632 |
+
local_name_map={
|
| 633 |
+
"norm1.linear": "norm1.linear",
|
| 634 |
+
"norm1_context.linear": "norm1_context.linear",
|
| 635 |
+
"qkv_proj": ["attn.to_q", "attn.to_k", "attn.to_v"],
|
| 636 |
+
"qkv_proj_context": ["attn.add_q_proj", "attn.add_k_proj", "attn.add_v_proj"],
|
| 637 |
+
"norm_q": "attn.norm_q",
|
| 638 |
+
"norm_k": "attn.norm_k",
|
| 639 |
+
"norm_added_q": "attn.norm_added_q",
|
| 640 |
+
"norm_added_k": "attn.norm_added_k",
|
| 641 |
+
"out_proj": "attn.to_out.0",
|
| 642 |
+
"out_proj_context": "attn.to_add_out",
|
| 643 |
+
"mlp_fc1": "ff.net.0.proj",
|
| 644 |
+
"mlp_fc2": "ff.net.2",
|
| 645 |
+
"mlp_context_fc1": "ff_context.net.0.proj",
|
| 646 |
+
"mlp_context_fc2": "ff_context.net.2",
|
| 647 |
+
},
|
| 648 |
+
convert_map={
|
| 649 |
+
"norm1.linear": "adanorm_zero",
|
| 650 |
+
"norm1_context.linear": "adanorm_zero",
|
| 651 |
+
"qkv_proj": "linear",
|
| 652 |
+
"qkv_proj_context": "linear",
|
| 653 |
+
"out_proj": "linear",
|
| 654 |
+
"out_proj_context": "linear",
|
| 655 |
+
"mlp_fc1": "linear",
|
| 656 |
+
"mlp_fc2": "linear",
|
| 657 |
+
"mlp_context_fc1": "linear",
|
| 658 |
+
"mlp_context_fc2": "linear",
|
| 659 |
+
},
|
| 660 |
+
default_dtype=default_dtype,
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
def convert_to_nunchaku_flux_lowrank_dict(
|
| 665 |
+
base_model: dict[str, torch.Tensor] | str,
|
| 666 |
+
lora: dict[str, torch.Tensor] | str,
|
| 667 |
+
default_dtype: torch.dtype = torch.bfloat16,
|
| 668 |
+
) -> dict[str, torch.Tensor]:
|
| 669 |
+
"""
|
| 670 |
+
Convert a base model and LoRA weights from Diffusers format to Nunchaku format.
|
| 671 |
+
|
| 672 |
+
Parameters
|
| 673 |
+
----------
|
| 674 |
+
base_model : dict[str, torch.Tensor] or str
|
| 675 |
+
Base model weights or path to safetensors file.
|
| 676 |
+
lora : dict[str, torch.Tensor] or str
|
| 677 |
+
LoRA weights or path to safetensors file.
|
| 678 |
+
default_dtype : torch.dtype, optional
|
| 679 |
+
Output tensor dtype (default: torch.bfloat16).
|
| 680 |
+
|
| 681 |
+
Returns
|
| 682 |
+
-------
|
| 683 |
+
dict[str, torch.Tensor]
|
| 684 |
+
LoRA weights in Nunchaku format.
|
| 685 |
+
"""
|
| 686 |
+
if isinstance(base_model, str):
|
| 687 |
+
orig_state_dict = load_state_dict_in_safetensors(base_model)
|
| 688 |
+
else:
|
| 689 |
+
orig_state_dict = base_model
|
| 690 |
+
|
| 691 |
+
if isinstance(lora, str):
|
| 692 |
+
# Load the LoRA - check if it has transformer prefix
|
| 693 |
+
temp_dict = load_state_dict_in_safetensors(lora)
|
| 694 |
+
if any(k.startswith("transformer.") for k in temp_dict.keys()):
|
| 695 |
+
# Standard FLUX LoRA with transformer prefix
|
| 696 |
+
extra_lora_dict = filter_state_dict(temp_dict, filter_prefix="transformer.")
|
| 697 |
+
# Remove the transformer. prefix after filtering
|
| 698 |
+
renamed_dict = {}
|
| 699 |
+
for k, v in extra_lora_dict.items():
|
| 700 |
+
new_k = k.replace("transformer.", "") if k.startswith("transformer.") else k
|
| 701 |
+
renamed_dict[new_k] = v
|
| 702 |
+
extra_lora_dict = renamed_dict
|
| 703 |
+
else:
|
| 704 |
+
# Kontext LoRA without transformer prefix - use as is
|
| 705 |
+
extra_lora_dict = temp_dict
|
| 706 |
+
else:
|
| 707 |
+
# When called from to_nunchaku, lora is already processed by to_diffusers
|
| 708 |
+
# Keys should be in format: single_blocks.0.linear1.lora_A.weight
|
| 709 |
+
extra_lora_dict = lora
|
| 710 |
+
|
| 711 |
+
# Add transformer. prefix and rename blocks to match expectations
|
| 712 |
+
renamed_dict = {}
|
| 713 |
+
for k, v in extra_lora_dict.items():
|
| 714 |
+
new_k = k
|
| 715 |
+
# Add transformer. prefix and rename blocks
|
| 716 |
+
if k.startswith("single_blocks."):
|
| 717 |
+
new_k = "transformer.single_transformer_blocks." + k[14:]
|
| 718 |
+
elif k.startswith("double_blocks."):
|
| 719 |
+
new_k = "transformer.transformer_blocks." + k[14:]
|
| 720 |
+
elif k.startswith("proj_out."):
|
| 721 |
+
new_k = "transformer." + k
|
| 722 |
+
elif not k.startswith("transformer."):
|
| 723 |
+
new_k = "transformer." + k
|
| 724 |
+
renamed_dict[new_k] = v
|
| 725 |
+
extra_lora_dict = renamed_dict
|
| 726 |
+
|
| 727 |
+
# Now filter for transformer prefix and remove it for processing
|
| 728 |
+
extra_lora_dict = filter_state_dict(extra_lora_dict, filter_prefix="transformer.")
|
| 729 |
+
|
| 730 |
+
# Remove the transformer. prefix for internal processing
|
| 731 |
+
renamed_dict = {}
|
| 732 |
+
for k, v in extra_lora_dict.items():
|
| 733 |
+
new_k = k.replace("transformer.", "") if k.startswith("transformer.") else k
|
| 734 |
+
renamed_dict[new_k] = v
|
| 735 |
+
extra_lora_dict = renamed_dict
|
| 736 |
+
|
| 737 |
+
vector_dict, unquantized_lora_dict = {}, {}
|
| 738 |
+
for k in list(extra_lora_dict.keys()):
|
| 739 |
+
v = extra_lora_dict[k]
|
| 740 |
+
if v.ndim == 1:
|
| 741 |
+
vector_dict[k.replace(".lora_B.bias", ".bias")] = extra_lora_dict.pop(k)
|
| 742 |
+
elif "transformer_blocks" not in k and "single_transformer_blocks" not in k:
|
| 743 |
+
# Only unquantized parts (like final_layer) go here
|
| 744 |
+
unquantized_lora_dict[k] = extra_lora_dict.pop(k)
|
| 745 |
+
|
| 746 |
+
# Concatenate qkv_proj biases if present
|
| 747 |
+
for k in list(vector_dict.keys()):
|
| 748 |
+
if ".to_q." in k or ".add_q_proj." in k:
|
| 749 |
+
k_q = k
|
| 750 |
+
k_k = k.replace(".to_q.", ".to_k.").replace(".add_q_proj.", ".add_k_proj.")
|
| 751 |
+
k_v = k.replace(".to_q.", ".to_v.").replace(".add_q_proj.", ".add_v_proj.")
|
| 752 |
+
keys = [k_q, k_k, k_v]
|
| 753 |
+
values = [vector_dict.pop(key) for key in keys]
|
| 754 |
+
new_k = k_q.replace(".to_q.", ".to_qkv.").replace(".add_q_proj.", ".add_qkv_proj.")
|
| 755 |
+
vector_dict[new_k] = torch.cat(values, dim=0)
|
| 756 |
+
|
| 757 |
+
for k in extra_lora_dict.keys():
|
| 758 |
+
fc1_k = k
|
| 759 |
+
if "ff.net.0.proj" in k:
|
| 760 |
+
fc2_k = k.replace("ff.net.0.proj", "ff.net.2")
|
| 761 |
+
elif "ff_context.net.0.proj" in k:
|
| 762 |
+
fc2_k = k.replace("ff_context.net.0.proj", "ff_context.net.2")
|
| 763 |
+
else:
|
| 764 |
+
continue
|
| 765 |
+
assert fc2_k in extra_lora_dict
|
| 766 |
+
fc1_v = extra_lora_dict[fc1_k]
|
| 767 |
+
fc2_v = extra_lora_dict[fc2_k]
|
| 768 |
+
dim = 0 if "lora_A" in fc1_k else 1
|
| 769 |
+
|
| 770 |
+
fc1_rank = fc1_v.shape[dim]
|
| 771 |
+
fc2_rank = fc2_v.shape[dim]
|
| 772 |
+
if fc1_rank != fc2_rank:
|
| 773 |
+
rank = max(fc1_rank, fc2_rank)
|
| 774 |
+
if fc1_rank < rank:
|
| 775 |
+
extra_lora_dict[fc1_k] = pad(fc1_v, divisor=rank, dim=dim)
|
| 776 |
+
if fc2_rank < rank:
|
| 777 |
+
extra_lora_dict[fc2_k] = pad(fc2_v, divisor=rank, dim=dim)
|
| 778 |
+
|
| 779 |
+
block_names: set[str] = set()
|
| 780 |
+
for param_name in orig_state_dict.keys():
|
| 781 |
+
if param_name.startswith(("transformer_blocks.", "single_transformer_blocks.")):
|
| 782 |
+
block_names.add(".".join(param_name.split(".")[:2]))
|
| 783 |
+
block_names = sorted(block_names, key=lambda x: (x.split(".")[0], int(x.split(".")[-1])))
|
| 784 |
+
logger.debug(f"Converting {len(block_names)} transformer blocks...")
|
| 785 |
+
converted: dict[str, torch.Tensor] = {}
|
| 786 |
+
for block_name in tqdm(block_names, dynamic_ncols=True, desc="Converting LoRAs to nunchaku format"):
|
| 787 |
+
if block_name.startswith("transformer_blocks"):
|
| 788 |
+
convert_fn = convert_to_nunchaku_flux_transformer_block_lowrank_dict
|
| 789 |
+
else:
|
| 790 |
+
convert_fn = convert_to_nunchaku_flux_single_transformer_block_lowrank_dict
|
| 791 |
+
update_state_dict(
|
| 792 |
+
converted,
|
| 793 |
+
convert_fn(
|
| 794 |
+
orig_state_dict=orig_state_dict,
|
| 795 |
+
extra_lora_dict=extra_lora_dict,
|
| 796 |
+
converted_block_name=block_name,
|
| 797 |
+
candidate_block_name=block_name,
|
| 798 |
+
default_dtype=default_dtype,
|
| 799 |
+
),
|
| 800 |
+
prefix=block_name,
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
converted.update(unquantized_lora_dict)
|
| 804 |
+
converted.update(vector_dict)
|
| 805 |
+
return converted
|
| 806 |
+
|
| 807 |
+
|
| 808 |
+
def to_nunchaku(
|
| 809 |
+
input_lora: str | dict[str, torch.Tensor],
|
| 810 |
+
base_sd: str | dict[str, torch.Tensor],
|
| 811 |
+
dtype: str | torch.dtype = torch.bfloat16,
|
| 812 |
+
output_path: str | None = None,
|
| 813 |
+
) -> dict[str, torch.Tensor]:
|
| 814 |
+
"""
|
| 815 |
+
Convert LoRA weights to Nunchaku format.
|
| 816 |
+
|
| 817 |
+
Parameters
|
| 818 |
+
----------
|
| 819 |
+
input_lora : str or dict[str, torch.Tensor]
|
| 820 |
+
Path or dictionary of LoRA weights in Diffusers format. Can be composed of multiple LoRA weights.
|
| 821 |
+
base_sd : str or dict[str, torch.Tensor]
|
| 822 |
+
Path or dictionary of base quantized model weights.
|
| 823 |
+
dtype : str or torch.dtype, optional
|
| 824 |
+
Output data type ("bfloat16", "float16", or torch dtype). Default is torch.bfloat16.
|
| 825 |
+
output_path : str, optional
|
| 826 |
+
If provided, saves the result to this path.
|
| 827 |
+
|
| 828 |
+
Returns
|
| 829 |
+
-------
|
| 830 |
+
dict[str, torch.Tensor]
|
| 831 |
+
LoRA weights in Nunchaku format.
|
| 832 |
+
|
| 833 |
+
Example
|
| 834 |
+
-------
|
| 835 |
+
.. code-block:: python
|
| 836 |
+
|
| 837 |
+
nunchaku_weights = to_nunchaku("lora.safetensors", "base_model.safetensors")
|
| 838 |
+
nunchaku_weights = to_nunchaku(lora_dict, base_dict)
|
| 839 |
+
"""
|
| 840 |
+
if isinstance(input_lora, str):
|
| 841 |
+
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
|
| 842 |
+
else:
|
| 843 |
+
tensors = input_lora
|
| 844 |
+
if is_nunchaku_format(tensors):
|
| 845 |
+
logger.debug("Already in nunchaku format, no conversion needed.")
|
| 846 |
+
converted = tensors
|
| 847 |
+
else:
|
| 848 |
+
extra_lora_dict = to_diffusers(tensors)
|
| 849 |
+
|
| 850 |
+
if isinstance(base_sd, str):
|
| 851 |
+
orig_state_dict = load_state_dict_in_safetensors(base_sd)
|
| 852 |
+
else:
|
| 853 |
+
orig_state_dict = base_sd
|
| 854 |
+
|
| 855 |
+
if isinstance(dtype, str):
|
| 856 |
+
if dtype == "bfloat16":
|
| 857 |
+
dtype = torch.bfloat16
|
| 858 |
+
elif dtype == "float16":
|
| 859 |
+
dtype = torch.float16
|
| 860 |
+
else:
|
| 861 |
+
raise ValueError(f"Unsupported dtype {dtype}.")
|
| 862 |
+
else:
|
| 863 |
+
assert isinstance(dtype, torch.dtype)
|
| 864 |
+
|
| 865 |
+
converted = convert_to_nunchaku_flux_lowrank_dict(
|
| 866 |
+
base_model=orig_state_dict, lora=extra_lora_dict, default_dtype=dtype
|
| 867 |
+
)
|
| 868 |
+
if output_path is not None:
|
| 869 |
+
output_dir = os.path.dirname(os.path.abspath(output_path))
|
| 870 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 871 |
+
save_file(converted, output_path)
|
| 872 |
+
return converted
|
| 873 |
+
|
| 874 |
+
|
| 875 |
+
#### fuse vectors ####
|
| 876 |
+
|
| 877 |
+
|
| 878 |
+
def fuse_vectors(
|
| 879 |
+
vectors: dict[str, torch.Tensor], base_sd: dict[str, torch.Tensor], strength: float = 1
|
| 880 |
+
) -> dict[str, torch.Tensor]:
|
| 881 |
+
"""
|
| 882 |
+
Fuse vector (bias) terms from LoRA into the base model.
|
| 883 |
+
|
| 884 |
+
Parameters
|
| 885 |
+
----------
|
| 886 |
+
vectors : dict[str, torch.Tensor]
|
| 887 |
+
LoRA vector terms.
|
| 888 |
+
base_sd : dict[str, torch.Tensor]
|
| 889 |
+
Base model state dict.
|
| 890 |
+
strength : float, optional
|
| 891 |
+
Scaling factor for LoRA vectors.
|
| 892 |
+
|
| 893 |
+
Returns
|
| 894 |
+
-------
|
| 895 |
+
dict[str, torch.Tensor]
|
| 896 |
+
State dict with fused vectors.
|
| 897 |
+
"""
|
| 898 |
+
tensors: dict[str, torch.Tensor] = {}
|
| 899 |
+
packer = NunchakuWeightPacker(bits=4)
|
| 900 |
+
for k, v in base_sd.items():
|
| 901 |
+
if v.ndim != 1 or "smooth" in k or (k.startswith("single_transformer_blocks.") and ".mlp_fc2." in k):
|
| 902 |
+
continue
|
| 903 |
+
if "norm_q" in k or "norm_k" in k or "norm_added_q" in k or "norm_added_k" in k:
|
| 904 |
+
new_k = k.replace(".norm_", ".attn.norm_")
|
| 905 |
+
new_v = vectors.get(new_k, None)
|
| 906 |
+
tensors[k] = v if new_v is None else new_v
|
| 907 |
+
|
| 908 |
+
elif "norm.linear" in k or "norm1.linear" in k or "norm1_context.linear" in k:
|
| 909 |
+
diff = vectors.get(k, None)
|
| 910 |
+
|
| 911 |
+
if diff is not None:
|
| 912 |
+
if k.startswith("single_transformer_blocks."):
|
| 913 |
+
adanorm_splits = 3
|
| 914 |
+
else:
|
| 915 |
+
assert k.startswith("transformer_blocks.")
|
| 916 |
+
adanorm_splits = 6
|
| 917 |
+
diff = diff.view(adanorm_splits, -1).transpose(0, 1).reshape(-1)
|
| 918 |
+
tensors[k] = v + diff * strength
|
| 919 |
+
else:
|
| 920 |
+
tensors[k] = v
|
| 921 |
+
|
| 922 |
+
else:
|
| 923 |
+
if k.startswith("single_transformer_blocks."):
|
| 924 |
+
name_map = {".qkv_proj.": ".attn.to_qkv.", ".out_proj.": ".proj_out.", ".mlp_fc1.": ".proj_mlp."}
|
| 925 |
+
else:
|
| 926 |
+
assert k.startswith("transformer_blocks.")
|
| 927 |
+
name_map = {
|
| 928 |
+
".qkv_proj.": ".attn.to_qkv.",
|
| 929 |
+
".qkv_proj_context.": ".attn.add_qkv_proj.",
|
| 930 |
+
".out_proj.": ".attn.to_out.0.",
|
| 931 |
+
".out_proj_context.": ".attn.to_add_out.",
|
| 932 |
+
".mlp_fc1.": ".ff.net.0.proj.",
|
| 933 |
+
".mlp_fc2.": ".ff.net.2.",
|
| 934 |
+
".mlp_context_fc1.": ".ff_context.net.0.proj.",
|
| 935 |
+
".mlp_context_fc2.": ".ff_context.net.2.",
|
| 936 |
+
}
|
| 937 |
+
|
| 938 |
+
for original_pattern, new_pattern in name_map.items():
|
| 939 |
+
if original_pattern in k:
|
| 940 |
+
new_k = k.replace(original_pattern, new_pattern)
|
| 941 |
+
diff = vectors.get(new_k, None)
|
| 942 |
+
if diff is not None:
|
| 943 |
+
diff = diff * strength
|
| 944 |
+
diff = packer.pad_scale(diff, group_size=-1)
|
| 945 |
+
diff = packer.pack_scale(diff, group_size=-1)
|
| 946 |
+
tensors[k] = v + diff
|
| 947 |
+
break
|
| 948 |
+
|
| 949 |
+
return tensors
|
nunchaku/lora/flux/packer.py
ADDED
|
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Weight packing utilities for Nunchaku quantization.
|
| 3 |
+
|
| 4 |
+
This module provides concise tools for packing and unpacking weight tensors,
|
| 5 |
+
optimized for efficient GPU computation using Matrix Multiply and Accumulate (MMA) operations.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from ...utils import ceil_divide
|
| 11 |
+
from .utils import pad
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class MmaWeightPackerBase:
|
| 15 |
+
"""
|
| 16 |
+
Base class for Matrix Multiply and Accumulate (MMA) weight packing.
|
| 17 |
+
|
| 18 |
+
Packs weight tensors for efficient GPU computation using MMA operations.
|
| 19 |
+
Handles tile sizes, memory layout, and packing parameters.
|
| 20 |
+
|
| 21 |
+
Parameters
|
| 22 |
+
----------
|
| 23 |
+
bits : int
|
| 24 |
+
Quantization bits. Must be 1, 4, 8, 16, or 32.
|
| 25 |
+
warp_n : int
|
| 26 |
+
Warp size in the n dimension.
|
| 27 |
+
comp_n : int, optional
|
| 28 |
+
Computation tile size in n (default: 16).
|
| 29 |
+
comp_k : int, optional
|
| 30 |
+
Computation tile size in k (default: 256 // bits).
|
| 31 |
+
|
| 32 |
+
Raises
|
| 33 |
+
------
|
| 34 |
+
AssertionError
|
| 35 |
+
If bits or tile/pack sizes are invalid.
|
| 36 |
+
|
| 37 |
+
Attributes
|
| 38 |
+
----------
|
| 39 |
+
comp_n : int
|
| 40 |
+
Tile size in n for MMA computation.
|
| 41 |
+
comp_k : int
|
| 42 |
+
Tile size in k for MMA computation.
|
| 43 |
+
insn_n : int
|
| 44 |
+
MMA instruction tile size in n.
|
| 45 |
+
insn_k : int
|
| 46 |
+
MMA instruction tile size in k.
|
| 47 |
+
num_lanes : int
|
| 48 |
+
Number of lanes (threads) in a warp.
|
| 49 |
+
num_k_lanes : int
|
| 50 |
+
Number of lanes in k.
|
| 51 |
+
num_n_lanes : int
|
| 52 |
+
Number of lanes in n.
|
| 53 |
+
warp_n : int
|
| 54 |
+
Warp size in n.
|
| 55 |
+
reg_k : int
|
| 56 |
+
Elements in a register in k.
|
| 57 |
+
reg_n : int
|
| 58 |
+
Elements in a register in n.
|
| 59 |
+
k_pack_size : int
|
| 60 |
+
Elements in a pack in k.
|
| 61 |
+
n_pack_size : int
|
| 62 |
+
Elements in a pack in n.
|
| 63 |
+
pack_size : int
|
| 64 |
+
Elements in a pack accessed by a lane.
|
| 65 |
+
mem_k : int
|
| 66 |
+
Tile size in k for one memory access.
|
| 67 |
+
mem_n : int
|
| 68 |
+
Tile size in n for one memory access.
|
| 69 |
+
num_k_packs : int
|
| 70 |
+
Packs in k for one memory access.
|
| 71 |
+
num_n_packs : int
|
| 72 |
+
Packs in n for one memory access.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(self, bits: int, warp_n: int, comp_n: int = None, comp_k: int = None):
|
| 76 |
+
self.bits = bits
|
| 77 |
+
assert self.bits in (1, 4, 8, 16, 32), "weight bits should be 1, 4, 8, 16, or 32."
|
| 78 |
+
|
| 79 |
+
# region compute tile size
|
| 80 |
+
self.comp_n = comp_n if comp_n is not None else 16
|
| 81 |
+
# smallest tile size in `n` dimension for MMA computation.
|
| 82 |
+
self.comp_k = comp_k if comp_k is not None else 256 // self.bits
|
| 83 |
+
# smallest tile size in `k` dimension for MMA computation.
|
| 84 |
+
# the smallest MMA computation may contain several MMA instructions
|
| 85 |
+
self.insn_n = 8 # mma instruction tile size in `n` dimension
|
| 86 |
+
# tile size in `n` dimension for MMA instruction.
|
| 87 |
+
self.insn_k = self.comp_k
|
| 88 |
+
# tile size in `k` dimension for MMA instruction.
|
| 89 |
+
assert self.insn_k * self.bits in (
|
| 90 |
+
128,
|
| 91 |
+
256,
|
| 92 |
+
), f"insn_k ({self.insn_k}) * bits ({self.bits}) should be 128 or 256."
|
| 93 |
+
assert self.comp_n % self.insn_n == 0, f"comp_n ({self.comp_n}) should be divisible by insn_n ({self.insn_n})."
|
| 94 |
+
self.num_lanes = 32
|
| 95 |
+
# there are 32 lanes (or threads) in a warp.
|
| 96 |
+
self.num_k_lanes = 4
|
| 97 |
+
self.num_n_lanes = 8
|
| 98 |
+
assert (
|
| 99 |
+
warp_n >= self.comp_n and warp_n % self.comp_n == 0
|
| 100 |
+
), f"warp_n ({warp_n}) should be divisible by comp_n({self.comp_n})."
|
| 101 |
+
self.warp_n = warp_n
|
| 102 |
+
# endregion
|
| 103 |
+
# region memory
|
| 104 |
+
self.reg_k = 32 // self.bits
|
| 105 |
+
# number of elements in a register in `k` dimension.
|
| 106 |
+
self.reg_n = 1
|
| 107 |
+
# number of elements in a register in `n` dimension (always 1).
|
| 108 |
+
self.k_pack_size = self.comp_k // (self.num_k_lanes * self.reg_k)
|
| 109 |
+
# number of elements in a pack in `k` dimension.
|
| 110 |
+
self.n_pack_size = self.comp_n // (self.num_n_lanes * self.reg_n)
|
| 111 |
+
# number of elements in a pack in `n` dimension.
|
| 112 |
+
self.pack_size = self.k_pack_size * self.n_pack_size
|
| 113 |
+
# number of elements in a pack accessed by a lane at a time.
|
| 114 |
+
assert 1 <= self.pack_size <= 4, "pack size should be less than or equal to 4."
|
| 115 |
+
assert self.k_pack_size * self.num_k_lanes * self.reg_k == self.comp_k
|
| 116 |
+
assert self.n_pack_size * self.num_n_lanes * self.reg_n == self.comp_n
|
| 117 |
+
self.mem_k = self.comp_k
|
| 118 |
+
# the tile size in `k` dimension for one tensor memory access.
|
| 119 |
+
self.mem_n = warp_n
|
| 120 |
+
# the tile size in `n` dimension for one tensor memory access.
|
| 121 |
+
self.num_k_packs = self.mem_k // (self.k_pack_size * self.num_k_lanes * self.reg_k)
|
| 122 |
+
# number of packs in `k` dimension for one tensor memory access.
|
| 123 |
+
self.num_n_packs = self.mem_n // (self.n_pack_size * self.num_n_lanes * self.reg_n)
|
| 124 |
+
# number of packs in `n` dimension for one tensor memory access.
|
| 125 |
+
# endregion
|
| 126 |
+
|
| 127 |
+
def get_view_shape(self, n: int, k: int) -> tuple[int, int, int, int, int, int, int, int, int, int]:
|
| 128 |
+
"""
|
| 129 |
+
Returns the tensor view shape for MMA operations.
|
| 130 |
+
|
| 131 |
+
Parameters
|
| 132 |
+
----------
|
| 133 |
+
n : int
|
| 134 |
+
Output channel size (must be divisible by mem_n).
|
| 135 |
+
k : int
|
| 136 |
+
Input channel size (must be divisible by mem_k).
|
| 137 |
+
|
| 138 |
+
Returns
|
| 139 |
+
-------
|
| 140 |
+
tuple of int
|
| 141 |
+
(n_tiles, num_n_packs, n_pack_size, num_n_lanes, reg_n,
|
| 142 |
+
k_tiles, num_k_packs, k_pack_size, num_k_lanes, reg_k)
|
| 143 |
+
|
| 144 |
+
Raises
|
| 145 |
+
------
|
| 146 |
+
AssertionError
|
| 147 |
+
If n or k is not divisible by mem_n or mem_k.
|
| 148 |
+
"""
|
| 149 |
+
assert n % self.mem_n == 0, "output channel size should be divisible by mem_n."
|
| 150 |
+
assert k % self.mem_k == 0, "input channel size should be divisible by mem_k."
|
| 151 |
+
return (
|
| 152 |
+
n // self.mem_n,
|
| 153 |
+
self.num_n_packs,
|
| 154 |
+
self.n_pack_size,
|
| 155 |
+
self.num_n_lanes,
|
| 156 |
+
self.reg_n,
|
| 157 |
+
k // self.mem_k,
|
| 158 |
+
self.num_k_packs,
|
| 159 |
+
self.k_pack_size,
|
| 160 |
+
self.num_k_lanes,
|
| 161 |
+
self.reg_k,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class NunchakuWeightPacker(MmaWeightPackerBase):
|
| 166 |
+
"""
|
| 167 |
+
Nunchaku-specific weight packer. Provide Nunchaku-specific packing of
|
| 168 |
+
quantized weights, scales, and low-rank weights.
|
| 169 |
+
|
| 170 |
+
Parameters
|
| 171 |
+
----------
|
| 172 |
+
bits : int
|
| 173 |
+
Number of quantization bits. Must be 1, 4, 8, 16, or 32.
|
| 174 |
+
warp_n : int, optional
|
| 175 |
+
Warp size in the n dimension. Default is 128.
|
| 176 |
+
|
| 177 |
+
Attributes
|
| 178 |
+
----------
|
| 179 |
+
num_k_unrolls : int
|
| 180 |
+
Number of unrolls in the k dimension (always 2 for Nunchaku).
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
def __init__(self, bits: int, warp_n: int = 128):
|
| 184 |
+
super().__init__(bits=bits, warp_n=warp_n)
|
| 185 |
+
self.num_k_unrolls = 2
|
| 186 |
+
|
| 187 |
+
def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
|
| 188 |
+
"""
|
| 189 |
+
Pack quantized weight tensor for Nunchaku MMA.
|
| 190 |
+
|
| 191 |
+
Parameters
|
| 192 |
+
----------
|
| 193 |
+
weight : torch.Tensor
|
| 194 |
+
Quantized weight tensor of dtype torch.int32 and shape (n, k).
|
| 195 |
+
|
| 196 |
+
Returns
|
| 197 |
+
-------
|
| 198 |
+
torch.Tensor
|
| 199 |
+
Packed weight tensor of dtype torch.int8.
|
| 200 |
+
"""
|
| 201 |
+
assert weight.dtype == torch.int32, f"quantized weight should be torch.int32, but got {weight.dtype}."
|
| 202 |
+
n, k = weight.shape
|
| 203 |
+
assert n % self.mem_n == 0, f"output channel size ({n}) should be divisible by mem_n ({self.mem_n})."
|
| 204 |
+
# currently, Nunchaku did not check the boundry of unrolled `k` dimension
|
| 205 |
+
assert k % (self.mem_k * self.num_k_unrolls) == 0, (
|
| 206 |
+
f"input channel size ({k}) should be divisible by "
|
| 207 |
+
f"mem_k ({self.mem_k}) * num_k_unrolls ({self.num_k_unrolls})."
|
| 208 |
+
)
|
| 209 |
+
n_tiles, k_tiles = n // self.mem_n, k // self.mem_k
|
| 210 |
+
weight = weight.reshape(
|
| 211 |
+
n_tiles,
|
| 212 |
+
self.num_n_packs, # 8 when warp_n = 128
|
| 213 |
+
self.n_pack_size, # always 2 in nunchaku
|
| 214 |
+
self.num_n_lanes, # constant 8
|
| 215 |
+
self.reg_n, # constant 1
|
| 216 |
+
k_tiles,
|
| 217 |
+
self.num_k_packs, # 1
|
| 218 |
+
self.k_pack_size, # always 2 in nunchaku
|
| 219 |
+
self.num_k_lanes, # constant 4
|
| 220 |
+
self.reg_k, # always 8 = 32 bits / 4 bits
|
| 221 |
+
)
|
| 222 |
+
# (n_tiles, num_n_packs, n_pack_size, num_n_lanes, reg_n, k_tiles, num_k_packs, k_pack_size, num_k_lanes, reg_k)
|
| 223 |
+
# =>
|
| 224 |
+
# (n_tiles, k_tiles, num_k_packs, num_n_packs, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, reg_n, reg_k)
|
| 225 |
+
weight = weight.permute(0, 5, 6, 1, 3, 8, 2, 7, 4, 9).contiguous()
|
| 226 |
+
assert weight.shape[4:-2] == (8, 4, 2, 2)
|
| 227 |
+
if self.bits == 4:
|
| 228 |
+
weight = weight.bitwise_and_(0xF)
|
| 229 |
+
shift = torch.arange(0, 32, 4, dtype=torch.int32, device=weight.device)
|
| 230 |
+
weight = weight.bitwise_left_shift_(shift)
|
| 231 |
+
weight = weight.sum(dim=-1, dtype=torch.int32)
|
| 232 |
+
elif self.bits == 8:
|
| 233 |
+
weight = weight.bitwise_and_(0xFF)
|
| 234 |
+
shift = torch.arange(0, 32, 8, dtype=torch.int32, device=weight.device)
|
| 235 |
+
weight = weight.bitwise_left_shift_(shift)
|
| 236 |
+
weight = weight.sum(dim=-1, dtype=torch.int32)
|
| 237 |
+
else:
|
| 238 |
+
raise NotImplementedError(f"weight bits {self.bits} is not supported.")
|
| 239 |
+
return weight.view(dtype=torch.int8).view(n, -1) # assume little-endian
|
| 240 |
+
|
| 241 |
+
def pack_scale(self, scale: torch.Tensor, group_size: int) -> torch.Tensor:
|
| 242 |
+
"""
|
| 243 |
+
Pack scale tensor for Nunchaku MMA.
|
| 244 |
+
|
| 245 |
+
Parameters
|
| 246 |
+
----------
|
| 247 |
+
scale : torch.Tensor
|
| 248 |
+
Scale tensor of dtype torch.float16 or torch.bfloat16.
|
| 249 |
+
group_size : int
|
| 250 |
+
Group size for quantization.
|
| 251 |
+
|
| 252 |
+
Returns
|
| 253 |
+
-------
|
| 254 |
+
torch.Tensor
|
| 255 |
+
Packed scale tensor.
|
| 256 |
+
"""
|
| 257 |
+
if self.check_if_micro_scale(group_size=group_size):
|
| 258 |
+
return self.pack_micro_scale(scale, group_size=group_size)
|
| 259 |
+
# note: refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16864-c
|
| 260 |
+
assert scale.dtype in (torch.float16, torch.bfloat16), "currently nunchaku only supports fp16 and bf16."
|
| 261 |
+
n = scale.shape[0]
|
| 262 |
+
# nunchaku load scales all in one access
|
| 263 |
+
# for `[warp_n, warp_k]` weights, we load `[warp_n, warp_k / group_size]` scales
|
| 264 |
+
# scale loading is parallelized in `n` dimension, that is,
|
| 265 |
+
# `num_s_lanes` in a warp load `num_s_packs` of `s_pack_size` elements, in total `warp_s` elements
|
| 266 |
+
# each element in `n` dimension is 16 bit as it contains 1 fp16
|
| 267 |
+
# min `s_pack_size` set to 2 element, since each lane at least holds 2 accumulator results in `n` dimension
|
| 268 |
+
# max `s_pack_size` set to 128b/16b = 8 elements
|
| 269 |
+
# for `warp_n = 8`, we have
|
| 270 |
+
# `s_pack_size = 2`, `num_s_lanes = 4`, `num_s_packs = 1`
|
| 271 |
+
# for `warp_n = 128`, we have
|
| 272 |
+
# `s_pack_size = 4`, `num_s_lanes = 32`, `num_s_packs = 1`
|
| 273 |
+
# for `warp_n = 512`, we have
|
| 274 |
+
# `s_pack_size = 8`, `num_s_lanes = 32`, `num_s_packs = 2`
|
| 275 |
+
s_pack_size = min(max(self.warp_n // self.num_lanes, 2), 8)
|
| 276 |
+
num_s_lanes = min(self.num_lanes, self.warp_n // s_pack_size)
|
| 277 |
+
num_s_packs = self.warp_n // (s_pack_size * num_s_lanes)
|
| 278 |
+
warp_s = num_s_packs * num_s_lanes * s_pack_size
|
| 279 |
+
assert warp_s == self.warp_n, "warp_n for scales should be equal to warp_n for weights."
|
| 280 |
+
# `num_n_lanes = 8 (constant)` generates 8 elements consecutive in `n` dimension
|
| 281 |
+
# however, they are held by 4 lanes, each lane holds 2 elements in `n` dimension
|
| 282 |
+
# thus, we start from first 4 lanes, assign 2 elements to each lane, until all 8 elements are assigned
|
| 283 |
+
# we then repeat the process for the same 4 lanes, until each lane holds `s_pack_size` elements
|
| 284 |
+
# finally, we move to next 4 lanes, and repeat the process until all `num_s_lanes` lanes are assigned
|
| 285 |
+
# the process is repeated for `num_s_packs` times
|
| 286 |
+
# here is an example for `warp_n = 128, s_pack_size = 4, num_s_lanes = 32, num_s_packs = 1`
|
| 287 |
+
# wscales store order:
|
| 288 |
+
# 0 1 8 9 <-- load by lane 0, broadcast to lane {0, 4, 8, ..., 28} (8x)
|
| 289 |
+
# 2 3 10 11 <-- load by lane 1, broadcast to lane {1, 5, 9, ..., 29} (8x)
|
| 290 |
+
# 4 5 12 13 <-- load by lane 2, broadcast to lane {2, 6, 10, ..., 30} (8x)
|
| 291 |
+
# 6 7 14 15 <-- load by lane 3, broadcast to lane {3, 7, 11, ..., 31} (8x)
|
| 292 |
+
# 16 17 24 25 <-- load by lane 4, broadcast to lane {0, 4, 8, ..., 28} (8x)
|
| 293 |
+
# ...
|
| 294 |
+
# 22 23 30 31 <-- load by lane 7, broadcast to lane {3, 7, 11, ..., 31} (8x)
|
| 295 |
+
# ... ...
|
| 296 |
+
# 112 113 120 121 <-- load by lane 28, broadcast to lane {0, 4, 8, ..., 28} (8x)
|
| 297 |
+
# ...
|
| 298 |
+
# 118 119 126 127 <-- load by lane 31, broadcast to lane {3, 7, 11, ..., 31} (8x)
|
| 299 |
+
scale = scale.reshape(n // warp_s, num_s_packs, num_s_lanes // 4, s_pack_size // 2, 4, 2, -1)
|
| 300 |
+
scale = scale.permute(0, 6, 1, 2, 4, 3, 5).contiguous()
|
| 301 |
+
return scale.view(-1) if group_size == -1 else scale.view(-1, n) # the shape is just used for validation
|
| 302 |
+
|
| 303 |
+
def pack_micro_scale(self, scale: torch.Tensor, group_size: int) -> torch.Tensor:
|
| 304 |
+
"""
|
| 305 |
+
Pack micro scale tensor for Nunchaku MMA.
|
| 306 |
+
|
| 307 |
+
Parameters
|
| 308 |
+
----------
|
| 309 |
+
scale : torch.Tensor
|
| 310 |
+
Scale tensor of dtype torch.float16 or torch.bfloat16.
|
| 311 |
+
group_size : int
|
| 312 |
+
Group size for quantization (must be 16).
|
| 313 |
+
|
| 314 |
+
Returns
|
| 315 |
+
-------
|
| 316 |
+
torch.Tensor
|
| 317 |
+
Packed micro scale tensor.
|
| 318 |
+
"""
|
| 319 |
+
assert scale.dtype in (torch.float16, torch.bfloat16), "currently nunchaku only supports fp16 and bf16."
|
| 320 |
+
assert scale.max() <= 448, "scale should be less than 448."
|
| 321 |
+
assert scale.min() >= -448, "scale should be greater than -448."
|
| 322 |
+
assert group_size == 16, "currently only support group size 16."
|
| 323 |
+
assert self.insn_k == 64, "insn_k should be 64."
|
| 324 |
+
scale = scale.to(dtype=torch.float8_e4m3fn)
|
| 325 |
+
n = scale.shape[0]
|
| 326 |
+
assert self.warp_n >= 32, "currently only support warp_n >= 32."
|
| 327 |
+
# for `[warp_n, warp_k]` weights, we load `[warp_n, warp_k / group_size]` scales
|
| 328 |
+
# scale loading is parallelized in `n` dimension, that is,
|
| 329 |
+
# `num_s_lanes` in a warp load `num_s_packs` of `s_pack_size` elements, in total `warp_s` elements
|
| 330 |
+
# each element in `n` dimension is 32 bit as it contains 4 fp8 in `k` dimension
|
| 331 |
+
# min `s_pack_size` set to 1 element
|
| 332 |
+
# max `s_pack_size` set to 128b/32b = 4 elements
|
| 333 |
+
# for `warp_n = 128`, we have
|
| 334 |
+
# `s_pack_size = 4`, `num_s_lanes = 32`, `num_s_packs = 1`
|
| 335 |
+
# for `warp_n = 512`, we have
|
| 336 |
+
# `s_pack_size = 8`, `num_s_lanes = 32`, `num_s_packs = 2`
|
| 337 |
+
s_pack_size = min(max(self.warp_n // self.num_lanes, 1), 4)
|
| 338 |
+
num_s_lanes = 4 * 8 # 32 lanes is divided into 4 pieces, each piece has 8 lanes at a stride of 4
|
| 339 |
+
num_s_packs = ceil_divide(self.warp_n, s_pack_size * num_s_lanes)
|
| 340 |
+
warp_s = num_s_packs * num_s_lanes * s_pack_size
|
| 341 |
+
assert warp_s == self.warp_n, "warp_n for scales should be equal to warp_n for weights."
|
| 342 |
+
# note: refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-scaling-thread-id-b-selection
|
| 343 |
+
# we start from first 8 lines at a stride of 4, assign 1 element to each lane, until all 8 elements are assigned
|
| 344 |
+
# we then move to next 8 lines at a stride of 4, and repeat the process until all 32 lanes are assigned
|
| 345 |
+
# here is an example for `warp_n = 128, s_pack_size = 4, num_s_lanes = 32, num_s_packs = 1`
|
| 346 |
+
# wscales store order:
|
| 347 |
+
# 0 32 64 96 <-- load by lane 0
|
| 348 |
+
# 8 40 72 104 <-- load by lane 1
|
| 349 |
+
# 16 48 80 112 <-- load by lane 2
|
| 350 |
+
# 24 56 88 120 <-- load by lane 3
|
| 351 |
+
# 1 33 65 97 <-- load by lane 4
|
| 352 |
+
# ...
|
| 353 |
+
# 25 57 81 113 <-- load by lane 7
|
| 354 |
+
# ...
|
| 355 |
+
# 7 39 71 103 <-- load by lane 28
|
| 356 |
+
# ...
|
| 357 |
+
# 31 63 95 127 <-- load by lane 31
|
| 358 |
+
scale = scale.view(n // warp_s, num_s_packs, s_pack_size, 4, 8, -1, self.insn_k // group_size)
|
| 359 |
+
scale = scale.permute(0, 5, 1, 4, 3, 2, 6).contiguous()
|
| 360 |
+
return scale.view(-1, n) # the shape is just used for validation
|
| 361 |
+
|
| 362 |
+
def pack_lowrank_weight(self, weight: torch.Tensor, down: bool) -> torch.Tensor:
|
| 363 |
+
"""
|
| 364 |
+
Pack low-rank weight tensor.
|
| 365 |
+
|
| 366 |
+
Parameters
|
| 367 |
+
----------
|
| 368 |
+
weight : torch.Tensor
|
| 369 |
+
Low-rank weight tensor of dtype torch.float16 or torch.bfloat16.
|
| 370 |
+
down : bool
|
| 371 |
+
If True, weight is for down projection in low-rank branch.
|
| 372 |
+
|
| 373 |
+
Returns
|
| 374 |
+
-------
|
| 375 |
+
torch.Tensor
|
| 376 |
+
Packed low-rank weight tensor.
|
| 377 |
+
"""
|
| 378 |
+
assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}."
|
| 379 |
+
reg_n, reg_k = 1, 2 # reg_n is always 1, reg_k is 32 bits // 16 bits = 2
|
| 380 |
+
pack_n = self.n_pack_size * self.num_n_lanes * reg_n
|
| 381 |
+
pack_k = self.k_pack_size * self.num_k_lanes * reg_k
|
| 382 |
+
weight = pad(weight, divisor=(pack_n, pack_k), dim=(0, 1))
|
| 383 |
+
if down:
|
| 384 |
+
r, c = weight.shape
|
| 385 |
+
r_packs, c_packs = r // pack_n, c // pack_k
|
| 386 |
+
weight = weight.view(r_packs, pack_n, c_packs, pack_k).permute(2, 0, 1, 3)
|
| 387 |
+
else:
|
| 388 |
+
c, r = weight.shape
|
| 389 |
+
c_packs, r_packs = c // pack_n, r // pack_k
|
| 390 |
+
weight = weight.view(c_packs, pack_n, r_packs, pack_k).permute(0, 2, 1, 3)
|
| 391 |
+
weight = weight.reshape(
|
| 392 |
+
c_packs, r_packs, self.n_pack_size, self.num_n_lanes, reg_n, self.k_pack_size, self.num_k_lanes, reg_k
|
| 393 |
+
)
|
| 394 |
+
# (c_packs, r_packs, n_pack_size, num_n_lanes, reg_n, k_pack_size, num_k_lanes, reg_k)
|
| 395 |
+
# =>
|
| 396 |
+
# (c_packs, r_packs, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, reg_n, reg_k)
|
| 397 |
+
weight = weight.permute(0, 1, 3, 6, 2, 5, 4, 7).contiguous()
|
| 398 |
+
return weight.view(c, r)
|
| 399 |
+
|
| 400 |
+
def unpack_lowrank_weight(self, weight: torch.Tensor, down: bool) -> torch.Tensor:
|
| 401 |
+
"""
|
| 402 |
+
Unpack low-rank weight tensor.
|
| 403 |
+
|
| 404 |
+
Parameters
|
| 405 |
+
----------
|
| 406 |
+
weight : torch.Tensor
|
| 407 |
+
Packed low-rank weight tensor of dtype torch.float16 or torch.bfloat16.
|
| 408 |
+
down : bool
|
| 409 |
+
If True, weight is for down projection in low-rank branch.
|
| 410 |
+
|
| 411 |
+
Returns
|
| 412 |
+
-------
|
| 413 |
+
torch.Tensor
|
| 414 |
+
Unpacked low-rank weight tensor.
|
| 415 |
+
"""
|
| 416 |
+
c, r = weight.shape
|
| 417 |
+
assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}."
|
| 418 |
+
reg_n, reg_k = 1, 2 # reg_n is always 1, reg_k is 32 bits // 16 bits = 2
|
| 419 |
+
pack_n = self.n_pack_size * self.num_n_lanes * reg_n
|
| 420 |
+
pack_k = self.k_pack_size * self.num_k_lanes * reg_k
|
| 421 |
+
if down:
|
| 422 |
+
r_packs, c_packs = r // pack_n, c // pack_k
|
| 423 |
+
else:
|
| 424 |
+
c_packs, r_packs = c // pack_n, r // pack_k
|
| 425 |
+
weight = weight.view(
|
| 426 |
+
c_packs, r_packs, self.num_n_lanes, self.num_k_lanes, self.n_pack_size, self.k_pack_size, reg_n, reg_k
|
| 427 |
+
)
|
| 428 |
+
# (c_packs, r_packs, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, reg_n, reg_k)
|
| 429 |
+
# =>
|
| 430 |
+
# (c_packs, r_packs, n_pack_size, num_n_lanes, reg_n, k_pack_size, num_k_lanes, reg_k)
|
| 431 |
+
weight = weight.permute(0, 1, 4, 2, 6, 5, 3, 7).contiguous()
|
| 432 |
+
weight = weight.view(c_packs, r_packs, pack_n, pack_k)
|
| 433 |
+
if down:
|
| 434 |
+
weight = weight.permute(1, 2, 0, 3).contiguous().view(r, c)
|
| 435 |
+
else:
|
| 436 |
+
weight = weight.permute(0, 2, 1, 3).contiguous().view(c, r)
|
| 437 |
+
return weight
|
| 438 |
+
|
| 439 |
+
def check_if_micro_scale(self, group_size: int) -> bool:
|
| 440 |
+
"""
|
| 441 |
+
Check if micro scale packing is required.
|
| 442 |
+
|
| 443 |
+
Parameters
|
| 444 |
+
----------
|
| 445 |
+
group_size : int
|
| 446 |
+
Group size for quantization.
|
| 447 |
+
|
| 448 |
+
Returns
|
| 449 |
+
-------
|
| 450 |
+
bool
|
| 451 |
+
True if micro scale packing is required.
|
| 452 |
+
"""
|
| 453 |
+
return self.insn_k == group_size * 4
|
| 454 |
+
|
| 455 |
+
def pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
|
| 456 |
+
"""
|
| 457 |
+
Pad weight tensor to required shape.
|
| 458 |
+
|
| 459 |
+
Parameters
|
| 460 |
+
----------
|
| 461 |
+
weight : torch.Tensor
|
| 462 |
+
Weight tensor of shape (n, k).
|
| 463 |
+
|
| 464 |
+
Returns
|
| 465 |
+
-------
|
| 466 |
+
torch.Tensor
|
| 467 |
+
Padded weight tensor.
|
| 468 |
+
"""
|
| 469 |
+
assert weight.ndim == 2, "weight tensor should be 2D."
|
| 470 |
+
return pad(weight, divisor=(self.mem_n, self.mem_k * self.num_k_unrolls), dim=(0, 1))
|
| 471 |
+
|
| 472 |
+
def pad_scale(self, scale: torch.Tensor, group_size: int, fill_value: float = 0) -> torch.Tensor:
|
| 473 |
+
"""
|
| 474 |
+
Pad scale tensor to required shape.
|
| 475 |
+
|
| 476 |
+
Parameters
|
| 477 |
+
----------
|
| 478 |
+
scale : torch.Tensor
|
| 479 |
+
Scale tensor.
|
| 480 |
+
group_size : int
|
| 481 |
+
Group size for quantization.
|
| 482 |
+
fill_value : float, optional
|
| 483 |
+
Value to use for padding. Default is 0.
|
| 484 |
+
|
| 485 |
+
Returns
|
| 486 |
+
-------
|
| 487 |
+
torch.Tensor
|
| 488 |
+
Padded scale tensor.
|
| 489 |
+
"""
|
| 490 |
+
if group_size > 0 and scale.numel() > scale.shape[0]:
|
| 491 |
+
scale = scale.view(scale.shape[0], 1, -1, 1)
|
| 492 |
+
if self.check_if_micro_scale(group_size=group_size):
|
| 493 |
+
scale = pad(scale, divisor=(self.warp_n, self.insn_k // group_size), dim=(0, 2), fill_value=fill_value)
|
| 494 |
+
else:
|
| 495 |
+
scale = pad(scale, divisor=(self.warp_n, self.num_k_unrolls), dim=(0, 2), fill_value=fill_value)
|
| 496 |
+
else:
|
| 497 |
+
scale = pad(scale, divisor=self.warp_n, dim=0, fill_value=fill_value)
|
| 498 |
+
return scale
|
| 499 |
+
|
| 500 |
+
def pad_lowrank_weight(self, weight: torch.Tensor, down: bool) -> torch.Tensor:
|
| 501 |
+
"""
|
| 502 |
+
Pad low-rank weight tensor to required shape.
|
| 503 |
+
|
| 504 |
+
Parameters
|
| 505 |
+
----------
|
| 506 |
+
weight : torch.Tensor
|
| 507 |
+
Low-rank weight tensor.
|
| 508 |
+
down : bool
|
| 509 |
+
If True, weight is for down projection in low-rank branch.
|
| 510 |
+
|
| 511 |
+
Returns
|
| 512 |
+
-------
|
| 513 |
+
torch.Tensor
|
| 514 |
+
Padded low-rank weight tensor.
|
| 515 |
+
"""
|
| 516 |
+
assert weight.ndim == 2, "weight tensor should be 2D."
|
| 517 |
+
return pad(weight, divisor=self.warp_n, dim=1 if down else 0)
|
nunchaku/lora/flux/utils.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions for LoRAs in Flux models.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import typing as tp
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from ...utils import ceil_divide, load_state_dict_in_safetensors
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def is_nunchaku_format(lora: str | dict[str, torch.Tensor]) -> bool:
|
| 13 |
+
"""
|
| 14 |
+
Check if LoRA weights are in Nunchaku format.
|
| 15 |
+
|
| 16 |
+
Parameters
|
| 17 |
+
----------
|
| 18 |
+
lora : str or dict[str, torch.Tensor]
|
| 19 |
+
Path to a safetensors file or a dictionary of LoRA weights.
|
| 20 |
+
|
| 21 |
+
Returns
|
| 22 |
+
-------
|
| 23 |
+
bool
|
| 24 |
+
True if the weights are in Nunchaku format, False otherwise.
|
| 25 |
+
|
| 26 |
+
Examples
|
| 27 |
+
--------
|
| 28 |
+
>>> is_nunchaku_format("path/to/lora.safetensors")
|
| 29 |
+
True
|
| 30 |
+
"""
|
| 31 |
+
if isinstance(lora, str):
|
| 32 |
+
tensors = load_state_dict_in_safetensors(lora, device="cpu", return_metadata=False)
|
| 33 |
+
assert isinstance(tensors, dict), "Expected dict when return_metadata=False"
|
| 34 |
+
else:
|
| 35 |
+
tensors = lora
|
| 36 |
+
|
| 37 |
+
for k in tensors.keys():
|
| 38 |
+
if ".mlp_fc" in k or "mlp_context_fc1" in k:
|
| 39 |
+
return True
|
| 40 |
+
return False
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def pad(
|
| 44 |
+
tensor: tp.Optional[torch.Tensor],
|
| 45 |
+
divisor: int | tp.Sequence[int],
|
| 46 |
+
dim: int | tp.Sequence[int],
|
| 47 |
+
fill_value: float | int = 0,
|
| 48 |
+
) -> torch.Tensor | None:
|
| 49 |
+
"""
|
| 50 |
+
Pad a tensor so specified dimensions are divisible by given divisors.
|
| 51 |
+
|
| 52 |
+
Parameters
|
| 53 |
+
----------
|
| 54 |
+
tensor : torch.Tensor or None
|
| 55 |
+
The tensor to pad. If None, returns None.
|
| 56 |
+
divisor : int or sequence of int
|
| 57 |
+
Divisor(s) for the dimension(s) to pad.
|
| 58 |
+
dim : int or sequence of int
|
| 59 |
+
Dimension(s) to pad.
|
| 60 |
+
fill_value : float or int, optional
|
| 61 |
+
Value to use for padding (default: 0).
|
| 62 |
+
|
| 63 |
+
Returns
|
| 64 |
+
-------
|
| 65 |
+
torch.Tensor or None
|
| 66 |
+
The padded tensor, or None if input tensor was None.
|
| 67 |
+
|
| 68 |
+
Examples
|
| 69 |
+
--------
|
| 70 |
+
>>> tensor = torch.randn(10, 20)
|
| 71 |
+
>>> pad(tensor, divisor=16, dim=0).shape
|
| 72 |
+
torch.Size([16, 20])
|
| 73 |
+
>>> pad(tensor, divisor=[16, 32], dim=[0, 1]).shape
|
| 74 |
+
torch.Size([16, 32])
|
| 75 |
+
"""
|
| 76 |
+
if isinstance(divisor, int):
|
| 77 |
+
if divisor <= 1:
|
| 78 |
+
return tensor
|
| 79 |
+
elif all(d <= 1 for d in divisor):
|
| 80 |
+
return tensor
|
| 81 |
+
if tensor is None:
|
| 82 |
+
return None
|
| 83 |
+
shape = list(tensor.shape)
|
| 84 |
+
if isinstance(dim, int):
|
| 85 |
+
assert isinstance(divisor, int)
|
| 86 |
+
shape[dim] = ceil_divide(shape[dim], divisor) * divisor
|
| 87 |
+
else:
|
| 88 |
+
if isinstance(divisor, int):
|
| 89 |
+
divisor = [divisor] * len(dim)
|
| 90 |
+
for d, div in zip(dim, divisor, strict=True):
|
| 91 |
+
shape[d] = ceil_divide(shape[d], div) * div
|
| 92 |
+
result = torch.full(shape, fill_value, dtype=tensor.dtype, device=tensor.device)
|
| 93 |
+
result[[slice(0, extent) for extent in tensor.shape]] = tensor
|
| 94 |
+
return result
|
nunchaku/models/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .text_encoders.t5_encoder import NunchakuT5EncoderModel
|
| 2 |
+
from .transformers import (
|
| 3 |
+
NunchakuFluxTransformer2dModel,
|
| 4 |
+
)
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"NunchakuFluxTransformer2dModel",
|
| 8 |
+
"NunchakuT5EncoderModel",
|
| 9 |
+
]
|
nunchaku/models/attention.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Nunchaku quantized attention-related modules.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from diffusers.models.activations import GELU
|
| 7 |
+
from diffusers.models.attention import FeedForward
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
from ..ops.fused import fused_gelu_mlp
|
| 11 |
+
from .linear import SVDQW4A4Linear
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class NunchakuBaseAttention(nn.Module):
|
| 15 |
+
"""
|
| 16 |
+
Base class for Nunchaku attention modules.
|
| 17 |
+
|
| 18 |
+
Provides a common interface for attention modules with processor selection.
|
| 19 |
+
|
| 20 |
+
Parameters
|
| 21 |
+
----------
|
| 22 |
+
processor : str, optional
|
| 23 |
+
Name of the attention processor to use. Default is "flashattn2".
|
| 24 |
+
*args, **kwargs :
|
| 25 |
+
Additional arguments for subclass initialization.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, processor: str = "flashattn2", *args, **kwargs):
|
| 29 |
+
super(NunchakuBaseAttention, self).__init__()
|
| 30 |
+
self.processor = None
|
| 31 |
+
self.set_processor(processor)
|
| 32 |
+
|
| 33 |
+
def set_processor(self, processor: str):
|
| 34 |
+
"""
|
| 35 |
+
Set the attention processor. Must be implemented by subclasses.
|
| 36 |
+
|
| 37 |
+
Parameters
|
| 38 |
+
----------
|
| 39 |
+
processor : str
|
| 40 |
+
Name of the processor to use.
|
| 41 |
+
|
| 42 |
+
Raises
|
| 43 |
+
------
|
| 44 |
+
NotImplementedError
|
| 45 |
+
If not implemented in subclass.
|
| 46 |
+
"""
|
| 47 |
+
raise NotImplementedError("Subclass must implement this method")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _patch_linear(module: nn.Module, linear_cls, **kwargs) -> nn.Module:
|
| 51 |
+
"""
|
| 52 |
+
Recursively replace all nn.Linear modules in a given module with a custom linear class.
|
| 53 |
+
|
| 54 |
+
Parameters
|
| 55 |
+
----------
|
| 56 |
+
module : nn.Module
|
| 57 |
+
The module to patch.
|
| 58 |
+
linear_cls : type
|
| 59 |
+
The custom linear class to use for replacement.
|
| 60 |
+
**kwargs :
|
| 61 |
+
Additional arguments passed to ``from_linear``.
|
| 62 |
+
|
| 63 |
+
Returns
|
| 64 |
+
-------
|
| 65 |
+
nn.Module
|
| 66 |
+
The patched module with custom linear layers.
|
| 67 |
+
"""
|
| 68 |
+
for name, child in module.named_children():
|
| 69 |
+
if isinstance(child, nn.Linear):
|
| 70 |
+
setattr(module, name, linear_cls.from_linear(child, **kwargs))
|
| 71 |
+
else:
|
| 72 |
+
_patch_linear(child, linear_cls, **kwargs)
|
| 73 |
+
return module
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class NunchakuFeedForward(FeedForward):
|
| 77 |
+
"""
|
| 78 |
+
Quantized feed-forward (MLP) block with fused GELU support.
|
| 79 |
+
|
| 80 |
+
Replaces linear layers in a FeedForward block with :class:`~nunchaku.models.linear.SVDQW4A4Linear` for quantized inference.
|
| 81 |
+
Supports fused GELU-MLP computation for efficiency.
|
| 82 |
+
|
| 83 |
+
Parameters
|
| 84 |
+
----------
|
| 85 |
+
ff : FeedForward
|
| 86 |
+
Source FeedForward block to quantize.
|
| 87 |
+
**kwargs :
|
| 88 |
+
Additional arguments for SVDQW4A4Linear.
|
| 89 |
+
|
| 90 |
+
Notes
|
| 91 |
+
-----
|
| 92 |
+
For int4 quantization, the activation of the second MLP layer is shifted to be unsigned.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
def __init__(self, ff: FeedForward, **kwargs):
|
| 96 |
+
super(FeedForward, self).__init__()
|
| 97 |
+
self.net = _patch_linear(ff.net, SVDQW4A4Linear, **kwargs)
|
| 98 |
+
# For int4, shift the activation of mlp_fc2 to make it unsigned
|
| 99 |
+
self.net[2].act_unsigned = self.net[2].precision != "nvfp4"
|
| 100 |
+
|
| 101 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 102 |
+
"""
|
| 103 |
+
Forward pass for the quantized feed-forward block.
|
| 104 |
+
It will call :func:`~nunchaku.ops.fused.fused_gelu_mlp` if the first layer is GELU;
|
| 105 |
+
otherwise, apply modules sequentially.
|
| 106 |
+
|
| 107 |
+
Parameters
|
| 108 |
+
----------
|
| 109 |
+
hidden_states : torch.Tensor, shape (B, D)
|
| 110 |
+
Input tensor.
|
| 111 |
+
|
| 112 |
+
Returns
|
| 113 |
+
-------
|
| 114 |
+
torch.Tensor, shape (B, D)
|
| 115 |
+
Output tensor after feed-forward transformation.
|
| 116 |
+
"""
|
| 117 |
+
if isinstance(self.net[0], GELU):
|
| 118 |
+
return fused_gelu_mlp(hidden_states, self.net[0].proj, self.net[2])
|
| 119 |
+
else:
|
| 120 |
+
# Fallback to original implementation
|
| 121 |
+
for module in self.net:
|
| 122 |
+
hidden_states = module(hidden_states)
|
| 123 |
+
return hidden_states
|
nunchaku/models/embeddings.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Embedding layers for Nunchaku.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import diffusers
|
| 6 |
+
import torch
|
| 7 |
+
from packaging.version import Version
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
| 12 |
+
"""
|
| 13 |
+
Rotary positional embedding function.
|
| 14 |
+
Copied from https://github.com/huggingface/diffusers/blob/c9ff360966327ace3faad3807dc871a4e5447501/src/diffusers/models/transformers/transformer_flux.py#L38
|
| 15 |
+
|
| 16 |
+
Parameters
|
| 17 |
+
----------
|
| 18 |
+
pos : torch.Tensor, shape (..., n), dtype int
|
| 19 |
+
Position indices.
|
| 20 |
+
dim : int
|
| 21 |
+
Embedding dimension (must be even).
|
| 22 |
+
theta : int
|
| 23 |
+
Rotary base.
|
| 24 |
+
|
| 25 |
+
Returns
|
| 26 |
+
-------
|
| 27 |
+
out : torch.Tensor, shape (B, M, D//2, 1, 2), dtype float32
|
| 28 |
+
Rotary embedding tensor.
|
| 29 |
+
|
| 30 |
+
Notes
|
| 31 |
+
-----
|
| 32 |
+
- B: batch size
|
| 33 |
+
- M: sequence length
|
| 34 |
+
- D: embedding dimension
|
| 35 |
+
"""
|
| 36 |
+
assert dim % 2 == 0, "The dimension must be even."
|
| 37 |
+
|
| 38 |
+
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
| 39 |
+
omega = 1.0 / (theta**scale)
|
| 40 |
+
|
| 41 |
+
batch_size, seq_length = pos.shape
|
| 42 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
| 43 |
+
|
| 44 |
+
# Sin/cos representation for rotary embedding
|
| 45 |
+
cos_out = torch.cos(out)
|
| 46 |
+
sin_out = torch.sin(out)
|
| 47 |
+
stacked_out = torch.stack([sin_out, cos_out], dim=-1)
|
| 48 |
+
out = stacked_out.view(batch_size, -1, dim // 2, 1, 2)
|
| 49 |
+
|
| 50 |
+
return out.float()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class NunchakuFluxPosEmbed(nn.Module):
|
| 54 |
+
"""
|
| 55 |
+
Nunchaku multi-dimensional rotary embedding module for FLUX.
|
| 56 |
+
Adapted from https://github.com/huggingface/diffusers/blob/c9ff360966327ace3faad3807dc871a4e5447501/src/diffusers/models/transformers/transformer_flux.py#L55
|
| 57 |
+
|
| 58 |
+
Parameters
|
| 59 |
+
----------
|
| 60 |
+
dim : int
|
| 61 |
+
Embedding dimension.
|
| 62 |
+
theta : int
|
| 63 |
+
Rotary base.
|
| 64 |
+
axes_dim : list of int
|
| 65 |
+
Dimension for each spatial axis.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
| 69 |
+
super(NunchakuFluxPosEmbed, self).__init__()
|
| 70 |
+
self.dim = dim
|
| 71 |
+
self.theta = theta
|
| 72 |
+
self.axes_dim = axes_dim
|
| 73 |
+
|
| 74 |
+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
| 75 |
+
"""
|
| 76 |
+
Compute rotary embeddings for multi-dimensional positions.
|
| 77 |
+
|
| 78 |
+
Parameters
|
| 79 |
+
----------
|
| 80 |
+
ids : torch.Tensor, shape (..., n_axes), dtype int
|
| 81 |
+
Position indices.
|
| 82 |
+
|
| 83 |
+
Returns
|
| 84 |
+
-------
|
| 85 |
+
out : torch.Tensor, shape (B, 1, ...), dtype float32
|
| 86 |
+
Rotary embedding tensor.
|
| 87 |
+
|
| 88 |
+
Notes
|
| 89 |
+
-----
|
| 90 |
+
- B: batch size
|
| 91 |
+
- n_axes: number of spatial axes
|
| 92 |
+
"""
|
| 93 |
+
if Version(diffusers.__version__) >= Version("0.31.0"):
|
| 94 |
+
ids = ids[None, ...]
|
| 95 |
+
n_axes = ids.shape[-1]
|
| 96 |
+
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
|
| 97 |
+
return emb.unsqueeze(1)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def pack_rotemb(rotemb: torch.Tensor) -> torch.Tensor:
|
| 101 |
+
"""
|
| 102 |
+
Pack rotary embeddings for efficient CUDA computation.
|
| 103 |
+
|
| 104 |
+
Parameters
|
| 105 |
+
----------
|
| 106 |
+
rotemb : torch.Tensor, shape (B, M, D//2, 1, 2), dtype float32
|
| 107 |
+
Rotary embedding tensor.
|
| 108 |
+
|
| 109 |
+
Returns
|
| 110 |
+
-------
|
| 111 |
+
packed : torch.Tensor, shape (B, M, D), dtype float32
|
| 112 |
+
Packed rotary embedding tensor.
|
| 113 |
+
|
| 114 |
+
Notes
|
| 115 |
+
-----
|
| 116 |
+
- B: batch size
|
| 117 |
+
- M: sequence length (must be divisible by 16)
|
| 118 |
+
- D: embedding dimension (must be divisible by 8)
|
| 119 |
+
"""
|
| 120 |
+
assert rotemb.dtype == torch.float32
|
| 121 |
+
B = rotemb.shape[0]
|
| 122 |
+
M = rotemb.shape[1]
|
| 123 |
+
D = rotemb.shape[2] * 2
|
| 124 |
+
assert rotemb.shape == (B, M, D // 2, 1, 2)
|
| 125 |
+
assert M % 16 == 0
|
| 126 |
+
assert D % 8 == 0
|
| 127 |
+
rotemb = rotemb.reshape(B, M // 16, 16, D // 8, 8)
|
| 128 |
+
rotemb = rotemb.permute(0, 1, 3, 2, 4)
|
| 129 |
+
# 16*8 pack, FP32 accumulator (C) format
|
| 130 |
+
# https://docs.nvidia.com/cuda/parallel-thread-execution/#mma-16816-c
|
| 131 |
+
##########################################|--M--|--D--|
|
| 132 |
+
##########################################|-3--4--5--6|
|
| 133 |
+
########################################## : : : :
|
| 134 |
+
rotemb = rotemb.reshape(*rotemb.shape[0:3], 2, 8, 4, 2)
|
| 135 |
+
rotemb = rotemb.permute(0, 1, 2, 4, 5, 3, 6)
|
| 136 |
+
rotemb = rotemb.contiguous()
|
| 137 |
+
rotemb = rotemb.view(B, M, D)
|
| 138 |
+
return rotemb
|
nunchaku/models/linear.py
ADDED
|
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Quantized linear layers for Nunchaku.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from ..ops.gemm import svdq_gemm_w4a4_cuda
|
| 9 |
+
from ..ops.gemv import awq_gemv_w4a16_cuda
|
| 10 |
+
from ..ops.quantize import svdq_quantize_w4a4_act_fuse_lora_cuda
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SVDQW4A4Linear(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
`SVDQuant <paper_svdquant_>`_ W4A4 quantized linear layer.
|
| 16 |
+
|
| 17 |
+
Parameters
|
| 18 |
+
----------
|
| 19 |
+
in_features : int
|
| 20 |
+
Input feature dimension.
|
| 21 |
+
out_features : int
|
| 22 |
+
Output feature dimension.
|
| 23 |
+
rank : int, optional
|
| 24 |
+
SVD low-rank dimension. Default is 32.
|
| 25 |
+
bias : bool, optional
|
| 26 |
+
If True, adds a learnable bias. Default is True.
|
| 27 |
+
precision : {'int4', 'nvfp4'}, optional
|
| 28 |
+
Quantization precision data type ('int4' or 'nvfp4'). Default is 'int4'.
|
| 29 |
+
act_unsigned : bool, optional
|
| 30 |
+
If True, use unsigned activation quantization (int4 only). Default is False.
|
| 31 |
+
torch_dtype : torch.dtype, optional
|
| 32 |
+
Parameter dtype. Default is torch.bfloat16.
|
| 33 |
+
device : str or torch.device or None, optional
|
| 34 |
+
Device for parameters. Default is CPU.
|
| 35 |
+
|
| 36 |
+
Attributes
|
| 37 |
+
----------
|
| 38 |
+
in_features : int
|
| 39 |
+
out_features : int
|
| 40 |
+
rank : int
|
| 41 |
+
precision : str
|
| 42 |
+
'int4' or 'nvfp4'.
|
| 43 |
+
group_size : int
|
| 44 |
+
64 for int4, 16 for nvfp4.
|
| 45 |
+
qweight : nn.Parameter
|
| 46 |
+
Packed quantized weights, shape (out_features, in_features // 2), dtype int8.
|
| 47 |
+
bias : nn.Parameter or None
|
| 48 |
+
Bias tensor.
|
| 49 |
+
wscales : nn.Parameter
|
| 50 |
+
Weight scales, shape (in_features // group_size, out_features).
|
| 51 |
+
Dtype: bfloat16/float16 (int4), float8_e4m3fn (nvfp4).
|
| 52 |
+
smooth_factor : nn.Parameter
|
| 53 |
+
Smoothing factors, shape (in_features,).
|
| 54 |
+
smooth_factor_orig : nn.Parameter
|
| 55 |
+
Original smoothing factors, shape (in_features,). (Unused)
|
| 56 |
+
proj_down : nn.Parameter
|
| 57 |
+
Packed low-rank down projection, shape (in_features, rank), dtype bfloat16/float16.
|
| 58 |
+
proj_up : nn.Parameter
|
| 59 |
+
Packed low-rank up projection, shape (out_features, rank), dtype bfloat16/float16.
|
| 60 |
+
wtscale : float or None
|
| 61 |
+
Global weight scale (nvfp4 only).
|
| 62 |
+
wcscales : nn.Parameter or None
|
| 63 |
+
Channel-wise weight scale (nvfp4 only), shape (out_features,), dtype float8_e4m3fn.
|
| 64 |
+
act_unsigned : bool
|
| 65 |
+
If True, input activations are unsigned (int4 only).
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
in_features: int,
|
| 71 |
+
out_features: int,
|
| 72 |
+
rank: int = 32,
|
| 73 |
+
bias: bool = True,
|
| 74 |
+
precision: str = "int4",
|
| 75 |
+
act_unsigned: bool = False,
|
| 76 |
+
torch_dtype: torch.dtype = torch.bfloat16,
|
| 77 |
+
device: str | torch.device | None = None,
|
| 78 |
+
):
|
| 79 |
+
super(SVDQW4A4Linear, self).__init__()
|
| 80 |
+
if device is None:
|
| 81 |
+
device = torch.device("cpu")
|
| 82 |
+
self.in_features = in_features
|
| 83 |
+
self.out_features = out_features
|
| 84 |
+
self.rank = rank
|
| 85 |
+
|
| 86 |
+
self.precision = precision
|
| 87 |
+
self.torch_dtype = torch_dtype
|
| 88 |
+
|
| 89 |
+
if precision == "nvfp4":
|
| 90 |
+
self.group_size = 16
|
| 91 |
+
elif precision == "int4":
|
| 92 |
+
self.group_size = 64
|
| 93 |
+
else:
|
| 94 |
+
raise ValueError(f"Invalid precision: {precision}")
|
| 95 |
+
|
| 96 |
+
self.qweight = nn.Parameter(
|
| 97 |
+
torch.empty(out_features, in_features // 2, dtype=torch.int8, device=device), requires_grad=False
|
| 98 |
+
)
|
| 99 |
+
self.bias = (
|
| 100 |
+
nn.Parameter(torch.empty(out_features, dtype=torch_dtype, device=device), requires_grad=True)
|
| 101 |
+
if bias
|
| 102 |
+
else None
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
self.wscales = nn.Parameter(
|
| 106 |
+
torch.empty(
|
| 107 |
+
in_features // self.group_size,
|
| 108 |
+
out_features,
|
| 109 |
+
dtype=torch_dtype if precision == "int4" else torch.float8_e4m3fn,
|
| 110 |
+
device=device,
|
| 111 |
+
),
|
| 112 |
+
requires_grad=False,
|
| 113 |
+
)
|
| 114 |
+
self.smooth_factor = nn.Parameter(
|
| 115 |
+
torch.empty(in_features, dtype=torch_dtype, device=device), requires_grad=False
|
| 116 |
+
)
|
| 117 |
+
self.smooth_factor_orig = nn.Parameter(
|
| 118 |
+
torch.empty(in_features, dtype=torch_dtype, device=device), requires_grad=False
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
self.proj_down = nn.Parameter(torch.empty(in_features, rank, dtype=torch_dtype, device=device))
|
| 122 |
+
self.proj_up = nn.Parameter(torch.empty(out_features, rank, dtype=torch_dtype, device=device))
|
| 123 |
+
|
| 124 |
+
if precision == "nvfp4":
|
| 125 |
+
self.wcscales = nn.Parameter(
|
| 126 |
+
torch.ones(out_features, dtype=torch_dtype, device=device), requires_grad=False
|
| 127 |
+
)
|
| 128 |
+
self.wtscale = 1.0
|
| 129 |
+
else:
|
| 130 |
+
self.wtscale = None
|
| 131 |
+
self.wcscales = None
|
| 132 |
+
|
| 133 |
+
self.act_unsigned = act_unsigned
|
| 134 |
+
|
| 135 |
+
@classmethod
|
| 136 |
+
def from_linear(cls, linear: nn.Linear, **kwargs):
|
| 137 |
+
"""
|
| 138 |
+
Create an SVDQW4A4Linear from a standard nn.Linear. The weight and bias are dummy tensors.
|
| 139 |
+
|
| 140 |
+
Parameters
|
| 141 |
+
----------
|
| 142 |
+
linear : nn.Linear
|
| 143 |
+
Source linear layer.
|
| 144 |
+
**kwargs
|
| 145 |
+
Additional init arguments.
|
| 146 |
+
|
| 147 |
+
Returns
|
| 148 |
+
-------
|
| 149 |
+
SVDQW4A4Linear
|
| 150 |
+
"""
|
| 151 |
+
in_features = kwargs.pop("in_features", linear.in_features)
|
| 152 |
+
return cls(
|
| 153 |
+
in_features=in_features,
|
| 154 |
+
out_features=linear.out_features,
|
| 155 |
+
bias=linear.bias is not None,
|
| 156 |
+
torch_dtype=linear.weight.dtype,
|
| 157 |
+
device=linear.weight.device,
|
| 158 |
+
**kwargs,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
def forward(self, x: torch.Tensor, output: torch.Tensor | None = None) -> torch.Tensor:
|
| 162 |
+
"""
|
| 163 |
+
Forward pass with 16-bit input. It will call :meth:`quantize` and :meth:`forward_quant`.
|
| 164 |
+
|
| 165 |
+
Parameters
|
| 166 |
+
----------
|
| 167 |
+
x : torch.Tensor, shape (B, S, in_features), dtype float16 or bfloat16
|
| 168 |
+
Input tensor.
|
| 169 |
+
output : torch.Tensor or None, optional
|
| 170 |
+
Optional output buffer.
|
| 171 |
+
|
| 172 |
+
Returns
|
| 173 |
+
-------
|
| 174 |
+
torch.Tensor, shape (B, S, out_features)
|
| 175 |
+
Output tensor.
|
| 176 |
+
|
| 177 |
+
Notes
|
| 178 |
+
-----
|
| 179 |
+
B: batch size, S: sequence length
|
| 180 |
+
"""
|
| 181 |
+
batch_size, seq_len, channels = x.shape
|
| 182 |
+
x = x.reshape(batch_size * seq_len, channels)
|
| 183 |
+
if output is None:
|
| 184 |
+
output = torch.empty(batch_size * seq_len, self.out_features, dtype=x.dtype, device=x.device)
|
| 185 |
+
quantized_x, ascales, lora_act_out = self.quantize(x)
|
| 186 |
+
output = self.forward_quant(quantized_x, ascales, lora_act_out, output)
|
| 187 |
+
output = output.reshape(batch_size, seq_len, -1)
|
| 188 |
+
return output
|
| 189 |
+
|
| 190 |
+
def quantize(self, x: torch.Tensor, pad_size: int = 256) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 191 |
+
"""
|
| 192 |
+
Quantize input to 4-bit and compute low-rank hidden states. It will call :func:`~nunchaku.ops.quantize.svdq_quantize_w4a4_act_fuse_lora_cuda`.
|
| 193 |
+
|
| 194 |
+
Parameters
|
| 195 |
+
----------
|
| 196 |
+
x : torch.Tensor, shape (N, in_features), dtype float16 or bfloat16
|
| 197 |
+
Input tensor.
|
| 198 |
+
pad_size : int, optional
|
| 199 |
+
Batch padding size. Default is 256.
|
| 200 |
+
|
| 201 |
+
Returns
|
| 202 |
+
-------
|
| 203 |
+
quantized_x : torch.Tensor
|
| 204 |
+
Quantized input, shape (pad_size * ceil(N / pad_size), in_features // 2), dtype uint8.
|
| 205 |
+
ascales : torch.Tensor
|
| 206 |
+
Activation scales, shape (in_features // group_size,), dtype float8_e4m3fn for nvfp4 and input dtype for int4.
|
| 207 |
+
lora_act_out : torch.Tensor
|
| 208 |
+
Low-rank hidden states, shape (pad_size * ceil(N / pad_size), rank), dtype float32.
|
| 209 |
+
|
| 210 |
+
Notes
|
| 211 |
+
-----
|
| 212 |
+
N: batch size
|
| 213 |
+
"""
|
| 214 |
+
quantized_x, ascales, lora_act_out = svdq_quantize_w4a4_act_fuse_lora_cuda(
|
| 215 |
+
x, lora_down=self.proj_down, smooth=self.smooth_factor, fp4=self.precision == "nvfp4", pad_size=pad_size
|
| 216 |
+
)
|
| 217 |
+
return quantized_x, ascales, lora_act_out
|
| 218 |
+
|
| 219 |
+
def forward_quant(
|
| 220 |
+
self,
|
| 221 |
+
quantized_x: torch.Tensor,
|
| 222 |
+
ascales: torch.Tensor,
|
| 223 |
+
lora_act: torch.Tensor,
|
| 224 |
+
output: torch.Tensor | None = None,
|
| 225 |
+
) -> torch.Tensor:
|
| 226 |
+
"""
|
| 227 |
+
Forward pass with pre-quantized input. It will call :func:`~nunchaku.ops.gemm.svdq_gemm_w4a4_cuda`.
|
| 228 |
+
|
| 229 |
+
Parameters
|
| 230 |
+
----------
|
| 231 |
+
quantized_x : torch.Tensor
|
| 232 |
+
Quantized input, shape (N, in_features // 2), dtype uint8.
|
| 233 |
+
ascales : torch.Tensor
|
| 234 |
+
Activation scales, shape (in_features // group_size,), dtype float8_e4m3fn for nvfp4 and input dtype for int4.
|
| 235 |
+
lora_act : torch.Tensor
|
| 236 |
+
Low-rank hidden states, shape (N, rank), dtype float32.
|
| 237 |
+
output : torch.Tensor or None, optional
|
| 238 |
+
Optional output buffer.
|
| 239 |
+
|
| 240 |
+
Returns
|
| 241 |
+
-------
|
| 242 |
+
torch.Tensor
|
| 243 |
+
Output tensor, shape (N, out_features), dtype bfloat16/float16 for int4 and float8_e4m3fn for nvfp4.
|
| 244 |
+
|
| 245 |
+
Notes
|
| 246 |
+
-----
|
| 247 |
+
N: batch size
|
| 248 |
+
"""
|
| 249 |
+
if output is None:
|
| 250 |
+
output = torch.empty(
|
| 251 |
+
quantized_x.shape[0], self.out_features, dtype=self.proj_up.dtype, device=quantized_x.device
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
svdq_gemm_w4a4_cuda(
|
| 255 |
+
act=quantized_x,
|
| 256 |
+
wgt=self.qweight,
|
| 257 |
+
out=output,
|
| 258 |
+
ascales=ascales,
|
| 259 |
+
wscales=self.wscales,
|
| 260 |
+
lora_act_in=lora_act,
|
| 261 |
+
lora_up=self.proj_up,
|
| 262 |
+
bias=self.bias,
|
| 263 |
+
fp4=self.precision == "nvfp4",
|
| 264 |
+
alpha=self.wtscale,
|
| 265 |
+
wcscales=self.wcscales,
|
| 266 |
+
act_unsigned=self.act_unsigned,
|
| 267 |
+
)
|
| 268 |
+
return output
|
| 269 |
+
|
| 270 |
+
def __repr__(self):
|
| 271 |
+
return (
|
| 272 |
+
f"SVDQW4A4Linear(in_features={self.in_features}, out_features={self.out_features}, "
|
| 273 |
+
f"rank={self.rank}, precision={self.precision}, act_unsigned={self.act_unsigned})"
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class AWQW4A16Linear(nn.Module):
|
| 278 |
+
"""
|
| 279 |
+
`AWQ <paper_awq_>`_ W4A16 quantized linear layer.
|
| 280 |
+
|
| 281 |
+
Parameters
|
| 282 |
+
----------
|
| 283 |
+
in_features : int
|
| 284 |
+
Input feature dimension.
|
| 285 |
+
out_features : int
|
| 286 |
+
Output feature dimension.
|
| 287 |
+
bias : bool, optional
|
| 288 |
+
If True, adds learnable bias. Default is True.
|
| 289 |
+
group_size : int, optional
|
| 290 |
+
Quantization group size. Default is 64.
|
| 291 |
+
torch_dtype : torch.dtype, optional
|
| 292 |
+
Parameter dtype. Default is torch.bfloat16.
|
| 293 |
+
device : str or torch.device or None, optional
|
| 294 |
+
Device for parameters. Default is CPU.
|
| 295 |
+
|
| 296 |
+
Attributes
|
| 297 |
+
----------
|
| 298 |
+
in_features : int
|
| 299 |
+
out_features : int
|
| 300 |
+
group_size : int
|
| 301 |
+
qweight : nn.Parameter
|
| 302 |
+
Packed quantized weights, shape (out_features // 4, in_features // 2), dtype int32.
|
| 303 |
+
bias : nn.Parameter or None
|
| 304 |
+
Bias tensor.
|
| 305 |
+
wscales : nn.Parameter
|
| 306 |
+
Weight scales, shape (in_features // group_size, out_features), dtype float16 or bfloat16.
|
| 307 |
+
wzeros : nn.Parameter
|
| 308 |
+
Weight zero points, shape (in_features // group_size, out_features), dtype float16 or bfloat16.
|
| 309 |
+
"""
|
| 310 |
+
|
| 311 |
+
def __init__(
|
| 312 |
+
self,
|
| 313 |
+
in_features: int,
|
| 314 |
+
out_features: int,
|
| 315 |
+
bias: bool = True,
|
| 316 |
+
group_size: int = 64,
|
| 317 |
+
torch_dtype: torch.dtype = torch.bfloat16,
|
| 318 |
+
device: str | torch.device | None = None,
|
| 319 |
+
):
|
| 320 |
+
super(AWQW4A16Linear, self).__init__()
|
| 321 |
+
if device is None:
|
| 322 |
+
device = torch.device("cpu")
|
| 323 |
+
self.in_features = in_features
|
| 324 |
+
self.out_features = out_features
|
| 325 |
+
self.group_size = group_size
|
| 326 |
+
|
| 327 |
+
self.qweight = nn.Parameter(
|
| 328 |
+
torch.empty(out_features // 4, in_features // 2, dtype=torch.int32, device=device), requires_grad=False
|
| 329 |
+
)
|
| 330 |
+
self.bias = (
|
| 331 |
+
nn.Parameter(torch.empty(out_features, dtype=torch_dtype, device=device), requires_grad=True)
|
| 332 |
+
if bias
|
| 333 |
+
else None
|
| 334 |
+
)
|
| 335 |
+
self.wscales = nn.Parameter(
|
| 336 |
+
torch.empty(in_features // self.group_size, out_features, dtype=torch_dtype, device=device),
|
| 337 |
+
requires_grad=False,
|
| 338 |
+
)
|
| 339 |
+
self.wzeros = nn.Parameter(
|
| 340 |
+
torch.empty(in_features // self.group_size, out_features, dtype=torch_dtype, device=device),
|
| 341 |
+
requires_grad=False,
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 345 |
+
"""
|
| 346 |
+
Forward pass for AWQW4A16Linear.
|
| 347 |
+
|
| 348 |
+
Parameters
|
| 349 |
+
----------
|
| 350 |
+
x : torch.Tensor, shape (N, in_features)
|
| 351 |
+
Input tensor.
|
| 352 |
+
|
| 353 |
+
Returns
|
| 354 |
+
-------
|
| 355 |
+
torch.Tensor, shape (N, out_features)
|
| 356 |
+
Output tensor.
|
| 357 |
+
|
| 358 |
+
Notes
|
| 359 |
+
-----
|
| 360 |
+
N: batch size
|
| 361 |
+
"""
|
| 362 |
+
output = awq_gemv_w4a16_cuda(
|
| 363 |
+
in_feats=x,
|
| 364 |
+
kernel=self.qweight,
|
| 365 |
+
scaling_factors=self.wscales,
|
| 366 |
+
zeros=self.wzeros,
|
| 367 |
+
m=x.shape[0],
|
| 368 |
+
n=self.out_features,
|
| 369 |
+
k=self.in_features,
|
| 370 |
+
group_size=self.group_size,
|
| 371 |
+
)
|
| 372 |
+
if self.bias is not None:
|
| 373 |
+
view_shape = [1] * (output.ndim - 1) + [-1]
|
| 374 |
+
output.add_(self.bias.view(view_shape))
|
| 375 |
+
return output
|
| 376 |
+
|
| 377 |
+
@classmethod
|
| 378 |
+
def from_linear(
|
| 379 |
+
cls,
|
| 380 |
+
linear: nn.Linear,
|
| 381 |
+
group_size: int = 64,
|
| 382 |
+
torch_dtype: torch.dtype = torch.bfloat16,
|
| 383 |
+
device: str = "cpu",
|
| 384 |
+
**kwargs,
|
| 385 |
+
):
|
| 386 |
+
"""
|
| 387 |
+
Create an uninitialized AWQW4A16Linear from a standard nn.Linear.
|
| 388 |
+
|
| 389 |
+
Parameters
|
| 390 |
+
----------
|
| 391 |
+
linear : nn.Linear
|
| 392 |
+
Source linear layer.
|
| 393 |
+
group_size : int, optional
|
| 394 |
+
Quantization group size.
|
| 395 |
+
torch_dtype : torch.dtype, optional
|
| 396 |
+
Parameter dtype.
|
| 397 |
+
device : str, optional
|
| 398 |
+
Device for parameters.
|
| 399 |
+
|
| 400 |
+
Returns
|
| 401 |
+
-------
|
| 402 |
+
AWQW4A16Linear
|
| 403 |
+
"""
|
| 404 |
+
return cls(
|
| 405 |
+
in_features=linear.in_features,
|
| 406 |
+
out_features=linear.out_features,
|
| 407 |
+
bias=linear.bias is not None,
|
| 408 |
+
group_size=group_size,
|
| 409 |
+
torch_dtype=torch_dtype,
|
| 410 |
+
device=device,
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
def __repr__(self):
|
| 414 |
+
return f"AWQW4A16Linear(in_features={self.in_features}, out_features={self.out_features}, group_size={self.group_size})"
|
nunchaku/models/normalization.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Quantized normalization layers for efficient inference.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from diffusers.models.normalization import AdaLayerNormZero, AdaLayerNormZeroSingle
|
| 9 |
+
|
| 10 |
+
from .linear import AWQW4A16Linear
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class NunchakuAdaLayerNormZero(AdaLayerNormZero):
|
| 14 |
+
"""
|
| 15 |
+
Nunchaku quantized AdaLayerNormZero for diffusion models.
|
| 16 |
+
|
| 17 |
+
Replaces the linear projection with AWQW4A16Linear for quantized inference.
|
| 18 |
+
|
| 19 |
+
Parameters
|
| 20 |
+
----------
|
| 21 |
+
other : AdaLayerNormZero
|
| 22 |
+
Source AdaLayerNormZero instance to copy weights and structure from.
|
| 23 |
+
scale_shift : float, optional
|
| 24 |
+
Value to add to scale parameters. Default is 1.0.
|
| 25 |
+
Nunchaku may have already fused the scale_shift into the linear weights, so you may want to set it to 0.
|
| 26 |
+
|
| 27 |
+
Notes
|
| 28 |
+
-----
|
| 29 |
+
- B: batch size
|
| 30 |
+
- D: hidden dimension
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, other: AdaLayerNormZero, scale_shift: float = 1.0):
|
| 34 |
+
super(AdaLayerNormZero, self).__init__()
|
| 35 |
+
self.scale_shift = scale_shift
|
| 36 |
+
self.emb = other.emb
|
| 37 |
+
self.silu = other.silu
|
| 38 |
+
self.linear = AWQW4A16Linear.from_linear(other.linear)
|
| 39 |
+
self.norm = other.norm
|
| 40 |
+
|
| 41 |
+
def forward(
|
| 42 |
+
self,
|
| 43 |
+
x: torch.Tensor,
|
| 44 |
+
timestep: Optional[torch.Tensor] = None,
|
| 45 |
+
class_labels: Optional[torch.LongTensor] = None,
|
| 46 |
+
hidden_dtype: Optional[torch.dtype] = None,
|
| 47 |
+
emb: Optional[torch.Tensor] = None,
|
| 48 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 49 |
+
"""
|
| 50 |
+
Forward pass for quantized AdaLayerNormZero.
|
| 51 |
+
|
| 52 |
+
Parameters
|
| 53 |
+
----------
|
| 54 |
+
x : torch.Tensor, shape (B, D), dtype float32/float16
|
| 55 |
+
Input tensor.
|
| 56 |
+
timestep : Optional[torch.Tensor], shape (B,) or (1,), optional
|
| 57 |
+
Timestep embedding input.
|
| 58 |
+
class_labels : Optional[torch.LongTensor], shape (B,) or (1,), optional
|
| 59 |
+
Class label input.
|
| 60 |
+
hidden_dtype : Optional[torch.dtype], optional
|
| 61 |
+
Dtype for embedding computation.
|
| 62 |
+
emb : Optional[torch.Tensor], shape (B, E), optional
|
| 63 |
+
Precomputed embedding. If None, computed from timestep and class_labels.
|
| 64 |
+
|
| 65 |
+
Returns
|
| 66 |
+
-------
|
| 67 |
+
norm_x_scaled : torch.Tensor, shape (B, D)
|
| 68 |
+
Normalized and scaled input.
|
| 69 |
+
gate_msa : torch.Tensor, shape (B, D)
|
| 70 |
+
Gate for MSA branch.
|
| 71 |
+
shift_mlp : torch.Tensor, shape (B, D)
|
| 72 |
+
Shift for MLP branch.
|
| 73 |
+
scale_mlp : torch.Tensor, shape (B, D)
|
| 74 |
+
Scale for MLP branch.
|
| 75 |
+
gate_mlp : torch.Tensor, shape (B, D)
|
| 76 |
+
Gate for MLP branch.
|
| 77 |
+
|
| 78 |
+
Notes
|
| 79 |
+
-----
|
| 80 |
+
- B: batch size
|
| 81 |
+
- D: hidden dimension
|
| 82 |
+
"""
|
| 83 |
+
if self.emb is not None:
|
| 84 |
+
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
|
| 85 |
+
emb = self.linear(self.silu(emb))
|
| 86 |
+
|
| 87 |
+
# The weight layout has changed; use split_mod rather than chunk to separate the embedding.
|
| 88 |
+
emb = emb.view(emb.shape[0], -1, 6).permute(2, 0, 1)
|
| 89 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb
|
| 90 |
+
|
| 91 |
+
norm_x = self.norm(x)
|
| 92 |
+
|
| 93 |
+
if self.scale_shift != 0:
|
| 94 |
+
scale_msa.add_(self.scale_shift)
|
| 95 |
+
scale_mlp.add_(self.scale_shift)
|
| 96 |
+
|
| 97 |
+
norm_x_scaled = norm_x * scale_msa[:, None] + shift_msa[:, None]
|
| 98 |
+
return norm_x_scaled, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class NunchakuAdaLayerNormZeroSingle(AdaLayerNormZeroSingle):
|
| 102 |
+
"""
|
| 103 |
+
Nunchaku quantized AdaLayerNormZeroSingle.
|
| 104 |
+
|
| 105 |
+
Uses AWQW4A16Linear for quantized embedding projection. Suitable for single-branch normalization.
|
| 106 |
+
|
| 107 |
+
Parameters
|
| 108 |
+
----------
|
| 109 |
+
other : AdaLayerNormZeroSingle
|
| 110 |
+
Source AdaLayerNormZeroSingle instance to copy weights and structure from.
|
| 111 |
+
scale_shift : float, optional
|
| 112 |
+
Value to add to scale parameters. Default is 1.0.
|
| 113 |
+
Nunchaku may have already fused the scale_shift into the linear weights, so you may want to set it to 0.
|
| 114 |
+
|
| 115 |
+
Notes
|
| 116 |
+
-----
|
| 117 |
+
- B: batch size
|
| 118 |
+
- D: hidden dimension
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
def __init__(self, other: AdaLayerNormZeroSingle, scale_shift: float = 1.0):
|
| 122 |
+
super(AdaLayerNormZeroSingle, self).__init__()
|
| 123 |
+
self.scale_shift = scale_shift
|
| 124 |
+
self.silu = other.silu
|
| 125 |
+
self.linear = AWQW4A16Linear.from_linear(other.linear)
|
| 126 |
+
self.norm = other.norm
|
| 127 |
+
|
| 128 |
+
def forward(
|
| 129 |
+
self,
|
| 130 |
+
x: torch.Tensor,
|
| 131 |
+
emb: Optional[torch.Tensor] = None,
|
| 132 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 133 |
+
"""
|
| 134 |
+
Forward pass for quantized AdaLayerNormZeroSingle.
|
| 135 |
+
|
| 136 |
+
Parameters
|
| 137 |
+
----------
|
| 138 |
+
x : torch.Tensor, shape (B, D), dtype float32/float16
|
| 139 |
+
Input tensor.
|
| 140 |
+
emb : Optional[torch.Tensor], shape (B, E), optional
|
| 141 |
+
Embedding tensor.
|
| 142 |
+
|
| 143 |
+
Returns
|
| 144 |
+
-------
|
| 145 |
+
norm_x_scaled : torch.Tensor, shape (B, D)
|
| 146 |
+
Normalized and scaled input.
|
| 147 |
+
gate_msa : torch.Tensor, shape (B, D)
|
| 148 |
+
Gate for MSA branch.
|
| 149 |
+
|
| 150 |
+
Notes
|
| 151 |
+
-----
|
| 152 |
+
- B: batch size
|
| 153 |
+
- D: hidden dimension
|
| 154 |
+
"""
|
| 155 |
+
emb = self.linear(self.silu(emb))
|
| 156 |
+
|
| 157 |
+
# The weight layout has changed; use split_mod rather than chunk to separate the embedding.
|
| 158 |
+
emb = emb.view(emb.shape[0], -1, 3).permute(2, 0, 1)
|
| 159 |
+
shift_msa, scale_msa, gate_msa = emb
|
| 160 |
+
|
| 161 |
+
if self.scale_shift != 0:
|
| 162 |
+
scale_msa.add_(self.scale_shift)
|
| 163 |
+
|
| 164 |
+
norm_x = self.norm(x)
|
| 165 |
+
norm_x_scaled = norm_x * scale_msa[:, None] + shift_msa[:, None]
|
| 166 |
+
return norm_x_scaled, gate_msa
|
nunchaku/models/text_encoders/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .t5_encoder import NunchakuT5EncoderModel
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
"NunchakuT5EncoderModel",
|
| 5 |
+
]
|
nunchaku/models/text_encoders/linear.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
This module provides the :class:`W4Linear` quantized linear layer, which implements
|
| 4 |
+
4-bit weight-only quantization for efficient inference.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
from ..._C.ops import gemm_awq, gemv_awq
|
| 11 |
+
from .tinychat_utils import ceil_num_groups, convert_to_tinychat_w4x16y16_linear_weight
|
| 12 |
+
|
| 13 |
+
__all__ = ["W4Linear"]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class W4Linear(nn.Module):
|
| 17 |
+
"""
|
| 18 |
+
4-bit quantized linear layer with group-wise quantization.
|
| 19 |
+
|
| 20 |
+
Parameters
|
| 21 |
+
----------
|
| 22 |
+
in_features : int
|
| 23 |
+
Number of input features.
|
| 24 |
+
out_features : int
|
| 25 |
+
Number of output features.
|
| 26 |
+
bias : bool, optional
|
| 27 |
+
If True, adds a learnable bias (default: False).
|
| 28 |
+
group_size : int, optional
|
| 29 |
+
Number of input channels per quantization group (default: 128).
|
| 30 |
+
If -1, uses the full input dimension as a single group.
|
| 31 |
+
dtype : torch.dtype, optional
|
| 32 |
+
Data type for quantization scales and zeros (default: torch.float16).
|
| 33 |
+
device : str or torch.device, optional
|
| 34 |
+
Device for weights and buffers (default: "cuda").
|
| 35 |
+
|
| 36 |
+
Attributes
|
| 37 |
+
----------
|
| 38 |
+
in_features : int
|
| 39 |
+
out_features : int
|
| 40 |
+
group_size : int
|
| 41 |
+
qweight : torch.Tensor
|
| 42 |
+
Quantized weight tensor (int16).
|
| 43 |
+
scales : torch.Tensor
|
| 44 |
+
Per-group scale tensor.
|
| 45 |
+
scaled_zeros : torch.Tensor
|
| 46 |
+
Per-group zero-point tensor (scaled).
|
| 47 |
+
bias : torch.Tensor or None
|
| 48 |
+
Optional bias tensor.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
in_features: int,
|
| 54 |
+
out_features: int,
|
| 55 |
+
bias: bool = False,
|
| 56 |
+
group_size: int = 128,
|
| 57 |
+
dtype: torch.dtype = torch.float16,
|
| 58 |
+
device: str | torch.device = "cuda",
|
| 59 |
+
):
|
| 60 |
+
super().__init__()
|
| 61 |
+
assert dtype in (torch.float16, torch.bfloat16), f"Unsupported dtype: {dtype}"
|
| 62 |
+
|
| 63 |
+
self.in_features = in_features
|
| 64 |
+
self.out_features = out_features
|
| 65 |
+
self.group_size = group_size if group_size != -1 else in_features
|
| 66 |
+
assert self.in_features % self.group_size == 0
|
| 67 |
+
assert out_features % (32 // self.weight_bits) == 0
|
| 68 |
+
self.ceil_num_groups = ceil_num_groups(
|
| 69 |
+
in_features=self.in_features,
|
| 70 |
+
group_size=self.group_size,
|
| 71 |
+
weight_bits=self.weight_bits,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
assert out_features % (self.interleave) == 0
|
| 75 |
+
self.register_buffer(
|
| 76 |
+
"qweight",
|
| 77 |
+
torch.zeros(
|
| 78 |
+
(
|
| 79 |
+
self.out_features // self.interleave,
|
| 80 |
+
self.in_features // (16 // self.weight_bits) * self.interleave,
|
| 81 |
+
),
|
| 82 |
+
dtype=torch.int16,
|
| 83 |
+
device=device,
|
| 84 |
+
),
|
| 85 |
+
)
|
| 86 |
+
self.register_buffer(
|
| 87 |
+
"scales",
|
| 88 |
+
torch.zeros((self.ceil_num_groups, self.out_features), dtype=dtype, device=device),
|
| 89 |
+
)
|
| 90 |
+
self.register_buffer(
|
| 91 |
+
"scaled_zeros",
|
| 92 |
+
torch.zeros((self.ceil_num_groups, self.out_features), dtype=dtype, device=device),
|
| 93 |
+
)
|
| 94 |
+
if bias:
|
| 95 |
+
self.register_buffer("bias", torch.zeros((out_features), dtype=dtype, device=device))
|
| 96 |
+
else:
|
| 97 |
+
self.bias = None
|
| 98 |
+
|
| 99 |
+
@property
|
| 100 |
+
def weight_bits(self) -> int:
|
| 101 |
+
"""
|
| 102 |
+
Number of bits per quantized weight (always 4).
|
| 103 |
+
"""
|
| 104 |
+
return 4
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def interleave(self) -> int:
|
| 108 |
+
"""
|
| 109 |
+
Interleave factor for quantized weights (always 4).
|
| 110 |
+
"""
|
| 111 |
+
return 4
|
| 112 |
+
|
| 113 |
+
@torch.no_grad()
|
| 114 |
+
def forward(self, x):
|
| 115 |
+
"""
|
| 116 |
+
Forward pass.
|
| 117 |
+
|
| 118 |
+
Parameters
|
| 119 |
+
----------
|
| 120 |
+
x : torch.Tensor
|
| 121 |
+
Input tensor of shape (..., in_features).
|
| 122 |
+
|
| 123 |
+
Returns
|
| 124 |
+
-------
|
| 125 |
+
torch.Tensor
|
| 126 |
+
Output tensor of shape (..., out_features).
|
| 127 |
+
"""
|
| 128 |
+
if x.numel() / x.shape[-1] < 8:
|
| 129 |
+
out = gemv_awq(
|
| 130 |
+
x,
|
| 131 |
+
self.qweight,
|
| 132 |
+
self.scales,
|
| 133 |
+
self.scaled_zeros,
|
| 134 |
+
x.numel() // x.shape[-1],
|
| 135 |
+
self.out_features,
|
| 136 |
+
self.in_features,
|
| 137 |
+
self.group_size,
|
| 138 |
+
)
|
| 139 |
+
else:
|
| 140 |
+
if self.group_size != 128:
|
| 141 |
+
raise NotImplementedError("Kernel currently only supports group_size=128.")
|
| 142 |
+
out = gemm_awq(x, self.qweight, self.scales, self.scaled_zeros)
|
| 143 |
+
out = out + self.bias if self.bias is not None else out
|
| 144 |
+
return out
|
| 145 |
+
|
| 146 |
+
@staticmethod
|
| 147 |
+
def from_linear(
|
| 148 |
+
linear: nn.Linear,
|
| 149 |
+
group_size: int,
|
| 150 |
+
init_only: bool = False,
|
| 151 |
+
weight: torch.Tensor | None = None,
|
| 152 |
+
scale: torch.Tensor | None = None,
|
| 153 |
+
zero: torch.Tensor | None = None,
|
| 154 |
+
zero_pre_scaled: bool = False,
|
| 155 |
+
) -> "W4Linear":
|
| 156 |
+
"""
|
| 157 |
+
Convert a standard nn.Linear to a quantized W4Linear.
|
| 158 |
+
|
| 159 |
+
Parameters
|
| 160 |
+
----------
|
| 161 |
+
linear : nn.Linear
|
| 162 |
+
The linear layer to convert.
|
| 163 |
+
group_size : int
|
| 164 |
+
Quantization group size.
|
| 165 |
+
init_only : bool, optional
|
| 166 |
+
If True, only initializes the quantized layer (default: False).
|
| 167 |
+
weight : torch.Tensor, optional
|
| 168 |
+
Precomputed quantized weight (default: None).
|
| 169 |
+
scale : torch.Tensor, optional
|
| 170 |
+
Precomputed scale tensor (default: None).
|
| 171 |
+
zero : torch.Tensor, optional
|
| 172 |
+
Precomputed zero-point tensor (default: None).
|
| 173 |
+
zero_pre_scaled : bool, optional
|
| 174 |
+
Whether the zero-point tensor is pre-scaled (default: False).
|
| 175 |
+
|
| 176 |
+
Returns
|
| 177 |
+
-------
|
| 178 |
+
W4Linear
|
| 179 |
+
Quantized linear layer.
|
| 180 |
+
"""
|
| 181 |
+
assert isinstance(linear, nn.Linear)
|
| 182 |
+
weight = linear.weight.data if weight is None else weight.data
|
| 183 |
+
dtype, device = weight.dtype, weight.device
|
| 184 |
+
oc, ic = linear.out_features, linear.in_features
|
| 185 |
+
_linear = W4Linear(
|
| 186 |
+
in_features=ic,
|
| 187 |
+
out_features=oc,
|
| 188 |
+
bias=linear.bias is not None,
|
| 189 |
+
group_size=group_size,
|
| 190 |
+
dtype=dtype,
|
| 191 |
+
device=device,
|
| 192 |
+
)
|
| 193 |
+
if init_only:
|
| 194 |
+
return _linear
|
| 195 |
+
if linear.bias is not None:
|
| 196 |
+
_linear.bias.data.copy_(linear.bias.data)
|
| 197 |
+
if scale is None:
|
| 198 |
+
assert zero is None, "scale and zero point tensors should be provided together."
|
| 199 |
+
group_size = ic if group_size <= 0 else group_size
|
| 200 |
+
assert group_size <= ic, "group size should be less than or equal to input channel size."
|
| 201 |
+
assert ic % group_size == 0, "input channel size should be divisible by group size."
|
| 202 |
+
ng, gs = ic // group_size, group_size
|
| 203 |
+
weight = weight.to(dtype=torch.float32).view(oc, 1, ng, gs)
|
| 204 |
+
vmin, vmax = weight.amin(dim=-1, keepdim=True), weight.amax(dim=-1, keepdim=True)
|
| 205 |
+
scale = (vmax - vmin).div_(15)
|
| 206 |
+
scale[scale == 0] = 1.0
|
| 207 |
+
if zero_pre_scaled:
|
| 208 |
+
zero = vmin.neg_().div_(scale).round_().clamp_(0, 15)
|
| 209 |
+
weight = weight.div_(scale).add_(zero).round_().clamp_(0, 15).sub_(zero).mul_(scale)
|
| 210 |
+
else:
|
| 211 |
+
zero = vmin.neg_().clamp_min(0)
|
| 212 |
+
weight = weight.add_(zero).div_(scale).round_().clamp_(0, 15).mul_(scale).sub_(zero)
|
| 213 |
+
weight = weight.to(dtype=dtype).view(oc, ic)
|
| 214 |
+
scale = scale.to(dtype=dtype)
|
| 215 |
+
zero = zero.to(dtype=dtype)
|
| 216 |
+
weight, scale, zero = convert_to_tinychat_w4x16y16_linear_weight(
|
| 217 |
+
weight=weight,
|
| 218 |
+
scale=scale,
|
| 219 |
+
zero=zero,
|
| 220 |
+
group_size=group_size,
|
| 221 |
+
zero_pre_scaled=zero_pre_scaled,
|
| 222 |
+
)
|
| 223 |
+
_linear.qweight.data.copy_(weight)
|
| 224 |
+
_linear.scales.data.copy_(scale)
|
| 225 |
+
_linear.scaled_zeros.data.copy_(zero)
|
| 226 |
+
return _linear
|
| 227 |
+
|
| 228 |
+
def extra_repr(self) -> str:
|
| 229 |
+
"""
|
| 230 |
+
Returns a string describing the layer configuration.
|
| 231 |
+
"""
|
| 232 |
+
return "in_features={}, out_features={}, bias={}, weight_bits={}, group_size={}".format(
|
| 233 |
+
self.in_features,
|
| 234 |
+
self.out_features,
|
| 235 |
+
self.bias is not None,
|
| 236 |
+
self.weight_bits,
|
| 237 |
+
self.group_size,
|
| 238 |
+
)
|
nunchaku/models/text_encoders/t5_encoder.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The NunchakuT5EncoderModel class enables loading T5 encoder weights from safetensors files,
|
| 3 |
+
automatically replacing supported linear layers with quantized :class:`~nunchaku.models.text_encoders.linear.W4Linear`
|
| 4 |
+
modules for improved performance and memory efficiency.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from accelerate import init_empty_weights
|
| 14 |
+
from torch import nn
|
| 15 |
+
from transformers import T5Config, T5EncoderModel
|
| 16 |
+
|
| 17 |
+
from ...utils import load_state_dict_in_safetensors
|
| 18 |
+
from .linear import W4Linear
|
| 19 |
+
|
| 20 |
+
# Get log level from environment variable (default to INFO)
|
| 21 |
+
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
|
| 22 |
+
|
| 23 |
+
# Configure logging
|
| 24 |
+
logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class NunchakuT5EncoderModel(T5EncoderModel):
|
| 29 |
+
"""
|
| 30 |
+
Nunchaku T5 Encoder Model
|
| 31 |
+
|
| 32 |
+
Extends :class:`transformers.T5EncoderModel` to support quantized weights and
|
| 33 |
+
memory-efficient inference using :class:`~nunchaku.models.text_encoders.linear.W4Linear`.
|
| 34 |
+
|
| 35 |
+
This class provides a convenient interface for loading T5 encoder weights from
|
| 36 |
+
safetensors files, automatically replacing supported linear layers with quantized
|
| 37 |
+
modules for improved speed and reduced memory usage.
|
| 38 |
+
|
| 39 |
+
Example
|
| 40 |
+
-------
|
| 41 |
+
.. code-block:: python
|
| 42 |
+
|
| 43 |
+
model = NunchakuT5EncoderModel.from_pretrained(
|
| 44 |
+
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
|
| 45 |
+
)
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
@classmethod
|
| 49 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
|
| 50 |
+
"""
|
| 51 |
+
Load a :class:`NunchakuT5EncoderModel` from a safetensors file.
|
| 52 |
+
|
| 53 |
+
This method loads the model configuration and weights from a safetensors file,
|
| 54 |
+
initializes the model on the 'meta' device (no memory allocation for weights),
|
| 55 |
+
and replaces supported linear layers with quantized :class:`~nunchaku.models.text_encoders.linear.W4Linear` modules.
|
| 56 |
+
|
| 57 |
+
Parameters
|
| 58 |
+
----------
|
| 59 |
+
pretrained_model_name_or_path : str or os.PathLike
|
| 60 |
+
Path to the safetensors file containing the model weights and metadata.
|
| 61 |
+
torch_dtype : torch.dtype, optional
|
| 62 |
+
Data type for model initialization (default: ``torch.bfloat16``).
|
| 63 |
+
Set to ``torch.float16`` for Turing GPUs.
|
| 64 |
+
device : str or torch.device, optional
|
| 65 |
+
Device to load the model onto (default: ``"cuda"``).
|
| 66 |
+
If the model is loaded on CPU, it will be automatically moved to GPU.
|
| 67 |
+
|
| 68 |
+
Returns
|
| 69 |
+
-------
|
| 70 |
+
NunchakuT5EncoderModel
|
| 71 |
+
The loaded and quantized T5 encoder model.
|
| 72 |
+
|
| 73 |
+
Example
|
| 74 |
+
-------
|
| 75 |
+
.. code-block:: python
|
| 76 |
+
|
| 77 |
+
model = NunchakuT5EncoderModel.from_pretrained(
|
| 78 |
+
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
|
| 79 |
+
)
|
| 80 |
+
"""
|
| 81 |
+
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
|
| 82 |
+
state_dict, metadata = load_state_dict_in_safetensors(pretrained_model_name_or_path, return_metadata=True)
|
| 83 |
+
|
| 84 |
+
# Load the config file from metadata
|
| 85 |
+
config = json.loads(metadata["config"])
|
| 86 |
+
config = T5Config(**config)
|
| 87 |
+
|
| 88 |
+
# Initialize model on 'meta' device (no memory allocation for weights)
|
| 89 |
+
with init_empty_weights():
|
| 90 |
+
t5_encoder = T5EncoderModel(config).to(kwargs.get("torch_dtype", torch.bfloat16))
|
| 91 |
+
|
| 92 |
+
t5_encoder.eval()
|
| 93 |
+
|
| 94 |
+
# Load the model weights from the safetensors file and quantize supported linear layers
|
| 95 |
+
named_modules = {}
|
| 96 |
+
for name, module in t5_encoder.named_modules():
|
| 97 |
+
assert isinstance(name, str)
|
| 98 |
+
if isinstance(module, nn.Linear):
|
| 99 |
+
if f"{name}.qweight" in state_dict:
|
| 100 |
+
logger.debug(f"Switching {name} to W4Linear")
|
| 101 |
+
qmodule = W4Linear.from_linear(module, group_size=128, init_only=True)
|
| 102 |
+
# modeling_t5.py: T5DenseGatedActDense needs dtype of weight
|
| 103 |
+
qmodule.weight = torch.empty([1], dtype=module.weight.dtype, device=module.weight.device)
|
| 104 |
+
|
| 105 |
+
parent_name, child_name = name.rsplit(".", 1)
|
| 106 |
+
setattr(named_modules[parent_name], child_name, qmodule)
|
| 107 |
+
else:
|
| 108 |
+
named_modules[name] = module
|
| 109 |
+
|
| 110 |
+
device = kwargs.get("device", "cuda")
|
| 111 |
+
if isinstance(device, str):
|
| 112 |
+
device = torch.device(device)
|
| 113 |
+
t5_encoder.to_empty(device=device)
|
| 114 |
+
t5_encoder.load_state_dict(state_dict, strict=True)
|
| 115 |
+
|
| 116 |
+
return t5_encoder
|
nunchaku/models/text_encoders/tinychat_utils.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
This module provides utility functions for quantized linear layers in the TinyChat backend.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
__all__ = ["ceil_num_groups", "convert_to_tinychat_w4x16y16_linear_weight"]
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def ceil_divide(x: int, divisor: int) -> int:
|
| 12 |
+
"""
|
| 13 |
+
Compute the ceiling of integer division.
|
| 14 |
+
|
| 15 |
+
Parameters
|
| 16 |
+
----------
|
| 17 |
+
x : int
|
| 18 |
+
Dividend.
|
| 19 |
+
divisor : int
|
| 20 |
+
Divisor.
|
| 21 |
+
|
| 22 |
+
Returns
|
| 23 |
+
-------
|
| 24 |
+
int
|
| 25 |
+
The smallest integer greater than or equal to ``x / divisor``.
|
| 26 |
+
"""
|
| 27 |
+
return (x + divisor - 1) // divisor
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def ceil_num_groups(in_features: int, group_size: int, weight_bits: int = 4) -> int:
|
| 31 |
+
"""
|
| 32 |
+
Calculate the padded number of quantization groups for TinyChat quantization.
|
| 33 |
+
|
| 34 |
+
This ensures the number of groups is compatible with TinyChat's packing and kernel requirements.
|
| 35 |
+
|
| 36 |
+
Parameters
|
| 37 |
+
----------
|
| 38 |
+
in_features : int
|
| 39 |
+
Input channel size (number of input features).
|
| 40 |
+
group_size : int
|
| 41 |
+
Quantization group size.
|
| 42 |
+
weight_bits : int, optional
|
| 43 |
+
Number of bits per quantized weight (default: 4).
|
| 44 |
+
|
| 45 |
+
Returns
|
| 46 |
+
-------
|
| 47 |
+
int
|
| 48 |
+
The padded number of quantization groups.
|
| 49 |
+
|
| 50 |
+
Raises
|
| 51 |
+
------
|
| 52 |
+
AssertionError
|
| 53 |
+
If ``in_features`` is not divisible by ``group_size``, or if ``weight_bits`` is not 4, 2, or 1.
|
| 54 |
+
NotImplementedError
|
| 55 |
+
If ``group_size`` is not one of the supported values (>=128, 64, 32).
|
| 56 |
+
"""
|
| 57 |
+
assert in_features % group_size == 0, "input channel size should be divisible by group size."
|
| 58 |
+
num_groups = in_features // group_size
|
| 59 |
+
assert weight_bits in (4, 2, 1), "weight bits should be 4, 2, or 1."
|
| 60 |
+
pack_size = 32 // weight_bits # one INT32 contains `pack_size` elements of weights
|
| 61 |
+
num_packs = ceil_divide(num_groups, pack_size)
|
| 62 |
+
if group_size >= 128:
|
| 63 |
+
num_packs_factor = 1
|
| 64 |
+
elif group_size == 64:
|
| 65 |
+
num_packs_factor = 2
|
| 66 |
+
elif group_size == 32:
|
| 67 |
+
num_packs_factor = 4
|
| 68 |
+
else:
|
| 69 |
+
raise NotImplementedError("Unsupported group size for TinyChat quantization.")
|
| 70 |
+
# make sure num_packs is a multiple of num_packs_factor
|
| 71 |
+
num_packs = ceil_divide(num_packs, num_packs_factor) * num_packs_factor
|
| 72 |
+
num_groups = num_packs * pack_size
|
| 73 |
+
return num_groups
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def pack_w4(weight: torch.Tensor) -> torch.Tensor:
|
| 77 |
+
"""
|
| 78 |
+
Pack quantized 4-bit weights into TinyChat's int16 format.
|
| 79 |
+
|
| 80 |
+
This function rearranges and packs 4-bit quantized weights (stored as int32) into
|
| 81 |
+
the format expected by TinyChat CUDA kernels.
|
| 82 |
+
|
| 83 |
+
Parameters
|
| 84 |
+
----------
|
| 85 |
+
weight : torch.Tensor
|
| 86 |
+
Quantized weight tensor of shape (out_features, in_features), dtype int32.
|
| 87 |
+
The input channel dimension must be divisible by 32.
|
| 88 |
+
|
| 89 |
+
Returns
|
| 90 |
+
-------
|
| 91 |
+
torch.Tensor
|
| 92 |
+
Packed weight tensor of dtype int16.
|
| 93 |
+
|
| 94 |
+
Raises
|
| 95 |
+
------
|
| 96 |
+
AssertionError
|
| 97 |
+
If input tensor is not int32 or input channel size is not divisible by 32.
|
| 98 |
+
"""
|
| 99 |
+
assert weight.dtype == torch.int32, f"quantized weight should be torch.int32, but got {weight.dtype}."
|
| 100 |
+
oc, ic = weight.shape
|
| 101 |
+
assert ic % 32 == 0, "input channel size should be divisible by 32."
|
| 102 |
+
# [0, 1, ..., 31] -> [0, 8, 16, 24, 1, 9, 17, 25, ..., 7, 15, 23, 31]
|
| 103 |
+
weight = weight.view(-1, 4, 8)
|
| 104 |
+
weight = weight[:, 0] | (weight[:, 1] << 4) | (weight[:, 2] << 8) | (weight[:, 3] << 12)
|
| 105 |
+
weight = weight.view(oc // 4, 4, ic // 64, 16).permute(0, 2, 1, 3).reshape(oc // 4, ic)
|
| 106 |
+
return weight.to(torch.int16)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def convert_to_tinychat_w4x16y16_linear_weight(
|
| 110 |
+
weight: torch.Tensor,
|
| 111 |
+
scale: torch.Tensor,
|
| 112 |
+
zero: torch.Tensor,
|
| 113 |
+
group_size: int = -1,
|
| 114 |
+
zero_pre_scaled: bool = False,
|
| 115 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 116 |
+
"""
|
| 117 |
+
Convert a floating-point weight tensor to TinyChat W4-X16-Y16 quantized linear format.
|
| 118 |
+
|
| 119 |
+
This function quantizes the input weights to 4 bits per value, applies group-wise
|
| 120 |
+
scaling and zero-point, and packs the result into the format expected by TinyChat
|
| 121 |
+
quantized linear layers.
|
| 122 |
+
|
| 123 |
+
Parameters
|
| 124 |
+
----------
|
| 125 |
+
weight : torch.Tensor
|
| 126 |
+
Floating-point weight tensor of shape (out_features, in_features).
|
| 127 |
+
Must be of dtype ``torch.float16`` or ``torch.bfloat16``.
|
| 128 |
+
scale : torch.Tensor
|
| 129 |
+
Per-group scale tensor (can be broadcastable).
|
| 130 |
+
zero : torch.Tensor
|
| 131 |
+
Per-group zero-point tensor (can be broadcastable).
|
| 132 |
+
group_size : int, optional
|
| 133 |
+
Quantization group size. If set to -1 (default), uses the full input dimension as a single group.
|
| 134 |
+
zero_pre_scaled : bool, optional
|
| 135 |
+
If True, the zero tensor is already scaled by the scale tensor (default: False).
|
| 136 |
+
|
| 137 |
+
Returns
|
| 138 |
+
-------
|
| 139 |
+
tuple of torch.Tensor
|
| 140 |
+
- packed_weight : torch.Tensor
|
| 141 |
+
Packed quantized weight tensor (int16).
|
| 142 |
+
- packed_scale : torch.Tensor
|
| 143 |
+
Packed scale tensor (shape: [num_groups, out_features], dtype matches input).
|
| 144 |
+
- packed_zero : torch.Tensor
|
| 145 |
+
Packed zero-point tensor (shape: [num_groups, out_features], dtype matches input).
|
| 146 |
+
|
| 147 |
+
Raises
|
| 148 |
+
------
|
| 149 |
+
AssertionError
|
| 150 |
+
If input types or shapes are invalid, or quantized values are out of range.
|
| 151 |
+
|
| 152 |
+
Example
|
| 153 |
+
-------
|
| 154 |
+
.. code-block:: python
|
| 155 |
+
|
| 156 |
+
qweight, qscale, qzero = convert_to_tinychat_w4x16y16_linear_weight(
|
| 157 |
+
weight, scale, zero, group_size=128
|
| 158 |
+
)
|
| 159 |
+
"""
|
| 160 |
+
dtype, device = weight.dtype, weight.device
|
| 161 |
+
assert dtype in (torch.float16, torch.bfloat16), "currently tinychat only supports fp16 and bf16."
|
| 162 |
+
assert scale is not None, "scale tensor is required for quantization."
|
| 163 |
+
assert zero is not None, "zero point tensor is required for quantization."
|
| 164 |
+
weight = weight.to(dtype=torch.float32)
|
| 165 |
+
scale = scale.to(dtype=torch.float32, device=device)
|
| 166 |
+
zero = zero.to(dtype=torch.float32, device=device)
|
| 167 |
+
if zero_pre_scaled:
|
| 168 |
+
zero = zero * scale
|
| 169 |
+
oc, ic = weight.shape
|
| 170 |
+
group_size = ic if group_size <= 0 else group_size
|
| 171 |
+
assert group_size <= ic, "group size should be less than or equal to input channel size."
|
| 172 |
+
assert ic % group_size == 0, "input channel size should be divisible by group size."
|
| 173 |
+
ng = ic // group_size
|
| 174 |
+
if scale.numel() == 1:
|
| 175 |
+
scale = scale.view(1, 1).expand(oc, ng)
|
| 176 |
+
scale = scale.reshape(oc, ng).contiguous().view(oc, ng, 1)
|
| 177 |
+
if zero.numel() == 1:
|
| 178 |
+
zero = zero.view(1, 1).expand(oc, ng)
|
| 179 |
+
zero = zero.reshape(oc, ng).contiguous().view(oc, ng, 1)
|
| 180 |
+
weight = weight.view(oc, ng, -1).add_(zero).div_(scale).round_().view(oc, ic)
|
| 181 |
+
assert weight.min() >= 0 and weight.max() <= 15, "quantized weight should be in [0, 15]."
|
| 182 |
+
_weight = pack_w4(weight.to(torch.int32))
|
| 183 |
+
_ng = ceil_num_groups(ic, group_size, weight_bits=4)
|
| 184 |
+
_scale = torch.zeros((_ng, oc), dtype=dtype, device=device)
|
| 185 |
+
_zero = torch.zeros((_ng, oc), dtype=dtype, device=device)
|
| 186 |
+
_scale[:ng] = scale.view(oc, ng).t().to(dtype=dtype)
|
| 187 |
+
_zero[:ng] = zero.view(oc, ng).t().to(dtype=dtype).neg_()
|
| 188 |
+
return _weight, _scale, _zero
|
nunchaku/models/transformers/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .transformer_flux import NunchakuFluxTransformer2dModel
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
"NunchakuFluxTransformer2dModel",
|
| 5 |
+
]
|
nunchaku/models/transformers/transformer_flux.py
ADDED
|
@@ -0,0 +1,991 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Implements the :class:`NunchakuFluxTransformer2dModel`, a quantized transformer for Diffusers with efficient inference and LoRA support.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, Optional, Union
|
| 10 |
+
|
| 11 |
+
import diffusers
|
| 12 |
+
import torch
|
| 13 |
+
from diffusers import FluxTransformer2DModel
|
| 14 |
+
from diffusers.configuration_utils import register_to_config
|
| 15 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 16 |
+
from huggingface_hub import utils
|
| 17 |
+
from packaging.version import Version
|
| 18 |
+
from safetensors.torch import load_file
|
| 19 |
+
from torch import nn
|
| 20 |
+
|
| 21 |
+
from ..._C import QuantizedFluxModel
|
| 22 |
+
from ..._C import utils as cutils
|
| 23 |
+
from ...lora.flux.nunchaku_converter import fuse_vectors, to_nunchaku
|
| 24 |
+
from ...lora.flux.utils import is_nunchaku_format
|
| 25 |
+
from ...utils import check_hardware_compatibility, get_precision, load_state_dict_in_safetensors, pad_tensor
|
| 26 |
+
from .utils import NunchakuModelLoaderMixin
|
| 27 |
+
|
| 28 |
+
SVD_RANK = 32
|
| 29 |
+
|
| 30 |
+
# Get log level from environment variable (default to INFO)
|
| 31 |
+
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
|
| 32 |
+
|
| 33 |
+
# Configure logging
|
| 34 |
+
logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
|
| 35 |
+
logger = logging.getLogger(__name__)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class NunchakuFluxTransformerBlocks(nn.Module):
|
| 39 |
+
"""
|
| 40 |
+
Wrapper for quantized Nunchaku FLUX transformer blocks.
|
| 41 |
+
|
| 42 |
+
This class manages the forward pass, rotary embedding packing, and optional
|
| 43 |
+
residual callbacks for ID embeddings.
|
| 44 |
+
|
| 45 |
+
Parameters
|
| 46 |
+
----------
|
| 47 |
+
m : QuantizedFluxModel
|
| 48 |
+
The quantized transformer model.
|
| 49 |
+
device : str or torch.device
|
| 50 |
+
Device to run the model on.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(self, m: QuantizedFluxModel, device: str | torch.device):
|
| 54 |
+
super(NunchakuFluxTransformerBlocks, self).__init__()
|
| 55 |
+
self.m = m
|
| 56 |
+
self.dtype = torch.bfloat16 if m.isBF16() else torch.float16
|
| 57 |
+
self.device = device
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def pack_rotemb(rotemb: torch.Tensor) -> torch.Tensor:
|
| 61 |
+
"""
|
| 62 |
+
Packs rotary embeddings for efficient computation.
|
| 63 |
+
|
| 64 |
+
Parameters
|
| 65 |
+
----------
|
| 66 |
+
rotemb : torch.Tensor
|
| 67 |
+
Rotary embedding tensor of shape (B, M, D//2, 1, 2), dtype float32.
|
| 68 |
+
|
| 69 |
+
Returns
|
| 70 |
+
-------
|
| 71 |
+
torch.Tensor
|
| 72 |
+
Packed rotary embedding tensor of shape (B, M, D).
|
| 73 |
+
"""
|
| 74 |
+
assert rotemb.dtype == torch.float32
|
| 75 |
+
B = rotemb.shape[0]
|
| 76 |
+
M = rotemb.shape[1]
|
| 77 |
+
D = rotemb.shape[2] * 2
|
| 78 |
+
assert rotemb.shape == (B, M, D // 2, 1, 2)
|
| 79 |
+
assert M % 16 == 0
|
| 80 |
+
assert D % 8 == 0
|
| 81 |
+
rotemb = rotemb.reshape(B, M // 16, 16, D // 8, 8)
|
| 82 |
+
rotemb = rotemb.permute(0, 1, 3, 2, 4)
|
| 83 |
+
# 16*8 pack, FP32 accumulator (C) format
|
| 84 |
+
# https://docs.nvidia.com/cuda/parallel-thread-execution/#mma-16816-c
|
| 85 |
+
##########################################|--M--|--D--|
|
| 86 |
+
##########################################|-3--4--5--6|
|
| 87 |
+
########################################## : : : :
|
| 88 |
+
rotemb = rotemb.reshape(*rotemb.shape[0:3], 2, 8, 4, 2)
|
| 89 |
+
rotemb = rotemb.permute(0, 1, 2, 4, 5, 3, 6)
|
| 90 |
+
rotemb = rotemb.contiguous()
|
| 91 |
+
rotemb = rotemb.view(B, M, D)
|
| 92 |
+
return rotemb
|
| 93 |
+
|
| 94 |
+
def forward(
|
| 95 |
+
self,
|
| 96 |
+
hidden_states: torch.Tensor,
|
| 97 |
+
temb: torch.Tensor,
|
| 98 |
+
encoder_hidden_states: torch.Tensor,
|
| 99 |
+
image_rotary_emb: torch.Tensor,
|
| 100 |
+
id_embeddings=None,
|
| 101 |
+
id_weight=None,
|
| 102 |
+
joint_attention_kwargs=None,
|
| 103 |
+
controlnet_block_samples=None,
|
| 104 |
+
controlnet_single_block_samples=None,
|
| 105 |
+
skip_first_layer=False,
|
| 106 |
+
):
|
| 107 |
+
"""
|
| 108 |
+
Forward pass for the quantized transformer blocks.
|
| 109 |
+
It will call the forward method of ``m`` on the C backend.
|
| 110 |
+
|
| 111 |
+
Parameters
|
| 112 |
+
----------
|
| 113 |
+
hidden_states : torch.Tensor
|
| 114 |
+
Input hidden states for image tokens.
|
| 115 |
+
temb : torch.Tensor
|
| 116 |
+
Temporal embedding tensor.
|
| 117 |
+
encoder_hidden_states : torch.Tensor
|
| 118 |
+
Input hidden states for text tokens.
|
| 119 |
+
image_rotary_emb : torch.Tensor
|
| 120 |
+
Rotary embedding tensor for all tokens.
|
| 121 |
+
id_embeddings : torch.Tensor, optional
|
| 122 |
+
Optional ID embeddings for residual callback.
|
| 123 |
+
id_weight : float, optional
|
| 124 |
+
Weight for ID embedding residual.
|
| 125 |
+
joint_attention_kwargs : dict, optional
|
| 126 |
+
Additional kwargs for joint attention.
|
| 127 |
+
controlnet_block_samples : list[torch.Tensor], optional
|
| 128 |
+
ControlNet block samples.
|
| 129 |
+
controlnet_single_block_samples : list[torch.Tensor], optional
|
| 130 |
+
ControlNet single block samples.
|
| 131 |
+
skip_first_layer : bool, optional
|
| 132 |
+
Whether to skip the first layer.
|
| 133 |
+
|
| 134 |
+
Returns
|
| 135 |
+
-------
|
| 136 |
+
tuple[torch.Tensor, torch.Tensor]
|
| 137 |
+
(encoder_hidden_states, hidden_states) after transformer blocks.
|
| 138 |
+
"""
|
| 139 |
+
# batch_size = hidden_states.shape[0]
|
| 140 |
+
txt_tokens = encoder_hidden_states.shape[1]
|
| 141 |
+
img_tokens = hidden_states.shape[1]
|
| 142 |
+
|
| 143 |
+
self.id_embeddings = id_embeddings
|
| 144 |
+
self.id_weight = id_weight
|
| 145 |
+
self.pulid_ca_idx = 0
|
| 146 |
+
if self.id_embeddings is not None:
|
| 147 |
+
self.set_pulid_residual_callback()
|
| 148 |
+
|
| 149 |
+
original_dtype = hidden_states.dtype
|
| 150 |
+
original_device = hidden_states.device
|
| 151 |
+
|
| 152 |
+
hidden_states = hidden_states.to(self.dtype).to(self.device)
|
| 153 |
+
encoder_hidden_states = encoder_hidden_states.to(self.dtype).to(self.device)
|
| 154 |
+
temb = temb.to(self.dtype).to(self.device)
|
| 155 |
+
image_rotary_emb = image_rotary_emb.to(self.device)
|
| 156 |
+
|
| 157 |
+
if controlnet_block_samples is not None:
|
| 158 |
+
if len(controlnet_block_samples) > 0:
|
| 159 |
+
controlnet_block_samples = torch.stack(controlnet_block_samples).to(self.device)
|
| 160 |
+
else:
|
| 161 |
+
controlnet_block_samples = None
|
| 162 |
+
|
| 163 |
+
if controlnet_single_block_samples is not None:
|
| 164 |
+
if len(controlnet_single_block_samples) > 0:
|
| 165 |
+
controlnet_single_block_samples = torch.stack(controlnet_single_block_samples).to(self.device)
|
| 166 |
+
else:
|
| 167 |
+
controlnet_single_block_samples = None
|
| 168 |
+
|
| 169 |
+
assert image_rotary_emb.ndim == 6
|
| 170 |
+
assert image_rotary_emb.shape[0] == 1
|
| 171 |
+
assert image_rotary_emb.shape[1] == 1
|
| 172 |
+
assert image_rotary_emb.shape[2] == 1 * (txt_tokens + img_tokens)
|
| 173 |
+
# [1, tokens, head_dim / 2, 1, 2] (sincos)
|
| 174 |
+
image_rotary_emb = image_rotary_emb.reshape([1, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
|
| 175 |
+
rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] # .to(self.dtype)
|
| 176 |
+
rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype)
|
| 177 |
+
rotary_emb_single = image_rotary_emb # .to(self.dtype)
|
| 178 |
+
|
| 179 |
+
rotary_emb_txt = self.pack_rotemb(pad_tensor(rotary_emb_txt, 256, 1))
|
| 180 |
+
rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1))
|
| 181 |
+
rotary_emb_single = self.pack_rotemb(pad_tensor(rotary_emb_single, 256, 1))
|
| 182 |
+
hidden_states = self.m.forward(
|
| 183 |
+
hidden_states,
|
| 184 |
+
encoder_hidden_states,
|
| 185 |
+
temb,
|
| 186 |
+
rotary_emb_img,
|
| 187 |
+
rotary_emb_txt,
|
| 188 |
+
rotary_emb_single,
|
| 189 |
+
controlnet_block_samples,
|
| 190 |
+
controlnet_single_block_samples,
|
| 191 |
+
skip_first_layer,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
if self.id_embeddings is not None:
|
| 195 |
+
self.reset_pulid_residual_callback()
|
| 196 |
+
|
| 197 |
+
hidden_states = hidden_states.to(original_dtype).to(original_device)
|
| 198 |
+
|
| 199 |
+
encoder_hidden_states = hidden_states[:, :txt_tokens, ...]
|
| 200 |
+
hidden_states = hidden_states[:, txt_tokens:, ...]
|
| 201 |
+
|
| 202 |
+
return encoder_hidden_states, hidden_states
|
| 203 |
+
|
| 204 |
+
def forward_layer_at(
|
| 205 |
+
self,
|
| 206 |
+
idx: int,
|
| 207 |
+
hidden_states: torch.Tensor,
|
| 208 |
+
encoder_hidden_states: torch.Tensor,
|
| 209 |
+
temb: torch.Tensor,
|
| 210 |
+
image_rotary_emb: torch.Tensor,
|
| 211 |
+
joint_attention_kwargs=None,
|
| 212 |
+
controlnet_block_samples=None,
|
| 213 |
+
controlnet_single_block_samples=None,
|
| 214 |
+
):
|
| 215 |
+
"""
|
| 216 |
+
Forward pass for a specific transformer layer in ``m``.
|
| 217 |
+
|
| 218 |
+
Parameters
|
| 219 |
+
----------
|
| 220 |
+
idx : int
|
| 221 |
+
Index of the transformer layer.
|
| 222 |
+
hidden_states : torch.Tensor
|
| 223 |
+
Input hidden states for image tokens.
|
| 224 |
+
encoder_hidden_states : torch.Tensor
|
| 225 |
+
Input hidden states for text tokens.
|
| 226 |
+
temb : torch.Tensor
|
| 227 |
+
Temporal embedding tensor.
|
| 228 |
+
image_rotary_emb : torch.Tensor
|
| 229 |
+
Rotary embedding tensor for all tokens.
|
| 230 |
+
joint_attention_kwargs : dict, optional
|
| 231 |
+
Additional kwargs for joint attention.
|
| 232 |
+
controlnet_block_samples : list[torch.Tensor], optional
|
| 233 |
+
ControlNet block samples.
|
| 234 |
+
controlnet_single_block_samples : list[torch.Tensor], optional
|
| 235 |
+
ControlNet single block samples.
|
| 236 |
+
|
| 237 |
+
Returns
|
| 238 |
+
-------
|
| 239 |
+
tuple[torch.Tensor, torch.Tensor]
|
| 240 |
+
(encoder_hidden_states, hidden_states) after the specified layer.
|
| 241 |
+
"""
|
| 242 |
+
# batch_size = hidden_states.shape[0]
|
| 243 |
+
txt_tokens = encoder_hidden_states.shape[1]
|
| 244 |
+
img_tokens = hidden_states.shape[1]
|
| 245 |
+
|
| 246 |
+
original_dtype = hidden_states.dtype
|
| 247 |
+
original_device = hidden_states.device
|
| 248 |
+
|
| 249 |
+
hidden_states = hidden_states.to(self.dtype).to(self.device)
|
| 250 |
+
encoder_hidden_states = encoder_hidden_states.to(self.dtype).to(self.device)
|
| 251 |
+
temb = temb.to(self.dtype).to(self.device)
|
| 252 |
+
image_rotary_emb = image_rotary_emb.to(self.device)
|
| 253 |
+
|
| 254 |
+
if controlnet_block_samples is not None:
|
| 255 |
+
controlnet_block_samples = torch.stack(controlnet_block_samples).to(self.device)
|
| 256 |
+
if controlnet_single_block_samples is not None:
|
| 257 |
+
controlnet_single_block_samples = torch.stack(controlnet_single_block_samples).to(self.device)
|
| 258 |
+
|
| 259 |
+
assert image_rotary_emb.ndim == 6
|
| 260 |
+
assert image_rotary_emb.shape[0] == 1
|
| 261 |
+
assert image_rotary_emb.shape[1] == 1
|
| 262 |
+
assert image_rotary_emb.shape[2] == 1 * (txt_tokens + img_tokens)
|
| 263 |
+
# [1, tokens, head_dim / 2, 1, 2] (sincos)
|
| 264 |
+
image_rotary_emb = image_rotary_emb.reshape([1, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
|
| 265 |
+
rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] # .to(self.dtype)
|
| 266 |
+
rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype)
|
| 267 |
+
|
| 268 |
+
rotary_emb_txt = self.pack_rotemb(pad_tensor(rotary_emb_txt, 256, 1))
|
| 269 |
+
rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1))
|
| 270 |
+
|
| 271 |
+
hidden_states, encoder_hidden_states = self.m.forward_layer(
|
| 272 |
+
idx,
|
| 273 |
+
hidden_states,
|
| 274 |
+
encoder_hidden_states,
|
| 275 |
+
temb,
|
| 276 |
+
rotary_emb_img,
|
| 277 |
+
rotary_emb_txt,
|
| 278 |
+
controlnet_block_samples,
|
| 279 |
+
controlnet_single_block_samples,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
hidden_states = hidden_states.to(original_dtype).to(original_device)
|
| 283 |
+
encoder_hidden_states = encoder_hidden_states.to(original_dtype).to(original_device)
|
| 284 |
+
|
| 285 |
+
return encoder_hidden_states, hidden_states
|
| 286 |
+
|
| 287 |
+
def set_pulid_residual_callback(self):
|
| 288 |
+
"""
|
| 289 |
+
Sets the residual callback for PulID (personalized ID) embeddings.
|
| 290 |
+
"""
|
| 291 |
+
id_embeddings = self.id_embeddings
|
| 292 |
+
pulid_ca = self.pulid_ca
|
| 293 |
+
pulid_ca_idx = [self.pulid_ca_idx]
|
| 294 |
+
id_weight = self.id_weight
|
| 295 |
+
|
| 296 |
+
def callback(hidden_states):
|
| 297 |
+
ip = id_weight * pulid_ca[pulid_ca_idx[0]](id_embeddings, hidden_states)
|
| 298 |
+
pulid_ca_idx[0] += 1
|
| 299 |
+
return ip
|
| 300 |
+
|
| 301 |
+
self.callback_holder = callback
|
| 302 |
+
self.m.set_residual_callback(callback)
|
| 303 |
+
|
| 304 |
+
def reset_pulid_residual_callback(self):
|
| 305 |
+
"""
|
| 306 |
+
Resets the PulID residual callback to None.
|
| 307 |
+
"""
|
| 308 |
+
self.callback_holder = None
|
| 309 |
+
self.m.set_residual_callback(None)
|
| 310 |
+
|
| 311 |
+
def __del__(self):
|
| 312 |
+
"""
|
| 313 |
+
Destructor to reset the quantized model.
|
| 314 |
+
"""
|
| 315 |
+
self.m.reset()
|
| 316 |
+
|
| 317 |
+
def norm1(
|
| 318 |
+
self,
|
| 319 |
+
hidden_states: torch.Tensor,
|
| 320 |
+
emb: torch.Tensor,
|
| 321 |
+
idx: int = 0,
|
| 322 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 323 |
+
"""
|
| 324 |
+
Runs the norm_one_forward for a specific layer in ``m``.
|
| 325 |
+
|
| 326 |
+
Parameters
|
| 327 |
+
----------
|
| 328 |
+
hidden_states : torch.Tensor
|
| 329 |
+
Input hidden states.
|
| 330 |
+
emb : torch.Tensor
|
| 331 |
+
Embedding tensor.
|
| 332 |
+
idx : int, optional
|
| 333 |
+
Layer index (default: 0).
|
| 334 |
+
|
| 335 |
+
Returns
|
| 336 |
+
-------
|
| 337 |
+
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
|
| 338 |
+
Output tensors from norm_one_forward.
|
| 339 |
+
"""
|
| 340 |
+
return self.m.norm_one_forward(idx, hidden_states, emb)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
| 344 |
+
"""
|
| 345 |
+
Rotary positional embedding function.
|
| 346 |
+
|
| 347 |
+
Parameters
|
| 348 |
+
----------
|
| 349 |
+
pos : torch.Tensor
|
| 350 |
+
Position tensor of shape (..., n).
|
| 351 |
+
dim : int
|
| 352 |
+
Embedding dimension (must be even).
|
| 353 |
+
theta : int
|
| 354 |
+
Rotary base.
|
| 355 |
+
|
| 356 |
+
Returns
|
| 357 |
+
-------
|
| 358 |
+
torch.Tensor
|
| 359 |
+
Rotary embedding tensor.
|
| 360 |
+
"""
|
| 361 |
+
assert dim % 2 == 0, "The dimension must be even."
|
| 362 |
+
|
| 363 |
+
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
| 364 |
+
omega = 1.0 / (theta**scale)
|
| 365 |
+
|
| 366 |
+
batch_size, seq_length = pos.shape
|
| 367 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
| 368 |
+
|
| 369 |
+
USE_SINCOS = True
|
| 370 |
+
if USE_SINCOS:
|
| 371 |
+
cos_out = torch.cos(out)
|
| 372 |
+
sin_out = torch.sin(out)
|
| 373 |
+
stacked_out = torch.stack([sin_out, cos_out], dim=-1)
|
| 374 |
+
out = stacked_out.view(batch_size, -1, dim // 2, 1, 2)
|
| 375 |
+
else:
|
| 376 |
+
out = out.view(batch_size, -1, dim // 2, 1, 1)
|
| 377 |
+
|
| 378 |
+
return out.float()
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
class EmbedND(nn.Module):
|
| 382 |
+
"""
|
| 383 |
+
Multi-dimensional rotary embedding module.
|
| 384 |
+
|
| 385 |
+
Parameters
|
| 386 |
+
----------
|
| 387 |
+
dim : int
|
| 388 |
+
Embedding dimension.
|
| 389 |
+
theta : int
|
| 390 |
+
Rotary base.
|
| 391 |
+
axes_dim : list[int]
|
| 392 |
+
List of axis dimensions for each spatial axis.
|
| 393 |
+
"""
|
| 394 |
+
|
| 395 |
+
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
| 396 |
+
super(EmbedND, self).__init__()
|
| 397 |
+
self.dim = dim
|
| 398 |
+
self.theta = theta
|
| 399 |
+
self.axes_dim = axes_dim
|
| 400 |
+
|
| 401 |
+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
| 402 |
+
"""
|
| 403 |
+
Computes rotary embeddings for multi-dimensional positions.
|
| 404 |
+
|
| 405 |
+
Parameters
|
| 406 |
+
----------
|
| 407 |
+
ids : torch.Tensor
|
| 408 |
+
Position indices tensor of shape (..., n_axes).
|
| 409 |
+
|
| 410 |
+
Returns
|
| 411 |
+
-------
|
| 412 |
+
torch.Tensor
|
| 413 |
+
Rotary embedding tensor.
|
| 414 |
+
"""
|
| 415 |
+
if Version(diffusers.__version__) >= Version("0.31.0"):
|
| 416 |
+
ids = ids[None, ...]
|
| 417 |
+
n_axes = ids.shape[-1]
|
| 418 |
+
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
|
| 419 |
+
return emb.unsqueeze(1)
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def load_quantized_module(
|
| 423 |
+
path_or_state_dict: str | os.PathLike[str] | dict[str, torch.Tensor],
|
| 424 |
+
device: str | torch.device = "cuda",
|
| 425 |
+
use_fp4: bool = False,
|
| 426 |
+
offload: bool = False,
|
| 427 |
+
bf16: bool = True,
|
| 428 |
+
) -> QuantizedFluxModel:
|
| 429 |
+
"""
|
| 430 |
+
Loads a quantized Nunchaku FLUX model from a state dict or file.
|
| 431 |
+
|
| 432 |
+
Parameters
|
| 433 |
+
----------
|
| 434 |
+
path_or_state_dict : str, os.PathLike, or dict
|
| 435 |
+
Path to the quantized model file or a state dict.
|
| 436 |
+
device : str or torch.device, optional
|
| 437 |
+
Device to load the model on (default: "cuda").
|
| 438 |
+
use_fp4 : bool, optional
|
| 439 |
+
Whether to use FP4 quantization (default: False).
|
| 440 |
+
offload : bool, optional
|
| 441 |
+
Whether to offload weights to CPU (default: False).
|
| 442 |
+
bf16 : bool, optional
|
| 443 |
+
Whether to use bfloat16 (default: True).
|
| 444 |
+
|
| 445 |
+
Returns
|
| 446 |
+
-------
|
| 447 |
+
QuantizedFluxModel
|
| 448 |
+
Loaded quantized model.
|
| 449 |
+
"""
|
| 450 |
+
device = torch.device(device)
|
| 451 |
+
assert device.type == "cuda"
|
| 452 |
+
m = QuantizedFluxModel()
|
| 453 |
+
cutils.disable_memory_auto_release()
|
| 454 |
+
m.init(use_fp4, offload, bf16, 0 if device.index is None else device.index)
|
| 455 |
+
if isinstance(path_or_state_dict, dict):
|
| 456 |
+
m.loadDict(path_or_state_dict, True)
|
| 457 |
+
else:
|
| 458 |
+
m.load(str(path_or_state_dict))
|
| 459 |
+
return m
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoaderMixin):
|
| 463 |
+
"""
|
| 464 |
+
Nunchaku FLUX Transformer 2D Model.
|
| 465 |
+
|
| 466 |
+
This class implements a quantized transformer model compatible with the Diffusers
|
| 467 |
+
library, supporting LoRA, rotary embeddings, and efficient inference.
|
| 468 |
+
|
| 469 |
+
Parameters
|
| 470 |
+
----------
|
| 471 |
+
patch_size : int, optional
|
| 472 |
+
Patch size for input images (default: 1).
|
| 473 |
+
in_channels : int, optional
|
| 474 |
+
Number of input channels (default: 64).
|
| 475 |
+
out_channels : int or None, optional
|
| 476 |
+
Number of output channels (default: None).
|
| 477 |
+
num_layers : int, optional
|
| 478 |
+
Number of transformer layers (default: 19).
|
| 479 |
+
num_single_layers : int, optional
|
| 480 |
+
Number of single transformer layers (default: 38).
|
| 481 |
+
attention_head_dim : int, optional
|
| 482 |
+
Dimension of each attention head (default: 128).
|
| 483 |
+
num_attention_heads : int, optional
|
| 484 |
+
Number of attention heads (default: 24).
|
| 485 |
+
joint_attention_dim : int, optional
|
| 486 |
+
Joint attention dimension (default: 4096).
|
| 487 |
+
pooled_projection_dim : int, optional
|
| 488 |
+
Pooled projection dimension (default: 768).
|
| 489 |
+
guidance_embeds : bool, optional
|
| 490 |
+
Whether to use guidance embeddings (default: False).
|
| 491 |
+
axes_dims_rope : tuple[int], optional
|
| 492 |
+
Axes dimensions for rotary embeddings (default: (16, 56, 56)).
|
| 493 |
+
"""
|
| 494 |
+
|
| 495 |
+
@register_to_config
|
| 496 |
+
def __init__(
|
| 497 |
+
self,
|
| 498 |
+
patch_size: int = 1,
|
| 499 |
+
in_channels: int = 64,
|
| 500 |
+
out_channels: int | None = None,
|
| 501 |
+
num_layers: int = 19,
|
| 502 |
+
num_single_layers: int = 38,
|
| 503 |
+
attention_head_dim: int = 128,
|
| 504 |
+
num_attention_heads: int = 24,
|
| 505 |
+
joint_attention_dim: int = 4096,
|
| 506 |
+
pooled_projection_dim: int = 768,
|
| 507 |
+
guidance_embeds: bool = False,
|
| 508 |
+
axes_dims_rope: tuple[int] = (16, 56, 56),
|
| 509 |
+
):
|
| 510 |
+
super(NunchakuFluxTransformer2dModel, self).__init__(
|
| 511 |
+
patch_size=patch_size,
|
| 512 |
+
in_channels=in_channels,
|
| 513 |
+
out_channels=out_channels,
|
| 514 |
+
num_layers=num_layers,
|
| 515 |
+
num_single_layers=num_single_layers,
|
| 516 |
+
attention_head_dim=attention_head_dim,
|
| 517 |
+
num_attention_heads=num_attention_heads,
|
| 518 |
+
joint_attention_dim=joint_attention_dim,
|
| 519 |
+
pooled_projection_dim=pooled_projection_dim,
|
| 520 |
+
guidance_embeds=guidance_embeds,
|
| 521 |
+
axes_dims_rope=axes_dims_rope,
|
| 522 |
+
)
|
| 523 |
+
# these state_dicts are used for supporting lora
|
| 524 |
+
self._unquantized_part_sd: dict[str, torch.Tensor] = {}
|
| 525 |
+
self._unquantized_part_loras: dict[str, torch.Tensor] = {}
|
| 526 |
+
self._quantized_part_sd: dict[str, torch.Tensor] = {}
|
| 527 |
+
self._quantized_part_vectors: dict[str, torch.Tensor] = {}
|
| 528 |
+
self._original_in_channels = in_channels
|
| 529 |
+
|
| 530 |
+
# ComfyUI LoRA related
|
| 531 |
+
self.comfy_lora_meta_list = []
|
| 532 |
+
self.comfy_lora_sd_list = []
|
| 533 |
+
|
| 534 |
+
@classmethod
|
| 535 |
+
@utils.validate_hf_hub_args
|
| 536 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
|
| 537 |
+
"""
|
| 538 |
+
Loads a Nunchaku FLUX transformer model from pretrained weights.
|
| 539 |
+
|
| 540 |
+
Parameters
|
| 541 |
+
----------
|
| 542 |
+
pretrained_model_name_or_path : str or os.PathLike
|
| 543 |
+
Path to the model directory or HuggingFace repo.
|
| 544 |
+
**kwargs
|
| 545 |
+
Additional keyword arguments for device, offload, torch_dtype, precision, etc.
|
| 546 |
+
|
| 547 |
+
Returns
|
| 548 |
+
-------
|
| 549 |
+
NunchakuFluxTransformer2dModel or (NunchakuFluxTransformer2dModel, dict)
|
| 550 |
+
The loaded model, and optionally metadata if `return_metadata=True`.
|
| 551 |
+
"""
|
| 552 |
+
device = kwargs.get("device", "cuda")
|
| 553 |
+
if isinstance(device, str):
|
| 554 |
+
device = torch.device(device)
|
| 555 |
+
offload = kwargs.get("offload", False)
|
| 556 |
+
torch_dtype = kwargs.get("torch_dtype", torch.bfloat16)
|
| 557 |
+
precision = get_precision(kwargs.get("precision", "auto"), device, pretrained_model_name_or_path)
|
| 558 |
+
metadata = None
|
| 559 |
+
|
| 560 |
+
if isinstance(pretrained_model_name_or_path, str):
|
| 561 |
+
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
|
| 562 |
+
if pretrained_model_name_or_path.is_file() or pretrained_model_name_or_path.name.endswith(
|
| 563 |
+
(".safetensors", ".sft")
|
| 564 |
+
):
|
| 565 |
+
transformer, model_state_dict, metadata = cls._build_model(pretrained_model_name_or_path, **kwargs)
|
| 566 |
+
quantized_part_sd = {}
|
| 567 |
+
unquantized_part_sd = {}
|
| 568 |
+
for k, v in model_state_dict.items():
|
| 569 |
+
if k.startswith(("transformer_blocks.", "single_transformer_blocks.")):
|
| 570 |
+
quantized_part_sd[k] = v
|
| 571 |
+
else:
|
| 572 |
+
unquantized_part_sd[k] = v
|
| 573 |
+
precision = get_precision(device=device)
|
| 574 |
+
quantization_config = json.loads(metadata["quantization_config"])
|
| 575 |
+
check_hardware_compatibility(quantization_config, device)
|
| 576 |
+
else:
|
| 577 |
+
transformer, unquantized_part_path, transformer_block_path = cls._build_model_legacy(
|
| 578 |
+
pretrained_model_name_or_path, **kwargs
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
# get the default LoRA branch and all the vectors
|
| 582 |
+
quantized_part_sd = load_file(transformer_block_path)
|
| 583 |
+
unquantized_part_sd = load_file(unquantized_part_path)
|
| 584 |
+
new_quantized_part_sd = {}
|
| 585 |
+
for k, v in quantized_part_sd.items():
|
| 586 |
+
if v.ndim == 1:
|
| 587 |
+
new_quantized_part_sd[k] = v
|
| 588 |
+
elif "qweight" in k:
|
| 589 |
+
# only the shape information of this tensor is needed
|
| 590 |
+
new_quantized_part_sd[k] = v.to("meta")
|
| 591 |
+
|
| 592 |
+
# if the tensor has qweight, but does not have low-rank branch, we need to add some artificial tensors
|
| 593 |
+
for t in ["lora_up", "lora_down"]:
|
| 594 |
+
new_k = k.replace(".qweight", f".{t}")
|
| 595 |
+
if new_k not in quantized_part_sd:
|
| 596 |
+
oc, ic = v.shape
|
| 597 |
+
ic = ic * 2 # v is packed into INT8, so we need to double the size
|
| 598 |
+
new_quantized_part_sd[k.replace(".qweight", f".{t}")] = torch.zeros(
|
| 599 |
+
(0, ic) if t == "lora_down" else (oc, 0), device=v.device, dtype=torch.bfloat16
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
elif "lora" in k:
|
| 603 |
+
new_quantized_part_sd[k] = v
|
| 604 |
+
transformer._quantized_part_sd = new_quantized_part_sd
|
| 605 |
+
m = load_quantized_module(
|
| 606 |
+
quantized_part_sd,
|
| 607 |
+
device=device,
|
| 608 |
+
use_fp4=precision == "fp4",
|
| 609 |
+
offload=offload,
|
| 610 |
+
bf16=torch_dtype == torch.bfloat16,
|
| 611 |
+
)
|
| 612 |
+
transformer.inject_quantized_module(m, device)
|
| 613 |
+
transformer.to_empty(device=device)
|
| 614 |
+
|
| 615 |
+
transformer.load_state_dict(unquantized_part_sd, strict=False)
|
| 616 |
+
transformer._unquantized_part_sd = unquantized_part_sd
|
| 617 |
+
|
| 618 |
+
if kwargs.get("return_metadata", False):
|
| 619 |
+
return transformer, metadata
|
| 620 |
+
else:
|
| 621 |
+
return transformer
|
| 622 |
+
|
| 623 |
+
def inject_quantized_module(self, m: QuantizedFluxModel, device: str | torch.device = "cuda"):
|
| 624 |
+
"""
|
| 625 |
+
Injects a quantized module into the model and sets up transformer blocks.
|
| 626 |
+
|
| 627 |
+
Parameters
|
| 628 |
+
----------
|
| 629 |
+
m : QuantizedFluxModel
|
| 630 |
+
The quantized transformer model.
|
| 631 |
+
device : str or torch.device, optional
|
| 632 |
+
Device to run the model on (default: "cuda").
|
| 633 |
+
|
| 634 |
+
Returns
|
| 635 |
+
-------
|
| 636 |
+
self : NunchakuFluxTransformer2dModel
|
| 637 |
+
The model with injected quantized module.
|
| 638 |
+
"""
|
| 639 |
+
print("Injecting quantized module")
|
| 640 |
+
self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=[16, 56, 56])
|
| 641 |
+
|
| 642 |
+
### Compatible with the original forward method
|
| 643 |
+
self.transformer_blocks = nn.ModuleList([NunchakuFluxTransformerBlocks(m, device)])
|
| 644 |
+
self.single_transformer_blocks = nn.ModuleList([])
|
| 645 |
+
|
| 646 |
+
return self
|
| 647 |
+
|
| 648 |
+
def set_attention_impl(self, impl: str):
|
| 649 |
+
"""
|
| 650 |
+
Set the attention implementation for the quantized transformer block.
|
| 651 |
+
|
| 652 |
+
Parameters
|
| 653 |
+
----------
|
| 654 |
+
impl : str
|
| 655 |
+
Attention implementation to use. Supported values:
|
| 656 |
+
|
| 657 |
+
- ``"flashattn2"`` (default): Standard FlashAttention-2.
|
| 658 |
+
- ``"nunchaku-fp16"``: Uses FP16 attention accumulation, up to 1.2× faster than FlashAttention-2 on NVIDIA 30-, 40-, and 50-series GPUs.
|
| 659 |
+
"""
|
| 660 |
+
block = self.transformer_blocks[0]
|
| 661 |
+
assert isinstance(block, NunchakuFluxTransformerBlocks)
|
| 662 |
+
block.m.setAttentionImpl(impl)
|
| 663 |
+
|
| 664 |
+
### LoRA Related Functions
|
| 665 |
+
|
| 666 |
+
def _expand_module(self, module_name: str, new_shape: tuple[int, int]):
|
| 667 |
+
"""
|
| 668 |
+
Expands a linear module to a new shape for LoRA compatibility.
|
| 669 |
+
Mostly for FLUX.1-tools LoRA which changes the input channels.
|
| 670 |
+
|
| 671 |
+
Parameters
|
| 672 |
+
----------
|
| 673 |
+
module_name : str
|
| 674 |
+
Name of the module to expand.
|
| 675 |
+
new_shape : tuple[int, int]
|
| 676 |
+
New shape (out_features, in_features) for the module.
|
| 677 |
+
"""
|
| 678 |
+
module = self.get_submodule(module_name)
|
| 679 |
+
assert isinstance(module, nn.Linear)
|
| 680 |
+
weight_shape = module.weight.shape
|
| 681 |
+
logger.info("Expand the shape of module {} from {} to {}".format(module_name, tuple(weight_shape), new_shape))
|
| 682 |
+
assert new_shape[0] >= weight_shape[0] and new_shape[1] >= weight_shape[1]
|
| 683 |
+
new_module = nn.Linear(
|
| 684 |
+
new_shape[1],
|
| 685 |
+
new_shape[0],
|
| 686 |
+
bias=module.bias is not None,
|
| 687 |
+
device=module.weight.device,
|
| 688 |
+
dtype=module.weight.dtype,
|
| 689 |
+
)
|
| 690 |
+
new_module.weight.data.zero_()
|
| 691 |
+
new_module.weight.data[: weight_shape[0], : weight_shape[1]] = module.weight.data
|
| 692 |
+
self._unquantized_part_sd[f"{module_name}.weight"] = new_module.weight.data.clone()
|
| 693 |
+
if new_module.bias is not None:
|
| 694 |
+
new_module.bias.data.zero_()
|
| 695 |
+
new_module.bias.data[: weight_shape[0]] = module.bias.data
|
| 696 |
+
self._unquantized_part_sd[f"{module_name}.bias"] = new_module.bias.data.clone()
|
| 697 |
+
parent_name = ".".join(module_name.split(".")[:-1])
|
| 698 |
+
parent_module = self.get_submodule(parent_name)
|
| 699 |
+
parent_module.add_module(module_name.split(".")[-1], new_module)
|
| 700 |
+
|
| 701 |
+
if module_name == "x_embedder":
|
| 702 |
+
new_value = int(new_module.weight.data.shape[1])
|
| 703 |
+
old_value = getattr(self.config, "in_channels")
|
| 704 |
+
if new_value != old_value:
|
| 705 |
+
logger.info(f"Update in_channels from {old_value} to {new_value}")
|
| 706 |
+
setattr(self.config, "in_channels", new_value)
|
| 707 |
+
|
| 708 |
+
def _update_unquantized_part_lora_params(self, strength: float = 1):
|
| 709 |
+
"""
|
| 710 |
+
Updates the unquantized part of the model with LoRA parameters.
|
| 711 |
+
|
| 712 |
+
Parameters
|
| 713 |
+
----------
|
| 714 |
+
strength : float, optional
|
| 715 |
+
LoRA scaling strength (default: 1).
|
| 716 |
+
"""
|
| 717 |
+
# check if we need to expand the linear layers
|
| 718 |
+
device = next(self.parameters()).device
|
| 719 |
+
for k, v in self._unquantized_part_loras.items():
|
| 720 |
+
if "lora_A" in k:
|
| 721 |
+
lora_a = v
|
| 722 |
+
lora_b = self._unquantized_part_loras[k.replace(".lora_A.", ".lora_B.")]
|
| 723 |
+
diff_shape = (lora_b.shape[0], lora_a.shape[1])
|
| 724 |
+
weight_shape = self._unquantized_part_sd[k.replace(".lora_A.", ".")].shape
|
| 725 |
+
if diff_shape[0] > weight_shape[0] or diff_shape[1] > weight_shape[1]:
|
| 726 |
+
module_name = ".".join(k.split(".")[:-2])
|
| 727 |
+
self._expand_module(module_name, diff_shape)
|
| 728 |
+
elif v.ndim == 1:
|
| 729 |
+
diff_shape = v.shape
|
| 730 |
+
weight_shape = self._unquantized_part_sd[k].shape
|
| 731 |
+
if diff_shape[0] > weight_shape[0]:
|
| 732 |
+
assert diff_shape[0] >= weight_shape[0]
|
| 733 |
+
module_name = ".".join(k.split(".")[:-1])
|
| 734 |
+
module = self.get_submodule(module_name)
|
| 735 |
+
weight_shape = module.weight.shape
|
| 736 |
+
diff_shape = (diff_shape[0], weight_shape[1])
|
| 737 |
+
self._expand_module(module_name, diff_shape)
|
| 738 |
+
new_state_dict = {}
|
| 739 |
+
for k in self._unquantized_part_sd.keys():
|
| 740 |
+
v = self._unquantized_part_sd[k]
|
| 741 |
+
v = v.to(device)
|
| 742 |
+
self._unquantized_part_sd[k] = v
|
| 743 |
+
|
| 744 |
+
if v.ndim == 1 and k in self._unquantized_part_loras:
|
| 745 |
+
diff = strength * self._unquantized_part_loras[k]
|
| 746 |
+
if diff.shape[0] < v.shape[0]:
|
| 747 |
+
diff = torch.cat(
|
| 748 |
+
[diff, torch.zeros(v.shape[0] - diff.shape[0], device=device, dtype=v.dtype)], dim=0
|
| 749 |
+
)
|
| 750 |
+
new_state_dict[k] = v + diff
|
| 751 |
+
elif v.ndim == 2 and k.replace(".weight", ".lora_B.weight") in self._unquantized_part_loras:
|
| 752 |
+
lora_a = self._unquantized_part_loras[k.replace(".weight", ".lora_A.weight")]
|
| 753 |
+
lora_b = self._unquantized_part_loras[k.replace(".weight", ".lora_B.weight")]
|
| 754 |
+
|
| 755 |
+
if lora_a.shape[1] < v.shape[1]:
|
| 756 |
+
lora_a = torch.cat(
|
| 757 |
+
[
|
| 758 |
+
lora_a,
|
| 759 |
+
torch.zeros(lora_a.shape[0], v.shape[1] - lora_a.shape[1], device=device, dtype=v.dtype),
|
| 760 |
+
],
|
| 761 |
+
dim=1,
|
| 762 |
+
)
|
| 763 |
+
if lora_b.shape[0] < v.shape[0]:
|
| 764 |
+
lora_b = torch.cat(
|
| 765 |
+
[
|
| 766 |
+
lora_b,
|
| 767 |
+
torch.zeros(v.shape[0] - lora_b.shape[0], lora_b.shape[1], device=device, dtype=v.dtype),
|
| 768 |
+
],
|
| 769 |
+
dim=0,
|
| 770 |
+
)
|
| 771 |
+
|
| 772 |
+
diff = strength * (lora_b @ lora_a)
|
| 773 |
+
new_state_dict[k] = v + diff
|
| 774 |
+
else:
|
| 775 |
+
new_state_dict[k] = v
|
| 776 |
+
self.load_state_dict(new_state_dict, strict=True)
|
| 777 |
+
|
| 778 |
+
def update_lora_params(self, path_or_state_dict: str | dict[str, torch.Tensor]):
|
| 779 |
+
"""
|
| 780 |
+
Update the model with new LoRA parameters.
|
| 781 |
+
|
| 782 |
+
Parameters
|
| 783 |
+
----------
|
| 784 |
+
path_or_state_dict : str or dict
|
| 785 |
+
Path to a LoRA weights file or a state dict. The path supports:
|
| 786 |
+
|
| 787 |
+
- Local file path, e.g., ``"/path/to/your/lora.safetensors"``
|
| 788 |
+
- HuggingFace repo with file, e.g., ``"user/repo/lora.safetensors"``
|
| 789 |
+
(automatically downloaded and cached)
|
| 790 |
+
"""
|
| 791 |
+
if isinstance(path_or_state_dict, dict):
|
| 792 |
+
state_dict = {
|
| 793 |
+
k: v for k, v in path_or_state_dict.items()
|
| 794 |
+
} # copy a new one to avoid modifying the original one
|
| 795 |
+
else:
|
| 796 |
+
state_dict = load_state_dict_in_safetensors(path_or_state_dict)
|
| 797 |
+
|
| 798 |
+
if not is_nunchaku_format(state_dict):
|
| 799 |
+
state_dict = to_nunchaku(state_dict, base_sd=self._quantized_part_sd)
|
| 800 |
+
|
| 801 |
+
unquantized_part_loras = {}
|
| 802 |
+
for k, v in list(state_dict.items()):
|
| 803 |
+
device = next(self.parameters()).device
|
| 804 |
+
if "transformer_blocks" not in k:
|
| 805 |
+
unquantized_part_loras[k] = state_dict.pop(k).to(device)
|
| 806 |
+
|
| 807 |
+
if len(self._unquantized_part_loras) > 0 or len(unquantized_part_loras) > 0:
|
| 808 |
+
self._unquantized_part_loras = unquantized_part_loras
|
| 809 |
+
|
| 810 |
+
self._unquantized_part_sd = {k: v for k, v in self._unquantized_part_sd.items() if "pulid_ca" not in k}
|
| 811 |
+
self._update_unquantized_part_lora_params(1)
|
| 812 |
+
|
| 813 |
+
quantized_part_vectors = {}
|
| 814 |
+
for k, v in list(state_dict.items()):
|
| 815 |
+
if v.ndim == 1:
|
| 816 |
+
quantized_part_vectors[k] = state_dict.pop(k)
|
| 817 |
+
if len(self._quantized_part_vectors) > 0 or len(quantized_part_vectors) > 0:
|
| 818 |
+
self._quantized_part_vectors = quantized_part_vectors
|
| 819 |
+
updated_vectors = fuse_vectors(quantized_part_vectors, self._quantized_part_sd, 1)
|
| 820 |
+
state_dict.update(updated_vectors)
|
| 821 |
+
|
| 822 |
+
# Get the vectors from the quantized part
|
| 823 |
+
|
| 824 |
+
block = self.transformer_blocks[0]
|
| 825 |
+
assert isinstance(block, NunchakuFluxTransformerBlocks)
|
| 826 |
+
|
| 827 |
+
block.m.loadDict(state_dict, True)
|
| 828 |
+
|
| 829 |
+
def set_lora_strength(self, strength: float = 1):
|
| 830 |
+
"""
|
| 831 |
+
Sets the LoRA scaling strength for the model.
|
| 832 |
+
|
| 833 |
+
Note: This function can only be used with a single LoRA. For multiple LoRAs,
|
| 834 |
+
please fuse the LoRA scale into the weights.
|
| 835 |
+
|
| 836 |
+
Parameters
|
| 837 |
+
----------
|
| 838 |
+
strength : float, optional
|
| 839 |
+
LoRA scaling strength (default: 1).
|
| 840 |
+
|
| 841 |
+
Note: This function will change the strength of all the LoRAs. So only use it when you only have a single LoRA.
|
| 842 |
+
"""
|
| 843 |
+
block = self.transformer_blocks[0]
|
| 844 |
+
assert isinstance(block, NunchakuFluxTransformerBlocks)
|
| 845 |
+
block.m.setLoraScale(SVD_RANK, strength)
|
| 846 |
+
if len(self._unquantized_part_loras) > 0:
|
| 847 |
+
self._update_unquantized_part_lora_params(strength)
|
| 848 |
+
if len(self._quantized_part_vectors) > 0:
|
| 849 |
+
vector_dict = fuse_vectors(self._quantized_part_vectors, self._quantized_part_sd, strength)
|
| 850 |
+
block.m.loadDict(vector_dict, True)
|
| 851 |
+
|
| 852 |
+
def reset_x_embedder(self):
|
| 853 |
+
"""
|
| 854 |
+
Resets the x_embedder module if the input channel count has changed.
|
| 855 |
+
This is used for removing the effect of FLUX.1-tools LoRA which changes the input channels.
|
| 856 |
+
"""
|
| 857 |
+
# if change the model in channels, we need to update the x_embedder
|
| 858 |
+
if self._original_in_channels != self.config.in_channels:
|
| 859 |
+
assert self._original_in_channels < self.config.in_channels
|
| 860 |
+
old_module = self.x_embedder
|
| 861 |
+
new_module = nn.Linear(
|
| 862 |
+
in_features=self._original_in_channels,
|
| 863 |
+
out_features=old_module.out_features,
|
| 864 |
+
bias=old_module.bias is not None,
|
| 865 |
+
device=old_module.weight.device,
|
| 866 |
+
dtype=old_module.weight.dtype,
|
| 867 |
+
)
|
| 868 |
+
new_module.weight.data.copy_(old_module.weight.data[: new_module.out_features, : new_module.in_features])
|
| 869 |
+
self._unquantized_part_sd["x_embedder.weight"] = new_module.weight.data.clone()
|
| 870 |
+
if new_module.bias is not None:
|
| 871 |
+
new_module.bias.data.zero_()
|
| 872 |
+
new_module.bias.data.copy_(old_module.bias.data[: new_module.out_features])
|
| 873 |
+
self._unquantized_part_sd["x_embedder.bias"] = new_module.bias.data.clone()
|
| 874 |
+
self.x_embedder = new_module
|
| 875 |
+
setattr(self.config, "in_channels", self._original_in_channels)
|
| 876 |
+
|
| 877 |
+
def reset_lora(self):
|
| 878 |
+
"""
|
| 879 |
+
Resets all LoRA parameters to their default state.
|
| 880 |
+
"""
|
| 881 |
+
unquantized_part_loras = {}
|
| 882 |
+
if len(self._unquantized_part_loras) > 0 or len(unquantized_part_loras) > 0:
|
| 883 |
+
self._unquantized_part_loras = unquantized_part_loras
|
| 884 |
+
self._update_unquantized_part_lora_params(1)
|
| 885 |
+
state_dict = {k: v for k, v in self._quantized_part_sd.items() if "lora" in k}
|
| 886 |
+
quantized_part_vectors = {}
|
| 887 |
+
if len(self._quantized_part_vectors) > 0 or len(quantized_part_vectors) > 0:
|
| 888 |
+
self._quantized_part_vectors = quantized_part_vectors
|
| 889 |
+
updated_vectors = fuse_vectors(quantized_part_vectors, self._quantized_part_sd, 1)
|
| 890 |
+
state_dict.update(updated_vectors)
|
| 891 |
+
self.transformer_blocks[0].m.loadDict(state_dict, True)
|
| 892 |
+
self.reset_x_embedder()
|
| 893 |
+
|
| 894 |
+
def forward(
|
| 895 |
+
self,
|
| 896 |
+
hidden_states: torch.Tensor,
|
| 897 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 898 |
+
pooled_projections: torch.Tensor = None,
|
| 899 |
+
timestep: torch.LongTensor = None,
|
| 900 |
+
img_ids: torch.Tensor = None,
|
| 901 |
+
txt_ids: torch.Tensor = None,
|
| 902 |
+
guidance: torch.Tensor = None,
|
| 903 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 904 |
+
controlnet_block_samples=None,
|
| 905 |
+
controlnet_single_block_samples=None,
|
| 906 |
+
return_dict: bool = True,
|
| 907 |
+
controlnet_blocks_repeat: bool = False,
|
| 908 |
+
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
| 909 |
+
"""
|
| 910 |
+
Forward pass for the Nunchaku FLUX transformer model.
|
| 911 |
+
|
| 912 |
+
This method is compatible with the Diffusers pipeline and supports LoRA,
|
| 913 |
+
rotary embeddings, and ControlNet.
|
| 914 |
+
|
| 915 |
+
Parameters
|
| 916 |
+
----------
|
| 917 |
+
hidden_states : torch.FloatTensor
|
| 918 |
+
Input hidden states of shape (batch_size, channel, height, width).
|
| 919 |
+
encoder_hidden_states : torch.FloatTensor, optional
|
| 920 |
+
Conditional embeddings (e.g., prompt embeddings) of shape (batch_size, sequence_len, embed_dims).
|
| 921 |
+
pooled_projections : torch.FloatTensor, optional
|
| 922 |
+
Embeddings projected from the input conditions.
|
| 923 |
+
timestep : torch.LongTensor, optional
|
| 924 |
+
Denoising step.
|
| 925 |
+
img_ids : torch.Tensor, optional
|
| 926 |
+
Image token indices.
|
| 927 |
+
txt_ids : torch.Tensor, optional
|
| 928 |
+
Text token indices.
|
| 929 |
+
guidance : torch.Tensor, optional
|
| 930 |
+
Guidance tensor for classifier-free guidance.
|
| 931 |
+
joint_attention_kwargs : dict, optional
|
| 932 |
+
Additional kwargs for joint attention.
|
| 933 |
+
controlnet_block_samples : list[torch.Tensor], optional
|
| 934 |
+
ControlNet block samples.
|
| 935 |
+
controlnet_single_block_samples : list[torch.Tensor], optional
|
| 936 |
+
ControlNet single block samples.
|
| 937 |
+
return_dict : bool, optional
|
| 938 |
+
Whether to return a Transformer2DModelOutput (default: True).
|
| 939 |
+
controlnet_blocks_repeat : bool, optional
|
| 940 |
+
Whether to repeat ControlNet blocks (default: False).
|
| 941 |
+
|
| 942 |
+
Returns
|
| 943 |
+
-------
|
| 944 |
+
torch.FloatTensor or Transformer2DModelOutput
|
| 945 |
+
Output tensor or output object containing the sample.
|
| 946 |
+
"""
|
| 947 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 948 |
+
|
| 949 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
| 950 |
+
if guidance is not None:
|
| 951 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
| 952 |
+
else:
|
| 953 |
+
guidance = None
|
| 954 |
+
|
| 955 |
+
temb = (
|
| 956 |
+
self.time_text_embed(timestep, pooled_projections)
|
| 957 |
+
if guidance is None
|
| 958 |
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
| 959 |
+
)
|
| 960 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
| 961 |
+
|
| 962 |
+
if txt_ids.ndim == 3:
|
| 963 |
+
txt_ids = txt_ids[0]
|
| 964 |
+
if img_ids.ndim == 3:
|
| 965 |
+
img_ids = img_ids[0]
|
| 966 |
+
|
| 967 |
+
ids = torch.cat((txt_ids, img_ids), dim=0)
|
| 968 |
+
image_rotary_emb = self.pos_embed(ids)
|
| 969 |
+
|
| 970 |
+
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
| 971 |
+
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
| 972 |
+
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
|
| 973 |
+
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
|
| 974 |
+
|
| 975 |
+
nunchaku_block = self.transformer_blocks[0]
|
| 976 |
+
encoder_hidden_states, hidden_states = nunchaku_block(
|
| 977 |
+
hidden_states=hidden_states,
|
| 978 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 979 |
+
temb=temb,
|
| 980 |
+
image_rotary_emb=image_rotary_emb,
|
| 981 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 982 |
+
controlnet_block_samples=controlnet_block_samples,
|
| 983 |
+
controlnet_single_block_samples=controlnet_single_block_samples,
|
| 984 |
+
)
|
| 985 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 986 |
+
output = self.proj_out(hidden_states)
|
| 987 |
+
|
| 988 |
+
if not return_dict:
|
| 989 |
+
return (output,)
|
| 990 |
+
|
| 991 |
+
return Transformer2DModelOutput(sample=output)
|
nunchaku/models/transformers/transformer_flux_v2.py
ADDED
|
@@ -0,0 +1,646 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This module provides Nunchaku FluxTransformer2DModel and its building blocks in Python.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 12 |
+
from diffusers.models.transformers.transformer_flux import (
|
| 13 |
+
FluxAttention,
|
| 14 |
+
FluxSingleTransformerBlock,
|
| 15 |
+
FluxTransformer2DModel,
|
| 16 |
+
FluxTransformerBlock,
|
| 17 |
+
)
|
| 18 |
+
from huggingface_hub import utils
|
| 19 |
+
from torch.nn import GELU
|
| 20 |
+
|
| 21 |
+
from ...ops.fused import fused_gelu_mlp
|
| 22 |
+
from ...utils import get_precision, pad_tensor
|
| 23 |
+
from ..attention import NunchakuBaseAttention, NunchakuFeedForward
|
| 24 |
+
from ..attention_processors.flux import NunchakuFluxFA2Processor, NunchakuFluxFP16AttnProcessor
|
| 25 |
+
from ..embeddings import NunchakuFluxPosEmbed, pack_rotemb
|
| 26 |
+
from ..linear import SVDQW4A4Linear
|
| 27 |
+
from ..normalization import NunchakuAdaLayerNormZero, NunchakuAdaLayerNormZeroSingle
|
| 28 |
+
from ..utils import fuse_linears
|
| 29 |
+
from .utils import NunchakuModelLoaderMixin
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class NunchakuFluxAttention(NunchakuBaseAttention):
|
| 33 |
+
"""
|
| 34 |
+
Nunchaku-optimized FluxAttention module with quantized and fused QKV projections.
|
| 35 |
+
|
| 36 |
+
Parameters
|
| 37 |
+
----------
|
| 38 |
+
other : FluxAttention
|
| 39 |
+
The original FluxAttention module to wrap and quantize.
|
| 40 |
+
processor : str, optional
|
| 41 |
+
The attention processor to use ("flashattn2" or "nunchaku-fp16").
|
| 42 |
+
**kwargs
|
| 43 |
+
Additional arguments for quantization.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(self, other: FluxAttention, processor: str = "flashattn2", **kwargs):
|
| 47 |
+
super(NunchakuFluxAttention, self).__init__(processor)
|
| 48 |
+
self.head_dim = other.head_dim
|
| 49 |
+
self.inner_dim = other.inner_dim
|
| 50 |
+
self.query_dim = other.query_dim
|
| 51 |
+
self.use_bias = other.use_bias
|
| 52 |
+
self.dropout = other.dropout
|
| 53 |
+
self.out_dim = other.out_dim
|
| 54 |
+
self.context_pre_only = other.context_pre_only
|
| 55 |
+
self.pre_only = other.pre_only
|
| 56 |
+
self.heads = other.heads
|
| 57 |
+
self.added_kv_proj_dim = other.added_kv_proj_dim
|
| 58 |
+
self.added_proj_bias = other.added_proj_bias
|
| 59 |
+
|
| 60 |
+
self.norm_q = other.norm_q
|
| 61 |
+
self.norm_k = other.norm_k
|
| 62 |
+
|
| 63 |
+
# Fuse the QKV projections for efficiency.
|
| 64 |
+
with torch.device("meta"):
|
| 65 |
+
to_qkv = fuse_linears([other.to_q, other.to_k, other.to_v])
|
| 66 |
+
self.to_qkv = SVDQW4A4Linear.from_linear(to_qkv, **kwargs)
|
| 67 |
+
|
| 68 |
+
if not self.pre_only:
|
| 69 |
+
self.to_out = other.to_out
|
| 70 |
+
self.to_out[0] = SVDQW4A4Linear.from_linear(self.to_out[0], **kwargs)
|
| 71 |
+
|
| 72 |
+
if self.added_kv_proj_dim is not None:
|
| 73 |
+
self.norm_added_q = other.norm_added_q
|
| 74 |
+
self.norm_added_k = other.norm_added_k
|
| 75 |
+
|
| 76 |
+
# Fuse the additional QKV projections.
|
| 77 |
+
with torch.device("meta"):
|
| 78 |
+
add_qkv_proj = fuse_linears([other.add_q_proj, other.add_k_proj, other.add_v_proj])
|
| 79 |
+
self.add_qkv_proj = SVDQW4A4Linear.from_linear(add_qkv_proj, **kwargs)
|
| 80 |
+
self.to_add_out = SVDQW4A4Linear.from_linear(other.to_add_out, **kwargs)
|
| 81 |
+
|
| 82 |
+
def forward(
|
| 83 |
+
self,
|
| 84 |
+
hidden_states: torch.Tensor,
|
| 85 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 86 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 87 |
+
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor = None,
|
| 88 |
+
**kwargs,
|
| 89 |
+
):
|
| 90 |
+
"""
|
| 91 |
+
Forward pass for NunchakuFluxAttention.
|
| 92 |
+
|
| 93 |
+
Parameters
|
| 94 |
+
----------
|
| 95 |
+
hidden_states : torch.Tensor
|
| 96 |
+
Input tensor.
|
| 97 |
+
encoder_hidden_states : torch.Tensor, optional
|
| 98 |
+
Encoder hidden states for cross-attention.
|
| 99 |
+
attention_mask : torch.Tensor, optional
|
| 100 |
+
Attention mask.
|
| 101 |
+
image_rotary_emb : tuple or torch.Tensor, optional
|
| 102 |
+
Rotary embeddings for image/text tokens.
|
| 103 |
+
**kwargs
|
| 104 |
+
Additional arguments.
|
| 105 |
+
|
| 106 |
+
Returns
|
| 107 |
+
-------
|
| 108 |
+
Output of the attention processor.
|
| 109 |
+
"""
|
| 110 |
+
return self.processor(
|
| 111 |
+
attn=self,
|
| 112 |
+
hidden_states=hidden_states,
|
| 113 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 114 |
+
attention_mask=attention_mask,
|
| 115 |
+
image_rotary_emb=image_rotary_emb,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def set_processor(self, processor: str):
|
| 119 |
+
"""
|
| 120 |
+
Set the attention processor.
|
| 121 |
+
|
| 122 |
+
Parameters
|
| 123 |
+
----------
|
| 124 |
+
processor : str
|
| 125 |
+
Name of the processor ("flashattn2" or "nunchaku-fp16").
|
| 126 |
+
|
| 127 |
+
- ``"flashattn2"``: Standard FlashAttention-2. See :class:`~nunchaku.models.attention_processors.flux.NunchakuFluxFA2Processor`.
|
| 128 |
+
- ``"nunchaku-fp16"``: Uses FP16 attention accumulation, up to 1.2× faster than FlashAttention-2 on NVIDIA 30-, 40-, and 50-series GPUs. See :class:`~nunchaku.models.attention_processors.flux.NunchakuFluxFP16AttnProcessor`.
|
| 129 |
+
|
| 130 |
+
Raises
|
| 131 |
+
------
|
| 132 |
+
ValueError
|
| 133 |
+
If the processor is not supported.
|
| 134 |
+
"""
|
| 135 |
+
if processor == "flashattn2":
|
| 136 |
+
self.processor = NunchakuFluxFA2Processor()
|
| 137 |
+
elif processor == "nunchaku-fp16":
|
| 138 |
+
self.processor = NunchakuFluxFP16AttnProcessor()
|
| 139 |
+
else:
|
| 140 |
+
raise ValueError(f"Processor {processor} is not supported")
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class NunchakuFluxTransformerBlock(FluxTransformerBlock):
|
| 144 |
+
"""
|
| 145 |
+
Nunchaku-optimized FluxTransformerBlock with quantized attention and feedforward layers.
|
| 146 |
+
|
| 147 |
+
Parameters
|
| 148 |
+
----------
|
| 149 |
+
block : FluxTransformerBlock
|
| 150 |
+
The original block to wrap and quantize.
|
| 151 |
+
scale_shift : float, optional
|
| 152 |
+
Value to add to scale parameters. Default is 1.0.
|
| 153 |
+
Nunchaku may have already fused the scale_shift into the linear weights, so you may want to set it to 0.
|
| 154 |
+
**kwargs
|
| 155 |
+
Additional arguments for quantization.
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
def __init__(self, block: FluxTransformerBlock, scale_shift: float = 1, **kwargs):
|
| 159 |
+
super(FluxTransformerBlock, self).__init__()
|
| 160 |
+
self.scale_shift = scale_shift
|
| 161 |
+
|
| 162 |
+
# The scale_shift=1 from AdaLayerNormZero has already been fused into the linear weights,
|
| 163 |
+
# so we set scale_shift=0 here to avoid applying it again.
|
| 164 |
+
self.norm1 = NunchakuAdaLayerNormZero(block.norm1, scale_shift=scale_shift)
|
| 165 |
+
self.norm1_context = NunchakuAdaLayerNormZero(block.norm1_context, scale_shift=scale_shift)
|
| 166 |
+
|
| 167 |
+
self.attn = NunchakuFluxAttention(block.attn, **kwargs)
|
| 168 |
+
self.norm2 = block.norm2
|
| 169 |
+
self.norm2_context = block.norm2_context
|
| 170 |
+
self.ff = NunchakuFeedForward(block.ff, **kwargs)
|
| 171 |
+
self.ff_context = NunchakuFeedForward(block.ff_context, **kwargs)
|
| 172 |
+
|
| 173 |
+
def forward(
|
| 174 |
+
self,
|
| 175 |
+
hidden_states: torch.Tensor,
|
| 176 |
+
encoder_hidden_states: torch.Tensor,
|
| 177 |
+
temb: torch.Tensor,
|
| 178 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 179 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 180 |
+
):
|
| 181 |
+
"""
|
| 182 |
+
Forward pass for the transformer block.
|
| 183 |
+
|
| 184 |
+
Parameters
|
| 185 |
+
----------
|
| 186 |
+
hidden_states : torch.Tensor
|
| 187 |
+
Input hidden states.
|
| 188 |
+
encoder_hidden_states : torch.Tensor
|
| 189 |
+
Encoder hidden states for cross-attention.
|
| 190 |
+
temb : torch.Tensor
|
| 191 |
+
Time or conditioning embedding.
|
| 192 |
+
image_rotary_emb : tuple of torch.Tensor, optional
|
| 193 |
+
Rotary embeddings for image/text tokens.
|
| 194 |
+
joint_attention_kwargs : dict, optional
|
| 195 |
+
Additional attention arguments (not supported).
|
| 196 |
+
|
| 197 |
+
Returns
|
| 198 |
+
-------
|
| 199 |
+
tuple
|
| 200 |
+
(encoder_hidden_states, hidden_states) after block processing.
|
| 201 |
+
|
| 202 |
+
Raises
|
| 203 |
+
------
|
| 204 |
+
NotImplementedError
|
| 205 |
+
If joint_attention_kwargs is provided.
|
| 206 |
+
"""
|
| 207 |
+
if joint_attention_kwargs is not None and len(joint_attention_kwargs) > 0:
|
| 208 |
+
raise NotImplementedError("joint_attention_kwargs is not supported")
|
| 209 |
+
|
| 210 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
| 211 |
+
|
| 212 |
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
| 213 |
+
encoder_hidden_states, emb=temb
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 217 |
+
|
| 218 |
+
# Attention.
|
| 219 |
+
attention_outputs = self.attn(
|
| 220 |
+
hidden_states=norm_hidden_states,
|
| 221 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 222 |
+
image_rotary_emb=image_rotary_emb,
|
| 223 |
+
**joint_attention_kwargs,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
if len(attention_outputs) == 2:
|
| 227 |
+
attn_output, context_attn_output = attention_outputs
|
| 228 |
+
elif len(attention_outputs) == 3:
|
| 229 |
+
attn_output, context_attn_output, ip_attn_output = attention_outputs
|
| 230 |
+
|
| 231 |
+
# Process attention outputs for the `hidden_states`.
|
| 232 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
| 233 |
+
hidden_states = hidden_states + attn_output
|
| 234 |
+
|
| 235 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 236 |
+
norm_hidden_states = norm_hidden_states * scale_mlp[:, None] + shift_mlp[:, None]
|
| 237 |
+
|
| 238 |
+
ff_output = self.ff(norm_hidden_states)
|
| 239 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
| 240 |
+
|
| 241 |
+
hidden_states = hidden_states + ff_output
|
| 242 |
+
if len(attention_outputs) == 3:
|
| 243 |
+
hidden_states = hidden_states + ip_attn_output
|
| 244 |
+
|
| 245 |
+
# Process attention outputs for the `encoder_hidden_states`.
|
| 246 |
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
| 247 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
| 248 |
+
|
| 249 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
| 250 |
+
norm_encoder_hidden_states = norm_encoder_hidden_states * c_scale_mlp[:, None] + c_shift_mlp[:, None]
|
| 251 |
+
|
| 252 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
| 253 |
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
| 254 |
+
if encoder_hidden_states.dtype == torch.float16:
|
| 255 |
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
| 256 |
+
|
| 257 |
+
return encoder_hidden_states, hidden_states
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class NunchakuFluxSingleTransformerBlock(FluxSingleTransformerBlock):
|
| 261 |
+
"""
|
| 262 |
+
Nunchaku-optimized single transformer block with quantized attention and MLP.
|
| 263 |
+
|
| 264 |
+
Parameters
|
| 265 |
+
----------
|
| 266 |
+
block : FluxSingleTransformerBlock
|
| 267 |
+
The original block to wrap and quantize.
|
| 268 |
+
scale_shift : float, optional
|
| 269 |
+
Value to add to scale parameters. Default is 1.0.
|
| 270 |
+
Nunchaku may have already fused the scale_shift into the linear weights, so you may want to set it to 0.
|
| 271 |
+
**kwargs
|
| 272 |
+
Additional arguments for quantization.
|
| 273 |
+
"""
|
| 274 |
+
|
| 275 |
+
def __init__(self, block: FluxSingleTransformerBlock, scale_shift: float = 1, **kwargs):
|
| 276 |
+
super(FluxSingleTransformerBlock, self).__init__()
|
| 277 |
+
self.mlp_hidden_dim = block.mlp_hidden_dim
|
| 278 |
+
self.norm = block.norm
|
| 279 |
+
self.norm = NunchakuAdaLayerNormZeroSingle(block.norm, scale_shift=scale_shift)
|
| 280 |
+
|
| 281 |
+
self.mlp_fc1 = SVDQW4A4Linear.from_linear(block.proj_mlp, **kwargs)
|
| 282 |
+
self.act_mlp = block.act_mlp
|
| 283 |
+
self.mlp_fc2 = SVDQW4A4Linear.from_linear(block.proj_out, in_features=self.mlp_hidden_dim, **kwargs)
|
| 284 |
+
# For int4, we shift the activation of mlp_fc2 to make it unsigned.
|
| 285 |
+
self.mlp_fc2.act_unsigned = self.mlp_fc2.precision != "nvfp4"
|
| 286 |
+
|
| 287 |
+
self.attn = NunchakuFluxAttention(block.attn, **kwargs)
|
| 288 |
+
self.attn.to_out = SVDQW4A4Linear.from_linear(block.proj_out, in_features=self.mlp_fc1.in_features, **kwargs)
|
| 289 |
+
|
| 290 |
+
def forward(
|
| 291 |
+
self,
|
| 292 |
+
hidden_states: torch.Tensor,
|
| 293 |
+
temb: torch.Tensor,
|
| 294 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 295 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 296 |
+
) -> torch.Tensor:
|
| 297 |
+
"""
|
| 298 |
+
Forward pass for the single transformer block.
|
| 299 |
+
|
| 300 |
+
Parameters
|
| 301 |
+
----------
|
| 302 |
+
hidden_states : torch.Tensor
|
| 303 |
+
Input hidden states.
|
| 304 |
+
temb : torch.Tensor
|
| 305 |
+
Time or conditioning embedding.
|
| 306 |
+
image_rotary_emb : tuple of torch.Tensor, optional
|
| 307 |
+
Rotary embeddings for tokens.
|
| 308 |
+
joint_attention_kwargs : dict, optional
|
| 309 |
+
Additional attention arguments.
|
| 310 |
+
|
| 311 |
+
Returns
|
| 312 |
+
-------
|
| 313 |
+
torch.Tensor
|
| 314 |
+
Output hidden states after block processing.
|
| 315 |
+
"""
|
| 316 |
+
residual = hidden_states
|
| 317 |
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
| 318 |
+
|
| 319 |
+
# Feedforward
|
| 320 |
+
if isinstance(self.act_mlp, GELU):
|
| 321 |
+
# Use fused GELU MLP for efficiency.
|
| 322 |
+
mlp_hidden_states = fused_gelu_mlp(norm_hidden_states, self.mlp_fc1, self.mlp_fc2)
|
| 323 |
+
else:
|
| 324 |
+
# Fallback to original MLP.
|
| 325 |
+
mlp_hidden_states = self.mlp_fc1(norm_hidden_states)
|
| 326 |
+
mlp_hidden_states = self.act_mlp(mlp_hidden_states)
|
| 327 |
+
mlp_hidden_states = self.mlp_fc2(mlp_hidden_states)
|
| 328 |
+
|
| 329 |
+
# Attention
|
| 330 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 331 |
+
attn_output = self.attn(
|
| 332 |
+
hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, **joint_attention_kwargs
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
hidden_states = attn_output + mlp_hidden_states
|
| 336 |
+
gate = gate.unsqueeze(1)
|
| 337 |
+
hidden_states = gate * hidden_states
|
| 338 |
+
hidden_states = residual + hidden_states
|
| 339 |
+
if hidden_states.dtype == torch.float16:
|
| 340 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
| 341 |
+
|
| 342 |
+
return hidden_states
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
class NunchakuFluxTransformer2DModelV2(FluxTransformer2DModel, NunchakuModelLoaderMixin):
|
| 346 |
+
"""
|
| 347 |
+
Nunchaku-optimized FluxTransformer2DModel.
|
| 348 |
+
"""
|
| 349 |
+
|
| 350 |
+
def _patch_model(self, **kwargs):
|
| 351 |
+
"""
|
| 352 |
+
Patch the model with :class:`~nunchaku.models.transformers.transformer_flux_v2.NunchakuFluxTransformerBlock`
|
| 353 |
+
and :class:`~nunchaku.models.transformers.transformer_flux_v2.NunchakuFluxSingleTransformerBlock`.
|
| 354 |
+
|
| 355 |
+
Parameters
|
| 356 |
+
----------
|
| 357 |
+
**kwargs
|
| 358 |
+
Additional arguments for quantization.
|
| 359 |
+
|
| 360 |
+
Returns
|
| 361 |
+
-------
|
| 362 |
+
self : NunchakuFluxTransformer2DModelV2
|
| 363 |
+
The patched model.
|
| 364 |
+
"""
|
| 365 |
+
self.pos_embed = NunchakuFluxPosEmbed(dim=self.inner_dim, theta=10000, axes_dim=self.pos_embed.axes_dim)
|
| 366 |
+
for i, block in enumerate(self.transformer_blocks):
|
| 367 |
+
self.transformer_blocks[i] = NunchakuFluxTransformerBlock(block, scale_shift=0, **kwargs)
|
| 368 |
+
for i, block in enumerate(self.single_transformer_blocks):
|
| 369 |
+
self.single_transformer_blocks[i] = NunchakuFluxSingleTransformerBlock(block, scale_shift=0, **kwargs)
|
| 370 |
+
return self
|
| 371 |
+
|
| 372 |
+
@classmethod
|
| 373 |
+
@utils.validate_hf_hub_args
|
| 374 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
|
| 375 |
+
"""
|
| 376 |
+
Load a pretrained NunchakuFluxTransformer2DModelV2 from a safetensors file.
|
| 377 |
+
|
| 378 |
+
Parameters
|
| 379 |
+
----------
|
| 380 |
+
pretrained_model_name_or_path : str or os.PathLike
|
| 381 |
+
Path to the safetensors file. It can be a local file or a remote HuggingFace path.
|
| 382 |
+
**kwargs
|
| 383 |
+
Additional arguments (e.g., device, torch_dtype).
|
| 384 |
+
|
| 385 |
+
Returns
|
| 386 |
+
-------
|
| 387 |
+
NunchakuFluxTransformer2DModelV2
|
| 388 |
+
The loaded and quantized model.
|
| 389 |
+
|
| 390 |
+
Raises
|
| 391 |
+
------
|
| 392 |
+
NotImplementedError
|
| 393 |
+
If offload is requested.
|
| 394 |
+
AssertionError
|
| 395 |
+
If the file is not a safetensors file.
|
| 396 |
+
"""
|
| 397 |
+
device = kwargs.get("device", "cpu")
|
| 398 |
+
offload = kwargs.get("offload", False)
|
| 399 |
+
|
| 400 |
+
if offload:
|
| 401 |
+
raise NotImplementedError("Offload is not supported for FluxTransformer2DModelV2")
|
| 402 |
+
|
| 403 |
+
torch_dtype = kwargs.get("torch_dtype", torch.bfloat16)
|
| 404 |
+
|
| 405 |
+
if isinstance(pretrained_model_name_or_path, str):
|
| 406 |
+
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
|
| 407 |
+
|
| 408 |
+
assert pretrained_model_name_or_path.is_file() or pretrained_model_name_or_path.name.endswith(
|
| 409 |
+
(".safetensors", ".sft")
|
| 410 |
+
), "Only safetensors are supported"
|
| 411 |
+
transformer, model_state_dict, metadata = cls._build_model(pretrained_model_name_or_path, **kwargs)
|
| 412 |
+
quantization_config = json.loads(metadata.get("quantization_config", "{}"))
|
| 413 |
+
rank = quantization_config.get("rank", 32)
|
| 414 |
+
transformer = transformer.to(torch_dtype)
|
| 415 |
+
|
| 416 |
+
precision = get_precision()
|
| 417 |
+
if precision == "fp4":
|
| 418 |
+
precision = "nvfp4"
|
| 419 |
+
transformer._patch_model(precision=precision, rank=rank)
|
| 420 |
+
|
| 421 |
+
transformer = transformer.to_empty(device=device)
|
| 422 |
+
converted_state_dict = convert_flux_state_dict(model_state_dict)
|
| 423 |
+
|
| 424 |
+
state_dict = transformer.state_dict()
|
| 425 |
+
|
| 426 |
+
for k in state_dict.keys():
|
| 427 |
+
if k not in converted_state_dict:
|
| 428 |
+
assert ".wcscales" in k
|
| 429 |
+
converted_state_dict[k] = torch.ones_like(state_dict[k])
|
| 430 |
+
else:
|
| 431 |
+
assert state_dict[k].dtype == converted_state_dict[k].dtype
|
| 432 |
+
|
| 433 |
+
# Load the wtscale from the converted state dict.
|
| 434 |
+
for n, m in transformer.named_modules():
|
| 435 |
+
if isinstance(m, SVDQW4A4Linear):
|
| 436 |
+
if m.wtscale is not None:
|
| 437 |
+
m.wtscale = converted_state_dict.pop(f"{n}.wtscale", 1.0)
|
| 438 |
+
|
| 439 |
+
transformer.load_state_dict(converted_state_dict)
|
| 440 |
+
|
| 441 |
+
return transformer
|
| 442 |
+
|
| 443 |
+
def forward(
|
| 444 |
+
self,
|
| 445 |
+
hidden_states: torch.Tensor,
|
| 446 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 447 |
+
pooled_projections: torch.Tensor = None,
|
| 448 |
+
timestep: torch.LongTensor = None,
|
| 449 |
+
img_ids: torch.Tensor = None,
|
| 450 |
+
txt_ids: torch.Tensor = None,
|
| 451 |
+
guidance: torch.Tensor = None,
|
| 452 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 453 |
+
controlnet_block_samples=None,
|
| 454 |
+
controlnet_single_block_samples=None,
|
| 455 |
+
return_dict: bool = True,
|
| 456 |
+
controlnet_blocks_repeat: bool = False,
|
| 457 |
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
| 458 |
+
"""
|
| 459 |
+
Forward pass for the NunchakuFluxTransformer2DModelV2.
|
| 460 |
+
|
| 461 |
+
Parameters
|
| 462 |
+
----------
|
| 463 |
+
hidden_states : torch.Tensor
|
| 464 |
+
Input hidden states of shape (batch_size, image_sequence_length, in_channels).
|
| 465 |
+
encoder_hidden_states : torch.Tensor, optional
|
| 466 |
+
Conditional embeddings (e.g., from text).
|
| 467 |
+
pooled_projections : torch.Tensor, optional
|
| 468 |
+
Projected embeddings from input conditions.
|
| 469 |
+
timestep : torch.LongTensor, optional
|
| 470 |
+
Denoising step.
|
| 471 |
+
img_ids : torch.Tensor, optional
|
| 472 |
+
Image token IDs.
|
| 473 |
+
txt_ids : torch.Tensor, optional
|
| 474 |
+
Text token IDs.
|
| 475 |
+
guidance : torch.Tensor, optional
|
| 476 |
+
Guidance tensor for classifier-free guidance.
|
| 477 |
+
joint_attention_kwargs : dict, optional
|
| 478 |
+
Additional attention arguments.
|
| 479 |
+
controlnet_block_samples : any, optional
|
| 480 |
+
Not supported.
|
| 481 |
+
controlnet_single_block_samples : any, optional
|
| 482 |
+
Not supported.
|
| 483 |
+
return_dict : bool, optional
|
| 484 |
+
Whether to return a Transformer2DModelOutput (default: True).
|
| 485 |
+
controlnet_blocks_repeat : bool, optional
|
| 486 |
+
Not supported.
|
| 487 |
+
|
| 488 |
+
Returns
|
| 489 |
+
-------
|
| 490 |
+
Transformer2DModelOutput or tuple
|
| 491 |
+
Output sample tensor or output tuple.
|
| 492 |
+
|
| 493 |
+
Raises
|
| 494 |
+
------
|
| 495 |
+
NotImplementedError
|
| 496 |
+
If controlnet is requested.
|
| 497 |
+
"""
|
| 498 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 499 |
+
|
| 500 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
| 501 |
+
if guidance is not None:
|
| 502 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
| 503 |
+
|
| 504 |
+
temb = (
|
| 505 |
+
self.time_text_embed(timestep, pooled_projections)
|
| 506 |
+
if guidance is None
|
| 507 |
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
| 508 |
+
)
|
| 509 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
| 510 |
+
|
| 511 |
+
if txt_ids.ndim == 3:
|
| 512 |
+
txt_ids = txt_ids[0]
|
| 513 |
+
if img_ids.ndim == 3:
|
| 514 |
+
img_ids = img_ids[0]
|
| 515 |
+
|
| 516 |
+
ids = torch.cat((txt_ids, img_ids), dim=0)
|
| 517 |
+
image_rotary_emb = self.pos_embed(ids)
|
| 518 |
+
|
| 519 |
+
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
| 520 |
+
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
| 521 |
+
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
|
| 522 |
+
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
|
| 523 |
+
|
| 524 |
+
txt_tokens = encoder_hidden_states.shape[1]
|
| 525 |
+
img_tokens = hidden_states.shape[1]
|
| 526 |
+
|
| 527 |
+
assert image_rotary_emb.ndim == 6
|
| 528 |
+
assert image_rotary_emb.shape[0] == 1
|
| 529 |
+
assert image_rotary_emb.shape[1] == 1
|
| 530 |
+
assert image_rotary_emb.shape[2] == 1 * (txt_tokens + img_tokens)
|
| 531 |
+
# [1, tokens, head_dim / 2, 1, 2] (sincos)
|
| 532 |
+
image_rotary_emb = image_rotary_emb.reshape([1, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
|
| 533 |
+
rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] # .to(self.dtype)
|
| 534 |
+
rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype)
|
| 535 |
+
rotary_emb_single = image_rotary_emb
|
| 536 |
+
|
| 537 |
+
rotary_emb_txt = pack_rotemb(pad_tensor(rotary_emb_txt, 256, 1))
|
| 538 |
+
rotary_emb_img = pack_rotemb(pad_tensor(rotary_emb_img, 256, 1))
|
| 539 |
+
rotary_emb_single = pack_rotemb(pad_tensor(rotary_emb_single, 256, 1))
|
| 540 |
+
|
| 541 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
| 542 |
+
encoder_hidden_states, hidden_states = block(
|
| 543 |
+
hidden_states=hidden_states,
|
| 544 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 545 |
+
temb=temb,
|
| 546 |
+
image_rotary_emb=(rotary_emb_img, rotary_emb_txt),
|
| 547 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
# Controlnet residual (not supported for now)
|
| 551 |
+
if controlnet_block_samples is not None:
|
| 552 |
+
raise NotImplementedError("Controlnet is not supported for FluxTransformer2DModelV2 for now")
|
| 553 |
+
|
| 554 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 555 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
| 556 |
+
hidden_states = block(
|
| 557 |
+
hidden_states=hidden_states,
|
| 558 |
+
temb=temb,
|
| 559 |
+
image_rotary_emb=rotary_emb_single,
|
| 560 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
# Controlnet residual (not supported for now)
|
| 564 |
+
if controlnet_single_block_samples is not None:
|
| 565 |
+
raise NotImplementedError("Controlnet is not supported for FluxTransformer2DModelV2 for now")
|
| 566 |
+
|
| 567 |
+
hidden_states = hidden_states[:, txt_tokens:]
|
| 568 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 569 |
+
output = self.proj_out(hidden_states)
|
| 570 |
+
|
| 571 |
+
if not return_dict:
|
| 572 |
+
return (output,)
|
| 573 |
+
|
| 574 |
+
return Transformer2DModelOutput(sample=output)
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
def convert_flux_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
| 578 |
+
"""
|
| 579 |
+
Convert a state dict from the :class:`~nunchaku.models.transformers.transformer_flux.NunchakuFluxTransformer2dModel`
|
| 580 |
+
format to :class:`~nunchaku.models.transformers.transformer_flux_v2.NunchakuFluxTransformer2DModelV2` format.
|
| 581 |
+
|
| 582 |
+
Parameters
|
| 583 |
+
----------
|
| 584 |
+
state_dict : dict[str, torch.Tensor]
|
| 585 |
+
The original state dict.
|
| 586 |
+
|
| 587 |
+
Returns
|
| 588 |
+
-------
|
| 589 |
+
dict[str, torch.Tensor]
|
| 590 |
+
The converted state dict compatible with :class:`~nunchaku.models.transformers.transformer_flux_v2.NunchakuFluxTransformer2DModelV2`.
|
| 591 |
+
"""
|
| 592 |
+
new_state_dict = {}
|
| 593 |
+
for k, v in state_dict.items():
|
| 594 |
+
if "single_transformer_blocks." in k:
|
| 595 |
+
if ".qkv_proj." in k:
|
| 596 |
+
new_k = k.replace(".qkv_proj.", ".attn.to_qkv.")
|
| 597 |
+
elif ".out_proj." in k:
|
| 598 |
+
new_k = k.replace(".out_proj.", ".attn.to_out.")
|
| 599 |
+
elif ".norm_q." in k or ".norm_k." in k:
|
| 600 |
+
new_k = k.replace(".norm_k.", ".attn.norm_k.")
|
| 601 |
+
new_k = new_k.replace(".norm_q.", ".attn.norm_q.")
|
| 602 |
+
else:
|
| 603 |
+
new_k = k
|
| 604 |
+
new_k = new_k.replace(".lora_down", ".proj_down")
|
| 605 |
+
new_k = new_k.replace(".lora_up", ".proj_up")
|
| 606 |
+
if ".smooth_orig" in k:
|
| 607 |
+
new_k = new_k.replace(".smooth_orig", ".smooth_factor_orig")
|
| 608 |
+
elif ".smooth" in k:
|
| 609 |
+
new_k = new_k.replace(".smooth", ".smooth_factor")
|
| 610 |
+
new_state_dict[new_k] = v
|
| 611 |
+
elif "transformer_blocks." in k:
|
| 612 |
+
if ".mlp_context_fc1" in k:
|
| 613 |
+
new_k = k.replace(".mlp_context_fc1.", ".ff_context.net.0.proj.")
|
| 614 |
+
elif ".mlp_context_fc2" in k:
|
| 615 |
+
new_k = k.replace(".mlp_context_fc2.", ".ff_context.net.2.")
|
| 616 |
+
elif ".mlp_fc1" in k:
|
| 617 |
+
new_k = k.replace(".mlp_fc1.", ".ff.net.0.proj.")
|
| 618 |
+
elif ".mlp_fc2" in k:
|
| 619 |
+
new_k = k.replace(".mlp_fc2.", ".ff.net.2.")
|
| 620 |
+
elif ".qkv_proj_context." in k:
|
| 621 |
+
new_k = k.replace(".qkv_proj_context.", ".attn.add_qkv_proj.")
|
| 622 |
+
elif ".qkv_proj." in k:
|
| 623 |
+
new_k = k.replace(".qkv_proj.", ".attn.to_qkv.")
|
| 624 |
+
elif ".norm_q." in k or ".norm_k." in k:
|
| 625 |
+
new_k = k.replace(".norm_k.", ".attn.norm_k.")
|
| 626 |
+
new_k = new_k.replace(".norm_q.", ".attn.norm_q.")
|
| 627 |
+
elif ".norm_added_q." in k or ".norm_added_k." in k:
|
| 628 |
+
new_k = k.replace(".norm_added_k.", ".attn.norm_added_k.")
|
| 629 |
+
new_k = new_k.replace(".norm_added_q.", ".attn.norm_added_q.")
|
| 630 |
+
elif ".out_proj." in k:
|
| 631 |
+
new_k = k.replace(".out_proj.", ".attn.to_out.0.")
|
| 632 |
+
elif ".out_proj_context." in k:
|
| 633 |
+
new_k = k.replace(".out_proj_context.", ".attn.to_add_out.")
|
| 634 |
+
else:
|
| 635 |
+
new_k = k
|
| 636 |
+
new_k = new_k.replace(".lora_down", ".proj_down")
|
| 637 |
+
new_k = new_k.replace(".lora_up", ".proj_up")
|
| 638 |
+
if ".smooth_orig" in k:
|
| 639 |
+
new_k = new_k.replace(".smooth_orig", ".smooth_factor_orig")
|
| 640 |
+
elif ".smooth" in k:
|
| 641 |
+
new_k = new_k.replace(".smooth", ".smooth_factor")
|
| 642 |
+
new_state_dict[new_k] = v
|
| 643 |
+
else:
|
| 644 |
+
new_state_dict[k] = v
|
| 645 |
+
|
| 646 |
+
return new_state_dict
|
nunchaku/models/transformers/transformer_qwenimage.py
ADDED
|
@@ -0,0 +1,601 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This module provides implementations of NunchakuQwenImageTransformer2DModel and its building blocks.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import gc
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 10 |
+
from warnings import warn
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from diffusers.models.attention_processor import Attention
|
| 14 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 15 |
+
from diffusers.models.transformers.transformer_qwenimage import (
|
| 16 |
+
QwenEmbedRope,
|
| 17 |
+
QwenImageTransformer2DModel,
|
| 18 |
+
QwenImageTransformerBlock,
|
| 19 |
+
)
|
| 20 |
+
from huggingface_hub import utils
|
| 21 |
+
|
| 22 |
+
from ...utils import get_precision
|
| 23 |
+
from ..attention import NunchakuBaseAttention, NunchakuFeedForward
|
| 24 |
+
from ..attention_processors.qwenimage import NunchakuQwenImageNaiveFA2Processor
|
| 25 |
+
from ..linear import AWQW4A16Linear, SVDQW4A4Linear
|
| 26 |
+
from ..utils import CPUOffloadManager, fuse_linears
|
| 27 |
+
from .utils import NunchakuModelLoaderMixin
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class NunchakuQwenAttention(NunchakuBaseAttention):
|
| 31 |
+
"""
|
| 32 |
+
Nunchaku-optimized quantized attention module for QwenImage.
|
| 33 |
+
|
| 34 |
+
Parameters
|
| 35 |
+
----------
|
| 36 |
+
other : Attention
|
| 37 |
+
The original QwenImage Attention module to wrap and quantize.
|
| 38 |
+
processor : str, default="flashattn2"
|
| 39 |
+
The attention processor to use.
|
| 40 |
+
**kwargs
|
| 41 |
+
Additional arguments for quantization.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, other: Attention, processor: str = "flashattn2", **kwargs):
|
| 45 |
+
super(NunchakuQwenAttention, self).__init__(processor)
|
| 46 |
+
self.inner_dim = other.inner_dim
|
| 47 |
+
self.inner_kv_dim = other.inner_kv_dim
|
| 48 |
+
self.query_dim = other.query_dim
|
| 49 |
+
self.use_bias = other.use_bias
|
| 50 |
+
self.is_cross_attention = other.is_cross_attention
|
| 51 |
+
self.cross_attention_dim = other.cross_attention_dim
|
| 52 |
+
self.upcast_attention = other.upcast_attention
|
| 53 |
+
self.upcast_softmax = other.upcast_softmax
|
| 54 |
+
self.rescale_output_factor = other.rescale_output_factor
|
| 55 |
+
self.residual_connection = other.residual_connection
|
| 56 |
+
self.dropout = other.dropout
|
| 57 |
+
self.fused_projections = other.fused_projections
|
| 58 |
+
self.out_dim = other.out_dim
|
| 59 |
+
self.out_context_dim = other.out_context_dim
|
| 60 |
+
self.context_pre_only = other.context_pre_only
|
| 61 |
+
self.pre_only = other.pre_only
|
| 62 |
+
self.is_causal = other.is_causal
|
| 63 |
+
self.scale_qk = other.scale_qk
|
| 64 |
+
self.scale = other.scale
|
| 65 |
+
self.heads = other.heads
|
| 66 |
+
self.sliceable_head_dim = other.sliceable_head_dim
|
| 67 |
+
self.added_kv_proj_dim = other.added_kv_proj_dim
|
| 68 |
+
self.only_cross_attention = other.only_cross_attention
|
| 69 |
+
self.group_norm = other.group_norm
|
| 70 |
+
self.spatial_norm = other.spatial_norm
|
| 71 |
+
|
| 72 |
+
self.norm_cross = other.norm_cross
|
| 73 |
+
|
| 74 |
+
self.norm_q = other.norm_q
|
| 75 |
+
self.norm_k = other.norm_k
|
| 76 |
+
self.norm_added_q = other.norm_added_q
|
| 77 |
+
self.norm_added_k = other.norm_added_k
|
| 78 |
+
|
| 79 |
+
# Fuse the QKV projections for quantization
|
| 80 |
+
with torch.device("meta"):
|
| 81 |
+
to_qkv = fuse_linears([other.to_q, other.to_k, other.to_v])
|
| 82 |
+
self.to_qkv = SVDQW4A4Linear.from_linear(to_qkv, **kwargs)
|
| 83 |
+
self.to_out = other.to_out
|
| 84 |
+
self.to_out[0] = SVDQW4A4Linear.from_linear(self.to_out[0], **kwargs)
|
| 85 |
+
|
| 86 |
+
assert self.added_kv_proj_dim is not None
|
| 87 |
+
# Fuse the additional QKV projections
|
| 88 |
+
with torch.device("meta"):
|
| 89 |
+
add_qkv_proj = fuse_linears([other.add_q_proj, other.add_k_proj, other.add_v_proj])
|
| 90 |
+
self.add_qkv_proj = SVDQW4A4Linear.from_linear(add_qkv_proj, **kwargs)
|
| 91 |
+
self.to_add_out = SVDQW4A4Linear.from_linear(other.to_add_out, **kwargs)
|
| 92 |
+
|
| 93 |
+
def forward(
|
| 94 |
+
self,
|
| 95 |
+
hidden_states: torch.FloatTensor,
|
| 96 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
| 97 |
+
encoder_hidden_states_mask: torch.FloatTensor = None,
|
| 98 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 99 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 100 |
+
**kwargs,
|
| 101 |
+
):
|
| 102 |
+
"""
|
| 103 |
+
Forward pass for NunchakuQwenAttention.
|
| 104 |
+
|
| 105 |
+
Parameters
|
| 106 |
+
----------
|
| 107 |
+
hidden_states : torch.FloatTensor
|
| 108 |
+
Image stream input.
|
| 109 |
+
encoder_hidden_states : torch.FloatTensor, optional
|
| 110 |
+
Text stream input.
|
| 111 |
+
encoder_hidden_states_mask : torch.FloatTensor, optional
|
| 112 |
+
Mask for encoder hidden states.
|
| 113 |
+
attention_mask : torch.FloatTensor, optional
|
| 114 |
+
Attention mask.
|
| 115 |
+
image_rotary_emb : torch.Tensor, optional
|
| 116 |
+
Rotary embedding for images.
|
| 117 |
+
**kwargs
|
| 118 |
+
Additional arguments.
|
| 119 |
+
|
| 120 |
+
Returns
|
| 121 |
+
-------
|
| 122 |
+
tuple
|
| 123 |
+
Attention outputs for image and text streams.
|
| 124 |
+
"""
|
| 125 |
+
return self.processor(
|
| 126 |
+
self,
|
| 127 |
+
hidden_states,
|
| 128 |
+
encoder_hidden_states,
|
| 129 |
+
encoder_hidden_states_mask,
|
| 130 |
+
attention_mask,
|
| 131 |
+
image_rotary_emb,
|
| 132 |
+
**kwargs,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def set_processor(self, processor: str):
|
| 136 |
+
"""
|
| 137 |
+
Set the attention processor.
|
| 138 |
+
|
| 139 |
+
Parameters
|
| 140 |
+
----------
|
| 141 |
+
processor : str
|
| 142 |
+
Name of the processor to use. Only "flashattn2" is supported for now. See :class:`~nunchaku.models.attention_processors.qwenimage.NunchakuQwenImageNaiveFA2Processor`.
|
| 143 |
+
|
| 144 |
+
Raises
|
| 145 |
+
------
|
| 146 |
+
ValueError
|
| 147 |
+
If the processor is not supported.
|
| 148 |
+
"""
|
| 149 |
+
if processor == "flashattn2":
|
| 150 |
+
self.processor = NunchakuQwenImageNaiveFA2Processor()
|
| 151 |
+
else:
|
| 152 |
+
raise ValueError(f"Processor {processor} is not supported")
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class NunchakuQwenImageTransformerBlock(QwenImageTransformerBlock):
|
| 156 |
+
"""
|
| 157 |
+
Quantized QwenImage Transformer Block.
|
| 158 |
+
|
| 159 |
+
This block supports quantized linear layers and joint attention for image and text streams.
|
| 160 |
+
|
| 161 |
+
Parameters
|
| 162 |
+
----------
|
| 163 |
+
other : QwenImageTransformerBlock
|
| 164 |
+
The original transformer block to wrap and quantize.
|
| 165 |
+
scale_shift : float, default=1.0
|
| 166 |
+
Value to add to scale parameters. Default is 1.0.
|
| 167 |
+
Nunchaku may have already fused the scale_shift into the linear weights, so you may want to set it to 0.
|
| 168 |
+
**kwargs
|
| 169 |
+
Additional arguments for quantization.
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
def __init__(self, other: QwenImageTransformerBlock, scale_shift: float = 1.0, **kwargs):
|
| 173 |
+
super(QwenImageTransformerBlock, self).__init__()
|
| 174 |
+
|
| 175 |
+
self.dim = other.dim
|
| 176 |
+
self.img_mod = other.img_mod
|
| 177 |
+
self.img_mod[1] = AWQW4A16Linear.from_linear(other.img_mod[1], **kwargs)
|
| 178 |
+
self.img_norm1 = other.img_norm1
|
| 179 |
+
self.attn = NunchakuQwenAttention(other.attn, **kwargs)
|
| 180 |
+
self.img_norm2 = other.img_norm2
|
| 181 |
+
self.img_mlp = NunchakuFeedForward(other.img_mlp, **kwargs)
|
| 182 |
+
|
| 183 |
+
# Text processing modules
|
| 184 |
+
self.txt_mod = other.txt_mod
|
| 185 |
+
self.txt_mod[1] = AWQW4A16Linear.from_linear(other.txt_mod[1], **kwargs)
|
| 186 |
+
self.txt_norm1 = other.txt_norm1
|
| 187 |
+
# Text doesn't need separate attention - it's handled by img_attn joint computation
|
| 188 |
+
self.txt_norm2 = other.txt_norm2
|
| 189 |
+
self.txt_mlp = NunchakuFeedForward(other.txt_mlp, **kwargs)
|
| 190 |
+
|
| 191 |
+
self.scale_shift = scale_shift
|
| 192 |
+
|
| 193 |
+
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 194 |
+
"""
|
| 195 |
+
Apply modulation to input tensor.
|
| 196 |
+
|
| 197 |
+
Parameters
|
| 198 |
+
----------
|
| 199 |
+
x : torch.Tensor
|
| 200 |
+
Input tensor.
|
| 201 |
+
mod_params : torch.Tensor
|
| 202 |
+
Modulation parameters.
|
| 203 |
+
|
| 204 |
+
Returns
|
| 205 |
+
-------
|
| 206 |
+
tuple
|
| 207 |
+
Modulated tensor and gate tensor.
|
| 208 |
+
"""
|
| 209 |
+
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
| 210 |
+
if self.scale_shift != 0:
|
| 211 |
+
scale.add_(self.scale_shift)
|
| 212 |
+
return x * scale.unsqueeze(1) + shift.unsqueeze(1), gate.unsqueeze(1)
|
| 213 |
+
|
| 214 |
+
def forward(
|
| 215 |
+
self,
|
| 216 |
+
hidden_states: torch.Tensor,
|
| 217 |
+
encoder_hidden_states: torch.Tensor,
|
| 218 |
+
encoder_hidden_states_mask: torch.Tensor,
|
| 219 |
+
temb: torch.Tensor,
|
| 220 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 221 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 222 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 223 |
+
"""
|
| 224 |
+
Forward pass for NunchakuQwenImageTransformerBlock.
|
| 225 |
+
|
| 226 |
+
Parameters
|
| 227 |
+
----------
|
| 228 |
+
hidden_states : torch.Tensor
|
| 229 |
+
Image stream input.
|
| 230 |
+
encoder_hidden_states : torch.Tensor
|
| 231 |
+
Text stream input.
|
| 232 |
+
encoder_hidden_states_mask : torch.Tensor
|
| 233 |
+
Mask for encoder hidden states.
|
| 234 |
+
temb : torch.Tensor
|
| 235 |
+
Temporal embedding.
|
| 236 |
+
image_rotary_emb : tuple of torch.Tensor, optional
|
| 237 |
+
Rotary embedding for images.
|
| 238 |
+
joint_attention_kwargs : dict, optional
|
| 239 |
+
Additional arguments for joint attention.
|
| 240 |
+
|
| 241 |
+
Returns
|
| 242 |
+
-------
|
| 243 |
+
tuple
|
| 244 |
+
Updated encoder_hidden_states and hidden_states.
|
| 245 |
+
"""
|
| 246 |
+
# Get modulation parameters for both streams
|
| 247 |
+
img_mod_params = self.img_mod(temb) # [B, 6*dim]
|
| 248 |
+
txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
|
| 249 |
+
|
| 250 |
+
# nunchaku's mod_params is [B, 6*dim] instead of [B, dim*6]
|
| 251 |
+
img_mod_params = (
|
| 252 |
+
img_mod_params.view(img_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(img_mod_params.shape[0], -1)
|
| 253 |
+
)
|
| 254 |
+
txt_mod_params = (
|
| 255 |
+
txt_mod_params.view(txt_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(txt_mod_params.shape[0], -1)
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
|
| 259 |
+
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
|
| 260 |
+
|
| 261 |
+
# Process image stream - norm1 + modulation
|
| 262 |
+
img_normed = self.img_norm1(hidden_states)
|
| 263 |
+
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
|
| 264 |
+
|
| 265 |
+
# Process text stream - norm1 + modulation
|
| 266 |
+
txt_normed = self.txt_norm1(encoder_hidden_states)
|
| 267 |
+
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
|
| 268 |
+
|
| 269 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 270 |
+
attn_output = self.attn(
|
| 271 |
+
hidden_states=img_modulated,
|
| 272 |
+
encoder_hidden_states=txt_modulated,
|
| 273 |
+
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
| 274 |
+
image_rotary_emb=image_rotary_emb,
|
| 275 |
+
**joint_attention_kwargs,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
# QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
|
| 279 |
+
img_attn_output, txt_attn_output = attn_output
|
| 280 |
+
|
| 281 |
+
# Apply attention gates and add residual (like in Megatron)
|
| 282 |
+
hidden_states = hidden_states + img_gate1 * img_attn_output
|
| 283 |
+
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
| 284 |
+
|
| 285 |
+
# Process image stream - norm2 + MLP
|
| 286 |
+
img_normed2 = self.img_norm2(hidden_states)
|
| 287 |
+
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
|
| 288 |
+
img_mlp_output = self.img_mlp(img_modulated2)
|
| 289 |
+
hidden_states = hidden_states + img_gate2 * img_mlp_output
|
| 290 |
+
|
| 291 |
+
# Process text stream - norm2 + MLP
|
| 292 |
+
txt_normed2 = self.txt_norm2(encoder_hidden_states)
|
| 293 |
+
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
|
| 294 |
+
txt_mlp_output = self.txt_mlp(txt_modulated2)
|
| 295 |
+
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
|
| 296 |
+
|
| 297 |
+
# Clip to prevent overflow for fp16
|
| 298 |
+
if encoder_hidden_states.dtype == torch.float16:
|
| 299 |
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
| 300 |
+
if hidden_states.dtype == torch.float16:
|
| 301 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
| 302 |
+
|
| 303 |
+
return encoder_hidden_states, hidden_states
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuModelLoaderMixin):
|
| 307 |
+
"""
|
| 308 |
+
Quantized QwenImage Transformer2DModel.
|
| 309 |
+
|
| 310 |
+
This model supports quantized transformer blocks and optional CPU offloading for memory efficiency.
|
| 311 |
+
|
| 312 |
+
Parameters
|
| 313 |
+
----------
|
| 314 |
+
*args
|
| 315 |
+
Positional arguments for the base model.
|
| 316 |
+
**kwargs
|
| 317 |
+
Keyword arguments for the base model and quantization.
|
| 318 |
+
|
| 319 |
+
Attributes
|
| 320 |
+
----------
|
| 321 |
+
offload : bool
|
| 322 |
+
Whether CPU offloading is enabled.
|
| 323 |
+
offload_manager : CPUOffloadManager or None
|
| 324 |
+
Manager for offloading transformer blocks.
|
| 325 |
+
_is_initialized : bool
|
| 326 |
+
Whether the model has been patched for quantization.
|
| 327 |
+
"""
|
| 328 |
+
|
| 329 |
+
def __init__(self, *args, **kwargs):
|
| 330 |
+
self.offload = kwargs.pop("offload", False)
|
| 331 |
+
self.offload_manager = None
|
| 332 |
+
self._is_initialized = False
|
| 333 |
+
super().__init__(*args, **kwargs)
|
| 334 |
+
|
| 335 |
+
def _patch_model(self, **kwargs):
|
| 336 |
+
"""
|
| 337 |
+
Patch the transformer blocks for quantization.
|
| 338 |
+
|
| 339 |
+
Parameters
|
| 340 |
+
----------
|
| 341 |
+
**kwargs
|
| 342 |
+
Additional arguments for quantization.
|
| 343 |
+
|
| 344 |
+
Returns
|
| 345 |
+
-------
|
| 346 |
+
self
|
| 347 |
+
"""
|
| 348 |
+
for i, block in enumerate(self.transformer_blocks):
|
| 349 |
+
self.transformer_blocks[i] = NunchakuQwenImageTransformerBlock(block, scale_shift=0, **kwargs)
|
| 350 |
+
self._is_initialized = True
|
| 351 |
+
return self
|
| 352 |
+
|
| 353 |
+
@classmethod
|
| 354 |
+
@utils.validate_hf_hub_args
|
| 355 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
|
| 356 |
+
"""
|
| 357 |
+
Load a quantized model from a pretrained checkpoint.
|
| 358 |
+
|
| 359 |
+
Parameters
|
| 360 |
+
----------
|
| 361 |
+
pretrained_model_name_or_path : str or os.PathLike
|
| 362 |
+
Path to the pretrained model checkpoint. It can be a local file or a remote HuggingFace path.
|
| 363 |
+
**kwargs
|
| 364 |
+
Additional arguments for loading and quantization.
|
| 365 |
+
|
| 366 |
+
Returns
|
| 367 |
+
-------
|
| 368 |
+
NunchakuQwenImageTransformer2DModel
|
| 369 |
+
The loaded and quantized model.
|
| 370 |
+
|
| 371 |
+
Raises
|
| 372 |
+
------
|
| 373 |
+
AssertionError
|
| 374 |
+
If the checkpoint is not a safetensors file.
|
| 375 |
+
"""
|
| 376 |
+
device = kwargs.get("device", "cpu")
|
| 377 |
+
offload = kwargs.get("offload", False)
|
| 378 |
+
|
| 379 |
+
torch_dtype = kwargs.get("torch_dtype", torch.bfloat16)
|
| 380 |
+
|
| 381 |
+
if isinstance(pretrained_model_name_or_path, str):
|
| 382 |
+
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
|
| 383 |
+
|
| 384 |
+
assert pretrained_model_name_or_path.is_file() or pretrained_model_name_or_path.name.endswith(
|
| 385 |
+
(".safetensors", ".sft")
|
| 386 |
+
), "Only safetensors are supported"
|
| 387 |
+
transformer, model_state_dict, metadata = cls._build_model(pretrained_model_name_or_path, **kwargs)
|
| 388 |
+
quantization_config = json.loads(metadata.get("quantization_config", "{}"))
|
| 389 |
+
config = json.loads(metadata.get("config", "{}"))
|
| 390 |
+
rank = quantization_config.get("rank", 32)
|
| 391 |
+
transformer = transformer.to(torch_dtype)
|
| 392 |
+
|
| 393 |
+
precision = get_precision()
|
| 394 |
+
if precision == "fp4":
|
| 395 |
+
precision = "nvfp4"
|
| 396 |
+
transformer._patch_model(precision=precision, rank=rank)
|
| 397 |
+
|
| 398 |
+
transformer = transformer.to_empty(device=device)
|
| 399 |
+
# need to re-init the pos_embed as to_empty does not work on it
|
| 400 |
+
transformer.pos_embed = QwenEmbedRope(
|
| 401 |
+
theta=10000, axes_dim=list(config.get("axes_dims_rope", [16, 56, 56])), scale_rope=True
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
state_dict = transformer.state_dict()
|
| 405 |
+
for k in state_dict.keys():
|
| 406 |
+
if k not in model_state_dict:
|
| 407 |
+
assert ".wcscales" in k
|
| 408 |
+
model_state_dict[k] = torch.ones_like(state_dict[k])
|
| 409 |
+
else:
|
| 410 |
+
assert state_dict[k].dtype == model_state_dict[k].dtype
|
| 411 |
+
|
| 412 |
+
# load the wtscale from the state dict, as it is a float on CPU
|
| 413 |
+
for n, m in transformer.named_modules():
|
| 414 |
+
if isinstance(m, SVDQW4A4Linear):
|
| 415 |
+
if m.wtscale is not None:
|
| 416 |
+
m.wtscale = model_state_dict.pop(f"{n}.wtscale", 1.0)
|
| 417 |
+
transformer.load_state_dict(model_state_dict)
|
| 418 |
+
transformer.set_offload(offload)
|
| 419 |
+
|
| 420 |
+
return transformer
|
| 421 |
+
|
| 422 |
+
def set_offload(self, offload: bool, **kwargs):
|
| 423 |
+
"""
|
| 424 |
+
Enable or disable asynchronous CPU offloading for transformer blocks.
|
| 425 |
+
|
| 426 |
+
Parameters
|
| 427 |
+
----------
|
| 428 |
+
offload : bool
|
| 429 |
+
Whether to enable offloading.
|
| 430 |
+
**kwargs
|
| 431 |
+
Additional arguments for offload manager.
|
| 432 |
+
|
| 433 |
+
See Also
|
| 434 |
+
--------
|
| 435 |
+
:class:`~nunchaku.models.utils.CPUOffloadManager`
|
| 436 |
+
"""
|
| 437 |
+
if offload == self.offload:
|
| 438 |
+
# nothing changed, just return
|
| 439 |
+
return
|
| 440 |
+
self.offload = offload
|
| 441 |
+
if offload:
|
| 442 |
+
self.offload_manager = CPUOffloadManager(
|
| 443 |
+
self.transformer_blocks,
|
| 444 |
+
use_pin_memory=kwargs.get("use_pin_memory", True),
|
| 445 |
+
on_gpu_modules=[
|
| 446 |
+
self.img_in,
|
| 447 |
+
self.txt_in,
|
| 448 |
+
self.txt_norm,
|
| 449 |
+
self.time_text_embed,
|
| 450 |
+
self.norm_out,
|
| 451 |
+
self.proj_out,
|
| 452 |
+
],
|
| 453 |
+
num_blocks_on_gpu=kwargs.get("num_blocks_on_gpu", 1),
|
| 454 |
+
)
|
| 455 |
+
else:
|
| 456 |
+
self.offload_manager = None
|
| 457 |
+
gc.collect()
|
| 458 |
+
torch.cuda.empty_cache()
|
| 459 |
+
|
| 460 |
+
def forward(
|
| 461 |
+
self,
|
| 462 |
+
hidden_states: torch.Tensor,
|
| 463 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 464 |
+
encoder_hidden_states_mask: torch.Tensor = None,
|
| 465 |
+
timestep: torch.LongTensor = None,
|
| 466 |
+
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
|
| 467 |
+
txt_seq_lens: Optional[List[int]] = None,
|
| 468 |
+
guidance: torch.Tensor = None,
|
| 469 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 470 |
+
return_dict: bool = True,
|
| 471 |
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
| 472 |
+
"""
|
| 473 |
+
Forward pass for the quantized QwenImage transformer model.
|
| 474 |
+
|
| 475 |
+
Parameters
|
| 476 |
+
----------
|
| 477 |
+
hidden_states : torch.Tensor
|
| 478 |
+
Image stream input.
|
| 479 |
+
encoder_hidden_states : torch.Tensor, optional
|
| 480 |
+
Text stream input.
|
| 481 |
+
encoder_hidden_states_mask : torch.Tensor, optional
|
| 482 |
+
Mask for encoder hidden states.
|
| 483 |
+
timestep : torch.LongTensor, optional
|
| 484 |
+
Timestep for temporal embedding.
|
| 485 |
+
img_shapes : list of tuple, optional
|
| 486 |
+
Image shapes for rotary embedding.
|
| 487 |
+
txt_seq_lens : list of int, optional
|
| 488 |
+
Text sequence lengths.
|
| 489 |
+
guidance : torch.Tensor, optional
|
| 490 |
+
Guidance tensor (for classifier-free guidance).
|
| 491 |
+
attention_kwargs : dict, optional
|
| 492 |
+
Additional attention arguments.
|
| 493 |
+
return_dict : bool, default=True
|
| 494 |
+
Whether to return a dict or tuple.
|
| 495 |
+
|
| 496 |
+
Returns
|
| 497 |
+
-------
|
| 498 |
+
torch.Tensor or Transformer2DModelOutput
|
| 499 |
+
Model output.
|
| 500 |
+
"""
|
| 501 |
+
device = hidden_states.device
|
| 502 |
+
if self.offload:
|
| 503 |
+
self.offload_manager.set_device(device)
|
| 504 |
+
|
| 505 |
+
hidden_states = self.img_in(hidden_states)
|
| 506 |
+
|
| 507 |
+
timestep = timestep.to(hidden_states.dtype)
|
| 508 |
+
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
| 509 |
+
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
| 510 |
+
|
| 511 |
+
if guidance is not None:
|
| 512 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
| 513 |
+
|
| 514 |
+
temb = (
|
| 515 |
+
self.time_text_embed(timestep, hidden_states)
|
| 516 |
+
if guidance is None
|
| 517 |
+
else self.time_text_embed(timestep, guidance, hidden_states)
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
|
| 521 |
+
|
| 522 |
+
compute_stream = torch.cuda.current_stream()
|
| 523 |
+
if self.offload:
|
| 524 |
+
self.offload_manager.initialize(compute_stream)
|
| 525 |
+
for block_idx, block in enumerate(self.transformer_blocks):
|
| 526 |
+
with torch.cuda.stream(compute_stream):
|
| 527 |
+
if self.offload:
|
| 528 |
+
block = self.offload_manager.get_block(block_idx)
|
| 529 |
+
encoder_hidden_states, hidden_states = block(
|
| 530 |
+
hidden_states=hidden_states,
|
| 531 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 532 |
+
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
| 533 |
+
temb=temb,
|
| 534 |
+
image_rotary_emb=image_rotary_emb,
|
| 535 |
+
joint_attention_kwargs=attention_kwargs,
|
| 536 |
+
)
|
| 537 |
+
if self.offload:
|
| 538 |
+
self.offload_manager.step(compute_stream)
|
| 539 |
+
|
| 540 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 541 |
+
output = self.proj_out(hidden_states)
|
| 542 |
+
|
| 543 |
+
torch.cuda.empty_cache()
|
| 544 |
+
|
| 545 |
+
if not return_dict:
|
| 546 |
+
return (output,)
|
| 547 |
+
|
| 548 |
+
return Transformer2DModelOutput(sample=output)
|
| 549 |
+
|
| 550 |
+
def to(self, *args, **kwargs):
|
| 551 |
+
"""
|
| 552 |
+
Override the default ``.to()`` method.
|
| 553 |
+
|
| 554 |
+
If offload is enabled, prevents moving the model to GPU.
|
| 555 |
+
Prevents changing dtype after quantization.
|
| 556 |
+
|
| 557 |
+
Parameters
|
| 558 |
+
----------
|
| 559 |
+
*args
|
| 560 |
+
Positional arguments for ``.to()``.
|
| 561 |
+
**kwargs
|
| 562 |
+
Keyword arguments for ``.to()``.
|
| 563 |
+
|
| 564 |
+
Returns
|
| 565 |
+
-------
|
| 566 |
+
self
|
| 567 |
+
|
| 568 |
+
Raises
|
| 569 |
+
------
|
| 570 |
+
ValueError
|
| 571 |
+
If attempting to change dtype after quantization.
|
| 572 |
+
"""
|
| 573 |
+
device_arg_or_kwarg_present = any(isinstance(arg, torch.device) for arg in args) or "device" in kwargs
|
| 574 |
+
dtype_present_in_args = "dtype" in kwargs
|
| 575 |
+
|
| 576 |
+
# Try converting arguments to torch.device in case they are passed as strings
|
| 577 |
+
for arg in args:
|
| 578 |
+
if not isinstance(arg, str):
|
| 579 |
+
continue
|
| 580 |
+
try:
|
| 581 |
+
torch.device(arg)
|
| 582 |
+
device_arg_or_kwarg_present = True
|
| 583 |
+
except RuntimeError:
|
| 584 |
+
pass
|
| 585 |
+
|
| 586 |
+
if not dtype_present_in_args:
|
| 587 |
+
for arg in args:
|
| 588 |
+
if isinstance(arg, torch.dtype):
|
| 589 |
+
dtype_present_in_args = True
|
| 590 |
+
break
|
| 591 |
+
|
| 592 |
+
if dtype_present_in_args and self._is_initialized:
|
| 593 |
+
raise ValueError(
|
| 594 |
+
"Casting a quantized model to a new `dtype` is unsupported. To set the dtype of unquantized layers, please "
|
| 595 |
+
"use the `torch_dtype` argument when loading the model using `from_pretrained` or `from_single_file`."
|
| 596 |
+
)
|
| 597 |
+
if self.offload:
|
| 598 |
+
if device_arg_or_kwarg_present:
|
| 599 |
+
warn("Skipping moving the model to GPU as offload is enabled", UserWarning)
|
| 600 |
+
return self
|
| 601 |
+
return super(type(self), self).to(*args, **kwargs)
|
nunchaku/models/transformers/transformer_sana.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Implements the :class:`NunchakuSanaTransformer2DModel`,
|
| 3 |
+
a quantized Sana transformer for Diffusers with efficient inference support.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from diffusers import SanaTransformer2DModel
|
| 12 |
+
from huggingface_hub import utils
|
| 13 |
+
from safetensors.torch import load_file
|
| 14 |
+
from torch import nn
|
| 15 |
+
from torch.nn import functional as F
|
| 16 |
+
|
| 17 |
+
from ..._C import QuantizedSanaModel
|
| 18 |
+
from ..._C import utils as cutils
|
| 19 |
+
from ...utils import get_precision
|
| 20 |
+
from .utils import NunchakuModelLoaderMixin
|
| 21 |
+
|
| 22 |
+
SVD_RANK = 32
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class NunchakuSanaTransformerBlocks(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
Wrapper for quantized Sana transformer blocks.
|
| 28 |
+
|
| 29 |
+
This module wraps a QuantizedSanaModel and provides forward methods compatible
|
| 30 |
+
with the expected transformer block interface.
|
| 31 |
+
|
| 32 |
+
Parameters
|
| 33 |
+
----------
|
| 34 |
+
m : QuantizedSanaModel
|
| 35 |
+
The quantized transformer model.
|
| 36 |
+
dtype : torch.dtype
|
| 37 |
+
The data type to use for computation.
|
| 38 |
+
device : str or torch.device
|
| 39 |
+
The device to run the model on.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, m: QuantizedSanaModel, dtype: torch.dtype, device: str | torch.device):
|
| 43 |
+
super(NunchakuSanaTransformerBlocks, self).__init__()
|
| 44 |
+
self.m = m
|
| 45 |
+
self.dtype = dtype
|
| 46 |
+
self.device = device
|
| 47 |
+
|
| 48 |
+
def forward(
|
| 49 |
+
self,
|
| 50 |
+
hidden_states: torch.Tensor,
|
| 51 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 52 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 53 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 54 |
+
timestep: Optional[torch.LongTensor] = None,
|
| 55 |
+
height: Optional[int] = None,
|
| 56 |
+
width: Optional[int] = None,
|
| 57 |
+
skip_first_layer: Optional[bool] = False,
|
| 58 |
+
):
|
| 59 |
+
"""
|
| 60 |
+
Forward pass through all quantized transformer blocks.
|
| 61 |
+
|
| 62 |
+
Parameters
|
| 63 |
+
----------
|
| 64 |
+
hidden_states : torch.Tensor
|
| 65 |
+
Input hidden states of shape (batch_size, img_tokens, ...).
|
| 66 |
+
attention_mask : torch.Tensor, optional
|
| 67 |
+
Not used.
|
| 68 |
+
encoder_hidden_states : torch.Tensor, optional
|
| 69 |
+
Encoder hidden states of shape (batch_size, txt_tokens, ...).
|
| 70 |
+
encoder_attention_mask : torch.Tensor, optional
|
| 71 |
+
Encoder attention mask of shape (batch_size, 1, txt_tokens).
|
| 72 |
+
timestep : torch.LongTensor, optional
|
| 73 |
+
Timestep tensor.
|
| 74 |
+
height : int, optional
|
| 75 |
+
Image height.
|
| 76 |
+
width : int, optional
|
| 77 |
+
Image width.
|
| 78 |
+
skip_first_layer : bool, optional
|
| 79 |
+
Whether to skip the first layer.
|
| 80 |
+
|
| 81 |
+
Returns
|
| 82 |
+
-------
|
| 83 |
+
torch.Tensor
|
| 84 |
+
Output tensor after passing through the quantized transformer blocks.
|
| 85 |
+
"""
|
| 86 |
+
batch_size = hidden_states.shape[0]
|
| 87 |
+
img_tokens = hidden_states.shape[1]
|
| 88 |
+
txt_tokens = encoder_hidden_states.shape[1]
|
| 89 |
+
|
| 90 |
+
original_dtype = hidden_states.dtype
|
| 91 |
+
original_device = hidden_states.device
|
| 92 |
+
|
| 93 |
+
assert encoder_attention_mask is not None
|
| 94 |
+
assert encoder_attention_mask.shape == (batch_size, 1, txt_tokens)
|
| 95 |
+
|
| 96 |
+
mask = encoder_attention_mask.reshape(batch_size, txt_tokens)
|
| 97 |
+
nunchaku_encoder_hidden_states = encoder_hidden_states[mask > -9000]
|
| 98 |
+
|
| 99 |
+
cu_seqlens_txt = F.pad((mask > -9000).sum(dim=1).cumsum(dim=0), pad=(1, 0), value=0).to(torch.int32)
|
| 100 |
+
cu_seqlens_img = torch.arange(
|
| 101 |
+
0, (batch_size + 1) * img_tokens, img_tokens, dtype=torch.int32, device=self.device
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
if height is None and width is None:
|
| 105 |
+
height = width = int(img_tokens**0.5)
|
| 106 |
+
elif height is None:
|
| 107 |
+
height = img_tokens // width
|
| 108 |
+
elif width is None:
|
| 109 |
+
width = img_tokens // height
|
| 110 |
+
assert height * width == img_tokens
|
| 111 |
+
|
| 112 |
+
return (
|
| 113 |
+
self.m.forward(
|
| 114 |
+
hidden_states.to(self.dtype).to(self.device),
|
| 115 |
+
nunchaku_encoder_hidden_states.to(self.dtype).to(self.device),
|
| 116 |
+
timestep.to(self.dtype).to(self.device),
|
| 117 |
+
cu_seqlens_img.to(self.device),
|
| 118 |
+
cu_seqlens_txt.to(self.device),
|
| 119 |
+
height,
|
| 120 |
+
width,
|
| 121 |
+
batch_size % 3 == 0, # pag is set when loading the model, FIXME: pag_scale == 0
|
| 122 |
+
True, # TODO: find a way to detect if we are doing CFG
|
| 123 |
+
skip_first_layer,
|
| 124 |
+
)
|
| 125 |
+
.to(original_dtype)
|
| 126 |
+
.to(original_device)
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def forward_layer_at(
|
| 130 |
+
self,
|
| 131 |
+
idx: int,
|
| 132 |
+
hidden_states: torch.Tensor,
|
| 133 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 134 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 135 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 136 |
+
timestep: Optional[torch.LongTensor] = None,
|
| 137 |
+
height: Optional[int] = None,
|
| 138 |
+
width: Optional[int] = None,
|
| 139 |
+
):
|
| 140 |
+
"""
|
| 141 |
+
Forward pass through a specific quantized transformer layer.
|
| 142 |
+
|
| 143 |
+
Parameters
|
| 144 |
+
----------
|
| 145 |
+
idx : int
|
| 146 |
+
Index of the layer to run.
|
| 147 |
+
hidden_states : torch.Tensor
|
| 148 |
+
Input hidden states.
|
| 149 |
+
attention_mask : torch.Tensor, optional
|
| 150 |
+
Not used.
|
| 151 |
+
encoder_hidden_states : torch.Tensor, optional
|
| 152 |
+
Encoder hidden states.
|
| 153 |
+
encoder_attention_mask : torch.Tensor, optional
|
| 154 |
+
Encoder attention mask.
|
| 155 |
+
timestep : torch.LongTensor, optional
|
| 156 |
+
Timestep tensor.
|
| 157 |
+
height : int, optional
|
| 158 |
+
Image height.
|
| 159 |
+
width : int, optional
|
| 160 |
+
Image width.
|
| 161 |
+
|
| 162 |
+
Returns
|
| 163 |
+
-------
|
| 164 |
+
torch.Tensor
|
| 165 |
+
Output tensor after passing through the specified quantized transformer layer.
|
| 166 |
+
"""
|
| 167 |
+
batch_size = hidden_states.shape[0]
|
| 168 |
+
img_tokens = hidden_states.shape[1]
|
| 169 |
+
txt_tokens = encoder_hidden_states.shape[1]
|
| 170 |
+
|
| 171 |
+
original_dtype = hidden_states.dtype
|
| 172 |
+
original_device = hidden_states.device
|
| 173 |
+
|
| 174 |
+
assert encoder_attention_mask is not None
|
| 175 |
+
assert encoder_attention_mask.shape == (batch_size, 1, txt_tokens)
|
| 176 |
+
|
| 177 |
+
mask = encoder_attention_mask.reshape(batch_size, txt_tokens)
|
| 178 |
+
nunchaku_encoder_hidden_states = encoder_hidden_states[mask > -9000]
|
| 179 |
+
|
| 180 |
+
cu_seqlens_txt = F.pad((mask > -9000).sum(dim=1).cumsum(dim=0), pad=(1, 0), value=0).to(torch.int32)
|
| 181 |
+
cu_seqlens_img = torch.arange(
|
| 182 |
+
0, (batch_size + 1) * img_tokens, img_tokens, dtype=torch.int32, device=self.device
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
if height is None and width is None:
|
| 186 |
+
height = width = int(img_tokens**0.5)
|
| 187 |
+
elif height is None:
|
| 188 |
+
height = img_tokens // width
|
| 189 |
+
elif width is None:
|
| 190 |
+
width = img_tokens // height
|
| 191 |
+
assert height * width == img_tokens
|
| 192 |
+
|
| 193 |
+
return (
|
| 194 |
+
self.m.forward_layer(
|
| 195 |
+
idx,
|
| 196 |
+
hidden_states.to(self.dtype).to(self.device),
|
| 197 |
+
nunchaku_encoder_hidden_states.to(self.dtype).to(self.device),
|
| 198 |
+
timestep.to(self.dtype).to(self.device),
|
| 199 |
+
cu_seqlens_img.to(self.device),
|
| 200 |
+
cu_seqlens_txt.to(self.device),
|
| 201 |
+
height,
|
| 202 |
+
width,
|
| 203 |
+
batch_size % 3 == 0, # pag is set when loading the model, FIXME: pag_scale == 0
|
| 204 |
+
True, # TODO: find a way to detect if we are doing CFG
|
| 205 |
+
)
|
| 206 |
+
.to(original_dtype)
|
| 207 |
+
.to(original_device)
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
def __del__(self):
|
| 211 |
+
"""
|
| 212 |
+
Destructor to reset the quantized model and free resources.
|
| 213 |
+
"""
|
| 214 |
+
self.m.reset()
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoaderMixin):
|
| 218 |
+
"""
|
| 219 |
+
SanaTransformer2DModel with Nunchaku quantized backend support.
|
| 220 |
+
|
| 221 |
+
This class extends the base SanaTransformer2DModel to support loading and
|
| 222 |
+
injecting quantized transformer blocks using Nunchaku's custom backend.
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
@classmethod
|
| 226 |
+
@utils.validate_hf_hub_args
|
| 227 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
|
| 228 |
+
"""
|
| 229 |
+
Load a pretrained NunchakuSanaTransformer2DModel from a local file or HuggingFace Hub.
|
| 230 |
+
|
| 231 |
+
This method supports both quantized and unquantized checkpoints, and will
|
| 232 |
+
automatically inject quantized transformer blocks if available.
|
| 233 |
+
|
| 234 |
+
Parameters
|
| 235 |
+
----------
|
| 236 |
+
pretrained_model_name_or_path : str or os.PathLike
|
| 237 |
+
Path to the model checkpoint or HuggingFace Hub model name.
|
| 238 |
+
**kwargs
|
| 239 |
+
Additional keyword arguments for model loading.
|
| 240 |
+
|
| 241 |
+
Returns
|
| 242 |
+
-------
|
| 243 |
+
NunchakuSanaTransformer2DModel or (NunchakuSanaTransformer2DModel, dict)
|
| 244 |
+
The loaded model, and optionally metadata if ``return_metadata=True``.
|
| 245 |
+
"""
|
| 246 |
+
device = kwargs.get("device", "cuda")
|
| 247 |
+
if isinstance(device, str):
|
| 248 |
+
device = torch.device(device)
|
| 249 |
+
pag_layers = kwargs.get("pag_layers", [])
|
| 250 |
+
precision = get_precision(kwargs.get("precision", "auto"), device, pretrained_model_name_or_path)
|
| 251 |
+
metadata = None
|
| 252 |
+
|
| 253 |
+
if isinstance(pretrained_model_name_or_path, str):
|
| 254 |
+
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
|
| 255 |
+
if pretrained_model_name_or_path.is_file() or pretrained_model_name_or_path.name.endswith(
|
| 256 |
+
(".safetensors", ".sft")
|
| 257 |
+
):
|
| 258 |
+
transformer, model_state_dict, metadata = cls._build_model(pretrained_model_name_or_path)
|
| 259 |
+
quantized_part_sd = {}
|
| 260 |
+
unquantized_part_sd = {}
|
| 261 |
+
for k, v in model_state_dict.items():
|
| 262 |
+
if k.startswith("transformer_blocks."):
|
| 263 |
+
quantized_part_sd[k] = v
|
| 264 |
+
else:
|
| 265 |
+
unquantized_part_sd[k] = v
|
| 266 |
+
m = load_quantized_module(
|
| 267 |
+
transformer, quantized_part_sd, device=device, pag_layers=pag_layers, use_fp4=precision == "fp4"
|
| 268 |
+
)
|
| 269 |
+
transformer.inject_quantized_module(m, device)
|
| 270 |
+
transformer.to_empty(device=device)
|
| 271 |
+
transformer.load_state_dict(unquantized_part_sd, strict=False)
|
| 272 |
+
else:
|
| 273 |
+
transformer, unquantized_part_path, transformer_block_path = cls._build_model_legacy(
|
| 274 |
+
pretrained_model_name_or_path, **kwargs
|
| 275 |
+
)
|
| 276 |
+
m = load_quantized_module(
|
| 277 |
+
transformer, transformer_block_path, device=device, pag_layers=pag_layers, use_fp4=precision == "fp4"
|
| 278 |
+
)
|
| 279 |
+
transformer.inject_quantized_module(m, device)
|
| 280 |
+
transformer.to_empty(device=device)
|
| 281 |
+
unquantized_state_dict = load_file(unquantized_part_path)
|
| 282 |
+
transformer.load_state_dict(unquantized_state_dict, strict=False)
|
| 283 |
+
if kwargs.get("return_metadata", False):
|
| 284 |
+
return transformer, metadata
|
| 285 |
+
else:
|
| 286 |
+
return transformer
|
| 287 |
+
|
| 288 |
+
def inject_quantized_module(self, m: QuantizedSanaModel, device: str | torch.device = "cuda"):
|
| 289 |
+
"""
|
| 290 |
+
Inject a quantized transformer module into this model.
|
| 291 |
+
|
| 292 |
+
Parameters
|
| 293 |
+
----------
|
| 294 |
+
m : QuantizedSanaModel
|
| 295 |
+
The quantized transformer module to inject.
|
| 296 |
+
device : str or torch.device, optional
|
| 297 |
+
The device to place the module on (default: "cuda").
|
| 298 |
+
|
| 299 |
+
Returns
|
| 300 |
+
-------
|
| 301 |
+
NunchakuSanaTransformer2DModel
|
| 302 |
+
The model with the quantized module injected.
|
| 303 |
+
"""
|
| 304 |
+
self.transformer_blocks = torch.nn.ModuleList([NunchakuSanaTransformerBlocks(m, self.dtype, device)])
|
| 305 |
+
return self
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def load_quantized_module(
|
| 309 |
+
net: SanaTransformer2DModel,
|
| 310 |
+
path_or_state_dict: str | os.PathLike[str] | dict[str, torch.Tensor],
|
| 311 |
+
device: str | torch.device = "cuda",
|
| 312 |
+
pag_layers: int | list[int] | None = None,
|
| 313 |
+
use_fp4: bool = False,
|
| 314 |
+
) -> QuantizedSanaModel:
|
| 315 |
+
"""
|
| 316 |
+
Load quantized weights into a QuantizedSanaModel.
|
| 317 |
+
|
| 318 |
+
Parameters
|
| 319 |
+
----------
|
| 320 |
+
net : SanaTransformer2DModel
|
| 321 |
+
The base transformer model (for config and dtype).
|
| 322 |
+
path_or_state_dict : str, os.PathLike, or dict
|
| 323 |
+
Path to the quantized weights or a state dict.
|
| 324 |
+
device : str or torch.device, optional
|
| 325 |
+
Device to load the quantized model on (default: "cuda").
|
| 326 |
+
pag_layers : int, list of int, or None, optional
|
| 327 |
+
List of layers to use pag (default: None).
|
| 328 |
+
use_fp4 : bool, optional
|
| 329 |
+
Whether to use FP4 quantization (default: False).
|
| 330 |
+
|
| 331 |
+
Returns
|
| 332 |
+
-------
|
| 333 |
+
QuantizedSanaModel
|
| 334 |
+
The loaded quantized model.
|
| 335 |
+
"""
|
| 336 |
+
if pag_layers is None:
|
| 337 |
+
pag_layers = []
|
| 338 |
+
elif isinstance(pag_layers, int):
|
| 339 |
+
pag_layers = [pag_layers]
|
| 340 |
+
device = torch.device(device)
|
| 341 |
+
assert device.type == "cuda"
|
| 342 |
+
|
| 343 |
+
m = QuantizedSanaModel()
|
| 344 |
+
cutils.disable_memory_auto_release()
|
| 345 |
+
m.init(net.config, pag_layers, use_fp4, net.dtype == torch.bfloat16, 0 if device.index is None else device.index)
|
| 346 |
+
if isinstance(path_or_state_dict, dict):
|
| 347 |
+
m.loadDict(path_or_state_dict, True)
|
| 348 |
+
else:
|
| 349 |
+
m.load(str(path_or_state_dict))
|
| 350 |
+
return m
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def inject_quantized_module(
|
| 354 |
+
net: SanaTransformer2DModel, m: QuantizedSanaModel, device: torch.device
|
| 355 |
+
) -> SanaTransformer2DModel:
|
| 356 |
+
"""
|
| 357 |
+
Inject a quantized transformer module into a SanaTransformer2DModel.
|
| 358 |
+
|
| 359 |
+
Parameters
|
| 360 |
+
----------
|
| 361 |
+
net : SanaTransformer2DModel
|
| 362 |
+
The base transformer model.
|
| 363 |
+
m : QuantizedSanaModel
|
| 364 |
+
The quantized transformer module to inject.
|
| 365 |
+
device : torch.device
|
| 366 |
+
The device to place the module on.
|
| 367 |
+
|
| 368 |
+
Returns
|
| 369 |
+
-------
|
| 370 |
+
SanaTransformer2DModel
|
| 371 |
+
The model with the quantized module injected.
|
| 372 |
+
"""
|
| 373 |
+
net.transformer_blocks = torch.nn.ModuleList([NunchakuSanaTransformerBlocks(m, net.dtype, device)])
|
| 374 |
+
return net
|
nunchaku/models/transformers/utils.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for Nunchaku transformer model loading.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from diffusers import __version__
|
| 12 |
+
from huggingface_hub import constants, hf_hub_download
|
| 13 |
+
from torch import nn
|
| 14 |
+
|
| 15 |
+
from ...utils import load_state_dict_in_safetensors
|
| 16 |
+
|
| 17 |
+
# Get log level from environment variable (default to INFO)
|
| 18 |
+
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
|
| 19 |
+
|
| 20 |
+
# Configure logging
|
| 21 |
+
logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class NunchakuModelLoaderMixin:
|
| 26 |
+
"""
|
| 27 |
+
Mixin for standardized model loading in Nunchaku transformer models.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
@classmethod
|
| 31 |
+
def _build_model(
|
| 32 |
+
cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs
|
| 33 |
+
) -> tuple[nn.Module, dict[str, torch.Tensor], dict[str, str]]:
|
| 34 |
+
"""
|
| 35 |
+
Build a transformer model from a safetensors file.
|
| 36 |
+
|
| 37 |
+
Parameters
|
| 38 |
+
----------
|
| 39 |
+
pretrained_model_name_or_path : str or os.PathLike
|
| 40 |
+
Path to the safetensors file.
|
| 41 |
+
**kwargs
|
| 42 |
+
Additional keyword arguments (e.g., ``torch_dtype``).
|
| 43 |
+
|
| 44 |
+
Returns
|
| 45 |
+
-------
|
| 46 |
+
tuple
|
| 47 |
+
(transformer, state_dict, metadata)
|
| 48 |
+
"""
|
| 49 |
+
if isinstance(pretrained_model_name_or_path, str):
|
| 50 |
+
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
|
| 51 |
+
state_dict, metadata = load_state_dict_in_safetensors(pretrained_model_name_or_path, return_metadata=True)
|
| 52 |
+
|
| 53 |
+
config = json.loads(metadata["config"])
|
| 54 |
+
|
| 55 |
+
with torch.device("meta"):
|
| 56 |
+
transformer = cls.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16))
|
| 57 |
+
|
| 58 |
+
return transformer, state_dict, metadata
|
| 59 |
+
|
| 60 |
+
@classmethod
|
| 61 |
+
def _build_model_legacy(
|
| 62 |
+
cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs
|
| 63 |
+
) -> tuple[nn.Module, str, str]:
|
| 64 |
+
"""
|
| 65 |
+
Build a transformer model from a legacy folder structure.
|
| 66 |
+
|
| 67 |
+
.. warning::
|
| 68 |
+
This method is deprecated and will be removed in December 2025.
|
| 69 |
+
Please use :meth:`_build_model` instead.
|
| 70 |
+
|
| 71 |
+
Parameters
|
| 72 |
+
----------
|
| 73 |
+
pretrained_model_name_or_path : str or os.PathLike
|
| 74 |
+
Path to the folder containing model weights.
|
| 75 |
+
**kwargs
|
| 76 |
+
Additional keyword arguments for HuggingFace Hub download and config loading.
|
| 77 |
+
|
| 78 |
+
Returns
|
| 79 |
+
-------
|
| 80 |
+
tuple
|
| 81 |
+
(transformer, unquantized_part_path, transformer_block_path)
|
| 82 |
+
"""
|
| 83 |
+
logger.warning(
|
| 84 |
+
"Loading models from a folder will be deprecated in December 2025. "
|
| 85 |
+
"Please download the latest safetensors model, or use one of the following tools to "
|
| 86 |
+
"merge your model into a single file: the CLI utility `python -m nunchaku.merge_safetensors` "
|
| 87 |
+
"or the ComfyUI workflow `merge_safetensors.json`."
|
| 88 |
+
)
|
| 89 |
+
subfolder = kwargs.get("subfolder", None)
|
| 90 |
+
if os.path.exists(pretrained_model_name_or_path):
|
| 91 |
+
dirname = (
|
| 92 |
+
pretrained_model_name_or_path
|
| 93 |
+
if subfolder is None
|
| 94 |
+
else os.path.join(pretrained_model_name_or_path, subfolder)
|
| 95 |
+
)
|
| 96 |
+
unquantized_part_path = os.path.join(dirname, "unquantized_layers.safetensors")
|
| 97 |
+
transformer_block_path = os.path.join(dirname, "transformer_blocks.safetensors")
|
| 98 |
+
else:
|
| 99 |
+
download_kwargs = {
|
| 100 |
+
"subfolder": subfolder,
|
| 101 |
+
"repo_type": "model",
|
| 102 |
+
"revision": kwargs.get("revision", None),
|
| 103 |
+
"cache_dir": kwargs.get("cache_dir", None),
|
| 104 |
+
"local_dir": kwargs.get("local_dir", None),
|
| 105 |
+
"user_agent": kwargs.get("user_agent", None),
|
| 106 |
+
"force_download": kwargs.get("force_download", False),
|
| 107 |
+
"proxies": kwargs.get("proxies", None),
|
| 108 |
+
"etag_timeout": kwargs.get("etag_timeout", constants.DEFAULT_ETAG_TIMEOUT),
|
| 109 |
+
"token": kwargs.get("token", None),
|
| 110 |
+
"local_files_only": kwargs.get("local_files_only", None),
|
| 111 |
+
"headers": kwargs.get("headers", None),
|
| 112 |
+
"endpoint": kwargs.get("endpoint", None),
|
| 113 |
+
"resume_download": kwargs.get("resume_download", None),
|
| 114 |
+
"force_filename": kwargs.get("force_filename", None),
|
| 115 |
+
"local_dir_use_symlinks": kwargs.get("local_dir_use_symlinks", "auto"),
|
| 116 |
+
}
|
| 117 |
+
unquantized_part_path = hf_hub_download(
|
| 118 |
+
repo_id=str(pretrained_model_name_or_path), filename="unquantized_layers.safetensors", **download_kwargs
|
| 119 |
+
)
|
| 120 |
+
transformer_block_path = hf_hub_download(
|
| 121 |
+
repo_id=str(pretrained_model_name_or_path), filename="transformer_blocks.safetensors", **download_kwargs
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
| 125 |
+
force_download = kwargs.pop("force_download", False)
|
| 126 |
+
proxies = kwargs.pop("proxies", None)
|
| 127 |
+
local_files_only = kwargs.pop("local_files_only", None)
|
| 128 |
+
token = kwargs.pop("token", None)
|
| 129 |
+
revision = kwargs.pop("revision", None)
|
| 130 |
+
config, _, _ = cls.load_config(
|
| 131 |
+
pretrained_model_name_or_path,
|
| 132 |
+
subfolder=subfolder,
|
| 133 |
+
cache_dir=cache_dir,
|
| 134 |
+
return_unused_kwargs=True,
|
| 135 |
+
return_commit_hash=True,
|
| 136 |
+
force_download=force_download,
|
| 137 |
+
proxies=proxies,
|
| 138 |
+
local_files_only=local_files_only,
|
| 139 |
+
token=token,
|
| 140 |
+
revision=revision,
|
| 141 |
+
user_agent={"diffusers": __version__, "file_type": "model", "framework": "pytorch"},
|
| 142 |
+
**kwargs,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
with torch.device("meta"):
|
| 146 |
+
transformer = cls.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16))
|
| 147 |
+
return transformer, unquantized_part_path, transformer_block_path
|
nunchaku/models/utils.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions and classes for efficient transformer model management in Nunchaku.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import copy
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
from ..utils import copy_params_into
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def fuse_linears(linears: list[nn.Linear]) -> nn.Linear:
|
| 14 |
+
"""
|
| 15 |
+
Fuse a list of nn.Linear layers into a single nn.Linear with concatenated output features.
|
| 16 |
+
|
| 17 |
+
Parameters
|
| 18 |
+
----------
|
| 19 |
+
linears : list of nn.Linear
|
| 20 |
+
List of linear layers to fuse. All must have the same input feature dimension.
|
| 21 |
+
|
| 22 |
+
Returns
|
| 23 |
+
-------
|
| 24 |
+
fused : nn.Linear
|
| 25 |
+
A new linear layer with concatenated output features and the same input features.
|
| 26 |
+
|
| 27 |
+
Raises
|
| 28 |
+
------
|
| 29 |
+
AssertionError
|
| 30 |
+
If the input feature dimensions do not match.
|
| 31 |
+
|
| 32 |
+
Notes
|
| 33 |
+
-----
|
| 34 |
+
The fused layer does not copy weights or biases from the input layers.
|
| 35 |
+
"""
|
| 36 |
+
assert len(linears) > 0
|
| 37 |
+
if len(linears) == 1:
|
| 38 |
+
return linears[0]
|
| 39 |
+
else:
|
| 40 |
+
assert all(linear.in_features == linears[0].in_features for linear in linears)
|
| 41 |
+
out_features = sum(linear.out_features for linear in linears)
|
| 42 |
+
bias = all(linear.bias is not None for linear in linears)
|
| 43 |
+
return nn.Linear(
|
| 44 |
+
linears[0].in_features,
|
| 45 |
+
out_features,
|
| 46 |
+
bias=bias,
|
| 47 |
+
dtype=linears[0].weight.dtype,
|
| 48 |
+
device=linears[0].weight.device,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class CPUOffloadManager:
|
| 53 |
+
"""
|
| 54 |
+
Manager for per-transformer-block CPU offloading with asynchronous memory operations using a Ping-Pong buffer strategy.
|
| 55 |
+
|
| 56 |
+
This class enables memory-efficient inference or training by keeping only a subset
|
| 57 |
+
of transformer blocks on GPU, offloading the rest to CPU, and preloading blocks as needed.
|
| 58 |
+
|
| 59 |
+
Parameters
|
| 60 |
+
----------
|
| 61 |
+
blocks : list of nn.Module
|
| 62 |
+
List of transformer blocks to manage.
|
| 63 |
+
device : str or torch.device, optional
|
| 64 |
+
Target CUDA device for GPU operations. Default is "cuda".
|
| 65 |
+
use_pin_memory : bool, optional
|
| 66 |
+
Whether to use pinned memory for faster CPU-to-GPU transfers. Default is True.
|
| 67 |
+
on_gpu_modules : list of nn.Module, optional
|
| 68 |
+
Additional modules to keep on GPU at all times. Default is [].
|
| 69 |
+
num_blocks_on_gpu : int, optional
|
| 70 |
+
Number of blocks to keep on GPU simultaneously. Must be > 0. Default is 1.
|
| 71 |
+
empty_cache_freq : int, optional
|
| 72 |
+
Frequency (in forward passes) to call torch.cuda.empty_cache(). Default is 0 (never).
|
| 73 |
+
|
| 74 |
+
Attributes
|
| 75 |
+
----------
|
| 76 |
+
blocks : list of nn.Module
|
| 77 |
+
The managed transformer blocks.
|
| 78 |
+
buffer_blocks : list of nn.Module
|
| 79 |
+
Buffers for preloading blocks onto GPU.
|
| 80 |
+
device : torch.device
|
| 81 |
+
The current CUDA device.
|
| 82 |
+
current_block_idx : int
|
| 83 |
+
Index of the current block on GPU.
|
| 84 |
+
forward_counter : int
|
| 85 |
+
Number of forward passes completed.
|
| 86 |
+
memory_stream : torch.cuda.Stream
|
| 87 |
+
CUDA stream for memory operations.
|
| 88 |
+
compute_done : torch.cuda.Event
|
| 89 |
+
CUDA event signaling compute completion.
|
| 90 |
+
memory_done : torch.cuda.Event
|
| 91 |
+
CUDA event signaling memory completion.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
def __init__(
|
| 95 |
+
self,
|
| 96 |
+
blocks: list[nn.Module],
|
| 97 |
+
device: str | torch.device = torch.device("cuda"),
|
| 98 |
+
use_pin_memory: bool = True,
|
| 99 |
+
on_gpu_modules: list[nn.Module] = [],
|
| 100 |
+
num_blocks_on_gpu: int = 1,
|
| 101 |
+
empty_cache_freq: int = 0,
|
| 102 |
+
):
|
| 103 |
+
self.blocks = blocks
|
| 104 |
+
self.use_pin_memory = use_pin_memory
|
| 105 |
+
self.on_gpu_modules = on_gpu_modules
|
| 106 |
+
self.num_blocks_on_gpu = num_blocks_on_gpu
|
| 107 |
+
assert self.num_blocks_on_gpu > 0
|
| 108 |
+
|
| 109 |
+
# Two streams: one for compute, one for memory operations, will be initialized in set_device
|
| 110 |
+
self.memory_stream = None
|
| 111 |
+
|
| 112 |
+
self.compute_done = torch.cuda.Event(blocking=False)
|
| 113 |
+
self.memory_done = torch.cuda.Event(blocking=False)
|
| 114 |
+
|
| 115 |
+
self.buffer_blocks = [copy.deepcopy(blocks[0]), copy.deepcopy(blocks[0])]
|
| 116 |
+
|
| 117 |
+
self.device = None
|
| 118 |
+
self.set_device(device)
|
| 119 |
+
|
| 120 |
+
self.current_block_idx = 0
|
| 121 |
+
self.forward_counter = 0
|
| 122 |
+
self.empty_cache_freq = empty_cache_freq
|
| 123 |
+
|
| 124 |
+
def set_device(self, device: torch.device | str, force: bool = False):
|
| 125 |
+
"""
|
| 126 |
+
Set the CUDA device for offloading and memory operations.
|
| 127 |
+
It will move buffer blocks and on-GPU modules to the specified device and offload other blocks to CPU, optionally using pinned memory.
|
| 128 |
+
|
| 129 |
+
Parameters
|
| 130 |
+
----------
|
| 131 |
+
device : torch.device or str
|
| 132 |
+
Target CUDA device.
|
| 133 |
+
force : bool, optional
|
| 134 |
+
If True, force re-initialization even if device is unchanged. Default is False.
|
| 135 |
+
|
| 136 |
+
Raises
|
| 137 |
+
------
|
| 138 |
+
AssertionError
|
| 139 |
+
If the device is not a CUDA device.
|
| 140 |
+
"""
|
| 141 |
+
if isinstance(device, str):
|
| 142 |
+
device = torch.device(device)
|
| 143 |
+
assert device.type == "cuda"
|
| 144 |
+
if self.device == device and not force:
|
| 145 |
+
return
|
| 146 |
+
self.device = device
|
| 147 |
+
self.memory_stream = torch.cuda.Stream(device=device)
|
| 148 |
+
for block in self.buffer_blocks:
|
| 149 |
+
block.to(device)
|
| 150 |
+
for module in self.on_gpu_modules:
|
| 151 |
+
module.to(device)
|
| 152 |
+
for i, block in enumerate(self.blocks):
|
| 153 |
+
if i < self.num_blocks_on_gpu:
|
| 154 |
+
block.to(device)
|
| 155 |
+
else:
|
| 156 |
+
block.to("cpu")
|
| 157 |
+
if self.use_pin_memory:
|
| 158 |
+
for p in block.parameters(recurse=True):
|
| 159 |
+
p.data = p.data.pin_memory()
|
| 160 |
+
for b in block.buffers(recurse=True):
|
| 161 |
+
b.data = b.data.pin_memory()
|
| 162 |
+
|
| 163 |
+
def load_block(self, block_idx: int, non_blocking: bool = True):
|
| 164 |
+
"""
|
| 165 |
+
Move a transformer block from CPU to GPU buffer.
|
| 166 |
+
|
| 167 |
+
Parameters
|
| 168 |
+
----------
|
| 169 |
+
block_idx : int
|
| 170 |
+
Index of the block to load.
|
| 171 |
+
non_blocking : bool, optional
|
| 172 |
+
Whether to use non-blocking memory copy. Default is True.
|
| 173 |
+
|
| 174 |
+
Notes
|
| 175 |
+
-----
|
| 176 |
+
- No action is taken if the block is already on GPU or index is out of range.
|
| 177 |
+
"""
|
| 178 |
+
# if the block is already on GPU, don't load it to the buffer
|
| 179 |
+
if block_idx < self.num_blocks_on_gpu:
|
| 180 |
+
return
|
| 181 |
+
# if there are blocks on GPU, don't load the first block to the buffer again
|
| 182 |
+
if block_idx >= len(self.blocks):
|
| 183 |
+
return
|
| 184 |
+
|
| 185 |
+
block = self.blocks[block_idx]
|
| 186 |
+
copy_params_into(block, self.buffer_blocks[block_idx % 2], non_blocking=non_blocking)
|
| 187 |
+
|
| 188 |
+
def step(self, compute_stream: torch.cuda.Stream | None = None):
|
| 189 |
+
"""
|
| 190 |
+
Advance to the next transformer block, triggering asynchronous preloading.
|
| 191 |
+
|
| 192 |
+
It will preload the next block onto GPU in the background and synchronize between compute and memory streams.
|
| 193 |
+
After all the blocks are processed, it will call torch.cuda.empty_cache() periodically if ``empty_cache_freq`` > 0.
|
| 194 |
+
|
| 195 |
+
Parameters
|
| 196 |
+
----------
|
| 197 |
+
compute_stream : torch.cuda.Stream, optional
|
| 198 |
+
CUDA stream for compute operations. If None, uses current stream.
|
| 199 |
+
"""
|
| 200 |
+
if compute_stream is None:
|
| 201 |
+
compute_stream = torch.cuda.current_stream()
|
| 202 |
+
next_compute_done = torch.cuda.Event()
|
| 203 |
+
next_compute_done.record(compute_stream)
|
| 204 |
+
with torch.cuda.stream(self.memory_stream):
|
| 205 |
+
self.memory_stream.wait_event(self.compute_done)
|
| 206 |
+
self.load_block(self.current_block_idx + 1) # if the current block is the last block, load the first block
|
| 207 |
+
next_memory_done = torch.cuda.Event()
|
| 208 |
+
next_memory_done.record(self.memory_stream)
|
| 209 |
+
self.memory_done = next_memory_done
|
| 210 |
+
self.compute_done = next_compute_done
|
| 211 |
+
self.current_block_idx += 1
|
| 212 |
+
if self.current_block_idx < len(self.blocks):
|
| 213 |
+
# get ready for the next compute
|
| 214 |
+
compute_stream.wait_event(self.memory_done)
|
| 215 |
+
else:
|
| 216 |
+
# ready to finish
|
| 217 |
+
compute_stream.wait_event(self.compute_done)
|
| 218 |
+
self.current_block_idx = 0
|
| 219 |
+
self.forward_counter += 1
|
| 220 |
+
if self.empty_cache_freq > 0 and self.forward_counter % self.empty_cache_freq == 0:
|
| 221 |
+
torch.cuda.empty_cache()
|
| 222 |
+
|
| 223 |
+
def get_block(self, block_idx: int | None = None) -> nn.Module:
|
| 224 |
+
"""
|
| 225 |
+
Retrieve the current or specified transformer block for computation.
|
| 226 |
+
It will return a buffer block if the requested block is offloaded.
|
| 227 |
+
|
| 228 |
+
Parameters
|
| 229 |
+
----------
|
| 230 |
+
block_idx : int, optional
|
| 231 |
+
Index of the block to retrieve. If None, returns the current block.
|
| 232 |
+
|
| 233 |
+
Returns
|
| 234 |
+
-------
|
| 235 |
+
block : nn.Module
|
| 236 |
+
The requested transformer block (on GPU if needed).
|
| 237 |
+
"""
|
| 238 |
+
if block_idx is None:
|
| 239 |
+
block_idx = self.current_block_idx
|
| 240 |
+
if block_idx < self.num_blocks_on_gpu:
|
| 241 |
+
return self.blocks[block_idx]
|
| 242 |
+
else:
|
| 243 |
+
return self.buffer_blocks[block_idx % 2]
|
| 244 |
+
|
| 245 |
+
def initialize(self, stream: torch.cuda.Stream | None = None):
|
| 246 |
+
"""
|
| 247 |
+
Initialize CUDA events for compute and memory streams.
|
| 248 |
+
It will record the initial events for the compute and memory streams.
|
| 249 |
+
|
| 250 |
+
Parameters
|
| 251 |
+
----------
|
| 252 |
+
stream : torch.cuda.Stream, optional
|
| 253 |
+
CUDA stream to record initial events. If None, uses current stream.
|
| 254 |
+
|
| 255 |
+
Notes
|
| 256 |
+
-----
|
| 257 |
+
- Should be called before the first forward pass.
|
| 258 |
+
"""
|
| 259 |
+
if stream is None:
|
| 260 |
+
stream = torch.cuda.current_stream()
|
| 261 |
+
self.compute_done.record(stream)
|
| 262 |
+
self.memory_done.record(stream)
|
nunchaku/ops/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Quantized operations for FLUX-Kontext
|
nunchaku/ops/fused.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
High-performance fused operators for quantized neural network inference.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.nn import RMSNorm
|
| 7 |
+
|
| 8 |
+
from nunchaku.models.linear import SVDQW4A4Linear
|
| 9 |
+
|
| 10 |
+
from ..utils import ceil_divide
|
| 11 |
+
from .gemm import svdq_gemm_w4a4_cuda
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def fused_gelu_mlp(x: torch.Tensor, fc1: SVDQW4A4Linear, fc2: SVDQW4A4Linear, pad_size: int = 256) -> torch.Tensor:
|
| 15 |
+
"""
|
| 16 |
+
Fused quantized MLP with GELU activation.
|
| 17 |
+
|
| 18 |
+
Combines the first quantized linear layer, GELU activation, and the second quantized linear layer into a single CUDA kernel. Supports INT4 and NVFP4 quantization.
|
| 19 |
+
|
| 20 |
+
Parameters
|
| 21 |
+
----------
|
| 22 |
+
x : torch.Tensor, shape (B, S, C_in), dtype float16 or bfloat16
|
| 23 |
+
Input tensor.
|
| 24 |
+
fc1 : SVDQW4A4Linear
|
| 25 |
+
First quantized linear layer (input → hidden).
|
| 26 |
+
fc2 : SVDQW4A4Linear
|
| 27 |
+
Second quantized linear layer (hidden → output).
|
| 28 |
+
pad_size : int, optional
|
| 29 |
+
Batch padding size for CUDA kernel efficiency. Default is 256.
|
| 30 |
+
|
| 31 |
+
Returns
|
| 32 |
+
-------
|
| 33 |
+
torch.Tensor, shape (B, S, C_out), dtype as input
|
| 34 |
+
Output tensor.
|
| 35 |
+
|
| 36 |
+
Notes
|
| 37 |
+
-----
|
| 38 |
+
- Notations:
|
| 39 |
+
|
| 40 |
+
- B: batch size
|
| 41 |
+
- S: sequence length
|
| 42 |
+
- C_in: input features
|
| 43 |
+
- C_out: output features
|
| 44 |
+
- For INT4 quantization, GELU activations are shifted by 0.171875 to ensure non-negativity, enabling unsigned quantization for improved quality. See: https://github.com/nunchaku-tech/nunchaku/blob/433f0b228a61a53fb700ac676fd2e290368ac94d/src/kernels/zgemm/gemm_w4a4_launch_impl.cuh#L286
|
| 45 |
+
"""
|
| 46 |
+
batch_size, seq_len, channels = x.shape
|
| 47 |
+
x = x.view(batch_size * seq_len, channels)
|
| 48 |
+
quantized_x, ascales, lora_act = fc1.quantize(x)
|
| 49 |
+
|
| 50 |
+
batch_size_pad = ceil_divide(batch_size * seq_len, pad_size) * pad_size
|
| 51 |
+
|
| 52 |
+
qout_act = torch.empty(batch_size_pad, fc1.out_features // 2, dtype=torch.uint8, device=x.device)
|
| 53 |
+
if fc2.precision == "nvfp4":
|
| 54 |
+
qout_ascales = torch.empty(fc1.out_features // 16, batch_size_pad, dtype=torch.float8_e4m3fn, device=x.device)
|
| 55 |
+
else:
|
| 56 |
+
qout_ascales = torch.empty(fc1.out_features // 64, batch_size_pad, dtype=x.dtype, device=x.device)
|
| 57 |
+
qout_lora_act = torch.empty(batch_size_pad, fc2.proj_down.shape[1], dtype=torch.float32, device=x.device)
|
| 58 |
+
|
| 59 |
+
svdq_gemm_w4a4_cuda(
|
| 60 |
+
act=quantized_x,
|
| 61 |
+
wgt=fc1.qweight,
|
| 62 |
+
qout=qout_act,
|
| 63 |
+
ascales=ascales,
|
| 64 |
+
wscales=fc1.wscales,
|
| 65 |
+
oscales=qout_ascales,
|
| 66 |
+
lora_act_in=lora_act,
|
| 67 |
+
lora_up=fc1.proj_up,
|
| 68 |
+
lora_down=fc2.proj_down,
|
| 69 |
+
lora_act_out=qout_lora_act,
|
| 70 |
+
bias=fc1.bias,
|
| 71 |
+
smooth_factor=fc2.smooth_factor,
|
| 72 |
+
fp4=fc1.precision == "nvfp4",
|
| 73 |
+
alpha=fc1.wtscale,
|
| 74 |
+
wcscales=fc1.wcscales,
|
| 75 |
+
)
|
| 76 |
+
output = torch.empty(batch_size * seq_len, fc2.out_features, dtype=x.dtype, device=x.device)
|
| 77 |
+
output = fc2.forward_quant(qout_act, qout_ascales, qout_lora_act, output=output)
|
| 78 |
+
output = output.view(batch_size, seq_len, -1)
|
| 79 |
+
return output
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def fused_qkv_norm_rottary(
|
| 83 |
+
x: torch.Tensor,
|
| 84 |
+
proj: SVDQW4A4Linear,
|
| 85 |
+
norm_q: RMSNorm,
|
| 86 |
+
norm_k: RMSNorm,
|
| 87 |
+
rotary_emb: torch.Tensor,
|
| 88 |
+
output: torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
| 89 |
+
attn_tokens: int = 0,
|
| 90 |
+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 91 |
+
"""
|
| 92 |
+
Fused quantized QKV projection with RMSNorm and rotary embeddings.
|
| 93 |
+
|
| 94 |
+
Performs quantized QKV projection, applies RMS normalization to Q and K, and fuses rotary embeddings in a single CUDA kernel call.
|
| 95 |
+
|
| 96 |
+
Parameters
|
| 97 |
+
----------
|
| 98 |
+
x : torch.Tensor, shape (B, S, C_in), dtype float16 or bfloat16
|
| 99 |
+
Input tensor.
|
| 100 |
+
proj : SVDQW4A4Linear
|
| 101 |
+
Quantized QKV projection layer.
|
| 102 |
+
norm_q : RMSNorm
|
| 103 |
+
RMSNorm for query.
|
| 104 |
+
norm_k : RMSNorm
|
| 105 |
+
RMSNorm for key.
|
| 106 |
+
rotary_emb : torch.Tensor
|
| 107 |
+
Packed rotary embedding tensor (see :func:`~nunchaku.models.embeddings.pack_rotemb`).
|
| 108 |
+
output : torch.Tensor or tuple of torch.Tensor, optional
|
| 109 |
+
Output tensor(s). If None, a new tensor is allocated.
|
| 110 |
+
If tuple, should be (output_q, output_k, output_v) for fused attention packing.
|
| 111 |
+
attn_tokens : int, optional
|
| 112 |
+
Number of attention tokens. Default is 0.
|
| 113 |
+
|
| 114 |
+
Returns
|
| 115 |
+
-------
|
| 116 |
+
torch.Tensor or tuple of torch.Tensor
|
| 117 |
+
Output tensor of shape (B, S, C_out), or tuple (output_q, output_k, output_v).
|
| 118 |
+
|
| 119 |
+
Notes
|
| 120 |
+
-----
|
| 121 |
+
Notations:
|
| 122 |
+
- B: batch size
|
| 123 |
+
- S: sequence length
|
| 124 |
+
- C_in: input features
|
| 125 |
+
- C_out: output features
|
| 126 |
+
"""
|
| 127 |
+
assert isinstance(norm_q, RMSNorm)
|
| 128 |
+
assert isinstance(norm_k, RMSNorm)
|
| 129 |
+
|
| 130 |
+
batch_size, seq_len, channels = x.shape
|
| 131 |
+
x = x.view(batch_size * seq_len, channels)
|
| 132 |
+
quantized_x, ascales, lora_act = proj.quantize(x)
|
| 133 |
+
|
| 134 |
+
if output is None:
|
| 135 |
+
output = torch.empty(quantized_x.shape[0], proj.out_features, dtype=x.dtype, device=x.device)
|
| 136 |
+
|
| 137 |
+
if isinstance(output, tuple):
|
| 138 |
+
assert len(output) == 3
|
| 139 |
+
output_q, output_k, output_v = output
|
| 140 |
+
svdq_gemm_w4a4_cuda(
|
| 141 |
+
act=quantized_x,
|
| 142 |
+
wgt=proj.qweight,
|
| 143 |
+
ascales=ascales,
|
| 144 |
+
wscales=proj.wscales,
|
| 145 |
+
lora_act_in=lora_act,
|
| 146 |
+
lora_up=proj.proj_up,
|
| 147 |
+
bias=proj.bias,
|
| 148 |
+
fp4=proj.precision == "nvfp4",
|
| 149 |
+
alpha=proj.wtscale,
|
| 150 |
+
wcscales=proj.wcscales,
|
| 151 |
+
norm_q=norm_q.weight,
|
| 152 |
+
norm_k=norm_k.weight,
|
| 153 |
+
rotary_emb=rotary_emb,
|
| 154 |
+
out_q=output_q,
|
| 155 |
+
out_k=output_k,
|
| 156 |
+
out_v=output_v,
|
| 157 |
+
attn_tokens=attn_tokens,
|
| 158 |
+
)
|
| 159 |
+
return output_q, output_k, output_v
|
| 160 |
+
else:
|
| 161 |
+
svdq_gemm_w4a4_cuda(
|
| 162 |
+
act=quantized_x,
|
| 163 |
+
wgt=proj.qweight,
|
| 164 |
+
out=output,
|
| 165 |
+
ascales=ascales,
|
| 166 |
+
wscales=proj.wscales,
|
| 167 |
+
lora_act_in=lora_act,
|
| 168 |
+
lora_up=proj.proj_up,
|
| 169 |
+
bias=proj.bias,
|
| 170 |
+
fp4=proj.precision == "nvfp4",
|
| 171 |
+
alpha=proj.wtscale,
|
| 172 |
+
wcscales=proj.wcscales,
|
| 173 |
+
norm_q=norm_q.weight,
|
| 174 |
+
norm_k=norm_k.weight,
|
| 175 |
+
rotary_emb=rotary_emb,
|
| 176 |
+
)
|
| 177 |
+
output = output.view(batch_size, seq_len, -1)
|
| 178 |
+
return output
|
nunchaku/ops/gemm.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Python wrappers for Nunchaku's high-performance quantized GEMM (General Matrix-Matrix Multiplication) CUDA kernels.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from .._C import ops
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def svdq_gemm_w4a4_cuda(
|
| 13 |
+
act: torch.Tensor,
|
| 14 |
+
wgt: torch.Tensor,
|
| 15 |
+
out: torch.Tensor | None = None,
|
| 16 |
+
qout: torch.Tensor | None = None,
|
| 17 |
+
ascales: torch.Tensor | None = None,
|
| 18 |
+
wscales: torch.Tensor | None = None,
|
| 19 |
+
oscales: torch.Tensor | None = None,
|
| 20 |
+
poolout: torch.Tensor | None = None,
|
| 21 |
+
lora_act_in: torch.Tensor | None = None,
|
| 22 |
+
lora_up: torch.Tensor | None = None,
|
| 23 |
+
lora_down: torch.Tensor | None = None,
|
| 24 |
+
lora_act_out: torch.Tensor | None = None,
|
| 25 |
+
norm_q: torch.Tensor | None = None,
|
| 26 |
+
norm_k: torch.Tensor | None = None,
|
| 27 |
+
rotary_emb: torch.Tensor | None = None,
|
| 28 |
+
bias: torch.Tensor | None = None,
|
| 29 |
+
smooth_factor: torch.Tensor | None = None,
|
| 30 |
+
out_vk: torch.Tensor | None = None,
|
| 31 |
+
out_linearattn: torch.Tensor | None = None,
|
| 32 |
+
act_unsigned: bool = False,
|
| 33 |
+
lora_scales: list[float] | None = None,
|
| 34 |
+
fuse_silu: bool = False,
|
| 35 |
+
fp4: bool = False,
|
| 36 |
+
alpha: float | None = 1.0,
|
| 37 |
+
wcscales: torch.Tensor | None = None,
|
| 38 |
+
out_q: torch.Tensor | None = None,
|
| 39 |
+
out_k: torch.Tensor | None = None,
|
| 40 |
+
out_v: torch.Tensor | None = None,
|
| 41 |
+
attn_tokens: int = 0,
|
| 42 |
+
):
|
| 43 |
+
"""
|
| 44 |
+
Quantized GEMM using SVDQuant W4A4 CUDA kernel, with support for LoRA, rotary embeddings, normalization, and fused activations.
|
| 45 |
+
|
| 46 |
+
Parameters
|
| 47 |
+
----------
|
| 48 |
+
act : torch.Tensor, shape (M, K // 2), dtype int8
|
| 49 |
+
Packed input activations.
|
| 50 |
+
wgt : torch.Tensor, shape (N, K // 2), dtype int8
|
| 51 |
+
Packed quantized weights.
|
| 52 |
+
out : torch.Tensor or None, shape (M, N), dtype float16 or bfloat16, optional
|
| 53 |
+
Output tensor for the linear layer.
|
| 54 |
+
qout : torch.Tensor or None, shape (M, N // 2), dtype int8, optional
|
| 55 |
+
Packed quantized input for the next layer.
|
| 56 |
+
ascales : torch.Tensor or None, shape (K // G, M), dtype float16/bfloat16 (INT4) or float8_e4m3fn (NVFP4), optional
|
| 57 |
+
Activation scales.
|
| 58 |
+
wscales : torch.Tensor or None, shape (K // G, N), dtype float16/bfloat16 (INT4) or float8_e4m3fn (NVFP4), optional
|
| 59 |
+
Weight scales.
|
| 60 |
+
oscales : torch.Tensor or None, shape (N // G, M), dtype float16/bfloat16 (INT4) or float8_e4m3fn (NVFP4), optional
|
| 61 |
+
Output scales.
|
| 62 |
+
poolout : torch.Tensor or None, optional
|
| 63 |
+
Reserved for future use.
|
| 64 |
+
lora_act_in : torch.Tensor or None, shape (M, R), dtype float32, optional
|
| 65 |
+
LoRA down-projection activations.
|
| 66 |
+
lora_up : torch.Tensor or None, shape (N, R), dtype float16 or bfloat16, optional
|
| 67 |
+
Packed LoRA up-projection weights.
|
| 68 |
+
lora_down : torch.Tensor or None, shape (N, R), dtype float16 or bfloat16, optional
|
| 69 |
+
Packed LoRA down-projection weights for the next layer.
|
| 70 |
+
lora_act_out : torch.Tensor or None, shape (M, R), dtype float32, optional
|
| 71 |
+
Output for LoRA down-projection in the next layer.
|
| 72 |
+
norm_q : torch.Tensor or None, shape (HEAD_DIM,), dtype float16 or bfloat16, optional
|
| 73 |
+
Query RMS normalization.
|
| 74 |
+
norm_k : torch.Tensor or None, shape (HEAD_DIM,), dtype float16 or bfloat16, optional
|
| 75 |
+
Key RMS normalization.
|
| 76 |
+
rotary_emb : torch.Tensor or None, shape (M, HEAD_DIM // 2, 2, 2), dtype float32, optional
|
| 77 |
+
Packed rotary embeddings.
|
| 78 |
+
bias : torch.Tensor or None, shape (N,), dtype float16 or bfloat16, optional
|
| 79 |
+
Bias tensor.
|
| 80 |
+
smooth_factor : torch.Tensor or None, shape (N,), dtype float16 or bfloat16, optional
|
| 81 |
+
Smoothing factor for quantization in the next layer.
|
| 82 |
+
out_vk : torch.Tensor or None, optional
|
| 83 |
+
Used only in SANA. Leave as None.
|
| 84 |
+
out_linearattn : torch.Tensor or None, optional
|
| 85 |
+
Used only in SANA. Leave as None.
|
| 86 |
+
act_unsigned : bool, default=False
|
| 87 |
+
If True, activations are unsigned (e.g., after GeLU, shifted by 0.171875). This is only used for INT4 to enable unsigned INT4 activation quantization for better quantization quality.
|
| 88 |
+
lora_scales : list of float or None, optional
|
| 89 |
+
Per-group LoRA scaling factors (16 channels per group). Defaults to 1.0 per group.
|
| 90 |
+
fuse_silu : bool, default=False
|
| 91 |
+
If True, fuse SiLU activation.
|
| 92 |
+
fp4 : bool, default=False
|
| 93 |
+
If True, use 4-bit floating point quantization (NVFP4).
|
| 94 |
+
alpha : float or None, default=1.0
|
| 95 |
+
Per-tensor scaling factor for NVFP4.
|
| 96 |
+
wcscales : torch.Tensor or None, shape (N,), dtype float8_e4m3fn, optional
|
| 97 |
+
Per-channel scaling for NVFP4.
|
| 98 |
+
out_q : torch.Tensor or None, shape (B, H, M, D), dtype int8, optional
|
| 99 |
+
Packed quantized Q for attention (used in ``nunchaku-fp16`` attention).
|
| 100 |
+
out_k : torch.Tensor or None, shape (B, H, M, D), dtype int8, optional
|
| 101 |
+
Packed quantized K for attention (used in ``nunchaku-fp16`` attention).
|
| 102 |
+
out_v : torch.Tensor or None, shape (B, H, M, D), dtype int8, optional
|
| 103 |
+
Packed quantized V for attention (used in ``nunchaku-fp16`` attention).
|
| 104 |
+
attn_tokens : int, default=0
|
| 105 |
+
Number of attention tokens.
|
| 106 |
+
|
| 107 |
+
Returns
|
| 108 |
+
-------
|
| 109 |
+
None
|
| 110 |
+
Results are written in-place to the provided output tensors.
|
| 111 |
+
|
| 112 |
+
Notes
|
| 113 |
+
-----
|
| 114 |
+
Notations:
|
| 115 |
+
|
| 116 |
+
- M: batch size (input tokens)
|
| 117 |
+
- K: input channels (feature dimension)
|
| 118 |
+
- N: output channels
|
| 119 |
+
- G: group size (64 for INT4, 16 for NVFP4)
|
| 120 |
+
- R: LoRA rank
|
| 121 |
+
- B: batch size for attention
|
| 122 |
+
- H: number of heads
|
| 123 |
+
- D: head dimension
|
| 124 |
+
"""
|
| 125 |
+
if lora_scales is None:
|
| 126 |
+
rank = lora_up.shape[1]
|
| 127 |
+
lora_scales = [1.0] * math.ceil(rank / 16)
|
| 128 |
+
if alpha is None:
|
| 129 |
+
alpha = 1.0
|
| 130 |
+
ops.gemm_w4a4(
|
| 131 |
+
act,
|
| 132 |
+
wgt,
|
| 133 |
+
out,
|
| 134 |
+
qout,
|
| 135 |
+
ascales,
|
| 136 |
+
wscales,
|
| 137 |
+
oscales,
|
| 138 |
+
poolout,
|
| 139 |
+
lora_act_in,
|
| 140 |
+
lora_up,
|
| 141 |
+
lora_down,
|
| 142 |
+
lora_act_out,
|
| 143 |
+
norm_q,
|
| 144 |
+
norm_k,
|
| 145 |
+
rotary_emb,
|
| 146 |
+
bias,
|
| 147 |
+
smooth_factor,
|
| 148 |
+
out_vk,
|
| 149 |
+
out_linearattn,
|
| 150 |
+
act_unsigned,
|
| 151 |
+
lora_scales,
|
| 152 |
+
fuse_silu,
|
| 153 |
+
fp4,
|
| 154 |
+
alpha,
|
| 155 |
+
wcscales,
|
| 156 |
+
out_q,
|
| 157 |
+
out_k,
|
| 158 |
+
out_v,
|
| 159 |
+
attn_tokens,
|
| 160 |
+
)
|
nunchaku/ops/gemv.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Python wrapper for Nunchaku's high-performance GEMV (General Matrix-Vector Multiplication) CUDA kernels.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from .._C import ops
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def awq_gemv_w4a16_cuda(
|
| 11 |
+
in_feats: torch.Tensor,
|
| 12 |
+
kernel: torch.Tensor,
|
| 13 |
+
scaling_factors: torch.Tensor,
|
| 14 |
+
zeros: torch.Tensor,
|
| 15 |
+
m: int,
|
| 16 |
+
n: int,
|
| 17 |
+
k: int,
|
| 18 |
+
group_size: int = 64,
|
| 19 |
+
) -> torch.Tensor:
|
| 20 |
+
"""
|
| 21 |
+
Performs quantized GEMV using the AWQ W4A16 format.
|
| 22 |
+
|
| 23 |
+
Parameters
|
| 24 |
+
----------
|
| 25 |
+
in_feats : torch.Tensor, shape (k,) or (m, k), dtype float16 or bfloat16
|
| 26 |
+
Input feature vector or batch of vectors.
|
| 27 |
+
kernel : torch.Tensor, shape (n // 4, k // 2), dtype int32
|
| 28 |
+
Packed quantized weight matrix.
|
| 29 |
+
scaling_factors : torch.Tensor, shape (k // group_size, n), dtype float16 or bfloat16
|
| 30 |
+
Per-group scaling factors.
|
| 31 |
+
zeros : torch.Tensor, shape (k // group_size, n), dtype float16 or bfloat16
|
| 32 |
+
Per-group zero points.
|
| 33 |
+
m : int
|
| 34 |
+
Batch size (number of input vectors).
|
| 35 |
+
n : int
|
| 36 |
+
Output feature dimension.
|
| 37 |
+
k : int
|
| 38 |
+
Input feature dimension.
|
| 39 |
+
group_size : int, optional
|
| 40 |
+
Number of input channels per quantization group. Default is 64.
|
| 41 |
+
|
| 42 |
+
Returns
|
| 43 |
+
-------
|
| 44 |
+
torch.Tensor, shape (m, n), dtype float16 or bfloat16
|
| 45 |
+
Output tensor.
|
| 46 |
+
|
| 47 |
+
Notes
|
| 48 |
+
-----
|
| 49 |
+
Notations:
|
| 50 |
+
|
| 51 |
+
- m: batch size
|
| 52 |
+
- n: output features
|
| 53 |
+
- k: input features
|
| 54 |
+
- group_size: quantization group size
|
| 55 |
+
"""
|
| 56 |
+
return ops.gemv_awq(in_feats, kernel, scaling_factors, zeros, m, n, k, group_size)
|
nunchaku/ops/quantize.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This module provides Python wrappers for Nunchaku's high-performance SVDQuant quantization CUDA kernels.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from .._C import ops
|
| 8 |
+
from ..utils import ceil_divide
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def svdq_quantize_w4a4_act_fuse_lora_cuda(
|
| 12 |
+
input: torch.Tensor,
|
| 13 |
+
output: torch.Tensor | None = None,
|
| 14 |
+
oscales: torch.Tensor | None = None,
|
| 15 |
+
lora_down: torch.Tensor | None = None,
|
| 16 |
+
lora_act_out: torch.Tensor | None = None,
|
| 17 |
+
smooth: torch.Tensor | None = None,
|
| 18 |
+
fuse_glu: bool = False,
|
| 19 |
+
fp4: bool = False,
|
| 20 |
+
pad_size: int = 256,
|
| 21 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 22 |
+
"""
|
| 23 |
+
Quantizes activations and computes LoRA down-projection using SVDQuant W4A4 CUDA kernel.
|
| 24 |
+
|
| 25 |
+
Parameters
|
| 26 |
+
----------
|
| 27 |
+
input : torch.Tensor, shape (M, K), dtype bfloat16/float16
|
| 28 |
+
Input activations.
|
| 29 |
+
output : torch.Tensor or None, shape (M_pad, K // 2), dtype uint8, optional
|
| 30 |
+
Packed output tensor for quantized activations. Allocated if None.
|
| 31 |
+
oscales : torch.Tensor or None, shape (K // G, M_pad), dtype float8_e4m3fn for NVFP4 or input dtype for INT4, optional
|
| 32 |
+
Output scales tensor. Allocated if None.
|
| 33 |
+
lora_down : torch.Tensor or None, shape (K, R), dtype bfloat16/float16, optional
|
| 34 |
+
Packed LoRA down-projection weights.
|
| 35 |
+
lora_act_out : torch.Tensor or None, shape (M_pad, R), dtype float32, optional
|
| 36 |
+
Packed output tensor for LoRA activations. Allocated if None.
|
| 37 |
+
smooth : torch.Tensor or None, optional, dtype bfloat16/float16
|
| 38 |
+
Smoothing factor for quantization.
|
| 39 |
+
fuse_glu : bool, default=False
|
| 40 |
+
If True, fuse GLU activation.
|
| 41 |
+
fp4 : bool, default=False
|
| 42 |
+
If True, use NVFP4 quantization; else INT4.
|
| 43 |
+
pad_size : int, default=256
|
| 44 |
+
Pad batch size to a multiple of this value for efficient CUDA execution.
|
| 45 |
+
|
| 46 |
+
Returns
|
| 47 |
+
-------
|
| 48 |
+
output : torch.Tensor, shape (M_pad, K // 2), dtype uint8
|
| 49 |
+
Packed quantized activations.
|
| 50 |
+
oscales : torch.Tensor, shape (K // G, M_pad), dtype float8_e4m3fn for NVFP4 or input dtype for INT4
|
| 51 |
+
Output scales.
|
| 52 |
+
lora_act_out : torch.Tensor, shape (M_pad, R), dtype float32
|
| 53 |
+
Packed LoRA activation output.
|
| 54 |
+
|
| 55 |
+
Notes
|
| 56 |
+
-----
|
| 57 |
+
Notations:
|
| 58 |
+
|
| 59 |
+
- M: batch size
|
| 60 |
+
- K: input channels
|
| 61 |
+
- R: LoRA rank
|
| 62 |
+
- G: group size (64 for INT4, 16 for NVFP4)
|
| 63 |
+
- M_pad: padded batch size = ceil(M / pad_size) * pad_size
|
| 64 |
+
"""
|
| 65 |
+
batch_size, channels = input.shape
|
| 66 |
+
rank = lora_down.shape[1]
|
| 67 |
+
batch_size_pad = ceil_divide(batch_size, pad_size) * pad_size
|
| 68 |
+
if output is None:
|
| 69 |
+
output = torch.empty(batch_size_pad, channels // 2, dtype=torch.uint8, device=input.device)
|
| 70 |
+
if oscales is None:
|
| 71 |
+
if fp4:
|
| 72 |
+
assert channels % 16 == 0
|
| 73 |
+
oscales = torch.empty(channels // 16, batch_size_pad, dtype=torch.float8_e4m3fn, device=input.device)
|
| 74 |
+
else:
|
| 75 |
+
assert channels % 64 == 0
|
| 76 |
+
oscales = torch.empty(channels // 64, batch_size_pad, dtype=input.dtype, device=input.device)
|
| 77 |
+
if lora_act_out is None:
|
| 78 |
+
lora_act_out = torch.empty(batch_size_pad, rank, dtype=torch.float32, device=input.device)
|
| 79 |
+
|
| 80 |
+
ops.quantize_w4a4_act_fuse_lora(input, output, oscales, lora_down, lora_act_out, smooth, fuse_glu, fp4)
|
| 81 |
+
return output, oscales, lora_act_out
|