#!/usr/bin/env python3 """ Ternary Transformer Inference Engine Full Qwen2 architecture inference using ternary (1.58-bit) linear layers with AVX-512 optimized kernels. Zero multiplications in linear layers. Architecture: DeepSeek-R1-Distill-Qwen-1.5B - 28 layers, hidden=1536, intermediate=8960 - GQA: 12 heads, 2 KV heads, head_dim=128 - SwiGLU MLP, RoPE, RMSNorm (c) 2026 OpenTransformers Ltd / Scott Bisset """ import os import json import ctypes import numpy as np from pathlib import Path import time # ============================================================ # Load C kernel # ============================================================ def load_kernel(so_path="ternary_kernel.so"): lib = ctypes.CDLL(so_path) # ternary_matvec_avx512 lib.ternary_matvec_avx512.restype = None lib.ternary_matvec_avx512.argtypes = [ ctypes.c_void_p, # pos_bits ctypes.c_void_p, # neg_bits ctypes.c_void_p, # scales ctypes.c_void_p, # x ctypes.c_void_p, # y ctypes.c_int, # out_dim ctypes.c_int, # in_dim ] # rmsnorm lib.rmsnorm_avx512.restype = None lib.rmsnorm_avx512.argtypes = [ ctypes.c_void_p, # x ctypes.c_void_p, # weight ctypes.c_void_p, # y ctypes.c_int, # dim ctypes.c_float, # eps ] # silu lib.silu_avx512.restype = None lib.silu_avx512.argtypes = [ctypes.c_void_p, ctypes.c_int] # elemwise_mul lib.elemwise_mul_avx512.restype = None lib.elemwise_mul_avx512.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int ] # softmax lib.softmax.restype = None lib.softmax.argtypes = [ctypes.c_void_p, ctypes.c_int] # rope lib.apply_rope.restype = None lib.apply_rope.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_float ] return lib # ============================================================ # Ternary Linear Layer # ============================================================ class TernaryLinear: def __init__(self, pos_bits, neg_bits, scales, out_dim, in_dim, kernel): self.pos = pos_bits # uint64 contiguous array self.neg = neg_bits self.scales = scales # float32 self.out_dim = out_dim self.in_dim = in_dim self.kernel = kernel def forward(self, x): """x: float32[in_dim] -> float32[out_dim]""" y = np.zeros(self.out_dim, dtype=np.float32) self.kernel.ternary_matvec_avx512( self.pos.ctypes.data, self.neg.ctypes.data, self.scales.ctypes.data, x.ctypes.data, y.ctypes.data, self.out_dim, self.in_dim, ) return y # ============================================================ # KV Cache # ============================================================ class KVCache: def __init__(self, n_layers, n_kv_heads, head_dim, max_seq=4096): self.n_layers = n_layers self.max_seq = max_seq # Pre-allocate self.k = [np.zeros((max_seq, n_kv_heads, head_dim), dtype=np.float32) for _ in range(n_layers)] self.v = [np.zeros((max_seq, n_kv_heads, head_dim), dtype=np.float32) for _ in range(n_layers)] self.seq_len = 0 def append(self, layer, k, v): """k, v: [n_kv_heads, head_dim]""" pos = self.seq_len self.k[layer][pos] = k self.v[layer][pos] = v def get(self, layer): """Returns k, v up to current position: [seq_len, n_kv_heads, head_dim]""" return self.k[layer][:self.seq_len + 1], self.v[layer][:self.seq_len + 1] def advance(self): self.seq_len += 1 # ============================================================ # Model # ============================================================ class TernaryQwen: def __init__(self, model_dir, kernel): self.kernel = kernel self.model_dir = model_dir with open(os.path.join(model_dir, "config.json")) as f: self.config = json.load(f) with open(os.path.join(model_dir, "manifest.json")) as f: self.manifest = json.load(f) self.hidden = self.config["hidden_size"] # 1536 self.inter = self.config["intermediate_size"] # 8960 self.n_heads = self.config["num_attention_heads"] # 12 self.n_kv = self.config["num_key_value_heads"] # 2 self.head_dim = self.config["head_dim"] # 128 self.n_layers = self.config["num_hidden_layers"] # 28 self.vocab = self.config["vocab_size"] # 151936 self.rope_theta = self.config["rope_theta"] self.eps = self.config["rms_norm_eps"] print(f"Loading ternary model: {self.n_layers} layers, " f"hidden={self.hidden}, heads={self.n_heads}/{self.n_kv}") t0 = time.time() self._load_weights() print(f"Model loaded in {time.time()-t0:.1f}s") self._compute_memory() def _load_ternary(self, key): """Load a ternary linear layer.""" prefix = os.path.join(self.model_dir, key.replace(".", "_")) shape = self.manifest["ternary"][key] out_dim, in_dim = shape chunks = (in_dim + 63) // 64 pos = np.fromfile(prefix + ".pos", dtype=np.uint64).reshape(out_dim, chunks) neg = np.fromfile(prefix + ".neg", dtype=np.uint64).reshape(out_dim, chunks) scales = np.fromfile(prefix + ".scales", dtype=np.float32) # Make contiguous pos = np.ascontiguousarray(pos) neg = np.ascontiguousarray(neg) return TernaryLinear(pos, neg, scales, out_dim, in_dim, self.kernel) def _load_fp16(self, key): """Load an FP16 tensor.""" prefix = os.path.join(self.model_dir, key.replace(".", "_")) shape = self.manifest["fp16"][key] return np.fromfile(prefix + ".fp16", dtype=np.float16).reshape(shape).astype(np.float32) def _load_weights(self): """Load all weights.""" # Embedding (FP16) self.embed = self._load_fp16("model.embed_tokens.weight") # [vocab, hidden] # Final norm self.final_norm = self._load_fp16("model.norm.weight") # [hidden] # LM head — check if it exists as ternary or fp16 if "lm_head.weight" in self.manifest.get("ternary", {}): self.lm_head = self._load_ternary("lm_head.weight") self.lm_head_ternary = True elif "lm_head.weight" in self.manifest.get("fp16", {}): self.lm_head_w = self._load_fp16("lm_head.weight") self.lm_head_ternary = False else: # Tied embeddings self.lm_head_w = self.embed self.lm_head_ternary = False # Layers self.layers = [] for i in range(self.n_layers): layer = {} prefix = f"model.layers.{i}" # Attention layer["q_proj"] = self._load_ternary(f"{prefix}.self_attn.q_proj.weight") layer["k_proj"] = self._load_ternary(f"{prefix}.self_attn.k_proj.weight") layer["v_proj"] = self._load_ternary(f"{prefix}.self_attn.v_proj.weight") layer["o_proj"] = self._load_ternary(f"{prefix}.self_attn.o_proj.weight") # MLP layer["gate_proj"] = self._load_ternary(f"{prefix}.mlp.gate_proj.weight") layer["up_proj"] = self._load_ternary(f"{prefix}.mlp.up_proj.weight") layer["down_proj"] = self._load_ternary(f"{prefix}.mlp.down_proj.weight") # Norms (FP16 -> FP32) layer["input_norm"] = self._load_fp16(f"{prefix}.input_layernorm.weight") layer["post_norm"] = self._load_fp16(f"{prefix}.post_attention_layernorm.weight") # Load biases if they exist for proj in ["q_proj", "k_proj", "v_proj"]: bias_key = f"{prefix}.self_attn.{proj}.bias" if bias_key in self.manifest.get("fp16", {}): layer[f"{proj}_bias"] = self._load_fp16(bias_key) self.layers.append(layer) if (i + 1) % 7 == 0: print(f" Loaded {i+1}/{self.n_layers} layers") print(f" Loaded {self.n_layers}/{self.n_layers} layers") def _compute_memory(self): """Report memory usage.""" ternary_bytes = 0 fp_bytes = 0 for layer in self.layers: for key in ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]: tl = layer[key] ternary_bytes += tl.pos.nbytes + tl.neg.nbytes + tl.scales.nbytes for key in ["input_norm", "post_norm"]: fp_bytes += layer[key].nbytes fp_bytes += self.embed.nbytes + self.final_norm.nbytes if not self.lm_head_ternary: fp_bytes += self.lm_head_w.nbytes if hasattr(self, 'lm_head_w') else 0 total = ternary_bytes + fp_bytes print(f"\nMemory: ternary={ternary_bytes/1024/1024:.1f}MB, " f"fp={fp_bytes/1024/1024:.1f}MB, total={total/1024/1024:.1f}MB") def _rmsnorm(self, x, weight): """RMSNorm using C kernel.""" y = np.zeros_like(x) self.kernel.rmsnorm_avx512( x.ctypes.data, weight.ctypes.data, y.ctypes.data, len(x), ctypes.c_float(self.eps) ) return y def _attention(self, x, layer, cache, layer_idx, pos): """Grouped-Query Attention.""" h = self.hidden n_h = self.n_heads n_kv = self.n_kv hd = self.head_dim # Project Q, K, V q = layer["q_proj"].forward(x) # [n_heads * head_dim] k = layer["k_proj"].forward(x) # [n_kv * head_dim] v = layer["v_proj"].forward(x) # [n_kv * head_dim] # Add biases if present if "q_proj_bias" in layer: q += layer["q_proj_bias"] if "k_proj_bias" in layer: k += layer["k_proj_bias"] if "v_proj_bias" in layer: v += layer["v_proj_bias"] # Reshape q = q.reshape(n_h, hd) k = k.reshape(n_kv, hd) v = v.reshape(n_kv, hd) # RoPE self.kernel.apply_rope( q.ctypes.data, k.ctypes.data, n_h, n_kv, hd, pos, ctypes.c_float(self.rope_theta) ) # Update KV cache cache.append(layer_idx, k, v) # Get full K, V history k_all, v_all = cache.get(layer_idx) # [seq_len, n_kv, head_dim] seq_len = k_all.shape[0] # GQA: repeat KV heads to match Q heads heads_per_kv = n_h // n_kv # Compute attention for each head output = np.zeros(n_h * hd, dtype=np.float32) scale = 1.0 / np.sqrt(hd) for head in range(n_h): kv_head = head // heads_per_kv q_h = q[head] # [head_dim] # Attention scores: q @ K^T scores = np.dot(k_all[:, kv_head, :], q_h) * scale # [seq_len] # Causal mask (all visible for single token generation) # Softmax scores_max = np.max(scores) scores = np.exp(scores - scores_max) scores /= np.sum(scores) # Weighted sum of values out_h = np.dot(scores, v_all[:, kv_head, :]) # [head_dim] output[head * hd:(head + 1) * hd] = out_h # Output projection return layer["o_proj"].forward(output) def _mlp(self, x, layer): """SwiGLU MLP.""" gate = layer["gate_proj"].forward(x) up = layer["up_proj"].forward(x) # SiLU on gate self.kernel.silu_avx512(gate.ctypes.data, len(gate)) # gate * up self.kernel.elemwise_mul_avx512( gate.ctypes.data, up.ctypes.data, gate.ctypes.data, len(gate) ) # Down projection return layer["down_proj"].forward(gate) def forward_token(self, token_id, cache, pos): """Forward pass for a single token.""" # Embedding lookup x = self.embed[token_id].copy() # [hidden] # Transformer layers for i, layer in enumerate(self.layers): # Pre-attention norm normed = self._rmsnorm(x, layer["input_norm"]) # Self-attention + residual attn_out = self._attention(normed, layer, cache, i, pos) x = x + attn_out # Pre-MLP norm normed = self._rmsnorm(x, layer["post_norm"]) # MLP + residual mlp_out = self._mlp(normed, layer) x = x + mlp_out # Final norm x = self._rmsnorm(x, self.final_norm) return x def logits(self, hidden): """Compute logits from hidden state.""" if self.lm_head_ternary: return self.lm_head.forward(hidden) else: return hidden @ self.lm_head_w.T def generate(self, token_ids, max_new_tokens=256, temperature=0.6, top_p=0.95): """Generate tokens autoregressively.""" cache = KVCache(self.n_layers, self.n_kv, self.head_dim) generated = [] all_tokens = list(token_ids) t_start = time.time() # Prefill: process all input tokens for i, tid in enumerate(token_ids): hidden = self.forward_token(tid, cache, i) if i < len(token_ids) - 1: cache.advance() t_prefill = time.time() - t_start # Decode t_decode_start = time.time() for step in range(max_new_tokens): # Get logits logit_vec = self.logits(hidden) # Sample if temperature < 0.01: next_token = int(np.argmax(logit_vec)) else: logit_vec = logit_vec / temperature # Top-p sampling sorted_idx = np.argsort(logit_vec)[::-1] sorted_logits = logit_vec[sorted_idx] # Softmax max_l = sorted_logits[0] probs = np.exp(sorted_logits - max_l) probs /= probs.sum() cumsum = np.cumsum(probs) cutoff = np.searchsorted(cumsum, top_p) + 1 top_probs = probs[:cutoff] top_probs /= top_probs.sum() top_idx = sorted_idx[:cutoff] next_token = int(np.random.choice(top_idx, p=top_probs)) generated.append(next_token) all_tokens.append(next_token) # Check stop tokens if next_token in [151643, 151644, 151645]: # Qwen EOS tokens break cache.advance() hidden = self.forward_token(next_token, cache, len(all_tokens) - 1) t_total = time.time() - t_start t_decode = time.time() - t_decode_start n_gen = len(generated) stats = { "prefill_ms": t_prefill * 1000, "decode_ms": t_decode * 1000, "total_ms": t_total * 1000, "tokens_generated": n_gen, "tok_per_sec": n_gen / t_decode if t_decode > 0 else 0, "prefill_tokens": len(token_ids), } return generated, stats # ============================================================ # Tokenizer wrapper # ============================================================ class Tokenizer: def __init__(self, model_dir): from tokenizers import Tokenizer as HFTokenizer tok_path = os.path.join(model_dir, "tokenizer.json") if os.path.exists(tok_path): self.tok = HFTokenizer.from_file(tok_path) else: # Try loading from HF from transformers import AutoTokenizer self.tok = AutoTokenizer.from_pretrained(model_dir) self._is_transformers = True return self._is_transformers = False def encode(self, text): if self._is_transformers: return self.tok.encode(text) return self.tok.encode(text).ids def decode(self, ids): if self._is_transformers: return self.tok.decode(ids, skip_special_tokens=True) return self.tok.decode(ids) def apply_chat_template(self, messages): """Build Qwen chat format.""" parts = [] for msg in messages: role = msg["role"] content = msg["content"] parts.append(f"<|im_start|>{role}\n{content}<|im_end|>") parts.append("<|im_start|>assistant\n") return "".join(parts) if __name__ == "__main__": import sys model_dir = sys.argv[1] if len(sys.argv) > 1 else "deepseek-r1-1.5b-ternary" kernel = load_kernel(os.path.join(os.path.dirname(__file__), "ternary_kernel.so")) model = TernaryQwen(model_dir, kernel) # Quick test test_ids = [151644, 8948, 198, 151645, 198, 151644, 872, 198, 9707, 151645, 198, 151644, 77091, 198] print("\nGenerating...") tokens, stats = model.generate(test_ids, max_new_tokens=50, temperature=0.6) print(f"Generated {stats['tokens_generated']} tokens") print(f"Speed: {stats['tok_per_sec']:.1f} tok/s") print(f"Prefill: {stats['prefill_ms']:.0f}ms, Decode: {stats['decode_ms']:.0f}ms") print(f"Token IDs: {tokens}")