quant-block / block.py
YiYiXu's picture
YiYiXu HF Staff
Update block.py
eeb9cad verified
from typing import List, Optional
from diffusers.modular_pipelines import (
InputParam,
OutputParam,
ModularPipelineBlocks,
PipelineState,
)
class QuantizationConfigBlock(ModularPipelineBlocks):
"""Block to create BitsAndBytes quantization config for model loading."""
@property
def description(self) -> str:
return "Creates a BitsAndBytes quantization config for loading models with reduced precision"
@property
def inputs(self) -> List[InputParam]:
return [
# Target component
InputParam(
"component",
type_hint=str,
default="transformer",
description="Component name to apply quantization to",
metadata={"mellon": "dropdown"}
),
# Bits selection
InputParam(
"quant_type",
type_hint=str,
default="bnb_4bit",
description="Quantization backend Type",
metadata={"mellon": "dropdown"}, # "options": ["bnb_4bit", "bnb_8bit"]
),
# ===== 4-bit options =====
InputParam(
"bnb_4bit_quant_type",
type_hint=str,
default="nf4",
description="4-bit quantization type",
metadata={"mellon": "dropdown"}, # "options": ["nf4", "fp4"]
),
InputParam(
"bnb_4bit_compute_dtype",
type_hint=Optional[str],
description="Compute dtype for 4-bit quantization",
metadata={"mellon": "dropdown"}, # "options": ["", "float32", "float16", "bfloat16"]
),
InputParam(
"bnb_4bit_use_double_quant",
type_hint=bool,
default=False,
description="Use nested quantization (quantize the quantization constants)",
metadata={"mellon": "checkbox"}
),
# ===== 8-bit options =====
InputParam(
"llm_int8_threshold",
type_hint=float,
default=6.0,
description="Outlier threshold for 8-bit quantization (values above this use fp16)",
metadata={"mellon": "slider"},
),
InputParam(
"llm_int8_has_fp16_weight",
type_hint=bool,
default=False,
description="Keep weights in fp16 for 8-bit (useful for fine-tuning)",
metadata={"mellon": "checkbox"},
),
InputParam(
"llm_int8_skip_modules",
type_hint=Optional[List[str]],
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"quantization_config",
type_hint=dict,
description="Quantization config dict for load_components",
),
]
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
import torch
from diffusers import BitsAndBytesConfig
block_state = self.get_block_state(state)
# Map string dtype to torch dtype
def str_to_dtype(dtype_str):
dtype_map = {
"": None,
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"uint8": torch.uint8,
"int8": torch.int8,
"float64": torch.float64,
}
return dtype_map.get(dtype_str, None)
if block_state.quant_type == "bnb_4bit":
config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type=block_state.bnb_4bit_quant_type,
bnb_4bit_compute_dtype=str_to_dtype(block_state.bnb_4bit_compute_dtype),
bnb_4bit_use_double_quant=block_state.bnb_4bit_use_double_quant,
llm_int8_skip_modules=block_state.llm_int8_skip_modules,
)
elif block_state.quant_type == "bnb_8bit":
config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=block_state.llm_int8_threshold,
llm_int8_has_fp16_weight=block_state.llm_int8_has_fp16_weight,
llm_int8_skip_modules=block_state.llm_int8_skip_modules,
)
# Output as dict: {"transformer": config}
quantization_config = {block_state.component: config}
block_state.quantization_config = quantization_config
self.set_block_state(state, block_state)
return pipeline, state