JuliaGPT / server.py
DavinciDreams
Fix default temp=0.8, top_k=20, fix fallback tokenizer to 38 chars
443b747
"""
server.py β€” JuliaGPT OpenAI-compatible inference server
Serves POST /v1/chat/completions (streaming + non-streaming) and GET /v1/models.
Loads the Flux.jl GPT-2 model from best_model.jld2 on HF Hub.
Architecture: GPT-2 style β€” LayerNorm, GELU, combined QKV, learned position embeddings.
6 layers, 384-dim, 6 heads, 38-char vocab, val_loss=2.91.
Weights are extracted from JLD2 (HDF5-based) via h5py, loaded into PyTorch.
Follows the RandyGPT FastAPI/uvicorn pattern for proven HF Spaces compatibility.
"""
import json
import math
import time
import uuid
import os
import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.exceptions import RequestValidationError
from pydantic import BaseModel
from typing import List, Optional
from huggingface_hub import hf_hub_download
# ── Model definition (GPT-2 style, matches Flux training) ────────────────────
class CausalSelfAttention(nn.Module):
def __init__(self, n_embd, n_head):
super().__init__()
self.n_head = n_head
self.head_dim = n_embd // n_head
self.scale = 1.0 / math.sqrt(self.head_dim)
self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False)
self.proj = nn.Linear(n_embd, n_embd, bias=False)
def forward(self, x):
B, T, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.split(C, dim=-1)
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
scores = q @ k.transpose(-2, -1) * self.scale
mask = torch.full((T, T), float('-inf'), device=x.device).triu(1)
attn = F.softmax(scores + mask, dim=-1)
out = (attn @ v).transpose(1, 2).contiguous().view(B, T, C)
return self.proj(out)
class FeedForward(nn.Module):
def __init__(self, n_embd):
super().__init__()
self.fc1 = nn.Linear(n_embd, 4 * n_embd, bias=False)
self.fc2 = nn.Linear(4 * n_embd, n_embd, bias=False)
def forward(self, x):
return self.fc2(F.gelu(self.fc1(x)))
class TransformerBlock(nn.Module):
def __init__(self, n_embd, n_head):
super().__init__()
self.ln1 = nn.LayerNorm(n_embd)
self.attn = CausalSelfAttention(n_embd, n_head)
self.ln2 = nn.LayerNorm(n_embd)
self.ffwd = FeedForward(n_embd)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.ffwd(self.ln2(x))
return x
class GPT(nn.Module):
def __init__(self, vocab_size, n_embd, n_head, n_layer, block_size):
super().__init__()
self.block_size = block_size
self.wte = nn.Embedding(vocab_size, n_embd)
self.wpe = nn.Embedding(block_size, n_embd)
self.blocks = nn.ModuleList([TransformerBlock(n_embd, n_head) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
def forward(self, ids):
B, T = ids.shape
x = self.wte(ids) + self.wpe(torch.arange(T, device=ids.device).unsqueeze(0))
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
return self.lm_head(x)
@torch.no_grad()
def generate_stream(self, ids, max_new_tokens=200, temperature=0.1,
top_k=8, repetition_penalty=1.3):
self.eval()
generated = []
for i in range(max_new_tokens):
ctx = ids[:, -self.block_size:]
logits = self(ctx)[:, -1, :]
logits = logits[0]
if repetition_penalty > 1.0:
seen = set()
for t in generated[-self.block_size:]:
seen.add(t)
for t in ctx[0].tolist():
seen.add(t)
for t in seen:
if 0 <= t < logits.shape[0]:
if logits[t] > 0:
logits[t] /= repetition_penalty
else:
logits[t] *= repetition_penalty
logits = logits / max(temperature, 0.01)
if top_k > 0 and top_k < logits.shape[0]:
topk_vals, _ = torch.topk(logits, top_k)
logits[logits < topk_vals[-1]] = float('-inf')
probs = F.softmax(logits, dim=-1)
nxt = torch.multinomial(probs, 1)
ids = torch.cat([ids, nxt.view(1, 1)], dim=1)
token_id = nxt.item()
generated.append(token_id)
is_last = (i == max_new_tokens - 1)
yield token_id, is_last
@torch.no_grad()
def generate(self, ids, max_new_tokens=200, temperature=0.1,
top_k=8, repetition_penalty=1.3):
self.eval()
generated = []
for token_id, _ in self.generate_stream(ids, max_new_tokens, temperature,
top_k, repetition_penalty):
generated.append(token_id)
return generated
# ── Char-level tokenizer ──────────────────────────────────────────────────────
class CharTokenizer:
def __init__(self, uchars):
self.uchars = uchars
self.stoi = {c: i for i, c in enumerate(uchars)}
self.itos = {i: c for i, c in enumerate(uchars)}
self.vocab_size = len(uchars)
def encode(self, text):
return [self.stoi[c] for c in text.lower() if c in self.stoi]
def decode(self, ids):
return "".join(self.itos.get(i, "?") for i in ids)
# ── Load JLD2 weights via h5py ───────────────────────────────────────────────
def load_jld2_gpt2(jld2_path, vocab_path=None):
"""Load Flux GPT-2 weights from JLD2, build PyTorch model."""
print(f"Loading JLD2 from {jld2_path} ...")
f = h5py.File(jld2_path, "r")
ms = f["model_state"][()]
def deref(ref):
return np.array(f[ref])
# Get architecture params
b1 = ms["blocks"]["layers"]["1"]
n_head = int(b1["attn"]["n_head"])
wte_w = deref(ms["wte"]["weight"])
vocab_size, n_embd = wte_w.shape
wpe_w = deref(ms["wpe"]["weight"])
block_size = wpe_w.shape[0]
layer_names = sorted(ms["blocks"]["layers"].dtype.names, key=int)
n_layer = len(layer_names)
step = int(f["step"][()])
best_val = float(f["best_val_loss"][()])
print(f" vocab={vocab_size}, embd={n_embd}, heads={n_head}, layers={n_layer}, block={block_size}")
print(f" step={step}, best_val_loss={best_val:.4f}")
# Build PyTorch model
model = GPT(vocab_size, n_embd, n_head, n_layer, block_size)
state = {}
# Embeddings: h5py (vocab, embd) = PyTorch (vocab, embd), no transpose
state["wte.weight"] = torch.tensor(wte_w, dtype=torch.float32)
state["wpe.weight"] = torch.tensor(wpe_w, dtype=torch.float32)
# Dense weights: h5py gives (in, out) due to Julia column-major β†’ need .T for PyTorch (out, in)
for i, lname in enumerate(layer_names):
layer = ms["blocks"]["layers"][lname]
# LayerNorm (1D, no transpose)
state[f"blocks.{i}.ln1.weight"] = torch.tensor(deref(layer["ln1"]["diag"]["scale"]), dtype=torch.float32)
state[f"blocks.{i}.ln1.bias"] = torch.tensor(deref(layer["ln1"]["diag"]["bias"]), dtype=torch.float32)
state[f"blocks.{i}.ln2.weight"] = torch.tensor(deref(layer["ln2"]["diag"]["scale"]), dtype=torch.float32)
state[f"blocks.{i}.ln2.bias"] = torch.tensor(deref(layer["ln2"]["diag"]["bias"]), dtype=torch.float32)
# Attention QKV + proj (transpose Dense weights)
state[f"blocks.{i}.attn.qkv.weight"] = torch.tensor(deref(layer["attn"]["qkv"]["weight"]).T.copy(), dtype=torch.float32)
state[f"blocks.{i}.attn.proj.weight"] = torch.tensor(deref(layer["attn"]["proj"]["weight"]).T.copy(), dtype=torch.float32)
# FeedForward (transpose Dense weights)
state[f"blocks.{i}.ffwd.fc1.weight"] = torch.tensor(deref(layer["ffwd"]["net"]["layers"]["1"]["weight"]).T.copy(), dtype=torch.float32)
state[f"blocks.{i}.ffwd.fc2.weight"] = torch.tensor(deref(layer["ffwd"]["net"]["layers"]["3"]["weight"]).T.copy(), dtype=torch.float32)
# Final LayerNorm
state["ln_f.weight"] = torch.tensor(deref(ms["ln_f"]["diag"]["scale"]), dtype=torch.float32)
state["ln_f.bias"] = torch.tensor(deref(ms["ln_f"]["diag"]["bias"]), dtype=torch.float32)
# Output projection (transpose Dense weight)
state["lm_head.weight"] = torch.tensor(deref(ms["lm_head"]["weight"]).T.copy(), dtype=torch.float32)
model.load_state_dict(state)
model.eval()
f.close()
params = sum(p.numel() for p in model.parameters())
print(f" PyTorch model loaded: {params:,} params")
# Load char vocab
tok = None
if vocab_path and os.path.exists(vocab_path):
uchars = json.loads(Path(vocab_path).read_text())
tok = CharTokenizer(uchars)
print(f" Loaded char vocab: {tok.vocab_size} chars")
return model, tok, {
"vocab_size": vocab_size, "n_embd": n_embd, "n_head": n_head,
"n_layer": n_layer, "block_size": block_size, "step": step,
"best_val_loss": best_val, "params": params,
}
# ── Load model at startup ────────────────────────────────────────────────────
REPO = os.environ.get("HF_REPO", "LisaMegaWatts/JuliaGPT")
MODEL_ID = "juliagpt-philosophy"
print(f"Downloading model from {REPO} ...")
jld2_path = hf_hub_download(repo_id=REPO, filename="best_model.jld2")
try:
vocab_path = hf_hub_download(repo_id=REPO, filename="vocab.json")
except Exception:
vocab_path = None
model, tok, hp = load_jld2_gpt2(jld2_path, vocab_path)
n_embd = hp["n_embd"]
n_head = hp["n_head"]
n_layer = hp["n_layer"]
block_size = hp["block_size"]
vocab_size = hp["vocab_size"]
# Fallback tokenizer if vocab.json missing
if tok is None:
chars = [" ","!","\"","'","(",")",",","-",".",":",";","?","a","b","c","d","e","f","g","h","i","j","k","l","m","n","o","p","q","r","s","t","u","v","w","x","y","z"]
tok = CharTokenizer(chars)
print(f" Built fallback char vocab: {tok.vocab_size} chars")
print(f"\nModel ready β€” {hp['params']:,} params, vocab={tok.vocab_size}, val_loss={hp['best_val_loss']:.4f}")
# ── FastAPI app ───────────────────────────────────────────────────────────────
app = FastAPI(title="JuliaGPT", version="2.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
def _openai_error(status, message, err_type="invalid_request_error", code=None):
body = {"error": {"message": message, "type": err_type}}
if code:
body["error"]["code"] = code
return JSONResponse(status_code=status, content=body)
@app.exception_handler(HTTPException)
async def http_exc(request, exc):
return _openai_error(exc.status_code, str(exc.detail))
@app.exception_handler(RequestValidationError)
async def val_exc(request, exc):
msg = "; ".join(f"{e['loc'][-1]}: {e['msg']}" for e in exc.errors())
return _openai_error(422, msg, code="invalid_request_error")
@app.get("/")
def root():
return {
"name": "JuliaGPT",
"version": "2.0.0",
"description": "Flux.jl GPT-2 trained on classical philosophy (served via PyTorch)",
"architecture": "GPT-2 (LayerNorm, GELU, combined QKV)",
"model": {
"vocab_size": tok.vocab_size, "n_embd": n_embd,
"n_layer": n_layer, "n_head": n_head,
"block_size": block_size, "params": hp["params"],
},
"endpoints": ["/v1/models", "/v1/chat/completions"],
"features": ["streaming", "OpenAI-compatible"],
}
@app.get("/v1/models")
def list_models():
return {
"object": "list",
"data": [{"id": MODEL_ID, "object": "model",
"created": 1700000000, "owned_by": "juliagpt"}]
}
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: Optional[str] = MODEL_ID
messages: List[Message]
max_tokens: Optional[int] = 200
temperature: Optional[float] = 0.8
top_k: Optional[int] = 20
repetition_penalty: Optional[float] = 1.3
n: Optional[int] = 1
stream: Optional[bool] = False
def _sse(data):
return f"data: {json.dumps(data)}\n\n"
def _stream_completion(ids, max_tokens, temperature, top_k, rep_penalty,
completion_id, _model, _tok):
yield _sse({
"id": completion_id, "object": "chat.completion.chunk",
"created": int(time.time()), "model": MODEL_ID,
"choices": [{"index": 0, "delta": {"role": "assistant", "content": ""},
"finish_reason": None}],
})
token_count = 0
for token_id, is_last in _model.generate_stream(
ids, max_new_tokens=max_tokens, temperature=temperature,
top_k=top_k, repetition_penalty=rep_penalty
):
token_text = _tok.decode([token_id])
token_count += 1
finish_reason = ("length" if token_count >= max_tokens else "stop") if is_last else None
yield _sse({
"id": completion_id, "object": "chat.completion.chunk",
"created": int(time.time()), "model": MODEL_ID,
"choices": [{"index": 0, "delta": {"content": token_text},
"finish_reason": finish_reason}],
})
yield "data: [DONE]\n\n"
@app.post("/v1/chat/completions")
def chat_completions(req: ChatRequest):
_m, _t = model, tok
prompt = req.messages[-1].content.strip() if req.messages else ""
if not prompt:
raise HTTPException(status_code=400, detail="No content in messages")
ids = _t.encode(prompt)
if not ids:
ids = [0]
max_tokens = max(1, min(req.max_tokens or 200, block_size))
temperature = max(0.01, min(req.temperature or 0.8, 2.0))
top_k = max(1, min(req.top_k or 20, tok.vocab_size))
rep_penalty = max(1.0, min(req.repetition_penalty or 1.3, 3.0))
n = max(1, min(req.n or 1, 4))
completion_id = f"chatcmpl-{uuid.uuid4().hex[:8]}"
tensor = torch.tensor([ids], dtype=torch.long)
if req.stream:
return StreamingResponse(
_stream_completion(tensor, max_tokens, temperature, top_k,
rep_penalty, completion_id, _m, _t),
media_type="text/event-stream",
headers={"X-Accel-Buffering": "no"},
)
choices = []
total_completion_tokens = 0
for i in range(n):
generated = _m.generate(tensor.clone(), max_new_tokens=max_tokens,
temperature=temperature, top_k=top_k,
repetition_penalty=rep_penalty)
text = _t.decode(generated)
total_completion_tokens += len(generated)
choices.append({
"index": i,
"message": {"role": "assistant", "content": text},
"finish_reason": "length" if len(generated) >= max_tokens else "stop",
})
return {
"id": completion_id, "object": "chat.completion",
"created": int(time.time()), "model": MODEL_ID,
"system_fingerprint": "juliagpt-v2",
"choices": choices,
"usage": {
"prompt_tokens": len(ids),
"completion_tokens": total_completion_tokens,
"total_tokens": len(ids) + total_completion_tokens,
},
}