SymbioGPT-10M-space / server.py
LisaMegaWatts's picture
fix: real token-by-token streaming (was generating all tokens then splitting by spaces)
c79eabb verified
"""SymbioGPT-10M β€” OpenAI-compatible inference server.
Serves a PyTorch SymbioGPT model (4 organelles: CausalConv + Monarch +
LongConv + Attention, fused via OrganelleGate). Downloads checkpoint and
tokenizer from HuggingFace on first run.
"""
import math
import os
import time
import uuid
import torch
import torch.nn.functional as F
import uvicorn
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from huggingface_hub import hf_hub_download
from symbio_model import SymbioConfig, SymbioGPT
from tokenizer import BPETokenizer
# ═══════════════════════════════════════════════════════════════════
# Configuration
# ═══════════════════════════════════════════════════════════════════
HF_REPO = os.environ.get("HF_REPO", "LisaMegaWatts/SymbioGPT-10M")
PORT = int(os.environ.get("PORT", "7860"))
CHECKPOINT_FILE = "symbio_best.pt"
MODEL_CONFIG = SymbioConfig(
d_model=320,
n_layers=8,
n_heads=5,
head_dim=64,
ffn_mult=4,
context_length=256,
vocab_size=2000,
weight_tying=True,
organelles=("causal_conv", "monarch", "long_conv", "attention"),
conv_kernel_size=4,
n_monarch_heads=1,
gate_temperature_init=1.0,
free_energy_beta=0.001,
)
# ═══════════════════════════════════════════════════════════════════
# Load model and tokenizer
# ═══════════════════════════════════════════════════════════════════
print(f"Downloading artifacts from {HF_REPO} ...")
ckpt_path = hf_hub_download(repo_id=HF_REPO, filename=CHECKPOINT_FILE)
vocab_path = hf_hub_download(repo_id=HF_REPO, filename="vocab.json")
merges_path = hf_hub_download(repo_id=HF_REPO, filename="merges.txt")
print("Loading tokenizer ...")
tokenizer = BPETokenizer.from_files(vocab_path, merges_path)
print(f" BPE vocab_size = {tokenizer.vocab_size}")
# Adjust vocab_size to match tokenizer
if tokenizer.vocab_size != MODEL_CONFIG.vocab_size:
print(f" Adjusting model vocab_size: {MODEL_CONFIG.vocab_size} -> {tokenizer.vocab_size}")
MODEL_CONFIG.vocab_size = tokenizer.vocab_size
print("Loading model ...")
model = SymbioGPT(MODEL_CONFIG)
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
# Handle both raw state_dict and wrapped checkpoint formats
if "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
elif "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
# Strip _orig_mod. prefix from torch.compile checkpoints
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()
n_params = sum(p.numel() for p in model.parameters())
print(f" Model loaded: {n_params/1e6:.1f}M params")
print(f" Config: d={MODEL_CONFIG.d_model}, L={MODEL_CONFIG.n_layers}, "
f"H={MODEL_CONFIG.n_heads}, ctx={MODEL_CONFIG.context_length}, "
f"vocab={MODEL_CONFIG.vocab_size}")
print(f" Organelles: {MODEL_CONFIG.organelles}")
# ═══════════════════════════════════════════════════════════════════
# Generation
# ═══════════════════════════════════════════════════════════════════
@torch.no_grad()
def generate_streaming(
prompt: str,
max_tokens: int = 200,
temperature: float = 0.8,
top_k: int = 40,
top_p: float = 1.0,
):
"""Generator yielding token strings one at a time for real SSE streaming."""
tokens = tokenizer.encode(prompt)
if not tokens:
tokens = [0]
idx = torch.tensor([tokens], dtype=torch.long)
for _ in range(max_tokens):
idx_cond = idx[:, -MODEL_CONFIG.context_length:]
logits = model(idx_cond)
logits_last = logits[0, -1, :].float()
if temperature > 0.01:
logits_last = logits_last / temperature
else:
logits_last = logits_last / 0.01
if 0 < top_k < logits_last.size(0):
threshold = torch.topk(logits_last, top_k).values[-1]
logits_last[logits_last < threshold] = float("-inf")
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits_last, descending=True)
probs_sorted = F.softmax(sorted_logits, dim=-1)
cumprobs = torch.cumsum(probs_sorted, dim=-1)
cutoff_mask = cumprobs - probs_sorted > top_p
sorted_logits[cutoff_mask] = float("-inf")
logits_last = sorted_logits.scatter(0, sorted_indices, sorted_logits)
probs = F.softmax(logits_last, dim=-1)
next_id = torch.multinomial(probs, 1).item()
idx = torch.cat([idx, torch.tensor([[next_id]])], dim=1)
yield tokenizer.decode([next_id])
@torch.no_grad()
def generate(
prompt: str,
max_tokens: int = 200,
temperature: float = 0.8,
top_k: int = 40,
top_p: float = 1.0,
) -> str:
"""Generate complete text (non-streaming wrapper)."""
return "".join(generate_streaming(prompt, max_tokens, temperature, top_k, top_p))
# ═══════════════════════════════════════════════════════════════════
# FastAPI server
# ═══════════════════════════════════════════════════════════════════
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
MODEL_CREATED_AT = int(time.time())
def extract_prompt(messages):
if not messages:
return ""
for msg in reversed(messages):
if msg.get("role") == "user":
return msg.get("content", "")
return messages[-1].get("content", "")
@app.get("/")
def health():
return {
"name": "SymbioGPT-10M",
"version": "1.0.0",
"description": "Multi-organelle GPT trained on classical philosophy β€” "
"CausalConv + Monarch + LongConv + Attention fused via OrganelleGate",
"architecture": "Decoder-only (4 organelles + OrganelleGate, RoPE, RMSNorm, SwiGLU, "
"SkipGate, weight-tied)",
"model": {
"d_model": MODEL_CONFIG.d_model,
"n_layers": MODEL_CONFIG.n_layers,
"n_heads": MODEL_CONFIG.n_heads,
"head_dim": MODEL_CONFIG.head_dim,
"context_length": MODEL_CONFIG.context_length,
"vocab_size": MODEL_CONFIG.vocab_size,
"n_monarch_heads": MODEL_CONFIG.n_monarch_heads,
"params": f"{n_params/1e6:.1f}M",
},
"organelles": list(MODEL_CONFIG.organelles),
"endpoints": ["/v1/models", "/v1/chat/completions"],
"features": ["streaming", "OpenAI-compatible", "top-k", "top-p"],
"compatible_with": ["OpenAI API", "OpenRouter"],
}
@app.get("/v1/models")
def list_models():
return {
"object": "list",
"data": [{
"id": "symbiogpt-10m",
"object": "model",
"created": MODEL_CREATED_AT,
"owned_by": "symbiogpt",
}],
}
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
try:
body = await request.json()
except Exception:
return JSONResponse(status_code=400, content={
"error": {"message": "Invalid JSON", "type": "invalid_request_error"}
})
temperature = max(0.01, min(2.0, body.get("temperature", 0.8)))
max_tokens = max(1, min(MODEL_CONFIG.context_length, body.get("max_tokens", 200)))
top_k_val = max(0, min(MODEL_CONFIG.vocab_size, body.get("top_k", 40)))
top_p_val = max(0.0, min(1.0, body.get("top_p", 1.0)))
stream = body.get("stream", False)
messages = body.get("messages", [])
prompt_text = extract_prompt(messages)
prompt_tokens = len(tokenizer.encode(prompt_text)) if prompt_text else 0
completion_id = f"chatcmpl-{uuid.uuid4()}"
created = int(time.time())
if stream:
import json as json_mod
def sse_stream():
initial = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": "symbiogpt-10m",
"choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None}],
}
yield f"data: {json_mod.dumps(initial)}\n\n"
token_count = 0
for token_str in generate_streaming(
prompt_text, max_tokens=max_tokens, temperature=temperature,
top_k=top_k_val, top_p=top_p_val,
):
token_count += 1
chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": "symbiogpt-10m",
"choices": [{"index": 0, "delta": {"content": token_str}, "finish_reason": None}],
}
yield f"data: {json_mod.dumps(chunk)}\n\n"
finish = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": "symbiogpt-10m",
"choices": [{"index": 0, "delta": {}, "finish_reason": "length" if token_count >= max_tokens else "stop"}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": token_count,
"total_tokens": prompt_tokens + token_count,
},
}
yield f"data: {json_mod.dumps(finish)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(sse_stream(), media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
else:
n_completions = max(1, min(4, body.get("n", 1)))
choices = []
for i in range(n_completions):
text = generate(prompt_text, max_tokens=max_tokens, temperature=temperature,
top_k=top_k_val, top_p=top_p_val)
choices.append({
"index": i,
"message": {"role": "assistant", "content": text},
"finish_reason": "length",
})
return {
"id": completion_id,
"object": "chat.completion",
"created": created,
"model": "symbiogpt-10m",
"choices": choices,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": max_tokens * n_completions,
"total_tokens": prompt_tokens + max_tokens * n_completions,
},
"system_fingerprint": "symbiogpt-10m-v1",
}
if __name__ == "__main__":
print(f"\nSymbioGPT-10M server starting on 0.0.0.0:{PORT} ...")
print(f" GET http://localhost:{PORT}/")
print(f" GET http://localhost:{PORT}/v1/models")
print(f" POST http://localhost:{PORT}/v1/chat/completions")
uvicorn.run(app, host="0.0.0.0", port=PORT)