Your Name commited on
Commit
d25aa2d
·
1 Parent(s): 24f9b3f
Files changed (1) hide show
  1. src/pipeline.py +22 -32
src/pipeline.py CHANGED
@@ -7,9 +7,8 @@ from bitsandbytes.nn.modules import Params4bit, QuantState
7
  import json
8
  import transformers
9
  from huggingface_hub.constants import HF_HUB_CACHE
10
- from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
11
 
12
- from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
13
  from torch import Generator
14
  from diffusers import FluxTransformer2DModel, DiffusionPipeline
15
 
@@ -20,7 +19,6 @@ import json
20
 
21
 
22
 
23
-
24
  torch._dynamo.config.suppress_errors = True
25
  os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
26
  os.environ["TOKENIZERS_PARALLELISM"] = "True"
@@ -30,16 +28,6 @@ REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"
30
  Pipeline = None
31
 
32
 
33
- import torch
34
- import math
35
- from typing import Dict, Any
36
-
37
- def remove_cache():
38
- gc.collect()
39
- torch.cuda.empty_cache()
40
- torch.cuda.reset_max_memory_allocated()
41
- torch.cuda.reset_peak_memory_stats()
42
-
43
  # ---------------- NF4 ----------------
44
  def functional_linear_4bits(x, weight, bias):
45
  out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
@@ -47,7 +35,7 @@ def functional_linear_4bits(x, weight, bias):
47
  return out
48
 
49
 
50
- def copy_quant_state(state, device=None):
51
  if state is None:
52
  return None
53
 
@@ -78,16 +66,16 @@ def copy_quant_state(state, device=None):
78
  )
79
 
80
 
81
- class ForgeParams4bit(Params4bit):
82
  def to(self, *args, **kwargs):
83
  device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
84
  if device is not None and device.type == "cuda" and not self.bnb_quantized:
85
  return self._quantize(device)
86
  else:
87
- n = ForgeParams4bit(
88
  torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
89
  requires_grad=self.requires_grad,
90
- quant_state=copy_quant_state(self.quant_state, device),
91
  compress_statistics=False,
92
  blocksize=64,
93
  quant_type=self.quant_type,
@@ -101,7 +89,7 @@ class ForgeParams4bit(Params4bit):
101
  return n
102
 
103
 
104
- class ForgeLoader4Bit(torch.nn.Module):
105
  def __init__(self, *, device, dtype, quant_type, **kwargs):
106
  super().__init__()
107
  self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype))
@@ -124,7 +112,7 @@ class ForgeLoader4Bit(torch.nn.Module):
124
  if any('bitsandbytes' in k for k in quant_state_keys):
125
  quant_state_dict = {k: state_dict[prefix + "weight." + k] for k in quant_state_keys}
126
 
