import time import hashlib import numpy as np import coremltools as ct IV = [ 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, ] K = [ 0x428a2f98,0x71374491,0xb5c0fbcf,0xe9b5dba5,0x3956c25b,0x59f111f1,0x923f82a4,0xab1c5ed5, 0xd807aa98,0x12835b01,0x243185be,0x550c7dc3,0x72be5d74,0x80deb1fe,0x9bdc06a7,0xc19bf174, 0xe49b69c1,0xefbe4786,0x0fc19dc6,0x240ca1cc,0x2de92c6f,0x4a7484aa,0x5cb0a9dc,0x76f988da, 0x983e5152,0xa831c66d,0xb00327c8,0xbf597fc7,0xc6e00bf3,0xd5a79147,0x06ca6351,0x14292967, 0x27b70a85,0x2e1b2138,0x4d2c6dfc,0x53380d13,0x650a7354,0x766a0abb,0x81c2c92e,0x92722c85, 0xa2bfe8a1,0xa81a664b,0xc24b8b70,0xc76c51a3,0xd192e819,0xd6990624,0xf40e3585,0x106aa070, 0x19a4c116,0x1e376c08,0x2748774c,0x34b0bcb5,0x391c0cb3,0x4ed8aa4a,0x5b9cca4f,0x682e6ff3, 0x748f82ee,0x78a5636f,0x84c87814,0x8cc70208,0x90befffa,0xa4506ceb,0xbef9a3f7,0xc67178f2, ] def _rotr(x, n): return ((x >> n) | (x << (32 - n))) & 0xFFFFFFFF def _Sigma0(x): return _rotr(x, 2) ^ _rotr(x, 13) ^ _rotr(x, 22) def _Sigma1(x): return _rotr(x, 6) ^ _rotr(x, 11) ^ _rotr(x, 25) def _sigma0(x): return _rotr(x, 7) ^ _rotr(x, 18) ^ (x >> 3) def _sigma1(x): return _rotr(x, 17) ^ _rotr(x, 19) ^ (x >> 10) def _pad_blocks(msg: bytes): L = len(msg) bitlen = L * 8 pad1 = b'\x80' pad0_len = (56 - (L + 1) % 64) % 64 pad0 = b'\x00' * pad0_len length_be = bitlen.to_bytes(8, 'big') m = msg + pad1 + pad0 + length_be blocks = [] for off in range(0, len(m), 64): b = m[off:off+64] words = [int.from_bytes(b[4*i:4*i+4], 'big') for i in range(16)] blocks.append(words) return blocks def _compress(state, block_words): W = [0]*64 W[:16] = block_words[:] for t in range(16, 64): W[t] = (_sigma1(W[t-2]) + W[t-7] + _sigma0(W[t-15]) + W[t-16]) & 0xFFFFFFFF a,b,c,d,e,f,g,h = state for t in range(64): T1 = (h + _Sigma1(e) + ((e & f) ^ ((~e) & g)) + K[t] + W[t]) & 0xFFFFFFFF T2 = (_Sigma0(a) + ((a & b) ^ (a & c) ^ (b & c))) & 0xFFFFFFFF h,g,f,e,d,c,b,a = g,f,e,(d + T1) & 0xFFFFFFFF,c,b,a,(T1 + T2) & 0xFFFFFFFF return [ (state[0] + a) & 0xFFFFFFFF, (state[1] + b) & 0xFFFFFFFF, (state[2] + c) & 0xFFFFFFFF, (state[3] + d) & 0xFFFFFFFF, (state[4] + e) & 0xFFFFFFFF, (state[5] + f) & 0xFFFFFFFF, (state[6] + g) & 0xFFFFFFFF, (state[7] + h) & 0xFFFFFFFF, ] def prepare_inputs_for_model(msg: bytes): blocks = _pad_blocks(msg) if len(blocks) == 1: midstate_words = IV[:] last_block = blocks[0] else: s = IV[:] for b in blocks[:-1]: s = _compress(s, b) midstate_words = s last_block = blocks[-1] return midstate_words, last_block def words_to_bitslice_fp16_3d(words): L = len(words) out = np.zeros((32, 1, L), dtype=np.float16) w = np.array(words, dtype=np.uint32) for i in range(32): out[i, 0, :] = ((w >> i) & 1).astype(np.float16) return out def batch_words_to_bitslice_fp16(list_of_word_lists): N = len(list_of_word_lists) L = len(list_of_word_lists[0]) arr = np.zeros((N, 32, 1, L), dtype=np.float16) for n, words in enumerate(list_of_word_lists): arr[n] = words_to_bitslice_fp16_3d(words) return arr def bitslice_to_words_fp16(bits): N, B, _, L = bits.shape words = [] for n in range(N): ws = [] for j in range(L): v = 0 col = bits[n, :, 0, j] for i in range(32): if col[i] > 0.5: v |= (1 << i) ws.append(v & 0xFFFFFFFF) words.append(ws) return words def words_to_bytes_be(words): return b''.join(int(w).to_bytes(4, 'big') for w in words) def double_sha256_bytes(msg: bytes): return hashlib.sha256(hashlib.sha256(msg).digest()).digest() def prepare_batch_from_messages(messages): midstates = [] last_blocks = [] for m in messages: ms, lb = prepare_inputs_for_model(m) midstates.append(ms) last_blocks.append(lb) mid_tensor = batch_words_to_bitslice_fp16(midstates) block_tensor = batch_words_to_bitslice_fp16(last_blocks) return mid_tensor, block_tensor def verify_model(model_path="sha256d.mlpackage", messages=None, compute_units=ct.ComputeUnit.CPU_AND_NE): if messages is None: messages = [ b"", b"abc", b"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq", b"message digest", b"The quick brown fox jumps over the lazy dog", b"The quick brown fox jumps over the lazy dog.", ] mid_tensor, block_tensor = prepare_batch_from_messages(messages) mlmodel = ct.models.MLModel(model_path, compute_units=compute_units) out = mlmodel.predict({"midstate": mid_tensor, "w_init": block_tensor}) out_name = next(iter(out.keys())) y = np.array(out[out_name]) out_words = bitslice_to_words_fp16(y) ok = True for i, m in enumerate(messages): model_digest = words_to_bytes_be(out_words[i]).hex() ref_digest = double_sha256_bytes(m).hex() tag = "OK" if model_digest == ref_digest else "FAIL" print(f"[verify {i}] {tag} model={model_digest} ref={ref_digest}") if tag == "FAIL": ok = False if not ok: raise AssertionError("Verification failed") def _replicate_bitslice(bits3d, N): x = np.expand_dims(bits3d, 0) return np.repeat(x, N, axis=0) def benchmark(model_path="sha256d.mlpackage", base_message=b"abc", batch_sizes=(1,2,4,8,16,32,64,128,256,512,1024), warmup=3, iters=10, compute_units=ct.ComputeUnit.CPU_AND_NE, verbose=False): ms, lb = prepare_inputs_for_model(base_message) mid3d = words_to_bitslice_fp16_3d(ms) blk3d = words_to_bitslice_fp16_3d(lb) mlmodel = ct.models.MLModel(model_path, compute_units=compute_units) results = [] for N in batch_sizes: mid = _replicate_bitslice(mid3d, N) blk = _replicate_bitslice(blk3d, N) for _ in range(warmup): mlmodel.predict({"midstate": mid, "w_init": blk}) times = [] for _ in range(iters): t0 = time.perf_counter() mlmodel.predict({"midstate": mid, "w_init": blk}) t1 = time.perf_counter() times.append(t1 - t0) times = np.array(times, dtype=np.float64) mean_s = float(times.mean()) p50_s = float(np.percentile(times, 50)) p90_s = float(np.percentile(times, 90)) thr = N / mean_s if verbose: print(f"[bench] N={N} mean={mean_s*1e3:.3f}ms p50={p50_s*1e3:.3f}ms p90={p90_s*1e3:.3f}ms {thr:.2f} H/s {thr/1e6:.3f} MH/s") results.append((N, mean_s, p50_s, p90_s, thr)) best = max(results, key=lambda r: r[4]) print("batch,mean_ms,p50_ms,p90_ms,hashes_per_s,MH_per_s") for N, mean_s, p50_s, p90_s, thr in results: print(f"{N},{mean_s*1e3:.3f},{p50_s*1e3:.3f},{p90_s*1e3:.3f},{thr:.2f},{thr/1e6:.3f}") print(f"best_batch={best[0]} best_hashrate={best[4]:.2f} H/s ({best[4]/1e6:.3f} MH/s)") return results, best if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="sha256d.mlpackage") parser.add_argument("--verify", action="store_true") parser.add_argument("--bench", action="store_true") parser.add_argument("--base", type=str, default="abc") parser.add_argument("--warmup", type=int, default=3) parser.add_argument("--iters", type=int, default=10) parser.add_argument("--batches", type=str, default="1,2,4,8,16,32,64,128,256,512,1024") parser.add_argument("--cu", type=str, default="CPU_AND_NE", choices=["CPU","CPU_AND_NE"]) parser.add_argument("--verbose", action="store_true") args = parser.parse_args() cu_map = { "CPU": ct.ComputeUnit.CPU_ONLY, "CPU_AND_NE": ct.ComputeUnit.CPU_AND_NE, } cu = cu_map[args.cu] if args.verify: msgs = [ b"", b"abc", b"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq", b"message digest", b"The quick brown fox jumps over the lazy dog", b"The quick brown fox jumps over the lazy dog.", ] verify_model(model_path=args.model, messages=msgs, compute_units=cu) if args.bench: batches = tuple(int(x) for x in args.batches.split(",") if x.strip()) benchmark(model_path=args.model, base_message=args.base.encode("utf-8"), batch_sizes=batches, warmup=args.warmup, iters=args.iters, compute_units=cu, verbose=args.verbose)