saassa commited on
Commit
42ca7c4
·
1 Parent(s): 1abde31

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +25 -16
handler.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Dict, List, Any
2
  import base64
3
  from PIL import Image
4
  from io import BytesIO
@@ -6,12 +6,11 @@ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, Autoen
6
  import torch
7
  from diffusers.utils import load_image
8
 
9
-
10
-
11
  import numpy as np
12
  import cv2
13
  import controlnet_hinter
14
- # ADDED AUTO PIPE, next try replacing
 
15
  # set device
16
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
  if device.type != 'cuda':
@@ -55,20 +54,20 @@ CONTROLNET_MAPPING = {
55
  }
56
  }
57
 
58
-
59
  class EndpointHandler():
60
  def __init__(self, path=""):
61
  # define default controlnet id and load controlnet
62
  self.control_type = "normal"
63
- self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],torch_dtype=dtype).to(device)
64
-
65
  # Load StableDiffusionControlNetPipeline
66
  self.stable_diffusion_id = "stablediffusionapi/disney-pixar-cartoon"
67
  self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
68
  controlnet=self.controlnet,
69
- torch_dtype=dtype,
70
  safety_checker=None).to(device)
71
-
72
  # Define Generator with seed
73
  # COMMENTED self.generator = torch.Generator(device="cpu").manual_seed(3)
74
 
@@ -80,6 +79,18 @@ class EndpointHandler():
80
  prompt = data.pop("inputs", None)
81
  image = data.pop("image", None)
82
  controlnet_type = data.pop("controlnet_type", None)
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  # Check if neither prompt nor image is provided
85
  if prompt is None and image is None:
@@ -93,9 +104,8 @@ class EndpointHandler():
93
  torch_dtype=dtype).to(device)
94
  self.pipe.controlnet = self.controlnet
95
 
96
-
97
- # hyperparamters
98
- negatice_prompt = data.pop("negative_prompt", None)
99
  num_inference_steps = data.pop("num_inference_steps", 150)
100
  guidance_scale = data.pop("guidance_scale", 5)
101
  negative_prompt = data.pop("negative_prompt", None)
@@ -118,12 +128,11 @@ class EndpointHandler():
118
  height=height,
119
  width=width,
120
  controlnet_conditioning_scale=controlnet_conditioning_scale,
121
- guess_mode=True,
122
-
123
  )
124
 
125
- #generator=self.generator COMMENTED from self.pipe
126
- # return first generate PIL image
127
  return out.images[0]
128
 
129
  # helper to decode input image
 
1
+ from typing import Dict, List, Any
2
  import base64
3
  from PIL import Image
4
  from io import BytesIO
 
6
  import torch
7
  from diffusers.utils import load_image
8
 
 
 
9
  import numpy as np
10
  import cv2
11
  import controlnet_hinter
12
+
13
+ # ADDED AUTO PIPE
14
  # set device
15
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
  if device.type != 'cuda':
 
54
  }
55
  }
56
 
57
+
58
  class EndpointHandler():
59
  def __init__(self, path=""):
60
  # define default controlnet id and load controlnet
61
  self.control_type = "normal"
62
+ self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"], torch_dtype=dtype).to(device)
63
+
64
  # Load StableDiffusionControlNetPipeline
65
  self.stable_diffusion_id = "stablediffusionapi/disney-pixar-cartoon"
66
  self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
67
  controlnet=self.controlnet,
68
+ torch_dtype=dtype,
69
  safety_checker=None).to(device)
70
+
71
  # Define Generator with seed
72
  # COMMENTED self.generator = torch.Generator(device="cpu").manual_seed(3)
73
 
 
79
  prompt = data.pop("inputs", None)
80
  image = data.pop("image", None)
81
  controlnet_type = data.pop("controlnet_type", None)
82
+ stablediffusion_id = data.pop("stablediffusionid", None) # Get the stablediffusionid from the request data
83
+
84
+ if stablediffusion_id is not None and stablediffusion_id != self.stable_diffusion_id:
85
+ # Change the Stable Diffusion model to the new model ID
86
+ self.stable_diffusion_id = stablediffusion_id
87
+ # Reinitialize the pipeline with the new model ID
88
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
89
+ self.stable_diffusion_id,
90
+ controlnet=self.controlnet,
91
+ torch_dtype=dtype,
92
+ safety_checker=None
93
+ ).to(device)
94
 
95
  # Check if neither prompt nor image is provided
96
  if prompt is None and image is None:
 
104
  torch_dtype=dtype).to(device)
105
  self.pipe.controlnet = self.controlnet
106
 
107
+ # hyperparameters
108
+ negative_prompt = data.pop("negative_prompt", None)
 
109
  num_inference_steps = data.pop("num_inference_steps", 150)
110
  guidance_scale = data.pop("guidance_scale", 5)
111
  negative_prompt = data.pop("negative_prompt", None)
 
128
  height=height,
129
  width=width,
130
  controlnet_conditioning_scale=controlnet_conditioning_scale,
131
+ guess_mode=True,
 
132
  )
133
 
134
+ # generator=self.generator COMMENTED from self.pipe
135
+ # return the first generated PIL image
136
  return out.images[0]
137
 
138
  # helper to decode input image