from typing import List, Optional from diffusers.modular_pipelines import ( InputParam, OutputParam, ModularPipelineBlocks, PipelineState, ) class QuantizationConfigBlock(ModularPipelineBlocks): """Block to create BitsAndBytes quantization config for model loading.""" @property def description(self) -> str: return "Creates a BitsAndBytes quantization config for loading models with reduced precision" @property def inputs(self) -> List[InputParam]: return [ # Target component InputParam( "component", type_hint=str, default="transformer", description="Component name to apply quantization to", metadata={"mellon": "dropdown"} ), # Bits selection InputParam( "quant_type", type_hint=str, default="bnb_4bit", description="Quantization backend Type", metadata={"mellon": "dropdown"}, # "options": ["bnb_4bit", "bnb_8bit"] ), # ===== 4-bit options ===== InputParam( "bnb_4bit_quant_type", type_hint=str, default="nf4", description="4-bit quantization type", metadata={"mellon": "dropdown"}, # "options": ["nf4", "fp4"] ), InputParam( "bnb_4bit_compute_dtype", type_hint=Optional[str], description="Compute dtype for 4-bit quantization", metadata={"mellon": "dropdown"}, # "options": ["", "float32", "float16", "bfloat16"] ), InputParam( "bnb_4bit_use_double_quant", type_hint=bool, default=False, description="Use nested quantization (quantize the quantization constants)", metadata={"mellon": "checkbox"} ), # ===== 8-bit options ===== InputParam( "llm_int8_threshold", type_hint=float, default=6.0, description="Outlier threshold for 8-bit quantization (values above this use fp16)", metadata={"mellon": "slider"}, ), InputParam( "llm_int8_has_fp16_weight", type_hint=bool, default=False, description="Keep weights in fp16 for 8-bit (useful for fine-tuning)", metadata={"mellon": "checkbox"}, ), InputParam( "llm_int8_skip_modules", type_hint=Optional[List[str]], ), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( "quantization_config", type_hint=dict, description="Quantization config dict for load_components", ), ] def __call__(self, pipeline, state: PipelineState) -> PipelineState: import torch from diffusers import BitsAndBytesConfig block_state = self.get_block_state(state) # Map string dtype to torch dtype def str_to_dtype(dtype_str): dtype_map = { "": None, "float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, "uint8": torch.uint8, "int8": torch.int8, "float64": torch.float64, } return dtype_map.get(dtype_str, None) if block_state.quant_type == "bnb_4bit": config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type=block_state.bnb_4bit_quant_type, bnb_4bit_compute_dtype=str_to_dtype(block_state.bnb_4bit_compute_dtype), bnb_4bit_use_double_quant=block_state.bnb_4bit_use_double_quant, llm_int8_skip_modules=block_state.llm_int8_skip_modules, ) elif block_state.quant_type == "bnb_8bit": config = BitsAndBytesConfig( load_in_8bit=True, llm_int8_threshold=block_state.llm_int8_threshold, llm_int8_has_fp16_weight=block_state.llm_int8_has_fp16_weight, llm_int8_skip_modules=block_state.llm_int8_skip_modules, ) # Output as dict: {"transformer": config} quantization_config = {block_state.component: config} block_state.quantization_config = quantization_config self.set_block_state(state, block_state) return pipeline, state