| |
| import os |
| import torch |
| import torch._dynamo |
| import gc |
|
|
|
|
| from huggingface_hub.constants import HF_HUB_CACHE |
| from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel |
|
|
| from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only |
| from torch import Generator |
| from diffusers import FluxTransformer2DModel, DiffusionPipeline |
|
|
| from PIL.Image import Image |
| from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny |
| from pipelines.models import TextToImageRequest |
| from optimum.quanto import requantize |
| import json |
| import transformers |
| import torch |
| import gc |
| import os |
| import json |
| import transformers |
|
|
|
|
| torch._dynamo.config.suppress_errors = True |
| os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True" |
| os.environ["TOKENIZERS_PARALLELISM"] = "True" |
|
|
| CHECKPOINT = "black-forest-labs/FLUX.1-schnell" |
| REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9" |
| Pipeline = None |
|
|
| class CleanAndOptimization: |
| def __init__(self, model, device="cuda"): |
| self.model = model |
| self.device = device |
| self.cache = {} |
|
|
| @staticmethod |
| def enhance_performance(): |
| torch.backends.cudnn.benchmark = True |
| torch.backends.cudnn.deterministic = False |
| return "Torch backend opt" |
| |
| def preprocess(self, data): |
| return [d[::-1] for d in data] |
|
|
| def quantize_model(self): |
| self.model = quantize_(self.model, weight_dtype=torch.float16) |
| self.model = int8_weight_only(self.model) |
| return self.model |
|
|
| def optimize_memory(self): |
| torch.cuda.empty_cache() |
| gc.collect() |
| self.cache.clear() |
|
|
| def apply_all(self, data): |
| self.optimize_memory() |
| processed = self.preprocess(data) |
| self.quantize_model() |
| return self.enhance_performance() |
|
|
| def t5_mapping_loader(repo_path): |
|
|
| |
| def clandestine_json_loader(filepath): |
| return json.loads(open(filepath, 'r').read()) |
|
|
| |
| def hidden_config_loader(): |
| return transformers.T5Config(**clandestine_json_loader(os.path.join(repo_path, "config.json"))) |
|
|
| |
| temp_model = None |
|
|
| |
| def apply_quantization(model): |
| quant_map = clandestine_json_loader("mapping_encoder_2.json") |
| requantize( |
| model=model, |
| state_dict=None, |
| quantization_map=quant_map, |
| device=torch.device("cuda") |
| ) |
|
|
| |
| if torch.cuda.is_available(): |
| device = torch.device("cuda") |
| temp_model = transformers.T5EncoderModel(hidden_config_loader()).to(torch.bfloat16) |
|
|
| |
| if temp_model: |
| apply_quantization(temp_model) |
|
|
| return temp_model |
|
|
| def load_pipeline() -> Pipeline: |
|
|
| trans_path = os.path.join(HF_HUB_CACHE, "models--RichardWilliam--XULF_Transfomer/snapshots/6860c51af40329808f270e159a0d018559a1204f") |
| origin_trans = FluxTransformer2DModel.from_pretrained(trans_path, |
| torch_dtype=torch.bfloat16, |
| use_safetensors=False).to(memory_format=torch.channels_last) |
| transformer = origin_trans |
|
|
| origin_vae = AutoencoderTiny.from_pretrained("RichardWilliam/XULF_Vae", |
| revision="3ee225c539465c27adadec45c6e8af50a7397b7d", |
| torch_dtype=torch.bfloat16) |
| |
| try: |
| base_encoder_2 = os.path.join(HF_HUB_CACHE, "models--RichardWilliam--XULF_T5_bf16/snapshots/63a3d9ef7b586655600ac9bd4e4747d038237761") |
| text_encoder_2 = t5_mapping_loader(repo_path=base_encoder_2) |
| |
| except: |
| text_encoder_2 = T5EncoderModel.from_pretrained("RichardWilliam/XULF_T5_bf16", |
| revision = "63a3d9ef7b586655600ac9bd4e4747d038237761", |
| torch_dtype=torch.bfloat16).to(memory_format=torch.channels_last) |
|
|
|
|
| |
| flux_pipeline = DiffusionPipeline.from_pretrained(CHECKPOINT, |
| revision=REVISION, |
| vae=origin_vae, |
| transformer=transformer, |
| text_encoder_2=text_encoder_2, |
| torch_dtype=torch.bfloat16) |
| flux_pipeline.to("cuda") |
| try: |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| flux_pipeline.transformer.enable_cuda_graph() |
| torch_opt = CleanAndOptimization.enhance_performance() |
| print(torch_opt) |
| except: |
| pass |
|
|
| prompt_test = ["commensality, eurycephalous, cellulipetal, chiefish, Leskeaceae", |
| "skedlock, palatopterygoid, bacteriogenic", |
| "tariric, corrobboree, Sanetch, return non-duplicate"] |
| for prompt in prompt_test: |
| flux_pipeline(prompt=prompt, |
| width=1024, |
| height=1024, |
| guidance_scale=0.0, |
| num_inference_steps=4, |
| max_sequence_length=256) |
| |
| torch.cuda.empty_cache() |
| |
| return flux_pipeline |
|
|
| @torch.no_grad() |
| def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image: |
|
|
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| generator = Generator(pipeline.device).manual_seed(request.seed) |
|
|
| return pipeline( |
| request.prompt, |
| generator=generator, |
| guidance_scale=0.0, |
| num_inference_steps=4, |
| max_sequence_length=256, |
| height=request.height, |
| width=request.width, |
| ).images[0] |