KarthikAI commited on
Commit
d66b293
·
verified ·
1 Parent(s): 8204928

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +58 -28
handler.py CHANGED
@@ -1,43 +1,73 @@
1
- # handler.py
2
  import base64
3
  import io
4
  from PIL import Image
5
  import torch
6
  from diffusers import StableDiffusionImg2ImgPipeline
7
 
 
8
  pipe = None
9
 
10
  class EndpointHandler:
11
- def __init__(self, model_dir):
12
- # model_dir is where HF clones your repo; you can ignore it
13
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
  def init(self):
 
 
 
16
  global pipe
17
- # Load the SD1.5 + InstantID adapter in one shot
18
- pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
19
- "karthikAI/InstantID-i2i", # your HF repo
20
- revision="main",
21
- torch_dtype=torch.float16,
22
- safety_checker=None
23
- ).to(self.device)
24
- pipe.enable_attention_slicing()
25
 
26
  def inference(self, model_inputs: dict) -> dict:
27
- # Decode input image (base64)
28
- img_data = base64.b64decode(model_inputs["image_base64"])
29
- init_img = Image.open(io.BytesIO(img_data)).convert("RGB")
30
-
31
- # Run img2img
32
- out = pipe(
33
- prompt=model_inputs.get("prompt", ""),
34
- image=init_img,
35
- strength=float(model_inputs.get("strength", 0.75)),
36
- guidance_scale=float(model_inputs.get("guidance_scale", 7.5)),
37
- num_inference_steps=int(model_inputs.get("num_inference_steps", 50)),
38
- ).images[0]
39
-
40
- # Encode output back to base64
41
- buf = io.BytesIO()
42
- out.save(buf, format="PNG")
43
- return {"generated_image_base64": base64.b64encode(buf.getvalue()).decode("utf-8")}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import base64
2
  import io
3
  from PIL import Image
4
  import torch
5
  from diffusers import StableDiffusionImg2ImgPipeline
6
 
7
+ # Global pipeline instance
8
  pipe = None
9
 
10
  class EndpointHandler:
11
+ def __init__(self, model_dir: str):
12
+ # Determine device based on CUDA availability
13
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
  def init(self):
16
+ """
17
+ Load the InstantID-enhanced Stable Diffusion img2img model once when the endpoint starts.
18
+ """
19
  global pipe
20
+ if pipe is None:
21
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
22
+ "karthikAI/InstantID-i2i", # Your HF repo with InstantID adapter
23
+ revision="main",
24
+ torch_dtype=torch.float16,
25
+ safety_checker=None
26
+ ).to(self.device)
27
+ pipe.enable_attention_slicing()
28
 
29
  def inference(self, model_inputs: dict) -> dict:
30
+ """
31
+ Run a single img2img inference.
32
+
33
+ Expects a JSON payload with:
34
+ - "inputs": base64-encoded input image
35
+ - "parameters": {
36
+ "prompt": str,
37
+ "strength": float,
38
+ "guidance_scale": float,
39
+ "num_inference_steps": int,
40
+ }
41
+ Returns a dict with:
42
+ - "generated_image_base64": base64-encoded PNG
43
+ """
44
+ # 1. Decode the incoming image
45
+ b64_img = model_inputs.get("inputs")
46
+ if not b64_img:
47
+ raise ValueError("No image data provided under 'inputs'.")
48
+ image_bytes = base64.b64decode(b64_img)
49
+ init_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
50
+
51
+ # 2. Extract parameters
52
+ params = model_inputs.get("parameters", {})
53
+ prompt = params.get("prompt", "")
54
+ strength = float(params.get("strength", 0.75))
55
+ guidance_scale = float(params.get("guidance_scale", 7.5))
56
+ num_steps = int(params.get("num_inference_steps", 50))
57
+
58
+ # 3. Run the img2img pipeline
59
+ result = pipe(
60
+ prompt=prompt,
61
+ init_image=init_img,
62
+ strength=strength,
63
+ guidance_scale=guidance_scale,
64
+ num_inference_steps=num_steps,
65
+ )
66
+ out_img = result.images[0]
67
+
68
+ # 4. Encode the output image back to base64
69
+ buffer = io.BytesIO()
70
+ out_img.save(buffer, format="PNG")
71
+ generated_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
72
+
73
+ return {"generated_image_base64": generated_b64}