Gjm1234 commited on
Commit
bf168de
·
verified ·
1 Parent(s): b81ba5d

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +65 -53
handler.py CHANGED
@@ -1,77 +1,89 @@
1
- from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
2
  import torch
 
 
 
3
  from PIL import Image
4
  import base64
5
  from io import BytesIO
6
- import os
7
 
8
  class EndpointHandler:
9
  def __init__(self, model_dir):
10
- print("🔑 Loading HF token...")
11
- hf_token = os.getenv("HF_TOKEN")
 
 
12
 
13
- print("🔧 Loading ControlNet...")
 
 
 
 
14
  controlnet = ControlNetModel.from_pretrained(
15
- "diffusers/controlnet-depth-sdxl-1.0",
16
  torch_dtype=torch.float16,
17
  use_safetensors=True,
18
- token=hf_token
19
  )
20
 
21
- print("🧠 Loading Juggernaut XL...")
22
- base_model = "Gjm1234/juggernaut-sfw"
23
-
24
  self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
25
- base_model,
26
  controlnet=controlnet,
27
  torch_dtype=torch.float16,
28
  use_safetensors=True,
29
- token=hf_token
30
- ).to("cuda")
31
 
32
- # IMPORTANT FIX — remove xformers, use PyTorch attention instead
33
- if hasattr(self.pipe, "enable_model_cpu_offload"):
34
- self.pipe.enable_model_cpu_offload()
 
 
35
 
36
- print("🚀 Pipeline loaded successfully!")
37
 
38
  def __call__(self, data):
39
- print("📥 Received request...")
40
-
41
- prompt = data.get("prompt", "")
42
- negative_prompt = data.get("negative_prompt", "blurry, bad quality, distorted, extra limbs")
43
-
44
- num_images = 10
45
-
46
- # Decode uploaded image
47
- image_b64 = data.get("image")
48
- if image_b64:
49
- print("🖼️ Decoding input image...")
50
- image_data = base64.b64decode(image_b64)
51
- init_image = Image.open(BytesIO(image_data)).convert("RGB")
52
- else:
53
- print("⚠️ No image uploaded — generating blank control input.")
54
- init_image = Image.new("RGB", (1024, 1024), "white")
55
-
56
- print("🎨 Generating images...")
57
- output = self.pipe(
58
- prompt=prompt,
59
- negative_prompt=negative_prompt,
60
- image=init_image,
61
- num_inference_steps=30,
62
- num_images_per_prompt=num_images,
63
- guidance_scale=7.0,
64
- )
 
 
 
 
 
 
65
 
66
- images = output.images
67
 
68
- print("📤 Encoding output images...")
69
- result = []
70
- for img in images:
71
- buffer = BytesIO()
72
- img.save(buffer, format="PNG")
73
- b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
74
- result.append(b64)
75
 
76
- print("✅ Returning images...")
77
- return {"images": result}
 
 
 
 
1
  import torch
2
+ import os
3
+ import gc
4
+ from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
5
  from PIL import Image
6
  import base64
7
  from io import BytesIO
8
+
9
 
10
  class EndpointHandler:
11
  def __init__(self, model_dir):
12
+ print("🔧 Initializing improved memory-safe handler...")
13
+
14
+ # Prevent cuda fragmentation
15
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
16
 
17
+ # HF token
18
+ token = os.getenv("HF_TOKEN", None)
19
+
20
+ # Load ControlNet
21
+ print("🔧 Loading ControlNet…")
22
  controlnet = ControlNetModel.from_pretrained(
23
+ os.path.join(model_dir, "controlnet"),
24
  torch_dtype=torch.float16,
25
  use_safetensors=True,
26
+ token=token
27
  )
28
 
29
+ # Load main model
30
+ print("🔧 Loading Juggernaut XL…")
 
31
  self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
32
+ os.path.join(model_dir, "model"),
33
  controlnet=controlnet,
34
  torch_dtype=torch.float16,
35
  use_safetensors=True,
36
+ token=token
37
+ )
38
 
39
+ # VRAM-saving settings
40
+ self.pipe.to("cuda")
41
+ self.pipe.enable_attention_slicing()
42
+ self.pipe.enable_vae_slicing()
43
+ self.pipe.enable_sequential_cpu_offload()
44
 
45
+ print(" Pipeline ready!")
46
 
47
  def __call__(self, data):
48
+ try:
49
+ prompt = data.get("prompt", "")
50
+ image_b64 = data.get("image", None)
51
+
52
+ if not prompt:
53
+ return {"error": "Missing prompt"}
54
+
55
+ if not image_b64:
56
+ return {"error": "Missing image input"}
57
+
58
+ # Decode ControlNet image
59
+ try:
60
+ image_bytes = base64.b64decode(image_b64)
61
+ control_image = Image.open(BytesIO(image_bytes)).convert("RGB")
62
+ except:
63
+ return {"error": "Invalid base64 image"}
64
+
65
+ # Run the pipeline
66
+ result = self.pipe(
67
+ prompt=prompt,
68
+ image=control_image,
69
+ num_inference_steps=20,
70
+ guidance_scale=3.0,
71
+ controlnet_conditioning_scale=1.0,
72
+ height=768,
73
+ width=512
74
+ ).images[0]
75
+
76
+ # Convert output to base64
77
+ buffered = BytesIO()
78
+ result.save(buffered, format="JPEG")
79
+ output_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
80
 
81
+ return {"image": output_b64}
82
 
83
+ except Exception as e:
84
+ return {"error": str(e)}
 
 
 
 
 
85
 
86
+ finally:
87
+ # 🔥 Force GPU/CPU memory cleanup
88
+ torch.cuda.empty_cache()
89
+ gc.collect()