oyly commited on
Commit
4cebc5d
·
1 Parent(s): 8c6d269

load ae on gpu

Browse files
Files changed (1) hide show
  1. app.py +1 -0
app.py CHANGED
@@ -21,6 +21,7 @@ def encode(init_image, torch_device, ae):
21
  init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
22
  init_image = init_image.unsqueeze(0)
23
  init_image = init_image.to(torch_device)
 
24
  init_image = ae.encode(init_image.to()).to(torch.bfloat16)
25
  return init_image
26
  from torchvision import transforms
 
21
  init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
22
  init_image = init_image.unsqueeze(0)
23
  init_image = init_image.to(torch_device)
24
+ ae.to(torch_device)
25
  init_image = ae.encode(init_image.to()).to(torch.bfloat16)
26
  return init_image
27
  from torchvision import transforms