Alexmikupro commited on
Commit
6dfcee3
·
1 Parent(s): e1150b0

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +2 -2
handler.py CHANGED
@@ -14,14 +14,14 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
 
15
  class EndpointHandler():
16
 
17
- def __init__(self, path="", esrgan_model_path=""):
18
  # Load the StableDiffusionPipeline
19
  self.pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float32)
20
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
  self.pipe = self.pipe.to(device)
22
 
23
  # Load the ESRGAN state dictionary
24
- checkpoint = torch.load(esrgan_model_path)
25
 
26
  # Check if 'params_ema' is in the keys and filter the state_dict
27
  if "params_ema" in checkpoint:
 
14
 
15
  class EndpointHandler():
16
 
17
+ def __init__(self, path=""):
18
  # Load the StableDiffusionPipeline
19
  self.pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float32)
20
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
  self.pipe = self.pipe.to(device)
22
 
23
  # Load the ESRGAN state dictionary
24
+ checkpoint = torch.load("RealESRGAN_x4plus_anime_6B.pth")
25
 
26
  # Check if 'params_ema' is in the keys and filter the state_dict
27
  if "params_ema" in checkpoint: