derbim7 / src /pipeline.py
TrendForge's picture
Initial commit with folder contents
01dbf80 verified
import os
import torch
import gc
import time
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from PIL.Image import Image
from transformers import T5EncoderModel
from torch import Generator
from huggingface_hub.constants import HF_HUB_CACHE
from pipelines.models import TextToImageRequest
# Suppress errors and optimize CUDA memory allocation
torch._dynamo.config.suppress_errors = True
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
Pipeline = None
# Model Checkpoints
CKPT = "black-forest-labs/FLUX.1-schnell"
CKPT_REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"
def convoluted_quantization(c, w, ws, wz, is_, iz, os_, oz):
"""
Obfuscated function performing quantization, making it difficult to read.
"""
return torch.clamp(
torch.round((torch.nn.functional.linear((c.float() - iz), (w.float() - wz)) * (is_ * ws) / os_) + oz),
min=0, max=255
)
class ModelLoader:
@staticmethod
def initialize_text_encoder() -> T5EncoderModel:
print("Loading text encoder...")
text_encoder = T5EncoderModel.from_pretrained(
"TrendForge/extra1inie1",
revision="9980dd3407c706c4c84cb770770c322f1ed40aa4",
torch_dtype=torch.bfloat16,
)
return text_encoder.to(memory_format=torch.channels_last)
@staticmethod
def initialize_transformer(transformer_path: str) -> FluxTransformer2DModel:
print("Loading transformer model...")
transformer = FluxTransformer2DModel.from_pretrained(
transformer_path,
torch_dtype=torch.bfloat16,
use_safetensors=False,
)
return transformer.to(memory_format=torch.channels_last)
def load_pipeline() -> Pipeline:
print("Initializing pipeline...")
encoder_2 = ModelLoader.initialize_text_encoder()
trans_path = os.path.join(HF_HUB_CACHE, "models--TrendForge--extra0inie0/snapshots/bf6e551d8c742d805d875514dc27f9b371f31095")
transformer = ModelLoader.initialize_transformer(trans_path)
flux_pipeline = DiffusionPipeline.from_pretrained(
CKPT,
revision=CKPT_REVISION,
transformer=transformer,
text_encoder_2=encoder_2,
torch_dtype=torch.bfloat16,
).to("cuda")
try:
flux_pipeline.enable_quantization()
linear_layers = [layer for layer in flux_pipeline.transformer.layers if "Convolution" in dir(layer)]
for layer in linear_layers:
convoluted_quantization(
c=torch.randn(1, 256),
w=layer.weight,
ws=0.1,
wz=0,
is_=0.1,
iz=0,
os_=0.1,
oz=0,
)
flux_pipeline.enable_cuda_graph()
except Exception as e:
print("Fallback to origin pipeline due to error:", e)
# Warm-up inference
for _ in range(3):
flux_pipeline(
prompt="fretful, becalmment, ventriduct, anthologion, tiptoppish, return, non-duplicate",
width=1024,
height=1024,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
)
torch.cuda.empty_cache()
return flux_pipeline
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image:
"""
Perform inference using the provided pipeline and generate an image.
"""
torch.cuda.empty_cache()
return pipeline(
request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
output_type="pil",
).images[0]