OpenTransformer's picture
Add files using upload-large-folder tool
19ed98b verified
#!/usr/bin/env python3
"""
Ternary Transformer Inference Engine
Full Qwen2 architecture inference using ternary (1.58-bit) linear layers
with AVX-512 optimized kernels. Zero multiplications in linear layers.
Architecture: DeepSeek-R1-Distill-Qwen-1.5B
- 28 layers, hidden=1536, intermediate=8960
- GQA: 12 heads, 2 KV heads, head_dim=128
- SwiGLU MLP, RoPE, RMSNorm
(c) 2026 OpenTransformers Ltd / Scott Bisset
"""
import os
import json
import ctypes
import numpy as np
from pathlib import Path
import time
# ============================================================
# Load C kernel
# ============================================================
def load_kernel(so_path="ternary_kernel.so"):
lib = ctypes.CDLL(so_path)
# ternary_matvec_avx512
lib.ternary_matvec_avx512.restype = None
lib.ternary_matvec_avx512.argtypes = [
ctypes.c_void_p, # pos_bits
ctypes.c_void_p, # neg_bits
ctypes.c_void_p, # scales
ctypes.c_void_p, # x
ctypes.c_void_p, # y
ctypes.c_int, # out_dim
ctypes.c_int, # in_dim
]
# rmsnorm
lib.rmsnorm_avx512.restype = None
lib.rmsnorm_avx512.argtypes = [
ctypes.c_void_p, # x
ctypes.c_void_p, # weight
ctypes.c_void_p, # y
ctypes.c_int, # dim
ctypes.c_float, # eps
]
# silu
lib.silu_avx512.restype = None
lib.silu_avx512.argtypes = [ctypes.c_void_p, ctypes.c_int]
# elemwise_mul
lib.elemwise_mul_avx512.restype = None
lib.elemwise_mul_avx512.argtypes = [
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int
]
# softmax
lib.softmax.restype = None
lib.softmax.argtypes = [ctypes.c_void_p, ctypes.c_int]
# rope
lib.apply_rope.restype = None
lib.apply_rope.argtypes = [
ctypes.c_void_p, ctypes.c_void_p,
ctypes.c_int, ctypes.c_int, ctypes.c_int,
ctypes.c_int, ctypes.c_float
]
return lib
# ============================================================
# Ternary Linear Layer
# ============================================================
class TernaryLinear:
def __init__(self, pos_bits, neg_bits, scales, out_dim, in_dim, kernel):
self.pos = pos_bits # uint64 contiguous array
self.neg = neg_bits
self.scales = scales # float32
self.out_dim = out_dim
self.in_dim = in_dim
self.kernel = kernel
def forward(self, x):
"""x: float32[in_dim] -> float32[out_dim]"""
y = np.zeros(self.out_dim, dtype=np.float32)
self.kernel.ternary_matvec_avx512(
self.pos.ctypes.data,
self.neg.ctypes.data,
self.scales.ctypes.data,
x.ctypes.data,
y.ctypes.data,
self.out_dim,
self.in_dim,
)
return y
# ============================================================
# KV Cache
# ============================================================
class KVCache:
def __init__(self, n_layers, n_kv_heads, head_dim, max_seq=4096):
self.n_layers = n_layers
self.max_seq = max_seq
# Pre-allocate
self.k = [np.zeros((max_seq, n_kv_heads, head_dim), dtype=np.float32) for _ in range(n_layers)]
self.v = [np.zeros((max_seq, n_kv_heads, head_dim), dtype=np.float32) for _ in range(n_layers)]
self.seq_len = 0
def append(self, layer, k, v):
"""k, v: [n_kv_heads, head_dim]"""
pos = self.seq_len
self.k[layer][pos] = k
self.v[layer][pos] = v
def get(self, layer):
"""Returns k, v up to current position: [seq_len, n_kv_heads, head_dim]"""
return self.k[layer][:self.seq_len + 1], self.v[layer][:self.seq_len + 1]
def advance(self):
self.seq_len += 1
# ============================================================
# Model
# ============================================================
class TernaryQwen:
def __init__(self, model_dir, kernel):
self.kernel = kernel
self.model_dir = model_dir
with open(os.path.join(model_dir, "config.json")) as f:
self.config = json.load(f)
with open(os.path.join(model_dir, "manifest.json")) as f:
self.manifest = json.load(f)
self.hidden = self.config["hidden_size"] # 1536
self.inter = self.config["intermediate_size"] # 8960
self.n_heads = self.config["num_attention_heads"] # 12
self.n_kv = self.config["num_key_value_heads"] # 2
self.head_dim = self.config["head_dim"] # 128
self.n_layers = self.config["num_hidden_layers"] # 28
self.vocab = self.config["vocab_size"] # 151936
self.rope_theta = self.config["rope_theta"]
self.eps = self.config["rms_norm_eps"]
print(f"Loading ternary model: {self.n_layers} layers, "
f"hidden={self.hidden}, heads={self.n_heads}/{self.n_kv}")
t0 = time.time()
self._load_weights()
print(f"Model loaded in {time.time()-t0:.1f}s")
self._compute_memory()
def _load_ternary(self, key):
"""Load a ternary linear layer."""
prefix = os.path.join(self.model_dir, key.replace(".", "_"))
shape = self.manifest["ternary"][key]
out_dim, in_dim = shape
chunks = (in_dim + 63) // 64
pos = np.fromfile(prefix + ".pos", dtype=np.uint64).reshape(out_dim, chunks)
neg = np.fromfile(prefix + ".neg", dtype=np.uint64).reshape(out_dim, chunks)
scales = np.fromfile(prefix + ".scales", dtype=np.float32)
# Make contiguous
pos = np.ascontiguousarray(pos)
neg = np.ascontiguousarray(neg)
return TernaryLinear(pos, neg, scales, out_dim, in_dim, self.kernel)
def _load_fp16(self, key):
"""Load an FP16 tensor."""
prefix = os.path.join(self.model_dir, key.replace(".", "_"))
shape = self.manifest["fp16"][key]
return np.fromfile(prefix + ".fp16", dtype=np.float16).reshape(shape).astype(np.float32)
def _load_weights(self):
"""Load all weights."""
# Embedding (FP16)
self.embed = self._load_fp16("model.embed_tokens.weight") # [vocab, hidden]
# Final norm
self.final_norm = self._load_fp16("model.norm.weight") # [hidden]
# LM head — check if it exists as ternary or fp16
if "lm_head.weight" in self.manifest.get("ternary", {}):
self.lm_head = self._load_ternary("lm_head.weight")
self.lm_head_ternary = True
elif "lm_head.weight" in self.manifest.get("fp16", {}):
self.lm_head_w = self._load_fp16("lm_head.weight")
self.lm_head_ternary = False
else:
# Tied embeddings
self.lm_head_w = self.embed
self.lm_head_ternary = False
# Layers
self.layers = []
for i in range(self.n_layers):
layer = {}
prefix = f"model.layers.{i}"
# Attention
layer["q_proj"] = self._load_ternary(f"{prefix}.self_attn.q_proj.weight")
layer["k_proj"] = self._load_ternary(f"{prefix}.self_attn.k_proj.weight")
layer["v_proj"] = self._load_ternary(f"{prefix}.self_attn.v_proj.weight")
layer["o_proj"] = self._load_ternary(f"{prefix}.self_attn.o_proj.weight")
# MLP
layer["gate_proj"] = self._load_ternary(f"{prefix}.mlp.gate_proj.weight")
layer["up_proj"] = self._load_ternary(f"{prefix}.mlp.up_proj.weight")
layer["down_proj"] = self._load_ternary(f"{prefix}.mlp.down_proj.weight")
# Norms (FP16 -> FP32)
layer["input_norm"] = self._load_fp16(f"{prefix}.input_layernorm.weight")
layer["post_norm"] = self._load_fp16(f"{prefix}.post_attention_layernorm.weight")
# Load biases if they exist
for proj in ["q_proj", "k_proj", "v_proj"]:
bias_key = f"{prefix}.self_attn.{proj}.bias"
if bias_key in self.manifest.get("fp16", {}):
layer[f"{proj}_bias"] = self._load_fp16(bias_key)
self.layers.append(layer)
if (i + 1) % 7 == 0:
print(f" Loaded {i+1}/{self.n_layers} layers")
print(f" Loaded {self.n_layers}/{self.n_layers} layers")
def _compute_memory(self):
"""Report memory usage."""
ternary_bytes = 0
fp_bytes = 0
for layer in self.layers:
for key in ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"]:
tl = layer[key]
ternary_bytes += tl.pos.nbytes + tl.neg.nbytes + tl.scales.nbytes
for key in ["input_norm", "post_norm"]:
fp_bytes += layer[key].nbytes
fp_bytes += self.embed.nbytes + self.final_norm.nbytes
if not self.lm_head_ternary:
fp_bytes += self.lm_head_w.nbytes if hasattr(self, 'lm_head_w') else 0
total = ternary_bytes + fp_bytes
print(f"\nMemory: ternary={ternary_bytes/1024/1024:.1f}MB, "
f"fp={fp_bytes/1024/1024:.1f}MB, total={total/1024/1024:.1f}MB")
def _rmsnorm(self, x, weight):
"""RMSNorm using C kernel."""
y = np.zeros_like(x)
self.kernel.rmsnorm_avx512(
x.ctypes.data, weight.ctypes.data, y.ctypes.data,
len(x), ctypes.c_float(self.eps)
)
return y
def _attention(self, x, layer, cache, layer_idx, pos):
"""Grouped-Query Attention."""
h = self.hidden
n_h = self.n_heads
n_kv = self.n_kv
hd = self.head_dim
# Project Q, K, V
q = layer["q_proj"].forward(x) # [n_heads * head_dim]
k = layer["k_proj"].forward(x) # [n_kv * head_dim]
v = layer["v_proj"].forward(x) # [n_kv * head_dim]
# Add biases if present
if "q_proj_bias" in layer:
q += layer["q_proj_bias"]
if "k_proj_bias" in layer:
k += layer["k_proj_bias"]
if "v_proj_bias" in layer:
v += layer["v_proj_bias"]
# Reshape
q = q.reshape(n_h, hd)
k = k.reshape(n_kv, hd)
v = v.reshape(n_kv, hd)
# RoPE
self.kernel.apply_rope(
q.ctypes.data, k.ctypes.data,
n_h, n_kv, hd, pos,
ctypes.c_float(self.rope_theta)
)
# Update KV cache
cache.append(layer_idx, k, v)
# Get full K, V history
k_all, v_all = cache.get(layer_idx) # [seq_len, n_kv, head_dim]
seq_len = k_all.shape[0]
# GQA: repeat KV heads to match Q heads
heads_per_kv = n_h // n_kv
# Compute attention for each head
output = np.zeros(n_h * hd, dtype=np.float32)
scale = 1.0 / np.sqrt(hd)
for head in range(n_h):
kv_head = head // heads_per_kv
q_h = q[head] # [head_dim]
# Attention scores: q @ K^T
scores = np.dot(k_all[:, kv_head, :], q_h) * scale # [seq_len]
# Causal mask (all visible for single token generation)
# Softmax
scores_max = np.max(scores)
scores = np.exp(scores - scores_max)
scores /= np.sum(scores)
# Weighted sum of values
out_h = np.dot(scores, v_all[:, kv_head, :]) # [head_dim]
output[head * hd:(head + 1) * hd] = out_h
# Output projection
return layer["o_proj"].forward(output)
def _mlp(self, x, layer):
"""SwiGLU MLP."""
gate = layer["gate_proj"].forward(x)
up = layer["up_proj"].forward(x)
# SiLU on gate
self.kernel.silu_avx512(gate.ctypes.data, len(gate))
# gate * up
self.kernel.elemwise_mul_avx512(
gate.ctypes.data, up.ctypes.data, gate.ctypes.data, len(gate)
)
# Down projection
return layer["down_proj"].forward(gate)
def forward_token(self, token_id, cache, pos):
"""Forward pass for a single token."""
# Embedding lookup
x = self.embed[token_id].copy() # [hidden]
# Transformer layers
for i, layer in enumerate(self.layers):
# Pre-attention norm
normed = self._rmsnorm(x, layer["input_norm"])
# Self-attention + residual
attn_out = self._attention(normed, layer, cache, i, pos)
x = x + attn_out
# Pre-MLP norm
normed = self._rmsnorm(x, layer["post_norm"])
# MLP + residual
mlp_out = self._mlp(normed, layer)
x = x + mlp_out
# Final norm
x = self._rmsnorm(x, self.final_norm)
return x
def logits(self, hidden):
"""Compute logits from hidden state."""
if self.lm_head_ternary:
return self.lm_head.forward(hidden)
else:
return hidden @ self.lm_head_w.T
def generate(self, token_ids, max_new_tokens=256, temperature=0.6, top_p=0.95):
"""Generate tokens autoregressively."""
cache = KVCache(self.n_layers, self.n_kv, self.head_dim)
generated = []
all_tokens = list(token_ids)
t_start = time.time()
# Prefill: process all input tokens
for i, tid in enumerate(token_ids):
hidden = self.forward_token(tid, cache, i)
if i < len(token_ids) - 1:
cache.advance()
t_prefill = time.time() - t_start
# Decode
t_decode_start = time.time()
for step in range(max_new_tokens):
# Get logits
logit_vec = self.logits(hidden)
# Sample
if temperature < 0.01:
next_token = int(np.argmax(logit_vec))
else:
logit_vec = logit_vec / temperature
# Top-p sampling
sorted_idx = np.argsort(logit_vec)[::-1]
sorted_logits = logit_vec[sorted_idx]
# Softmax
max_l = sorted_logits[0]
probs = np.exp(sorted_logits - max_l)
probs /= probs.sum()
cumsum = np.cumsum(probs)
cutoff = np.searchsorted(cumsum, top_p) + 1
top_probs = probs[:cutoff]
top_probs /= top_probs.sum()
top_idx = sorted_idx[:cutoff]
next_token = int(np.random.choice(top_idx, p=top_probs))
generated.append(next_token)
all_tokens.append(next_token)
# Check stop tokens
if next_token in [151643, 151644, 151645]: # Qwen EOS tokens
break
cache.advance()
hidden = self.forward_token(next_token, cache, len(all_tokens) - 1)
t_total = time.time() - t_start
t_decode = time.time() - t_decode_start
n_gen = len(generated)
stats = {
"prefill_ms": t_prefill * 1000,
"decode_ms": t_decode * 1000,
"total_ms": t_total * 1000,
"tokens_generated": n_gen,
"tok_per_sec": n_gen / t_decode if t_decode > 0 else 0,
"prefill_tokens": len(token_ids),
}
return generated, stats
# ============================================================
# Tokenizer wrapper
# ============================================================
class Tokenizer:
def __init__(self, model_dir):
from tokenizers import Tokenizer as HFTokenizer
tok_path = os.path.join(model_dir, "tokenizer.json")
if os.path.exists(tok_path):
self.tok = HFTokenizer.from_file(tok_path)
else:
# Try loading from HF
from transformers import AutoTokenizer
self.tok = AutoTokenizer.from_pretrained(model_dir)
self._is_transformers = True
return
self._is_transformers = False
def encode(self, text):
if self._is_transformers:
return self.tok.encode(text)
return self.tok.encode(text).ids
def decode(self, ids):
if self._is_transformers:
return self.tok.decode(ids, skip_special_tokens=True)
return self.tok.decode(ids)
def apply_chat_template(self, messages):
"""Build Qwen chat format."""
parts = []
for msg in messages:
role = msg["role"]
content = msg["content"]
parts.append(f"<|im_start|>{role}\n{content}<|im_end|>")
parts.append("<|im_start|>assistant\n")
return "".join(parts)
if __name__ == "__main__":
import sys
model_dir = sys.argv[1] if len(sys.argv) > 1 else "deepseek-r1-1.5b-ternary"
kernel = load_kernel(os.path.join(os.path.dirname(__file__), "ternary_kernel.so"))
model = TernaryQwen(model_dir, kernel)
# Quick test
test_ids = [151644, 8948, 198, 151645, 198, 151644, 872, 198, 9707, 151645, 198, 151644, 77091, 198]
print("\nGenerating...")
tokens, stats = model.generate(test_ids, max_new_tokens=50, temperature=0.6)
print(f"Generated {stats['tokens_generated']} tokens")
print(f"Speed: {stats['tok_per_sec']:.1f} tok/s")
print(f"Prefill: {stats['prefill_ms']:.0f}ms, Decode: {stats['decode_ms']:.0f}ms")
print(f"Token IDs: {tokens}")