Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torchvision.transforms as transforms | |
| import torchvision.transforms.functional as TF | |
| from torch import Tensor, nn | |
| from .app_utils import count_parameters | |
| from .device import device | |
| from .dpt.models import DPTDepthModel | |
| class BaseDepthModel: | |
| def __init__(self, image_size: int) -> None: | |
| self.image_size = image_size | |
| self.model: nn.Module = None | |
| def forward(self, image: Tensor) -> Tensor: | |
| """Perform forward inference for an image | |
| Input image of shape [c, h, w] | |
| Return of shape [c, h, w] | |
| """ | |
| raise NotImplementedError() | |
| def batch_forward(self, images: Tensor) -> Tensor: | |
| """Perform forward inference for a batch of images | |
| Input images of shape [b, c, h, w] | |
| Return of shape [b, c, h, w]""" | |
| raise NotImplementedError() | |
| def get_number_of_parameters(self) -> int: | |
| return count_parameters(self.model) | |
| class DPTDepth(BaseDepthModel): | |
| def __init__(self, image_size: int) -> None: | |
| super().__init__(image_size) | |
| print("DPTDepthconstructor") | |
| weights_fname = "omnidata_rgb2depth_dpt_hybrid.pth" | |
| weights_path = os.path.join("weights", weights_fname) | |
| if not os.path.isfile(weights_path): | |
| from huggingface_hub import hf_hub_download | |
| hf_hub_download( | |
| repo_id="RGBD-SOD/S-MultiMAE", | |
| filename=weights_fname, | |
| local_dir="weights", | |
| ) | |
| omnidata_ckpt = torch.load( | |
| weights_path, | |
| map_location="cpu", | |
| ) | |
| self.model = DPTDepthModel() | |
| self.model.load_state_dict(omnidata_ckpt) | |
| self.model: DPTDepthModel = self.model.to(device).eval() | |
| self.transform = transforms.Compose( | |
| [ | |
| transforms.Resize( | |
| (self.image_size, self.image_size), | |
| interpolation=TF.InterpolationMode.BICUBIC, | |
| ), | |
| transforms.Normalize( | |
| (0.5, 0.5, 0.5), | |
| (0.5, 0.5, 0.5), | |
| ), | |
| ] | |
| ) | |
| def forward(self, image: Tensor) -> Tensor: | |
| depth_model_input = self.transform(image.unsqueeze(0)) | |
| return self.model.forward(depth_model_input.to(device)).squeeze(0) | |
| def batch_forward(self, images: Tensor) -> Tensor: | |
| images: Tensor = TF.resize( | |
| images, | |
| (self.image_size, self.image_size), | |
| interpolation=TF.InterpolationMode.BICUBIC, | |
| ) | |
| depth_model_input = (images - 0.5) / 0.5 | |
| return self.model(depth_model_input.to(device)) | |