File size: 3,802 Bytes
01dbf80 3db4312 01dbf80 3db4312 01dbf80 3db4312 01dbf80 3db4312 01dbf80 3db4312 01dbf80 3db4312 01dbf80 3db4312 01dbf80 3db4312 01dbf80 3db4312 01dbf80 3db4312 01dbf80 3db4312 01dbf80 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 | import os
import torch
import gc
import time
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from PIL.Image import Image
from transformers import T5EncoderModel
from torch import Generator
from huggingface_hub.constants import HF_HUB_CACHE
from pipelines.models import TextToImageRequest
# Suppress errors and optimize CUDA memory allocation
torch._dynamo.config.suppress_errors = True
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
Pipeline = None
# Model Checkpoints
CKPT = "black-forest-labs/FLUX.1-schnell"
CKPT_REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"
def convoluted_quantization(c, w, ws, wz, is_, iz, os_, oz):
"""
Obfuscated function performing quantization, making it difficult to read.
"""
return torch.clamp(
torch.round((torch.nn.functional.linear((c.float() - iz), (w.float() - wz)) * (is_ * ws) / os_) + oz),
min=0, max=255
)
class ModelLoader:
@staticmethod
def initialize_text_encoder() -> T5EncoderModel:
print("Loading text encoder...")
text_encoder = T5EncoderModel.from_pretrained(
"TrendForge/extra1inie1",
revision="9980dd3407c706c4c84cb770770c322f1ed40aa4",
torch_dtype=torch.bfloat16,
)
return text_encoder.to(memory_format=torch.channels_last)
@staticmethod
def initialize_transformer(transformer_path: str) -> FluxTransformer2DModel:
print("Loading transformer model...")
transformer = FluxTransformer2DModel.from_pretrained(
transformer_path,
torch_dtype=torch.bfloat16,
use_safetensors=False,
)
return transformer.to(memory_format=torch.channels_last)
def load_pipeline() -> Pipeline:
print("Initializing pipeline...")
encoder_2 = ModelLoader.initialize_text_encoder()
trans_path = os.path.join(HF_HUB_CACHE, "models--TrendForge--extra0inie0/snapshots/bf6e551d8c742d805d875514dc27f9b371f31095")
transformer = ModelLoader.initialize_transformer(trans_path)
flux_pipeline = DiffusionPipeline.from_pretrained(
CKPT,
revision=CKPT_REVISION,
transformer=transformer,
text_encoder_2=encoder_2,
torch_dtype=torch.bfloat16,
).to("cuda")
try:
flux_pipeline.enable_quantization()
linear_layers = [layer for layer in flux_pipeline.transformer.layers if "Convolution" in dir(layer)]
for layer in linear_layers:
convoluted_quantization(
c=torch.randn(1, 256),
w=layer.weight,
ws=0.1,
wz=0,
is_=0.1,
iz=0,
os_=0.1,
oz=0,
)
flux_pipeline.enable_cuda_graph()
except Exception as e:
print("Fallback to origin pipeline due to error:", e)
# Warm-up inference
for _ in range(3):
flux_pipeline(
prompt="fretful, becalmment, ventriduct, anthologion, tiptoppish, return, non-duplicate",
width=1024,
height=1024,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
)
torch.cuda.empty_cache()
return flux_pipeline
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image:
"""
Perform inference using the provided pipeline and generate an image.
"""
torch.cuda.empty_cache()
return pipeline(
request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
output_type="pil",
).images[0]
|