RobertML commited on
Commit
335de76
·
verified ·
1 Parent(s): 5aec2d3

Add files using upload-large-folder tool

Browse files
Files changed (1) hide show
  1. src/pipeline.py +54 -0
src/pipeline.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pathlib import Path
3
+ from PIL.Image import Image
4
+ from diffusers import StableDiffusionXLPipeline, DDIMScheduler
5
+ from pipelines.models import TextToImageRequest
6
+ from torch import Generator
7
+ from cache_diffusion import cachify
8
+ from pipe.deploy import compile
9
+ from loss import SchedulerWrapper
10
+
11
+ generator = Generator(torch.device("cuda")).manual_seed(6969)
12
+ prompt = "Make submissions great again"
13
+ SDXL_DEFAULT_CONFIG = [
14
+ {
15
+ "wildcard_or_filter_func": lambda name: "down_blocks.2" not in name and"down_blocks.3" not in name and "up_blocks.2" not in name,
16
+ "select_cache_step_func": lambda step: (step % 2 != 0) and (step >= 8),
17
+ }]
18
+ def load_pipeline() -> StableDiffusionXLPipeline:
19
+ pipe = StableDiffusionXLPipeline.from_pretrained(
20
+ "models/newdream-sdxl-20",torch_dtype=torch.float16, use_safetensors=True, local_files_only=True
21
+ ).to("cuda")
22
+ compile(
23
+ pipe,
24
+ onnx_path=Path("./onnx"),
25
+ engine_path=Path("./engine"),
26
+ batch_size=1,
27
+ )
28
+ cachify.prepare(pipe, SDXL_DEFAULT_CONFIG)
29
+ cachify.enable(pipe)
30
+ pipe.scheduler = SchedulerWrapper(DDIMScheduler.from_config(pipe.scheduler.config))
31
+ with cachify.infer(pipe) as cached_pipe:
32
+ for _ in range(4):
33
+ pipe(prompt=prompt, num_inference_steps=20)
34
+ pipe.scheduler.prepare_loss()
35
+ cachify.disable(pipe)
36
+ return pipe
37
+
38
+ def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image:
39
+
40
+ if request.seed is None:
41
+ generator = None
42
+ else:
43
+ generator = Generator(pipeline.device).manual_seed(request.seed)
44
+ cachify.enable(pipeline)
45
+ with cachify.infer(pipeline) as cached_pipe:
46
+ image = cached_pipe(
47
+ prompt=request.prompt,
48
+ negative_prompt=request.negative_prompt,
49
+ width=request.width,
50
+ height=request.height,
51
+ generator=generator,
52
+ num_inference_steps=13,
53
+ ).images[0]
54
+ return image