| import os |
| import torch |
| import torch._dynamo |
| import gc |
| import bitsandbytes as bnb |
| from bitsandbytes.nn.modules import Params4bit, QuantState |
| import json |
| import transformers |
| from huggingface_hub.constants import HF_HUB_CACHE |
| from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel |
|
|
| from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only |
| from torch import Generator |
| from diffusers import FluxTransformer2DModel, DiffusionPipeline |
|
|
| from PIL.Image import Image |
| from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny |
| from pipelines.models import TextToImageRequest |
| import json |
|
|
|
|
|
|
|
|
| torch._dynamo.config.suppress_errors = True |
| os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True" |
| os.environ["TOKENIZERS_PARALLELISM"] = "True" |
|
|
| CHECKPOINT = "black-forest-labs/FLUX.1-schnell" |
| REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9" |
| Pipeline = None |
|
|
|
|
| import torch |
| import math |
| from typing import Dict, Any |
|
|
| def remove_cache(): |
| gc.collect() |
| torch.cuda.empty_cache() |
| torch.cuda.reset_max_memory_allocated() |
| torch.cuda.reset_peak_memory_stats() |
|
|
| |
| def functional_linear_4bits(x, weight, bias): |
| out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state) |
| out = out.to(x) |
| return out |
|
|
|
|
| def copy_quant_state(state, device=None): |
| if state is None: |
| return None |
|
|
| device = device or state.absmax.device |
|
|
| state2 = ( |
| QuantState( |
| absmax=state.state2.absmax.to(device), |
| shape=state.state2.shape, |
| code=state.state2.code.to(device), |
| blocksize=state.state2.blocksize, |
| quant_type=state.state2.quant_type, |
| dtype=state.state2.dtype, |
| ) |
| if state.nested |
| else None |
| ) |
|
|
| return QuantState( |
| absmax=state.absmax.to(device), |
| shape=state.shape, |
| code=state.code, |
| blocksize=state.blocksize, |
| quant_type=state.quant_type, |
| dtype=state.dtype, |
| offset=state.offset.to(device) if state.nested else None, |
| state2=state2, |
| ) |
|
|
|
|
| class ForgeParams4bit(Params4bit): |
| def to(self, *args, **kwargs): |
| device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) |
| if device is not None and device.type == "cuda" and not self.bnb_quantized: |
| return self._quantize(device) |
| else: |
| n = ForgeParams4bit( |
| torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking), |
| requires_grad=self.requires_grad, |
| quant_state=copy_quant_state(self.quant_state, device), |
| compress_statistics=False, |
| blocksize=64, |
| quant_type=self.quant_type, |
| quant_storage=self.quant_storage, |
| bnb_quantized=self.bnb_quantized, |
| module=self.module |
| ) |
| self.module.quant_state = n.quant_state |
| self.data = n.data |
| self.quant_state = n.quant_state |
| return n |
|
|
|
|
| class ForgeLoader4Bit(torch.nn.Module): |
| def __init__(self, *, device, dtype, quant_type, **kwargs): |
| super().__init__() |
| self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype)) |
| self.weight = None |
| self.quant_state = None |
| self.bias = None |
| self.quant_type = quant_type |
|
|
| def _save_to_state_dict(self, destination, prefix, keep_vars): |
| super()._save_to_state_dict(destination, prefix, keep_vars) |
| quant_state = getattr(self.weight, "quant_state", None) |
| if quant_state is not None: |
| for k, v in quant_state.as_dict(packed=True).items(): |
| destination[prefix + "weight." + k] = v if keep_vars else v.detach() |
| return |
|
|
| def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): |
| quant_state_keys = {k[len(prefix + "weight."):] for k in state_dict.keys() if k.startswith(prefix + "weight.")} |
|
|
| if any('bitsandbytes' in k for k in quant_state_keys): |
| quant_state_dict = {k: state_dict[prefix + "weight." + k] for k in quant_state_keys} |
|
|
| self.weight = ForgeParams4bit.from_prequantized( |
| data=state_dict[prefix + 'weight'], |
| quantized_stats=quant_state_dict, |
| requires_grad=False, |
| device=torch.device('cuda'), |
| module=self |
| ) |
| self.quant_state = self.weight.quant_state |
|
|
| if prefix + 'bias' in state_dict: |
| self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy)) |
|
|
| del self.dummy |
| elif hasattr(self, 'dummy'): |
| if prefix + 'weight' in state_dict: |
| self.weight = ForgeParams4bit( |
| state_dict[prefix + 'weight'].to(self.dummy), |
| requires_grad=False, |
| compress_statistics=True, |
| quant_type=self.quant_type, |
| quant_storage=torch.uint8, |
| module=self, |
| ) |
| self.quant_state = self.weight.quant_state |
|
|
| if prefix + 'bias' in state_dict: |
| self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy)) |
|
|
| del self.dummy |
| else: |
| super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) |
|
|
|
|
| class Linear(ForgeLoader4Bit): |
| def __init__(self, *args, device=None, dtype=None, **kwargs): |
| super().__init__(device=device, dtype=dtype, quant_type='nf4') |
|
|
| def forward(self, x): |
| self.weight.quant_state = self.quant_state |
|
|
| if self.bias is not None and self.bias.dtype != x.dtype: |
| self.bias.data = self.bias.data.to(x.dtype) |
|
|
| return functional_linear_4bits(x, self.weight, self.bias) |
|
|
|
|
| |
| |
|
|
| class InitModel: |
|
|
| @staticmethod |
| def load_text_encoder() -> T5EncoderModel: |
| print("Loading text encoder...") |
| text_encoder = T5EncoderModel.from_pretrained( |
| "city96/t5-v1_1-xxl-encoder-bf16", |
| revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86", |
| torch_dtype=torch.bfloat16, |
| ) |
| return text_encoder.to(memory_format=torch.channels_last) |
|
|
| @staticmethod |
| def load_vae() -> AutoencoderTiny: |
| print("Loading VAE model...") |
| vae = AutoencoderTiny.from_pretrained( |
| "XiangquiAI/FLUX_Vae_Model", |
| revision="103bcc03998f48ef311c100ee119f1b9942132ab", |
| torch_dtype=torch.bfloat16, |
| ) |
| return vae |
|
|
| @staticmethod |
| def load_transformer(trans_path: str) -> FluxTransformer2DModel: |
| print("Loading transformer model...") |
| transformer = FluxTransformer2DModel.from_pretrained( |
| trans_path, |
| torch_dtype=torch.bfloat16, |
| use_safetensors=False, |
| ) |
| return transformer.to(memory_format=torch.channels_last) |
|
|
|
|
|
|
| def load_pipeline() -> Pipeline: |
|
|
|
|
| transformer_path = os.path.join(HF_HUB_CACHE, "models--MyApricity--Flux_Transformer_float8/snapshots/66c5f182385555a00ec90272ab711bb6d3c197db") |
| transformer = InitModel.load_transformer(transformer_path) |
| |
| text_encoder_2 = InitModel.load_text_encoder() |
| vae = InitModel.load_vae() |
|
|
|
|
| pipeline = DiffusionPipeline.from_pretrained(CHECKPOINT, |
| revision=REVISION, |
| vae=vae, |
| transformer=transformer, |
| text_encoder_2=text_encoder_2, |
| torch_dtype=torch.bfloat16) |
| pipeline.to("cuda") |
| try: |
| pipeline.enable_vae_slicing() |
| torch.nn.LinearLayer = Linear |
| except: |
| print("Using origin pipeline") |
| |
|
|
| prms = [ |
| "melanogen, tiptilt", |
| "melanogen, endosome, apical, polymyodous, ", |
| "buffer, cutie, buttinsky, prototrophic", |
| "puzzlehead", |
| ] |
| |
| for __ in prms: |
| pipeline(prompt=p, |
| width=1024, |
| height=1024, |
| guidance_scale=0.0, |
| num_inference_steps=4, |
| max_sequence_length=256) |
|
|
| return pipeline |
|
|
|
|
| @torch.no_grad() |
| def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image: |
|
|
| remove_cache() |
| |
| generator = Generator(pipeline.device).manual_seed(request.seed) |
|
|
| return pipeline( |
| request.prompt, |
| generator=generator, |
| guidance_scale=0.0, |
| num_inference_steps=4, |
| max_sequence_length=256, |
| height=request.height, |
| width=request.width, |
| ).images[0] |