samsl commited on
Commit
f0a85c1
·
1 Parent(s): e6dcb6a

re-enable gpu support

Browse files
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -31,7 +31,7 @@ from rocketshp.network import (
31
  )
32
 
33
  os.environ["OMP_NUM_THREADS"] = "4"
34
- os.environ["CUDA_VISIBLE_DEVICES"] = ""
35
 
36
 
37
  def plot_predictions(
@@ -138,8 +138,7 @@ def predict_rocketshp(
138
  raise gr.Error("Failed to authorize repository access.")
139
 
140
  # Load the model
141
- # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
142
- device = "cpu"
143
  model = RocketSHP.load_from_checkpoint(model_variant).to(device)
144
  is_sequence_model = "seq" in model_variant or "mini" in model_variant
145
 
 
31
  )
32
 
33
  os.environ["OMP_NUM_THREADS"] = "4"
34
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
35
 
36
 
37
  def plot_predictions(
 
138
  raise gr.Error("Failed to authorize repository access.")
139
 
140
  # Load the model
141
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
142
  model = RocketSHP.load_from_checkpoint(model_variant).to(device)
143
  is_sequence_model = "seq" in model_variant or "mini" in model_variant
144