nabeel.muhammad commited on
Commit
fdcda3c
·
1 Parent(s): 4f24bd0

fix:add device

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +2 -1
src/streamlit_app.py CHANGED
@@ -20,6 +20,7 @@ def load_model():
20
  if re.search(r'in\d+\.running_(mean|var)$', k):
21
  del state_dict[k]
22
  model.load_state_dict(state_dict)
 
23
  model.eval()
24
  return model
25
 
@@ -45,7 +46,7 @@ else:
45
  torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
46
  ])
47
 
48
- image_tensor = preprocess(image).unsqueeze(0)
49
 
50
  with torch.no_grad():
51
  output = model(image_tensor)
 
20
  if re.search(r'in\d+\.running_(mean|var)$', k):
21
  del state_dict[k]
22
  model.load_state_dict(state_dict)
23
+ model = model.to(device)
24
  model.eval()
25
  return model
26
 
 
46
  torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
47
  ])
48
 
49
+ image_tensor = preprocess(image).unsqueeze(0).to(device)
50
 
51
  with torch.no_grad():
52
  output = model(image_tensor)