Gjm1234 commited on
Commit
e60d5f7
·
verified ·
1 Parent(s): bf168de

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +37 -69
handler.py CHANGED
@@ -1,89 +1,57 @@
 
 
1
  import torch
2
- import os
3
- import gc
4
- from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
5
  from PIL import Image
6
- import base64
7
- from io import BytesIO
8
 
 
 
9
 
10
  class EndpointHandler:
11
- def __init__(self, model_dir):
12
- print("🔧 Initializing improved memory-safe handler...")
13
 
14
- # Prevent cuda fragmentation
15
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
16
-
17
- # HF token
18
- token = os.getenv("HF_TOKEN", None)
19
-
20
- # Load ControlNet
21
- print("🔧 Loading ControlNet…")
22
  controlnet = ControlNetModel.from_pretrained(
23
- os.path.join(model_dir, "controlnet"),
24
- torch_dtype=torch.float16,
25
- use_safetensors=True,
26
- token=token
27
  )
28
 
29
- # Load main model
30
- print("🔧 Loading Juggernaut XL…")
31
  self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
32
- os.path.join(model_dir, "model"),
33
  controlnet=controlnet,
34
  torch_dtype=torch.float16,
35
- use_safetensors=True,
36
- token=token
37
- )
38
 
39
- # VRAM-saving settings
40
- self.pipe.to("cuda")
41
- self.pipe.enable_attention_slicing()
42
- self.pipe.enable_vae_slicing()
43
- self.pipe.enable_sequential_cpu_offload()
44
-
45
- print("✅ Pipeline ready!")
46
 
47
  def __call__(self, data):
48
- try:
49
- prompt = data.get("prompt", "")
50
- image_b64 = data.get("image", None)
51
-
52
- if not prompt:
53
- return {"error": "Missing prompt"}
54
-
55
- if not image_b64:
56
- return {"error": "Missing image input"}
57
-
58
- # Decode ControlNet image
59
- try:
60
- image_bytes = base64.b64decode(image_b64)
61
- control_image = Image.open(BytesIO(image_bytes)).convert("RGB")
62
- except:
63
- return {"error": "Invalid base64 image"}
64
-
65
- # Run the pipeline
66
  result = self.pipe(
67
  prompt=prompt,
68
- image=control_image,
69
- num_inference_steps=20,
70
- guidance_scale=3.0,
71
- controlnet_conditioning_scale=1.0,
72
- height=768,
73
- width=512
74
  ).images[0]
75
 
76
- # Convert output to base64
77
- buffered = BytesIO()
78
- result.save(buffered, format="JPEG")
79
- output_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
80
-
81
- return {"image": output_b64}
82
-
83
- except Exception as e:
84
- return {"error": str(e)}
85
 
86
- finally:
87
- # 🔥 Force GPU/CPU memory cleanup
88
- torch.cuda.empty_cache()
89
- gc.collect()
 
1
+ import io
2
+ import base64
3
  import torch
 
 
 
4
  from PIL import Image
5
+ from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
 
6
 
7
+ BASE_MODEL = "Gjm1234/juggernaut-sfw"
8
+ CONTROLNET = "lllyasviel/controlnet-depth-sdxl-1.0"
9
 
10
  class EndpointHandler:
11
+ def __init__(self, path=""):
12
+ print("🔧 Initializing handler loading remote models...")
13
 
14
+ print("🔧 Loading ControlNet...")
 
 
 
 
 
 
 
15
  controlnet = ControlNetModel.from_pretrained(
16
+ CONTROLNET,
17
+ torch_dtype=torch.float16
 
 
18
  )
19
 
20
+ print("🚀 Loading Juggernaut XL main model...")
 
21
  self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
22
+ BASE_MODEL,
23
  controlnet=controlnet,
24
  torch_dtype=torch.float16,
25
+ use_safetensors=True
26
+ ).to("cuda")
 
27
 
28
+ self.pipe.enable_xformers_memory_efficient_attention()
29
+ print("✅ Pipeline ready")
 
 
 
 
 
30
 
31
  def __call__(self, data):
32
+ prompt = data.get("inputs", "")
33
+ img_b64 = data.get("image", None)
34
+
35
+ # Decode input image OR generate blank white one
36
+ if img_b64:
37
+ img_bytes = base64.b64decode(img_b64)
38
+ init = Image.open(io.BytesIO(img_bytes)).convert("RGB")
39
+ else:
40
+ init = Image.new("RGB", (1024, 1024), "white")
41
+
42
+ outputs = []
43
+ for _ in range(10): # always 10 variations
 
 
 
 
 
 
44
  result = self.pipe(
45
  prompt=prompt,
46
+ image=init,
47
+ num_inference_steps=25,
48
+ guidance_scale=6.0,
49
+ width=1024,
50
+ height=1024,
 
51
  ).images[0]
52
 
53
+ buf = io.BytesIO()
54
+ result.save(buf, format="PNG")
55
+ outputs.append(base64.b64encode(buf.getvalue()).decode())
 
 
 
 
 
 
56
 
57
+ return { "images": outputs }