File size: 8,916 Bytes
4fcd1d5 24f9b3f 4fcd1d5 24f9b3f 4fcd1d5 24f9b3f 4fcd1d5 24f9b3f 4fcd1d5 24f9b3f 4fcd1d5 7cf7a66 4fcd1d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 |
import os
import torch
import torch._dynamo
import gc
import bitsandbytes as bnb
from bitsandbytes.nn.modules import Params4bit, QuantState
import json
import transformers
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from PIL.Image import Image
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from pipelines.models import TextToImageRequest
import json
torch._dynamo.config.suppress_errors = True
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
CHECKPOINT = "black-forest-labs/FLUX.1-schnell"
REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"
Pipeline = None
import torch
import math
from typing import Dict, Any
def remove_cache():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
# ---------------- NF4 ----------------
def functional_linear_4bits(x, weight, bias):
out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
out = out.to(x)
return out
def copy_quant_state(state, device=None):
if state is None:
return None
device = device or state.absmax.device
state2 = (
QuantState(
absmax=state.state2.absmax.to(device),
shape=state.state2.shape,
code=state.state2.code.to(device),
blocksize=state.state2.blocksize,
quant_type=state.state2.quant_type,
dtype=state.state2.dtype,
)
if state.nested
else None
)
return QuantState(
absmax=state.absmax.to(device),
shape=state.shape,
code=state.code,
blocksize=state.blocksize,
quant_type=state.quant_type,
dtype=state.dtype,
offset=state.offset.to(device) if state.nested else None,
state2=state2,
)
class ForgeParams4bit(Params4bit):
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None and device.type == "cuda" and not self.bnb_quantized:
return self._quantize(device)
else:
n = ForgeParams4bit(
torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad,
quant_state=copy_quant_state(self.quant_state, device),
compress_statistics=False,
blocksize=64,
quant_type=self.quant_type,
quant_storage=self.quant_storage,
bnb_quantized=self.bnb_quantized,
module=self.module
)
self.module.quant_state = n.quant_state
self.data = n.data
self.quant_state = n.quant_state
return n
class ForgeLoader4Bit(torch.nn.Module):
def __init__(self, *, device, dtype, quant_type, **kwargs):
super().__init__()
self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype))
self.weight = None
self.quant_state = None
self.bias = None
self.quant_type = quant_type
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
quant_state = getattr(self.weight, "quant_state", None)
if quant_state is not None:
for k, v in quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
return
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
quant_state_keys = {k[len(prefix + "weight."):] for k in state_dict.keys() if k.startswith(prefix + "weight.")}
if any('bitsandbytes' in k for k in quant_state_keys):
quant_state_dict = {k: state_dict[prefix + "weight." + k] for k in quant_state_keys}
self.weight = ForgeParams4bit.from_prequantized(
data=state_dict[prefix + 'weight'],
quantized_stats=quant_state_dict,
requires_grad=False,
device=torch.device('cuda'),
module=self
)
self.quant_state = self.weight.quant_state
if prefix + 'bias' in state_dict:
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
del self.dummy
elif hasattr(self, 'dummy'):
if prefix + 'weight' in state_dict:
self.weight = ForgeParams4bit(
state_dict[prefix + 'weight'].to(self.dummy),
requires_grad=False,
compress_statistics=True,
quant_type=self.quant_type,
quant_storage=torch.uint8,
module=self,
)
self.quant_state = self.weight.quant_state
if prefix + 'bias' in state_dict:
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
del self.dummy
else:
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
class Linear(ForgeLoader4Bit):
def __init__(self, *args, device=None, dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, quant_type='nf4')
def forward(self, x):
self.weight.quant_state = self.quant_state
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
return functional_linear_4bits(x, self.weight, self.bias)
# Replace nn.Linear with the 4-bit quantized Linear
# torch.nn.Linear = Linear
class InitModel:
@staticmethod
def load_text_encoder() -> T5EncoderModel:
print("Loading text encoder...")
text_encoder = T5EncoderModel.from_pretrained(
"city96/t5-v1_1-xxl-encoder-bf16",
revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
torch_dtype=torch.bfloat16,
)
return text_encoder.to(memory_format=torch.channels_last)
@staticmethod
def load_vae() -> AutoencoderTiny:
print("Loading VAE model...")
vae = AutoencoderTiny.from_pretrained(
"XiangquiAI/FLUX_Vae_Model",
revision="103bcc03998f48ef311c100ee119f1b9942132ab",
torch_dtype=torch.bfloat16,
)
return vae
@staticmethod
def load_transformer(trans_path: str) -> FluxTransformer2DModel:
print("Loading transformer model...")
transformer = FluxTransformer2DModel.from_pretrained(
trans_path,
torch_dtype=torch.bfloat16,
use_safetensors=False,
)
return transformer.to(memory_format=torch.channels_last)
def load_pipeline() -> Pipeline:
transformer_path = os.path.join(HF_HUB_CACHE, "models--MyApricity--Flux_Transformer_float8/snapshots/66c5f182385555a00ec90272ab711bb6d3c197db")
transformer = InitModel.load_transformer(transformer_path)
text_encoder_2 = InitModel.load_text_encoder()
vae = InitModel.load_vae()
pipeline = DiffusionPipeline.from_pretrained(CHECKPOINT,
revision=REVISION,
vae=vae,
transformer=transformer,
text_encoder_2=text_encoder_2,
torch_dtype=torch.bfloat16)
pipeline.to("cuda")
try:
pipeline.enable_vae_slicing()
torch.nn.LinearLayer = Linear
except:
print("Using origin pipeline")
prms = [
"melanogen, tiptilt",
"melanogen, endosome, apical, polymyodous, ",
"buffer, cutie, buttinsky, prototrophic",
"puzzlehead",
]
for prompt in prms:
pipeline(prompt=prompt,
width=1024,
height=1024,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256)
return pipeline
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
remove_cache()
# remove cache here for better result
generator = Generator(pipeline.device).manual_seed(request.seed)
return pipeline(
request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
).images[0] |