import numpy as np import coremltools as ct from coremltools.converters.mil import Builder as mb from coremltools.converters.mil.mil import types def bit_const(v): b = np.array([(v >> i) & 1 for i in range(32)], dtype=np.float16).reshape(1,32,1,1) return mb.const(val=b) def band(a,b): return mb.mul(x=a, y=b) def bxor(a,b): return mb.abs(x=mb.sub(x=a, y=b)) def bor(a,b): return mb.maximum(x=a, y=b) def xor3(a,b,c): return mb.abs(x=mb.sub(x=mb.abs(x=mb.sub(x=a, y=b)), y=c)) def maj(a,b,c): return mb.maximum( x=mb.maximum(x=mb.minimum(x=a, y=b), y=mb.minimum(x=a, y=c)), y=mb.minimum(x=b, y=c) ) def ch(e,f,g): return bxor(g, band(e, bxor(f, g))) _W_ROTR = {} _W_SHL = {} _W_SHR = {} def _w_rotr(k): W = np.zeros((32,32,1,1), dtype=np.float16) for o in range(32): i = (o + k) % 32 W[o, i, 0, 0] = np.float16(1.0) return mb.const(val=W) def _w_shl(k): W = np.zeros((32,32,1,1), dtype=np.float16) for o in range(32): i = o - k if i >= 0: W[o, i, 0, 0] = np.float16(1.0) return mb.const(val=W) def _w_shr(k): W = np.zeros((32,32,1,1), dtype=np.float16) for o in range(32): i = o + k if i < 32: W[o, i, 0, 0] = np.float16(1.0) return mb.const(val=W) def rotr(x,k): k %= 32 if k == 0: return x if k not in _W_ROTR: _W_ROTR[k] = _w_rotr(k) return mb.conv(x=x, weight=_W_ROTR[k], pad_type="valid", groups=1) def shl(x,k): k = 0 if k < 0 else (31 if k > 31 else k) if k == 0: return x if k not in _W_SHL: _W_SHL[k] = _w_shl(k) return mb.conv(x=x, weight=_W_SHL[k], pad_type="valid", groups=1) def shr(x,k): k = 0 if k < 0 else (31 if k > 31 else k) if k == 0: return x if k not in _W_SHR: _W_SHR[k] = _w_shr(k) return mb.conv(x=x, weight=_W_SHR[k], pad_type="valid", groups=1) def Sigma0(x): return xor3(rotr(x,2), rotr(x,13), rotr(x,22)) def Sigma1(x): return xor3(rotr(x,6), rotr(x,11), rotr(x,25)) def sigma0(x): return xor3(rotr(x,7), rotr(x,18), shr(x,3)) def sigma1(x): return xor3(rotr(x,17), rotr(x,19), shr(x,10)) def csa(a,b,c): return xor3(a,b,c), maj(a,b,c) def cpa(a,b): p0 = bxor(a,b) p = p0 g = band(a,b) for d in [1,2,4,8,16]: g = bor(g, band(p, shl(g, d))) p = band(p, shl(p, d)) return bxor(p0, shl(g, 1)) def add2(a,b): return cpa(a,b) def add3(a,b,c): s1,c1 = csa(a,b,c) return cpa(s1, shl(c1,1)) def add4(a,b,c,d): z = mb.const(val=np.zeros((1,32,1,1), dtype=np.float16)) s1,c1 = csa(a,b,c) s2,c2 = csa(s1,d,z) s3,c3 = csa(s2, shl(c1,1), shl(c2,1)) return cpa(s3, shl(c3,1)) def add5(a,b,c,d,e): s1,c1 = csa(a,b,c) s2,c2 = csa(d,e,s1) s3,c3 = csa(s2, shl(c1,1), shl(c2,1)) return cpa(s3, shl(c3,1)) K_vals = [ 0x428a2f98,0x71374491,0xb5c0fbcf,0xe9b5dba5,0x3956c25b,0x59f111f1,0x923f82a4,0xab1c5ed5, 0xd807aa98,0x12835b01,0x243185be,0x550c7dc3,0x72be5d74,0x80deb1fe,0x9bdc06a7,0xc19bf174, 0xe49b69c1,0xefbe4786,0x0fc19dc6,0x240ca1cc,0x2de92c6f,0x4a7484aa,0x5cb0a9dc,0x76f988da, 0x983e5152,0xa831c66d,0xb00327c8,0xbf597fc7,0xc6e00bf3,0xd5a79147,0x06ca6351,0x14292967, 0x27b70a85,0x2e1b2138,0x4d2c6dfc,0x53380d13,0x650a7354,0x766a0abb,0x81c2c92e,0x92722c85, 0xa2bfe8a1,0xa81a664b,0xc24b8b70,0xc76c51a3,0xd192e819,0xd6990624,0xf40e3585,0x106aa070, 0x19a4c116,0x1e376c08,0x2748774c,0x34b0bcb5,0x391c0cb3,0x4ed8aa4a,0x5b9cca4f,0x682e6ff3, 0x748f82ee,0x78a5636f,0x84c87814,0x8cc70208,0x90befffa,0xa4506ceb,0xbef9a3f7,0xc67178f2 ] IV_vals = [ 0x6a09e667,0xbb67ae85,0x3c6ef372,0xa54ff53a,0x510e527f,0x9b05688c,0x1f83d9ab,0x5be0cd19 ] flexN = ct.RangeDim(1, 1024, default=1) N = flexN.symbol @mb.program( input_specs=[ mb.TensorSpec(shape=(N, 32, 1, 8), dtype=types.fp16), mb.TensorSpec(shape=(N, 32, 1, 16), dtype=types.fp16), ], opset_version=ct.target.iOS18, ) def prog(midstate, w_init): K_bits = [bit_const(k) for k in K_vals] IV_bits = [bit_const(v) for v in IV_vals] ONEBIT31 = bit_const(0x80000000) LEN256 = bit_const(256) H = mb.split(x=midstate, axis=3, num_splits=8) W = list(mb.split(x=w_init, axis=3, num_splits=16)) for t in range(16,64): W.append(add4(sigma1(W[t-2]), W[t-7], sigma0(W[t-15]), W[t-16])) a,b,c,d,e,f,g,h = H for t in range(64): T1 = add5(h, Sigma1(e), ch(e,f,g), W[t], K_bits[t]) T2 = add2(Sigma0(a), maj(a,b,c)) a,b,c,d,e,f,g,h = add2(T1,T2), a, b, c, add2(d,T1), e, f, g H1 = [add2(H[i], [a,b,c,d,e,f,g,h][i]) for i in range(8)] W2 = list(H1) Z = mb.const(val=np.zeros((1,32,1,1), dtype=np.float16)) W2.append(ONEBIT31) W2.extend([Z,Z,Z,Z,Z,Z]) W2.append(LEN256) for t in range(16,64): W2.append(add4(sigma1(W2[t-2]), W2[t-7], sigma0(W2[t-15]), W2[t-16])) a,b,c,d,e,f,g,h = [IV_bits[i] for i in range(8)] for t in range(64): T1 = add5(h, Sigma1(e), ch(e,f,g), W2[t], K_bits[t]) T2 = add2(Sigma0(a), maj(a,b,c)) a,b,c,d,e,f,g,h = add2(T1,T2), a, b, c, add2(d,T1), e, f, g H2 = [add2([a,b,c,d,e,f,g,h][i], IV_bits[i]) for i in range(8)] return mb.concat(values=H2, axis=3) mlmodel = ct.convert( prog, convert_to="mlprogram", compute_units=ct.ComputeUnit.CPU_AND_NE, minimum_deployment_target=ct.target.iOS18, compute_precision=ct.precision.FLOAT16, debug=True, inputs=[ ct.TensorType(name="midstate", shape=(flexN, 32, 1, 8), dtype=np.float16), ct.TensorType(name="w_init", shape=(flexN, 32, 1, 16), dtype=np.float16), ], ) mlmodel.save("sha256d.mlpackage")