jit-lora / src /ane_bridge_py.py
Ex0bit's picture
Add missing ane_bridge_py.py Python wrapper
7397b43
"""
ane_bridge_py.py — Python ctypes wrapper for libane_bridge.dylib
Provides a Pythonic interface to Apple Neural Engine private APIs
via the maderix/ANE C bridge library. Enables compiling and executing
MIL programs on ANE hardware from Python.
Usage:
from ane_bridge_py import ANEBridge
ane = ANEBridge()
kernel = ane.compile_kernel(mil_text, weights, input_sizes, output_sizes)
ane.write_input(kernel, 0, my_numpy_array)
ane.eval(kernel)
result = ane.read_output(kernel, 0, output_shape, dtype=np.float16)
ane.free_kernel(kernel)
"""
import ctypes
import ctypes.util
import os
import numpy as np
from pathlib import Path
from typing import Optional
# Resolve library path relative to this file
_BRIDGE_DIR = Path(__file__).parent / "bridge"
_LIB_PATH = str(_BRIDGE_DIR / "libane_bridge.dylib")
# Max compiles before needing process restart (ANE limitation)
MAX_COMPILE_BUDGET = 110 # Leave margin from the ~119 hard limit
class ANEBridgeError(Exception):
"""Error from ANE bridge operations."""
pass
class ANEBridge:
"""Python wrapper for the ANE C bridge library."""
def __init__(self, lib_path: Optional[str] = None):
lib_path = lib_path or _LIB_PATH
if not os.path.exists(lib_path):
raise ANEBridgeError(
f"ANE bridge library not found at {lib_path}. "
f"Run: cd scripts/ane-engine/bridge && make"
)
self._lib = ctypes.CDLL(lib_path)
self._setup_signatures()
rc = self._lib.ane_bridge_init()
if rc != 0:
raise ANEBridgeError(
"Failed to initialize ANE runtime. "
"Requires macOS 15+ on Apple Silicon."
)
def _setup_signatures(self):
"""Define C function signatures for type safety."""
lib = self._lib
# ane_bridge_init() -> int
lib.ane_bridge_init.restype = ctypes.c_int
lib.ane_bridge_init.argtypes = []
# ane_bridge_compile(...) -> void*
lib.ane_bridge_compile.restype = ctypes.c_void_p
lib.ane_bridge_compile.argtypes = [
ctypes.c_char_p, # mil_text
ctypes.c_size_t, # mil_len
ctypes.POINTER(ctypes.c_uint8), # weight_data
ctypes.c_size_t, # weight_len
ctypes.c_int, # n_inputs
ctypes.POINTER(ctypes.c_size_t), # input_sizes
ctypes.c_int, # n_outputs
ctypes.POINTER(ctypes.c_size_t), # output_sizes
]
# ane_bridge_compile_multi_weights(...) -> void*
lib.ane_bridge_compile_multi_weights.restype = ctypes.c_void_p
lib.ane_bridge_compile_multi_weights.argtypes = [
ctypes.c_char_p, # mil_text
ctypes.c_size_t, # mil_len
ctypes.POINTER(ctypes.c_char_p), # weight_names
ctypes.POINTER(ctypes.POINTER(ctypes.c_uint8)), # weight_datas
ctypes.POINTER(ctypes.c_size_t), # weight_lens
ctypes.c_int, # n_weights
ctypes.c_int, # n_inputs
ctypes.POINTER(ctypes.c_size_t), # input_sizes
ctypes.c_int, # n_outputs
ctypes.POINTER(ctypes.c_size_t), # output_sizes
]
# ane_bridge_eval(kernel) -> bool
lib.ane_bridge_eval.restype = ctypes.c_bool
lib.ane_bridge_eval.argtypes = [ctypes.c_void_p]
# ane_bridge_write_input(kernel, idx, data, bytes) -> void
lib.ane_bridge_write_input.restype = None
lib.ane_bridge_write_input.argtypes = [
ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p, ctypes.c_size_t
]
# ane_bridge_read_output(kernel, idx, data, bytes) -> void
lib.ane_bridge_read_output.restype = None
lib.ane_bridge_read_output.argtypes = [
ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p, ctypes.c_size_t
]
# ane_bridge_free(kernel) -> void
lib.ane_bridge_free.restype = None
lib.ane_bridge_free.argtypes = [ctypes.c_void_p]
# ane_bridge_get_compile_count() -> int
lib.ane_bridge_get_compile_count.restype = ctypes.c_int
lib.ane_bridge_get_compile_count.argtypes = []
# ane_bridge_reset_compile_count() -> void
lib.ane_bridge_reset_compile_count.restype = None
lib.ane_bridge_reset_compile_count.argtypes = []
# ane_bridge_build_weight_blob(src, rows, cols, out_len) -> uint8*
lib.ane_bridge_build_weight_blob.restype = ctypes.POINTER(ctypes.c_uint8)
lib.ane_bridge_build_weight_blob.argtypes = [
ctypes.POINTER(ctypes.c_float), ctypes.c_int, ctypes.c_int,
ctypes.POINTER(ctypes.c_size_t)
]
# ane_bridge_build_weight_blob_transposed
lib.ane_bridge_build_weight_blob_transposed.restype = ctypes.POINTER(ctypes.c_uint8)
lib.ane_bridge_build_weight_blob_transposed.argtypes = [
ctypes.POINTER(ctypes.c_float), ctypes.c_int, ctypes.c_int,
ctypes.POINTER(ctypes.c_size_t)
]
# ane_bridge_free_blob(ptr) -> void
lib.ane_bridge_free_blob.restype = None
lib.ane_bridge_free_blob.argtypes = [ctypes.c_void_p]
@property
def compile_count(self) -> int:
"""Current number of ANE compilations in this process."""
return self._lib.ane_bridge_get_compile_count()
@property
def compile_budget_remaining(self) -> int:
"""Remaining compilations before process restart needed."""
return MAX_COMPILE_BUDGET - self.compile_count
def needs_restart(self) -> bool:
"""True if compile budget is exhausted and process needs restart."""
return self.compile_count >= MAX_COMPILE_BUDGET
def reset_compile_count(self):
"""Reset compile counter (call after process restart)."""
self._lib.ane_bridge_reset_compile_count()
def build_weight_blob(self, weights: np.ndarray, transpose: bool = False) -> tuple:
"""Convert numpy float32 weights to ANE blob format (128-byte header + fp16).
Args:
weights: float32 numpy array of shape (rows, cols)
transpose: if True, store in transposed layout
Returns:
(blob_pointer, blob_length) — caller should free via free_blob()
"""
if weights.dtype != np.float32:
weights = weights.astype(np.float32)
weights = np.ascontiguousarray(weights)
rows, cols = weights.shape
out_len = ctypes.c_size_t()
src_ptr = weights.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
if transpose:
blob = self._lib.ane_bridge_build_weight_blob_transposed(
src_ptr, rows, cols, ctypes.byref(out_len))
else:
blob = self._lib.ane_bridge_build_weight_blob(
src_ptr, rows, cols, ctypes.byref(out_len))
if not blob:
raise ANEBridgeError("Failed to build weight blob")
return blob, out_len.value
def free_blob(self, blob_ptr):
"""Free a weight blob allocated by build_weight_blob."""
self._lib.ane_bridge_free_blob(blob_ptr)
def compile_kernel(
self,
mil_text: str,
input_sizes: list[int],
output_sizes: list[int],
weight_data: Optional[bytes] = None,
) -> int:
"""Compile a MIL program with optional single weight blob.
Args:
mil_text: UTF-8 MIL program text
input_sizes: list of byte sizes for each input IOSurface
output_sizes: list of byte sizes for each output IOSurface
weight_data: optional raw weight blob bytes
Returns:
Opaque kernel handle (int). Use with eval(), write_input(), etc.
"""
if self.needs_restart():
raise ANEBridgeError(
f"Compile budget exhausted ({self.compile_count} compiles). "
"Process restart required."
)
mil_bytes = mil_text.encode('utf-8')
n_inputs = len(input_sizes)
n_outputs = len(output_sizes)
c_input_sizes = (ctypes.c_size_t * n_inputs)(*input_sizes)
c_output_sizes = (ctypes.c_size_t * n_outputs)(*output_sizes)
if weight_data:
c_weight = (ctypes.c_uint8 * len(weight_data)).from_buffer_copy(weight_data)
handle = self._lib.ane_bridge_compile(
mil_bytes, len(mil_bytes),
c_weight, len(weight_data),
n_inputs, c_input_sizes,
n_outputs, c_output_sizes)
else:
handle = self._lib.ane_bridge_compile(
mil_bytes, len(mil_bytes),
None, 0,
n_inputs, c_input_sizes,
n_outputs, c_output_sizes)
if not handle:
raise ANEBridgeError("ANE kernel compilation failed")
return handle
def compile_kernel_multi_weights(
self,
mil_text: str,
weights: dict[str, tuple],
input_sizes: list[int],
output_sizes: list[int],
) -> int:
"""Compile a MIL program with multiple named weight blobs.
Args:
mil_text: UTF-8 MIL program text
weights: dict of {name: (blob_ptr, blob_len)} from build_weight_blob()
input_sizes: list of byte sizes for each input IOSurface
output_sizes: list of byte sizes for each output IOSurface
Returns:
Opaque kernel handle
"""
if self.needs_restart():
raise ANEBridgeError(
f"Compile budget exhausted ({self.compile_count} compiles). "
"Process restart required."
)
mil_bytes = mil_text.encode('utf-8')
n_inputs = len(input_sizes)
n_outputs = len(output_sizes)
n_weights = len(weights)
# Build weight arrays
c_names = (ctypes.c_char_p * n_weights)()
c_datas = (ctypes.POINTER(ctypes.c_uint8) * n_weights)()
c_lens = (ctypes.c_size_t * n_weights)()
for i, (name, (blob_ptr, blob_len)) in enumerate(weights.items()):
c_names[i] = name.encode('utf-8')
c_datas[i] = ctypes.cast(blob_ptr, ctypes.POINTER(ctypes.c_uint8))
c_lens[i] = blob_len
c_input_sizes = (ctypes.c_size_t * n_inputs)(*input_sizes)
c_output_sizes = (ctypes.c_size_t * n_outputs)(*output_sizes)
handle = self._lib.ane_bridge_compile_multi_weights(
mil_bytes, len(mil_bytes),
c_names, c_datas, c_lens, n_weights,
n_inputs, c_input_sizes,
n_outputs, c_output_sizes)
if not handle:
raise ANEBridgeError("ANE kernel compilation with multi-weights failed")
return handle
def eval(self, kernel_handle: int) -> bool:
"""Execute a compiled kernel on ANE hardware.
Args:
kernel_handle: handle from compile_kernel()
Returns:
True on success
"""
result = self._lib.ane_bridge_eval(kernel_handle)
if not result:
raise ANEBridgeError("ANE kernel evaluation failed")
return True
def write_input(self, kernel_handle: int, index: int, data: np.ndarray):
"""Write numpy array to kernel input IOSurface.
Args:
kernel_handle: handle from compile_kernel()
index: input tensor index (0-based)
data: numpy array (will be made contiguous if needed)
"""
data = np.ascontiguousarray(data)
self._lib.ane_bridge_write_input(
kernel_handle, index,
data.ctypes.data, data.nbytes)
def read_output(
self,
kernel_handle: int,
index: int,
shape: tuple,
dtype=np.float16,
) -> np.ndarray:
"""Read kernel output IOSurface into numpy array.
Args:
kernel_handle: handle from compile_kernel()
index: output tensor index (0-based)
shape: shape of the output tensor
dtype: numpy dtype (default float16, matching ANE native format)
Returns:
numpy array with output data
"""
out = np.empty(shape, dtype=dtype)
self._lib.ane_bridge_read_output(
kernel_handle, index,
out.ctypes.data, out.nbytes)
return out
def free_kernel(self, kernel_handle: int):
"""Free a compiled kernel and all associated resources."""
if kernel_handle:
self._lib.ane_bridge_free(kernel_handle)
def self_test():
"""Quick self-test to verify ANE bridge works on this machine."""
print("ANE Bridge Self-Test")
print("=" * 40)
try:
ane = ANEBridge()
print(f"[OK] ANE runtime initialized")
print(f" Compile count: {ane.compile_count}")
print(f" Budget remaining: {ane.compile_budget_remaining}")
except ANEBridgeError as e:
print(f"[FAIL] {e}")
return False
# --- Test 1: conv with weights (matches proven sram_probe.m pattern) ---
# Uses fp32 input → cast to fp16 → conv → cast to fp32 output
# ANE has minimum tensor size requirements — use ch=64, sp=16
ch, sp = 64, 16
mil_text = (
'program(1.3)\n'
'[buildInfo = dict<string, string>({{"coremlc-component-MIL", "3510.2.1"}, '
'{"coremlc-version", "3505.4.1"}, '
'{"coremltools-component-milinternal", ""}, '
'{"coremltools-version", "9.0"}})]\n'
'{\n'
f' func main<ios18>(tensor<fp32, [1, {ch}, 1, {sp}]> x) {{\n'
' string c_pad_type = const()[name = string("c_pad_type"), val = string("valid")];\n'
' tensor<int32, [2]> c_strides = const()[name = string("c_strides"), val = tensor<int32, [2]>([1, 1])];\n'
' tensor<int32, [4]> c_pad = const()[name = string("c_pad"), val = tensor<int32, [4]>([0, 0, 0, 0])];\n'
' tensor<int32, [2]> c_dilations = const()[name = string("c_dilations"), val = tensor<int32, [2]>([1, 1])];\n'
' int32 c_groups = const()[name = string("c_groups"), val = int32(1)];\n'
' string to_fp16 = const()[name = string("to_fp16"), val = string("fp16")];\n'
f' tensor<fp16, [1, {ch}, 1, {sp}]> x16 = cast(dtype = to_fp16, x = x)[name = string("cast_in")];\n'
f' tensor<fp16, [{ch}, {ch}, 1, 1]> W = const()[name = string("W"), val = tensor<fp16, [{ch}, {ch}, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(64)))];\n'
f' tensor<fp16, [1, {ch}, 1, {sp}]> y16 = conv(dilations = c_dilations, groups = c_groups, pad = c_pad, pad_type = c_pad_type, strides = c_strides, weight = W, x = x16)[name = string("conv")];\n'
' string to_fp32 = const()[name = string("to_fp32"), val = string("fp32")];\n'
f' tensor<fp32, [1, {ch}, 1, {sp}]> y = cast(dtype = to_fp32, x = y16)[name = string("cast_out")];\n'
' } -> (y);\n'
'}\n'
)
# Build identity-like weight: eye(ch) so conv is identity transform
W = np.eye(ch, dtype=np.float32)
blob_ptr, blob_len = ane.build_weight_blob(W)
tensor_bytes_in = ch * sp * 4 # fp32 input
tensor_bytes_out = ch * sp * 4 # fp32 output
try:
# Get raw weight bytes from blob pointer
blob_bytes = bytes(ctypes.cast(blob_ptr, ctypes.POINTER(ctypes.c_uint8 * blob_len)).contents)
kernel = ane.compile_kernel(
mil_text,
input_sizes=[tensor_bytes_in],
output_sizes=[tensor_bytes_out],
weight_data=blob_bytes,
)
print(f"[OK] MIL compilation succeeded (handle: 0x{kernel:x})")
print(f" Compile count: {ane.compile_count}")
except ANEBridgeError as e:
print(f"[FAIL] Compilation: {e}")
ane.free_blob(blob_ptr)
return False
finally:
ane.free_blob(blob_ptr)
# Test: evaluate — identity conv should return input
x = np.random.randn(1, ch, 1, sp).astype(np.float32)
try:
ane.write_input(kernel, 0, x)
ane.eval(kernel)
result = ane.read_output(kernel, 0, (1, ch, 1, sp), dtype=np.float32)
# With identity weight matrix, output should ≈ input (fp16 rounding)
if np.allclose(result, x, atol=0.05):
print(f"[OK] ANE evaluation correct (identity conv)")
print(f" Input[:4]: {x.flatten()[:4]}")
print(f" Output[:4]: {result.flatten()[:4]}")
else:
max_err = np.max(np.abs(result - x))
print(f"[WARN] Result differs (max err: {max_err:.4f})")
print(f" Input[:4]: {x.flatten()[:4]}")
print(f" Output[:4]: {result.flatten()[:4]}")
# Don't fail — fp16 rounding can be significant
except ANEBridgeError as e:
print(f"[FAIL] Evaluation: {e}")
ane.free_kernel(kernel)
return False
# Test: weight blob
try:
weights = np.random.randn(4, 4).astype(np.float32)
blob, blob_len = ane.build_weight_blob(weights)
print(f"[OK] Weight blob built ({blob_len} bytes for 4x4 float32)")
ane.free_blob(blob)
except ANEBridgeError as e:
print(f"[FAIL] Weight blob: {e}")
ane.free_kernel(kernel)
return False
ane.free_kernel(kernel)
print(f"\n[PASS] All ANE bridge tests passed")
print(f" Final compile count: {ane.compile_count}")
return True
if __name__ == "__main__":
success = self_test()
exit(0 if success else 1)