Gjm1234 commited on
Commit
80f2df6
·
verified ·
1 Parent(s): fdb21d9

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +53 -43
handler.py CHANGED
@@ -1,62 +1,72 @@
1
- import io
2
- import base64
3
- import torch
4
- import os
5
  from PIL import Image
6
- from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
 
 
 
7
 
8
- BASE_MODEL = "Gjm1234/juggernaut-sfw"
9
- CONTROLNET = "lllyasviel/controlnet-depth-sdxl-1.0"
10
 
11
  class EndpointHandler:
12
- def __init__(self, path=""):
13
- print("🔧 Initializing handler…")
14
 
15
- HF_TOKEN = os.environ.get("HF_TOKEN")
16
- if not HF_TOKEN:
17
- raise RuntimeError("❌ HF_TOKEN not found in environment variables")
18
 
19
- print("🔧 Loading ControlNet with token…")
20
- controlnet = ControlNetModel.from_pretrained(
21
- CONTROLNET,
22
  torch_dtype=torch.float16,
23
- token=HF_TOKEN
24
- )
 
25
 
26
- print("🚀 Loading Juggernaut XL model with token…")
27
- self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
28
- BASE_MODEL,
29
- controlnet=controlnet,
30
  torch_dtype=torch.float16,
31
  use_safetensors=True,
32
- token=HF_TOKEN
 
33
  ).to("cuda")
34
 
35
- self.pipe.enable_xformers_memory_efficient_attention()
36
- print("✅ Pipeline loaded successfully!")
 
 
 
 
37
 
38
  def __call__(self, data):
39
  prompt = data.get("inputs", "")
40
- img_b64 = data.get("image", None)
41
-
42
- # Decode or create blank input
43
- if img_b64:
44
- img_bytes = base64.b64decode(img_b64)
45
- init = Image.open(io.BytesIO(img_bytes)).convert("RGB")
46
- else:
47
- init = Image.new("RGB", (1024, 1024), "white")
48
-
49
- outputs = []
50
- for _ in range(10): # ALWAYS generate 10 images
51
- result = self.pipe(
 
52
  prompt=prompt,
53
- image=init,
54
- num_inference_steps=20,
55
- guidance_scale=6.0,
 
 
 
 
56
  ).images[0]
57
 
58
- buffer = io.BytesIO()
59
- result.save(buffer, format="PNG")
60
- outputs.append(base64.b64encode(buffer.getvalue()).decode())
61
 
62
- return { "images": outputs }
 
1
+ import io, os, torch, base64
 
 
 
2
  from PIL import Image
3
+ from diffusers import (
4
+ StableDiffusionXLPipeline,
5
+ ControlNetModel
6
+ )
7
 
8
+ JUGGERNAUT_REPO = "Gjm1234/juggernaut-sfw"
9
+ CONTROLNET_REPO = "thibaud/controlnet-openpose-sdxl-1.0"
10
 
11
  class EndpointHandler:
12
+ def __init__(self, root=""):
13
+ print("🔧 Initializing Juggernaut + ControlNet")
14
 
15
+ token = os.environ.get("HF_TOKEN")
16
+ if not token:
17
+ raise RuntimeError("❌ Missing HF_TOKEN")
18
 
19
+ print("📥 Loading ControlNet …")
20
+ self.controlnet = ControlNetModel.from_pretrained(
21
+ CONTROLNET_REPO,
22
  torch_dtype=torch.float16,
23
+ use_safetensors=True,
24
+ token=token
25
+ ).to("cuda")
26
 
27
+ print("📥 Loading Juggernaut …")
28
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(
29
+ JUGGERNAUT_REPO,
 
30
  torch_dtype=torch.float16,
31
  use_safetensors=True,
32
+ token=token,
33
+ controlnet=self.controlnet
34
  ).to("cuda")
35
 
36
+ # Memory optimizations
37
+ self.pipe.enable_attention_slicing()
38
+ self.pipe.enable_vae_slicing()
39
+ self.pipe.unet.to(memory_format=torch.channels_last)
40
+
41
+ print("✅ Ready!")
42
 
43
  def __call__(self, data):
44
  prompt = data.get("inputs", "")
45
+ neg = data.get("negative_prompt", "")
46
+
47
+ # Optional: base64 input image for editing
48
+ image_b64 = data.get("image")
49
+ input_image = None
50
+
51
+ if image_b64:
52
+ img_bytes = base64.b64decode(image_b64)
53
+ input_image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
54
+
55
+ results = []
56
+ for _ in range(10):
57
+ output = self.pipe(
58
  prompt=prompt,
59
+ negative_prompt=neg,
60
+ image=input_image,
61
+ controlnet_conditioning_scale=0.7,
62
+ num_inference_steps=25,
63
+ guidance_scale=7.5,
64
+ width=1024,
65
+ height=1024,
66
  ).images[0]
67
 
68
+ buf = io.BytesIO()
69
+ output.save(buf, format="PNG")
70
+ results.append(base64.b64encode(buf.getvalue()).decode())
71
 
72
+ return { "images": results }