PE-Core-base-patch16-224-LiteRT / convert_pecore.py
mlboydaisuke's picture
Upload convert_pecore.py with huggingface_hub
1c0fc7c verified
Raw
History Blame Contribute Delete
13.4 kB
"""Convert timm Perception Encoder (PE-Core, base/patch16/224) image tower to a
GPU-clean LiteRT .tflite for the ML Drift GPU delegate.
PE-Core (Meta 2025, Apache-2.0) is a CLIP-style ViT image tower. timm exposes it
as `vit_pe_core_base_patch16_224` (weights `timm/vit_pe_core_base_patch16_224.fb`).
Walls re-authored here (all numerically verbatim, weights copied):
* AttentionRope (x12): fused qkv -> 5D reshape head-split = the "C12" GPU wall.
Decompose to separate q/k/v Linears, manual 4D (B,H,N,d) attention.
* RoPE: PE-Core uses the *interleaved* layout (rotate_half=False) whose `rot()`
does strided `x[...,::2]` -> GATHER_ND (GPU-banned). Fix = the proven
even->odd channel permutation baked into q/k weights + `rotate_half`
(slice+neg+concat, 4D) + constant half-layout cos/sin (const-folds to MUL/ADD).
Permuting q AND k identically preserves q.k exactly, so attention is unchanged.
* AttentionPoolLatent: fused kv -> 5D head-split. Decompose kv to k/v Linears.
I/O: input [1,3,224,224] NCHW float32, output [1,1024] L2-normalized image embedding.
~/clipconv/bin/python scripts/convert_pecore.py
"""
import os
import sys
import types
import collections
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import _stub # noqa: F401 (macOS scipy/_propack guard, import FIRST)
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
MODEL = "vit_pe_core_base_patch16_224"
IMG = 224
OUT_DIR = os.path.expanduser("~/code/litertlm-convert/out/pecore")
os.makedirs(OUT_DIR, exist_ok=True)
FP32 = os.path.join(OUT_DIR, "pe_core_base_224.tflite")
FP16 = os.path.join(OUT_DIR, "pe_core_base_224_fp16.tflite")
BANNED = {"GATHER_ND", "GATHER", "TOPK_V2", "FLEX_ERF", "ERF", "BROADCAST_TO"}
# -------------------------------------------------- overflow-safe LayerNorm
class SafeLayerNorm(nn.Module):
"""LayerNorm whose variance reduction can't overflow fp16. The ML Drift GPU
delegate computes the sum-of-squares reduction in fp16 even for an fp32 model;
deep-ViT massive activations (|x|~50+) make `sum((x-mean)^2)` exceed fp16 max
(65504) -> wrong normalization that compounds with depth (corr collapses to
~0.28 over 12 blocks). Scaling by `SC` before squaring (and undoing after)
keeps the running sum in range -- mathematically identical to nn.LayerNorm."""
SC = 0.03125 # 1/32: keeps sum((x-mean)*SC)^2 << 65504 for |x|<~290
def __init__(self, ln: nn.LayerNorm):
super().__init__()
self.weight, self.bias, self.eps = ln.weight, ln.bias, ln.eps
def forward(self, x):
xc = x - x.mean(-1, keepdim=True)
xs = xc * self.SC
var = (xs * xs).mean(-1, keepdim=True) / (self.SC * self.SC)
return xc * torch.rsqrt(var + self.eps) * self.weight + self.bias
def patch_layernorm(module):
for name, child in module.named_children():
if isinstance(child, nn.LayerNorm):
setattr(module, name, SafeLayerNorm(child))
else:
patch_layernorm(child)
# ---------------------------------------------------------------- rope (clean)
def rope_rotate_half(x):
# 4D-clean: slice halves, negate, concat. No strided slice, no >4D.
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
def apply_half(x, cos, sin):
# x: [B,H,N,d]; cos/sin: [1,1,N,d]
return x * cos + rope_rotate_half(x) * sin
def _even_odd_perm(num_heads, head_dim):
"""Per-head index permutation [0,2,..,1,3,..] that maps the interleaved RoPE
layout to the rotate-half layout (evens then odds within each head)."""
perm = []
for h in range(num_heads):
base = h * head_dim
perm += [base + i for i in range(0, head_dim, 2)]
perm += [base + i for i in range(1, head_dim, 2)]
return torch.tensor(perm, dtype=torch.long)
# ----------------------------------------------- AttentionRope -> 4D + clean rope
def _attn_rope_forward(self, x, rope=None, attn_mask=None, is_causal=False):
B, N, C = x.shape
H, d = self.num_heads, self.head_dim
q = self.q_proj_d(x).reshape(B, N, H, d).transpose(1, 2)
k = self.k_proj_d(x).reshape(B, N, H, d).transpose(1, 2)
v = self.v_proj_d(x).reshape(B, N, H, d).transpose(1, 2)
q, k = self.q_norm(q), self.k_norm(k) # Identity for PE-Core
npt = self.npt_
cos, sin = self.cos_half, self.sin_half
q = torch.cat([q[:, :, :npt, :], apply_half(q[:, :, npt:, :], cos, sin)], dim=2)
k = torch.cat([k[:, :, :npt, :], apply_half(k[:, :, npt:, :], cos, sin)], dim=2)
# SDPA lowers to a 3D batch-matmul with a MATERIALIZED transpose (adj_y=False),
# which the GPU delegate accepts -- unlike explicit q@k.transpose (folds to
# adj_y=True, rejected for non-constant RHS). Default scale = head_dim**-0.5.
out = F.scaled_dot_product_attention(q, k, v)
out = out.transpose(1, 2).reshape(B, N, self.attn_dim)
out = self.norm(out) # Identity (scale_norm off)
return self.proj(out)
def reauthor_attn_rope(attn, cos_half, sin_half, npt):
C = attn.qkv.in_features
H, d = attn.num_heads, attn.head_dim
w = attn.qkv.weight.data
b = attn.qkv.bias.data if attn.qkv.bias is not None else None
wq, wk, wv = w[:C], w[C:2 * C], w[2 * C:]
perm = _even_odd_perm(H, d)
has_b = b is not None
q_proj = nn.Linear(C, C, bias=has_b)
k_proj = nn.Linear(C, C, bias=has_b)
v_proj = nn.Linear(C, C, bias=has_b)
with torch.no_grad():
q_proj.weight.copy_(wq[perm]) # permute OUTPUT channels (rows)
k_proj.weight.copy_(wk[perm])
v_proj.weight.copy_(wv)
if has_b:
q_proj.bias.copy_(b[:C][perm])
k_proj.bias.copy_(b[C:2 * C][perm])
v_proj.bias.copy_(b[2 * C:])
attn.q_proj_d, attn.k_proj_d, attn.v_proj_d = q_proj, k_proj, v_proj
attn.register_buffer("cos_half", cos_half[None, None]) # [1,1,N,d]
attn.register_buffer("sin_half", sin_half[None, None])
attn.npt_ = npt
attn.forward = types.MethodType(_attn_rope_forward, attn)
# ----------------------------------------------- AttentionPoolLatent -> 4D
def _attn_pool_forward(self, x, attn_mask=None):
# The pooling query is derived from a constant latent (latent_len=1). Both a
# const@non-const BMM (rejected at compile) AND the reordered const-RHS BMM
# (compiles but the GPU delegate MIS-COMPUTES it -> garbage embedding) fail, so
# express the single-query attention as broadcast-multiply + reduce-sum, which
# is exact and GPU-correct.
B, N, C = x.shape
H, d, L = self.num_heads, self.head_dim, self.latent_len
k = self.k_norm(self.k_proj_d(x).reshape(B, N, H, d).transpose(1, 2)) # [B,H,N,d]
v = self.v_proj_d(x).reshape(B, N, H, d).transpose(1, 2) # [B,H,N,d]
qc = self.q_const # [H, L, d] constant, q_norm'd + scaled
# Broadcast-multiply + reduce (no batch-matmul): exact for latent_len=1 and
# avoids the const@non-const BMM that the GPU delegate mis-computes.
scores = (qc.unsqueeze(0) * k).sum(dim=-1) # [B, H, N]
attn = scores.softmax(dim=-1).unsqueeze(-1) # [B, H, N, 1]
out = (attn * v).sum(dim=2).reshape(B, L, C) # [B, L, C]
out = self.proj(out)
if self.mlp is not None:
out = out + self.mlp(self.norm(out))
if self.pool == "token":
out = out[:, 0]
elif self.pool == "avg":
out = out.mean(1)
return out
def reauthor_attn_pool(ap):
assert ap.pos_embed is None, "attn_pool pos_embed not handled"
C = ap.kv.in_features
inner = ap.num_heads * ap.head_dim
has_b = ap.kv.bias is not None
k_proj = nn.Linear(C, inner, bias=has_b)
v_proj = nn.Linear(C, inner, bias=has_b)
with torch.no_grad():
k_proj.weight.copy_(ap.kv.weight.data[:inner])
v_proj.weight.copy_(ap.kv.weight.data[inner:])
if has_b:
k_proj.bias.copy_(ap.kv.bias.data[:inner])
v_proj.bias.copy_(ap.kv.bias.data[inner:])
H, d, L = ap.num_heads, ap.head_dim, ap.latent_len
# constant query: q_norm(q(latent)) * scale -> [H, L, d]
ql = ap.q(ap.latent.expand(1, -1, -1)).reshape(1, L, H, d).transpose(1, 2)
ql = ap.q_norm(ql) * ap.scale
ap.k_proj_d, ap.v_proj_d = k_proj, v_proj
ap.register_buffer("q_const", ql.reshape(H, L, d).detach())
ap.forward = types.MethodType(_attn_pool_forward, ap)
# ------------------------------------------------------------------- wrapper
class PECoreImageEncoder(nn.Module):
def __init__(self, m):
super().__init__()
self.m = m
def forward(self, pixel):
m = self.m
x = m.patch_embed(pixel)
if x.dim() == 4: # [B,Hg,Wg,C] -> [B,N,C]
x = x.flatten(1, 2)
cls = m.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat([cls, x], dim=1)
if m.pos_embed is not None:
x = x + m.pos_embed
x = m.norm_pre(x)
for blk in m.blocks:
x = blk(x) # rope=None default; patched attn uses baked buffers
x = m.norm(x)
x = m.attn_pool(x)
x = m.head(x)
return F.normalize(x, dim=-1)
def build_half_cos_sin(m):
"""Half-layout constant cos/sin [N_patch, head_dim] from timm's interleaved rope."""
emb = m.rope.get_embed() # [N, 2*d] = cat(sin, cos)
sin_emb, cos_emb = emb.chunk(2, -1) # each [N, d] interleaved [s0,s0,s1,s1,...]
s = sin_emb[:, ::2] # [N, d/2] = [s0,s1,...]
c = cos_emb[:, ::2]
sin_half = torch.cat([s, s], dim=-1) # [N, d]
cos_half = torch.cat([c, c], dim=-1)
return cos_half.detach(), sin_half.detach()
def op_hist(path):
from ai_edge_litert.interpreter import Interpreter
it = Interpreter(model_path=path)
it.allocate_tensors()
hist = collections.Counter(d["op_name"] for d in it._get_ops_details())
over4d = sum(1 for d in it.get_tensor_details() if len(d.get("shape", [])) > 4)
return hist, over4d, it
def tflite_run(it, x_nchw):
inp = it.get_input_details()[0]
shp = list(inp["shape"])
x = x_nchw if shp[1] == 3 else np.transpose(x_nchw, (0, 2, 3, 1)).copy()
it.set_tensor(inp["index"], x.astype(inp["dtype"]))
it.invoke()
return it.get_tensor(it.get_output_details()[0]["index"]).astype("float64").reshape(-1)
def main():
torch.manual_seed(0)
print(f"loading {MODEL} (pretrained, apache-2.0) ...")
m = timm.create_model(MODEL, pretrained=True).eval()
x = torch.randn(1, 3, IMG, IMG)
with torch.no_grad():
ref = F.normalize(m(x), dim=-1).numpy().flatten() # original (interleaved rope, fused qkv)
# ---- re-author in place ----
cos_half, sin_half = build_half_cos_sin(m)
npt = m.blocks[0].attn.num_prefix_tokens
for blk in m.blocks:
reauthor_attn_rope(blk.attn, cos_half, sin_half, npt)
reauthor_attn_pool(m.attn_pool)
patch_layernorm(m) # GPU fp16 variance reduction overflows on deep-ViT outliers
enc = PECoreImageEncoder(m).eval()
with torch.no_grad():
got = enc(x).numpy().flatten()
corr = float(np.corrcoef(ref, got)[0, 1])
maxd = float(np.abs(ref - got).max())
print(f"EAGER parity (orig vs re-authored): corr {corr:.8f} max|diff| {maxd:.3e}")
assert corr > 0.9999, "re-authoring changed the math -- fix before convert"
# ---- convert fp32 ----
print("converting (litert_torch) ...")
import litert_torch
litert_torch.convert(enc, (x,)).export(FP32)
hist, over4d, it = op_hist(FP32)
bad = {k: v for k, v in hist.items() if k in BANNED}
print(f"FP32 ops: {dict(sorted(hist.items(), key=lambda kv: -kv[1]))}")
print(f"banned: {bad or 'NONE'} | >4D tensors: {over4d}")
o = tflite_run(it, x.numpy())
print(f"PARITY tflite(fp32) vs torch: corr {np.corrcoef(ref, o)[0,1]:.6f}")
assert not bad and over4d == 0, "GPU blockers remain -- inspect op histogram"
# ---- fp16 FLOAT_CASTING ----
print("quantizing fp16 (FLOAT_CASTING) ...")
from ai_edge_quantizer import quantizer, recipe_manager
from ai_edge_quantizer.recipe import AlgorithmName, qtyping
rm = recipe_manager.RecipeManager()
rm.add_quantization_config(
regex=".*",
operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
op_config=qtyping.OpQuantizationConfig(
weight_tensor_config=qtyping.TensorQuantizationConfig(
num_bits=16, dtype=qtyping.TensorDataType.FLOAT),
compute_precision=qtyping.ComputePrecision.FLOAT,
),
algorithm_key=AlgorithmName.FLOAT_CASTING,
)
if os.path.exists(FP16):
os.remove(FP16)
qt = quantizer.Quantizer(float_model=FP32)
qt.load_quantization_recipe(rm.get_quantization_recipe())
qt.quantize().export_model(FP16)
s32, s16 = os.path.getsize(FP32) / 1e6, os.path.getsize(FP16) / 1e6
print(f"SIZE fp32 {s32:.1f} MB -> fp16 {s16:.1f} MB ({s16/s32*100:.0f}%)")
h16, o16d, it16 = op_hist(FP16)
bad16 = {k: v for k, v in h16.items() if k in BANNED}
print(f"FP16 banned: {bad16 or 'NONE'} | >4D: {o16d}")
o16 = tflite_run(it16, x.numpy())
print(f"PARITY tflite(fp16) vs torch: corr {np.corrcoef(ref, o16)[0,1]:.6f} "
f"fp16-vs-fp32 corr {np.corrcoef(o, o16)[0,1]:.6f}")
print("\nDONE:", FP16)
if __name__ == "__main__":
main()