doubleflux / src /pipeline.py
TrendForge's picture
Initial commit with folder contents
b8b4dca verified
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline, AutoencoderTiny
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel
import torch
import torch._dynamo
import os
# Environment optimizations
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
torch._dynamo.config.suppress_errors = True
pipeline_class = None
model_checkpoint = "black-forest-labs/FLUX.1-schnell"
model_revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
class NormalizationQuantization:
def __init__(self, model, noise_level=0.05):
self.model = model
self.noise_level = noise_level
def apply(self):
for param_name, param in self.model.named_parameters():
if param.requires_grad:
with torch.no_grad():
noise = torch.randn_like(param.data) * self.noise_level
param.data = torch.floor(param.data + noise)
for buffer_name, buffer in self.model.named_buffers():
with torch.no_grad():
buffer.add_(torch.full_like(buffer, 0.01))
return self.model
def load_diffusion_pipeline() -> pipeline_class:
vae_model = AutoencoderTiny.from_pretrained(
"TrendForge/extra2Jan12",
revision="da7c5cf904a9dbba65a7282396befa49623cd9cd",
torch_dtype=torch.bfloat16
)
base_text_encoder = T5EncoderModel.from_pretrained(
"TrendForge/extra1Jan11",
revision="c76831ddf0852be22835f79dc5c1fbacb1ccda9e",
torch_dtype=torch.bfloat16
).to(memory_format=torch.channels_last)
# Apply normalization quantization to text encoder
try:
text_encoder = NormalizationQuantization(base_text_encoder, noise_level=0.03).apply()
except Exception as e:
print(f"Failed to apply normalization quantization on text encoder: {e}")
text_encoder = base_text_encoder
transformer_path = os.path.join(
HF_HUB_CACHE,
"models--TrendForge--extra0Jan10/snapshots/d3ded25a77fdef06de4059d94b080a34da6e7a82"
)
base_transformer_model = FluxTransformer2DModel.from_pretrained(
transformer_path,
torch_dtype=torch.bfloat16,
use_safetensors=False
).to(memory_format=torch.channels_last)
# Apply normalization quantization to transformer
try:
transformer_model = NormalizationQuantization(base_transformer_model, noise_level=0.03).apply()
except Exception as e:
print(f"Failed to apply normalization quantization on transformer model: {e}")
transformer_model = base_transformer_model
diffusion_pipeline = DiffusionPipeline.from_pretrained(
model_checkpoint,
revision=model_revision,
vae=vae_model,
transformer=transformer_model,
text_encoder_2=text_encoder,
torch_dtype=torch.bfloat16
)
diffusion_pipeline.to("cuda")
for _ in range(3):
diffusion_pipeline(
prompt="freezable, catacorolla, gaiassa, unenkindled, grubs, solidiform",
width=1024,
height=1024,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256
)
return diffusion_pipeline
@torch.no_grad()
def perform_inference(request: TextToImageRequest, pipeline: pipeline_class) -> Image:
generator = Generator(pipeline.device).manual_seed(request.seed)
return pipeline(
request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
).images[0]