comfyui
File size: 7,583 Bytes
24dbacf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
#!/usr/bin/env python3
"""Repackage an original Depth Anything 3 checkpoint into ComfyUI's native layout.

Usage:
    python scripts/convert_da3.py \
        --input da3_original.safetensors \
        --output models/diffusion_models/da3_comfy.safetensors

This applies, *offline*, the exact transform that ComfyUI used to do at load
time for DA3:

  * remap the DINOv2 backbone keys ``backbone.pretrained.*`` (upstream DA3)
    to the ``Dinov2Model`` runtime layout (``backbone.embeddings.*``,
    ``backbone.encoder.layer.*``, ``backbone.layernorm.*``);
  * split each fused ``attn.qkv`` projection into separate query/key/value
    linears;
  * drop the unused Gaussian-splat head weights (``gs_head.*``, ``gs_adapter.*``).

The head (``head.*``), camera encoder/decoder (``cam_enc.*``, ``cam_dec.*``)
and any other keys are passed through unchanged. After conversion the file
loads directly via ComfyUI auto-detection with no in-code remap.
"""

import argparse
import glob
import os

import torch

from safetensors.torch import load_file, save_file

DROP_PREFIXES = ("gs_head.", "gs_adapter.")


def remap_backbone_keys(state_dict, prefix="backbone."):
    """Map ``backbone.pretrained.*`` (upstream DA3) keys to ``Dinov2Model`` layout."""
    pre = prefix + "pretrained."
    src_keys = [k for k in state_dict.keys() if k.startswith(pre)]
    if not src_keys:
        return state_dict

    static_renames = {
        pre + "patch_embed.proj.weight":  prefix + "embeddings.patch_embeddings.projection.weight",
        pre + "patch_embed.proj.bias":    prefix + "embeddings.patch_embeddings.projection.bias",
        pre + "pos_embed":                prefix + "embeddings.position_embeddings",
        pre + "cls_token":                prefix + "embeddings.cls_token",
        pre + "camera_token":             prefix + "embeddings.camera_token",
        pre + "norm.weight":              prefix + "layernorm.weight",
        pre + "norm.bias":                prefix + "layernorm.bias",
    }
    for src, dst in static_renames.items():
        if src in state_dict:
            state_dict[dst] = state_dict.pop(src)

    block_pre = pre + "blocks."
    block_keys = [k for k in state_dict.keys() if k.startswith(block_pre)]
    for k in block_keys:
        rest = k[len(block_pre):]                 # e.g. "5.attn.qkv.weight"
        idx_str, _, sub = rest.partition(".")
        target_block = "{}encoder.layer.{}.".format(prefix, idx_str)

        # Fused QKV -> split query/key/value linears.
        if sub == "attn.qkv.weight":
            qkv = state_dict.pop(k)
            c = qkv.shape[0] // 3
            state_dict[target_block + "attention.attention.query.weight"] = qkv[:c].clone()
            state_dict[target_block + "attention.attention.key.weight"]   = qkv[c:2 * c].clone()
            state_dict[target_block + "attention.attention.value.weight"] = qkv[2 * c:].clone()
            continue
        if sub == "attn.qkv.bias":
            qkv = state_dict.pop(k)
            c = qkv.shape[0] // 3
            state_dict[target_block + "attention.attention.query.bias"] = qkv[:c].clone()
            state_dict[target_block + "attention.attention.key.bias"]   = qkv[c:2 * c].clone()
            state_dict[target_block + "attention.attention.value.bias"] = qkv[2 * c:].clone()
            continue

        # Sub-key remap (suffix preserved).
        if sub.startswith("attn.proj."):
            tail = sub[len("attn.proj."):]
            new = "attention.output.dense." + tail
        elif sub.startswith("attn.q_norm."):
            new = "attention.q_norm." + sub[len("attn.q_norm."):]
        elif sub.startswith("attn.k_norm."):
            new = "attention.k_norm." + sub[len("attn.k_norm."):]
        elif sub == "ls1.gamma":
            new = "layer_scale1.lambda1"
        elif sub == "ls2.gamma":
            new = "layer_scale2.lambda1"
        elif sub.startswith("mlp.w12."):
            new = "mlp.weights_in." + sub[len("mlp.w12."):]
        elif sub.startswith("mlp.w3."):
            new = "mlp.weights_out." + sub[len("mlp.w3."):]
        elif sub.startswith(("norm1.", "norm2.", "mlp.fc1.", "mlp.fc2.")):
            new = sub
        else:
            # Unrecognised key -- leave as-is so a later load can complain.
            raise ValueError("Unrecognised DA3 backbone key: {}".format(k))

        state_dict[target_block + new] = state_dict.pop(k)

    return state_dict


