MyApricity commited on
Commit
facf29b
·
verified ·
1 Parent(s): 3f21de6

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +43 -107
src/pipeline.py CHANGED
@@ -28,20 +28,19 @@ REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"
28
  Pipeline = None
29
 
30
 
31
- # ---------------- NF4 ----------------
32
- def functional_linear_4bits(x, weight, bias):
33
- out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
34
- out = out.to(x)
35
- return out
36
 
 
 
 
 
37
 
38
- def quant_state_copier(state, device=None):
 
39
  if state is None:
40
  return None
41
 
42
  device = device or state.absmax.device
43
-
44
- state2 = (
45
  QuantState(
46
  absmax=state.state2.absmax.to(device),
47
  shape=state.state2.shape,
@@ -50,8 +49,7 @@ def quant_state_copier(state, device=None):
50
  quant_type=state.state2.quant_type,
51
  dtype=state.state2.dtype,
52
  )
53
- if state.nested
54
- else None
55
  )
56
 
57
  return QuantState(
@@ -62,100 +60,45 @@ def quant_state_copier(state, device=None):
62
  quant_type=state.quant_type,
63
  dtype=state.dtype,
64
  offset=state.offset.to(device) if state.nested else None,
65
- state2=state2,
66
  )
67
 
68
-
69
- class Forge_Params_4Bit(Params4bit):
70
  def to(self, *args, **kwargs):
71
- device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
72
  if device is not None and device.type == "cuda" and not self.bnb_quantized:
73
  return self._quantize(device)
74
- else:
75
- n = Forge_Params_4Bit(
76
- torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
77
- requires_grad=self.requires_grad,
78
- quant_state=quant_state_copier(self.quant_state, device),
79
- compress_statistics=False,
80
- blocksize=64,
81
- quant_type=self.quant_type,
82
- quant_storage=self.quant_storage,
83
- bnb_quantized=self.bnb_quantized,
84
- module=self.module
85
- )
86
- self.module.quant_state = n.quant_state
87
- self.data = n.data
88
- self.quant_state = n.quant_state
89
- return n
90
-
91
-
92
- class Force_Loader_4Bits(torch.nn.Module):
93
- def __init__(self, *, device, dtype, quant_type, **kwargs):
94
  super().__init__()
95
  self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype))
96
  self.weight = None
97
  self.quant_state = None
98
  self.bias = None
99
- self.quant_type = quant_type
100
-
101
- def _save_to_state_dict(self, destination, prefix, keep_vars):
102
- super()._save_to_state_dict(destination, prefix, keep_vars)
103
- quant_state = getattr(self.weight, "quant_state", None)
104
- if quant_state is not None:
105
- for k, v in quant_state.as_dict(packed=True).items():
106
- destination[prefix + "weight." + k] = v if keep_vars else v.detach()
107
- return
108
-
109
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
110
- quant_state_keys = {k[len(prefix + "weight."):] for k in state_dict.keys() if k.startswith(prefix + "weight.")}
111
-
112
- if any('bitsandbytes' in k for k in quant_state_keys):
113
- quant_state_dict = {k: state_dict[prefix + "weight." + k] for k in quant_state_keys}
114
-
115
- self.weight = Forge_Params_4Bit.from_prequantized(
116
- data=state_dict[prefix + 'weight'],
117
- quantized_stats=quant_state_dict,
118
- requires_grad=False,
119
- device=torch.device('cuda'),
120
- module=self
121
- )
122
- self.quant_state = self.weight.quant_state
123
-
124
- if prefix + 'bias' in state_dict:
125
- self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
126
-
127
- del self.dummy
128
- elif hasattr(self, 'dummy'):
129
- if prefix + 'weight' in state_dict:
130
- self.weight = Forge_Params_4Bit(
131
- state_dict[prefix + 'weight'].to(self.dummy),
132
- requires_grad=False,
133
- compress_statistics=True,
134
- quant_type=self.quant_type,
135
- quant_storage=torch.uint8,
136
- module=self,
137
- )
138
- self.quant_state = self.weight.quant_state
139
-
140
- if prefix + 'bias' in state_dict:
141
- self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
142
-
143
- del self.dummy
144
- else:
145
- super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
146
-
147
-
148
- class CustomLinear(Force_Loader_4Bits):
149
- def __init__(self, *args, device=None, dtype=None, **kwargs):
150
- super().__init__(device=device, dtype=dtype, quant_type='nf4')
151
 
152
  def forward(self, x):
153
  self.weight.quant_state = self.quant_state
154
-
155
  if self.bias is not None and self.bias.dtype != x.dtype:
156
  self.bias.data = self.bias.data.to(x.dtype)
157
-
158
- return functional_linear_4bits(x, self.weight, self.bias)
159
 
160
 
161
  class InitModel:
@@ -170,16 +113,6 @@ class InitModel:
170
  )
171
  return text_encoder.to(memory_format=torch.channels_last)
172
 
173
- @staticmethod
174
- def load_vae() -> AutoencoderTiny:
175
- print("Loading VAE model...")
176
- vae = AutoencoderTiny.from_pretrained(
177
- "XiangquiAI/FLUX_Vae_Model",
178
- revision="103bcc03998f48ef311c100ee119f1b9942132ab",
179
- torch_dtype=torch.bfloat16,
180
- )
181
- return vae
182
-
183
  @staticmethod
184
  def load_transformer(trans_path: str) -> FluxTransformer2DModel:
185
  print("Loading transformer model...")
@@ -194,8 +127,7 @@ class InitModel:
194
 
195
  def load_pipeline() -> Pipeline:
196
 
197
- t5_encoder_2 = InitModel.load_text_encoder()
198
- vae = InitModel.load_vae()
199
 
