tested and added
Browse files- handler.py +21 -12
handler.py
CHANGED
|
@@ -23,31 +23,40 @@ class EndpointHandler():
|
|
| 23 |
self.pipe = self.pipe.to(device)
|
| 24 |
|
| 25 |
|
| 26 |
-
def __call__(self, data: Any) ->
|
| 27 |
"""
|
| 28 |
Args:
|
| 29 |
-
data (
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
"""
|
| 34 |
-
|
| 35 |
negative_prompt = data.pop("negative_prompt", None)
|
| 36 |
height = data.pop("height", 512)
|
| 37 |
width = data.pop("width", 512)
|
|
|
|
| 38 |
guidance_scale = data.pop("guidance_scale", 7.5)
|
| 39 |
|
| 40 |
-
#
|
| 41 |
with autocast(device.type):
|
| 42 |
if negative_prompt is None:
|
| 43 |
-
image = self.pipe(prompt
|
|
|
|
| 44 |
else:
|
| 45 |
-
image = self.pipe(prompt
|
|
|
|
| 46 |
|
| 47 |
-
#
|
| 48 |
buffered = BytesIO()
|
| 49 |
image.save(buffered, format="JPEG")
|
| 50 |
img_str = base64.b64encode(buffered.getvalue())
|
| 51 |
|
| 52 |
-
#
|
| 53 |
-
return {"image": img_str.decode()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
self.pipe = self.pipe.to(device)
|
| 24 |
|
| 25 |
|
| 26 |
+
def __call__(self, data: Any) -> Dict[str, str]:
|
| 27 |
"""
|
| 28 |
Args:
|
| 29 |
+
data (Any): Includes the input data and the parameters for the inference.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
Dict[str, str]: Dictionary with the base64 encoded image.
|
| 33 |
"""
|
| 34 |
+
positive_prompt = data.pop("positive_prompt", "")
|
| 35 |
negative_prompt = data.pop("negative_prompt", None)
|
| 36 |
height = data.pop("height", 512)
|
| 37 |
width = data.pop("width", 512)
|
| 38 |
+
|
| 39 |
guidance_scale = data.pop("guidance_scale", 7.5)
|
| 40 |
|
| 41 |
+
# Run inference pipeline
|
| 42 |
with autocast(device.type):
|
| 43 |
if negative_prompt is None:
|
| 44 |
+
image = self.pipe(prompt=positive_prompt, height=height, width=width, guidance_scale=float(guidance_scale))
|
| 45 |
+
image = image.images[0]
|
| 46 |
else:
|
| 47 |
+
image = self.pipe(prompt=positive_prompt, negative_prompt=negative_prompt, height=height, width=width, guidance_scale=float(guidance_scale))
|
| 48 |
+
image = image.images[0]
|
| 49 |
|
| 50 |
+
# Encode image as base64
|
| 51 |
buffered = BytesIO()
|
| 52 |
image.save(buffered, format="JPEG")
|
| 53 |
img_str = base64.b64encode(buffered.getvalue())
|
| 54 |
|
| 55 |
+
# Postprocess the prediction
|
| 56 |
+
return {"image": img_str.decode()}
|
| 57 |
+
|
| 58 |
+
def decode_base64_image(self, image_string):
|
| 59 |
+
base64_image = base64.b64decode(image_string)
|
| 60 |
+
buffer = BytesIO(base64_image)
|
| 61 |
+
image = Image.open(buffer)
|
| 62 |
+
return image
|