transquant_3 / src /pipeline.py
MyApricity's picture
Update src/pipeline.py
3f21de6 verified
raw
history blame
8.72 kB
import os
import torch
import torch._dynamo
import gc
import bitsandbytes as bnb
from bitsandbytes.nn.modules import Params4bit, QuantState
import json
import transformers
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from PIL.Image import Image
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from pipelines.models import TextToImageRequest
import json
torch._dynamo.config.suppress_errors = True
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
CHECKPOINT = "black-forest-labs/FLUX.1-schnell"
REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"
Pipeline = None
# ---------------- NF4 ----------------
def functional_linear_4bits(x, weight, bias):
out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
out = out.to(x)
return out
def quant_state_copier(state, device=None):
if state is None:
return None
device = device or state.absmax.device
state2 = (
QuantState(
absmax=state.state2.absmax.to(device),
shape=state.state2.shape,
code=state.state2.code.to(device),
blocksize=state.state2.blocksize,
quant_type=state.state2.quant_type,
dtype=state.state2.dtype,
)
if state.nested
else None
)
return QuantState(
absmax=state.absmax.to(device),
shape=state.shape,
code=state.code,
blocksize=state.blocksize,
quant_type=state.quant_type,
dtype=state.dtype,
offset=state.offset.to(device) if state.nested else None,
state2=state2,
)
class Forge_Params_4Bit(Params4bit):
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None and device.type == "cuda" and not self.bnb_quantized:
return self._quantize(device)
else:
n = Forge_Params_4Bit(
torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad,
quant_state=quant_state_copier(self.quant_state, device),
compress_statistics=False,
blocksize=64,
quant_type=self.quant_type,
quant_storage=self.quant_storage,
bnb_quantized=self.bnb_quantized,
module=self.module
)
self.module.quant_state = n.quant_state
self.data = n.data
self.quant_state = n.quant_state
return n
class Force_Loader_4Bits(torch.nn.Module):
def __init__(self, *, device, dtype, quant_type, **kwargs):
super().__init__()
self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype))
self.weight = None
self.quant_state = None
self.bias = None
self.quant_type = quant_type
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
quant_state = getattr(self.weight, "quant_state", None)
if quant_state is not None:
for k, v in quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
return
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
quant_state_keys = {k[len(prefix + "weight."):] for k in state_dict.keys() if k.startswith(prefix + "weight.")}
if any('bitsandbytes' in k for k in quant_state_keys):
quant_state_dict = {k: state_dict[prefix + "weight." + k] for k in quant_state_keys}
self.weight = Forge_Params_4Bit.from_prequantized(
data=state_dict[prefix + 'weight'],
quantized_stats=quant_state_dict,
requires_grad=False,
device=torch.device('cuda'),
module=self
)
self.quant_state = self.weight.quant_state
if prefix + 'bias' in state_dict:
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
del self.dummy
elif hasattr(self, 'dummy'):
if prefix + 'weight' in state_dict:
self.weight = Forge_Params_4Bit(
state_dict[prefix + 'weight'].to(self.dummy),
requires_grad=False,
compress_statistics=True,
quant_type=self.quant_type,
quant_storage=torch.uint8,
module=self,
)
self.quant_state = self.weight.quant_state
if prefix + 'bias' in state_dict:
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
del self.dummy
else:
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
class CustomLinear(Force_Loader_4Bits):
def __init__(self, *args, device=None, dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, quant_type='nf4')
def forward(self, x):
self.weight.quant_state = self.quant_state
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
return functional_linear_4bits(x, self.weight, self.bias)
class InitModel:
@staticmethod
def load_text_encoder() -> T5EncoderModel:
print("Loading text encoder...")
text_encoder = T5EncoderModel.from_pretrained(
"city96/t5-v1_1-xxl-encoder-bf16",
revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
torch_dtype=torch.bfloat16,
)
return text_encoder.to(memory_format=torch.channels_last)
@staticmethod
def load_vae() -> AutoencoderTiny:
print("Loading VAE model...")
vae = AutoencoderTiny.from_pretrained(
"XiangquiAI/FLUX_Vae_Model",
revision="103bcc03998f48ef311c100ee119f1b9942132ab",
torch_dtype=torch.bfloat16,
)
return vae
@staticmethod
def load_transformer(trans_path: str) -> FluxTransformer2DModel:
print("Loading transformer model...")
transformer = FluxTransformer2DModel.from_pretrained(
trans_path,
torch_dtype=torch.bfloat16,
use_safetensors=False,
)
return transformer.to(memory_format=torch.channels_last)
def load_pipeline() -> Pipeline:
t5_encoder_2 = InitModel.load_text_encoder()
vae = InitModel.load_vae()
transformer_path = os.path.join(HF_HUB_CACHE, "models--MyApricity--Flux_Transformer_float8/snapshots/66c5f182385555a00ec90272ab711bb6d3c197db")
transformer = InitModel.load_transformer(transformer_path)
pipeline = DiffusionPipeline.from_pretrained(CHECKPOINT,
revision=REVISION,
vae=vae,
transformer=transformer,
text_encoder_2=t5_encoder_2,
torch_dtype=torch.bfloat16)
pipeline.to("cuda")
try:
# Enable some options for better vae
pipeline.enable_vae_slicing()
pipeline.enable_vae_tiling()
torch.nn.LinearLayer = CustomLinear
except:
print("Debug here")
prms = [
"melanogen, tiptilt",
"melanogen, endosome, apical, polymyodous, ",
"buffer, cutie, buttinsky, prototrophic",
"puzzlehead",
]
for warmprompt in prms:
pipeline(prompt=warmprompt,
width=1024,
height=1024,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256)
return pipeline
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
# remove cache here for better result
generator = Generator(pipeline.device).manual_seed(request.seed)
return pipeline(
request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
).images[0]