- handler.py +24 -12
handler.py
CHANGED
|
@@ -1,16 +1,21 @@
|
|
| 1 |
from typing import Dict, List, Any
|
| 2 |
from transformers import pipeline
|
| 3 |
-
from diffusers import AutoPipelineForText2Image
|
| 4 |
import torch
|
| 5 |
import base64
|
| 6 |
from io import BytesIO
|
| 7 |
from PIL import Image
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class EndpointHandler():
|
| 11 |
-
def __init__(self, path=""):
|
| 12 |
-
self.pipe =
|
| 13 |
self.pipe.to("cuda")
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 16 |
"""
|
|
@@ -23,7 +28,6 @@ class EndpointHandler():
|
|
| 23 |
# get inputs
|
| 24 |
inputs = data.pop("inputs", data)
|
| 25 |
encoded_image = data.pop("image", None)
|
| 26 |
-
encoded_mask_image = data.pop("mask_image", None)
|
| 27 |
|
| 28 |
# hyperparamters
|
| 29 |
num_inference_steps = data.pop("num_inference_steps", 25)
|
|
@@ -31,20 +35,30 @@ class EndpointHandler():
|
|
| 31 |
negative_prompt = data.pop("negative_prompt", None)
|
| 32 |
height = data.pop("height", None)
|
| 33 |
width = data.pop("width", None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
# process image
|
| 36 |
-
if encoded_image is not None
|
| 37 |
-
image = self.decode_base64_image(encoded_image)
|
| 38 |
-
mask_image = self.decode_base64_image(encoded_mask_image)
|
| 39 |
else:
|
| 40 |
image = None
|
| 41 |
-
|
| 42 |
|
| 43 |
# run inference pipeline
|
| 44 |
out = self.pipe(inputs,
|
| 45 |
image=image,
|
| 46 |
-
|
| 47 |
num_inference_steps=num_inference_steps,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
guidance_scale=guidance_scale,
|
| 49 |
num_images_per_prompt=1,
|
| 50 |
negative_prompt=negative_prompt,
|
|
@@ -60,6 +74,4 @@ class EndpointHandler():
|
|
| 60 |
base64_image = base64.b64decode(image_string)
|
| 61 |
buffer = BytesIO(base64_image)
|
| 62 |
image = Image.open(buffer)
|
| 63 |
-
return image
|
| 64 |
-
|
| 65 |
-
|
|
|
|
| 1 |
from typing import Dict, List, Any
|
| 2 |
from transformers import pipeline
|
|
|
|
| 3 |
import torch
|
| 4 |
import base64
|
| 5 |
from io import BytesIO
|
| 6 |
from PIL import Image
|
| 7 |
+
from diffusers import StableDiffusionXLImg2ImgPipeline
|
| 8 |
+
from diffusers.utils import load_image
|
| 9 |
+
|
| 10 |
|
| 11 |
|
| 12 |
class EndpointHandler():
|
| 13 |
+
def __init__(self, path=""):
|
| 14 |
+
self.pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
|
| 15 |
self.pipe.to("cuda")
|
| 16 |
+
self.pipe.unet = torch.compile(self.pipe.unet, mode="reduce-overhead", fullgraph=True)
|
| 17 |
+
self.pipe.upcast_vae()
|
| 18 |
+
|
| 19 |
|
| 20 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 21 |
"""
|
|
|
|
| 28 |
# get inputs
|
| 29 |
inputs = data.pop("inputs", data)
|
| 30 |
encoded_image = data.pop("image", None)
|
|
|
|
| 31 |
|
| 32 |
# hyperparamters
|
| 33 |
num_inference_steps = data.pop("num_inference_steps", 25)
|
|
|
|
| 35 |
negative_prompt = data.pop("negative_prompt", None)
|
| 36 |
height = data.pop("height", None)
|
| 37 |
width = data.pop("width", None)
|
| 38 |
+
|
| 39 |
+
strength = data.pop("strength", 0.7)
|
| 40 |
+
denoising_start = data.pop("denoising_start_step", 0)
|
| 41 |
+
denoising_end = data.pop("denoising_start_step", 0)
|
| 42 |
+
num_images_per_prompt = data.pop("num_images_per_prompt", 1)
|
| 43 |
+
aesthetic_score = data.pop("aesthetic_score", 0.6)
|
| 44 |
+
|
| 45 |
|
| 46 |
# process image
|
| 47 |
+
if encoded_image is not None:
|
| 48 |
+
image = self.decode_base64_image(encoded_image)
|
|
|
|
| 49 |
else:
|
| 50 |
image = None
|
| 51 |
+
|
| 52 |
|
| 53 |
# run inference pipeline
|
| 54 |
out = self.pipe(inputs,
|
| 55 |
image=image,
|
| 56 |
+
strenght=strength,
|
| 57 |
num_inference_steps=num_inference_steps,
|
| 58 |
+
denoising_start_step=denoising_start,
|
| 59 |
+
denoising_end_step=denoising_end,
|
| 60 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 61 |
+
aesthetic_score=aesthetic_score,
|
| 62 |
guidance_scale=guidance_scale,
|
| 63 |
num_images_per_prompt=1,
|
| 64 |
negative_prompt=negative_prompt,
|
|
|
|
| 74 |
base64_image = base64.b64decode(image_string)
|
| 75 |
buffer = BytesIO(base64_image)
|
| 76 |
image = Image.open(buffer)
|
| 77 |
+
return image
|
|
|
|
|
|