LisaMegaWatts's picture
Fix: use merged_model_state.pt (LoRA baked into base weights, fixes missing FFN weights)
13790f5 verified
"""SymbioSLM-GrammarExpert β€” OpenAI-compatible inference server.
SymbioSLM (5.1M) with progressive-unfreeze LoRA training on CoLA grammar task.
Uses LoRA-merged weights (no LoRA wrapper needed at inference).
True token-by-token SSE streaming via background thread + queue.
"""
import json as json_mod
import math
import os
import queue
import threading
import time
import uuid
import torch
import torch.nn.functional as F
import uvicorn
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from huggingface_hub import hf_hub_download
from symbio_model import SymbioConfig, SymbioGPT
from tokenizer import BPETokenizer
# ═══════════════════════════════════════════════════════════════════
# Configuration
# ═══════════════════════════════════════════════════════════════════
BASE_REPO = os.environ.get("BASE_REPO", "LisaMegaWatts/SymbioSLM")
MODEL_REPO = os.environ.get("MODEL_REPO", "LisaMegaWatts/SymbioSLM-GrammarExpert-Progressive-20260301")
PORT = int(os.environ.get("PORT", "7860"))
MODEL_FILE = "merged_model_state.pt"
# SymbioSLM config β€” 3 organelles (no attention), matches pretrained JLD2 checkpoint
MODEL_CONFIG = SymbioConfig(
d_model=256,
n_layers=8,
n_heads=4,
head_dim=64,
ffn_mult=4,
context_length=256,
vocab_size=2000,
weight_tying=True,
organelles=("causal_conv", "monarch", "long_conv"),
conv_kernel_size=4,
n_monarch_heads=1,
gate_temperature_init=1.0,
free_energy_beta=0.001,
)
# ═══════════════════════════════════════════════════════════════════
# Load model and tokenizer
# ═══════════════════════════════════════════════════════════════════
print(f"Downloading tokenizer from {BASE_REPO} ...")
vocab_path = hf_hub_download(repo_id=BASE_REPO, filename="vocab.json")
merges_path = hf_hub_download(repo_id=BASE_REPO, filename="merges.txt")
print(f"Downloading trained model from {MODEL_REPO} ...")
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
print("Loading tokenizer ...")
tokenizer = BPETokenizer.from_files(vocab_path, merges_path)
print(f" BPE vocab_size = {tokenizer.vocab_size}")
if tokenizer.vocab_size != MODEL_CONFIG.vocab_size:
print(f" Adjusting model vocab_size: {MODEL_CONFIG.vocab_size} -> {tokenizer.vocab_size}")
MODEL_CONFIG.vocab_size = tokenizer.vocab_size
print("Loading model ...")
model = SymbioGPT(MODEL_CONFIG)
# Load LoRA-merged state (LoRA A@B baked into base weights, no wrapper needed)
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict, strict=True)
print(" Loaded merged weights (LoRA baked in)")
model.eval()
n_params = sum(p.numel() for p in model.parameters())
print(f" Model ready: {n_params/1e6:.1f}M params")
print(f" Config: d={MODEL_CONFIG.d_model}, L={MODEL_CONFIG.n_layers}, "
f"monarch_heads={MODEL_CONFIG.n_monarch_heads}")
print(f" Organelles: {MODEL_CONFIG.organelles}")
# ═══════════════════════════════════════════════════════════════════
# Generation
# ═══════════════════════════════════════════════════════════════════
_SENTINEL = object()
@torch.no_grad()
def generate(
prompt: str,
max_tokens: int = 200,
temperature: float = 0.8,
top_k: int = 40,
top_p: float = 1.0,
token_queue: queue.Queue = None,
) -> str:
tokens = tokenizer.encode(prompt)
if not tokens:
tokens = [0]
idx = torch.tensor([tokens], dtype=torch.long)
generated_ids = []
for _ in range(max_tokens):
idx_cond = idx[:, -MODEL_CONFIG.context_length:]
logits = model(idx_cond)
logits_last = logits[0, -1, :].float()
if temperature > 0.01:
logits_last = logits_last / temperature
else:
logits_last = logits_last / 0.01
if 0 < top_k < logits_last.size(0):
threshold = torch.topk(logits_last, top_k).values[-1]
logits_last[logits_last < threshold] = float("-inf")
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits_last, descending=True)
probs_sorted = F.softmax(sorted_logits, dim=-1)
cumprobs = torch.cumsum(probs_sorted, dim=-1)
cutoff_mask = cumprobs - probs_sorted > top_p
sorted_logits[cutoff_mask] = float("-inf")
logits_last = sorted_logits.scatter(0, sorted_indices, sorted_logits)
probs = F.softmax(logits_last, dim=-1)
next_id = torch.multinomial(probs, 1).item()
generated_ids.append(next_id)
idx = torch.cat([idx, torch.tensor([[next_id]])], dim=1)
if token_queue is not None:
token_queue.put(tokenizer.decode([next_id]))
if token_queue is not None:
token_queue.put(_SENTINEL)
return tokenizer.decode(generated_ids)
# ═══════════════════════════════════════════════════════════════════
# FastAPI server
# ═══════════════════════════════════════════════════════════════════
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
MODEL_CREATED_AT = int(time.time())
MODEL_ID = "symbioslm-grammar-expert"
def extract_prompt(messages):
if not messages:
return ""
for msg in reversed(messages):
if msg.get("role") == "user":
return msg.get("content", "")
return messages[-1].get("content", "")
@app.get("/")
def health():
return {
"name": "SymbioSLM-GrammarExpert",
"version": "1.1.0",
"description": "SymbioSLM (5.1M) with progressive-unfreeze LoRA on CoLA grammar",
"architecture": "3-organelle decoder (CausalConv + Monarch + LongConv) "
"+ OrganelleGate, progressive unfreeze LoRA (rank=30, w1+v+w2, merged)",
"model": {
"d_model": MODEL_CONFIG.d_model,
"n_layers": MODEL_CONFIG.n_layers,
"n_monarch_heads": MODEL_CONFIG.n_monarch_heads,
"context_length": MODEL_CONFIG.context_length,
"vocab_size": MODEL_CONFIG.vocab_size,
"params": f"{n_params/1e6:.1f}M",
},
"organelles": list(MODEL_CONFIG.organelles),
"endpoints": ["/v1/models", "/v1/chat/completions"],
"features": ["streaming", "OpenAI-compatible", "top-k", "top-p", "grammar-expert"],
}
@app.get("/v1/models")
def list_models():
return {
"object": "list",
"data": [{
"id": MODEL_ID,
"object": "model",
"created": MODEL_CREATED_AT,
"owned_by": "symbioslm",
}],
}
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
try:
body = await request.json()
except Exception:
return JSONResponse(status_code=400, content={
"error": {"message": "Invalid JSON", "type": "invalid_request_error"}
})
temperature = max(0.01, min(2.0, body.get("temperature", 0.8)))
max_tokens = max(1, min(MODEL_CONFIG.context_length, body.get("max_tokens", 200)))
top_k_val = max(0, min(MODEL_CONFIG.vocab_size, body.get("top_k", 40)))
top_p_val = max(0.0, min(1.0, body.get("top_p", 1.0)))
stream = body.get("stream", False)
messages = body.get("messages", [])
prompt_text = extract_prompt(messages)
prompt_tokens = len(tokenizer.encode(prompt_text)) if prompt_text else 0
completion_id = f"chatcmpl-{uuid.uuid4()}"
created = int(time.time())
if stream:
def sse_stream():
initial = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": MODEL_ID,
"choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None}],
}
yield f"data: {json_mod.dumps(initial)}\n\n"
q = queue.Queue()
gen_thread = threading.Thread(
target=generate,
kwargs={
"prompt": prompt_text,
"max_tokens": max_tokens,
"temperature": temperature,
"top_k": top_k_val,
"top_p": top_p_val,
"token_queue": q,
},
daemon=True,
)
gen_thread.start()
token_count = 0
while True:
tok = q.get()
if tok is _SENTINEL:
break
token_count += 1
chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": MODEL_ID,
"choices": [{"index": 0, "delta": {"content": tok}, "finish_reason": None}],
}
yield f"data: {json_mod.dumps(chunk)}\n\n"
gen_thread.join(timeout=5.0)
finish = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": MODEL_ID,
"choices": [{"index": 0, "delta": {}, "finish_reason": "length" if token_count >= max_tokens else "stop"}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": token_count,
"total_tokens": prompt_tokens + token_count,
},
}
yield f"data: {json_mod.dumps(finish)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(sse_stream(), media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
else:
text = generate(prompt_text, max_tokens=max_tokens, temperature=temperature,
top_k=top_k_val, top_p=top_p_val)
completion_tokens = len(tokenizer.encode(text))
return {
"id": completion_id,
"object": "chat.completion",
"created": created,
"model": MODEL_ID,
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": text},
"finish_reason": "length",
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
"system_fingerprint": "symbioslm-grammar-expert-v1",
}
if __name__ == "__main__":
print(f"\nSymbioSLM-GrammarExpert server starting on 0.0.0.0:{PORT} ...")
print(f" GET http://localhost:{PORT}/")
print(f" GET http://localhost:{PORT}/v1/models")
print(f" POST http://localhost:{PORT}/v1/chat/completions")
uvicorn.run(app, host="0.0.0.0", port=PORT)