Gjm1234 commited on
Commit
aadf329
·
verified ·
1 Parent(s): a91abf3

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +40 -80
handler.py CHANGED
@@ -1,87 +1,47 @@
1
  import base64
2
- import io
3
- from PIL import Image
4
  import torch
5
- from diffusers import (
6
- StableDiffusionXLControlNetPipeline,
7
- ControlNetModel,
8
- AutoencoderKL,
9
- EulerAncestralDiscreteScheduler
10
- )
11
 
12
  class EndpointHandler:
13
- def __init__(self, model_dir):
14
-
15
- # -------------------------------
16
- # Load actual model from OTHER repo
17
- # -------------------------------
18
- BASE_MODEL = "Gjm1234/juggernaut-sfw" # MAIN MODEL REPO
19
- CONTROLNET_REPO = "Gjm1234/juggernaut-sfw/controlnet" # CONTROLNET INSIDE MODEL REPO
20
-
21
- print("🔧 Loading ControlNet from:", CONTROLNET_REPO)
22
- self.controlnet = ControlNetModel.from_pretrained(
23
- CONTROLNET_REPO,
24
- torch_dtype=torch.float16
25
- )
26
-
27
- print("🔧 Loading base model pipeline from:", BASE_MODEL)
28
- self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
29
- BASE_MODEL,
30
- controlnet=self.controlnet,
31
- torch_dtype=torch.float16
32
- )
33
-
34
- # Memory-friendly settings
35
- self.pipe.to("cuda")
36
  self.pipe.enable_model_cpu_offload()
37
- self.pipe.enable_xformers_memory_efficient_attention()
38
-
39
- print("✅ Pipeline loaded successfully.")
40
 
41
  def __call__(self, data):
42
-
43
- prompt = data.get("prompt", "")
44
- num_images = int(data.get("num_images", 1))
45
-
46
- # Optional image input
47
- image_b64 = data.get("image", None)
48
-
49
- if image_b64:
50
- try:
51
- image_bytes = base64.b64decode(image_b64)
52
- init_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
53
- except:
54
- init_image = None
55
- else:
56
- init_image = None
57
-
58
- print("📥 Prompt:", prompt)
59
- print("📥 Image provided:", init_image is not None)
60
-
61
- # ---- Handle optional image ----
62
- if init_image is None:
63
- # No image → use text-only diffusion
64
- print("🎨 Running TEXT-ONLY generation...")
65
- out = self.pipe(
66
- prompt=prompt,
67
- num_inference_steps=25,
68
- num_images_per_prompt=num_images
69
- )
70
- else:
71
- # ControlNet uses the image
72
- print("🎛 Running ControlNet IMAGE + TEXT...")
73
- out = self.pipe(
74
- prompt=prompt,
75
- image=init_image,
76
- num_inference_steps=25,
77
- num_images_per_prompt=num_images
78
- )
79
-
80
- images = []
81
-
82
- for img in out.images:
83
- buffer = io.BytesIO()
84
- img.save(buffer, format="JPEG")
85
- images.append(base64.b64encode(buffer.getvalue()).decode("utf-8"))
86
-
87
- return {"images": images}
 
1
  import base64
2
+ from io import BytesIO
3
+ from diffusers import StableDiffusionXLPipeline
4
  import torch
5
+ from PIL import Image
 
 
 
 
 
6
 
7
  class EndpointHandler:
8
+ def __init__(self, path=""):
9
+ print("🔧 Loading Juggernaut-SFW model...")
10
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(
11
+ "Gjm1234/juggernaut-sfw",
12
+ torch_dtype=torch.float16,
13
+ ).to("cuda")
14
+
15
+ # Disable controlnet entirely (you are not loading any ControlNet weights)
16
+ print("⚠️ ControlNet disabled — no weights provided.")
17
+
18
+ # Memory optimisation
19
+ self.pipe.enable_attention_slicing()
 
 
 
 
 
 
 
 
 
 
 
20
  self.pipe.enable_model_cpu_offload()
 
 
 
21
 
22
  def __call__(self, data):
23
+ prompt = data.get("prompt", None)
24
+ if prompt is None:
25
+ return {"error": "prompt is required"}
26
+
27
+ num_images = int(data.get("num_images", 4))
28
+ if num_images < 1:
29
+ num_images = 1
30
+
31
+ print("🎨 Generating images…")
32
+
33
+ images = self.pipe(
34
+ prompt=prompt,
35
+ num_inference_steps=25,
36
+ guidance_scale=6,
37
+ num_images_per_prompt=num_images
38
+ ).images
39
+
40
+ # Encode all images to Base64
41
+ encoded = []
42
+ for img in images:
43
+ buf = BytesIO()
44
+ img.save(buf, format="PNG")
45
+ encoded.append(base64.b64encode(buf.getvalue()).decode("utf-8"))
46
+
47
+ return {"images": encoded}