"""SymbioSLM-GrammarExpert — OpenAI-compatible inference server. SymbioSLM (5.1M) with progressive-unfreeze LoRA training on CoLA grammar task. Uses LoRA-merged weights (no LoRA wrapper needed at inference). True token-by-token SSE streaming via background thread + queue. """ import json as json_mod import math import os import queue import threading import time import uuid import torch import torch.nn.functional as F import uvicorn from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from huggingface_hub import hf_hub_download from symbio_model import SymbioConfig, SymbioGPT from tokenizer import BPETokenizer # ═══════════════════════════════════════════════════════════════════ # Configuration # ═══════════════════════════════════════════════════════════════════ BASE_REPO = os.environ.get("BASE_REPO", "LisaMegaWatts/SymbioSLM") MODEL_REPO = os.environ.get("MODEL_REPO", "LisaMegaWatts/SymbioSLM-GrammarExpert-Progressive-20260301") PORT = int(os.environ.get("PORT", "7860")) MODEL_FILE = "merged_model_state.pt" # SymbioSLM config — 3 organelles (no attention), matches pretrained JLD2 checkpoint MODEL_CONFIG = SymbioConfig( d_model=256, n_layers=8, n_heads=4, head_dim=64, ffn_mult=4, context_length=256, vocab_size=2000, weight_tying=True, organelles=("causal_conv", "monarch", "long_conv"), conv_kernel_size=4, n_monarch_heads=1, gate_temperature_init=1.0, free_energy_beta=0.001, ) # ═══════════════════════════════════════════════════════════════════ # Load model and tokenizer # ═══════════════════════════════════════════════════════════════════ print(f"Downloading tokenizer from {BASE_REPO} ...") vocab_path = hf_hub_download(repo_id=BASE_REPO, filename="vocab.json") merges_path = hf_hub_download(repo_id=BASE_REPO, filename="merges.txt") print(f"Downloading trained model from {MODEL_REPO} ...") model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE) print("Loading tokenizer ...") tokenizer = BPETokenizer.from_files(vocab_path, merges_path) print(f" BPE vocab_size = {tokenizer.vocab_size}") if tokenizer.vocab_size != MODEL_CONFIG.vocab_size: print(f" Adjusting model vocab_size: {MODEL_CONFIG.vocab_size} -> {tokenizer.vocab_size}") MODEL_CONFIG.vocab_size = tokenizer.vocab_size print("Loading model ...") model = SymbioGPT(MODEL_CONFIG) # Load LoRA-merged state (LoRA A@B baked into base weights, no wrapper needed) state_dict = torch.load(model_path, map_location="cpu", weights_only=True) model.load_state_dict(state_dict, strict=True) print(" Loaded merged weights (LoRA baked in)") model.eval() n_params = sum(p.numel() for p in model.parameters()) print(f" Model ready: {n_params/1e6:.1f}M params") print(f" Config: d={MODEL_CONFIG.d_model}, L={MODEL_CONFIG.n_layers}, " f"monarch_heads={MODEL_CONFIG.n_monarch_heads}") print(f" Organelles: {MODEL_CONFIG.organelles}") # ═══════════════════════════════════════════════════════════════════ # Generation # ═══════════════════════════════════════════════════════════════════ _SENTINEL = object() @torch.no_grad() def generate( prompt: str, max_tokens: int = 200, temperature: float = 0.8, top_k: int = 40, top_p: float = 1.0, token_queue: queue.Queue = None, ) -> str: tokens = tokenizer.encode(prompt) if not tokens: tokens = [0] idx = torch.tensor([tokens], dtype=torch.long) generated_ids = [] for _ in range(max_tokens): idx_cond = idx[:, -MODEL_CONFIG.context_length:] logits = model(idx_cond) logits_last = logits[0, -1, :].float() if temperature > 0.01: logits_last = logits_last / temperature else: logits_last = logits_last / 0.01 if 0 < top_k < logits_last.size(0): threshold = torch.topk(logits_last, top_k).values[-1] logits_last[logits_last < threshold] = float("-inf") if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits_last, descending=True) probs_sorted = F.softmax(sorted_logits, dim=-1) cumprobs = torch.cumsum(probs_sorted, dim=-1) cutoff_mask = cumprobs - probs_sorted > top_p sorted_logits[cutoff_mask] = float("-inf") logits_last = sorted_logits.scatter(0, sorted_indices, sorted_logits) probs = F.softmax(logits_last, dim=-1) next_id = torch.multinomial(probs, 1).item() generated_ids.append(next_id) idx = torch.cat([idx, torch.tensor([[next_id]])], dim=1) if token_queue is not None: token_queue.put(tokenizer.decode([next_id])) if token_queue is not None: token_queue.put(_SENTINEL) return tokenizer.decode(generated_ids) # ═══════════════════════════════════════════════════════════════════ # FastAPI server # ═══════════════════════════════════════════════════════════════════ app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) MODEL_CREATED_AT = int(time.time()) MODEL_ID = "symbioslm-grammar-expert" def extract_prompt(messages): if not messages: return "" for msg in reversed(messages): if msg.get("role") == "user": return msg.get("content", "") return messages[-1].get("content", "") @app.get("/") def health(): return { "name": "SymbioSLM-GrammarExpert", "version": "1.1.0", "description": "SymbioSLM (5.1M) with progressive-unfreeze LoRA on CoLA grammar", "architecture": "3-organelle decoder (CausalConv + Monarch + LongConv) " "+ OrganelleGate, progressive unfreeze LoRA (rank=30, w1+v+w2, merged)", "model": { "d_model": MODEL_CONFIG.d_model, "n_layers": MODEL_CONFIG.n_layers, "n_monarch_heads": MODEL_CONFIG.n_monarch_heads, "context_length": MODEL_CONFIG.context_length, "vocab_size": MODEL_CONFIG.vocab_size, "params": f"{n_params/1e6:.1f}M", }, "organelles": list(MODEL_CONFIG.organelles), "endpoints": ["/v1/models", "/v1/chat/completions"], "features": ["streaming", "OpenAI-compatible", "top-k", "top-p", "grammar-expert"], } @app.get("/v1/models") def list_models(): return { "object": "list", "data": [{ "id": MODEL_ID, "object": "model", "created": MODEL_CREATED_AT, "owned_by": "symbioslm", }], } @app.post("/v1/chat/completions") async def chat_completions(request: Request): try: body = await request.json() except Exception: return JSONResponse(status_code=400, content={ "error": {"message": "Invalid JSON", "type": "invalid_request_error"} }) temperature = max(0.01, min(2.0, body.get("temperature", 0.8))) max_tokens = max(1, min(MODEL_CONFIG.context_length, body.get("max_tokens", 200))) top_k_val = max(0, min(MODEL_CONFIG.vocab_size, body.get("top_k", 40))) top_p_val = max(0.0, min(1.0, body.get("top_p", 1.0))) stream = body.get("stream", False) messages = body.get("messages", []) prompt_text = extract_prompt(messages) prompt_tokens = len(tokenizer.encode(prompt_text)) if prompt_text else 0 completion_id = f"chatcmpl-{uuid.uuid4()}" created = int(time.time()) if stream: def sse_stream(): initial = { "id": completion_id, "object": "chat.completion.chunk", "created": created, "model": MODEL_ID, "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None}], } yield f"data: {json_mod.dumps(initial)}\n\n" q = queue.Queue() gen_thread = threading.Thread( target=generate, kwargs={ "prompt": prompt_text, "max_tokens": max_tokens, "temperature": temperature, "top_k": top_k_val, "top_p": top_p_val, "token_queue": q, }, daemon=True, ) gen_thread.start() token_count = 0 while True: tok = q.get() if tok is _SENTINEL: break token_count += 1 chunk = { "id": completion_id, "object": "chat.completion.chunk", "created": created, "model": MODEL_ID, "choices": [{"index": 0, "delta": {"content": tok}, "finish_reason": None}], } yield f"data: {json_mod.dumps(chunk)}\n\n" gen_thread.join(timeout=5.0) finish = { "id": completion_id, "object": "chat.completion.chunk", "created": created, "model": MODEL_ID, "choices": [{"index": 0, "delta": {}, "finish_reason": "length" if token_count >= max_tokens else "stop"}], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": token_count, "total_tokens": prompt_tokens + token_count, }, } yield f"data: {json_mod.dumps(finish)}\n\n" yield "data: [DONE]\n\n" return StreamingResponse(sse_stream(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}) else: text = generate(prompt_text, max_tokens=max_tokens, temperature=temperature, top_k=top_k_val, top_p=top_p_val) completion_tokens = len(tokenizer.encode(text)) return { "id": completion_id, "object": "chat.completion", "created": created, "model": MODEL_ID, "choices": [{ "index": 0, "message": {"role": "assistant", "content": text}, "finish_reason": "length", }], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, }, "system_fingerprint": "symbioslm-grammar-expert-v1", } if __name__ == "__main__": print(f"\nSymbioSLM-GrammarExpert server starting on 0.0.0.0:{PORT} ...") print(f" GET http://localhost:{PORT}/") print(f" GET http://localhost:{PORT}/v1/models") print(f" POST http://localhost:{PORT}/v1/chat/completions") uvicorn.run(app, host="0.0.0.0", port=PORT)