D3vShoaib commited on
Commit
04eaca9
·
0 Parent(s):

Add Git LFS support and remove binary files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +102 -0
  2. app.py +150 -0
  3. app/kontext/README.md +12 -0
  4. app/kontext/assets/description.html +21 -0
  5. app/kontext/assets/style.css +40 -0
  6. app/kontext/run_gradio.py +150 -0
  7. app/kontext/utils.py +14 -0
  8. examples/__init__.py +1 -0
  9. examples/flux.1-kontext-FALAI_lora.py +30 -0
  10. examples/flux.1-kontext-dev-teacache.py +30 -0
  11. examples/flux.1-kontext-dev.py +22 -0
  12. nunchaku/__init__.py +9 -0
  13. nunchaku/__version__.py +1 -0
  14. nunchaku/csrc/flux.h +254 -0
  15. nunchaku/csrc/gemm.h +114 -0
  16. nunchaku/csrc/gemm88.h +37 -0
  17. nunchaku/csrc/module.h +85 -0
  18. nunchaku/csrc/ops.h +173 -0
  19. nunchaku/csrc/pybind.cpp +124 -0
  20. nunchaku/csrc/sana.h +102 -0
  21. nunchaku/csrc/utils.h +39 -0
  22. nunchaku/lora/__init__.py +1 -0
  23. nunchaku/lora/flux/__init__.py +5 -0
  24. nunchaku/lora/flux/compose.py +218 -0
  25. nunchaku/lora/flux/convert.py +74 -0
  26. nunchaku/lora/flux/diffusers_converter.py +220 -0
  27. nunchaku/lora/flux/nunchaku_converter.py +949 -0
  28. nunchaku/lora/flux/packer.py +517 -0
  29. nunchaku/lora/flux/utils.py +94 -0
  30. nunchaku/models/__init__.py +9 -0
  31. nunchaku/models/attention.py +123 -0
  32. nunchaku/models/embeddings.py +138 -0
  33. nunchaku/models/linear.py +414 -0
  34. nunchaku/models/normalization.py +166 -0
  35. nunchaku/models/text_encoders/__init__.py +5 -0
  36. nunchaku/models/text_encoders/linear.py +238 -0
  37. nunchaku/models/text_encoders/t5_encoder.py +116 -0
  38. nunchaku/models/text_encoders/tinychat_utils.py +188 -0
  39. nunchaku/models/transformers/__init__.py +5 -0
  40. nunchaku/models/transformers/transformer_flux.py +991 -0
  41. nunchaku/models/transformers/transformer_flux_v2.py +646 -0
  42. nunchaku/models/transformers/transformer_qwenimage.py +601 -0
  43. nunchaku/models/transformers/transformer_sana.py +374 -0
  44. nunchaku/models/transformers/utils.py +147 -0
  45. nunchaku/models/utils.py +262 -0
  46. nunchaku/ops/__init__.py +1 -0
  47. nunchaku/ops/fused.py +178 -0
  48. nunchaku/ops/gemm.py +160 -0
  49. nunchaku/ops/gemv.py +56 -0
  50. nunchaku/ops/quantize.py +81 -0
README.md ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FLUX-Kontext Optimized Implementation
2
+
3
+ This package contains an optimized implementation of FLUX-Kontext using quantization and acceleration techniques.
4
+
5
+ ## Features
6
+
7
+ - **Quantized FLUX Transformer**: Efficient INT4/FP4 quantized implementation of FLUX.1-Kontext
8
+ - **Quantized T5 Encoder**: AWQ INT4 quantized T5 text encoder for memory efficiency
9
+ - **LoRA Support**: Full support for LoRA fine-tuning and inference
10
+ - **Gradio Web Interface**: Ready-to-use web interface for image editing
11
+
12
+ ## Installation
13
+
14
+ ```bash
15
+ pip install -r requirements.txt
16
+ python setup.py build_ext --inplace
17
+ ```
18
+
19
+ ## Quick Start
20
+
21
+ ### Using the Gradio Interface
22
+
23
+ ```bash
24
+ cd app/kontext
25
+ python run_gradio.py --precision int4
26
+ ```
27
+
28
+ ### Programmatic Usage
29
+
30
+ ```python
31
+ import torch
32
+ from diffusers import FluxKontextPipeline
33
+ from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
34
+ from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
35
+
36
+ # Load quantized transformer
37
+ transformer = NunchakuFluxTransformer2dModel.from_pretrained(
38
+ "mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-int4_r32-flux.1-kontext-dev.safetensors"
39
+ )
40
+
41
+ # Load quantized text encoder (optional)
42
+ text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
43
+ "mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
44
+ )
45
+
46
+ # Create pipeline
47
+ pipeline = FluxKontextPipeline.from_pretrained(
48
+ "black-forest-labs/FLUX.1-Kontext-dev",
49
+ transformer=transformer,
50
+ text_encoder_2=text_encoder_2,
51
+ torch_dtype=torch.bfloat16
52
+ )
53
+ pipeline = pipeline.to("cuda")
54
+
55
+ # Generate image
56
+ result = pipeline(
57
+ prompt="Your prompt here",
58
+ image=your_input_image,
59
+ num_inference_steps=28,
60
+ guidance_scale=2.5,
61
+ ).images[0]
62
+ ```
63
+
64
+ ## Available Models
65
+
66
+ - `int4`: INT4 quantized transformer (default, most memory efficient)
67
+ - `fp4`: FP4 quantized transformer
68
+ - `bf16`: Full precision BFloat16 (highest quality, most memory usage)
69
+
70
+ ## Directory Structure
71
+
72
+ ```
73
+ flux-kontext/
74
+ ├── nunchaku/ # Core quantized models and utilities
75
+ │ ├── models/ # Transformer and text encoder models
76
+ │ ├── lora/ # LoRA utilities
77
+ │ ├── ops/ # Quantized operations
78
+ │ └── csrc/ # C++ CUDA kernels
79
+ ├── app/ # Application interfaces
80
+ │ └── kontext/ # Gradio web interface
81
+ ├── examples/ # Example scripts
82
+ └── tests/ # Test scripts
83
+ ```
84
+
85
+ ## Examples
86
+
87
+ See the `examples/` directory for various usage patterns:
88
+
89
+ - `flux.1-kontext-dev.py`: Basic usage example
90
+ - `flux.1-kontext-dev-teacache.py`: Using TeaCache for acceleration
91
+ - `flux.1-kontext-FALAI_lora.py`: LoRA fine-tuning example
92
+
93
+ ## Requirements
94
+
95
+ - Python >= 3.10
96
+ - PyTorch >= 2.5
97
+ - CUDA-capable GPU (recommended)
98
+ - 8GB+ GPU memory (for INT4 quantization)
99
+
100
+ ## License
101
+
102
+ See the main nunchaku project for license information.
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import time
3
+
4
+ import torch
5
+ from diffusers import FluxKontextPipeline
6
+ from PIL import Image
7
+ from utils import get_args
8
+ from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
9
+ from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
10
+
11
+
12
+ import gradio as gr
13
+
14
+
15
+ MAX_SEED = 1000000000
16
+
17
+ args = get_args()
18
+
19
+ if args.precision == "bf16":
20
+ pipeline = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16)
21
+ pipeline = pipeline.to("cuda")
22
+ pipeline.precision = "bf16"
23
+ else:
24
+ assert args.precision in ["int4", "fp4"]
25
+ pipeline_init_kwargs = {}
26
+ transformer = NunchakuFluxTransformer2dModel.from_pretrained(
27
+ f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{args.precision}_r32-flux.1-kontext-dev.safetensors"
28
+ )
29
+ pipeline_init_kwargs["transformer"] = transformer
30
+ if args.use_qencoder:
31
+ text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
32
+ "mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
33
+ )
34
+ pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
35
+
36
+ pipeline = FluxKontextPipeline.from_pretrained(
37
+ "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
38
+ )
39
+ pipeline = pipeline.to("cuda")
40
+ pipeline.precision = args.precision
41
+
42
+
43
+ def run(image, prompt: str, num_inference_steps: int, guidance_scale: float, seed: int) -> tuple[Image, str]:
44
+ img = image["composite"].convert("RGB")
45
+
46
+ start_time = time.time()
47
+ result_image = pipeline(
48
+ prompt=prompt,
49
+ image=img,
50
+ height=img.height,
51
+ width=img.width,
52
+ num_inference_steps=num_inference_steps,
53
+ guidance_scale=guidance_scale,
54
+ generator=torch.Generator().manual_seed(seed),
55
+ ).images[0]
56
+
57
+ latency = time.time() - start_time
58
+ if latency < 1:
59
+ latency = latency * 1000
60
+ latency_str = f"{latency:.2f}ms"
61
+ else:
62
+ latency_str = f"{latency:.2f}s"
63
+ torch.cuda.empty_cache()
64
+ return result_image, latency_str
65
+
66
+
67
+ with gr.Blocks(css_paths="assets/style.css", title="Nunchaku FLUX.1-Kontext Demo") as demo:
68
+ with open("assets/description.html", "r") as f:
69
+ DESCRIPTION = f.read()
70
+ # Get the GPU properties
71
+ if torch.cuda.device_count() > 0:
72
+ gpu_properties = torch.cuda.get_device_properties(0)
73
+ gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB
74
+ gpu_name = torch.cuda.get_device_name(0)
75
+ device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
76
+ else:
77
+ device_info = "Running on CPU 🥶 This demo does not work on CPU."
78
+
79
+ header_str = DESCRIPTION.format(precision=args.precision.upper(), device_info=device_info, count_info="")
80
+ header = gr.HTML(header_str)
81
+
82
+ with gr.Row(elem_id="main_row"):
83
+ with gr.Column(elem_id="column_input"):
84
+ gr.Markdown("## INPUT", elem_id="input_header")
85
+ with gr.Group():
86
+ canvas = gr.ImageEditor(
87
+ height=640,
88
+ image_mode="RGB",
89
+ sources=["upload", "clipboard"],
90
+ type="pil",
91
+ label="Input",
92
+ show_label=False,
93
+ show_download_button=True,
94
+ interactive=True,
95
+ transforms=[],
96
+ canvas_size=(1024, 1024),
97
+ scale=1,
98
+ format="png",
99
+ layers=False,
100
+ )
101
+ with gr.Row():
102
+ prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
103
+ run_button = gr.Button("Run", scale=1, elem_id="run_button")
104
+
105
+ with gr.Row():
106
+ seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
107
+ randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
108
+ with gr.Accordion("Advanced options", open=False):
109
+ with gr.Group():
110
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, value=28)
111
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, step=0.1, value=2.5)
112
+
113
+ with gr.Column(elem_id="column_output"):
114
+ gr.Markdown("## OUTPUT", elem_id="output_header")
115
+ with gr.Group():
116
+ result = gr.Image(
117
+ format="png",
118
+ height=640,
119
+ image_mode="RGB",
120
+ type="pil",
121
+ label="Result",
122
+ show_label=False,
123
+ show_download_button=True,
124
+ interactive=False,
125
+ elem_id="output_image",
126
+ )
127
+ latency_result = gr.Text(label="Inference Latency", show_label=True)
128
+
129
+ gr.Markdown("### Instructions")
130
+ gr.Markdown("**1**. Enter a text prompt")
131
+ gr.Markdown("**2**. Upload an image")
132
+ gr.Markdown("**3**. Try different seeds to generate different results")
133
+
134
+ run_inputs = [canvas, prompt, num_inference_steps, guidance_scale, seed]
135
+ run_outputs = [result, latency_result]
136
+
137
+ randomize_seed.click(
138
+ lambda: random.randint(0, MAX_SEED), inputs=[], outputs=seed, api_name=False, queue=False
139
+ ).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False)
140
+
141
+ gr.on(
142
+ triggers=[prompt.submit, run_button.click],
143
+ fn=run,
144
+ inputs=run_inputs,
145
+ outputs=run_outputs,
146
+ api_name=False,
147
+ )
148
+
149
+ if __name__ == "__main__":
150
+ demo.queue().launch(debug=True, share=True, root_path=args.gradio_root_path)
app/kontext/README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Nunchaku INT4 FLUX.1 Kontext Demo
2
+
3
+ ![demo](https://huggingface.co/mit-han-lab/nunchaku-artifacts/resolve/main/nunchaku/assets/kontext.png)
4
+
5
+ This interactive Gradio application allows you to edit an image with natural language. Simply run:
6
+
7
+ ```shell
8
+ python run_gradio.py
9
+ ```
10
+
11
+ - To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`.
12
+ - By default, we use our INT4 model. Use `-p bf16` to switch to the BF16 model.
app/kontext/assets/description.html ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
2
+ <div>
3
+ <!-- Logo Row -->
4
+ <div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
5
+ <a href="https://github.com/mit-han-lab/nunchaku" target="_blank">
6
+ <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
7
+ alt="nunchaku logo" style="height: 150px; width: auto;" />
8
+ </a>
9
+ <a href="https://hanlab.mit.edu/projects/svdquant" target="_blank">
10
+ <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
11
+ alt="svdquant logo" style="height: 40px; width: auto;" />
12
+ </a>
13
+ </div>
14
+ <h1 style="margin-top: 0;">{precision} FLUX.1-Kontext-dev Demo</h1>
15
+
16
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
17
+ {device_info}
18
+ </div>
19
+ {count_info}
20
+ </div>
21
+ </div>
app/kontext/assets/style.css ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/css/all.min.css');
2
+
3
+ .gradio-container {
4
+ max-width: 1200px !important;
5
+ margin: auto; /* Centers the element horizontally */
6
+ }
7
+
8
+ h1 {
9
+ text-align: center
10
+ }
11
+
12
+ .wrap.svelte-p4aq0j.svelte-p4aq0j {
13
+ display: none;
14
+ }
15
+
16
+ #column_input, #column_output {
17
+ width: 500px;
18
+ display: flex;
19
+ align-items: center;
20
+ }
21
+
22
+ #input_header, #output_header {
23
+ display: flex;
24
+ justify-content: center;
25
+ align-items: center;
26
+ width: 400px;
27
+ }
28
+
29
+ #accessibility {
30
+ text-align: center; /* Center-aligns the text */
31
+ margin: auto; /* Centers the element horizontally */
32
+ }
33
+
34
+ #random_seed {
35
+ height: 71px;
36
+ }
37
+
38
+ #run_button {
39
+ height: 87px;
40
+ }
app/kontext/run_gradio.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import time
3
+
4
+ import torch
5
+ from diffusers import FluxKontextPipeline
6
+ from PIL import Image
7
+ from utils import get_args
8
+ from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
9
+ from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
10
+
11
+
12
+ import gradio as gr
13
+
14
+
15
+ MAX_SEED = 1000000000
16
+
17
+ args = get_args()
18
+
19
+ if args.precision == "bf16":
20
+ pipeline = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16)
21
+ pipeline = pipeline.to("cuda")
22
+ pipeline.precision = "bf16"
23
+ else:
24
+ assert args.precision in ["int4", "fp4"]
25
+ pipeline_init_kwargs = {}
26
+ transformer = NunchakuFluxTransformer2dModel.from_pretrained(
27
+ f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{args.precision}_r32-flux.1-kontext-dev.safetensors"
28
+ )
29
+ pipeline_init_kwargs["transformer"] = transformer
30
+ if args.use_qencoder:
31
+ text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
32
+ "mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
33
+ )
34
+ pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
35
+
36
+ pipeline = FluxKontextPipeline.from_pretrained(
37
+ "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
38
+ )
39
+ pipeline = pipeline.to("cuda")
40
+ pipeline.precision = args.precision
41
+
42
+
43
+ def run(image, prompt: str, num_inference_steps: int, guidance_scale: float, seed: int) -> tuple[Image, str]:
44
+ img = image["composite"].convert("RGB")
45
+
46
+ start_time = time.time()
47
+ result_image = pipeline(
48
+ prompt=prompt,
49
+ image=img,
50
+ height=img.height,
51
+ width=img.width,
52
+ num_inference_steps=num_inference_steps,
53
+ guidance_scale=guidance_scale,
54
+ generator=torch.Generator().manual_seed(seed),
55
+ ).images[0]
56
+
57
+ latency = time.time() - start_time
58
+ if latency < 1:
59
+ latency = latency * 1000
60
+ latency_str = f"{latency:.2f}ms"
61
+ else:
62
+ latency_str = f"{latency:.2f}s"
63
+ torch.cuda.empty_cache()
64
+ return result_image, latency_str
65
+
66
+
67
+ with gr.Blocks(css_paths="assets/style.css", title="Nunchaku FLUX.1-Kontext Demo") as demo:
68
+ with open("assets/description.html", "r") as f:
69
+ DESCRIPTION = f.read()
70
+ # Get the GPU properties
71
+ if torch.cuda.device_count() > 0:
72
+ gpu_properties = torch.cuda.get_device_properties(0)
73
+ gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB
74
+ gpu_name = torch.cuda.get_device_name(0)
75
+ device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
76
+ else:
77
+ device_info = "Running on CPU 🥶 This demo does not work on CPU."
78
+
79
+ header_str = DESCRIPTION.format(precision=args.precision.upper(), device_info=device_info, count_info="")
80
+ header = gr.HTML(header_str)
81
+
82
+ with gr.Row(elem_id="main_row"):
83
+ with gr.Column(elem_id="column_input"):
84
+ gr.Markdown("## INPUT", elem_id="input_header")
85
+ with gr.Group():
86
+ canvas = gr.ImageEditor(
87
+ height=640,
88
+ image_mode="RGB",
89
+ sources=["upload", "clipboard"],
90
+ type="pil",
91
+ label="Input",
92
+ show_label=False,
93
+ show_download_button=True,
94
+ interactive=True,
95
+ transforms=[],
96
+ canvas_size=(1024, 1024),
97
+ scale=1,
98
+ format="png",
99
+ layers=False,
100
+ )
101
+ with gr.Row():
102
+ prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
103
+ run_button = gr.Button("Run", scale=1, elem_id="run_button")
104
+
105
+ with gr.Row():
106
+ seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
107
+ randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
108
+ with gr.Accordion("Advanced options", open=False):
109
+ with gr.Group():
110
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, value=28)
111
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, step=0.1, value=2.5)
112
+
113
+ with gr.Column(elem_id="column_output"):
114
+ gr.Markdown("## OUTPUT", elem_id="output_header")
115
+ with gr.Group():
116
+ result = gr.Image(
117
+ format="png",
118
+ height=640,
119
+ image_mode="RGB",
120
+ type="pil",
121
+ label="Result",
122
+ show_label=False,
123
+ show_download_button=True,
124
+ interactive=False,
125
+ elem_id="output_image",
126
+ )
127
+ latency_result = gr.Text(label="Inference Latency", show_label=True)
128
+
129
+ gr.Markdown("### Instructions")
130
+ gr.Markdown("**1**. Enter a text prompt")
131
+ gr.Markdown("**2**. Upload an image")
132
+ gr.Markdown("**3**. Try different seeds to generate different results")
133
+
134
+ run_inputs = [canvas, prompt, num_inference_steps, guidance_scale, seed]
135
+ run_outputs = [result, latency_result]
136
+
137
+ randomize_seed.click(
138
+ lambda: random.randint(0, MAX_SEED), inputs=[], outputs=seed, api_name=False, queue=False
139
+ ).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False)
140
+
141
+ gr.on(
142
+ triggers=[prompt.submit, run_button.click],
143
+ fn=run,
144
+ inputs=run_inputs,
145
+ outputs=run_outputs,
146
+ api_name=False,
147
+ )
148
+
149
+ if __name__ == "__main__":
150
+ demo.queue().launch(debug=True, share=True, root_path=args.gradio_root_path)
app/kontext/utils.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ def get_args() -> argparse.Namespace:
5
+ parser = argparse.ArgumentParser()
6
+ parser.add_argument(
7
+ "-p", "--precision", type=str, default="int4", choices=["int4", "fp4", "bf16"], help="Which precisions to use"
8
+ )
9
+ parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
10
+ parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
11
+ parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
12
+ parser.add_argument("--gradio-root-path", type=str, default="")
13
+ args = parser.parse_args()
14
+ return args
examples/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # FLUX-Kontext examples
examples/flux.1-kontext-FALAI_lora.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import FluxKontextPipeline
3
+ from diffusers.utils import load_image
4
+
5
+ from nunchaku import NunchakuFluxTransformer2dModel
6
+ from nunchaku.utils import get_precision
7
+
8
+ transformer = NunchakuFluxTransformer2dModel.from_pretrained(
9
+ f"nunchaku-tech/nunchaku-flux.1-kontext-dev/svdq-{get_precision()}_r32-flux.1-kontext-dev.safetensors"
10
+ )
11
+
12
+ pipeline = FluxKontextPipeline.from_pretrained(
13
+ "black-forest-labs/FLUX.1-Kontext-dev", transformer=transformer, torch_dtype=torch.bfloat16
14
+ ).to("cuda")
15
+
16
+ image = load_image(
17
+ "https://huggingface.co/datasets/nunchaku-tech/test-data/resolve/main/ComfyUI-nunchaku/inputs/monalisa.jpg"
18
+ ).convert("RGB")
19
+
20
+ ### LoRA Related Code ###
21
+ transformer.update_lora_params(
22
+ "nunchaku-tech/nunchaku-test-models/relight-kontext-lora-single-caption_comfy.safetensors"
23
+ # "linoyts/relight-kontext-lora-single-caption/relight-kontext-lora-single-caption.safetensors"
24
+ ) # Path to your LoRA safetensors, can also be a remote HuggingFace path
25
+ transformer.set_lora_strength(1) # Your LoRA strength here
26
+ ### End of LoRA Related Code ###
27
+
28
+ prompt = "neon light, city"
29
+ image = pipeline(image=image, prompt=prompt, generator=torch.Generator().manual_seed(23), guidance_scale=2.5).images[0]
30
+ image.save("flux-kontext-dev.png")
examples/flux.1-kontext-dev-teacache.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import torch
4
+ from diffusers import FluxKontextPipeline
5
+ from diffusers.utils import load_image
6
+
7
+ from nunchaku import NunchakuFluxTransformer2dModel
8
+ from nunchaku.caching.teacache import TeaCache
9
+ from nunchaku.utils import get_precision
10
+
11
+ transformer = NunchakuFluxTransformer2dModel.from_pretrained(
12
+ f"nunchaku-tech/nunchaku-flux.1-kontext-dev/svdq-{get_precision()}_r32-flux.1-kontext-dev.safetensors"
13
+ )
14
+
15
+ pipeline = FluxKontextPipeline.from_pretrained(
16
+ "black-forest-labs/FLUX.1-Kontext-dev", transformer=transformer, torch_dtype=torch.bfloat16
17
+ ).to("cuda")
18
+
19
+ image = load_image(
20
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
21
+ ).convert("RGB")
22
+
23
+ prompt = "Make Pikachu hold a sign that says 'Nunchaku is awesome', yarn art style, detailed, vibrant colors"
24
+
25
+ start_time = time.time()
26
+ with TeaCache(model=transformer, num_steps=50, rel_l1_thresh=0.3, enabled=True, model_name="flux-kontext"):
27
+ image = pipeline(image=image, prompt=prompt, guidance_scale=2.5).images[0]
28
+ end_time = time.time()
29
+ print(f"Time taken: {(end_time - start_time)} seconds")
30
+ image.save(f"flux-kontext-dev-{get_precision()}-tc.png")
examples/flux.1-kontext-dev.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import FluxKontextPipeline
3
+ from diffusers.utils import load_image
4
+
5
+ from nunchaku import NunchakuFluxTransformer2dModel
6
+ from nunchaku.utils import get_precision
7
+
8
+ transformer = NunchakuFluxTransformer2dModel.from_pretrained(
9
+ f"nunchaku-tech/nunchaku-flux.1-kontext-dev/svdq-{get_precision()}_r32-flux.1-kontext-dev.safetensors"
10
+ )
11
+
12
+ pipeline = FluxKontextPipeline.from_pretrained(
13
+ "black-forest-labs/FLUX.1-Kontext-dev", transformer=transformer, torch_dtype=torch.bfloat16
14
+ ).to("cuda")
15
+
16
+ image = load_image(
17
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
18
+ ).convert("RGB")
19
+
20
+ prompt = "Make Pikachu hold a sign that says 'Nunchaku is awesome', yarn art style, detailed, vibrant colors"
21
+ image = pipeline(image=image, prompt=prompt, guidance_scale=2.5).images[0]
22
+ image.save("flux-kontext-dev.png")
nunchaku/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .models import (
2
+ NunchakuFluxTransformer2dModel,
3
+ NunchakuT5EncoderModel,
4
+ )
5
+
6
+ __all__ = [
7
+ "NunchakuFluxTransformer2dModel",
8
+ "NunchakuT5EncoderModel",
9
+ ]
nunchaku/__version__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "1.0.0-flux-kontext"
nunchaku/csrc/flux.h ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "interop/torch.h"
4
+ #include "FluxModel.h"
5
+ #include "Serialization.h"
6
+ #include "debug.h"
7
+ #include "Linear.h"
8
+ #include "module.h"
9
+
10
+ class QuantizedFluxModel : public ModuleWrapper<FluxModel> { // : public torch::CustomClassHolder {
11
+ public:
12
+ void init(bool use_fp4, bool offload, bool bf16, int8_t deviceId) {
13
+ spdlog::info("Initializing QuantizedFluxModel on device {}", deviceId);
14
+ if (!bf16) {
15
+ spdlog::info("Use FP16 model");
16
+ }
17
+ if (offload) {
18
+ spdlog::info("Layer offloading enabled");
19
+ }
20
+ ModuleWrapper::init(deviceId);
21
+
22
+ CUDADeviceContext ctx(this->deviceId);
23
+ net = std::make_unique<FluxModel>(
24
+ use_fp4, offload, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
25
+ }
26
+
27
+ bool isBF16() {
28
+ checkModel();
29
+ return net->dtype == Tensor::BF16;
30
+ }
31
+ pybind11::function residual_callback;
32
+ void set_residual_callback(pybind11::function callback) {
33
+ pybind11::gil_scoped_acquire gil;
34
+ if (!callback || callback.is_none()) {
35
+ residual_callback = pybind11::function();
36
+ if (net) {
37
+ net->set_residual_callback(nullptr);
38
+ }
39
+ return;
40
+ }
41
+ residual_callback = std::move(callback);
42
+ if (net) {
43
+ pybind11::object cb = residual_callback;
44
+ net->set_residual_callback([cb](const Tensor &x) -> Tensor {
45
+ torch::Tensor torch_x = to_torch(x);
46
+ pybind11::object result = cb(torch_x);
47
+ torch::Tensor torch_y = result.cast<torch::Tensor>();
48
+ Tensor y = from_torch(torch_y);
49
+ return y;
50
+ });
51
+ } else {
52
+ }
53
+ }
54
+
55
+ torch::Tensor forward(torch::Tensor hidden_states,
56
+ torch::Tensor encoder_hidden_states,
57
+ torch::Tensor temb,
58
+ torch::Tensor rotary_emb_img,
59
+ torch::Tensor rotary_emb_context,
60
+ torch::Tensor rotary_emb_single,
61
+ std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
62
+ std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt,
63
+ bool skip_first_layer = false) {
64
+ checkModel();
65
+ CUDADeviceContext ctx(deviceId);
66
+
67
+ spdlog::debug("QuantizedFluxModel forward");
68
+
69
+ hidden_states = hidden_states.contiguous();
70
+ encoder_hidden_states = encoder_hidden_states.contiguous();
71
+ temb = temb.contiguous();
72
+ rotary_emb_img = rotary_emb_img.contiguous();
73
+ rotary_emb_context = rotary_emb_context.contiguous();
74
+ rotary_emb_single = rotary_emb_single.contiguous();
75
+
76
+ Tensor result = net->forward(
77
+ from_torch(hidden_states),
78
+ from_torch(encoder_hidden_states),
79
+ from_torch(temb),
80
+ from_torch(rotary_emb_img),
81
+ from_torch(rotary_emb_context),
82
+ from_torch(rotary_emb_single),
83
+ controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
84
+ controlnet_single_block_samples.has_value()
85
+ ? from_torch(controlnet_single_block_samples.value().contiguous())
86
+ : Tensor{},
87
+ skip_first_layer);
88
+
89
+ torch::Tensor output = to_torch(result);
90
+ Tensor::synchronizeDevice();
91
+
92
+ return output;
93
+ }
94
+
95
+ std::tuple<torch::Tensor, torch::Tensor>
96
+ forward_layer(int64_t idx,
97
+ torch::Tensor hidden_states,
98
+ torch::Tensor encoder_hidden_states,
99
+ torch::Tensor temb,
100
+ torch::Tensor rotary_emb_img,
101
+ torch::Tensor rotary_emb_context,
102
+ std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
103
+ std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt) {
104
+ CUDADeviceContext ctx(deviceId);
105
+
106
+ spdlog::debug("QuantizedFluxModel forward_layer {}", idx);
107
+
108
+ hidden_states = hidden_states.contiguous();
109
+ encoder_hidden_states = encoder_hidden_states.contiguous();
110
+ temb = temb.contiguous();
111
+ rotary_emb_img = rotary_emb_img.contiguous();
112
+ rotary_emb_context = rotary_emb_context.contiguous();
113
+
114
+ auto &&[hidden_states_, encoder_hidden_states_] = net->forward_layer(
115
+ idx,
116
+ from_torch(hidden_states),
117
+ from_torch(encoder_hidden_states),
118
+ from_torch(temb),
119
+ from_torch(rotary_emb_img),
120
+ from_torch(rotary_emb_context),
121
+ controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
122
+ controlnet_single_block_samples.has_value()
123
+ ? from_torch(controlnet_single_block_samples.value().contiguous())
124
+ : Tensor{});
125
+
126
+ hidden_states = to_torch(hidden_states_);
127
+ encoder_hidden_states = to_torch(encoder_hidden_states_);
128
+ Tensor::synchronizeDevice();
129
+
130
+ return {hidden_states, encoder_hidden_states};
131
+ }
132
+
133
+ torch::Tensor forward_single_layer(int64_t idx,
134
+ torch::Tensor hidden_states,
135
+ torch::Tensor temb,
136
+ torch::Tensor rotary_emb_single) {
137
+ CUDADeviceContext ctx(deviceId);
138
+
139
+ spdlog::debug("QuantizedFluxModel forward_single_layer {}", idx);
140
+
141
+ hidden_states = hidden_states.contiguous();
142
+ temb = temb.contiguous();
143
+ rotary_emb_single = rotary_emb_single.contiguous();
144
+
145
+ if (net->isOffloadEnabled()) {
146
+ net->single_transformer_blocks.at(idx)->loadLazyParams();
147
+ }
148
+
149
+ Tensor result = net->single_transformer_blocks.at(idx)->forward(
150
+ from_torch(hidden_states), from_torch(temb), from_torch(rotary_emb_single));
151
+
152
+ if (net->isOffloadEnabled()) {
153
+ net->single_transformer_blocks.at(idx)->releaseLazyParams();
154
+ }
155
+
156
+ hidden_states = to_torch(result);
157
+ Tensor::synchronizeDevice();
158
+
159
+ return hidden_states;
160
+ }
161
+
162
+ // expose the norm1 forward method of the transformer blocks
163
+ // this is used by TeaCache to get the norm1 output
164
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
165
+ norm_one_forward(int64_t idx, torch::Tensor hidden_states, torch::Tensor temb) {
166
+ AdaLayerNormZero::Output result =
167
+ net->transformer_blocks.at(idx)->norm1.forward(from_torch(hidden_states), from_torch(temb));
168
+ return {to_torch(result.x),
169
+ to_torch(result.gate_msa),
170
+ to_torch(result.shift_mlp),
171
+ to_torch(result.scale_mlp),
172
+ to_torch(result.gate_mlp)};
173
+ }
174
+
175
+ // must be called after loading lora
176
+ // skip specific ranks in W4A4 layers
177
+ void setLoraScale(int skipRanks, float scale) {
178
+ if (skipRanks % 16 != 0) {
179
+ throw std::invalid_argument("skipRanks must be multiples of 16");
180
+ }
181
+
182
+ CUDADeviceContext ctx(deviceId);
183
+
184
+ spdlog::info("Set lora scale to {} (skip {} ranks)", scale, skipRanks);
185
+
186
+ net->traverse([&](Module *module) {
187
+ if (auto *m = dynamic_cast<GEMV_AWQ *>(module)) {
188
+ m->lora_scale = scale;
189
+ } else if (auto *m = dynamic_cast<GEMM_W4A4 *>(module)) {
190
+ for (int i = 0; i < skipRanks / 16; i++) {
191
+ m->lora_scales[i] = 1.0f;
192
+ }
193
+ for (int i = skipRanks / 16; i < (int)m->lora_scales.size(); i++) {
194
+ m->lora_scales[i] = scale;
195
+ }
196
+ }
197
+ });
198
+ }
199
+
200
+ void setAttentionImpl(std::string name) {
201
+ if (name.empty() || name == "default") {
202
+ name = "flashattn2";
203
+ }
204
+
205
+ spdlog::info("Set attention implementation to {}", name);
206
+
207
+ if (name == "flashattn2") {
208
+ net->setAttentionImpl(AttentionImpl::FlashAttention2);
209
+ } else if (name == "nunchaku-fp16") {
210
+ net->setAttentionImpl(AttentionImpl::NunchakuFP16);
211
+ } else {
212
+ throw std::invalid_argument(spdlog::fmt_lib::format("Invalid attention implementation {}", name));
213
+ }
214
+ }
215
+
216
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
217
+ forward_layer_ip_adapter(int64_t idx,
218
+ torch::Tensor hidden_states,
219
+ torch::Tensor encoder_hidden_states,
220
+ torch::Tensor temb,
221
+ torch::Tensor rotary_emb_img,
222
+ torch::Tensor rotary_emb_context,
223
+ std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
224
+ std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt) {
225
+ CUDADeviceContext ctx(deviceId);
226
+
227
+ spdlog::debug("QuantizedFluxModel forward_layer {}", idx);
228
+
229
+ hidden_states = hidden_states.contiguous();
230
+ encoder_hidden_states = encoder_hidden_states.contiguous();
231
+ temb = temb.contiguous();
232
+ rotary_emb_img = rotary_emb_img.contiguous();
233
+ rotary_emb_context = rotary_emb_context.contiguous();
234
+
235
+ auto &&[hidden_states_, encoder_hidden_states_, ip_query_] = net->forward_ip_adapter(
236
+ idx,
237
+ from_torch(hidden_states),
238
+ from_torch(encoder_hidden_states),
239
+ from_torch(temb),
240
+ from_torch(rotary_emb_img),
241
+ from_torch(rotary_emb_context),
242
+ controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
243
+ controlnet_single_block_samples.has_value()
244
+ ? from_torch(controlnet_single_block_samples.value().contiguous())
245
+ : Tensor{});
246
+
247
+ hidden_states = to_torch(hidden_states_);
248
+ encoder_hidden_states = to_torch(encoder_hidden_states_);
249
+ torch::Tensor ip_query = to_torch(ip_query_);
250
+ Tensor::synchronizeDevice();
251
+
252
+ return {hidden_states, encoder_hidden_states, ip_query};
253
+ }
254
+ };
nunchaku/csrc/gemm.h ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "interop/torch.h"
4
+ #include "Serialization.h"
5
+ #include "Linear.h"
6
+ #include "debug.h"
7
+ #include "module.h"
8
+
9
+ class QuantizedGEMM : public ModuleWrapper<GEMM_W4A4> {
10
+ public:
11
+ void init(int64_t in_features, int64_t out_features, bool bias, bool use_fp4, bool bf16, int8_t deviceId) {
12
+ spdlog::info("Initializing QuantizedGEMM");
13
+
14
+ size_t val = 0;
15
+ checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, 8192));
16
+ checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
17
+ spdlog::debug("Stack={}", val);
18
+
19
+ net = std::make_unique<GEMM_W4A4>((int)in_features,
20
+ (int)out_features,
21
+ bias,
22
+ use_fp4,
23
+ bf16 ? Tensor::BF16 : Tensor::FP16,
24
+ Device::cuda((int)deviceId));
25
+ }
26
+
27
+ torch::Tensor forward(torch::Tensor x) {
28
+ checkModel();
29
+
30
+ std::cerr << "QuantizedGEMM forward" << std::endl;
31
+
32
+ x = x.contiguous();
33
+
34
+ Tensor result = net->forward(from_torch(x));
35
+
36
+ torch::Tensor output = to_torch(result);
37
+ Tensor::synchronizeDevice();
38
+
39
+ return output;
40
+ }
41
+
42
+ std::string dumpTensorBF16(Tensor x) {
43
+ std::stringstream ss;
44
+ for (int i = 0; i < 256; i++) {
45
+ ss << spdlog::fmt_lib::format("{:.3f} ", (float)(x.data_ptr<__nv_bfloat16>()[i]));
46
+ }
47
+ ss << std::endl;
48
+ return ss.str();
49
+ }
50
+
51
+ std::string dumpTensorINT4(Tensor x) {
52
+ using spdlog::fmt_lib::format;
53
+
54
+ const int M = x.shape[0];
55
+ const int K = x.shape[1] * 2;
56
+
57
+ assert(x.dtype() == Tensor::INT8);
58
+
59
+ // activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t (uint4)
60
+
61
+ constexpr int BLOCK_M = 256;
62
+ constexpr int WARP_K = 64;
63
+ constexpr int NUM_WARPS = 8;
64
+ constexpr int WARP_M_TILES = 2;
65
+ constexpr int WARP_SIZE = 32;
66
+
67
+ std::stringstream ss;
68
+ for (int bm = 0; bm < M / BLOCK_M; bm++) {
69
+ for (int bn = 0; bn < K / WARP_K; bn++) {
70
+ for (int warpId = 0; warpId < NUM_WARPS; warpId++) {
71
+ ss << format("[bm={},bn={},warp={}] ", bm, bn, warpId);
72
+ const int offset = ((bm * (K / WARP_K) + bn) * NUM_WARPS + warpId) * WARP_M_TILES * WARP_SIZE * 4;
73
+
74
+ for (int i = 0; i < 16; i++) {
75
+ assert(static_cast<size_t>(offset + i) < x.numel() / 4);
76
+ uint32_t val = x.data_ptr<uint32_t>()[offset + i];
77
+ ss << "{";
78
+ for (int j = 0; j < 8; j++) {
79
+ int i4val = (val >> (j * 4)) & 0xf;
80
+ if (i4val & 0x8) {
81
+ i4val = -((~i4val & 0x7) + 1);
82
+ }
83
+ ss << format("{} ", i4val);
84
+ }
85
+ ss << format("}} {:x} ", val);
86
+ }
87
+ ss << std::endl;
88
+ }
89
+ }
90
+ }
91
+
92
+ ss << std::endl;
93
+ return ss.str();
94
+ }
95
+
96
+ void quantize(torch::Tensor x, bool fuse_glu) {
97
+ checkModel();
98
+
99
+ spdlog::debug("QuantizedGEMM quantize");
100
+
101
+ x = x.contiguous();
102
+
103
+ auto qout = net->quantize(from_torch(x), fuse_glu);
104
+
105
+ Tensor act = qout.act.copy(Device::cpu());
106
+ Tensor ascales = qout.ascales.copy(Device::cpu());
107
+ Tensor lora_act = qout.lora_act.copy(Device::cpu());
108
+
109
+ Tensor::synchronizeDevice();
110
+
111
+ spdlog::debug("act = {}", dumpTensorINT4(act));
112
+ spdlog::debug("ascales = {}", dumpTensorBF16(ascales));
113
+ }
114
+ };
nunchaku/csrc/gemm88.h ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "interop/torch.h"
4
+ #include "Serialization.h"
5
+ #include "Linear.h"
6
+ #include "debug.h"
7
+ #include "module.h"
8
+
9
+ class QuantizedGEMM88 : public ModuleWrapper<GEMM_W8A8> {
10
+ public:
11
+ void init(int64_t in_features, int64_t out_features, bool bias, bool bf16, int8_t deviceId) {
12
+ spdlog::info("Initializing QuantizedGEMM88");
13
+
14
+ size_t val = 0;
15
+ checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, 8192));
16
+ checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
17
+ spdlog::debug("Stack={}", val);
18
+
19
+ net = std::make_unique<GEMM_W8A8>(
20
+ (int)in_features, (int)out_features, bias, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
21
+ }
22
+
23
+ torch::Tensor forward(torch::Tensor x) {
24
+ checkModel();
25
+
26
+ std::cerr << "QuantizedGEMM88 forward" << std::endl;
27
+
28
+ x = x.contiguous();
29
+
30
+ Tensor result = net->forward(from_torch(x));
31
+
32
+ torch::Tensor output = to_torch(result);
33
+ Tensor::synchronizeDevice();
34
+
35
+ return output;
36
+ }
37
+ };
nunchaku/csrc/module.h ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "interop/torch.h"
4
+ #include "Serialization.h"
5
+ #include "Module.h"
6
+ #include "debug.h"
7
+ #include "utils.h"
8
+
9
+ template<typename M>
10
+ class ModuleWrapper {
11
+ public:
12
+ void init(int deviceId) {
13
+ this->deviceId = deviceId;
14
+ }
15
+ void reset() {
16
+ CUDADeviceContext ctx(this->deviceId);
17
+
18
+ debugContext.reset();
19
+ net.reset();
20
+ Tensor::synchronizeDevice();
21
+
22
+ nunchaku::utils::trim_memory();
23
+ Tensor::synchronizeDevice();
24
+ }
25
+
26
+ void load(std::string path, bool partial = false) {
27
+ checkModel();
28
+ CUDADeviceContext ctx(this->deviceId);
29
+
30
+ spdlog::info("{} weights from {}", partial ? "Loading partial" : "Loading", path);
31
+
32
+ std::shared_ptr<SafeTensors> provider = std::make_shared<SafeTensors>(path);
33
+ net->loadParams(*provider, partial);
34
+ Tensor::synchronizeDevice();
35
+
36
+ spdlog::info("Done.");
37
+ }
38
+
39
+ void loadDict(std::map<std::string, torch::Tensor> dict, bool partial = false) {
40
+ checkModel();
41
+ CUDADeviceContext ctx(this->deviceId);
42
+
43
+ spdlog::info("{} weights from pytorch", partial ? "Loading partial" : "Loading");
44
+
45
+ std::shared_ptr<TensorsProviderTorch> provider = std::make_shared<TensorsProviderTorch>(std::move(dict));
46
+ net->loadParams(*provider, partial);
47
+ Tensor::synchronizeDevice();
48
+
49
+ spdlog::info("Done.");
50
+ }
51
+
52
+ void startDebug() {
53
+ debugContext = std::make_unique<DebugContext>();
54
+ }
55
+ void stopDebug() {
56
+ debugContext.reset();
57
+ }
58
+
59
+ auto getDebugResults() {
60
+ CUDADeviceContext ctx(this->deviceId);
61
+
62
+ std::map<std::string, torch::Tensor> result;
63
+
64
+ if (debugContext) {
65
+ for (auto &&[key, value] : debugContext->tensors) {
66
+ result[key] = to_torch(value);
67
+ }
68
+ }
69
+
70
+ return result;
71
+ }
72
+
73
+ protected:
74
+ void checkModel() {
75
+ if (!net) {
76
+ throw std::runtime_error("Model not initialized");
77
+ }
78
+ }
79
+
80
+ protected:
81
+ std::unique_ptr<M> net;
82
+ std::unique_ptr<DebugContext> debugContext;
83
+
84
+ int deviceId = -1;
85
+ };
nunchaku/csrc/ops.h ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "interop/torch.h"
4
+ #include "kernels/zgemm/zgemm.h"
5
+ #include "kernels/awq/gemv_awq.h"
6
+ #include "kernels/awq/gemm_awq.h"
7
+
8
+ namespace nunchaku::ops {
9
+
10
+ void gemm_w4a4(std::optional<torch::Tensor> act, // packed act [M, K / 2]
11
+ std::optional<torch::Tensor> wgt, // packed act [N, K / 2]
12
+ std::optional<torch::Tensor> out, // linear [M, N]
13
+ std::optional<torch::Tensor> qout, // packed act [M, N / 2]
14
+ std::optional<torch::Tensor> ascales, // packed as [K / 64, M]
15
+ std::optional<torch::Tensor> wscales, // packed ws [K / 64, N]
16
+ std::optional<torch::Tensor> oscales, // packed as [N / 64, M]
17
+ std::optional<torch::Tensor> poolout, // linear [M / PoolSize, N]
18
+ std::optional<torch::Tensor> lora_act_in, // packed lora_act [M, R]
19
+ std::optional<torch::Tensor> lora_up, // packed lora_wgt [N, R]
20
+ std::optional<torch::Tensor> lora_down, // packed lora_wgt [N, R]
21
+ std::optional<torch::Tensor> lora_act_out, // packed lora_act [M, R]
22
+ std::optional<torch::Tensor> norm_q, // linear [HEAD_DIM]
23
+ std::optional<torch::Tensor> norm_k, // linear [HEAD_DIM]
24
+ std::optional<torch::Tensor> rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
25
+ std::optional<torch::Tensor> bias, // packed ws [N]
26
+ std::optional<torch::Tensor> smooth_factor, // packed ws [N], for quantization of the next layer
27
+ std::optional<torch::Tensor> out_vk, // linear [B, num_heads, head_dim + 1, head_dim]
28
+ std::optional<torch::Tensor> out_linearattn, // linear [B, (M), N / 3]
29
+ bool act_unsigned,
30
+ std::vector<float> lora_scales,
31
+ bool fuse_silu,
32
+ bool fp4,
33
+ float alpha,
34
+ std::optional<torch::Tensor> wcscales,
35
+ std::optional<torch::Tensor> out_q, // packed attention [B, H, M, D]
36
+ std::optional<torch::Tensor> out_k, // packed attention [B, H, M, D]
37
+ std::optional<torch::Tensor> out_v, // packed attention [B, H, M, D]
38
+ int attn_tokens) {
39
+ TorchOpContext ctx;
40
+ spdlog::trace("running gemm_w4a4: ");
41
+
42
+ auto getTensor = [](std::optional<torch::Tensor> &t) {
43
+ Tensor ret = t.has_value() ? from_torch(t.value()) : Tensor{};
44
+ if (ret.valid()) {
45
+ spdlog::trace(" {}", ret.shape.str());
46
+ } else {
47
+ spdlog::trace(" <invalid>");
48
+ }
49
+ return ret;
50
+ };
51
+ nunchaku::kernels::gemm_w4a4(getTensor(act),
52
+ getTensor(wgt),
53
+ getTensor(out),
54
+ getTensor(qout),
55
+ getTensor(ascales),
56
+ getTensor(wscales),
57
+ getTensor(oscales),
58
+ getTensor(poolout),
59
+ getTensor(lora_act_in),
60
+ getTensor(lora_up),
61
+ getTensor(lora_down),
62
+ getTensor(lora_act_out),
63
+ getTensor(norm_q),
64
+ getTensor(norm_k),
65
+ getTensor(rotary_emb),
66
+ getTensor(bias),
67
+ getTensor(smooth_factor),
68
+ getTensor(out_vk),
69
+ getTensor(out_linearattn),
70
+ act_unsigned,
71
+ lora_scales,
72
+ fuse_silu,
73
+ fp4,
74
+ alpha,
75
+ getTensor(wcscales),
76
+ getTensor(out_q),
77
+ getTensor(out_k),
78
+ getTensor(out_v),
79
+ attn_tokens);
80
+ // Tensor::synchronizeDevice();
81
+ }
82
+
83
+ void quantize_w4a4_act_fuse_lora(std::optional<torch::Tensor> input,
84
+ std::optional<torch::Tensor> output,
85
+ std::optional<torch::Tensor> oscales,
86
+ std::optional<torch::Tensor> lora_down,
87
+ std::optional<torch::Tensor> lora_act_out,
88
+ std::optional<torch::Tensor> smooth,
89
+ bool fuse_glu,
90
+ bool fp4) {
91
+ TorchOpContext ctx;
92
+
93
+ spdlog::trace("running quantize_w4a4_act_fuse_lora: ");
94
+
95
+ auto getTensor = [](std::optional<torch::Tensor> &t) {
96
+ Tensor ret = t.has_value() ? from_torch(t.value()) : Tensor{};
97
+ if (ret.valid()) {
98
+ spdlog::trace(" {}", ret.shape.str());
99
+ } else {
100
+ spdlog::trace(" <invalid>");
101
+ }
102
+ return ret;
103
+ };
104
+ nunchaku::kernels::quantize_w4a4_act_fuse_lora(getTensor(input),
105
+ getTensor(output),
106
+ getTensor(oscales),
107
+ getTensor(lora_down),
108
+ getTensor(lora_act_out),
109
+ getTensor(smooth),
110
+ fuse_glu,
111
+ fp4);
112
+ }
113
+
114
+ void attention_fp16(torch::Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
115
+ torch::Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM]
116
+ torch::Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM]
117
+ torch::Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM]
118
+ float scale) {
119
+ TorchOpContext ctx;
120
+ nunchaku::kernels::attention_fp16(from_torch(q), from_torch(k), from_torch(v), from_torch(o), scale);
121
+ }
122
+
123
+ torch::Tensor gemv_awq(torch::Tensor _in_feats,
124
+ torch::Tensor _kernel,
125
+ torch::Tensor _scaling_factors,
126
+ torch::Tensor _zeros,
127
+ int64_t m,
128
+ int64_t n,
129
+ int64_t k,
130
+ int64_t group_size) {
131
+ TorchOpContext ctx;
132
+ Tensor result = ::gemv_awq(from_torch(_in_feats.contiguous()),
133
+ from_torch(_kernel.contiguous()),
134
+ from_torch(_scaling_factors.contiguous()),
135
+ from_torch(_zeros.contiguous()),
136
+ (int)m,
137
+ (int)n,
138
+ (int)k,
139
+ (int)group_size);
140
+
141
+ torch::Tensor output = to_torch(result);
142
+ // Tensor::synchronizeDevice();
143
+
144
+ return output;
145
+ }
146
+
147
+ torch::Tensor
148
+ gemm_awq(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros) {
149
+ Tensor result = ::awq_gemm_forward_cuda(from_torch(_in_feats.contiguous()),
150
+ from_torch(_kernel.contiguous()),
151
+ from_torch(_scaling_factors.contiguous()),
152
+ from_torch(_zeros.contiguous()));
153
+
154
+ TorchOpContext ctx;
155
+ // TODO: allocate output in torch and use from_torch instead (to_torch needs an extra copy)
156
+ torch::Tensor output = to_torch(result);
157
+ // Tensor::synchronizeDevice();
158
+
159
+ return output;
160
+ }
161
+
162
+ void test_rmsnorm_rope(
163
+ torch::Tensor input, torch::Tensor output, torch::Tensor norm_q, torch::Tensor norm_k, torch::Tensor rotary_emb) {
164
+ nunchaku::kernels::test_rmsnorm_rope(
165
+ from_torch(input), from_torch(output), from_torch(norm_q), from_torch(norm_k), from_torch(rotary_emb));
166
+ }
167
+
168
+ void test_pack_qkv(torch::Tensor input, torch::Tensor out_q, torch::Tensor out_k, torch::Tensor out_v, int numTokens) {
169
+ nunchaku::kernels::test_pack_qkv(
170
+ from_torch(input), from_torch(out_q), from_torch(out_k), from_torch(out_v), numTokens);
171
+ }
172
+
173
+ }; // namespace nunchaku::ops
nunchaku/csrc/pybind.cpp ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "gemm.h"
2
+ #include "gemm88.h"
3
+ #include "flux.h"
4
+ #include "sana.h"
5
+ #include "ops.h"
6
+ #include "utils.h"
7
+ #include <torch/extension.h>
8
+ #include "interop/torch.h"
9
+ #include <pybind11/pybind11.h>
10
+
11
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
12
+ py::class_<QuantizedFluxModel>(m, "QuantizedFluxModel")
13
+ .def(py::init<>())
14
+ .def("init",
15
+ &QuantizedFluxModel::init,
16
+ py::arg("use_fp4"),
17
+ py::arg("offload"),
18
+ py::arg("bf16"),
19
+ py::arg("deviceId"))
20
+ .def("set_residual_callback",
21
+ [](QuantizedFluxModel &self, pybind11::object call_back) {
22
+ if (call_back.is_none()) {
23
+ self.set_residual_callback(pybind11::function());
24
+ } else {
25
+ self.set_residual_callback(call_back);
26
+ }
27
+ })
28
+ .def("reset", &QuantizedFluxModel::reset)
29
+ .def("load", &QuantizedFluxModel::load, py::arg("path"), py::arg("partial") = false)
30
+ .def("loadDict", &QuantizedFluxModel::loadDict, py::arg("dict"), py::arg("partial") = false)
31
+ .def("forward",
32
+ &QuantizedFluxModel::forward,
33
+ py::arg("hidden_states"),
34
+ py::arg("encoder_hidden_states"),
35
+ py::arg("temb"),
36
+ py::arg("rotary_emb_img"),
37
+ py::arg("rotary_emb_context"),
38
+ py::arg("rotary_emb_single"),
39
+ py::arg("controlnet_block_samples") = py::none(),
40
+ py::arg("controlnet_single_block_samples") = py::none(),
41
+ py::arg("skip_first_layer") = false)
42
+ .def("forward_layer",
43
+ &QuantizedFluxModel::forward_layer,
44
+ py::arg("idx"),
45
+ py::arg("hidden_states"),
46
+ py::arg("encoder_hidden_states"),
47
+ py::arg("temb"),
48
+ py::arg("rotary_emb_img"),
49
+ py::arg("rotary_emb_context"),
50
+ py::arg("controlnet_block_samples") = py::none(),
51
+ py::arg("controlnet_single_block_samples") = py::none())
52
+ .def("forward_layer_ip_adapter",
53
+ &QuantizedFluxModel::forward_layer_ip_adapter,
54
+ py::arg("idx"),
55
+ py::arg("hidden_states"),
56
+ py::arg("encoder_hidden_states"),
57
+ py::arg("temb"),
58
+ py::arg("rotary_emb_img"),
59
+ py::arg("rotary_emb_context"),
60
+ py::arg("controlnet_block_samples") = py::none(),
61
+ py::arg("controlnet_single_block_samples") = py::none())
62
+ .def("forward_single_layer", &QuantizedFluxModel::forward_single_layer)
63
+ .def("norm_one_forward", &QuantizedFluxModel::norm_one_forward)
64
+ .def("startDebug", &QuantizedFluxModel::startDebug)
65
+ .def("stopDebug", &QuantizedFluxModel::stopDebug)
66
+ .def("getDebugResults", &QuantizedFluxModel::getDebugResults)
67
+ .def("setLoraScale", &QuantizedFluxModel::setLoraScale)
68
+ .def("setAttentionImpl", &QuantizedFluxModel::setAttentionImpl)
69
+ .def("isBF16", &QuantizedFluxModel::isBF16);
70
+ py::class_<QuantizedSanaModel>(m, "QuantizedSanaModel")
71
+ .def(py::init<>())
72
+ .def("init",
73
+ &QuantizedSanaModel::init,
74
+ py::arg("config"),
75
+ py::arg("pag_layers"),
76
+ py::arg("use_fp4"),
77
+ py::arg("bf16"),
78
+ py::arg("deviceId"))
79
+ .def("reset", &QuantizedSanaModel::reset)
80
+ .def("load", &QuantizedSanaModel::load, py::arg("path"), py::arg("partial") = false)
81
+ .def("loadDict", &QuantizedSanaModel::loadDict, py::arg("dict"), py::arg("partial") = false)
82
+ .def("forward", &QuantizedSanaModel::forward)
83
+ .def("forward_layer", &QuantizedSanaModel::forward_layer)
84
+ .def("startDebug", &QuantizedSanaModel::startDebug)
85
+ .def("stopDebug", &QuantizedSanaModel::stopDebug)
86
+ .def("getDebugResults", &QuantizedSanaModel::getDebugResults);
87
+ py::class_<QuantizedGEMM>(m, "QuantizedGEMM")
88
+ .def(py::init<>())
89
+ .def("init", &QuantizedGEMM::init)
90
+ .def("reset", &QuantizedGEMM::reset)
91
+ .def("load", &QuantizedGEMM::load)
92
+ .def("forward", &QuantizedGEMM::forward)
93
+ .def("quantize", &QuantizedGEMM::quantize)
94
+ .def("startDebug", &QuantizedGEMM::startDebug)
95
+ .def("stopDebug", &QuantizedGEMM::stopDebug)
96
+ .def("getDebugResults", &QuantizedGEMM::getDebugResults);
97
+ py::class_<Tensor>(m, "Tensor");
98
+ py::class_<QuantizedGEMM88>(m, "QuantizedGEMM88")
99
+ .def(py::init<>())
100
+ .def("init", &QuantizedGEMM88::init)
101
+ .def("reset", &QuantizedGEMM88::reset)
102
+ .def("load", &QuantizedGEMM88::load)
103
+ .def("forward", &QuantizedGEMM88::forward)
104
+ .def("startDebug", &QuantizedGEMM88::startDebug)
105
+ .def("stopDebug", &QuantizedGEMM88::stopDebug)
106
+ .def("getDebugResults", &QuantizedGEMM88::getDebugResults);
107
+
108
+ m.def_submodule("ops")
109
+ .def("gemm_w4a4", nunchaku::ops::gemm_w4a4)
110
+ .def("quantize_w4a4_act_fuse_lora", nunchaku::ops::quantize_w4a4_act_fuse_lora)
111
+ .def("attention_fp16", nunchaku::ops::attention_fp16)
112
+ .def("gemm_awq", nunchaku::ops::gemm_awq)
113
+ .def("gemv_awq", nunchaku::ops::gemv_awq)
114
+
115
+ .def("test_rmsnorm_rope", nunchaku::ops::test_rmsnorm_rope)
116
+ .def("test_pack_qkv", nunchaku::ops::test_pack_qkv);
117
+
118
+ m.def_submodule("utils")
119
+ .def("set_log_level", [](const std::string &level) { spdlog::set_level(spdlog::level::from_str(level)); })
120
+ .def("set_cuda_stack_limit", nunchaku::utils::set_cuda_stack_limit)
121
+ .def("disable_memory_auto_release", nunchaku::utils::disable_memory_auto_release)
122
+ .def("trim_memory", nunchaku::utils::trim_memory)
123
+ .def("set_faster_i2f_mode", nunchaku::utils::set_faster_i2f_mode);
124
+ }
nunchaku/csrc/sana.h ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "interop/torch.h"
4
+ #include "SanaModel.h"
5
+ #include "Serialization.h"
6
+ #include "debug.h"
7
+ #include "module.h"
8
+
9
+ class QuantizedSanaModel : public ModuleWrapper<SanaModel> {
10
+ public:
11
+ void init(pybind11::dict config, std::vector<int> pag_layers, bool use_fp4, bool bf16, int8_t deviceId) {
12
+ spdlog::info("Initializing QuantizedSanaModel on device {}", deviceId);
13
+ SanaConfig cfg{
14
+ .num_layers = config["num_layers"].cast<int>(),
15
+ .num_attention_heads = config["num_attention_heads"].cast<int>(),
16
+ .attention_head_dim = config["attention_head_dim"].cast<int>(),
17
+ .num_cross_attention_heads = config["num_cross_attention_heads"].cast<int>(),
18
+ .expand_ratio = config["mlp_ratio"].cast<double>(),
19
+ .pag_layers = pag_layers,
20
+ .use_fp4 = use_fp4,
21
+ };
22
+
23
+ ModuleWrapper::init(deviceId);
24
+ CUDADeviceContext ctx(this->deviceId);
25
+ net = std::make_unique<SanaModel>(cfg, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
26
+ }
27
+
28
+ torch::Tensor forward(torch::Tensor hidden_states,
29
+ torch::Tensor encoder_hidden_states,
30
+ torch::Tensor timestep,
31
+ torch::Tensor cu_seqlens_img,
32
+ torch::Tensor cu_seqlens_txt,
33
+ int H,
34
+ int W,
35
+ bool pag,
36
+ bool cfg,
37
+ bool skip_first_layer = false) {
38
+ checkModel();
39
+ CUDADeviceContext ctx(deviceId);
40
+
41
+ spdlog::debug("QuantizedSanaModel forward");
42
+
43
+ hidden_states = hidden_states.contiguous();
44
+ encoder_hidden_states = encoder_hidden_states.contiguous();
45
+ timestep = timestep.contiguous();
46
+ cu_seqlens_img = cu_seqlens_img.contiguous();
47
+ cu_seqlens_txt = cu_seqlens_txt.contiguous();
48
+
49
+ Tensor result = net->forward(from_torch(hidden_states),
50
+ from_torch(encoder_hidden_states),
51
+ from_torch(timestep),
52
+ from_torch(cu_seqlens_img),
53
+ from_torch(cu_seqlens_txt),
54
+ H,
55
+ W,
56
+ pag,
57
+ cfg,
58
+ skip_first_layer);
59
+
60
+ torch::Tensor output = to_torch(result);
61
+ // Tensor::synchronizeDevice();
62
+
63
+ return output;
64
+ }
65
+
66
+ torch::Tensor forward_layer(int64_t idx,
67
+ torch::Tensor hidden_states,
68
+ torch::Tensor encoder_hidden_states,
69
+ torch::Tensor timestep,
70
+ torch::Tensor cu_seqlens_img,
71
+ torch::Tensor cu_seqlens_txt,
72
+ int H,
73
+ int W,
74
+ bool pag,
75
+ bool cfg) {
76
+ checkModel();
77
+ CUDADeviceContext ctx(deviceId);
78
+
79
+ spdlog::debug("QuantizedSanaModel forward_layer {}", idx);
80
+
81
+ hidden_states = hidden_states.contiguous();
82
+ encoder_hidden_states = encoder_hidden_states.contiguous();
83
+ timestep = timestep.contiguous();
84
+ cu_seqlens_img = cu_seqlens_img.contiguous();
85
+ cu_seqlens_txt = cu_seqlens_txt.contiguous();
86
+
87
+ Tensor result = net->transformer_blocks.at(idx)->forward(from_torch(hidden_states),
88
+ from_torch(encoder_hidden_states),
89
+ from_torch(timestep),
90
+ from_torch(cu_seqlens_img),
91
+ from_torch(cu_seqlens_txt),
92
+ H,
93
+ W,
94
+ pag,
95
+ cfg);
96
+
97
+ torch::Tensor output = to_torch(result);
98
+ // Tensor::synchronizeDevice();
99
+
100
+ return output;
101
+ }
102
+ };
nunchaku/csrc/utils.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "common.h"
4
+ #include "Tensor.h"
5
+ #include "kernels/zgemm/zgemm.h"
6
+
7
+ namespace nunchaku::utils {
8
+
9
+ void set_cuda_stack_limit(int64_t newval) {
10
+ size_t val = 0;
11
+ checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, (size_t)newval));
12
+ checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
13
+ spdlog::debug("Stack={}", val);
14
+ }
15
+
16
+ void disable_memory_auto_release() {
17
+ int device;
18
+ checkCUDA(cudaGetDevice(&device));
19
+ cudaMemPool_t mempool;
20
+ checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device));
21
+ uint64_t threshold = UINT64_MAX;
22
+ checkCUDA(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &threshold));
23
+ }
24
+
25
+ void trim_memory() {
26
+ int device;
27
+ checkCUDA(cudaGetDevice(&device));
28
+ cudaMemPool_t mempool;
29
+ checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device));
30
+ size_t bytesToKeep = 0;
31
+ checkCUDA(cudaMemPoolTrimTo(mempool, bytesToKeep));
32
+ }
33
+
34
+ void set_faster_i2f_mode(std::string mode) {
35
+ spdlog::info("Set fasteri2f mode to {}", mode);
36
+ kernels::set_faster_i2f_mode(mode);
37
+ }
38
+
39
+ }; // namespace nunchaku::utils
nunchaku/lora/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # LoRA utilities for FLUX models
nunchaku/lora/flux/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .diffusers_converter import to_diffusers
2
+ from .nunchaku_converter import convert_to_nunchaku_flux_lowrank_dict, to_nunchaku
3
+ from .utils import is_nunchaku_format
4
+
5
+ __all__ = ["to_diffusers", "to_nunchaku", "convert_to_nunchaku_flux_lowrank_dict", "is_nunchaku_format"]
nunchaku/lora/flux/compose.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Compose multiple LoRA weights into a single LoRA for FLUX models.
3
+
4
+ This script merges several LoRA safetensors files into one, applying individual strength values to each.
5
+
6
+ **Example Usage:**
7
+
8
+ .. code-block:: bash
9
+
10
+ python -m nunchaku.lora.flux.compose \\
11
+ -i lora1.safetensors lora2.safetensors \\
12
+ -s 0.8 1.0 \\
13
+ -o composed_lora.safetensors
14
+
15
+ **Arguments:**
16
+
17
+ - ``-i``, ``--input-paths``: Input LoRA safetensors files (one or more).
18
+ - ``-s``, ``--strengths``: Strength value for each LoRA (must match number of inputs).
19
+ - ``-o``, ``--output-path``: Output path for the composed LoRA safetensors file.
20
+
21
+ This will merge ``lora1.safetensors`` (strength 0.8) and ``lora2.safetensors`` (strength 1.0) into ``composed_lora.safetensors``.
22
+
23
+ **Main Function**
24
+
25
+ :func:`compose_lora`
26
+ """
27
+
28
+ import argparse
29
+ import os
30
+
31
+ import torch
32
+ import torch.nn.functional as F
33
+ from safetensors.torch import save_file
34
+
35
+ from .diffusers_converter import to_diffusers
36
+ from .utils import is_nunchaku_format, load_state_dict_in_safetensors
37
+
38
+
39
+ def compose_lora(
40
+ loras: list[tuple[str | dict[str, torch.Tensor], float]], output_path: str | None = None
41
+ ) -> dict[str, torch.Tensor]:
42
+ """
43
+ Compose multiple LoRA weights into a single LoRA representation.
44
+
45
+ Parameters
46
+ ----------
47
+ loras : list of (str or dict[str, torch.Tensor], float)
48
+ Each tuple contains:
49
+ - Path to a LoRA safetensors file or a LoRA weights dictionary.
50
+ - Strength/scale factor for that LoRA.
51
+ output_path : str, optional
52
+ Path to save the composed LoRA weights as a safetensors file. If None, does not save.
53
+
54
+ Returns
55
+ -------
56
+ dict[str, torch.Tensor]
57
+ The composed LoRA weights.
58
+
59
+ Raises
60
+ ------
61
+ AssertionError
62
+ If LoRA weights are in Nunchaku format (must be converted to Diffusers format first)
63
+ or if tensor shapes are incompatible.
64
+
65
+ Notes
66
+ -----
67
+ - Converts all input LoRAs to Diffusers format.
68
+ - Handles QKV projection fusion for attention layers.
69
+ - Applies strength scaling to LoRA weights.
70
+ - Concatenates multiple LoRAs along appropriate dimensions.
71
+ - Handles normalization layers, bias vectors, and FLUX.1-tools LoRA compatibility.
72
+
73
+ Examples
74
+ --------
75
+ >>> lora_paths = [("lora1.safetensors", 0.8), ("lora2.safetensors", 0.6)]
76
+ >>> composed = compose_lora(lora_paths, "composed_lora.safetensors")
77
+ >>> lora_dicts = [({"layer.weight": torch.randn(10, 20)}, 1.0)]
78
+ >>> composed = compose_lora(lora_dicts)
79
+ """
80
+ if len(loras) == 1:
81
+ if is_nunchaku_format(loras[0][0]) and (loras[0][1] - 1) < 1e-5:
82
+ if isinstance(loras[0][0], str):
83
+ return load_state_dict_in_safetensors(loras[0][0], device="cpu")
84
+ else:
85
+ return loras[0][0]
86
+
87
+ composed = {}
88
+ for lora, strength in loras:
89
+ assert not is_nunchaku_format(lora)
90
+ lora = to_diffusers(lora)
91
+ for k, v in list(lora.items()):
92
+ if v.ndim == 1:
93
+ previous_tensor = composed.get(k, None)
94
+ if previous_tensor is None:
95
+ if "norm_q" in k or "norm_k" in k or "norm_added_q" in k or "norm_added_k" in k:
96
+ composed[k] = v
97
+ else:
98
+ composed[k] = v * strength
99
+ else:
100
+ assert not ("norm_q" in k or "norm_k" in k or "norm_added_q" in k or "norm_added_k" in k)
101
+ composed[k] = previous_tensor + v * strength
102
+ else:
103
+ assert v.ndim == 2
104
+ if ".to_q." in k or ".add_q_proj." in k: # qkv must all exist
105
+ if "lora_B" in k:
106
+ continue
107
+
108
+ q_a = v
109
+ k_a = lora[k.replace(".to_q.", ".to_k.").replace(".add_q_proj.", ".add_k_proj.")]
110
+ v_a = lora[k.replace(".to_q.", ".to_v.").replace(".add_q_proj.", ".add_v_proj.")]
111
+
112
+ q_b = lora[k.replace("lora_A", "lora_B")]
113
+ k_b = lora[
114
+ k.replace("lora_A", "lora_B")
115
+ .replace(".to_q.", ".to_k.")
116
+ .replace(".add_q_proj.", ".add_k_proj.")
117
+ ]
118
+ v_b = lora[
119
+ k.replace("lora_A", "lora_B")
120
+ .replace(".to_q.", ".to_v.")
121
+ .replace(".add_q_proj.", ".add_v_proj.")
122
+ ]
123
+
124
+ # Add paddings if their ranks are different
125
+ max_rank = max(q_a.shape[0], k_a.shape[0], v_a.shape[0])
126
+ q_a = F.pad(q_a, (0, 0, 0, max_rank - q_a.shape[0]))
127
+ k_a = F.pad(k_a, (0, 0, 0, max_rank - k_a.shape[0]))
128
+ v_a = F.pad(v_a, (0, 0, 0, max_rank - v_a.shape[0]))
129
+ q_b = F.pad(q_b, (0, max_rank - q_b.shape[1]))
130
+ k_b = F.pad(k_b, (0, max_rank - k_b.shape[1]))
131
+ v_b = F.pad(v_b, (0, max_rank - v_b.shape[1]))
132
+
133
+ if torch.isclose(q_a, k_a).all() and torch.isclose(q_a, v_a).all():
134
+ lora_a = q_a
135
+ lora_b = torch.cat((q_b, k_b, v_b), dim=0)
136
+ else:
137
+ lora_a_group = (q_a, k_a, v_a)
138
+ new_shape_a = [sum([_.shape[0] for _ in lora_a_group]), q_a.shape[1]]
139
+ lora_a = torch.zeros(new_shape_a, dtype=q_a.dtype, device=q_a.device)
140
+ start_dim = 0
141
+ for tensor in lora_a_group:
142
+ lora_a[start_dim : start_dim + tensor.shape[0]] = tensor
143
+ start_dim += tensor.shape[0]
144
+
145
+ lora_b_group = (q_b, k_b, v_b)
146
+ new_shape_b = [sum([_.shape[0] for _ in lora_b_group]), sum([_.shape[1] for _ in lora_b_group])]
147
+ lora_b = torch.zeros(new_shape_b, dtype=q_b.dtype, device=q_b.device)
148
+ start_dims = (0, 0)
149
+ for tensor in lora_b_group:
150
+ end_dims = (start_dims[0] + tensor.shape[0], start_dims[1] + tensor.shape[1])
151
+ lora_b[start_dims[0] : end_dims[0], start_dims[1] : end_dims[1]] = tensor
152
+ start_dims = end_dims
153
+
154
+ lora_a = lora_a * strength
155
+
156
+ new_k_a = k.replace(".to_q.", ".to_qkv.").replace(".add_q_proj.", ".add_qkv_proj.")
157
+ new_k_b = new_k_a.replace("lora_A", "lora_B")
158
+
159
+ for kk, vv, dim in ((new_k_a, lora_a, 0), (new_k_b, lora_b, 1)):
160
+ previous_lora = composed.get(kk, None)
161
+ composed[kk] = vv if previous_lora is None else torch.cat([previous_lora, vv], dim=dim)
162
+
163
+ elif ".to_k." in k or ".to_v." in k or ".add_k_proj." in k or ".add_v_proj." in k:
164
+ continue
165
+ else:
166
+ if "lora_A" in k:
167
+ v = v * strength
168
+
169
+ previous_lora = composed.get(k, None)
170
+ if previous_lora is None:
171
+ composed[k] = v
172
+ else:
173
+ if "lora_A" in k:
174
+ if previous_lora.shape[1] != v.shape[1]: # flux.1-tools LoRA compatibility
175
+ assert "x_embedder" in k
176
+ expanded_size = max(previous_lora.shape[1], v.shape[1])
177
+ if expanded_size > previous_lora.shape[1]:
178
+ expanded_previous_lora = torch.zeros(
179
+ (previous_lora.shape[0], expanded_size),
180
+ device=previous_lora.device,
181
+ dtype=previous_lora.dtype,
182
+ )
183
+ expanded_previous_lora[:, : previous_lora.shape[1]] = previous_lora
184
+ else:
185
+ expanded_previous_lora = previous_lora
186
+ if expanded_size > v.shape[1]:
187
+ expanded_v = torch.zeros(
188
+ (v.shape[0], expanded_size), device=v.device, dtype=v.dtype
189
+ )
190
+ expanded_v[:, : v.shape[1]] = v
191
+ else:
192
+ expanded_v = v
193
+ composed[k] = torch.cat([expanded_previous_lora, expanded_v], dim=0)
194
+ else:
195
+ composed[k] = torch.cat([previous_lora, v], dim=0)
196
+ else:
197
+ composed[k] = torch.cat([previous_lora, v], dim=1)
198
+
199
+ composed[k] = (
200
+ v if previous_lora is None else torch.cat([previous_lora, v], dim=0 if "lora_A" in k else 1)
201
+ )
202
+ if output_path is not None:
203
+ output_dir = os.path.dirname(os.path.abspath(output_path))
204
+ os.makedirs(output_dir, exist_ok=True)
205
+ save_file(composed, output_path)
206
+ return composed
207
+
208
+
209
+ if __name__ == "__main__":
210
+ parser = argparse.ArgumentParser()
211
+ parser.add_argument(
212
+ "-i", "--input-paths", type=str, nargs="*", required=True, help="paths to the lora safetensors files"
213
+ )
214
+ parser.add_argument("-s", "--strengths", type=float, nargs="*", required=True, help="strengths for each lora")
215
+ parser.add_argument("-o", "--output-path", type=str, required=True, help="path to the output safetensors file")
216
+ args = parser.parse_args()
217
+ assert len(args.input_paths) == len(args.strengths)
218
+ compose_lora(list(zip(args.input_paths, args.strengths)), args.output_path)
nunchaku/lora/flux/convert.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CLI tool to convert LoRA weights to Nunchaku format.
3
+
4
+ **Example Usage:**
5
+
6
+ .. code-block:: bash
7
+
8
+ python -m nunchaku.lora.flux.convert \\
9
+ --lora-path composed_lora.safetensors \\
10
+ --quant-path mit-han-lab/svdq-int4-flux.1-dev/transformer_blocks.safetensors \\
11
+ --output-root ./converted \\
12
+ --dtype bfloat16
13
+
14
+ **Arguments:**
15
+
16
+ - ``--lora-path``: Path to the LoRA weights safetensor file (required)
17
+ - ``--quant-path``: Path to the quantized model safetensor file (default: ``mit-han-lab/svdq-int4-flux.1-dev/transformer_blocks.safetensors``)
18
+ - ``--output-root``: Root directory for the output safetensor file (default: parent directory of the lora file)
19
+ - ``--lora-name``: Name of the LoRA weights (optional, auto-generated if not provided)
20
+ - ``--dtype``: Data type of the converted weights, either ``bfloat16`` or ``float16`` (default: ``bfloat16``)
21
+
22
+ **Main Function**
23
+
24
+ :func:`nunchaku.lora.flux.nunchaku_converter.to_nunchaku`
25
+ """
26
+
27
+ import argparse
28
+ import os
29
+
30
+ from .nunchaku_converter import to_nunchaku
31
+ from .utils import is_nunchaku_format
32
+
33
+ if __name__ == "__main__":
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument(
36
+ "--quant-path",
37
+ type=str,
38
+ help="Path to the quantized model safetensors file.",
39
+ default="mit-han-lab/svdq-int4-flux.1-dev/transformer_blocks.safetensors",
40
+ )
41
+ parser.add_argument("--lora-path", type=str, required=True, help="Path to LoRA weights safetensors file.")
42
+ parser.add_argument("--output-root", type=str, default="", help="Root directory for output safetensors file.")
43
+ parser.add_argument("--lora-name", type=str, default=None, help="Name for the output LoRA weights.")
44
+ parser.add_argument(
45
+ "--dtype",
46
+ type=str,
47
+ default="bfloat16",
48
+ choices=["bfloat16", "float16"],
49
+ help="Data type of the converted weights.",
50
+ )
51
+ args = parser.parse_args()
52
+
53
+ if is_nunchaku_format(args.lora_path):
54
+ print("Already in Nunchaku format, no conversion needed.")
55
+ exit(0)
56
+
57
+ if not args.output_root:
58
+ args.output_root = os.path.dirname(args.lora_path)
59
+ if args.lora_name is None:
60
+ base_name = os.path.basename(args.lora_path)
61
+ lora_name = base_name.rsplit(".", 1)[0]
62
+ precision = "fp4" if "fp4" in args.quant_path else "int4"
63
+ lora_name = f"svdq-{precision}-{lora_name}"
64
+ print(f"LoRA name not provided, using {lora_name} as the LoRA name")
65
+ else:
66
+ lora_name = args.lora_name
67
+ assert lora_name, "LoRA name must be provided."
68
+
69
+ to_nunchaku(
70
+ args.lora_path,
71
+ args.quant_path,
72
+ dtype=args.dtype,
73
+ output_path=os.path.join(args.output_root, f"{lora_name}.safetensors"),
74
+ )
nunchaku/lora/flux/diffusers_converter.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module implements the functions to convert FLUX LoRA weights from various formats
3
+ to the Diffusers format, which will later be converted to Nunchaku format.
4
+ """
5
+
6
+ import argparse
7
+ import logging
8
+ import os
9
+
10
+ import torch
11
+ from diffusers.loaders import FluxLoraLoaderMixin
12
+ from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
13
+ from safetensors.torch import save_file
14
+
15
+ from ...utils import load_state_dict_in_safetensors
16
+
17
+ # Get log level from environment variable (default to INFO)
18
+ log_level = os.getenv("LOG_LEVEL", "INFO").upper()
19
+
20
+ # Configure logging
21
+ logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def handle_kohya_lora(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
26
+ """
27
+ Convert Kohya LoRA format keys to Diffusers format.
28
+
29
+ Parameters
30
+ ----------
31
+ state_dict : dict[str, torch.Tensor]
32
+ LoRA weights, possibly in Kohya format.
33
+
34
+ Returns
35
+ -------
36
+ dict[str, torch.Tensor]
37
+ LoRA weights in Diffusers format.
38
+ """
39
+ # first check if the state_dict is in the kohya format
40
+ # like: https://civitai.com/models/1118358?modelVersionId=1256866
41
+ if any([not k.startswith("lora_transformer_") for k in state_dict.keys()]):
42
+ return state_dict
43
+ else:
44
+ new_state_dict = {}
45
+ for k, v in state_dict.items():
46
+ new_k = k.replace("lora_transformer_", "transformer.")
47
+
48
+ new_k = new_k.replace("norm_out_", "norm_out.")
49
+
50
+ new_k = new_k.replace("time_text_embed_", "time_text_embed.")
51
+ new_k = new_k.replace("guidance_embedder_", "guidance_embedder.")
52
+ new_k = new_k.replace("text_embedder_", "text_embedder.")
53
+ new_k = new_k.replace("timestep_embedder_", "timestep_embedder.")
54
+
55
+ new_k = new_k.replace("single_transformer_blocks_", "single_transformer_blocks.")
56
+ new_k = new_k.replace("_attn_", ".attn.")
57
+ new_k = new_k.replace("_norm_linear.", ".norm.linear.")
58
+ new_k = new_k.replace("_proj_mlp.", ".proj_mlp.")
59
+ new_k = new_k.replace("_proj_out.", ".proj_out.")
60
+
61
+ new_k = new_k.replace("transformer_blocks_", "transformer_blocks.")
62
+ new_k = new_k.replace("to_out_0.", "to_out.0.")
63
+ new_k = new_k.replace("_ff_context_net_0_proj.", ".ff_context.net.0.proj.")
64
+ new_k = new_k.replace("_ff_context_net_2.", ".ff_context.net.2.")
65
+ new_k = new_k.replace("_ff_net_0_proj.", ".ff.net.0.proj.")
66
+ new_k = new_k.replace("_ff_net_2.", ".ff.net.2.")
67
+ new_k = new_k.replace("_norm1_context_linear.", ".norm1_context.linear.")
68
+ new_k = new_k.replace("_norm1_linear.", ".norm1.linear.")
69
+
70
+ new_k = new_k.replace(".lora_down.", ".lora_A.")
71
+ new_k = new_k.replace(".lora_up.", ".lora_B.")
72
+
73
+ new_state_dict[new_k] = v
74
+ return new_state_dict
75
+
76
+
77
+ def convert_peft_to_comfyui(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
78
+ """
79
+ Convert PEFT format (base_model.model.*) to ComfyUI format (lora_unet_*).
80
+
81
+ Mapping rules:
82
+ - base_model.model.double_blocks.X.img_attn.proj → lora_unet_double_blocks_X_img_attn_proj
83
+ - base_model.model.single_blocks.X.linear1 → lora_unet_single_blocks_X_linear1
84
+ - base_model.model.final_layer.linear → lora_unet_final_layer_linear
85
+ - lora_A/lora_B → lora_down/lora_up
86
+
87
+ Parameters
88
+ ----------
89
+ state_dict : dict[str, torch.Tensor]
90
+ LoRA weights in PEFT format
91
+
92
+ Returns
93
+ -------
94
+ dict[str, torch.Tensor]
95
+ LoRA weights in ComfyUI format
96
+ """
97
+ converted_dict = {}
98
+
99
+ for key, value in state_dict.items():
100
+ new_key = key
101
+
102
+ if key.startswith("base_model.model."):
103
+ # Remove base_model.model. prefix
104
+ new_key = key.replace("base_model.model.", "")
105
+
106
+ # Convert to ComfyUI format with underscores
107
+ # Handle double_blocks
108
+ if "double_blocks" in new_key:
109
+ # Replace dots with underscores within the block structure
110
+ # e.g., double_blocks.0.img_attn.proj → double_blocks_0_img_attn_proj
111
+ new_key = new_key.replace("double_blocks.", "lora_unet_double_blocks_")
112
+ # Replace remaining dots with underscores
113
+ new_key = new_key.replace(".", "_")
114
+
115
+ # Handle single_blocks
116
+ elif "single_blocks" in new_key:
117
+ new_key = new_key.replace("single_blocks.", "lora_unet_single_blocks_")
118
+ # Special handling for modulation.lin → modulation_lin
119
+ new_key = new_key.replace("modulation.lin", "modulation_lin")
120
+ # Replace remaining dots with underscores
121
+ new_key = new_key.replace(".", "_")
122
+
123
+ # Handle final_layer
124
+ elif "final_layer" in new_key:
125
+ new_key = new_key.replace("final_layer.linear", "lora_unet_final_layer_linear")
126
+ # Replace remaining dots with underscores
127
+ new_key = new_key.replace(".", "_")
128
+
129
+ else:
130
+ # For any other keys, add lora_unet_ prefix and replace dots
131
+ new_key = "lora_unet_" + new_key.replace(".", "_")
132
+
133
+ # Convert lora_A/lora_B to lora_down/lora_up
134
+ new_key = new_key.replace("_lora_A_weight", ".lora_down.weight")
135
+ new_key = new_key.replace("_lora_B_weight", ".lora_up.weight")
136
+
137
+ converted_dict[new_key] = value
138
+
139
+ if key != new_key:
140
+ logger.debug(f"Converted: {key} → {new_key}")
141
+
142
+ return converted_dict
143
+
144
+
145
+ def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | None = None) -> dict[str, torch.Tensor]:
146
+ """
147
+ Convert LoRA weights to Diffusers format, which will later be converted to Nunchaku format.
148
+
149
+ Parameters
150
+ ----------
151
+ input_lora : str or dict[str, torch.Tensor]
152
+ Path to a safetensors file or a LoRA weight dictionary.
153
+ output_path : str, optional
154
+ If given, save the converted weights to this path.
155
+
156
+ Returns
157
+ -------
158
+ dict[str, torch.Tensor]
159
+ LoRA weights in Diffusers format.
160
+ """
161
+ if isinstance(input_lora, str):
162
+ tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
163
+ else:
164
+ tensors = {k: v for k, v in input_lora.items()}
165
+
166
+ tensors = handle_kohya_lora(tensors)
167
+
168
+ # Convert FP8 tensors to BF16
169
+ for k, v in tensors.items():
170
+ if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]:
171
+ tensors[k] = v.to(torch.bfloat16)
172
+
173
+ # Apply Kontext-specific key conversion for both PEFT format and ComfyUI format
174
+ # This handles LoRAs with base_model.model.* prefix or lora_unet_* prefix (including final_layer_linear)
175
+ if any(k.startswith("base_model.model.") for k in tensors.keys()):
176
+ logger.info("Converting PEFT format to ComfyUI format")
177
+ return convert_peft_to_comfyui(tensors)
178
+
179
+ # Handle LoRAs that only have final_layer_linear without adaLN_modulation
180
+ # This is a workaround for incomplete final layer LoRAs
181
+ final_keys = [k for k in tensors.keys() if "final_layer" in k]
182
+ has_linear = any("final_layer_linear" in k for k in final_keys)
183
+ has_adaln = any("final_layer_adaLN_modulation" in k for k in final_keys)
184
+
185
+ if has_linear and not has_adaln:
186
+ for key in list(tensors.keys()):
187
+ if "final_layer_linear" in key:
188
+ adaln_key = key.replace("final_layer_linear", "final_layer_adaLN_modulation_1")
189
+ if adaln_key not in tensors:
190
+ tensors[adaln_key] = torch.zeros_like(tensors[key])
191
+
192
+ new_tensors, alphas = FluxLoraLoaderMixin.lora_state_dict(tensors, return_alphas=True)
193
+ new_tensors = convert_unet_state_dict_to_peft(new_tensors)
194
+
195
+ if alphas is not None and len(alphas) > 0:
196
+ for k, v in alphas.items():
197
+ key_A = k.replace(".alpha", ".lora_A.weight")
198
+ key_B = k.replace(".alpha", ".lora_B.weight")
199
+ assert key_A in new_tensors, f"Key {key_A} not found in new tensors."
200
+ assert key_B in new_tensors, f"Key {key_B} not found in new tensors."
201
+ rank = new_tensors[key_A].shape[0]
202
+ assert new_tensors[key_B].shape[1] == rank, f"Rank mismatch for {key_B}."
203
+ new_tensors[key_A] = new_tensors[key_A] * v / rank
204
+
205
+ if output_path is not None:
206
+ output_dir = os.path.dirname(os.path.abspath(output_path))
207
+ os.makedirs(output_dir, exist_ok=True)
208
+ save_file(new_tensors, output_path)
209
+
210
+ return new_tensors
211
+
212
+
213
+ if __name__ == "__main__":
214
+ parser = argparse.ArgumentParser()
215
+ parser.add_argument("-i", "--input-path", type=str, required=True, help="path to the comfyui lora safetensors file")
216
+ parser.add_argument(
217
+ "-o", "--output-path", type=str, required=True, help="path to the output diffusers safetensors file"
218
+ )
219
+ args = parser.parse_args()
220
+ to_diffusers(args.input_path, args.output_path)
nunchaku/lora/flux/nunchaku_converter.py ADDED
@@ -0,0 +1,949 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Nunchaku LoRA format converter for Flux models.
3
+
4
+ This module provides utilities to convert LoRA weights from Diffusers format
5
+ to Nunchaku format for efficient quantized inference in Flux models.
6
+
7
+ Key functions
8
+ -------------
9
+ - :func:`to_nunchaku` : Main conversion entry point
10
+ - :func:`fuse_vectors` : Vector fusion for bias terms
11
+ """
12
+
13
+ import logging
14
+ import os
15
+
16
+ import torch
17
+ from safetensors.torch import save_file
18
+ from tqdm import tqdm
19
+
20
+ from ...utils import filter_state_dict, load_state_dict_in_safetensors
21
+ from .diffusers_converter import to_diffusers
22
+ from .packer import NunchakuWeightPacker
23
+ from .utils import is_nunchaku_format, pad
24
+
25
+ # Get log level from environment variable (default to INFO)
26
+ log_level = os.getenv("LOG_LEVEL", "INFO").upper()
27
+
28
+ # Configure logging
29
+ logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ # region utilities
34
+
35
+
36
+ def update_state_dict(
37
+ lhs: dict[str, torch.Tensor], rhs: dict[str, torch.Tensor], prefix: str = ""
38
+ ) -> dict[str, torch.Tensor]:
39
+ """
40
+ Update a state dictionary with values from another, optionally adding a prefix to keys.
41
+
42
+ Parameters
43
+ ----------
44
+ lhs : dict[str, torch.Tensor]
45
+ Target state dictionary.
46
+ rhs : dict[str, torch.Tensor]
47
+ Source state dictionary.
48
+ prefix : str, optional
49
+ Prefix to add to keys from rhs.
50
+
51
+ Returns
52
+ -------
53
+ dict[str, torch.Tensor]
54
+ Updated state dictionary.
55
+
56
+ Raises
57
+ ------
58
+ AssertionError
59
+ If any key already exists in the target dictionary.
60
+ """
61
+ for rkey, value in rhs.items():
62
+ lkey = f"{prefix}.{rkey}" if prefix else rkey
63
+ assert lkey not in lhs, f"Key {lkey} already exists in the state dict."
64
+ lhs[lkey] = value
65
+ return lhs
66
+
67
+
68
+ # endregion
69
+
70
+
71
+ def pack_lowrank_weight(weight: torch.Tensor, down: bool) -> torch.Tensor:
72
+ """
73
+ Pack the low-rank weight tensor for W4A4 linear layers.
74
+
75
+ Parameters
76
+ ----------
77
+ weight : torch.Tensor
78
+ Low-rank weight tensor.
79
+ down : bool
80
+ If True, pack as down-projection; else as up-projection.
81
+
82
+ Returns
83
+ -------
84
+ torch.Tensor
85
+ Packed weight tensor.
86
+ """
87
+ assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}."
88
+ lane_n, lane_k = 1, 2 # lane_n is always 1, lane_k is 32 bits // 16 bits = 2
89
+ n_pack_size, k_pack_size = 2, 2
90
+ num_n_lanes, num_k_lanes = 8, 4
91
+ frag_n = n_pack_size * num_n_lanes * lane_n
92
+ frag_k = k_pack_size * num_k_lanes * lane_k
93
+ weight = pad(weight, divisor=(frag_n, frag_k), dim=(0, 1))
94
+ if down:
95
+ r, c = weight.shape
96
+ r_frags, c_frags = r // frag_n, c // frag_k
97
+ weight = weight.view(r_frags, frag_n, c_frags, frag_k).permute(2, 0, 1, 3)
98
+ else:
99
+ c, r = weight.shape
100
+ c_frags, r_frags = c // frag_n, r // frag_k
101
+ weight = weight.view(c_frags, frag_n, r_frags, frag_k).permute(0, 2, 1, 3)
102
+ weight = weight.reshape(c_frags, r_frags, n_pack_size, num_n_lanes, k_pack_size, num_k_lanes, lane_k)
103
+ weight = weight.permute(0, 1, 3, 5, 2, 4, 6).contiguous()
104
+ return weight.view(c, r)
105
+
106
+
107
+ def unpack_lowrank_weight(weight: torch.Tensor, down: bool) -> torch.Tensor:
108
+ """
109
+ Unpack the low-rank weight tensor from W4A4 linear layers.
110
+
111
+ Parameters
112
+ ----------
113
+ weight : torch.Tensor
114
+ Packed low-rank weight tensor.
115
+ down : bool
116
+ If True, unpack as down-projection; else as up-projection.
117
+
118
+ Returns
119
+ -------
120
+ torch.Tensor
121
+ Unpacked weight tensor.
122
+ """
123
+ c, r = weight.shape
124
+ assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}."
125
+ lane_n, lane_k = 1, 2 # lane_n is always 1, lane_k is 32 bits // 16 bits = 2
126
+ n_pack_size, k_pack_size = 2, 2
127
+ num_n_lanes, num_k_lanes = 8, 4
128
+ frag_n = n_pack_size * num_n_lanes * lane_n
129
+ frag_k = k_pack_size * num_k_lanes * lane_k
130
+ if down:
131
+ r_frags, c_frags = r // frag_n, c // frag_k
132
+ else:
133
+ c_frags, r_frags = c // frag_n, r // frag_k
134
+ weight = weight.view(c_frags, r_frags, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, lane_k)
135
+ weight = weight.permute(0, 1, 4, 2, 5, 3, 6).contiguous()
136
+ weight = weight.view(c_frags, r_frags, frag_n, frag_k)
137
+ if down:
138
+ weight = weight.permute(1, 2, 0, 3).contiguous().view(r, c)
139
+ else:
140
+ weight = weight.permute(0, 2, 1, 3).contiguous().view(c, r)
141
+ return weight
142
+
143
+
144
+ def reorder_adanorm_lora_up(lora_up: torch.Tensor, splits: int) -> torch.Tensor:
145
+ """
146
+ Reorder AdaNorm LoRA up-projection tensor for correct shape.
147
+
148
+ Parameters
149
+ ----------
150
+ lora_up : torch.Tensor
151
+ LoRA up-projection tensor.
152
+ splits : int
153
+ Number of splits for AdaNorm.
154
+
155
+ Returns
156
+ -------
157
+ torch.Tensor
158
+ Reordered tensor.
159
+ """
160
+ c, r = lora_up.shape
161
+ assert c % splits == 0
162
+ return lora_up.view(splits, c // splits, r).transpose(0, 1).reshape(c, r).contiguous()
163
+
164
+
165
+ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
166
+ orig_state_dict: dict[str, torch.Tensor],
167
+ extra_lora_dict: dict[str, torch.Tensor],
168
+ converted_block_name: str,
169
+ candidate_block_name: str,
170
+ local_name_map: dict[str, str | list[str]],
171
+ convert_map: dict[str, str],
172
+ default_dtype: torch.dtype = torch.bfloat16,
173
+ ) -> dict[str, torch.Tensor]:
174
+ """
175
+ Convert LoRA weights for a transformer block from Diffusers to Nunchaku format.
176
+
177
+ Merges and converts LoRA weights from the original SVDQuant low-rank branch and an extra LoRA dict
178
+ for a given transformer block, producing a Nunchaku-compatible dictionary. Handles both fused and
179
+ unfused LoRA branches (e.g., qkv), and merges multiple LoRA branches as needed.
180
+
181
+ Parameters
182
+ ----------
183
+ orig_state_dict : dict[str, torch.Tensor]
184
+ Original state dict with LoRA weights, keys like ``"{block}.{local}.lora_down"`` and ``"{block}.{local}.lora_up"``.
185
+ extra_lora_dict : dict[str, torch.Tensor]
186
+ Extra LoRA weights, keys like ``"{block}.{local}.lora_A.weight"`` and ``"{block}.{local}.lora_B.weight"``.
187
+ converted_block_name : str
188
+ Block name for output (e.g., ``"transformer_blocks.0"``).
189
+ candidate_block_name : str
190
+ Block name for input lookup (e.g., ``"blocks.0"``).
191
+ local_name_map : dict[str, str | list[str]]
192
+ Maps output local names (e.g., ``"attn.qkv"``) to one or more input local names.
193
+ convert_map : dict[str, str]
194
+ Maps output local names to conversion types: ``"adanorm_single"``, ``"adanorm_zero"``, or ``"linear"``.
195
+ default_dtype : torch.dtype, optional
196
+ Output tensor dtype (default: ``torch.bfloat16``).
197
+
198
+ Returns
199
+ -------
200
+ dict[str, torch.Tensor]
201
+ A dictionary containing the converted LoRA weights in Nunchaku format.
202
+
203
+ Notes
204
+ -----
205
+ - If both original and extra LoRA weights are present, they are merged by concatenation.
206
+ - Handles both fused and unfused attention projections (e.g., qkv).
207
+ - Applies special packing for W4A16 linear layers (e.g., ``"adanorm_single"`` and ``"adanorm_zero"``).
208
+ """
209
+ logger.debug(f"Converting LoRA branch for block {candidate_block_name}...")
210
+ converted: dict[str, torch.Tensor] = {}
211
+ for converted_local_name, candidate_local_names in local_name_map.items():
212
+ if isinstance(candidate_local_names, str):
213
+ candidate_local_names = [candidate_local_names]
214
+ # region original LoRA
215
+ orig_lora = (
216
+ orig_state_dict.get(f"{converted_block_name}.{converted_local_name}.lora_down", None),
217
+ orig_state_dict.get(f"{converted_block_name}.{converted_local_name}.lora_up", None),
218
+ )
219
+ if orig_lora[0] is None or orig_lora[1] is None:
220
+ assert orig_lora[0] is None and orig_lora[1] is None
221
+ orig_lora = None
222
+ elif orig_lora[0].numel() == 0 or orig_lora[1].numel() == 0:
223
+ assert orig_lora[0].numel() == 0 and orig_lora[1].numel() == 0
224
+ orig_lora = None
225
+ else:
226
+ assert orig_lora[0] is not None and orig_lora[1] is not None
227
+ orig_lora = (
228
+ unpack_lowrank_weight(orig_lora[0], down=True),
229
+ unpack_lowrank_weight(orig_lora[1], down=False),
230
+ )
231
+ logger.debug(
232
+ f" - Found {converted_block_name} LoRA of {converted_local_name} (rank: {orig_lora[0].shape[0]})"
233
+ )
234
+ # endregion
235
+ # region extra LoRA
236
+ extra_lora_list = None
237
+
238
+ # if the qkv are already fused
239
+ if "qkv" in converted_local_name:
240
+ candidate_local_name = candidate_local_names[0]
241
+ assert "_q" in candidate_local_name
242
+ candidate_local_name = candidate_local_name.replace("_q", "_qkv")
243
+ lora_A = extra_lora_dict.get(f"{candidate_block_name}.{candidate_local_name}.lora_A.weight", None)
244
+ lora_B = extra_lora_dict.get(f"{candidate_block_name}.{candidate_local_name}.lora_B.weight", None)
245
+ if lora_A is None and lora_B is None:
246
+ extra_lora_list = None
247
+ else:
248
+ assert lora_A is not None and lora_B is not None
249
+ extra_lora_list = [(lora_A, lora_B)]
250
+
251
+ # not fused, fuse them manually
252
+ if extra_lora_list is None:
253
+ extra_lora_list = [
254
+ (
255
+ extra_lora_dict.get(f"{candidate_block_name}.{candidate_local_name}.lora_A.weight", None),
256
+ extra_lora_dict.get(f"{candidate_block_name}.{candidate_local_name}.lora_B.weight", None),
257
+ )
258
+ for candidate_local_name in candidate_local_names
259
+ ]
260
+ if any(lora[0] is not None or lora[1] is not None for lora in extra_lora_list):
261
+ # merge extra LoRAs into one LoRA
262
+ if len(extra_lora_list) > 1:
263
+ first_lora = None
264
+ for lora in extra_lora_list:
265
+ if lora[0] is not None:
266
+ assert lora[1] is not None
267
+ first_lora = lora
268
+ break
269
+ assert first_lora is not None
270
+ for lora_index in range(len(extra_lora_list)):
271
+ if extra_lora_list[lora_index][0] is None:
272
+ assert extra_lora_list[lora_index][1] is None
273
+ extra_lora_list[lora_index] = (first_lora[0].clone(), torch.zeros_like(first_lora[1]))
274
+ if all(lora[0].equal(extra_lora_list[0][0]) for lora in extra_lora_list):
275
+ # if all extra LoRAs have the same lora_down, use it
276
+ extra_lora_down = extra_lora_list[0][0]
277
+ extra_lora_up = torch.cat([lora[1] for lora in extra_lora_list], dim=0)
278
+ else:
279
+ extra_lora_down = torch.cat([lora[0] for lora in extra_lora_list], dim=0)
280
+ extra_lora_up_c = sum(lora[1].shape[0] for lora in extra_lora_list)
281
+ extra_lora_up_r = sum(lora[1].shape[1] for lora in extra_lora_list)
282
+ assert extra_lora_up_r == extra_lora_down.shape[0]
283
+ extra_lora_up = torch.zeros((extra_lora_up_c, extra_lora_up_r), dtype=extra_lora_down.dtype)
284
+ c, r = 0, 0
285
+ for lora in extra_lora_list:
286
+ c_next, r_next = c + lora[1].shape[0], r + lora[1].shape[1]
287
+ extra_lora_up[c:c_next, r:r_next] = lora[1]
288
+ c, r = c_next, r_next
289
+ else:
290
+ extra_lora_down, extra_lora_up = extra_lora_list[0]
291
+ extra_lora: tuple[torch.Tensor, torch.Tensor] = (extra_lora_down, extra_lora_up)
292
+ logger.debug(
293
+ f" - Found {candidate_block_name} LoRA of {candidate_local_names} (rank: {extra_lora[0].shape[0]})"
294
+ )
295
+ else:
296
+ extra_lora = None
297
+ # endregion
298
+ # region merge LoRA
299
+ if orig_lora is None:
300
+ if extra_lora is None:
301
+ lora = None
302
+ else:
303
+ logger.debug(" - Using extra LoRA")
304
+ lora = (extra_lora[0].to(default_dtype), extra_lora[1].to(default_dtype))
305
+ elif extra_lora is None:
306
+ logger.debug(" - Using original LoRA")
307
+ lora = orig_lora
308
+ else:
309
+ try:
310
+ lora = (
311
+ torch.cat([orig_lora[0], extra_lora[0].to(orig_lora[0].dtype)], dim=0), # [r, c]
312
+ torch.cat([orig_lora[1], extra_lora[1].to(orig_lora[1].dtype)], dim=1), # [c, r]
313
+ )
314
+ logger.debug(f" - Merging original and extra LoRA (rank: {lora[0].shape[0]})")
315
+ except RuntimeError as e:
316
+ if "Sizes of tensors must match" in str(e):
317
+ # Handle various dimension mismatch cases for LoRA
318
+ logger.debug(
319
+ f" - Dimension mismatch detected: orig_lora[1]={orig_lora[1].shape}, extra_lora[1]={extra_lora[1].shape}"
320
+ )
321
+
322
+ # Handle dimension mismatch by using only the properly sized portion of extra_lora
323
+ # instead of trying to concatenate mismatched dimensions
324
+
325
+ # Case 1: single_blocks linear1 [21504] -> mlp_fc1 [12288]
326
+ if extra_lora[1].shape[1] == 21504 and orig_lora[1].shape[1] == 12288:
327
+ # Use only the first 12288 dimensions from the 21504 extra LoRA
328
+ extra_lora_up_split = extra_lora[1][:, :12288].clone()
329
+ extra_lora_down = extra_lora[0].clone()
330
+ # logger.debug(f" - Dimension fix 21504->12288: using split extra LoRA instead of merge")
331
+
332
+ # Use the split extra LoRA instead of concatenating
333
+ lora = (extra_lora_down.to(orig_lora[0].dtype), extra_lora_up_split.to(orig_lora[1].dtype))
334
+
335
+ # Case 2: transformer_blocks with different MLP dimensions (27648 -> 9216)
336
+ elif extra_lora[1].shape[1] == 27648 and orig_lora[1].shape[1] == 9216:
337
+ # Use only the first 9216 dimensions from the 27648 extra LoRA
338
+ extra_lora_up_split = extra_lora[1][:, :9216].clone()
339
+ extra_lora_down = extra_lora[0].clone()
340
+ # logger.debug(f" - Dimension fix 27648->9216: using split extra LoRA instead of merge")
341
+
342
+ # Use the split extra LoRA instead of concatenating
343
+ lora = (extra_lora_down.to(orig_lora[0].dtype), extra_lora_up_split.to(orig_lora[1].dtype))
344
+
345
+ # Case 3: Other dimension ratios - try to find a reasonable split
346
+ elif extra_lora[1].shape[1] > orig_lora[1].shape[1]:
347
+ # Use only what we need from extra LoRA
348
+ target_dim = orig_lora[1].shape[1]
349
+ extra_lora_up_split = extra_lora[1][:, :target_dim].clone()
350
+ extra_lora_down = extra_lora[0].clone()
351
+ # logger.debug(
352
+ # f" - Dimension fix {extra_lora[1].shape[1]}->{target_dim}: using truncated extra LoRA"
353
+ # )
354
+
355
+ # Use the truncated extra LoRA instead of concatenating
356
+ lora = (extra_lora_down.to(orig_lora[0].dtype), extra_lora_up_split.to(orig_lora[1].dtype))
357
+
358
+ else:
359
+ # For cases where extra LoRA has fewer dimensions, use original LoRA only
360
+ # logger.warning(
361
+ # f" - Cannot split extra LoRA {extra_lora[1].shape[1]}->{orig_lora[1].shape[1]}, using original only"
362
+ # )
363
+ lora = orig_lora
364
+ else:
365
+ raise e
366
+ # endregion
367
+ if lora is not None:
368
+ if convert_map[converted_local_name] == "adanorm_single":
369
+ update_state_dict(
370
+ converted,
371
+ {
372
+ "lora_down": pad(lora[0], divisor=16, dim=0),
373
+ "lora_up": pad(reorder_adanorm_lora_up(lora[1], splits=3), divisor=16, dim=1),
374
+ },
375
+ prefix=converted_local_name,
376
+ )
377
+ elif convert_map[converted_local_name] == "adanorm_zero":
378
+ update_state_dict(
379
+ converted,
380
+ {
381
+ "lora_down": pad(lora[0], divisor=16, dim=0),
382
+ "lora_up": pad(reorder_adanorm_lora_up(lora[1], splits=6), divisor=16, dim=1),
383
+ },
384
+ prefix=converted_local_name,
385
+ )
386
+ elif convert_map[converted_local_name] == "linear":
387
+ update_state_dict(
388
+ converted,
389
+ {
390
+ "lora_down": pack_lowrank_weight(lora[0], down=True),
391
+ "lora_up": pack_lowrank_weight(lora[1], down=False),
392
+ },
393
+ prefix=converted_local_name,
394
+ )
395
+ return converted
396
+
397
+
398
+ def preprocess_single_blocks_lora(
399
+ extra_lora_dict: dict[str, torch.Tensor], candidate_block_name: str
400
+ ) -> dict[str, torch.Tensor]:
401
+ """
402
+ Preprocess LoRA weights from single_blocks format to match single_transformer_blocks structure.
403
+
404
+ This function handles the architectural mismatch between old and new models:
405
+ - Old single_blocks: linear1 (fused 21504-dim layer) and linear2
406
+ - New single_transformer_blocks: mlp_fc1 (12288-dim), qkv_proj (9216-dim), and mlp_fc2
407
+
408
+ The linear1 layer in the old architecture combines two functions:
409
+ 1. MLP projection (first 12288 dimensions)
410
+ 2. QKV projection for attention (last 9216 dimensions)
411
+
412
+ These are split into separate layers in the new architecture.
413
+ """
414
+ processed_dict = extra_lora_dict.copy()
415
+
416
+ # Find all single_transformer_blocks keys that need preprocessing
417
+ single_blocks_keys = [k for k in extra_lora_dict.keys() if "single_transformer_blocks" in k and "linear" in k]
418
+
419
+ logger.debug(f"Preprocessing LoRA for candidate: {candidate_block_name}")
420
+ logger.debug(f"All keys in extra_lora_dict: {list(extra_lora_dict.keys())[:10]}...") # Show first 10 keys
421
+ logger.debug(f"Found single_transformer_blocks keys: {single_blocks_keys[:5]}...") # Show first 5 keys
422
+
423
+ if single_blocks_keys:
424
+ logger.debug(f"Found single_transformer_blocks LoRA keys, preprocessing for candidate: {candidate_block_name}")
425
+
426
+ # The candidate_block_name is already "single_transformer_blocks.0"
427
+ # Look for linear1 and linear2 keys with this exact name
428
+ linear1_lora_A_key = f"{candidate_block_name}.linear1.lora_A.weight"
429
+ linear1_lora_B_key = f"{candidate_block_name}.linear1.lora_B.weight"
430
+ linear2_lora_A_key = f"{candidate_block_name}.linear2.lora_A.weight"
431
+ linear2_lora_B_key = f"{candidate_block_name}.linear2.lora_B.weight"
432
+
433
+ logger.debug(f"Looking for keys: {linear1_lora_B_key}")
434
+ logger.debug(
435
+ f"Available keys matching pattern: {[k for k in extra_lora_dict.keys() if candidate_block_name in k][:5]}..."
436
+ )
437
+
438
+ if linear1_lora_B_key in extra_lora_dict:
439
+ linear1_lora_A = extra_lora_dict[linear1_lora_A_key]
440
+ linear1_lora_B = extra_lora_dict[linear1_lora_B_key]
441
+
442
+ # Check if this is the problematic 21504 dimension case
443
+ if linear1_lora_B.shape[0] == 21504:
444
+ logger.debug(
445
+ f"Splitting linear1 LoRA weights: [21504, {linear1_lora_B.shape[1]}] -> "
446
+ f"mlp_fc1 [12288, {linear1_lora_B.shape[1]}] + qkv_proj [9216, {linear1_lora_B.shape[1]}]"
447
+ )
448
+
449
+ # Split linear1.lora_B [21504, rank] into two parts:
450
+ # 1. First 12288 dimensions -> mlp_fc1
451
+ # 2. Last 9216 dimensions (12288:21504) -> qkv_proj
452
+ mlp_fc1_lora_B = linear1_lora_B[:12288, :].clone()
453
+ qkv_proj_lora_B = linear1_lora_B[12288:21504, :].clone()
454
+
455
+ # The lora_A weight is reused for both new layers
456
+ # since it represents the down-projection from the input
457
+ mlp_fc1_lora_A = linear1_lora_A.clone()
458
+ qkv_proj_lora_A = linear1_lora_A.clone()
459
+
460
+ # Map to new architecture:
461
+ # 1. proj_mlp corresponds to mlp_fc1
462
+ processed_dict[f"{candidate_block_name}.proj_mlp.lora_A.weight"] = mlp_fc1_lora_A
463
+ processed_dict[f"{candidate_block_name}.proj_mlp.lora_B.weight"] = mlp_fc1_lora_B
464
+
465
+ # 2. Map the QKV part to the attention layers
466
+ # Note: In the new architecture, this maps to attn.to_q, attn.to_k, attn.to_v
467
+ # which get fused into qkv_proj during the conversion
468
+ processed_dict[f"{candidate_block_name}.attn.to_q.lora_A.weight"] = qkv_proj_lora_A
469
+ processed_dict[f"{candidate_block_name}.attn.to_q.lora_B.weight"] = qkv_proj_lora_B[
470
+ :3072, :
471
+ ] # Q projection
472
+ processed_dict[f"{candidate_block_name}.attn.to_k.lora_A.weight"] = qkv_proj_lora_A
473
+ processed_dict[f"{candidate_block_name}.attn.to_k.lora_B.weight"] = qkv_proj_lora_B[
474
+ 3072:6144, :
475
+ ] # K projection
476
+ processed_dict[f"{candidate_block_name}.attn.to_v.lora_A.weight"] = qkv_proj_lora_A
477
+ processed_dict[f"{candidate_block_name}.attn.to_v.lora_B.weight"] = qkv_proj_lora_B[
478
+ 6144:9216, :
479
+ ] # V projection
480
+
481
+ # Handle linear2 -> mlp_fc2 mapping
482
+ if linear2_lora_B_key in extra_lora_dict:
483
+ linear2_lora_A = extra_lora_dict[linear2_lora_A_key]
484
+ linear2_lora_B = extra_lora_dict[linear2_lora_B_key]
485
+
486
+ # Map linear2 to proj_out.linears.1 (mlp_fc2)
487
+ processed_dict[f"{candidate_block_name}.proj_out.linears.1.lora_A.weight"] = linear2_lora_A
488
+ processed_dict[f"{candidate_block_name}.proj_out.linears.1.lora_B.weight"] = linear2_lora_B
489
+
490
+ # Remove original keys
491
+ processed_dict.pop(linear2_lora_A_key, None)
492
+ processed_dict.pop(linear2_lora_B_key, None)
493
+
494
+ # Remove original linear1 keys
495
+ processed_dict.pop(linear1_lora_A_key, None)
496
+ processed_dict.pop(linear1_lora_B_key, None)
497
+
498
+ return processed_dict
499
+
500
+
501
+ def convert_to_nunchaku_flux_single_transformer_block_lowrank_dict(
502
+ orig_state_dict: dict[str, torch.Tensor],
503
+ extra_lora_dict: dict[str, torch.Tensor],
504
+ converted_block_name: str,
505
+ candidate_block_name: str,
506
+ default_dtype: torch.dtype = torch.bfloat16,
507
+ ) -> dict[str, torch.Tensor]:
508
+ """
509
+ Convert LoRA weights for a single FLUX transformer block from Diffusers to Nunchaku format.
510
+
511
+ This function merges and converts LoRA weights from the original SVDQuant low-rank branch and an
512
+ extra LoRA dictionary for a given transformer block, producing a Nunchaku-compatible dictionary.
513
+ It handles both fused and unfused LoRA branches (e.g., qkv), and merges multiple LoRA branches as needed.
514
+
515
+ Parameters
516
+ ----------
517
+ orig_state_dict : dict[str, torch.Tensor]
518
+ Original state dict with LoRA weights, keys like ``"{block}.{local}.lora_down"`` and ``"{block}.{local}.lora_up"``.
519
+ extra_lora_dict : dict[str, torch.Tensor]
520
+ Extra LoRA weights, keys like ``"{block}.{local}.lora_A.weight"`` and ``"{block}.{local}.lora_B.weight"``.
521
+ converted_block_name : str
522
+ Block name for output (e.g., ``"transformer_blocks.0"``).
523
+ candidate_block_name : str
524
+ Block name for input lookup (e.g., ``"blocks.0"``).
525
+ default_dtype : torch.dtype, optional
526
+ Output tensor dtype (default: ``torch.bfloat16``).
527
+
528
+ Returns
529
+ -------
530
+ dict[str, torch.Tensor]
531
+ A dictionary containing the converted LoRA weights in Nunchaku format.
532
+
533
+ Notes
534
+ -----
535
+ - If both original and extra LoRA weights are present, they are merged by concatenation.
536
+ - Handles both fused and unfused attention projections (e.g., qkv).
537
+ - Applies special packing for W4A16 linear layers (e.g., ``"adanorm_single"`` and ``"adanorm_zero"``).
538
+ """
539
+
540
+ # Preprocess single_blocks LoRA structure if needed
541
+ # extra_lora_dict = preprocess_single_blocks_lora(extra_lora_dict, candidate_block_name)
542
+
543
+ if f"{candidate_block_name}.proj_out.lora_A.weight" in extra_lora_dict:
544
+ assert f"{converted_block_name}.out_proj.qweight" in orig_state_dict
545
+ assert f"{converted_block_name}.mlp_fc2.qweight" in orig_state_dict
546
+ n1 = orig_state_dict[f"{converted_block_name}.out_proj.qweight"].shape[1] * 2
547
+ n2 = orig_state_dict[f"{converted_block_name}.mlp_fc2.qweight"].shape[1] * 2
548
+ lora_down = extra_lora_dict[f"{candidate_block_name}.proj_out.lora_A.weight"]
549
+ lora_up = extra_lora_dict[f"{candidate_block_name}.proj_out.lora_B.weight"]
550
+ assert lora_down.shape[1] == n1 + n2
551
+ extra_lora_dict[f"{candidate_block_name}.proj_out.linears.0.lora_A.weight"] = lora_down[:, :n1].clone()
552
+ extra_lora_dict[f"{candidate_block_name}.proj_out.linears.0.lora_B.weight"] = lora_up.clone()
553
+ extra_lora_dict[f"{candidate_block_name}.proj_out.linears.1.lora_A.weight"] = lora_down[:, n1:].clone()
554
+ extra_lora_dict[f"{candidate_block_name}.proj_out.linears.1.lora_B.weight"] = lora_up.clone()
555
+ extra_lora_dict.pop(f"{candidate_block_name}.proj_out.lora_A.weight")
556
+ extra_lora_dict.pop(f"{candidate_block_name}.proj_out.lora_B.weight")
557
+
558
+ for component in ["lora_A", "lora_B"]:
559
+ fc1_k = f"{candidate_block_name}.proj_mlp.{component}.weight"
560
+ fc2_k = f"{candidate_block_name}.proj_out.linears.1.{component}.weight"
561
+ fc1_v = extra_lora_dict[fc1_k]
562
+ fc2_v = extra_lora_dict[fc2_k]
563
+ dim = 0 if "lora_A" in fc1_k else 1
564
+
565
+ fc1_rank = fc1_v.shape[dim]
566
+ fc2_rank = fc2_v.shape[dim]
567
+ if fc1_rank != fc2_rank:
568
+ rank = max(fc1_rank, fc2_rank)
569
+ if fc1_rank < rank:
570
+ extra_lora_dict[fc1_k] = pad(fc1_v, divisor=rank, dim=dim)
571
+ if fc2_rank < rank:
572
+ extra_lora_dict[fc2_k] = pad(fc2_v, divisor=rank, dim=dim)
573
+
574
+ return convert_to_nunchaku_transformer_block_lowrank_dict(
575
+ orig_state_dict=orig_state_dict,
576
+ extra_lora_dict=extra_lora_dict,
577
+ converted_block_name=converted_block_name,
578
+ candidate_block_name=candidate_block_name,
579
+ local_name_map={
580
+ "norm.linear": "norm.linear",
581
+ "qkv_proj": ["attn.to_q", "attn.to_k", "attn.to_v"],
582
+ "norm_q": "attn.norm_q",
583
+ "norm_k": "attn.norm_k",
584
+ "out_proj": "proj_out.linears.0",
585
+ "mlp_fc1": "proj_mlp",
586
+ "mlp_fc2": "proj_out.linears.1",
587
+ },
588
+ convert_map={
589
+ "norm.linear": "adanorm_single",
590
+ "qkv_proj": "linear",
591
+ "out_proj": "linear",
592
+ "mlp_fc1": "linear",
593
+ "mlp_fc2": "linear",
594
+ },
595
+ default_dtype=default_dtype,
596
+ )
597
+
598
+
599
+ def convert_to_nunchaku_flux_transformer_block_lowrank_dict(
600
+ orig_state_dict: dict[str, torch.Tensor],
601
+ extra_lora_dict: dict[str, torch.Tensor],
602
+ converted_block_name: str,
603
+ candidate_block_name: str,
604
+ default_dtype: torch.dtype = torch.bfloat16,
605
+ ) -> dict[str, torch.Tensor]:
606
+ """
607
+ Convert LoRA weights for a single transformer block from Diffusers to Nunchaku format.
608
+
609
+ Parameters
610
+ ----------
611
+ orig_state_dict : dict[str, torch.Tensor]
612
+ Original model state dict.
613
+ extra_lora_dict : dict[str, torch.Tensor]
614
+ LoRA weights state dict.
615
+ converted_block_name : str
616
+ Output block name for the converted weights.
617
+ candidate_block_name : str
618
+ Input block name for lookup.
619
+ default_dtype : torch.dtype, optional
620
+ Output tensor dtype (default: torch.bfloat16).
621
+
622
+ Returns
623
+ -------
624
+ dict[str, torch.Tensor]
625
+ Converted LoRA weights in Nunchaku format.
626
+ """
627
+ return convert_to_nunchaku_transformer_block_lowrank_dict(
628
+ orig_state_dict=orig_state_dict,
629
+ extra_lora_dict=extra_lora_dict,
630
+ converted_block_name=converted_block_name,
631
+ candidate_block_name=candidate_block_name,
632
+ local_name_map={
633
+ "norm1.linear": "norm1.linear",
634
+ "norm1_context.linear": "norm1_context.linear",
635
+ "qkv_proj": ["attn.to_q", "attn.to_k", "attn.to_v"],
636
+ "qkv_proj_context": ["attn.add_q_proj", "attn.add_k_proj", "attn.add_v_proj"],
637
+ "norm_q": "attn.norm_q",
638
+ "norm_k": "attn.norm_k",
639
+ "norm_added_q": "attn.norm_added_q",
640
+ "norm_added_k": "attn.norm_added_k",
641
+ "out_proj": "attn.to_out.0",
642
+ "out_proj_context": "attn.to_add_out",
643
+ "mlp_fc1": "ff.net.0.proj",
644
+ "mlp_fc2": "ff.net.2",
645
+ "mlp_context_fc1": "ff_context.net.0.proj",
646
+ "mlp_context_fc2": "ff_context.net.2",
647
+ },
648
+ convert_map={
649
+ "norm1.linear": "adanorm_zero",
650
+ "norm1_context.linear": "adanorm_zero",
651
+ "qkv_proj": "linear",
652
+ "qkv_proj_context": "linear",
653
+ "out_proj": "linear",
654
+ "out_proj_context": "linear",
655
+ "mlp_fc1": "linear",
656
+ "mlp_fc2": "linear",
657
+ "mlp_context_fc1": "linear",
658
+ "mlp_context_fc2": "linear",
659
+ },
660
+ default_dtype=default_dtype,
661
+ )
662
+
663
+
664
+ def convert_to_nunchaku_flux_lowrank_dict(
665
+ base_model: dict[str, torch.Tensor] | str,
666
+ lora: dict[str, torch.Tensor] | str,
667
+ default_dtype: torch.dtype = torch.bfloat16,
668
+ ) -> dict[str, torch.Tensor]:
669
+ """
670
+ Convert a base model and LoRA weights from Diffusers format to Nunchaku format.
671
+
672
+ Parameters
673
+ ----------
674
+ base_model : dict[str, torch.Tensor] or str
675
+ Base model weights or path to safetensors file.
676
+ lora : dict[str, torch.Tensor] or str
677
+ LoRA weights or path to safetensors file.
678
+ default_dtype : torch.dtype, optional
679
+ Output tensor dtype (default: torch.bfloat16).
680
+
681
+ Returns
682
+ -------
683
+ dict[str, torch.Tensor]
684
+ LoRA weights in Nunchaku format.
685
+ """
686
+ if isinstance(base_model, str):
687
+ orig_state_dict = load_state_dict_in_safetensors(base_model)
688
+ else:
689
+ orig_state_dict = base_model
690
+
691
+ if isinstance(lora, str):
692
+ # Load the LoRA - check if it has transformer prefix
693
+ temp_dict = load_state_dict_in_safetensors(lora)
694
+ if any(k.startswith("transformer.") for k in temp_dict.keys()):
695
+ # Standard FLUX LoRA with transformer prefix
696
+ extra_lora_dict = filter_state_dict(temp_dict, filter_prefix="transformer.")
697
+ # Remove the transformer. prefix after filtering
698
+ renamed_dict = {}
699
+ for k, v in extra_lora_dict.items():
700
+ new_k = k.replace("transformer.", "") if k.startswith("transformer.") else k
701
+ renamed_dict[new_k] = v
702
+ extra_lora_dict = renamed_dict
703
+ else:
704
+ # Kontext LoRA without transformer prefix - use as is
705
+ extra_lora_dict = temp_dict
706
+ else:
707
+ # When called from to_nunchaku, lora is already processed by to_diffusers
708
+ # Keys should be in format: single_blocks.0.linear1.lora_A.weight
709
+ extra_lora_dict = lora
710
+
711
+ # Add transformer. prefix and rename blocks to match expectations
712
+ renamed_dict = {}
713
+ for k, v in extra_lora_dict.items():
714
+ new_k = k
715
+ # Add transformer. prefix and rename blocks
716
+ if k.startswith("single_blocks."):
717
+ new_k = "transformer.single_transformer_blocks." + k[14:]
718
+ elif k.startswith("double_blocks."):
719
+ new_k = "transformer.transformer_blocks." + k[14:]
720
+ elif k.startswith("proj_out."):
721
+ new_k = "transformer." + k
722
+ elif not k.startswith("transformer."):
723
+ new_k = "transformer." + k
724
+ renamed_dict[new_k] = v
725
+ extra_lora_dict = renamed_dict
726
+
727
+ # Now filter for transformer prefix and remove it for processing
728
+ extra_lora_dict = filter_state_dict(extra_lora_dict, filter_prefix="transformer.")
729
+
730
+ # Remove the transformer. prefix for internal processing
731
+ renamed_dict = {}
732
+ for k, v in extra_lora_dict.items():
733
+ new_k = k.replace("transformer.", "") if k.startswith("transformer.") else k
734
+ renamed_dict[new_k] = v
735
+ extra_lora_dict = renamed_dict
736
+
737
+ vector_dict, unquantized_lora_dict = {}, {}
738
+ for k in list(extra_lora_dict.keys()):
739
+ v = extra_lora_dict[k]
740
+ if v.ndim == 1:
741
+ vector_dict[k.replace(".lora_B.bias", ".bias")] = extra_lora_dict.pop(k)
742
+ elif "transformer_blocks" not in k and "single_transformer_blocks" not in k:
743
+ # Only unquantized parts (like final_layer) go here
744
+ unquantized_lora_dict[k] = extra_lora_dict.pop(k)
745
+
746
+ # Concatenate qkv_proj biases if present
747
+ for k in list(vector_dict.keys()):
748
+ if ".to_q." in k or ".add_q_proj." in k:
749
+ k_q = k
750
+ k_k = k.replace(".to_q.", ".to_k.").replace(".add_q_proj.", ".add_k_proj.")
751
+ k_v = k.replace(".to_q.", ".to_v.").replace(".add_q_proj.", ".add_v_proj.")
752
+ keys = [k_q, k_k, k_v]
753
+ values = [vector_dict.pop(key) for key in keys]
754
+ new_k = k_q.replace(".to_q.", ".to_qkv.").replace(".add_q_proj.", ".add_qkv_proj.")
755
+ vector_dict[new_k] = torch.cat(values, dim=0)
756
+
757
+ for k in extra_lora_dict.keys():
758
+ fc1_k = k
759
+ if "ff.net.0.proj" in k:
760
+ fc2_k = k.replace("ff.net.0.proj", "ff.net.2")
761
+ elif "ff_context.net.0.proj" in k:
762
+ fc2_k = k.replace("ff_context.net.0.proj", "ff_context.net.2")
763
+ else:
764
+ continue
765
+ assert fc2_k in extra_lora_dict
766
+ fc1_v = extra_lora_dict[fc1_k]
767
+ fc2_v = extra_lora_dict[fc2_k]
768
+ dim = 0 if "lora_A" in fc1_k else 1
769
+
770
+ fc1_rank = fc1_v.shape[dim]
771
+ fc2_rank = fc2_v.shape[dim]
772
+ if fc1_rank != fc2_rank:
773
+ rank = max(fc1_rank, fc2_rank)
774
+ if fc1_rank < rank:
775
+ extra_lora_dict[fc1_k] = pad(fc1_v, divisor=rank, dim=dim)
776
+ if fc2_rank < rank:
777
+ extra_lora_dict[fc2_k] = pad(fc2_v, divisor=rank, dim=dim)
778
+
779
+ block_names: set[str] = set()
780
+ for param_name in orig_state_dict.keys():
781
+ if param_name.startswith(("transformer_blocks.", "single_transformer_blocks.")):
782
+ block_names.add(".".join(param_name.split(".")[:2]))
783
+ block_names = sorted(block_names, key=lambda x: (x.split(".")[0], int(x.split(".")[-1])))
784
+ logger.debug(f"Converting {len(block_names)} transformer blocks...")
785
+ converted: dict[str, torch.Tensor] = {}
786
+ for block_name in tqdm(block_names, dynamic_ncols=True, desc="Converting LoRAs to nunchaku format"):
787
+ if block_name.startswith("transformer_blocks"):
788
+ convert_fn = convert_to_nunchaku_flux_transformer_block_lowrank_dict
789
+ else:
790
+ convert_fn = convert_to_nunchaku_flux_single_transformer_block_lowrank_dict
791
+ update_state_dict(
792
+ converted,
793
+ convert_fn(
794
+ orig_state_dict=orig_state_dict,
795
+ extra_lora_dict=extra_lora_dict,
796
+ converted_block_name=block_name,
797
+ candidate_block_name=block_name,
798
+ default_dtype=default_dtype,
799
+ ),
800
+ prefix=block_name,
801
+ )
802
+
803
+ converted.update(unquantized_lora_dict)
804
+ converted.update(vector_dict)
805
+ return converted
806
+
807
+
808
+ def to_nunchaku(
809
+ input_lora: str | dict[str, torch.Tensor],
810
+ base_sd: str | dict[str, torch.Tensor],
811
+ dtype: str | torch.dtype = torch.bfloat16,
812
+ output_path: str | None = None,
813
+ ) -> dict[str, torch.Tensor]:
814
+ """
815
+ Convert LoRA weights to Nunchaku format.
816
+
817
+ Parameters
818
+ ----------
819
+ input_lora : str or dict[str, torch.Tensor]
820
+ Path or dictionary of LoRA weights in Diffusers format. Can be composed of multiple LoRA weights.
821
+ base_sd : str or dict[str, torch.Tensor]
822
+ Path or dictionary of base quantized model weights.
823
+ dtype : str or torch.dtype, optional
824
+ Output data type ("bfloat16", "float16", or torch dtype). Default is torch.bfloat16.
825
+ output_path : str, optional
826
+ If provided, saves the result to this path.
827
+
828
+ Returns
829
+ -------
830
+ dict[str, torch.Tensor]
831
+ LoRA weights in Nunchaku format.
832
+
833
+ Example
834
+ -------
835
+ .. code-block:: python
836
+
837
+ nunchaku_weights = to_nunchaku("lora.safetensors", "base_model.safetensors")
838
+ nunchaku_weights = to_nunchaku(lora_dict, base_dict)
839
+ """
840
+ if isinstance(input_lora, str):
841
+ tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
842
+ else:
843
+ tensors = input_lora
844
+ if is_nunchaku_format(tensors):
845
+ logger.debug("Already in nunchaku format, no conversion needed.")
846
+ converted = tensors
847
+ else:
848
+ extra_lora_dict = to_diffusers(tensors)
849
+
850
+ if isinstance(base_sd, str):
851
+ orig_state_dict = load_state_dict_in_safetensors(base_sd)
852
+ else:
853
+ orig_state_dict = base_sd
854
+
855
+ if isinstance(dtype, str):
856
+ if dtype == "bfloat16":
857
+ dtype = torch.bfloat16
858
+ elif dtype == "float16":
859
+ dtype = torch.float16
860
+ else:
861
+ raise ValueError(f"Unsupported dtype {dtype}.")
862
+ else:
863
+ assert isinstance(dtype, torch.dtype)
864
+
865
+ converted = convert_to_nunchaku_flux_lowrank_dict(
866
+ base_model=orig_state_dict, lora=extra_lora_dict, default_dtype=dtype
867
+ )
868
+ if output_path is not None:
869
+ output_dir = os.path.dirname(os.path.abspath(output_path))
870
+ os.makedirs(output_dir, exist_ok=True)
871
+ save_file(converted, output_path)
872
+ return converted
873
+
874
+
875
+ #### fuse vectors ####
876
+
877
+
878
+ def fuse_vectors(
879
+ vectors: dict[str, torch.Tensor], base_sd: dict[str, torch.Tensor], strength: float = 1
880
+ ) -> dict[str, torch.Tensor]:
881
+ """
882
+ Fuse vector (bias) terms from LoRA into the base model.
883
+
884
+ Parameters
885
+ ----------
886
+ vectors : dict[str, torch.Tensor]
887
+ LoRA vector terms.
888
+ base_sd : dict[str, torch.Tensor]
889
+ Base model state dict.
890
+ strength : float, optional
891
+ Scaling factor for LoRA vectors.
892
+
893
+ Returns
894
+ -------
895
+ dict[str, torch.Tensor]
896
+ State dict with fused vectors.
897
+ """
898
+ tensors: dict[str, torch.Tensor] = {}
899
+ packer = NunchakuWeightPacker(bits=4)
900
+ for k, v in base_sd.items():
901
+ if v.ndim != 1 or "smooth" in k or (k.startswith("single_transformer_blocks.") and ".mlp_fc2." in k):
902
+ continue
903
+ if "norm_q" in k or "norm_k" in k or "norm_added_q" in k or "norm_added_k" in k:
904
+ new_k = k.replace(".norm_", ".attn.norm_")
905
+ new_v = vectors.get(new_k, None)
906
+ tensors[k] = v if new_v is None else new_v
907
+
908
+ elif "norm.linear" in k or "norm1.linear" in k or "norm1_context.linear" in k:
909
+ diff = vectors.get(k, None)
910
+
911
+ if diff is not None:
912
+ if k.startswith("single_transformer_blocks."):
913
+ adanorm_splits = 3
914
+ else:
915
+ assert k.startswith("transformer_blocks.")
916
+ adanorm_splits = 6
917
+ diff = diff.view(adanorm_splits, -1).transpose(0, 1).reshape(-1)
918
+ tensors[k] = v + diff * strength
919
+ else:
920
+ tensors[k] = v
921
+
922
+ else:
923
+ if k.startswith("single_transformer_blocks."):
924
+ name_map = {".qkv_proj.": ".attn.to_qkv.", ".out_proj.": ".proj_out.", ".mlp_fc1.": ".proj_mlp."}
925
+ else:
926
+ assert k.startswith("transformer_blocks.")
927
+ name_map = {
928
+ ".qkv_proj.": ".attn.to_qkv.",
929
+ ".qkv_proj_context.": ".attn.add_qkv_proj.",
930
+ ".out_proj.": ".attn.to_out.0.",
931
+ ".out_proj_context.": ".attn.to_add_out.",
932
+ ".mlp_fc1.": ".ff.net.0.proj.",
933
+ ".mlp_fc2.": ".ff.net.2.",
934
+ ".mlp_context_fc1.": ".ff_context.net.0.proj.",
935
+ ".mlp_context_fc2.": ".ff_context.net.2.",
936
+ }
937
+
938
+ for original_pattern, new_pattern in name_map.items():
939
+ if original_pattern in k:
940
+ new_k = k.replace(original_pattern, new_pattern)
941
+ diff = vectors.get(new_k, None)
942
+ if diff is not None:
943
+ diff = diff * strength
944
+ diff = packer.pad_scale(diff, group_size=-1)
945
+ diff = packer.pack_scale(diff, group_size=-1)
946
+ tensors[k] = v + diff
947
+ break
948
+
949
+ return tensors
nunchaku/lora/flux/packer.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Weight packing utilities for Nunchaku quantization.
3
+
4
+ This module provides concise tools for packing and unpacking weight tensors,
5
+ optimized for efficient GPU computation using Matrix Multiply and Accumulate (MMA) operations.
6
+ """
7
+
8
+ import torch
9
+
10
+ from ...utils import ceil_divide
11
+ from .utils import pad
12
+
13
+
14
+ class MmaWeightPackerBase:
15
+ """
16
+ Base class for Matrix Multiply and Accumulate (MMA) weight packing.
17
+
18
+ Packs weight tensors for efficient GPU computation using MMA operations.
19
+ Handles tile sizes, memory layout, and packing parameters.
20
+
21
+ Parameters
22
+ ----------
23
+ bits : int
24
+ Quantization bits. Must be 1, 4, 8, 16, or 32.
25
+ warp_n : int
26
+ Warp size in the n dimension.
27
+ comp_n : int, optional
28
+ Computation tile size in n (default: 16).
29
+ comp_k : int, optional
30
+ Computation tile size in k (default: 256 // bits).
31
+
32
+ Raises
33
+ ------
34
+ AssertionError
35
+ If bits or tile/pack sizes are invalid.
36
+
37
+ Attributes
38
+ ----------
39
+ comp_n : int
40
+ Tile size in n for MMA computation.
41
+ comp_k : int
42
+ Tile size in k for MMA computation.
43
+ insn_n : int
44
+ MMA instruction tile size in n.
45
+ insn_k : int
46
+ MMA instruction tile size in k.
47
+ num_lanes : int
48
+ Number of lanes (threads) in a warp.
49
+ num_k_lanes : int
50
+ Number of lanes in k.
51
+ num_n_lanes : int
52
+ Number of lanes in n.
53
+ warp_n : int
54
+ Warp size in n.
55
+ reg_k : int
56
+ Elements in a register in k.
57
+ reg_n : int
58
+ Elements in a register in n.
59
+ k_pack_size : int
60
+ Elements in a pack in k.
61
+ n_pack_size : int
62
+ Elements in a pack in n.
63
+ pack_size : int
64
+ Elements in a pack accessed by a lane.
65
+ mem_k : int
66
+ Tile size in k for one memory access.
67
+ mem_n : int
68
+ Tile size in n for one memory access.
69
+ num_k_packs : int
70
+ Packs in k for one memory access.
71
+ num_n_packs : int
72
+ Packs in n for one memory access.
73
+ """
74
+
75
+ def __init__(self, bits: int, warp_n: int, comp_n: int = None, comp_k: int = None):
76
+ self.bits = bits
77
+ assert self.bits in (1, 4, 8, 16, 32), "weight bits should be 1, 4, 8, 16, or 32."
78
+
79
+ # region compute tile size
80
+ self.comp_n = comp_n if comp_n is not None else 16
81
+ # smallest tile size in `n` dimension for MMA computation.
82
+ self.comp_k = comp_k if comp_k is not None else 256 // self.bits
83
+ # smallest tile size in `k` dimension for MMA computation.
84
+ # the smallest MMA computation may contain several MMA instructions
85
+ self.insn_n = 8 # mma instruction tile size in `n` dimension
86
+ # tile size in `n` dimension for MMA instruction.
87
+ self.insn_k = self.comp_k
88
+ # tile size in `k` dimension for MMA instruction.
89
+ assert self.insn_k * self.bits in (
90
+ 128,
91
+ 256,
92
+ ), f"insn_k ({self.insn_k}) * bits ({self.bits}) should be 128 or 256."
93
+ assert self.comp_n % self.insn_n == 0, f"comp_n ({self.comp_n}) should be divisible by insn_n ({self.insn_n})."
94
+ self.num_lanes = 32
95
+ # there are 32 lanes (or threads) in a warp.
96
+ self.num_k_lanes = 4
97
+ self.num_n_lanes = 8
98
+ assert (
99
+ warp_n >= self.comp_n and warp_n % self.comp_n == 0
100
+ ), f"warp_n ({warp_n}) should be divisible by comp_n({self.comp_n})."
101
+ self.warp_n = warp_n
102
+ # endregion
103
+ # region memory
104
+ self.reg_k = 32 // self.bits
105
+ # number of elements in a register in `k` dimension.
106
+ self.reg_n = 1
107
+ # number of elements in a register in `n` dimension (always 1).
108
+ self.k_pack_size = self.comp_k // (self.num_k_lanes * self.reg_k)
109
+ # number of elements in a pack in `k` dimension.
110
+ self.n_pack_size = self.comp_n // (self.num_n_lanes * self.reg_n)
111
+ # number of elements in a pack in `n` dimension.
112
+ self.pack_size = self.k_pack_size * self.n_pack_size
113
+ # number of elements in a pack accessed by a lane at a time.
114
+ assert 1 <= self.pack_size <= 4, "pack size should be less than or equal to 4."
115
+ assert self.k_pack_size * self.num_k_lanes * self.reg_k == self.comp_k
116
+ assert self.n_pack_size * self.num_n_lanes * self.reg_n == self.comp_n
117
+ self.mem_k = self.comp_k
118
+ # the tile size in `k` dimension for one tensor memory access.
119
+ self.mem_n = warp_n
120
+ # the tile size in `n` dimension for one tensor memory access.
121
+ self.num_k_packs = self.mem_k // (self.k_pack_size * self.num_k_lanes * self.reg_k)
122
+ # number of packs in `k` dimension for one tensor memory access.
123
+ self.num_n_packs = self.mem_n // (self.n_pack_size * self.num_n_lanes * self.reg_n)
124
+ # number of packs in `n` dimension for one tensor memory access.
125
+ # endregion
126
+
127
+ def get_view_shape(self, n: int, k: int) -> tuple[int, int, int, int, int, int, int, int, int, int]:
128
+ """
129
+ Returns the tensor view shape for MMA operations.
130
+
131
+ Parameters
132
+ ----------
133
+ n : int
134
+ Output channel size (must be divisible by mem_n).
135
+ k : int
136
+ Input channel size (must be divisible by mem_k).
137
+
138
+ Returns
139
+ -------
140
+ tuple of int
141
+ (n_tiles, num_n_packs, n_pack_size, num_n_lanes, reg_n,
142
+ k_tiles, num_k_packs, k_pack_size, num_k_lanes, reg_k)
143
+
144
+ Raises
145
+ ------
146
+ AssertionError
147
+ If n or k is not divisible by mem_n or mem_k.
148
+ """
149
+ assert n % self.mem_n == 0, "output channel size should be divisible by mem_n."
150
+ assert k % self.mem_k == 0, "input channel size should be divisible by mem_k."
151
+ return (
152
+ n // self.mem_n,
153
+ self.num_n_packs,
154
+ self.n_pack_size,
155
+ self.num_n_lanes,
156
+ self.reg_n,
157
+ k // self.mem_k,
158
+ self.num_k_packs,
159
+ self.k_pack_size,
160
+ self.num_k_lanes,
161
+ self.reg_k,
162
+ )
163
+
164
+
165
+ class NunchakuWeightPacker(MmaWeightPackerBase):
166
+ """
167
+ Nunchaku-specific weight packer. Provide Nunchaku-specific packing of
168
+ quantized weights, scales, and low-rank weights.
169
+
170
+ Parameters
171
+ ----------
172
+ bits : int
173
+ Number of quantization bits. Must be 1, 4, 8, 16, or 32.
174
+ warp_n : int, optional
175
+ Warp size in the n dimension. Default is 128.
176
+
177
+ Attributes
178
+ ----------
179
+ num_k_unrolls : int
180
+ Number of unrolls in the k dimension (always 2 for Nunchaku).
181
+ """
182
+
183
+ def __init__(self, bits: int, warp_n: int = 128):
184
+ super().__init__(bits=bits, warp_n=warp_n)
185
+ self.num_k_unrolls = 2
186
+
187
+ def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
188
+ """
189
+ Pack quantized weight tensor for Nunchaku MMA.
190
+
191
+ Parameters
192
+ ----------
193
+ weight : torch.Tensor
194
+ Quantized weight tensor of dtype torch.int32 and shape (n, k).
195
+
196
+ Returns
197
+ -------
198
+ torch.Tensor
199
+ Packed weight tensor of dtype torch.int8.
200
+ """
201
+ assert weight.dtype == torch.int32, f"quantized weight should be torch.int32, but got {weight.dtype}."
202
+ n, k = weight.shape
203
+ assert n % self.mem_n == 0, f"output channel size ({n}) should be divisible by mem_n ({self.mem_n})."
204
+ # currently, Nunchaku did not check the boundry of unrolled `k` dimension
205
+ assert k % (self.mem_k * self.num_k_unrolls) == 0, (
206
+ f"input channel size ({k}) should be divisible by "
207
+ f"mem_k ({self.mem_k}) * num_k_unrolls ({self.num_k_unrolls})."
208
+ )
209
+ n_tiles, k_tiles = n // self.mem_n, k // self.mem_k
210
+ weight = weight.reshape(
211
+ n_tiles,
212
+ self.num_n_packs, # 8 when warp_n = 128
213
+ self.n_pack_size, # always 2 in nunchaku
214
+ self.num_n_lanes, # constant 8
215
+ self.reg_n, # constant 1
216
+ k_tiles,
217
+ self.num_k_packs, # 1
218
+ self.k_pack_size, # always 2 in nunchaku
219
+ self.num_k_lanes, # constant 4
220
+ self.reg_k, # always 8 = 32 bits / 4 bits
221
+ )
222
+ # (n_tiles, num_n_packs, n_pack_size, num_n_lanes, reg_n, k_tiles, num_k_packs, k_pack_size, num_k_lanes, reg_k)
223
+ # =>
224
+ # (n_tiles, k_tiles, num_k_packs, num_n_packs, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, reg_n, reg_k)
225
+ weight = weight.permute(0, 5, 6, 1, 3, 8, 2, 7, 4, 9).contiguous()
226
+ assert weight.shape[4:-2] == (8, 4, 2, 2)
227
+ if self.bits == 4:
228
+ weight = weight.bitwise_and_(0xF)
229
+ shift = torch.arange(0, 32, 4, dtype=torch.int32, device=weight.device)
230
+ weight = weight.bitwise_left_shift_(shift)
231
+ weight = weight.sum(dim=-1, dtype=torch.int32)
232
+ elif self.bits == 8:
233
+ weight = weight.bitwise_and_(0xFF)
234
+ shift = torch.arange(0, 32, 8, dtype=torch.int32, device=weight.device)
235
+ weight = weight.bitwise_left_shift_(shift)
236
+ weight = weight.sum(dim=-1, dtype=torch.int32)
237
+ else:
238
+ raise NotImplementedError(f"weight bits {self.bits} is not supported.")
239
+ return weight.view(dtype=torch.int8).view(n, -1) # assume little-endian
240
+
241
+ def pack_scale(self, scale: torch.Tensor, group_size: int) -> torch.Tensor:
242
+ """
243
+ Pack scale tensor for Nunchaku MMA.
244
+
245
+ Parameters
246
+ ----------
247
+ scale : torch.Tensor
248
+ Scale tensor of dtype torch.float16 or torch.bfloat16.
249
+ group_size : int
250
+ Group size for quantization.
251
+
252
+ Returns
253
+ -------
254
+ torch.Tensor
255
+ Packed scale tensor.
256
+ """
257
+ if self.check_if_micro_scale(group_size=group_size):
258
+ return self.pack_micro_scale(scale, group_size=group_size)
259
+ # note: refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16864-c
260
+ assert scale.dtype in (torch.float16, torch.bfloat16), "currently nunchaku only supports fp16 and bf16."
261
+ n = scale.shape[0]
262
+ # nunchaku load scales all in one access
263
+ # for `[warp_n, warp_k]` weights, we load `[warp_n, warp_k / group_size]` scales
264
+ # scale loading is parallelized in `n` dimension, that is,
265
+ # `num_s_lanes` in a warp load `num_s_packs` of `s_pack_size` elements, in total `warp_s` elements
266
+ # each element in `n` dimension is 16 bit as it contains 1 fp16
267
+ # min `s_pack_size` set to 2 element, since each lane at least holds 2 accumulator results in `n` dimension
268
+ # max `s_pack_size` set to 128b/16b = 8 elements
269
+ # for `warp_n = 8`, we have
270
+ # `s_pack_size = 2`, `num_s_lanes = 4`, `num_s_packs = 1`
271
+ # for `warp_n = 128`, we have
272
+ # `s_pack_size = 4`, `num_s_lanes = 32`, `num_s_packs = 1`
273
+ # for `warp_n = 512`, we have
274
+ # `s_pack_size = 8`, `num_s_lanes = 32`, `num_s_packs = 2`
275
+ s_pack_size = min(max(self.warp_n // self.num_lanes, 2), 8)
276
+ num_s_lanes = min(self.num_lanes, self.warp_n // s_pack_size)
277
+ num_s_packs = self.warp_n // (s_pack_size * num_s_lanes)
278
+ warp_s = num_s_packs * num_s_lanes * s_pack_size
279
+ assert warp_s == self.warp_n, "warp_n for scales should be equal to warp_n for weights."
280
+ # `num_n_lanes = 8 (constant)` generates 8 elements consecutive in `n` dimension
281
+ # however, they are held by 4 lanes, each lane holds 2 elements in `n` dimension
282
+ # thus, we start from first 4 lanes, assign 2 elements to each lane, until all 8 elements are assigned
283
+ # we then repeat the process for the same 4 lanes, until each lane holds `s_pack_size` elements
284
+ # finally, we move to next 4 lanes, and repeat the process until all `num_s_lanes` lanes are assigned
285
+ # the process is repeated for `num_s_packs` times
286
+ # here is an example for `warp_n = 128, s_pack_size = 4, num_s_lanes = 32, num_s_packs = 1`
287
+ # wscales store order:
288
+ # 0 1 8 9 <-- load by lane 0, broadcast to lane {0, 4, 8, ..., 28} (8x)
289
+ # 2 3 10 11 <-- load by lane 1, broadcast to lane {1, 5, 9, ..., 29} (8x)
290
+ # 4 5 12 13 <-- load by lane 2, broadcast to lane {2, 6, 10, ..., 30} (8x)
291
+ # 6 7 14 15 <-- load by lane 3, broadcast to lane {3, 7, 11, ..., 31} (8x)
292
+ # 16 17 24 25 <-- load by lane 4, broadcast to lane {0, 4, 8, ..., 28} (8x)
293
+ # ...
294
+ # 22 23 30 31 <-- load by lane 7, broadcast to lane {3, 7, 11, ..., 31} (8x)
295
+ # ... ...
296
+ # 112 113 120 121 <-- load by lane 28, broadcast to lane {0, 4, 8, ..., 28} (8x)
297
+ # ...
298
+ # 118 119 126 127 <-- load by lane 31, broadcast to lane {3, 7, 11, ..., 31} (8x)
299
+ scale = scale.reshape(n // warp_s, num_s_packs, num_s_lanes // 4, s_pack_size // 2, 4, 2, -1)
300
+ scale = scale.permute(0, 6, 1, 2, 4, 3, 5).contiguous()
301
+ return scale.view(-1) if group_size == -1 else scale.view(-1, n) # the shape is just used for validation
302
+
303
+ def pack_micro_scale(self, scale: torch.Tensor, group_size: int) -> torch.Tensor:
304
+ """
305
+ Pack micro scale tensor for Nunchaku MMA.
306
+
307
+ Parameters
308
+ ----------
309
+ scale : torch.Tensor
310
+ Scale tensor of dtype torch.float16 or torch.bfloat16.
311
+ group_size : int
312
+ Group size for quantization (must be 16).
313
+
314
+ Returns
315
+ -------
316
+ torch.Tensor
317
+ Packed micro scale tensor.
318
+ """
319
+ assert scale.dtype in (torch.float16, torch.bfloat16), "currently nunchaku only supports fp16 and bf16."
320
+ assert scale.max() <= 448, "scale should be less than 448."
321
+ assert scale.min() >= -448, "scale should be greater than -448."
322
+ assert group_size == 16, "currently only support group size 16."
323
+ assert self.insn_k == 64, "insn_k should be 64."
324
+ scale = scale.to(dtype=torch.float8_e4m3fn)
325
+ n = scale.shape[0]
326
+ assert self.warp_n >= 32, "currently only support warp_n >= 32."
327
+ # for `[warp_n, warp_k]` weights, we load `[warp_n, warp_k / group_size]` scales
328
+ # scale loading is parallelized in `n` dimension, that is,
329
+ # `num_s_lanes` in a warp load `num_s_packs` of `s_pack_size` elements, in total `warp_s` elements
330
+ # each element in `n` dimension is 32 bit as it contains 4 fp8 in `k` dimension
331
+ # min `s_pack_size` set to 1 element
332
+ # max `s_pack_size` set to 128b/32b = 4 elements
333
+ # for `warp_n = 128`, we have
334
+ # `s_pack_size = 4`, `num_s_lanes = 32`, `num_s_packs = 1`
335
+ # for `warp_n = 512`, we have
336
+ # `s_pack_size = 8`, `num_s_lanes = 32`, `num_s_packs = 2`
337
+ s_pack_size = min(max(self.warp_n // self.num_lanes, 1), 4)
338
+ num_s_lanes = 4 * 8 # 32 lanes is divided into 4 pieces, each piece has 8 lanes at a stride of 4
339
+ num_s_packs = ceil_divide(self.warp_n, s_pack_size * num_s_lanes)
340
+ warp_s = num_s_packs * num_s_lanes * s_pack_size
341
+ assert warp_s == self.warp_n, "warp_n for scales should be equal to warp_n for weights."
342
+ # note: refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-scaling-thread-id-b-selection
343
+ # we start from first 8 lines at a stride of 4, assign 1 element to each lane, until all 8 elements are assigned
344
+ # we then move to next 8 lines at a stride of 4, and repeat the process until all 32 lanes are assigned
345
+ # here is an example for `warp_n = 128, s_pack_size = 4, num_s_lanes = 32, num_s_packs = 1`
346
+ # wscales store order:
347
+ # 0 32 64 96 <-- load by lane 0
348
+ # 8 40 72 104 <-- load by lane 1
349
+ # 16 48 80 112 <-- load by lane 2
350
+ # 24 56 88 120 <-- load by lane 3
351
+ # 1 33 65 97 <-- load by lane 4
352
+ # ...
353
+ # 25 57 81 113 <-- load by lane 7
354
+ # ...
355
+ # 7 39 71 103 <-- load by lane 28
356
+ # ...
357
+ # 31 63 95 127 <-- load by lane 31
358
+ scale = scale.view(n // warp_s, num_s_packs, s_pack_size, 4, 8, -1, self.insn_k // group_size)
359
+ scale = scale.permute(0, 5, 1, 4, 3, 2, 6).contiguous()
360
+ return scale.view(-1, n) # the shape is just used for validation
361
+
362
+ def pack_lowrank_weight(self, weight: torch.Tensor, down: bool) -> torch.Tensor:
363
+ """
364
+ Pack low-rank weight tensor.
365
+
366
+ Parameters
367
+ ----------
368
+ weight : torch.Tensor
369
+ Low-rank weight tensor of dtype torch.float16 or torch.bfloat16.
370
+ down : bool
371
+ If True, weight is for down projection in low-rank branch.
372
+
373
+ Returns
374
+ -------
375
+ torch.Tensor
376
+ Packed low-rank weight tensor.
377
+ """
378
+ assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}."
379
+ reg_n, reg_k = 1, 2 # reg_n is always 1, reg_k is 32 bits // 16 bits = 2
380
+ pack_n = self.n_pack_size * self.num_n_lanes * reg_n
381
+ pack_k = self.k_pack_size * self.num_k_lanes * reg_k
382
+ weight = pad(weight, divisor=(pack_n, pack_k), dim=(0, 1))
383
+ if down:
384
+ r, c = weight.shape
385
+ r_packs, c_packs = r // pack_n, c // pack_k
386
+ weight = weight.view(r_packs, pack_n, c_packs, pack_k).permute(2, 0, 1, 3)
387
+ else:
388
+ c, r = weight.shape
389
+ c_packs, r_packs = c // pack_n, r // pack_k
390
+ weight = weight.view(c_packs, pack_n, r_packs, pack_k).permute(0, 2, 1, 3)
391
+ weight = weight.reshape(
392
+ c_packs, r_packs, self.n_pack_size, self.num_n_lanes, reg_n, self.k_pack_size, self.num_k_lanes, reg_k
393
+ )
394
+ # (c_packs, r_packs, n_pack_size, num_n_lanes, reg_n, k_pack_size, num_k_lanes, reg_k)
395
+ # =>
396
+ # (c_packs, r_packs, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, reg_n, reg_k)
397
+ weight = weight.permute(0, 1, 3, 6, 2, 5, 4, 7).contiguous()
398
+ return weight.view(c, r)
399
+
400
+ def unpack_lowrank_weight(self, weight: torch.Tensor, down: bool) -> torch.Tensor:
401
+ """
402
+ Unpack low-rank weight tensor.
403
+
404
+ Parameters
405
+ ----------
406
+ weight : torch.Tensor
407
+ Packed low-rank weight tensor of dtype torch.float16 or torch.bfloat16.
408
+ down : bool
409
+ If True, weight is for down projection in low-rank branch.
410
+
411
+ Returns
412
+ -------
413
+ torch.Tensor
414
+ Unpacked low-rank weight tensor.
415
+ """
416
+ c, r = weight.shape
417
+ assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}."
418
+ reg_n, reg_k = 1, 2 # reg_n is always 1, reg_k is 32 bits // 16 bits = 2
419
+ pack_n = self.n_pack_size * self.num_n_lanes * reg_n
420
+ pack_k = self.k_pack_size * self.num_k_lanes * reg_k
421
+ if down:
422
+ r_packs, c_packs = r // pack_n, c // pack_k
423
+ else:
424
+ c_packs, r_packs = c // pack_n, r // pack_k
425
+ weight = weight.view(
426
+ c_packs, r_packs, self.num_n_lanes, self.num_k_lanes, self.n_pack_size, self.k_pack_size, reg_n, reg_k
427
+ )
428
+ # (c_packs, r_packs, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, reg_n, reg_k)
429
+ # =>
430
+ # (c_packs, r_packs, n_pack_size, num_n_lanes, reg_n, k_pack_size, num_k_lanes, reg_k)
431
+ weight = weight.permute(0, 1, 4, 2, 6, 5, 3, 7).contiguous()
432
+ weight = weight.view(c_packs, r_packs, pack_n, pack_k)
433
+ if down:
434
+ weight = weight.permute(1, 2, 0, 3).contiguous().view(r, c)
435
+ else:
436
+ weight = weight.permute(0, 2, 1, 3).contiguous().view(c, r)
437
+ return weight
438
+
439
+ def check_if_micro_scale(self, group_size: int) -> bool:
440
+ """
441
+ Check if micro scale packing is required.
442
+
443
+ Parameters
444
+ ----------
445
+ group_size : int
446
+ Group size for quantization.
447
+
448
+ Returns
449
+ -------
450
+ bool
451
+ True if micro scale packing is required.
452
+ """
453
+ return self.insn_k == group_size * 4
454
+
455
+ def pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
456
+ """
457
+ Pad weight tensor to required shape.
458
+
459
+ Parameters
460
+ ----------
461
+ weight : torch.Tensor
462
+ Weight tensor of shape (n, k).
463
+
464
+ Returns
465
+ -------
466
+ torch.Tensor
467
+ Padded weight tensor.
468
+ """
469
+ assert weight.ndim == 2, "weight tensor should be 2D."
470
+ return pad(weight, divisor=(self.mem_n, self.mem_k * self.num_k_unrolls), dim=(0, 1))
471
+
472
+ def pad_scale(self, scale: torch.Tensor, group_size: int, fill_value: float = 0) -> torch.Tensor:
473
+ """
474
+ Pad scale tensor to required shape.
475
+
476
+ Parameters
477
+ ----------
478
+ scale : torch.Tensor
479
+ Scale tensor.
480
+ group_size : int
481
+ Group size for quantization.
482
+ fill_value : float, optional
483
+ Value to use for padding. Default is 0.
484
+
485
+ Returns
486
+ -------
487
+ torch.Tensor
488
+ Padded scale tensor.
489
+ """
490
+ if group_size > 0 and scale.numel() > scale.shape[0]:
491
+ scale = scale.view(scale.shape[0], 1, -1, 1)
492
+ if self.check_if_micro_scale(group_size=group_size):
493
+ scale = pad(scale, divisor=(self.warp_n, self.insn_k // group_size), dim=(0, 2), fill_value=fill_value)
494
+ else:
495
+ scale = pad(scale, divisor=(self.warp_n, self.num_k_unrolls), dim=(0, 2), fill_value=fill_value)
496
+ else:
497
+ scale = pad(scale, divisor=self.warp_n, dim=0, fill_value=fill_value)
498
+ return scale
499
+
500
+ def pad_lowrank_weight(self, weight: torch.Tensor, down: bool) -> torch.Tensor:
501
+ """
502
+ Pad low-rank weight tensor to required shape.
503
+
504
+ Parameters
505
+ ----------
506
+ weight : torch.Tensor
507
+ Low-rank weight tensor.
508
+ down : bool
509
+ If True, weight is for down projection in low-rank branch.
510
+
511
+ Returns
512
+ -------
513
+ torch.Tensor
514
+ Padded low-rank weight tensor.
515
+ """
516
+ assert weight.ndim == 2, "weight tensor should be 2D."
517
+ return pad(weight, divisor=self.warp_n, dim=1 if down else 0)
nunchaku/lora/flux/utils.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for LoRAs in Flux models.
3
+ """
4
+
5
+ import typing as tp
6
+
7
+ import torch
8
+
9
+ from ...utils import ceil_divide, load_state_dict_in_safetensors
10
+
11
+
12
+ def is_nunchaku_format(lora: str | dict[str, torch.Tensor]) -> bool:
13
+ """
14
+ Check if LoRA weights are in Nunchaku format.
15
+
16
+ Parameters
17
+ ----------
18
+ lora : str or dict[str, torch.Tensor]
19
+ Path to a safetensors file or a dictionary of LoRA weights.
20
+
21
+ Returns
22
+ -------
23
+ bool
24
+ True if the weights are in Nunchaku format, False otherwise.
25
+
26
+ Examples
27
+ --------
28
+ >>> is_nunchaku_format("path/to/lora.safetensors")
29
+ True
30
+ """
31
+ if isinstance(lora, str):
32
+ tensors = load_state_dict_in_safetensors(lora, device="cpu", return_metadata=False)
33
+ assert isinstance(tensors, dict), "Expected dict when return_metadata=False"
34
+ else:
35
+ tensors = lora
36
+
37
+ for k in tensors.keys():
38
+ if ".mlp_fc" in k or "mlp_context_fc1" in k:
39
+ return True
40
+ return False
41
+
42
+
43
+ def pad(
44
+ tensor: tp.Optional[torch.Tensor],
45
+ divisor: int | tp.Sequence[int],
46
+ dim: int | tp.Sequence[int],
47
+ fill_value: float | int = 0,
48
+ ) -> torch.Tensor | None:
49
+ """
50
+ Pad a tensor so specified dimensions are divisible by given divisors.
51
+
52
+ Parameters
53
+ ----------
54
+ tensor : torch.Tensor or None
55
+ The tensor to pad. If None, returns None.
56
+ divisor : int or sequence of int
57
+ Divisor(s) for the dimension(s) to pad.
58
+ dim : int or sequence of int
59
+ Dimension(s) to pad.
60
+ fill_value : float or int, optional
61
+ Value to use for padding (default: 0).
62
+
63
+ Returns
64
+ -------
65
+ torch.Tensor or None
66
+ The padded tensor, or None if input tensor was None.
67
+
68
+ Examples
69
+ --------
70
+ >>> tensor = torch.randn(10, 20)
71
+ >>> pad(tensor, divisor=16, dim=0).shape
72
+ torch.Size([16, 20])
73
+ >>> pad(tensor, divisor=[16, 32], dim=[0, 1]).shape
74
+ torch.Size([16, 32])
75
+ """
76
+ if isinstance(divisor, int):
77
+ if divisor <= 1:
78
+ return tensor
79
+ elif all(d <= 1 for d in divisor):
80
+ return tensor
81
+ if tensor is None:
82
+ return None
83
+ shape = list(tensor.shape)
84
+ if isinstance(dim, int):
85
+ assert isinstance(divisor, int)
86
+ shape[dim] = ceil_divide(shape[dim], divisor) * divisor
87
+ else:
88
+ if isinstance(divisor, int):
89
+ divisor = [divisor] * len(dim)
90
+ for d, div in zip(dim, divisor, strict=True):
91
+ shape[d] = ceil_divide(shape[d], div) * div
92
+ result = torch.full(shape, fill_value, dtype=tensor.dtype, device=tensor.device)
93
+ result[[slice(0, extent) for extent in tensor.shape]] = tensor
94
+ return result
nunchaku/models/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .text_encoders.t5_encoder import NunchakuT5EncoderModel
2
+ from .transformers import (
3
+ NunchakuFluxTransformer2dModel,
4
+ )
5
+
6
+ __all__ = [
7
+ "NunchakuFluxTransformer2dModel",
8
+ "NunchakuT5EncoderModel",
9
+ ]
nunchaku/models/attention.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Nunchaku quantized attention-related modules.
3
+ """
4
+
5
+ import torch
6
+ from diffusers.models.activations import GELU
7
+ from diffusers.models.attention import FeedForward
8
+ from torch import nn
9
+
10
+ from ..ops.fused import fused_gelu_mlp
11
+ from .linear import SVDQW4A4Linear
12
+
13
+
14
+ class NunchakuBaseAttention(nn.Module):
15
+ """
16
+ Base class for Nunchaku attention modules.
17
+
18
+ Provides a common interface for attention modules with processor selection.
19
+
20
+ Parameters
21
+ ----------
22
+ processor : str, optional
23
+ Name of the attention processor to use. Default is "flashattn2".
24
+ *args, **kwargs :
25
+ Additional arguments for subclass initialization.
26
+ """
27
+
28
+ def __init__(self, processor: str = "flashattn2", *args, **kwargs):
29
+ super(NunchakuBaseAttention, self).__init__()
30
+ self.processor = None
31
+ self.set_processor(processor)
32
+
33
+ def set_processor(self, processor: str):
34
+ """
35
+ Set the attention processor. Must be implemented by subclasses.
36
+
37
+ Parameters
38
+ ----------
39
+ processor : str
40
+ Name of the processor to use.
41
+
42
+ Raises
43
+ ------
44
+ NotImplementedError
45
+ If not implemented in subclass.
46
+ """
47
+ raise NotImplementedError("Subclass must implement this method")
48
+
49
+
50
+ def _patch_linear(module: nn.Module, linear_cls, **kwargs) -> nn.Module:
51
+ """
52
+ Recursively replace all nn.Linear modules in a given module with a custom linear class.
53
+
54
+ Parameters
55
+ ----------
56
+ module : nn.Module
57
+ The module to patch.
58
+ linear_cls : type
59
+ The custom linear class to use for replacement.
60
+ **kwargs :
61
+ Additional arguments passed to ``from_linear``.
62
+
63
+ Returns
64
+ -------
65
+ nn.Module
66
+ The patched module with custom linear layers.
67
+ """
68
+ for name, child in module.named_children():
69
+ if isinstance(child, nn.Linear):
70
+ setattr(module, name, linear_cls.from_linear(child, **kwargs))
71
+ else:
72
+ _patch_linear(child, linear_cls, **kwargs)
73
+ return module
74
+
75
+
76
+ class NunchakuFeedForward(FeedForward):
77
+ """
78
+ Quantized feed-forward (MLP) block with fused GELU support.
79
+
80
+ Replaces linear layers in a FeedForward block with :class:`~nunchaku.models.linear.SVDQW4A4Linear` for quantized inference.
81
+ Supports fused GELU-MLP computation for efficiency.
82
+
83
+ Parameters
84
+ ----------
85
+ ff : FeedForward
86
+ Source FeedForward block to quantize.
87
+ **kwargs :
88
+ Additional arguments for SVDQW4A4Linear.
89
+
90
+ Notes
91
+ -----
92
+ For int4 quantization, the activation of the second MLP layer is shifted to be unsigned.
93
+ """
94
+
95
+ def __init__(self, ff: FeedForward, **kwargs):
96
+ super(FeedForward, self).__init__()
97
+ self.net = _patch_linear(ff.net, SVDQW4A4Linear, **kwargs)
98
+ # For int4, shift the activation of mlp_fc2 to make it unsigned
99
+ self.net[2].act_unsigned = self.net[2].precision != "nvfp4"
100
+
101
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
102
+ """
103
+ Forward pass for the quantized feed-forward block.
104
+ It will call :func:`~nunchaku.ops.fused.fused_gelu_mlp` if the first layer is GELU;
105
+ otherwise, apply modules sequentially.
106
+
107
+ Parameters
108
+ ----------
109
+ hidden_states : torch.Tensor, shape (B, D)
110
+ Input tensor.
111
+
112
+ Returns
113
+ -------
114
+ torch.Tensor, shape (B, D)
115
+ Output tensor after feed-forward transformation.
116
+ """
117
+ if isinstance(self.net[0], GELU):
118
+ return fused_gelu_mlp(hidden_states, self.net[0].proj, self.net[2])
119
+ else:
120
+ # Fallback to original implementation
121
+ for module in self.net:
122
+ hidden_states = module(hidden_states)
123
+ return hidden_states
nunchaku/models/embeddings.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Embedding layers for Nunchaku.
3
+ """
4
+
5
+ import diffusers
6
+ import torch
7
+ from packaging.version import Version
8
+ from torch import nn
9
+
10
+
11
+ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
12
+ """
13
+ Rotary positional embedding function.
14
+ Copied from https://github.com/huggingface/diffusers/blob/c9ff360966327ace3faad3807dc871a4e5447501/src/diffusers/models/transformers/transformer_flux.py#L38
15
+
16
+ Parameters
17
+ ----------
18
+ pos : torch.Tensor, shape (..., n), dtype int
19
+ Position indices.
20
+ dim : int
21
+ Embedding dimension (must be even).
22
+ theta : int
23
+ Rotary base.
24
+
25
+ Returns
26
+ -------
27
+ out : torch.Tensor, shape (B, M, D//2, 1, 2), dtype float32
28
+ Rotary embedding tensor.
29
+
30
+ Notes
31
+ -----
32
+ - B: batch size
33
+ - M: sequence length
34
+ - D: embedding dimension
35
+ """
36
+ assert dim % 2 == 0, "The dimension must be even."
37
+
38
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
39
+ omega = 1.0 / (theta**scale)
40
+
41
+ batch_size, seq_length = pos.shape
42
+ out = torch.einsum("...n,d->...nd", pos, omega)
43
+
44
+ # Sin/cos representation for rotary embedding
45
+ cos_out = torch.cos(out)
46
+ sin_out = torch.sin(out)
47
+ stacked_out = torch.stack([sin_out, cos_out], dim=-1)
48
+ out = stacked_out.view(batch_size, -1, dim // 2, 1, 2)
49
+
50
+ return out.float()
51
+
52
+
53
+ class NunchakuFluxPosEmbed(nn.Module):
54
+ """
55
+ Nunchaku multi-dimensional rotary embedding module for FLUX.
56
+ Adapted from https://github.com/huggingface/diffusers/blob/c9ff360966327ace3faad3807dc871a4e5447501/src/diffusers/models/transformers/transformer_flux.py#L55
57
+
58
+ Parameters
59
+ ----------
60
+ dim : int
61
+ Embedding dimension.
62
+ theta : int
63
+ Rotary base.
64
+ axes_dim : list of int
65
+ Dimension for each spatial axis.
66
+ """
67
+
68
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
69
+ super(NunchakuFluxPosEmbed, self).__init__()
70
+ self.dim = dim
71
+ self.theta = theta
72
+ self.axes_dim = axes_dim
73
+
74
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
75
+ """
76
+ Compute rotary embeddings for multi-dimensional positions.
77
+
78
+ Parameters
79
+ ----------
80
+ ids : torch.Tensor, shape (..., n_axes), dtype int
81
+ Position indices.
82
+
83
+ Returns
84
+ -------
85
+ out : torch.Tensor, shape (B, 1, ...), dtype float32
86
+ Rotary embedding tensor.
87
+
88
+ Notes
89
+ -----
90
+ - B: batch size
91
+ - n_axes: number of spatial axes
92
+ """
93
+ if Version(diffusers.__version__) >= Version("0.31.0"):
94
+ ids = ids[None, ...]
95
+ n_axes = ids.shape[-1]
96
+ emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
97
+ return emb.unsqueeze(1)
98
+
99
+
100
+ def pack_rotemb(rotemb: torch.Tensor) -> torch.Tensor:
101
+ """
102
+ Pack rotary embeddings for efficient CUDA computation.
103
+
104
+ Parameters
105
+ ----------
106
+ rotemb : torch.Tensor, shape (B, M, D//2, 1, 2), dtype float32
107
+ Rotary embedding tensor.
108
+
109
+ Returns
110
+ -------
111
+ packed : torch.Tensor, shape (B, M, D), dtype float32
112
+ Packed rotary embedding tensor.
113
+
114
+ Notes
115
+ -----
116
+ - B: batch size
117
+ - M: sequence length (must be divisible by 16)
118
+ - D: embedding dimension (must be divisible by 8)
119
+ """
120
+ assert rotemb.dtype == torch.float32
121
+ B = rotemb.shape[0]
122
+ M = rotemb.shape[1]
123
+ D = rotemb.shape[2] * 2
124
+ assert rotemb.shape == (B, M, D // 2, 1, 2)
125
+ assert M % 16 == 0
126
+ assert D % 8 == 0
127
+ rotemb = rotemb.reshape(B, M // 16, 16, D // 8, 8)
128
+ rotemb = rotemb.permute(0, 1, 3, 2, 4)
129
+ # 16*8 pack, FP32 accumulator (C) format
130
+ # https://docs.nvidia.com/cuda/parallel-thread-execution/#mma-16816-c
131
+ ##########################################|--M--|--D--|
132
+ ##########################################|-3--4--5--6|
133
+ ########################################## : : : :
134
+ rotemb = rotemb.reshape(*rotemb.shape[0:3], 2, 8, 4, 2)
135
+ rotemb = rotemb.permute(0, 1, 2, 4, 5, 3, 6)
136
+ rotemb = rotemb.contiguous()
137
+ rotemb = rotemb.view(B, M, D)
138
+ return rotemb
nunchaku/models/linear.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Quantized linear layers for Nunchaku.
3
+ """
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from ..ops.gemm import svdq_gemm_w4a4_cuda
9
+ from ..ops.gemv import awq_gemv_w4a16_cuda
10
+ from ..ops.quantize import svdq_quantize_w4a4_act_fuse_lora_cuda
11
+
12
+
13
+ class SVDQW4A4Linear(nn.Module):
14
+ """
15
+ `SVDQuant <paper_svdquant_>`_ W4A4 quantized linear layer.
16
+
17
+ Parameters
18
+ ----------
19
+ in_features : int
20
+ Input feature dimension.
21
+ out_features : int
22
+ Output feature dimension.
23
+ rank : int, optional
24
+ SVD low-rank dimension. Default is 32.
25
+ bias : bool, optional
26
+ If True, adds a learnable bias. Default is True.
27
+ precision : {'int4', 'nvfp4'}, optional
28
+ Quantization precision data type ('int4' or 'nvfp4'). Default is 'int4'.
29
+ act_unsigned : bool, optional
30
+ If True, use unsigned activation quantization (int4 only). Default is False.
31
+ torch_dtype : torch.dtype, optional
32
+ Parameter dtype. Default is torch.bfloat16.
33
+ device : str or torch.device or None, optional
34
+ Device for parameters. Default is CPU.
35
+
36
+ Attributes
37
+ ----------
38
+ in_features : int
39
+ out_features : int
40
+ rank : int
41
+ precision : str
42
+ 'int4' or 'nvfp4'.
43
+ group_size : int
44
+ 64 for int4, 16 for nvfp4.
45
+ qweight : nn.Parameter
46
+ Packed quantized weights, shape (out_features, in_features // 2), dtype int8.
47
+ bias : nn.Parameter or None
48
+ Bias tensor.
49
+ wscales : nn.Parameter
50
+ Weight scales, shape (in_features // group_size, out_features).
51
+ Dtype: bfloat16/float16 (int4), float8_e4m3fn (nvfp4).
52
+ smooth_factor : nn.Parameter
53
+ Smoothing factors, shape (in_features,).
54
+ smooth_factor_orig : nn.Parameter
55
+ Original smoothing factors, shape (in_features,). (Unused)
56
+ proj_down : nn.Parameter
57
+ Packed low-rank down projection, shape (in_features, rank), dtype bfloat16/float16.
58
+ proj_up : nn.Parameter
59
+ Packed low-rank up projection, shape (out_features, rank), dtype bfloat16/float16.
60
+ wtscale : float or None
61
+ Global weight scale (nvfp4 only).
62
+ wcscales : nn.Parameter or None
63
+ Channel-wise weight scale (nvfp4 only), shape (out_features,), dtype float8_e4m3fn.
64
+ act_unsigned : bool
65
+ If True, input activations are unsigned (int4 only).
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ in_features: int,
71
+ out_features: int,
72
+ rank: int = 32,
73
+ bias: bool = True,
74
+ precision: str = "int4",
75
+ act_unsigned: bool = False,
76
+ torch_dtype: torch.dtype = torch.bfloat16,
77
+ device: str | torch.device | None = None,
78
+ ):
79
+ super(SVDQW4A4Linear, self).__init__()
80
+ if device is None:
81
+ device = torch.device("cpu")
82
+ self.in_features = in_features
83
+ self.out_features = out_features
84
+ self.rank = rank
85
+
86
+ self.precision = precision
87
+ self.torch_dtype = torch_dtype
88
+
89
+ if precision == "nvfp4":
90
+ self.group_size = 16
91
+ elif precision == "int4":
92
+ self.group_size = 64
93
+ else:
94
+ raise ValueError(f"Invalid precision: {precision}")
95
+
96
+ self.qweight = nn.Parameter(
97
+ torch.empty(out_features, in_features // 2, dtype=torch.int8, device=device), requires_grad=False
98
+ )
99
+ self.bias = (
100
+ nn.Parameter(torch.empty(out_features, dtype=torch_dtype, device=device), requires_grad=True)
101
+ if bias
102
+ else None
103
+ )
104
+
105
+ self.wscales = nn.Parameter(
106
+ torch.empty(
107
+ in_features // self.group_size,
108
+ out_features,
109
+ dtype=torch_dtype if precision == "int4" else torch.float8_e4m3fn,
110
+ device=device,
111
+ ),
112
+ requires_grad=False,
113
+ )
114
+ self.smooth_factor = nn.Parameter(
115
+ torch.empty(in_features, dtype=torch_dtype, device=device), requires_grad=False
116
+ )
117
+ self.smooth_factor_orig = nn.Parameter(
118
+ torch.empty(in_features, dtype=torch_dtype, device=device), requires_grad=False
119
+ )
120
+
121
+ self.proj_down = nn.Parameter(torch.empty(in_features, rank, dtype=torch_dtype, device=device))
122
+ self.proj_up = nn.Parameter(torch.empty(out_features, rank, dtype=torch_dtype, device=device))
123
+
124
+ if precision == "nvfp4":
125
+ self.wcscales = nn.Parameter(
126
+ torch.ones(out_features, dtype=torch_dtype, device=device), requires_grad=False
127
+ )
128
+ self.wtscale = 1.0
129
+ else:
130
+ self.wtscale = None
131
+ self.wcscales = None
132
+
133
+ self.act_unsigned = act_unsigned
134
+
135
+ @classmethod
136
+ def from_linear(cls, linear: nn.Linear, **kwargs):
137
+ """
138
+ Create an SVDQW4A4Linear from a standard nn.Linear. The weight and bias are dummy tensors.
139
+
140
+ Parameters
141
+ ----------
142
+ linear : nn.Linear
143
+ Source linear layer.
144
+ **kwargs
145
+ Additional init arguments.
146
+
147
+ Returns
148
+ -------
149
+ SVDQW4A4Linear
150
+ """
151
+ in_features = kwargs.pop("in_features", linear.in_features)
152
+ return cls(
153
+ in_features=in_features,
154
+ out_features=linear.out_features,
155
+ bias=linear.bias is not None,
156
+ torch_dtype=linear.weight.dtype,
157
+ device=linear.weight.device,
158
+ **kwargs,
159
+ )
160
+
161
+ def forward(self, x: torch.Tensor, output: torch.Tensor | None = None) -> torch.Tensor:
162
+ """
163
+ Forward pass with 16-bit input. It will call :meth:`quantize` and :meth:`forward_quant`.
164
+
165
+ Parameters
166
+ ----------
167
+ x : torch.Tensor, shape (B, S, in_features), dtype float16 or bfloat16
168
+ Input tensor.
169
+ output : torch.Tensor or None, optional
170
+ Optional output buffer.
171
+
172
+ Returns
173
+ -------
174
+ torch.Tensor, shape (B, S, out_features)
175
+ Output tensor.
176
+
177
+ Notes
178
+ -----
179
+ B: batch size, S: sequence length
180
+ """
181
+ batch_size, seq_len, channels = x.shape
182
+ x = x.reshape(batch_size * seq_len, channels)
183
+ if output is None:
184
+ output = torch.empty(batch_size * seq_len, self.out_features, dtype=x.dtype, device=x.device)
185
+ quantized_x, ascales, lora_act_out = self.quantize(x)
186
+ output = self.forward_quant(quantized_x, ascales, lora_act_out, output)
187
+ output = output.reshape(batch_size, seq_len, -1)
188
+ return output
189
+
190
+ def quantize(self, x: torch.Tensor, pad_size: int = 256) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
191
+ """
192
+ Quantize input to 4-bit and compute low-rank hidden states. It will call :func:`~nunchaku.ops.quantize.svdq_quantize_w4a4_act_fuse_lora_cuda`.
193
+
194
+ Parameters
195
+ ----------
196
+ x : torch.Tensor, shape (N, in_features), dtype float16 or bfloat16
197
+ Input tensor.
198
+ pad_size : int, optional
199
+ Batch padding size. Default is 256.
200
+
201
+ Returns
202
+ -------
203
+ quantized_x : torch.Tensor
204
+ Quantized input, shape (pad_size * ceil(N / pad_size), in_features // 2), dtype uint8.
205
+ ascales : torch.Tensor
206
+ Activation scales, shape (in_features // group_size,), dtype float8_e4m3fn for nvfp4 and input dtype for int4.
207
+ lora_act_out : torch.Tensor
208
+ Low-rank hidden states, shape (pad_size * ceil(N / pad_size), rank), dtype float32.
209
+
210
+ Notes
211
+ -----
212
+ N: batch size
213
+ """
214
+ quantized_x, ascales, lora_act_out = svdq_quantize_w4a4_act_fuse_lora_cuda(
215
+ x, lora_down=self.proj_down, smooth=self.smooth_factor, fp4=self.precision == "nvfp4", pad_size=pad_size
216
+ )
217
+ return quantized_x, ascales, lora_act_out
218
+
219
+ def forward_quant(
220
+ self,
221
+ quantized_x: torch.Tensor,
222
+ ascales: torch.Tensor,
223
+ lora_act: torch.Tensor,
224
+ output: torch.Tensor | None = None,
225
+ ) -> torch.Tensor:
226
+ """
227
+ Forward pass with pre-quantized input. It will call :func:`~nunchaku.ops.gemm.svdq_gemm_w4a4_cuda`.
228
+
229
+ Parameters
230
+ ----------
231
+ quantized_x : torch.Tensor
232
+ Quantized input, shape (N, in_features // 2), dtype uint8.
233
+ ascales : torch.Tensor
234
+ Activation scales, shape (in_features // group_size,), dtype float8_e4m3fn for nvfp4 and input dtype for int4.
235
+ lora_act : torch.Tensor
236
+ Low-rank hidden states, shape (N, rank), dtype float32.
237
+ output : torch.Tensor or None, optional
238
+ Optional output buffer.
239
+
240
+ Returns
241
+ -------
242
+ torch.Tensor
243
+ Output tensor, shape (N, out_features), dtype bfloat16/float16 for int4 and float8_e4m3fn for nvfp4.
244
+
245
+ Notes
246
+ -----
247
+ N: batch size
248
+ """
249
+ if output is None:
250
+ output = torch.empty(
251
+ quantized_x.shape[0], self.out_features, dtype=self.proj_up.dtype, device=quantized_x.device
252
+ )
253
+
254
+ svdq_gemm_w4a4_cuda(
255
+ act=quantized_x,
256
+ wgt=self.qweight,
257
+ out=output,
258
+ ascales=ascales,
259
+ wscales=self.wscales,
260
+ lora_act_in=lora_act,
261
+ lora_up=self.proj_up,
262
+ bias=self.bias,
263
+ fp4=self.precision == "nvfp4",
264
+ alpha=self.wtscale,
265
+ wcscales=self.wcscales,
266
+ act_unsigned=self.act_unsigned,
267
+ )
268
+ return output
269
+
270
+ def __repr__(self):
271
+ return (
272
+ f"SVDQW4A4Linear(in_features={self.in_features}, out_features={self.out_features}, "
273
+ f"rank={self.rank}, precision={self.precision}, act_unsigned={self.act_unsigned})"
274
+ )
275
+
276
+
277
+ class AWQW4A16Linear(nn.Module):
278
+ """
279
+ `AWQ <paper_awq_>`_ W4A16 quantized linear layer.
280
+
281
+ Parameters
282
+ ----------
283
+ in_features : int
284
+ Input feature dimension.
285
+ out_features : int
286
+ Output feature dimension.
287
+ bias : bool, optional
288
+ If True, adds learnable bias. Default is True.
289
+ group_size : int, optional
290
+ Quantization group size. Default is 64.
291
+ torch_dtype : torch.dtype, optional
292
+ Parameter dtype. Default is torch.bfloat16.
293
+ device : str or torch.device or None, optional
294
+ Device for parameters. Default is CPU.
295
+
296
+ Attributes
297
+ ----------
298
+ in_features : int
299
+ out_features : int
300
+ group_size : int
301
+ qweight : nn.Parameter
302
+ Packed quantized weights, shape (out_features // 4, in_features // 2), dtype int32.
303
+ bias : nn.Parameter or None
304
+ Bias tensor.
305
+ wscales : nn.Parameter
306
+ Weight scales, shape (in_features // group_size, out_features), dtype float16 or bfloat16.
307
+ wzeros : nn.Parameter
308
+ Weight zero points, shape (in_features // group_size, out_features), dtype float16 or bfloat16.
309
+ """
310
+
311
+ def __init__(
312
+ self,
313
+ in_features: int,
314
+ out_features: int,
315
+ bias: bool = True,
316
+ group_size: int = 64,
317
+ torch_dtype: torch.dtype = torch.bfloat16,
318
+ device: str | torch.device | None = None,
319
+ ):
320
+ super(AWQW4A16Linear, self).__init__()
321
+ if device is None:
322
+ device = torch.device("cpu")
323
+ self.in_features = in_features
324
+ self.out_features = out_features
325
+ self.group_size = group_size
326
+
327
+ self.qweight = nn.Parameter(
328
+ torch.empty(out_features // 4, in_features // 2, dtype=torch.int32, device=device), requires_grad=False
329
+ )
330
+ self.bias = (
331
+ nn.Parameter(torch.empty(out_features, dtype=torch_dtype, device=device), requires_grad=True)
332
+ if bias
333
+ else None
334
+ )
335
+ self.wscales = nn.Parameter(
336
+ torch.empty(in_features // self.group_size, out_features, dtype=torch_dtype, device=device),
337
+ requires_grad=False,
338
+ )
339
+ self.wzeros = nn.Parameter(
340
+ torch.empty(in_features // self.group_size, out_features, dtype=torch_dtype, device=device),
341
+ requires_grad=False,
342
+ )
343
+
344
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
345
+ """
346
+ Forward pass for AWQW4A16Linear.
347
+
348
+ Parameters
349
+ ----------
350
+ x : torch.Tensor, shape (N, in_features)
351
+ Input tensor.
352
+
353
+ Returns
354
+ -------
355
+ torch.Tensor, shape (N, out_features)
356
+ Output tensor.
357
+
358
+ Notes
359
+ -----
360
+ N: batch size
361
+ """
362
+ output = awq_gemv_w4a16_cuda(
363
+ in_feats=x,
364
+ kernel=self.qweight,
365
+ scaling_factors=self.wscales,
366
+ zeros=self.wzeros,
367
+ m=x.shape[0],
368
+ n=self.out_features,
369
+ k=self.in_features,
370
+ group_size=self.group_size,
371
+ )
372
+ if self.bias is not None:
373
+ view_shape = [1] * (output.ndim - 1) + [-1]
374
+ output.add_(self.bias.view(view_shape))
375
+ return output
376
+
377
+ @classmethod
378
+ def from_linear(
379
+ cls,
380
+ linear: nn.Linear,
381
+ group_size: int = 64,
382
+ torch_dtype: torch.dtype = torch.bfloat16,
383
+ device: str = "cpu",
384
+ **kwargs,
385
+ ):
386
+ """
387
+ Create an uninitialized AWQW4A16Linear from a standard nn.Linear.
388
+
389
+ Parameters
390
+ ----------
391
+ linear : nn.Linear
392
+ Source linear layer.
393
+ group_size : int, optional
394
+ Quantization group size.
395
+ torch_dtype : torch.dtype, optional
396
+ Parameter dtype.
397
+ device : str, optional
398
+ Device for parameters.
399
+
400
+ Returns
401
+ -------
402
+ AWQW4A16Linear
403
+ """
404
+ return cls(
405
+ in_features=linear.in_features,
406
+ out_features=linear.out_features,
407
+ bias=linear.bias is not None,
408
+ group_size=group_size,
409
+ torch_dtype=torch_dtype,
410
+ device=device,
411
+ )
412
+
413
+ def __repr__(self):
414
+ return f"AWQW4A16Linear(in_features={self.in_features}, out_features={self.out_features}, group_size={self.group_size})"
nunchaku/models/normalization.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Quantized normalization layers for efficient inference.
3
+ """
4
+
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ from diffusers.models.normalization import AdaLayerNormZero, AdaLayerNormZeroSingle
9
+
10
+ from .linear import AWQW4A16Linear
11
+
12
+
13
+ class NunchakuAdaLayerNormZero(AdaLayerNormZero):
14
+ """
15
+ Nunchaku quantized AdaLayerNormZero for diffusion models.
16
+
17
+ Replaces the linear projection with AWQW4A16Linear for quantized inference.
18
+
19
+ Parameters
20
+ ----------
21
+ other : AdaLayerNormZero
22
+ Source AdaLayerNormZero instance to copy weights and structure from.
23
+ scale_shift : float, optional
24
+ Value to add to scale parameters. Default is 1.0.
25
+ Nunchaku may have already fused the scale_shift into the linear weights, so you may want to set it to 0.
26
+
27
+ Notes
28
+ -----
29
+ - B: batch size
30
+ - D: hidden dimension
31
+ """
32
+
33
+ def __init__(self, other: AdaLayerNormZero, scale_shift: float = 1.0):
34
+ super(AdaLayerNormZero, self).__init__()
35
+ self.scale_shift = scale_shift
36
+ self.emb = other.emb
37
+ self.silu = other.silu
38
+ self.linear = AWQW4A16Linear.from_linear(other.linear)
39
+ self.norm = other.norm
40
+
41
+ def forward(
42
+ self,
43
+ x: torch.Tensor,
44
+ timestep: Optional[torch.Tensor] = None,
45
+ class_labels: Optional[torch.LongTensor] = None,
46
+ hidden_dtype: Optional[torch.dtype] = None,
47
+ emb: Optional[torch.Tensor] = None,
48
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
49
+ """
50
+ Forward pass for quantized AdaLayerNormZero.
51
+
52
+ Parameters
53
+ ----------
54
+ x : torch.Tensor, shape (B, D), dtype float32/float16
55
+ Input tensor.
56
+ timestep : Optional[torch.Tensor], shape (B,) or (1,), optional
57
+ Timestep embedding input.
58
+ class_labels : Optional[torch.LongTensor], shape (B,) or (1,), optional
59
+ Class label input.
60
+ hidden_dtype : Optional[torch.dtype], optional
61
+ Dtype for embedding computation.
62
+ emb : Optional[torch.Tensor], shape (B, E), optional
63
+ Precomputed embedding. If None, computed from timestep and class_labels.
64
+
65
+ Returns
66
+ -------
67
+ norm_x_scaled : torch.Tensor, shape (B, D)
68
+ Normalized and scaled input.
69
+ gate_msa : torch.Tensor, shape (B, D)
70
+ Gate for MSA branch.
71
+ shift_mlp : torch.Tensor, shape (B, D)
72
+ Shift for MLP branch.
73
+ scale_mlp : torch.Tensor, shape (B, D)
74
+ Scale for MLP branch.
75
+ gate_mlp : torch.Tensor, shape (B, D)
76
+ Gate for MLP branch.
77
+
78
+ Notes
79
+ -----
80
+ - B: batch size
81
+ - D: hidden dimension
82
+ """
83
+ if self.emb is not None:
84
+ emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
85
+ emb = self.linear(self.silu(emb))
86
+
87
+ # The weight layout has changed; use split_mod rather than chunk to separate the embedding.
88
+ emb = emb.view(emb.shape[0], -1, 6).permute(2, 0, 1)
89
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb
90
+
91
+ norm_x = self.norm(x)
92
+
93
+ if self.scale_shift != 0:
94
+ scale_msa.add_(self.scale_shift)
95
+ scale_mlp.add_(self.scale_shift)
96
+
97
+ norm_x_scaled = norm_x * scale_msa[:, None] + shift_msa[:, None]
98
+ return norm_x_scaled, gate_msa, shift_mlp, scale_mlp, gate_mlp
99
+
100
+
101
+ class NunchakuAdaLayerNormZeroSingle(AdaLayerNormZeroSingle):
102
+ """
103
+ Nunchaku quantized AdaLayerNormZeroSingle.
104
+
105
+ Uses AWQW4A16Linear for quantized embedding projection. Suitable for single-branch normalization.
106
+
107
+ Parameters
108
+ ----------
109
+ other : AdaLayerNormZeroSingle
110
+ Source AdaLayerNormZeroSingle instance to copy weights and structure from.
111
+ scale_shift : float, optional
112
+ Value to add to scale parameters. Default is 1.0.
113
+ Nunchaku may have already fused the scale_shift into the linear weights, so you may want to set it to 0.
114
+
115
+ Notes
116
+ -----
117
+ - B: batch size
118
+ - D: hidden dimension
119
+ """
120
+
121
+ def __init__(self, other: AdaLayerNormZeroSingle, scale_shift: float = 1.0):
122
+ super(AdaLayerNormZeroSingle, self).__init__()
123
+ self.scale_shift = scale_shift
124
+ self.silu = other.silu
125
+ self.linear = AWQW4A16Linear.from_linear(other.linear)
126
+ self.norm = other.norm
127
+
128
+ def forward(
129
+ self,
130
+ x: torch.Tensor,
131
+ emb: Optional[torch.Tensor] = None,
132
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
133
+ """
134
+ Forward pass for quantized AdaLayerNormZeroSingle.
135
+
136
+ Parameters
137
+ ----------
138
+ x : torch.Tensor, shape (B, D), dtype float32/float16
139
+ Input tensor.
140
+ emb : Optional[torch.Tensor], shape (B, E), optional
141
+ Embedding tensor.
142
+
143
+ Returns
144
+ -------
145
+ norm_x_scaled : torch.Tensor, shape (B, D)
146
+ Normalized and scaled input.
147
+ gate_msa : torch.Tensor, shape (B, D)
148
+ Gate for MSA branch.
149
+
150
+ Notes
151
+ -----
152
+ - B: batch size
153
+ - D: hidden dimension
154
+ """
155
+ emb = self.linear(self.silu(emb))
156
+
157
+ # The weight layout has changed; use split_mod rather than chunk to separate the embedding.
158
+ emb = emb.view(emb.shape[0], -1, 3).permute(2, 0, 1)
159
+ shift_msa, scale_msa, gate_msa = emb
160
+
161
+ if self.scale_shift != 0:
162
+ scale_msa.add_(self.scale_shift)
163
+
164
+ norm_x = self.norm(x)
165
+ norm_x_scaled = norm_x * scale_msa[:, None] + shift_msa[:, None]
166
+ return norm_x_scaled, gate_msa
nunchaku/models/text_encoders/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .t5_encoder import NunchakuT5EncoderModel
2
+
3
+ __all__ = [
4
+ "NunchakuT5EncoderModel",
5
+ ]
nunchaku/models/text_encoders/linear.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ This module provides the :class:`W4Linear` quantized linear layer, which implements
4
+ 4-bit weight-only quantization for efficient inference.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from ..._C.ops import gemm_awq, gemv_awq
11
+ from .tinychat_utils import ceil_num_groups, convert_to_tinychat_w4x16y16_linear_weight
12
+
13
+ __all__ = ["W4Linear"]
14
+
15
+
16
+ class W4Linear(nn.Module):
17
+ """
18
+ 4-bit quantized linear layer with group-wise quantization.
19
+
20
+ Parameters
21
+ ----------
22
+ in_features : int
23
+ Number of input features.
24
+ out_features : int
25
+ Number of output features.
26
+ bias : bool, optional
27
+ If True, adds a learnable bias (default: False).
28
+ group_size : int, optional
29
+ Number of input channels per quantization group (default: 128).
30
+ If -1, uses the full input dimension as a single group.
31
+ dtype : torch.dtype, optional
32
+ Data type for quantization scales and zeros (default: torch.float16).
33
+ device : str or torch.device, optional
34
+ Device for weights and buffers (default: "cuda").
35
+
36
+ Attributes
37
+ ----------
38
+ in_features : int
39
+ out_features : int
40
+ group_size : int
41
+ qweight : torch.Tensor
42
+ Quantized weight tensor (int16).
43
+ scales : torch.Tensor
44
+ Per-group scale tensor.
45
+ scaled_zeros : torch.Tensor
46
+ Per-group zero-point tensor (scaled).
47
+ bias : torch.Tensor or None
48
+ Optional bias tensor.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ in_features: int,
54
+ out_features: int,
55
+ bias: bool = False,
56
+ group_size: int = 128,
57
+ dtype: torch.dtype = torch.float16,
58
+ device: str | torch.device = "cuda",
59
+ ):
60
+ super().__init__()
61
+ assert dtype in (torch.float16, torch.bfloat16), f"Unsupported dtype: {dtype}"
62
+
63
+ self.in_features = in_features
64
+ self.out_features = out_features
65
+ self.group_size = group_size if group_size != -1 else in_features
66
+ assert self.in_features % self.group_size == 0
67
+ assert out_features % (32 // self.weight_bits) == 0
68
+ self.ceil_num_groups = ceil_num_groups(
69
+ in_features=self.in_features,
70
+ group_size=self.group_size,
71
+ weight_bits=self.weight_bits,
72
+ )
73
+
74
+ assert out_features % (self.interleave) == 0
75
+ self.register_buffer(
76
+ "qweight",
77
+ torch.zeros(
78
+ (
79
+ self.out_features // self.interleave,
80
+ self.in_features // (16 // self.weight_bits) * self.interleave,
81
+ ),
82
+ dtype=torch.int16,
83
+ device=device,
84
+ ),
85
+ )
86
+ self.register_buffer(
87
+ "scales",
88
+ torch.zeros((self.ceil_num_groups, self.out_features), dtype=dtype, device=device),
89
+ )
90
+ self.register_buffer(
91
+ "scaled_zeros",
92
+ torch.zeros((self.ceil_num_groups, self.out_features), dtype=dtype, device=device),
93
+ )
94
+ if bias:
95
+ self.register_buffer("bias", torch.zeros((out_features), dtype=dtype, device=device))
96
+ else:
97
+ self.bias = None
98
+
99
+ @property
100
+ def weight_bits(self) -> int:
101
+ """
102
+ Number of bits per quantized weight (always 4).
103
+ """
104
+ return 4
105
+
106
+ @property
107
+ def interleave(self) -> int:
108
+ """
109
+ Interleave factor for quantized weights (always 4).
110
+ """
111
+ return 4
112
+
113
+ @torch.no_grad()
114
+ def forward(self, x):
115
+ """
116
+ Forward pass.
117
+
118
+ Parameters
119
+ ----------
120
+ x : torch.Tensor
121
+ Input tensor of shape (..., in_features).
122
+
123
+ Returns
124
+ -------
125
+ torch.Tensor
126
+ Output tensor of shape (..., out_features).
127
+ """
128
+ if x.numel() / x.shape[-1] < 8:
129
+ out = gemv_awq(
130
+ x,
131
+ self.qweight,
132
+ self.scales,
133
+ self.scaled_zeros,
134
+ x.numel() // x.shape[-1],
135
+ self.out_features,
136
+ self.in_features,
137
+ self.group_size,
138
+ )
139
+ else:
140
+ if self.group_size != 128:
141
+ raise NotImplementedError("Kernel currently only supports group_size=128.")
142
+ out = gemm_awq(x, self.qweight, self.scales, self.scaled_zeros)
143
+ out = out + self.bias if self.bias is not None else out
144
+ return out
145
+
146
+ @staticmethod
147
+ def from_linear(
148
+ linear: nn.Linear,
149
+ group_size: int,
150
+ init_only: bool = False,
151
+ weight: torch.Tensor | None = None,
152
+ scale: torch.Tensor | None = None,
153
+ zero: torch.Tensor | None = None,
154
+ zero_pre_scaled: bool = False,
155
+ ) -> "W4Linear":
156
+ """
157
+ Convert a standard nn.Linear to a quantized W4Linear.
158
+
159
+ Parameters
160
+ ----------
161
+ linear : nn.Linear
162
+ The linear layer to convert.
163
+ group_size : int
164
+ Quantization group size.
165
+ init_only : bool, optional
166
+ If True, only initializes the quantized layer (default: False).
167
+ weight : torch.Tensor, optional
168
+ Precomputed quantized weight (default: None).
169
+ scale : torch.Tensor, optional
170
+ Precomputed scale tensor (default: None).
171
+ zero : torch.Tensor, optional
172
+ Precomputed zero-point tensor (default: None).
173
+ zero_pre_scaled : bool, optional
174
+ Whether the zero-point tensor is pre-scaled (default: False).
175
+
176
+ Returns
177
+ -------
178
+ W4Linear
179
+ Quantized linear layer.
180
+ """
181
+ assert isinstance(linear, nn.Linear)
182
+ weight = linear.weight.data if weight is None else weight.data
183
+ dtype, device = weight.dtype, weight.device
184
+ oc, ic = linear.out_features, linear.in_features
185
+ _linear = W4Linear(
186
+ in_features=ic,
187
+ out_features=oc,
188
+ bias=linear.bias is not None,
189
+ group_size=group_size,
190
+ dtype=dtype,
191
+ device=device,
192
+ )
193
+ if init_only:
194
+ return _linear
195
+ if linear.bias is not None:
196
+ _linear.bias.data.copy_(linear.bias.data)
197
+ if scale is None:
198
+ assert zero is None, "scale and zero point tensors should be provided together."
199
+ group_size = ic if group_size <= 0 else group_size
200
+ assert group_size <= ic, "group size should be less than or equal to input channel size."
201
+ assert ic % group_size == 0, "input channel size should be divisible by group size."
202
+ ng, gs = ic // group_size, group_size
203
+ weight = weight.to(dtype=torch.float32).view(oc, 1, ng, gs)
204
+ vmin, vmax = weight.amin(dim=-1, keepdim=True), weight.amax(dim=-1, keepdim=True)
205
+ scale = (vmax - vmin).div_(15)
206
+ scale[scale == 0] = 1.0
207
+ if zero_pre_scaled:
208
+ zero = vmin.neg_().div_(scale).round_().clamp_(0, 15)
209
+ weight = weight.div_(scale).add_(zero).round_().clamp_(0, 15).sub_(zero).mul_(scale)
210
+ else:
211
+ zero = vmin.neg_().clamp_min(0)
212
+ weight = weight.add_(zero).div_(scale).round_().clamp_(0, 15).mul_(scale).sub_(zero)
213
+ weight = weight.to(dtype=dtype).view(oc, ic)
214
+ scale = scale.to(dtype=dtype)
215
+ zero = zero.to(dtype=dtype)
216
+ weight, scale, zero = convert_to_tinychat_w4x16y16_linear_weight(
217
+ weight=weight,
218
+ scale=scale,
219
+ zero=zero,
220
+ group_size=group_size,
221
+ zero_pre_scaled=zero_pre_scaled,
222
+ )
223
+ _linear.qweight.data.copy_(weight)
224
+ _linear.scales.data.copy_(scale)
225
+ _linear.scaled_zeros.data.copy_(zero)
226
+ return _linear
227
+
228
+ def extra_repr(self) -> str:
229
+ """
230
+ Returns a string describing the layer configuration.
231
+ """
232
+ return "in_features={}, out_features={}, bias={}, weight_bits={}, group_size={}".format(
233
+ self.in_features,
234
+ self.out_features,
235
+ self.bias is not None,
236
+ self.weight_bits,
237
+ self.group_size,
238
+ )
nunchaku/models/text_encoders/t5_encoder.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The NunchakuT5EncoderModel class enables loading T5 encoder weights from safetensors files,
3
+ automatically replacing supported linear layers with quantized :class:`~nunchaku.models.text_encoders.linear.W4Linear`
4
+ modules for improved performance and memory efficiency.
5
+ """
6
+
7
+ import json
8
+ import logging
9
+ import os
10
+ from pathlib import Path
11
+
12
+ import torch
13
+ from accelerate import init_empty_weights
14
+ from torch import nn
15
+ from transformers import T5Config, T5EncoderModel
16
+
17
+ from ...utils import load_state_dict_in_safetensors
18
+ from .linear import W4Linear
19
+
20
+ # Get log level from environment variable (default to INFO)
21
+ log_level = os.getenv("LOG_LEVEL", "INFO").upper()
22
+
23
+ # Configure logging
24
+ logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class NunchakuT5EncoderModel(T5EncoderModel):
29
+ """
30
+ Nunchaku T5 Encoder Model
31
+
32
+ Extends :class:`transformers.T5EncoderModel` to support quantized weights and
33
+ memory-efficient inference using :class:`~nunchaku.models.text_encoders.linear.W4Linear`.
34
+
35
+ This class provides a convenient interface for loading T5 encoder weights from
36
+ safetensors files, automatically replacing supported linear layers with quantized
37
+ modules for improved speed and reduced memory usage.
38
+
39
+ Example
40
+ -------
41
+ .. code-block:: python
42
+
43
+ model = NunchakuT5EncoderModel.from_pretrained(
44
+ "mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
45
+ )
46
+ """
47
+
48
+ @classmethod
49
+ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
50
+ """
51
+ Load a :class:`NunchakuT5EncoderModel` from a safetensors file.
52
+
53
+ This method loads the model configuration and weights from a safetensors file,
54
+ initializes the model on the 'meta' device (no memory allocation for weights),
55
+ and replaces supported linear layers with quantized :class:`~nunchaku.models.text_encoders.linear.W4Linear` modules.
56
+
57
+ Parameters
58
+ ----------
59
+ pretrained_model_name_or_path : str or os.PathLike
60
+ Path to the safetensors file containing the model weights and metadata.
61
+ torch_dtype : torch.dtype, optional
62
+ Data type for model initialization (default: ``torch.bfloat16``).
63
+ Set to ``torch.float16`` for Turing GPUs.
64
+ device : str or torch.device, optional
65
+ Device to load the model onto (default: ``"cuda"``).
66
+ If the model is loaded on CPU, it will be automatically moved to GPU.
67
+
68
+ Returns
69
+ -------
70
+ NunchakuT5EncoderModel
71
+ The loaded and quantized T5 encoder model.
72
+
73
+ Example
74
+ -------
75
+ .. code-block:: python
76
+
77
+ model = NunchakuT5EncoderModel.from_pretrained(
78
+ "mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
79
+ )
80
+ """
81
+ pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
82
+ state_dict, metadata = load_state_dict_in_safetensors(pretrained_model_name_or_path, return_metadata=True)
83
+
84
+ # Load the config file from metadata
85
+ config = json.loads(metadata["config"])
86
+ config = T5Config(**config)
87
+
88
+ # Initialize model on 'meta' device (no memory allocation for weights)
89
+ with init_empty_weights():
90
+ t5_encoder = T5EncoderModel(config).to(kwargs.get("torch_dtype", torch.bfloat16))
91
+
92
+ t5_encoder.eval()
93
+
94
+ # Load the model weights from the safetensors file and quantize supported linear layers
95
+ named_modules = {}
96
+ for name, module in t5_encoder.named_modules():
97
+ assert isinstance(name, str)
98
+ if isinstance(module, nn.Linear):
99
+ if f"{name}.qweight" in state_dict:
100
+ logger.debug(f"Switching {name} to W4Linear")
101
+ qmodule = W4Linear.from_linear(module, group_size=128, init_only=True)
102
+ # modeling_t5.py: T5DenseGatedActDense needs dtype of weight
103
+ qmodule.weight = torch.empty([1], dtype=module.weight.dtype, device=module.weight.device)
104
+
105
+ parent_name, child_name = name.rsplit(".", 1)
106
+ setattr(named_modules[parent_name], child_name, qmodule)
107
+ else:
108
+ named_modules[name] = module
109
+
110
+ device = kwargs.get("device", "cuda")
111
+ if isinstance(device, str):
112
+ device = torch.device(device)
113
+ t5_encoder.to_empty(device=device)
114
+ t5_encoder.load_state_dict(state_dict, strict=True)
115
+
116
+ return t5_encoder
nunchaku/models/text_encoders/tinychat_utils.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ This module provides utility functions for quantized linear layers in the TinyChat backend.
4
+ """
5
+
6
+ import torch
7
+
8
+ __all__ = ["ceil_num_groups", "convert_to_tinychat_w4x16y16_linear_weight"]
9
+
10
+
11
+ def ceil_divide(x: int, divisor: int) -> int:
12
+ """
13
+ Compute the ceiling of integer division.
14
+
15
+ Parameters
16
+ ----------
17
+ x : int
18
+ Dividend.
19
+ divisor : int
20
+ Divisor.
21
+
22
+ Returns
23
+ -------
24
+ int
25
+ The smallest integer greater than or equal to ``x / divisor``.
26
+ """
27
+ return (x + divisor - 1) // divisor
28
+
29
+
30
+ def ceil_num_groups(in_features: int, group_size: int, weight_bits: int = 4) -> int:
31
+ """
32
+ Calculate the padded number of quantization groups for TinyChat quantization.
33
+
34
+ This ensures the number of groups is compatible with TinyChat's packing and kernel requirements.
35
+
36
+ Parameters
37
+ ----------
38
+ in_features : int
39
+ Input channel size (number of input features).
40
+ group_size : int
41
+ Quantization group size.
42
+ weight_bits : int, optional
43
+ Number of bits per quantized weight (default: 4).
44
+
45
+ Returns
46
+ -------
47
+ int
48
+ The padded number of quantization groups.
49
+
50
+ Raises
51
+ ------
52
+ AssertionError
53
+ If ``in_features`` is not divisible by ``group_size``, or if ``weight_bits`` is not 4, 2, or 1.
54
+ NotImplementedError
55
+ If ``group_size`` is not one of the supported values (>=128, 64, 32).
56
+ """
57
+ assert in_features % group_size == 0, "input channel size should be divisible by group size."
58
+ num_groups = in_features // group_size
59
+ assert weight_bits in (4, 2, 1), "weight bits should be 4, 2, or 1."
60
+ pack_size = 32 // weight_bits # one INT32 contains `pack_size` elements of weights
61
+ num_packs = ceil_divide(num_groups, pack_size)
62
+ if group_size >= 128:
63
+ num_packs_factor = 1
64
+ elif group_size == 64:
65
+ num_packs_factor = 2
66
+ elif group_size == 32:
67
+ num_packs_factor = 4
68
+ else:
69
+ raise NotImplementedError("Unsupported group size for TinyChat quantization.")
70
+ # make sure num_packs is a multiple of num_packs_factor
71
+ num_packs = ceil_divide(num_packs, num_packs_factor) * num_packs_factor
72
+ num_groups = num_packs * pack_size
73
+ return num_groups
74
+
75
+
76
+ def pack_w4(weight: torch.Tensor) -> torch.Tensor:
77
+ """
78
+ Pack quantized 4-bit weights into TinyChat's int16 format.
79
+
80
+ This function rearranges and packs 4-bit quantized weights (stored as int32) into
81
+ the format expected by TinyChat CUDA kernels.
82
+
83
+ Parameters
84
+ ----------
85
+ weight : torch.Tensor
86
+ Quantized weight tensor of shape (out_features, in_features), dtype int32.
87
+ The input channel dimension must be divisible by 32.
88
+
89
+ Returns
90
+ -------
91
+ torch.Tensor
92
+ Packed weight tensor of dtype int16.
93
+
94
+ Raises
95
+ ------
96
+ AssertionError
97
+ If input tensor is not int32 or input channel size is not divisible by 32.
98
+ """
99
+ assert weight.dtype == torch.int32, f"quantized weight should be torch.int32, but got {weight.dtype}."
100
+ oc, ic = weight.shape
101
+ assert ic % 32 == 0, "input channel size should be divisible by 32."
102
+ # [0, 1, ..., 31] -> [0, 8, 16, 24, 1, 9, 17, 25, ..., 7, 15, 23, 31]
103
+ weight = weight.view(-1, 4, 8)
104
+ weight = weight[:, 0] | (weight[:, 1] << 4) | (weight[:, 2] << 8) | (weight[:, 3] << 12)
105
+ weight = weight.view(oc // 4, 4, ic // 64, 16).permute(0, 2, 1, 3).reshape(oc // 4, ic)
106
+ return weight.to(torch.int16)
107
+
108
+
109
+ def convert_to_tinychat_w4x16y16_linear_weight(
110
+ weight: torch.Tensor,
111
+ scale: torch.Tensor,
112
+ zero: torch.Tensor,
113
+ group_size: int = -1,
114
+ zero_pre_scaled: bool = False,
115
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
116
+ """
117
+ Convert a floating-point weight tensor to TinyChat W4-X16-Y16 quantized linear format.
118
+
119
+ This function quantizes the input weights to 4 bits per value, applies group-wise
120
+ scaling and zero-point, and packs the result into the format expected by TinyChat
121
+ quantized linear layers.
122
+
123
+ Parameters
124
+ ----------
125
+ weight : torch.Tensor
126
+ Floating-point weight tensor of shape (out_features, in_features).
127
+ Must be of dtype ``torch.float16`` or ``torch.bfloat16``.
128
+ scale : torch.Tensor
129
+ Per-group scale tensor (can be broadcastable).
130
+ zero : torch.Tensor
131
+ Per-group zero-point tensor (can be broadcastable).
132
+ group_size : int, optional
133
+ Quantization group size. If set to -1 (default), uses the full input dimension as a single group.
134
+ zero_pre_scaled : bool, optional
135
+ If True, the zero tensor is already scaled by the scale tensor (default: False).
136
+
137
+ Returns
138
+ -------
139
+ tuple of torch.Tensor
140
+ - packed_weight : torch.Tensor
141
+ Packed quantized weight tensor (int16).
142
+ - packed_scale : torch.Tensor
143
+ Packed scale tensor (shape: [num_groups, out_features], dtype matches input).
144
+ - packed_zero : torch.Tensor
145
+ Packed zero-point tensor (shape: [num_groups, out_features], dtype matches input).
146
+
147
+ Raises
148
+ ------
149
+ AssertionError
150
+ If input types or shapes are invalid, or quantized values are out of range.
151
+
152
+ Example
153
+ -------
154
+ .. code-block:: python
155
+
156
+ qweight, qscale, qzero = convert_to_tinychat_w4x16y16_linear_weight(
157
+ weight, scale, zero, group_size=128
158
+ )
159
+ """
160
+ dtype, device = weight.dtype, weight.device
161
+ assert dtype in (torch.float16, torch.bfloat16), "currently tinychat only supports fp16 and bf16."
162
+ assert scale is not None, "scale tensor is required for quantization."
163
+ assert zero is not None, "zero point tensor is required for quantization."
164
+ weight = weight.to(dtype=torch.float32)
165
+ scale = scale.to(dtype=torch.float32, device=device)
166
+ zero = zero.to(dtype=torch.float32, device=device)
167
+ if zero_pre_scaled:
168
+ zero = zero * scale
169
+ oc, ic = weight.shape
170
+ group_size = ic if group_size <= 0 else group_size
171
+ assert group_size <= ic, "group size should be less than or equal to input channel size."
172
+ assert ic % group_size == 0, "input channel size should be divisible by group size."
173
+ ng = ic // group_size
174
+ if scale.numel() == 1:
175
+ scale = scale.view(1, 1).expand(oc, ng)
176
+ scale = scale.reshape(oc, ng).contiguous().view(oc, ng, 1)
177
+ if zero.numel() == 1:
178
+ zero = zero.view(1, 1).expand(oc, ng)
179
+ zero = zero.reshape(oc, ng).contiguous().view(oc, ng, 1)
180
+ weight = weight.view(oc, ng, -1).add_(zero).div_(scale).round_().view(oc, ic)
181
+ assert weight.min() >= 0 and weight.max() <= 15, "quantized weight should be in [0, 15]."
182
+ _weight = pack_w4(weight.to(torch.int32))
183
+ _ng = ceil_num_groups(ic, group_size, weight_bits=4)
184
+ _scale = torch.zeros((_ng, oc), dtype=dtype, device=device)
185
+ _zero = torch.zeros((_ng, oc), dtype=dtype, device=device)
186
+ _scale[:ng] = scale.view(oc, ng).t().to(dtype=dtype)
187
+ _zero[:ng] = zero.view(oc, ng).t().to(dtype=dtype).neg_()
188
+ return _weight, _scale, _zero
nunchaku/models/transformers/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .transformer_flux import NunchakuFluxTransformer2dModel
2
+
3
+ __all__ = [
4
+ "NunchakuFluxTransformer2dModel",
5
+ ]
nunchaku/models/transformers/transformer_flux.py ADDED
@@ -0,0 +1,991 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implements the :class:`NunchakuFluxTransformer2dModel`, a quantized transformer for Diffusers with efficient inference and LoRA support.
3
+ """
4
+
5
+ import json
6
+ import logging
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Any, Dict, Optional, Union
10
+
11
+ import diffusers
12
+ import torch
13
+ from diffusers import FluxTransformer2DModel
14
+ from diffusers.configuration_utils import register_to_config
15
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
16
+ from huggingface_hub import utils
17
+ from packaging.version import Version
18
+ from safetensors.torch import load_file
19
+ from torch import nn
20
+
21
+ from ..._C import QuantizedFluxModel
22
+ from ..._C import utils as cutils
23
+ from ...lora.flux.nunchaku_converter import fuse_vectors, to_nunchaku
24
+ from ...lora.flux.utils import is_nunchaku_format
25
+ from ...utils import check_hardware_compatibility, get_precision, load_state_dict_in_safetensors, pad_tensor
26
+ from .utils import NunchakuModelLoaderMixin
27
+
28
+ SVD_RANK = 32
29
+
30
+ # Get log level from environment variable (default to INFO)
31
+ log_level = os.getenv("LOG_LEVEL", "INFO").upper()
32
+
33
+ # Configure logging
34
+ logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ class NunchakuFluxTransformerBlocks(nn.Module):
39
+ """
40
+ Wrapper for quantized Nunchaku FLUX transformer blocks.
41
+
42
+ This class manages the forward pass, rotary embedding packing, and optional
43
+ residual callbacks for ID embeddings.
44
+
45
+ Parameters
46
+ ----------
47
+ m : QuantizedFluxModel
48
+ The quantized transformer model.
49
+ device : str or torch.device
50
+ Device to run the model on.
51
+ """
52
+
53
+ def __init__(self, m: QuantizedFluxModel, device: str | torch.device):
54
+ super(NunchakuFluxTransformerBlocks, self).__init__()
55
+ self.m = m
56
+ self.dtype = torch.bfloat16 if m.isBF16() else torch.float16
57
+ self.device = device
58
+
59
+ @staticmethod
60
+ def pack_rotemb(rotemb: torch.Tensor) -> torch.Tensor:
61
+ """
62
+ Packs rotary embeddings for efficient computation.
63
+
64
+ Parameters
65
+ ----------
66
+ rotemb : torch.Tensor
67
+ Rotary embedding tensor of shape (B, M, D//2, 1, 2), dtype float32.
68
+
69
+ Returns
70
+ -------
71
+ torch.Tensor
72
+ Packed rotary embedding tensor of shape (B, M, D).
73
+ """
74
+ assert rotemb.dtype == torch.float32
75
+ B = rotemb.shape[0]
76
+ M = rotemb.shape[1]
77
+ D = rotemb.shape[2] * 2
78
+ assert rotemb.shape == (B, M, D // 2, 1, 2)
79
+ assert M % 16 == 0
80
+ assert D % 8 == 0
81
+ rotemb = rotemb.reshape(B, M // 16, 16, D // 8, 8)
82
+ rotemb = rotemb.permute(0, 1, 3, 2, 4)
83
+ # 16*8 pack, FP32 accumulator (C) format
84
+ # https://docs.nvidia.com/cuda/parallel-thread-execution/#mma-16816-c
85
+ ##########################################|--M--|--D--|
86
+ ##########################################|-3--4--5--6|
87
+ ########################################## : : : :
88
+ rotemb = rotemb.reshape(*rotemb.shape[0:3], 2, 8, 4, 2)
89
+ rotemb = rotemb.permute(0, 1, 2, 4, 5, 3, 6)
90
+ rotemb = rotemb.contiguous()
91
+ rotemb = rotemb.view(B, M, D)
92
+ return rotemb
93
+
94
+ def forward(
95
+ self,
96
+ hidden_states: torch.Tensor,
97
+ temb: torch.Tensor,
98
+ encoder_hidden_states: torch.Tensor,
99
+ image_rotary_emb: torch.Tensor,
100
+ id_embeddings=None,
101
+ id_weight=None,
102
+ joint_attention_kwargs=None,
103
+ controlnet_block_samples=None,
104
+ controlnet_single_block_samples=None,
105
+ skip_first_layer=False,
106
+ ):
107
+ """
108
+ Forward pass for the quantized transformer blocks.
109
+ It will call the forward method of ``m`` on the C backend.
110
+
111
+ Parameters
112
+ ----------
113
+ hidden_states : torch.Tensor
114
+ Input hidden states for image tokens.
115
+ temb : torch.Tensor
116
+ Temporal embedding tensor.
117
+ encoder_hidden_states : torch.Tensor
118
+ Input hidden states for text tokens.
119
+ image_rotary_emb : torch.Tensor
120
+ Rotary embedding tensor for all tokens.
121
+ id_embeddings : torch.Tensor, optional
122
+ Optional ID embeddings for residual callback.
123
+ id_weight : float, optional
124
+ Weight for ID embedding residual.
125
+ joint_attention_kwargs : dict, optional
126
+ Additional kwargs for joint attention.
127
+ controlnet_block_samples : list[torch.Tensor], optional
128
+ ControlNet block samples.
129
+ controlnet_single_block_samples : list[torch.Tensor], optional
130
+ ControlNet single block samples.
131
+ skip_first_layer : bool, optional
132
+ Whether to skip the first layer.
133
+
134
+ Returns
135
+ -------
136
+ tuple[torch.Tensor, torch.Tensor]
137
+ (encoder_hidden_states, hidden_states) after transformer blocks.
138
+ """
139
+ # batch_size = hidden_states.shape[0]
140
+ txt_tokens = encoder_hidden_states.shape[1]
141
+ img_tokens = hidden_states.shape[1]
142
+
143
+ self.id_embeddings = id_embeddings
144
+ self.id_weight = id_weight
145
+ self.pulid_ca_idx = 0
146
+ if self.id_embeddings is not None:
147
+ self.set_pulid_residual_callback()
148
+
149
+ original_dtype = hidden_states.dtype
150
+ original_device = hidden_states.device
151
+
152
+ hidden_states = hidden_states.to(self.dtype).to(self.device)
153
+ encoder_hidden_states = encoder_hidden_states.to(self.dtype).to(self.device)
154
+ temb = temb.to(self.dtype).to(self.device)
155
+ image_rotary_emb = image_rotary_emb.to(self.device)
156
+
157
+ if controlnet_block_samples is not None:
158
+ if len(controlnet_block_samples) > 0:
159
+ controlnet_block_samples = torch.stack(controlnet_block_samples).to(self.device)
160
+ else:
161
+ controlnet_block_samples = None
162
+
163
+ if controlnet_single_block_samples is not None:
164
+ if len(controlnet_single_block_samples) > 0:
165
+ controlnet_single_block_samples = torch.stack(controlnet_single_block_samples).to(self.device)
166
+ else:
167
+ controlnet_single_block_samples = None
168
+
169
+ assert image_rotary_emb.ndim == 6
170
+ assert image_rotary_emb.shape[0] == 1
171
+ assert image_rotary_emb.shape[1] == 1
172
+ assert image_rotary_emb.shape[2] == 1 * (txt_tokens + img_tokens)
173
+ # [1, tokens, head_dim / 2, 1, 2] (sincos)
174
+ image_rotary_emb = image_rotary_emb.reshape([1, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
175
+ rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] # .to(self.dtype)
176
+ rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype)
177
+ rotary_emb_single = image_rotary_emb # .to(self.dtype)
178
+
179
+ rotary_emb_txt = self.pack_rotemb(pad_tensor(rotary_emb_txt, 256, 1))
180
+ rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1))
181
+ rotary_emb_single = self.pack_rotemb(pad_tensor(rotary_emb_single, 256, 1))
182
+ hidden_states = self.m.forward(
183
+ hidden_states,
184
+ encoder_hidden_states,
185
+ temb,
186
+ rotary_emb_img,
187
+ rotary_emb_txt,
188
+ rotary_emb_single,
189
+ controlnet_block_samples,
190
+ controlnet_single_block_samples,
191
+ skip_first_layer,
192
+ )
193
+
194
+ if self.id_embeddings is not None:
195
+ self.reset_pulid_residual_callback()
196
+
197
+ hidden_states = hidden_states.to(original_dtype).to(original_device)
198
+
199
+ encoder_hidden_states = hidden_states[:, :txt_tokens, ...]
200
+ hidden_states = hidden_states[:, txt_tokens:, ...]
201
+
202
+ return encoder_hidden_states, hidden_states
203
+
204
+ def forward_layer_at(
205
+ self,
206
+ idx: int,
207
+ hidden_states: torch.Tensor,
208
+ encoder_hidden_states: torch.Tensor,
209
+ temb: torch.Tensor,
210
+ image_rotary_emb: torch.Tensor,
211
+ joint_attention_kwargs=None,
212
+ controlnet_block_samples=None,
213
+ controlnet_single_block_samples=None,
214
+ ):
215
+ """
216
+ Forward pass for a specific transformer layer in ``m``.
217
+
218
+ Parameters
219
+ ----------
220
+ idx : int
221
+ Index of the transformer layer.
222
+ hidden_states : torch.Tensor
223
+ Input hidden states for image tokens.
224
+ encoder_hidden_states : torch.Tensor
225
+ Input hidden states for text tokens.
226
+ temb : torch.Tensor
227
+ Temporal embedding tensor.
228
+ image_rotary_emb : torch.Tensor
229
+ Rotary embedding tensor for all tokens.
230
+ joint_attention_kwargs : dict, optional
231
+ Additional kwargs for joint attention.
232
+ controlnet_block_samples : list[torch.Tensor], optional
233
+ ControlNet block samples.
234
+ controlnet_single_block_samples : list[torch.Tensor], optional
235
+ ControlNet single block samples.
236
+
237
+ Returns
238
+ -------
239
+ tuple[torch.Tensor, torch.Tensor]
240
+ (encoder_hidden_states, hidden_states) after the specified layer.
241
+ """
242
+ # batch_size = hidden_states.shape[0]
243
+ txt_tokens = encoder_hidden_states.shape[1]
244
+ img_tokens = hidden_states.shape[1]
245
+
246
+ original_dtype = hidden_states.dtype
247
+ original_device = hidden_states.device
248
+
249
+ hidden_states = hidden_states.to(self.dtype).to(self.device)
250
+ encoder_hidden_states = encoder_hidden_states.to(self.dtype).to(self.device)
251
+ temb = temb.to(self.dtype).to(self.device)
252
+ image_rotary_emb = image_rotary_emb.to(self.device)
253
+
254
+ if controlnet_block_samples is not None:
255
+ controlnet_block_samples = torch.stack(controlnet_block_samples).to(self.device)
256
+ if controlnet_single_block_samples is not None:
257
+ controlnet_single_block_samples = torch.stack(controlnet_single_block_samples).to(self.device)
258
+
259
+ assert image_rotary_emb.ndim == 6
260
+ assert image_rotary_emb.shape[0] == 1
261
+ assert image_rotary_emb.shape[1] == 1
262
+ assert image_rotary_emb.shape[2] == 1 * (txt_tokens + img_tokens)
263
+ # [1, tokens, head_dim / 2, 1, 2] (sincos)
264
+ image_rotary_emb = image_rotary_emb.reshape([1, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
265
+ rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] # .to(self.dtype)
266
+ rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype)
267
+
268
+ rotary_emb_txt = self.pack_rotemb(pad_tensor(rotary_emb_txt, 256, 1))
269
+ rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1))
270
+
271
+ hidden_states, encoder_hidden_states = self.m.forward_layer(
272
+ idx,
273
+ hidden_states,
274
+ encoder_hidden_states,
275
+ temb,
276
+ rotary_emb_img,
277
+ rotary_emb_txt,
278
+ controlnet_block_samples,
279
+ controlnet_single_block_samples,
280
+ )
281
+
282
+ hidden_states = hidden_states.to(original_dtype).to(original_device)
283
+ encoder_hidden_states = encoder_hidden_states.to(original_dtype).to(original_device)
284
+
285
+ return encoder_hidden_states, hidden_states
286
+
287
+ def set_pulid_residual_callback(self):
288
+ """
289
+ Sets the residual callback for PulID (personalized ID) embeddings.
290
+ """
291
+ id_embeddings = self.id_embeddings
292
+ pulid_ca = self.pulid_ca
293
+ pulid_ca_idx = [self.pulid_ca_idx]
294
+ id_weight = self.id_weight
295
+
296
+ def callback(hidden_states):
297
+ ip = id_weight * pulid_ca[pulid_ca_idx[0]](id_embeddings, hidden_states)
298
+ pulid_ca_idx[0] += 1
299
+ return ip
300
+
301
+ self.callback_holder = callback
302
+ self.m.set_residual_callback(callback)
303
+
304
+ def reset_pulid_residual_callback(self):
305
+ """
306
+ Resets the PulID residual callback to None.
307
+ """
308
+ self.callback_holder = None
309
+ self.m.set_residual_callback(None)
310
+
311
+ def __del__(self):
312
+ """
313
+ Destructor to reset the quantized model.
314
+ """
315
+ self.m.reset()
316
+
317
+ def norm1(
318
+ self,
319
+ hidden_states: torch.Tensor,
320
+ emb: torch.Tensor,
321
+ idx: int = 0,
322
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
323
+ """
324
+ Runs the norm_one_forward for a specific layer in ``m``.
325
+
326
+ Parameters
327
+ ----------
328
+ hidden_states : torch.Tensor
329
+ Input hidden states.
330
+ emb : torch.Tensor
331
+ Embedding tensor.
332
+ idx : int, optional
333
+ Layer index (default: 0).
334
+
335
+ Returns
336
+ -------
337
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
338
+ Output tensors from norm_one_forward.
339
+ """
340
+ return self.m.norm_one_forward(idx, hidden_states, emb)
341
+
342
+
343
+ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
344
+ """
345
+ Rotary positional embedding function.
346
+
347
+ Parameters
348
+ ----------
349
+ pos : torch.Tensor
350
+ Position tensor of shape (..., n).
351
+ dim : int
352
+ Embedding dimension (must be even).
353
+ theta : int
354
+ Rotary base.
355
+
356
+ Returns
357
+ -------
358
+ torch.Tensor
359
+ Rotary embedding tensor.
360
+ """
361
+ assert dim % 2 == 0, "The dimension must be even."
362
+
363
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
364
+ omega = 1.0 / (theta**scale)
365
+
366
+ batch_size, seq_length = pos.shape
367
+ out = torch.einsum("...n,d->...nd", pos, omega)
368
+
369
+ USE_SINCOS = True
370
+ if USE_SINCOS:
371
+ cos_out = torch.cos(out)
372
+ sin_out = torch.sin(out)
373
+ stacked_out = torch.stack([sin_out, cos_out], dim=-1)
374
+ out = stacked_out.view(batch_size, -1, dim // 2, 1, 2)
375
+ else:
376
+ out = out.view(batch_size, -1, dim // 2, 1, 1)
377
+
378
+ return out.float()
379
+
380
+
381
+ class EmbedND(nn.Module):
382
+ """
383
+ Multi-dimensional rotary embedding module.
384
+
385
+ Parameters
386
+ ----------
387
+ dim : int
388
+ Embedding dimension.
389
+ theta : int
390
+ Rotary base.
391
+ axes_dim : list[int]
392
+ List of axis dimensions for each spatial axis.
393
+ """
394
+
395
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
396
+ super(EmbedND, self).__init__()
397
+ self.dim = dim
398
+ self.theta = theta
399
+ self.axes_dim = axes_dim
400
+
401
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
402
+ """
403
+ Computes rotary embeddings for multi-dimensional positions.
404
+
405
+ Parameters
406
+ ----------
407
+ ids : torch.Tensor
408
+ Position indices tensor of shape (..., n_axes).
409
+
410
+ Returns
411
+ -------
412
+ torch.Tensor
413
+ Rotary embedding tensor.
414
+ """
415
+ if Version(diffusers.__version__) >= Version("0.31.0"):
416
+ ids = ids[None, ...]
417
+ n_axes = ids.shape[-1]
418
+ emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
419
+ return emb.unsqueeze(1)
420
+
421
+
422
+ def load_quantized_module(
423
+ path_or_state_dict: str | os.PathLike[str] | dict[str, torch.Tensor],
424
+ device: str | torch.device = "cuda",
425
+ use_fp4: bool = False,
426
+ offload: bool = False,
427
+ bf16: bool = True,
428
+ ) -> QuantizedFluxModel:
429
+ """
430
+ Loads a quantized Nunchaku FLUX model from a state dict or file.
431
+
432
+ Parameters
433
+ ----------
434
+ path_or_state_dict : str, os.PathLike, or dict
435
+ Path to the quantized model file or a state dict.
436
+ device : str or torch.device, optional
437
+ Device to load the model on (default: "cuda").
438
+ use_fp4 : bool, optional
439
+ Whether to use FP4 quantization (default: False).
440
+ offload : bool, optional
441
+ Whether to offload weights to CPU (default: False).
442
+ bf16 : bool, optional
443
+ Whether to use bfloat16 (default: True).
444
+
445
+ Returns
446
+ -------
447
+ QuantizedFluxModel
448
+ Loaded quantized model.
449
+ """
450
+ device = torch.device(device)
451
+ assert device.type == "cuda"
452
+ m = QuantizedFluxModel()
453
+ cutils.disable_memory_auto_release()
454
+ m.init(use_fp4, offload, bf16, 0 if device.index is None else device.index)
455
+ if isinstance(path_or_state_dict, dict):
456
+ m.loadDict(path_or_state_dict, True)
457
+ else:
458
+ m.load(str(path_or_state_dict))
459
+ return m
460
+
461
+
462
+ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoaderMixin):
463
+ """
464
+ Nunchaku FLUX Transformer 2D Model.
465
+
466
+ This class implements a quantized transformer model compatible with the Diffusers
467
+ library, supporting LoRA, rotary embeddings, and efficient inference.
468
+
469
+ Parameters
470
+ ----------
471
+ patch_size : int, optional
472
+ Patch size for input images (default: 1).
473
+ in_channels : int, optional
474
+ Number of input channels (default: 64).
475
+ out_channels : int or None, optional
476
+ Number of output channels (default: None).
477
+ num_layers : int, optional
478
+ Number of transformer layers (default: 19).
479
+ num_single_layers : int, optional
480
+ Number of single transformer layers (default: 38).
481
+ attention_head_dim : int, optional
482
+ Dimension of each attention head (default: 128).
483
+ num_attention_heads : int, optional
484
+ Number of attention heads (default: 24).
485
+ joint_attention_dim : int, optional
486
+ Joint attention dimension (default: 4096).
487
+ pooled_projection_dim : int, optional
488
+ Pooled projection dimension (default: 768).
489
+ guidance_embeds : bool, optional
490
+ Whether to use guidance embeddings (default: False).
491
+ axes_dims_rope : tuple[int], optional
492
+ Axes dimensions for rotary embeddings (default: (16, 56, 56)).
493
+ """
494
+
495
+ @register_to_config
496
+ def __init__(
497
+ self,
498
+ patch_size: int = 1,
499
+ in_channels: int = 64,
500
+ out_channels: int | None = None,
501
+ num_layers: int = 19,
502
+ num_single_layers: int = 38,
503
+ attention_head_dim: int = 128,
504
+ num_attention_heads: int = 24,
505
+ joint_attention_dim: int = 4096,
506
+ pooled_projection_dim: int = 768,
507
+ guidance_embeds: bool = False,
508
+ axes_dims_rope: tuple[int] = (16, 56, 56),
509
+ ):
510
+ super(NunchakuFluxTransformer2dModel, self).__init__(
511
+ patch_size=patch_size,
512
+ in_channels=in_channels,
513
+ out_channels=out_channels,
514
+ num_layers=num_layers,
515
+ num_single_layers=num_single_layers,
516
+ attention_head_dim=attention_head_dim,
517
+ num_attention_heads=num_attention_heads,
518
+ joint_attention_dim=joint_attention_dim,
519
+ pooled_projection_dim=pooled_projection_dim,
520
+ guidance_embeds=guidance_embeds,
521
+ axes_dims_rope=axes_dims_rope,
522
+ )
523
+ # these state_dicts are used for supporting lora
524
+ self._unquantized_part_sd: dict[str, torch.Tensor] = {}
525
+ self._unquantized_part_loras: dict[str, torch.Tensor] = {}
526
+ self._quantized_part_sd: dict[str, torch.Tensor] = {}
527
+ self._quantized_part_vectors: dict[str, torch.Tensor] = {}
528
+ self._original_in_channels = in_channels
529
+
530
+ # ComfyUI LoRA related
531
+ self.comfy_lora_meta_list = []
532
+ self.comfy_lora_sd_list = []
533
+
534
+ @classmethod
535
+ @utils.validate_hf_hub_args
536
+ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
537
+ """
538
+ Loads a Nunchaku FLUX transformer model from pretrained weights.
539
+
540
+ Parameters
541
+ ----------
542
+ pretrained_model_name_or_path : str or os.PathLike
543
+ Path to the model directory or HuggingFace repo.
544
+ **kwargs
545
+ Additional keyword arguments for device, offload, torch_dtype, precision, etc.
546
+
547
+ Returns
548
+ -------
549
+ NunchakuFluxTransformer2dModel or (NunchakuFluxTransformer2dModel, dict)
550
+ The loaded model, and optionally metadata if `return_metadata=True`.
551
+ """
552
+ device = kwargs.get("device", "cuda")
553
+ if isinstance(device, str):
554
+ device = torch.device(device)
555
+ offload = kwargs.get("offload", False)
556
+ torch_dtype = kwargs.get("torch_dtype", torch.bfloat16)
557
+ precision = get_precision(kwargs.get("precision", "auto"), device, pretrained_model_name_or_path)
558
+ metadata = None
559
+
560
+ if isinstance(pretrained_model_name_or_path, str):
561
+ pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
562
+ if pretrained_model_name_or_path.is_file() or pretrained_model_name_or_path.name.endswith(
563
+ (".safetensors", ".sft")
564
+ ):
565
+ transformer, model_state_dict, metadata = cls._build_model(pretrained_model_name_or_path, **kwargs)
566
+ quantized_part_sd = {}
567
+ unquantized_part_sd = {}
568
+ for k, v in model_state_dict.items():
569
+ if k.startswith(("transformer_blocks.", "single_transformer_blocks.")):
570
+ quantized_part_sd[k] = v
571
+ else:
572
+ unquantized_part_sd[k] = v
573
+ precision = get_precision(device=device)
574
+ quantization_config = json.loads(metadata["quantization_config"])
575
+ check_hardware_compatibility(quantization_config, device)
576
+ else:
577
+ transformer, unquantized_part_path, transformer_block_path = cls._build_model_legacy(
578
+ pretrained_model_name_or_path, **kwargs
579
+ )
580
+
581
+ # get the default LoRA branch and all the vectors
582
+ quantized_part_sd = load_file(transformer_block_path)
583
+ unquantized_part_sd = load_file(unquantized_part_path)
584
+ new_quantized_part_sd = {}
585
+ for k, v in quantized_part_sd.items():
586
+ if v.ndim == 1:
587
+ new_quantized_part_sd[k] = v
588
+ elif "qweight" in k:
589
+ # only the shape information of this tensor is needed
590
+ new_quantized_part_sd[k] = v.to("meta")
591
+
592
+ # if the tensor has qweight, but does not have low-rank branch, we need to add some artificial tensors
593
+ for t in ["lora_up", "lora_down"]:
594
+ new_k = k.replace(".qweight", f".{t}")
595
+ if new_k not in quantized_part_sd:
596
+ oc, ic = v.shape
597
+ ic = ic * 2 # v is packed into INT8, so we need to double the size
598
+ new_quantized_part_sd[k.replace(".qweight", f".{t}")] = torch.zeros(
599
+ (0, ic) if t == "lora_down" else (oc, 0), device=v.device, dtype=torch.bfloat16
600
+ )
601
+
602
+ elif "lora" in k:
603
+ new_quantized_part_sd[k] = v
604
+ transformer._quantized_part_sd = new_quantized_part_sd
605
+ m = load_quantized_module(
606
+ quantized_part_sd,
607
+ device=device,
608
+ use_fp4=precision == "fp4",
609
+ offload=offload,
610
+ bf16=torch_dtype == torch.bfloat16,
611
+ )
612
+ transformer.inject_quantized_module(m, device)
613
+ transformer.to_empty(device=device)
614
+
615
+ transformer.load_state_dict(unquantized_part_sd, strict=False)
616
+ transformer._unquantized_part_sd = unquantized_part_sd
617
+
618
+ if kwargs.get("return_metadata", False):
619
+ return transformer, metadata
620
+ else:
621
+ return transformer
622
+
623
+ def inject_quantized_module(self, m: QuantizedFluxModel, device: str | torch.device = "cuda"):
624
+ """
625
+ Injects a quantized module into the model and sets up transformer blocks.
626
+
627
+ Parameters
628
+ ----------
629
+ m : QuantizedFluxModel
630
+ The quantized transformer model.
631
+ device : str or torch.device, optional
632
+ Device to run the model on (default: "cuda").
633
+
634
+ Returns
635
+ -------
636
+ self : NunchakuFluxTransformer2dModel
637
+ The model with injected quantized module.
638
+ """
639
+ print("Injecting quantized module")
640
+ self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=[16, 56, 56])
641
+
642
+ ### Compatible with the original forward method
643
+ self.transformer_blocks = nn.ModuleList([NunchakuFluxTransformerBlocks(m, device)])
644
+ self.single_transformer_blocks = nn.ModuleList([])
645
+
646
+ return self
647
+
648
+ def set_attention_impl(self, impl: str):
649
+ """
650
+ Set the attention implementation for the quantized transformer block.
651
+
652
+ Parameters
653
+ ----------
654
+ impl : str
655
+ Attention implementation to use. Supported values:
656
+
657
+ - ``"flashattn2"`` (default): Standard FlashAttention-2.
658
+ - ``"nunchaku-fp16"``: Uses FP16 attention accumulation, up to 1.2× faster than FlashAttention-2 on NVIDIA 30-, 40-, and 50-series GPUs.
659
+ """
660
+ block = self.transformer_blocks[0]
661
+ assert isinstance(block, NunchakuFluxTransformerBlocks)
662
+ block.m.setAttentionImpl(impl)
663
+
664
+ ### LoRA Related Functions
665
+
666
+ def _expand_module(self, module_name: str, new_shape: tuple[int, int]):
667
+ """
668
+ Expands a linear module to a new shape for LoRA compatibility.
669
+ Mostly for FLUX.1-tools LoRA which changes the input channels.
670
+
671
+ Parameters
672
+ ----------
673
+ module_name : str
674
+ Name of the module to expand.
675
+ new_shape : tuple[int, int]
676
+ New shape (out_features, in_features) for the module.
677
+ """
678
+ module = self.get_submodule(module_name)
679
+ assert isinstance(module, nn.Linear)
680
+ weight_shape = module.weight.shape
681
+ logger.info("Expand the shape of module {} from {} to {}".format(module_name, tuple(weight_shape), new_shape))
682
+ assert new_shape[0] >= weight_shape[0] and new_shape[1] >= weight_shape[1]
683
+ new_module = nn.Linear(
684
+ new_shape[1],
685
+ new_shape[0],
686
+ bias=module.bias is not None,
687
+ device=module.weight.device,
688
+ dtype=module.weight.dtype,
689
+ )
690
+ new_module.weight.data.zero_()
691
+ new_module.weight.data[: weight_shape[0], : weight_shape[1]] = module.weight.data
692
+ self._unquantized_part_sd[f"{module_name}.weight"] = new_module.weight.data.clone()
693
+ if new_module.bias is not None:
694
+ new_module.bias.data.zero_()
695
+ new_module.bias.data[: weight_shape[0]] = module.bias.data
696
+ self._unquantized_part_sd[f"{module_name}.bias"] = new_module.bias.data.clone()
697
+ parent_name = ".".join(module_name.split(".")[:-1])
698
+ parent_module = self.get_submodule(parent_name)
699
+ parent_module.add_module(module_name.split(".")[-1], new_module)
700
+
701
+ if module_name == "x_embedder":
702
+ new_value = int(new_module.weight.data.shape[1])
703
+ old_value = getattr(self.config, "in_channels")
704
+ if new_value != old_value:
705
+ logger.info(f"Update in_channels from {old_value} to {new_value}")
706
+ setattr(self.config, "in_channels", new_value)
707
+
708
+ def _update_unquantized_part_lora_params(self, strength: float = 1):
709
+ """
710
+ Updates the unquantized part of the model with LoRA parameters.
711
+
712
+ Parameters
713
+ ----------
714
+ strength : float, optional
715
+ LoRA scaling strength (default: 1).
716
+ """
717
+ # check if we need to expand the linear layers
718
+ device = next(self.parameters()).device
719
+ for k, v in self._unquantized_part_loras.items():
720
+ if "lora_A" in k:
721
+ lora_a = v
722
+ lora_b = self._unquantized_part_loras[k.replace(".lora_A.", ".lora_B.")]
723
+ diff_shape = (lora_b.shape[0], lora_a.shape[1])
724
+ weight_shape = self._unquantized_part_sd[k.replace(".lora_A.", ".")].shape
725
+ if diff_shape[0] > weight_shape[0] or diff_shape[1] > weight_shape[1]:
726
+ module_name = ".".join(k.split(".")[:-2])
727
+ self._expand_module(module_name, diff_shape)
728
+ elif v.ndim == 1:
729
+ diff_shape = v.shape
730
+ weight_shape = self._unquantized_part_sd[k].shape
731
+ if diff_shape[0] > weight_shape[0]:
732
+ assert diff_shape[0] >= weight_shape[0]
733
+ module_name = ".".join(k.split(".")[:-1])
734
+ module = self.get_submodule(module_name)
735
+ weight_shape = module.weight.shape
736
+ diff_shape = (diff_shape[0], weight_shape[1])
737
+ self._expand_module(module_name, diff_shape)
738
+ new_state_dict = {}
739
+ for k in self._unquantized_part_sd.keys():
740
+ v = self._unquantized_part_sd[k]
741
+ v = v.to(device)
742
+ self._unquantized_part_sd[k] = v
743
+
744
+ if v.ndim == 1 and k in self._unquantized_part_loras:
745
+ diff = strength * self._unquantized_part_loras[k]
746
+ if diff.shape[0] < v.shape[0]:
747
+ diff = torch.cat(
748
+ [diff, torch.zeros(v.shape[0] - diff.shape[0], device=device, dtype=v.dtype)], dim=0
749
+ )
750
+ new_state_dict[k] = v + diff
751
+ elif v.ndim == 2 and k.replace(".weight", ".lora_B.weight") in self._unquantized_part_loras:
752
+ lora_a = self._unquantized_part_loras[k.replace(".weight", ".lora_A.weight")]
753
+ lora_b = self._unquantized_part_loras[k.replace(".weight", ".lora_B.weight")]
754
+
755
+ if lora_a.shape[1] < v.shape[1]:
756
+ lora_a = torch.cat(
757
+ [
758
+ lora_a,
759
+ torch.zeros(lora_a.shape[0], v.shape[1] - lora_a.shape[1], device=device, dtype=v.dtype),
760
+ ],
761
+ dim=1,
762
+ )
763
+ if lora_b.shape[0] < v.shape[0]:
764
+ lora_b = torch.cat(
765
+ [
766
+ lora_b,
767
+ torch.zeros(v.shape[0] - lora_b.shape[0], lora_b.shape[1], device=device, dtype=v.dtype),
768
+ ],
769
+ dim=0,
770
+ )
771
+
772
+ diff = strength * (lora_b @ lora_a)
773
+ new_state_dict[k] = v + diff
774
+ else:
775
+ new_state_dict[k] = v
776
+ self.load_state_dict(new_state_dict, strict=True)
777
+
778
+ def update_lora_params(self, path_or_state_dict: str | dict[str, torch.Tensor]):
779
+ """
780
+ Update the model with new LoRA parameters.
781
+
782
+ Parameters
783
+ ----------
784
+ path_or_state_dict : str or dict
785
+ Path to a LoRA weights file or a state dict. The path supports:
786
+
787
+ - Local file path, e.g., ``"/path/to/your/lora.safetensors"``
788
+ - HuggingFace repo with file, e.g., ``"user/repo/lora.safetensors"``
789
+ (automatically downloaded and cached)
790
+ """
791
+ if isinstance(path_or_state_dict, dict):
792
+ state_dict = {
793
+ k: v for k, v in path_or_state_dict.items()
794
+ } # copy a new one to avoid modifying the original one
795
+ else:
796
+ state_dict = load_state_dict_in_safetensors(path_or_state_dict)
797
+
798
+ if not is_nunchaku_format(state_dict):
799
+ state_dict = to_nunchaku(state_dict, base_sd=self._quantized_part_sd)
800
+
801
+ unquantized_part_loras = {}
802
+ for k, v in list(state_dict.items()):
803
+ device = next(self.parameters()).device
804
+ if "transformer_blocks" not in k:
805
+ unquantized_part_loras[k] = state_dict.pop(k).to(device)
806
+
807
+ if len(self._unquantized_part_loras) > 0 or len(unquantized_part_loras) > 0:
808
+ self._unquantized_part_loras = unquantized_part_loras
809
+
810
+ self._unquantized_part_sd = {k: v for k, v in self._unquantized_part_sd.items() if "pulid_ca" not in k}
811
+ self._update_unquantized_part_lora_params(1)
812
+
813
+ quantized_part_vectors = {}
814
+ for k, v in list(state_dict.items()):
815
+ if v.ndim == 1:
816
+ quantized_part_vectors[k] = state_dict.pop(k)
817
+ if len(self._quantized_part_vectors) > 0 or len(quantized_part_vectors) > 0:
818
+ self._quantized_part_vectors = quantized_part_vectors
819
+ updated_vectors = fuse_vectors(quantized_part_vectors, self._quantized_part_sd, 1)
820
+ state_dict.update(updated_vectors)
821
+
822
+ # Get the vectors from the quantized part
823
+
824
+ block = self.transformer_blocks[0]
825
+ assert isinstance(block, NunchakuFluxTransformerBlocks)
826
+
827
+ block.m.loadDict(state_dict, True)
828
+
829
+ def set_lora_strength(self, strength: float = 1):
830
+ """
831
+ Sets the LoRA scaling strength for the model.
832
+
833
+ Note: This function can only be used with a single LoRA. For multiple LoRAs,
834
+ please fuse the LoRA scale into the weights.
835
+
836
+ Parameters
837
+ ----------
838
+ strength : float, optional
839
+ LoRA scaling strength (default: 1).
840
+
841
+ Note: This function will change the strength of all the LoRAs. So only use it when you only have a single LoRA.
842
+ """
843
+ block = self.transformer_blocks[0]
844
+ assert isinstance(block, NunchakuFluxTransformerBlocks)
845
+ block.m.setLoraScale(SVD_RANK, strength)
846
+ if len(self._unquantized_part_loras) > 0:
847
+ self._update_unquantized_part_lora_params(strength)
848
+ if len(self._quantized_part_vectors) > 0:
849
+ vector_dict = fuse_vectors(self._quantized_part_vectors, self._quantized_part_sd, strength)
850
+ block.m.loadDict(vector_dict, True)
851
+
852
+ def reset_x_embedder(self):
853
+ """
854
+ Resets the x_embedder module if the input channel count has changed.
855
+ This is used for removing the effect of FLUX.1-tools LoRA which changes the input channels.
856
+ """
857
+ # if change the model in channels, we need to update the x_embedder
858
+ if self._original_in_channels != self.config.in_channels:
859
+ assert self._original_in_channels < self.config.in_channels
860
+ old_module = self.x_embedder
861
+ new_module = nn.Linear(
862
+ in_features=self._original_in_channels,
863
+ out_features=old_module.out_features,
864
+ bias=old_module.bias is not None,
865
+ device=old_module.weight.device,
866
+ dtype=old_module.weight.dtype,
867
+ )
868
+ new_module.weight.data.copy_(old_module.weight.data[: new_module.out_features, : new_module.in_features])
869
+ self._unquantized_part_sd["x_embedder.weight"] = new_module.weight.data.clone()
870
+ if new_module.bias is not None:
871
+ new_module.bias.data.zero_()
872
+ new_module.bias.data.copy_(old_module.bias.data[: new_module.out_features])
873
+ self._unquantized_part_sd["x_embedder.bias"] = new_module.bias.data.clone()
874
+ self.x_embedder = new_module
875
+ setattr(self.config, "in_channels", self._original_in_channels)
876
+
877
+ def reset_lora(self):
878
+ """
879
+ Resets all LoRA parameters to their default state.
880
+ """
881
+ unquantized_part_loras = {}
882
+ if len(self._unquantized_part_loras) > 0 or len(unquantized_part_loras) > 0:
883
+ self._unquantized_part_loras = unquantized_part_loras
884
+ self._update_unquantized_part_lora_params(1)
885
+ state_dict = {k: v for k, v in self._quantized_part_sd.items() if "lora" in k}
886
+ quantized_part_vectors = {}
887
+ if len(self._quantized_part_vectors) > 0 or len(quantized_part_vectors) > 0:
888
+ self._quantized_part_vectors = quantized_part_vectors
889
+ updated_vectors = fuse_vectors(quantized_part_vectors, self._quantized_part_sd, 1)
890
+ state_dict.update(updated_vectors)
891
+ self.transformer_blocks[0].m.loadDict(state_dict, True)
892
+ self.reset_x_embedder()
893
+
894
+ def forward(
895
+ self,
896
+ hidden_states: torch.Tensor,
897
+ encoder_hidden_states: torch.Tensor = None,
898
+ pooled_projections: torch.Tensor = None,
899
+ timestep: torch.LongTensor = None,
900
+ img_ids: torch.Tensor = None,
901
+ txt_ids: torch.Tensor = None,
902
+ guidance: torch.Tensor = None,
903
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
904
+ controlnet_block_samples=None,
905
+ controlnet_single_block_samples=None,
906
+ return_dict: bool = True,
907
+ controlnet_blocks_repeat: bool = False,
908
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
909
+ """
910
+ Forward pass for the Nunchaku FLUX transformer model.
911
+
912
+ This method is compatible with the Diffusers pipeline and supports LoRA,
913
+ rotary embeddings, and ControlNet.
914
+
915
+ Parameters
916
+ ----------
917
+ hidden_states : torch.FloatTensor
918
+ Input hidden states of shape (batch_size, channel, height, width).
919
+ encoder_hidden_states : torch.FloatTensor, optional
920
+ Conditional embeddings (e.g., prompt embeddings) of shape (batch_size, sequence_len, embed_dims).
921
+ pooled_projections : torch.FloatTensor, optional
922
+ Embeddings projected from the input conditions.
923
+ timestep : torch.LongTensor, optional
924
+ Denoising step.
925
+ img_ids : torch.Tensor, optional
926
+ Image token indices.
927
+ txt_ids : torch.Tensor, optional
928
+ Text token indices.
929
+ guidance : torch.Tensor, optional
930
+ Guidance tensor for classifier-free guidance.
931
+ joint_attention_kwargs : dict, optional
932
+ Additional kwargs for joint attention.
933
+ controlnet_block_samples : list[torch.Tensor], optional
934
+ ControlNet block samples.
935
+ controlnet_single_block_samples : list[torch.Tensor], optional
936
+ ControlNet single block samples.
937
+ return_dict : bool, optional
938
+ Whether to return a Transformer2DModelOutput (default: True).
939
+ controlnet_blocks_repeat : bool, optional
940
+ Whether to repeat ControlNet blocks (default: False).
941
+
942
+ Returns
943
+ -------
944
+ torch.FloatTensor or Transformer2DModelOutput
945
+ Output tensor or output object containing the sample.
946
+ """
947
+ hidden_states = self.x_embedder(hidden_states)
948
+
949
+ timestep = timestep.to(hidden_states.dtype) * 1000
950
+ if guidance is not None:
951
+ guidance = guidance.to(hidden_states.dtype) * 1000
952
+ else:
953
+ guidance = None
954
+
955
+ temb = (
956
+ self.time_text_embed(timestep, pooled_projections)
957
+ if guidance is None
958
+ else self.time_text_embed(timestep, guidance, pooled_projections)
959
+ )
960
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
961
+
962
+ if txt_ids.ndim == 3:
963
+ txt_ids = txt_ids[0]
964
+ if img_ids.ndim == 3:
965
+ img_ids = img_ids[0]
966
+
967
+ ids = torch.cat((txt_ids, img_ids), dim=0)
968
+ image_rotary_emb = self.pos_embed(ids)
969
+
970
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
971
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
972
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
973
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
974
+
975
+ nunchaku_block = self.transformer_blocks[0]
976
+ encoder_hidden_states, hidden_states = nunchaku_block(
977
+ hidden_states=hidden_states,
978
+ encoder_hidden_states=encoder_hidden_states,
979
+ temb=temb,
980
+ image_rotary_emb=image_rotary_emb,
981
+ joint_attention_kwargs=joint_attention_kwargs,
982
+ controlnet_block_samples=controlnet_block_samples,
983
+ controlnet_single_block_samples=controlnet_single_block_samples,
984
+ )
985
+ hidden_states = self.norm_out(hidden_states, temb)
986
+ output = self.proj_out(hidden_states)
987
+
988
+ if not return_dict:
989
+ return (output,)
990
+
991
+ return Transformer2DModelOutput(sample=output)
nunchaku/models/transformers/transformer_flux_v2.py ADDED
@@ -0,0 +1,646 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module provides Nunchaku FluxTransformer2DModel and its building blocks in Python.
3
+ """
4
+
5
+ import json
6
+ import os
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Optional, Tuple, Union
9
+
10
+ import torch
11
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
12
+ from diffusers.models.transformers.transformer_flux import (
13
+ FluxAttention,
14
+ FluxSingleTransformerBlock,
15
+ FluxTransformer2DModel,
16
+ FluxTransformerBlock,
17
+ )
18
+ from huggingface_hub import utils
19
+ from torch.nn import GELU
20
+
21
+ from ...ops.fused import fused_gelu_mlp
22
+ from ...utils import get_precision, pad_tensor
23
+ from ..attention import NunchakuBaseAttention, NunchakuFeedForward
24
+ from ..attention_processors.flux import NunchakuFluxFA2Processor, NunchakuFluxFP16AttnProcessor
25
+ from ..embeddings import NunchakuFluxPosEmbed, pack_rotemb
26
+ from ..linear import SVDQW4A4Linear
27
+ from ..normalization import NunchakuAdaLayerNormZero, NunchakuAdaLayerNormZeroSingle
28
+ from ..utils import fuse_linears
29
+ from .utils import NunchakuModelLoaderMixin
30
+
31
+
32
+ class NunchakuFluxAttention(NunchakuBaseAttention):
33
+ """
34
+ Nunchaku-optimized FluxAttention module with quantized and fused QKV projections.
35
+
36
+ Parameters
37
+ ----------
38
+ other : FluxAttention
39
+ The original FluxAttention module to wrap and quantize.
40
+ processor : str, optional
41
+ The attention processor to use ("flashattn2" or "nunchaku-fp16").
42
+ **kwargs
43
+ Additional arguments for quantization.
44
+ """
45
+
46
+ def __init__(self, other: FluxAttention, processor: str = "flashattn2", **kwargs):
47
+ super(NunchakuFluxAttention, self).__init__(processor)
48
+ self.head_dim = other.head_dim
49
+ self.inner_dim = other.inner_dim
50
+ self.query_dim = other.query_dim
51
+ self.use_bias = other.use_bias
52
+ self.dropout = other.dropout
53
+ self.out_dim = other.out_dim
54
+ self.context_pre_only = other.context_pre_only
55
+ self.pre_only = other.pre_only
56
+ self.heads = other.heads
57
+ self.added_kv_proj_dim = other.added_kv_proj_dim
58
+ self.added_proj_bias = other.added_proj_bias
59
+
60
+ self.norm_q = other.norm_q
61
+ self.norm_k = other.norm_k
62
+
63
+ # Fuse the QKV projections for efficiency.
64
+ with torch.device("meta"):
65
+ to_qkv = fuse_linears([other.to_q, other.to_k, other.to_v])
66
+ self.to_qkv = SVDQW4A4Linear.from_linear(to_qkv, **kwargs)
67
+
68
+ if not self.pre_only:
69
+ self.to_out = other.to_out
70
+ self.to_out[0] = SVDQW4A4Linear.from_linear(self.to_out[0], **kwargs)
71
+
72
+ if self.added_kv_proj_dim is not None:
73
+ self.norm_added_q = other.norm_added_q
74
+ self.norm_added_k = other.norm_added_k
75
+
76
+ # Fuse the additional QKV projections.
77
+ with torch.device("meta"):
78
+ add_qkv_proj = fuse_linears([other.add_q_proj, other.add_k_proj, other.add_v_proj])
79
+ self.add_qkv_proj = SVDQW4A4Linear.from_linear(add_qkv_proj, **kwargs)
80
+ self.to_add_out = SVDQW4A4Linear.from_linear(other.to_add_out, **kwargs)
81
+
82
+ def forward(
83
+ self,
84
+ hidden_states: torch.Tensor,
85
+ encoder_hidden_states: Optional[torch.Tensor] = None,
86
+ attention_mask: Optional[torch.Tensor] = None,
87
+ image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor = None,
88
+ **kwargs,
89
+ ):
90
+ """
91
+ Forward pass for NunchakuFluxAttention.
92
+
93
+ Parameters
94
+ ----------
95
+ hidden_states : torch.Tensor
96
+ Input tensor.
97
+ encoder_hidden_states : torch.Tensor, optional
98
+ Encoder hidden states for cross-attention.
99
+ attention_mask : torch.Tensor, optional
100
+ Attention mask.
101
+ image_rotary_emb : tuple or torch.Tensor, optional
102
+ Rotary embeddings for image/text tokens.
103
+ **kwargs
104
+ Additional arguments.
105
+
106
+ Returns
107
+ -------
108
+ Output of the attention processor.
109
+ """
110
+ return self.processor(
111
+ attn=self,
112
+ hidden_states=hidden_states,
113
+ encoder_hidden_states=encoder_hidden_states,
114
+ attention_mask=attention_mask,
115
+ image_rotary_emb=image_rotary_emb,
116
+ )
117
+
118
+ def set_processor(self, processor: str):
119
+ """
120
+ Set the attention processor.
121
+
122
+ Parameters
123
+ ----------
124
+ processor : str
125
+ Name of the processor ("flashattn2" or "nunchaku-fp16").
126
+
127
+ - ``"flashattn2"``: Standard FlashAttention-2. See :class:`~nunchaku.models.attention_processors.flux.NunchakuFluxFA2Processor`.
128
+ - ``"nunchaku-fp16"``: Uses FP16 attention accumulation, up to 1.2× faster than FlashAttention-2 on NVIDIA 30-, 40-, and 50-series GPUs. See :class:`~nunchaku.models.attention_processors.flux.NunchakuFluxFP16AttnProcessor`.
129
+
130
+ Raises
131
+ ------
132
+ ValueError
133
+ If the processor is not supported.
134
+ """
135
+ if processor == "flashattn2":
136
+ self.processor = NunchakuFluxFA2Processor()
137
+ elif processor == "nunchaku-fp16":
138
+ self.processor = NunchakuFluxFP16AttnProcessor()
139
+ else:
140
+ raise ValueError(f"Processor {processor} is not supported")
141
+
142
+
143
+ class NunchakuFluxTransformerBlock(FluxTransformerBlock):
144
+ """
145
+ Nunchaku-optimized FluxTransformerBlock with quantized attention and feedforward layers.
146
+
147
+ Parameters
148
+ ----------
149
+ block : FluxTransformerBlock
150
+ The original block to wrap and quantize.
151
+ scale_shift : float, optional
152
+ Value to add to scale parameters. Default is 1.0.
153
+ Nunchaku may have already fused the scale_shift into the linear weights, so you may want to set it to 0.
154
+ **kwargs
155
+ Additional arguments for quantization.
156
+ """
157
+
158
+ def __init__(self, block: FluxTransformerBlock, scale_shift: float = 1, **kwargs):
159
+ super(FluxTransformerBlock, self).__init__()
160
+ self.scale_shift = scale_shift
161
+
162
+ # The scale_shift=1 from AdaLayerNormZero has already been fused into the linear weights,
163
+ # so we set scale_shift=0 here to avoid applying it again.
164
+ self.norm1 = NunchakuAdaLayerNormZero(block.norm1, scale_shift=scale_shift)
165
+ self.norm1_context = NunchakuAdaLayerNormZero(block.norm1_context, scale_shift=scale_shift)
166
+
167
+ self.attn = NunchakuFluxAttention(block.attn, **kwargs)
168
+ self.norm2 = block.norm2
169
+ self.norm2_context = block.norm2_context
170
+ self.ff = NunchakuFeedForward(block.ff, **kwargs)
171
+ self.ff_context = NunchakuFeedForward(block.ff_context, **kwargs)
172
+
173
+ def forward(
174
+ self,
175
+ hidden_states: torch.Tensor,
176
+ encoder_hidden_states: torch.Tensor,
177
+ temb: torch.Tensor,
178
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
179
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
180
+ ):
181
+ """
182
+ Forward pass for the transformer block.
183
+
184
+ Parameters
185
+ ----------
186
+ hidden_states : torch.Tensor
187
+ Input hidden states.
188
+ encoder_hidden_states : torch.Tensor
189
+ Encoder hidden states for cross-attention.
190
+ temb : torch.Tensor
191
+ Time or conditioning embedding.
192
+ image_rotary_emb : tuple of torch.Tensor, optional
193
+ Rotary embeddings for image/text tokens.
194
+ joint_attention_kwargs : dict, optional
195
+ Additional attention arguments (not supported).
196
+
197
+ Returns
198
+ -------
199
+ tuple
200
+ (encoder_hidden_states, hidden_states) after block processing.
201
+
202
+ Raises
203
+ ------
204
+ NotImplementedError
205
+ If joint_attention_kwargs is provided.
206
+ """
207
+ if joint_attention_kwargs is not None and len(joint_attention_kwargs) > 0:
208
+ raise NotImplementedError("joint_attention_kwargs is not supported")
209
+
210
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
211
+
212
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
213
+ encoder_hidden_states, emb=temb
214
+ )
215
+
216
+ joint_attention_kwargs = joint_attention_kwargs or {}
217
+
218
+ # Attention.
219
+ attention_outputs = self.attn(
220
+ hidden_states=norm_hidden_states,
221
+ encoder_hidden_states=norm_encoder_hidden_states,
222
+ image_rotary_emb=image_rotary_emb,
223
+ **joint_attention_kwargs,
224
+ )
225
+
226
+ if len(attention_outputs) == 2:
227
+ attn_output, context_attn_output = attention_outputs
228
+ elif len(attention_outputs) == 3:
229
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
230
+
231
+ # Process attention outputs for the `hidden_states`.
232
+ attn_output = gate_msa.unsqueeze(1) * attn_output
233
+ hidden_states = hidden_states + attn_output
234
+
235
+ norm_hidden_states = self.norm2(hidden_states)
236
+ norm_hidden_states = norm_hidden_states * scale_mlp[:, None] + shift_mlp[:, None]
237
+
238
+ ff_output = self.ff(norm_hidden_states)
239
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
240
+
241
+ hidden_states = hidden_states + ff_output
242
+ if len(attention_outputs) == 3:
243
+ hidden_states = hidden_states + ip_attn_output
244
+
245
+ # Process attention outputs for the `encoder_hidden_states`.
246
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
247
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
248
+
249
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
250
+ norm_encoder_hidden_states = norm_encoder_hidden_states * c_scale_mlp[:, None] + c_shift_mlp[:, None]
251
+
252
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
253
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
254
+ if encoder_hidden_states.dtype == torch.float16:
255
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
256
+
257
+ return encoder_hidden_states, hidden_states
258
+
259
+
260
+ class NunchakuFluxSingleTransformerBlock(FluxSingleTransformerBlock):
261
+ """
262
+ Nunchaku-optimized single transformer block with quantized attention and MLP.
263
+
264
+ Parameters
265
+ ----------
266
+ block : FluxSingleTransformerBlock
267
+ The original block to wrap and quantize.
268
+ scale_shift : float, optional
269
+ Value to add to scale parameters. Default is 1.0.
270
+ Nunchaku may have already fused the scale_shift into the linear weights, so you may want to set it to 0.
271
+ **kwargs
272
+ Additional arguments for quantization.
273
+ """
274
+
275
+ def __init__(self, block: FluxSingleTransformerBlock, scale_shift: float = 1, **kwargs):
276
+ super(FluxSingleTransformerBlock, self).__init__()
277
+ self.mlp_hidden_dim = block.mlp_hidden_dim
278
+ self.norm = block.norm
279
+ self.norm = NunchakuAdaLayerNormZeroSingle(block.norm, scale_shift=scale_shift)
280
+
281
+ self.mlp_fc1 = SVDQW4A4Linear.from_linear(block.proj_mlp, **kwargs)
282
+ self.act_mlp = block.act_mlp
283
+ self.mlp_fc2 = SVDQW4A4Linear.from_linear(block.proj_out, in_features=self.mlp_hidden_dim, **kwargs)
284
+ # For int4, we shift the activation of mlp_fc2 to make it unsigned.
285
+ self.mlp_fc2.act_unsigned = self.mlp_fc2.precision != "nvfp4"
286
+
287
+ self.attn = NunchakuFluxAttention(block.attn, **kwargs)
288
+ self.attn.to_out = SVDQW4A4Linear.from_linear(block.proj_out, in_features=self.mlp_fc1.in_features, **kwargs)
289
+
290
+ def forward(
291
+ self,
292
+ hidden_states: torch.Tensor,
293
+ temb: torch.Tensor,
294
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
295
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
296
+ ) -> torch.Tensor:
297
+ """
298
+ Forward pass for the single transformer block.
299
+
300
+ Parameters
301
+ ----------
302
+ hidden_states : torch.Tensor
303
+ Input hidden states.
304
+ temb : torch.Tensor
305
+ Time or conditioning embedding.
306
+ image_rotary_emb : tuple of torch.Tensor, optional
307
+ Rotary embeddings for tokens.
308
+ joint_attention_kwargs : dict, optional
309
+ Additional attention arguments.
310
+
311
+ Returns
312
+ -------
313
+ torch.Tensor
314
+ Output hidden states after block processing.
315
+ """
316
+ residual = hidden_states
317
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
318
+
319
+ # Feedforward
320
+ if isinstance(self.act_mlp, GELU):
321
+ # Use fused GELU MLP for efficiency.
322
+ mlp_hidden_states = fused_gelu_mlp(norm_hidden_states, self.mlp_fc1, self.mlp_fc2)
323
+ else:
324
+ # Fallback to original MLP.
325
+ mlp_hidden_states = self.mlp_fc1(norm_hidden_states)
326
+ mlp_hidden_states = self.act_mlp(mlp_hidden_states)
327
+ mlp_hidden_states = self.mlp_fc2(mlp_hidden_states)
328
+
329
+ # Attention
330
+ joint_attention_kwargs = joint_attention_kwargs or {}
331
+ attn_output = self.attn(
332
+ hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, **joint_attention_kwargs
333
+ )
334
+
335
+ hidden_states = attn_output + mlp_hidden_states
336
+ gate = gate.unsqueeze(1)
337
+ hidden_states = gate * hidden_states
338
+ hidden_states = residual + hidden_states
339
+ if hidden_states.dtype == torch.float16:
340
+ hidden_states = hidden_states.clip(-65504, 65504)
341
+
342
+ return hidden_states
343
+
344
+
345
+ class NunchakuFluxTransformer2DModelV2(FluxTransformer2DModel, NunchakuModelLoaderMixin):
346
+ """
347
+ Nunchaku-optimized FluxTransformer2DModel.
348
+ """
349
+
350
+ def _patch_model(self, **kwargs):
351
+ """
352
+ Patch the model with :class:`~nunchaku.models.transformers.transformer_flux_v2.NunchakuFluxTransformerBlock`
353
+ and :class:`~nunchaku.models.transformers.transformer_flux_v2.NunchakuFluxSingleTransformerBlock`.
354
+
355
+ Parameters
356
+ ----------
357
+ **kwargs
358
+ Additional arguments for quantization.
359
+
360
+ Returns
361
+ -------
362
+ self : NunchakuFluxTransformer2DModelV2
363
+ The patched model.
364
+ """
365
+ self.pos_embed = NunchakuFluxPosEmbed(dim=self.inner_dim, theta=10000, axes_dim=self.pos_embed.axes_dim)
366
+ for i, block in enumerate(self.transformer_blocks):
367
+ self.transformer_blocks[i] = NunchakuFluxTransformerBlock(block, scale_shift=0, **kwargs)
368
+ for i, block in enumerate(self.single_transformer_blocks):
369
+ self.single_transformer_blocks[i] = NunchakuFluxSingleTransformerBlock(block, scale_shift=0, **kwargs)
370
+ return self
371
+
372
+ @classmethod
373
+ @utils.validate_hf_hub_args
374
+ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
375
+ """
376
+ Load a pretrained NunchakuFluxTransformer2DModelV2 from a safetensors file.
377
+
378
+ Parameters
379
+ ----------
380
+ pretrained_model_name_or_path : str or os.PathLike
381
+ Path to the safetensors file. It can be a local file or a remote HuggingFace path.
382
+ **kwargs
383
+ Additional arguments (e.g., device, torch_dtype).
384
+
385
+ Returns
386
+ -------
387
+ NunchakuFluxTransformer2DModelV2
388
+ The loaded and quantized model.
389
+
390
+ Raises
391
+ ------
392
+ NotImplementedError
393
+ If offload is requested.
394
+ AssertionError
395
+ If the file is not a safetensors file.
396
+ """
397
+ device = kwargs.get("device", "cpu")
398
+ offload = kwargs.get("offload", False)
399
+
400
+ if offload:
401
+ raise NotImplementedError("Offload is not supported for FluxTransformer2DModelV2")
402
+
403
+ torch_dtype = kwargs.get("torch_dtype", torch.bfloat16)
404
+
405
+ if isinstance(pretrained_model_name_or_path, str):
406
+ pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
407
+
408
+ assert pretrained_model_name_or_path.is_file() or pretrained_model_name_or_path.name.endswith(
409
+ (".safetensors", ".sft")
410
+ ), "Only safetensors are supported"
411
+ transformer, model_state_dict, metadata = cls._build_model(pretrained_model_name_or_path, **kwargs)
412
+ quantization_config = json.loads(metadata.get("quantization_config", "{}"))
413
+ rank = quantization_config.get("rank", 32)
414
+ transformer = transformer.to(torch_dtype)
415
+
416
+ precision = get_precision()
417
+ if precision == "fp4":
418
+ precision = "nvfp4"
419
+ transformer._patch_model(precision=precision, rank=rank)
420
+
421
+ transformer = transformer.to_empty(device=device)
422
+ converted_state_dict = convert_flux_state_dict(model_state_dict)
423
+
424
+ state_dict = transformer.state_dict()
425
+
426
+ for k in state_dict.keys():
427
+ if k not in converted_state_dict:
428
+ assert ".wcscales" in k
429
+ converted_state_dict[k] = torch.ones_like(state_dict[k])
430
+ else:
431
+ assert state_dict[k].dtype == converted_state_dict[k].dtype
432
+
433
+ # Load the wtscale from the converted state dict.
434
+ for n, m in transformer.named_modules():
435
+ if isinstance(m, SVDQW4A4Linear):
436
+ if m.wtscale is not None:
437
+ m.wtscale = converted_state_dict.pop(f"{n}.wtscale", 1.0)
438
+
439
+ transformer.load_state_dict(converted_state_dict)
440
+
441
+ return transformer
442
+
443
+ def forward(
444
+ self,
445
+ hidden_states: torch.Tensor,
446
+ encoder_hidden_states: torch.Tensor = None,
447
+ pooled_projections: torch.Tensor = None,
448
+ timestep: torch.LongTensor = None,
449
+ img_ids: torch.Tensor = None,
450
+ txt_ids: torch.Tensor = None,
451
+ guidance: torch.Tensor = None,
452
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
453
+ controlnet_block_samples=None,
454
+ controlnet_single_block_samples=None,
455
+ return_dict: bool = True,
456
+ controlnet_blocks_repeat: bool = False,
457
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
458
+ """
459
+ Forward pass for the NunchakuFluxTransformer2DModelV2.
460
+
461
+ Parameters
462
+ ----------
463
+ hidden_states : torch.Tensor
464
+ Input hidden states of shape (batch_size, image_sequence_length, in_channels).
465
+ encoder_hidden_states : torch.Tensor, optional
466
+ Conditional embeddings (e.g., from text).
467
+ pooled_projections : torch.Tensor, optional
468
+ Projected embeddings from input conditions.
469
+ timestep : torch.LongTensor, optional
470
+ Denoising step.
471
+ img_ids : torch.Tensor, optional
472
+ Image token IDs.
473
+ txt_ids : torch.Tensor, optional
474
+ Text token IDs.
475
+ guidance : torch.Tensor, optional
476
+ Guidance tensor for classifier-free guidance.
477
+ joint_attention_kwargs : dict, optional
478
+ Additional attention arguments.
479
+ controlnet_block_samples : any, optional
480
+ Not supported.
481
+ controlnet_single_block_samples : any, optional
482
+ Not supported.
483
+ return_dict : bool, optional
484
+ Whether to return a Transformer2DModelOutput (default: True).
485
+ controlnet_blocks_repeat : bool, optional
486
+ Not supported.
487
+
488
+ Returns
489
+ -------
490
+ Transformer2DModelOutput or tuple
491
+ Output sample tensor or output tuple.
492
+
493
+ Raises
494
+ ------
495
+ NotImplementedError
496
+ If controlnet is requested.
497
+ """
498
+ hidden_states = self.x_embedder(hidden_states)
499
+
500
+ timestep = timestep.to(hidden_states.dtype) * 1000
501
+ if guidance is not None:
502
+ guidance = guidance.to(hidden_states.dtype) * 1000
503
+
504
+ temb = (
505
+ self.time_text_embed(timestep, pooled_projections)
506
+ if guidance is None
507
+ else self.time_text_embed(timestep, guidance, pooled_projections)
508
+ )
509
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
510
+
511
+ if txt_ids.ndim == 3:
512
+ txt_ids = txt_ids[0]
513
+ if img_ids.ndim == 3:
514
+ img_ids = img_ids[0]
515
+
516
+ ids = torch.cat((txt_ids, img_ids), dim=0)
517
+ image_rotary_emb = self.pos_embed(ids)
518
+
519
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
520
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
521
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
522
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
523
+
524
+ txt_tokens = encoder_hidden_states.shape[1]
525
+ img_tokens = hidden_states.shape[1]
526
+
527
+ assert image_rotary_emb.ndim == 6
528
+ assert image_rotary_emb.shape[0] == 1
529
+ assert image_rotary_emb.shape[1] == 1
530
+ assert image_rotary_emb.shape[2] == 1 * (txt_tokens + img_tokens)
531
+ # [1, tokens, head_dim / 2, 1, 2] (sincos)
532
+ image_rotary_emb = image_rotary_emb.reshape([1, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
533
+ rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] # .to(self.dtype)
534
+ rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype)
535
+ rotary_emb_single = image_rotary_emb
536
+
537
+ rotary_emb_txt = pack_rotemb(pad_tensor(rotary_emb_txt, 256, 1))
538
+ rotary_emb_img = pack_rotemb(pad_tensor(rotary_emb_img, 256, 1))
539
+ rotary_emb_single = pack_rotemb(pad_tensor(rotary_emb_single, 256, 1))
540
+
541
+ for index_block, block in enumerate(self.transformer_blocks):
542
+ encoder_hidden_states, hidden_states = block(
543
+ hidden_states=hidden_states,
544
+ encoder_hidden_states=encoder_hidden_states,
545
+ temb=temb,
546
+ image_rotary_emb=(rotary_emb_img, rotary_emb_txt),
547
+ joint_attention_kwargs=joint_attention_kwargs,
548
+ )
549
+
550
+ # Controlnet residual (not supported for now)
551
+ if controlnet_block_samples is not None:
552
+ raise NotImplementedError("Controlnet is not supported for FluxTransformer2DModelV2 for now")
553
+
554
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
555
+ for index_block, block in enumerate(self.single_transformer_blocks):
556
+ hidden_states = block(
557
+ hidden_states=hidden_states,
558
+ temb=temb,
559
+ image_rotary_emb=rotary_emb_single,
560
+ joint_attention_kwargs=joint_attention_kwargs,
561
+ )
562
+
563
+ # Controlnet residual (not supported for now)
564
+ if controlnet_single_block_samples is not None:
565
+ raise NotImplementedError("Controlnet is not supported for FluxTransformer2DModelV2 for now")
566
+
567
+ hidden_states = hidden_states[:, txt_tokens:]
568
+ hidden_states = self.norm_out(hidden_states, temb)
569
+ output = self.proj_out(hidden_states)
570
+
571
+ if not return_dict:
572
+ return (output,)
573
+
574
+ return Transformer2DModelOutput(sample=output)
575
+
576
+
577
+ def convert_flux_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
578
+ """
579
+ Convert a state dict from the :class:`~nunchaku.models.transformers.transformer_flux.NunchakuFluxTransformer2dModel`
580
+ format to :class:`~nunchaku.models.transformers.transformer_flux_v2.NunchakuFluxTransformer2DModelV2` format.
581
+
582
+ Parameters
583
+ ----------
584
+ state_dict : dict[str, torch.Tensor]
585
+ The original state dict.
586
+
587
+ Returns
588
+ -------
589
+ dict[str, torch.Tensor]
590
+ The converted state dict compatible with :class:`~nunchaku.models.transformers.transformer_flux_v2.NunchakuFluxTransformer2DModelV2`.
591
+ """
592
+ new_state_dict = {}
593
+ for k, v in state_dict.items():
594
+ if "single_transformer_blocks." in k:
595
+ if ".qkv_proj." in k:
596
+ new_k = k.replace(".qkv_proj.", ".attn.to_qkv.")
597
+ elif ".out_proj." in k:
598
+ new_k = k.replace(".out_proj.", ".attn.to_out.")
599
+ elif ".norm_q." in k or ".norm_k." in k:
600
+ new_k = k.replace(".norm_k.", ".attn.norm_k.")
601
+ new_k = new_k.replace(".norm_q.", ".attn.norm_q.")
602
+ else:
603
+ new_k = k
604
+ new_k = new_k.replace(".lora_down", ".proj_down")
605
+ new_k = new_k.replace(".lora_up", ".proj_up")
606
+ if ".smooth_orig" in k:
607
+ new_k = new_k.replace(".smooth_orig", ".smooth_factor_orig")
608
+ elif ".smooth" in k:
609
+ new_k = new_k.replace(".smooth", ".smooth_factor")
610
+ new_state_dict[new_k] = v
611
+ elif "transformer_blocks." in k:
612
+ if ".mlp_context_fc1" in k:
613
+ new_k = k.replace(".mlp_context_fc1.", ".ff_context.net.0.proj.")
614
+ elif ".mlp_context_fc2" in k:
615
+ new_k = k.replace(".mlp_context_fc2.", ".ff_context.net.2.")
616
+ elif ".mlp_fc1" in k:
617
+ new_k = k.replace(".mlp_fc1.", ".ff.net.0.proj.")
618
+ elif ".mlp_fc2" in k:
619
+ new_k = k.replace(".mlp_fc2.", ".ff.net.2.")
620
+ elif ".qkv_proj_context." in k:
621
+ new_k = k.replace(".qkv_proj_context.", ".attn.add_qkv_proj.")
622
+ elif ".qkv_proj." in k:
623
+ new_k = k.replace(".qkv_proj.", ".attn.to_qkv.")
624
+ elif ".norm_q." in k or ".norm_k." in k:
625
+ new_k = k.replace(".norm_k.", ".attn.norm_k.")
626
+ new_k = new_k.replace(".norm_q.", ".attn.norm_q.")
627
+ elif ".norm_added_q." in k or ".norm_added_k." in k:
628
+ new_k = k.replace(".norm_added_k.", ".attn.norm_added_k.")
629
+ new_k = new_k.replace(".norm_added_q.", ".attn.norm_added_q.")
630
+ elif ".out_proj." in k:
631
+ new_k = k.replace(".out_proj.", ".attn.to_out.0.")
632
+ elif ".out_proj_context." in k:
633
+ new_k = k.replace(".out_proj_context.", ".attn.to_add_out.")
634
+ else:
635
+ new_k = k
636
+ new_k = new_k.replace(".lora_down", ".proj_down")
637
+ new_k = new_k.replace(".lora_up", ".proj_up")
638
+ if ".smooth_orig" in k:
639
+ new_k = new_k.replace(".smooth_orig", ".smooth_factor_orig")
640
+ elif ".smooth" in k:
641
+ new_k = new_k.replace(".smooth", ".smooth_factor")
642
+ new_state_dict[new_k] = v
643
+ else:
644
+ new_state_dict[k] = v
645
+
646
+ return new_state_dict
nunchaku/models/transformers/transformer_qwenimage.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module provides implementations of NunchakuQwenImageTransformer2DModel and its building blocks.
3
+ """
4
+
5
+ import gc
6
+ import json
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Optional, Tuple, Union
10
+ from warnings import warn
11
+
12
+ import torch
13
+ from diffusers.models.attention_processor import Attention
14
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
15
+ from diffusers.models.transformers.transformer_qwenimage import (
16
+ QwenEmbedRope,
17
+ QwenImageTransformer2DModel,
18
+ QwenImageTransformerBlock,
19
+ )
20
+ from huggingface_hub import utils
21
+
22
+ from ...utils import get_precision
23
+ from ..attention import NunchakuBaseAttention, NunchakuFeedForward
24
+ from ..attention_processors.qwenimage import NunchakuQwenImageNaiveFA2Processor
25
+ from ..linear import AWQW4A16Linear, SVDQW4A4Linear
26
+ from ..utils import CPUOffloadManager, fuse_linears
27
+ from .utils import NunchakuModelLoaderMixin
28
+
29
+
30
+ class NunchakuQwenAttention(NunchakuBaseAttention):
31
+ """
32
+ Nunchaku-optimized quantized attention module for QwenImage.
33
+
34
+ Parameters
35
+ ----------
36
+ other : Attention
37
+ The original QwenImage Attention module to wrap and quantize.
38
+ processor : str, default="flashattn2"
39
+ The attention processor to use.
40
+ **kwargs
41
+ Additional arguments for quantization.
42
+ """
43
+
44
+ def __init__(self, other: Attention, processor: str = "flashattn2", **kwargs):
45
+ super(NunchakuQwenAttention, self).__init__(processor)
46
+ self.inner_dim = other.inner_dim
47
+ self.inner_kv_dim = other.inner_kv_dim
48
+ self.query_dim = other.query_dim
49
+ self.use_bias = other.use_bias
50
+ self.is_cross_attention = other.is_cross_attention
51
+ self.cross_attention_dim = other.cross_attention_dim
52
+ self.upcast_attention = other.upcast_attention
53
+ self.upcast_softmax = other.upcast_softmax
54
+ self.rescale_output_factor = other.rescale_output_factor
55
+ self.residual_connection = other.residual_connection
56
+ self.dropout = other.dropout
57
+ self.fused_projections = other.fused_projections
58
+ self.out_dim = other.out_dim
59
+ self.out_context_dim = other.out_context_dim
60
+ self.context_pre_only = other.context_pre_only
61
+ self.pre_only = other.pre_only
62
+ self.is_causal = other.is_causal
63
+ self.scale_qk = other.scale_qk
64
+ self.scale = other.scale
65
+ self.heads = other.heads
66
+ self.sliceable_head_dim = other.sliceable_head_dim
67
+ self.added_kv_proj_dim = other.added_kv_proj_dim
68
+ self.only_cross_attention = other.only_cross_attention
69
+ self.group_norm = other.group_norm
70
+ self.spatial_norm = other.spatial_norm
71
+
72
+ self.norm_cross = other.norm_cross
73
+
74
+ self.norm_q = other.norm_q
75
+ self.norm_k = other.norm_k
76
+ self.norm_added_q = other.norm_added_q
77
+ self.norm_added_k = other.norm_added_k
78
+
79
+ # Fuse the QKV projections for quantization
80
+ with torch.device("meta"):
81
+ to_qkv = fuse_linears([other.to_q, other.to_k, other.to_v])
82
+ self.to_qkv = SVDQW4A4Linear.from_linear(to_qkv, **kwargs)
83
+ self.to_out = other.to_out
84
+ self.to_out[0] = SVDQW4A4Linear.from_linear(self.to_out[0], **kwargs)
85
+
86
+ assert self.added_kv_proj_dim is not None
87
+ # Fuse the additional QKV projections
88
+ with torch.device("meta"):
89
+ add_qkv_proj = fuse_linears([other.add_q_proj, other.add_k_proj, other.add_v_proj])
90
+ self.add_qkv_proj = SVDQW4A4Linear.from_linear(add_qkv_proj, **kwargs)
91
+ self.to_add_out = SVDQW4A4Linear.from_linear(other.to_add_out, **kwargs)
92
+
93
+ def forward(
94
+ self,
95
+ hidden_states: torch.FloatTensor,
96
+ encoder_hidden_states: torch.FloatTensor = None,
97
+ encoder_hidden_states_mask: torch.FloatTensor = None,
98
+ attention_mask: Optional[torch.FloatTensor] = None,
99
+ image_rotary_emb: Optional[torch.Tensor] = None,
100
+ **kwargs,
101
+ ):
102
+ """
103
+ Forward pass for NunchakuQwenAttention.
104
+
105
+ Parameters
106
+ ----------
107
+ hidden_states : torch.FloatTensor
108
+ Image stream input.
109
+ encoder_hidden_states : torch.FloatTensor, optional
110
+ Text stream input.
111
+ encoder_hidden_states_mask : torch.FloatTensor, optional
112
+ Mask for encoder hidden states.
113
+ attention_mask : torch.FloatTensor, optional
114
+ Attention mask.
115
+ image_rotary_emb : torch.Tensor, optional
116
+ Rotary embedding for images.
117
+ **kwargs
118
+ Additional arguments.
119
+
120
+ Returns
121
+ -------
122
+ tuple
123
+ Attention outputs for image and text streams.
124
+ """
125
+ return self.processor(
126
+ self,
127
+ hidden_states,
128
+ encoder_hidden_states,
129
+ encoder_hidden_states_mask,
130
+ attention_mask,
131
+ image_rotary_emb,
132
+ **kwargs,
133
+ )
134
+
135
+ def set_processor(self, processor: str):
136
+ """
137
+ Set the attention processor.
138
+
139
+ Parameters
140
+ ----------
141
+ processor : str
142
+ Name of the processor to use. Only "flashattn2" is supported for now. See :class:`~nunchaku.models.attention_processors.qwenimage.NunchakuQwenImageNaiveFA2Processor`.
143
+
144
+ Raises
145
+ ------
146
+ ValueError
147
+ If the processor is not supported.
148
+ """
149
+ if processor == "flashattn2":
150
+ self.processor = NunchakuQwenImageNaiveFA2Processor()
151
+ else:
152
+ raise ValueError(f"Processor {processor} is not supported")
153
+
154
+
155
+ class NunchakuQwenImageTransformerBlock(QwenImageTransformerBlock):
156
+ """
157
+ Quantized QwenImage Transformer Block.
158
+
159
+ This block supports quantized linear layers and joint attention for image and text streams.
160
+
161
+ Parameters
162
+ ----------
163
+ other : QwenImageTransformerBlock
164
+ The original transformer block to wrap and quantize.
165
+ scale_shift : float, default=1.0
166
+ Value to add to scale parameters. Default is 1.0.
167
+ Nunchaku may have already fused the scale_shift into the linear weights, so you may want to set it to 0.
168
+ **kwargs
169
+ Additional arguments for quantization.
170
+ """
171
+
172
+ def __init__(self, other: QwenImageTransformerBlock, scale_shift: float = 1.0, **kwargs):
173
+ super(QwenImageTransformerBlock, self).__init__()
174
+
175
+ self.dim = other.dim
176
+ self.img_mod = other.img_mod
177
+ self.img_mod[1] = AWQW4A16Linear.from_linear(other.img_mod[1], **kwargs)
178
+ self.img_norm1 = other.img_norm1
179
+ self.attn = NunchakuQwenAttention(other.attn, **kwargs)
180
+ self.img_norm2 = other.img_norm2
181
+ self.img_mlp = NunchakuFeedForward(other.img_mlp, **kwargs)
182
+
183
+ # Text processing modules
184
+ self.txt_mod = other.txt_mod
185
+ self.txt_mod[1] = AWQW4A16Linear.from_linear(other.txt_mod[1], **kwargs)
186
+ self.txt_norm1 = other.txt_norm1
187
+ # Text doesn't need separate attention - it's handled by img_attn joint computation
188
+ self.txt_norm2 = other.txt_norm2
189
+ self.txt_mlp = NunchakuFeedForward(other.txt_mlp, **kwargs)
190
+
191
+ self.scale_shift = scale_shift
192
+
193
+ def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
194
+ """
195
+ Apply modulation to input tensor.
196
+
197
+ Parameters
198
+ ----------
199
+ x : torch.Tensor
200
+ Input tensor.
201
+ mod_params : torch.Tensor
202
+ Modulation parameters.
203
+
204
+ Returns
205
+ -------
206
+ tuple
207
+ Modulated tensor and gate tensor.
208
+ """
209
+ shift, scale, gate = mod_params.chunk(3, dim=-1)
210
+ if self.scale_shift != 0:
211
+ scale.add_(self.scale_shift)
212
+ return x * scale.unsqueeze(1) + shift.unsqueeze(1), gate.unsqueeze(1)
213
+
214
+ def forward(
215
+ self,
216
+ hidden_states: torch.Tensor,
217
+ encoder_hidden_states: torch.Tensor,
218
+ encoder_hidden_states_mask: torch.Tensor,
219
+ temb: torch.Tensor,
220
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
221
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
222
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
223
+ """
224
+ Forward pass for NunchakuQwenImageTransformerBlock.
225
+
226
+ Parameters
227
+ ----------
228
+ hidden_states : torch.Tensor
229
+ Image stream input.
230
+ encoder_hidden_states : torch.Tensor
231
+ Text stream input.
232
+ encoder_hidden_states_mask : torch.Tensor
233
+ Mask for encoder hidden states.
234
+ temb : torch.Tensor
235
+ Temporal embedding.
236
+ image_rotary_emb : tuple of torch.Tensor, optional
237
+ Rotary embedding for images.
238
+ joint_attention_kwargs : dict, optional
239
+ Additional arguments for joint attention.
240
+
241
+ Returns
242
+ -------
243
+ tuple
244
+ Updated encoder_hidden_states and hidden_states.
245
+ """
246
+ # Get modulation parameters for both streams
247
+ img_mod_params = self.img_mod(temb) # [B, 6*dim]
248
+ txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
249
+
250
+ # nunchaku's mod_params is [B, 6*dim] instead of [B, dim*6]
251
+ img_mod_params = (
252
+ img_mod_params.view(img_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(img_mod_params.shape[0], -1)
253
+ )
254
+ txt_mod_params = (
255
+ txt_mod_params.view(txt_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(txt_mod_params.shape[0], -1)
256
+ )
257
+
258
+ img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
259
+ txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
260
+
261
+ # Process image stream - norm1 + modulation
262
+ img_normed = self.img_norm1(hidden_states)
263
+ img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
264
+
265
+ # Process text stream - norm1 + modulation
266
+ txt_normed = self.txt_norm1(encoder_hidden_states)
267
+ txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
268
+
269
+ joint_attention_kwargs = joint_attention_kwargs or {}
270
+ attn_output = self.attn(
271
+ hidden_states=img_modulated,
272
+ encoder_hidden_states=txt_modulated,
273
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
274
+ image_rotary_emb=image_rotary_emb,
275
+ **joint_attention_kwargs,
276
+ )
277
+
278
+ # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
279
+ img_attn_output, txt_attn_output = attn_output
280
+
281
+ # Apply attention gates and add residual (like in Megatron)
282
+ hidden_states = hidden_states + img_gate1 * img_attn_output
283
+ encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
284
+
285
+ # Process image stream - norm2 + MLP
286
+ img_normed2 = self.img_norm2(hidden_states)
287
+ img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
288
+ img_mlp_output = self.img_mlp(img_modulated2)
289
+ hidden_states = hidden_states + img_gate2 * img_mlp_output
290
+
291
+ # Process text stream - norm2 + MLP
292
+ txt_normed2 = self.txt_norm2(encoder_hidden_states)
293
+ txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
294
+ txt_mlp_output = self.txt_mlp(txt_modulated2)
295
+ encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
296
+
297
+ # Clip to prevent overflow for fp16
298
+ if encoder_hidden_states.dtype == torch.float16:
299
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
300
+ if hidden_states.dtype == torch.float16:
301
+ hidden_states = hidden_states.clip(-65504, 65504)
302
+
303
+ return encoder_hidden_states, hidden_states
304
+
305
+
306
+ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuModelLoaderMixin):
307
+ """
308
+ Quantized QwenImage Transformer2DModel.
309
+
310
+ This model supports quantized transformer blocks and optional CPU offloading for memory efficiency.
311
+
312
+ Parameters
313
+ ----------
314
+ *args
315
+ Positional arguments for the base model.
316
+ **kwargs
317
+ Keyword arguments for the base model and quantization.
318
+
319
+ Attributes
320
+ ----------
321
+ offload : bool
322
+ Whether CPU offloading is enabled.
323
+ offload_manager : CPUOffloadManager or None
324
+ Manager for offloading transformer blocks.
325
+ _is_initialized : bool
326
+ Whether the model has been patched for quantization.
327
+ """
328
+
329
+ def __init__(self, *args, **kwargs):
330
+ self.offload = kwargs.pop("offload", False)
331
+ self.offload_manager = None
332
+ self._is_initialized = False
333
+ super().__init__(*args, **kwargs)
334
+
335
+ def _patch_model(self, **kwargs):
336
+ """
337
+ Patch the transformer blocks for quantization.
338
+
339
+ Parameters
340
+ ----------
341
+ **kwargs
342
+ Additional arguments for quantization.
343
+
344
+ Returns
345
+ -------
346
+ self
347
+ """
348
+ for i, block in enumerate(self.transformer_blocks):
349
+ self.transformer_blocks[i] = NunchakuQwenImageTransformerBlock(block, scale_shift=0, **kwargs)
350
+ self._is_initialized = True
351
+ return self
352
+
353
+ @classmethod
354
+ @utils.validate_hf_hub_args
355
+ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
356
+ """
357
+ Load a quantized model from a pretrained checkpoint.
358
+
359
+ Parameters
360
+ ----------
361
+ pretrained_model_name_or_path : str or os.PathLike
362
+ Path to the pretrained model checkpoint. It can be a local file or a remote HuggingFace path.
363
+ **kwargs
364
+ Additional arguments for loading and quantization.
365
+
366
+ Returns
367
+ -------
368
+ NunchakuQwenImageTransformer2DModel
369
+ The loaded and quantized model.
370
+
371
+ Raises
372
+ ------
373
+ AssertionError
374
+ If the checkpoint is not a safetensors file.
375
+ """
376
+ device = kwargs.get("device", "cpu")
377
+ offload = kwargs.get("offload", False)
378
+
379
+ torch_dtype = kwargs.get("torch_dtype", torch.bfloat16)
380
+
381
+ if isinstance(pretrained_model_name_or_path, str):
382
+ pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
383
+
384
+ assert pretrained_model_name_or_path.is_file() or pretrained_model_name_or_path.name.endswith(
385
+ (".safetensors", ".sft")
386
+ ), "Only safetensors are supported"
387
+ transformer, model_state_dict, metadata = cls._build_model(pretrained_model_name_or_path, **kwargs)
388
+ quantization_config = json.loads(metadata.get("quantization_config", "{}"))
389
+ config = json.loads(metadata.get("config", "{}"))
390
+ rank = quantization_config.get("rank", 32)
391
+ transformer = transformer.to(torch_dtype)
392
+
393
+ precision = get_precision()
394
+ if precision == "fp4":
395
+ precision = "nvfp4"
396
+ transformer._patch_model(precision=precision, rank=rank)
397
+
398
+ transformer = transformer.to_empty(device=device)
399
+ # need to re-init the pos_embed as to_empty does not work on it
400
+ transformer.pos_embed = QwenEmbedRope(
401
+ theta=10000, axes_dim=list(config.get("axes_dims_rope", [16, 56, 56])), scale_rope=True
402
+ )
403
+
404
+ state_dict = transformer.state_dict()
405
+ for k in state_dict.keys():
406
+ if k not in model_state_dict:
407
+ assert ".wcscales" in k
408
+ model_state_dict[k] = torch.ones_like(state_dict[k])
409
+ else:
410
+ assert state_dict[k].dtype == model_state_dict[k].dtype
411
+
412
+ # load the wtscale from the state dict, as it is a float on CPU
413
+ for n, m in transformer.named_modules():
414
+ if isinstance(m, SVDQW4A4Linear):
415
+ if m.wtscale is not None:
416
+ m.wtscale = model_state_dict.pop(f"{n}.wtscale", 1.0)
417
+ transformer.load_state_dict(model_state_dict)
418
+ transformer.set_offload(offload)
419
+
420
+ return transformer
421
+
422
+ def set_offload(self, offload: bool, **kwargs):
423
+ """
424
+ Enable or disable asynchronous CPU offloading for transformer blocks.
425
+
426
+ Parameters
427
+ ----------
428
+ offload : bool
429
+ Whether to enable offloading.
430
+ **kwargs
431
+ Additional arguments for offload manager.
432
+
433
+ See Also
434
+ --------
435
+ :class:`~nunchaku.models.utils.CPUOffloadManager`
436
+ """
437
+ if offload == self.offload:
438
+ # nothing changed, just return
439
+ return
440
+ self.offload = offload
441
+ if offload:
442
+ self.offload_manager = CPUOffloadManager(
443
+ self.transformer_blocks,
444
+ use_pin_memory=kwargs.get("use_pin_memory", True),
445
+ on_gpu_modules=[
446
+ self.img_in,
447
+ self.txt_in,
448
+ self.txt_norm,
449
+ self.time_text_embed,
450
+ self.norm_out,
451
+ self.proj_out,
452
+ ],
453
+ num_blocks_on_gpu=kwargs.get("num_blocks_on_gpu", 1),
454
+ )
455
+ else:
456
+ self.offload_manager = None
457
+ gc.collect()
458
+ torch.cuda.empty_cache()
459
+
460
+ def forward(
461
+ self,
462
+ hidden_states: torch.Tensor,
463
+ encoder_hidden_states: torch.Tensor = None,
464
+ encoder_hidden_states_mask: torch.Tensor = None,
465
+ timestep: torch.LongTensor = None,
466
+ img_shapes: Optional[List[Tuple[int, int, int]]] = None,
467
+ txt_seq_lens: Optional[List[int]] = None,
468
+ guidance: torch.Tensor = None,
469
+ attention_kwargs: Optional[Dict[str, Any]] = None,
470
+ return_dict: bool = True,
471
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
472
+ """
473
+ Forward pass for the quantized QwenImage transformer model.
474
+
475
+ Parameters
476
+ ----------
477
+ hidden_states : torch.Tensor
478
+ Image stream input.
479
+ encoder_hidden_states : torch.Tensor, optional
480
+ Text stream input.
481
+ encoder_hidden_states_mask : torch.Tensor, optional
482
+ Mask for encoder hidden states.
483
+ timestep : torch.LongTensor, optional
484
+ Timestep for temporal embedding.
485
+ img_shapes : list of tuple, optional
486
+ Image shapes for rotary embedding.
487
+ txt_seq_lens : list of int, optional
488
+ Text sequence lengths.
489
+ guidance : torch.Tensor, optional
490
+ Guidance tensor (for classifier-free guidance).
491
+ attention_kwargs : dict, optional
492
+ Additional attention arguments.
493
+ return_dict : bool, default=True
494
+ Whether to return a dict or tuple.
495
+
496
+ Returns
497
+ -------
498
+ torch.Tensor or Transformer2DModelOutput
499
+ Model output.
500
+ """
501
+ device = hidden_states.device
502
+ if self.offload:
503
+ self.offload_manager.set_device(device)
504
+
505
+ hidden_states = self.img_in(hidden_states)
506
+
507
+ timestep = timestep.to(hidden_states.dtype)
508
+ encoder_hidden_states = self.txt_norm(encoder_hidden_states)
509
+ encoder_hidden_states = self.txt_in(encoder_hidden_states)
510
+
511
+ if guidance is not None:
512
+ guidance = guidance.to(hidden_states.dtype) * 1000
513
+
514
+ temb = (
515
+ self.time_text_embed(timestep, hidden_states)
516
+ if guidance is None
517
+ else self.time_text_embed(timestep, guidance, hidden_states)
518
+ )
519
+
520
+ image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
521
+
522
+ compute_stream = torch.cuda.current_stream()
523
+ if self.offload:
524
+ self.offload_manager.initialize(compute_stream)
525
+ for block_idx, block in enumerate(self.transformer_blocks):
526
+ with torch.cuda.stream(compute_stream):
527
+ if self.offload:
528
+ block = self.offload_manager.get_block(block_idx)
529
+ encoder_hidden_states, hidden_states = block(
530
+ hidden_states=hidden_states,
531
+ encoder_hidden_states=encoder_hidden_states,
532
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
533
+ temb=temb,
534
+ image_rotary_emb=image_rotary_emb,
535
+ joint_attention_kwargs=attention_kwargs,
536
+ )
537
+ if self.offload:
538
+ self.offload_manager.step(compute_stream)
539
+
540
+ hidden_states = self.norm_out(hidden_states, temb)
541
+ output = self.proj_out(hidden_states)
542
+
543
+ torch.cuda.empty_cache()
544
+
545
+ if not return_dict:
546
+ return (output,)
547
+
548
+ return Transformer2DModelOutput(sample=output)
549
+
550
+ def to(self, *args, **kwargs):
551
+ """
552
+ Override the default ``.to()`` method.
553
+
554
+ If offload is enabled, prevents moving the model to GPU.
555
+ Prevents changing dtype after quantization.
556
+
557
+ Parameters
558
+ ----------
559
+ *args
560
+ Positional arguments for ``.to()``.
561
+ **kwargs
562
+ Keyword arguments for ``.to()``.
563
+
564
+ Returns
565
+ -------
566
+ self
567
+
568
+ Raises
569
+ ------
570
+ ValueError
571
+ If attempting to change dtype after quantization.
572
+ """
573
+ device_arg_or_kwarg_present = any(isinstance(arg, torch.device) for arg in args) or "device" in kwargs
574
+ dtype_present_in_args = "dtype" in kwargs
575
+
576
+ # Try converting arguments to torch.device in case they are passed as strings
577
+ for arg in args:
578
+ if not isinstance(arg, str):
579
+ continue
580
+ try:
581
+ torch.device(arg)
582
+ device_arg_or_kwarg_present = True
583
+ except RuntimeError:
584
+ pass
585
+
586
+ if not dtype_present_in_args:
587
+ for arg in args:
588
+ if isinstance(arg, torch.dtype):
589
+ dtype_present_in_args = True
590
+ break
591
+
592
+ if dtype_present_in_args and self._is_initialized:
593
+ raise ValueError(
594
+ "Casting a quantized model to a new `dtype` is unsupported. To set the dtype of unquantized layers, please "
595
+ "use the `torch_dtype` argument when loading the model using `from_pretrained` or `from_single_file`."
596
+ )
597
+ if self.offload:
598
+ if device_arg_or_kwarg_present:
599
+ warn("Skipping moving the model to GPU as offload is enabled", UserWarning)
600
+ return self
601
+ return super(type(self), self).to(*args, **kwargs)
nunchaku/models/transformers/transformer_sana.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implements the :class:`NunchakuSanaTransformer2DModel`,
3
+ a quantized Sana transformer for Diffusers with efficient inference support.
4
+ """
5
+
6
+ import os
7
+ from pathlib import Path
8
+ from typing import Optional
9
+
10
+ import torch
11
+ from diffusers import SanaTransformer2DModel
12
+ from huggingface_hub import utils
13
+ from safetensors.torch import load_file
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+
17
+ from ..._C import QuantizedSanaModel
18
+ from ..._C import utils as cutils
19
+ from ...utils import get_precision
20
+ from .utils import NunchakuModelLoaderMixin
21
+
22
+ SVD_RANK = 32
23
+
24
+
25
+ class NunchakuSanaTransformerBlocks(nn.Module):
26
+ """
27
+ Wrapper for quantized Sana transformer blocks.
28
+
29
+ This module wraps a QuantizedSanaModel and provides forward methods compatible
30
+ with the expected transformer block interface.
31
+
32
+ Parameters
33
+ ----------
34
+ m : QuantizedSanaModel
35
+ The quantized transformer model.
36
+ dtype : torch.dtype
37
+ The data type to use for computation.
38
+ device : str or torch.device
39
+ The device to run the model on.
40
+ """
41
+
42
+ def __init__(self, m: QuantizedSanaModel, dtype: torch.dtype, device: str | torch.device):
43
+ super(NunchakuSanaTransformerBlocks, self).__init__()
44
+ self.m = m
45
+ self.dtype = dtype
46
+ self.device = device
47
+
48
+ def forward(
49
+ self,
50
+ hidden_states: torch.Tensor,
51
+ attention_mask: Optional[torch.Tensor] = None,
52
+ encoder_hidden_states: Optional[torch.Tensor] = None,
53
+ encoder_attention_mask: Optional[torch.Tensor] = None,
54
+ timestep: Optional[torch.LongTensor] = None,
55
+ height: Optional[int] = None,
56
+ width: Optional[int] = None,
57
+ skip_first_layer: Optional[bool] = False,
58
+ ):
59
+ """
60
+ Forward pass through all quantized transformer blocks.
61
+
62
+ Parameters
63
+ ----------
64
+ hidden_states : torch.Tensor
65
+ Input hidden states of shape (batch_size, img_tokens, ...).
66
+ attention_mask : torch.Tensor, optional
67
+ Not used.
68
+ encoder_hidden_states : torch.Tensor, optional
69
+ Encoder hidden states of shape (batch_size, txt_tokens, ...).
70
+ encoder_attention_mask : torch.Tensor, optional
71
+ Encoder attention mask of shape (batch_size, 1, txt_tokens).
72
+ timestep : torch.LongTensor, optional
73
+ Timestep tensor.
74
+ height : int, optional
75
+ Image height.
76
+ width : int, optional
77
+ Image width.
78
+ skip_first_layer : bool, optional
79
+ Whether to skip the first layer.
80
+
81
+ Returns
82
+ -------
83
+ torch.Tensor
84
+ Output tensor after passing through the quantized transformer blocks.
85
+ """
86
+ batch_size = hidden_states.shape[0]
87
+ img_tokens = hidden_states.shape[1]
88
+ txt_tokens = encoder_hidden_states.shape[1]
89
+
90
+ original_dtype = hidden_states.dtype
91
+ original_device = hidden_states.device
92
+
93
+ assert encoder_attention_mask is not None
94
+ assert encoder_attention_mask.shape == (batch_size, 1, txt_tokens)
95
+
96
+ mask = encoder_attention_mask.reshape(batch_size, txt_tokens)
97
+ nunchaku_encoder_hidden_states = encoder_hidden_states[mask > -9000]
98
+
99
+ cu_seqlens_txt = F.pad((mask > -9000).sum(dim=1).cumsum(dim=0), pad=(1, 0), value=0).to(torch.int32)
100
+ cu_seqlens_img = torch.arange(
101
+ 0, (batch_size + 1) * img_tokens, img_tokens, dtype=torch.int32, device=self.device
102
+ )
103
+
104
+ if height is None and width is None:
105
+ height = width = int(img_tokens**0.5)
106
+ elif height is None:
107
+ height = img_tokens // width
108
+ elif width is None:
109
+ width = img_tokens // height
110
+ assert height * width == img_tokens
111
+
112
+ return (
113
+ self.m.forward(
114
+ hidden_states.to(self.dtype).to(self.device),
115
+ nunchaku_encoder_hidden_states.to(self.dtype).to(self.device),
116
+ timestep.to(self.dtype).to(self.device),
117
+ cu_seqlens_img.to(self.device),
118
+ cu_seqlens_txt.to(self.device),
119
+ height,
120
+ width,
121
+ batch_size % 3 == 0, # pag is set when loading the model, FIXME: pag_scale == 0
122
+ True, # TODO: find a way to detect if we are doing CFG
123
+ skip_first_layer,
124
+ )
125
+ .to(original_dtype)
126
+ .to(original_device)
127
+ )
128
+
129
+ def forward_layer_at(
130
+ self,
131
+ idx: int,
132
+ hidden_states: torch.Tensor,
133
+ attention_mask: Optional[torch.Tensor] = None,
134
+ encoder_hidden_states: Optional[torch.Tensor] = None,
135
+ encoder_attention_mask: Optional[torch.Tensor] = None,
136
+ timestep: Optional[torch.LongTensor] = None,
137
+ height: Optional[int] = None,
138
+ width: Optional[int] = None,
139
+ ):
140
+ """
141
+ Forward pass through a specific quantized transformer layer.
142
+
143
+ Parameters
144
+ ----------
145
+ idx : int
146
+ Index of the layer to run.
147
+ hidden_states : torch.Tensor
148
+ Input hidden states.
149
+ attention_mask : torch.Tensor, optional
150
+ Not used.
151
+ encoder_hidden_states : torch.Tensor, optional
152
+ Encoder hidden states.
153
+ encoder_attention_mask : torch.Tensor, optional
154
+ Encoder attention mask.
155
+ timestep : torch.LongTensor, optional
156
+ Timestep tensor.
157
+ height : int, optional
158
+ Image height.
159
+ width : int, optional
160
+ Image width.
161
+
162
+ Returns
163
+ -------
164
+ torch.Tensor
165
+ Output tensor after passing through the specified quantized transformer layer.
166
+ """
167
+ batch_size = hidden_states.shape[0]
168
+ img_tokens = hidden_states.shape[1]
169
+ txt_tokens = encoder_hidden_states.shape[1]
170
+
171
+ original_dtype = hidden_states.dtype
172
+ original_device = hidden_states.device
173
+
174
+ assert encoder_attention_mask is not None
175
+ assert encoder_attention_mask.shape == (batch_size, 1, txt_tokens)
176
+
177
+ mask = encoder_attention_mask.reshape(batch_size, txt_tokens)
178
+ nunchaku_encoder_hidden_states = encoder_hidden_states[mask > -9000]
179
+
180
+ cu_seqlens_txt = F.pad((mask > -9000).sum(dim=1).cumsum(dim=0), pad=(1, 0), value=0).to(torch.int32)
181
+ cu_seqlens_img = torch.arange(
182
+ 0, (batch_size + 1) * img_tokens, img_tokens, dtype=torch.int32, device=self.device
183
+ )
184
+
185
+ if height is None and width is None:
186
+ height = width = int(img_tokens**0.5)
187
+ elif height is None:
188
+ height = img_tokens // width
189
+ elif width is None:
190
+ width = img_tokens // height
191
+ assert height * width == img_tokens
192
+
193
+ return (
194
+ self.m.forward_layer(
195
+ idx,
196
+ hidden_states.to(self.dtype).to(self.device),
197
+ nunchaku_encoder_hidden_states.to(self.dtype).to(self.device),
198
+ timestep.to(self.dtype).to(self.device),
199
+ cu_seqlens_img.to(self.device),
200
+ cu_seqlens_txt.to(self.device),
201
+ height,
202
+ width,
203
+ batch_size % 3 == 0, # pag is set when loading the model, FIXME: pag_scale == 0
204
+ True, # TODO: find a way to detect if we are doing CFG
205
+ )
206
+ .to(original_dtype)
207
+ .to(original_device)
208
+ )
209
+
210
+ def __del__(self):
211
+ """
212
+ Destructor to reset the quantized model and free resources.
213
+ """
214
+ self.m.reset()
215
+
216
+
217
+ class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoaderMixin):
218
+ """
219
+ SanaTransformer2DModel with Nunchaku quantized backend support.
220
+
221
+ This class extends the base SanaTransformer2DModel to support loading and
222
+ injecting quantized transformer blocks using Nunchaku's custom backend.
223
+ """
224
+
225
+ @classmethod
226
+ @utils.validate_hf_hub_args
227
+ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
228
+ """
229
+ Load a pretrained NunchakuSanaTransformer2DModel from a local file or HuggingFace Hub.
230
+
231
+ This method supports both quantized and unquantized checkpoints, and will
232
+ automatically inject quantized transformer blocks if available.
233
+
234
+ Parameters
235
+ ----------
236
+ pretrained_model_name_or_path : str or os.PathLike
237
+ Path to the model checkpoint or HuggingFace Hub model name.
238
+ **kwargs
239
+ Additional keyword arguments for model loading.
240
+
241
+ Returns
242
+ -------
243
+ NunchakuSanaTransformer2DModel or (NunchakuSanaTransformer2DModel, dict)
244
+ The loaded model, and optionally metadata if ``return_metadata=True``.
245
+ """
246
+ device = kwargs.get("device", "cuda")
247
+ if isinstance(device, str):
248
+ device = torch.device(device)
249
+ pag_layers = kwargs.get("pag_layers", [])
250
+ precision = get_precision(kwargs.get("precision", "auto"), device, pretrained_model_name_or_path)
251
+ metadata = None
252
+
253
+ if isinstance(pretrained_model_name_or_path, str):
254
+ pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
255
+ if pretrained_model_name_or_path.is_file() or pretrained_model_name_or_path.name.endswith(
256
+ (".safetensors", ".sft")
257
+ ):
258
+ transformer, model_state_dict, metadata = cls._build_model(pretrained_model_name_or_path)
259
+ quantized_part_sd = {}
260
+ unquantized_part_sd = {}
261
+ for k, v in model_state_dict.items():
262
+ if k.startswith("transformer_blocks."):
263
+ quantized_part_sd[k] = v
264
+ else:
265
+ unquantized_part_sd[k] = v
266
+ m = load_quantized_module(
267
+ transformer, quantized_part_sd, device=device, pag_layers=pag_layers, use_fp4=precision == "fp4"
268
+ )
269
+ transformer.inject_quantized_module(m, device)
270
+ transformer.to_empty(device=device)
271
+ transformer.load_state_dict(unquantized_part_sd, strict=False)
272
+ else:
273
+ transformer, unquantized_part_path, transformer_block_path = cls._build_model_legacy(
274
+ pretrained_model_name_or_path, **kwargs
275
+ )
276
+ m = load_quantized_module(
277
+ transformer, transformer_block_path, device=device, pag_layers=pag_layers, use_fp4=precision == "fp4"
278
+ )
279
+ transformer.inject_quantized_module(m, device)
280
+ transformer.to_empty(device=device)
281
+ unquantized_state_dict = load_file(unquantized_part_path)
282
+ transformer.load_state_dict(unquantized_state_dict, strict=False)
283
+ if kwargs.get("return_metadata", False):
284
+ return transformer, metadata
285
+ else:
286
+ return transformer
287
+
288
+ def inject_quantized_module(self, m: QuantizedSanaModel, device: str | torch.device = "cuda"):
289
+ """
290
+ Inject a quantized transformer module into this model.
291
+
292
+ Parameters
293
+ ----------
294
+ m : QuantizedSanaModel
295
+ The quantized transformer module to inject.
296
+ device : str or torch.device, optional
297
+ The device to place the module on (default: "cuda").
298
+
299
+ Returns
300
+ -------
301
+ NunchakuSanaTransformer2DModel
302
+ The model with the quantized module injected.
303
+ """
304
+ self.transformer_blocks = torch.nn.ModuleList([NunchakuSanaTransformerBlocks(m, self.dtype, device)])
305
+ return self
306
+
307
+
308
+ def load_quantized_module(
309
+ net: SanaTransformer2DModel,
310
+ path_or_state_dict: str | os.PathLike[str] | dict[str, torch.Tensor],
311
+ device: str | torch.device = "cuda",
312
+ pag_layers: int | list[int] | None = None,
313
+ use_fp4: bool = False,
314
+ ) -> QuantizedSanaModel:
315
+ """
316
+ Load quantized weights into a QuantizedSanaModel.
317
+
318
+ Parameters
319
+ ----------
320
+ net : SanaTransformer2DModel
321
+ The base transformer model (for config and dtype).
322
+ path_or_state_dict : str, os.PathLike, or dict
323
+ Path to the quantized weights or a state dict.
324
+ device : str or torch.device, optional
325
+ Device to load the quantized model on (default: "cuda").
326
+ pag_layers : int, list of int, or None, optional
327
+ List of layers to use pag (default: None).
328
+ use_fp4 : bool, optional
329
+ Whether to use FP4 quantization (default: False).
330
+
331
+ Returns
332
+ -------
333
+ QuantizedSanaModel
334
+ The loaded quantized model.
335
+ """
336
+ if pag_layers is None:
337
+ pag_layers = []
338
+ elif isinstance(pag_layers, int):
339
+ pag_layers = [pag_layers]
340
+ device = torch.device(device)
341
+ assert device.type == "cuda"
342
+
343
+ m = QuantizedSanaModel()
344
+ cutils.disable_memory_auto_release()
345
+ m.init(net.config, pag_layers, use_fp4, net.dtype == torch.bfloat16, 0 if device.index is None else device.index)
346
+ if isinstance(path_or_state_dict, dict):
347
+ m.loadDict(path_or_state_dict, True)
348
+ else:
349
+ m.load(str(path_or_state_dict))
350
+ return m
351
+
352
+
353
+ def inject_quantized_module(
354
+ net: SanaTransformer2DModel, m: QuantizedSanaModel, device: torch.device
355
+ ) -> SanaTransformer2DModel:
356
+ """
357
+ Inject a quantized transformer module into a SanaTransformer2DModel.
358
+
359
+ Parameters
360
+ ----------
361
+ net : SanaTransformer2DModel
362
+ The base transformer model.
363
+ m : QuantizedSanaModel
364
+ The quantized transformer module to inject.
365
+ device : torch.device
366
+ The device to place the module on.
367
+
368
+ Returns
369
+ -------
370
+ SanaTransformer2DModel
371
+ The model with the quantized module injected.
372
+ """
373
+ net.transformer_blocks = torch.nn.ModuleList([NunchakuSanaTransformerBlocks(m, net.dtype, device)])
374
+ return net
nunchaku/models/transformers/utils.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for Nunchaku transformer model loading.
3
+ """
4
+
5
+ import json
6
+ import logging
7
+ import os
8
+ from pathlib import Path
9
+
10
+ import torch
11
+ from diffusers import __version__
12
+ from huggingface_hub import constants, hf_hub_download
13
+ from torch import nn
14
+
15
+ from ...utils import load_state_dict_in_safetensors
16
+
17
+ # Get log level from environment variable (default to INFO)
18
+ log_level = os.getenv("LOG_LEVEL", "INFO").upper()
19
+
20
+ # Configure logging
21
+ logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class NunchakuModelLoaderMixin:
26
+ """
27
+ Mixin for standardized model loading in Nunchaku transformer models.
28
+ """
29
+
30
+ @classmethod
31
+ def _build_model(
32
+ cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs
33
+ ) -> tuple[nn.Module, dict[str, torch.Tensor], dict[str, str]]:
34
+ """
35
+ Build a transformer model from a safetensors file.
36
+
37
+ Parameters
38
+ ----------
39
+ pretrained_model_name_or_path : str or os.PathLike
40
+ Path to the safetensors file.
41
+ **kwargs
42
+ Additional keyword arguments (e.g., ``torch_dtype``).
43
+
44
+ Returns
45
+ -------
46
+ tuple
47
+ (transformer, state_dict, metadata)
48
+ """
49
+ if isinstance(pretrained_model_name_or_path, str):
50
+ pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
51
+ state_dict, metadata = load_state_dict_in_safetensors(pretrained_model_name_or_path, return_metadata=True)
52
+
53
+ config = json.loads(metadata["config"])
54
+
55
+ with torch.device("meta"):
56
+ transformer = cls.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16))
57
+
58
+ return transformer, state_dict, metadata
59
+
60
+ @classmethod
61
+ def _build_model_legacy(
62
+ cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs
63
+ ) -> tuple[nn.Module, str, str]:
64
+ """
65
+ Build a transformer model from a legacy folder structure.
66
+
67
+ .. warning::
68
+ This method is deprecated and will be removed in December 2025.
69
+ Please use :meth:`_build_model` instead.
70
+
71
+ Parameters
72
+ ----------
73
+ pretrained_model_name_or_path : str or os.PathLike
74
+ Path to the folder containing model weights.
75
+ **kwargs
76
+ Additional keyword arguments for HuggingFace Hub download and config loading.
77
+
78
+ Returns
79
+ -------
80
+ tuple
81
+ (transformer, unquantized_part_path, transformer_block_path)
82
+ """
83
+ logger.warning(
84
+ "Loading models from a folder will be deprecated in December 2025. "
85
+ "Please download the latest safetensors model, or use one of the following tools to "
86
+ "merge your model into a single file: the CLI utility `python -m nunchaku.merge_safetensors` "
87
+ "or the ComfyUI workflow `merge_safetensors.json`."
88
+ )
89
+ subfolder = kwargs.get("subfolder", None)
90
+ if os.path.exists(pretrained_model_name_or_path):
91
+ dirname = (
92
+ pretrained_model_name_or_path
93
+ if subfolder is None
94
+ else os.path.join(pretrained_model_name_or_path, subfolder)
95
+ )
96
+ unquantized_part_path = os.path.join(dirname, "unquantized_layers.safetensors")
97
+ transformer_block_path = os.path.join(dirname, "transformer_blocks.safetensors")
98
+ else:
99
+ download_kwargs = {
100
+ "subfolder": subfolder,
101
+ "repo_type": "model",
102
+ "revision": kwargs.get("revision", None),
103
+ "cache_dir": kwargs.get("cache_dir", None),
104
+ "local_dir": kwargs.get("local_dir", None),
105
+ "user_agent": kwargs.get("user_agent", None),
106
+ "force_download": kwargs.get("force_download", False),
107
+ "proxies": kwargs.get("proxies", None),
108
+ "etag_timeout": kwargs.get("etag_timeout", constants.DEFAULT_ETAG_TIMEOUT),
109
+ "token": kwargs.get("token", None),
110
+ "local_files_only": kwargs.get("local_files_only", None),
111
+ "headers": kwargs.get("headers", None),
112
+ "endpoint": kwargs.get("endpoint", None),
113
+ "resume_download": kwargs.get("resume_download", None),
114
+ "force_filename": kwargs.get("force_filename", None),
115
+ "local_dir_use_symlinks": kwargs.get("local_dir_use_symlinks", "auto"),
116
+ }
117
+ unquantized_part_path = hf_hub_download(
118
+ repo_id=str(pretrained_model_name_or_path), filename="unquantized_layers.safetensors", **download_kwargs
119
+ )
120
+ transformer_block_path = hf_hub_download(
121
+ repo_id=str(pretrained_model_name_or_path), filename="transformer_blocks.safetensors", **download_kwargs
122
+ )
123
+
124
+ cache_dir = kwargs.pop("cache_dir", None)
125
+ force_download = kwargs.pop("force_download", False)
126
+ proxies = kwargs.pop("proxies", None)
127
+ local_files_only = kwargs.pop("local_files_only", None)
128
+ token = kwargs.pop("token", None)
129
+ revision = kwargs.pop("revision", None)
130
+ config, _, _ = cls.load_config(
131
+ pretrained_model_name_or_path,
132
+ subfolder=subfolder,
133
+ cache_dir=cache_dir,
134
+ return_unused_kwargs=True,
135
+ return_commit_hash=True,
136
+ force_download=force_download,
137
+ proxies=proxies,
138
+ local_files_only=local_files_only,
139
+ token=token,
140
+ revision=revision,
141
+ user_agent={"diffusers": __version__, "file_type": "model", "framework": "pytorch"},
142
+ **kwargs,
143
+ )
144
+
145
+ with torch.device("meta"):
146
+ transformer = cls.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16))
147
+ return transformer, unquantized_part_path, transformer_block_path
nunchaku/models/utils.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions and classes for efficient transformer model management in Nunchaku.
3
+ """
4
+
5
+ import copy
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+ from ..utils import copy_params_into
11
+
12
+
13
+ def fuse_linears(linears: list[nn.Linear]) -> nn.Linear:
14
+ """
15
+ Fuse a list of nn.Linear layers into a single nn.Linear with concatenated output features.
16
+
17
+ Parameters
18
+ ----------
19
+ linears : list of nn.Linear
20
+ List of linear layers to fuse. All must have the same input feature dimension.
21
+
22
+ Returns
23
+ -------
24
+ fused : nn.Linear
25
+ A new linear layer with concatenated output features and the same input features.
26
+
27
+ Raises
28
+ ------
29
+ AssertionError
30
+ If the input feature dimensions do not match.
31
+
32
+ Notes
33
+ -----
34
+ The fused layer does not copy weights or biases from the input layers.
35
+ """
36
+ assert len(linears) > 0
37
+ if len(linears) == 1:
38
+ return linears[0]
39
+ else:
40
+ assert all(linear.in_features == linears[0].in_features for linear in linears)
41
+ out_features = sum(linear.out_features for linear in linears)
42
+ bias = all(linear.bias is not None for linear in linears)
43
+ return nn.Linear(
44
+ linears[0].in_features,
45
+ out_features,
46
+ bias=bias,
47
+ dtype=linears[0].weight.dtype,
48
+ device=linears[0].weight.device,
49
+ )
50
+
51
+
52
+ class CPUOffloadManager:
53
+ """
54
+ Manager for per-transformer-block CPU offloading with asynchronous memory operations using a Ping-Pong buffer strategy.
55
+
56
+ This class enables memory-efficient inference or training by keeping only a subset
57
+ of transformer blocks on GPU, offloading the rest to CPU, and preloading blocks as needed.
58
+
59
+ Parameters
60
+ ----------
61
+ blocks : list of nn.Module
62
+ List of transformer blocks to manage.
63
+ device : str or torch.device, optional
64
+ Target CUDA device for GPU operations. Default is "cuda".
65
+ use_pin_memory : bool, optional
66
+ Whether to use pinned memory for faster CPU-to-GPU transfers. Default is True.
67
+ on_gpu_modules : list of nn.Module, optional
68
+ Additional modules to keep on GPU at all times. Default is [].
69
+ num_blocks_on_gpu : int, optional
70
+ Number of blocks to keep on GPU simultaneously. Must be > 0. Default is 1.
71
+ empty_cache_freq : int, optional
72
+ Frequency (in forward passes) to call torch.cuda.empty_cache(). Default is 0 (never).
73
+
74
+ Attributes
75
+ ----------
76
+ blocks : list of nn.Module
77
+ The managed transformer blocks.
78
+ buffer_blocks : list of nn.Module
79
+ Buffers for preloading blocks onto GPU.
80
+ device : torch.device
81
+ The current CUDA device.
82
+ current_block_idx : int
83
+ Index of the current block on GPU.
84
+ forward_counter : int
85
+ Number of forward passes completed.
86
+ memory_stream : torch.cuda.Stream
87
+ CUDA stream for memory operations.
88
+ compute_done : torch.cuda.Event
89
+ CUDA event signaling compute completion.
90
+ memory_done : torch.cuda.Event
91
+ CUDA event signaling memory completion.
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ blocks: list[nn.Module],
97
+ device: str | torch.device = torch.device("cuda"),
98
+ use_pin_memory: bool = True,
99
+ on_gpu_modules: list[nn.Module] = [],
100
+ num_blocks_on_gpu: int = 1,
101
+ empty_cache_freq: int = 0,
102
+ ):
103
+ self.blocks = blocks
104
+ self.use_pin_memory = use_pin_memory
105
+ self.on_gpu_modules = on_gpu_modules
106
+ self.num_blocks_on_gpu = num_blocks_on_gpu
107
+ assert self.num_blocks_on_gpu > 0
108
+
109
+ # Two streams: one for compute, one for memory operations, will be initialized in set_device
110
+ self.memory_stream = None
111
+
112
+ self.compute_done = torch.cuda.Event(blocking=False)
113
+ self.memory_done = torch.cuda.Event(blocking=False)
114
+
115
+ self.buffer_blocks = [copy.deepcopy(blocks[0]), copy.deepcopy(blocks[0])]
116
+
117
+ self.device = None
118
+ self.set_device(device)
119
+
120
+ self.current_block_idx = 0
121
+ self.forward_counter = 0
122
+ self.empty_cache_freq = empty_cache_freq
123
+
124
+ def set_device(self, device: torch.device | str, force: bool = False):
125
+ """
126
+ Set the CUDA device for offloading and memory operations.
127
+ It will move buffer blocks and on-GPU modules to the specified device and offload other blocks to CPU, optionally using pinned memory.
128
+
129
+ Parameters
130
+ ----------
131
+ device : torch.device or str
132
+ Target CUDA device.
133
+ force : bool, optional
134
+ If True, force re-initialization even if device is unchanged. Default is False.
135
+
136
+ Raises
137
+ ------
138
+ AssertionError
139
+ If the device is not a CUDA device.
140
+ """
141
+ if isinstance(device, str):
142
+ device = torch.device(device)
143
+ assert device.type == "cuda"
144
+ if self.device == device and not force:
145
+ return
146
+ self.device = device
147
+ self.memory_stream = torch.cuda.Stream(device=device)
148
+ for block in self.buffer_blocks:
149
+ block.to(device)
150
+ for module in self.on_gpu_modules:
151
+ module.to(device)
152
+ for i, block in enumerate(self.blocks):
153
+ if i < self.num_blocks_on_gpu:
154
+ block.to(device)
155
+ else:
156
+ block.to("cpu")
157
+ if self.use_pin_memory:
158
+ for p in block.parameters(recurse=True):
159
+ p.data = p.data.pin_memory()
160
+ for b in block.buffers(recurse=True):
161
+ b.data = b.data.pin_memory()
162
+
163
+ def load_block(self, block_idx: int, non_blocking: bool = True):
164
+ """
165
+ Move a transformer block from CPU to GPU buffer.
166
+
167
+ Parameters
168
+ ----------
169
+ block_idx : int
170
+ Index of the block to load.
171
+ non_blocking : bool, optional
172
+ Whether to use non-blocking memory copy. Default is True.
173
+
174
+ Notes
175
+ -----
176
+ - No action is taken if the block is already on GPU or index is out of range.
177
+ """
178
+ # if the block is already on GPU, don't load it to the buffer
179
+ if block_idx < self.num_blocks_on_gpu:
180
+ return
181
+ # if there are blocks on GPU, don't load the first block to the buffer again
182
+ if block_idx >= len(self.blocks):
183
+ return
184
+
185
+ block = self.blocks[block_idx]
186
+ copy_params_into(block, self.buffer_blocks[block_idx % 2], non_blocking=non_blocking)
187
+
188
+ def step(self, compute_stream: torch.cuda.Stream | None = None):
189
+ """
190
+ Advance to the next transformer block, triggering asynchronous preloading.
191
+
192
+ It will preload the next block onto GPU in the background and synchronize between compute and memory streams.
193
+ After all the blocks are processed, it will call torch.cuda.empty_cache() periodically if ``empty_cache_freq`` > 0.
194
+
195
+ Parameters
196
+ ----------
197
+ compute_stream : torch.cuda.Stream, optional
198
+ CUDA stream for compute operations. If None, uses current stream.
199
+ """
200
+ if compute_stream is None:
201
+ compute_stream = torch.cuda.current_stream()
202
+ next_compute_done = torch.cuda.Event()
203
+ next_compute_done.record(compute_stream)
204
+ with torch.cuda.stream(self.memory_stream):
205
+ self.memory_stream.wait_event(self.compute_done)
206
+ self.load_block(self.current_block_idx + 1) # if the current block is the last block, load the first block
207
+ next_memory_done = torch.cuda.Event()
208
+ next_memory_done.record(self.memory_stream)
209
+ self.memory_done = next_memory_done
210
+ self.compute_done = next_compute_done
211
+ self.current_block_idx += 1
212
+ if self.current_block_idx < len(self.blocks):
213
+ # get ready for the next compute
214
+ compute_stream.wait_event(self.memory_done)
215
+ else:
216
+ # ready to finish
217
+ compute_stream.wait_event(self.compute_done)
218
+ self.current_block_idx = 0
219
+ self.forward_counter += 1
220
+ if self.empty_cache_freq > 0 and self.forward_counter % self.empty_cache_freq == 0:
221
+ torch.cuda.empty_cache()
222
+
223
+ def get_block(self, block_idx: int | None = None) -> nn.Module:
224
+ """
225
+ Retrieve the current or specified transformer block for computation.
226
+ It will return a buffer block if the requested block is offloaded.
227
+
228
+ Parameters
229
+ ----------
230
+ block_idx : int, optional
231
+ Index of the block to retrieve. If None, returns the current block.
232
+
233
+ Returns
234
+ -------
235
+ block : nn.Module
236
+ The requested transformer block (on GPU if needed).
237
+ """
238
+ if block_idx is None:
239
+ block_idx = self.current_block_idx
240
+ if block_idx < self.num_blocks_on_gpu:
241
+ return self.blocks[block_idx]
242
+ else:
243
+ return self.buffer_blocks[block_idx % 2]
244
+
245
+ def initialize(self, stream: torch.cuda.Stream | None = None):
246
+ """
247
+ Initialize CUDA events for compute and memory streams.
248
+ It will record the initial events for the compute and memory streams.
249
+
250
+ Parameters
251
+ ----------
252
+ stream : torch.cuda.Stream, optional
253
+ CUDA stream to record initial events. If None, uses current stream.
254
+
255
+ Notes
256
+ -----
257
+ - Should be called before the first forward pass.
258
+ """
259
+ if stream is None:
260
+ stream = torch.cuda.current_stream()
261
+ self.compute_done.record(stream)
262
+ self.memory_done.record(stream)
nunchaku/ops/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Quantized operations for FLUX-Kontext
nunchaku/ops/fused.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ High-performance fused operators for quantized neural network inference.
3
+ """
4
+
5
+ import torch
6
+ from torch.nn import RMSNorm
7
+
8
+ from nunchaku.models.linear import SVDQW4A4Linear
9
+
10
+ from ..utils import ceil_divide
11
+ from .gemm import svdq_gemm_w4a4_cuda
12
+
13
+
14
+ def fused_gelu_mlp(x: torch.Tensor, fc1: SVDQW4A4Linear, fc2: SVDQW4A4Linear, pad_size: int = 256) -> torch.Tensor:
15
+ """
16
+ Fused quantized MLP with GELU activation.
17
+
18
+ Combines the first quantized linear layer, GELU activation, and the second quantized linear layer into a single CUDA kernel. Supports INT4 and NVFP4 quantization.
19
+
20
+ Parameters
21
+ ----------
22
+ x : torch.Tensor, shape (B, S, C_in), dtype float16 or bfloat16
23
+ Input tensor.
24
+ fc1 : SVDQW4A4Linear
25
+ First quantized linear layer (input → hidden).
26
+ fc2 : SVDQW4A4Linear
27
+ Second quantized linear layer (hidden → output).
28
+ pad_size : int, optional
29
+ Batch padding size for CUDA kernel efficiency. Default is 256.
30
+
31
+ Returns
32
+ -------
33
+ torch.Tensor, shape (B, S, C_out), dtype as input
34
+ Output tensor.
35
+
36
+ Notes
37
+ -----
38
+ - Notations:
39
+
40
+ - B: batch size
41
+ - S: sequence length
42
+ - C_in: input features
43
+ - C_out: output features
44
+ - For INT4 quantization, GELU activations are shifted by 0.171875 to ensure non-negativity, enabling unsigned quantization for improved quality. See: https://github.com/nunchaku-tech/nunchaku/blob/433f0b228a61a53fb700ac676fd2e290368ac94d/src/kernels/zgemm/gemm_w4a4_launch_impl.cuh#L286
45
+ """
46
+ batch_size, seq_len, channels = x.shape
47
+ x = x.view(batch_size * seq_len, channels)
48
+ quantized_x, ascales, lora_act = fc1.quantize(x)
49
+
50
+ batch_size_pad = ceil_divide(batch_size * seq_len, pad_size) * pad_size
51
+
52
+ qout_act = torch.empty(batch_size_pad, fc1.out_features // 2, dtype=torch.uint8, device=x.device)
53
+ if fc2.precision == "nvfp4":
54
+ qout_ascales = torch.empty(fc1.out_features // 16, batch_size_pad, dtype=torch.float8_e4m3fn, device=x.device)
55
+ else:
56
+ qout_ascales = torch.empty(fc1.out_features // 64, batch_size_pad, dtype=x.dtype, device=x.device)
57
+ qout_lora_act = torch.empty(batch_size_pad, fc2.proj_down.shape[1], dtype=torch.float32, device=x.device)
58
+
59
+ svdq_gemm_w4a4_cuda(
60
+ act=quantized_x,
61
+ wgt=fc1.qweight,
62
+ qout=qout_act,
63
+ ascales=ascales,
64
+ wscales=fc1.wscales,
65
+ oscales=qout_ascales,
66
+ lora_act_in=lora_act,
67
+ lora_up=fc1.proj_up,
68
+ lora_down=fc2.proj_down,
69
+ lora_act_out=qout_lora_act,
70
+ bias=fc1.bias,
71
+ smooth_factor=fc2.smooth_factor,
72
+ fp4=fc1.precision == "nvfp4",
73
+ alpha=fc1.wtscale,
74
+ wcscales=fc1.wcscales,
75
+ )
76
+ output = torch.empty(batch_size * seq_len, fc2.out_features, dtype=x.dtype, device=x.device)
77
+ output = fc2.forward_quant(qout_act, qout_ascales, qout_lora_act, output=output)
78
+ output = output.view(batch_size, seq_len, -1)
79
+ return output
80
+
81
+
82
+ def fused_qkv_norm_rottary(
83
+ x: torch.Tensor,
84
+ proj: SVDQW4A4Linear,
85
+ norm_q: RMSNorm,
86
+ norm_k: RMSNorm,
87
+ rotary_emb: torch.Tensor,
88
+ output: torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
89
+ attn_tokens: int = 0,
90
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
91
+ """
92
+ Fused quantized QKV projection with RMSNorm and rotary embeddings.
93
+
94
+ Performs quantized QKV projection, applies RMS normalization to Q and K, and fuses rotary embeddings in a single CUDA kernel call.
95
+
96
+ Parameters
97
+ ----------
98
+ x : torch.Tensor, shape (B, S, C_in), dtype float16 or bfloat16
99
+ Input tensor.
100
+ proj : SVDQW4A4Linear
101
+ Quantized QKV projection layer.
102
+ norm_q : RMSNorm
103
+ RMSNorm for query.
104
+ norm_k : RMSNorm
105
+ RMSNorm for key.
106
+ rotary_emb : torch.Tensor
107
+ Packed rotary embedding tensor (see :func:`~nunchaku.models.embeddings.pack_rotemb`).
108
+ output : torch.Tensor or tuple of torch.Tensor, optional
109
+ Output tensor(s). If None, a new tensor is allocated.
110
+ If tuple, should be (output_q, output_k, output_v) for fused attention packing.
111
+ attn_tokens : int, optional
112
+ Number of attention tokens. Default is 0.
113
+
114
+ Returns
115
+ -------
116
+ torch.Tensor or tuple of torch.Tensor
117
+ Output tensor of shape (B, S, C_out), or tuple (output_q, output_k, output_v).
118
+
119
+ Notes
120
+ -----
121
+ Notations:
122
+ - B: batch size
123
+ - S: sequence length
124
+ - C_in: input features
125
+ - C_out: output features
126
+ """
127
+ assert isinstance(norm_q, RMSNorm)
128
+ assert isinstance(norm_k, RMSNorm)
129
+
130
+ batch_size, seq_len, channels = x.shape
131
+ x = x.view(batch_size * seq_len, channels)
132
+ quantized_x, ascales, lora_act = proj.quantize(x)
133
+
134
+ if output is None:
135
+ output = torch.empty(quantized_x.shape[0], proj.out_features, dtype=x.dtype, device=x.device)
136
+
137
+ if isinstance(output, tuple):
138
+ assert len(output) == 3
139
+ output_q, output_k, output_v = output
140
+ svdq_gemm_w4a4_cuda(
141
+ act=quantized_x,
142
+ wgt=proj.qweight,
143
+ ascales=ascales,
144
+ wscales=proj.wscales,
145
+ lora_act_in=lora_act,
146
+ lora_up=proj.proj_up,
147
+ bias=proj.bias,
148
+ fp4=proj.precision == "nvfp4",
149
+ alpha=proj.wtscale,
150
+ wcscales=proj.wcscales,
151
+ norm_q=norm_q.weight,
152
+ norm_k=norm_k.weight,
153
+ rotary_emb=rotary_emb,
154
+ out_q=output_q,
155
+ out_k=output_k,
156
+ out_v=output_v,
157
+ attn_tokens=attn_tokens,
158
+ )
159
+ return output_q, output_k, output_v
160
+ else:
161
+ svdq_gemm_w4a4_cuda(
162
+ act=quantized_x,
163
+ wgt=proj.qweight,
164
+ out=output,
165
+ ascales=ascales,
166
+ wscales=proj.wscales,
167
+ lora_act_in=lora_act,
168
+ lora_up=proj.proj_up,
169
+ bias=proj.bias,
170
+ fp4=proj.precision == "nvfp4",
171
+ alpha=proj.wtscale,
172
+ wcscales=proj.wcscales,
173
+ norm_q=norm_q.weight,
174
+ norm_k=norm_k.weight,
175
+ rotary_emb=rotary_emb,
176
+ )
177
+ output = output.view(batch_size, seq_len, -1)
178
+ return output
nunchaku/ops/gemm.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Python wrappers for Nunchaku's high-performance quantized GEMM (General Matrix-Matrix Multiplication) CUDA kernels.
3
+ """
4
+
5
+ import math
6
+
7
+ import torch
8
+
9
+ from .._C import ops
10
+
11
+
12
+ def svdq_gemm_w4a4_cuda(
13
+ act: torch.Tensor,
14
+ wgt: torch.Tensor,
15
+ out: torch.Tensor | None = None,
16
+ qout: torch.Tensor | None = None,
17
+ ascales: torch.Tensor | None = None,
18
+ wscales: torch.Tensor | None = None,
19
+ oscales: torch.Tensor | None = None,
20
+ poolout: torch.Tensor | None = None,
21
+ lora_act_in: torch.Tensor | None = None,
22
+ lora_up: torch.Tensor | None = None,
23
+ lora_down: torch.Tensor | None = None,
24
+ lora_act_out: torch.Tensor | None = None,
25
+ norm_q: torch.Tensor | None = None,
26
+ norm_k: torch.Tensor | None = None,
27
+ rotary_emb: torch.Tensor | None = None,
28
+ bias: torch.Tensor | None = None,
29
+ smooth_factor: torch.Tensor | None = None,
30
+ out_vk: torch.Tensor | None = None,
31
+ out_linearattn: torch.Tensor | None = None,
32
+ act_unsigned: bool = False,
33
+ lora_scales: list[float] | None = None,
34
+ fuse_silu: bool = False,
35
+ fp4: bool = False,
36
+ alpha: float | None = 1.0,
37
+ wcscales: torch.Tensor | None = None,
38
+ out_q: torch.Tensor | None = None,
39
+ out_k: torch.Tensor | None = None,
40
+ out_v: torch.Tensor | None = None,
41
+ attn_tokens: int = 0,
42
+ ):
43
+ """
44
+ Quantized GEMM using SVDQuant W4A4 CUDA kernel, with support for LoRA, rotary embeddings, normalization, and fused activations.
45
+
46
+ Parameters
47
+ ----------
48
+ act : torch.Tensor, shape (M, K // 2), dtype int8
49
+ Packed input activations.
50
+ wgt : torch.Tensor, shape (N, K // 2), dtype int8
51
+ Packed quantized weights.
52
+ out : torch.Tensor or None, shape (M, N), dtype float16 or bfloat16, optional
53
+ Output tensor for the linear layer.
54
+ qout : torch.Tensor or None, shape (M, N // 2), dtype int8, optional
55
+ Packed quantized input for the next layer.
56
+ ascales : torch.Tensor or None, shape (K // G, M), dtype float16/bfloat16 (INT4) or float8_e4m3fn (NVFP4), optional
57
+ Activation scales.
58
+ wscales : torch.Tensor or None, shape (K // G, N), dtype float16/bfloat16 (INT4) or float8_e4m3fn (NVFP4), optional
59
+ Weight scales.
60
+ oscales : torch.Tensor or None, shape (N // G, M), dtype float16/bfloat16 (INT4) or float8_e4m3fn (NVFP4), optional
61
+ Output scales.
62
+ poolout : torch.Tensor or None, optional
63
+ Reserved for future use.
64
+ lora_act_in : torch.Tensor or None, shape (M, R), dtype float32, optional
65
+ LoRA down-projection activations.
66
+ lora_up : torch.Tensor or None, shape (N, R), dtype float16 or bfloat16, optional
67
+ Packed LoRA up-projection weights.
68
+ lora_down : torch.Tensor or None, shape (N, R), dtype float16 or bfloat16, optional
69
+ Packed LoRA down-projection weights for the next layer.
70
+ lora_act_out : torch.Tensor or None, shape (M, R), dtype float32, optional
71
+ Output for LoRA down-projection in the next layer.
72
+ norm_q : torch.Tensor or None, shape (HEAD_DIM,), dtype float16 or bfloat16, optional
73
+ Query RMS normalization.
74
+ norm_k : torch.Tensor or None, shape (HEAD_DIM,), dtype float16 or bfloat16, optional
75
+ Key RMS normalization.
76
+ rotary_emb : torch.Tensor or None, shape (M, HEAD_DIM // 2, 2, 2), dtype float32, optional
77
+ Packed rotary embeddings.
78
+ bias : torch.Tensor or None, shape (N,), dtype float16 or bfloat16, optional
79
+ Bias tensor.
80
+ smooth_factor : torch.Tensor or None, shape (N,), dtype float16 or bfloat16, optional
81
+ Smoothing factor for quantization in the next layer.
82
+ out_vk : torch.Tensor or None, optional
83
+ Used only in SANA. Leave as None.
84
+ out_linearattn : torch.Tensor or None, optional
85
+ Used only in SANA. Leave as None.
86
+ act_unsigned : bool, default=False
87
+ If True, activations are unsigned (e.g., after GeLU, shifted by 0.171875). This is only used for INT4 to enable unsigned INT4 activation quantization for better quantization quality.
88
+ lora_scales : list of float or None, optional
89
+ Per-group LoRA scaling factors (16 channels per group). Defaults to 1.0 per group.
90
+ fuse_silu : bool, default=False
91
+ If True, fuse SiLU activation.
92
+ fp4 : bool, default=False
93
+ If True, use 4-bit floating point quantization (NVFP4).
94
+ alpha : float or None, default=1.0
95
+ Per-tensor scaling factor for NVFP4.
96
+ wcscales : torch.Tensor or None, shape (N,), dtype float8_e4m3fn, optional
97
+ Per-channel scaling for NVFP4.
98
+ out_q : torch.Tensor or None, shape (B, H, M, D), dtype int8, optional
99
+ Packed quantized Q for attention (used in ``nunchaku-fp16`` attention).
100
+ out_k : torch.Tensor or None, shape (B, H, M, D), dtype int8, optional
101
+ Packed quantized K for attention (used in ``nunchaku-fp16`` attention).
102
+ out_v : torch.Tensor or None, shape (B, H, M, D), dtype int8, optional
103
+ Packed quantized V for attention (used in ``nunchaku-fp16`` attention).
104
+ attn_tokens : int, default=0
105
+ Number of attention tokens.
106
+
107
+ Returns
108
+ -------
109
+ None
110
+ Results are written in-place to the provided output tensors.
111
+
112
+ Notes
113
+ -----
114
+ Notations:
115
+
116
+ - M: batch size (input tokens)
117
+ - K: input channels (feature dimension)
118
+ - N: output channels
119
+ - G: group size (64 for INT4, 16 for NVFP4)
120
+ - R: LoRA rank
121
+ - B: batch size for attention
122
+ - H: number of heads
123
+ - D: head dimension
124
+ """
125
+ if lora_scales is None:
126
+ rank = lora_up.shape[1]
127
+ lora_scales = [1.0] * math.ceil(rank / 16)
128
+ if alpha is None:
129
+ alpha = 1.0
130
+ ops.gemm_w4a4(
131
+ act,
132
+ wgt,
133
+ out,
134
+ qout,
135
+ ascales,
136
+ wscales,
137
+ oscales,
138
+ poolout,
139
+ lora_act_in,
140
+ lora_up,
141
+ lora_down,
142
+ lora_act_out,
143
+ norm_q,
144
+ norm_k,
145
+ rotary_emb,
146
+ bias,
147
+ smooth_factor,
148
+ out_vk,
149
+ out_linearattn,
150
+ act_unsigned,
151
+ lora_scales,
152
+ fuse_silu,
153
+ fp4,
154
+ alpha,
155
+ wcscales,
156
+ out_q,
157
+ out_k,
158
+ out_v,
159
+ attn_tokens,
160
+ )
nunchaku/ops/gemv.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Python wrapper for Nunchaku's high-performance GEMV (General Matrix-Vector Multiplication) CUDA kernels.
3
+ """
4
+
5
+ import torch
6
+
7
+ from .._C import ops
8
+
9
+
10
+ def awq_gemv_w4a16_cuda(
11
+ in_feats: torch.Tensor,
12
+ kernel: torch.Tensor,
13
+ scaling_factors: torch.Tensor,
14
+ zeros: torch.Tensor,
15
+ m: int,
16
+ n: int,
17
+ k: int,
18
+ group_size: int = 64,
19
+ ) -> torch.Tensor:
20
+ """
21
+ Performs quantized GEMV using the AWQ W4A16 format.
22
+
23
+ Parameters
24
+ ----------
25
+ in_feats : torch.Tensor, shape (k,) or (m, k), dtype float16 or bfloat16
26
+ Input feature vector or batch of vectors.
27
+ kernel : torch.Tensor, shape (n // 4, k // 2), dtype int32
28
+ Packed quantized weight matrix.
29
+ scaling_factors : torch.Tensor, shape (k // group_size, n), dtype float16 or bfloat16
30
+ Per-group scaling factors.
31
+ zeros : torch.Tensor, shape (k // group_size, n), dtype float16 or bfloat16
32
+ Per-group zero points.
33
+ m : int
34
+ Batch size (number of input vectors).
35
+ n : int
36
+ Output feature dimension.
37
+ k : int
38
+ Input feature dimension.
39
+ group_size : int, optional
40
+ Number of input channels per quantization group. Default is 64.
41
+
42
+ Returns
43
+ -------
44
+ torch.Tensor, shape (m, n), dtype float16 or bfloat16
45
+ Output tensor.
46
+
47
+ Notes
48
+ -----
49
+ Notations:
50
+
51
+ - m: batch size
52
+ - n: output features
53
+ - k: input features
54
+ - group_size: quantization group size
55
+ """
56
+ return ops.gemv_awq(in_feats, kernel, scaling_factors, zeros, m, n, k, group_size)
nunchaku/ops/quantize.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module provides Python wrappers for Nunchaku's high-performance SVDQuant quantization CUDA kernels.
3
+ """
4
+
5
+ import torch
6
+
7
+ from .._C import ops
8
+ from ..utils import ceil_divide
9
+
10
+
11
+ def svdq_quantize_w4a4_act_fuse_lora_cuda(
12
+ input: torch.Tensor,
13
+ output: torch.Tensor | None = None,
14
+ oscales: torch.Tensor | None = None,
15
+ lora_down: torch.Tensor | None = None,
16
+ lora_act_out: torch.Tensor | None = None,
17
+ smooth: torch.Tensor | None = None,
18
+ fuse_glu: bool = False,
19
+ fp4: bool = False,
20
+ pad_size: int = 256,
21
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
22
+ """
23
+ Quantizes activations and computes LoRA down-projection using SVDQuant W4A4 CUDA kernel.
24
+
25
+ Parameters
26
+ ----------
27
+ input : torch.Tensor, shape (M, K), dtype bfloat16/float16
28
+ Input activations.
29
+ output : torch.Tensor or None, shape (M_pad, K // 2), dtype uint8, optional
30
+ Packed output tensor for quantized activations. Allocated if None.
31
+ oscales : torch.Tensor or None, shape (K // G, M_pad), dtype float8_e4m3fn for NVFP4 or input dtype for INT4, optional
32
+ Output scales tensor. Allocated if None.
33
+ lora_down : torch.Tensor or None, shape (K, R), dtype bfloat16/float16, optional
34
+ Packed LoRA down-projection weights.
35
+ lora_act_out : torch.Tensor or None, shape (M_pad, R), dtype float32, optional
36
+ Packed output tensor for LoRA activations. Allocated if None.
37
+ smooth : torch.Tensor or None, optional, dtype bfloat16/float16
38
+ Smoothing factor for quantization.
39
+ fuse_glu : bool, default=False
40
+ If True, fuse GLU activation.
41
+ fp4 : bool, default=False
42
+ If True, use NVFP4 quantization; else INT4.
43
+ pad_size : int, default=256
44
+ Pad batch size to a multiple of this value for efficient CUDA execution.
45
+
46
+ Returns
47
+ -------
48
+ output : torch.Tensor, shape (M_pad, K // 2), dtype uint8
49
+ Packed quantized activations.
50
+ oscales : torch.Tensor, shape (K // G, M_pad), dtype float8_e4m3fn for NVFP4 or input dtype for INT4
51
+ Output scales.
52
+ lora_act_out : torch.Tensor, shape (M_pad, R), dtype float32
53
+ Packed LoRA activation output.
54
+
55
+ Notes
56
+ -----
57
+ Notations:
58
+
59
+ - M: batch size
60
+ - K: input channels
61
+ - R: LoRA rank
62
+ - G: group size (64 for INT4, 16 for NVFP4)
63
+ - M_pad: padded batch size = ceil(M / pad_size) * pad_size
64
+ """
65
+ batch_size, channels = input.shape
66
+ rank = lora_down.shape[1]
67
+ batch_size_pad = ceil_divide(batch_size, pad_size) * pad_size
68
+ if output is None:
69
+ output = torch.empty(batch_size_pad, channels // 2, dtype=torch.uint8, device=input.device)
70
+ if oscales is None:
71
+ if fp4:
72
+ assert channels % 16 == 0
73
+ oscales = torch.empty(channels // 16, batch_size_pad, dtype=torch.float8_e4m3fn, device=input.device)
74
+ else:
75
+ assert channels % 64 == 0
76
+ oscales = torch.empty(channels // 64, batch_size_pad, dtype=input.dtype, device=input.device)
77
+ if lora_act_out is None:
78
+ lora_act_out = torch.empty(batch_size_pad, rank, dtype=torch.float32, device=input.device)
79
+
80
+ ops.quantize_w4a4_act_fuse_lora(input, output, oscales, lora_down, lora_act_out, smooth, fuse_glu, fp4)
81
+ return output, oscales, lora_act_out