Alex Mikulaniec commited on
Commit
cf991ea
·
1 Parent(s): 4766aba

Added handler.py to manage the model’s endpoints

Browse files
Files changed (1) hide show
  1. handler.py +39 -0
handler.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ import torch
3
+ from torch.cuda.amp import autocast
4
+ from diffusers import StableDiffusionPipeline
5
+ import base64
6
+ from io import BytesIO
7
+
8
+ # Setting the device
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+
11
+ class EndpointHandler():
12
+
13
+ def __init__(self, path=""):
14
+ # Load the model
15
+ self.pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float32)
16
+ self.pipe = self.pipe.to(device)
17
+
18
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
19
+ """
20
+ Args:
21
+ data (dict): Includes the input data for inference.
22
+
23
+ Return:
24
+ dict: Base64 encoded image.
25
+ """
26
+ inputs = data.get("inputs") # Getting the inputs from the data dictionary
27
+
28
+ # Run inference pipeline
29
+ with autocast():
30
+ output = self.pipe(inputs, guidance_scale=7.5)
31
+ image = output['images'][0] # Accessing the 'images' key in the output
32
+
33
+ # Encoding image as base 64
34
+ buffered = BytesIO()
35
+ image.save(buffered, format="PNG")
36
+ img_str = base64.b64encode(buffered.getvalue())
37
+
38
+ # Returning the base64 image as a dictionary
39
+ return {"image": img_str.decode()}