File size: 17,863 Bytes
7397b43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
"""
ane_bridge_py.py — Python ctypes wrapper for libane_bridge.dylib

Provides a Pythonic interface to Apple Neural Engine private APIs
via the maderix/ANE C bridge library. Enables compiling and executing
MIL programs on ANE hardware from Python.

Usage:
    from ane_bridge_py import ANEBridge
    ane = ANEBridge()
    kernel = ane.compile_kernel(mil_text, weights, input_sizes, output_sizes)
    ane.write_input(kernel, 0, my_numpy_array)
    ane.eval(kernel)
    result = ane.read_output(kernel, 0, output_shape, dtype=np.float16)
    ane.free_kernel(kernel)
"""

import ctypes
import ctypes.util
import os
import numpy as np
from pathlib import Path
from typing import Optional

# Resolve library path relative to this file
_BRIDGE_DIR = Path(__file__).parent / "bridge"
_LIB_PATH = str(_BRIDGE_DIR / "libane_bridge.dylib")

# Max compiles before needing process restart (ANE limitation)
MAX_COMPILE_BUDGET = 110  # Leave margin from the ~119 hard limit


class ANEBridgeError(Exception):
    """Error from ANE bridge operations."""
    pass


class ANEBridge:
    """Python wrapper for the ANE C bridge library."""

    def __init__(self, lib_path: Optional[str] = None):
        lib_path = lib_path or _LIB_PATH
        if not os.path.exists(lib_path):
            raise ANEBridgeError(
                f"ANE bridge library not found at {lib_path}. "
                f"Run: cd scripts/ane-engine/bridge && make"
            )

        self._lib = ctypes.CDLL(lib_path)
        self._setup_signatures()

        rc = self._lib.ane_bridge_init()
        if rc != 0:
            raise ANEBridgeError(
                "Failed to initialize ANE runtime. "
                "Requires macOS 15+ on Apple Silicon."
            )

    def _setup_signatures(self):
        """Define C function signatures for type safety."""
        lib = self._lib

        # ane_bridge_init() -> int
        lib.ane_bridge_init.restype = ctypes.c_int
        lib.ane_bridge_init.argtypes = []

        # ane_bridge_compile(...) -> void*
        lib.ane_bridge_compile.restype = ctypes.c_void_p
        lib.ane_bridge_compile.argtypes = [
            ctypes.c_char_p,                   # mil_text
            ctypes.c_size_t,                   # mil_len
            ctypes.POINTER(ctypes.c_uint8),    # weight_data
            ctypes.c_size_t,                   # weight_len
            ctypes.c_int,                      # n_inputs
            ctypes.POINTER(ctypes.c_size_t),   # input_sizes
            ctypes.c_int,                      # n_outputs
            ctypes.POINTER(ctypes.c_size_t),   # output_sizes
        ]

        # ane_bridge_compile_multi_weights(...) -> void*
        lib.ane_bridge_compile_multi_weights.restype = ctypes.c_void_p
        lib.ane_bridge_compile_multi_weights.argtypes = [
            ctypes.c_char_p,                             # mil_text
            ctypes.c_size_t,                             # mil_len
            ctypes.POINTER(ctypes.c_char_p),             # weight_names
            ctypes.POINTER(ctypes.POINTER(ctypes.c_uint8)),  # weight_datas
            ctypes.POINTER(ctypes.c_size_t),             # weight_lens
            ctypes.c_int,                                # n_weights
            ctypes.c_int,                                # n_inputs
            ctypes.POINTER(ctypes.c_size_t),             # input_sizes
            ctypes.c_int,                                # n_outputs
            ctypes.POINTER(ctypes.c_size_t),             # output_sizes
        ]

        # ane_bridge_eval(kernel) -> bool
        lib.ane_bridge_eval.restype = ctypes.c_bool
        lib.ane_bridge_eval.argtypes = [ctypes.c_void_p]

        # ane_bridge_write_input(kernel, idx, data, bytes) -> void
        lib.ane_bridge_write_input.restype = None
        lib.ane_bridge_write_input.argtypes = [
            ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p, ctypes.c_size_t
        ]

        # ane_bridge_read_output(kernel, idx, data, bytes) -> void
        lib.ane_bridge_read_output.restype = None
        lib.ane_bridge_read_output.argtypes = [
            ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p, ctypes.c_size_t
        ]

        # ane_bridge_free(kernel) -> void
        lib.ane_bridge_free.restype = None
        lib.ane_bridge_free.argtypes = [ctypes.c_void_p]

        # ane_bridge_get_compile_count() -> int
        lib.ane_bridge_get_compile_count.restype = ctypes.c_int
        lib.ane_bridge_get_compile_count.argtypes = []

        # ane_bridge_reset_compile_count() -> void
        lib.ane_bridge_reset_compile_count.restype = None
        lib.ane_bridge_reset_compile_count.argtypes = []

        # ane_bridge_build_weight_blob(src, rows, cols, out_len) -> uint8*
        lib.ane_bridge_build_weight_blob.restype = ctypes.POINTER(ctypes.c_uint8)
        lib.ane_bridge_build_weight_blob.argtypes = [
            ctypes.POINTER(ctypes.c_float), ctypes.c_int, ctypes.c_int,
            ctypes.POINTER(ctypes.c_size_t)
        ]

        # ane_bridge_build_weight_blob_transposed
        lib.ane_bridge_build_weight_blob_transposed.restype = ctypes.POINTER(ctypes.c_uint8)
        lib.ane_bridge_build_weight_blob_transposed.argtypes = [
            ctypes.POINTER(ctypes.c_float), ctypes.c_int, ctypes.c_int,
            ctypes.POINTER(ctypes.c_size_t)
        ]

        # ane_bridge_free_blob(ptr) -> void
        lib.ane_bridge_free_blob.restype = None
        lib.ane_bridge_free_blob.argtypes = [ctypes.c_void_p]

    @property
    def compile_count(self) -> int:
        """Current number of ANE compilations in this process."""
        return self._lib.ane_bridge_get_compile_count()

    @property
    def compile_budget_remaining(self) -> int:
        """Remaining compilations before process restart needed."""
        return MAX_COMPILE_BUDGET - self.compile_count

    def needs_restart(self) -> bool:
        """True if compile budget is exhausted and process needs restart."""
        return self.compile_count >= MAX_COMPILE_BUDGET

    def reset_compile_count(self):
        """Reset compile counter (call after process restart)."""
        self._lib.ane_bridge_reset_compile_count()

    def build_weight_blob(self, weights: np.ndarray, transpose: bool = False) -> tuple:
        """Convert numpy float32 weights to ANE blob format (128-byte header + fp16).

        Args:
            weights: float32 numpy array of shape (rows, cols)
            transpose: if True, store in transposed layout

        Returns:
            (blob_pointer, blob_length) — caller should free via free_blob()
        """
        if weights.dtype != np.float32:
            weights = weights.astype(np.float32)
        weights = np.ascontiguousarray(weights)

        rows, cols = weights.shape
        out_len = ctypes.c_size_t()
        src_ptr = weights.ctypes.data_as(ctypes.POINTER(ctypes.c_float))

        if transpose:
            blob = self._lib.ane_bridge_build_weight_blob_transposed(
                src_ptr, rows, cols, ctypes.byref(out_len))
        else:
            blob = self._lib.ane_bridge_build_weight_blob(
                src_ptr, rows, cols, ctypes.byref(out_len))

        if not blob:
            raise ANEBridgeError("Failed to build weight blob")

        return blob, out_len.value

    def free_blob(self, blob_ptr):
        """Free a weight blob allocated by build_weight_blob."""
        self._lib.ane_bridge_free_blob(blob_ptr)

    def compile_kernel(
        self,
        mil_text: str,
        input_sizes: list[int],
        output_sizes: list[int],
        weight_data: Optional[bytes] = None,
    ) -> int:
        """Compile a MIL program with optional single weight blob.

        Args:
            mil_text: UTF-8 MIL program text
            input_sizes: list of byte sizes for each input IOSurface
            output_sizes: list of byte sizes for each output IOSurface
            weight_data: optional raw weight blob bytes

        Returns:
            Opaque kernel handle (int). Use with eval(), write_input(), etc.
        """
        if self.needs_restart():
            raise ANEBridgeError(
                f"Compile budget exhausted ({self.compile_count} compiles). "
                "Process restart required."
            )

        mil_bytes = mil_text.encode('utf-8')
        n_inputs = len(input_sizes)
        n_outputs = len(output_sizes)

        c_input_sizes = (ctypes.c_size_t * n_inputs)(*input_sizes)
        c_output_sizes = (ctypes.c_size_t * n_outputs)(*output_sizes)

        if weight_data:
            c_weight = (ctypes.c_uint8 * len(weight_data)).from_buffer_copy(weight_data)
            handle = self._lib.ane_bridge_compile(
                mil_bytes, len(mil_bytes),
                c_weight, len(weight_data),
                n_inputs, c_input_sizes,
                n_outputs, c_output_sizes)
        else:
            handle = self._lib.ane_bridge_compile(
                mil_bytes, len(mil_bytes),
                None, 0,
                n_inputs, c_input_sizes,
                n_outputs, c_output_sizes)

        if not handle:
            raise ANEBridgeError("ANE kernel compilation failed")

        return handle

    def compile_kernel_multi_weights(
        self,
        mil_text: str,
        weights: dict[str, tuple],
        input_sizes: list[int],
        output_sizes: list[int],
    ) -> int:
        """Compile a MIL program with multiple named weight blobs.

        Args:
            mil_text: UTF-8 MIL program text
            weights: dict of {name: (blob_ptr, blob_len)} from build_weight_blob()
            input_sizes: list of byte sizes for each input IOSurface
            output_sizes: list of byte sizes for each output IOSurface

        Returns:
            Opaque kernel handle
        """
        if self.needs_restart():
            raise ANEBridgeError(
                f"Compile budget exhausted ({self.compile_count} compiles). "
                "Process restart required."
            )

        mil_bytes = mil_text.encode('utf-8')
        n_inputs = len(input_sizes)
        n_outputs = len(output_sizes)
        n_weights = len(weights)

        # Build weight arrays
        c_names = (ctypes.c_char_p * n_weights)()
        c_datas = (ctypes.POINTER(ctypes.c_uint8) * n_weights)()
        c_lens = (ctypes.c_size_t * n_weights)()

        for i, (name, (blob_ptr, blob_len)) in enumerate(weights.items()):
            c_names[i] = name.encode('utf-8')
            c_datas[i] = ctypes.cast(blob_ptr, ctypes.POINTER(ctypes.c_uint8))
            c_lens[i] = blob_len

        c_input_sizes = (ctypes.c_size_t * n_inputs)(*input_sizes)
        c_output_sizes = (ctypes.c_size_t * n_outputs)(*output_sizes)

        handle = self._lib.ane_bridge_compile_multi_weights(
            mil_bytes, len(mil_bytes),
            c_names, c_datas, c_lens, n_weights,
            n_inputs, c_input_sizes,
            n_outputs, c_output_sizes)

        if not handle:
            raise ANEBridgeError("ANE kernel compilation with multi-weights failed")

        return handle

    def eval(self, kernel_handle: int) -> bool:
        """Execute a compiled kernel on ANE hardware.

        Args:
            kernel_handle: handle from compile_kernel()

        Returns:
            True on success
        """
        result = self._lib.ane_bridge_eval(kernel_handle)
        if not result:
            raise ANEBridgeError("ANE kernel evaluation failed")
        return True

    def write_input(self, kernel_handle: int, index: int, data: np.ndarray):
        """Write numpy array to kernel input IOSurface.

        Args:
            kernel_handle: handle from compile_kernel()
            index: input tensor index (0-based)
            data: numpy array (will be made contiguous if needed)
        """
        data = np.ascontiguousarray(data)
        self._lib.ane_bridge_write_input(
            kernel_handle, index,
            data.ctypes.data, data.nbytes)

    def read_output(
        self,
        kernel_handle: int,
        index: int,
        shape: tuple,
        dtype=np.float16,
    ) -> np.ndarray:
        """Read kernel output IOSurface into numpy array.

        Args:
            kernel_handle: handle from compile_kernel()
            index: output tensor index (0-based)
            shape: shape of the output tensor
            dtype: numpy dtype (default float16, matching ANE native format)

        Returns:
            numpy array with output data
        """
        out = np.empty(shape, dtype=dtype)
        self._lib.ane_bridge_read_output(
            kernel_handle, index,
            out.ctypes.data, out.nbytes)
        return out

    def free_kernel(self, kernel_handle: int):
        """Free a compiled kernel and all associated resources."""
        if kernel_handle:
            self._lib.ane_bridge_free(kernel_handle)


