gberton commited on
Commit
4e09276
·
verified ·
1 Parent(s): 10532f4

Upload modeling_dpt.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_dpt.py +141 -0
modeling_dpt.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TIPSv2 DPT dense prediction model for HuggingFace."""
2
+
3
+ import importlib
4
+ import os
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from huggingface_hub import hf_hub_download
11
+ from transformers import AutoModel, PreTrainedModel
12
+
13
+ from .configuration_dpt import TIPSv2DPTConfig
14
+
15
+ _this_dir = Path(__file__).parent
16
+ _sibling_cache = {}
17
+
18
+
19
+ def _load_sibling(name, repo_id=None):
20
+ if name in _sibling_cache:
21
+ return _sibling_cache[name]
22
+ path = _this_dir / f"{name}.py"
23
+ if not path.exists() and repo_id:
24
+ path = Path(hf_hub_download(repo_id, f"{name}.py"))
25
+ spec = importlib.util.spec_from_file_location(name, str(path))
26
+ mod = importlib.util.module_from_spec(spec)
27
+ spec.loader.exec_module(mod)
28
+ _sibling_cache[name] = mod
29
+ return mod
30
+
31
+
32
+ @dataclass
33
+ class TIPSv2DPTOutput:
34
+ depth: Optional[torch.Tensor] = None
35
+ normals: Optional[torch.Tensor] = None
36
+ segmentation: Optional[torch.Tensor] = None
37
+
38
+
39
+ class TIPSv2DPTModel(PreTrainedModel):
40
+ """TIPSv2 DPT dense prediction model (depth, normals, segmentation).
41
+
42
+ The backbone is loaded automatically from the base TIPSv2 model repo.
43
+
44
+ Usage::
45
+
46
+ model = AutoModel.from_pretrained("google/tipsv2-l14-dpt", trust_remote_code=True)
47
+ model.eval().cuda()
48
+
49
+ outputs = model(pixel_values)
50
+ outputs.depth # (B, 1, H, W)
51
+ outputs.normals # (B, 3, H, W)
52
+ outputs.segmentation # (B, 150, H, W)
53
+
54
+ # Individual tasks
55
+ depth = model.predict_depth(pixel_values)
56
+ normals = model.predict_normals(pixel_values)
57
+ seg = model.predict_segmentation(pixel_values)
58
+ """
59
+
60
+ config_class = TIPSv2DPTConfig
61
+ _no_split_modules = []
62
+ _supports_cache_class = False
63
+ _tied_weights_keys = []
64
+
65
+ @property
66
+ def all_tied_weights_keys(self):
67
+ return {}
68
+
69
+ def __init__(self, config: TIPSv2DPTConfig):
70
+ super().__init__(config)
71
+
72
+ repo_id = getattr(config, "_name_or_path", None)
73
+ dpt_mod = _load_sibling("dpt_head", repo_id)
74
+
75
+ ppc = tuple(config.post_process_channels)
76
+
77
+ self.depth_head = dpt_mod.DPTDepthHead(
78
+ input_embed_dim=config.embed_dim, channels=config.channels,
79
+ post_process_channels=ppc, readout_type=config.readout_type,
80
+ num_depth_bins=config.num_depth_bins,
81
+ min_depth=config.min_depth, max_depth=config.max_depth,
82
+ )
83
+ self.normals_head = dpt_mod.DPTNormalsHead(
84
+ input_embed_dim=config.embed_dim, channels=config.channels,
85
+ post_process_channels=ppc, readout_type=config.readout_type,
86
+ )
87
+ self.segmentation_head = dpt_mod.DPTSegmentationHead(
88
+ input_embed_dim=config.embed_dim, channels=config.channels,
89
+ post_process_channels=ppc, readout_type=config.readout_type,
90
+ num_classes=config.num_seg_classes,
91
+ )
92
+ self._backbone = None
93
+
94
+ def _get_backbone(self):
95
+ if self._backbone is None:
96
+ self._backbone = AutoModel.from_pretrained(self.config.backbone_repo, trust_remote_code=True)
97
+ self._backbone.to(self.device).eval()
98
+ return self._backbone.vision_encoder
99
+
100
+ def _extract_intermediate(self, pixel_values):
101
+ backbone = self._get_backbone()
102
+ intermediate = backbone.get_intermediate_layers(
103
+ pixel_values, n=self.config.block_indices,
104
+ reshape=True, return_class_token=True, norm=True,
105
+ )
106
+ return [(cls_tok, patch_feat) for patch_feat, cls_tok in intermediate]
107
+
108
+ @torch.no_grad()
109
+ def predict_depth(self, pixel_values: torch.Tensor) -> torch.Tensor:
110
+ """Predict depth map. Returns (B, 1, H, W)."""
111
+ pixel_values = pixel_values.to(self.device)
112
+ h, w = pixel_values.shape[2:]
113
+ dpt_inputs = self._extract_intermediate(pixel_values)
114
+ return self.depth_head(dpt_inputs, image_size=(h, w))
115
+
116
+ @torch.no_grad()
117
+ def predict_normals(self, pixel_values: torch.Tensor) -> torch.Tensor:
118
+ """Predict surface normals. Returns (B, 3, H, W)."""
119
+ pixel_values = pixel_values.to(self.device)
120
+ h, w = pixel_values.shape[2:]
121
+ dpt_inputs = self._extract_intermediate(pixel_values)
122
+ return self.normals_head(dpt_inputs, image_size=(h, w))
123
+
124
+ @torch.no_grad()
125
+ def predict_segmentation(self, pixel_values: torch.Tensor) -> torch.Tensor:
126
+ """Predict semantic segmentation (ADE20K). Returns (B, 150, H, W)."""
127
+ pixel_values = pixel_values.to(self.device)
128
+ h, w = pixel_values.shape[2:]
129
+ dpt_inputs = self._extract_intermediate(pixel_values)
130
+ return self.segmentation_head(dpt_inputs, image_size=(h, w))
131
+
132
+ def forward(self, pixel_values: torch.Tensor) -> TIPSv2DPTOutput:
133
+ """Run all three tasks. Returns TIPSv2DPTOutput."""
134
+ pixel_values = pixel_values.to(self.device)
135
+ h, w = pixel_values.shape[2:]
136
+ dpt_inputs = self._extract_intermediate(pixel_values)
137
+ return TIPSv2DPTOutput(
138
+ depth=self.depth_head(dpt_inputs, image_size=(h, w)),
139
+ normals=self.normals_head(dpt_inputs, image_size=(h, w)),
140
+ segmentation=self.segmentation_head(dpt_inputs, image_size=(h, w)),
141
+ )