LisaMegaWatts's picture
True token-by-token SSE streaming via thread + queue
21fcfa2 verified
"""SymbioGPT-GrammarExpert β€” OpenAI-compatible inference server.
SymbioGPT-10M base model with Grammar Expert LoRA adapter merged at startup.
The LoRA was discovered via evolutionary search on CoLA (grammar acceptability).
Downloads base checkpoint + LoRA weights from HuggingFace on first run.
True token-by-token SSE streaming via background thread + queue.
"""
import json as json_mod
import math
import os
import queue
import threading
import time
import uuid
import torch
import torch.nn.functional as F
import uvicorn
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from huggingface_hub import hf_hub_download
from symbio_model import SymbioConfig, SymbioGPT
from tokenizer import BPETokenizer
# ═══════════════════════════════════════════════════════════════════
# Configuration
# ═══════════════════════════════════════════════════════════════════
BASE_REPO = os.environ.get("BASE_REPO", "LisaMegaWatts/SymbioGPT-10M")
LORA_REPO = os.environ.get("LORA_REPO", "LisaMegaWatts/SymbioGPT-GrammarExpert-20260301")
PORT = int(os.environ.get("PORT", "7860"))
CHECKPOINT_FILE = "symbio_best.pt"
LORA_FILE = "lora_weights.pt"
# LoRA config (from metadata.json)
LORA_RANK = 8
LORA_ALPHA = 8.0
MODEL_CONFIG = SymbioConfig(
d_model=320,
n_layers=8,
n_heads=5,
head_dim=64,
ffn_mult=4,
context_length=256,
vocab_size=2000,
weight_tying=True,
organelles=("causal_conv", "monarch", "long_conv", "attention"),
conv_kernel_size=4,
n_monarch_heads=1,
gate_temperature_init=1.0,
free_energy_beta=0.001,
)
# ═══════════════════════════════════════════════════════════════════
# LoRA merging
# ═══════════════════════════════════════════════════════════════════
# Map LoRA short keys to base model full keys
LORA_KEY_MAP = {
# block.{i}.attn.{proj} -> blocks.{i}.seq_mixer.organelle_modules.attention.{proj}
"attn": "seq_mixer.organelle_modules.attention",
# block.{i}.ffn.{proj} -> blocks.{i}.ffn.{proj}
"ffn": "ffn",
}
def merge_lora(model, lora_state, alpha, rank):
"""Merge LoRA weights into base model.
LoRA formula: W_merged = W_base + (B^T @ A^T) * (alpha / rank)
Where A: (in_features, rank), B: (rank, out_features) as stored.
"""
base_state = model.state_dict()
scaling = alpha / rank
merged_count = 0
# Group LoRA pairs (A and B for each target)
lora_pairs = {}
for key in lora_state:
if key.endswith(".lora_A"):
base_key = key[:-7] # strip .lora_A
lora_pairs[base_key] = lora_pairs.get(base_key, {})
lora_pairs[base_key]["A"] = lora_state[key]
elif key.endswith(".lora_B"):
base_key = key[:-7]
lora_pairs[base_key] = lora_pairs.get(base_key, {})
lora_pairs[base_key]["B"] = lora_state[key]
for lora_key, pair in lora_pairs.items():
if "A" not in pair or "B" not in pair:
print(f" WARNING: incomplete LoRA pair for {lora_key}")
continue
# Map LoRA key to base model key
# lora_key format: "block.{i}.{module}.{proj}"
# base format: "blocks.{i}.{full_module_path}.{proj}"
parts = lora_key.split(".")
if len(parts) >= 4 and parts[0] == "block":
layer_idx = parts[1]
module = parts[2] # "attn" or "ffn"
proj = parts[3] # "wq", "wk", etc.
if module in LORA_KEY_MAP:
mapped_module = LORA_KEY_MAP[module]
base_weight_key = f"blocks.{layer_idx}.{mapped_module}.{proj}.weight"
else:
base_weight_key = f"blocks.{layer_idx}.{module}.{proj}.weight"
else:
print(f" WARNING: unexpected LoRA key format: {lora_key}")
continue
if base_weight_key not in base_state:
print(f" WARNING: base key not found: {base_weight_key}")
continue
A = pair["A"].float() # (in_features, rank)
B = pair["B"].float() # (rank, out_features)
# delta_W = B^T @ A^T = (out, rank) @ (rank, in) = (out, in)
delta_W = B.T @ A.T
base_state[base_weight_key] = (
base_state[base_weight_key].float() + delta_W * scaling
).to(base_state[base_weight_key].dtype)
merged_count += 1
model.load_state_dict(base_state)
return merged_count
# ═══════════════════════════════════════════════════════════════════
# Load model and tokenizer
# ═══════════════════════════════════════════════════════════════════
print(f"Downloading base model from {BASE_REPO} ...")
ckpt_path = hf_hub_download(repo_id=BASE_REPO, filename=CHECKPOINT_FILE)
vocab_path = hf_hub_download(repo_id=BASE_REPO, filename="vocab.json")
merges_path = hf_hub_download(repo_id=BASE_REPO, filename="merges.txt")
print(f"Downloading LoRA from {LORA_REPO} ...")
lora_path = hf_hub_download(repo_id=LORA_REPO, filename=LORA_FILE)
print("Loading tokenizer ...")
tokenizer = BPETokenizer.from_files(vocab_path, merges_path)
print(f" BPE vocab_size = {tokenizer.vocab_size}")
if tokenizer.vocab_size != MODEL_CONFIG.vocab_size:
print(f" Adjusting model vocab_size: {MODEL_CONFIG.vocab_size} -> {tokenizer.vocab_size}")
MODEL_CONFIG.vocab_size = tokenizer.vocab_size
print("Loading base model ...")
model = SymbioGPT(MODEL_CONFIG)
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
if "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
elif "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
print("Merging LoRA weights ...")
lora_state = torch.load(lora_path, map_location="cpu", weights_only=True)
n_merged = merge_lora(model, lora_state, LORA_ALPHA, LORA_RANK)
print(f" Merged {n_merged} LoRA weight pairs (rank={LORA_RANK}, alpha={LORA_ALPHA})")
model.eval()
n_params = sum(p.numel() for p in model.parameters())
print(f" Model ready: {n_params/1e6:.1f}M params (base + LoRA merged)")
# ═══════════════════════════════════════════════════════════════════
# Generation
# ═══════════════════════════════════════════════════════════════════
_SENTINEL = object() # marks end of generation
@torch.no_grad()
def generate(
prompt: str,
max_tokens: int = 200,
temperature: float = 0.8,
top_k: int = 40,
top_p: float = 1.0,
token_queue: queue.Queue = None,
) -> str:
"""Generate text. If token_queue is provided, pushes each token string
to the queue as it's generated for true streaming."""
tokens = tokenizer.encode(prompt)
if not tokens:
tokens = [0]
idx = torch.tensor([tokens], dtype=torch.long)
generated_ids = []
for _ in range(max_tokens):
idx_cond = idx[:, -MODEL_CONFIG.context_length:]
logits = model(idx_cond)
logits_last = logits[0, -1, :].float()
if temperature > 0.01:
logits_last = logits_last / temperature
else:
logits_last = logits_last / 0.01
if 0 < top_k < logits_last.size(0):
threshold = torch.topk(logits_last, top_k).values[-1]
logits_last[logits_last < threshold] = float("-inf")
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits_last, descending=True)
probs_sorted = F.softmax(sorted_logits, dim=-1)
cumprobs = torch.cumsum(probs_sorted, dim=-1)
cutoff_mask = cumprobs - probs_sorted > top_p
sorted_logits[cutoff_mask] = float("-inf")
logits_last = sorted_logits.scatter(0, sorted_indices, sorted_logits)
probs = F.softmax(logits_last, dim=-1)
next_id = torch.multinomial(probs, 1).item()
generated_ids.append(next_id)
idx = torch.cat([idx, torch.tensor([[next_id]])], dim=1)
if token_queue is not None:
token_queue.put(tokenizer.decode([next_id]))
if token_queue is not None:
token_queue.put(_SENTINEL)
return tokenizer.decode(generated_ids)
# ═══════════════════════════════════════════════════════════════════
# FastAPI server
# ═══════════════════════════════════════════════════════════════════
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
MODEL_CREATED_AT = int(time.time())
MODEL_ID = "symbiogpt-grammar-expert"
def extract_prompt(messages):
if not messages:
return ""
for msg in reversed(messages):
if msg.get("role") == "user":
return msg.get("content", "")
return messages[-1].get("content", "")
@app.get("/")
def health():
return {
"name": "SymbioGPT-GrammarExpert",
"version": "1.1.0",
"description": "SymbioGPT-10M + Grammar Expert LoRA (evolved on CoLA)",
"architecture": "4-organelle decoder (CausalConv + Monarch + LongConv + Attention) "
"+ OrganelleGate + LoRA (rank=8, attn+ffn)",
"model": {
"d_model": MODEL_CONFIG.d_model,
"n_layers": MODEL_CONFIG.n_layers,
"n_heads": MODEL_CONFIG.n_heads,
"context_length": MODEL_CONFIG.context_length,
"vocab_size": MODEL_CONFIG.vocab_size,
"params": f"{n_params/1e6:.1f}M",
"lora_rank": LORA_RANK,
},
"organelles": list(MODEL_CONFIG.organelles),
"endpoints": ["/v1/models", "/v1/chat/completions"],
"features": ["streaming", "OpenAI-compatible", "top-k", "top-p", "grammar-expert-lora"],
}
@app.get("/v1/models")
def list_models():
return {
"object": "list",
"data": [{
"id": MODEL_ID,
"object": "model",
"created": MODEL_CREATED_AT,
"owned_by": "symbiogpt",
}],
}
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
try:
body = await request.json()
except Exception:
return JSONResponse(status_code=400, content={
"error": {"message": "Invalid JSON", "type": "invalid_request_error"}
})
temperature = max(0.01, min(2.0, body.get("temperature", 0.8)))
max_tokens = max(1, min(MODEL_CONFIG.context_length, body.get("max_tokens", 200)))
top_k_val = max(0, min(MODEL_CONFIG.vocab_size, body.get("top_k", 40)))
top_p_val = max(0.0, min(1.0, body.get("top_p", 1.0)))
stream = body.get("stream", False)
messages = body.get("messages", [])
prompt_text = extract_prompt(messages)
prompt_tokens = len(tokenizer.encode(prompt_text)) if prompt_text else 0
completion_id = f"chatcmpl-{uuid.uuid4()}"
created = int(time.time())
if stream:
def sse_stream():
# Initial chunk with role
initial = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": MODEL_ID,
"choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None}],
}
yield f"data: {json_mod.dumps(initial)}\n\n"
# Start generation in background thread
q = queue.Queue()
gen_thread = threading.Thread(
target=generate,
kwargs={
"prompt": prompt_text,
"max_tokens": max_tokens,
"temperature": temperature,
"top_k": top_k_val,
"top_p": top_p_val,
"token_queue": q,
},
daemon=True,
)
gen_thread.start()
# Stream tokens as they arrive
token_count = 0
while True:
tok = q.get()
if tok is _SENTINEL:
break
token_count += 1
chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": MODEL_ID,
"choices": [{"index": 0, "delta": {"content": tok}, "finish_reason": None}],
}
yield f"data: {json_mod.dumps(chunk)}\n\n"
gen_thread.join(timeout=5.0)
# Final chunk
finish = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": MODEL_ID,
"choices": [{"index": 0, "delta": {}, "finish_reason": "length" if token_count >= max_tokens else "stop"}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": token_count,
"total_tokens": prompt_tokens + token_count,
},
}
yield f"data: {json_mod.dumps(finish)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(sse_stream(), media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
else:
text = generate(prompt_text, max_tokens=max_tokens, temperature=temperature,
top_k=top_k_val, top_p=top_p_val)
completion_tokens = len(tokenizer.encode(text))
return {
"id": completion_id,
"object": "chat.completion",
"created": created,
"model": MODEL_ID,
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": text},
"finish_reason": "length",
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
"system_fingerprint": "symbiogpt-grammar-expert-v1",
}
if __name__ == "__main__":
print(f"\nSymbioGPT-GrammarExpert server starting on 0.0.0.0:{PORT} ...")
print(f" GET http://localhost:{PORT}/")
print(f" GET http://localhost:{PORT}/v1/models")
print(f" POST http://localhost:{PORT}/v1/chat/completions")
uvicorn.run(app, host="0.0.0.0", port=PORT)