YiYiXu HF Staff commited on
Commit
1181cf5
·
verified ·
1 Parent(s): cbd07bb

Update block.py

Browse files
Files changed (1) hide show
  1. block.py +108 -51
block.py CHANGED
@@ -1,73 +1,130 @@
1
- from typing import List
2
-
3
  from diffusers.modular_pipelines import (
4
- ComponentSpec,
5
  InputParam,
6
- ModularPipelineBlocks,
7
  OutputParam,
 
8
  PipelineState,
9
  )
10
 
11
 
12
- class MyCustomBlock(ModularPipelineBlocks):
13
- """
14
- A custom block for [describe what your block does].
15
-
16
- Replace this with your implementation.
17
- """
18
 
19
  @property
20
  def description(self) -> str:
21
- """Description of the block."""
22
- return "A template custom block - replace with your description"
23
-
24
-
25
- @property
26
- def expected_components(self) -> List[ComponentSpec]:
27
- """Define model components your block needs (e.g., transformers, VAEs)."""
28
- return [
29
- # Example:
30
- # ComponentSpec(
31
- # name="model",
32
- # type_hint=SomeModelClass,
33
- # repo="organization/model-name",
34
- # ),
35
- ]
36
 
37
  @property
38
  def inputs(self) -> List[InputParam]:
39
- """Define input parameters for your block."""
40
  return [
41
- # Example:
42
- # InputParam(
43
- # "prompt",
44
- # type_hint=str,
45
- # required=True,
46
- # description="Input prompt",
47
- # metadata={"mellon": "textbox"}, # For Mellon UI
48
- # ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  ]
50
 
51
  @property
52
- def intermediate_outputs(self) -> List[OutputParam]:
53
- """Define output parameters for your block."""
54
  return [
55
- # Example:
56
- # OutputParam(
57
- # "result",
58
- # type_hint=str,
59
- # description="Output result",
60
- # metadata={"mellon": "text"}, # For Mellon UI
61
- # ),
62
  ]
63
 
64
- def __call__(self, components, state: PipelineState) -> PipelineState:
65
- """Execute your block logic."""
 
 
66
  block_state = self.get_block_state(state)
67
-
68
- # Your implementation here
69
- # Access inputs via block_state.<input_name>
70
- # Set outputs via block_state.<output_name> = value
71
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  self.set_block_state(state, block_state)
73
- return components, state
 
1
+ from typing import List, Optional
 
2
  from diffusers.modular_pipelines import (
 
3
  InputParam,
 
4
  OutputParam,
5
+ PipelineBlock,
6
  PipelineState,
7
  )
8
 
9
 
10
+ class QuantizationConfigBlock(PipelineBlock):
11
+ """Block to create BitsAndBytes quantization config for model loading."""
 
 
 
 
12
 
13
  @property
14
  def description(self) -> str:
15
+ return "Creates a BitsAndBytes quantization config for loading models with reduced precision"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  @property
18
  def inputs(self) -> List[InputParam]:
 
19
  return [
20
+ # Target component
21
+ InputParam(
22
+ "component",
23
+ type_hint=str,
24
+ default="transformer",
25
+ description="Component name to apply quantization to",
26
+ metadata={"mellon": "dropdown"}
27
+ ),
28
+ # Bits selection
29
+ InputParam(
30
+ "quant_type",
31
+ type_hint=str,
32
+ default="bnb_4bit",
33
+ description="Quantization backend Type",
34
+ metadata={"mellon": "dropdown"}, # "options": ["bnb_4bit", "bnb_8bit"]
35
+ ),
36
+
37
+ # ===== 4-bit options =====
38
+ InputParam(
39
+ "bnb_4bit_quant_type",
40
+ type_hint=str,
41
+ default="nf4",
42
+ description="4-bit quantization type",
43
+ metadata={"mellon": "dropdown"}, # "options": ["nf4", "fp4"]
44
+ ),
45
+ InputParam(
46
+ "bnb_4bit_compute_dtype",
47
+ type_hint=Optional[str],
48
+ description="Compute dtype for 4-bit quantization",
49
+ metadata={"mellon": "dropdown"}, # "options": ["", "float32", "float16", "bfloat16"]
50
+ ),
51
+ InputParam(
52
+ "bnb_4bit_use_double_quant",
53
+ type_hint=bool,
54
+ default=False,
55
+ description="Use nested quantization (quantize the quantization constants)",
56
+ metadata={"mellon": "checkbox"}
57
+ ),
58
+
59
+ # ===== 8-bit options =====
60
+ InputParam(
61
+ "llm_int8_threshold",
62
+ type_hint=float,
63
+ default=6.0,
64
+ description="Outlier threshold for 8-bit quantization (values above this use fp16)",
65
+ metadata={"mellon": "slider"},
66
+ ),
67
+ InputParam(
68
+ "llm_int8_has_fp16_weight",
69
+ type_hint=bool,
70
+ default=False,
71
+ description="Keep weights in fp16 for 8-bit (useful for fine-tuning)",
72
+ metadata={"mellon": "checkbox"},
73
+ ),
74
+ InputParam(
75
+ "llm_int8_skip_modules",
76
+ type_hint=Optional[List[str]],
77
+ ),
78
  ]
79
 
80
  @property
81
+ def intermediates_outputs(self) -> List[OutputParam]:
 
82
  return [
83
+ OutputParam(
84
+ "quantization_config",
85
+ type_hint=dict,
86
+ description="Quantization config dict for load_components",
87
+ ),
 
 
88
  ]
89
 
90
+ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
91
+ import torch
92
+ from diffusers import BitsAndBytesConfig
93
+
94
  block_state = self.get_block_state(state)
95
+
96
+ # Map string dtype to torch dtype
97
+ def str_to_dtype(dtype_str):
98
+ dtype_map = {
99
+ "": None,
100
+ "float32": torch.float32,
101
+ "float16": torch.float16,
102
+ "bfloat16": torch.bfloat16,
103
+ "uint8": torch.uint8,
104
+ "int8": torch.int8,
105
+ "float64": torch.float64,
106
+ }
107
+ return dtype_map.get(dtype_str, None)
108
+
109
+ if block_state.quant_type == "bnb_4bit":
110
+ config = BitsAndBytesConfig(
111
+ load_in_4bit=True,
112
+ bnb_4bit_quant_type=block_state.bnb_4bit_quant_type,
113
+ bnb_4bit_compute_dtype=str_to_dtype(block_state.bnb_4bit_compute_dtype),
114
+ bnb_4bit_use_double_quant=block_state.bnb_4bit_use_double_quant,
115
+ llm_int8_skip_modules=block_state.llm_int8_skip_modules,
116
+ )
117
+ elif block_state.quant_type == "bnb_8bit":
118
+ config = BitsAndBytesConfig(
119
+ load_in_8bit=True,
120
+ llm_int8_threshold=block_state.llm_int8_threshold,
121
+ llm_int8_has_fp16_weight=block_state.llm_int8_has_fp16_weight,
122
+ llm_int8_skip_modules=block_state.llm_int8_skip_modules,
123
+ )
124
+
125
+ # Output as dict: {"transformer": config}
126
+ quantization_config = {block_state.component: config}
127
+
128
+ block_state.quantization_config = quantization_config
129
  self.set_block_state(state, block_state)
130
+ return pipeline, state