| | """ |
| | server.py β OpenAI-compatible inference server for JuliaSLM-compressed-svd |
| | |
| | Serves the SVD-90 compressed JuliaSLM model (4.81M params, ~4.5% smaller). |
| | Downloads checkpoint and tokenizer from HuggingFace on first run. |
| | |
| | SVD compression: each linear layer W β A @ B (low-rank factorization), |
| | reducing parameter count while preserving model quality. |
| | |
| | Endpoints: |
| | GET / -> health check / API info |
| | GET /v1/models -> list available models |
| | POST /v1/chat/completions -> generate text (OpenAI format, streaming supported) |
| | """ |
| |
|
| | import json |
| | import os |
| | import regex |
| | import time |
| | import uuid |
| | from http.server import HTTPServer, BaseHTTPRequestHandler |
| | from threading import Lock |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from huggingface_hub import hf_hub_download |
| |
|
| | from juliaslm_svd_model import SVDConfig, JuliaSLM_SVD |
| |
|
| | |
| | |
| | |
| |
|
| | HF_MODEL_REPO = os.environ.get("HF_MODEL_REPO", "LisaMegaWatts/JuliaSLM-compressed-svd") |
| | HF_TOKENIZER_REPO = os.environ.get("HF_TOKENIZER_REPO", "LisaMegaWatts/JuliaSLM") |
| | CHECKPOINT_NAME = os.environ.get("CHECKPOINT_NAME", "svd_SVD-90_best.pt") |
| | PORT = int(os.environ.get("PORT", "7860")) |
| | CKPT_DIR = "checkpoints" |
| | MODEL_ID = "juliaslm-compressed-svd-90" |
| |
|
| | |
| | |
| | |
| |
|
| | GPT2_PATTERN = regex.compile( |
| | r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", |
| | regex.UNICODE, |
| | ) |
| |
|
| |
|
| | def _build_byte_to_unicode(): |
| | bs = list(range(0x21, 0x7F)) + list(range(0xA1, 0xAD)) + list(range(0xAE, 0x100)) |
| | cs = list(bs) |
| | n = 0 |
| | for b in range(256): |
| | if b not in bs: |
| | bs.append(b) |
| | cs.append(256 + n) |
| | n += 1 |
| | return {b: chr(c) for b, c in zip(bs, cs)} |
| |
|
| |
|
| | BYTE_TO_UNICODE = _build_byte_to_unicode() |
| | UNICODE_TO_BYTE = {v: k for k, v in BYTE_TO_UNICODE.items()} |
| |
|
| |
|
| | class BPETokenizer: |
| | def __init__(self, vocab_path: str, merges_path: str): |
| | with open(vocab_path, "r", encoding="utf-8") as f: |
| | self.vocab = json.load(f) |
| | self.id_to_token = {v: k for k, v in self.vocab.items()} |
| |
|
| | self.merges = [] |
| | self.merge_rank = {} |
| | with open(merges_path, "r", encoding="utf-8") as f: |
| | for line in f: |
| | line = line.strip() |
| | if not line or line.startswith("#"): |
| | continue |
| | parts = line.split() |
| | if len(parts) == 2: |
| | pair = (parts[0], parts[1]) |
| | self.merges.append(pair) |
| | self.merge_rank[pair] = len(self.merge_rank) |
| |
|
| | self.cache = {} |
| |
|
| | def _bpe_word(self, chars: list[str]) -> list[str]: |
| | tokens = list(chars) |
| | while len(tokens) >= 2: |
| | best_rank = float("inf") |
| | best_pair = None |
| | for i in range(len(tokens) - 1): |
| | pair = (tokens[i], tokens[i + 1]) |
| | rank = self.merge_rank.get(pair, float("inf")) |
| | if rank < best_rank: |
| | best_rank = rank |
| | best_pair = pair |
| | if best_pair is None or best_rank == float("inf"): |
| | break |
| | a, b = best_pair |
| | new_tokens = [] |
| | i = 0 |
| | while i < len(tokens): |
| | if i < len(tokens) - 1 and tokens[i] == a and tokens[i + 1] == b: |
| | new_tokens.append(a + b) |
| | i += 2 |
| | else: |
| | new_tokens.append(tokens[i]) |
| | i += 1 |
| | tokens = new_tokens |
| | return tokens |
| |
|
| | def encode(self, text: str) -> list[int]: |
| | ids = [] |
| | for m in GPT2_PATTERN.finditer(text): |
| | word = m.group() |
| | if word in self.cache: |
| | ids.extend(self.cache[word]) |
| | continue |
| | chars = [BYTE_TO_UNICODE[b] for b in word.encode("utf-8")] |
| | tokens = self._bpe_word(chars) |
| | word_ids = [self.vocab[t] for t in tokens if t in self.vocab] |
| | self.cache[word] = word_ids |
| | ids.extend(word_ids) |
| | return ids |
| |
|
| | def decode(self, ids: list[int]) -> str: |
| | text = "".join(self.id_to_token.get(i, "") for i in ids) |
| | byte_vals = [UNICODE_TO_BYTE[c] for c in text if c in UNICODE_TO_BYTE] |
| | return bytes(byte_vals).decode("utf-8", errors="replace") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def _sample_logits(logits: torch.Tensor, temperature: float, top_k: int, |
| | top_p: float, vocab_size: int) -> int: |
| | if temperature <= 0: |
| | return logits.argmax().item() |
| |
|
| | logits = logits / temperature |
| |
|
| | if 0 < top_k < vocab_size: |
| | topk_vals, _ = torch.topk(logits, top_k) |
| | logits[logits < topk_vals[-1]] = float("-inf") |
| |
|
| | if top_p < 1.0: |
| | sorted_logits, sorted_idx = torch.sort(logits, descending=True) |
| | cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| | remove = cum_probs - F.softmax(sorted_logits, dim=-1) >= top_p |
| | sorted_logits[remove] = float("-inf") |
| | logits = sorted_logits.scatter(0, sorted_idx, sorted_logits) |
| |
|
| | probs = F.softmax(logits, dim=-1) |
| | return torch.multinomial(probs, 1).item() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @torch.inference_mode() |
| | def generate( |
| | model: JuliaSLM_SVD, |
| | tokenizer: BPETokenizer, |
| | prompt: str, |
| | max_tokens: int = 200, |
| | temperature: float = 0.8, |
| | top_k: int = 40, |
| | top_p: float = 1.0, |
| | ) -> tuple[str, int]: |
| | config = model.config |
| | input_ids = tokenizer.encode(prompt) |
| | prompt_len = len(input_ids) |
| | ids = input_ids[-config.context_length:] |
| |
|
| | x = torch.tensor([ids], dtype=torch.long, device=DEVICE) |
| | logits, kv_caches = model(x) |
| | next_logits = logits[0, -1, :].float() |
| |
|
| | generated_ids = [] |
| | seq_len = len(ids) |
| |
|
| | for _ in range(max_tokens): |
| | if seq_len >= config.context_length: |
| | break |
| |
|
| | idx = _sample_logits(next_logits, temperature, top_k, top_p, config.vocab_size) |
| | generated_ids.append(idx) |
| | seq_len += 1 |
| |
|
| | x = torch.tensor([[idx]], dtype=torch.long, device=DEVICE) |
| | logits, kv_caches = model(x, kv_caches) |
| | next_logits = logits[0, -1, :].float() |
| |
|
| | return tokenizer.decode(generated_ids), prompt_len |
| |
|
| |
|
| | @torch.inference_mode() |
| | def generate_streaming( |
| | model: JuliaSLM_SVD, |
| | tokenizer: BPETokenizer, |
| | prompt: str, |
| | max_tokens: int = 200, |
| | temperature: float = 0.8, |
| | top_k: int = 40, |
| | top_p: float = 1.0, |
| | ): |
| | config = model.config |
| | input_ids = tokenizer.encode(prompt) |
| | prompt_len = len(input_ids) |
| | ids = input_ids[-config.context_length:] |
| |
|
| | x = torch.tensor([ids], dtype=torch.long, device=DEVICE) |
| | logits, kv_caches = model(x) |
| | next_logits = logits[0, -1, :].float() |
| |
|
| | seq_len = len(ids) |
| |
|
| | for _ in range(max_tokens): |
| | if seq_len >= config.context_length: |
| | break |
| |
|
| | idx = _sample_logits(next_logits, temperature, top_k, top_p, config.vocab_size) |
| | seq_len += 1 |
| |
|
| | yield tokenizer.decode([idx]), prompt_len |
| |
|
| | x = torch.tensor([[idx]], dtype=torch.long, device=DEVICE) |
| | logits, kv_caches = model(x, kv_caches) |
| | next_logits = logits[0, -1, :].float() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def ensure_artifacts(): |
| | os.makedirs(CKPT_DIR, exist_ok=True) |
| | files = {} |
| |
|
| | |
| | ckpt_local = os.path.join(CKPT_DIR, CHECKPOINT_NAME) |
| | if not os.path.isfile(ckpt_local): |
| | print(f"Downloading {CHECKPOINT_NAME} from {HF_MODEL_REPO} ...") |
| | hf_hub_download(repo_id=HF_MODEL_REPO, filename=CHECKPOINT_NAME, local_dir=CKPT_DIR) |
| | sz_mb = os.path.getsize(ckpt_local) / (1024 * 1024) |
| | print(f" -> {ckpt_local} ({sz_mb:.1f} MB)") |
| | files["checkpoint"] = ckpt_local |
| |
|
| | |
| | for fname in ("vocab.json", "merges.txt"): |
| | local = os.path.join(CKPT_DIR, fname) |
| | if not os.path.isfile(local): |
| | print(f"Downloading {fname} from {HF_TOKENIZER_REPO} ...") |
| | hf_hub_download(repo_id=HF_TOKENIZER_REPO, filename=fname, local_dir=CKPT_DIR) |
| | sz_mb = os.path.getsize(local) / (1024 * 1024) |
| | print(f" -> {local} ({sz_mb:.1f} MB)") |
| | files[fname] = local |
| |
|
| | return files |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | print("Downloading artifacts...") |
| | ARTIFACT_PATHS = ensure_artifacts() |
| |
|
| | print("\nLoading SVD-compressed model...") |
| | state_dict = torch.load(ARTIFACT_PATHS["checkpoint"], map_location="cpu", weights_only=True) |
| |
|
| | |
| | CONFIG = SVDConfig.from_checkpoint(state_dict) |
| | MODEL = JuliaSLM_SVD(CONFIG) |
| | MODEL.load_state_dict(state_dict, strict=False) |
| | MODEL.eval() |
| | DEVICE = torch.device("cpu") |
| |
|
| | print("Loading tokenizer...") |
| | TOKENIZER = BPETokenizer( |
| | ARTIFACT_PATHS["vocab.json"], |
| | ARTIFACT_PATHS["merges.txt"], |
| | ) |
| |
|
| | MODEL_CREATED_AT = int(time.time()) |
| | NUM_PARAMS = MODEL.num_parameters |
| | print( |
| | f"\nSVD-compressed model ready: vocab={CONFIG.vocab_size}, d_model={CONFIG.d_model}, " |
| | f"layers={CONFIG.n_layers}, heads={CONFIG.n_heads}, " |
| | f"ctx={CONFIG.context_length}, params={NUM_PARAMS:,}" |
| | ) |
| | print("SVD-90 compression: ~4.5% parameter reduction") |
| | print("KV cache enabled: O(1) per-token decoding") |
| |
|
| | MODEL_LOCK = Lock() |
| |
|
| | |
| | |
| | |
| |
|
| | CORS_HEADERS = { |
| | "Access-Control-Allow-Origin": "*", |
| | "Access-Control-Allow-Methods": "GET, POST, OPTIONS", |
| | "Access-Control-Allow-Headers": "Content-Type, Authorization", |
| | } |
| |
|
| |
|
| | def extract_prompt(messages): |
| | if not messages: |
| | return "" |
| | for msg in reversed(messages): |
| | if msg.get("role") == "user": |
| | return msg.get("content", "") |
| | return messages[-1].get("content", "") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class Handler(BaseHTTPRequestHandler): |
| | def log_message(self, format, *args): |
| | print(f"[{self.log_date_time_string()}] {format % args}") |
| |
|
| | def _send_json(self, status, body): |
| | data = json.dumps(body).encode() |
| | self.send_response(status) |
| | self.send_header("Content-Type", "application/json") |
| | for k, v in CORS_HEADERS.items(): |
| | self.send_header(k, v) |
| | self.send_header("Content-Length", str(len(data))) |
| | self.end_headers() |
| | self.wfile.write(data) |
| |
|
| | def do_OPTIONS(self): |
| | self.send_response(204) |
| | for k, v in CORS_HEADERS.items(): |
| | self.send_header(k, v) |
| | self.end_headers() |
| |
|
| | def do_GET(self): |
| | if self.path == "/": |
| | self._send_json(200, { |
| | "name": "JuliaSLM-compressed-svd", |
| | "version": "1.0.0", |
| | "description": "SVD-compressed JuliaSLM β low-rank factorized weight matrices for efficient inference", |
| | "architecture": "MHA + RoPE + SwiGLU + RMSNorm + weight tying + SVD compression", |
| | "compression": { |
| | "method": "SVD-90", |
| | "original_params": 5_040_000, |
| | "compressed_params": NUM_PARAMS, |
| | "reduction_pct": round((1 - NUM_PARAMS / 5_040_000) * 100, 1), |
| | "val_loss": 3.756, |
| | "original_val_loss": 3.552, |
| | }, |
| | "model": { |
| | "vocab_size": CONFIG.vocab_size, |
| | "d_model": CONFIG.d_model, |
| | "n_layers": CONFIG.n_layers, |
| | "n_heads": CONFIG.n_heads, |
| | "context_length": CONFIG.context_length, |
| | "parameters": NUM_PARAMS, |
| | }, |
| | "endpoints": ["/v1/models", "/v1/chat/completions"], |
| | "features": ["streaming", "OpenAI-compatible", "top-k", "top-p", "kv-cache"], |
| | "compatible_with": ["OpenAI API", "OpenRouter"], |
| | }) |
| | elif self.path == "/v1/models": |
| | self._send_json(200, { |
| | "object": "list", |
| | "data": [{ |
| | "id": MODEL_ID, |
| | "object": "model", |
| | "created": MODEL_CREATED_AT, |
| | "owned_by": "juliaslm", |
| | }], |
| | }) |
| | else: |
| | self._send_json(404, {"error": { |
| | "message": f"Not found: GET {self.path}", |
| | "type": "invalid_request_error", |
| | "code": "not_found", |
| | }}) |
| |
|
| | def do_POST(self): |
| | if self.path != "/v1/chat/completions": |
| | self._send_json(404, {"error": { |
| | "message": f"Not found: POST {self.path}", |
| | "type": "invalid_request_error", |
| | "code": "not_found", |
| | }}) |
| | return |
| |
|
| | content_length = int(self.headers.get("Content-Length", 0)) |
| | try: |
| | body = json.loads(self.rfile.read(content_length)) |
| | except (json.JSONDecodeError, ValueError): |
| | self._send_json(400, {"error": { |
| | "message": "Invalid JSON in request body", |
| | "type": "invalid_request_error", |
| | "code": "invalid_json", |
| | }}) |
| | return |
| |
|
| | temperature = max(0.0, min(2.0, float(body.get("temperature", 0.8)))) |
| | max_tokens = max(1, min(CONFIG.context_length, int(body.get("max_tokens", 200)))) |
| | top_k_val = max(0, min(CONFIG.vocab_size, int(body.get("top_k", 40)))) |
| | top_p_val = max(0.0, min(1.0, float(body.get("top_p", 1.0)))) |
| | stream = bool(body.get("stream", False)) |
| |
|
| | messages = body.get("messages", []) |
| | prompt_text = extract_prompt(messages) |
| | completion_id = f"chatcmpl-{uuid.uuid4()}" |
| | created = int(time.time()) |
| |
|
| | with MODEL_LOCK: |
| | if stream: |
| | self._handle_stream( |
| | prompt_text, max_tokens, temperature, top_k_val, top_p_val, |
| | completion_id, created, |
| | ) |
| | else: |
| | self._handle_non_stream( |
| | prompt_text, max_tokens, temperature, top_k_val, top_p_val, |
| | completion_id, created, |
| | ) |
| |
|
| | def _handle_stream(self, prompt_text, max_tokens, temperature, top_k, top_p, |
| | completion_id, created): |
| | self.send_response(200) |
| | self.send_header("Content-Type", "text/event-stream") |
| | self.send_header("Cache-Control", "no-cache") |
| | self.send_header("X-Accel-Buffering", "no") |
| | for k, v in CORS_HEADERS.items(): |
| | self.send_header(k, v) |
| | self.end_headers() |
| |
|
| | def sse(data): |
| | self.wfile.write(f"data: {json.dumps(data)}\n\n".encode()) |
| | self.wfile.flush() |
| |
|
| | sse({ |
| | "id": completion_id, |
| | "object": "chat.completion.chunk", |
| | "created": created, |
| | "model": MODEL_ID, |
| | "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None}], |
| | }) |
| |
|
| | token_count = 0 |
| | prompt_tokens = 0 |
| | for token_str, p_len in generate_streaming( |
| | MODEL, TOKENIZER, prompt_text, |
| | max_tokens=max_tokens, temperature=temperature, |
| | top_k=top_k, top_p=top_p, |
| | ): |
| | token_count += 1 |
| | prompt_tokens = p_len |
| | sse({ |
| | "id": completion_id, |
| | "object": "chat.completion.chunk", |
| | "created": created, |
| | "model": MODEL_ID, |
| | "choices": [{"index": 0, "delta": {"content": token_str}, "finish_reason": None}], |
| | }) |
| |
|
| | sse({ |
| | "id": completion_id, |
| | "object": "chat.completion.chunk", |
| | "created": created, |
| | "model": MODEL_ID, |
| | "choices": [{"index": 0, "delta": {}, "finish_reason": "length" if token_count >= max_tokens else "stop"}], |
| | "usage": { |
| | "prompt_tokens": prompt_tokens, |
| | "completion_tokens": token_count, |
| | "total_tokens": prompt_tokens + token_count, |
| | }, |
| | }) |
| | self.wfile.write(b"data: [DONE]\n\n") |
| | self.wfile.flush() |
| |
|
| | def _handle_non_stream(self, prompt_text, max_tokens, temperature, top_k, top_p, |
| | completion_id, created): |
| | text, prompt_tokens = generate( |
| | MODEL, TOKENIZER, prompt_text, |
| | max_tokens=max_tokens, temperature=temperature, |
| | top_k=top_k, top_p=top_p, |
| | ) |
| | completion_tokens = len(TOKENIZER.encode(text)) |
| | finish_reason = "length" if completion_tokens >= max_tokens else "stop" |
| |
|
| | self._send_json(200, { |
| | "id": completion_id, |
| | "object": "chat.completion", |
| | "created": created, |
| | "model": MODEL_ID, |
| | "choices": [{ |
| | "index": 0, |
| | "message": {"role": "assistant", "content": text}, |
| | "finish_reason": finish_reason, |
| | }], |
| | "usage": { |
| | "prompt_tokens": prompt_tokens, |
| | "completion_tokens": completion_tokens, |
| | "total_tokens": prompt_tokens + completion_tokens, |
| | }, |
| | "system_fingerprint": "juliaslm-svd90-v1", |
| | }) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | if __name__ == "__main__": |
| | print(f"\nJuliaSLM-compressed-svd server starting on 0.0.0.0:{PORT} ...") |
| | print(f" GET http://localhost:{PORT}/") |
| | print(f" GET http://localhost:{PORT}/v1/models") |
| | print(f" POST http://localhost:{PORT}/v1/chat/completions") |
| | print(f" POST http://localhost:{PORT}/v1/chat/completions (stream=true)") |
| | print() |
| |
|
| | server = HTTPServer(("0.0.0.0", PORT), Handler) |
| | server.serve_forever() |
| |
|