senseicashpls5 / src /pipeline_bkp.py
Manoj Bhat
init one
2c7aa03
import torch
from PIL.Image import Image
from diffusers import StableDiffusionXLPipeline
from pipelines.models import TextToImageRequest
from diffusers import DDIMScheduler
from torch import Generator
from loss import SchedulerWrapper
from onediffx import compile_pipe, save_pipe, load_pipe
class OptimizedCallback:
"""Optimized callback handler with pre-computed indices"""
def __init__(self, num_timesteps):
self.switch_point = int(num_timesteps * 0.88)
def __call__(self, pipe, step_index, timestep, callback_kwargs):
if step_index == self.switch_point:
# Optimize chunking operations
for key in ['prompt_embeds', 'add_text_embeds', 'add_time_ids']:
if key in callback_kwargs:
callback_kwargs[key] = callback_kwargs[key].chunk(2)[-1]
pipe._guidance_scale = 0.1
return callback_kwargs
@torch.inference_mode()
def load_pipeline(pipeline=None, model_path="stablediffusionapi/newdream-sdxl-20", cache_dir="/home/sandbox/.cache/") -> StableDiffusionXLPipeline:
"""
Optimized pipeline loading with better memory management and error handling
"""
try:
if pipeline is None:
# Use context manager for better memory handling
with autocast(enabled=True):
pipeline = StableDiffusionXLPipeline.from_pretrained(
model_path,
torch_dtype=torch.float16,
use_safetensors=True, # Faster loading
variant="fp16",
cache_dir=cache_dir
).to("cuda", non_blocking=True)
# Optimize memory usage
# pipeline.enable_model_cpu_offload()
# pipeline.enable_vae_slicing()
# Configure scheduler
pipeline.scheduler = SchedulerWrapper(
DDIMScheduler.from_config(
pipeline.scheduler.config,
use_karras_sigmas=True # Better quality
)
)
pipeline = compile_pipe(pipeline)
# Load additional components
load_pipe(pipeline, dir=cache_dir)
# Optimize warmup runs
warmup_prompt = "warmup"
warmup_params = {
"prompt": warmup_prompt,
"output_type": "pil",
"num_inference_steps": 20,
"guidance_scale": 5.0
}
# First warmup with deep cache
with torch.cuda.amp.autocast():
pipeline(
**warmup_params,
cache_interval=1,
cache_layer_id=1,
cache_block_id=0
)
# Prepare loss and additional warmup runs
pipeline.scheduler.prepare_loss()
# Batch remaining warmup runs
for _ in range(2):
with torch.cuda.amp.autocast():
pipeline(**warmup_params)
# Clear cache after warmup
torch.cuda.empty_cache()
gc.collect()
return pipeline
except Exception as e:
print(f"Pipeline loading failed: {e}")
raise
@torch.inference_mode()
def infer(request, pipeline):
"""
Optimized inference function with better error handling and performance
"""
try:
# Setup generator
generator = None if request['seed'] is None else torch.Generator('cuda').manual_seed(request['seed'])
# Initialize callback
callback = OptimizedCallback(pipeline.num_timesteps)
# Record execution time
start_time = torch.cuda.Event(enable_timing=True)
end_time = torch.cuda.Event(enable_timing=True)
start_time.record()
# Run inference with optimized settings
with torch.cuda.amp.autocast():
with nvtx.range("inference"): # For profiling
image = pipeline(
prompt=request['prompt'],
negative_prompt=request['negative_prompt'],
generator=generator,
num_inference_steps=2,
cache_interval=1,
cache_layer_id=1,
cache_block_id=0,
eta=1.0,
guidance_scale=5.0,
guidance_rescale=0.0,
callback_on_step_end=callback,
callback_on_step_end_tensor_inputs=[
'prompt_embeds',
'add_text_embeds',
'add_time_ids'
],
).images[0]
end_time.record()
torch.cuda.synchronize()
generation_time = start_time.elapsed_time(end_time) / 1000 # Convert to seconds
return image, generation_time
except Exception as e:
print(f"Inference failed: {e}")
raise