enesbol commited on
Commit
10b0ac1
·
1 Parent(s): 1eb16a3
Files changed (1) hide show
  1. handler.py +9 -11
handler.py CHANGED
@@ -2,16 +2,14 @@ from typing import Dict, Any
2
  import torch
3
  from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler
4
  from PIL import Image
5
- import torch
6
- from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler
7
-
8
- model_id = "timbrooks/instruct-pix2pix"
9
 
10
  class EndpointHandler:
11
 
12
- def __init__(self):
13
  model_id = "timbrooks/instruct-pix2pix"
14
- self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None).to("cuda")
15
  self.model.scheduler = EulerAncestralDiscreteScheduler.from_config(self.model.scheduler.config)
16
 
17
  def __call__(self, data: Dict[str, Any]) -> Image:
@@ -20,6 +18,7 @@ class EndpointHandler:
20
  data (dict):
21
  The payload with the input image, text prompt, color code, and optional parameters.
22
  """
 
23
  input_image = self.load_process_image(data['image_path'])
24
  text_prompt = data['text_prompt']
25
  color_code = data['color_code']
@@ -36,10 +35,9 @@ class EndpointHandler:
36
  num_inference_steps=50,
37
  guidance_scale=guidance_scale,
38
  image_guidance_scale=image_guidance_scale,
39
-
40
  ).images
41
 
42
- return images
43
 
44
  def load_process_image(self, image_path):
45
  image = Image.open(image_path)
@@ -53,8 +51,8 @@ class EndpointHandler:
53
  result_prompt = f"{text_prompt}{coloring_prompt}"
54
  return result_prompt
55
 
56
- def hex_to_name(hex_color):
57
- rgb_tuple = tuple(int(a[i:i+2], 16) for i in (1, 3, 5))
58
  names = []
59
  rgb_values = []
60
 
@@ -64,4 +62,4 @@ class EndpointHandler:
64
 
65
  kdt_db = KDTree(rgb_values)
66
  distance, index = kdt_db.query(rgb_tuple)
67
- return names[index]
 
2
  import torch
3
  from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler
4
  from PIL import Image
5
+ from webcolors import CSS3_HEX_TO_NAMES, hex_to_rgb
6
+ from scipy.spatial import KDTree
 
 
7
 
8
  class EndpointHandler:
9
 
10
+ def __init__(self, path=""):
11
  model_id = "timbrooks/instruct-pix2pix"
12
+ self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained(path, torch_dtype=torch.float16, safety_checker=None).to("cuda")
13
  self.model.scheduler = EulerAncestralDiscreteScheduler.from_config(self.model.scheduler.config)
14
 
15
  def __call__(self, data: Dict[str, Any]) -> Image:
 
18
  data (dict):
19
  The payload with the input image, text prompt, color code, and optional parameters.
20
  """
21
+
22
  input_image = self.load_process_image(data['image_path'])
23
  text_prompt = data['text_prompt']
24
  color_code = data['color_code']
 
35
  num_inference_steps=50,
36
  guidance_scale=guidance_scale,
37
  image_guidance_scale=image_guidance_scale,
 
38
  ).images
39
 
40
+ return images
41
 
42
  def load_process_image(self, image_path):
43
  image = Image.open(image_path)
 
51
  result_prompt = f"{text_prompt}{coloring_prompt}"
52
  return result_prompt
53
 
54
+ def hex_to_name(self, hex_color):
55
+ rgb_tuple = tuple(int(hex_color[i:i+2], 16) for i in (1, 3, 5))
56
  names = []
57
  rgb_values = []
58
 
 
62
 
63
  kdt_db = KDTree(rgb_values)
64
  distance, index = kdt_db.query(rgb_tuple)
65
+ return names[index]