Gjm1234 commited on
Commit
6b33414
·
verified ·
1 Parent(s): 80f2df6

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +29 -27
handler.py CHANGED
@@ -1,12 +1,9 @@
1
  import io, os, torch, base64
2
  from PIL import Image
3
- from diffusers import (
4
- StableDiffusionXLPipeline,
5
- ControlNetModel
6
- )
7
 
8
- JUGGERNAUT_REPO = "Gjm1234/juggernaut-sfw"
9
- CONTROLNET_REPO = "thibaud/controlnet-openpose-sdxl-1.0"
10
 
11
  class EndpointHandler:
12
  def __init__(self, root=""):
@@ -16,57 +13,62 @@ class EndpointHandler:
16
  if not token:
17
  raise RuntimeError("❌ Missing HF_TOKEN")
18
 
 
 
 
 
 
19
  print("📥 Loading ControlNet …")
20
  self.controlnet = ControlNetModel.from_pretrained(
21
- CONTROLNET_REPO,
22
  torch_dtype=torch.float16,
23
  use_safetensors=True,
24
  token=token
25
- ).to("cuda")
26
 
27
- print("📥 Loading Juggernaut …")
28
- self.pipe = StableDiffusionXLPipeline.from_pretrained(
29
- JUGGERNAUT_REPO,
 
30
  torch_dtype=torch.float16,
31
  use_safetensors=True,
32
- token=token,
33
- controlnet=self.controlnet
34
  ).to("cuda")
35
 
36
- # Memory optimizations
37
- self.pipe.enable_attention_slicing()
38
  self.pipe.enable_vae_slicing()
 
39
  self.pipe.unet.to(memory_format=torch.channels_last)
40
 
41
- print("✅ Ready!")
42
 
43
  def __call__(self, data):
44
- prompt = data.get("inputs", "")
45
- neg = data.get("negative_prompt", "")
46
 
47
- # Optional: base64 input image for editing
 
 
48
  image_b64 = data.get("image")
49
  input_image = None
50
 
51
  if image_b64:
52
- img_bytes = base64.b64decode(image_b64)
53
- input_image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
54
 
55
  results = []
 
 
56
  for _ in range(10):
57
- output = self.pipe(
58
  prompt=prompt,
59
- negative_prompt=neg,
60
  image=input_image,
61
- controlnet_conditioning_scale=0.7,
62
  num_inference_steps=25,
63
- guidance_scale=7.5,
64
  width=1024,
65
- height=1024,
66
  ).images[0]
67
 
68
  buf = io.BytesIO()
69
- output.save(buf, format="PNG")
70
  results.append(base64.b64encode(buf.getvalue()).decode())
71
 
72
  return { "images": results }
 
1
  import io, os, torch, base64
2
  from PIL import Image
3
+ from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
 
 
 
4
 
5
+ BASE_MODEL = "Gjm1234/juggernaut-sfw"
6
+ CONTROLNET_MODEL = "diffusers/controlnet-depth-sdxl-1.0" # ✅ RECOMMENDED WORKING MODEL
7
 
8
  class EndpointHandler:
9
  def __init__(self, root=""):
 
13
  if not token:
14
  raise RuntimeError("❌ Missing HF_TOKEN")
15
 
16
+ # 🚫 Disable flash/xformers/SDP
17
+ torch.backends.cuda.enable_flash_sdp(False)
18
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
19
+ torch.backends.cuda.enable_math_sdp(True)
20
+
21
  print("📥 Loading ControlNet …")
22
  self.controlnet = ControlNetModel.from_pretrained(
23
+ CONTROLNET_MODEL,
24
  torch_dtype=torch.float16,
25
  use_safetensors=True,
26
  token=token
27
+ )
28
 
29
+ print("📥 Loading Juggernaut XL (base)…")
30
+ self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
31
+ BASE_MODEL,
32
+ controlnet=self.controlnet,
33
  torch_dtype=torch.float16,
34
  use_safetensors=True,
35
+ token=token
 
36
  ).to("cuda")
37
 
 
 
38
  self.pipe.enable_vae_slicing()
39
+ self.pipe.enable_attention_slicing()
40
  self.pipe.unet.to(memory_format=torch.channels_last)
41
 
42
+ print("✅ Juggernaut + ControlNet ready!")
43
 
44
  def __call__(self, data):
 
 
45
 
46
+ prompt = data.get("prompt", "")
47
+ negative = data.get("negative_prompt", "")
48
+
49
  image_b64 = data.get("image")
50
  input_image = None
51
 
52
  if image_b64:
53
+ decoded = base64.b64decode(image_b64)
54
+ input_image = Image.open(io.BytesIO(decoded)).convert("RGB")
55
 
56
  results = []
57
+
58
+ # Run 10 variations
59
  for _ in range(10):
60
+ out = self.pipe(
61
  prompt=prompt,
62
+ negative_prompt=negative,
63
  image=input_image,
 
64
  num_inference_steps=25,
65
+ guidance_scale=5.5,
66
  width=1024,
67
+ height=1024
68
  ).images[0]
69
 
70
  buf = io.BytesIO()
71
+ out.save(buf, format="PNG")
72
  results.append(base64.b64encode(buf.getvalue()).decode())
73
 
74
  return { "images": results }