Spaces:
Runtime error
Runtime error
| from lib import network_auxi as network | |
| from lib.net_tools import get_func | |
| import torch | |
| import torch.nn as nn | |
| class RelDepthModel(nn.Module): | |
| def __init__(self, backbone='resnet50'): | |
| super(RelDepthModel, self).__init__() | |
| if backbone == 'resnet50': | |
| encoder = 'resnet50_stride32' | |
| elif backbone == 'resnext101': | |
| encoder = 'resnext101_stride32x8d' | |
| self.depth_model = DepthModel(encoder) | |
| def inference(self, rgb): | |
| with torch.no_grad(): | |
| input = rgb.cuda() | |
| depth = self.depth_model(input) | |
| #pred_depth_out = depth - depth.min() + 0.01 | |
| return depth #pred_depth_out | |
| class DepthModel(nn.Module): | |
| def __init__(self, encoder): | |
| super(DepthModel, self).__init__() | |
| backbone = network.__name__.split('.')[-1] + '.' + encoder | |
| self.encoder_modules = get_func(backbone)() | |
| self.decoder_modules = network.Decoder() | |
| def forward(self, x): | |
| lateral_out = self.encoder_modules(x) | |
| out_logit = self.decoder_modules(lateral_out) | |
| return out_logit |