Gjm1234 commited on
Commit
d2d10a5
·
verified ·
1 Parent(s): 6b33414

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +54 -54
handler.py CHANGED
@@ -1,74 +1,74 @@
1
- import io, os, torch, base64
2
- from PIL import Image
 
3
  from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
 
 
4
 
5
- BASE_MODEL = "Gjm1234/juggernaut-sfw"
6
- CONTROLNET_MODEL = "diffusers/controlnet-depth-sdxl-1.0" # ✅ RECOMMENDED WORKING MODEL
7
 
8
  class EndpointHandler:
9
- def __init__(self, root=""):
10
- print("🔧 Initializing Juggernaut + ControlNet")
11
-
12
- token = os.environ.get("HF_TOKEN")
13
- if not token:
14
- raise RuntimeError("❌ Missing HF_TOKEN")
15
 
16
- # 🚫 Disable flash/xformers/SDP
17
- torch.backends.cuda.enable_flash_sdp(False)
18
- torch.backends.cuda.enable_mem_efficient_sdp(False)
19
- torch.backends.cuda.enable_math_sdp(True)
20
-
21
- print("📥 Loading ControlNet …")
22
  self.controlnet = ControlNetModel.from_pretrained(
23
- CONTROLNET_MODEL,
24
  torch_dtype=torch.float16,
25
  use_safetensors=True,
26
- token=token
27
  )
28
 
29
- print("📥 Loading Juggernaut XL (base)…")
 
30
  self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
31
- BASE_MODEL,
32
  controlnet=self.controlnet,
33
  torch_dtype=torch.float16,
34
  use_safetensors=True,
35
- token=token
36
  ).to("cuda")
37
 
38
- self.pipe.enable_vae_slicing()
39
- self.pipe.enable_attention_slicing()
40
- self.pipe.unet.to(memory_format=torch.channels_last)
41
 
42
- print(" Juggernaut + ControlNet ready!")
43
 
44
  def __call__(self, data):
45
-
46
- prompt = data.get("prompt", "")
47
- negative = data.get("negative_prompt", "")
48
-
49
- image_b64 = data.get("image")
50
- input_image = None
51
-
52
- if image_b64:
53
- decoded = base64.b64decode(image_b64)
54
- input_image = Image.open(io.BytesIO(decoded)).convert("RGB")
55
-
56
- results = []
57
-
58
- # Run 10 variations
59
- for _ in range(10):
60
- out = self.pipe(
 
 
61
  prompt=prompt,
62
- negative_prompt=negative,
63
- image=input_image,
64
- num_inference_steps=25,
65
- guidance_scale=5.5,
66
- width=1024,
67
- height=1024
68
- ).images[0]
69
-
70
- buf = io.BytesIO()
71
- out.save(buf, format="PNG")
72
- results.append(base64.b64encode(buf.getvalue()).decode())
73
-
74
- return { "images": results }
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import torch
4
  from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
5
+ from PIL import Image
6
+ import os
7
 
8
+ HF_TOKEN = os.getenv("HF_TOKEN", None)
 
9
 
10
  class EndpointHandler:
11
+ def __init__(self, path=""):
12
+ print("🔧 Initializing Juggernaut + ControlNet")
 
 
 
 
13
 
14
+ # Load ControlNet
15
+ print("📥 Loading ControlNet…")
 
 
 
 
16
  self.controlnet = ControlNetModel.from_pretrained(
17
+ "lllyasviel/controlnet-depth-sdxl-1.0",
18
  torch_dtype=torch.float16,
19
  use_safetensors=True,
20
+ token=HF_TOKEN
21
  )
22
 
23
+ # Load your big base model repo
24
+ print("📥 Loading Base Model (juggernaut-sfw)…")
25
  self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
26
+ "Gjm1234/juggernaut-sfw",
27
  controlnet=self.controlnet,
28
  torch_dtype=torch.float16,
29
  use_safetensors=True,
30
+ token=HF_TOKEN
31
  ).to("cuda")
32
 
33
+ # prevent OOM
34
+ self.pipe.enable_model_cpu_offload()
 
35
 
36
+ print("🚀 Pipeline Loaded Successfully!")
37
 
38
  def __call__(self, data):
39
+ try:
40
+ prompt = data.get("prompt", "")
41
+ negative_prompt = data.get("negative_prompt", "")
42
+ num_images = data.get("num_images", 10)
43
+
44
+ # Handle optional image
45
+ image_data = data.get("image", None)
46
+
47
+ if image_data:
48
+ # Base64 → PIL
49
+ image_bytes = base64.b64decode(image_data)
50
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
51
+ else:
52
+ # ControlNet requires something — generate depth = blank
53
+ image = Image.new("RGB", (1024, 1024), "white")
54
+
55
+ # Run generation
56
+ output = self.pipe(
57
  prompt=prompt,
58
+ negative_prompt=negative_prompt,
59
+ image=image,
60
+ num_inference_steps=28,
61
+ guidance_scale=6.5,
62
+ num_images_per_prompt=num_images,
63
+ )
64
+
65
+ images = []
66
+ for img in output.images:
67
+ buf = io.BytesIO()
68
+ img.save(buf, format="PNG")
69
+ images.append(base64.b64encode(buf.getvalue()).decode("utf-8"))
70
+
71
+ return {"images": images}
72
+
73
+ except Exception as e:
74
+ return {"error": str(e)}