lean-laguna / scripts /stub_server.py
art87able's picture
Lean Laguna: lossless DFlash speculative decoding on Laguna XS.2 (harness, environment, results)
0a55ff6
#!/usr/bin/env python3
"""
stub_server.py — a tiny, stdlib-only OpenAI-compatible STUB so the benchmark and
eval harness (bench/measure.py, evals/humaneval_subset.py) can be exercised
END-TO-END on the Mac, with NO CUDA / vLLM / Laguna. It fakes just enough of the
vLLM surface to shape-test the whole pipeline before the venue.
What it fakes:
* POST /v1/completions — both streaming (SSE, for measure.py) and non-streaming
(single JSON, for humaneval_subset.py). Output is DETERMINISTIC given the prompt,
so two stubs return identical greedy text → the parity check proves "lossless".
* GET /metrics — Prometheus text. With --spec, it exposes the
spec_decode_* counters measure.py reads to compute acceptance length τ
(tuned so τ ≈ 2.6, in the DFlash card's 2.56–3.07 range). Without --spec it's a
plain baseline (no spec counters → measure.py reports τ = None, which is correct).
JVM analogy: this is WireMock for an LLM endpoint — a canned stub standing in for
the real service so you can integration-test the client/harness without the backend.
Usage:
python scripts/stub_server.py --port 8000 # baseline stub
python scripts/stub_server.py --port 8001 --spec # "dflash" stub (has τ metrics)
"""
from __future__ import annotations
import argparse
import json
import threading
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
GAMMA = 7 # draft length, matches the DFlash card / serve_vllm.py
TAU_TARGET = 2.6 # acceptance length we want measure.py to report for the spec stub
# Deterministic canned completion (same for every prompt → greedy parity is identical).
# Content is irrelevant locally: humaneval runs with --no-exec, measure.py only times it.
COMPLETION = (
"\n # stub completion (local shape-test only; not a real model)\n"
" result = 0\n"
" for i in range(n):\n"
" result += i\n"
" return result\n"
)
def _tokens(text: str) -> list[str]:
"""Split into whitespace-preserving 'tokens' so streaming has several chunks."""
out, buf = [], ""
for ch in text:
buf += ch
if ch.isspace():
out.append(buf)
buf = ""
if buf:
out.append(buf)
return out
class State:
"""Shared mutable counters (one server instance)."""
def __init__(self, spec: bool):
self.spec = spec
self.emitted = 0
self.lock = threading.Lock()
def add_emitted(self, n: int) -> None:
with self.lock:
self.emitted += n
def metrics_text(self) -> str:
lines = [
"# HELP stub_up 1 if the stub is serving",
"# TYPE stub_up gauge",
"stub_up 1",
]
if self.spec:
# Invert measure.py's math so it recovers TAU_TARGET:
# passes = emitted / tau ; draft = passes*gamma ; accepted = emitted - passes
# measure.py: passes' = draft/gamma = passes ; committed = accepted + passes = emitted
# tau = committed / passes = emitted / passes = TAU_TARGET
passes = max(self.emitted / TAU_TARGET, 0.0)
draft = passes * GAMMA
accepted = max(self.emitted - passes, 0.0)
lines += [
f"spec_decode_num_draft_tokens {draft:.0f}",
f"spec_decode_num_accepted_tokens {accepted:.0f}",
f"spec_decode_num_emitted_tokens {self.emitted:.0f}",
]
return "\n".join(lines) + "\n"
class Handler(BaseHTTPRequestHandler):
state: State = None # set on the class before serving
def log_message(self, *args): # quiet
pass
def _send(self, code: int, body: bytes, ctype: str) -> None:
self.send_response(code)
self.send_header("Content-Type", ctype)
self.send_header("Content-Length", str(len(body)))
self.end_headers()
self.wfile.write(body)
def do_GET(self):
if self.path.rstrip("/") == "/metrics":
self._send(200, self.state.metrics_text().encode(), "text/plain; version=0.0.4")
else:
self._send(404, b"not found\n", "text/plain")
def do_POST(self):
path = self.path.rstrip("/")
# Real vLLM serves both the legacy text route (/v1/completions, used by
# bench/measure.py) and the chat route (/v1/chat/completions, used by the
# Kotlin load-test client). The only wire difference is the chunk shape:
# chat streams {delta:{content:...}}, legacy streams {text:...}.
is_chat = path == "/v1/chat/completions"
if not is_chat and path != "/v1/completions":
self._send(404, b"not found\n", "text/plain")
return
n = int(self.headers.get("Content-Length", 0))
try:
req = json.loads(self.rfile.read(n) or b"{}")
except json.JSONDecodeError:
self._send(400, b'{"error":"bad json"}', "application/json")
return
max_tokens = int(req.get("max_tokens", 64))
toks = _tokens(COMPLETION)[:max_tokens]
text = "".join(toks)
self.state.add_emitted(len(toks))
if req.get("stream"):
self.send_response(200)
self.send_header("Content-Type", "text/event-stream")
self.end_headers()
for t in toks:
if is_chat:
chunk = {"choices": [{"delta": {"content": t}, "index": 0,
"finish_reason": None}]}
else:
chunk = {"choices": [{"text": t, "index": 0,
"finish_reason": None}]}
self.wfile.write(f"data: {json.dumps(chunk)}\n\n".encode())
self.wfile.flush()
self.wfile.write(b"data: [DONE]\n\n")
self.wfile.flush()
elif is_chat:
body = {
"id": "stub-chatcmpl",
"object": "chat.completion",
"model": req.get("model", "laguna"),
"choices": [{"message": {"role": "assistant", "content": text},
"index": 0, "finish_reason": "stop"}],
}
self._send(200, json.dumps(body).encode(), "application/json")
else:
body = {
"id": "stub-cmpl",
"object": "text_completion",
"model": req.get("model", "laguna"),
"choices": [{"text": text, "index": 0, "finish_reason": "stop"}],
}
self._send(200, json.dumps(body).encode(), "application/json")
def main() -> None:
p = argparse.ArgumentParser(description="Stdlib OpenAI-compatible stub for local harness shape-tests.")
p.add_argument("--port", type=int, default=8000)
p.add_argument("--spec", action="store_true",
help="Expose spec_decode_* metrics (simulate the DFlash endpoint, τ≈2.6).")
args = p.parse_args()
Handler.state = State(spec=args.spec)
srv = ThreadingHTTPServer(("127.0.0.1", args.port), Handler)
tag = "dflash-stub (with τ metrics)" if args.spec else "baseline-stub"
print(f"[stub] {tag} serving on http://127.0.0.1:{args.port} "
f"(/v1/completions, /v1/chat/completions, /metrics)")
try:
srv.serve_forever()
except KeyboardInterrupt:
pass
finally:
srv.shutdown()
if __name__ == "__main__":
main()