Spaces:
Runtime error
Runtime error
Update depth_app.py
Browse files- depth_app.py +6 -3
depth_app.py
CHANGED
|
@@ -7,8 +7,11 @@ import io
|
|
| 7 |
from torchvision import transforms
|
| 8 |
import matplotlib as mpl
|
| 9 |
import matplotlib.cm as cm
|
| 10 |
-
import networks
|
| 11 |
from layers import disp_to_depth_no_scaling
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
# Function to load the model
|
| 14 |
def load_model(device, model_path):
|
|
@@ -16,7 +19,7 @@ def load_model(device, model_path):
|
|
| 16 |
encoder_path = os.path.join(model_path, "encoder.pth")
|
| 17 |
depth_decoder_path = os.path.join(model_path, "depth.pth")
|
| 18 |
|
| 19 |
-
encoder =
|
| 20 |
loaded_dict_enc = torch.load(encoder_path, map_location=device)
|
| 21 |
feed_height = loaded_dict_enc['height']
|
| 22 |
feed_width = loaded_dict_enc['width']
|
|
@@ -25,7 +28,7 @@ def load_model(device, model_path):
|
|
| 25 |
encoder.to(device)
|
| 26 |
encoder.eval()
|
| 27 |
|
| 28 |
-
depth_decoder =
|
| 29 |
loaded_dict = torch.load(depth_decoder_path, map_location=device)
|
| 30 |
depth_decoder.load_state_dict(loaded_dict, strict=False)
|
| 31 |
depth_decoder.to(device)
|
|
|
|
| 7 |
from torchvision import transforms
|
| 8 |
import matplotlib as mpl
|
| 9 |
import matplotlib.cm as cm
|
|
|
|
| 10 |
from layers import disp_to_depth_no_scaling
|
| 11 |
+
from resnet_encoder import ResnetEncoder
|
| 12 |
+
from depth_decoder import DepthDecoder
|
| 13 |
+
# from pose_decoder import PoseDecoder
|
| 14 |
+
# from pose_cnn import PoseCNN
|
| 15 |
|
| 16 |
# Function to load the model
|
| 17 |
def load_model(device, model_path):
|
|
|
|
| 19 |
encoder_path = os.path.join(model_path, "encoder.pth")
|
| 20 |
depth_decoder_path = os.path.join(model_path, "depth.pth")
|
| 21 |
|
| 22 |
+
encoder = ResnetEncoder(18, False)
|
| 23 |
loaded_dict_enc = torch.load(encoder_path, map_location=device)
|
| 24 |
feed_height = loaded_dict_enc['height']
|
| 25 |
feed_width = loaded_dict_enc['width']
|
|
|
|
| 28 |
encoder.to(device)
|
| 29 |
encoder.eval()
|
| 30 |
|
| 31 |
+
depth_decoder = DepthDecoder(num_ch_enc=encoder.num_ch_enc, scales=range(4))
|
| 32 |
loaded_dict = torch.load(depth_decoder_path, map_location=device)
|
| 33 |
depth_decoder.load_state_dict(loaded_dict, strict=False)
|
| 34 |
depth_decoder.to(device)
|