import ctypes, numpy as np, os, time MODEL_DIR = "deepseek-r1-1.5b-unary" HF_DIR = "deepseek-r1-1.5b-hf" lib = ctypes.CDLL("./unary_engine.so") lib.model_alloc.restype = ctypes.c_void_p lib.model_alloc.argtypes = [ctypes.c_int] lib.model_set_embed.argtypes = [ctypes.c_void_p, ctypes.c_void_p] lib.model_set_final_norm.argtypes = [ctypes.c_void_p, ctypes.c_void_p] lib.model_set_lm_head.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int, ctypes.c_int] lib.layer_set_norms.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p, ctypes.c_void_p] lib.layer_set_bias.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p] lib.layer_set_linears.argtypes = [ctypes.c_void_p, ctypes.c_int] + [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int, ctypes.c_int]*7 + [ctypes.c_int] lib.generate.restype = ctypes.c_int lib.generate.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p, ctypes.c_int, ctypes.c_float, ctypes.c_float, ctypes.c_int] lib.model_reset_cache.argtypes = [ctypes.c_void_p] _refs = [] def keep(a): _refs.append(a) return a.ctypes.data N_PLANES = 7 N_LAYERS = 28 PROJS = ['self_attn_q_proj','self_attn_k_proj','self_attn_v_proj','self_attn_o_proj','mlp_gate_proj','mlp_up_proj','mlp_down_proj'] DIMS = {'self_attn_q_proj':(1536,1536),'self_attn_k_proj':(256,1536),'self_attn_v_proj':(256,1536),'self_attn_o_proj':(1536,1536),'mlp_gate_proj':(8960,1536),'mlp_up_proj':(8960,1536),'mlp_down_proj':(1536,8960)} m = lib.model_alloc(N_PLANES) e = np.fromfile(os.path.join(MODEL_DIR,'model_embed_tokens_weight.fp16'), dtype=np.uint16) lib.model_set_embed(m, keep(e)) n = np.fromfile(os.path.join(MODEL_DIR,'model_norm_weight.fp16'), dtype=np.float16).astype(np.float32) lib.model_set_final_norm(m, keep(n)) h = np.fromfile(os.path.join(MODEL_DIR,'lm_head_weight.fp16'), dtype=np.uint16) lib.model_set_lm_head(m, keep(h), 151936, 1536) for l in range(N_LAYERS): inorm = np.fromfile(os.path.join(MODEL_DIR,f'model_layers_{l}_input_layernorm_weight.fp16'),dtype=np.float16).astype(np.float32) pnorm = np.fromfile(os.path.join(MODEL_DIR,f'model_layers_{l}_post_attention_layernorm_weight.fp16'),dtype=np.float16).astype(np.float32) lib.layer_set_norms(m, l, keep(inorm), keep(pnorm)) qb = np.fromfile(os.path.join(MODEL_DIR,f'model_layers_{l}_self_attn_q_proj_bias.fp16'),dtype=np.float16).astype(np.float32) kb = np.fromfile(os.path.join(MODEL_DIR,f'model_layers_{l}_self_attn_k_proj_bias.fp16'),dtype=np.float16).astype(np.float32) vb = np.fromfile(os.path.join(MODEL_DIR,f'model_layers_{l}_self_attn_v_proj_bias.fp16'),dtype=np.float16).astype(np.float32) lib.layer_set_bias(m, l, keep(qb), keep(kb), keep(vb)) pa = [] for pn in PROJS: base = os.path.join(MODEL_DIR,f'model_layers_{l}_{pn}_weight') s = np.fromfile(base+'.sign',dtype=np.uint64) p = np.fromfile(base+'.planes',dtype=np.uint64) sc = np.fromfile(base+'.scales',dtype=np.float32) od,id = DIMS[pn] pa.extend([keep(s),keep(p),keep(sc),od,id]) lib.layer_set_linears(m, l, *pa, N_PLANES) from transformers import AutoTokenizer tok = AutoTokenizer.from_pretrained(HF_DIR, trust_remote_code=True) # Test with actual prompt prompt = "What is 2+2? Think step by step." ids = tok.encode(prompt) inp = np.array(ids, dtype=np.int32) out = np.zeros(64, dtype=np.int32) lib.model_reset_cache(m) print(f"Prompt: {prompt} ({len(ids)} tokens)") # Test greedy first print("\n--- Greedy ---") t0 = time.time() ng = lib.generate(m, inp.ctypes.data, len(ids), out.ctypes.data, 64, ctypes.c_float(0.0), ctypes.c_float(0.9), tok.eos_token_id) dt = time.time() - t0 text = tok.decode(out[:ng].tolist(), skip_special_tokens=False) print(f"{ng} tokens, {dt:.1f}s, {ng/dt:.1f} tok/s") print(f"Output: {text}") # Test with temperature print("\n--- Temperature=0.6 ---") lib.model_reset_cache(m) out2 = np.zeros(64, dtype=np.int32) t0 = time.time() ng2 = lib.generate(m, inp.ctypes.data, len(ids), out2.ctypes.data, 64, ctypes.c_float(0.6), ctypes.c_float(0.9), tok.eos_token_id) dt2 = time.time() - t0 text2 = tok.decode(out2[:ng2].tolist(), skip_special_tokens=False) print(f"{ng2} tokens, {dt2:.1f}s, {ng2/dt2:.1f} tok/s") print(f"Output: {text2}")