Your Name commited on
Commit
24f9b3f
·
1 Parent(s): 4fcd1d5
Files changed (2) hide show
  1. pyproject.toml +1 -1
  2. src/pipeline.py +141 -8
pyproject.toml CHANGED
@@ -16,7 +16,7 @@ dependencies = [
16
  "protobuf==5.28.3",
17
  "sentencepiece==0.2.0",
18
  "torchao==0.6.1",
19
- "optimum-quanto",
20
  "hf_transfer==0.1.8",
21
  "setuptools==75.2.0",
22
  "edge-maxxing-pipelines @ git+https://github.com/womboai/edge-maxxing@7c760ac54f6052803dadb3ade8ebfc9679a94589#subdirectory=pipelines",
 
16
  "protobuf==5.28.3",
17
  "sentencepiece==0.2.0",
18
  "torchao==0.6.1",
19
+ "bitsandbytes",
20
  "hf_transfer==0.1.8",
21
  "setuptools==75.2.0",
22
  "edge-maxxing-pipelines @ git+https://github.com/womboai/edge-maxxing@7c760ac54f6052803dadb3ade8ebfc9679a94589#subdirectory=pipelines",
src/pipeline.py CHANGED
@@ -2,7 +2,8 @@ import os
2
  import torch
3
  import torch._dynamo
4
  import gc
5
-
 
6
  import json
7
  import transformers
8
  from huggingface_hub.constants import HF_HUB_CACHE
@@ -15,7 +16,6 @@ from diffusers import FluxTransformer2DModel, DiffusionPipeline
15
  from PIL.Image import Image
16
  from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
17
  from pipelines.models import TextToImageRequest
18
- from optimum.quanto import requantize
19
  import json
20
 
21
 
@@ -40,6 +40,138 @@ def remove_cache():
40
  torch.cuda.reset_max_memory_allocated()
41
  torch.cuda.reset_peak_memory_stats()
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  class InitModel:
45
 
@@ -93,19 +225,20 @@ def load_pipeline() -> Pipeline:
93
  torch_dtype=torch.bfloat16)
94
  pipeline.to("cuda")
95
  try:
96
- pipeline.disable_vae_slice()
 
97
  except:
98
  print("Using origin pipeline")
99
 
100
 
101
- promts_listing = [
102
- "melanogen, endosome",
 
103
  "buffer, cutie, buttinsky, prototrophic",
104
- "puzzlehead, fistical, must return non duplicate",
105
- "apical, polymyodous, tiptilt"
106
  ]
107
 
