SAM2.1-Hiera-Tiny-Mask-Decoder / convert_sam2_decoder.py
mlboydaisuke's picture
Mirror the conversion script
a8e7c27 verified
Raw
History Blame Contribute Delete
13.7 kB
"""
SAM 2.1 (hiera-tiny) mask decoder -> LiteRT GPU-clean .tflite (Bucket 1: model-side re-authoring only)
Phase-2 companion to convert_sam2.py (the image encoder). This converts the prompt-conditioned
mask DECODER. The tiny prompt-encoder (point -> sparse tokens, sin/cos) is done HOST-SIDE in Kotlin
(see emit at the end) so the GPU graph stays sin/cos-free; the decoder takes `sparse` as an input.
Walls re-authored (all model-side; no converter patch):
1. Sam2Attention (7x: 2 blocks x 3 + 1 final) : 4D fused attn -> 3D batched SDPA [heads, N, d]
2. ConvTranspose2d (upscale_conv1/2) : -> ZeroStuffConvT (exact zero-stuff + Conv2d), TRANSPOSE_CONV-free
3. mask head (hyper_in @ upscaled) : kept <=4D (no [1,1,4,256,256] 5D tensor); collapse batch==1
4. LayerNorm (9x) : SafeLayerNorm (scale-before-square), fp16-overflow-safe, exact
5. image_positional_embeddings + no-mask dense : baked CONSTANT buffers (host doesn't supply them)
6. multimask_output=True path : static slice [1:], no dynamic-stability argmax/gather/where
Decoder I/O (single point prompt):
inputs : image_embeddings [1,256,64,64], sparse [1,2,256], feat_s1 [1,64,128,128], feat_s0 [1,32,256,256]
outputs: pred_masks [1,3,256,256] (logits, 3 multimask), iou_scores [1,3]
Run:
python convert_sam2_decoder.py # eager parity vs transformers reference (correctness gate)
python convert_sam2_decoder.py --convert # + litert_torch convert + op-gate + fp16
"""
import sys, types, argparse, math
import torch
import torch.nn as nn
import torch.nn.functional as F
# macOS scipy stub (same as convert_sam2.py)
_svdp = types.ModuleType("scipy.sparse.linalg._svdp"); _svdp._svdp = lambda *a, **k: None
sys.modules["scipy.sparse.linalg._svdp"] = _svdp
_opt = types.ModuleType("scipy.optimize"); _opt.linear_sum_assignment = lambda *a, **k: (None, None)
sys.modules["scipy.optimize"] = _opt
from transformers import Sam2Model
MODEL_ID = "facebook/sam2.1-hiera-tiny"
SCRATCH = "/private/tmp/claude-501/-Users-majimadaisuke-Downloads-meeting/4ab9d785-6580-4aef-9d43-30f02ad9879b/scratchpad/sam2"
# ----- LayerNorm. SafeLayerNorm (scale-before-square) protects the encoder's huge deep-stage
# activations from fp16 variance overflow, but the decoder's activations are normal-scale, where
# the down-scaling instead HURTS GPU fp16 (device A/B: SafeLN decoder masks the background).
# PLAIN_LN=1 uses stock LayerNorm for the decoder. -----
import os
_PLAIN_LN = os.environ.get("PLAIN_LN") == "1"
def safe_ln(x, weight, bias, eps, sc=0.03125):
if _PLAIN_LN:
xc = x - x.mean(-1, keepdim=True)
var = (xc * xc).mean(-1, keepdim=True)
return xc * torch.rsqrt(var + eps) * weight + bias
xc = x - x.mean(-1, keepdim=True)
xs = xc * sc
var = (xs * xs).mean(-1, keepdim=True) / (sc * sc)
return xc * torch.rsqrt(var + eps) * weight + bias
# ----- ZeroStuffConvT: ConvTranspose2d(k=s,stride=s) == zero-stuff(nearest x top-left mask) + Conv2d(flipped w) -----
class ZeroStuffConvT(nn.Module):
def __init__(self, ct, H, W):
super().__init__()
self.s = ct.stride[0]; self.k = ct.kernel_size[0]
self.register_buffer("w", ct.weight.flip(2, 3).transpose(0, 1).contiguous())
self.register_buffer("b", ct.bias.detach().clone() if ct.bias is not None else torch.zeros(ct.out_channels))
s = self.s
mk = torch.zeros(H * s, W * s)
mk[::s, ::s] = 1.0
self.register_buffer("mask", mk[None, None])
def forward(self, x):
H, W = x.shape[-2], x.shape[-1]
s, k = self.s, self.k
xn = F.interpolate(x, size=(H * s, W * s), mode="nearest")
y = F.conv2d(xn * self.mask, self.w, bias=self.b, padding=k - 1)
return y[:, :, :H * s, :W * s]
class CleanMaskDecoder(nn.Module):
"""Static single-point SAM2 mask decoder, GPU-clean. batch==1, point_batch==1 collapsed away."""
def __init__(self, model: Sam2Model):
super().__init__()
dec = model.mask_decoder
self.dec = dec
self.layers = dec.transformer.layers
self.final_attn = dec.transformer.final_attn_token_to_image
self.ln_final = dec.transformer.layer_norm_final_attn
self.mlps = dec.output_hypernetworks_mlps
self.iou_head = dec.iou_prediction_head
self.act = dec.activation
self.upscale_ln = dec.upscale_layer_norm # Sam2LayerNorm channels_first (32ch)
# ConvTranspose2d -> ZeroStuffConvT (input sizes are static: 64x64 -> 128 -> 256)
self.upscale_conv1 = ZeroStuffConvT(dec.upscale_conv1, 64, 64) # 256->64, 64x64 -> 128x128
self.upscale_conv2 = ZeroStuffConvT(dec.upscale_conv2, 128, 128) # 64->32, 128x128 -> 256x256
# baked constants
with torch.no_grad():
image_pos = model.get_image_wide_positional_embeddings() # [1,256,64,64]
self.register_buffer("image_pos_flat", image_pos.flatten(2).transpose(1, 2)[0].contiguous()) # [4096,256]
dense = model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(1, 256, 64, 64).contiguous()
self.register_buffer("dense", dense) # [1,256,64,64]
out_tokens = torch.cat([dec.obj_score_token.weight, dec.iou_token.weight, dec.mask_tokens.weight], 0)
self.register_buffer("output_tokens", out_tokens.contiguous()) # [6,256]
def _ln(self, ln_module, x):
return safe_ln(x, ln_module.weight, ln_module.bias, ln_module.eps)
def _attn(self, mod, query, key, value):
"""3D batched SDPA. query [Nq,C], key/value [Nk,C] -> [Nq,C]."""
Nq, Nk = query.shape[0], key.shape[0]
H, hd = mod.num_attention_heads, mod.head_dim
q = mod.q_proj(query).reshape(Nq, H, hd).transpose(0, 1) # [H,Nq,hd]
k = mod.k_proj(key).reshape(Nk, H, hd).transpose(0, 1) # [H,Nk,hd]
v = mod.v_proj(value).reshape(Nk, H, hd).transpose(0, 1) # [H,Nk,hd]
o = F.scaled_dot_product_attention(q, k, v, scale=mod.scaling) # [H,Nq,hd]
o = o.transpose(0, 1).reshape(Nq, H * hd) # [Nq, internal]
return mod.o_proj(o) # [Nq, C]
def _block(self, layer, queries, keys, qpe, kpe, skip):
if skip:
queries = self._attn(layer.self_attn, queries, queries, queries)
else:
qq = queries + qpe
queries = queries + self._attn(layer.self_attn, qq, qq, queries)
queries = self._ln(layer.layer_norm1, queries)
qq = queries + qpe; kk = keys + kpe
queries = queries + self._attn(layer.cross_attn_token_to_image, qq, kk, keys)
queries = self._ln(layer.layer_norm2, queries)
queries = queries + layer.mlp(queries)
queries = self._ln(layer.layer_norm3, queries)
qq = queries + qpe; kk = keys + kpe
keys = keys + self._attn(layer.cross_attn_image_to_token, kk, qq, queries)
keys = self._ln(layer.layer_norm4, keys)
return queries, keys
def forward(self, image_embeddings, sparse, feat_s1, feat_s0):
keys = (image_embeddings + self.dense).flatten(2).transpose(1, 2)[0] # [4096,256]
kpe = self.image_pos_flat # [4096,256]
queries = torch.cat([self.output_tokens, sparse[0]], 0) # [8,256]
qpe = queries # query_point_embedding (constant across layers)
q, k = queries, keys
q, k = self._block(self.layers[0], q, k, qpe, kpe, skip=True)
q, k = self._block(self.layers[1], q, k, qpe, kpe, skip=False)
fq, fk = q + qpe, k + kpe
q = q + self._attn(self.final_attn, fq, fk, k)
q = self._ln(self.ln_final, q)
iou_tok = q[1:2] # [1,256]
mask_toks = q[2:6] # [4,256]
img = k.transpose(0, 1).reshape(1, 256, 64, 64) # [1,256,64,64]
u = self.upscale_conv1(img) + feat_s1 # [1,64,128,128]
# upscale_layer_norm: channels_first SafeLN over the 64 channels
u = u.permute(0, 2, 3, 1)
u = safe_ln(u, self.upscale_ln.weight, self.upscale_ln.bias, self.upscale_ln.eps)
u = u.permute(0, 3, 1, 2)
u = self.act(u)
u = self.act(self.upscale_conv2(u) + feat_s0) # [1,32,256,256]
hyper = torch.cat([self.mlps[j](mask_toks[j:j + 1]) for j in range(4)], 0) # [4,32]
uf = u.reshape(32, 256 * 256) # [32,65536]
masks = (hyper @ uf).reshape(4, 256, 256)[1:].unsqueeze(0) # [1,3,256,256]
iou = self.iou_head(iou_tok)[:, 1:] # [1,3]
return masks, iou
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--convert", action="store_true")
args = ap.parse_args()
m = Sam2Model.from_pretrained(MODEL_ID).eval()
ref = torch.load(f"{SCRATCH}/ref_decoder.pt")
image_embeddings = ref["image_embeddings"] # list of 3
img_emb = image_embeddings[-1] # [1,256,64,64]
feat_s0, feat_s1 = image_embeddings[0], image_embeddings[1]
sparse = ref["sparse"][0] # [1,1,2,256] -> [1,2,256]
ref_masks, ref_iou = ref["masks"], ref["iou"] # [1,1,3,256,256], [1,1,3]
net = CleanMaskDecoder(m).eval()
with torch.no_grad():
masks, iou = net(img_emb, sparse, feat_s1, feat_s0)
rm = ref_masks.reshape(3, -1)
gm = masks.reshape(3, -1)
cos = F.cosine_similarity(gm.flatten(), rm.flatten(), dim=0).item()
mae = (gm - rm).abs().mean().item()
# mask agreement (binary IoU at threshold 0)
inter = ((gm > 0) & (rm > 0)).float().sum().item()
union = ((gm > 0) | (rm > 0)).float().sum().item()
iou_mask = inter / max(union, 1.0)
print(f"[eager] masks cos={cos:.6f} mae={mae:.3e} | binary-IoU(thr0)={iou_mask:.5f}")
print(f"[eager] iou ref={ref_iou.flatten().tolist()}")
print(f"[eager] iou got={iou.flatten().tolist()}")
assert cos > 0.9999, f"re-authoring changed the math! cos={cos}"
print(" -> re-authoring is numerically exact ✓")
if args.convert:
import os, collections, numpy as np, litert_torch
from ai_edge_litert.interpreter import Interpreter
BANNED = {"GATHER_ND", "GATHER", "TOPK_V2", "FLEX_ERF", "ERF", "BROADCAST_TO", "TRANSPOSE_CONV"}
FP32 = f"{SCRATCH}/sam2_tiny_dec_fp32.tflite"
FP16 = f"{SCRATCH}/sam2_tiny_mask_decoder_fp16.tflite"
ex = (img_emb, sparse, feat_s1, feat_s0)
with torch.no_grad():
ref_out = [t.detach().numpy().astype("float64").reshape(-1) for t in net(*ex)]
print("converting (litert_torch) ...")
litert_torch.convert(net, ex).export(FP32)
def gate(path, tag):
it = Interpreter(model_path=path); it.allocate_tensors()
hist = collections.Counter(d["op_name"] for d in it._get_ops_details())
over4d = sum(1 for d in it.get_tensor_details() if len(d.get("shape", [])) > 4)
bad = {k: v for k, v in hist.items() if k in BANNED}
print(f"[{tag}] ops: {dict(sorted(hist.items(), key=lambda kv: -kv[1]))}")
print(f"[{tag}] banned: {bad or 'NONE'} | >4D tensors: {over4d}")
return it, bad, over4d
def parity(it, tag):
ins = it.get_input_details()
order = [img_emb, sparse, feat_s1, feat_s0]
# match each model input slot to our tensors by shape
for d in ins:
want = next(t for t in order if tuple(t.shape) == tuple(d["shape"]))
it.set_tensor(d["index"], want.numpy().astype(d["dtype"]))
it.invoke()
outs = [it.get_tensor(o["index"]).astype("float64").reshape(-1) for o in it.get_output_details()]
for ro in ref_out:
cand = [o for o in outs if o.size == ro.size]
if cand:
c = max(np.corrcoef(ro, o)[0, 1] for o in cand)
print(f"[{tag}] parity corr={c:.6f} (len {ro.size})")
it32, bad, over4d = gate(FP32, "FP32")
parity(it32, "FP32")
print("quantizing fp16 (FLOAT_CASTING) ...")
from ai_edge_quantizer import quantizer, recipe_manager
from ai_edge_quantizer.recipe import AlgorithmName, qtyping
rmgr = recipe_manager.RecipeManager()
rmgr.add_quantization_config(
regex=".*", operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
op_config=qtyping.OpQuantizationConfig(
weight_tensor_config=qtyping.TensorQuantizationConfig(num_bits=16, dtype=qtyping.TensorDataType.FLOAT),
compute_precision=qtyping.ComputePrecision.FLOAT),
algorithm_key=AlgorithmName.FLOAT_CASTING)
if os.path.exists(FP16):
os.remove(FP16)
qt = quantizer.Quantizer(float_model=FP32)
qt.load_quantization_recipe(rmgr.get_quantization_recipe())
qt.quantize().export_model(FP16)
print(f"SIZE fp32 {os.path.getsize(FP32)/1e6:.1f} MB -> fp16 {os.path.getsize(FP16)/1e6:.1f} MB")
it16, bad16, over4d16 = gate(FP16, "FP16")
parity(it16, "FP16")
print(f"\n{'OK -> GPU-clean' if not bad16 and over4d16 == 0 else 'BLOCKERS REMAIN'}: {FP16}")
if __name__ == "__main__":
main()