Update README.md
Browse files
README.md
CHANGED
|
@@ -128,7 +128,7 @@ Potential fix: app.diffusion.pipeline.config.py
|
|
| 128 |
if isinstance(module, torch.nn.Module):
|
| 129 |
try:
|
| 130 |
print(f">>> Moving {name} to {device} using to_empty()")
|
| 131 |
-
module.to_empty(device)
|
| 132 |
except Exception as e:
|
| 133 |
print(f">>> WARNING: {name}.to_empty({device}) failed: {e}")
|
| 134 |
try:
|
|
@@ -150,38 +150,6 @@ Potential fix: app.diffusion.pipeline.config.py
|
|
| 150 |
return pipeline
|
| 151 |
```
|
| 152 |
|
| 153 |
-
Debug Log
|
| 154 |
-
```
|
| 155 |
-
25-07-22 20:11:56 | I | === Start Evaluating ===
|
| 156 |
-
25-07-22 20:11:56 | I | * Building diffusion model pipeline
|
| 157 |
-
Loading pipeline components...: 0%| | 0/7 [00:00<?, ?it/s]
|
| 158 |
-
You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
|
| 159 |
-
Loading checkpoint shards: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 2/2 [00:00<00:00, 18.92it/s]
|
| 160 |
-
Loading pipeline components...: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββ| 7/7 [00:00<00:00, 9.50it/s]
|
| 161 |
-
>>> DEVICE: cuda
|
| 162 |
-
>>> PIPELINE TYPE: <class 'diffusers.pipelines.flux.pipeline_flux_kontext.FluxKontextPipeline'>
|
| 163 |
-
>>> Moving transformer to cuda using to_empty()
|
| 164 |
-
>>> WARNING: transformer.to_empty(cuda) failed: Module.to_empty() takes 1 positional argument but 2 were given
|
| 165 |
-
>>> Falling back to transformer.to(cuda)
|
| 166 |
-
>>> ERROR: transformer.to(cuda) also failed: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.
|
| 167 |
-
>>> Moving vae to cuda using to_empty()
|
| 168 |
-
>>> WARNING: vae.to_empty(cuda) failed: Module.to_empty() takes 1 positional argument but 2 were given
|
| 169 |
-
>>> Falling back to vae.to(cuda)
|
| 170 |
-
>>> Moving text_encoder to cuda using to_empty()
|
| 171 |
-
>>> WARNING: text_encoder.to_empty(cuda) failed: Module.to_empty() takes 1 positional argument but 2 were given
|
| 172 |
-
>>> Falling back to text_encoder.to(cuda)
|
| 173 |
-
25-07-22 20:11:59 | I | Replacing fused Linear with ConcatLinear.
|
| 174 |
-
25-07-22 20:11:59 | I | + Replacing fused Linear in single_transformer_blocks.0 with ConcatLinear.
|
| 175 |
-
25-07-22 20:11:59 | I | - in_features = 3072/15360
|
| 176 |
-
25-07-22 20:11:59 | I | - out_features = 3072
|
| 177 |
-
25-07-22 20:11:59 | I | + Replacing fused Linear in single_transformer_blocks.1 with ConcatLinear.
|
| 178 |
-
25-07-22 20:11:59 | I | - in_features = 3072/15360
|
| 179 |
-
25-07-22 20:11:59 | I | - out_features = 3072
|
| 180 |
-
25-07-22 20:11:59 | I | + Replacing fused Linear in single_transformer_blocks.2 with ConcatLinear.
|
| 181 |
-
25-07-22 20:11:59 | I | - in_features = 3072/15360
|
| 182 |
-
25-07-22 20:11:59 | I | - out_features = 3072
|
| 183 |
-
```
|
| 184 |
-
|
| 185 |
2) KeyError: <class 'diffusers.models.transformers.transformer_flux.FluxAttention'>
|
| 186 |
|
| 187 |
Potential fix: app.diffusion.nn.struct.py
|
|
@@ -413,7 +381,7 @@ def collect(config: DiffusionPtqRunConfig, dataset: datasets.Dataset):
|
|
| 413 |
|
| 414 |
4) RuntimeError: Tensor.item() cannot be called on meta tensors
|
| 415 |
|
| 416 |
-
Potential Fix:
|
| 417 |
|
| 418 |
```python
|
| 419 |
def quantize_scale(
|
|
@@ -591,6 +559,288 @@ def quantize_scale(
|
|
| 591 |
return s, z
|
| 592 |
```
|
| 593 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 594 |
References
|
| 595 |
|
| 596 |
https://github.com/nunchaku-tech/nunchaku/commit/b99fb8be615bc98c6915bbe06a1e0092cbc074a5
|
|
|
|
| 128 |
if isinstance(module, torch.nn.Module):
|
| 129 |
try:
|
| 130 |
print(f">>> Moving {name} to {device} using to_empty()")
|
| 131 |
+
module.to_empty(device=device)
|
| 132 |
except Exception as e:
|
| 133 |
print(f">>> WARNING: {name}.to_empty({device}) failed: {e}")
|
| 134 |
try:
|
|
|
|
| 150 |
return pipeline
|
| 151 |
```
|
| 152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
2) KeyError: <class 'diffusers.models.transformers.transformer_flux.FluxAttention'>
|
| 154 |
|
| 155 |
Potential fix: app.diffusion.nn.struct.py
|
|
|
|
| 381 |
|
| 382 |
4) RuntimeError: Tensor.item() cannot be called on meta tensors
|
| 383 |
|
| 384 |
+
Potential Fix: quantizer.impl.scale.py
|
| 385 |
|
| 386 |
```python
|
| 387 |
def quantize_scale(
|
|
|
|
| 559 |
return s, z
|
| 560 |
```
|
| 561 |
|
| 562 |
+
Potential Fix: app.diffusion.ptq.py
|
| 563 |
+
|
| 564 |
+
```python
|
| 565 |
+
def ptq( # noqa: C901
|
| 566 |
+
model: DiffusionModelStruct,
|
| 567 |
+
config: DiffusionQuantConfig,
|
| 568 |
+
cache: DiffusionPtqCacheConfig | None = None,
|
| 569 |
+
load_dirpath: str = "",
|
| 570 |
+
save_dirpath: str = "",
|
| 571 |
+
copy_on_save: bool = False,
|
| 572 |
+
save_model: bool = False,
|
| 573 |
+
) -> DiffusionModelStruct:
|
| 574 |
+
"""Post-training quantization of a diffusion model.
|
| 575 |
+
|
| 576 |
+
Args:
|
| 577 |
+
model (`DiffusionModelStruct`):
|
| 578 |
+
The diffusion model.
|
| 579 |
+
config (`DiffusionQuantConfig`):
|
| 580 |
+
The diffusion model post-training quantization configuration.
|
| 581 |
+
cache (`DiffusionPtqCacheConfig`, *optional*, defaults to `None`):
|
| 582 |
+
The diffusion model quantization cache path configuration.
|
| 583 |
+
load_dirpath (`str`, *optional*, defaults to `""`):
|
| 584 |
+
The directory path to load the quantization checkpoint.
|
| 585 |
+
save_dirpath (`str`, *optional*, defaults to `""`):
|
| 586 |
+
The directory path to save the quantization checkpoint.
|
| 587 |
+
copy_on_save (`bool`, *optional*, defaults to `False`):
|
| 588 |
+
Whether to copy the cache to the save directory.
|
| 589 |
+
save_model (`bool`, *optional*, defaults to `False`):
|
| 590 |
+
Whether to save the quantized model checkpoint.
|
| 591 |
+
|
| 592 |
+
Returns:
|
| 593 |
+
`DiffusionModelStruct`:
|
| 594 |
+
The quantized diffusion model.
|
| 595 |
+
"""
|
| 596 |
+
logger = tools.logging.getLogger(__name__)
|
| 597 |
+
if not isinstance(model, DiffusionModelStruct):
|
| 598 |
+
model = DiffusionModelStruct.construct(model)
|
| 599 |
+
assert isinstance(model, DiffusionModelStruct)
|
| 600 |
+
|
| 601 |
+
quant_wgts = config.enabled_wgts
|
| 602 |
+
quant_ipts = config.enabled_ipts
|
| 603 |
+
quant_opts = config.enabled_opts
|
| 604 |
+
quant_acts = quant_ipts or quant_opts
|
| 605 |
+
quant = quant_wgts or quant_acts
|
| 606 |
+
|
| 607 |
+
load_model_path, load_path, save_path = "", None, None
|
| 608 |
+
if load_dirpath:
|
| 609 |
+
load_path = DiffusionQuantCacheConfig(
|
| 610 |
+
smooth=os.path.join(load_dirpath, "smooth.pt"),
|
| 611 |
+
branch=os.path.join(load_dirpath, "branch.pt"),
|
| 612 |
+
wgts=os.path.join(load_dirpath, "wgts.pt"),
|
| 613 |
+
acts=os.path.join(load_dirpath, "acts.pt"),
|
| 614 |
+
)
|
| 615 |
+
load_model_path = os.path.join(load_dirpath, "model.pt")
|
| 616 |
+
if os.path.exists(load_model_path):
|
| 617 |
+
if config.enabled_wgts and config.wgts.enabled_low_rank:
|
| 618 |
+
if os.path.exists(load_path.branch):
|
| 619 |
+
load_model = True
|
| 620 |
+
else:
|
| 621 |
+
logger.warning(f"Model low-rank branch checkpoint {load_path.branch} does not exist")
|
| 622 |
+
load_model = False
|
| 623 |
+
else:
|
| 624 |
+
load_model = True
|
| 625 |
+
if load_model:
|
| 626 |
+
logger.info(f"* Loading model from {load_model_path}")
|
| 627 |
+
save_dirpath = "" # do not save the model if loading
|
| 628 |
+
else:
|
| 629 |
+
logger.warning(f"Model checkpoint {load_model_path} does not exist")
|
| 630 |
+
load_model = False
|
| 631 |
+
else:
|
| 632 |
+
load_model = False
|
| 633 |
+
if save_dirpath:
|
| 634 |
+
os.makedirs(save_dirpath, exist_ok=True)
|
| 635 |
+
save_path = DiffusionQuantCacheConfig(
|
| 636 |
+
smooth=os.path.join(save_dirpath, "smooth.pt"),
|
| 637 |
+
branch=os.path.join(save_dirpath, "branch.pt"),
|
| 638 |
+
wgts=os.path.join(save_dirpath, "wgts.pt"),
|
| 639 |
+
acts=os.path.join(save_dirpath, "acts.pt"),
|
| 640 |
+
)
|
| 641 |
+
else:
|
| 642 |
+
save_model = False
|
| 643 |
+
|
| 644 |
+
if quant and config.enabled_rotation:
|
| 645 |
+
logger.info("* Rotating model for quantization")
|
| 646 |
+
tools.logging.Formatter.indent_inc()
|
| 647 |
+
rotate_diffusion(model, config=config)
|
| 648 |
+
tools.logging.Formatter.indent_dec()
|
| 649 |
+
gc.collect()
|
| 650 |
+
torch.cuda.empty_cache()
|
| 651 |
+
|
| 652 |
+
# region smooth quantization
|
| 653 |
+
if quant and config.enabled_smooth:
|
| 654 |
+
logger.info("* Smoothing model for quantization")
|
| 655 |
+
tools.logging.Formatter.indent_inc()
|
| 656 |
+
load_from = ""
|
| 657 |
+
if load_path and os.path.exists(load_path.smooth):
|
| 658 |
+
load_from = load_path.smooth
|
| 659 |
+
elif cache and cache.path.smooth and os.path.exists(cache.path.smooth):
|
| 660 |
+
load_from = cache.path.smooth
|
| 661 |
+
if load_from:
|
| 662 |
+
logger.info(f"- Loading smooth scales from {load_from}")
|
| 663 |
+
smooth_cache = torch.load(load_from)
|
| 664 |
+
smooth_diffusion(model, config, smooth_cache=smooth_cache)
|
| 665 |
+
else:
|
| 666 |
+
logger.info("- Generating smooth scales")
|
| 667 |
+
smooth_cache = smooth_diffusion(model, config)
|
| 668 |
+
if cache and cache.path.smooth:
|
| 669 |
+
logger.info(f"- Saving smooth scales to {cache.path.smooth}")
|
| 670 |
+
os.makedirs(cache.dirpath.smooth, exist_ok=True)
|
| 671 |
+
torch.save(smooth_cache, cache.path.smooth)
|
| 672 |
+
load_from = cache.path.smooth
|
| 673 |
+
if save_path:
|
| 674 |
+
if not copy_on_save and load_from:
|
| 675 |
+
logger.info(f"- Linking smooth scales to {save_path.smooth}")
|
| 676 |
+
os.symlink(os.path.relpath(load_from, save_dirpath), save_path.smooth)
|
| 677 |
+
else:
|
| 678 |
+
logger.info(f"- Saving smooth scales to {save_path.smooth}")
|
| 679 |
+
torch.save(smooth_cache, save_path.smooth)
|
| 680 |
+
del smooth_cache
|
| 681 |
+
tools.logging.Formatter.indent_dec()
|
| 682 |
+
gc.collect()
|
| 683 |
+
torch.cuda.empty_cache()
|
| 684 |
+
# endregion
|
| 685 |
+
# region collect original state dict
|
| 686 |
+
if config.needs_acts_quantizer_cache:
|
| 687 |
+
if load_path and os.path.exists(load_path.acts):
|
| 688 |
+
orig_state_dict = None
|
| 689 |
+
elif cache and cache.path.acts and os.path.exists(cache.path.acts):
|
| 690 |
+
orig_state_dict = None
|
| 691 |
+
else:
|
| 692 |
+
orig_state_dict: dict[str, torch.Tensor] = {
|
| 693 |
+
name: param.detach().clone() for name, param in model.module.named_parameters() if param.ndim > 1
|
| 694 |
+
}
|
| 695 |
+
else:
|
| 696 |
+
orig_state_dict = None
|
| 697 |
+
# endregion
|
| 698 |
+
if load_model:
|
| 699 |
+
logger.info(f"* Loading model checkpoint from {load_model_path}")
|
| 700 |
+
load_diffusion_weights_state_dict(
|
| 701 |
+
model,
|
| 702 |
+
config,
|
| 703 |
+
state_dict=torch.load(load_model_path),
|
| 704 |
+
branch_state_dict=torch.load(load_path.branch) if os.path.exists(load_path.branch) else None,
|
| 705 |
+
)
|
| 706 |
+
gc.collect()
|
| 707 |
+
torch.cuda.empty_cache()
|
| 708 |
+
elif quant_wgts:
|
| 709 |
+
logger.info("* Ensuring model is on actual device before quantization")
|
| 710 |
+
|
| 711 |
+
# Check if model has meta tensors
|
| 712 |
+
has_meta_tensors = any(param.is_meta for param in model.module.parameters())
|
| 713 |
+
|
| 714 |
+
if has_meta_tensors:
|
| 715 |
+
logger.info("* Model contains meta tensors, materializing to actual device")
|
| 716 |
+
|
| 717 |
+
# Option 1: Use to_empty() and reload weights (recommended)
|
| 718 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 719 |
+
|
| 720 |
+
# Store original state dict if available
|
| 721 |
+
try:
|
| 722 |
+
original_state_dict = model.module.state_dict()
|
| 723 |
+
model.module = model.module.to_empty(device=device)
|
| 724 |
+
model.module.load_state_dict(original_state_dict)
|
| 725 |
+
logger.info("* Successfully materialized model with original weights")
|
| 726 |
+
except Exception as e:
|
| 727 |
+
logger.warning(f"* Failed to preserve weights during materialization: {e}")
|
| 728 |
+
# Fallback: just move to empty device (weights will be zero)
|
| 729 |
+
model.module = model.module.to_empty(device=device)
|
| 730 |
+
logger.warning("* Model moved to device but weights may be uninitialized")
|
| 731 |
+
else:
|
| 732 |
+
# Model already has real tensors, just ensure it's on the right device
|
| 733 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 734 |
+
model.module = model.module.to(device)
|
| 735 |
+
|
| 736 |
+
# Verify no meta tensors remain
|
| 737 |
+
remaining_meta = [name for name, param in model.module.named_parameters() if param.is_meta]
|
| 738 |
+
if remaining_meta:
|
| 739 |
+
raise RuntimeError(f"Parameters still on meta device: {remaining_meta}")
|
| 740 |
+
|
| 741 |
+
logger.info("* Model successfully prepared for quantization")
|
| 742 |
+
|
| 743 |
+
logger.info("* Quantizing weights")
|
| 744 |
+
tools.logging.Formatter.indent_inc()
|
| 745 |
+
quantizer_state_dict, quantizer_load_from = None, ""
|
| 746 |
+
if load_path and os.path.exists(load_path.wgts):
|
| 747 |
+
quantizer_load_from = load_path.wgts
|
| 748 |
+
elif cache and cache.path.wgts and os.path.exists(cache.path.wgts):
|
| 749 |
+
quantizer_load_from = cache.path.wgts
|
| 750 |
+
if quantizer_load_from:
|
| 751 |
+
logger.info(f"- Loading weight settings from {quantizer_load_from}")
|
| 752 |
+
quantizer_state_dict = torch.load(quantizer_load_from)
|
| 753 |
+
branch_state_dict, branch_load_from = None, ""
|
| 754 |
+
if load_path and os.path.exists(load_path.branch):
|
| 755 |
+
branch_load_from = load_path.branch
|
| 756 |
+
elif cache and cache.path.branch and os.path.exists(cache.path.branch):
|
| 757 |
+
branch_load_from = cache.path.branch
|
| 758 |
+
if branch_load_from:
|
| 759 |
+
logger.info(f"- Loading branch settings from {branch_load_from}")
|
| 760 |
+
branch_state_dict = torch.load(branch_load_from)
|
| 761 |
+
if not quantizer_load_from:
|
| 762 |
+
logger.info("- Generating weight settings")
|
| 763 |
+
if not branch_load_from:
|
| 764 |
+
logger.info("- Generating branch settings")
|
| 765 |
+
quantizer_state_dict, branch_state_dict, scale_state_dict = quantize_diffusion_weights(
|
| 766 |
+
model,
|
| 767 |
+
config,
|
| 768 |
+
quantizer_state_dict=quantizer_state_dict,
|
| 769 |
+
branch_state_dict=branch_state_dict,
|
| 770 |
+
return_with_scale_state_dict=bool(save_dirpath),
|
| 771 |
+
)
|
| 772 |
+
if not quantizer_load_from and cache and cache.dirpath.wgts:
|
| 773 |
+
logger.info(f"- Saving weight settings to {cache.path.wgts}")
|
| 774 |
+
os.makedirs(cache.dirpath.wgts, exist_ok=True)
|
| 775 |
+
torch.save(quantizer_state_dict, cache.path.wgts)
|
| 776 |
+
quantizer_load_from = cache.path.wgts
|
| 777 |
+
if not branch_load_from and cache and cache.dirpath.branch:
|
| 778 |
+
logger.info(f"- Saving branch settings to {cache.path.branch}")
|
| 779 |
+
os.makedirs(cache.dirpath.branch, exist_ok=True)
|
| 780 |
+
torch.save(branch_state_dict, cache.path.branch)
|
| 781 |
+
branch_load_from = cache.path.branch
|
| 782 |
+
if save_path:
|
| 783 |
+
if not copy_on_save and quantizer_load_from:
|
| 784 |
+
logger.info(f"- Linking weight settings to {save_path.wgts}")
|
| 785 |
+
os.symlink(os.path.relpath(quantizer_load_from, save_dirpath), save_path.wgts)
|
| 786 |
+
else:
|
| 787 |
+
logger.info(f"- Saving weight settings to {save_path.wgts}")
|
| 788 |
+
torch.save(quantizer_state_dict, save_path.wgts)
|
| 789 |
+
if not copy_on_save and branch_load_from:
|
| 790 |
+
logger.info(f"- Linking branch settings to {save_path.branch}")
|
| 791 |
+
os.symlink(os.path.relpath(branch_load_from, save_dirpath), save_path.branch)
|
| 792 |
+
else:
|
| 793 |
+
logger.info(f"- Saving branch settings to {save_path.branch}")
|
| 794 |
+
torch.save(branch_state_dict, save_path.branch)
|
| 795 |
+
if save_model:
|
| 796 |
+
logger.info(f"- Saving model to {save_dirpath}")
|
| 797 |
+
torch.save(scale_state_dict, os.path.join(save_dirpath, "scale.pt"))
|
| 798 |
+
torch.save(model.module.state_dict(), os.path.join(save_dirpath, "model.pt"))
|
| 799 |
+
del quantizer_state_dict, branch_state_dict, scale_state_dict
|
| 800 |
+
tools.logging.Formatter.indent_dec()
|
| 801 |
+
gc.collect()
|
| 802 |
+
torch.cuda.empty_cache()
|
| 803 |
+
if quant_acts:
|
| 804 |
+
logger.info(" * Quantizing activations")
|
| 805 |
+
tools.logging.Formatter.indent_inc()
|
| 806 |
+
if config.needs_acts_quantizer_cache:
|
| 807 |
+
load_from = ""
|
| 808 |
+
if load_path and os.path.exists(load_path.acts):
|
| 809 |
+
load_from = load_path.acts
|
| 810 |
+
elif cache and cache.path.acts and os.path.exists(cache.path.acts):
|
| 811 |
+
load_from = cache.path.acts
|
| 812 |
+
if load_from:
|
| 813 |
+
logger.info(f"- Loading activation settings from {load_from}")
|
| 814 |
+
quantizer_state_dict = torch.load(load_from)
|
| 815 |
+
quantize_diffusion_activations(
|
| 816 |
+
model, config, quantizer_state_dict=quantizer_state_dict, orig_state_dict=orig_state_dict
|
| 817 |
+
)
|
| 818 |
+
else:
|
| 819 |
+
logger.info("- Generating activation settings")
|
| 820 |
+
quantizer_state_dict = quantize_diffusion_activations(model, config, orig_state_dict=orig_state_dict)
|
| 821 |
+
if cache and cache.dirpath.acts and quantizer_state_dict is not None:
|
| 822 |
+
logger.info(f"- Saving activation settings to {cache.path.acts}")
|
| 823 |
+
os.makedirs(cache.dirpath.acts, exist_ok=True)
|
| 824 |
+
torch.save(quantizer_state_dict, cache.path.acts)
|
| 825 |
+
load_from = cache.path.acts
|
| 826 |
+
if save_dirpath:
|
| 827 |
+
if not copy_on_save and load_from:
|
| 828 |
+
logger.info(f"- Linking activation quantizer settings to {save_path.acts}")
|
| 829 |
+
os.symlink(os.path.relpath(load_from, save_dirpath), save_path.acts)
|
| 830 |
+
else:
|
| 831 |
+
logger.info(f"- Saving activation quantizer settings to {save_path.acts}")
|
| 832 |
+
torch.save(quantizer_state_dict, save_path.acts)
|
| 833 |
+
del quantizer_state_dict
|
| 834 |
+
else:
|
| 835 |
+
logger.info("- No need to generate/load activation quantizer settings")
|
| 836 |
+
quantize_diffusion_activations(model, config, orig_state_dict=orig_state_dict)
|
| 837 |
+
tools.logging.Formatter.indent_dec()
|
| 838 |
+
del orig_state_dict
|
| 839 |
+
gc.collect()
|
| 840 |
+
torch.cuda.empty_cache()
|
| 841 |
+
return model
|
| 842 |
+
```
|
| 843 |
+
|
| 844 |
References
|
| 845 |
|
| 846 |
https://github.com/nunchaku-tech/nunchaku/commit/b99fb8be615bc98c6915bbe06a1e0092cbc074a5
|