File size: 5,910 Bytes
6a07ce1
 
 
 
 
570384a
 
3631a8e
 
 
b89e643
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3631a8e
 
 
 
 
 
 
 
 
b89e643
 
 
 
 
 
3631a8e
 
 
 
b89e643
 
 
 
 
 
3631a8e
 
 
 
b89e643
 
 
 
3631a8e
 
b89e643
 
 
 
3631a8e
 
 
 
b89e643
 
 
 
3631a8e
 
 
 
 
 
b89e643
 
 
 
 
 
 
3631a8e
 
 
 
 
6a07ce1
 
 
 
 
 
 
570384a
6a07ce1
 
 
 
 
 
 
 
 
 
b89e643
6a07ce1
570384a
 
 
 
b89e643
6a07ce1
 
459ac47
 
 
b89e643
6a07ce1
3631a8e
6a07ce1
 
 
 
 
 
 
 
b89e643
6a07ce1
 
b89e643
6a07ce1
570384a
 
b89e643
6a07ce1
 
 
 
 
b89e643
 
6a07ce1
b89e643
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""Model export functionality for SDXL Model Merger."""

import torch
from safetensors.torch import save_file

from . import config
from .config import SCRIPT_DIR
from .gpu_decorator import GPU


def _quantize_model(model, qtype: str):
    """Apply torchao quantization to a model using quantize_."""
    from torchao.quantization import quantize_

    if qtype == "int8":
        from torchao.quantization import Int8WeightOnlyConfig

        print("  ⚙️ Quantizing with int8_weight_only...")
        config = Int8WeightOnlyConfig()
        quantize_(model, config)

    elif qtype == "int4":
        from torchao.quantization import Int4WeightOnlyConfig

        print("  ⚙️ Quantizing with int4_weight_only (group_size=32)...")
        config = Int4WeightOnlyConfig(group_size=32)
        quantize_(model, config)

    elif qtype == "float8":
        from torchao.quantization import Float8DynamicActivationFloat8WeightConfig

        print("  ⚙️ Quantizing with float8_dynamic_activation_float8_weight...")
        config = Float8DynamicActivationFloat8WeightConfig()
        quantize_(model, config)

    else:
        raise ValueError(f"Unsupported qtype: {qtype}. Must be one of: int8, int4, float8")


@GPU(duration=180)
def _extract_and_save(pipe, include_lora, quantize, qtype, save_format):
    """GPU-decorated helper that extracts weights and saves the model."""
    if include_lora:
        try:
            pipe.unload_lora_weights()
        except Exception as e:
            print(f"  ℹ️ Could not unload LoRAs: {e}")

    # Quantize components in-place before extracting state dicts
    if quantize and qtype != "none":
        _quantize_model(pipe.unet, qtype)
        # torchao quantized tensors cannot be saved with safetensors, use torch.save instead
        # Don't dequantize - keep the quantized format for smaller file size

    merged_state_dict = {}

    # Extract UNet weights
    for k, v in pipe.unet.state_dict().items():
        # For quantized tensors, save directly; otherwise convert to half
        if hasattr(v, 'dequantize'):
            # Keep quantized tensor as-is for smaller file size
            merged_state_dict[f"unet.{k}"] = v
        else:
            merged_state_dict[f"unet.{k}"] = v.contiguous().half()

    # Extract text encoder weights
    if pipe.text_encoder is not None:
        for k, v in pipe.text_encoder.state_dict().items():
            if hasattr(v, 'dequantize'):
                merged_state_dict[f"text_encoder.{k}"] = v
            else:
                merged_state_dict[f"text_encoder.{k}"] = v.contiguous().half()
    if pipe.text_encoder_2 is not None:
        for k, v in pipe.text_encoder_2.state_dict().items():
            if hasattr(v, 'dequantize'):
                merged_state_dict[f"text_encoder_2.{k}"] = v
            else:
                merged_state_dict[f"text_encoder_2.{k}"] = v.contiguous().half()

    # Extract VAE weights
    if pipe.vae is not None:
        for k, v in pipe.vae.state_dict().items():
            if hasattr(v, 'dequantize'):
                merged_state_dict[f"first_stage_model.{k}"] = v
            else:
                merged_state_dict[f"first_stage_model.{k}"] = v.contiguous().half()

    # Save model
    ext = ".bin" if save_format == "bin" else ".safetensors"
    prefix = f"{qtype}_" if quantize and qtype != "none" else ""
    out_path = SCRIPT_DIR / f"merged_{prefix}checkpoint{ext}"

    if quantize and qtype != "none":
        # torchao quantized tensors are not compatible with safetensors
        # Use torch.save instead which preserves the quantization format
        ext = ".pt"
        out_path = SCRIPT_DIR / f"merged_{qtype}_checkpoint.pt"
        torch.save(merged_state_dict, str(out_path))
    elif ext == ".bin":
        torch.save(merged_state_dict, str(out_path))
    else:
        save_file(merged_state_dict, str(out_path))

    return out_path


def export_merged_model(
    include_lora: bool,
    quantize: bool,
    qtype: str,
    save_format: str = "safetensors",
):
    """
    Export the merged pipeline model with optional LoRA baking and quantization.

    Args:
        include_lora: Whether to include fused LoRAs in export
        quantize: Whether to apply quantization
        qtype: Quantization type - 'none', 'int8', 'int4', or 'float8'
        save_format: Output format - 'safetensors' or 'bin'

    Returns:
        Tuple of (output_path or None, status message)
    """
    # Fetch the pipeline at call time — avoids the stale import-by-value problem.
    pipe = config.get_pipe()

    if not pipe:
        return None, "⚠️ Please load a pipeline first."

    try:
        # Validate quantization type
        valid_qtypes = ("none", "int8", "int4", "float8")
        if qtype not in valid_qtypes:
            return None, f"❌ Invalid quantization type: {qtype}. Must be one of: {valid_qtypes}"

        out_path = _extract_and_save(pipe, include_lora, quantize, qtype, save_format)

        size_gb = out_path.stat().st_size / 1024**3

        if quantize and qtype != "none":
            msg = f"✅ Quantized checkpoint saved: `{out_path}` ({size_gb:.2f} GB)"
        else:
            msg = f"✅ Merged checkpoint saved: `{out_path}` ({size_gb:.2f} GB)"

        return str(out_path), msg

    except ImportError as e:
        return None, f"❌ Missing dependency: {str(e)}"
    except Exception as e:
        import traceback
        print(traceback.format_exc())
        return None, f"❌ Export failed: {str(e)}"


def get_export_status() -> str:
    """Get current export capability status."""
    try:
        from torchao.quantization import quantize_, Int4WeightOnlyConfig, Int8WeightOnlyConfig, Float8DynamicActivationFloat8WeightConfig
        return "✅ torchao available for quantization"
    except ImportError:
        return "ℹ️ Install torchao for quantization support: pip install torchao"