raza2 commited on
Commit
a00b664
·
1 Parent(s): f64541b

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +48 -0
handler.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ from torch import autocast
4
+ from diffusers import StableDiffusionPipeline
5
+ import base64
6
+ from io import BytesIO
7
+ # from transformers.utils import logging
8
+
9
+ # logging.set_verbosity_info()
10
+ # logger = logging.get_logger("transformers")
11
+
12
+ # set device
13
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+
15
+ if device.type != 'cuda':
16
+ raise ValueError("need to run on GPU")
17
+
18
+ class EndpointHandler():
19
+ def __init__(self, path=""):
20
+ # load the optimized model
21
+ self.pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
22
+ self.pipe = self.pipe.to(device)
23
+
24
+
25
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
26
+ """
27
+ Args:
28
+ data (:obj:):
29
+ includes the input data and the parameters for the inference.
30
+ Return:
31
+ A :obj:`dict`:. base64 encoded image
32
+ """
33
+ inputs = data.pop("inputs", data)
34
+
35
+ # run inference pipeline
36
+ with autocast(device.type):
37
+ image = self.pipe(inputs, guidance_scale=20["sample"][0]
38
+ # logger.info("Passed inputs, set guidance to 20")
39
+ # print("Set guidance scale to 20")
40
+
41
+ # encode image as base 64
42
+ buffered = BytesIO()
43
+ image.save(buffered, format="JPEG")
44
+ img_str = base64.b64encode(buffered.getvalue())
45
+
46
+ # postprocess the prediction
47
+ return {"image": img_str.decode(), "isRunning": "true"}
48
+