File size: 8,782 Bytes
fbd72be 5dec6e3 706813c 5dec6e3 98d56de 2f5adf0 5dec6e3 fbd72be 5dec6e3 fbd72be 8e641d3 fbd72be 2f5adf0 fbd72be 3fc2f94 2f5adf0 fbd72be 2f5adf0 fbd72be 2f5adf0 fbd72be 2f5adf0 0b842dd 2f5adf0 b7fb4c4 0b842dd 6005ab8 b7fb4c4 fbd72be b7fb4c4 2f5adf0 5dec6e3 019dcb1 fbd72be 5dec6e3 5829a8d 5dec6e3 2f5adf0 5dec6e3 019dcb1 3c28eba 5dec6e3 3c28eba 5dec6e3 6d3ee5b fbd72be 5dec6e3 019dcb1 5230bad 5dec6e3 019dcb1 5dec6e3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | from diffusers import FluxPipeline, AutoencoderKL, FluxTransformer2DModel #AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
import torch.nn.functional as F
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import torch._dynamo
import gc
from PIL import Image as img
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
import time
from diffusers import FluxTransformer2DModel, DiffusionPipeline
import torch.nn as nn
# from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight #PerRow,
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.01"
os.environ["HUGGINGFACE_HUB_TOKEN"] = ""
Pipeline = None
# def w8_a16_forward(weight, input, scales, bias=None):
# casted_weights = weight.to(input.dtype)
# output = F.linear(input, casted_weights) * scales # overhead
# if bias is not None:
# output = output + bias
# return output
# class W8A16LinearLayer(nn.Module):
# def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
# super().__init__()
# self.register_buffer(
# "int8_weights",
# torch.randint(-128, 127, (out_features, in_features), dtype=torch.int8))
# self.register_buffer("scales", torch.randn((out_features), dtype=dtype))
# if bias:
# self.register_buffer("bias", torch.randn((1, out_features), dtype=dtype))
# def quantize(self, weights):
# w_fp32 = weights.clone().to(torch.float32)
# scales = w_fp32.abs().max(dim=-1).values / 127
# scales = scales.to(weights.dtype)
# int8_weights = torch.round(weights/scales.unsqueeze(1)).to(torch.int8)
# self.int8_weights = int8_weights
# self.scales = scales
# self.bias = None
# def forward(self, input):
# return w8_a16_forward(self.int8_weights, input, self.scales, self.bias)
class W8A16LinearLayer(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_features, in_features, dtype=dtype))
self.weight.requires_grad = False
if bias:
self.bias = nn.Parameter(torch.randn(1, out_features, dtype=dtype))
self.scales = nn.Parameter(torch.randn(out_features, dtype=dtype))
def quantize(self, weights):
w_fp32 = weights.clone().to(torch.float32)
scales = w_fp32.abs().max(dim=-1).values / 127
scales = scales.to(weights.dtype)
self.weight.data = torch.round(weights/scales.unsqueeze(1)).to(torch.int8)
self.scales.data = scales
def forward(self, input):
casted_weights = self.weight.to(input.dtype)
output = F.linear(input, casted_weights) * self.scales
if self.bias is not None:
output = output + self.bias
return output
# def replace_linear_with_target_and_quantize(module, target_class, module_name_to_exclude):
# # with open("/root/.cache/huggingface/hub/output_layers.txt", "a") as f:
# for name, child in module.named_children():
# if isinstance(child, nn.Linear) and ( 'add_k_proj' in name or 'add_v_proj' in name or 'add_q_proj' in name ): #and not any([x == name for x in module_name_to_exclude]): 'linear' in name or
# old_bias = child.bias
# old_weight = child.weight
# new_module = target_class(child.in_features, child.out_features, old_bias is not None, child.weight.dtype)
# new_module.quantize(old_weight)
# delattr(module, name)
# setattr(module, name, new_module)
# if old_bias is not None:
# getattr(module, name).bias = old_bias
# # # Print the replaced layer name and calculate the change in size
# # old_size = old_weight.numel() * old_weight.element_size()
# # new_size = new_module.int8_weights.numel() * new_module.int8_weights.element_size()
# # f.write(f"Replaced layer: {name}" + f" Size reduction: {old_size} bytes -> {new_size} bytes ({(old_size - new_size) / old_size * 100:.2f}% reduction)")
# else:
# # Recursively call the function for nested modules
# replace_linear_with_target_and_quantize(child, target_class, module_name_to_exclude)
def replace_linear_with_target_and_quantize(module, target_class, module_name_to_exclude):
# with open("/root/.cache/huggingface/hub/output_layers.txt", "a") as f:
for name in list(module._modules.keys()):
child = module._modules[name]
if isinstance(child, nn.Linear) and ( 'add_k_proj' in name or 'add_v_proj' in name or 'add_q_proj' in name ): #and not any([x == name for x in module_name_to_exclude]): 'linear' in name or
old_bias = child.bias
old_weight = child.weight
new_module = target_class(child.in_features, child.out_features, old_bias is not None, child.weight.dtype)
new_module.quantize(old_weight)
delattr(module, name)
setattr(module, name, new_module)
if old_bias is not None:
getattr(module, name).bias = old_bias
# # Print the replaced layer name and calculate the change in size
# old_size = old_weight.numel() * old_weight.element_size()
# new_size = new_module.int8_weights.numel() * new_module.int8_weights.element_size()
# f.write(f"Replaced layer: {name}" + f" Size reduction: {old_size} bytes -> {new_size} bytes ({(old_size - new_size) / old_size * 100:.2f}% reduction)")
else:
# Recursively call the function for nested modules
replace_linear_with_target_and_quantize(child, target_class, module_name_to_exclude)
ckpt_id = "black-forest-labs/FLUX.1-schnell"
def empty_cache():
start = time.time()
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def load_pipeline() -> Pipeline:
empty_cache()
dtype, device = torch.bfloat16, "cuda"
text_encoder_2 = T5EncoderModel.from_pretrained(
"city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=torch.bfloat16
)
vae=AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=dtype)
# vae = torch.load('/root/.cache/huggingface/hub/compiled_vae.pth')
# transformer = FluxTransformer2DModel.from_pretrined("manbeast3b/transfomer-flux-schnell-int8") # torch_dtype=dtype
pipeline = DiffusionPipeline.from_pretrained(
ckpt_id,
vae=vae,
text_encoder_2 = text_encoder_2,
# transformer=transformer,
torch_dtype=dtype,
)
# quantize_(pipeline.transformer, float8_dynamic_activation_float8_weight())
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
# torch.set_deterministic_debug_mode(0)
torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
torch.cuda.set_memory_growth(True)
torch.cuda.set_per_process_memory_fraction(0.99)
pipeline.text_encoder.to(memory_format=torch.channels_last)
pipeline.transformer.to(memory_format=torch.channels_last)
# replace_linear_with_target_and_quantize(pipeline.transformer, W8A16LinearLayer, [])
# pipeline.transformer.save_pretrained("manbeast3b/transfomer-flux-schnell-int8-new", push_to_hub=True, token="")
# pipeline.transformer.save_pretrained("/root/.cache/huggingface/hub/transformer-flux")
# exit()
pipeline.vae.to(memory_format=torch.channels_last)
pipeline.vae = torch.compile(pipeline.vae)
# torch.save(pipeline.vae, '/root/.cache/huggingface/hub/compiled_vae.pth')
# exit()
pipeline._exclude_from_cpu_offload = ["vae"]
pipeline.enable_sequential_cpu_offload()
for _ in range(2):
pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
return pipeline
@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
torch.cuda.reset_peak_memory_stats()
generator = Generator("cuda").manual_seed(request.seed)
image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
return(image)
|