|
|
import os |
|
|
import torch |
|
|
import torch._dynamo |
|
|
import gc |
|
|
import bitsandbytes as bnb |
|
|
from bitsandbytes.nn.modules import Params4bit, QuantState |
|
|
import json |
|
|
import transformers |
|
|
from huggingface_hub.constants import HF_HUB_CACHE |
|
|
from transformers import T5EncoderModel, T5TokenizerFast |
|
|
|
|
|
from torch import Generator |
|
|
from diffusers import FluxTransformer2DModel, DiffusionPipeline |
|
|
|
|
|
from PIL.Image import Image |
|
|
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny |
|
|
from pipelines.models import TextToImageRequest |
|
|
import json |
|
|
|
|
|
|
|
|
|
|
|
torch._dynamo.config.suppress_errors = True |
|
|
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True" |
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "True" |
|
|
|
|
|
CHECKPOINT = "black-forest-labs/FLUX.1-schnell" |
|
|
REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9" |
|
|
Pipeline = None |
|
|
|
|
|
|
|
|
|
|
|
def quantized_matrix_multiply(x, weight, bias): |
|
|
"""Perform matrix multiplication for 4-bit quantized weights.""" |
|
|
output = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state) |
|
|
return output.to(x) |
|
|
|
|
|
def copy_quant_state(state, device=None): |
|
|
"""Create a copy of quantization state for a given device.""" |
|
|
if state is None: |
|
|
return None |
|
|
|
|
|
device = device or state.absmax.device |
|
|
nested_state = ( |
|
|
QuantState( |
|
|
absmax=state.state2.absmax.to(device), |
|
|
shape=state.state2.shape, |
|
|
code=state.state2.code.to(device), |
|
|
blocksize=state.state2.blocksize, |
|
|
quant_type=state.state2.quant_type, |
|
|
dtype=state.state2.dtype, |
|
|
) |
|
|
if state.nested else None |
|
|
) |
|
|
|
|
|
return QuantState( |
|
|
absmax=state.absmax.to(device), |
|
|
shape=state.shape, |
|
|
code=state.code, |
|
|
blocksize=state.blocksize, |
|
|
quant_type=state.quant_type, |
|
|
dtype=state.dtype, |
|
|
offset=state.offset.to(device) if state.nested else None, |
|
|
state2=nested_state, |
|
|
) |
|
|
|
|
|
class QuantizedModelParams(Params4bit): |
|
|
def to(self, *args, **kwargs): |
|
|
device, dtype, non_blocking, _ = torch._C._nn._parse_to(*args, **kwargs) |
|
|
if device is not None and device.type == "cuda" and not self.bnb_quantized: |
|
|
return self._quantize(device) |
|
|
|
|
|
updated_params = QuantizedModelParams( |
|
|
torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking), |
|
|
requires_grad=self.requires_grad, |
|
|
quant_state=copy_quant_state(self.quant_state, device), |
|
|
compress_statistics=False, |
|
|
blocksize=64, |
|
|
quant_type=self.quant_type, |
|
|
quant_storage=self.quant_storage, |
|
|
bnb_quantized=self.bnb_quantized, |
|
|
module=self.module |
|
|
) |
|
|
self.module.quant_state = updated_params.quant_state |
|
|
self.data = updated_params.data |
|
|
self.quant_state = updated_params.quant_state |
|
|
return updated_params |
|
|
|
|
|
class QuantizedLinearLayer(torch.nn.Module): |
|
|
def __init__(self, *args, device=None, dtype=None, **kwargs): |
|
|
super().__init__() |
|
|
self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype)) |
|
|
self.weight = None |
|
|
self.quant_state = None |
|
|
self.bias = None |
|
|
self.quant_type = 'nf4' |
|
|
|
|
|
def forward(self, x): |
|
|
self.weight.quant_state = self.quant_state |
|
|
if self.bias is not None and self.bias.dtype != x.dtype: |
|
|
self.bias.data = self.bias.data.to(x.dtype) |
|
|
return quantized_matrix_multiply(x, self.weight, self.bias) |
|
|
|
|
|
|
|
|
class InitModel: |
|
|
|
|
|
@staticmethod |
|
|
def load_text_encoder() -> T5EncoderModel: |
|
|
print("Loading text encoder...") |
|
|
text_encoder = T5EncoderModel.from_pretrained( |
|
|
"city96/t5-v1_1-xxl-encoder-bf16", |
|
|
revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86", |
|
|
torch_dtype=torch.bfloat16, |
|
|
) |
|
|
return text_encoder.to(memory_format=torch.channels_last) |
|
|
|
|
|
@staticmethod |
|
|
def load_transformer(trans_path: str) -> FluxTransformer2DModel: |
|
|
print("Loading transformer model...") |
|
|
transformer = FluxTransformer2DModel.from_pretrained( |
|
|
trans_path, |
|
|
torch_dtype=torch.bfloat16, |
|
|
use_safetensors=False, |
|
|
) |
|
|
return transformer.to(memory_format=torch.channels_last) |
|
|
|
|
|
|
|
|
|
|
|
def load_pipeline() -> Pipeline: |
|
|
|
|
|
|
|
|
|
|
|
transformer_path = os.path.join(HF_HUB_CACHE, "models--MyApricity--Flux_Transformer_float8/snapshots/66c5f182385555a00ec90272ab711bb6d3c197db") |
|
|
transformer = InitModel.load_transformer(transformer_path) |
|
|
|
|
|
|
|
|
pipeline = DiffusionPipeline.from_pretrained(CHECKPOINT, |
|
|
revision=REVISION, |
|
|
transformer=transformer, |
|
|
torch_dtype=torch.bfloat16) |
|
|
pipeline.to("cuda") |
|
|
|
|
|
try: |
|
|
|
|
|
pipeline.enable_vae_slicing() |
|
|
pipeline.enable_vae_tiling() |
|
|
torch.nn.LinearLayer = QuantizedLinearLayer |
|
|
except: |
|
|
print("Debug here") |
|
|
|
|
|
try: |
|
|
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True) |
|
|
|
|
|
except: |
|
|
print("nothing") |
|
|
|
|
|
|
|
|
ps = [ |
|
|
"overgross, mandative, inventful, braunite, penneeck", |
|
|
"melanogen, endosome, apical, polymyodous, ", |
|
|
"buffer, cutie, buttinsky, prototrophic", |
|
|
"puzzlehead", |
|
|
] |
|
|
|
|
|
for warmprompt in ps: |
|
|
pipeline(prompt=warmprompt, |
|
|
width=1024, |
|
|
height=1024, |
|
|
guidance_scale=0.0, |
|
|
num_inference_steps=4, |
|
|
max_sequence_length=256) |
|
|
|
|
|
return pipeline |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image: |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.reset_max_memory_allocated() |
|
|
torch.cuda.reset_peak_memory_stats() |
|
|
|
|
|
|
|
|
generator = Generator(pipeline.device).manual_seed(request.seed) |
|
|
|
|
|
return pipeline( |
|
|
request.prompt, |
|
|
generator=generator, |
|
|
guidance_scale=0.0, |
|
|
num_inference_steps=4, |
|
|
max_sequence_length=256, |
|
|
height=request.height, |
|
|
width=request.width, |
|
|
).images[0] |