108
- for p in promts_listing:
109
  pipeline(prompt=p,
110
  width=1024,
111
  height=1024,
 
2
  import torch
3
  import torch._dynamo
4
  import gc
5
+ import bitsandbytes as bnb
6
+ from bitsandbytes.nn.modules import Params4bit, QuantState
7
  import json
8
  import transformers
9
  from huggingface_hub.constants import HF_HUB_CACHE
 
16
  from PIL.Image import Image
17
  from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
18
  from pipelines.models import TextToImageRequest
 
19
  import json
20
 
21
 
 
40
  torch.cuda.reset_max_memory_allocated()
41
  torch.cuda.reset_peak_memory_stats()
42
 
43
+ # ---------------- NF4 ----------------
44
+ def functional_linear_4bits(x, weight, bias):
45
+ out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
46
+ out = out.to(x)
47
+ return out
48
+
49
+
50
+ def copy_quant_state(state, device=None):
51
+ if state is None:
52
+ return None
53
+
54
+ device = device or state.absmax.device
55
+
56
+ state2 = (
57
+ QuantState(
58
+ absmax=state.state2.absmax.to(device),
59
+ shape=state.state2.shape,
60
+ code=state.state2.code.to(device),
61
+ blocksize=state.state2.blocksize,
62
+ quant_type=state.state2.quant_type,
63
+ dtype=state.state2.dtype,
64
+ )
65
+ if state.nested
66
+ else None
67
+ )
68
+
69
+ return QuantState(
70
+ absmax=state.absmax.to(device),
71
+ shape=state.shape,
72
+ code=state.code,
73
+ blocksize=state.blocksize,
74
+ quant_type=state.quant_type,
75
+ dtype=state.dtype,
76
+ offset=state.offset.to(device) if state.nested else None,
77
+ state2=state2,
78
+ )
79
+
80
+
81
+ class ForgeParams4bit(Params4bit):
82
+ def to(self, *args, **kwargs):
83
+ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
84
+ if device is not None and device.type == "cuda" and not self.bnb_quantized:
85
+ return self._quantize(device)
86
+ else:
87
+ n = ForgeParams4bit(
88
+ torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
89
+ requires_grad=self.requires_grad,
90
+ quant_state=copy_quant_state(self.quant_state, device),
91
+ compress_statistics=False,
92
+ blocksize=64,
93
+ quant_type=self.quant_type,
94
+ quant_storage=self.quant_storage,
95
+ bnb_quantized=self.bnb_quantized,
96
+ module=self.module
97
+ )
98
+ self.module.quant_state = n.quant_state
99
+ self.data = n.data
100
+ self.quant_state = n.quant_state
101
+ return n
102
+
103
+
104
+ class ForgeLoader4Bit(torch.nn.Module):
105
+ def __init__(self, *, device, dtype, quant_type, **kwargs):
106
+ super().__init__()
107
+ self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype))
108
+ self.weight = None
109
+ self.quant_state = None
110
+ self.bias = None
111
+ self.quant_type = quant_type
112
+
113
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
114
+ super()._save_to_state_dict(destination, prefix, keep_vars)
115
+ quant_state = getattr(self.weight, "quant_state", None)
116
+ if quant_state is not None:
117
+ for k, v in quant_state.as_dict(packed=True).items():
118
+ destination[prefix + "weight." + k] = v if keep_vars else v.detach()
119
+ return
120
+
121
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
122
+ quant_state_keys = {k[len(prefix + "weight."):] for k in state_dict.keys() if k.startswith(prefix + "weight.")}
123
+
124
+ if any('bitsandbytes' in k for k in quant_state_keys):
125
+ quant_state_dict = {k: state_dict[prefix + "weight." + k] for k in quant_state_keys}
126
+
127
+ self.weight = ForgeParams4bit.from_prequantized(
128
+ data=state_dict[prefix + 'weight'],
129
+ quantized_stats=quant_state_dict,
130
+ requires_grad=False,
131
+ device=torch.device('cuda'),
132
+ module=self
133
+ )
134
+ self.quant_state = self.weight.quant_state
135
+
136
+ if prefix + 'bias' in state_dict:
137
+ self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
138
+
139
+ del self.dummy
140
+ elif hasattr(self, 'dummy'):
141
+ if prefix + 'weight' in state_dict:
142
+ self.weight = ForgeParams4bit(
143
+ state_dict[prefix + 'weight'].to(self.dummy),
144
+ requires_grad=False,
145
+ compress_statistics=True,
146
+ quant_type=self.quant_type,
147
+ quant_storage=torch.uint8,
148
+ module=self,
149
+ )
150
+ self.quant_state = self.weight.quant_state
151
+
152
+ if prefix + 'bias' in state_dict:
153
+ self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
154
+
155
+ del self.dummy
156
+ else:
157
+ super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
158
+
159
+
160
+ class Linear(ForgeLoader4Bit):
161
+ def __init__(self, *args, device=None, dtype=None, **kwargs):
162
+ super().__init__(device=device, dtype=dtype, quant_type='nf4')
163
+
164
+ def forward(self, x):
165
+ self.weight.quant_state = self.quant_state
166
+
167
+ if self.bias is not None and self.bias.dtype != x.dtype:
168
+ self.bias.data = self.bias.data.to(x.dtype)
169
+
170
+ return functional_linear_4bits(x, self.weight, self.bias)
171
+
172
+
173
+ # Replace nn.Linear with the 4-bit quantized Linear
174
+ # torch.nn.Linear = Linear
175
 
176
  class InitModel:
177
 
 
225
  torch_dtype=torch.bfloat16)
226
  pipeline.to("cuda")
227
  try:
228
+ pipeline.enable_vae_slicing()
229
+ torch.nn.LinearLayer = Linear
230
  except:
231
  print("Using origin pipeline")
232
 
233
 
234
+ prms = [
235
+ "melanogen, tiptilt",
236
+ "melanogen, endosome, apical, polymyodous, ",
237
  "buffer, cutie, buttinsky, prototrophic",
238
+ "puzzlehead",
 
239
  ]
240
 
241
+ for __ in prms:
242
  pipeline(prompt=p,
243
  width=1024,
244
  height=1024,