File size: 5,250 Bytes
d4f10cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""TIPSv2 DPT dense prediction model for HuggingFace."""

import importlib
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import torch
from huggingface_hub import hf_hub_download
from transformers import AutoModel, PreTrainedModel

from .configuration_dpt import TIPSv2DPTConfig

_this_dir = Path(__file__).parent
_sibling_cache = {}


def _load_sibling(name, repo_id=None):
    if name in _sibling_cache:
        return _sibling_cache[name]
    path = _this_dir / f"{name}.py"
    if not path.exists() and repo_id:
        path = Path(hf_hub_download(repo_id, f"{name}.py"))
    spec = importlib.util.spec_from_file_location(name, str(path))
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)
    _sibling_cache[name] = mod
    return mod


@dataclass
class TIPSv2DPTOutput:
    depth: Optional[torch.Tensor] = None
    normals: Optional[torch.Tensor] = None
    segmentation: Optional[torch.Tensor] = None


class TIPSv2DPTModel(PreTrainedModel):
    """TIPSv2 DPT dense prediction model (depth, normals, segmentation).

    The backbone is loaded automatically from the base TIPSv2 model repo.

    Usage::

        model = AutoModel.from_pretrained("google/tipsv2-l14-dpt", trust_remote_code=True)
        model.eval().cuda()

        outputs = model(pixel_values)
        outputs.depth          # (B, 1, H, W)
        outputs.normals        # (B, 3, H, W)
        outputs.segmentation   # (B, 150, H, W)

        # Individual tasks
        depth = model.predict_depth(pixel_values)
        normals = model.predict_normals(pixel_values)
        seg = model.predict_segmentation(pixel_values)
    """

    config_class = TIPSv2DPTConfig
    _no_split_modules = []
    _supports_cache_class = False
    _tied_weights_keys = []

    @property
    def all_tied_weights_keys(self):
        return {}

    def __init__(self, config: TIPSv2DPTConfig):
        super().__init__(config)

        repo_id = getattr(config, "_name_or_path", None)
        dpt_mod = _load_sibling("dpt_head", repo_id)

        ppc = tuple(config.post_process_channels)

        self.depth_head = dpt_mod.DPTDepthHead(
            input_embed_dim=config.embed_dim, channels=config.channels,
            post_process_channels=ppc, readout_type=config.readout_type,
            num_depth_bins=config.num_depth_bins,
            min_depth=config.min_depth, max_depth=config.max_depth,
        )
        self.normals_head = dpt_mod.DPTNormalsHead(
            input_embed_dim=config.embed_dim, channels=config.channels,
            post_process_channels=ppc, readout_type=config.readout_type,
        )
        self.segmentation_head = dpt_mod.DPTSegmentationHead(
            input_embed_dim=config.embed_dim, channels=config.channels,
            post_process_channels=ppc, readout_type=config.readout_type,
            num_classes=config.num_seg_classes,
        )
        self._backbone = None

    def _get_backbone(self):
        if self._backbone is None:
            self._backbone = AutoModel.from_pretrained(self.config.backbone_repo, trust_remote_code=True)
            self._backbone.to(self.device).eval()
        return self._backbone.vision_encoder

    def _extract_intermediate(self, pixel_values):
        backbone = self._get_backbone()
        intermediate = backbone.get_intermediate_layers(
            pixel_values, n=self.config.block_indices,
            reshape=True, return_class_token=True, norm=True,
        )
        return [(cls_tok, patch_feat) for patch_feat, cls_tok in intermediate]

    @torch.no_grad()
    def predict_depth(self, pixel_values: torch.Tensor) -> torch.Tensor:
        """Predict depth map. Returns (B, 1, H, W)."""
        pixel_values = pixel_values.to(self.device)
        h, w = pixel_values.shape[2:]
        dpt_inputs = self._extract_intermediate(pixel_values)
        return self.depth_head(dpt_inputs, image_size=(h, w))

    @torch.no_grad()
    def predict_normals(self, pixel_values: torch.Tensor) -> torch.Tensor:
        """Predict surface normals. Returns (B, 3, H, W)."""
        pixel_values = pixel_values.to(self.device)
        h, w = pixel_values.shape[2:]
        dpt_inputs = self._extract_intermediate(pixel_values)
        return self.normals_head(dpt_inputs, image_size=(h, w))

    @torch.no_grad()
    def predict_segmentation(self, pixel_values: torch.Tensor) -> torch.Tensor:
        """Predict semantic segmentation (ADE20K). Returns (B, 150, H, W)."""
        pixel_values = pixel_values.to(self.device)
        h, w = pixel_values.shape[2:]
        dpt_inputs = self._extract_intermediate(pixel_values)
        return self.segmentation_head(dpt_inputs, image_size=(h, w))

    def forward(self, pixel_values: torch.Tensor) -> TIPSv2DPTOutput:
        """Run all three tasks. Returns TIPSv2DPTOutput."""
        pixel_values = pixel_values.to(self.device)
        h, w = pixel_values.shape[2:]
        dpt_inputs = self._extract_intermediate(pixel_values)
        return TIPSv2DPTOutput(
            depth=self.depth_head(dpt_inputs, image_size=(h, w)),
            normals=self.normals_head(dpt_inputs, image_size=(h, w)),
            segmentation=self.segmentation_head(dpt_inputs, image_size=(h, w)),
        )