""" server.py — OpenAI-compatible inference server for JuliaSLM-compressed-svd Serves the SVD-90 compressed JuliaSLM model (4.81M params, ~4.5% smaller). Downloads checkpoint and tokenizer from HuggingFace on first run. SVD compression: each linear layer W ≈ A @ B (low-rank factorization), reducing parameter count while preserving model quality. Endpoints: GET / -> health check / API info GET /v1/models -> list available models POST /v1/chat/completions -> generate text (OpenAI format, streaming supported) """ import json import os import regex import time import uuid from http.server import HTTPServer, BaseHTTPRequestHandler from threading import Lock import torch import torch.nn.functional as F from huggingface_hub import hf_hub_download from juliaslm_svd_model import SVDConfig, JuliaSLM_SVD # ═══════════════════════════════════════════════════════════════════ # Configuration # ═══════════════════════════════════════════════════════════════════ HF_MODEL_REPO = os.environ.get("HF_MODEL_REPO", "LisaMegaWatts/JuliaSLM-compressed-svd") HF_TOKENIZER_REPO = os.environ.get("HF_TOKENIZER_REPO", "LisaMegaWatts/JuliaSLM") CHECKPOINT_NAME = os.environ.get("CHECKPOINT_NAME", "svd_SVD-90_best.pt") PORT = int(os.environ.get("PORT", "7860")) CKPT_DIR = "checkpoints" MODEL_ID = "juliaslm-compressed-svd-90" # ═══════════════════════════════════════════════════════════════════ # BPE Tokenizer (vocab.json + merges.txt) # ═══════════════════════════════════════════════════════════════════ GPT2_PATTERN = regex.compile( r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", regex.UNICODE, ) def _build_byte_to_unicode(): bs = list(range(0x21, 0x7F)) + list(range(0xA1, 0xAD)) + list(range(0xAE, 0x100)) cs = list(bs) n = 0 for b in range(256): if b not in bs: bs.append(b) cs.append(256 + n) n += 1 return {b: chr(c) for b, c in zip(bs, cs)} BYTE_TO_UNICODE = _build_byte_to_unicode() UNICODE_TO_BYTE = {v: k for k, v in BYTE_TO_UNICODE.items()} class BPETokenizer: def __init__(self, vocab_path: str, merges_path: str): with open(vocab_path, "r", encoding="utf-8") as f: self.vocab = json.load(f) self.id_to_token = {v: k for k, v in self.vocab.items()} self.merges = [] self.merge_rank = {} with open(merges_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line or line.startswith("#"): continue parts = line.split() if len(parts) == 2: pair = (parts[0], parts[1]) self.merges.append(pair) self.merge_rank[pair] = len(self.merge_rank) self.cache = {} def _bpe_word(self, chars: list[str]) -> list[str]: tokens = list(chars) while len(tokens) >= 2: best_rank = float("inf") best_pair = None for i in range(len(tokens) - 1): pair = (tokens[i], tokens[i + 1]) rank = self.merge_rank.get(pair, float("inf")) if rank < best_rank: best_rank = rank best_pair = pair if best_pair is None or best_rank == float("inf"): break a, b = best_pair new_tokens = [] i = 0 while i < len(tokens): if i < len(tokens) - 1 and tokens[i] == a and tokens[i + 1] == b: new_tokens.append(a + b) i += 2 else: new_tokens.append(tokens[i]) i += 1 tokens = new_tokens return tokens def encode(self, text: str) -> list[int]: ids = [] for m in GPT2_PATTERN.finditer(text): word = m.group() if word in self.cache: ids.extend(self.cache[word]) continue chars = [BYTE_TO_UNICODE[b] for b in word.encode("utf-8")] tokens = self._bpe_word(chars) word_ids = [self.vocab[t] for t in tokens if t in self.vocab] self.cache[word] = word_ids ids.extend(word_ids) return ids def decode(self, ids: list[int]) -> str: text = "".join(self.id_to_token.get(i, "") for i in ids) byte_vals = [UNICODE_TO_BYTE[c] for c in text if c in UNICODE_TO_BYTE] return bytes(byte_vals).decode("utf-8", errors="replace") # ═══════════════════════════════════════════════════════════════════ # Sampling helpers # ═══════════════════════════════════════════════════════════════════ def _sample_logits(logits: torch.Tensor, temperature: float, top_k: int, top_p: float, vocab_size: int) -> int: if temperature <= 0: return logits.argmax().item() logits = logits / temperature if 0 < top_k < vocab_size: topk_vals, _ = torch.topk(logits, top_k) logits[logits < topk_vals[-1]] = float("-inf") if top_p < 1.0: sorted_logits, sorted_idx = torch.sort(logits, descending=True) cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) remove = cum_probs - F.softmax(sorted_logits, dim=-1) >= top_p sorted_logits[remove] = float("-inf") logits = sorted_logits.scatter(0, sorted_idx, sorted_logits) probs = F.softmax(logits, dim=-1) return torch.multinomial(probs, 1).item() # ═══════════════════════════════════════════════════════════════════ # Text generation with KV cache # ═══════════════════════════════════════════════════════════════════ @torch.inference_mode() def generate( model: JuliaSLM_SVD, tokenizer: BPETokenizer, prompt: str, max_tokens: int = 200, temperature: float = 0.8, top_k: int = 40, top_p: float = 1.0, ) -> tuple[str, int]: config = model.config input_ids = tokenizer.encode(prompt) prompt_len = len(input_ids) ids = input_ids[-config.context_length:] x = torch.tensor([ids], dtype=torch.long, device=DEVICE) logits, kv_caches = model(x) next_logits = logits[0, -1, :].float() generated_ids = [] seq_len = len(ids) for _ in range(max_tokens): if seq_len >= config.context_length: break idx = _sample_logits(next_logits, temperature, top_k, top_p, config.vocab_size) generated_ids.append(idx) seq_len += 1 x = torch.tensor([[idx]], dtype=torch.long, device=DEVICE) logits, kv_caches = model(x, kv_caches) next_logits = logits[0, -1, :].float() return tokenizer.decode(generated_ids), prompt_len @torch.inference_mode() def generate_streaming( model: JuliaSLM_SVD, tokenizer: BPETokenizer, prompt: str, max_tokens: int = 200, temperature: float = 0.8, top_k: int = 40, top_p: float = 1.0, ): config = model.config input_ids = tokenizer.encode(prompt) prompt_len = len(input_ids) ids = input_ids[-config.context_length:] x = torch.tensor([ids], dtype=torch.long, device=DEVICE) logits, kv_caches = model(x) next_logits = logits[0, -1, :].float() seq_len = len(ids) for _ in range(max_tokens): if seq_len >= config.context_length: break idx = _sample_logits(next_logits, temperature, top_k, top_p, config.vocab_size) seq_len += 1 yield tokenizer.decode([idx]), prompt_len x = torch.tensor([[idx]], dtype=torch.long, device=DEVICE) logits, kv_caches = model(x, kv_caches) next_logits = logits[0, -1, :].float() # ═══════════════════════════════════════════════════════════════════ # Download artifacts from HuggingFace # ═══════════════════════════════════════════════════════════════════ def ensure_artifacts(): os.makedirs(CKPT_DIR, exist_ok=True) files = {} # Model checkpoint from SVD-compressed repo ckpt_local = os.path.join(CKPT_DIR, CHECKPOINT_NAME) if not os.path.isfile(ckpt_local): print(f"Downloading {CHECKPOINT_NAME} from {HF_MODEL_REPO} ...") hf_hub_download(repo_id=HF_MODEL_REPO, filename=CHECKPOINT_NAME, local_dir=CKPT_DIR) sz_mb = os.path.getsize(ckpt_local) / (1024 * 1024) print(f" -> {ckpt_local} ({sz_mb:.1f} MB)") files["checkpoint"] = ckpt_local # Tokenizer from original JuliaSLM repo for fname in ("vocab.json", "merges.txt"): local = os.path.join(CKPT_DIR, fname) if not os.path.isfile(local): print(f"Downloading {fname} from {HF_TOKENIZER_REPO} ...") hf_hub_download(repo_id=HF_TOKENIZER_REPO, filename=fname, local_dir=CKPT_DIR) sz_mb = os.path.getsize(local) / (1024 * 1024) print(f" -> {local} ({sz_mb:.1f} MB)") files[fname] = local return files # ═══════════════════════════════════════════════════════════════════ # Load model # ═══════════════════════════════════════════════════════════════════ print("Downloading artifacts...") ARTIFACT_PATHS = ensure_artifacts() print("\nLoading SVD-compressed model...") state_dict = torch.load(ARTIFACT_PATHS["checkpoint"], map_location="cpu", weights_only=True) # Build config from checkpoint (auto-detects ranks per layer) CONFIG = SVDConfig.from_checkpoint(state_dict) MODEL = JuliaSLM_SVD(CONFIG) MODEL.load_state_dict(state_dict, strict=False) MODEL.eval() DEVICE = torch.device("cpu") print("Loading tokenizer...") TOKENIZER = BPETokenizer( ARTIFACT_PATHS["vocab.json"], ARTIFACT_PATHS["merges.txt"], ) MODEL_CREATED_AT = int(time.time()) NUM_PARAMS = MODEL.num_parameters print( f"\nSVD-compressed model ready: vocab={CONFIG.vocab_size}, d_model={CONFIG.d_model}, " f"layers={CONFIG.n_layers}, heads={CONFIG.n_heads}, " f"ctx={CONFIG.context_length}, params={NUM_PARAMS:,}" ) print("SVD-90 compression: ~4.5% parameter reduction") print("KV cache enabled: O(1) per-token decoding") MODEL_LOCK = Lock() # ═══════════════════════════════════════════════════════════════════ # HTTP helpers # ═══════════════════════════════════════════════════════════════════ CORS_HEADERS = { "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "GET, POST, OPTIONS", "Access-Control-Allow-Headers": "Content-Type, Authorization", } def extract_prompt(messages): if not messages: return "" for msg in reversed(messages): if msg.get("role") == "user": return msg.get("content", "") return messages[-1].get("content", "") # ═══════════════════════════════════════════════════════════════════ # Request handler # ═══════════════════════════════════════════════════════════════════ class Handler(BaseHTTPRequestHandler): def log_message(self, format, *args): print(f"[{self.log_date_time_string()}] {format % args}") def _send_json(self, status, body): data = json.dumps(body).encode() self.send_response(status) self.send_header("Content-Type", "application/json") for k, v in CORS_HEADERS.items(): self.send_header(k, v) self.send_header("Content-Length", str(len(data))) self.end_headers() self.wfile.write(data) def do_OPTIONS(self): self.send_response(204) for k, v in CORS_HEADERS.items(): self.send_header(k, v) self.end_headers() def do_GET(self): if self.path == "/": self._send_json(200, { "name": "JuliaSLM-compressed-svd", "version": "1.0.0", "description": "SVD-compressed JuliaSLM — low-rank factorized weight matrices for efficient inference", "architecture": "MHA + RoPE + SwiGLU + RMSNorm + weight tying + SVD compression", "compression": { "method": "SVD-90", "original_params": 5_040_000, "compressed_params": NUM_PARAMS, "reduction_pct": round((1 - NUM_PARAMS / 5_040_000) * 100, 1), "val_loss": 3.756, "original_val_loss": 3.552, }, "model": { "vocab_size": CONFIG.vocab_size, "d_model": CONFIG.d_model, "n_layers": CONFIG.n_layers, "n_heads": CONFIG.n_heads, "context_length": CONFIG.context_length, "parameters": NUM_PARAMS, }, "endpoints": ["/v1/models", "/v1/chat/completions"], "features": ["streaming", "OpenAI-compatible", "top-k", "top-p", "kv-cache"], "compatible_with": ["OpenAI API", "OpenRouter"], }) elif self.path == "/v1/models": self._send_json(200, { "object": "list", "data": [{ "id": MODEL_ID, "object": "model", "created": MODEL_CREATED_AT, "owned_by": "juliaslm", }], }) else: self._send_json(404, {"error": { "message": f"Not found: GET {self.path}", "type": "invalid_request_error", "code": "not_found", }}) def do_POST(self): if self.path != "/v1/chat/completions": self._send_json(404, {"error": { "message": f"Not found: POST {self.path}", "type": "invalid_request_error", "code": "not_found", }}) return content_length = int(self.headers.get("Content-Length", 0)) try: body = json.loads(self.rfile.read(content_length)) except (json.JSONDecodeError, ValueError): self._send_json(400, {"error": { "message": "Invalid JSON in request body", "type": "invalid_request_error", "code": "invalid_json", }}) return temperature = max(0.0, min(2.0, float(body.get("temperature", 0.8)))) max_tokens = max(1, min(CONFIG.context_length, int(body.get("max_tokens", 200)))) top_k_val = max(0, min(CONFIG.vocab_size, int(body.get("top_k", 40)))) top_p_val = max(0.0, min(1.0, float(body.get("top_p", 1.0)))) stream = bool(body.get("stream", False)) messages = body.get("messages", []) prompt_text = extract_prompt(messages) completion_id = f"chatcmpl-{uuid.uuid4()}" created = int(time.time()) with MODEL_LOCK: if stream: self._handle_stream( prompt_text, max_tokens, temperature, top_k_val, top_p_val, completion_id, created, ) else: self._handle_non_stream( prompt_text, max_tokens, temperature, top_k_val, top_p_val, completion_id, created, ) def _handle_stream(self, prompt_text, max_tokens, temperature, top_k, top_p, completion_id, created): self.send_response(200) self.send_header("Content-Type", "text/event-stream") self.send_header("Cache-Control", "no-cache") self.send_header("X-Accel-Buffering", "no") for k, v in CORS_HEADERS.items(): self.send_header(k, v) self.end_headers() def sse(data): self.wfile.write(f"data: {json.dumps(data)}\n\n".encode()) self.wfile.flush() sse({ "id": completion_id, "object": "chat.completion.chunk", "created": created, "model": MODEL_ID, "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None}], }) token_count = 0 prompt_tokens = 0 for token_str, p_len in generate_streaming( MODEL, TOKENIZER, prompt_text, max_tokens=max_tokens, temperature=temperature, top_k=top_k, top_p=top_p, ): token_count += 1 prompt_tokens = p_len sse({ "id": completion_id, "object": "chat.completion.chunk", "created": created, "model": MODEL_ID, "choices": [{"index": 0, "delta": {"content": token_str}, "finish_reason": None}], }) sse({ "id": completion_id, "object": "chat.completion.chunk", "created": created, "model": MODEL_ID, "choices": [{"index": 0, "delta": {}, "finish_reason": "length" if token_count >= max_tokens else "stop"}], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": token_count, "total_tokens": prompt_tokens + token_count, }, }) self.wfile.write(b"data: [DONE]\n\n") self.wfile.flush() def _handle_non_stream(self, prompt_text, max_tokens, temperature, top_k, top_p, completion_id, created): text, prompt_tokens = generate( MODEL, TOKENIZER, prompt_text, max_tokens=max_tokens, temperature=temperature, top_k=top_k, top_p=top_p, ) completion_tokens = len(TOKENIZER.encode(text)) finish_reason = "length" if completion_tokens >= max_tokens else "stop" self._send_json(200, { "id": completion_id, "object": "chat.completion", "created": created, "model": MODEL_ID, "choices": [{ "index": 0, "message": {"role": "assistant", "content": text}, "finish_reason": finish_reason, }], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, }, "system_fingerprint": "juliaslm-svd90-v1", }) # ═══════════════════════════════════════════════════════════════════ # Start server # ═══════════════════════════════════════════════════════════════════ if __name__ == "__main__": print(f"\nJuliaSLM-compressed-svd server starting on 0.0.0.0:{PORT} ...") print(f" GET http://localhost:{PORT}/") print(f" GET http://localhost:{PORT}/v1/models") print(f" POST http://localhost:{PORT}/v1/chat/completions") print(f" POST http://localhost:{PORT}/v1/chat/completions (stream=true)") print() server = HTTPServer(("0.0.0.0", PORT), Handler) server.serve_forever()