re-enable gpu support
Browse files
app.py
CHANGED
|
@@ -31,7 +31,7 @@ from rocketshp.network import (
|
|
| 31 |
)
|
| 32 |
|
| 33 |
os.environ["OMP_NUM_THREADS"] = "4"
|
| 34 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
| 35 |
|
| 36 |
|
| 37 |
def plot_predictions(
|
|
@@ -138,8 +138,7 @@ def predict_rocketshp(
|
|
| 138 |
raise gr.Error("Failed to authorize repository access.")
|
| 139 |
|
| 140 |
# Load the model
|
| 141 |
-
|
| 142 |
-
device = "cpu"
|
| 143 |
model = RocketSHP.load_from_checkpoint(model_variant).to(device)
|
| 144 |
is_sequence_model = "seq" in model_variant or "mini" in model_variant
|
| 145 |
|
|
|
|
| 31 |
)
|
| 32 |
|
| 33 |
os.environ["OMP_NUM_THREADS"] = "4"
|
| 34 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
| 35 |
|
| 36 |
|
| 37 |
def plot_predictions(
|
|
|
|
| 138 |
raise gr.Error("Failed to authorize repository access.")
|
| 139 |
|
| 140 |
# Load the model
|
| 141 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 142 |
model = RocketSHP.load_from_checkpoint(model_variant).to(device)
|
| 143 |
is_sequence_model = "seq" in model_variant or "mini" in model_variant
|
| 144 |
|