200
  transformer_path = os.path.join(HF_HUB_CACHE, "models--MyApricity--Flux_Transformer_float8/snapshots/66c5f182385555a00ec90272ab711bb6d3c197db")
201
  transformer = InitModel.load_transformer(transformer_path)
@@ -203,9 +135,7 @@ def load_pipeline() -> Pipeline:
203
 
204
  pipeline = DiffusionPipeline.from_pretrained(CHECKPOINT,
205
  revision=REVISION,
206
- vae=vae,
207
  transformer=transformer,
208
- text_encoder_2=t5_encoder_2,
209
  torch_dtype=torch.bfloat16)
210
  pipeline.to("cuda")
211
 
@@ -213,19 +143,25 @@ def load_pipeline() -> Pipeline:
213
  # Enable some options for better vae
214
  pipeline.enable_vae_slicing()
215
  pipeline.enable_vae_tiling()
216
- torch.nn.LinearLayer = CustomLinear
217
  except:
218
  print("Debug here")
 
 
 
 
 
 
219
 
220
 
221
- prms = [
222
- "melanogen, tiptilt",
223
  "melanogen, endosome, apical, polymyodous, ",
224
  "buffer, cutie, buttinsky, prototrophic",
225
  "puzzlehead",
226
  ]
227
 
228
- for warmprompt in prms:
229
  pipeline(prompt=warmprompt,
230
  width=1024,
231
  height=1024,
 
28
  Pipeline = None
29
 
30
 
 
 
 
 
 
31
 
32
+ def quantized_matrix_multiply(x, weight, bias):
33
+ """Perform matrix multiplication for 4-bit quantized weights."""
34
+ output = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
35
+ return output.to(x)
36
 
37
+ def copy_quant_state(state, device=None):
38
+ """Create a copy of quantization state for a given device."""
39
  if state is None:
40
  return None
41
 
42
  device = device or state.absmax.device
43
+ nested_state = (
 
44
  QuantState(
45
  absmax=state.state2.absmax.to(device),
46
  shape=state.state2.shape,
 
49
  quant_type=state.state2.quant_type,
50
  dtype=state.state2.dtype,
51
  )
52
+ if state.nested else None
 
53
  )
54
 
55
  return QuantState(
 
60
  quant_type=state.quant_type,
61
  dtype=state.dtype,
62
  offset=state.offset.to(device) if state.nested else None,
63
+ state2=nested_state,
64
  )
65
 
66
+ class QuantizedModelParams(Params4bit):
 
67
  def to(self, *args, **kwargs):
68
+ device, dtype, non_blocking, _ = torch._C._nn._parse_to(*args, **kwargs)
69
  if device is not None and device.type == "cuda" and not self.bnb_quantized:
70
  return self._quantize(device)
71
+
72
+ updated_params = QuantizedModelParams(
73
+ torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
74
+ requires_grad=self.requires_grad,
75
+ quant_state=copy_quant_state(self.quant_state, device),
76
+ compress_statistics=False,
77
+ blocksize=64,
78
+ quant_type=self.quant_type,
79
+ quant_storage=self.quant_storage,
80
+ bnb_quantized=self.bnb_quantized,
81
+ module=self.module
82
+ )
83
+ self.module.quant_state = updated_params.quant_state
84
+ self.data = updated_params.data
85
+ self.quant_state = updated_params.quant_state
86
+ return updated_params
87
+
88
+ class QuantizedLinearLayer(torch.nn.Module):
89
+ def __init__(self, *args, device=None, dtype=None, **kwargs):
 
90
  super().__init__()
91
  self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype))
92
  self.weight = None
93
  self.quant_state = None
94
  self.bias = None
95
+ self.quant_type = 'nf4'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  def forward(self, x):
98
  self.weight.quant_state = self.quant_state
 
99
  if self.bias is not None and self.bias.dtype != x.dtype:
100
  self.bias.data = self.bias.data.to(x.dtype)
101
+ return quantized_matrix_multiply(x, self.weight, self.bias)
 
102
 
103
 
104
  class InitModel:
 
113
  )
114
  return text_encoder.to(memory_format=torch.channels_last)
115
 
 
 
 
 
 
 
 
 
 
 
116
  @staticmethod
117
  def load_transformer(trans_path: str) -> FluxTransformer2DModel:
118
  print("Loading transformer model...")
 
127
 
128
  def load_pipeline() -> Pipeline:
129
 
130
+
 
131
 
132
  transformer_path = os.path.join(HF_HUB_CACHE, "models--MyApricity--Flux_Transformer_float8/snapshots/66c5f182385555a00ec90272ab711bb6d3c197db")
133
  transformer = InitModel.load_transformer(transformer_path)
 
135
 
136
  pipeline = DiffusionPipeline.from_pretrained(CHECKPOINT,
137
  revision=REVISION,
 
138
  transformer=transformer,
 
139
  torch_dtype=torch.bfloat16)
140
  pipeline.to("cuda")
141
 
 
143
  # Enable some options for better vae
144
  pipeline.enable_vae_slicing()
145
  pipeline.enable_vae_tiling()
146
+ torch.nn.LinearLayer = QuantizedLinearLayer
147
  except:
148
  print("Debug here")
149
+
150
+ try:
151
+ pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
152
+
153
+ except:
154
+ print("nothing")
155
 
156
 
157
+ ps = [
158
+ "overgross, mandative, inventful, braunite, penneeck",
159
  "melanogen, endosome, apical, polymyodous, ",
160
  "buffer, cutie, buttinsky, prototrophic",
161
  "puzzlehead",
162
  ]
163
 
164
+ for warmprompt in ps:
165
  pipeline(prompt=warmprompt,
166
  width=1024,
167
  height=1024,