mrcuddle commited on
Commit
7235d74
·
verified ·
1 Parent(s): 9b59a85

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +75 -0
handler.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ import base64
4
+ import io
5
+ from PIL import Image
6
+ from diffusers import DPMSolverMultistepScheduler, StableDiffusionXLInpaintPipeline
7
+
8
+ # Set device
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+
11
+ if device.type != 'cuda':
12
+ raise ValueError("Need to run on GPU")
13
+
14
+ class SDXLInpaintHandler:
15
+ def __init__(self, path="mrcuddle/URPM-Inpaint-Hyper-SDXL"):
16
+ """Load the SDXL Inpainting model."""
17
+ self.pipeline = StableDiffusionXLInpaintPipeline.from_pretrained(
18
+ path, torch_dtype=torch.float16
19
+ )
20
+ self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(self.pipeline.scheduler.config)
21
+ self.pipeline = self.pipeline.to(device)
22
+
23
+ def __call__(self, data: dict):
24
+ """Custom call function for Hugging Face Inference Endpoints."""
25
+ try:
26
+ inputs = data.pop("inputs", data)
27
+ encoded_image = data.pop("image", None)
28
+ encoded_mask_image = data.pop("mask_image", None)
29
+
30
+ num_inference_steps = data.pop("num_inference_steps", 25)
31
+ guidance_scale = data.pop("guidance_scale", 7.5)
32
+ negative_prompt = data.pop("negative_prompt", None)
33
+ height = data.pop("height", None)
34
+ width = data.pop("width", None)
35
+
36
+ # Process images
37
+ if encoded_image and encoded_mask_image:
38
+ image = self.decode_base64_image(encoded_image)
39
+ mask_image = self.decode_base64_image(encoded_mask_image)
40
+ else:
41
+ raise ValueError("Both image and mask_image are required")
42
+
43
+ # Run inference
44
+ output_image = self.pipeline(
45
+ prompt=inputs,
46
+ image=image,
47
+ mask_image=mask_image,
48
+ num_inference_steps=num_inference_steps,
49
+ guidance_scale=guidance_scale,
50
+ num_images_per_prompt=1,
51
+ negative_prompt=negative_prompt,
52
+ height=height,
53
+ width=width
54
+ ).images[0]
55
+
56
+ return json.dumps({"output": self.encode_base64_image(output_image)})
57
+ except Exception as e:
58
+ return json.dumps({"error": str(e)})
59
+
60
+ def decode_base64_image(self, image_string):
61
+ """Decode base64 encoded image."""
62
+ base64_image = base64.b64decode(image_string)
63
+ buffer = io.BytesIO(base64_image)
64
+ return Image.open(buffer).convert("RGB")
65
+
66
+ def encode_base64_image(self, image):
67
+ """Encode PIL image to base64."""
68
+ buffered = io.BytesIO()
69
+ image.save(buffered, format="PNG")
70
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
71
+
72
+ handler = SDXLInpaintHandler()
73
+
74
+ def handle(data: dict):
75
+ return handler(data)