pix2pix_serving / handler.py
enesbol's picture
Update handler.py
5300450
from typing import Dict, Any
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler
from PIL import Image
from webcolors import CSS3_HEX_TO_NAMES, hex_to_rgb
from scipy.spatial import KDTree
import base64
from io import BytesIO
from PIL import Image
import json
class EndpointHandler():
def __init__(self, path=""):
model_id = "timbrooks/instruct-pix2pix"
self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None)
self.pipe.to("cuda")
self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config)
def __call__(self, data):
info=data['inputs']
image=info.pop("image",data)
prompt=info.pop("text",data)
image=base64.b64decode(image)
raw_images = Image.open(BytesIO(image)).convert('RGB')
images = self.pipe(prompt, image=raw_images, num_inference_steps=25, image_guidance_scale=1).images
img=images[0]
img.save("./1.png")
with open('./1.png','rb') as img_file:
encoded_string = base64.b64encode(img_file.read()).decode('utf-8')
return {'image':encoded_string}
"""
def process_image_base64(self, base64_image_data):
# Decode base64 data to bytes
image_bytes = base64.b64decode(base64_image_data)
# Convert bytes to an image
image = Image.open(io.BytesIO(image_bytes))
image = image.convert("RGB")
image = image.resize((512, 512))
return image
def build_prompt(self, text_prompt, color_code):
color_name = self.hex_to_name(color_code)
coloring_prompt = f" with a {color_name} color applied to only the designated area or key element in the picture by avoiding it becoming the dominant color of the image, leaving the text, logo, and shadows untouched."
result_prompt = f"{text_prompt}{coloring_prompt}"
return result_prompt
def hex_to_name(self, hex_color):
rgb_tuple = tuple(int(hex_color[i:i+2], 16) for i in (1, 3, 5))
names = []
rgb_values = []
for color_hex, color_name in CSS3_HEX_TO_NAMES.items():
names.append(color_name)
rgb_values.append(hex_to_rgb(color_hex))
kdt_db = KDTree(rgb_values)
distance, index = kdt_db.query(rgb_tuple)
return names[index]
"""
if __name__=="__main__":
my_handler = EndpointHandler(path='.')