File size: 3,145 Bytes
2cbdeaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import torch
from pathlib import Path
from PIL.Image import Image
from diffusers import StableDiffusionXLPipeline, DDIMScheduler, AutoencoderTiny
from autoencoder_kl import AutoencoderKL
from pipelines.models import TextToImageRequest
from torch import Generator
from cache_diffusion import cachify
from trt_pipeline.deploy import load_unet_trt
from loss import SchedulerWrapper

no_cache_blk = ["down_blocks.2", "up_blocks.0", "mid_block"]
SDXL_DEFAULT_CONFIG = [{
    "wildcard_or_filter_func": lambda name: any([blk in name for blk in no_cache_blk]),
    "select_cache_step_func": lambda step: (step % 2 == 0) and (step >= 8),
                                                                        }]
HOME = os.environ["HOME"]

def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
    if step_index == int(pipe.num_timesteps * 0.75):
        callback_kwargs['prompt_embeds'] = callback_kwargs['prompt_embeds'].chunk(2)[-1]
        callback_kwargs['add_text_embeds'] = callback_kwargs['add_text_embeds'].chunk(2)[-1]
        callback_kwargs['add_time_ids'] = callback_kwargs['add_time_ids'].chunk(2)[-1]
        pipe._guidance_scale = 1.1

    return callback_kwargs

def load_pipeline() -> StableDiffusionXLPipeline:
    pipe = StableDiffusionXLPipeline.from_pretrained(
        "stablediffusionapi/newdream-sdxl-20", 
        torch_dtype=torch.float16,
        use_safetensors=True
    ) 
    pipe.scheduler = SchedulerWrapper(DDIMScheduler.from_config(pipe.scheduler.config)) 
    pipe = pipe.to("cuda") 
    pipe.scheduler.prepare_loss() 
    ENGINE_PATH = f"{HOME}/.cache/huggingface/hub/models--slobers--cancer/snapshots/209cecbed645ffa913ebaefc115029021a0fa230"
    try: 
        file_path = os.path.join(ENGINE_PATH, ".gitattributes")
        os.remove(file_path)
    except Exception as err:
        print(err)
        pass
    load_unet_trt(
        pipe.unet,
        engine_path=Path(ENGINE_PATH),
        batch_size=1,
    )

    pipe(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus")
    cachify.prepare(pipe, SDXL_DEFAULT_CONFIG)
    cachify.enable(pipe)
    for _ in range(5):
        pipe(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", 
                callback_on_step_end=callback_dynamic_cfg,
                callback_on_step_end_tensor_inputs=['prompt_embeds', 'add_text_embeds', 'add_time_ids']
)
    cachify.disable(pipe)
    return pipe

def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image:

    if request.seed is None:
        generator = None
    else:
        generator = Generator(pipeline.device).manual_seed(request.seed)
    cachify.enable(pipeline)
    image = pipeline(
        prompt=request.prompt,
        negative_prompt=request.negative_prompt,
        width=request.width,
        height=request.height,
        generator=generator,
        num_inference_steps=15,
        end_cfg=0.5,
        eta=1.0,
        guidance_scale = 5.0,
        guidance_rescale = 0.0,
        ).images[0]
    cachify.disable(pipeline)
    return image