"""SymbioGPT-10M — OpenAI-compatible inference server. Serves a PyTorch SymbioGPT model (4 organelles: CausalConv + Monarch + LongConv + Attention, fused via OrganelleGate). Downloads checkpoint and tokenizer from HuggingFace on first run. """ import math import os import time import uuid import torch import torch.nn.functional as F import uvicorn from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from huggingface_hub import hf_hub_download from symbio_model import SymbioConfig, SymbioGPT from tokenizer import BPETokenizer # ═══════════════════════════════════════════════════════════════════ # Configuration # ═══════════════════════════════════════════════════════════════════ HF_REPO = os.environ.get("HF_REPO", "LisaMegaWatts/SymbioGPT-10M") PORT = int(os.environ.get("PORT", "7860")) CHECKPOINT_FILE = "symbio_best.pt" MODEL_CONFIG = SymbioConfig( d_model=320, n_layers=8, n_heads=5, head_dim=64, ffn_mult=4, context_length=256, vocab_size=2000, weight_tying=True, organelles=("causal_conv", "monarch", "long_conv", "attention"), conv_kernel_size=4, n_monarch_heads=1, gate_temperature_init=1.0, free_energy_beta=0.001, ) # ═══════════════════════════════════════════════════════════════════ # Load model and tokenizer # ═══════════════════════════════════════════════════════════════════ print(f"Downloading artifacts from {HF_REPO} ...") ckpt_path = hf_hub_download(repo_id=HF_REPO, filename=CHECKPOINT_FILE) vocab_path = hf_hub_download(repo_id=HF_REPO, filename="vocab.json") merges_path = hf_hub_download(repo_id=HF_REPO, filename="merges.txt") print("Loading tokenizer ...") tokenizer = BPETokenizer.from_files(vocab_path, merges_path) print(f" BPE vocab_size = {tokenizer.vocab_size}") # Adjust vocab_size to match tokenizer if tokenizer.vocab_size != MODEL_CONFIG.vocab_size: print(f" Adjusting model vocab_size: {MODEL_CONFIG.vocab_size} -> {tokenizer.vocab_size}") MODEL_CONFIG.vocab_size = tokenizer.vocab_size print("Loading model ...") model = SymbioGPT(MODEL_CONFIG) checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True) # Handle both raw state_dict and wrapped checkpoint formats if "model_state_dict" in checkpoint: state_dict = checkpoint["model_state_dict"] elif "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] else: state_dict = checkpoint # Strip _orig_mod. prefix from torch.compile checkpoints state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} model.load_state_dict(state_dict) model.eval() n_params = sum(p.numel() for p in model.parameters()) print(f" Model loaded: {n_params/1e6:.1f}M params") print(f" Config: d={MODEL_CONFIG.d_model}, L={MODEL_CONFIG.n_layers}, " f"H={MODEL_CONFIG.n_heads}, ctx={MODEL_CONFIG.context_length}, " f"vocab={MODEL_CONFIG.vocab_size}") print(f" Organelles: {MODEL_CONFIG.organelles}") # ═══════════════════════════════════════════════════════════════════ # Generation # ═══════════════════════════════════════════════════════════════════ @torch.no_grad() def generate_streaming( prompt: str, max_tokens: int = 200, temperature: float = 0.8, top_k: int = 40, top_p: float = 1.0, ): """Generator yielding token strings one at a time for real SSE streaming.""" tokens = tokenizer.encode(prompt) if not tokens: tokens = [0] idx = torch.tensor([tokens], dtype=torch.long) for _ in range(max_tokens): idx_cond = idx[:, -MODEL_CONFIG.context_length:] logits = model(idx_cond) logits_last = logits[0, -1, :].float() if temperature > 0.01: logits_last = logits_last / temperature else: logits_last = logits_last / 0.01 if 0 < top_k < logits_last.size(0): threshold = torch.topk(logits_last, top_k).values[-1] logits_last[logits_last < threshold] = float("-inf") if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits_last, descending=True) probs_sorted = F.softmax(sorted_logits, dim=-1) cumprobs = torch.cumsum(probs_sorted, dim=-1) cutoff_mask = cumprobs - probs_sorted > top_p sorted_logits[cutoff_mask] = float("-inf") logits_last = sorted_logits.scatter(0, sorted_indices, sorted_logits) probs = F.softmax(logits_last, dim=-1) next_id = torch.multinomial(probs, 1).item() idx = torch.cat([idx, torch.tensor([[next_id]])], dim=1) yield tokenizer.decode([next_id]) @torch.no_grad() def generate( prompt: str, max_tokens: int = 200, temperature: float = 0.8, top_k: int = 40, top_p: float = 1.0, ) -> str: """Generate complete text (non-streaming wrapper).""" return "".join(generate_streaming(prompt, max_tokens, temperature, top_k, top_p)) # ═══════════════════════════════════════════════════════════════════ # FastAPI server # ═══════════════════════════════════════════════════════════════════ app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) MODEL_CREATED_AT = int(time.time()) def extract_prompt(messages): if not messages: return "" for msg in reversed(messages): if msg.get("role") == "user": return msg.get("content", "") return messages[-1].get("content", "") @app.get("/") def health(): return { "name": "SymbioGPT-10M", "version": "1.0.0", "description": "Multi-organelle GPT trained on classical philosophy — " "CausalConv + Monarch + LongConv + Attention fused via OrganelleGate", "architecture": "Decoder-only (4 organelles + OrganelleGate, RoPE, RMSNorm, SwiGLU, " "SkipGate, weight-tied)", "model": { "d_model": MODEL_CONFIG.d_model, "n_layers": MODEL_CONFIG.n_layers, "n_heads": MODEL_CONFIG.n_heads, "head_dim": MODEL_CONFIG.head_dim, "context_length": MODEL_CONFIG.context_length, "vocab_size": MODEL_CONFIG.vocab_size, "n_monarch_heads": MODEL_CONFIG.n_monarch_heads, "params": f"{n_params/1e6:.1f}M", }, "organelles": list(MODEL_CONFIG.organelles), "endpoints": ["/v1/models", "/v1/chat/completions"], "features": ["streaming", "OpenAI-compatible", "top-k", "top-p"], "compatible_with": ["OpenAI API", "OpenRouter"], } @app.get("/v1/models") def list_models(): return { "object": "list", "data": [{ "id": "symbiogpt-10m", "object": "model", "created": MODEL_CREATED_AT, "owned_by": "symbiogpt", }], } @app.post("/v1/chat/completions") async def chat_completions(request: Request): try: body = await request.json() except Exception: return JSONResponse(status_code=400, content={ "error": {"message": "Invalid JSON", "type": "invalid_request_error"} }) temperature = max(0.01, min(2.0, body.get("temperature", 0.8))) max_tokens = max(1, min(MODEL_CONFIG.context_length, body.get("max_tokens", 200))) top_k_val = max(0, min(MODEL_CONFIG.vocab_size, body.get("top_k", 40))) top_p_val = max(0.0, min(1.0, body.get("top_p", 1.0))) stream = body.get("stream", False) messages = body.get("messages", []) prompt_text = extract_prompt(messages) prompt_tokens = len(tokenizer.encode(prompt_text)) if prompt_text else 0 completion_id = f"chatcmpl-{uuid.uuid4()}" created = int(time.time()) if stream: import json as json_mod def sse_stream(): initial = { "id": completion_id, "object": "chat.completion.chunk", "created": created, "model": "symbiogpt-10m", "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None}], } yield f"data: {json_mod.dumps(initial)}\n\n" token_count = 0 for token_str in generate_streaming( prompt_text, max_tokens=max_tokens, temperature=temperature, top_k=top_k_val, top_p=top_p_val, ): token_count += 1 chunk = { "id": completion_id, "object": "chat.completion.chunk", "created": created, "model": "symbiogpt-10m", "choices": [{"index": 0, "delta": {"content": token_str}, "finish_reason": None}], } yield f"data: {json_mod.dumps(chunk)}\n\n" finish = { "id": completion_id, "object": "chat.completion.chunk", "created": created, "model": "symbiogpt-10m", "choices": [{"index": 0, "delta": {}, "finish_reason": "length" if token_count >= max_tokens else "stop"}], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": token_count, "total_tokens": prompt_tokens + token_count, }, } yield f"data: {json_mod.dumps(finish)}\n\n" yield "data: [DONE]\n\n" return StreamingResponse(sse_stream(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}) else: n_completions = max(1, min(4, body.get("n", 1))) choices = [] for i in range(n_completions): text = generate(prompt_text, max_tokens=max_tokens, temperature=temperature, top_k=top_k_val, top_p=top_p_val) choices.append({ "index": i, "message": {"role": "assistant", "content": text}, "finish_reason": "length", }) return { "id": completion_id, "object": "chat.completion", "created": created, "model": "symbiogpt-10m", "choices": choices, "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": max_tokens * n_completions, "total_tokens": prompt_tokens + max_tokens * n_completions, }, "system_fingerprint": "symbiogpt-10m-v1", } if __name__ == "__main__": print(f"\nSymbioGPT-10M server starting on 0.0.0.0:{PORT} ...") print(f" GET http://localhost:{PORT}/") print(f" GET http://localhost:{PORT}/v1/models") print(f" POST http://localhost:{PORT}/v1/chat/completions") uvicorn.run(app, host="0.0.0.0", port=PORT)