def drop_unused(state_dict):
    for k in list(state_dict.keys()):
        if k.startswith(DROP_PREFIXES):
            state_dict.pop(k)
    return state_dict


def load_state_dict(path):
    if os.path.isdir(path):
        sd = {}
        files = sorted(glob.glob(os.path.join(path, "*.safetensors")))
        if not files:
            raise FileNotFoundError("No .safetensors files in {}".format(path))
        for f in files:
            sd.update(load_file(f))
        return sd

    if path.endswith(".safetensors"):
        return load_file(path)

    sd = torch.load(path, map_location="cpu", weights_only=False)
    # Unwrap common nesting (e.g. {"model": ...} / {"state_dict": ...}).
    for wrap in ("state_dict", "model", "module"):
        if isinstance(sd, dict) and wrap in sd and isinstance(sd[wrap], dict):
            sd = sd[wrap]
            break
    return sd


def main():
    parser = argparse.ArgumentParser(
        description="Repackage an original Depth Anything 3 checkpoint into ComfyUI's native layout"
    )
    parser.add_argument("--input", type=str, required=True,
                        help="Path to original DA3 .safetensors / .pt / .pth file or directory")
    parser.add_argument("--output", type=str, required=True,
                        help="Output .safetensors file path")
    args = parser.parse_args()

    print("Loading: {}".format(args.input))
    sd = load_state_dict(args.input)
    print("  Loaded {} keys".format(len(sd)))

    # Original DA3 checkpoints store everything under a "model." string prefix
    # (e.g. "model.backbone.pretrained.*"). Strip it so the remap works on bare
    # "backbone.*" keys, then re-add it at the end: ComfyUI's loader resolves the
    # diffusion-model prefix to "model." for DA3, so the saved file must keep it.
    if any(k.startswith("model.") for k in sd):
        print('  Stripping "model." prefix for processing')
        sd = {(k[len("model."):] if k.startswith("model.") else k): v for k, v in sd.items()}

    if any(k.startswith("backbone.pretrained.") for k in sd):
        print("  Remapping backbone (backbone.pretrained.* -> Dinov2Model layout)...")
        sd = remap_backbone_keys(sd, prefix="backbone.")
    elif any(k.startswith("backbone.embeddings.") for k in sd):
        print("  Backbone already in ComfyUI layout, skipping remap.")
    else:
        raise ValueError("Input does not look like a DA3 checkpoint (no backbone.* keys found)")

    n_before = len(sd)
    sd = drop_unused(sd)
    dropped = n_before - len(sd)
    if dropped:
        print("  Dropped {} unused Gaussian-head keys".format(dropped))

    # Re-add the "model." prefix expected by ComfyUI's diffusion-model loader.
    sd = {"model." + k: v for k, v in sd.items()}

    # safetensors requires contiguous tensors; the qkv split slices are cloned
    # above but enforce contiguity defensively for all tensors.
    sd = {k: v.contiguous() for k, v in sd.items()}

    os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True)
    save_file(sd, args.output)
    print("  Saved {} keys to {}".format(len(sd), args.output))


if __name__ == "__main__":
    main()