File size: 2,533 Bytes
1eb16a3
 
 
 
10b0ac1
 
b9a675b
 
 
 
1eb16a3
b9a675b
1eb16a3
10b0ac1
1eb16a3
5300450
 
 
1eb16a3
b9a675b
 
 
 
 
 
 
1eb16a3
b9a675b
 
 
 
 
1eb16a3
b9a675b
 
972801a
 
 
 
 
 
1eb16a3
 
 
 
 
 
 
 
 
 
10b0ac1
 
1eb16a3
 
 
 
 
 
 
 
 
10b0ac1
b9a675b
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from typing import Dict, Any
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler
from PIL import Image
from webcolors import CSS3_HEX_TO_NAMES, hex_to_rgb
from scipy.spatial import KDTree
import base64
from io import BytesIO
from PIL import Image
import json

class EndpointHandler():
    
    def __init__(self, path=""):
        model_id = "timbrooks/instruct-pix2pix"
        self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None)
        self.pipe.to("cuda")
        self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config)

    def __call__(self, data):

        info=data['inputs']
        image=info.pop("image",data)
        prompt=info.pop("text",data)
        image=base64.b64decode(image)
        raw_images = Image.open(BytesIO(image)).convert('RGB')
        
        images = self.pipe(prompt, image=raw_images, num_inference_steps=25, image_guidance_scale=1).images
        img=images[0]
        img.save("./1.png")
        with open('./1.png','rb') as img_file:
            encoded_string = base64.b64encode(img_file.read()).decode('utf-8')
        
        return {'image':encoded_string}
    """
    def process_image_base64(self, base64_image_data):
        # Decode base64 data to bytes
        image_bytes = base64.b64decode(base64_image_data)

        # Convert bytes to an image
        image = Image.open(io.BytesIO(image_bytes))
        image = image.convert("RGB")
        image = image.resize((512, 512))
        return image

    def build_prompt(self, text_prompt, color_code):
        color_name = self.hex_to_name(color_code)
        coloring_prompt = f" with a {color_name} color applied to only the designated area or key element in the picture by avoiding it becoming the dominant color of the image, leaving the text, logo, and shadows untouched."
        result_prompt = f"{text_prompt}{coloring_prompt}"
        return result_prompt

    def hex_to_name(self, hex_color):
        rgb_tuple = tuple(int(hex_color[i:i+2], 16) for i in (1, 3, 5))
        names = []
        rgb_values = []

        for color_hex, color_name in CSS3_HEX_TO_NAMES.items():
            names.append(color_name)
            rgb_values.append(hex_to_rgb(color_hex))

        kdt_db = KDTree(rgb_values)
        distance, index = kdt_db.query(rgb_tuple)
        return names[index]
    """
if __name__=="__main__":
    my_handler = EndpointHandler(path='.')