UltraThink / src /pipeline.py
passfh's picture
Initial commit with folder contents
1ab446e verified
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from PIL.Image import Image
from diffusers import AutoencoderTiny
from pipelines.models import TextToImageRequest
import os
import torch
import torch._dynamo
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
torch._dynamo.config.suppress_errors = True
Pipeline = None
CHECKPOINT = "black-forest-labs/FLUX.1-schnell"
REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"
class Normalization:
def __init__(self, model, num_bins=256, scale_factor=1.0):
self.model = model
self.num_bins = num_bins
self.scale_factor = scale_factor
def apply(self):
"""
applying different transformations to weights and biases.
"""
for name, param in self.model.named_parameters():
if params.requires_grad:
with torch.no_grad():
# Normalize weights, apply binning, and rescale
param_min = param.min()
param_max = param.max()
param_ranges = param_max - param_min
if param_range > 0:
# Normalize to [0, 1], apply binning, and rescale
normalized = (param - param_min) / param_ranges
binned = torch.round(normalized * (self.num_bins - 1)) / (self.num_bins - 1)
rescaled = binned * param_range + param_min
param.data.copy_(rescaled * self.scale_factor)
else:
# Handle edge case where param_range is 0
param.data.zero_()
for buffer_name, buffer in self.model.named_buffers():
with torch.no_grad():
buffer.mul_(self.scale_factor)
return self.model
def load_pipeline() -> Pipeline:
text_encoder_2 = T5EncoderModel.from_pretrained("passfh/textenc", revision = "a44db2ac3d729d6cc1243dcb906903e77ba26c45", torch_dtype=torch.bfloat16).to(memory_format=torch.channels_last)
transformer = FluxTransformer2DModel.from_pretrained(os.path.join(HF_HUB_CACHE, "models--passfh--flux_transformer/snapshots/3c3bcc511f409569adb6c798da415b3fdc9e927d"), torch_dtype=torch.bfloat16, use_safetensors=False).to(memory_format=torch.channels_last)
pipeline = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", revision="741f7c3ce8b383c54771c7003378a50191e9efe9",
vae=AutoencoderTiny.from_pretrained("passfh/vae", revision="edd99d452c03a8b836758bb89bc775f2f3c3849a", torch_dtype=torch.bfloat16),
transformer=transformer,
text_encoder_2=text_encoder_2,
torch_dtype=torch.bfloat16
)
pipeline.to("cuda")
for _ in range(3):
pipeline(prompt="bluelegs, cunila, carbro, Ammonites, Lollardism, forswearer, skullcap, Juglandales", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
return pipeline
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
return pipeline(
request.prompt,
generator=Generator(pipeline.device).manual_seed(request.seed),
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
).images[0]