Update handler.py
Browse files- handler.py +24 -65
handler.py
CHANGED
|
@@ -1,85 +1,44 @@
|
|
| 1 |
import base64
|
| 2 |
import io
|
| 3 |
-
import traceback
|
| 4 |
from PIL import Image
|
| 5 |
import torch
|
| 6 |
from diffusers import StableDiffusionImg2ImgPipeline
|
| 7 |
|
| 8 |
# Global pipeline instance
|
|
|
|
| 9 |
pipe = None
|
| 10 |
|
| 11 |
class EndpointHandler:
|
| 12 |
def __init__(self, model_dir: str):
|
| 13 |
-
#
|
| 14 |
-
|
| 15 |
|
| 16 |
def init(self):
|
| 17 |
-
"""
|
| 18 |
-
Load the InstantID-enhanced Stable Diffusion img2img model once when the endpoint starts.
|
| 19 |
-
"""
|
| 20 |
global pipe
|
| 21 |
if pipe is None:
|
|
|
|
| 22 |
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
| 23 |
"karthikAI/InstantID-i2i",
|
| 24 |
revision="main",
|
| 25 |
-
torch_dtype=torch.float16
|
| 26 |
-
|
| 27 |
-
).to(self.device)
|
| 28 |
-
pipe.enable_attention_slicing()
|
| 29 |
|
| 30 |
def inference(self, model_inputs: dict) -> dict:
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
# 1. Decode incoming image
|
| 50 |
-
b64_img = model_inputs.get("inputs")
|
| 51 |
-
if not b64_img:
|
| 52 |
-
raise ValueError("No image data provided under 'inputs'.")
|
| 53 |
-
image_bytes = base64.b64decode(b64_img)
|
| 54 |
-
init_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 55 |
-
|
| 56 |
-
# 2. Extract parameters
|
| 57 |
-
params = model_inputs.get("parameters", {})
|
| 58 |
-
prompt = params.get("prompt", "")
|
| 59 |
-
strength = float(params.get("strength", 0.75))
|
| 60 |
-
guidance_scale = float(params.get("guidance_scale", 7.5))
|
| 61 |
-
num_steps = int(params.get("num_inference_steps", 50))
|
| 62 |
-
|
| 63 |
-
# 3. Run the img2img pipeline
|
| 64 |
-
result = pipe(
|
| 65 |
-
prompt=prompt,
|
| 66 |
-
image=init_img,
|
| 67 |
-
strength=strength,
|
| 68 |
-
guidance_scale=guidance_scale,
|
| 69 |
-
num_inference_steps=num_steps,
|
| 70 |
-
)
|
| 71 |
-
out_img = result.images[0]
|
| 72 |
-
|
| 73 |
-
# 4. Encode and return image
|
| 74 |
-
buffer = io.BytesIO()
|
| 75 |
-
out_img.save(buffer, format="PNG")
|
| 76 |
-
generated_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
| 77 |
-
return {"generated_image_base64": generated_b64}
|
| 78 |
-
|
| 79 |
-
except Exception as e:
|
| 80 |
-
# Return detailed error info for debugging
|
| 81 |
-
tb = traceback.format_exc()
|
| 82 |
-
return {
|
| 83 |
-
"error": str(e),
|
| 84 |
-
"traceback": tb
|
| 85 |
-
}
|
|
|
|
| 1 |
import base64
|
| 2 |
import io
|
|
|
|
| 3 |
from PIL import Image
|
| 4 |
import torch
|
| 5 |
from diffusers import StableDiffusionImg2ImgPipeline
|
| 6 |
|
| 7 |
# Global pipeline instance
|
| 8 |
+
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 9 |
pipe = None
|
| 10 |
|
| 11 |
class EndpointHandler:
|
| 12 |
def __init__(self, model_dir: str):
|
| 13 |
+
# model_dir is ignored; HF clones your repo here
|
| 14 |
+
pass
|
| 15 |
|
| 16 |
def init(self):
|
|
|
|
|
|
|
|
|
|
| 17 |
global pipe
|
| 18 |
if pipe is None:
|
| 19 |
+
# Load your InstantID img2img model
|
| 20 |
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
| 21 |
"karthikAI/InstantID-i2i",
|
| 22 |
revision="main",
|
| 23 |
+
torch_dtype=torch.float16
|
| 24 |
+
).to(torch_device)
|
|
|
|
|
|
|
| 25 |
|
| 26 |
def inference(self, model_inputs: dict) -> dict:
|
| 27 |
+
# 1) decode base64 image
|
| 28 |
+
b64 = model_inputs.get("inputs")
|
| 29 |
+
if b64 is None:
|
| 30 |
+
return {"error": "No 'inputs' key with base64 image provided."}
|
| 31 |
+
img = Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
|
| 32 |
+
|
| 33 |
+
# 2) extract prompt
|
| 34 |
+
prompt = model_inputs.get("parameters", {}).get("prompt", "")
|
| 35 |
+
|
| 36 |
+
# 3) minimal call: prompt + image only
|
| 37 |
+
out = pipe(prompt=prompt, image=img)
|
| 38 |
+
result_img = out.images[0]
|
| 39 |
+
|
| 40 |
+
# 4) encode output
|
| 41 |
+
buf = io.BytesIO()
|
| 42 |
+
result_img.save(buf, format="PNG")
|
| 43 |
+
b64_out = base64.b64encode(buf.getvalue()).decode()
|
| 44 |
+
return {"generated_image_base64": b64_out}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|