Gjm1234 commited on
Commit
e134161
·
verified ·
1 Parent(s): aadf329

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +56 -29
handler.py CHANGED
@@ -1,47 +1,74 @@
1
  import base64
2
- from io import BytesIO
3
- from diffusers import StableDiffusionXLPipeline
4
- import torch
5
  from PIL import Image
 
 
 
 
6
 
7
  class EndpointHandler:
8
- def __init__(self, path=""):
9
- print("🔧 Loading Juggernaut-SFW model...")
 
 
10
  self.pipe = StableDiffusionXLPipeline.from_pretrained(
11
  "Gjm1234/juggernaut-sfw",
12
  torch_dtype=torch.float16,
 
13
  ).to("cuda")
14
 
15
- # Disable controlnet entirely (you are not loading any ControlNet weights)
16
- print("⚠️ ControlNet disabled — no weights provided.")
17
-
18
- # Memory optimisation
19
  self.pipe.enable_attention_slicing()
20
- self.pipe.enable_model_cpu_offload()
21
 
22
- def __call__(self, data):
23
- prompt = data.get("prompt", None)
24
- if prompt is None:
 
 
 
 
 
 
 
 
 
25
  return {"error": "prompt is required"}
26
 
27
- num_images = int(data.get("num_images", 4))
28
- if num_images < 1:
29
- num_images = 1
 
 
 
 
 
 
 
30
 
31
- print("🎨 Generating images…")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- images = self.pipe(
34
- prompt=prompt,
35
- num_inference_steps=25,
36
- guidance_scale=6,
37
- num_images_per_prompt=num_images
38
- ).images
39
 
40
- # Encode all images to Base64
41
- encoded = []
42
  for img in images:
43
- buf = BytesIO()
44
- img.save(buf, format="PNG")
45
- encoded.append(base64.b64encode(buf.getvalue()).decode("utf-8"))
 
46
 
47
- return {"images": encoded}
 
 
1
  import base64
2
+ import io
 
 
3
  from PIL import Image
4
+ import torch
5
+ from diffusers import StableDiffusionXLPipeline
6
+ from typing import Any, Dict
7
+
8
 
9
  class EndpointHandler:
10
+ def __init__(self, model_dir: str, **kwargs):
11
+ print("🔥 Initializing Juggernaut XL Handler (Prompt + Optional Image)...")
12
+
13
+ # Load XL model from your big repo
14
  self.pipe = StableDiffusionXLPipeline.from_pretrained(
15
  "Gjm1234/juggernaut-sfw",
16
  torch_dtype=torch.float16,
17
+ use_safetensors=True
18
  ).to("cuda")
19
 
 
 
 
 
20
  self.pipe.enable_attention_slicing()
 
21
 
22
+ print("✅ Pipeline loaded successfully.")
23
+
24
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
25
+
26
+ # Must receive `inputs`
27
+ if "inputs" not in data:
28
+ return {"error": "Body must contain 'inputs' object"}
29
+
30
+ inputs = data["inputs"]
31
+
32
+ prompt = inputs.get("prompt", None)
33
+ if not prompt:
34
  return {"error": "prompt is required"}
35
 
36
+ num_images = inputs.get("num_images", 10)
37
+ image_b64 = inputs.get("image", None)
38
+
39
+ init_image = None
40
+ if image_b64:
41
+ try:
42
+ img_bytes = base64.b64decode(image_b64)
43
+ init_image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
44
+ except Exception as e:
45
+ return {"error": f"Invalid image data: {str(e)}"}
46
 
47
+ # Run txt2img OR img2img depending on whether image was sent
48
+ if init_image is None:
49
+ print("🎨 Running TEXT → IMAGE")
50
+ output = self.pipe(
51
+ prompt=prompt,
52
+ num_images_per_prompt=num_images
53
+ )
54
+ else:
55
+ print("🎨 Running IMAGE → IMAGE")
56
+ output = self.pipe(
57
+ prompt=prompt,
58
+ image=init_image,
59
+ strength=0.6,
60
+ num_images_per_prompt=num_images
61
+ )
62
 
63
+ images = output.images
 
 
 
 
 
64
 
65
+ # convert to base64 array
66
+ results = []
67
  for img in images:
68
+ buffered = io.BytesIO()
69
+ img.save(buffered, format="PNG")
70
+ img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
71
+ results.append(img_b64)
72
 
73
+ print(f"✅ Returning {len(results)} images.")
74
+ return {"images": results}