justin-shopcapsule commited on
Commit
32f0f58
·
1 Parent(s): a898dc0

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +42 -0
handler.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from PIL import Image
3
+ from io import BytesIO
4
+ from transformers import AutoModelForSemanticSegmentation, AutoFeatureExtractor
5
+ import base64
6
+ import torch
7
+ from torch import nn
8
+
9
+ from RealESRGAN import RealESRGAN
10
+
11
+
12
+ class EndpointHandler():
13
+ def __init__(self, path="."):
14
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ self.model = RealESRGAN(device, scale=4).to(self.device)
16
+ self.model.load_weights('./RealESRGAN_x4.pth', download=True)
17
+
18
+
19
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
20
+ """
21
+ data args:
22
+ images (:obj:`PIL.Image`)
23
+ candiates (:obj:`list`)
24
+ Return:
25
+ A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82}
26
+ """
27
+ inputs = data.pop("inputs", data)
28
+
29
+ # decode base64 image to PIL
30
+ image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
31
+
32
+ # forward pass
33
+ output_image = self.model.predict(image)
34
+
35
+ # base64 encode output
36
+ buffered = BytesIO()
37
+ output_image = output_image.convert('RGB')
38
+ output_image.save(buffered, format="png")
39
+ img_str = base64.b64encode(buffered.getvalue())
40
+
41
+ # postprocess the prediction
42
+ return {"image": img_str.decode()}