mrcuddle commited on
Commit
d483ab6
·
verified ·
1 Parent(s): 6c830d7

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +4 -3
handler.py CHANGED
@@ -11,8 +11,8 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
  if device.type != 'cuda':
12
  raise ValueError("Need to run on GPU")
13
 
14
- class SDXLInpaintHandler:
15
- def __init__(self, path="mrcuddle/urpm-inpaint-sdxl"):
16
  """Load the SDXL Inpainting model."""
17
  self.pipeline = StableDiffusionXLInpaintPipeline.from_pretrained(
18
  path, torch_dtype=torch.float16
@@ -69,7 +69,8 @@ class SDXLInpaintHandler:
69
  image.save(buffered, format="PNG")
70
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
71
 
72
- handler = SDXLInpaintHandler()
 
73
 
74
  def handle(data: dict):
75
  return handler(data)
 
11
  if device.type != 'cuda':
12
  raise ValueError("Need to run on GPU")
13
 
14
+ class EndpointHandler:
15
+ def __init__(self, path="mrcuddle/URPM-Inpaint-Hyper-SDXL"):
16
  """Load the SDXL Inpainting model."""
17
  self.pipeline = StableDiffusionXLInpaintPipeline.from_pretrained(
18
  path, torch_dtype=torch.float16
 
69
  image.save(buffered, format="PNG")
70
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
71
 
72
+ # Create an instance of EndpointHandler
73
+ handler = EndpointHandler()
74
 
75
  def handle(data: dict):
76
  return handler(data)