Mask Generation
LiteRT
LiteRT
sam2
segment-anything
mask-decoder
interactive-segmentation
on-device
gpu
Instructions to use litert-community/SAM2.1-Hiera-Tiny-Mask-Decoder with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- LiteRT
How to use litert-community/SAM2.1-Hiera-Tiny-Mask-Decoder with LiteRT:
# No code snippets available yet for this library. # To use this model, check the repository files and the library's documentation. # Want to help? PRs adding snippets are welcome at: # https://github.com/huggingface/huggingface.js
- sam2
How to use litert-community/SAM2.1-Hiera-Tiny-Mask-Decoder with sam2:
# Use SAM2 with images import torch from sam2.sam2_image_predictor import SAM2ImagePredictor predictor = SAM2ImagePredictor.from_pretrained(litert-community/SAM2.1-Hiera-Tiny-Mask-Decoder) with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): predictor.set_image(<your_image>) masks, _, _ = predictor.predict(<input_prompts>)# Use SAM2 with videos import torch from sam2.sam2_video_predictor import SAM2VideoPredictor predictor = SAM2VideoPredictor.from_pretrained(litert-community/SAM2.1-Hiera-Tiny-Mask-Decoder) with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): state = predictor.init_state(<your_video>) # add new prompts and instantly get the output on the same frame frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>): # propagate the prompts to get masklets throughout the video for frame_idx, object_ids, masks in predictor.propagate_in_video(state): ... - Notebooks
- Google Colab
- Kaggle
File size: 13,658 Bytes
a8e7c27 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 | """
SAM 2.1 (hiera-tiny) mask decoder -> LiteRT GPU-clean .tflite (Bucket 1: model-side re-authoring only)
Phase-2 companion to convert_sam2.py (the image encoder). This converts the prompt-conditioned
mask DECODER. The tiny prompt-encoder (point -> sparse tokens, sin/cos) is done HOST-SIDE in Kotlin
(see emit at the end) so the GPU graph stays sin/cos-free; the decoder takes `sparse` as an input.
Walls re-authored (all model-side; no converter patch):
1. Sam2Attention (7x: 2 blocks x 3 + 1 final) : 4D fused attn -> 3D batched SDPA [heads, N, d]
2. ConvTranspose2d (upscale_conv1/2) : -> ZeroStuffConvT (exact zero-stuff + Conv2d), TRANSPOSE_CONV-free
3. mask head (hyper_in @ upscaled) : kept <=4D (no [1,1,4,256,256] 5D tensor); collapse batch==1
4. LayerNorm (9x) : SafeLayerNorm (scale-before-square), fp16-overflow-safe, exact
5. image_positional_embeddings + no-mask dense : baked CONSTANT buffers (host doesn't supply them)
6. multimask_output=True path : static slice [1:], no dynamic-stability argmax/gather/where
Decoder I/O (single point prompt):
inputs : image_embeddings [1,256,64,64], sparse [1,2,256], feat_s1 [1,64,128,128], feat_s0 [1,32,256,256]
outputs: pred_masks [1,3,256,256] (logits, 3 multimask), iou_scores [1,3]
Run:
python convert_sam2_decoder.py # eager parity vs transformers reference (correctness gate)
python convert_sam2_decoder.py --convert # + litert_torch convert + op-gate + fp16
"""
import sys, types, argparse, math
import torch
import torch.nn as nn
import torch.nn.functional as F
# macOS scipy stub (same as convert_sam2.py)
_svdp = types.ModuleType("scipy.sparse.linalg._svdp"); _svdp._svdp = lambda *a, **k: None
sys.modules["scipy.sparse.linalg._svdp"] = _svdp
_opt = types.ModuleType("scipy.optimize"); _opt.linear_sum_assignment = lambda *a, **k: (None, None)
sys.modules["scipy.optimize"] = _opt
from transformers import Sam2Model
MODEL_ID = "facebook/sam2.1-hiera-tiny"
SCRATCH = "/private/tmp/claude-501/-Users-majimadaisuke-Downloads-meeting/4ab9d785-6580-4aef-9d43-30f02ad9879b/scratchpad/sam2"
# ----- LayerNorm. SafeLayerNorm (scale-before-square) protects the encoder's huge deep-stage
# activations from fp16 variance overflow, but the decoder's activations are normal-scale, where
# the down-scaling instead HURTS GPU fp16 (device A/B: SafeLN decoder masks the background).
# PLAIN_LN=1 uses stock LayerNorm for the decoder. -----
import os
_PLAIN_LN = os.environ.get("PLAIN_LN") == "1"
def safe_ln(x, weight, bias, eps, sc=0.03125):
if _PLAIN_LN:
xc = x - x.mean(-1, keepdim=True)
var = (xc * xc).mean(-1, keepdim=True)
return xc * torch.rsqrt(var + eps) * weight + bias
xc = x - x.mean(-1, keepdim=True)
xs = xc * sc
var = (xs * xs).mean(-1, keepdim=True) / (sc * sc)
return xc * torch.rsqrt(var + eps) * weight + bias
# ----- ZeroStuffConvT: ConvTranspose2d(k=s,stride=s) == zero-stuff(nearest x top-left mask) + Conv2d(flipped w) -----
class ZeroStuffConvT(nn.Module):
def __init__(self, ct, H, W):
super().__init__()
self.s = ct.stride[0]; self.k = ct.kernel_size[0]
self.register_buffer("w", ct.weight.flip(2, 3).transpose(0, 1).contiguous())
self.register_buffer("b", ct.bias.detach().clone() if ct.bias is not None else torch.zeros(ct.out_channels))
s = self.s
mk = torch.zeros(H * s, W * s)
mk[::s, ::s] = 1.0
self.register_buffer("mask", mk[None, None])
def forward(self, x):
H, W = x.shape[-2], x.shape[-1]
s, k = self.s, self.k
xn = F.interpolate(x, size=(H * s, W * s), mode="nearest")
y = F.conv2d(xn * self.mask, self.w, bias=self.b, padding=k - 1)
return y[:, :, :H * s, :W * s]
class CleanMaskDecoder(nn.Module):
"""Static single-point SAM2 mask decoder, GPU-clean. batch==1, point_batch==1 collapsed away."""
def __init__(self, model: Sam2Model):
super().__init__()
dec = model.mask_decoder
self.dec = dec
self.layers = dec.transformer.layers
self.final_attn = dec.transformer.final_attn_token_to_image
self.ln_final = dec.transformer.layer_norm_final_attn
self.mlps = dec.output_hypernetworks_mlps
self.iou_head = dec.iou_prediction_head
self.act = dec.activation
self.upscale_ln = dec.upscale_layer_norm # Sam2LayerNorm channels_first (32ch)
# ConvTranspose2d -> ZeroStuffConvT (input sizes are static: 64x64 -> 128 -> 256)
self.upscale_conv1 = ZeroStuffConvT(dec.upscale_conv1, 64, 64) # 256->64, 64x64 -> 128x128
self.upscale_conv2 = ZeroStuffConvT(dec.upscale_conv2, 128, 128) # 64->32, 128x128 -> 256x256
# baked constants
with torch.no_grad():
image_pos = model.get_image_wide_positional_embeddings() # [1,256,64,64]
self.register_buffer("image_pos_flat", image_pos.flatten(2).transpose(1, 2)[0].contiguous()) # [4096,256]
dense = model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(1, 256, 64, 64).contiguous()
self.register_buffer("dense", dense) # [1,256,64,64]
out_tokens = torch.cat([dec.obj_score_token.weight, dec.iou_token.weight, dec.mask_tokens.weight], 0)
self.register_buffer("output_tokens", out_tokens.contiguous()) # [6,256]
def _ln(self, ln_module, x):
return safe_ln(x, ln_module.weight, ln_module.bias, ln_module.eps)
def _attn(self, mod, query, key, value):
"""3D batched SDPA. query [Nq,C], key/value [Nk,C] -> [Nq,C]."""
Nq, Nk = query.shape[0], key.shape[0]
H, hd = mod.num_attention_heads, mod.head_dim
q = mod.q_proj(query).reshape(Nq, H, hd).transpose(0, 1) # [H,Nq,hd]
k = mod.k_proj(key).reshape(Nk, H, hd).transpose(0, 1) # [H,Nk,hd]
v = mod.v_proj(value).reshape(Nk, H, hd).transpose(0, 1) # [H,Nk,hd]
o = F.scaled_dot_product_attention(q, k, v, scale=mod.scaling) # [H,Nq,hd]
o = o.transpose(0, 1).reshape(Nq, H * hd) # [Nq, internal]
return mod.o_proj(o) # [Nq, C]
def _block(self, layer, queries, keys, qpe, kpe, skip):
if skip:
queries = self._attn(layer.self_attn, queries, queries, queries)
else:
qq = queries + qpe
queries = queries + self._attn(layer.self_attn, qq, qq, queries)
queries = self._ln(layer.layer_norm1, queries)
qq = queries + qpe; kk = keys + kpe
queries = queries + self._attn(layer.cross_attn_token_to_image, qq, kk, keys)
queries = self._ln(layer.layer_norm2, queries)
queries = queries + layer.mlp(queries)
queries = self._ln(layer.layer_norm3, queries)
qq = queries + qpe; kk = keys + kpe
keys = keys + self._attn(layer.cross_attn_image_to_token, kk, qq, queries)
keys = self._ln(layer.layer_norm4, keys)
return queries, keys
def forward(self, image_embeddings, sparse, feat_s1, feat_s0):
keys = (image_embeddings + self.dense).flatten(2).transpose(1, 2)[0] # [4096,256]
kpe = self.image_pos_flat # [4096,256]
queries = torch.cat([self.output_tokens, sparse[0]], 0) # [8,256]
qpe = queries # query_point_embedding (constant across layers)
q, k = queries, keys
q, k = self._block(self.layers[0], q, k, qpe, kpe, skip=True)
q, k = self._block(self.layers[1], q, k, qpe, kpe, skip=False)
fq, fk = q + qpe, k + kpe
q = q + self._attn(self.final_attn, fq, fk, k)
q = self._ln(self.ln_final, q)
iou_tok = q[1:2] # [1,256]
mask_toks = q[2:6] # [4,256]
img = k.transpose(0, 1).reshape(1, 256, 64, 64) # [1,256,64,64]
u = self.upscale_conv1(img) + feat_s1 # [1,64,128,128]
# upscale_layer_norm: channels_first SafeLN over the 64 channels
u = u.permute(0, 2, 3, 1)
u = safe_ln(u, self.upscale_ln.weight, self.upscale_ln.bias, self.upscale_ln.eps)
u = u.permute(0, 3, 1, 2)
u = self.act(u)
u = self.act(self.upscale_conv2(u) + feat_s0) # [1,32,256,256]
hyper = torch.cat([self.mlps[j](mask_toks[j:j + 1]) for j in range(4)], 0) # [4,32]
uf = u.reshape(32, 256 * 256) # [32,65536]
masks = (hyper @ uf).reshape(4, 256, 256)[1:].unsqueeze(0) # [1,3,256,256]
iou = self.iou_head(iou_tok)[:, 1:] # [1,3]
return masks, iou
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--convert", action="store_true")
args = ap.parse_args()
m = Sam2Model.from_pretrained(MODEL_ID).eval()
ref = torch.load(f"{SCRATCH}/ref_decoder.pt")
image_embeddings = ref["image_embeddings"] # list of 3
img_emb = image_embeddings[-1] # [1,256,64,64]
feat_s0, feat_s1 = image_embeddings[0], image_embeddings[1]
sparse = ref["sparse"][0] # [1,1,2,256] -> [1,2,256]
ref_masks, ref_iou = ref["masks"], ref["iou"] # [1,1,3,256,256], [1,1,3]
net = CleanMaskDecoder(m).eval()
with torch.no_grad():
masks, iou = net(img_emb, sparse, feat_s1, feat_s0)
rm = ref_masks.reshape(3, -1)
gm = masks.reshape(3, -1)
cos = F.cosine_similarity(gm.flatten(), rm.flatten(), dim=0).item()
mae = (gm - rm).abs().mean().item()
# mask agreement (binary IoU at threshold 0)
inter = ((gm > 0) & (rm > 0)).float().sum().item()
union = ((gm > 0) | (rm > 0)).float().sum().item()
iou_mask = inter / max(union, 1.0)
print(f"[eager] masks cos={cos:.6f} mae={mae:.3e} | binary-IoU(thr0)={iou_mask:.5f}")
print(f"[eager] iou ref={ref_iou.flatten().tolist()}")
print(f"[eager] iou got={iou.flatten().tolist()}")
assert cos > 0.9999, f"re-authoring changed the math! cos={cos}"
print(" -> re-authoring is numerically exact ✓")
if args.convert:
import os, collections, numpy as np, litert_torch
from ai_edge_litert.interpreter import Interpreter
BANNED = {"GATHER_ND", "GATHER", "TOPK_V2", "FLEX_ERF", "ERF", "BROADCAST_TO", "TRANSPOSE_CONV"}
FP32 = f"{SCRATCH}/sam2_tiny_dec_fp32.tflite"
FP16 = f"{SCRATCH}/sam2_tiny_mask_decoder_fp16.tflite"
ex = (img_emb, sparse, feat_s1, feat_s0)
with torch.no_grad():
ref_out = [t.detach().numpy().astype("float64").reshape(-1) for t in net(*ex)]
print("converting (litert_torch) ...")
litert_torch.convert(net, ex).export(FP32)
def gate(path, tag):
it = Interpreter(model_path=path); it.allocate_tensors()
hist = collections.Counter(d["op_name"] for d in it._get_ops_details())
over4d = sum(1 for d in it.get_tensor_details() if len(d.get("shape", [])) > 4)
bad = {k: v for k, v in hist.items() if k in BANNED}
print(f"[{tag}] ops: {dict(sorted(hist.items(), key=lambda kv: -kv[1]))}")
print(f"[{tag}] banned: {bad or 'NONE'} | >4D tensors: {over4d}")
return it, bad, over4d
def parity(it, tag):
ins = it.get_input_details()
order = [img_emb, sparse, feat_s1, feat_s0]
# match each model input slot to our tensors by shape
for d in ins:
want = next(t for t in order if tuple(t.shape) == tuple(d["shape"]))
it.set_tensor(d["index"], want.numpy().astype(d["dtype"]))
it.invoke()
outs = [it.get_tensor(o["index"]).astype("float64").reshape(-1) for o in it.get_output_details()]
for ro in ref_out:
cand = [o for o in outs if o.size == ro.size]
if cand:
c = max(np.corrcoef(ro, o)[0, 1] for o in cand)
print(f"[{tag}] parity corr={c:.6f} (len {ro.size})")
it32, bad, over4d = gate(FP32, "FP32")
parity(it32, "FP32")
print("quantizing fp16 (FLOAT_CASTING) ...")
from ai_edge_quantizer import quantizer, recipe_manager
from ai_edge_quantizer.recipe import AlgorithmName, qtyping
rmgr = recipe_manager.RecipeManager()
rmgr.add_quantization_config(
regex=".*", operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
op_config=qtyping.OpQuantizationConfig(
weight_tensor_config=qtyping.TensorQuantizationConfig(num_bits=16, dtype=qtyping.TensorDataType.FLOAT),
compute_precision=qtyping.ComputePrecision.FLOAT),
algorithm_key=AlgorithmName.FLOAT_CASTING)
if os.path.exists(FP16):
os.remove(FP16)
qt = quantizer.Quantizer(float_model=FP32)
qt.load_quantization_recipe(rmgr.get_quantization_recipe())
qt.quantize().export_model(FP16)
print(f"SIZE fp32 {os.path.getsize(FP32)/1e6:.1f} MB -> fp16 {os.path.getsize(FP16)/1e6:.1f} MB")
it16, bad16, over4d16 = gate(FP16, "FP16")
parity(it16, "FP16")
print(f"\n{'OK -> GPU-clean' if not bad16 and over4d16 == 0 else 'BLOCKERS REMAIN'}: {FP16}")
if __name__ == "__main__":
main()
|