"""SymbioGPT-GrammarExpert — OpenAI-compatible inference server. SymbioGPT-10M base model with Grammar Expert LoRA adapter merged at startup. The LoRA was discovered via evolutionary search on CoLA (grammar acceptability). Downloads base checkpoint + LoRA weights from HuggingFace on first run. True token-by-token SSE streaming via background thread + queue. """ import json as json_mod import math import os import queue import threading import time import uuid import torch import torch.nn.functional as F import uvicorn from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from huggingface_hub import hf_hub_download from symbio_model import SymbioConfig, SymbioGPT from tokenizer import BPETokenizer # ═══════════════════════════════════════════════════════════════════ # Configuration # ═══════════════════════════════════════════════════════════════════ BASE_REPO = os.environ.get("BASE_REPO", "LisaMegaWatts/SymbioGPT-10M") LORA_REPO = os.environ.get("LORA_REPO", "LisaMegaWatts/SymbioGPT-GrammarExpert-20260301") PORT = int(os.environ.get("PORT", "7860")) CHECKPOINT_FILE = "symbio_best.pt" LORA_FILE = "lora_weights.pt" # LoRA config (from metadata.json) LORA_RANK = 8 LORA_ALPHA = 8.0 MODEL_CONFIG = SymbioConfig( d_model=320, n_layers=8, n_heads=5, head_dim=64, ffn_mult=4, context_length=256, vocab_size=2000, weight_tying=True, organelles=("causal_conv", "monarch", "long_conv", "attention"), conv_kernel_size=4, n_monarch_heads=1, gate_temperature_init=1.0, free_energy_beta=0.001, ) # ═══════════════════════════════════════════════════════════════════ # LoRA merging # ═══════════════════════════════════════════════════════════════════ # Map LoRA short keys to base model full keys LORA_KEY_MAP = { # block.{i}.attn.{proj} -> blocks.{i}.seq_mixer.organelle_modules.attention.{proj} "attn": "seq_mixer.organelle_modules.attention", # block.{i}.ffn.{proj} -> blocks.{i}.ffn.{proj} "ffn": "ffn", } def merge_lora(model, lora_state, alpha, rank): """Merge LoRA weights into base model. LoRA formula: W_merged = W_base + (B^T @ A^T) * (alpha / rank) Where A: (in_features, rank), B: (rank, out_features) as stored. """ base_state = model.state_dict() scaling = alpha / rank merged_count = 0 # Group LoRA pairs (A and B for each target) lora_pairs = {} for key in lora_state: if key.endswith(".lora_A"): base_key = key[:-7] # strip .lora_A lora_pairs[base_key] = lora_pairs.get(base_key, {}) lora_pairs[base_key]["A"] = lora_state[key] elif key.endswith(".lora_B"): base_key = key[:-7] lora_pairs[base_key] = lora_pairs.get(base_key, {}) lora_pairs[base_key]["B"] = lora_state[key] for lora_key, pair in lora_pairs.items(): if "A" not in pair or "B" not in pair: print(f" WARNING: incomplete LoRA pair for {lora_key}") continue # Map LoRA key to base model key # lora_key format: "block.{i}.{module}.{proj}" # base format: "blocks.{i}.{full_module_path}.{proj}" parts = lora_key.split(".") if len(parts) >= 4 and parts[0] == "block": layer_idx = parts[1] module = parts[2] # "attn" or "ffn" proj = parts[3] # "wq", "wk", etc. if module in LORA_KEY_MAP: mapped_module = LORA_KEY_MAP[module] base_weight_key = f"blocks.{layer_idx}.{mapped_module}.{proj}.weight" else: base_weight_key = f"blocks.{layer_idx}.{module}.{proj}.weight" else: print(f" WARNING: unexpected LoRA key format: {lora_key}") continue if base_weight_key not in base_state: print(f" WARNING: base key not found: {base_weight_key}") continue A = pair["A"].float() # (in_features, rank) B = pair["B"].float() # (rank, out_features) # delta_W = B^T @ A^T = (out, rank) @ (rank, in) = (out, in) delta_W = B.T @ A.T base_state[base_weight_key] = ( base_state[base_weight_key].float() + delta_W * scaling ).to(base_state[base_weight_key].dtype) merged_count += 1 model.load_state_dict(base_state) return merged_count # ═══════════════════════════════════════════════════════════════════ # Load model and tokenizer # ═══════════════════════════════════════════════════════════════════ print(f"Downloading base model from {BASE_REPO} ...") ckpt_path = hf_hub_download(repo_id=BASE_REPO, filename=CHECKPOINT_FILE) vocab_path = hf_hub_download(repo_id=BASE_REPO, filename="vocab.json") merges_path = hf_hub_download(repo_id=BASE_REPO, filename="merges.txt") print(f"Downloading LoRA from {LORA_REPO} ...") lora_path = hf_hub_download(repo_id=LORA_REPO, filename=LORA_FILE) print("Loading tokenizer ...") tokenizer = BPETokenizer.from_files(vocab_path, merges_path) print(f" BPE vocab_size = {tokenizer.vocab_size}") if tokenizer.vocab_size != MODEL_CONFIG.vocab_size: print(f" Adjusting model vocab_size: {MODEL_CONFIG.vocab_size} -> {tokenizer.vocab_size}") MODEL_CONFIG.vocab_size = tokenizer.vocab_size print("Loading base model ...") model = SymbioGPT(MODEL_CONFIG) checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True) if "model_state_dict" in checkpoint: state_dict = checkpoint["model_state_dict"] elif "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] else: state_dict = checkpoint state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} model.load_state_dict(state_dict) print("Merging LoRA weights ...") lora_state = torch.load(lora_path, map_location="cpu", weights_only=True) n_merged = merge_lora(model, lora_state, LORA_ALPHA, LORA_RANK) print(f" Merged {n_merged} LoRA weight pairs (rank={LORA_RANK}, alpha={LORA_ALPHA})") model.eval() n_params = sum(p.numel() for p in model.parameters()) print(f" Model ready: {n_params/1e6:.1f}M params (base + LoRA merged)") # ═══════════════════════════════════════════════════════════════════ # Generation # ═══════════════════════════════════════════════════════════════════ _SENTINEL = object() # marks end of generation @torch.no_grad() def generate( prompt: str, max_tokens: int = 200, temperature: float = 0.8, top_k: int = 40, top_p: float = 1.0, token_queue: queue.Queue = None, ) -> str: """Generate text. If token_queue is provided, pushes each token string to the queue as it's generated for true streaming.""" tokens = tokenizer.encode(prompt) if not tokens: tokens = [0] idx = torch.tensor([tokens], dtype=torch.long) generated_ids = [] for _ in range(max_tokens): idx_cond = idx[:, -MODEL_CONFIG.context_length:] logits = model(idx_cond) logits_last = logits[0, -1, :].float() if temperature > 0.01: logits_last = logits_last / temperature else: logits_last = logits_last / 0.01 if 0 < top_k < logits_last.size(0): threshold = torch.topk(logits_last, top_k).values[-1] logits_last[logits_last < threshold] = float("-inf") if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits_last, descending=True) probs_sorted = F.softmax(sorted_logits, dim=-1) cumprobs = torch.cumsum(probs_sorted, dim=-1) cutoff_mask = cumprobs - probs_sorted > top_p sorted_logits[cutoff_mask] = float("-inf") logits_last = sorted_logits.scatter(0, sorted_indices, sorted_logits) probs = F.softmax(logits_last, dim=-1) next_id = torch.multinomial(probs, 1).item() generated_ids.append(next_id) idx = torch.cat([idx, torch.tensor([[next_id]])], dim=1) if token_queue is not None: token_queue.put(tokenizer.decode([next_id])) if token_queue is not None: token_queue.put(_SENTINEL) return tokenizer.decode(generated_ids) # ═══════════════════════════════════════════════════════════════════ # FastAPI server # ═══════════════════════════════════════════════════════════════════ app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) MODEL_CREATED_AT = int(time.time()) MODEL_ID = "symbiogpt-grammar-expert" def extract_prompt(messages): if not messages: return "" for msg in reversed(messages): if msg.get("role") == "user": return msg.get("content", "") return messages[-1].get("content", "") @app.get("/") def health(): return { "name": "SymbioGPT-GrammarExpert", "version": "1.1.0", "description": "SymbioGPT-10M + Grammar Expert LoRA (evolved on CoLA)", "architecture": "4-organelle decoder (CausalConv + Monarch + LongConv + Attention) " "+ OrganelleGate + LoRA (rank=8, attn+ffn)", "model": { "d_model": MODEL_CONFIG.d_model, "n_layers": MODEL_CONFIG.n_layers, "n_heads": MODEL_CONFIG.n_heads, "context_length": MODEL_CONFIG.context_length, "vocab_size": MODEL_CONFIG.vocab_size, "params": f"{n_params/1e6:.1f}M", "lora_rank": LORA_RANK, }, "organelles": list(MODEL_CONFIG.organelles), "endpoints": ["/v1/models", "/v1/chat/completions"], "features": ["streaming", "OpenAI-compatible", "top-k", "top-p", "grammar-expert-lora"], } @app.get("/v1/models") def list_models(): return { "object": "list", "data": [{ "id": MODEL_ID, "object": "model", "created": MODEL_CREATED_AT, "owned_by": "symbiogpt", }], } @app.post("/v1/chat/completions") async def chat_completions(request: Request): try: body = await request.json() except Exception: return JSONResponse(status_code=400, content={ "error": {"message": "Invalid JSON", "type": "invalid_request_error"} }) temperature = max(0.01, min(2.0, body.get("temperature", 0.8))) max_tokens = max(1, min(MODEL_CONFIG.context_length, body.get("max_tokens", 200))) top_k_val = max(0, min(MODEL_CONFIG.vocab_size, body.get("top_k", 40))) top_p_val = max(0.0, min(1.0, body.get("top_p", 1.0))) stream = body.get("stream", False) messages = body.get("messages", []) prompt_text = extract_prompt(messages) prompt_tokens = len(tokenizer.encode(prompt_text)) if prompt_text else 0 completion_id = f"chatcmpl-{uuid.uuid4()}" created = int(time.time()) if stream: def sse_stream(): # Initial chunk with role initial = { "id": completion_id, "object": "chat.completion.chunk", "created": created, "model": MODEL_ID, "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None}], } yield f"data: {json_mod.dumps(initial)}\n\n" # Start generation in background thread q = queue.Queue() gen_thread = threading.Thread( target=generate, kwargs={ "prompt": prompt_text, "max_tokens": max_tokens, "temperature": temperature, "top_k": top_k_val, "top_p": top_p_val, "token_queue": q, }, daemon=True, ) gen_thread.start() # Stream tokens as they arrive token_count = 0 while True: tok = q.get() if tok is _SENTINEL: break token_count += 1 chunk = { "id": completion_id, "object": "chat.completion.chunk", "created": created, "model": MODEL_ID, "choices": [{"index": 0, "delta": {"content": tok}, "finish_reason": None}], } yield f"data: {json_mod.dumps(chunk)}\n\n" gen_thread.join(timeout=5.0) # Final chunk finish = { "id": completion_id, "object": "chat.completion.chunk", "created": created, "model": MODEL_ID, "choices": [{"index": 0, "delta": {}, "finish_reason": "length" if token_count >= max_tokens else "stop"}], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": token_count, "total_tokens": prompt_tokens + token_count, }, } yield f"data: {json_mod.dumps(finish)}\n\n" yield "data: [DONE]\n\n" return StreamingResponse(sse_stream(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}) else: text = generate(prompt_text, max_tokens=max_tokens, temperature=temperature, top_k=top_k_val, top_p=top_p_val) completion_tokens = len(tokenizer.encode(text)) return { "id": completion_id, "object": "chat.completion", "created": created, "model": MODEL_ID, "choices": [{ "index": 0, "message": {"role": "assistant", "content": text}, "finish_reason": "length", }], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, }, "system_fingerprint": "symbiogpt-grammar-expert-v1", } if __name__ == "__main__": print(f"\nSymbioGPT-GrammarExpert server starting on 0.0.0.0:{PORT} ...") print(f" GET http://localhost:{PORT}/") print(f" GET http://localhost:{PORT}/v1/models") print(f" POST http://localhost:{PORT}/v1/chat/completions") uvicorn.run(app, host="0.0.0.0", port=PORT)