| | import time |
| | from contextlib import contextmanager |
| | from pathlib import Path |
| |
|
| | import accelerate |
| | import torch |
| | from safetensors.torch import load_file, save_file |
| |
|
| | from invokeai.backend.flux.model import Flux |
| | from invokeai.backend.flux.util import params |
| | from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 |
| |
|
| |
|
| | @contextmanager |
| | def log_time(name: str): |
| | """Helper context manager to log the time taken by a block of code.""" |
| | start = time.time() |
| | try: |
| | yield None |
| | finally: |
| | end = time.time() |
| | print(f"'{name}' took {end - start:.4f} secs") |
| |
|
| |
|
| | def main(): |
| | """A script for quantizing a FLUX transformer model using the bitsandbytes NF4 quantization method. |
| | |
| | This script is primarily intended for reference. The script params (e.g. the model_path, modules_to_not_convert, |
| | etc.) are hardcoded and would need to be modified for other use cases. |
| | """ |
| | model_path = Path( |
| | "/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors" |
| | ) |
| |
|
| | |
| | with log_time("Intialize FLUX transformer on meta device"): |
| | |
| | p = params["flux-schnell"] |
| |
|
| | |
| | with accelerate.init_empty_weights(): |
| | model = Flux(p) |
| |
|
| | |
| | |
| | modules_to_not_convert: set[str] = set() |
| |
|
| | model_nf4_path = model_path.parent / "bnb_nf4.safetensors" |
| | if model_nf4_path.exists(): |
| | |
| | print(f"A pre-quantized model already exists at '{model_nf4_path}'. Attempting to load it...") |
| |
|
| | |
| | with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights(): |
| | model = quantize_model_nf4( |
| | model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16 |
| | ) |
| |
|
| | with log_time("Load state dict into model"): |
| | state_dict = load_file(model_nf4_path) |
| | model.load_state_dict(state_dict, strict=True, assign=True) |
| |
|
| | with log_time("Move model to cuda"): |
| | model = model.to("cuda") |
| |
|
| | print(f"Successfully loaded pre-quantized model from '{model_nf4_path}'.") |
| |
|
| | else: |
| | |
| | print(f"No pre-quantized model found at '{model_nf4_path}'. Quantizing the model...") |
| |
|
| | with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights(): |
| | model = quantize_model_nf4( |
| | model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16 |
| | ) |
| |
|
| | with log_time("Load state dict into model"): |
| | state_dict = load_file(model_path) |
| | |
| | model.load_state_dict(state_dict, strict=True, assign=True) |
| |
|
| | with log_time("Move model to cuda and quantize"): |
| | model = model.to("cuda") |
| |
|
| | with log_time("Save quantized model"): |
| | model_nf4_path.parent.mkdir(parents=True, exist_ok=True) |
| | save_file(model.state_dict(), model_nf4_path) |
| |
|
| | print(f"Successfully quantized and saved model to '{model_nf4_path}'.") |
| |
|
| | assert isinstance(model, Flux) |
| | return model |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|