"""TIPSv2 DPT dense prediction model for HuggingFace.""" import importlib import os from dataclasses import dataclass from pathlib import Path from typing import Optional import torch from huggingface_hub import hf_hub_download from transformers import AutoModel, PreTrainedModel from .configuration_dpt import TIPSv2DPTConfig _this_dir = Path(__file__).parent _sibling_cache = {} def _load_sibling(name, repo_id=None): if name in _sibling_cache: return _sibling_cache[name] path = _this_dir / f"{name}.py" if not path.exists() and repo_id: path = Path(hf_hub_download(repo_id, f"{name}.py")) spec = importlib.util.spec_from_file_location(name, str(path)) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) _sibling_cache[name] = mod return mod @dataclass class TIPSv2DPTOutput: depth: Optional[torch.Tensor] = None normals: Optional[torch.Tensor] = None segmentation: Optional[torch.Tensor] = None class TIPSv2DPTModel(PreTrainedModel): """TIPSv2 DPT dense prediction model (depth, normals, segmentation). The backbone is loaded automatically from the base TIPSv2 model repo. Usage:: model = AutoModel.from_pretrained("google/tipsv2-l14-dpt", trust_remote_code=True) model.eval().cuda() outputs = model(pixel_values) outputs.depth # (B, 1, H, W) outputs.normals # (B, 3, H, W) outputs.segmentation # (B, 150, H, W) # Individual tasks depth = model.predict_depth(pixel_values) normals = model.predict_normals(pixel_values) seg = model.predict_segmentation(pixel_values) """ config_class = TIPSv2DPTConfig _no_split_modules = [] _supports_cache_class = False _tied_weights_keys = [] @property def all_tied_weights_keys(self): return {} def __init__(self, config: TIPSv2DPTConfig): super().__init__(config) repo_id = getattr(config, "_name_or_path", None) dpt_mod = _load_sibling("dpt_head", repo_id) ppc = tuple(config.post_process_channels) self.depth_head = dpt_mod.DPTDepthHead( input_embed_dim=config.embed_dim, channels=config.channels, post_process_channels=ppc, readout_type=config.readout_type, num_depth_bins=config.num_depth_bins, min_depth=config.min_depth, max_depth=config.max_depth, ) self.normals_head = dpt_mod.DPTNormalsHead( input_embed_dim=config.embed_dim, channels=config.channels, post_process_channels=ppc, readout_type=config.readout_type, ) self.segmentation_head = dpt_mod.DPTSegmentationHead( input_embed_dim=config.embed_dim, channels=config.channels, post_process_channels=ppc, readout_type=config.readout_type, num_classes=config.num_seg_classes, ) self._backbone = None def _get_backbone(self): if self._backbone is None: self._backbone = AutoModel.from_pretrained(self.config.backbone_repo, trust_remote_code=True) self._backbone.to(self.device).eval() return self._backbone.vision_encoder def _extract_intermediate(self, pixel_values): backbone = self._get_backbone() intermediate = backbone.get_intermediate_layers( pixel_values, n=self.config.block_indices, reshape=True, return_class_token=True, norm=True, ) return [(cls_tok, patch_feat) for patch_feat, cls_tok in intermediate] @torch.no_grad() def predict_depth(self, pixel_values: torch.Tensor) -> torch.Tensor: """Predict depth map. Returns (B, 1, H, W).""" pixel_values = pixel_values.to(self.device) h, w = pixel_values.shape[2:] dpt_inputs = self._extract_intermediate(pixel_values) return self.depth_head(dpt_inputs, image_size=(h, w)) @torch.no_grad() def predict_normals(self, pixel_values: torch.Tensor) -> torch.Tensor: """Predict surface normals. Returns (B, 3, H, W).""" pixel_values = pixel_values.to(self.device) h, w = pixel_values.shape[2:] dpt_inputs = self._extract_intermediate(pixel_values) return self.normals_head(dpt_inputs, image_size=(h, w)) @torch.no_grad() def predict_segmentation(self, pixel_values: torch.Tensor) -> torch.Tensor: """Predict semantic segmentation (ADE20K). Returns (B, 150, H, W).""" pixel_values = pixel_values.to(self.device) h, w = pixel_values.shape[2:] dpt_inputs = self._extract_intermediate(pixel_values) return self.segmentation_head(dpt_inputs, image_size=(h, w)) def forward(self, pixel_values: torch.Tensor) -> TIPSv2DPTOutput: """Run all three tasks. Returns TIPSv2DPTOutput.""" pixel_values = pixel_values.to(self.device) h, w = pixel_values.shape[2:] dpt_inputs = self._extract_intermediate(pixel_values) return TIPSv2DPTOutput( depth=self.depth_head(dpt_inputs, image_size=(h, w)), normals=self.normals_head(dpt_inputs, image_size=(h, w)), segmentation=self.segmentation_head(dpt_inputs, image_size=(h, w)), )