""" server.py — JuliaGPT OpenAI-compatible inference server Serves POST /v1/chat/completions (streaming + non-streaming) and GET /v1/models. Loads the Flux.jl GPT-2 model from best_model.jld2 on HF Hub. Architecture: GPT-2 style — LayerNorm, GELU, combined QKV, learned position embeddings. 6 layers, 384-dim, 6 heads, 38-char vocab, val_loss=2.91. Weights are extracted from JLD2 (HDF5-based) via h5py, loaded into PyTorch. Follows the RandyGPT FastAPI/uvicorn pattern for proven HF Spaces compatibility. """ import json import math import time import uuid import os import h5py import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from pathlib import Path from fastapi import FastAPI, HTTPException, Request from fastapi.responses import JSONResponse, StreamingResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.exceptions import RequestValidationError from pydantic import BaseModel from typing import List, Optional from huggingface_hub import hf_hub_download # ── Model definition (GPT-2 style, matches Flux training) ──────────────────── class CausalSelfAttention(nn.Module): def __init__(self, n_embd, n_head): super().__init__() self.n_head = n_head self.head_dim = n_embd // n_head self.scale = 1.0 / math.sqrt(self.head_dim) self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False) self.proj = nn.Linear(n_embd, n_embd, bias=False) def forward(self, x): B, T, C = x.shape qkv = self.qkv(x) q, k, v = qkv.split(C, dim=-1) q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) scores = q @ k.transpose(-2, -1) * self.scale mask = torch.full((T, T), float('-inf'), device=x.device).triu(1) attn = F.softmax(scores + mask, dim=-1) out = (attn @ v).transpose(1, 2).contiguous().view(B, T, C) return self.proj(out) class FeedForward(nn.Module): def __init__(self, n_embd): super().__init__() self.fc1 = nn.Linear(n_embd, 4 * n_embd, bias=False) self.fc2 = nn.Linear(4 * n_embd, n_embd, bias=False) def forward(self, x): return self.fc2(F.gelu(self.fc1(x))) class TransformerBlock(nn.Module): def __init__(self, n_embd, n_head): super().__init__() self.ln1 = nn.LayerNorm(n_embd) self.attn = CausalSelfAttention(n_embd, n_head) self.ln2 = nn.LayerNorm(n_embd) self.ffwd = FeedForward(n_embd) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.ffwd(self.ln2(x)) return x class GPT(nn.Module): def __init__(self, vocab_size, n_embd, n_head, n_layer, block_size): super().__init__() self.block_size = block_size self.wte = nn.Embedding(vocab_size, n_embd) self.wpe = nn.Embedding(block_size, n_embd) self.blocks = nn.ModuleList([TransformerBlock(n_embd, n_head) for _ in range(n_layer)]) self.ln_f = nn.LayerNorm(n_embd) self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) def forward(self, ids): B, T = ids.shape x = self.wte(ids) + self.wpe(torch.arange(T, device=ids.device).unsqueeze(0)) for block in self.blocks: x = block(x) x = self.ln_f(x) return self.lm_head(x) @torch.no_grad() def generate_stream(self, ids, max_new_tokens=200, temperature=0.1, top_k=8, repetition_penalty=1.3): self.eval() generated = [] for i in range(max_new_tokens): ctx = ids[:, -self.block_size:] logits = self(ctx)[:, -1, :] logits = logits[0] if repetition_penalty > 1.0: seen = set() for t in generated[-self.block_size:]: seen.add(t) for t in ctx[0].tolist(): seen.add(t) for t in seen: if 0 <= t < logits.shape[0]: if logits[t] > 0: logits[t] /= repetition_penalty else: logits[t] *= repetition_penalty logits = logits / max(temperature, 0.01) if top_k > 0 and top_k < logits.shape[0]: topk_vals, _ = torch.topk(logits, top_k) logits[logits < topk_vals[-1]] = float('-inf') probs = F.softmax(logits, dim=-1) nxt = torch.multinomial(probs, 1) ids = torch.cat([ids, nxt.view(1, 1)], dim=1) token_id = nxt.item() generated.append(token_id) is_last = (i == max_new_tokens - 1) yield token_id, is_last @torch.no_grad() def generate(self, ids, max_new_tokens=200, temperature=0.1, top_k=8, repetition_penalty=1.3): self.eval() generated = [] for token_id, _ in self.generate_stream(ids, max_new_tokens, temperature, top_k, repetition_penalty): generated.append(token_id) return generated # ── Char-level tokenizer ────────────────────────────────────────────────────── class CharTokenizer: def __init__(self, uchars): self.uchars = uchars self.stoi = {c: i for i, c in enumerate(uchars)} self.itos = {i: c for i, c in enumerate(uchars)} self.vocab_size = len(uchars) def encode(self, text): return [self.stoi[c] for c in text.lower() if c in self.stoi] def decode(self, ids): return "".join(self.itos.get(i, "?") for i in ids) # ── Load JLD2 weights via h5py ─────────────────────────────────────────────── def load_jld2_gpt2(jld2_path, vocab_path=None): """Load Flux GPT-2 weights from JLD2, build PyTorch model.""" print(f"Loading JLD2 from {jld2_path} ...") f = h5py.File(jld2_path, "r") ms = f["model_state"][()] def deref(ref): return np.array(f[ref]) # Get architecture params b1 = ms["blocks"]["layers"]["1"] n_head = int(b1["attn"]["n_head"]) wte_w = deref(ms["wte"]["weight"]) vocab_size, n_embd = wte_w.shape wpe_w = deref(ms["wpe"]["weight"]) block_size = wpe_w.shape[0] layer_names = sorted(ms["blocks"]["layers"].dtype.names, key=int) n_layer = len(layer_names) step = int(f["step"][()]) best_val = float(f["best_val_loss"][()]) print(f" vocab={vocab_size}, embd={n_embd}, heads={n_head}, layers={n_layer}, block={block_size}") print(f" step={step}, best_val_loss={best_val:.4f}") # Build PyTorch model model = GPT(vocab_size, n_embd, n_head, n_layer, block_size) state = {} # Embeddings: h5py (vocab, embd) = PyTorch (vocab, embd), no transpose state["wte.weight"] = torch.tensor(wte_w, dtype=torch.float32) state["wpe.weight"] = torch.tensor(wpe_w, dtype=torch.float32) # Dense weights: h5py gives (in, out) due to Julia column-major → need .T for PyTorch (out, in) for i, lname in enumerate(layer_names): layer = ms["blocks"]["layers"][lname] # LayerNorm (1D, no transpose) state[f"blocks.{i}.ln1.weight"] = torch.tensor(deref(layer["ln1"]["diag"]["scale"]), dtype=torch.float32) state[f"blocks.{i}.ln1.bias"] = torch.tensor(deref(layer["ln1"]["diag"]["bias"]), dtype=torch.float32) state[f"blocks.{i}.ln2.weight"] = torch.tensor(deref(layer["ln2"]["diag"]["scale"]), dtype=torch.float32) state[f"blocks.{i}.ln2.bias"] = torch.tensor(deref(layer["ln2"]["diag"]["bias"]), dtype=torch.float32) # Attention QKV + proj (transpose Dense weights) state[f"blocks.{i}.attn.qkv.weight"] = torch.tensor(deref(layer["attn"]["qkv"]["weight"]).T.copy(), dtype=torch.float32) state[f"blocks.{i}.attn.proj.weight"] = torch.tensor(deref(layer["attn"]["proj"]["weight"]).T.copy(), dtype=torch.float32) # FeedForward (transpose Dense weights) state[f"blocks.{i}.ffwd.fc1.weight"] = torch.tensor(deref(layer["ffwd"]["net"]["layers"]["1"]["weight"]).T.copy(), dtype=torch.float32) state[f"blocks.{i}.ffwd.fc2.weight"] = torch.tensor(deref(layer["ffwd"]["net"]["layers"]["3"]["weight"]).T.copy(), dtype=torch.float32) # Final LayerNorm state["ln_f.weight"] = torch.tensor(deref(ms["ln_f"]["diag"]["scale"]), dtype=torch.float32) state["ln_f.bias"] = torch.tensor(deref(ms["ln_f"]["diag"]["bias"]), dtype=torch.float32) # Output projection (transpose Dense weight) state["lm_head.weight"] = torch.tensor(deref(ms["lm_head"]["weight"]).T.copy(), dtype=torch.float32) model.load_state_dict(state) model.eval() f.close() params = sum(p.numel() for p in model.parameters()) print(f" PyTorch model loaded: {params:,} params") # Load char vocab tok = None if vocab_path and os.path.exists(vocab_path): uchars = json.loads(Path(vocab_path).read_text()) tok = CharTokenizer(uchars) print(f" Loaded char vocab: {tok.vocab_size} chars") return model, tok, { "vocab_size": vocab_size, "n_embd": n_embd, "n_head": n_head, "n_layer": n_layer, "block_size": block_size, "step": step, "best_val_loss": best_val, "params": params, } # ── Load model at startup ──────────────────────────────────────────────────── REPO = os.environ.get("HF_REPO", "LisaMegaWatts/JuliaGPT") MODEL_ID = "juliagpt-philosophy" print(f"Downloading model from {REPO} ...") jld2_path = hf_hub_download(repo_id=REPO, filename="best_model.jld2") try: vocab_path = hf_hub_download(repo_id=REPO, filename="vocab.json") except Exception: vocab_path = None model, tok, hp = load_jld2_gpt2(jld2_path, vocab_path) n_embd = hp["n_embd"] n_head = hp["n_head"] n_layer = hp["n_layer"] block_size = hp["block_size"] vocab_size = hp["vocab_size"] # Fallback tokenizer if vocab.json missing if tok is None: chars = [" ","!","\"","'","(",")",",","-",".",":",";","?","a","b","c","d","e","f","g","h","i","j","k","l","m","n","o","p","q","r","s","t","u","v","w","x","y","z"] tok = CharTokenizer(chars) print(f" Built fallback char vocab: {tok.vocab_size} chars") print(f"\nModel ready — {hp['params']:,} params, vocab={tok.vocab_size}, val_loss={hp['best_val_loss']:.4f}") # ── FastAPI app ─────────────────────────────────────────────────────────────── app = FastAPI(title="JuliaGPT", version="2.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) def _openai_error(status, message, err_type="invalid_request_error", code=None): body = {"error": {"message": message, "type": err_type}} if code: body["error"]["code"] = code return JSONResponse(status_code=status, content=body) @app.exception_handler(HTTPException) async def http_exc(request, exc): return _openai_error(exc.status_code, str(exc.detail)) @app.exception_handler(RequestValidationError) async def val_exc(request, exc): msg = "; ".join(f"{e['loc'][-1]}: {e['msg']}" for e in exc.errors()) return _openai_error(422, msg, code="invalid_request_error") @app.get("/") def root(): return { "name": "JuliaGPT", "version": "2.0.0", "description": "Flux.jl GPT-2 trained on classical philosophy (served via PyTorch)", "architecture": "GPT-2 (LayerNorm, GELU, combined QKV)", "model": { "vocab_size": tok.vocab_size, "n_embd": n_embd, "n_layer": n_layer, "n_head": n_head, "block_size": block_size, "params": hp["params"], }, "endpoints": ["/v1/models", "/v1/chat/completions"], "features": ["streaming", "OpenAI-compatible"], } @app.get("/v1/models") def list_models(): return { "object": "list", "data": [{"id": MODEL_ID, "object": "model", "created": 1700000000, "owned_by": "juliagpt"}] } class Message(BaseModel): role: str content: str class ChatRequest(BaseModel): model: Optional[str] = MODEL_ID messages: List[Message] max_tokens: Optional[int] = 200 temperature: Optional[float] = 0.8 top_k: Optional[int] = 20 repetition_penalty: Optional[float] = 1.3 n: Optional[int] = 1 stream: Optional[bool] = False def _sse(data): return f"data: {json.dumps(data)}\n\n" def _stream_completion(ids, max_tokens, temperature, top_k, rep_penalty, completion_id, _model, _tok): yield _sse({ "id": completion_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": MODEL_ID, "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None}], }) token_count = 0 for token_id, is_last in _model.generate_stream( ids, max_new_tokens=max_tokens, temperature=temperature, top_k=top_k, repetition_penalty=rep_penalty ): token_text = _tok.decode([token_id]) token_count += 1 finish_reason = ("length" if token_count >= max_tokens else "stop") if is_last else None yield _sse({ "id": completion_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": MODEL_ID, "choices": [{"index": 0, "delta": {"content": token_text}, "finish_reason": finish_reason}], }) yield "data: [DONE]\n\n" @app.post("/v1/chat/completions") def chat_completions(req: ChatRequest): _m, _t = model, tok prompt = req.messages[-1].content.strip() if req.messages else "" if not prompt: raise HTTPException(status_code=400, detail="No content in messages") ids = _t.encode(prompt) if not ids: ids = [0] max_tokens = max(1, min(req.max_tokens or 200, block_size)) temperature = max(0.01, min(req.temperature or 0.8, 2.0)) top_k = max(1, min(req.top_k or 20, tok.vocab_size)) rep_penalty = max(1.0, min(req.repetition_penalty or 1.3, 3.0)) n = max(1, min(req.n or 1, 4)) completion_id = f"chatcmpl-{uuid.uuid4().hex[:8]}" tensor = torch.tensor([ids], dtype=torch.long) if req.stream: return StreamingResponse( _stream_completion(tensor, max_tokens, temperature, top_k, rep_penalty, completion_id, _m, _t), media_type="text/event-stream", headers={"X-Accel-Buffering": "no"}, ) choices = [] total_completion_tokens = 0 for i in range(n): generated = _m.generate(tensor.clone(), max_new_tokens=max_tokens, temperature=temperature, top_k=top_k, repetition_penalty=rep_penalty) text = _t.decode(generated) total_completion_tokens += len(generated) choices.append({ "index": i, "message": {"role": "assistant", "content": text}, "finish_reason": "length" if len(generated) >= max_tokens else "stop", }) return { "id": completion_id, "object": "chat.completion", "created": int(time.time()), "model": MODEL_ID, "system_fingerprint": "juliagpt-v2", "choices": choices, "usage": { "prompt_tokens": len(ids), "completion_tokens": total_completion_tokens, "total_tokens": len(ids) + total_completion_tokens, }, }