| | """SymbioGPT-GrammarExpert β OpenAI-compatible inference server. |
| | |
| | SymbioGPT-10M base model with Grammar Expert LoRA adapter merged at startup. |
| | The LoRA was discovered via evolutionary search on CoLA (grammar acceptability). |
| | Downloads base checkpoint + LoRA weights from HuggingFace on first run. |
| | |
| | True token-by-token SSE streaming via background thread + queue. |
| | """ |
| | import json as json_mod |
| | import math |
| | import os |
| | import queue |
| | import threading |
| | import time |
| | import uuid |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | import uvicorn |
| | from fastapi import FastAPI, Request |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from fastapi.responses import JSONResponse, StreamingResponse |
| | from huggingface_hub import hf_hub_download |
| |
|
| | from symbio_model import SymbioConfig, SymbioGPT |
| | from tokenizer import BPETokenizer |
| |
|
| | |
| | |
| | |
| |
|
| | BASE_REPO = os.environ.get("BASE_REPO", "LisaMegaWatts/SymbioGPT-10M") |
| | LORA_REPO = os.environ.get("LORA_REPO", "LisaMegaWatts/SymbioGPT-GrammarExpert-20260301") |
| | PORT = int(os.environ.get("PORT", "7860")) |
| | CHECKPOINT_FILE = "symbio_best.pt" |
| | LORA_FILE = "lora_weights.pt" |
| |
|
| | |
| | LORA_RANK = 8 |
| | LORA_ALPHA = 8.0 |
| |
|
| | MODEL_CONFIG = SymbioConfig( |
| | d_model=320, |
| | n_layers=8, |
| | n_heads=5, |
| | head_dim=64, |
| | ffn_mult=4, |
| | context_length=256, |
| | vocab_size=2000, |
| | weight_tying=True, |
| | organelles=("causal_conv", "monarch", "long_conv", "attention"), |
| | conv_kernel_size=4, |
| | n_monarch_heads=1, |
| | gate_temperature_init=1.0, |
| | free_energy_beta=0.001, |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | LORA_KEY_MAP = { |
| | |
| | "attn": "seq_mixer.organelle_modules.attention", |
| | |
| | "ffn": "ffn", |
| | } |
| |
|
| |
|
| | def merge_lora(model, lora_state, alpha, rank): |
| | """Merge LoRA weights into base model. |
| | |
| | LoRA formula: W_merged = W_base + (B^T @ A^T) * (alpha / rank) |
| | Where A: (in_features, rank), B: (rank, out_features) as stored. |
| | """ |
| | base_state = model.state_dict() |
| | scaling = alpha / rank |
| | merged_count = 0 |
| |
|
| | |
| | lora_pairs = {} |
| | for key in lora_state: |
| | if key.endswith(".lora_A"): |
| | base_key = key[:-7] |
| | lora_pairs[base_key] = lora_pairs.get(base_key, {}) |
| | lora_pairs[base_key]["A"] = lora_state[key] |
| | elif key.endswith(".lora_B"): |
| | base_key = key[:-7] |
| | lora_pairs[base_key] = lora_pairs.get(base_key, {}) |
| | lora_pairs[base_key]["B"] = lora_state[key] |
| |
|
| | for lora_key, pair in lora_pairs.items(): |
| | if "A" not in pair or "B" not in pair: |
| | print(f" WARNING: incomplete LoRA pair for {lora_key}") |
| | continue |
| |
|
| | |
| | |
| | |
| | parts = lora_key.split(".") |
| | if len(parts) >= 4 and parts[0] == "block": |
| | layer_idx = parts[1] |
| | module = parts[2] |
| | proj = parts[3] |
| |
|
| | if module in LORA_KEY_MAP: |
| | mapped_module = LORA_KEY_MAP[module] |
| | base_weight_key = f"blocks.{layer_idx}.{mapped_module}.{proj}.weight" |
| | else: |
| | base_weight_key = f"blocks.{layer_idx}.{module}.{proj}.weight" |
| | else: |
| | print(f" WARNING: unexpected LoRA key format: {lora_key}") |
| | continue |
| |
|
| | if base_weight_key not in base_state: |
| | print(f" WARNING: base key not found: {base_weight_key}") |
| | continue |
| |
|
| | A = pair["A"].float() |
| | B = pair["B"].float() |
| |
|
| | |
| | delta_W = B.T @ A.T |
| | base_state[base_weight_key] = ( |
| | base_state[base_weight_key].float() + delta_W * scaling |
| | ).to(base_state[base_weight_key].dtype) |
| | merged_count += 1 |
| |
|
| | model.load_state_dict(base_state) |
| | return merged_count |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | print(f"Downloading base model from {BASE_REPO} ...") |
| | ckpt_path = hf_hub_download(repo_id=BASE_REPO, filename=CHECKPOINT_FILE) |
| | vocab_path = hf_hub_download(repo_id=BASE_REPO, filename="vocab.json") |
| | merges_path = hf_hub_download(repo_id=BASE_REPO, filename="merges.txt") |
| |
|
| | print(f"Downloading LoRA from {LORA_REPO} ...") |
| | lora_path = hf_hub_download(repo_id=LORA_REPO, filename=LORA_FILE) |
| |
|
| | print("Loading tokenizer ...") |
| | tokenizer = BPETokenizer.from_files(vocab_path, merges_path) |
| | print(f" BPE vocab_size = {tokenizer.vocab_size}") |
| |
|
| | if tokenizer.vocab_size != MODEL_CONFIG.vocab_size: |
| | print(f" Adjusting model vocab_size: {MODEL_CONFIG.vocab_size} -> {tokenizer.vocab_size}") |
| | MODEL_CONFIG.vocab_size = tokenizer.vocab_size |
| |
|
| | print("Loading base model ...") |
| | model = SymbioGPT(MODEL_CONFIG) |
| |
|
| | checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True) |
| | if "model_state_dict" in checkpoint: |
| | state_dict = checkpoint["model_state_dict"] |
| | elif "state_dict" in checkpoint: |
| | state_dict = checkpoint["state_dict"] |
| | else: |
| | state_dict = checkpoint |
| | state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} |
| | model.load_state_dict(state_dict) |
| |
|
| | print("Merging LoRA weights ...") |
| | lora_state = torch.load(lora_path, map_location="cpu", weights_only=True) |
| | n_merged = merge_lora(model, lora_state, LORA_ALPHA, LORA_RANK) |
| | print(f" Merged {n_merged} LoRA weight pairs (rank={LORA_RANK}, alpha={LORA_ALPHA})") |
| |
|
| | model.eval() |
| | n_params = sum(p.numel() for p in model.parameters()) |
| | print(f" Model ready: {n_params/1e6:.1f}M params (base + LoRA merged)") |
| |
|
| | |
| | |
| | |
| |
|
| | _SENTINEL = object() |
| |
|
| |
|
| | @torch.no_grad() |
| | def generate( |
| | prompt: str, |
| | max_tokens: int = 200, |
| | temperature: float = 0.8, |
| | top_k: int = 40, |
| | top_p: float = 1.0, |
| | token_queue: queue.Queue = None, |
| | ) -> str: |
| | """Generate text. If token_queue is provided, pushes each token string |
| | to the queue as it's generated for true streaming.""" |
| | tokens = tokenizer.encode(prompt) |
| | if not tokens: |
| | tokens = [0] |
| | idx = torch.tensor([tokens], dtype=torch.long) |
| | generated_ids = [] |
| |
|
| | for _ in range(max_tokens): |
| | idx_cond = idx[:, -MODEL_CONFIG.context_length:] |
| | logits = model(idx_cond) |
| | logits_last = logits[0, -1, :].float() |
| |
|
| | if temperature > 0.01: |
| | logits_last = logits_last / temperature |
| | else: |
| | logits_last = logits_last / 0.01 |
| |
|
| | if 0 < top_k < logits_last.size(0): |
| | threshold = torch.topk(logits_last, top_k).values[-1] |
| | logits_last[logits_last < threshold] = float("-inf") |
| |
|
| | if top_p < 1.0: |
| | sorted_logits, sorted_indices = torch.sort(logits_last, descending=True) |
| | probs_sorted = F.softmax(sorted_logits, dim=-1) |
| | cumprobs = torch.cumsum(probs_sorted, dim=-1) |
| | cutoff_mask = cumprobs - probs_sorted > top_p |
| | sorted_logits[cutoff_mask] = float("-inf") |
| | logits_last = sorted_logits.scatter(0, sorted_indices, sorted_logits) |
| |
|
| | probs = F.softmax(logits_last, dim=-1) |
| | next_id = torch.multinomial(probs, 1).item() |
| | generated_ids.append(next_id) |
| | idx = torch.cat([idx, torch.tensor([[next_id]])], dim=1) |
| |
|
| | if token_queue is not None: |
| | token_queue.put(tokenizer.decode([next_id])) |
| |
|
| | if token_queue is not None: |
| | token_queue.put(_SENTINEL) |
| |
|
| | return tokenizer.decode(generated_ids) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | app = FastAPI() |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| | MODEL_CREATED_AT = int(time.time()) |
| | MODEL_ID = "symbiogpt-grammar-expert" |
| |
|
| |
|
| | def extract_prompt(messages): |
| | if not messages: |
| | return "" |
| | for msg in reversed(messages): |
| | if msg.get("role") == "user": |
| | return msg.get("content", "") |
| | return messages[-1].get("content", "") |
| |
|
| |
|
| | @app.get("/") |
| | def health(): |
| | return { |
| | "name": "SymbioGPT-GrammarExpert", |
| | "version": "1.1.0", |
| | "description": "SymbioGPT-10M + Grammar Expert LoRA (evolved on CoLA)", |
| | "architecture": "4-organelle decoder (CausalConv + Monarch + LongConv + Attention) " |
| | "+ OrganelleGate + LoRA (rank=8, attn+ffn)", |
| | "model": { |
| | "d_model": MODEL_CONFIG.d_model, |
| | "n_layers": MODEL_CONFIG.n_layers, |
| | "n_heads": MODEL_CONFIG.n_heads, |
| | "context_length": MODEL_CONFIG.context_length, |
| | "vocab_size": MODEL_CONFIG.vocab_size, |
| | "params": f"{n_params/1e6:.1f}M", |
| | "lora_rank": LORA_RANK, |
| | }, |
| | "organelles": list(MODEL_CONFIG.organelles), |
| | "endpoints": ["/v1/models", "/v1/chat/completions"], |
| | "features": ["streaming", "OpenAI-compatible", "top-k", "top-p", "grammar-expert-lora"], |
| | } |
| |
|
| |
|
| | @app.get("/v1/models") |
| | def list_models(): |
| | return { |
| | "object": "list", |
| | "data": [{ |
| | "id": MODEL_ID, |
| | "object": "model", |
| | "created": MODEL_CREATED_AT, |
| | "owned_by": "symbiogpt", |
| | }], |
| | } |
| |
|
| |
|
| | @app.post("/v1/chat/completions") |
| | async def chat_completions(request: Request): |
| | try: |
| | body = await request.json() |
| | except Exception: |
| | return JSONResponse(status_code=400, content={ |
| | "error": {"message": "Invalid JSON", "type": "invalid_request_error"} |
| | }) |
| |
|
| | temperature = max(0.01, min(2.0, body.get("temperature", 0.8))) |
| | max_tokens = max(1, min(MODEL_CONFIG.context_length, body.get("max_tokens", 200))) |
| | top_k_val = max(0, min(MODEL_CONFIG.vocab_size, body.get("top_k", 40))) |
| | top_p_val = max(0.0, min(1.0, body.get("top_p", 1.0))) |
| | stream = body.get("stream", False) |
| | messages = body.get("messages", []) |
| | prompt_text = extract_prompt(messages) |
| | prompt_tokens = len(tokenizer.encode(prompt_text)) if prompt_text else 0 |
| |
|
| | completion_id = f"chatcmpl-{uuid.uuid4()}" |
| | created = int(time.time()) |
| |
|
| | if stream: |
| | def sse_stream(): |
| | |
| | initial = { |
| | "id": completion_id, |
| | "object": "chat.completion.chunk", |
| | "created": created, |
| | "model": MODEL_ID, |
| | "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None}], |
| | } |
| | yield f"data: {json_mod.dumps(initial)}\n\n" |
| |
|
| | |
| | q = queue.Queue() |
| | gen_thread = threading.Thread( |
| | target=generate, |
| | kwargs={ |
| | "prompt": prompt_text, |
| | "max_tokens": max_tokens, |
| | "temperature": temperature, |
| | "top_k": top_k_val, |
| | "top_p": top_p_val, |
| | "token_queue": q, |
| | }, |
| | daemon=True, |
| | ) |
| | gen_thread.start() |
| |
|
| | |
| | token_count = 0 |
| | while True: |
| | tok = q.get() |
| | if tok is _SENTINEL: |
| | break |
| | token_count += 1 |
| | chunk = { |
| | "id": completion_id, |
| | "object": "chat.completion.chunk", |
| | "created": created, |
| | "model": MODEL_ID, |
| | "choices": [{"index": 0, "delta": {"content": tok}, "finish_reason": None}], |
| | } |
| | yield f"data: {json_mod.dumps(chunk)}\n\n" |
| |
|
| | gen_thread.join(timeout=5.0) |
| |
|
| | |
| | finish = { |
| | "id": completion_id, |
| | "object": "chat.completion.chunk", |
| | "created": created, |
| | "model": MODEL_ID, |
| | "choices": [{"index": 0, "delta": {}, "finish_reason": "length" if token_count >= max_tokens else "stop"}], |
| | "usage": { |
| | "prompt_tokens": prompt_tokens, |
| | "completion_tokens": token_count, |
| | "total_tokens": prompt_tokens + token_count, |
| | }, |
| | } |
| | yield f"data: {json_mod.dumps(finish)}\n\n" |
| | yield "data: [DONE]\n\n" |
| |
|
| | return StreamingResponse(sse_stream(), media_type="text/event-stream", |
| | headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}) |
| | else: |
| | text = generate(prompt_text, max_tokens=max_tokens, temperature=temperature, |
| | top_k=top_k_val, top_p=top_p_val) |
| | completion_tokens = len(tokenizer.encode(text)) |
| |
|
| | return { |
| | "id": completion_id, |
| | "object": "chat.completion", |
| | "created": created, |
| | "model": MODEL_ID, |
| | "choices": [{ |
| | "index": 0, |
| | "message": {"role": "assistant", "content": text}, |
| | "finish_reason": "length", |
| | }], |
| | "usage": { |
| | "prompt_tokens": prompt_tokens, |
| | "completion_tokens": completion_tokens, |
| | "total_tokens": prompt_tokens + completion_tokens, |
| | }, |
| | "system_fingerprint": "symbiogpt-grammar-expert-v1", |
| | } |
| |
|
| |
|
| | if __name__ == "__main__": |
| | print(f"\nSymbioGPT-GrammarExpert server starting on 0.0.0.0:{PORT} ...") |
| | print(f" GET http://localhost:{PORT}/") |
| | print(f" GET http://localhost:{PORT}/v1/models") |
| | print(f" POST http://localhost:{PORT}/v1/chat/completions") |
| | uvicorn.run(app, host="0.0.0.0", port=PORT) |
| |
|