| import time | |
| import hashlib | |
| import numpy as np | |
| import coremltools as ct | |
| IV = [ | |
| 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, | |
| 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, | |
| ] | |
| K = [ | |
| 0x428a2f98,0x71374491,0xb5c0fbcf,0xe9b5dba5,0x3956c25b,0x59f111f1,0x923f82a4,0xab1c5ed5, | |
| 0xd807aa98,0x12835b01,0x243185be,0x550c7dc3,0x72be5d74,0x80deb1fe,0x9bdc06a7,0xc19bf174, | |
| 0xe49b69c1,0xefbe4786,0x0fc19dc6,0x240ca1cc,0x2de92c6f,0x4a7484aa,0x5cb0a9dc,0x76f988da, | |
| 0x983e5152,0xa831c66d,0xb00327c8,0xbf597fc7,0xc6e00bf3,0xd5a79147,0x06ca6351,0x14292967, | |
| 0x27b70a85,0x2e1b2138,0x4d2c6dfc,0x53380d13,0x650a7354,0x766a0abb,0x81c2c92e,0x92722c85, | |
| 0xa2bfe8a1,0xa81a664b,0xc24b8b70,0xc76c51a3,0xd192e819,0xd6990624,0xf40e3585,0x106aa070, | |
| 0x19a4c116,0x1e376c08,0x2748774c,0x34b0bcb5,0x391c0cb3,0x4ed8aa4a,0x5b9cca4f,0x682e6ff3, | |
| 0x748f82ee,0x78a5636f,0x84c87814,0x8cc70208,0x90befffa,0xa4506ceb,0xbef9a3f7,0xc67178f2, | |
| ] | |
| def _rotr(x, n): return ((x >> n) | (x << (32 - n))) & 0xFFFFFFFF | |
| def _Sigma0(x): return _rotr(x, 2) ^ _rotr(x, 13) ^ _rotr(x, 22) | |
| def _Sigma1(x): return _rotr(x, 6) ^ _rotr(x, 11) ^ _rotr(x, 25) | |
| def _sigma0(x): return _rotr(x, 7) ^ _rotr(x, 18) ^ (x >> 3) | |
| def _sigma1(x): return _rotr(x, 17) ^ _rotr(x, 19) ^ (x >> 10) | |
| def _pad_blocks(msg: bytes): | |
| L = len(msg) | |
| bitlen = L * 8 | |
| pad1 = b'\x80' | |
| pad0_len = (56 - (L + 1) % 64) % 64 | |
| pad0 = b'\x00' * pad0_len | |
| length_be = bitlen.to_bytes(8, 'big') | |
| m = msg + pad1 + pad0 + length_be | |
| blocks = [] | |
| for off in range(0, len(m), 64): | |
| b = m[off:off+64] | |
| words = [int.from_bytes(b[4*i:4*i+4], 'big') for i in range(16)] | |
| blocks.append(words) | |
| return blocks | |
| def _compress(state, block_words): | |
| W = [0]*64 | |
| W[:16] = block_words[:] | |
| for t in range(16, 64): | |
| W[t] = (_sigma1(W[t-2]) + W[t-7] + _sigma0(W[t-15]) + W[t-16]) & 0xFFFFFFFF | |
| a,b,c,d,e,f,g,h = state | |
| for t in range(64): | |
| T1 = (h + _Sigma1(e) + ((e & f) ^ ((~e) & g)) + K[t] + W[t]) & 0xFFFFFFFF | |
| T2 = (_Sigma0(a) + ((a & b) ^ (a & c) ^ (b & c))) & 0xFFFFFFFF | |
| h,g,f,e,d,c,b,a = g,f,e,(d + T1) & 0xFFFFFFFF,c,b,a,(T1 + T2) & 0xFFFFFFFF | |
| return [ | |
| (state[0] + a) & 0xFFFFFFFF, (state[1] + b) & 0xFFFFFFFF, | |
| (state[2] + c) & 0xFFFFFFFF, (state[3] + d) & 0xFFFFFFFF, | |
| (state[4] + e) & 0xFFFFFFFF, (state[5] + f) & 0xFFFFFFFF, | |
| (state[6] + g) & 0xFFFFFFFF, (state[7] + h) & 0xFFFFFFFF, | |
| ] | |
| def prepare_inputs_for_model(msg: bytes): | |
| blocks = _pad_blocks(msg) | |
| if len(blocks) == 1: | |
| midstate_words = IV[:] | |
| last_block = blocks[0] | |
| else: | |
| s = IV[:] | |
| for b in blocks[:-1]: | |
| s = _compress(s, b) | |
| midstate_words = s | |
| last_block = blocks[-1] | |
| return midstate_words, last_block | |
| def words_to_bitslice_fp16_3d(words): | |
| L = len(words) | |
| out = np.zeros((32, 1, L), dtype=np.float16) | |
| w = np.array(words, dtype=np.uint32) | |
| for i in range(32): | |
| out[i, 0, :] = ((w >> i) & 1).astype(np.float16) | |
| return out | |
| def batch_words_to_bitslice_fp16(list_of_word_lists): | |
| N = len(list_of_word_lists) | |
| L = len(list_of_word_lists[0]) | |
| arr = np.zeros((N, 32, 1, L), dtype=np.float16) | |
| for n, words in enumerate(list_of_word_lists): | |
| arr[n] = words_to_bitslice_fp16_3d(words) | |
| return arr | |
| def bitslice_to_words_fp16(bits): | |
| N, B, _, L = bits.shape | |
| words = [] | |
| for n in range(N): | |
| ws = [] | |
| for j in range(L): | |
| v = 0 | |
| col = bits[n, :, 0, j] | |
| for i in range(32): | |
| if col[i] > 0.5: | |
| v |= (1 << i) | |
| ws.append(v & 0xFFFFFFFF) | |
| words.append(ws) | |
| return words | |
| def words_to_bytes_be(words): | |
| return b''.join(int(w).to_bytes(4, 'big') for w in words) | |
| def double_sha256_bytes(msg: bytes): | |
| return hashlib.sha256(hashlib.sha256(msg).digest()).digest() | |
| def prepare_batch_from_messages(messages): | |
| midstates = [] | |
| last_blocks = [] | |
| for m in messages: | |
| ms, lb = prepare_inputs_for_model(m) | |
| midstates.append(ms) | |
| last_blocks.append(lb) | |
| mid_tensor = batch_words_to_bitslice_fp16(midstates) | |
| block_tensor = batch_words_to_bitslice_fp16(last_blocks) | |
| return mid_tensor, block_tensor | |
| def verify_model(model_path="sha256d.mlpackage", messages=None, compute_units=ct.ComputeUnit.CPU_AND_NE): | |
| if messages is None: | |
| messages = [ | |
| b"", | |
| b"abc", | |
| b"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq", | |
| b"message digest", | |
| b"The quick brown fox jumps over the lazy dog", | |
| b"The quick brown fox jumps over the lazy dog.", | |
| ] | |
| mid_tensor, block_tensor = prepare_batch_from_messages(messages) | |
| mlmodel = ct.models.MLModel(model_path, compute_units=compute_units) | |
| out = mlmodel.predict({"midstate": mid_tensor, "w_init": block_tensor}) | |
| out_name = next(iter(out.keys())) | |
| y = np.array(out[out_name]) | |
| out_words = bitslice_to_words_fp16(y) | |
| ok = True | |
| for i, m in enumerate(messages): | |
| model_digest = words_to_bytes_be(out_words[i]).hex() | |
| ref_digest = double_sha256_bytes(m).hex() | |
| tag = "OK" if model_digest == ref_digest else "FAIL" | |
| print(f"[verify {i}] {tag} model={model_digest} ref={ref_digest}") | |
| if tag == "FAIL": | |
| ok = False | |
| if not ok: | |
| raise AssertionError("Verification failed") | |
| def _replicate_bitslice(bits3d, N): | |
| x = np.expand_dims(bits3d, 0) | |
| return np.repeat(x, N, axis=0) | |
| def benchmark(model_path="sha256d.mlpackage", | |
| base_message=b"abc", | |
| batch_sizes=(1,2,4,8,16,32,64,128,256,512,1024), | |
| warmup=3, | |
| iters=10, | |
| compute_units=ct.ComputeUnit.CPU_AND_NE, | |
| verbose=False): | |
| ms, lb = prepare_inputs_for_model(base_message) | |
| mid3d = words_to_bitslice_fp16_3d(ms) | |
| blk3d = words_to_bitslice_fp16_3d(lb) | |
| mlmodel = ct.models.MLModel(model_path, compute_units=compute_units) | |
| results = [] | |
| for N in batch_sizes: | |
| mid = _replicate_bitslice(mid3d, N) | |
| blk = _replicate_bitslice(blk3d, N) | |
| for _ in range(warmup): | |
| mlmodel.predict({"midstate": mid, "w_init": blk}) | |
| times = [] | |
| for _ in range(iters): | |
| t0 = time.perf_counter() | |
| mlmodel.predict({"midstate": mid, "w_init": blk}) | |
| t1 = time.perf_counter() | |
| times.append(t1 - t0) | |
| times = np.array(times, dtype=np.float64) | |
| mean_s = float(times.mean()) | |
| p50_s = float(np.percentile(times, 50)) | |
| p90_s = float(np.percentile(times, 90)) | |
| thr = N / mean_s | |
| if verbose: | |
| print(f"[bench] N={N} mean={mean_s*1e3:.3f}ms p50={p50_s*1e3:.3f}ms p90={p90_s*1e3:.3f}ms {thr:.2f} H/s {thr/1e6:.3f} MH/s") | |
| results.append((N, mean_s, p50_s, p90_s, thr)) | |
| best = max(results, key=lambda r: r[4]) | |
| print("batch,mean_ms,p50_ms,p90_ms,hashes_per_s,MH_per_s") | |
| for N, mean_s, p50_s, p90_s, thr in results: | |
| print(f"{N},{mean_s*1e3:.3f},{p50_s*1e3:.3f},{p90_s*1e3:.3f},{thr:.2f},{thr/1e6:.3f}") | |
| print(f"best_batch={best[0]} best_hashrate={best[4]:.2f} H/s ({best[4]/1e6:.3f} MH/s)") | |
| return results, best | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model", type=str, default="sha256d.mlpackage") | |
| parser.add_argument("--verify", action="store_true") | |
| parser.add_argument("--bench", action="store_true") | |
| parser.add_argument("--base", type=str, default="abc") | |
| parser.add_argument("--warmup", type=int, default=3) | |
| parser.add_argument("--iters", type=int, default=10) | |
| parser.add_argument("--batches", type=str, default="1,2,4,8,16,32,64,128,256,512,1024") | |
| parser.add_argument("--cu", type=str, default="CPU_AND_NE", choices=["CPU","CPU_AND_NE"]) | |
| parser.add_argument("--verbose", action="store_true") | |
| args = parser.parse_args() | |
| cu_map = { | |
| "CPU": ct.ComputeUnit.CPU_ONLY, | |
| "CPU_AND_NE": ct.ComputeUnit.CPU_AND_NE, | |
| } | |
| cu = cu_map[args.cu] | |
| if args.verify: | |
| msgs = [ | |
| b"", | |
| b"abc", | |
| b"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq", | |
| b"message digest", | |
| b"The quick brown fox jumps over the lazy dog", | |
| b"The quick brown fox jumps over the lazy dog.", | |
| ] | |
| verify_model(model_path=args.model, messages=msgs, compute_units=cu) | |
| if args.bench: | |
| batches = tuple(int(x) for x in args.batches.split(",") if x.strip()) | |
| benchmark(model_path=args.model, | |
| base_message=args.base.encode("utf-8"), | |
| batch_sizes=batches, | |
| warmup=args.warmup, | |
| iters=args.iters, | |
| compute_units=cu, | |
| verbose=args.verbose) |