def self_test():
    """Quick self-test to verify ANE bridge works on this machine."""
    print("ANE Bridge Self-Test")
    print("=" * 40)

    try:
        ane = ANEBridge()
        print(f"[OK] ANE runtime initialized")
        print(f"     Compile count: {ane.compile_count}")
        print(f"     Budget remaining: {ane.compile_budget_remaining}")
    except ANEBridgeError as e:
        print(f"[FAIL] {e}")
        return False

    # --- Test 1: conv with weights (matches proven sram_probe.m pattern) ---
    # Uses fp32 input → cast to fp16 → conv → cast to fp32 output
    # ANE has minimum tensor size requirements — use ch=64, sp=16
    ch, sp = 64, 16
    mil_text = (
        'program(1.3)\n'
        '[buildInfo = dict<string, string>({{"coremlc-component-MIL", "3510.2.1"}, '
        '{"coremlc-version", "3505.4.1"}, '
        '{"coremltools-component-milinternal", ""}, '
        '{"coremltools-version", "9.0"}})]\n'
        '{\n'
        f'    func main<ios18>(tensor<fp32, [1, {ch}, 1, {sp}]> x) {{\n'
        '        string c_pad_type = const()[name = string("c_pad_type"), val = string("valid")];\n'
        '        tensor<int32, [2]> c_strides = const()[name = string("c_strides"), val = tensor<int32, [2]>([1, 1])];\n'
        '        tensor<int32, [4]> c_pad = const()[name = string("c_pad"), val = tensor<int32, [4]>([0, 0, 0, 0])];\n'
        '        tensor<int32, [2]> c_dilations = const()[name = string("c_dilations"), val = tensor<int32, [2]>([1, 1])];\n'
        '        int32 c_groups = const()[name = string("c_groups"), val = int32(1)];\n'
        '        string to_fp16 = const()[name = string("to_fp16"), val = string("fp16")];\n'
        f'        tensor<fp16, [1, {ch}, 1, {sp}]> x16 = cast(dtype = to_fp16, x = x)[name = string("cast_in")];\n'
        f'        tensor<fp16, [{ch}, {ch}, 1, 1]> W = const()[name = string("W"), val = tensor<fp16, [{ch}, {ch}, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(64)))];\n'
        f'        tensor<fp16, [1, {ch}, 1, {sp}]> y16 = conv(dilations = c_dilations, groups = c_groups, pad = c_pad, pad_type = c_pad_type, strides = c_strides, weight = W, x = x16)[name = string("conv")];\n'
        '        string to_fp32 = const()[name = string("to_fp32"), val = string("fp32")];\n'
        f'        tensor<fp32, [1, {ch}, 1, {sp}]> y = cast(dtype = to_fp32, x = y16)[name = string("cast_out")];\n'
        '    } -> (y);\n'
        '}\n'
    )

    # Build identity-like weight: eye(ch) so conv is identity transform
    W = np.eye(ch, dtype=np.float32)
    blob_ptr, blob_len = ane.build_weight_blob(W)

    tensor_bytes_in = ch * sp * 4   # fp32 input
    tensor_bytes_out = ch * sp * 4  # fp32 output

    try:
        # Get raw weight bytes from blob pointer
        blob_bytes = bytes(ctypes.cast(blob_ptr, ctypes.POINTER(ctypes.c_uint8 * blob_len)).contents)
        kernel = ane.compile_kernel(
            mil_text,
            input_sizes=[tensor_bytes_in],
            output_sizes=[tensor_bytes_out],
            weight_data=blob_bytes,
        )
        print(f"[OK] MIL compilation succeeded (handle: 0x{kernel:x})")
        print(f"     Compile count: {ane.compile_count}")
    except ANEBridgeError as e:
        print(f"[FAIL] Compilation: {e}")
        ane.free_blob(blob_ptr)
        return False
    finally:
        ane.free_blob(blob_ptr)

    # Test: evaluate — identity conv should return input
    x = np.random.randn(1, ch, 1, sp).astype(np.float32)

    try:
        ane.write_input(kernel, 0, x)
        ane.eval(kernel)
        result = ane.read_output(kernel, 0, (1, ch, 1, sp), dtype=np.float32)

        # With identity weight matrix, output should ≈ input (fp16 rounding)
        if np.allclose(result, x, atol=0.05):
            print(f"[OK] ANE evaluation correct (identity conv)")
            print(f"     Input[:4]:  {x.flatten()[:4]}")
            print(f"     Output[:4]: {result.flatten()[:4]}")
        else:
            max_err = np.max(np.abs(result - x))
            print(f"[WARN] Result differs (max err: {max_err:.4f})")
            print(f"     Input[:4]:  {x.flatten()[:4]}")
            print(f"     Output[:4]: {result.flatten()[:4]}")
            # Don't fail — fp16 rounding can be significant
    except ANEBridgeError as e:
        print(f"[FAIL] Evaluation: {e}")
        ane.free_kernel(kernel)
        return False

    # Test: weight blob
    try:
        weights = np.random.randn(4, 4).astype(np.float32)
        blob, blob_len = ane.build_weight_blob(weights)
        print(f"[OK] Weight blob built ({blob_len} bytes for 4x4 float32)")
        ane.free_blob(blob)
    except ANEBridgeError as e:
        print(f"[FAIL] Weight blob: {e}")
        ane.free_kernel(kernel)
        return False

    ane.free_kernel(kernel)
    print(f"\n[PASS] All ANE bridge tests passed")
    print(f"       Final compile count: {ane.compile_count}")
    return True


if __name__ == "__main__":
    success = self_test()
    exit(0 if success else 1)