fix: real token-by-token streaming (was generating all tokens then splitting by spaces)
c79eabb verified | """SymbioGPT-10M β OpenAI-compatible inference server. | |
| Serves a PyTorch SymbioGPT model (4 organelles: CausalConv + Monarch + | |
| LongConv + Attention, fused via OrganelleGate). Downloads checkpoint and | |
| tokenizer from HuggingFace on first run. | |
| """ | |
| import math | |
| import os | |
| import time | |
| import uuid | |
| import torch | |
| import torch.nn.functional as F | |
| import uvicorn | |
| from fastapi import FastAPI, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from huggingface_hub import hf_hub_download | |
| from symbio_model import SymbioConfig, SymbioGPT | |
| from tokenizer import BPETokenizer | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Configuration | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| HF_REPO = os.environ.get("HF_REPO", "LisaMegaWatts/SymbioGPT-10M") | |
| PORT = int(os.environ.get("PORT", "7860")) | |
| CHECKPOINT_FILE = "symbio_best.pt" | |
| MODEL_CONFIG = SymbioConfig( | |
| d_model=320, | |
| n_layers=8, | |
| n_heads=5, | |
| head_dim=64, | |
| ffn_mult=4, | |
| context_length=256, | |
| vocab_size=2000, | |
| weight_tying=True, | |
| organelles=("causal_conv", "monarch", "long_conv", "attention"), | |
| conv_kernel_size=4, | |
| n_monarch_heads=1, | |
| gate_temperature_init=1.0, | |
| free_energy_beta=0.001, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Load model and tokenizer | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"Downloading artifacts from {HF_REPO} ...") | |
| ckpt_path = hf_hub_download(repo_id=HF_REPO, filename=CHECKPOINT_FILE) | |
| vocab_path = hf_hub_download(repo_id=HF_REPO, filename="vocab.json") | |
| merges_path = hf_hub_download(repo_id=HF_REPO, filename="merges.txt") | |
| print("Loading tokenizer ...") | |
| tokenizer = BPETokenizer.from_files(vocab_path, merges_path) | |
| print(f" BPE vocab_size = {tokenizer.vocab_size}") | |
| # Adjust vocab_size to match tokenizer | |
| if tokenizer.vocab_size != MODEL_CONFIG.vocab_size: | |
| print(f" Adjusting model vocab_size: {MODEL_CONFIG.vocab_size} -> {tokenizer.vocab_size}") | |
| MODEL_CONFIG.vocab_size = tokenizer.vocab_size | |
| print("Loading model ...") | |
| model = SymbioGPT(MODEL_CONFIG) | |
| checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True) | |
| # Handle both raw state_dict and wrapped checkpoint formats | |
| if "model_state_dict" in checkpoint: | |
| state_dict = checkpoint["model_state_dict"] | |
| elif "state_dict" in checkpoint: | |
| state_dict = checkpoint["state_dict"] | |
| else: | |
| state_dict = checkpoint | |
| # Strip _orig_mod. prefix from torch.compile checkpoints | |
| state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| n_params = sum(p.numel() for p in model.parameters()) | |
| print(f" Model loaded: {n_params/1e6:.1f}M params") | |
| print(f" Config: d={MODEL_CONFIG.d_model}, L={MODEL_CONFIG.n_layers}, " | |
| f"H={MODEL_CONFIG.n_heads}, ctx={MODEL_CONFIG.context_length}, " | |
| f"vocab={MODEL_CONFIG.vocab_size}") | |
| print(f" Organelles: {MODEL_CONFIG.organelles}") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Generation | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_streaming( | |
| prompt: str, | |
| max_tokens: int = 200, | |
| temperature: float = 0.8, | |
| top_k: int = 40, | |
| top_p: float = 1.0, | |
| ): | |
| """Generator yielding token strings one at a time for real SSE streaming.""" | |
| tokens = tokenizer.encode(prompt) | |
| if not tokens: | |
| tokens = [0] | |
| idx = torch.tensor([tokens], dtype=torch.long) | |
| for _ in range(max_tokens): | |
| idx_cond = idx[:, -MODEL_CONFIG.context_length:] | |
| logits = model(idx_cond) | |
| logits_last = logits[0, -1, :].float() | |
| if temperature > 0.01: | |
| logits_last = logits_last / temperature | |
| else: | |
| logits_last = logits_last / 0.01 | |
| if 0 < top_k < logits_last.size(0): | |
| threshold = torch.topk(logits_last, top_k).values[-1] | |
| logits_last[logits_last < threshold] = float("-inf") | |
| if top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(logits_last, descending=True) | |
| probs_sorted = F.softmax(sorted_logits, dim=-1) | |
| cumprobs = torch.cumsum(probs_sorted, dim=-1) | |
| cutoff_mask = cumprobs - probs_sorted > top_p | |
| sorted_logits[cutoff_mask] = float("-inf") | |
| logits_last = sorted_logits.scatter(0, sorted_indices, sorted_logits) | |
| probs = F.softmax(logits_last, dim=-1) | |
| next_id = torch.multinomial(probs, 1).item() | |
| idx = torch.cat([idx, torch.tensor([[next_id]])], dim=1) | |
| yield tokenizer.decode([next_id]) | |
| def generate( | |
| prompt: str, | |
| max_tokens: int = 200, | |
| temperature: float = 0.8, | |
| top_k: int = 40, | |
| top_p: float = 1.0, | |
| ) -> str: | |
| """Generate complete text (non-streaming wrapper).""" | |
| return "".join(generate_streaming(prompt, max_tokens, temperature, top_k, top_p)) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # FastAPI server | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| MODEL_CREATED_AT = int(time.time()) | |
| def extract_prompt(messages): | |
| if not messages: | |
| return "" | |
| for msg in reversed(messages): | |
| if msg.get("role") == "user": | |
| return msg.get("content", "") | |
| return messages[-1].get("content", "") | |
| def health(): | |
| return { | |
| "name": "SymbioGPT-10M", | |
| "version": "1.0.0", | |
| "description": "Multi-organelle GPT trained on classical philosophy β " | |
| "CausalConv + Monarch + LongConv + Attention fused via OrganelleGate", | |
| "architecture": "Decoder-only (4 organelles + OrganelleGate, RoPE, RMSNorm, SwiGLU, " | |
| "SkipGate, weight-tied)", | |
| "model": { | |
| "d_model": MODEL_CONFIG.d_model, | |
| "n_layers": MODEL_CONFIG.n_layers, | |
| "n_heads": MODEL_CONFIG.n_heads, | |
| "head_dim": MODEL_CONFIG.head_dim, | |
| "context_length": MODEL_CONFIG.context_length, | |
| "vocab_size": MODEL_CONFIG.vocab_size, | |
| "n_monarch_heads": MODEL_CONFIG.n_monarch_heads, | |
| "params": f"{n_params/1e6:.1f}M", | |
| }, | |
| "organelles": list(MODEL_CONFIG.organelles), | |
| "endpoints": ["/v1/models", "/v1/chat/completions"], | |
| "features": ["streaming", "OpenAI-compatible", "top-k", "top-p"], | |
| "compatible_with": ["OpenAI API", "OpenRouter"], | |
| } | |
| def list_models(): | |
| return { | |
| "object": "list", | |
| "data": [{ | |
| "id": "symbiogpt-10m", | |
| "object": "model", | |
| "created": MODEL_CREATED_AT, | |
| "owned_by": "symbiogpt", | |
| }], | |
| } | |
| async def chat_completions(request: Request): | |
| try: | |
| body = await request.json() | |
| except Exception: | |
| return JSONResponse(status_code=400, content={ | |
| "error": {"message": "Invalid JSON", "type": "invalid_request_error"} | |
| }) | |
| temperature = max(0.01, min(2.0, body.get("temperature", 0.8))) | |
| max_tokens = max(1, min(MODEL_CONFIG.context_length, body.get("max_tokens", 200))) | |
| top_k_val = max(0, min(MODEL_CONFIG.vocab_size, body.get("top_k", 40))) | |
| top_p_val = max(0.0, min(1.0, body.get("top_p", 1.0))) | |
| stream = body.get("stream", False) | |
| messages = body.get("messages", []) | |
| prompt_text = extract_prompt(messages) | |
| prompt_tokens = len(tokenizer.encode(prompt_text)) if prompt_text else 0 | |
| completion_id = f"chatcmpl-{uuid.uuid4()}" | |
| created = int(time.time()) | |
| if stream: | |
| import json as json_mod | |
| def sse_stream(): | |
| initial = { | |
| "id": completion_id, | |
| "object": "chat.completion.chunk", | |
| "created": created, | |
| "model": "symbiogpt-10m", | |
| "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None}], | |
| } | |
| yield f"data: {json_mod.dumps(initial)}\n\n" | |
| token_count = 0 | |
| for token_str in generate_streaming( | |
| prompt_text, max_tokens=max_tokens, temperature=temperature, | |
| top_k=top_k_val, top_p=top_p_val, | |
| ): | |
| token_count += 1 | |
| chunk = { | |
| "id": completion_id, | |
| "object": "chat.completion.chunk", | |
| "created": created, | |
| "model": "symbiogpt-10m", | |
| "choices": [{"index": 0, "delta": {"content": token_str}, "finish_reason": None}], | |
| } | |
| yield f"data: {json_mod.dumps(chunk)}\n\n" | |
| finish = { | |
| "id": completion_id, | |
| "object": "chat.completion.chunk", | |
| "created": created, | |
| "model": "symbiogpt-10m", | |
| "choices": [{"index": 0, "delta": {}, "finish_reason": "length" if token_count >= max_tokens else "stop"}], | |
| "usage": { | |
| "prompt_tokens": prompt_tokens, | |
| "completion_tokens": token_count, | |
| "total_tokens": prompt_tokens + token_count, | |
| }, | |
| } | |
| yield f"data: {json_mod.dumps(finish)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return StreamingResponse(sse_stream(), media_type="text/event-stream", | |
| headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}) | |
| else: | |
| n_completions = max(1, min(4, body.get("n", 1))) | |
| choices = [] | |
| for i in range(n_completions): | |
| text = generate(prompt_text, max_tokens=max_tokens, temperature=temperature, | |
| top_k=top_k_val, top_p=top_p_val) | |
| choices.append({ | |
| "index": i, | |
| "message": {"role": "assistant", "content": text}, | |
| "finish_reason": "length", | |
| }) | |
| return { | |
| "id": completion_id, | |
| "object": "chat.completion", | |
| "created": created, | |
| "model": "symbiogpt-10m", | |
| "choices": choices, | |
| "usage": { | |
| "prompt_tokens": prompt_tokens, | |
| "completion_tokens": max_tokens * n_completions, | |
| "total_tokens": prompt_tokens + max_tokens * n_completions, | |
| }, | |
| "system_fingerprint": "symbiogpt-10m-v1", | |
| } | |
| if __name__ == "__main__": | |
| print(f"\nSymbioGPT-10M server starting on 0.0.0.0:{PORT} ...") | |
| print(f" GET http://localhost:{PORT}/") | |
| print(f" GET http://localhost:{PORT}/v1/models") | |
| print(f" POST http://localhost:{PORT}/v1/chat/completions") | |
| uvicorn.run(app, host="0.0.0.0", port=PORT) | |