import os import torch import torch._dynamo import gc import bitsandbytes as bnb from bitsandbytes.nn.modules import Params4bit, QuantState import json import transformers from huggingface_hub.constants import HF_HUB_CACHE from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only from torch import Generator from diffusers import FluxTransformer2DModel, DiffusionPipeline from PIL.Image import Image from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny from pipelines.models import TextToImageRequest import json torch._dynamo.config.suppress_errors = True os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True" os.environ["TOKENIZERS_PARALLELISM"] = "True" CHECKPOINT = "black-forest-labs/FLUX.1-schnell" REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9" Pipeline = None import torch import math from typing import Dict, Any def remove_cache(): gc.collect() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats() # ---------------- NF4 ---------------- def functional_linear_4bits(x, weight, bias): out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state) out = out.to(x) return out def copy_quant_state(state, device=None): if state is None: return None device = device or state.absmax.device state2 = ( QuantState( absmax=state.state2.absmax.to(device), shape=state.state2.shape, code=state.state2.code.to(device), blocksize=state.state2.blocksize, quant_type=state.state2.quant_type, dtype=state.state2.dtype, ) if state.nested else None ) return QuantState( absmax=state.absmax.to(device), shape=state.shape, code=state.code, blocksize=state.blocksize, quant_type=state.quant_type, dtype=state.dtype, offset=state.offset.to(device) if state.nested else None, state2=state2, ) class ForgeParams4bit(Params4bit): def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) if device is not None and device.type == "cuda" and not self.bnb_quantized: return self._quantize(device) else: n = ForgeParams4bit( torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, quant_state=copy_quant_state(self.quant_state, device), compress_statistics=False, blocksize=64, quant_type=self.quant_type, quant_storage=self.quant_storage, bnb_quantized=self.bnb_quantized, module=self.module ) self.module.quant_state = n.quant_state self.data = n.data self.quant_state = n.quant_state return n class ForgeLoader4Bit(torch.nn.Module): def __init__(self, *, device, dtype, quant_type, **kwargs): super().__init__() self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype)) self.weight = None self.quant_state = None self.bias = None self.quant_type = quant_type def _save_to_state_dict(self, destination, prefix, keep_vars): super()._save_to_state_dict(destination, prefix, keep_vars) quant_state = getattr(self.weight, "quant_state", None) if quant_state is not None: for k, v in quant_state.as_dict(packed=True).items(): destination[prefix + "weight." + k] = v if keep_vars else v.detach() return def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): quant_state_keys = {k[len(prefix + "weight."):] for k in state_dict.keys() if k.startswith(prefix + "weight.")} if any('bitsandbytes' in k for k in quant_state_keys): quant_state_dict = {k: state_dict[prefix + "weight." + k] for k in quant_state_keys} self.weight = ForgeParams4bit.from_prequantized( data=state_dict[prefix + 'weight'], quantized_stats=quant_state_dict, requires_grad=False, device=torch.device('cuda'), module=self ) self.quant_state = self.weight.quant_state if prefix + 'bias' in state_dict: self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy)) del self.dummy elif hasattr(self, 'dummy'): if prefix + 'weight' in state_dict: self.weight = ForgeParams4bit( state_dict[prefix + 'weight'].to(self.dummy), requires_grad=False, compress_statistics=True, quant_type=self.quant_type, quant_storage=torch.uint8, module=self, ) self.quant_state = self.weight.quant_state if prefix + 'bias' in state_dict: self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy)) del self.dummy else: super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) class Linear(ForgeLoader4Bit): def __init__(self, *args, device=None, dtype=None, **kwargs): super().__init__(device=device, dtype=dtype, quant_type='nf4') def forward(self, x): self.weight.quant_state = self.quant_state if self.bias is not None and self.bias.dtype != x.dtype: self.bias.data = self.bias.data.to(x.dtype) return functional_linear_4bits(x, self.weight, self.bias) # Replace nn.Linear with the 4-bit quantized Linear # torch.nn.Linear = Linear class InitModel: @staticmethod def load_text_encoder() -> T5EncoderModel: print("Loading text encoder...") text_encoder = T5EncoderModel.from_pretrained( "city96/t5-v1_1-xxl-encoder-bf16", revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86", torch_dtype=torch.bfloat16, ) return text_encoder.to(memory_format=torch.channels_last) @staticmethod def load_vae() -> AutoencoderTiny: print("Loading VAE model...") vae = AutoencoderTiny.from_pretrained( "XiangquiAI/FLUX_Vae_Model", revision="103bcc03998f48ef311c100ee119f1b9942132ab", torch_dtype=torch.bfloat16, ) return vae @staticmethod def load_transformer(trans_path: str) -> FluxTransformer2DModel: print("Loading transformer model...") transformer = FluxTransformer2DModel.from_pretrained( trans_path, torch_dtype=torch.bfloat16, use_safetensors=False, ) return transformer.to(memory_format=torch.channels_last) def load_pipeline() -> Pipeline: transformer_path = os.path.join(HF_HUB_CACHE, "models--MyApricity--Flux_Transformer_float8/snapshots/66c5f182385555a00ec90272ab711bb6d3c197db") transformer = InitModel.load_transformer(transformer_path) text_encoder_2 = InitModel.load_text_encoder() vae = InitModel.load_vae() pipeline = DiffusionPipeline.from_pretrained(CHECKPOINT, revision=REVISION, vae=vae, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=torch.bfloat16) pipeline.to("cuda") try: pipeline.enable_vae_slicing() torch.nn.LinearLayer = Linear except: print("Using origin pipeline") prms = [ "melanogen, tiptilt", "melanogen, endosome, apical, polymyodous, ", "buffer, cutie, buttinsky, prototrophic", "puzzlehead", ] for prompt in prms: pipeline(prompt=prompt, width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256) return pipeline @torch.no_grad() def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image: remove_cache() # remove cache here for better result generator = Generator(pipeline.device).manual_seed(request.seed) return pipeline( request.prompt, generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, ).images[0]