File size: 5,250 Bytes
4e09276 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | """TIPSv2 DPT dense prediction model for HuggingFace."""
import importlib
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoModel, PreTrainedModel
from .configuration_dpt import TIPSv2DPTConfig
_this_dir = Path(__file__).parent
_sibling_cache = {}
def _load_sibling(name, repo_id=None):
if name in _sibling_cache:
return _sibling_cache[name]
path = _this_dir / f"{name}.py"
if not path.exists() and repo_id:
path = Path(hf_hub_download(repo_id, f"{name}.py"))
spec = importlib.util.spec_from_file_location(name, str(path))
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
_sibling_cache[name] = mod
return mod
@dataclass
class TIPSv2DPTOutput:
depth: Optional[torch.Tensor] = None
normals: Optional[torch.Tensor] = None
segmentation: Optional[torch.Tensor] = None
class TIPSv2DPTModel(PreTrainedModel):
"""TIPSv2 DPT dense prediction model (depth, normals, segmentation).
The backbone is loaded automatically from the base TIPSv2 model repo.
Usage::
model = AutoModel.from_pretrained("google/tipsv2-l14-dpt", trust_remote_code=True)
model.eval().cuda()
outputs = model(pixel_values)
outputs.depth # (B, 1, H, W)
outputs.normals # (B, 3, H, W)
outputs.segmentation # (B, 150, H, W)
# Individual tasks
depth = model.predict_depth(pixel_values)
normals = model.predict_normals(pixel_values)
seg = model.predict_segmentation(pixel_values)
"""
config_class = TIPSv2DPTConfig
_no_split_modules = []
_supports_cache_class = False
_tied_weights_keys = []
@property
def all_tied_weights_keys(self):
return {}
def __init__(self, config: TIPSv2DPTConfig):
super().__init__(config)
repo_id = getattr(config, "_name_or_path", None)
dpt_mod = _load_sibling("dpt_head", repo_id)
ppc = tuple(config.post_process_channels)
self.depth_head = dpt_mod.DPTDepthHead(
input_embed_dim=config.embed_dim, channels=config.channels,
post_process_channels=ppc, readout_type=config.readout_type,
num_depth_bins=config.num_depth_bins,
min_depth=config.min_depth, max_depth=config.max_depth,
)
self.normals_head = dpt_mod.DPTNormalsHead(
input_embed_dim=config.embed_dim, channels=config.channels,
post_process_channels=ppc, readout_type=config.readout_type,
)
self.segmentation_head = dpt_mod.DPTSegmentationHead(
input_embed_dim=config.embed_dim, channels=config.channels,
post_process_channels=ppc, readout_type=config.readout_type,
num_classes=config.num_seg_classes,
)
self._backbone = None
def _get_backbone(self):
if self._backbone is None:
self._backbone = AutoModel.from_pretrained(self.config.backbone_repo, trust_remote_code=True)
self._backbone.to(self.device).eval()
return self._backbone.vision_encoder
def _extract_intermediate(self, pixel_values):
backbone = self._get_backbone()
intermediate = backbone.get_intermediate_layers(
pixel_values, n=self.config.block_indices,
reshape=True, return_class_token=True, norm=True,
)
return [(cls_tok, patch_feat) for patch_feat, cls_tok in intermediate]
@torch.no_grad()
def predict_depth(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""Predict depth map. Returns (B, 1, H, W)."""
pixel_values = pixel_values.to(self.device)
h, w = pixel_values.shape[2:]
dpt_inputs = self._extract_intermediate(pixel_values)
return self.depth_head(dpt_inputs, image_size=(h, w))
@torch.no_grad()
def predict_normals(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""Predict surface normals. Returns (B, 3, H, W)."""
pixel_values = pixel_values.to(self.device)
h, w = pixel_values.shape[2:]
dpt_inputs = self._extract_intermediate(pixel_values)
return self.normals_head(dpt_inputs, image_size=(h, w))
@torch.no_grad()
def predict_segmentation(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""Predict semantic segmentation (ADE20K). Returns (B, 150, H, W)."""
pixel_values = pixel_values.to(self.device)
h, w = pixel_values.shape[2:]
dpt_inputs = self._extract_intermediate(pixel_values)
return self.segmentation_head(dpt_inputs, image_size=(h, w))
def forward(self, pixel_values: torch.Tensor) -> TIPSv2DPTOutput:
"""Run all three tasks. Returns TIPSv2DPTOutput."""
pixel_values = pixel_values.to(self.device)
h, w = pixel_values.shape[2:]
dpt_inputs = self._extract_intermediate(pixel_values)
return TIPSv2DPTOutput(
depth=self.depth_head(dpt_inputs, image_size=(h, w)),
normals=self.normals_head(dpt_inputs, image_size=(h, w)),
segmentation=self.segmentation_head(dpt_inputs, image_size=(h, w)),
)
|