| from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny |
| from diffusers.image_processor import VaeImageProcessor |
| from diffusers.schedulers import FlowMatchEulerDiscreteScheduler |
| from huggingface_hub.constants import HF_HUB_CACHE |
| from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel |
| import torch |
| import torch._dynamo |
| import gc |
| from PIL import Image as img |
| from PIL.Image import Image |
| from pipelines.models import TextToImageRequest |
| from torch import Generator |
| import time |
| from diffusers import DiffusionPipeline |
| from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only |
| import os |
| os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True" |
|
|
| import torch |
| import math |
| from typing import Type, Dict, Any, Tuple, Callable, Optional, Union |
| import ghanta |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from diffusers.configuration_utils import ConfigMixin, register_to_config |
| from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin |
| from diffusers.models.attention import FeedForward |
| from diffusers.models.attention_processor import ( |
| Attention, |
| AttentionProcessor, |
| FluxAttnProcessor2_0, |
| FusedFluxAttnProcessor2_0, |
| ) |
| from diffusers.models.modeling_utils import ModelMixin |
| from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle |
| from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers |
| from diffusers.utils.import_utils import is_torch_npu_available |
| from diffusers.utils.torch_utils import maybe_allow_in_graph |
| from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed |
| from diffusers.models.modeling_outputs import Transformer2DModelOutput |
| from transformers import T5EncoderModel |
| from diffusers.loaders.single_file_model import FromOriginalModelMixin |
| from diffusers.quantizers import DiffusersAutoQuantizer |
| from diffusers.models.modeling_utils import load_model_dict_into_meta |
| from diffusers.utils import deprecate, is_accelerate_available, logging |
| |
|
|
|
|
| from accelerate import init_empty_weights |
| from accelerate import infer_auto_device_map |
| from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device |
|
|
| class CustomT5EncoderModel(T5EncoderModel, FromOriginalModelMixin): |
| pass |
|
|
| def load_single_file_checkpoint( |
| pretrained_model_link_or_path, |
| force_download=False, |
| proxies=None, |
| token=None, |
| cache_dir=None, |
| local_files_only=None, |
| revision=None, |
| ): |
| import pdb; pdb.set_trace() |
| if os.path.isfile(pretrained_model_link_or_path): |
| pretrained_model_link_or_path = pretrained_model_link_or_path |
|
|
| else: |
| repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path) |
| pretrained_model_link_or_path = _get_model_file( |
| repo_id, |
| weights_name=weights_name, |
| force_download=force_download, |
| cache_dir=cache_dir, |
| proxies=proxies, |
| local_files_only=local_files_only, |
| token=token, |
| revision=revision, |
| ) |
| import pdb; pdb.set_trace() |
|
|
| checkpoint = load_state_dict(pretrained_model_link_or_path) |
|
|
| |
| while "state_dict" in checkpoint: |
| checkpoint = checkpoint["state_dict"] |
|
|
| return checkpoint |
|
|
|
|
| def convert_sd3_t5_checkpoint_to_diffusers(checkpoint): |
| keys = list(checkpoint.keys()) |
| text_model_dict = {} |
|
|
| remove_prefixes = ["text_encoders.t5xxl.transformer."] |
|
|
| for key in keys: |
| for prefix in remove_prefixes: |
| if key.startswith(prefix): |
| diffusers_key = key.replace(prefix, "") |
| text_model_dict[diffusers_key] = checkpoint.get(key) |
|
|
| return text_model_dict |
|
|
| def load_model_dict_into_meta( |
| model, |
| state_dict, |
| device=None, |
| dtype= None, |
| model_name_or_path= None, |
| hf_quantizer=None, |
| keep_in_fp32_modules=None, |
| ) : |
| if device is not None and not isinstance(device, (str, torch.device)): |
| raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.") |
| if hf_quantizer is None: |
| device = device or torch.device("cpu") |
| dtype = dtype or torch.float32 |
| is_quantized = hf_quantizer is not None |
|
|
| accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) |
| empty_state_dict = model.state_dict() |
| unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict] |
|
|
| for param_name, param in state_dict.items(): |
| if param_name not in empty_state_dict: |
| continue |
|
|
| set_module_kwargs = {} |
| |
| |
| |
| if torch.is_floating_point(param): |
| if ( |
| keep_in_fp32_modules is not None |
| and any( |
| module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules |
| ) |
| and dtype == torch.float16 |
| ): |
| param = param.to(torch.float32) |
| if accepts_dtype: |
| set_module_kwargs["dtype"] = torch.float32 |
| else: |
| param = param.to(dtype) |
| if accepts_dtype: |
| set_module_kwargs["dtype"] = dtype |
|
|
| |
| |
| import pdb; pdb.set_trace() |
| if empty_state_dict[param_name].shape != param.shape: |
| if ( |
| is_quantized |
| and hf_quantizer.pre_quantized |
| and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) |
| ): |
| hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param) |
| else: |
| model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" |
| raise ValueError( |
| f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." |
| ) |
|
|
| if is_quantized and ( |
| hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) |
| ): |
| hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys) |
| else: |
| if accepts_dtype: |
| set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs) |
| else: |
| set_module_tensor_to_device(model, param_name, device, value=param) |
|
|
| return unexpected_keys |
|
|
| def create_diffusers_t5_model_from_checkpoint_gguf( |
| cls, |
| checkpoint, |
| subfolder="", |
| config=None, |
| torch_dtype=None, |
| local_files_only=None, |
| quantization_config=None, |
| device=None, |
| force_download=False, |
| **kwargs |
| ): |
|
|
| proxies = kwargs.pop("proxies", None) |
| token = kwargs.pop("token", None) |
| cache_dir = kwargs.pop("cache_dir", None) |
| revision = kwargs.pop("revision", None) |
| import pdb;pdb.set_trace() |
|
|
| print("Entering create_diffusers_t5_model_from_checkpoint function") |
| if config: |
| print("Config is provided") |
| if isinstance(config, str): |
| print("Config is a string, converting to dictionary") |
| config = {"pretrained_model_name_or_path": config} |
| print("Config after conversion:", config) |
| else: |
| print("Config is already a dictionary") |
| config = config |
| print("Config:", config) |
|
|
| else: |
| print("Config is not provided, fetching from checkpoint") |
| config = fetch_diffusers_config(checkpoint) |
| print("Fetched config:", config) |
|
|
| model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only) |
| print("Model config created:", model_config) |
|
|
| ctx = init_empty_weights if is_accelerate_available() else nullcontext |
| print("Context created:", ctx) |
| with ctx(): |
| model = cls(model_config) |
| print("Model created:", model) |
|
|
| if not isinstance(checkpoint, dict): |
| checkpoint = load_single_file_checkpoint( |
| checkpoint, |
| force_download=force_download, |
| proxies=proxies, |
| token=token, |
| cache_dir=cache_dir, |
| local_files_only=local_files_only, |
| revision=revision, |
| ) |
|
|
| import pdb;pdb.set_trace() |
| |
| if quantization_config is not None: |
| print("Quantization config is provided, initializing hf_quantizer") |
| hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config) |
| print("Hf_quantizer created:", hf_quantizer) |
| hf_quantizer.validate_environment() |
| print("Hf_quantizer environment validated") |
| else: |
| print("Quantization config is not provided, setting hf_quantizer to None") |
| hf_quantizer = None |
| print("Hf_quantizer:", hf_quantizer) |
|
|
| |
| |
| diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint) |
| |
| print("Checkpoint loaded:") |
|
|
| |
| use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( |
| (torch_dtype == torch.float16) or (hf_quantizer is not None and hasattr(hf_quantizer, "use_keep_in_fp32_modules")) |
| ) |
| print("Use keep in fp32 modules:", use_keep_in_fp32_modules) |
| if use_keep_in_fp32_modules: |
| keep_in_fp32_modules = cls._keep_in_fp32_modules |
| if not isinstance(keep_in_fp32_modules, list): |
| keep_in_fp32_modules = [keep_in_fp32_modules] |
| print("Keep in fp32 modules:", keep_in_fp32_modules) |
| else: |
| keep_in_fp32_modules = [] |
| print("Keep in fp32 modules:", keep_in_fp32_modules) |
|
|
| |
| if hf_quantizer is not None: |
| print("Hf_quantizer is available, preprocessing model") |
| hf_quantizer.preprocess_model( |
| model=model, |
| device_map=None, |
| state_dict=diffusers_format_checkpoint, |
| keep_in_fp32_modules=keep_in_fp32_modules, |
| ) |
| print("Model preprocessed") |
|
|
| model = model.to_empty(device=device) |
| print("moved model to empty") |
|
|
| if is_accelerate_available(): |
| print("Accelerate is available") |
| param_device = torch.device(device) if device is not None else torch.device("cpu") |
| print("Param device:", param_device) |
| unexpected_keys = load_model_dict_into_meta( |
| model, |
| diffusers_format_checkpoint, |
| dtype=torch_dtype, |
| device=param_device, |
| hf_quantizer=hf_quantizer, |
| keep_in_fp32_modules=keep_in_fp32_modules, |
| ) |
| print("Unexpected keys:", unexpected_keys) |
| if model._keys_to_ignore_on_load_unexpected is not None: |
| for pat in model._keys_to_ignore_on_load_unexpected: |
| unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] |
| print("Unexpected keys after filtering:", unexpected_keys) |
|
|
| if len(unexpected_keys) > 0: |
| logger.warning( |
| f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" |
| ) |
| print("Warning: some weights were not used") |
|
|
| else: |
| print("Accelerate is not available, loading state dict directly") |
| model.load_state_dict(diffusers_format_checkpoint) |
| print("State dict loaded") |
|
|
| |
| if hf_quantizer is not None: |
| print("Hf_quantizer is available, postprocessing model") |
| hf_quantizer.postprocess_model(model) |
| model.hf_quantizer = hf_quantizer |
| print("Model postprocessed") |
|
|
| |
| if torch_dtype is not None and hf_quantizer is None: |
| print("No quantization, converting to torch_dtype") |
| model.to(torch_dtype) |
| print("Model converted to torch_dtype") |
|
|
| print("Returning model") |
| return model |
|
|
|
|
|
|
| Pipeline = None |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.enabled = True |
| torch.backends.cudnn.benchmark = True |
|
|
| ckpt_id = "black-forest-labs/FLUX.1-schnell" |
| ckpt_revision = "741f7c3ce8b383c54771c7003378a50191e9efe9" |
| def empty_cache(): |
| gc.collect() |
| torch.cuda.empty_cache() |
| torch.cuda.reset_max_memory_allocated() |
| torch.cuda.reset_peak_memory_stats() |
|
|
| def load_pipeline() -> Pipeline: |
| empty_cache() |
|
|
| dtype, device = torch.bfloat16, "cuda" |
|
|
| ''' |
| ckpt_path = ("https://huggingface.co/manbeast3b/t5-v1_1-xxl-encoder-q8/blob/main/t5-v1_1-xxl-encoder-Q8_0.gguf") |
| text_encoder_2 = T5EncoderModel.from_single_file( |
| ckpt_path, |
| quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), |
| torch_dtype=torch.bfloat16, |
| ).to(memory_format=torch.channels_last) |
| |
| vae = AutoencoderTiny.from_pretrained("RobertML/FLUX.1-schnell-vae_e3m2", revision="da0d2cd7815792fb40d084dbd8ed32b63f153d8d", torch_dtype=dtype) |
| |
| path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a") |
| generator = torch.Generator(device=device) |
| model = FluxTransformer2DModel.from_pretrained(path, torch_dtype=dtype, use_safetensors=False, generator= generator).to(memory_format=torch.channels_last) |
| torch.backends.cudnn.benchmark = True |
| torch.backends.cudnn.deterministic = False |
| pipeline = DiffusionPipeline.from_pretrained( |
| ckpt_id, |
| vae=vae, |
| revision=ckpt_revision, |
| transformer=model, |
| text_encoder_2=text_encoder_2, |
| torch_dtype=dtype, |
| ).to(device) |
| pipeline.vae = torch.compile(pipeline.vae) |
| ''' |
|
|
| ''' |
| from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig |
| ckpt_path = ("https://huggingface.co/manbeast3b/FLUX.1-schnell-Q5/blob/main/flux1-schnell-Q5_0.gguf") |
| transformer = FluxTransformer2DModel.from_single_file( |
| (os.path.join(HF_HUB_CACHE, "models--manbeast3b--FLUX.1-schnell-Q5/snapshots/ae345440b85f765d755dc8649607282d3ef3c069/flux1-schnell-Q5_0.gguf")), |
| quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), |
| torch_dtype=torch.bfloat16, |
| local_files_only=True, |
| ) |
| pipeline = FluxPipeline.from_pretrained( |
| "black-forest-labs/FLUX.1-schnell", |
| transformer=transformer, |
| generator=torch.manual_seed(0), |
| torch_dtype=torch.bfloat16, |
| ).to('cuda') |
| |
| # Average Similarity: 0.7995357004599877 |
| # Min Similarity: 0.6877657011269583 |
| |
| ''' |
|
|
| |
| |
| from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig |
| from diffusers.loaders.single_file_utils import create_diffusers_t5_model_from_checkpoint |
| from diffusers.loaders.single_file_model import FromOriginalModelMixin |
|
|
| import pdb; pdb.set_trace() |
| |
| t5_path = os.path.join(HF_HUB_CACHE, "models--manbeast3b--t5-v1_1-xxl-encoder-q8/snapshots/59c6c9cb99dcea42067f32caac3ea0836ef4c548/t5-v1_1-xxl-encoder-Q8_0.gguf") |
| |
| config_path = os.path.join(HF_HUB_CACHE, "models--black-forest-labs--FLUX.1-schnell/snapshots/741f7c3ce8b383c54771c7003378a50191e9efe9/") |
| ckpt_t5 = load_single_file_checkpoint(t5_path,local_files_only=True) |
| print("loaded ckpt") |
| |
| |
| |
| t5 = create_diffusers_t5_model_from_checkpoint(cls=T5EncoderModel, checkpoint=ckpt_t5, config=config_path) |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| for name, param in t5.named_parameters(): |
| print(f"Parameter: {name}, Device: {param.device}") |
| |
| |
| if any(param.is_meta for param in t5.parameters()): |
| print("Model is still on the meta device!") |
| else: |
| print("Model weights are loaded onto a real device!") |
|
|
| |
| t5 = t5.to("cuda") |
| |
| |
| for name, param in t5.named_parameters(): |
| print(f"Parameter: {name}, Device: {param.device}") |
| |
| |
| print("T5 created") |
| pipeline = FluxPipeline.from_pretrained( |
| config_path, |
| text_encoder_2 = t5, |
| generator=torch.manual_seed(0), |
| torch_dtype=torch.bfloat16, |
| ).to("cuda") |
| print("pipeline created") |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| for _ in range(3): |
| pipeline(prompt="blah blah waah waah oneshot oneshot gang gang", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256) |
| |
| empty_cache() |
| return pipeline |
|
|
|
|
| @torch.no_grad() |
| def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image: |
| image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0] |
| return image |