chansung commited on
Commit
72c9a82
·
1 Parent(s): 0c40cdd

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +86 -0
handler.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import sys
3
+ import base64
4
+ import logging
5
+ import copy
6
+
7
+ import numpy as np
8
+ import torch
9
+ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
10
+
11
+ class ReusablePipePool:
12
+ def __init__(
13
+ self,
14
+ size,
15
+ model_base="runwayml/stable-diffusion-v1-5"
16
+ ):
17
+ self._reusablePipes = []
18
+ for i in range(size):
19
+ pipe = StableDiffusionPipeline.from_pretrained(
20
+ model_base, torch_dtype=torch.float16
21
+ )
22
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
23
+ self._reusablePipes.append(pipe)
24
+
25
+ if not self.empty():
26
+ self.original_unet = copy.deepcopy(self._reusablePipes[0].unet)
27
+
28
+ def acquire(self):
29
+ return self._reusablePipes.pop()
30
+
31
+ def release(self, reusablePipe):
32
+ self._reusablePipes.append(reusablePipe)
33
+
34
+ def empty(self):
35
+ return len(self._reusablePipes) == 0
36
+
37
+ class EndpointHandler():
38
+ def __init__(self, path=""):
39
+ self.pool = ReusablePipePool(2)
40
+
41
+ def _generate_images(
42
+ self,
43
+ model_path,
44
+ prompt,
45
+ num_inference_steps=25,
46
+ guidance_scale=7.5,
47
+ num_images_per_prompt=1):
48
+
49
+ reusablePipe = None
50
+ while not self.pool.empty():
51
+ reusablePipe = self.pool.acquire()
52
+
53
+ if model_path == "base":
54
+ reusablePipe.unet = copy.deepcopy(self.pool.original_unet)
55
+ else:
56
+ reusablePipe.unet.load_attn_procs(model_path)
57
+ reusablePipe.to("cuda")
58
+
59
+ pil_images = reusablePipe(
60
+ prompt=prompt,
61
+ num_inference_steps=num_inference_steps,
62
+ guidance_scale=guidance_scale,
63
+ num_images_per_prompt=num_images_per_prompt).images
64
+
65
+ self.pool.release(reusablePipe)
66
+
67
+ np_images = []
68
+ for i in range(len(pil_images)):
69
+ np_images.append(np.asarray(pil_images[i]))
70
+
71
+ return np.stack(np_images, axis=0)
72
+
73
+ def __call__(self, data: Dict[str, Any]) -> str:
74
+ prompt = data.pop("inputs", "test image")
75
+ model_path = data.pop("model_path", "base")
76
+
77
+ num_inference_steps = data.pop("num_inference_steps", 25)
78
+ guidance_scale = data.pop("guidance_scale", 7.5)
79
+ num_images_per_prompt = data.pop("num_images_per_prompt", 1)
80
+
81
+ images = self._generate_images(
82
+ model_path, prompt,
83
+ num_inference_steps, guidance_scale, num_images_per_prompt
84
+ )
85
+
86
+ return base64.b64encode(images.tobytes()).decode()