| |
|
| |
|
| |
|
| |
|
| | import cv2
|
| | import torch
|
| | import torch.nn as nn
|
| | from torchvision.transforms import Compose
|
| |
|
| | from .dpt_depth import DPTDepthModel
|
| | from .midas_net import MidasNet
|
| | from .midas_net_custom import MidasNet_small
|
| | from .transforms import NormalizeImage, PrepareForNet, Resize
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def disabled_train(self, mode=True):
|
| | """Overwrite model.train with this function to make sure train/eval mode
|
| | does not change anymore."""
|
| | return self
|
| |
|
| |
|
| | def load_midas_transform(model_type):
|
| |
|
| |
|
| | if model_type == 'dpt_large':
|
| | net_w, net_h = 384, 384
|
| | resize_mode = 'minimal'
|
| | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
|
| | std=[0.5, 0.5, 0.5])
|
| |
|
| | elif model_type == 'dpt_hybrid':
|
| | net_w, net_h = 384, 384
|
| | resize_mode = 'minimal'
|
| | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
|
| | std=[0.5, 0.5, 0.5])
|
| |
|
| | elif model_type == 'midas_v21':
|
| | net_w, net_h = 384, 384
|
| | resize_mode = 'upper_bound'
|
| | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
|
| | std=[0.229, 0.224, 0.225])
|
| |
|
| | elif model_type == 'midas_v21_small':
|
| | net_w, net_h = 256, 256
|
| | resize_mode = 'upper_bound'
|
| | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
|
| | std=[0.229, 0.224, 0.225])
|
| |
|
| | else:
|
| | assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
|
| |
|
| | transform = Compose([
|
| | Resize(
|
| | net_w,
|
| | net_h,
|
| | resize_target=None,
|
| | keep_aspect_ratio=True,
|
| | ensure_multiple_of=32,
|
| | resize_method=resize_mode,
|
| | image_interpolation_method=cv2.INTER_CUBIC,
|
| | ),
|
| | normalization,
|
| | PrepareForNet(),
|
| | ])
|
| |
|
| | return transform
|
| |
|
| |
|
| | def load_model(model_type, model_path):
|
| |
|
| |
|
| |
|
| | if model_type == 'dpt_large':
|
| | model = DPTDepthModel(
|
| | path=model_path,
|
| | backbone='vitl16_384',
|
| | non_negative=True,
|
| | )
|
| | net_w, net_h = 384, 384
|
| | resize_mode = 'minimal'
|
| | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
|
| | std=[0.5, 0.5, 0.5])
|
| |
|
| | elif model_type == 'dpt_hybrid':
|
| | model = DPTDepthModel(
|
| | path=model_path,
|
| | backbone='vitb_rn50_384',
|
| | non_negative=True,
|
| | )
|
| | net_w, net_h = 384, 384
|
| | resize_mode = 'minimal'
|
| | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
|
| | std=[0.5, 0.5, 0.5])
|
| |
|
| | elif model_type == 'midas_v21':
|
| | model = MidasNet(model_path, non_negative=True)
|
| | net_w, net_h = 384, 384
|
| | resize_mode = 'upper_bound'
|
| | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
|
| | std=[0.229, 0.224, 0.225])
|
| |
|
| | elif model_type == 'midas_v21_small':
|
| | model = MidasNet_small(model_path,
|
| | features=64,
|
| | backbone='efficientnet_lite3',
|
| | exportable=True,
|
| | non_negative=True,
|
| | blocks={'expand': True})
|
| | net_w, net_h = 256, 256
|
| | resize_mode = 'upper_bound'
|
| | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
|
| | std=[0.229, 0.224, 0.225])
|
| |
|
| | else:
|
| | print(
|
| | f"model_type '{model_type}' not implemented, use: --model_type large"
|
| | )
|
| | assert False
|
| |
|
| | transform = Compose([
|
| | Resize(
|
| | net_w,
|
| | net_h,
|
| | resize_target=None,
|
| | keep_aspect_ratio=True,
|
| | ensure_multiple_of=32,
|
| | resize_method=resize_mode,
|
| | image_interpolation_method=cv2.INTER_CUBIC,
|
| | ),
|
| | normalization,
|
| | PrepareForNet(),
|
| | ])
|
| |
|
| | return model.eval(), transform
|
| |
|
| |
|
| | class MiDaSInference(nn.Module):
|
| | MODEL_TYPES_TORCH_HUB = ['DPT_Large', 'DPT_Hybrid', 'MiDaS_small']
|
| | MODEL_TYPES_ISL = [
|
| | 'dpt_large',
|
| | 'dpt_hybrid',
|
| | 'midas_v21',
|
| | 'midas_v21_small',
|
| | ]
|
| |
|
| | def __init__(self, model_type, model_path):
|
| | super().__init__()
|
| | assert (model_type in self.MODEL_TYPES_ISL)
|
| | model, _ = load_model(model_type, model_path)
|
| | self.model = model
|
| | self.model.train = disabled_train
|
| |
|
| | def forward(self, x):
|
| | with torch.no_grad():
|
| | prediction = self.model(x)
|
| | return prediction
|
| |
|