| import os |
| import torch |
| import torch._dynamo |
| import gc |
| import json |
| import transformers |
| from huggingface_hub.constants import HF_HUB_CACHE |
| from transformers import T5EncoderModel, T5TokenizerFast |
| from PIL.Image import Image |
| from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny, FluxTransformer2DModel, DiffusionPipeline |
| from pipelines.models import TextToImageRequest |
| from optimum.quanto import requantize |
| from torch import Generator |
| from torch._dynamo import config |
| from torch._inductor import config as ind_config |
| from typing import Dict, Any, Callable |
| from functools import wraps |
|
|
| def error_handler(func: Callable): |
| @wraps(func) |
| def wrapper(*args, **kwargs): |
| try: |
| return func(*args, **kwargs) |
| except Exception as e: |
| print(f"Error in {func.__name__}: {str(e)}") |
| return None |
| return wrapper |
|
|
| class TorchOptimizer: |
| def optimize_settings(self): |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| torch.backends.cudnn.benchmark = True |
| torch.set_float32_matmul_precision("high") |
|
|
| def clear_cache(self): |
| torch.cuda.empty_cache() |
| torch.cuda.reset_max_memory_allocated() |
| torch.cuda.reset_peak_memory_stats() |
|
|
| class PipelineManager: |
| def __init__(self): |
| self.ckpt_root = "MyApricity/FLUX_OPT_SCHNELL_1.2" |
| self.revision_root = "488528b6f815bff1bbc747cf1e0947c77c544665" |
| self.pipeline = None |
| self.optimizer = TorchOptimizer() |
| |
| |
| torch._dynamo.config.suppress_errors = True |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True" |
| os.environ["TOKENIZERS_PARALLELISM"] = "True" |
| |
| |
| self.optimizer.optimize_settings() |
|
|
| |
| print("Initializing pipeline...") |
| self.pipeline = self.load_pipeline() |
| print("Pipeline initialization complete.") |
|
|
|
|
| def load_transformer(self): |
| transformer_path = os.path.join( |
| HF_HUB_CACHE, |
| "models--MyApricity--FLUX_OPT_SCHNELL_1.2/snapshots/488528b6f815bff1bbc747cf1e0947c77c544665" |
| ) |
| return FluxTransformer2DModel.from_pretrained( |
| transformer_path, |
| torch_dtype=torch.bfloat16, |
| use_safetensors=False |
| ) |
|
|
| @error_handler |
| def optimize_pipeline(self, pipe): |
| |
| pipe.transformer.fuse_qkv_projections() |
| pipe.vae.fuse_qkv_projections() |
|
|
| |
| pipe.transformer.to(memory_format=torch.channels_last) |
| pipe.vae.to(memory_format=torch.channels_last) |
|
|
| |
| config = torch._inductor.config |
| config.disable_progress = False |
| config.conv_1x1_as_mm = True |
|
|
| |
| pipe.transformer = torch.compile( |
| pipe.transformer, |
| mode="max-autotune", |
| fullgraph=True |
| ) |
| pipe.vae.decode = torch.compile( |
| pipe.vae.decode, |
| mode="max-autotune", |
| fullgraph=True |
| ) |
|
|
| return pipe |
|
|
| def load_pipeline(self): |
| |
| transformer_model = self.load_transformer() |
| |
| |
| pipe = DiffusionPipeline.from_pretrained( |
| self.ckpt_root, |
| revision=self.revision_root, |
| transformer=transformer_model, |
| torch_dtype=torch.bfloat16 |
| ) |
| pipe.to("cuda") |
|
|
| |
| pipe_ops = self.optimize_pipeline(pipe) |
| if pipe_ops!=None: |
| pipe = pipe_ops |
| |
| |
| print("Running torch compilation...") |
| pipe( |
| "dummy prompt to trigger torch compilation", |
| output_type="pil", |
| num_inference_steps=4 |
| ).images[0] |
| print("Finished torch compilation") |
|
|
| return pipe |
|
|
| def run_inference(self, request: TextToImageRequest) -> Image: |
| if self.pipeline is None: |
| self.pipeline = self.load_pipeline() |
|
|
| self.optimizer.clear_cache() |
| generator = Generator(self.pipeline.device).manual_seed(request.seed) |
|
|
| return self.pipeline( |
| request.prompt, |
| generator=generator, |
| guidance_scale=0.0, |
| num_inference_steps=4, |
| max_sequence_length=256, |
| height=request.height, |
| width=request.width, |
| ).images[0] |