Fix: use merged_model_state.pt (LoRA baked into base weights, fixes missing FFN weights)
13790f5 verified | """SymbioSLM-GrammarExpert β OpenAI-compatible inference server. | |
| SymbioSLM (5.1M) with progressive-unfreeze LoRA training on CoLA grammar task. | |
| Uses LoRA-merged weights (no LoRA wrapper needed at inference). | |
| True token-by-token SSE streaming via background thread + queue. | |
| """ | |
| import json as json_mod | |
| import math | |
| import os | |
| import queue | |
| import threading | |
| import time | |
| import uuid | |
| import torch | |
| import torch.nn.functional as F | |
| import uvicorn | |
| from fastapi import FastAPI, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from huggingface_hub import hf_hub_download | |
| from symbio_model import SymbioConfig, SymbioGPT | |
| from tokenizer import BPETokenizer | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Configuration | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| BASE_REPO = os.environ.get("BASE_REPO", "LisaMegaWatts/SymbioSLM") | |
| MODEL_REPO = os.environ.get("MODEL_REPO", "LisaMegaWatts/SymbioSLM-GrammarExpert-Progressive-20260301") | |
| PORT = int(os.environ.get("PORT", "7860")) | |
| MODEL_FILE = "merged_model_state.pt" | |
| # SymbioSLM config β 3 organelles (no attention), matches pretrained JLD2 checkpoint | |
| MODEL_CONFIG = SymbioConfig( | |
| d_model=256, | |
| n_layers=8, | |
| n_heads=4, | |
| head_dim=64, | |
| ffn_mult=4, | |
| context_length=256, | |
| vocab_size=2000, | |
| weight_tying=True, | |
| organelles=("causal_conv", "monarch", "long_conv"), | |
| conv_kernel_size=4, | |
| n_monarch_heads=1, | |
| gate_temperature_init=1.0, | |
| free_energy_beta=0.001, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Load model and tokenizer | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"Downloading tokenizer from {BASE_REPO} ...") | |
| vocab_path = hf_hub_download(repo_id=BASE_REPO, filename="vocab.json") | |
| merges_path = hf_hub_download(repo_id=BASE_REPO, filename="merges.txt") | |
| print(f"Downloading trained model from {MODEL_REPO} ...") | |
| model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE) | |
| print("Loading tokenizer ...") | |
| tokenizer = BPETokenizer.from_files(vocab_path, merges_path) | |
| print(f" BPE vocab_size = {tokenizer.vocab_size}") | |
| if tokenizer.vocab_size != MODEL_CONFIG.vocab_size: | |
| print(f" Adjusting model vocab_size: {MODEL_CONFIG.vocab_size} -> {tokenizer.vocab_size}") | |
| MODEL_CONFIG.vocab_size = tokenizer.vocab_size | |
| print("Loading model ...") | |
| model = SymbioGPT(MODEL_CONFIG) | |
| # Load LoRA-merged state (LoRA A@B baked into base weights, no wrapper needed) | |
| state_dict = torch.load(model_path, map_location="cpu", weights_only=True) | |
| model.load_state_dict(state_dict, strict=True) | |
| print(" Loaded merged weights (LoRA baked in)") | |
| model.eval() | |
| n_params = sum(p.numel() for p in model.parameters()) | |
| print(f" Model ready: {n_params/1e6:.1f}M params") | |
| print(f" Config: d={MODEL_CONFIG.d_model}, L={MODEL_CONFIG.n_layers}, " | |
| f"monarch_heads={MODEL_CONFIG.n_monarch_heads}") | |
| print(f" Organelles: {MODEL_CONFIG.organelles}") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Generation | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _SENTINEL = object() | |
| def generate( | |
| prompt: str, | |
| max_tokens: int = 200, | |
| temperature: float = 0.8, | |
| top_k: int = 40, | |
| top_p: float = 1.0, | |
| token_queue: queue.Queue = None, | |
| ) -> str: | |
| tokens = tokenizer.encode(prompt) | |
| if not tokens: | |
| tokens = [0] | |
| idx = torch.tensor([tokens], dtype=torch.long) | |
| generated_ids = [] | |
| for _ in range(max_tokens): | |
| idx_cond = idx[:, -MODEL_CONFIG.context_length:] | |
| logits = model(idx_cond) | |
| logits_last = logits[0, -1, :].float() | |
| if temperature > 0.01: | |
| logits_last = logits_last / temperature | |
| else: | |
| logits_last = logits_last / 0.01 | |
| if 0 < top_k < logits_last.size(0): | |
| threshold = torch.topk(logits_last, top_k).values[-1] | |
| logits_last[logits_last < threshold] = float("-inf") | |
| if top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(logits_last, descending=True) | |
| probs_sorted = F.softmax(sorted_logits, dim=-1) | |
| cumprobs = torch.cumsum(probs_sorted, dim=-1) | |
| cutoff_mask = cumprobs - probs_sorted > top_p | |
| sorted_logits[cutoff_mask] = float("-inf") | |
| logits_last = sorted_logits.scatter(0, sorted_indices, sorted_logits) | |
| probs = F.softmax(logits_last, dim=-1) | |
| next_id = torch.multinomial(probs, 1).item() | |
| generated_ids.append(next_id) | |
| idx = torch.cat([idx, torch.tensor([[next_id]])], dim=1) | |
| if token_queue is not None: | |
| token_queue.put(tokenizer.decode([next_id])) | |
| if token_queue is not None: | |
| token_queue.put(_SENTINEL) | |
| return tokenizer.decode(generated_ids) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # FastAPI server | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| MODEL_CREATED_AT = int(time.time()) | |
| MODEL_ID = "symbioslm-grammar-expert" | |
| def extract_prompt(messages): | |
| if not messages: | |
| return "" | |
| for msg in reversed(messages): | |
| if msg.get("role") == "user": | |
| return msg.get("content", "") | |
| return messages[-1].get("content", "") | |
| def health(): | |
| return { | |
| "name": "SymbioSLM-GrammarExpert", | |
| "version": "1.1.0", | |
| "description": "SymbioSLM (5.1M) with progressive-unfreeze LoRA on CoLA grammar", | |
| "architecture": "3-organelle decoder (CausalConv + Monarch + LongConv) " | |
| "+ OrganelleGate, progressive unfreeze LoRA (rank=30, w1+v+w2, merged)", | |
| "model": { | |
| "d_model": MODEL_CONFIG.d_model, | |
| "n_layers": MODEL_CONFIG.n_layers, | |
| "n_monarch_heads": MODEL_CONFIG.n_monarch_heads, | |
| "context_length": MODEL_CONFIG.context_length, | |
| "vocab_size": MODEL_CONFIG.vocab_size, | |
| "params": f"{n_params/1e6:.1f}M", | |
| }, | |
| "organelles": list(MODEL_CONFIG.organelles), | |
| "endpoints": ["/v1/models", "/v1/chat/completions"], | |
| "features": ["streaming", "OpenAI-compatible", "top-k", "top-p", "grammar-expert"], | |
| } | |
| def list_models(): | |
| return { | |
| "object": "list", | |
| "data": [{ | |
| "id": MODEL_ID, | |
| "object": "model", | |
| "created": MODEL_CREATED_AT, | |
| "owned_by": "symbioslm", | |
| }], | |
| } | |
| async def chat_completions(request: Request): | |
| try: | |
| body = await request.json() | |
| except Exception: | |
| return JSONResponse(status_code=400, content={ | |
| "error": {"message": "Invalid JSON", "type": "invalid_request_error"} | |
| }) | |
| temperature = max(0.01, min(2.0, body.get("temperature", 0.8))) | |
| max_tokens = max(1, min(MODEL_CONFIG.context_length, body.get("max_tokens", 200))) | |
| top_k_val = max(0, min(MODEL_CONFIG.vocab_size, body.get("top_k", 40))) | |
| top_p_val = max(0.0, min(1.0, body.get("top_p", 1.0))) | |
| stream = body.get("stream", False) | |
| messages = body.get("messages", []) | |
| prompt_text = extract_prompt(messages) | |
| prompt_tokens = len(tokenizer.encode(prompt_text)) if prompt_text else 0 | |
| completion_id = f"chatcmpl-{uuid.uuid4()}" | |
| created = int(time.time()) | |
| if stream: | |
| def sse_stream(): | |
| initial = { | |
| "id": completion_id, | |
| "object": "chat.completion.chunk", | |
| "created": created, | |
| "model": MODEL_ID, | |
| "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None}], | |
| } | |
| yield f"data: {json_mod.dumps(initial)}\n\n" | |
| q = queue.Queue() | |
| gen_thread = threading.Thread( | |
| target=generate, | |
| kwargs={ | |
| "prompt": prompt_text, | |
| "max_tokens": max_tokens, | |
| "temperature": temperature, | |
| "top_k": top_k_val, | |
| "top_p": top_p_val, | |
| "token_queue": q, | |
| }, | |
| daemon=True, | |
| ) | |
| gen_thread.start() | |
| token_count = 0 | |
| while True: | |
| tok = q.get() | |
| if tok is _SENTINEL: | |
| break | |
| token_count += 1 | |
| chunk = { | |
| "id": completion_id, | |
| "object": "chat.completion.chunk", | |
| "created": created, | |
| "model": MODEL_ID, | |
| "choices": [{"index": 0, "delta": {"content": tok}, "finish_reason": None}], | |
| } | |
| yield f"data: {json_mod.dumps(chunk)}\n\n" | |
| gen_thread.join(timeout=5.0) | |
| finish = { | |
| "id": completion_id, | |
| "object": "chat.completion.chunk", | |
| "created": created, | |
| "model": MODEL_ID, | |
| "choices": [{"index": 0, "delta": {}, "finish_reason": "length" if token_count >= max_tokens else "stop"}], | |
| "usage": { | |
| "prompt_tokens": prompt_tokens, | |
| "completion_tokens": token_count, | |
| "total_tokens": prompt_tokens + token_count, | |
| }, | |
| } | |
| yield f"data: {json_mod.dumps(finish)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return StreamingResponse(sse_stream(), media_type="text/event-stream", | |
| headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}) | |
| else: | |
| text = generate(prompt_text, max_tokens=max_tokens, temperature=temperature, | |
| top_k=top_k_val, top_p=top_p_val) | |
| completion_tokens = len(tokenizer.encode(text)) | |
| return { | |
| "id": completion_id, | |
| "object": "chat.completion", | |
| "created": created, | |
| "model": MODEL_ID, | |
| "choices": [{ | |
| "index": 0, | |
| "message": {"role": "assistant", "content": text}, | |
| "finish_reason": "length", | |
| }], | |
| "usage": { | |
| "prompt_tokens": prompt_tokens, | |
| "completion_tokens": completion_tokens, | |
| "total_tokens": prompt_tokens + completion_tokens, | |
| }, | |
| "system_fingerprint": "symbioslm-grammar-expert-v1", | |
| } | |
| if __name__ == "__main__": | |
| print(f"\nSymbioSLM-GrammarExpert server starting on 0.0.0.0:{PORT} ...") | |
| print(f" GET http://localhost:{PORT}/") | |
| print(f" GET http://localhost:{PORT}/v1/models") | |
| print(f" POST http://localhost:{PORT}/v1/chat/completions") | |
| uvicorn.run(app, host="0.0.0.0", port=PORT) | |