File size: 4,776 Bytes
8e3f1b7 3e39c60 8e3f1b7 524d6b8 3e39c60 8e3f1b7 524d6b8 3e39c60 524d6b8 3e39c60 524d6b8 8e3f1b7 524d6b8 8e3f1b7 524d6b8 8e3f1b7 524d6b8 8e3f1b7 524d6b8 8e3f1b7 524d6b8 8e3f1b7 524d6b8 8e3f1b7 524d6b8 8e3f1b7 3e39c60 8e3f1b7 524d6b8 8e3f1b7 524d6b8 8e3f1b7 524d6b8 8e3f1b7 524d6b8 8e3f1b7 524d6b8 8e3f1b7 524d6b8 3e39c60 524d6b8 3e39c60 524d6b8 3e39c60 524d6b8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | import os
import torch
import torch._dynamo
import gc
from PIL.Image import Image
from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import (
T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
)
from diffusers import (
FluxPipeline, AutoencoderKL, AutoencoderTiny, FluxTransformer2DModel, DiffusionPipeline
)
from pipelines.models import TextToImageRequest
from torch import Generator
# Set environment variables
os.environ["TOKENIZERS_PARALLELISM"] = "True"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
torch._dynamo.config.suppress_errors = True
Pipeline = None
# Define constants
CHECKPOINT = "black-forest-labs/FLUX.1-schnell"
REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"
class QuantativeAnalysis:
def __init__(self, model, num_bins=256, scale_ratio=1.0):
self.model = model
self.num_bins = num_bins
self.scale_ratio = scale_ratio
def apply(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
with torch.no_grad():
param_min = param.min()
param_max = param.max()
param_range = param_max - param_min
if param_range > 0:
params = 0.8 * param_min + 0.2 * param_max
return self.model
class AttentionQuant:
def __init__(self, model, att_config):
self.model = model
self.att_config = att_config
def apply(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
layer_name = name.split(".")[0]
if layer_name in self.att_config:
num_bins, scale_factor = self.att_config[layer_name]
with torch.no_grad():
param_min = param.min()
param_max = param.max()
param_range = param_max - param_min
if param_range > 0:
normalized = (param - param_min) / param_range
binned = torch.round(normalized * (num_bins - 1)) / (num_bins - 1)
rescaled = binned * param_range + param_min
param.data.copy_(rescaled * scale_factor)
else:
param.data.zero_()
return self.model
def load_pipeline() -> Pipeline:
# Load T5 model
__t5_model = T5EncoderModel.from_pretrained(
"TrendForge/extra1manQ1",
revision="d302b6e39214ed4532be34ec337f93c7eef3eaa6",
torch_dtype=torch.bfloat16
).to(memory_format=torch.channels_last)
__text_encoder_2 = __t5_model
# Load VAE
base_vae = AutoencoderTiny.from_pretrained(
"TrendForge/extra2manQ2",
revision="cef012d2db2f5a006567e797a0b9130aea5449c1",
torch_dtype=torch.bfloat16
)
# Load Transformer Model
path = os.path.join(HF_HUB_CACHE, "models--TrendForge--extra0manQ0/snapshots/dc2cda167b8f53792a98020a3ef2f21808b09bb4")
base_trans = FluxTransformer2DModel.from_pretrained(
path, torch_dtype=torch.bfloat16, use_safetensors=False
).to(memory_format=torch.channels_last)
try:
att_config = {
"transformer_blocks.15.attn.norm_added_k.weight": (64, 0.1),
"transformer_blocks.15.attn.norm_added_q.weight": (64, 0.1),
"transformer_blocks.15.attn.norm_added_v.weight": (64, 0.1)
}
transformer = AttentionQuant(base_trans, att_config).apply()
except Exception:
transformer = base_trans
# Load pipeline
pipeline = DiffusionPipeline.from_pretrained(
CHECKPOINT,
revision=REVISION,
vae=base_vae,
transformer=transformer,
text_encoder_2=__text_encoder_2,
torch_dtype=torch.bfloat16
)
pipeline.to("cuda")
# Warmup
for _ in range(3):
pipeline(
prompt="forswearer, skullcap, Juglandales, bluelegs, cunila, carbro, Ammonites",
width=1024,
height=1024,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256
)
return pipeline
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
generator = Generator(pipeline.device).manual_seed(request.seed)
return pipeline(
request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width
).images[0]
|