tipsv2-g14-dpt / modeling_dpt.py
gberton's picture
Upload modeling_dpt.py with huggingface_hub
4e09276 verified
"""TIPSv2 DPT dense prediction model for HuggingFace."""
import importlib
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoModel, PreTrainedModel
from .configuration_dpt import TIPSv2DPTConfig
_this_dir = Path(__file__).parent
_sibling_cache = {}
def _load_sibling(name, repo_id=None):
if name in _sibling_cache:
return _sibling_cache[name]
path = _this_dir / f"{name}.py"
if not path.exists() and repo_id:
path = Path(hf_hub_download(repo_id, f"{name}.py"))
spec = importlib.util.spec_from_file_location(name, str(path))
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
_sibling_cache[name] = mod
return mod
@dataclass
class TIPSv2DPTOutput:
depth: Optional[torch.Tensor] = None
normals: Optional[torch.Tensor] = None
segmentation: Optional[torch.Tensor] = None
class TIPSv2DPTModel(PreTrainedModel):
"""TIPSv2 DPT dense prediction model (depth, normals, segmentation).
The backbone is loaded automatically from the base TIPSv2 model repo.
Usage::
model = AutoModel.from_pretrained("google/tipsv2-l14-dpt", trust_remote_code=True)
model.eval().cuda()
outputs = model(pixel_values)
outputs.depth # (B, 1, H, W)
outputs.normals # (B, 3, H, W)
outputs.segmentation # (B, 150, H, W)
# Individual tasks
depth = model.predict_depth(pixel_values)
normals = model.predict_normals(pixel_values)
seg = model.predict_segmentation(pixel_values)
"""
config_class = TIPSv2DPTConfig
_no_split_modules = []
_supports_cache_class = False
_tied_weights_keys = []
@property
def all_tied_weights_keys(self):
return {}
def __init__(self, config: TIPSv2DPTConfig):
super().__init__(config)
repo_id = getattr(config, "_name_or_path", None)
dpt_mod = _load_sibling("dpt_head", repo_id)
ppc = tuple(config.post_process_channels)
self.depth_head = dpt_mod.DPTDepthHead(
input_embed_dim=config.embed_dim, channels=config.channels,
post_process_channels=ppc, readout_type=config.readout_type,
num_depth_bins=config.num_depth_bins,
min_depth=config.min_depth, max_depth=config.max_depth,
)
self.normals_head = dpt_mod.DPTNormalsHead(
input_embed_dim=config.embed_dim, channels=config.channels,
post_process_channels=ppc, readout_type=config.readout_type,
)
self.segmentation_head = dpt_mod.DPTSegmentationHead(
input_embed_dim=config.embed_dim, channels=config.channels,
post_process_channels=ppc, readout_type=config.readout_type,
num_classes=config.num_seg_classes,
)
self._backbone = None
def _get_backbone(self):
if self._backbone is None:
self._backbone = AutoModel.from_pretrained(self.config.backbone_repo, trust_remote_code=True)
self._backbone.to(self.device).eval()
return self._backbone.vision_encoder
def _extract_intermediate(self, pixel_values):
backbone = self._get_backbone()
intermediate = backbone.get_intermediate_layers(
pixel_values, n=self.config.block_indices,
reshape=True, return_class_token=True, norm=True,
)
return [(cls_tok, patch_feat) for patch_feat, cls_tok in intermediate]
@torch.no_grad()
def predict_depth(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""Predict depth map. Returns (B, 1, H, W)."""
pixel_values = pixel_values.to(self.device)
h, w = pixel_values.shape[2:]
dpt_inputs = self._extract_intermediate(pixel_values)
return self.depth_head(dpt_inputs, image_size=(h, w))
@torch.no_grad()
def predict_normals(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""Predict surface normals. Returns (B, 3, H, W)."""
pixel_values = pixel_values.to(self.device)
h, w = pixel_values.shape[2:]
dpt_inputs = self._extract_intermediate(pixel_values)
return self.normals_head(dpt_inputs, image_size=(h, w))
@torch.no_grad()
def predict_segmentation(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""Predict semantic segmentation (ADE20K). Returns (B, 150, H, W)."""
pixel_values = pixel_values.to(self.device)
h, w = pixel_values.shape[2:]
dpt_inputs = self._extract_intermediate(pixel_values)
return self.segmentation_head(dpt_inputs, image_size=(h, w))
def forward(self, pixel_values: torch.Tensor) -> TIPSv2DPTOutput:
"""Run all three tasks. Returns TIPSv2DPTOutput."""
pixel_values = pixel_values.to(self.device)
h, w = pixel_values.shape[2:]
dpt_inputs = self._extract_intermediate(pixel_values)
return TIPSv2DPTOutput(
depth=self.depth_head(dpt_inputs, image_size=(h, w)),
normals=self.normals_head(dpt_inputs, image_size=(h, w)),
segmentation=self.segmentation_head(dpt_inputs, image_size=(h, w)),
)