RobertML commited on
Commit
a8fe2a7
·
verified ·
1 Parent(s): 34a8dbb

Add files using upload-large-folder tool

Browse files
Files changed (1) hide show
  1. src/pipeline.py +51 -0
src/pipeline.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pathlib import Path
3
+ from PIL.Image import Image
4
+ from diffusers import StableDiffusionXLPipeline, DDIMScheduler
5
+ from pipelines.models import TextToImageRequest
6
+ from torch import Generator
7
+ from cache_diffusion import cachify
8
+ from pipe.deploy import compile
9
+
10
+ generator = Generator(torch.device("cuda")).manual_seed(6969)
11
+ prompt = "Make submissions great again"
12
+ SDXL_DEFAULT_CONFIG = [
13
+ {
14
+ "wildcard_or_filter_func": lambda name: "down_blocks.2" not in name and"down_blocks.3" not in name and "up_blocks.2" not in name,
15
+ "select_cache_step_func": lambda step: (step % 2 != 0) and (step >= 14),
16
+ }]
17
+ def load_pipeline() -> StableDiffusionXLPipeline:
18
+ pipe = StableDiffusionXLPipeline.from_pretrained(
19
+ "models/newdream-sdxl-20",torch_dtype=torch.float16, use_safetensors=True, local_files_only=True
20
+ ).to("cuda")
21
+ compile(
22
+ pipe,
23
+ onnx_path=Path("./onnx"),
24
+ engine_path=Path("./engine"),
25
+ batch_size=1,
26
+ )
27
+ cachify.prepare(pipe, SDXL_DEFAULT_CONFIG)
28
+ cachify.enable(pipe)
29
+ with cachify.infer(pipe) as cached_pipe:
30
+ for _ in range(4):
31
+ pipe(prompt=prompt, num_inference_steps=20)
32
+ cachify.disable(pipe)
33
+ return pipe
34
+
35
+ def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image:
36
+
37
+ if request.seed is None:
38
+ generator = None
39
+ else:
40
+ generator = Generator(pipeline.device).manual_seed(request.seed)
41
+ cachify.enable(pipeline)
42
+ with cachify.infer(pipeline) as cached_pipe:
43
+ image = cached_pipe(
44
+ prompt=request.prompt,
45
+ negative_prompt=request.negative_prompt,
46
+ width=request.width,
47
+ height=request.height,
48
+ generator=generator,
49
+ num_inference_steps=20,
50
+ ).images[0]
51
+ return image