127
- self.weight = ForgeParams4bit.from_prequantized(
128
  data=state_dict[prefix + 'weight'],
129
  quantized_stats=quant_state_dict,
130
  requires_grad=False,
@@ -139,7 +127,7 @@ class ForgeLoader4Bit(torch.nn.Module):
139
  del self.dummy
140
  elif hasattr(self, 'dummy'):
141
  if prefix + 'weight' in state_dict:
142
- self.weight = ForgeParams4bit(
143
  state_dict[prefix + 'weight'].to(self.dummy),
144
  requires_grad=False,
145
  compress_statistics=True,
@@ -157,7 +145,7 @@ class ForgeLoader4Bit(torch.nn.Module):
157
  super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
158
 
159
 
160
- class Linear(ForgeLoader4Bit):
161
  def __init__(self, *args, device=None, dtype=None, **kwargs):
162
  super().__init__(device=device, dtype=dtype, quant_type='nf4')
163
 
@@ -170,9 +158,6 @@ class Linear(ForgeLoader4Bit):
170
  return functional_linear_4bits(x, self.weight, self.bias)
171
 
172
 
173
- # Replace nn.Linear with the 4-bit quantized Linear
174
- # torch.nn.Linear = Linear
175
-
176
  class InitModel:
177
 
178
  @staticmethod
@@ -209,26 +194,28 @@ class InitModel:
209
 
210
  def load_pipeline() -> Pipeline:
211
 
 
 
212
 
213
  transformer_path = os.path.join(HF_HUB_CACHE, "models--MyApricity--Flux_Transformer_float8/snapshots/66c5f182385555a00ec90272ab711bb6d3c197db")
214
  transformer = InitModel.load_transformer(transformer_path)
215
-
216
- text_encoder_2 = InitModel.load_text_encoder()
217
- vae = InitModel.load_vae()
218
-
219
 
220
  pipeline = DiffusionPipeline.from_pretrained(CHECKPOINT,
221
  revision=REVISION,
222
  vae=vae,
223
  transformer=transformer,
224
- text_encoder_2=text_encoder_2,
225
  torch_dtype=torch.bfloat16)
226
  pipeline.to("cuda")
 
227
  try:
 
228
  pipeline.enable_vae_slicing()
229
- torch.nn.LinearLayer = Linear
 
230
  except:
231
- print("Using origin pipeline")
232
 
233
 
234
  prms = [
@@ -252,7 +239,10 @@ def load_pipeline() -> Pipeline:
252
  @torch.no_grad()
253
  def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
254
 
255
- remove_cache()
 
 
 
256
  # remove cache here for better result
257
  generator = Generator(pipeline.device).manual_seed(request.seed)
258
 
 
7
  import json
8
  import transformers
9
  from huggingface_hub.constants import HF_HUB_CACHE
10
+ from transformers import T5EncoderModel, T5TokenizerFast
11
 
 
12
  from torch import Generator
13
  from diffusers import FluxTransformer2DModel, DiffusionPipeline
14
 
 
19
 
20
 
21
 
 
22
  torch._dynamo.config.suppress_errors = True
23
  os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
24
  os.environ["TOKENIZERS_PARALLELISM"] = "True"
 
28
  Pipeline = None
29
 
30
 
 
 
 
 
 
 
 
 
 
 
31
  # ---------------- NF4 ----------------
32
  def functional_linear_4bits(x, weight, bias):
33
  out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
 
35
  return out
36
 
37
 
38
+ def quant_state_copier(state, device=None):
39
  if state is None:
40
  return None
41
 
 
66
  )
67
 
68
 
69
+ class Forge_Params_4Bit(Params4bit):
70
  def to(self, *args, **kwargs):
71
  device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
72
  if device is not None and device.type == "cuda" and not self.bnb_quantized:
73
  return self._quantize(device)
74
  else:
75
+ n = Forge_Params_4Bit(
76
  torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
77
  requires_grad=self.requires_grad,
78
+ quant_state=quant_state_copier(self.quant_state, device),
79
  compress_statistics=False,
80
  blocksize=64,
81
  quant_type=self.quant_type,
 
89
  return n
90
 
91
 
92
+ class Force_Loader_4Bits(torch.nn.Module):
93
  def __init__(self, *, device, dtype, quant_type, **kwargs):
94
  super().__init__()
95
  self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype))
 
112
  if any('bitsandbytes' in k for k in quant_state_keys):
113
  quant_state_dict = {k: state_dict[prefix + "weight." + k] for k in quant_state_keys}
114
 
115
+ self.weight = Forge_Params_4Bit.from_prequantized(
116
  data=state_dict[prefix + 'weight'],
117
  quantized_stats=quant_state_dict,
118
  requires_grad=False,
 
127
  del self.dummy
128
  elif hasattr(self, 'dummy'):
129
  if prefix + 'weight' in state_dict:
130
+ self.weight = Forge_Params_4Bit(
131
  state_dict[prefix + 'weight'].to(self.dummy),
132
  requires_grad=False,
133
  compress_statistics=True,
 
145
  super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
146
 
147
 
148
+ class CustomLinear(Force_Loader_4Bits):
149
  def __init__(self, *args, device=None, dtype=None, **kwargs):
150
  super().__init__(device=device, dtype=dtype, quant_type='nf4')
151
 
 
158
  return functional_linear_4bits(x, self.weight, self.bias)
159
 
160
 
 
 
 
161
  class InitModel:
162
 
163
  @staticmethod
 
194
 
195
  def load_pipeline() -> Pipeline:
196
 
197
+ t5_encoder_2 = InitModel.load_text_encoder()
198
+ vae = InitModel.load_vae()
199
 
200
  transformer_path = os.path.join(HF_HUB_CACHE, "models--MyApricity--Flux_Transformer_float8/snapshots/66c5f182385555a00ec90272ab711bb6d3c197db")
201
  transformer = InitModel.load_transformer(transformer_path)
202
+
 
 
 
203
 
204
  pipeline = DiffusionPipeline.from_pretrained(CHECKPOINT,
205
  revision=REVISION,
206
  vae=vae,
207
  transformer=transformer,
208
+ text_encoder_2=t5_encoder_2,
209
  torch_dtype=torch.bfloat16)
210
  pipeline.to("cuda")
211
+
212
  try:
213
+ # Enable some options for better vae
214
  pipeline.enable_vae_slicing()
215
+ pipeline.enable_vae_tiling()
216
+ torch.nn.LinearLayer = CustomLinear
217
  except:
218
+ print("Debug here")
219
 
220
 
221
  prms = [
 
239
  @torch.no_grad()
240
  def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
241
 
242
+ torch.cuda.empty_cache()
243
+ torch.cuda.reset_max_memory_allocated()
244
+ torch.cuda.reset_peak_memory_stats()
245
+
246
  # remove cache here for better result
247
  generator = Generator(pipeline.device).manual_seed(request.seed)
248