Spaces:
Sleeping
Sleeping
Update depth_decoder.py
Browse files- depth_decoder.py +3 -1
depth_decoder.py
CHANGED
|
@@ -18,6 +18,8 @@ class DepthDecoder(nn.Module):
|
|
| 18 |
def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True, batch_norm = True):
|
| 19 |
super(DepthDecoder, self).__init__()
|
| 20 |
|
|
|
|
|
|
|
| 21 |
self.num_output_channels = num_output_channels
|
| 22 |
self.use_skips = use_skips
|
| 23 |
self.upsample_mode = 'nearest'
|
|
@@ -70,7 +72,7 @@ class DepthDecoder(nn.Module):
|
|
| 70 |
x = torch.cat(x, 1)
|
| 71 |
x = self.convs[("upconv", i, 1)](x)
|
| 72 |
if self.batch_norm:
|
| 73 |
-
x = self.bn[('bn', i)].
|
| 74 |
|
| 75 |
|
| 76 |
# batchnorm
|
|
|
|
| 18 |
def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True, batch_norm = True):
|
| 19 |
super(DepthDecoder, self).__init__()
|
| 20 |
|
| 21 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 22 |
+
|
| 23 |
self.num_output_channels = num_output_channels
|
| 24 |
self.use_skips = use_skips
|
| 25 |
self.upsample_mode = 'nearest'
|
|
|
|
| 72 |
x = torch.cat(x, 1)
|
| 73 |
x = self.convs[("upconv", i, 1)](x)
|
| 74 |
if self.batch_norm:
|
| 75 |
+
x = self.bn[('bn', i)].to(self.device)(x)
|
| 76 |
|
| 77 |
|
| 78 |
# batchnorm
|