transquant_2 / src /pipeline.py
MyApricity's picture
Update src/pipeline.py
7cf7a66 verified
import os
import torch
import torch._dynamo
import gc
import bitsandbytes as bnb
from bitsandbytes.nn.modules import Params4bit, QuantState
import json
import transformers
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from PIL.Image import Image
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from pipelines.models import TextToImageRequest
import json
torch._dynamo.config.suppress_errors = True
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
CHECKPOINT = "black-forest-labs/FLUX.1-schnell"
REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"
Pipeline = None
import torch
import math
from typing import Dict, Any
def remove_cache():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
# ---------------- NF4 ----------------
def functional_linear_4bits(x, weight, bias):
out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
out = out.to(x)
return out
def copy_quant_state(state, device=None):
if state is None:
return None
device = device or state.absmax.device
state2 = (
QuantState(
absmax=state.state2.absmax.to(device),
shape=state.state2.shape,
code=state.state2.code.to(device),
blocksize=state.state2.blocksize,
quant_type=state.state2.quant_type,
dtype=state.state2.dtype,
)
if state.nested
else None
)
return QuantState(
absmax=state.absmax.to(device),
shape=state.shape,
code=state.code,
blocksize=state.blocksize,
quant_type=state.quant_type,
dtype=state.dtype,
offset=state.offset.to(device) if state.nested else None,
state2=state2,
)
class ForgeParams4bit(Params4bit):
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None and device.type == "cuda" and not self.bnb_quantized:
return self._quantize(device)
else:
n = ForgeParams4bit(
torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad,
quant_state=copy_quant_state(self.quant_state, device),
compress_statistics=False,
blocksize=64,
quant_type=self.quant_type,
quant_storage=self.quant_storage,
bnb_quantized=self.bnb_quantized,
module=self.module
)
self.module.quant_state = n.quant_state
self.data = n.data
self.quant_state = n.quant_state
return n
class ForgeLoader4Bit(torch.nn.Module):
def __init__(self, *, device, dtype, quant_type, **kwargs):
super().__init__()
self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype))
self.weight = None
self.quant_state = None
self.bias = None
self.quant_type = quant_type
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
quant_state = getattr(self.weight, "quant_state", None)
if quant_state is not None:
for k, v in quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
return
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
quant_state_keys = {k[len(prefix + "weight."):] for k in state_dict.keys() if k.startswith(prefix + "weight.")}
if any('bitsandbytes' in k for k in quant_state_keys):
quant_state_dict = {k: state_dict[prefix + "weight." + k] for k in quant_state_keys}
self.weight = ForgeParams4bit.from_prequantized(
data=state_dict[prefix + 'weight'],
quantized_stats=quant_state_dict,
requires_grad=False,
device=torch.device('cuda'),
module=self
)
self.quant_state = self.weight.quant_state
if prefix + 'bias' in state_dict:
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
del self.dummy
elif hasattr(self, 'dummy'):
if prefix + 'weight' in state_dict:
self.weight = ForgeParams4bit(
state_dict[prefix + 'weight'].to(self.dummy),
requires_grad=False,
compress_statistics=True,
quant_type=self.quant_type,
quant_storage=torch.uint8,
module=self,
)
self.quant_state = self.weight.quant_state
if prefix + 'bias' in state_dict:
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
del self.dummy
else:
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
class Linear(ForgeLoader4Bit):
def __init__(self, *args, device=None, dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, quant_type='nf4')
def forward(self, x):
self.weight.quant_state = self.quant_state
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
return functional_linear_4bits(x, self.weight, self.bias)
# Replace nn.Linear with the 4-bit quantized Linear
# torch.nn.Linear = Linear
class InitModel:
@staticmethod
def load_text_encoder() -> T5EncoderModel:
print("Loading text encoder...")
text_encoder = T5EncoderModel.from_pretrained(
"city96/t5-v1_1-xxl-encoder-bf16",
revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
torch_dtype=torch.bfloat16,
)
return text_encoder.to(memory_format=torch.channels_last)
@staticmethod
def load_vae() -> AutoencoderTiny:
print("Loading VAE model...")
vae = AutoencoderTiny.from_pretrained(
"XiangquiAI/FLUX_Vae_Model",
revision="103bcc03998f48ef311c100ee119f1b9942132ab",
torch_dtype=torch.bfloat16,
)
return vae
@staticmethod
def load_transformer(trans_path: str) -> FluxTransformer2DModel:
print("Loading transformer model...")
transformer = FluxTransformer2DModel.from_pretrained(
trans_path,
torch_dtype=torch.bfloat16,
use_safetensors=False,
)
return transformer.to(memory_format=torch.channels_last)
def load_pipeline() -> Pipeline:
transformer_path = os.path.join(HF_HUB_CACHE, "models--MyApricity--Flux_Transformer_float8/snapshots/66c5f182385555a00ec90272ab711bb6d3c197db")
transformer = InitModel.load_transformer(transformer_path)
text_encoder_2 = InitModel.load_text_encoder()
vae = InitModel.load_vae()
pipeline = DiffusionPipeline.from_pretrained(CHECKPOINT,
revision=REVISION,
vae=vae,
transformer=transformer,
text_encoder_2=text_encoder_2,
torch_dtype=torch.bfloat16)
pipeline.to("cuda")
try:
pipeline.enable_vae_slicing()
torch.nn.LinearLayer = Linear
except:
print("Using origin pipeline")
prms = [
"melanogen, tiptilt",
"melanogen, endosome, apical, polymyodous, ",
"buffer, cutie, buttinsky, prototrophic",
"puzzlehead",
]
for prompt in prms:
pipeline(prompt=prompt,
width=1024,
height=1024,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256)
return pipeline
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
remove_cache()
# remove cache here for better result
generator = Generator(pipeline.device).manual_seed(request.seed)
return pipeline(
request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
).images[0]