phoebe777777 commited on
Commit
06115fd
·
verified ·
1 Parent(s): d55530a

Update inference_seg.py

Browse files
Files changed (1) hide show
  1. inference_seg.py +5 -2
inference_seg.py CHANGED
@@ -18,13 +18,16 @@ def load_model(use_box=False):
18
  )
19
  MODEL.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
20
  MODEL.eval()
21
- DEVICE = torch.device("cpu")
 
 
 
22
  return MODEL, DEVICE
23
 
24
 
25
  @torch.no_grad()
26
  def run(model, img_path, box=None, device="cpu"):
27
- print(device)
28
  model.move_to_device(device)
29
  model.eval()
30
  with torch.no_grad():
 
18
  )
19
  MODEL.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
20
  MODEL.eval()
21
+ if torch.cuda.is_available():
22
+ DEVICE = torch.device("cuda")
23
+ else:
24
+ DEVICE = torch.device("cpu")
25
  return MODEL, DEVICE
26
 
27
 
28
  @torch.no_grad()
29
  def run(model, img_path, box=None, device="cpu"):
30
+ print("DEVICE": device)
31
  model.move_to_device(device)
32
  model.eval()
33
  with torch.no_grad():