File size: 3,753 Bytes
19ed98b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#!/usr/bin/env python3
"""
OpenAI-compatible API server for Ternary Transformer Engine.
Drop-in replacement for llama-server.

(c) 2026 OpenTransformers Ltd / Scott Bisset
"""

import json
import time
import threading
from http.server import HTTPServer, BaseHTTPRequestHandler
from inference import TernaryQwen, Tokenizer, load_kernel
import os

MODEL_DIR = os.environ.get("TERNARY_MODEL_DIR", "deepseek-r1-1.5b-ternary")
TOKENIZER_DIR = os.environ.get("TOKENIZER_DIR", "deepseek-r1-1.5b-hf")
HOST = os.environ.get("HOST", "127.0.0.1")
PORT = int(os.environ.get("PORT", "8080"))

print("Loading ternary kernel...")
kernel = load_kernel(os.path.join(os.path.dirname(__file__), "ternary_kernel.so"))

print(f"Loading model from {MODEL_DIR}...")
model = TernaryQwen(MODEL_DIR, kernel)

print(f"Loading tokenizer from {TOKENIZER_DIR}...")
tokenizer = Tokenizer(TOKENIZER_DIR)

lock = threading.Lock()
print("Ready!")

class Handler(BaseHTTPRequestHandler):
    def do_POST(self):
        if self.path == "/v1/chat/completions":
            length = int(self.headers.get("Content-Length", 0))
            body = json.loads(self.rfile.read(length))
            
            messages = body.get("messages", [])
            max_tokens = body.get("max_tokens", 256)
            temperature = body.get("temperature", 0.6)
            top_p = body.get("top_p", 0.95)
            
            # Build prompt
            prompt = tokenizer.apply_chat_template(messages)
            input_ids = tokenizer.encode(prompt)
            
            # Generate
            with lock:
                gen_ids, stats = model.generate(
                    input_ids, 
                    max_new_tokens=max_tokens,
                    temperature=temperature,
                    top_p=top_p
                )
            
            text = tokenizer.decode(gen_ids)
            
            response = {
                "id": f"chatcmpl-ternary-{int(time.time())}",
                "object": "chat.completion",
                "created": int(time.time()),
                "model": "DeepSeek-R1-Distill-Qwen-1.5B-TERNARY",
                "choices": [{
                    "index": 0,
                    "message": {"role": "assistant", "content": text},
                    "finish_reason": "stop"
                }],
                "usage": {
                    "prompt_tokens": len(input_ids),
                    "completion_tokens": stats["tokens_generated"],
                    "total_tokens": len(input_ids) + stats["tokens_generated"]
                },
                "timings": {
                    "prompt_n": stats["prefill_tokens"],
                    "prompt_ms": stats["prefill_ms"],
                    "predicted_n": stats["tokens_generated"],
                    "predicted_ms": stats["decode_ms"],
                    "predicted_per_second": stats["tok_per_sec"],
                }
            }
            
            self.send_response(200)
            self.send_header("Content-Type", "application/json")
            self.end_headers()
            self.wfile.write(json.dumps(response).encode())
        else:
            self.send_response(404)
            self.end_headers()
    
    def do_GET(self):
        if self.path == "/health":
            self.send_response(200)
            self.send_header("Content-Type", "application/json")
            self.end_headers()
            self.wfile.write(b'{"status":"ok","engine":"ternary-avx512"}')
        else:
            self.send_response(404)
            self.end_headers()
    
    def log_message(self, format, *args):
        pass

if __name__ == "__main__":
    server = HTTPServer((HOST, PORT), Handler)
    print(f"Ternary engine serving on {HOST}:{PORT}")
    server.serve_forever()