plutosss commited on
Commit
c68cd5d
·
verified ·
1 Parent(s): 3d1b365

Update depthAnything/depth_anything/dpt.py

Browse files
Files changed (1) hide show
  1. depthAnything/depth_anything/dpt.py +187 -186
depthAnything/depth_anything/dpt.py CHANGED
@@ -1,187 +1,188 @@
1
- import argparse
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
6
-
7
- from depth_anything.blocks import FeatureFusionBlock, _make_scratch
8
-
9
-
10
- def _make_fusion_block(features, use_bn, size = None):
11
- return FeatureFusionBlock(
12
- features,
13
- nn.ReLU(False),
14
- deconv=False,
15
- bn=use_bn,
16
- expand=False,
17
- align_corners=True,
18
- size=size,
19
- )
20
-
21
-
22
- class DPTHead(nn.Module):
23
- def __init__(self, nclass, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False):
24
- super(DPTHead, self).__init__()
25
-
26
- self.nclass = nclass
27
- self.use_clstoken = use_clstoken
28
-
29
- self.projects = nn.ModuleList([
30
- nn.Conv2d(
31
- in_channels=in_channels,
32
- out_channels=out_channel,
33
- kernel_size=1,
34
- stride=1,
35
- padding=0,
36
- ) for out_channel in out_channels
37
- ])
38
-
39
- self.resize_layers = nn.ModuleList([
40
- nn.ConvTranspose2d(
41
- in_channels=out_channels[0],
42
- out_channels=out_channels[0],
43
- kernel_size=4,
44
- stride=4,
45
- padding=0),
46
- nn.ConvTranspose2d(
47
- in_channels=out_channels[1],
48
- out_channels=out_channels[1],
49
- kernel_size=2,
50
- stride=2,
51
- padding=0),
52
- nn.Identity(),
53
- nn.Conv2d(
54
- in_channels=out_channels[3],
55
- out_channels=out_channels[3],
56
- kernel_size=3,
57
- stride=2,
58
- padding=1)
59
- ])
60
-
61
- if use_clstoken:
62
- self.readout_projects = nn.ModuleList()
63
- for _ in range(len(self.projects)):
64
- self.readout_projects.append(
65
- nn.Sequential(
66
- nn.Linear(2 * in_channels, in_channels),
67
- nn.GELU()))
68
-
69
- self.scratch = _make_scratch(
70
- out_channels,
71
- features,
72
- groups=1,
73
- expand=False,
74
- )
75
-
76
- self.scratch.stem_transpose = None
77
-
78
- self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
79
- self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
80
- self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
81
- self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
82
-
83
- head_features_1 = features
84
- head_features_2 = 32
85
-
86
- if nclass > 1:
87
- self.scratch.output_conv = nn.Sequential(
88
- nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1),
89
- nn.ReLU(True),
90
- nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0),
91
- )
92
- else:
93
- self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
94
-
95
- self.scratch.output_conv2 = nn.Sequential(
96
- nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
97
- nn.ReLU(True),
98
- nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
99
- nn.ReLU(True),
100
- nn.Identity(),
101
- )
102
-
103
- def forward(self, out_features, patch_h, patch_w):
104
- out = []
105
- for i, x in enumerate(out_features):
106
- if self.use_clstoken:
107
- x, cls_token = x[0], x[1]
108
- readout = cls_token.unsqueeze(1).expand_as(x)
109
- x = self.readout_projects[i](torch.cat((x, readout), -1))
110
- else:
111
- x = x[0]
112
-
113
- x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
114
-
115
- x = self.projects[i](x)
116
- x = self.resize_layers[i](x)
117
-
118
- out.append(x)
119
-
120
- layer_1, layer_2, layer_3, layer_4 = out
121
-
122
- layer_1_rn = self.scratch.layer1_rn(layer_1)
123
- layer_2_rn = self.scratch.layer2_rn(layer_2)
124
- layer_3_rn = self.scratch.layer3_rn(layer_3)
125
- layer_4_rn = self.scratch.layer4_rn(layer_4)
126
-
127
- path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
128
- path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
129
- path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
130
- path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
131
-
132
- out = self.scratch.output_conv1(path_1)
133
- out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
134
- out = self.scratch.output_conv2(out)
135
-
136
- return out
137
-
138
-
139
- class DPT_DINOv2(nn.Module):
140
- def __init__(self, encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024], use_bn=False, use_clstoken=False, localhub=True):
141
- super(DPT_DINOv2, self).__init__()
142
-
143
- assert encoder in ['vits', 'vitb', 'vitl']
144
-
145
- # in case the Internet connection is not stable, please load the DINOv2 locally
146
- if localhub:
147
- self.pretrained = torch.hub.load('torchhub/facebookresearch_dinov2_main', 'dinov2_{:}14'.format(encoder), source='local', pretrained=False)
148
- else:
149
- self.pretrained = torch.hub.load('facebookresearch/dinov2', 'dinov2_{:}14'.format(encoder))
150
-
151
- dim = self.pretrained.blocks[0].attn.qkv.in_features
152
-
153
- self.depth_head = DPTHead(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
154
-
155
- def forward(self, x):
156
- h, w = x.shape[-2:]
157
-
158
- features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True)
159
-
160
- patch_h, patch_w = h // 14, w // 14
161
-
162
- depth = self.depth_head(features, patch_h, patch_w)
163
- depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True)
164
- depth = F.relu(depth)
165
-
166
- return depth.squeeze(1)
167
-
168
-
169
- class DepthAnything(DPT_DINOv2, PyTorchModelHubMixin):
170
- def __init__(self, config):
171
- super().__init__(**config)
172
-
173
-
174
- if __name__ == '__main__':
175
- parser = argparse.ArgumentParser()
176
- parser.add_argument(
177
- "--encoder",
178
- default="vits",
179
- type=str,
180
- choices=["vits", "vitb", "vitl"],
181
- )
182
- args = parser.parse_args()
183
-
184
- model = DepthAnything.from_pretrained("LiheYoung/depth_anything_{:}14".format(args.encoder))
185
-
186
- print(model)
 
