ANE-sha256d / test.py
pkhairkh's picture
Initial commit
e2c008f
import time
import hashlib
import numpy as np
import coremltools as ct
IV = [
0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a,
0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19,
]
K = [
0x428a2f98,0x71374491,0xb5c0fbcf,0xe9b5dba5,0x3956c25b,0x59f111f1,0x923f82a4,0xab1c5ed5,
0xd807aa98,0x12835b01,0x243185be,0x550c7dc3,0x72be5d74,0x80deb1fe,0x9bdc06a7,0xc19bf174,
0xe49b69c1,0xefbe4786,0x0fc19dc6,0x240ca1cc,0x2de92c6f,0x4a7484aa,0x5cb0a9dc,0x76f988da,
0x983e5152,0xa831c66d,0xb00327c8,0xbf597fc7,0xc6e00bf3,0xd5a79147,0x06ca6351,0x14292967,
0x27b70a85,0x2e1b2138,0x4d2c6dfc,0x53380d13,0x650a7354,0x766a0abb,0x81c2c92e,0x92722c85,
0xa2bfe8a1,0xa81a664b,0xc24b8b70,0xc76c51a3,0xd192e819,0xd6990624,0xf40e3585,0x106aa070,
0x19a4c116,0x1e376c08,0x2748774c,0x34b0bcb5,0x391c0cb3,0x4ed8aa4a,0x5b9cca4f,0x682e6ff3,
0x748f82ee,0x78a5636f,0x84c87814,0x8cc70208,0x90befffa,0xa4506ceb,0xbef9a3f7,0xc67178f2,
]
def _rotr(x, n): return ((x >> n) | (x << (32 - n))) & 0xFFFFFFFF
def _Sigma0(x): return _rotr(x, 2) ^ _rotr(x, 13) ^ _rotr(x, 22)
def _Sigma1(x): return _rotr(x, 6) ^ _rotr(x, 11) ^ _rotr(x, 25)
def _sigma0(x): return _rotr(x, 7) ^ _rotr(x, 18) ^ (x >> 3)
def _sigma1(x): return _rotr(x, 17) ^ _rotr(x, 19) ^ (x >> 10)
def _pad_blocks(msg: bytes):
L = len(msg)
bitlen = L * 8
pad1 = b'\x80'
pad0_len = (56 - (L + 1) % 64) % 64
pad0 = b'\x00' * pad0_len
length_be = bitlen.to_bytes(8, 'big')
m = msg + pad1 + pad0 + length_be
blocks = []
for off in range(0, len(m), 64):
b = m[off:off+64]
words = [int.from_bytes(b[4*i:4*i+4], 'big') for i in range(16)]
blocks.append(words)
return blocks
def _compress(state, block_words):
W = [0]*64
W[:16] = block_words[:]
for t in range(16, 64):
W[t] = (_sigma1(W[t-2]) + W[t-7] + _sigma0(W[t-15]) + W[t-16]) & 0xFFFFFFFF
a,b,c,d,e,f,g,h = state
for t in range(64):
T1 = (h + _Sigma1(e) + ((e & f) ^ ((~e) & g)) + K[t] + W[t]) & 0xFFFFFFFF
T2 = (_Sigma0(a) + ((a & b) ^ (a & c) ^ (b & c))) & 0xFFFFFFFF
h,g,f,e,d,c,b,a = g,f,e,(d + T1) & 0xFFFFFFFF,c,b,a,(T1 + T2) & 0xFFFFFFFF
return [
(state[0] + a) & 0xFFFFFFFF, (state[1] + b) & 0xFFFFFFFF,
(state[2] + c) & 0xFFFFFFFF, (state[3] + d) & 0xFFFFFFFF,
(state[4] + e) & 0xFFFFFFFF, (state[5] + f) & 0xFFFFFFFF,
(state[6] + g) & 0xFFFFFFFF, (state[7] + h) & 0xFFFFFFFF,
]
def prepare_inputs_for_model(msg: bytes):
blocks = _pad_blocks(msg)
if len(blocks) == 1:
midstate_words = IV[:]
last_block = blocks[0]
else:
s = IV[:]
for b in blocks[:-1]:
s = _compress(s, b)
midstate_words = s
last_block = blocks[-1]
return midstate_words, last_block
def words_to_bitslice_fp16_3d(words):
L = len(words)
out = np.zeros((32, 1, L), dtype=np.float16)
w = np.array(words, dtype=np.uint32)
for i in range(32):
out[i, 0, :] = ((w >> i) & 1).astype(np.float16)
return out
def batch_words_to_bitslice_fp16(list_of_word_lists):
N = len(list_of_word_lists)
L = len(list_of_word_lists[0])
arr = np.zeros((N, 32, 1, L), dtype=np.float16)
for n, words in enumerate(list_of_word_lists):
arr[n] = words_to_bitslice_fp16_3d(words)
return arr
def bitslice_to_words_fp16(bits):
N, B, _, L = bits.shape
words = []
for n in range(N):
ws = []
for j in range(L):
v = 0
col = bits[n, :, 0, j]
for i in range(32):
if col[i] > 0.5:
v |= (1 << i)
ws.append(v & 0xFFFFFFFF)
words.append(ws)
return words
def words_to_bytes_be(words):
return b''.join(int(w).to_bytes(4, 'big') for w in words)
def double_sha256_bytes(msg: bytes):
return hashlib.sha256(hashlib.sha256(msg).digest()).digest()
def prepare_batch_from_messages(messages):
midstates = []
last_blocks = []
for m in messages:
ms, lb = prepare_inputs_for_model(m)
midstates.append(ms)
last_blocks.append(lb)
mid_tensor = batch_words_to_bitslice_fp16(midstates)
block_tensor = batch_words_to_bitslice_fp16(last_blocks)
return mid_tensor, block_tensor
def verify_model(model_path="sha256d.mlpackage", messages=None, compute_units=ct.ComputeUnit.CPU_AND_NE):
if messages is None:
messages = [
b"",
b"abc",
b"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq",
b"message digest",
b"The quick brown fox jumps over the lazy dog",
b"The quick brown fox jumps over the lazy dog.",
]
mid_tensor, block_tensor = prepare_batch_from_messages(messages)
mlmodel = ct.models.MLModel(model_path, compute_units=compute_units)
out = mlmodel.predict({"midstate": mid_tensor, "w_init": block_tensor})
out_name = next(iter(out.keys()))
y = np.array(out[out_name])
out_words = bitslice_to_words_fp16(y)
ok = True
for i, m in enumerate(messages):
model_digest = words_to_bytes_be(out_words[i]).hex()
ref_digest = double_sha256_bytes(m).hex()
tag = "OK" if model_digest == ref_digest else "FAIL"
print(f"[verify {i}] {tag} model={model_digest} ref={ref_digest}")
if tag == "FAIL":
ok = False
if not ok:
raise AssertionError("Verification failed")
def _replicate_bitslice(bits3d, N):
x = np.expand_dims(bits3d, 0)
return np.repeat(x, N, axis=0)
def benchmark(model_path="sha256d.mlpackage",
base_message=b"abc",
batch_sizes=(1,2,4,8,16,32,64,128,256,512,1024),
warmup=3,
iters=10,
compute_units=ct.ComputeUnit.CPU_AND_NE,
verbose=False):
ms, lb = prepare_inputs_for_model(base_message)
mid3d = words_to_bitslice_fp16_3d(ms)
blk3d = words_to_bitslice_fp16_3d(lb)
mlmodel = ct.models.MLModel(model_path, compute_units=compute_units)
results = []
for N in batch_sizes:
mid = _replicate_bitslice(mid3d, N)
blk = _replicate_bitslice(blk3d, N)
for _ in range(warmup):
mlmodel.predict({"midstate": mid, "w_init": blk})
times = []
for _ in range(iters):
t0 = time.perf_counter()
mlmodel.predict({"midstate": mid, "w_init": blk})
t1 = time.perf_counter()
times.append(t1 - t0)
times = np.array(times, dtype=np.float64)
mean_s = float(times.mean())
p50_s = float(np.percentile(times, 50))
p90_s = float(np.percentile(times, 90))
thr = N / mean_s
if verbose:
print(f"[bench] N={N} mean={mean_s*1e3:.3f}ms p50={p50_s*1e3:.3f}ms p90={p90_s*1e3:.3f}ms {thr:.2f} H/s {thr/1e6:.3f} MH/s")
results.append((N, mean_s, p50_s, p90_s, thr))
best = max(results, key=lambda r: r[4])
print("batch,mean_ms,p50_ms,p90_ms,hashes_per_s,MH_per_s")
for N, mean_s, p50_s, p90_s, thr in results:
print(f"{N},{mean_s*1e3:.3f},{p50_s*1e3:.3f},{p90_s*1e3:.3f},{thr:.2f},{thr/1e6:.3f}")
print(f"best_batch={best[0]} best_hashrate={best[4]:.2f} H/s ({best[4]/1e6:.3f} MH/s)")
return results, best
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="sha256d.mlpackage")
parser.add_argument("--verify", action="store_true")
parser.add_argument("--bench", action="store_true")
parser.add_argument("--base", type=str, default="abc")
parser.add_argument("--warmup", type=int, default=3)
parser.add_argument("--iters", type=int, default=10)
parser.add_argument("--batches", type=str, default="1,2,4,8,16,32,64,128,256,512,1024")
parser.add_argument("--cu", type=str, default="CPU_AND_NE", choices=["CPU","CPU_AND_NE"])
parser.add_argument("--verbose", action="store_true")
args = parser.parse_args()
cu_map = {
"CPU": ct.ComputeUnit.CPU_ONLY,
"CPU_AND_NE": ct.ComputeUnit.CPU_AND_NE,
}
cu = cu_map[args.cu]
if args.verify:
msgs = [
b"",
b"abc",
b"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq",
b"message digest",
b"The quick brown fox jumps over the lazy dog",
b"The quick brown fox jumps over the lazy dog.",
]
verify_model(model_path=args.model, messages=msgs, compute_units=cu)
if args.bench:
batches = tuple(int(x) for x in args.batches.split(",") if x.strip())
benchmark(model_path=args.model,
base_message=args.base.encode("utf-8"),
batch_sizes=batches,
warmup=args.warmup,
iters=args.iters,
compute_units=cu,
verbose=args.verbose)