|
|
from typing import List, Optional |
|
|
from diffusers.modular_pipelines import ( |
|
|
InputParam, |
|
|
OutputParam, |
|
|
ModularPipelineBlocks, |
|
|
PipelineState, |
|
|
) |
|
|
|
|
|
|
|
|
class QuantizationConfigBlock(ModularPipelineBlocks): |
|
|
"""Block to create BitsAndBytes quantization config for model loading.""" |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return "Creates a BitsAndBytes quantization config for loading models with reduced precision" |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
|
|
|
InputParam( |
|
|
"component", |
|
|
type_hint=str, |
|
|
default="transformer", |
|
|
description="Component name to apply quantization to", |
|
|
metadata={"mellon": "dropdown"} |
|
|
), |
|
|
|
|
|
InputParam( |
|
|
"quant_type", |
|
|
type_hint=str, |
|
|
default="bnb_4bit", |
|
|
description="Quantization backend Type", |
|
|
metadata={"mellon": "dropdown"}, |
|
|
), |
|
|
|
|
|
|
|
|
InputParam( |
|
|
"bnb_4bit_quant_type", |
|
|
type_hint=str, |
|
|
default="nf4", |
|
|
description="4-bit quantization type", |
|
|
metadata={"mellon": "dropdown"}, |
|
|
), |
|
|
InputParam( |
|
|
"bnb_4bit_compute_dtype", |
|
|
type_hint=Optional[str], |
|
|
description="Compute dtype for 4-bit quantization", |
|
|
metadata={"mellon": "dropdown"}, |
|
|
), |
|
|
InputParam( |
|
|
"bnb_4bit_use_double_quant", |
|
|
type_hint=bool, |
|
|
default=False, |
|
|
description="Use nested quantization (quantize the quantization constants)", |
|
|
metadata={"mellon": "checkbox"} |
|
|
), |
|
|
|
|
|
|
|
|
InputParam( |
|
|
"llm_int8_threshold", |
|
|
type_hint=float, |
|
|
default=6.0, |
|
|
description="Outlier threshold for 8-bit quantization (values above this use fp16)", |
|
|
metadata={"mellon": "slider"}, |
|
|
), |
|
|
InputParam( |
|
|
"llm_int8_has_fp16_weight", |
|
|
type_hint=bool, |
|
|
default=False, |
|
|
description="Keep weights in fp16 for 8-bit (useful for fine-tuning)", |
|
|
metadata={"mellon": "checkbox"}, |
|
|
), |
|
|
InputParam( |
|
|
"llm_int8_skip_modules", |
|
|
type_hint=Optional[List[str]], |
|
|
), |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"quantization_config", |
|
|
type_hint=dict, |
|
|
description="Quantization config dict for load_components", |
|
|
), |
|
|
] |
|
|
|
|
|
def __call__(self, pipeline, state: PipelineState) -> PipelineState: |
|
|
import torch |
|
|
from diffusers import BitsAndBytesConfig |
|
|
|
|
|
block_state = self.get_block_state(state) |
|
|
|
|
|
|
|
|
def str_to_dtype(dtype_str): |
|
|
dtype_map = { |
|
|
"": None, |
|
|
"float32": torch.float32, |
|
|
"float16": torch.float16, |
|
|
"bfloat16": torch.bfloat16, |
|
|
"uint8": torch.uint8, |
|
|
"int8": torch.int8, |
|
|
"float64": torch.float64, |
|
|
} |
|
|
return dtype_map.get(dtype_str, None) |
|
|
|
|
|
if block_state.quant_type == "bnb_4bit": |
|
|
config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_quant_type=block_state.bnb_4bit_quant_type, |
|
|
bnb_4bit_compute_dtype=str_to_dtype(block_state.bnb_4bit_compute_dtype), |
|
|
bnb_4bit_use_double_quant=block_state.bnb_4bit_use_double_quant, |
|
|
llm_int8_skip_modules=block_state.llm_int8_skip_modules, |
|
|
) |
|
|
elif block_state.quant_type == "bnb_8bit": |
|
|
config = BitsAndBytesConfig( |
|
|
load_in_8bit=True, |
|
|
llm_int8_threshold=block_state.llm_int8_threshold, |
|
|
llm_int8_has_fp16_weight=block_state.llm_int8_has_fp16_weight, |
|
|
llm_int8_skip_modules=block_state.llm_int8_skip_modules, |
|
|
) |
|
|
|
|
|
|
|
|
quantization_config = {block_state.component: config} |
|
|
|
|
|
block_state.quantization_config = quantization_config |
|
|
self.set_block_state(state, block_state) |
|
|
return pipeline, state |