Update handler.py
Browse files- handler.py +4 -3
handler.py
CHANGED
|
@@ -11,8 +11,8 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
| 11 |
if device.type != 'cuda':
|
| 12 |
raise ValueError("Need to run on GPU")
|
| 13 |
|
| 14 |
-
class
|
| 15 |
-
def __init__(self, path="mrcuddle/
|
| 16 |
"""Load the SDXL Inpainting model."""
|
| 17 |
self.pipeline = StableDiffusionXLInpaintPipeline.from_pretrained(
|
| 18 |
path, torch_dtype=torch.float16
|
|
@@ -69,7 +69,8 @@ class SDXLInpaintHandler:
|
|
| 69 |
image.save(buffered, format="PNG")
|
| 70 |
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 71 |
|
| 72 |
-
|
|
|
|
| 73 |
|
| 74 |
def handle(data: dict):
|
| 75 |
return handler(data)
|
|
|
|
| 11 |
if device.type != 'cuda':
|
| 12 |
raise ValueError("Need to run on GPU")
|
| 13 |
|
| 14 |
+
class EndpointHandler:
|
| 15 |
+
def __init__(self, path="mrcuddle/URPM-Inpaint-Hyper-SDXL"):
|
| 16 |
"""Load the SDXL Inpainting model."""
|
| 17 |
self.pipeline = StableDiffusionXLInpaintPipeline.from_pretrained(
|
| 18 |
path, torch_dtype=torch.float16
|
|
|
|
| 69 |
image.save(buffered, format="PNG")
|
| 70 |
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 71 |
|
| 72 |
+
# Create an instance of EndpointHandler
|
| 73 |
+
handler = EndpointHandler()
|
| 74 |
|
| 75 |
def handle(data: dict):
|
| 76 |
return handler(data)
|