samsl commited on
Commit
4e1899a
·
1 Parent(s): 8e4970b

try cpu only

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -31,7 +31,7 @@ from rocketshp.network import (
31
  )
32
 
33
  os.environ["OMP_NUM_THREADS"] = "4"
34
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
35
 
36
 
37
  def plot_predictions(
@@ -138,7 +138,8 @@ def predict_rocketshp(
138
  raise gr.Error("Failed to authorize repository access.")
139
 
140
  # Load the model
141
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
142
  model = RocketSHP.load_from_checkpoint(model_variant).to(device)
143
  is_sequence_model = "seq" in model_variant or "mini" in model_variant
144
 
 
31
  )
32
 
33
  os.environ["OMP_NUM_THREADS"] = "4"
34
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
35
 
36
 
37
  def plot_predictions(
 
138
  raise gr.Error("Failed to authorize repository access.")
139
 
140
  # Load the model
141
+ # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
142
+ device = "cpu"
143
  model = RocketSHP.load_from_checkpoint(model_variant).to(device)
144
  is_sequence_model = "seq" in model_variant or "mini" in model_variant
145