|
|
|
|
|
|
|
|
| import cv2
|
| import torch
|
| import torch.nn as nn
|
| from torchvision.transforms import Compose
|
|
|
| from .dpt_depth import DPTDepthModel
|
| from .midas_net import MidasNet
|
| from .midas_net_custom import MidasNet_small
|
| from .transforms import NormalizeImage, PrepareForNet, Resize
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def disabled_train(self, mode=True):
|
| """Overwrite model.train with this function to make sure train/eval mode
|
| does not change anymore."""
|
| return self
|
|
|
|
|
| def load_midas_transform(model_type):
|
|
|
|
|
| if model_type == 'dpt_large':
|
| net_w, net_h = 384, 384
|
| resize_mode = 'minimal'
|
| normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
|
| std=[0.5, 0.5, 0.5])
|
|
|
| elif model_type == 'dpt_hybrid':
|
| net_w, net_h = 384, 384
|
| resize_mode = 'minimal'
|
| normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
|
| std=[0.5, 0.5, 0.5])
|
|
|
| elif model_type == 'midas_v21':
|
| net_w, net_h = 384, 384
|
| resize_mode = 'upper_bound'
|
| normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
|
| std=[0.229, 0.224, 0.225])
|
|
|
| elif model_type == 'midas_v21_small':
|
| net_w, net_h = 256, 256
|
| resize_mode = 'upper_bound'
|
| normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
|
| std=[0.229, 0.224, 0.225])
|
|
|
| else:
|
| assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
|
|
|
| transform = Compose([
|
| Resize(
|
| net_w,
|
| net_h,
|
| resize_target=None,
|
| keep_aspect_ratio=True,
|
| ensure_multiple_of=32,
|
| resize_method=resize_mode,
|
| image_interpolation_method=cv2.INTER_CUBIC,
|
| ),
|
| normalization,
|
| PrepareForNet(),
|
| ])
|
|
|
| return transform
|
|
|
|
|
| def load_model(model_type, model_path):
|
|
|
|
|
|
|
| if model_type == 'dpt_large':
|
| model = DPTDepthModel(
|
| path=model_path,
|
| backbone='vitl16_384',
|
| non_negative=True,
|
| )
|
| net_w, net_h = 384, 384
|
| resize_mode = 'minimal'
|
| normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
|
| std=[0.5, 0.5, 0.5])
|
|
|
| elif model_type == 'dpt_hybrid':
|
| model = DPTDepthModel(
|
| path=model_path,
|
| backbone='vitb_rn50_384',
|
| non_negative=True,
|
| )
|
| net_w, net_h = 384, 384
|
| resize_mode = 'minimal'
|
| normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
|
| std=[0.5, 0.5, 0.5])
|
|
|
| elif model_type == 'midas_v21':
|
| model = MidasNet(model_path, non_negative=True)
|
| net_w, net_h = 384, 384
|
| resize_mode = 'upper_bound'
|
| normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
|
| std=[0.229, 0.224, 0.225])
|
|
|
| elif model_type == 'midas_v21_small':
|
| model = MidasNet_small(model_path,
|
| features=64,
|
| backbone='efficientnet_lite3',
|
| exportable=True,
|
| non_negative=True,
|
| blocks={'expand': True})
|
| net_w, net_h = 256, 256
|
| resize_mode = 'upper_bound'
|
| normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
|
| std=[0.229, 0.224, 0.225])
|
|
|
| else:
|
| print(
|
| f"model_type '{model_type}' not implemented, use: --model_type large"
|
| )
|
| assert False
|
|
|
| transform = Compose([
|
| Resize(
|
| net_w,
|
| net_h,
|
| resize_target=None,
|
| keep_aspect_ratio=True,
|
| ensure_multiple_of=32,
|
| resize_method=resize_mode,
|
| image_interpolation_method=cv2.INTER_CUBIC,
|
| ),
|
| normalization,
|
| PrepareForNet(),
|
| ])
|
|
|
| return model.eval(), transform
|
|
|
|
|
| class MiDaSInference(nn.Module):
|
| MODEL_TYPES_TORCH_HUB = ['DPT_Large', 'DPT_Hybrid', 'MiDaS_small']
|
| MODEL_TYPES_ISL = [
|
| 'dpt_large',
|
| 'dpt_hybrid',
|
| 'midas_v21',
|
| 'midas_v21_small',
|
| ]
|
|
|
| def __init__(self, model_type, model_path):
|
| super().__init__()
|
| assert (model_type in self.MODEL_TYPES_ISL)
|
| model, _ = load_model(model_type, model_path)
|
| self.model = model
|
| self.model.train = disabled_train
|
|
|
| def forward(self, x):
|
| with torch.no_grad():
|
| prediction = self.model(x)
|
| return prediction
|
|
|