File size: 7,583 Bytes
24dbacf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 | #!/usr/bin/env python3
"""Repackage an original Depth Anything 3 checkpoint into ComfyUI's native layout.
Usage:
python scripts/convert_da3.py \
--input da3_original.safetensors \
--output models/diffusion_models/da3_comfy.safetensors
This applies, *offline*, the exact transform that ComfyUI used to do at load
time for DA3:
* remap the DINOv2 backbone keys ``backbone.pretrained.*`` (upstream DA3)
to the ``Dinov2Model`` runtime layout (``backbone.embeddings.*``,
``backbone.encoder.layer.*``, ``backbone.layernorm.*``);
* split each fused ``attn.qkv`` projection into separate query/key/value
linears;
* drop the unused Gaussian-splat head weights (``gs_head.*``, ``gs_adapter.*``).
The head (``head.*``), camera encoder/decoder (``cam_enc.*``, ``cam_dec.*``)
and any other keys are passed through unchanged. After conversion the file
loads directly via ComfyUI auto-detection with no in-code remap.
"""
import argparse
import glob
import os
import torch
from safetensors.torch import load_file, save_file
DROP_PREFIXES = ("gs_head.", "gs_adapter.")
def remap_backbone_keys(state_dict, prefix="backbone."):
"""Map ``backbone.pretrained.*`` (upstream DA3) keys to ``Dinov2Model`` layout."""
pre = prefix + "pretrained."
src_keys = [k for k in state_dict.keys() if k.startswith(pre)]
if not src_keys:
return state_dict
static_renames = {
pre + "patch_embed.proj.weight": prefix + "embeddings.patch_embeddings.projection.weight",
pre + "patch_embed.proj.bias": prefix + "embeddings.patch_embeddings.projection.bias",
pre + "pos_embed": prefix + "embeddings.position_embeddings",
pre + "cls_token": prefix + "embeddings.cls_token",
pre + "camera_token": prefix + "embeddings.camera_token",
pre + "norm.weight": prefix + "layernorm.weight",
pre + "norm.bias": prefix + "layernorm.bias",
}
for src, dst in static_renames.items():
if src in state_dict:
state_dict[dst] = state_dict.pop(src)
block_pre = pre + "blocks."
block_keys = [k for k in state_dict.keys() if k.startswith(block_pre)]
for k in block_keys:
rest = k[len(block_pre):] # e.g. "5.attn.qkv.weight"
idx_str, _, sub = rest.partition(".")
target_block = "{}encoder.layer.{}.".format(prefix, idx_str)
# Fused QKV -> split query/key/value linears.
if sub == "attn.qkv.weight":
qkv = state_dict.pop(k)
c = qkv.shape[0] // 3
state_dict[target_block + "attention.attention.query.weight"] = qkv[:c].clone()
state_dict[target_block + "attention.attention.key.weight"] = qkv[c:2 * c].clone()
state_dict[target_block + "attention.attention.value.weight"] = qkv[2 * c:].clone()
continue
if sub == "attn.qkv.bias":
qkv = state_dict.pop(k)
c = qkv.shape[0] // 3
state_dict[target_block + "attention.attention.query.bias"] = qkv[:c].clone()
state_dict[target_block + "attention.attention.key.bias"] = qkv[c:2 * c].clone()
state_dict[target_block + "attention.attention.value.bias"] = qkv[2 * c:].clone()
continue
# Sub-key remap (suffix preserved).
if sub.startswith("attn.proj."):
tail = sub[len("attn.proj."):]
new = "attention.output.dense." + tail
elif sub.startswith("attn.q_norm."):
new = "attention.q_norm." + sub[len("attn.q_norm."):]
elif sub.startswith("attn.k_norm."):
new = "attention.k_norm." + sub[len("attn.k_norm."):]
elif sub == "ls1.gamma":
new = "layer_scale1.lambda1"
elif sub == "ls2.gamma":
new = "layer_scale2.lambda1"
elif sub.startswith("mlp.w12."):
new = "mlp.weights_in." + sub[len("mlp.w12."):]
elif sub.startswith("mlp.w3."):
new = "mlp.weights_out." + sub[len("mlp.w3."):]
elif sub.startswith(("norm1.", "norm2.", "mlp.fc1.", "mlp.fc2.")):
new = sub
else:
# Unrecognised key -- leave as-is so a later load can complain.
raise ValueError("Unrecognised DA3 backbone key: {}".format(k))
state_dict[target_block + new] = state_dict.pop(k)
return state_dict
def drop_unused(state_dict):
for k in list(state_dict.keys()):
if k.startswith(DROP_PREFIXES):
state_dict.pop(k)
return state_dict
def load_state_dict(path):
if os.path.isdir(path):
sd = {}
files = sorted(glob.glob(os.path.join(path, "*.safetensors")))
if not files:
raise FileNotFoundError("No .safetensors files in {}".format(path))
for f in files:
sd.update(load_file(f))
return sd
if path.endswith(".safetensors"):
return load_file(path)
sd = torch.load(path, map_location="cpu", weights_only=False)
# Unwrap common nesting (e.g. {"model": ...} / {"state_dict": ...}).
for wrap in ("state_dict", "model", "module"):
if isinstance(sd, dict) and wrap in sd and isinstance(sd[wrap], dict):
sd = sd[wrap]
break
return sd
def main():
parser = argparse.ArgumentParser(
description="Repackage an original Depth Anything 3 checkpoint into ComfyUI's native layout"
)
parser.add_argument("--input", type=str, required=True,
help="Path to original DA3 .safetensors / .pt / .pth file or directory")
parser.add_argument("--output", type=str, required=True,
help="Output .safetensors file path")
args = parser.parse_args()
print("Loading: {}".format(args.input))
sd = load_state_dict(args.input)
print(" Loaded {} keys".format(len(sd)))
# Original DA3 checkpoints store everything under a "model." string prefix
# (e.g. "model.backbone.pretrained.*"). Strip it so the remap works on bare
# "backbone.*" keys, then re-add it at the end: ComfyUI's loader resolves the
# diffusion-model prefix to "model." for DA3, so the saved file must keep it.
if any(k.startswith("model.") for k in sd):
print(' Stripping "model." prefix for processing')
sd = {(k[len("model."):] if k.startswith("model.") else k): v for k, v in sd.items()}
if any(k.startswith("backbone.pretrained.") for k in sd):
print(" Remapping backbone (backbone.pretrained.* -> Dinov2Model layout)...")
sd = remap_backbone_keys(sd, prefix="backbone.")
elif any(k.startswith("backbone.embeddings.") for k in sd):
print(" Backbone already in ComfyUI layout, skipping remap.")
else:
raise ValueError("Input does not look like a DA3 checkpoint (no backbone.* keys found)")
n_before = len(sd)
sd = drop_unused(sd)
dropped = n_before - len(sd)
if dropped:
print(" Dropped {} unused Gaussian-head keys".format(dropped))
# Re-add the "model." prefix expected by ComfyUI's diffusion-model loader.
sd = {"model." + k: v for k, v in sd.items()}
# safetensors requires contiguous tensors; the qkv split slices are cloned
# above but enforce contiguity defensively for all tensors.
sd = {k: v.contiguous() for k, v in sd.items()}
os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True)
save_file(sd, args.output)
print(" Saved {} keys to {}".format(len(sd), args.output))
if __name__ == "__main__":
main()
|