Nap commited on
Commit
f9441d0
·
verified ·
1 Parent(s): 029d229

Modified version of Kijai's ComfyUI-DepthAnythingV2 custom nodes

Browse files

Modified version of Kijai's ComfyUI-DepthAnythingV2 custom nodes that will work with depth_anything_v2_vitg_fp32.safetensors. Just replace the ComfyUI/custom_nodes/comfyui-depthanythingv2/nodes.py file with this one and ensure depth_anything_v2_vitg_fp32.safetensors is in the ComfyUI/models/depthanything/ folder, as it will not be downloaded automatically.

Files changed (1) hide show
  1. nodes.py +190 -0
nodes.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms
5
+ import os
6
+ from contextlib import nullcontext
7
+
8
+ import comfy.model_management as mm
9
+ from comfy.utils import ProgressBar, load_torch_file
10
+ import folder_paths
11
+
12
+ from .depth_anything_v2.dpt import DepthAnythingV2
13
+
14
+ from contextlib import nullcontext
15
+ try:
16
+ from accelerate import init_empty_weights
17
+ from accelerate.utils import set_module_tensor_to_device
18
+ is_accelerate_available = True
19
+ except:
20
+ pass
21
+
22
+ class DownloadAndLoadDepthAnythingV2Model:
23
+ @classmethod
24
+ def INPUT_TYPES(s):
25
+ return {"required": {
26
+ "model": (
27
+ [
28
+ 'depth_anything_v2_vits_fp16.safetensors',
29
+ 'depth_anything_v2_vits_fp32.safetensors',
30
+ 'depth_anything_v2_vitb_fp16.safetensors',
31
+ 'depth_anything_v2_vitb_fp32.safetensors',
32
+ 'depth_anything_v2_vitl_fp16.safetensors',
33
+ 'depth_anything_v2_vitl_fp32.safetensors',
34
+ 'depth_anything_v2_vitg_fp32.safetensors',
35
+ 'depth_anything_v2_metric_hypersim_vitl_fp32.safetensors',
36
+ 'depth_anything_v2_metric_vkitti_vitl_fp32.safetensors'
37
+ ],
38
+ {
39
+ "default": 'depth_anything_v2_vitl_fp32.safetensors'
40
+ }),
41
+ },
42
+ }
43
+
44
+ RETURN_TYPES = ("DAMODEL",)
45
+ RETURN_NAMES = ("da_v2_model",)
46
+ FUNCTION = "loadmodel"
47
+ CATEGORY = "DepthAnythingV2"
48
+ DESCRIPTION = """
49
+ Models autodownload to `ComfyUI\models\depthanything` from
50
+ https://huggingface.co/Kijai/DepthAnythingV2-safetensors/tree/main
51
+
52
+ fp16 reduces quality by a LOT, not recommended.
53
+ """
54
+
55
+ def loadmodel(self, model):
56
+ device = mm.get_torch_device()
57
+ dtype = torch.float16 if "fp16" in model else torch.float32
58
+ model_configs = {
59
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
60
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
61
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
62
+ 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
63
+ }
64
+ custom_config = {
65
+ 'model_name': model,
66
+ }
67
+ if not hasattr(self, 'model') or self.model == None or custom_config != self.current_config:
68
+ self.current_config = custom_config
69
+ download_path = os.path.join(folder_paths.models_dir, "depthanything")
70
+ model_path = os.path.join(download_path, model)
71
+
72
+ if not os.path.exists(model_path):
73
+ print(f"Downloading model to: {model_path}")
74
+ from huggingface_hub import snapshot_download
75
+ snapshot_download(repo_id="Kijai/DepthAnythingV2-safetensors",
76
+ allow_patterns=[f"*{model}*"],
77
+ local_dir=download_path,
78
+ local_dir_use_symlinks=False)
79
+
80
+ print(f"Loading model from: {model_path}")
81
+
82
+ if "vitg" in model:
83
+ encoder = "vitg"
84
+ elif "vitl" in model:
85
+ encoder = "vitl"
86
+ elif "vitb" in model:
87
+ encoder = "vitb"
88
+ elif "vits" in model:
89
+ encoder = "vits"
90
+
91
+ if "hypersim" in model:
92
+ max_depth = 20.0
93
+ else:
94
+ max_depth = 80.0
95
+
96
+ with (init_empty_weights() if is_accelerate_available else nullcontext()):
97
+ if 'metric' in model:
98
+ self.model = DepthAnythingV2(**{**model_configs[encoder], 'is_metric': True, 'max_depth': max_depth})
99
+ else:
100
+ self.model = DepthAnythingV2(**model_configs[encoder])
101
+
102
+ state_dict = load_torch_file(model_path)
103
+ if is_accelerate_available:
104
+ for key in state_dict:
105
+ set_module_tensor_to_device(self.model, key, device=device, dtype=dtype, value=state_dict[key])
106
+ else:
107
+ self.model.load_state_dict(state_dict)
108
+
109
+ self.model.eval()
110
+ da_model = {
111
+ "model": self.model,
112
+ "dtype": dtype,
113
+ "is_metric": self.model.is_metric
114
+ }
115
+
116
+ return (da_model,)
117
+
118
+ class DepthAnything_V2:
119
+ @classmethod
120
+ def INPUT_TYPES(s):
121
+ return {"required": {
122
+ "da_model": ("DAMODEL", ),
123
+ "images": ("IMAGE", ),
124
+ },
125
+ }
126
+
127
+ RETURN_TYPES = ("IMAGE",)
128
+ RETURN_NAMES =("image",)
129
+ FUNCTION = "process"
130
+ CATEGORY = "DepthAnythingV2"
131
+ DESCRIPTION = """
132
+ https://depth-anything-v2.github.io
133
+ """
134
+
135
+ def process(self, da_model, images):
136
+ device = mm.get_torch_device()
137
+ offload_device = mm.unet_offload_device()
138
+ model = da_model['model']
139
+ dtype=da_model['dtype']
140
+
141
+ B, H, W, C = images.shape
142
+
143
+ #images = images.to(device)
144
+ images = images.permute(0, 3, 1, 2)
145
+
146
+ orig_H, orig_W = H, W
147
+ if W % 14 != 0:
148
+ W = W - (W % 14)
149
+ if H % 14 != 0:
150
+ H = H - (H % 14)
151
+ if orig_H % 14 != 0 or orig_W % 14 != 0:
152
+ images = F.interpolate(images, size=(H, W), mode="bilinear")
153
+
154
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
155
+ normalized_images = normalize(images)
156
+ pbar = ProgressBar(B)
157
+ out = []
158
+ model.to(device)
159
+ autocast_condition = (dtype != torch.float32) and not mm.is_device_mps(device)
160
+ with torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext():
161
+ for img in normalized_images:
162
+ depth = model(img.unsqueeze(0).to(device))
163
+ depth = (depth - depth.min()) / (depth.max() - depth.min())
164
+ out.append(depth.cpu())
165
+ pbar.update(1)
166
+ model.to(offload_device)
167
+ depth_out = torch.cat(out, dim=0)
168
+ depth_out = depth_out.unsqueeze(-1).repeat(1, 1, 1, 3).cpu().float()
169
+
170
+ final_H = (orig_H // 2) * 2
171
+ final_W = (orig_W // 2) * 2
172
+
173
+
174
+
175
+ if depth_out.shape[1] != final_H or depth_out.shape[2] != final_W:
176
+ depth_out = F.interpolate(depth_out.permute(0, 3, 1, 2), size=(final_H, final_W), mode="bilinear").permute(0, 2, 3, 1)
177
+ depth_out = (depth_out - depth_out.min()) / (depth_out.max() - depth_out.min())
178
+ depth_out = torch.clamp(depth_out, 0, 1)
179
+ if da_model['is_metric']:
180
+ depth_out = 1 - depth_out
181
+ return (depth_out,)
182
+
183
+ NODE_CLASS_MAPPINGS = {
184
+ "DepthAnything_V2": DepthAnything_V2,
185
+ "DownloadAndLoadDepthAnythingV2Model": DownloadAndLoadDepthAnythingV2Model
186
+ }
187
+ NODE_DISPLAY_NAME_MAPPINGS = {
188
+ "DepthAnything_V2": "Depth Anything V2",
189
+ "DownloadAndLoadDepthAnythingV2Model": "DownloadAndLoadDepthAnythingV2Model"
190
+ }