LisaMegaWatts's picture
fix: use dtype= instead of deprecated torch_dtype=
fa396cb verified
"""
server.py β€” OpenAI-compatible inference server for SymbioSLM-ouroboros-lora
Loads the Ouroboros-1MContext-Gemma-270m base model with the symbiogenesis-evolved
LoRA adapter (r=44, all 7 targets, PPL 61.4 on philosophy corpus).
Downloads base model + LoRA adapter from HuggingFace at runtime.
Endpoints:
GET / -> health check / API info
GET /v1/models -> list available models
POST /v1/chat/completions -> generate text (OpenAI format, streaming supported)
"""
import json
import os
import time
import uuid
from http.server import HTTPServer, BaseHTTPRequestHandler
from threading import Lock
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
# ═══════════════════════════════════════════════════════════════════
# Configuration
# ═══════════════════════════════════════════════════════════════════
BASE_MODEL_REPO = os.environ.get(
"BASE_MODEL_REPO", "LisaMegaWatts/Ouroboros-1MContext-Gemma-270m"
)
LORA_REPO = os.environ.get(
"LORA_REPO", "LisaMegaWatts/SymbioSLM-ouroboros-lora-20260301"
)
PORT = int(os.environ.get("PORT", "7860"))
MODEL_ID = "symbioslm-ouroboros-lora"
MAX_CONTEXT = 512 # LoRA was trained with context_length=512
# ═══════════════════════════════════════════════════════════════════
# Sampling helpers
# ═══════════════════════════════════════════════════════════════════
def _sample_logits(logits: torch.Tensor, temperature: float, top_k: int,
top_p: float) -> int:
if temperature <= 0:
return logits.argmax().item()
logits = logits / temperature
if top_k > 0:
topk_vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < topk_vals[-1]] = float("-inf")
if top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
remove = cum_probs - F.softmax(sorted_logits, dim=-1) >= top_p
sorted_logits[remove] = float("-inf")
logits = sorted_logits.scatter(0, sorted_idx, sorted_logits)
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, 1).item()
# ═══════════════════════════════════════════════════════════════════
# Text generation (uses HF generate for KV cache)
# ═══════════════════════════════════════════════════════════════════
@torch.inference_mode()
def generate(prompt: str, max_tokens: int = 200, temperature: float = 0.8,
top_k: int = 40, top_p: float = 1.0) -> tuple[str, int]:
inputs = TOKENIZER(prompt, return_tensors="pt").to(DEVICE)
input_ids = inputs["input_ids"][:, -MAX_CONTEXT:]
prompt_len = input_ids.shape[1]
generated = input_ids
past_key_values = None
for _ in range(max_tokens):
if generated.shape[1] >= MAX_CONTEXT:
break
outputs = MODEL(
input_ids=generated[:, -1:] if past_key_values is not None else generated,
past_key_values=past_key_values,
use_cache=True,
)
past_key_values = outputs.past_key_values
next_logits = outputs.logits[0, -1, :].float()
idx = _sample_logits(next_logits, temperature, top_k, top_p)
next_token = torch.tensor([[idx]], device=DEVICE)
generated = torch.cat([generated, next_token], dim=1)
if idx == TOKENIZER.eos_token_id:
break
new_ids = generated[0, prompt_len:].tolist()
return TOKENIZER.decode(new_ids, skip_special_tokens=True), prompt_len
@torch.inference_mode()
def generate_streaming(prompt: str, max_tokens: int = 200,
temperature: float = 0.8, top_k: int = 40,
top_p: float = 1.0):
inputs = TOKENIZER(prompt, return_tensors="pt").to(DEVICE)
input_ids = inputs["input_ids"][:, -MAX_CONTEXT:]
prompt_len = input_ids.shape[1]
generated = input_ids
past_key_values = None
# Diff-based decode: SentencePiece ▁ prefix and multi-byte UTF-8
# require decoding all generated IDs and diffing against previous output
generated_ids = []
prev_text = ""
for _ in range(max_tokens):
if generated.shape[1] >= MAX_CONTEXT:
break
outputs = MODEL(
input_ids=generated[:, -1:] if past_key_values is not None else generated,
past_key_values=past_key_values,
use_cache=True,
)
past_key_values = outputs.past_key_values
next_logits = outputs.logits[0, -1, :].float()
idx = _sample_logits(next_logits, temperature, top_k, top_p)
if idx == TOKENIZER.eos_token_id:
break
generated_ids.append(idx)
full_text = TOKENIZER.decode(generated_ids, skip_special_tokens=True)
delta = full_text[len(prev_text):]
prev_text = full_text
if delta:
yield delta, prompt_len
next_token = torch.tensor([[idx]], device=DEVICE)
generated = torch.cat([generated, next_token], dim=1)
# ═══════════════════════════════════════════════════════════════════
# Load model
# ═══════════════════════════════════════════════════════════════════
print(f"Loading base model: {BASE_MODEL_REPO} ...")
BASE_MODEL = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_REPO,
dtype=torch.float32,
device_map="cpu",
low_cpu_mem_usage=True,
)
print(f"Loading LoRA adapter: {LORA_REPO} ...")
MODEL = PeftModel.from_pretrained(BASE_MODEL, LORA_REPO)
MODEL.eval()
print("Loading tokenizer ...")
TOKENIZER = AutoTokenizer.from_pretrained(BASE_MODEL_REPO)
DEVICE = torch.device("cpu")
MODEL_CREATED_AT = int(time.time())
NUM_PARAMS_BASE = sum(p.numel() for p in BASE_MODEL.parameters())
NUM_PARAMS_LORA = sum(p.numel() for n, p in MODEL.named_parameters() if "lora_" in n)
print(
f"\nModel ready: {NUM_PARAMS_BASE:,} base params + {NUM_PARAMS_LORA:,} LoRA params"
)
print(f" Base: {BASE_MODEL_REPO}")
print(f" LoRA: {LORA_REPO}")
print(f" Context: {MAX_CONTEXT}")
MODEL_LOCK = Lock()
# ═══════════════════════════════════════════════════════════════════
# HTTP helpers
# ═══════════════════════════════════════════════════════════════════
CORS_HEADERS = {
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization",
}
def extract_prompt(messages):
if not messages:
return ""
for msg in reversed(messages):
if msg.get("role") == "user":
return msg.get("content", "")
return messages[-1].get("content", "")
# ═══════════════════════════════════════════════════════════════════
# Request handler
# ═══════════════════════════════════════════════════════════════════
class Handler(BaseHTTPRequestHandler):
def log_message(self, format, *args):
print(f"[{self.log_date_time_string()}] {format % args}")
def _send_json(self, status, body):
data = json.dumps(body).encode()
self.send_response(status)
self.send_header("Content-Type", "application/json")
for k, v in CORS_HEADERS.items():
self.send_header(k, v)
self.send_header("Content-Length", str(len(data)))
self.end_headers()
self.wfile.write(data)
def do_OPTIONS(self):
self.send_response(204)
for k, v in CORS_HEADERS.items():
self.send_header(k, v)
self.end_headers()
def do_GET(self):
if self.path == "/":
self._send_json(200, {
"name": "SymbioSLM-ouroboros-lora",
"version": "1.0.0",
"description": (
"Symbiogenesis-evolved LoRA adapter (r=44, all 7 targets) "
"on Ouroboros-1MContext-Gemma-270m. PPL 309 -> 61 (5x improvement) "
"on philosophy corpus."
),
"architecture": "Gemma3ForCausalLM + PEFT LoRA",
"model": {
"base_model": BASE_MODEL_REPO,
"lora_adapter": LORA_REPO,
"base_params": NUM_PARAMS_BASE,
"lora_params": NUM_PARAMS_LORA,
"lora_rank": 44,
"lora_targets": "q,k,v,o,gate,up,down",
"context_length": MAX_CONTEXT,
"vocab_size": 262144,
},
"evolution": {
"method": "symbiogenesis",
"gelation_step": 7,
"base_ppl": 309.3,
"final_ppl": 61.4,
},
"endpoints": ["/v1/models", "/v1/chat/completions"],
"features": ["streaming", "OpenAI-compatible", "top-k", "top-p", "kv-cache"],
"github": "https://github.com/DavinciDreams/SymbioGPT",
})
elif self.path == "/v1/models":
self._send_json(200, {
"object": "list",
"data": [{
"id": MODEL_ID,
"object": "model",
"created": MODEL_CREATED_AT,
"owned_by": "symbioslm",
}],
})
else:
self._send_json(404, {"error": {
"message": f"Not found: GET {self.path}",
"type": "invalid_request_error",
"code": "not_found",
}})
def do_POST(self):
if self.path != "/v1/chat/completions":
self._send_json(404, {"error": {
"message": f"Not found: POST {self.path}",
"type": "invalid_request_error",
"code": "not_found",
}})
return
content_length = int(self.headers.get("Content-Length", 0))
try:
body = json.loads(self.rfile.read(content_length))
except (json.JSONDecodeError, ValueError):
self._send_json(400, {"error": {
"message": "Invalid JSON in request body",
"type": "invalid_request_error",
"code": "invalid_json",
}})
return
temperature = max(0.0, min(2.0, float(body.get("temperature", 0.8))))
max_tokens = max(1, min(MAX_CONTEXT, int(body.get("max_tokens", 200))))
top_k_val = max(0, int(body.get("top_k", 40)))
top_p_val = max(0.0, min(1.0, float(body.get("top_p", 1.0))))
stream = bool(body.get("stream", False))
messages = body.get("messages", [])
prompt_text = extract_prompt(messages)
completion_id = f"chatcmpl-{uuid.uuid4()}"
created = int(time.time())
with MODEL_LOCK:
if stream:
self._handle_stream(
prompt_text, max_tokens, temperature, top_k_val, top_p_val,
completion_id, created,
)
else:
self._handle_non_stream(
prompt_text, max_tokens, temperature, top_k_val, top_p_val,
completion_id, created,
)
def _handle_stream(self, prompt_text, max_tokens, temperature, top_k, top_p,
completion_id, created):
self.send_response(200)
self.send_header("Content-Type", "text/event-stream")
self.send_header("Cache-Control", "no-cache")
self.send_header("X-Accel-Buffering", "no")
for k, v in CORS_HEADERS.items():
self.send_header(k, v)
self.end_headers()
def sse(data):
self.wfile.write(f"data: {json.dumps(data)}\n\n".encode())
self.wfile.flush()
sse({
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": MODEL_ID,
"choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None}],
})
token_count = 0
prompt_tokens = 0
for token_str, p_len in generate_streaming(
prompt_text, max_tokens=max_tokens, temperature=temperature,
top_k=top_k, top_p=top_p,
):
token_count += 1
prompt_tokens = p_len
sse({
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": MODEL_ID,
"choices": [{"index": 0, "delta": {"content": token_str}, "finish_reason": None}],
})
sse({
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": MODEL_ID,
"choices": [{"index": 0, "delta": {}, "finish_reason": "length" if token_count >= max_tokens else "stop"}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": token_count,
"total_tokens": prompt_tokens + token_count,
},
})
self.wfile.write(b"data: [DONE]\n\n")
self.wfile.flush()
def _handle_non_stream(self, prompt_text, max_tokens, temperature, top_k, top_p,
completion_id, created):
text, prompt_tokens = generate(
prompt_text, max_tokens=max_tokens, temperature=temperature,
top_k=top_k, top_p=top_p,
)
completion_tokens = len(TOKENIZER.encode(text))
finish_reason = "length" if completion_tokens >= max_tokens else "stop"
self._send_json(200, {
"id": completion_id,
"object": "chat.completion",
"created": created,
"model": MODEL_ID,
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": text},
"finish_reason": finish_reason,
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
"system_fingerprint": "symbioslm-ouroboros-lora-v1",
})
# ═══════════════════════════════════════════════════════════════════
# Start server
# ═══════════════════════════════════════════════════════════════════
if __name__ == "__main__":
print(f"\nSymbioSLM-ouroboros-lora server starting on 0.0.0.0:{PORT} ...")
print(f" GET http://localhost:{PORT}/")
print(f" GET http://localhost:{PORT}/v1/models")
print(f" POST http://localhost:{PORT}/v1/chat/completions")
print(f" POST http://localhost:{PORT}/v1/chat/completions (stream=true)")
print()
server = HTTPServer(("0.0.0.0", PORT), Handler)
server.serve_forever()