187
 
 
1
+ import argparse
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
6
+
7
+ from depth_anything.blocks import FeatureFusionBlock, _make_scratch
8
+ from .blocks import FeatureFusionBlock, _make_scratch
9
+
10
+
11
+ def _make_fusion_block(features, use_bn, size = None):
12
+ return FeatureFusionBlock(
13
+ features,
14
+ nn.ReLU(False),
15
+ deconv=False,
16
+ bn=use_bn,
17
+ expand=False,
18
+ align_corners=True,
19
+ size=size,
20
+ )
21
+
22
+
23
+ class DPTHead(nn.Module):
24
+ def __init__(self, nclass, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False):
25
+ super(DPTHead, self).__init__()
26
+
27
+ self.nclass = nclass
28
+ self.use_clstoken = use_clstoken
29
+
30
+ self.projects = nn.ModuleList([
31
+ nn.Conv2d(
32
+ in_channels=in_channels,
33
+ out_channels=out_channel,
34
+ kernel_size=1,
35
+ stride=1,
36
+ padding=0,
37
+ ) for out_channel in out_channels
38
+ ])
39
+
40
+ self.resize_layers = nn.ModuleList([
41
+ nn.ConvTranspose2d(
42
+ in_channels=out_channels[0],
43
+ out_channels=out_channels[0],
44
+ kernel_size=4,
45
+ stride=4,
46
+ padding=0),
47
+ nn.ConvTranspose2d(
48
+ in_channels=out_channels[1],
49
+ out_channels=out_channels[1],
50
+ kernel_size=2,
51
+ stride=2,
52
+ padding=0),
53
+ nn.Identity(),
54
+ nn.Conv2d(
55
+ in_channels=out_channels[3],
56
+ out_channels=out_channels[3],
57
+ kernel_size=3,
58
+ stride=2,
59
+ padding=1)
60
+ ])
61
+
62
+ if use_clstoken:
63
+ self.readout_projects = nn.ModuleList()
64
+ for _ in range(len(self.projects)):
65
+ self.readout_projects.append(
66
+ nn.Sequential(
67
+ nn.Linear(2 * in_channels, in_channels),
68
+ nn.GELU()))
69
+
70
+ self.scratch = _make_scratch(
71
+ out_channels,
72
+ features,
73
+ groups=1,
74
+ expand=False,
75
+ )
76
+
77
+ self.scratch.stem_transpose = None
78
+
79
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
80
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
81
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
82
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
83
+
84
+ head_features_1 = features
85
+ head_features_2 = 32
86
+
87
+ if nclass > 1:
88
+ self.scratch.output_conv = nn.Sequential(
89
+ nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1),
90
+ nn.ReLU(True),
91
+ nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0),
92
+ )
93
+ else:
94
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
95
+
96
+ self.scratch.output_conv2 = nn.Sequential(
97
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
98
+ nn.ReLU(True),
99
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
100
+ nn.ReLU(True),
101
+ nn.Identity(),
102
+ )
103
+
104
+ def forward(self, out_features, patch_h, patch_w):
105
+ out = []
106
+ for i, x in enumerate(out_features):
107
+ if self.use_clstoken:
108
+ x, cls_token = x[0], x[1]
109
+ readout = cls_token.unsqueeze(1).expand_as(x)
110
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
111
+ else:
112
+ x = x[0]
113
+
114
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
115
+
116
+ x = self.projects[i](x)
117
+ x = self.resize_layers[i](x)
118
+
119
+ out.append(x)
120
+
121
+ layer_1, layer_2, layer_3, layer_4 = out
122
+
123
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
124
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
125
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
126
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
127
+
128
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
129
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
130
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
131
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
132
+
133
+ out = self.scratch.output_conv1(path_1)
134
+ out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
135
+ out = self.scratch.output_conv2(out)
136
+
137
+ return out
138
+
139
+
140
+ class DPT_DINOv2(nn.Module):
141
+ def __init__(self, encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024], use_bn=False, use_clstoken=False, localhub=True):
142
+ super(DPT_DINOv2, self).__init__()
143
+
144
+ assert encoder in ['vits', 'vitb', 'vitl']
145
+
146
+ # in case the Internet connection is not stable, please load the DINOv2 locally
147
+ if localhub:
148
+ self.pretrained = torch.hub.load('torchhub/facebookresearch_dinov2_main', 'dinov2_{:}14'.format(encoder), source='local', pretrained=False)
149
+ else:
150
+ self.pretrained = torch.hub.load('facebookresearch/dinov2', 'dinov2_{:}14'.format(encoder))
151
+
152
+ dim = self.pretrained.blocks[0].attn.qkv.in_features
153
+
154
+ self.depth_head = DPTHead(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
155
+
156
+ def forward(self, x):
157
+ h, w = x.shape[-2:]
158
+
159
+ features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True)
160
+
161
+ patch_h, patch_w = h // 14, w // 14
162
+
163
+ depth = self.depth_head(features, patch_h, patch_w)
164
+ depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True)
165
+ depth = F.relu(depth)
166
+
167
+ return depth.squeeze(1)
168
+
169
+
170
+ class DepthAnything(DPT_DINOv2, PyTorchModelHubMixin):
171
+ def __init__(self, config):
172
+ super().__init__(**config)
173
+
174
+
175
+ if __name__ == '__main__':
176
+ parser = argparse.ArgumentParser()
177
+ parser.add_argument(
178
+ "--encoder",
179
+ default="vits",
180
+ type=str,
181
+ choices=["vits", "vitb", "vitl"],
182
+ )
183
+ args = parser.parse_args()
184
+
185
+ model = DepthAnything.from_pretrained("LiheYoung/depth_anything_{:}14".format(args.encoder))
186
+
187
+ print(model)
188