KarthikAI commited on
Commit
7361a5e
·
verified ·
1 Parent(s): 60325ef

Delete handler.py

Browse files
Files changed (1) hide show
  1. handler.py +0 -36
handler.py DELETED
@@ -1,36 +0,0 @@
1
- import base64
2
- import io
3
- from PIL import Image
4
- import torch
5
- from diffusers import StableDiffusionImg2ImgPipeline
6
-
7
- pipe = None
8
-
9
- class EndpointHandler:
10
- def __init__(self,model_dir):
11
- self.model_dir = model_dir
12
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
13
-
14
- def init(self):
15
- global pipe
16
- pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
17
- "InstantX/InstantID",
18
- torch_dtype=torch.float16,
19
- safety_checker=None
20
- ).to(self.device)
21
- pipe.enable_attention_slicing()
22
-
23
- def inference(self, model_inputs: dict) -> dict:
24
- img_bytes = base64.b64decode(model_inputs.get("image_base64"))
25
- init_image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
26
- output = pipe(
27
- prompt=model_inputs.get("prompt", ""),
28
- image=init_image,
29
- strength=float(model_inputs.get("strength", 0.75)),
30
- guidance_scale=float(model_inputs.get("guidance_scale", 7.5)),
31
- num_inference_steps=int(model_inputs.get("num_inference_steps", 50)),
32
- )
33
- sticker = output.images[0]
34
- buf = io.BytesIO()
35
- sticker.save(buf, format="PNG")
36
- return {"generated_image_base64": base64.b64encode(buf.getvalue()).decode("utf-8")}