| | |
| | """ |
| | Ternary Transformer Inference Engine |
| | |
| | Full Qwen2 architecture inference using ternary (1.58-bit) linear layers |
| | with AVX-512 optimized kernels. Zero multiplications in linear layers. |
| | |
| | Architecture: DeepSeek-R1-Distill-Qwen-1.5B |
| | - 28 layers, hidden=1536, intermediate=8960 |
| | - GQA: 12 heads, 2 KV heads, head_dim=128 |
| | - SwiGLU MLP, RoPE, RMSNorm |
| | |
| | (c) 2026 OpenTransformers Ltd / Scott Bisset |
| | """ |
| |
|
| | import os |
| | import json |
| | import ctypes |
| | import numpy as np |
| | from pathlib import Path |
| | import time |
| |
|
| | |
| | |
| | |
| | def load_kernel(so_path="ternary_kernel.so"): |
| | lib = ctypes.CDLL(so_path) |
| | |
| | |
| | lib.ternary_matvec_avx512.restype = None |
| | lib.ternary_matvec_avx512.argtypes = [ |
| | ctypes.c_void_p, |
| | ctypes.c_void_p, |
| | ctypes.c_void_p, |
| | ctypes.c_void_p, |
| | ctypes.c_void_p, |
| | ctypes.c_int, |
| | ctypes.c_int, |
| | ] |
| | |
| | |
| | lib.rmsnorm_avx512.restype = None |
| | lib.rmsnorm_avx512.argtypes = [ |
| | ctypes.c_void_p, |
| | ctypes.c_void_p, |
| | ctypes.c_void_p, |
| | ctypes.c_int, |
| | ctypes.c_float, |
| | ] |
| | |
| | |
| | lib.silu_avx512.restype = None |
| | lib.silu_avx512.argtypes = [ctypes.c_void_p, ctypes.c_int] |
| | |
| | |
| | lib.elemwise_mul_avx512.restype = None |
| | lib.elemwise_mul_avx512.argtypes = [ |
| | ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int |
| | ] |
| | |
| | |
| | lib.softmax.restype = None |
| | lib.softmax.argtypes = [ctypes.c_void_p, ctypes.c_int] |
| | |
| | |
| | lib.apply_rope.restype = None |
| | lib.apply_rope.argtypes = [ |
| | ctypes.c_void_p, ctypes.c_void_p, |
| | ctypes.c_int, ctypes.c_int, ctypes.c_int, |
| | ctypes.c_int, ctypes.c_float |
| | ] |
| | |
| | return lib |
| |
|
| | |
| | |
| | |
| | class TernaryLinear: |
| | def __init__(self, pos_bits, neg_bits, scales, out_dim, in_dim, kernel): |
| | self.pos = pos_bits |
| | self.neg = neg_bits |
| | self.scales = scales |
| | self.out_dim = out_dim |
| | self.in_dim = in_dim |
| | self.kernel = kernel |
| | |
| | def forward(self, x): |
| | """x: float32[in_dim] -> float32[out_dim]""" |
| | y = np.zeros(self.out_dim, dtype=np.float32) |
| | self.kernel.ternary_matvec_avx512( |
| | self.pos.ctypes.data, |
| | self.neg.ctypes.data, |
| | self.scales.ctypes.data, |
| | x.ctypes.data, |
| | y.ctypes.data, |
| | self.out_dim, |
| | self.in_dim, |
| | ) |
| | return y |
| |
|
| | |
| | |
| | |
| | class KVCache: |
| | def __init__(self, n_layers, n_kv_heads, head_dim, max_seq=4096): |
| | self.n_layers = n_layers |
| | self.max_seq = max_seq |
| | |
| | self.k = [np.zeros((max_seq, n_kv_heads, head_dim), dtype=np.float32) for _ in range(n_layers)] |
| | self.v = [np.zeros((max_seq, n_kv_heads, head_dim), dtype=np.float32) for _ in range(n_layers)] |
| | self.seq_len = 0 |
| | |
| | def append(self, layer, k, v): |
| | """k, v: [n_kv_heads, head_dim]""" |
| | pos = self.seq_len |
| | self.k[layer][pos] = k |
| | self.v[layer][pos] = v |
| | |
| | def get(self, layer): |
| | """Returns k, v up to current position: [seq_len, n_kv_heads, head_dim]""" |
| | return self.k[layer][:self.seq_len + 1], self.v[layer][:self.seq_len + 1] |
| | |
| | def advance(self): |
| | self.seq_len += 1 |
| |
|
| | |
| | |
| | |
| | class TernaryQwen: |
| | def __init__(self, model_dir, kernel): |
| | self.kernel = kernel |
| | self.model_dir = model_dir |
| | |
| | with open(os.path.join(model_dir, "config.json")) as f: |
| | self.config = json.load(f) |
| | with open(os.path.join(model_dir, "manifest.json")) as f: |
| | self.manifest = json.load(f) |
| | |
| | self.hidden = self.config["hidden_size"] |
| | self.inter = self.config["intermediate_size"] |
| | self.n_heads = self.config["num_attention_heads"] |
| | self.n_kv = self.config["num_key_value_heads"] |
| | self.head_dim = self.config["head_dim"] |
| | self.n_layers = self.config["num_hidden_layers"] |
| | self.vocab = self.config["vocab_size"] |
| | self.rope_theta = self.config["rope_theta"] |
| | self.eps = self.config["rms_norm_eps"] |
| | |
| | print(f"Loading ternary model: {self.n_layers} layers, " |
| | f"hidden={self.hidden}, heads={self.n_heads}/{self.n_kv}") |
| | |
| | t0 = time.time() |
| | self._load_weights() |
| | print(f"Model loaded in {time.time()-t0:.1f}s") |
| | |
| | self._compute_memory() |
| | |
| | def _load_ternary(self, key): |
| | """Load a ternary linear layer.""" |
| | prefix = os.path.join(self.model_dir, key.replace(".", "_")) |
| | shape = self.manifest["ternary"][key] |
| | out_dim, in_dim = shape |
| | chunks = (in_dim + 63) // 64 |
| | |
| | pos = np.fromfile(prefix + ".pos", dtype=np.uint64).reshape(out_dim, chunks) |
| | neg = np.fromfile(prefix + ".neg", dtype=np.uint64).reshape(out_dim, chunks) |
| | scales = np.fromfile(prefix + ".scales", dtype=np.float32) |
| | |
| | |
| | pos = np.ascontiguousarray(pos) |
| | neg = np.ascontiguousarray(neg) |
| | |
| | return TernaryLinear(pos, neg, scales, out_dim, in_dim, self.kernel) |
| | |
| | def _load_fp16(self, key): |
| | """Load an FP16 tensor.""" |
| | prefix = os.path.join(self.model_dir, key.replace(".", "_")) |
| | shape = self.manifest["fp16"][key] |
| | return np.fromfile(prefix + ".fp16", dtype=np.float16).reshape(shape).astype(np.float32) |
| | |
| | def _load_weights(self): |
| | """Load all weights.""" |
| | |
| | self.embed = self._load_fp16("model.embed_tokens.weight") |
| | |
| | |
| | self.final_norm = self._load_fp16("model.norm.weight") |
| | |
| | |
| | if "lm_head.weight" in self.manifest.get("ternary", {}): |
| | self.lm_head = self._load_ternary("lm_head.weight") |
| | self.lm_head_ternary = True |
| | elif "lm_head.weight" in self.manifest.get("fp16", {}): |
| | self.lm_head_w = self._load_fp16("lm_head.weight") |
| | self.lm_head_ternary = False |
| | else: |
| | |
| | self.lm_head_w = self.embed |
| | self.lm_head_ternary = False |
| | |
| | |
| | self.layers = [] |
| | for i in range(self.n_layers): |
| | layer = {} |
| | prefix = f"model.layers.{i}" |
| | |
| | |
| | layer["q_proj"] = self._load_ternary(f"{prefix}.self_attn.q_proj.weight") |
| | layer["k_proj"] = self._load_ternary(f"{prefix}.self_attn.k_proj.weight") |
| | layer["v_proj"] = self._load_ternary(f"{prefix}.self_attn.v_proj.weight") |
| | layer["o_proj"] = self._load_ternary(f"{prefix}.self_attn.o_proj.weight") |
| | |
| | |
| | layer["gate_proj"] = self._load_ternary(f"{prefix}.mlp.gate_proj.weight") |
| | layer["up_proj"] = self._load_ternary(f"{prefix}.mlp.up_proj.weight") |
| | layer["down_proj"] = self._load_ternary(f"{prefix}.mlp.down_proj.weight") |
| | |
| | |
| | layer["input_norm"] = self._load_fp16(f"{prefix}.input_layernorm.weight") |
| | layer["post_norm"] = self._load_fp16(f"{prefix}.post_attention_layernorm.weight") |
| | |
| | |
| | for proj in ["q_proj", "k_proj", "v_proj"]: |
| | bias_key = f"{prefix}.self_attn.{proj}.bias" |
| | if bias_key in self.manifest.get("fp16", {}): |
| | layer[f"{proj}_bias"] = self._load_fp16(bias_key) |
| | |
| | self.layers.append(layer) |
| | if (i + 1) % 7 == 0: |
| | print(f" Loaded {i+1}/{self.n_layers} layers") |
| | |
| | print(f" Loaded {self.n_layers}/{self.n_layers} layers") |
| | |
| | def _compute_memory(self): |
| | """Report memory usage.""" |
| | ternary_bytes = 0 |
| | fp_bytes = 0 |
| | |
| | for layer in self.layers: |
| | for key in ["q_proj", "k_proj", "v_proj", "o_proj", |
| | "gate_proj", "up_proj", "down_proj"]: |
| | tl = layer[key] |
| | ternary_bytes += tl.pos.nbytes + tl.neg.nbytes + tl.scales.nbytes |
| | for key in ["input_norm", "post_norm"]: |
| | fp_bytes += layer[key].nbytes |
| | |
| | fp_bytes += self.embed.nbytes + self.final_norm.nbytes |
| | if not self.lm_head_ternary: |
| | fp_bytes += self.lm_head_w.nbytes if hasattr(self, 'lm_head_w') else 0 |
| | |
| | total = ternary_bytes + fp_bytes |
| | print(f"\nMemory: ternary={ternary_bytes/1024/1024:.1f}MB, " |
| | f"fp={fp_bytes/1024/1024:.1f}MB, total={total/1024/1024:.1f}MB") |
| | |
| | def _rmsnorm(self, x, weight): |
| | """RMSNorm using C kernel.""" |
| | y = np.zeros_like(x) |
| | self.kernel.rmsnorm_avx512( |
| | x.ctypes.data, weight.ctypes.data, y.ctypes.data, |
| | len(x), ctypes.c_float(self.eps) |
| | ) |
| | return y |
| | |
| | def _attention(self, x, layer, cache, layer_idx, pos): |
| | """Grouped-Query Attention.""" |
| | h = self.hidden |
| | n_h = self.n_heads |
| | n_kv = self.n_kv |
| | hd = self.head_dim |
| | |
| | |
| | q = layer["q_proj"].forward(x) |
| | k = layer["k_proj"].forward(x) |
| | v = layer["v_proj"].forward(x) |
| | |
| | |
| | if "q_proj_bias" in layer: |
| | q += layer["q_proj_bias"] |
| | if "k_proj_bias" in layer: |
| | k += layer["k_proj_bias"] |
| | if "v_proj_bias" in layer: |
| | v += layer["v_proj_bias"] |
| | |
| | |
| | q = q.reshape(n_h, hd) |
| | k = k.reshape(n_kv, hd) |
| | v = v.reshape(n_kv, hd) |
| | |
| | |
| | self.kernel.apply_rope( |
| | q.ctypes.data, k.ctypes.data, |
| | n_h, n_kv, hd, pos, |
| | ctypes.c_float(self.rope_theta) |
| | ) |
| | |
| | |
| | cache.append(layer_idx, k, v) |
| | |
| | |
| | k_all, v_all = cache.get(layer_idx) |
| | seq_len = k_all.shape[0] |
| | |
| | |
| | heads_per_kv = n_h // n_kv |
| | |
| | |
| | output = np.zeros(n_h * hd, dtype=np.float32) |
| | scale = 1.0 / np.sqrt(hd) |
| | |
| | for head in range(n_h): |
| | kv_head = head // heads_per_kv |
| | q_h = q[head] |
| | |
| | |
| | scores = np.dot(k_all[:, kv_head, :], q_h) * scale |
| | |
| | |
| | |
| | scores_max = np.max(scores) |
| | scores = np.exp(scores - scores_max) |
| | scores /= np.sum(scores) |
| | |
| | |
| | out_h = np.dot(scores, v_all[:, kv_head, :]) |
| | output[head * hd:(head + 1) * hd] = out_h |
| | |
| | |
| | return layer["o_proj"].forward(output) |
| | |
| | def _mlp(self, x, layer): |
| | """SwiGLU MLP.""" |
| | gate = layer["gate_proj"].forward(x) |
| | up = layer["up_proj"].forward(x) |
| | |
| | |
| | self.kernel.silu_avx512(gate.ctypes.data, len(gate)) |
| | |
| | |
| | self.kernel.elemwise_mul_avx512( |
| | gate.ctypes.data, up.ctypes.data, gate.ctypes.data, len(gate) |
| | ) |
| | |
| | |
| | return layer["down_proj"].forward(gate) |
| | |
| | def forward_token(self, token_id, cache, pos): |
| | """Forward pass for a single token.""" |
| | |
| | x = self.embed[token_id].copy() |
| | |
| | |
| | for i, layer in enumerate(self.layers): |
| | |
| | normed = self._rmsnorm(x, layer["input_norm"]) |
| | |
| | |
| | attn_out = self._attention(normed, layer, cache, i, pos) |
| | x = x + attn_out |
| | |
| | |
| | normed = self._rmsnorm(x, layer["post_norm"]) |
| | |
| | |
| | mlp_out = self._mlp(normed, layer) |
| | x = x + mlp_out |
| | |
| | |
| | x = self._rmsnorm(x, self.final_norm) |
| | |
| | return x |
| | |
| | def logits(self, hidden): |
| | """Compute logits from hidden state.""" |
| | if self.lm_head_ternary: |
| | return self.lm_head.forward(hidden) |
| | else: |
| | return hidden @ self.lm_head_w.T |
| | |
| | def generate(self, token_ids, max_new_tokens=256, temperature=0.6, top_p=0.95): |
| | """Generate tokens autoregressively.""" |
| | cache = KVCache(self.n_layers, self.n_kv, self.head_dim) |
| | |
| | generated = [] |
| | all_tokens = list(token_ids) |
| | |
| | t_start = time.time() |
| | |
| | |
| | for i, tid in enumerate(token_ids): |
| | hidden = self.forward_token(tid, cache, i) |
| | if i < len(token_ids) - 1: |
| | cache.advance() |
| | |
| | t_prefill = time.time() - t_start |
| | |
| | |
| | t_decode_start = time.time() |
| | for step in range(max_new_tokens): |
| | |
| | logit_vec = self.logits(hidden) |
| | |
| | |
| | if temperature < 0.01: |
| | next_token = int(np.argmax(logit_vec)) |
| | else: |
| | logit_vec = logit_vec / temperature |
| | |
| | sorted_idx = np.argsort(logit_vec)[::-1] |
| | sorted_logits = logit_vec[sorted_idx] |
| | |
| | |
| | max_l = sorted_logits[0] |
| | probs = np.exp(sorted_logits - max_l) |
| | probs /= probs.sum() |
| | |
| | cumsum = np.cumsum(probs) |
| | cutoff = np.searchsorted(cumsum, top_p) + 1 |
| | |
| | top_probs = probs[:cutoff] |
| | top_probs /= top_probs.sum() |
| | top_idx = sorted_idx[:cutoff] |
| | |
| | next_token = int(np.random.choice(top_idx, p=top_probs)) |
| | |
| | generated.append(next_token) |
| | all_tokens.append(next_token) |
| | |
| | |
| | if next_token in [151643, 151644, 151645]: |
| | break |
| | |
| | cache.advance() |
| | hidden = self.forward_token(next_token, cache, len(all_tokens) - 1) |
| | |
| | t_total = time.time() - t_start |
| | t_decode = time.time() - t_decode_start |
| | n_gen = len(generated) |
| | |
| | stats = { |
| | "prefill_ms": t_prefill * 1000, |
| | "decode_ms": t_decode * 1000, |
| | "total_ms": t_total * 1000, |
| | "tokens_generated": n_gen, |
| | "tok_per_sec": n_gen / t_decode if t_decode > 0 else 0, |
| | "prefill_tokens": len(token_ids), |
| | } |
| | |
| | return generated, stats |
| |
|
| | |
| | |
| | |
| | class Tokenizer: |
| | def __init__(self, model_dir): |
| | from tokenizers import Tokenizer as HFTokenizer |
| | tok_path = os.path.join(model_dir, "tokenizer.json") |
| | if os.path.exists(tok_path): |
| | self.tok = HFTokenizer.from_file(tok_path) |
| | else: |
| | |
| | from transformers import AutoTokenizer |
| | self.tok = AutoTokenizer.from_pretrained(model_dir) |
| | self._is_transformers = True |
| | return |
| | self._is_transformers = False |
| | |
| | def encode(self, text): |
| | if self._is_transformers: |
| | return self.tok.encode(text) |
| | return self.tok.encode(text).ids |
| | |
| | def decode(self, ids): |
| | if self._is_transformers: |
| | return self.tok.decode(ids, skip_special_tokens=True) |
| | return self.tok.decode(ids) |
| | |
| | def apply_chat_template(self, messages): |
| | """Build Qwen chat format.""" |
| | parts = [] |
| | for msg in messages: |
| | role = msg["role"] |
| | content = msg["content"] |
| | parts.append(f"<|im_start|>{role}\n{content}<|im_end|>") |
| | parts.append("<|im_start|>assistant\n") |
| | return "".join(parts) |
| |
|
| | if __name__ == "__main__": |
| | import sys |
| | |
| | model_dir = sys.argv[1] if len(sys.argv) > 1 else "deepseek-r1-1.5b-ternary" |
| | kernel = load_kernel(os.path.join(os.path.dirname(__file__), "ternary_kernel.so")) |
| | |
| | model = TernaryQwen(model_dir, kernel) |
| | |
| | |
| | test_ids = [151644, 8948, 198, 151645, 198, 151644, 872, 198, 9707, 151645, 198, 151644, 77091, 198] |
| | |
| | print("\nGenerating...") |
| | tokens, stats = model.generate(test_ids, max_new_tokens=50, temperature=0.6) |
| | print(f"Generated {stats['tokens_generated']} tokens") |
| | print(f"Speed: {stats['tok_per_sec']:.1f} tok/s") |
| | print(f"Prefill: {stats['prefill_ms']:.0f}ms, Decode: {stats['decode_ms']:.0f}ms") |
| | print(f"Token IDs: {tokens}") |
| |
|