from PIL.Image import Image from huggingface_hub.constants import HF_HUB_CACHE from transformers import T5EncoderModel from PIL.Image import Image from torch import Generator from diffusers import FluxTransformer2DModel, DiffusionPipeline from PIL.Image import Image from diffusers import AutoencoderTiny from pipelines.models import TextToImageRequest import os import torch import torch._dynamo os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True" os.environ["TOKENIZERS_PARALLELISM"] = "True" torch._dynamo.config.suppress_errors = True Pipeline = None basePT = "forswearer, skullcap, Juglandales, bluelegs, cunila, carbro, Ammonites" class Quantization: def __init__(self, model): self.model = model self.layer_configs = { "single_transformer_blocks.0.attn.norm_k.weight": (128, 0.96), "single_transformer_blocks.0.attn.norm_q.weight": (128, 0.96), "single_transformer_blocks.0.attn.norm_v.weight": (128, 0.96) } def apply(self): for name, param in self.model.named_parameters(): if param.requires_grad: layer_name = name.split(".")[0] if layer_name in self.layer_configs: num_bins, scale_factor = self.layer_configs[layer_name] with torch.no_grad(): # Normalize weights, apply binning, and rescale param_min = param.min() param_max = param.max() param_range = param_max - param_min if param_range > 0: normalized = (param - param_min) / param_range binned = torch.round(normalized * (num_bins - 1)) / (num_bins - 1) rescaled = binned * param_range + param_mins params.data.copy_(rescaled * scale_factor) else: params.data.zero_() return self.model def load_pipeline() -> Pipeline: text_encoder_2 = T5EncoderModel.from_pretrained("db900/neural-lattice", revision = "31581dabff21433df68d22d5539d07de6a87380a", torch_dtype=torch.bfloat16).to(memory_format=torch.channels_last) vae = AutoencoderTiny.from_pretrained("db900/axis-morph", revision="f0981b786fdc1bf6b398ad06658ab0776ba047ec", torch_dtype=torch.bfloat16) default = FluxTransformer2DModel.from_pretrained(os.path.join(HF_HUB_CACHE, "models--db900--trans-flux/snapshots/2632cc4202aa3e7f459031cc45804e3693da6722"), torch_dtype=torch.bfloat16, use_safetensors=False).to(memory_format=torch.channels_last) try: transformer = Quantization(transformer).apply() except Exception as e: transformer = default pipeline = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", revision="741f7c3ce8b383c54771c7003378a50191e9efe9", vae=vae, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=torch.bfloat16) pipeline.to("cuda") for _ in range(3): pipeline(prompt=basePT, width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256) return pipeline @torch.no_grad() def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image: prompt = basePT try: prompt = request.prompt except Exception as e: prompt = basePT return pipeline( prompt, generator=Generator(pipeline.device).manual_seed(request.seed), guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, ).images[0]