File size: 4,767 Bytes
1181cf5 cbd07bb 3475425 cbd07bb 3475425 1181cf5 cbd07bb 1181cf5 cbd07bb 1181cf5 cbd07bb eeb9cad cbd07bb 1181cf5 cbd07bb 1181cf5 cbd07bb 1181cf5 cbd07bb 1181cf5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from typing import List, Optional
from diffusers.modular_pipelines import (
InputParam,
OutputParam,
ModularPipelineBlocks,
PipelineState,
)
class QuantizationConfigBlock(ModularPipelineBlocks):
"""Block to create BitsAndBytes quantization config for model loading."""
@property
def description(self) -> str:
return "Creates a BitsAndBytes quantization config for loading models with reduced precision"
@property
def inputs(self) -> List[InputParam]:
return [
# Target component
InputParam(
"component",
type_hint=str,
default="transformer",
description="Component name to apply quantization to",
metadata={"mellon": "dropdown"}
),
# Bits selection
InputParam(
"quant_type",
type_hint=str,
default="bnb_4bit",
description="Quantization backend Type",
metadata={"mellon": "dropdown"}, # "options": ["bnb_4bit", "bnb_8bit"]
),
# ===== 4-bit options =====
InputParam(
"bnb_4bit_quant_type",
type_hint=str,
default="nf4",
description="4-bit quantization type",
metadata={"mellon": "dropdown"}, # "options": ["nf4", "fp4"]
),
InputParam(
"bnb_4bit_compute_dtype",
type_hint=Optional[str],
description="Compute dtype for 4-bit quantization",
metadata={"mellon": "dropdown"}, # "options": ["", "float32", "float16", "bfloat16"]
),
InputParam(
"bnb_4bit_use_double_quant",
type_hint=bool,
default=False,
description="Use nested quantization (quantize the quantization constants)",
metadata={"mellon": "checkbox"}
),
# ===== 8-bit options =====
InputParam(
"llm_int8_threshold",
type_hint=float,
default=6.0,
description="Outlier threshold for 8-bit quantization (values above this use fp16)",
metadata={"mellon": "slider"},
),
InputParam(
"llm_int8_has_fp16_weight",
type_hint=bool,
default=False,
description="Keep weights in fp16 for 8-bit (useful for fine-tuning)",
metadata={"mellon": "checkbox"},
),
InputParam(
"llm_int8_skip_modules",
type_hint=Optional[List[str]],
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"quantization_config",
type_hint=dict,
description="Quantization config dict for load_components",
),
]
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
import torch
from diffusers import BitsAndBytesConfig
block_state = self.get_block_state(state)
# Map string dtype to torch dtype
def str_to_dtype(dtype_str):
dtype_map = {
"": None,
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"uint8": torch.uint8,
"int8": torch.int8,
"float64": torch.float64,
}
return dtype_map.get(dtype_str, None)
if block_state.quant_type == "bnb_4bit":
config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type=block_state.bnb_4bit_quant_type,
bnb_4bit_compute_dtype=str_to_dtype(block_state.bnb_4bit_compute_dtype),
bnb_4bit_use_double_quant=block_state.bnb_4bit_use_double_quant,
llm_int8_skip_modules=block_state.llm_int8_skip_modules,
)
elif block_state.quant_type == "bnb_8bit":
config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=block_state.llm_int8_threshold,
llm_int8_has_fp16_weight=block_state.llm_int8_has_fp16_weight,
llm_int8_skip_modules=block_state.llm_int8_skip_modules,
)
# Output as dict: {"transformer": config}
quantization_config = {block_state.component: config}
block_state.quantization_config = quantization_config
self.set_block_state(state, block_state)
return pipeline, state |