LisaMegaWatts's picture
Upload server.py with huggingface_hub
0a315c6 verified
"""
server.py β€” OpenAI-compatible inference server for JuliaSLM-compressed-svd
Serves the SVD-90 compressed JuliaSLM model (4.81M params, ~4.5% smaller).
Downloads checkpoint and tokenizer from HuggingFace on first run.
SVD compression: each linear layer W β‰ˆ A @ B (low-rank factorization),
reducing parameter count while preserving model quality.
Endpoints:
GET / -> health check / API info
GET /v1/models -> list available models
POST /v1/chat/completions -> generate text (OpenAI format, streaming supported)
"""
import json
import os
import regex
import time
import uuid
from http.server import HTTPServer, BaseHTTPRequestHandler
from threading import Lock
import torch
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from juliaslm_svd_model import SVDConfig, JuliaSLM_SVD
# ═══════════════════════════════════════════════════════════════════
# Configuration
# ═══════════════════════════════════════════════════════════════════
HF_MODEL_REPO = os.environ.get("HF_MODEL_REPO", "LisaMegaWatts/JuliaSLM-compressed-svd")
HF_TOKENIZER_REPO = os.environ.get("HF_TOKENIZER_REPO", "LisaMegaWatts/JuliaSLM")
CHECKPOINT_NAME = os.environ.get("CHECKPOINT_NAME", "svd_SVD-90_best.pt")
PORT = int(os.environ.get("PORT", "7860"))
CKPT_DIR = "checkpoints"
MODEL_ID = "juliaslm-compressed-svd-90"
# ═══════════════════════════════════════════════════════════════════
# BPE Tokenizer (vocab.json + merges.txt)
# ═══════════════════════════════════════════════════════════════════
GPT2_PATTERN = regex.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
regex.UNICODE,
)
def _build_byte_to_unicode():
bs = list(range(0x21, 0x7F)) + list(range(0xA1, 0xAD)) + list(range(0xAE, 0x100))
cs = list(bs)
n = 0
for b in range(256):
if b not in bs:
bs.append(b)
cs.append(256 + n)
n += 1
return {b: chr(c) for b, c in zip(bs, cs)}
BYTE_TO_UNICODE = _build_byte_to_unicode()
UNICODE_TO_BYTE = {v: k for k, v in BYTE_TO_UNICODE.items()}
class BPETokenizer:
def __init__(self, vocab_path: str, merges_path: str):
with open(vocab_path, "r", encoding="utf-8") as f:
self.vocab = json.load(f)
self.id_to_token = {v: k for k, v in self.vocab.items()}
self.merges = []
self.merge_rank = {}
with open(merges_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
parts = line.split()
if len(parts) == 2:
pair = (parts[0], parts[1])
self.merges.append(pair)
self.merge_rank[pair] = len(self.merge_rank)
self.cache = {}
def _bpe_word(self, chars: list[str]) -> list[str]:
tokens = list(chars)
while len(tokens) >= 2:
best_rank = float("inf")
best_pair = None
for i in range(len(tokens) - 1):
pair = (tokens[i], tokens[i + 1])
rank = self.merge_rank.get(pair, float("inf"))
if rank < best_rank:
best_rank = rank
best_pair = pair
if best_pair is None or best_rank == float("inf"):
break
a, b = best_pair
new_tokens = []
i = 0
while i < len(tokens):
if i < len(tokens) - 1 and tokens[i] == a and tokens[i + 1] == b:
new_tokens.append(a + b)
i += 2
else:
new_tokens.append(tokens[i])
i += 1
tokens = new_tokens
return tokens
def encode(self, text: str) -> list[int]:
ids = []
for m in GPT2_PATTERN.finditer(text):
word = m.group()
if word in self.cache:
ids.extend(self.cache[word])
continue
chars = [BYTE_TO_UNICODE[b] for b in word.encode("utf-8")]
tokens = self._bpe_word(chars)
word_ids = [self.vocab[t] for t in tokens if t in self.vocab]
self.cache[word] = word_ids
ids.extend(word_ids)
return ids
def decode(self, ids: list[int]) -> str:
text = "".join(self.id_to_token.get(i, "") for i in ids)
byte_vals = [UNICODE_TO_BYTE[c] for c in text if c in UNICODE_TO_BYTE]
return bytes(byte_vals).decode("utf-8", errors="replace")
# ═══════════════════════════════════════════════════════════════════
# Sampling helpers
# ═══════════════════════════════════════════════════════════════════
def _sample_logits(logits: torch.Tensor, temperature: float, top_k: int,
top_p: float, vocab_size: int) -> int:
if temperature <= 0:
return logits.argmax().item()
logits = logits / temperature
if 0 < top_k < vocab_size:
topk_vals, _ = torch.topk(logits, top_k)
logits[logits < topk_vals[-1]] = float("-inf")
if top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
remove = cum_probs - F.softmax(sorted_logits, dim=-1) >= top_p
sorted_logits[remove] = float("-inf")
logits = sorted_logits.scatter(0, sorted_idx, sorted_logits)
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, 1).item()
# ═══════════════════════════════════════════════════════════════════
# Text generation with KV cache
# ═══════════════════════════════════════════════════════════════════
@torch.inference_mode()
def generate(
model: JuliaSLM_SVD,
tokenizer: BPETokenizer,
prompt: str,
max_tokens: int = 200,
temperature: float = 0.8,
top_k: int = 40,
top_p: float = 1.0,
) -> tuple[str, int]:
config = model.config
input_ids = tokenizer.encode(prompt)
prompt_len = len(input_ids)
ids = input_ids[-config.context_length:]
x = torch.tensor([ids], dtype=torch.long, device=DEVICE)
logits, kv_caches = model(x)
next_logits = logits[0, -1, :].float()
generated_ids = []
seq_len = len(ids)
for _ in range(max_tokens):
if seq_len >= config.context_length:
break
idx = _sample_logits(next_logits, temperature, top_k, top_p, config.vocab_size)
generated_ids.append(idx)
seq_len += 1
x = torch.tensor([[idx]], dtype=torch.long, device=DEVICE)
logits, kv_caches = model(x, kv_caches)
next_logits = logits[0, -1, :].float()
return tokenizer.decode(generated_ids), prompt_len
@torch.inference_mode()
def generate_streaming(
model: JuliaSLM_SVD,
tokenizer: BPETokenizer,
prompt: str,
max_tokens: int = 200,
temperature: float = 0.8,
top_k: int = 40,
top_p: float = 1.0,
):
config = model.config
input_ids = tokenizer.encode(prompt)
prompt_len = len(input_ids)
ids = input_ids[-config.context_length:]
x = torch.tensor([ids], dtype=torch.long, device=DEVICE)
logits, kv_caches = model(x)
next_logits = logits[0, -1, :].float()
seq_len = len(ids)
for _ in range(max_tokens):
if seq_len >= config.context_length:
break
idx = _sample_logits(next_logits, temperature, top_k, top_p, config.vocab_size)
seq_len += 1
yield tokenizer.decode([idx]), prompt_len
x = torch.tensor([[idx]], dtype=torch.long, device=DEVICE)
logits, kv_caches = model(x, kv_caches)
next_logits = logits[0, -1, :].float()
# ═══════════════════════════════════════════════════════════════════
# Download artifacts from HuggingFace
# ═══════════════════════════════════════════════════════════════════
def ensure_artifacts():
os.makedirs(CKPT_DIR, exist_ok=True)
files = {}
# Model checkpoint from SVD-compressed repo
ckpt_local = os.path.join(CKPT_DIR, CHECKPOINT_NAME)
if not os.path.isfile(ckpt_local):
print(f"Downloading {CHECKPOINT_NAME} from {HF_MODEL_REPO} ...")
hf_hub_download(repo_id=HF_MODEL_REPO, filename=CHECKPOINT_NAME, local_dir=CKPT_DIR)
sz_mb = os.path.getsize(ckpt_local) / (1024 * 1024)
print(f" -> {ckpt_local} ({sz_mb:.1f} MB)")
files["checkpoint"] = ckpt_local
# Tokenizer from original JuliaSLM repo
for fname in ("vocab.json", "merges.txt"):
local = os.path.join(CKPT_DIR, fname)
if not os.path.isfile(local):
print(f"Downloading {fname} from {HF_TOKENIZER_REPO} ...")
hf_hub_download(repo_id=HF_TOKENIZER_REPO, filename=fname, local_dir=CKPT_DIR)
sz_mb = os.path.getsize(local) / (1024 * 1024)
print(f" -> {local} ({sz_mb:.1f} MB)")
files[fname] = local
return files
# ═══════════════════════════════════════════════════════════════════
# Load model
# ═══════════════════════════════════════════════════════════════════
print("Downloading artifacts...")
ARTIFACT_PATHS = ensure_artifacts()
print("\nLoading SVD-compressed model...")
state_dict = torch.load(ARTIFACT_PATHS["checkpoint"], map_location="cpu", weights_only=True)
# Build config from checkpoint (auto-detects ranks per layer)
CONFIG = SVDConfig.from_checkpoint(state_dict)
MODEL = JuliaSLM_SVD(CONFIG)
MODEL.load_state_dict(state_dict, strict=False)
MODEL.eval()
DEVICE = torch.device("cpu")
print("Loading tokenizer...")
TOKENIZER = BPETokenizer(
ARTIFACT_PATHS["vocab.json"],
ARTIFACT_PATHS["merges.txt"],
)
MODEL_CREATED_AT = int(time.time())
NUM_PARAMS = MODEL.num_parameters
print(
f"\nSVD-compressed model ready: vocab={CONFIG.vocab_size}, d_model={CONFIG.d_model}, "
f"layers={CONFIG.n_layers}, heads={CONFIG.n_heads}, "
f"ctx={CONFIG.context_length}, params={NUM_PARAMS:,}"
)
print("SVD-90 compression: ~4.5% parameter reduction")
print("KV cache enabled: O(1) per-token decoding")
MODEL_LOCK = Lock()
# ═══════════════════════════════════════════════════════════════════
# HTTP helpers
# ═══════════════════════════════════════════════════════════════════
CORS_HEADERS = {
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization",
}
def extract_prompt(messages):
if not messages:
return ""
for msg in reversed(messages):
if msg.get("role") == "user":
return msg.get("content", "")
return messages[-1].get("content", "")
# ═══════════════════════════════════════════════════════════════════
# Request handler
# ═══════════════════════════════════════════════════════════════════
class Handler(BaseHTTPRequestHandler):
def log_message(self, format, *args):
print(f"[{self.log_date_time_string()}] {format % args}")
def _send_json(self, status, body):
data = json.dumps(body).encode()
self.send_response(status)
self.send_header("Content-Type", "application/json")
for k, v in CORS_HEADERS.items():
self.send_header(k, v)
self.send_header("Content-Length", str(len(data)))
self.end_headers()
self.wfile.write(data)
def do_OPTIONS(self):
self.send_response(204)
for k, v in CORS_HEADERS.items():
self.send_header(k, v)
self.end_headers()
def do_GET(self):
if self.path == "/":
self._send_json(200, {
"name": "JuliaSLM-compressed-svd",
"version": "1.0.0",
"description": "SVD-compressed JuliaSLM β€” low-rank factorized weight matrices for efficient inference",
"architecture": "MHA + RoPE + SwiGLU + RMSNorm + weight tying + SVD compression",
"compression": {
"method": "SVD-90",
"original_params": 5_040_000,
"compressed_params": NUM_PARAMS,
"reduction_pct": round((1 - NUM_PARAMS / 5_040_000) * 100, 1),
"val_loss": 3.756,
"original_val_loss": 3.552,
},
"model": {
"vocab_size": CONFIG.vocab_size,
"d_model": CONFIG.d_model,
"n_layers": CONFIG.n_layers,
"n_heads": CONFIG.n_heads,
"context_length": CONFIG.context_length,
"parameters": NUM_PARAMS,
},
"endpoints": ["/v1/models", "/v1/chat/completions"],
"features": ["streaming", "OpenAI-compatible", "top-k", "top-p", "kv-cache"],
"compatible_with": ["OpenAI API", "OpenRouter"],
})
elif self.path == "/v1/models":
self._send_json(200, {
"object": "list",
"data": [{
"id": MODEL_ID,
"object": "model",
"created": MODEL_CREATED_AT,
"owned_by": "juliaslm",
}],
})
else:
self._send_json(404, {"error": {
"message": f"Not found: GET {self.path}",
"type": "invalid_request_error",
"code": "not_found",
}})
def do_POST(self):
if self.path != "/v1/chat/completions":
self._send_json(404, {"error": {
"message": f"Not found: POST {self.path}",
"type": "invalid_request_error",
"code": "not_found",
}})
return
content_length = int(self.headers.get("Content-Length", 0))
try:
body = json.loads(self.rfile.read(content_length))
except (json.JSONDecodeError, ValueError):
self._send_json(400, {"error": {
"message": "Invalid JSON in request body",
"type": "invalid_request_error",
"code": "invalid_json",
}})
return
temperature = max(0.0, min(2.0, float(body.get("temperature", 0.8))))
max_tokens = max(1, min(CONFIG.context_length, int(body.get("max_tokens", 200))))
top_k_val = max(0, min(CONFIG.vocab_size, int(body.get("top_k", 40))))
top_p_val = max(0.0, min(1.0, float(body.get("top_p", 1.0))))
stream = bool(body.get("stream", False))
messages = body.get("messages", [])
prompt_text = extract_prompt(messages)
completion_id = f"chatcmpl-{uuid.uuid4()}"
created = int(time.time())
with MODEL_LOCK:
if stream:
self._handle_stream(
prompt_text, max_tokens, temperature, top_k_val, top_p_val,
completion_id, created,
)
else:
self._handle_non_stream(
prompt_text, max_tokens, temperature, top_k_val, top_p_val,
completion_id, created,
)
def _handle_stream(self, prompt_text, max_tokens, temperature, top_k, top_p,
completion_id, created):
self.send_response(200)
self.send_header("Content-Type", "text/event-stream")
self.send_header("Cache-Control", "no-cache")
self.send_header("X-Accel-Buffering", "no")
for k, v in CORS_HEADERS.items():
self.send_header(k, v)
self.end_headers()
def sse(data):
self.wfile.write(f"data: {json.dumps(data)}\n\n".encode())
self.wfile.flush()
sse({
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": MODEL_ID,
"choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None}],
})
token_count = 0
prompt_tokens = 0
for token_str, p_len in generate_streaming(
MODEL, TOKENIZER, prompt_text,
max_tokens=max_tokens, temperature=temperature,
top_k=top_k, top_p=top_p,
):
token_count += 1
prompt_tokens = p_len
sse({
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": MODEL_ID,
"choices": [{"index": 0, "delta": {"content": token_str}, "finish_reason": None}],
})
sse({
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": MODEL_ID,
"choices": [{"index": 0, "delta": {}, "finish_reason": "length" if token_count >= max_tokens else "stop"}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": token_count,
"total_tokens": prompt_tokens + token_count,
},
})
self.wfile.write(b"data: [DONE]\n\n")
self.wfile.flush()
def _handle_non_stream(self, prompt_text, max_tokens, temperature, top_k, top_p,
completion_id, created):
text, prompt_tokens = generate(
MODEL, TOKENIZER, prompt_text,
max_tokens=max_tokens, temperature=temperature,
top_k=top_k, top_p=top_p,
)
completion_tokens = len(TOKENIZER.encode(text))
finish_reason = "length" if completion_tokens >= max_tokens else "stop"
self._send_json(200, {
"id": completion_id,
"object": "chat.completion",
"created": created,
"model": MODEL_ID,
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": text},
"finish_reason": finish_reason,
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
"system_fingerprint": "juliaslm-svd90-v1",
})
# ═══════════════════════════════════════════════════════════════════
# Start server
# ═══════════════════════════════════════════════════════════════════
if __name__ == "__main__":
print(f"\nJuliaSLM-compressed-svd server starting on 0.0.0.0:{PORT} ...")
print(f" GET http://localhost:{PORT}/")
print(f" GET http://localhost:{PORT}/v1/models")
print(f" POST http://localhost:{PORT}/v1/chat/completions")
print(f" POST http://localhost:{PORT}/v1/chat/completions (stream=true)")
print()
server = HTTPServer(("0.0.0.0", PORT), Handler)
server.serve_forever()