Depth-Anything-3-LiteRT / convert_da3.py
mlboydaisuke's picture
Upload convert_da3.py with huggingface_hub
33377c8 verified
Raw
History Blame Contribute Delete
13.5 kB
#!/usr/bin/env python3
"""Produce the full GPU-clean DA3-SMALL mono tflite:
weight-prefix fix + C14 (RoPE const) + C12 (qkv-decompose) + pos_embed bilinear + ConvTranspose->bilinear+1x1.
Measures depth corr vs the ORIGINAL model, op-checks, then FP16-quantizes."""
import sys, types, math, numpy as np, json, collections
class _Dummy:
def __getattr__(self, n): return lambda *a, **k: None
_pp = types.ModuleType('scipy.sparse.linalg._propack')
for nm in ('_spropack', '_dpropack', '_cpropack', '_zpropack'): setattr(_pp, nm, _Dummy())
sys.modules['scipy.sparse.linalg._propack'] = _pp
for _n, _t in (("bool", bool), ("float", float), ("int", int), ("object", object), ("str", str)):
if not hasattr(np, _n): setattr(np, _n, _t)
sys.path.insert(0, "src")
import torch, torch.nn as nn, torch.nn.functional as F, types as _ty
from PIL import Image
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from depth_anything_3.cfg import create_object
from depth_anything_3.model.dinov2.layers.attention import Attention
import depth_anything_3.model.dinov2.layers.rope as _rope
# C14: RoPE max_position = int(positions.max())+1 is data-dependent (torch.export aborts). Fix = constant
# 128 (>= the 33 a 518 grid needs); extra table rows are unused -> numerically identical to stock.
def _rope_fwd(self, tokens, positions):
fd = tokens.size(-1) // 2
c, s = self._compute_frequency_components(fd, 128, tokens.device, tokens.dtype)
v, h = tokens.chunk(2, dim=-1)
v = self._apply_1d_rope(v, positions[..., 0], c, s)
h = self._apply_1d_rope(h, positions[..., 1], c, s)
return torch.cat((v, h), dim=-1)
_rope.RotaryPositionEmbedding2D.forward = _rope_fwd
cfg = json.load(open(hf_hub_download("depth-anything/DA3-SMALL", "config.json")))
net = create_object(cfg["config"]).eval()
sd = load_file(hf_hub_download("depth-anything/DA3-SMALL", "model.safetensors"))
net.load_state_dict({(k[6:] if k.startswith("model.") else k): v for k, v in sd.items()}, strict=False)
class Mono(nn.Module):
def __init__(s, n): super().__init__(); s.b, s.h = n.backbone, n.head
def forward(s, x):
H, W = x.shape[-2], x.shape[-1]
f, _ = s.b(x.unsqueeze(1), cam_token=None)
return s.h(f, H, W, patch_start_idx=0)["depth"]
m = Mono(net).eval()
# input H,W (multiples of 14). argv[2]=H, argv[3]=W (default square). Native aspect (no padding) matches official.
H_IN = int(sys.argv[2]) if len(sys.argv) > 2 else 448
W_IN = int(sys.argv[3]) if len(sys.argv) > 3 else H_IN
S = H_IN # back-compat alias used below for the square-ish paths
img = Image.open(sys.argv[1]).convert("RGB").resize((W_IN, H_IN), Image.BILINEAR)
a = (np.asarray(img, np.float32)/255.0 - [0.485,0.456,0.406]) / [0.229,0.224,0.225]
x_img = torch.from_numpy(np.transpose(a, (2,0,1))[None].astype(np.float32))
with torch.no_grad():
d_orig = m(x_img)[0,0].numpy() # ORIGINAL (all stock)
# ---- C12: qkv-decompose -> 4D manual attention ----
def _attn(self, x, pos=None, attn_mask=None):
B,N,C = x.shape; H=self.num_heads; Hd=C//H
q=self.q_lin(x).reshape(B,N,H,Hd).permute(0,2,1,3); k=self.k_lin(x).reshape(B,N,H,Hd).permute(0,2,1,3)
v=self.v_lin(x).reshape(B,N,H,Hd).permute(0,2,1,3); q,k=self.q_norm(q),self.k_norm(k)
if self.rope is not None and pos is not None: q=self.rope(q,pos); k=self.rope(k,pos)
q=q*self.scale; attn=(q@k.transpose(-2,-1)).softmax(-1)
return self.proj_drop(self.proj((attn@v).transpose(1,2).reshape(B,N,C)))
Attention.forward = _attn
for mod in net.modules():
if isinstance(mod, Attention):
C=mod.qkv.in_features; w=mod.qkv.weight; b=mod.qkv.bias
for nm,sl in (("q_lin",slice(0,C)),("k_lin",slice(C,2*C)),("v_lin",slice(2*C,3*C))):
lin=nn.Linear(C,C,bias=b is not None)
with torch.no_grad():
lin.weight.copy_(w[sl]); b is not None and lin.bias.copy_(b[sl])
setattr(mod,nm,lin)
# ---- LayerScale bake (GPU FC-layout fix): fold ls1/ls2 gamma into attn.proj / mlp.fc2, ls->Identity.
# The LayerScale MUL (FC output [N,C] * gamma [C]) makes ML Drift mis-lay-out the token dim
# ({1,1,1025,384} vs {1025,1,1,384}) -> GPU compile fails. Baking eliminates the MUL. (MoGe's fix.)
def bake_layerscale(model):
cnt = 0
for block in model.modules():
if hasattr(block, "ls1") and hasattr(getattr(block, "ls1"), "gamma") and hasattr(block, "attn"):
g = block.ls1.gamma.data.squeeze()
with torch.no_grad():
block.attn.proj.weight.data.mul_(g.unsqueeze(1))
if block.attn.proj.bias is not None: block.attn.proj.bias.data.mul_(g)
block.ls1 = nn.Identity(); cnt += 1
if hasattr(block, "ls2") and hasattr(getattr(block, "ls2"), "gamma") and hasattr(block, "mlp"):
g = block.ls2.gamma.data.squeeze()
last = None
for ch in reversed(list(block.mlp.children())):
if isinstance(ch, nn.Linear): last = ch; break
if last is None and hasattr(block.mlp, "fc2"): last = block.mlp.fc2
if last is not None:
with torch.no_grad():
last.weight.data.mul_(g.unsqueeze(1))
if last.bias is not None: last.bias.data.mul_(g)
block.ls2 = nn.Identity(); cnt += 1
print(f"baked {cnt} LayerScale into Linear")
return cnt
bake_layerscale(net)
# ---- C15: pos_embed BAKE (interpolating the constant pos_embed emits RESIZE_BILINEAR with 0 runtime
# inputs -> GPU rejects). Precompute the bilinear-resized pos_embed as a constant buffer -> no resize op. ----
GH, GW = H_IN // 14, W_IN // 14
for mod in net.modules():
if hasattr(mod, "interpolate_pos_encoding") and hasattr(mod, "pos_embed"):
with torch.no_grad():
N = mod.pos_embed.shape[1] - 1; Mb = int(math.sqrt(N)); dim = mod.pos_embed.shape[-1]
pe = mod.pos_embed.float(); cls = pe[:, 0]; patch = pe[:, 1:]
patch = F.interpolate(patch.reshape(1, Mb, Mb, dim).permute(0, 3, 1, 2),
size=(GH, GW), mode="bicubic", antialias=False) # bicubic = match official (baked const)
patch = patch.permute(0, 2, 3, 1).view(1, -1, dim)
baked = torch.cat((cls.unsqueeze(0), patch), dim=1).to(mod.pos_embed.dtype) # [1,1025,dim]
mod.register_buffer("_baked_pos", baked)
mod.interpolate_pos_encoding = _ty.MethodType(lambda self, x, w, h: self._baked_pos, mod)
# ---- SELECT_V2 fix: `x[:, :, 0] = cam_token` (in-place index assign) lowers to SELECT_V2 with a
# broadcast 'else' shape the GPU delegate rejects. Replace with an equivalent torch.cat (exact, GPU-clean).
# (alt_start must stay on — the camera-token / global-attn path DOES affect mono depth.) ----
import depth_anything_3.model.dinov2.vision_transformer as _vt
def _patched_gil(self, x, n=1, export_feat_layers=[], **kwargs):
B, S, _, H, W = x.shape
x = self.prepare_tokens_with_masks(x)
output, total_block_len, aux_output = [], len(self.blocks), []
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
pos, pos_nodiff = self._prepare_rope(B, S, H, W, x.device)
local_x = x
for i, blk in enumerate(self.blocks):
if i < self.rope_start or self.rope is None:
g_pos, l_pos = None, None
else:
g_pos, l_pos = pos_nodiff, pos
if self.alt_start != -1 and (i == self.alt_start - 1) and x.shape[1] >= _vt.THRESH_FOR_REF_SELECTION and kwargs.get("cam_token", None) is None:
b_idx = _vt.select_reference_view(x, strategy=kwargs.get("ref_view_strategy", "saddle_balanced"))
x = _vt.reorder_by_reference(x, b_idx); local_x = _vt.reorder_by_reference(local_x, b_idx)
if self.alt_start != -1 and i == self.alt_start:
if kwargs.get("cam_token", None) is not None:
cam_token = kwargs.get("cam_token")
else:
ref_token = self.camera_token[:, :1].expand(B, -1, -1)
src_token = self.camera_token[:, 1:].expand(B, S - 1, -1)
cam_token = torch.cat([ref_token, src_token], dim=1)
x = torch.cat([cam_token.unsqueeze(2), x[:, :, 1:]], dim=2) # was: x[:, :, 0] = cam_token
if self.alt_start != -1 and i >= self.alt_start and i % 2 == 1:
x = self.process_attention(x, blk, "global", pos=g_pos, attn_mask=kwargs.get("attn_mask", None))
else:
x = self.process_attention(x, blk, "local", pos=l_pos); local_x = x
if i in blocks_to_take:
out_x = torch.cat([local_x, x], dim=-1) if self.cat_token else x
if x.shape[1] >= _vt.THRESH_FOR_REF_SELECTION and self.alt_start != -1 and "b_idx" in locals():
out_x = _vt.restore_original_order(out_x, b_idx)
output.append((out_x[:, :, 0], out_x))
if i in export_feat_layers:
aux_output.append(x)
return output, aux_output
_vt.DinoVisionTransformer._get_intermediate_layers_not_chunked = _patched_gil
# ---- TRANSPOSE_CONV (Pixel 8a reject): ConvTranspose2d(k=s,stride=s) -> bilinear-resize + 1x1 conv ----
# 1x1 conv weight = mean of the transposed kernel over its s*s positions (preserves channel mixing;
# spatial upsampling handled by bilinear). 1x1 conv commutes with bilinear so order is exact for the mix.
# EXACT GPU-clean equivalent: ConvTranspose2d(k=s,stride=s) == zero-stuff (nearest-upsample × top-left
# mask) + Conv2d(flipped weight). Matches the learned upsampler to ~1e-7 → depth stays as sharp as the
# original (a bilinear approx blurred it). Mask is a precomputed constant buffer (no index-assign/broadcast).
class ZeroStuffConvT(nn.Module):
def __init__(self, ct, H, W):
super().__init__(); self.s = ct.stride[0]; self.k = ct.kernel_size[0]
self.register_buffer("w", ct.weight.flip(2, 3).transpose(0, 1).contiguous())
self.register_buffer("b", ct.bias.detach().clone() if ct.bias is not None else torch.zeros(ct.out_channels))
s = self.s; mk = np.zeros((H*s, W*s), np.float32); mk[::s, ::s] = 1.0
self.register_buffer("mask", torch.from_numpy(mk)[None, None])
def forward(self, x):
H, W = x.shape[-2], x.shape[-1]; s, k = self.s, self.k
xn = F.interpolate(x, size=(H*s, W*s), mode="nearest")
y = F.conv2d(xn * self.mask, self.w, bias=self.b, padding=k-1)
return y[:, :, :H*s, :W*s]
# discover each ConvTranspose input size via a dry run, then swap with the exact equivalent
_ct_hw, _hooks = {}, []
for _nm, _mod in net.named_modules():
if isinstance(_mod, nn.ConvTranspose2d):
def _mk(nm):
def _h(m, inp, out): _ct_hw[nm] = (inp[0].shape[-2], inp[0].shape[-1])
return _h
_hooks.append(_mod.register_forward_hook(_mk(_nm)))
with torch.no_grad(): m(x_img)
for _hk in _hooks: _hk.remove()
def swap_ct(module, prefix=""):
for name, ch in module.named_children():
full = f"{prefix}.{name}" if prefix else name
if isinstance(ch, nn.ConvTranspose2d):
H, W = _ct_hw[full]; setattr(module, name, ZeroStuffConvT(ch, H, W))
else: swap_ct(ch, full)
swap_ct(net)
# ---- DPT head: align_corners=True -> False (banned RESIZE_BILINEAR) + drop _add_pos_embed expand (BROADCAST_TO) ----
import depth_anything_3.model.utils.head_utils as _hu
import depth_anything_3.model.dualdpt as _dd
import depth_anything_3.model.dpt as _dpt
_orig_ci = _hu.custom_interpolate
def _ci_no_ac(x, size=None, scale_factor=None, mode="bilinear", align_corners=True):
return _orig_ci(x, size=size, scale_factor=scale_factor, mode=mode, align_corners=False)
_hu.custom_interpolate = _ci_no_ac; _dd.custom_interpolate = _ci_no_ac; _dpt.custom_interpolate = _ci_no_ac
# head pos-embed-again (UV sincos, ratio 0.1): make_sincos broadcast emits BROADCAST_TO. Baking it as
# constants matches official ~0.0002 better but adds ~64 MB (full [1,C,H,W] per shape) — not worth it.
# Disable it (the ratio-0.1 UV refinement is negligible vs the size cost).
_n_pe = 0
for mod in net.modules():
if isinstance(mod, _dd.DualDPT) and getattr(mod, "pos_embed", False):
mod.pos_embed = False; _n_pe += 1
print(f"disabled head pos_embed on {_n_pe} DualDPT")
with torch.no_grad():
d_clean = m(x_img)[0,0].numpy() # FULLY GPU-CLEAN
corr = np.corrcoef(d_orig.flatten(), d_clean.flatten())[0,1]
print(f"depth corr (original vs full GPU-clean) = {corr:.6f} mean-rel-diff = {np.abs(d_orig-d_clean).mean()/(np.abs(d_orig).mean()+1e-9)*100:.3f}%")
# ---- convert + op-check (canonical banned list; SELECT_V2 is NOT banned) ----
dummy = torch.rand(1, 3, H_IN, W_IN)
import litert_torch
litert_torch.convert(m.eval(), (dummy,)).export("da3_small_gpu.tflite")
from ai_edge_litert.interpreter import Interpreter
BANNED={'GATHER_ND','GATHER','TOPK_V2','PACK','SPLIT','FLEX_ERF','ERF','TRANSPOSE_CONV','BROADCAST_TO'}
it=Interpreter(model_path="da3_small_gpu.tflite"); it.allocate_tensors()
ops=collections.Counter(d.get('op_name','?') for d in it._get_ops_details())
bad={k:v for k,v in ops.items() if k in BANNED}
over=sum(1 for d in it.get_tensor_details() if len(d.get('shape',[]))>4)
print(f"op-check FP32: banned {bad or 'NONE'} | >4D {over} | GELU {ops.get('GELU',0)} | SELECT_V2 {ops.get('SELECT_V2',0)}")
print("VERDICT:", "GPU-CLEAN" if not bad and not over else "BLOCKERS REMAIN")
import os; print("FP32 size %.1f MB" % (os.path.getsize("da3_small_gpu.tflite")/1e6))