viba98 commited on
Commit
8cc5118
·
1 Parent(s): 79579ce

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +49 -0
handler.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ from torch import autocast
4
+ from diffusers import StableDiffusionPipeline
5
+ import base64
6
+ from io import BytesIO
7
+
8
+
9
+ # set device
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+
12
+ if device.type != 'cuda':
13
+ raise ValueError("need to run on GPU")
14
+
15
+ class EndpointHandler():
16
+ def __init__(self, path=""):
17
+ # load the optimized model
18
+ self.pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
19
+ self.pipe = self.pipe.to(device)
20
+
21
+
22
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
23
+ """
24
+ Args:
25
+ data (:obj:):
26
+ includes the input data and the parameters for the inference.
27
+ Return:
28
+ A :obj:`dict`:. base64 encoded image
29
+ """
30
+ inputs = data.pop("inputs", data)
31
+
32
+ #disable_safety = True
33
+
34
+ #if disable_safety:
35
+ #def null_safety(images, **kwargs):
36
+ # return images, False
37
+ #self.pipe.safety_checker = null_safety
38
+
39
+ # run inference pipeline
40
+ with autocast(device.type):
41
+ image = self.pipe(inputs, guidance_scale=7.5, num_inference_steps=100, num_images_per_prompt=4)["sample"][0]
42
+
43
+ # encode image as base 64
44
+ buffered = BytesIO()
45
+ image.save(buffered, format="PNG")
46
+ img_str = base64.b64encode(buffered.getvalue())
47
+
48
+ # postprocess the prediction
49
+ return {"image": img_str.decode()}