plutosss commited on
Commit
c74335c
·
verified ·
1 Parent(s): 8810ea6

Update depthAnything/depth_anything/dpt.py

Browse files
Files changed (1) hide show
  1. depthAnything/depth_anything/dpt.py +188 -186
depthAnything/depth_anything/dpt.py CHANGED
@@ -1,187 +1,189 @@
1
- import argparse
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
6
-
7
- from depth_anything.blocks import FeatureFusionBlock, _make_scratch
8
-
9
-
10
- def _make_fusion_block(features, use_bn, size = None):
11
- return FeatureFusionBlock(
12
- features,
13
- nn.ReLU(False),
14
- deconv=False,
15
- bn=use_bn,
16
- expand=False,
17
- align_corners=True,
18
- size=size,
19
- )
20
-
21
-
22
- class DPTHead(nn.Module):
23
- def __init__(self, nclass, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False):
24
- super(DPTHead, self).__init__()
25
-
26
- self.nclass = nclass
27
- self.use_clstoken = use_clstoken
28
-
29
- self.projects = nn.ModuleList([
30
- nn.Conv2d(
31
- in_channels=in_channels,
32
- out_channels=out_channel,
33
- kernel_size=1,
34
- stride=1,
35
- padding=0,
36
- ) for out_channel in out_channels
37
- ])
38
-
39
- self.resize_layers = nn.ModuleList([
40
- nn.ConvTranspose2d(
41
- in_channels=out_channels[0],
42
- out_channels=out_channels[0],
43
- kernel_size=4,
44
- stride=4,
45
- padding=0),
46
- nn.ConvTranspose2d(
47
- in_channels=out_channels[1],
48
- out_channels=out_channels[1],
49
- kernel_size=2,
50
- stride=2,
51
- padding=0),
52
- nn.Identity(),
53
- nn.Conv2d(
54
- in_channels=out_channels[3],
55
- out_channels=out_channels[3],
56
- kernel_size=3,
57
- stride=2,
58
- padding=1)
59
- ])
60
-
61
- if use_clstoken:
62
- self.readout_projects = nn.ModuleList()
63
- for _ in range(len(self.projects)):
64
- self.readout_projects.append(
65
- nn.Sequential(
66
- nn.Linear(2 * in_channels, in_channels),
67
- nn.GELU()))
68
-
69
- self.scratch = _make_scratch(
70
- out_channels,
71
- features,
72
- groups=1,
73
- expand=False,
74
- )
75
-
76
- self.scratch.stem_transpose = None
77
-
78
- self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
79
- self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
80
- self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
81
- self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
82
-
83
- head_features_1 = features
84
- head_features_2 = 32
85
-
86
- if nclass > 1:
87
- self.scratch.output_conv = nn.Sequential(
88
- nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1),
89
- nn.ReLU(True),
90
- nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0),
91
- )
92
- else:
93
- self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
94
-
95
- self.scratch.output_conv2 = nn.Sequential(
96
- nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
97
- nn.ReLU(True),
98
- nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
99
- nn.ReLU(True),
100
- nn.Identity(),
101
- )
102
-
103
- def forward(self, out_features, patch_h, patch_w):
104
- out = []
105
- for i, x in enumerate(out_features):
106
- if self.use_clstoken:
107
- x, cls_token = x[0], x[1]
108
- readout = cls_token.unsqueeze(1).expand_as(x)
109
- x = self.readout_projects[i](torch.cat((x, readout), -1))
110
- else:
111
- x = x[0]
112
-
113
- x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
114
-
115
- x = self.projects[i](x)
116
- x = self.resize_layers[i](x)
117
-
118
- out.append(x)
119
-
120
- layer_1, layer_2, layer_3, layer_4 = out
121
-
122
- layer_1_rn = self.scratch.layer1_rn(layer_1)
123
- layer_2_rn = self.scratch.layer2_rn(layer_2)
124
- layer_3_rn = self.scratch.layer3_rn(layer_3)
125
- layer_4_rn = self.scratch.layer4_rn(layer_4)
126
-
127
- path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
128
- path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
129
- path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
130
- path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
131
-
132
- out = self.scratch.output_conv1(path_1)
133
- out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
134
- out = self.scratch.output_conv2(out)
135
-
136
- return out
137
-
138
-
139
- class DPT_DINOv2(nn.Module):
140
- def __init__(self, encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024], use_bn=False, use_clstoken=False, localhub=True):
141
- super(DPT_DINOv2, self).__init__()
142
-
143
- assert encoder in ['vits', 'vitb', 'vitl']
144
-
145
- # in case the Internet connection is not stable, please load the DINOv2 locally
146
- if localhub:
147
- self.pretrained = torch.hub.load('torchhub/facebookresearch_dinov2_main', 'dinov2_{:}14'.format(encoder), source='local', pretrained=False)
148
- else:
149
- self.pretrained = torch.hub.load('facebookresearch/dinov2', 'dinov2_{:}14'.format(encoder))
150
-
151
- dim = self.pretrained.blocks[0].attn.qkv.in_features
152
-
153
- self.depth_head = DPTHead(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
154
-
155
- def forward(self, x):
156
- h, w = x.shape[-2:]
157
-
158
- features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True)
159
-
160
- patch_h, patch_w = h // 14, w // 14
161
-
162
- depth = self.depth_head(features, patch_h, patch_w)
163
- depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True)
164
- depth = F.relu(depth)
165
-
166
- return depth.squeeze(1)
167
-
168
-
169
- class DepthAnything(DPT_DINOv2, PyTorchModelHubMixin):
170
- def __init__(self, config):
171
- super().__init__(**config)
172
-
173
-
174
- if __name__ == '__main__':
175
- parser = argparse.ArgumentParser()
176
- parser.add_argument(
177
- "--encoder",
178
- default="vits",
179
- type=str,
180
- choices=["vits", "vitb", "vitl"],
181
- )
182
- args = parser.parse_args()
183
-
184
- model = DepthAnything.from_pretrained("LiheYoung/depth_anything_{:}14".format(args.encoder))
185
-
186
- print(model)
 
 
187
 
 
1
+ import argparse
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
6
+
7
+ #from depth_anything.blocks import FeatureFusionBlock, _make_scratch
8
+ from .blocks import FeatureFusionBlock, _make_scratch
9
+
10
+
11
+
12
+ def _make_fusion_block(features, use_bn, size = None):
13
+ return FeatureFusionBlock(
14
+ features,
15
+ nn.ReLU(False),
16
+ deconv=False,
17
+ bn=use_bn,
18
+ expand=False,
19
+ align_corners=True,
20
+ size=size,
21
+ )
22
+
23
+
24
+ class DPTHead(nn.Module):
25
+ def __init__(self, nclass, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False):
26
+ super(DPTHead, self).__init__()
27
+
28
+ self.nclass = nclass
29
+ self.use_clstoken = use_clstoken
30
+
31
+ self.projects = nn.ModuleList([
32
+ nn.Conv2d(
33
+ in_channels=in_channels,
34
+ out_channels=out_channel,
35
+ kernel_size=1,
36
+ stride=1,
37
+ padding=0,
38
+ ) for out_channel in out_channels
39
+ ])
40
+
41
+ self.resize_layers = nn.ModuleList([
42
+ nn.ConvTranspose2d(
43
+ in_channels=out_channels[0],
44
+ out_channels=out_channels[0],
45
+ kernel_size=4,
46
+ stride=4,
47
+ padding=0),
48
+ nn.ConvTranspose2d(
49
+ in_channels=out_channels[1],
50
+ out_channels=out_channels[1],
51
+ kernel_size=2,
52
+ stride=2,
53
+ padding=0),
54
+ nn.Identity(),
55
+ nn.Conv2d(
56
+ in_channels=out_channels[3],
57
+ out_channels=out_channels[3],
58
+ kernel_size=3,
59
+ stride=2,
60
+ padding=1)
61
+ ])
62
+
63
+ if use_clstoken:
64
+ self.readout_projects = nn.ModuleList()
65
+ for _ in range(len(self.projects)):
66
+ self.readout_projects.append(
67
+ nn.Sequential(
68
+ nn.Linear(2 * in_channels, in_channels),
69
+ nn.GELU()))
70
+
71
+ self.scratch = _make_scratch(
72
+ out_channels,
73
+ features,
74
+ groups=1,
75
+ expand=False,
76
+ )
77
+
78
+ self.scratch.stem_transpose = None
79
+
80
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
81
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
82
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
83
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
84
+
85
+ head_features_1 = features
86
+ head_features_2 = 32
87
+
88
+ if nclass > 1:
89
+ self.scratch.output_conv = nn.Sequential(
90
+ nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1),
91
+ nn.ReLU(True),
92
+ nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0),
93
+ )
94
+ else:
95
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
96
+
97
+ self.scratch.output_conv2 = nn.Sequential(
98
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
99
+ nn.ReLU(True),
100
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
101
+ nn.ReLU(True),
102
+ nn.Identity(),
103
+ )
104
+
105
+ def forward(self, out_features, patch_h, patch_w):
106
+ out = []
107
+ for i, x in enumerate(out_features):
108
+ if self.use_clstoken:
109
+ x, cls_token = x[0], x[1]
110
+ readout = cls_token.unsqueeze(1).expand_as(x)
111
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
112
+ else:
113
+ x = x[0]
114
+
115
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
116
+
117
+ x = self.projects[i](x)
118
+ x = self.resize_layers[i](x)
119
+
120
+ out.append(x)
121
+
122
+ layer_1, layer_2, layer_3, layer_4 = out
123
+
124
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
125
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
126
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
127
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
128
+
129
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
130
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
131
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
132
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
133
+
134
+ out = self.scratch.output_conv1(path_1)
135
+ out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
136
+ out = self.scratch.output_conv2(out)
137
+
138
+ return out
139
+
140
+
141
+ class DPT_DINOv2(nn.Module):
142
+ def __init__(self, encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024], use_bn=False, use_clstoken=False, localhub=True):
143
+ super(DPT_DINOv2, self).__init__()
144
+
145
+ assert encoder in ['vits', 'vitb', 'vitl']
146
+
147
+ # in case the Internet connection is not stable, please load the DINOv2 locally
148
+ if localhub:
149
+ self.pretrained = torch.hub.load('torchhub/facebookresearch_dinov2_main', 'dinov2_{:}14'.format(encoder), source='local', pretrained=False)
150
+ else:
151
+ self.pretrained = torch.hub.load('facebookresearch/dinov2', 'dinov2_{:}14'.format(encoder))
152
+
153
+ dim = self.pretrained.blocks[0].attn.qkv.in_features
154
+
155
+ self.depth_head = DPTHead(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
156
+
157
+ def forward(self, x):
158
+ h, w = x.shape[-2:]
159
+
160
+ features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True)
161
+
162
+ patch_h, patch_w = h // 14, w // 14
163
+
164
+ depth = self.depth_head(features, patch_h, patch_w)
165
+ depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True)
166
+ depth = F.relu(depth)
167
+
168
+ return depth.squeeze(1)
169
+
170
+
171
+ class DepthAnything(DPT_DINOv2, PyTorchModelHubMixin):
172
+ def __init__(self, config):
173
+ super().__init__(**config)
174
+
175
+
176
+ if __name__ == '__main__':
177
+ parser = argparse.ArgumentParser()
178
+ parser.add_argument(
179
+ "--encoder",
180
+ default="vits",
181
+ type=str,
182
+ choices=["vits", "vitb", "vitl"],
183
+ )
184
+ args = parser.parse_args()
185
+
186
+ model = DepthAnything.from_pretrained("LiheYoung/depth_anything_{:}14".format(args.encoder))
187
+
188
+ print(model)
189