Spaces:
Runtime error
Runtime error
igashov
commited on
Commit
·
d5b42eb
1
Parent(s):
f97fe8b
update device
Browse files- src/lightning.py +2 -2
- src/linker_size_lightning.py +1 -1
src/lightning.py
CHANGED
|
@@ -55,7 +55,7 @@ class DDPM(pl.LightningModule):
|
|
| 55 |
self.val_data_prefix = val_data_prefix
|
| 56 |
self.batch_size = batch_size
|
| 57 |
self.lr = lr
|
| 58 |
-
self.torch_device =
|
| 59 |
self.include_charges = include_charges
|
| 60 |
self.test_epochs = test_epochs
|
| 61 |
self.n_stability_samples = n_stability_samples
|
|
@@ -81,7 +81,7 @@ class DDPM(pl.LightningModule):
|
|
| 81 |
in_node_nf=in_node_nf,
|
| 82 |
n_dims=n_dims,
|
| 83 |
context_node_nf=context_node_nf,
|
| 84 |
-
device=torch_device,
|
| 85 |
hidden_nf=hidden_nf,
|
| 86 |
activation=activation,
|
| 87 |
n_layers=n_layers,
|
|
|
|
| 55 |
self.val_data_prefix = val_data_prefix
|
| 56 |
self.batch_size = batch_size
|
| 57 |
self.lr = lr
|
| 58 |
+
self.torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 59 |
self.include_charges = include_charges
|
| 60 |
self.test_epochs = test_epochs
|
| 61 |
self.n_stability_samples = n_stability_samples
|
|
|
|
| 81 |
in_node_nf=in_node_nf,
|
| 82 |
n_dims=n_dims,
|
| 83 |
context_node_nf=context_node_nf,
|
| 84 |
+
device=self.torch_device,
|
| 85 |
hidden_nf=hidden_nf,
|
| 86 |
activation=activation,
|
| 87 |
n_layers=n_layers,
|
src/linker_size_lightning.py
CHANGED
|
@@ -45,7 +45,7 @@ class SizeClassifier(pl.LightningModule):
|
|
| 45 |
hidden_nf=hidden_nf,
|
| 46 |
out_node_nf=out_node_nf,
|
| 47 |
n_layers=n_layers,
|
| 48 |
-
device=torch_device,
|
| 49 |
normalization=normalization,
|
| 50 |
)
|
| 51 |
|
|
|
|
| 45 |
hidden_nf=hidden_nf,
|
| 46 |
out_node_nf=out_node_nf,
|
| 47 |
n_layers=n_layers,
|
| 48 |
+
device=self.torch_device,
|
| 49 |
normalization=normalization,
|
| 50 |
)
|
| 51 |
|