| from typing import Dict, Any | |
| import torch | |
| from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler | |
| from PIL import Image | |
| from webcolors import CSS3_HEX_TO_NAMES, hex_to_rgb | |
| from scipy.spatial import KDTree | |
| import base64 | |
| from io import BytesIO | |
| from PIL import Image | |
| import json | |
| class EndpointHandler(): | |
| def __init__(self, path=""): | |
| model_id = "timbrooks/instruct-pix2pix" | |
| self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None) | |
| self.pipe.to("cuda") | |
| self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config) | |
| def __call__(self, data): | |
| info=data['inputs'] | |
| image=info.pop("image",data) | |
| prompt=info.pop("text",data) | |
| image=base64.b64decode(image) | |
| raw_images = Image.open(BytesIO(image)).convert('RGB') | |
| images = self.pipe(prompt, image=raw_images, num_inference_steps=25, image_guidance_scale=1).images | |
| img=images[0] | |
| img.save("./1.png") | |
| with open('./1.png','rb') as img_file: | |
| encoded_string = base64.b64encode(img_file.read()).decode('utf-8') | |
| return {'image':encoded_string} | |
| """ | |
| def process_image_base64(self, base64_image_data): | |
| # Decode base64 data to bytes | |
| image_bytes = base64.b64decode(base64_image_data) | |
| # Convert bytes to an image | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| image = image.convert("RGB") | |
| image = image.resize((512, 512)) | |
| return image | |
| def build_prompt(self, text_prompt, color_code): | |
| color_name = self.hex_to_name(color_code) | |
| coloring_prompt = f" with a {color_name} color applied to only the designated area or key element in the picture by avoiding it becoming the dominant color of the image, leaving the text, logo, and shadows untouched." | |
| result_prompt = f"{text_prompt}{coloring_prompt}" | |
| return result_prompt | |
| def hex_to_name(self, hex_color): | |
| rgb_tuple = tuple(int(hex_color[i:i+2], 16) for i in (1, 3, 5)) | |
| names = [] | |
| rgb_values = [] | |
| for color_hex, color_name in CSS3_HEX_TO_NAMES.items(): | |
| names.append(color_name) | |
| rgb_values.append(hex_to_rgb(color_hex)) | |
| kdt_db = KDTree(rgb_values) | |
| distance, index = kdt_db.query(rgb_tuple) | |
| return names[index] | |
| """ | |
| if __name__=="__main__": | |
| my_handler = EndpointHandler(path='.') | |