plutosss commited on
Commit
6ca799c
·
verified ·
1 Parent(s): 481ebb3

Delete depthAnything/depth_anything/dpt.py

Browse files
Files changed (1) hide show
  1. depthAnything/depth_anything/dpt.py +0 -188
depthAnything/depth_anything/dpt.py DELETED
@@ -1,188 +0,0 @@
1
- import argparse
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
6
-
7
- from depth_anything.blocks import FeatureFusionBlock, _make_scratch
8
- from .blocks import FeatureFusionBlock, _make_scratch
9
-
10
-
11
- def _make_fusion_block(features, use_bn, size = None):
12
- return FeatureFusionBlock(
13
- features,
14
- nn.ReLU(False),
15
- deconv=False,
16
- bn=use_bn,
17
- expand=False,
18
- align_corners=True,
19
- size=size,
20
- )
21
-
22
-
23
- class DPTHead(nn.Module):
24
- def __init__(self, nclass, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False):
25
- super(DPTHead, self).__init__()
26
-
27
- self.nclass = nclass
28
- self.use_clstoken = use_clstoken
29
-
30
- self.projects = nn.ModuleList([
31
- nn.Conv2d(
32
- in_channels=in_channels,
33
- out_channels=out_channel,
34
- kernel_size=1,
35
- stride=1,
36
- padding=0,
37
- ) for out_channel in out_channels
38
- ])
39
-
40
- self.resize_layers = nn.ModuleList([
41
- nn.ConvTranspose2d(
42
- in_channels=out_channels[0],
43
- out_channels=out_channels[0],
44
- kernel_size=4,
45
- stride=4,
46
- padding=0),
47
- nn.ConvTranspose2d(
48
- in_channels=out_channels[1],
49
- out_channels=out_channels[1],
50
- kernel_size=2,
51
- stride=2,
52
- padding=0),
53
- nn.Identity(),
54
- nn.Conv2d(
55
- in_channels=out_channels[3],
56
- out_channels=out_channels[3],
57
- kernel_size=3,
58
- stride=2,
59
- padding=1)
60
- ])
61
-
62
- if use_clstoken:
63
- self.readout_projects = nn.ModuleList()
64
- for _ in range(len(self.projects)):
65
- self.readout_projects.append(
66
- nn.Sequential(
67
- nn.Linear(2 * in_channels, in_channels),
68
- nn.GELU()))
69
-
70
- self.scratch = _make_scratch(
71
- out_channels,
72
- features,
73
- groups=1,
74
- expand=False,
75
- )
76
-
77
- self.scratch.stem_transpose = None
78
-
79
- self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
80
- self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
81
- self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
82
- self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
83
-
84
- head_features_1 = features
85
- head_features_2 = 32
86
-
87
- if nclass > 1:
88
- self.scratch.output_conv = nn.Sequential(
89
- nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1),
90
- nn.ReLU(True),
91
- nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0),
92
- )
93
- else:
94
- self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
95
-
96
- self.scratch.output_conv2 = nn.Sequential(
97
- nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
98
- nn.ReLU(True),
99
- nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
100
- nn.ReLU(True),
101
- nn.Identity(),
102
- )
103
-
104
- def forward(self, out_features, patch_h, patch_w):
105
- out = []
106
- for i, x in enumerate(out_features):
107
- if self.use_clstoken:
108
- x, cls_token = x[0], x[1]
109
- readout = cls_token.unsqueeze(1).expand_as(x)
110
- x = self.readout_projects[i](torch.cat((x, readout), -1))
111
- else:
112
- x = x[0]
113
-
114
- x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
115
-
116
- x = self.projects[i](x)
117
- x = self.resize_layers[i](x)
118
-
119
- out.append(x)
120
-
121
- layer_1, layer_2, layer_3, layer_4 = out
122
-
123
- layer_1_rn = self.scratch.layer1_rn(layer_1)
124
- layer_2_rn = self.scratch.layer2_rn(layer_2)
125
- layer_3_rn = self.scratch.layer3_rn(layer_3)
126
- layer_4_rn = self.scratch.layer4_rn(layer_4)
127
-
128
- path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
129
- path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
130
- path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
131
- path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
132
-
133
- out = self.scratch.output_conv1(path_1)
134
- out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
135
- out = self.scratch.output_conv2(out)
136
-
137
- return out
138
-
139
-
140
- class DPT_DINOv2(nn.Module):
141
- def __init__(self, encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024], use_bn=False, use_clstoken=False, localhub=True):
142
- super(DPT_DINOv2, self).__init__()
143
-
144
- assert encoder in ['vits', 'vitb', 'vitl']
145
-
146
- # in case the Internet connection is not stable, please load the DINOv2 locally
147
- if localhub:
148
- self.pretrained = torch.hub.load('torchhub/facebookresearch_dinov2_main', 'dinov2_{:}14'.format(encoder), source='local', pretrained=False)
149
- else:
150
- self.pretrained = torch.hub.load('facebookresearch/dinov2', 'dinov2_{:}14'.format(encoder))
151
-
152
- dim = self.pretrained.blocks[0].attn.qkv.in_features
153
-
154
- self.depth_head = DPTHead(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
155
-
156
- def forward(self, x):
157
- h, w = x.shape[-2:]
158
-
159
- features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True)
160
-
161
- patch_h, patch_w = h // 14, w // 14
162
-
163
- depth = self.depth_head(features, patch_h, patch_w)
164
- depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True)
165
- depth = F.relu(depth)
166
-
167
- return depth.squeeze(1)
168
-
169
-
170
- class DepthAnything(DPT_DINOv2, PyTorchModelHubMixin):
171
- def __init__(self, config):
172
- super().__init__(**config)
173
-
174
-
175
- if __name__ == '__main__':
176
- parser = argparse.ArgumentParser()
177
- parser.add_argument(
178
- "--encoder",
179
- default="vits",
180
- type=str,
181
- choices=["vits", "vitb", "vitl"],
182
- )
183
- args = parser.parse_args()
184
-
185
- model = DepthAnything.from_pretrained("LiheYoung/depth_anything_{:}14".format(args.encoder))
186
-
187
- print(model)
188
-