| import torch
|
| from PIL.Image import Image
|
| from diffusers import StableDiffusionXLPipeline
|
|
|
| from pipelines.models import TextToImageRequest
|
| from diffusers import DDIMScheduler
|
| from torch import Generator
|
| from loss import SchedulerWrapper
|
|
|
| from onediffx import compile_pipe, save_pipe, load_pipe
|
|
|
| class OptimizedCallback:
|
| """Optimized callback handler with pre-computed indices"""
|
| def __init__(self, num_timesteps):
|
| self.switch_point = int(num_timesteps * 0.88)
|
|
|
| def __call__(self, pipe, step_index, timestep, callback_kwargs):
|
| if step_index == self.switch_point:
|
|
|
| for key in ['prompt_embeds', 'add_text_embeds', 'add_time_ids']:
|
| if key in callback_kwargs:
|
| callback_kwargs[key] = callback_kwargs[key].chunk(2)[-1]
|
| pipe._guidance_scale = 0.1
|
| return callback_kwargs
|
|
|
| @torch.inference_mode()
|
| def load_pipeline(pipeline=None, model_path="stablediffusionapi/newdream-sdxl-20", cache_dir="/home/sandbox/.cache/") -> StableDiffusionXLPipeline:
|
| """
|
| Optimized pipeline loading with better memory management and error handling
|
| """
|
| try:
|
| if pipeline is None:
|
|
|
| with autocast(enabled=True):
|
| pipeline = StableDiffusionXLPipeline.from_pretrained(
|
| model_path,
|
| torch_dtype=torch.float16,
|
| use_safetensors=True,
|
| variant="fp16",
|
| cache_dir=cache_dir
|
| ).to("cuda", non_blocking=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| pipeline.scheduler = SchedulerWrapper(
|
| DDIMScheduler.from_config(
|
| pipeline.scheduler.config,
|
| use_karras_sigmas=True
|
| )
|
| )
|
|
|
| pipeline = compile_pipe(pipeline)
|
|
|
| load_pipe(pipeline, dir=cache_dir)
|
|
|
|
|
| warmup_prompt = "warmup"
|
| warmup_params = {
|
| "prompt": warmup_prompt,
|
| "output_type": "pil",
|
| "num_inference_steps": 20,
|
| "guidance_scale": 5.0
|
| }
|
|
|
|
|
| with torch.cuda.amp.autocast():
|
| pipeline(
|
| **warmup_params,
|
| cache_interval=1,
|
| cache_layer_id=1,
|
| cache_block_id=0
|
| )
|
|
|
|
|
| pipeline.scheduler.prepare_loss()
|
|
|
|
|
| for _ in range(2):
|
| with torch.cuda.amp.autocast():
|
| pipeline(**warmup_params)
|
|
|
|
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
|
|
| return pipeline
|
|
|
| except Exception as e:
|
| print(f"Pipeline loading failed: {e}")
|
| raise
|
|
|
| @torch.inference_mode()
|
| def infer(request, pipeline):
|
| """
|
| Optimized inference function with better error handling and performance
|
| """
|
| try:
|
|
|
| generator = None if request['seed'] is None else torch.Generator('cuda').manual_seed(request['seed'])
|
|
|
|
|
| callback = OptimizedCallback(pipeline.num_timesteps)
|
|
|
|
|
| start_time = torch.cuda.Event(enable_timing=True)
|
| end_time = torch.cuda.Event(enable_timing=True)
|
|
|
| start_time.record()
|
|
|
|
|
| with torch.cuda.amp.autocast():
|
| with nvtx.range("inference"):
|
| image = pipeline(
|
| prompt=request['prompt'],
|
| negative_prompt=request['negative_prompt'],
|
| generator=generator,
|
| num_inference_steps=2,
|
| cache_interval=1,
|
| cache_layer_id=1,
|
| cache_block_id=0,
|
| eta=1.0,
|
| guidance_scale=5.0,
|
| guidance_rescale=0.0,
|
| callback_on_step_end=callback,
|
| callback_on_step_end_tensor_inputs=[
|
| 'prompt_embeds',
|
| 'add_text_embeds',
|
| 'add_time_ids'
|
| ],
|
| ).images[0]
|
|
|
| end_time.record()
|
| torch.cuda.synchronize()
|
|
|
| generation_time = start_time.elapsed_time(end_time) / 1000
|
|
|
| return image, generation_time
|
|
|
| except Exception as e:
|
| print(f"Inference failed: {e}")
|
| raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|