File size: 4,767 Bytes
1181cf5
cbd07bb
 
 
3475425
cbd07bb
 
 
 
3475425
1181cf5
cbd07bb
 
 
1181cf5
cbd07bb
 
 
 
1181cf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbd07bb
 
 
eeb9cad
cbd07bb
1181cf5
 
 
 
 
cbd07bb
 
1181cf5
 
 
 
cbd07bb
1181cf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbd07bb
1181cf5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from typing import List, Optional
from diffusers.modular_pipelines import (
    InputParam,
    OutputParam,
    ModularPipelineBlocks,
    PipelineState,
)


class QuantizationConfigBlock(ModularPipelineBlocks):
    """Block to create BitsAndBytes quantization config for model loading."""

    @property
    def description(self) -> str:
        return "Creates a BitsAndBytes quantization config for loading models with reduced precision"

    @property
    def inputs(self) -> List[InputParam]:
        return [
            # Target component
            InputParam(
                "component",
                type_hint=str,
                default="transformer",
                description="Component name to apply quantization to",
                metadata={"mellon": "dropdown"}
            ),
            # Bits selection
            InputParam(
                "quant_type",
                type_hint=str,
                default="bnb_4bit",
                description="Quantization backend Type",
                metadata={"mellon": "dropdown"},  # "options": ["bnb_4bit", "bnb_8bit"]
            ),
            
            # ===== 4-bit options =====
            InputParam(
                "bnb_4bit_quant_type",
                type_hint=str,
                default="nf4",
                description="4-bit quantization type",
                metadata={"mellon": "dropdown"},  # "options": ["nf4", "fp4"]
            ),
            InputParam(
                "bnb_4bit_compute_dtype",
                type_hint=Optional[str],
                description="Compute dtype for 4-bit quantization",
                metadata={"mellon": "dropdown"},  # "options": ["", "float32", "float16", "bfloat16"]
            ),
            InputParam(
                "bnb_4bit_use_double_quant",
                type_hint=bool,
                default=False,
                description="Use nested quantization (quantize the quantization constants)",
                metadata={"mellon": "checkbox"}
            ),
            
            # ===== 8-bit options =====
            InputParam(
                "llm_int8_threshold",
                type_hint=float,
                default=6.0,
                description="Outlier threshold for 8-bit quantization (values above this use fp16)",
                metadata={"mellon": "slider"},
            ),
            InputParam(
                "llm_int8_has_fp16_weight",
                type_hint=bool,
                default=False,
                description="Keep weights in fp16 for 8-bit (useful for fine-tuning)",
                metadata={"mellon": "checkbox"},
            ),
            InputParam(
                "llm_int8_skip_modules",
                type_hint=Optional[List[str]],
            ),
        ]

    @property
    def intermediate_outputs(self) -> List[OutputParam]:
        return [
            OutputParam(
                "quantization_config",
                type_hint=dict,
                description="Quantization config dict for load_components",
            ),
        ]

    def __call__(self, pipeline, state: PipelineState) -> PipelineState:
        import torch
        from diffusers import BitsAndBytesConfig

        block_state = self.get_block_state(state)

        # Map string dtype to torch dtype
        def str_to_dtype(dtype_str):
            dtype_map = {
                "": None,
                "float32": torch.float32,
                "float16": torch.float16,
                "bfloat16": torch.bfloat16,
                "uint8": torch.uint8,
                "int8": torch.int8,
                "float64": torch.float64,
            }
            return dtype_map.get(dtype_str, None)

        if block_state.quant_type == "bnb_4bit":
            config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type=block_state.bnb_4bit_quant_type,
                bnb_4bit_compute_dtype=str_to_dtype(block_state.bnb_4bit_compute_dtype),
                bnb_4bit_use_double_quant=block_state.bnb_4bit_use_double_quant,
                llm_int8_skip_modules=block_state.llm_int8_skip_modules,
            )
        elif block_state.quant_type == "bnb_8bit":
            config = BitsAndBytesConfig(
                load_in_8bit=True,
                llm_int8_threshold=block_state.llm_int8_threshold,
                llm_int8_has_fp16_weight=block_state.llm_int8_has_fp16_weight,
                llm_int8_skip_modules=block_state.llm_int8_skip_modules,
            )

        # Output as dict: {"transformer": config}
        quantization_config = {block_state.component: config}

        block_state.quantization_config = quantization_config
        self.set_block_state(state, block_state)
        return pipeline, state