from typing import Dict, Any import torch from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler from PIL import Image from webcolors import CSS3_HEX_TO_NAMES, hex_to_rgb from scipy.spatial import KDTree import base64 from io import BytesIO from PIL import Image import json class EndpointHandler(): def __init__(self, path=""): model_id = "timbrooks/instruct-pix2pix" self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None) self.pipe.to("cuda") self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config) def __call__(self, data): info=data['inputs'] image=info.pop("image",data) prompt=info.pop("text",data) image=base64.b64decode(image) raw_images = Image.open(BytesIO(image)).convert('RGB') images = self.pipe(prompt, image=raw_images, num_inference_steps=25, image_guidance_scale=1).images img=images[0] img.save("./1.png") with open('./1.png','rb') as img_file: encoded_string = base64.b64encode(img_file.read()).decode('utf-8') return {'image':encoded_string} """ def process_image_base64(self, base64_image_data): # Decode base64 data to bytes image_bytes = base64.b64decode(base64_image_data) # Convert bytes to an image image = Image.open(io.BytesIO(image_bytes)) image = image.convert("RGB") image = image.resize((512, 512)) return image def build_prompt(self, text_prompt, color_code): color_name = self.hex_to_name(color_code) coloring_prompt = f" with a {color_name} color applied to only the designated area or key element in the picture by avoiding it becoming the dominant color of the image, leaving the text, logo, and shadows untouched." result_prompt = f"{text_prompt}{coloring_prompt}" return result_prompt def hex_to_name(self, hex_color): rgb_tuple = tuple(int(hex_color[i:i+2], 16) for i in (1, 3, 5)) names = [] rgb_values = [] for color_hex, color_name in CSS3_HEX_TO_NAMES.items(): names.append(color_name) rgb_values.append(hex_to_rgb(color_hex)) kdt_db = KDTree(rgb_values) distance, index = kdt_db.query(rgb_tuple) return names[index] """ if __name__=="__main__": my_handler = EndpointHandler(path='.')