ksangk commited on
Commit
66a9c17
·
1 Parent(s): e10ae74

update device

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -41,10 +41,11 @@ def load_model(ckpt_path):
41
  return model
42
 
43
  def run_model(model, img: Image.Image):
 
44
  to_tensor = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
45
- image = to_tensor(img).to(next(model.parameters()).device)
46
  x = v2.Resize(size=(1024, 1024), antialias=True)(image).unsqueeze(0)
47
- with torch.no_grad() as no_grad, torch.autocast(device_type="cuda") as amp:
48
  output = model(x)
49
  return output
50
 
 
41
  return model
42
 
43
  def run_model(model, img: Image.Image):
44
+ device = next(model.parameters()).device
45
  to_tensor = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
46
+ image = to_tensor(img).to(device)
47
  x = v2.Resize(size=(1024, 1024), antialias=True)(image).unsqueeze(0)
48
+ with torch.no_grad() as no_grad, torch.autocast(device_type=device.type) as amp:
49
  output = model(x)
50
  return output
51