Update block.py
Browse files
block.py
CHANGED
|
@@ -1,73 +1,130 @@
|
|
| 1 |
-
from typing import List
|
| 2 |
-
|
| 3 |
from diffusers.modular_pipelines import (
|
| 4 |
-
ComponentSpec,
|
| 5 |
InputParam,
|
| 6 |
-
ModularPipelineBlocks,
|
| 7 |
OutputParam,
|
|
|
|
| 8 |
PipelineState,
|
| 9 |
)
|
| 10 |
|
| 11 |
|
| 12 |
-
class
|
| 13 |
-
"""
|
| 14 |
-
A custom block for [describe what your block does].
|
| 15 |
-
|
| 16 |
-
Replace this with your implementation.
|
| 17 |
-
"""
|
| 18 |
|
| 19 |
@property
|
| 20 |
def description(self) -> str:
|
| 21 |
-
"
|
| 22 |
-
return "A template custom block - replace with your description"
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
@property
|
| 26 |
-
def expected_components(self) -> List[ComponentSpec]:
|
| 27 |
-
"""Define model components your block needs (e.g., transformers, VAEs)."""
|
| 28 |
-
return [
|
| 29 |
-
# Example:
|
| 30 |
-
# ComponentSpec(
|
| 31 |
-
# name="model",
|
| 32 |
-
# type_hint=SomeModelClass,
|
| 33 |
-
# repo="organization/model-name",
|
| 34 |
-
# ),
|
| 35 |
-
]
|
| 36 |
|
| 37 |
@property
|
| 38 |
def inputs(self) -> List[InputParam]:
|
| 39 |
-
"""Define input parameters for your block."""
|
| 40 |
return [
|
| 41 |
-
#
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
]
|
| 50 |
|
| 51 |
@property
|
| 52 |
-
def
|
| 53 |
-
"""Define output parameters for your block."""
|
| 54 |
return [
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
# metadata={"mellon": "text"}, # For Mellon UI
|
| 61 |
-
# ),
|
| 62 |
]
|
| 63 |
|
| 64 |
-
def __call__(self,
|
| 65 |
-
|
|
|
|
|
|
|
| 66 |
block_state = self.get_block_state(state)
|
| 67 |
-
|
| 68 |
-
#
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
self.set_block_state(state, block_state)
|
| 73 |
-
return
|
|
|
|
| 1 |
+
from typing import List, Optional
|
|
|
|
| 2 |
from diffusers.modular_pipelines import (
|
|
|
|
| 3 |
InputParam,
|
|
|
|
| 4 |
OutputParam,
|
| 5 |
+
PipelineBlock,
|
| 6 |
PipelineState,
|
| 7 |
)
|
| 8 |
|
| 9 |
|
| 10 |
+
class QuantizationConfigBlock(PipelineBlock):
|
| 11 |
+
"""Block to create BitsAndBytes quantization config for model loading."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
@property
|
| 14 |
def description(self) -> str:
|
| 15 |
+
return "Creates a BitsAndBytes quantization config for loading models with reduced precision"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
@property
|
| 18 |
def inputs(self) -> List[InputParam]:
|
|
|
|
| 19 |
return [
|
| 20 |
+
# Target component
|
| 21 |
+
InputParam(
|
| 22 |
+
"component",
|
| 23 |
+
type_hint=str,
|
| 24 |
+
default="transformer",
|
| 25 |
+
description="Component name to apply quantization to",
|
| 26 |
+
metadata={"mellon": "dropdown"}
|
| 27 |
+
),
|
| 28 |
+
# Bits selection
|
| 29 |
+
InputParam(
|
| 30 |
+
"quant_type",
|
| 31 |
+
type_hint=str,
|
| 32 |
+
default="bnb_4bit",
|
| 33 |
+
description="Quantization backend Type",
|
| 34 |
+
metadata={"mellon": "dropdown"}, # "options": ["bnb_4bit", "bnb_8bit"]
|
| 35 |
+
),
|
| 36 |
+
|
| 37 |
+
# ===== 4-bit options =====
|
| 38 |
+
InputParam(
|
| 39 |
+
"bnb_4bit_quant_type",
|
| 40 |
+
type_hint=str,
|
| 41 |
+
default="nf4",
|
| 42 |
+
description="4-bit quantization type",
|
| 43 |
+
metadata={"mellon": "dropdown"}, # "options": ["nf4", "fp4"]
|
| 44 |
+
),
|
| 45 |
+
InputParam(
|
| 46 |
+
"bnb_4bit_compute_dtype",
|
| 47 |
+
type_hint=Optional[str],
|
| 48 |
+
description="Compute dtype for 4-bit quantization",
|
| 49 |
+
metadata={"mellon": "dropdown"}, # "options": ["", "float32", "float16", "bfloat16"]
|
| 50 |
+
),
|
| 51 |
+
InputParam(
|
| 52 |
+
"bnb_4bit_use_double_quant",
|
| 53 |
+
type_hint=bool,
|
| 54 |
+
default=False,
|
| 55 |
+
description="Use nested quantization (quantize the quantization constants)",
|
| 56 |
+
metadata={"mellon": "checkbox"}
|
| 57 |
+
),
|
| 58 |
+
|
| 59 |
+
# ===== 8-bit options =====
|
| 60 |
+
InputParam(
|
| 61 |
+
"llm_int8_threshold",
|
| 62 |
+
type_hint=float,
|
| 63 |
+
default=6.0,
|
| 64 |
+
description="Outlier threshold for 8-bit quantization (values above this use fp16)",
|
| 65 |
+
metadata={"mellon": "slider"},
|
| 66 |
+
),
|
| 67 |
+
InputParam(
|
| 68 |
+
"llm_int8_has_fp16_weight",
|
| 69 |
+
type_hint=bool,
|
| 70 |
+
default=False,
|
| 71 |
+
description="Keep weights in fp16 for 8-bit (useful for fine-tuning)",
|
| 72 |
+
metadata={"mellon": "checkbox"},
|
| 73 |
+
),
|
| 74 |
+
InputParam(
|
| 75 |
+
"llm_int8_skip_modules",
|
| 76 |
+
type_hint=Optional[List[str]],
|
| 77 |
+
),
|
| 78 |
]
|
| 79 |
|
| 80 |
@property
|
| 81 |
+
def intermediates_outputs(self) -> List[OutputParam]:
|
|
|
|
| 82 |
return [
|
| 83 |
+
OutputParam(
|
| 84 |
+
"quantization_config",
|
| 85 |
+
type_hint=dict,
|
| 86 |
+
description="Quantization config dict for load_components",
|
| 87 |
+
),
|
|
|
|
|
|
|
| 88 |
]
|
| 89 |
|
| 90 |
+
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
|
| 91 |
+
import torch
|
| 92 |
+
from diffusers import BitsAndBytesConfig
|
| 93 |
+
|
| 94 |
block_state = self.get_block_state(state)
|
| 95 |
+
|
| 96 |
+
# Map string dtype to torch dtype
|
| 97 |
+
def str_to_dtype(dtype_str):
|
| 98 |
+
dtype_map = {
|
| 99 |
+
"": None,
|
| 100 |
+
"float32": torch.float32,
|
| 101 |
+
"float16": torch.float16,
|
| 102 |
+
"bfloat16": torch.bfloat16,
|
| 103 |
+
"uint8": torch.uint8,
|
| 104 |
+
"int8": torch.int8,
|
| 105 |
+
"float64": torch.float64,
|
| 106 |
+
}
|
| 107 |
+
return dtype_map.get(dtype_str, None)
|
| 108 |
+
|
| 109 |
+
if block_state.quant_type == "bnb_4bit":
|
| 110 |
+
config = BitsAndBytesConfig(
|
| 111 |
+
load_in_4bit=True,
|
| 112 |
+
bnb_4bit_quant_type=block_state.bnb_4bit_quant_type,
|
| 113 |
+
bnb_4bit_compute_dtype=str_to_dtype(block_state.bnb_4bit_compute_dtype),
|
| 114 |
+
bnb_4bit_use_double_quant=block_state.bnb_4bit_use_double_quant,
|
| 115 |
+
llm_int8_skip_modules=block_state.llm_int8_skip_modules,
|
| 116 |
+
)
|
| 117 |
+
elif block_state.quant_type == "bnb_8bit":
|
| 118 |
+
config = BitsAndBytesConfig(
|
| 119 |
+
load_in_8bit=True,
|
| 120 |
+
llm_int8_threshold=block_state.llm_int8_threshold,
|
| 121 |
+
llm_int8_has_fp16_weight=block_state.llm_int8_has_fp16_weight,
|
| 122 |
+
llm_int8_skip_modules=block_state.llm_int8_skip_modules,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Output as dict: {"transformer": config}
|
| 126 |
+
quantization_config = {block_state.component: config}
|
| 127 |
+
|
| 128 |
+
block_state.quantization_config = quantization_config
|
| 129 |
self.set_block_state(state, block_state)
|
| 130 |
+
return pipeline, state
|