Spaces:
Running on Zero
Running on Zero
File size: 5,910 Bytes
6a07ce1 570384a 3631a8e b89e643 3631a8e b89e643 3631a8e b89e643 3631a8e b89e643 3631a8e b89e643 3631a8e b89e643 3631a8e b89e643 3631a8e 6a07ce1 570384a 6a07ce1 b89e643 6a07ce1 570384a b89e643 6a07ce1 459ac47 b89e643 6a07ce1 3631a8e 6a07ce1 b89e643 6a07ce1 b89e643 6a07ce1 570384a b89e643 6a07ce1 b89e643 6a07ce1 b89e643 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | """Model export functionality for SDXL Model Merger."""
import torch
from safetensors.torch import save_file
from . import config
from .config import SCRIPT_DIR
from .gpu_decorator import GPU
def _quantize_model(model, qtype: str):
"""Apply torchao quantization to a model using quantize_."""
from torchao.quantization import quantize_
if qtype == "int8":
from torchao.quantization import Int8WeightOnlyConfig
print(" ⚙️ Quantizing with int8_weight_only...")
config = Int8WeightOnlyConfig()
quantize_(model, config)
elif qtype == "int4":
from torchao.quantization import Int4WeightOnlyConfig
print(" ⚙️ Quantizing with int4_weight_only (group_size=32)...")
config = Int4WeightOnlyConfig(group_size=32)
quantize_(model, config)
elif qtype == "float8":
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
print(" ⚙️ Quantizing with float8_dynamic_activation_float8_weight...")
config = Float8DynamicActivationFloat8WeightConfig()
quantize_(model, config)
else:
raise ValueError(f"Unsupported qtype: {qtype}. Must be one of: int8, int4, float8")
@GPU(duration=180)
def _extract_and_save(pipe, include_lora, quantize, qtype, save_format):
"""GPU-decorated helper that extracts weights and saves the model."""
if include_lora:
try:
pipe.unload_lora_weights()
except Exception as e:
print(f" ℹ️ Could not unload LoRAs: {e}")
# Quantize components in-place before extracting state dicts
if quantize and qtype != "none":
_quantize_model(pipe.unet, qtype)
# torchao quantized tensors cannot be saved with safetensors, use torch.save instead
# Don't dequantize - keep the quantized format for smaller file size
merged_state_dict = {}
# Extract UNet weights
for k, v in pipe.unet.state_dict().items():
# For quantized tensors, save directly; otherwise convert to half
if hasattr(v, 'dequantize'):
# Keep quantized tensor as-is for smaller file size
merged_state_dict[f"unet.{k}"] = v
else:
merged_state_dict[f"unet.{k}"] = v.contiguous().half()
# Extract text encoder weights
if pipe.text_encoder is not None:
for k, v in pipe.text_encoder.state_dict().items():
if hasattr(v, 'dequantize'):
merged_state_dict[f"text_encoder.{k}"] = v
else:
merged_state_dict[f"text_encoder.{k}"] = v.contiguous().half()
if pipe.text_encoder_2 is not None:
for k, v in pipe.text_encoder_2.state_dict().items():
if hasattr(v, 'dequantize'):
merged_state_dict[f"text_encoder_2.{k}"] = v
else:
merged_state_dict[f"text_encoder_2.{k}"] = v.contiguous().half()
# Extract VAE weights
if pipe.vae is not None:
for k, v in pipe.vae.state_dict().items():
if hasattr(v, 'dequantize'):
merged_state_dict[f"first_stage_model.{k}"] = v
else:
merged_state_dict[f"first_stage_model.{k}"] = v.contiguous().half()
# Save model
ext = ".bin" if save_format == "bin" else ".safetensors"
prefix = f"{qtype}_" if quantize and qtype != "none" else ""
out_path = SCRIPT_DIR / f"merged_{prefix}checkpoint{ext}"
if quantize and qtype != "none":
# torchao quantized tensors are not compatible with safetensors
# Use torch.save instead which preserves the quantization format
ext = ".pt"
out_path = SCRIPT_DIR / f"merged_{qtype}_checkpoint.pt"
torch.save(merged_state_dict, str(out_path))
elif ext == ".bin":
torch.save(merged_state_dict, str(out_path))
else:
save_file(merged_state_dict, str(out_path))
return out_path
def export_merged_model(
include_lora: bool,
quantize: bool,
qtype: str,
save_format: str = "safetensors",
):
"""
Export the merged pipeline model with optional LoRA baking and quantization.
Args:
include_lora: Whether to include fused LoRAs in export
quantize: Whether to apply quantization
qtype: Quantization type - 'none', 'int8', 'int4', or 'float8'
save_format: Output format - 'safetensors' or 'bin'
Returns:
Tuple of (output_path or None, status message)
"""
# Fetch the pipeline at call time — avoids the stale import-by-value problem.
pipe = config.get_pipe()
if not pipe:
return None, "⚠️ Please load a pipeline first."
try:
# Validate quantization type
valid_qtypes = ("none", "int8", "int4", "float8")
if qtype not in valid_qtypes:
return None, f"❌ Invalid quantization type: {qtype}. Must be one of: {valid_qtypes}"
out_path = _extract_and_save(pipe, include_lora, quantize, qtype, save_format)
size_gb = out_path.stat().st_size / 1024**3
if quantize and qtype != "none":
msg = f"✅ Quantized checkpoint saved: `{out_path}` ({size_gb:.2f} GB)"
else:
msg = f"✅ Merged checkpoint saved: `{out_path}` ({size_gb:.2f} GB)"
return str(out_path), msg
except ImportError as e:
return None, f"❌ Missing dependency: {str(e)}"
except Exception as e:
import traceback
print(traceback.format_exc())
return None, f"❌ Export failed: {str(e)}"
def get_export_status() -> str:
"""Get current export capability status."""
try:
from torchao.quantization import quantize_, Int4WeightOnlyConfig, Int8WeightOnlyConfig, Float8DynamicActivationFloat8WeightConfig
return "✅ torchao available for quantization"
except ImportError:
return "ℹ️ Install torchao for quantization support: pip install torchao"
|