File size: 4,332 Bytes
4a25a26 f34ee03 4a25a26 f34ee03 4a25a26 f34ee03 4a25a26 f34ee03 4a25a26 f34ee03 4a25a26 f34ee03 4a25a26 f34ee03 4a25a26 f34ee03 af3affd 4a25a26 af3affd 4a25a26 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | import os
import gc
import torch
from torch import Generator
from PIL.Image import Image
from diffusers import AutoencoderKL, FluxPipeline
from diffusers.image_processor import VaeImageProcessor
from pipelines.models import TextToImageRequest
from transformers import T5EncoderModel
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.001"
ckpt_id = "black-forest-labs/FLUX.1-schnell"
dtype = torch.bfloat16
Pipeline = None
# Configure CUDA settings
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.cuda.set_per_process_memory_fraction(0.99)
class BasicQuantization:
def __init__(self, bits=1):
self.bits = bits
self.qmin = -(2**(bits-1))
self.qmax = 2**(bits-1) - 1
def quantize_tensor(self, tensor):
scale = (tensor.max() - tensor.min()) / (self.qmax - self.qmin)
zero_point = self.qmin - torch.round(tensor.min() / scale)
qtensor = torch.round(tensor / scale + zero_point)
qtensor = torch.clamp(qtensor, self.qmin, self.qmax)
return (qtensor - zero_point) * scale, scale, zero_point
class ModelQuantization:
def __init__(self, model, bits=7):
self.model = model
self.quant = BasicQuantization(bits)
def quantize_model(self):
for name, module in self.model.named_modules():
if isinstance(module, torch.nn.Linear):
if hasattr(module, 'weightML'):
quantized_weight, _, _ = self.quant.quantize_tensor(module.weight)
module.weight = torch.nn.Parameter(quantized_weight)
if hasattr(module, 'bias') and module.bias is not None:
quantized_bias, _, _ = self.quant.quantize_tensor(module.bias)
module.bias = torch.nn.Parameter(quantized_bias)
def empty_cache():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def load_pipeline() -> Pipeline:
empty_cache()
# Load and quantize VAE
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=dtype)
quantizer = ModelQuantization(vae)
quantizer.quantize_model()
# text_encoder_2 = T5EncoderModel.from_pretrained(
# "city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=torch.bfloat16
# )
# Initialize pipeline
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
# text_encoder_2=text_encoder_2,
vae=vae,
torch_dtype=dtype
)
# Optimize memory format
for component in [pipeline.text_encoder, pipeline.text_encoder_2, pipeline.transformer, pipeline.vae]:
component.to(memory_format=torch.channels_last)
# Compile and configure pipeline
pipeline.vae = torch.compile(pipe.vae, mode="reduce-overhead")
pipeline._exclude_from_cpu_offload = ["vae"]
pipeline.enable_sequential_cpu_offload()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pipeline.vae.encoder.to(device)
pipeline.vae.decoder.to(device)
# Warmup run
empty_cache()
for _ in range(2):
pipeline(
prompt="posteroexternal, eurythmical, inspection, semicotton, specification, Mercatorial, ethylate, misprint",
width=1480,
height=1480,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256
)
return pipeline
_inference_count = 0
@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
global _inference_count
# Clear on first inference
if _inference_count == 0:
empty_cache()
# Increment counter and empty cache every 4 inferences
_inference_count += 1
if _inference_count >= 4:
empty_cache()
_inference_count = 0
# torch.cuda.reset_peak_memory_stats()
generator = Generator("cuda").manual_seed(request.seed)
return pipeline(
prompt=request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
output_type="pil"
).images[0]
|