Commit ·
c32c359
0
Parent(s):
minimal continuous-batching LLM engine
Browse files- .claude/settings.local.json +12 -0
- .gitignore +17 -0
- LICENSE +21 -0
- README.md +116 -0
- examples/smoke_client.py +74 -0
- pyproject.toml +27 -0
- requirements.txt +7 -0
- tests/__init__.py +0 -0
- tests/test_block_manager.py +121 -0
- tests/test_scheduler.py +103 -0
- tiny_vllm/__init__.py +28 -0
- tiny_vllm/block_manager.py +265 -0
- tiny_vllm/config.py +49 -0
- tiny_vllm/engine.py +385 -0
- tiny_vllm/model_runner.py +392 -0
- tiny_vllm/paged_kv.py +70 -0
- tiny_vllm/request.py +86 -0
- tiny_vllm/sampler.py +53 -0
- tiny_vllm/scheduler.py +223 -0
- tiny_vllm/server.py +307 -0
- web/app.js +272 -0
- web/index.html +68 -0
- web/style.css +213 -0
.claude/settings.local.json
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"permissions": {
|
| 3 |
+
"allow": [
|
| 4 |
+
"Bash(python -c \"import torch, transformers, fastapi, uvicorn, pydantic; print\\('deps ok'\\)\")",
|
| 5 |
+
"Bash(python -c \"from tiny_vllm.block_manager import BlockManager; print\\('ok'\\)\")",
|
| 6 |
+
"Bash(python -m pytest tests/test_block_manager.py tests/test_scheduler.py -v)",
|
| 7 |
+
"Bash(pip install *)",
|
| 8 |
+
"Bash(python -m pytest tests/ -v)",
|
| 9 |
+
"Bash(python -c ' *)"
|
| 10 |
+
]
|
| 11 |
+
}
|
| 12 |
+
}
|
.gitignore
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
*.egg-info/
|
| 4 |
+
.venv/
|
| 5 |
+
venv/
|
| 6 |
+
.env
|
| 7 |
+
.pytest_cache/
|
| 8 |
+
.mypy_cache/
|
| 9 |
+
.ruff_cache/
|
| 10 |
+
.DS_Store
|
| 11 |
+
*.log
|
| 12 |
+
# HF cache that may land in CWD
|
| 13 |
+
.cache/
|
| 14 |
+
hf_cache/
|
| 15 |
+
# Editor
|
| 16 |
+
.vscode/
|
| 17 |
+
.idea/
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# tiny_vllm
|
| 2 |
+
|
| 3 |
+
A **minimal continuous-batching LLM engine** built to be read end-to-end. It
|
| 4 |
+
re-implements the load-bearing ideas of vLLM / SGLang in ~1.5k lines of
|
| 5 |
+
Python:
|
| 6 |
+
|
| 7 |
+
- **Paged KV cache** with logical block tables — physical blocks are a flat
|
| 8 |
+
pool; per-sequence block tables map logical positions → physical slots.
|
| 9 |
+
- **Automatic prefix caching** via content-addressed hashes — two requests
|
| 10 |
+
with the same prompt prefix share KV blocks.
|
| 11 |
+
- **Continuous batching with chunked prefill** — each scheduling step packs a
|
| 12 |
+
budget of tokens from any mix of new prefills and ongoing decodes; long
|
| 13 |
+
prompts are sliced so they don't starve the decoders.
|
| 14 |
+
- **Recompute-style preemption** — when the pool runs dry, the youngest
|
| 15 |
+
running sequence is evicted and re-enqueued.
|
| 16 |
+
- **SSE streaming** over a thin FastAPI layer — both token deltas
|
| 17 |
+
(`/generate`, OpenAI-compatible `/v1/completions`) and a parallel engine
|
| 18 |
+
event stream (`/engine/events`) the demo page subscribes to.
|
| 19 |
+
- A **visualization demo page** that renders the block pool, scheduler
|
| 20 |
+
queues, per-sequence block tables, and live tokens as the engine runs.
|
| 21 |
+
|
| 22 |
+
It is **not** vLLM. Attention runs in plain PyTorch SDPA (per-sequence loop),
|
| 23 |
+
there are no fused or paged-attention kernels, and CPU is the default device.
|
| 24 |
+
This is a learning artifact, not a serving stack.
|
| 25 |
+
|
| 26 |
+
## Quick start
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
pip install -r requirements.txt
|
| 30 |
+
# or: pip install -e .
|
| 31 |
+
|
| 32 |
+
python -m tiny_vllm.server --model Qwen/Qwen2.5-0.5B-Instruct --device cpu
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
Open [http://localhost:8000](http://localhost:8000) for the live
|
| 36 |
+
visualization, or hit the API directly:
|
| 37 |
+
|
| 38 |
+
```bash
|
| 39 |
+
# OpenAI-style streaming
|
| 40 |
+
curl -N http://localhost:8000/v1/completions \
|
| 41 |
+
-H 'content-type: application/json' \
|
| 42 |
+
-d '{"prompt":"In two sentences, what is paged attention?","max_tokens":80,"stream":true}'
|
| 43 |
+
|
| 44 |
+
# A simpler endpoint
|
| 45 |
+
curl -N http://localhost:8000/generate \
|
| 46 |
+
-H 'content-type: application/json' \
|
| 47 |
+
-d '{"prompt":"haiku about KV caches","max_tokens":48,"stream":true}'
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
Smoke test with concurrent requests:
|
| 51 |
+
|
| 52 |
+
```bash
|
| 53 |
+
python examples/smoke_client.py # 4 prompts in parallel
|
| 54 |
+
python examples/smoke_client.py --prefix-demo # show prefix-cache speedup
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
## The pieces
|
| 58 |
+
|
| 59 |
+
| File | What |
|
| 60 |
+
|---|---|
|
| 61 |
+
| `tiny_vllm/config.py` | `EngineConfig`, `SamplingParams` |
|
| 62 |
+
| `tiny_vllm/request.py` | `Sequence`, status enum, KV bookkeeping fields |
|
| 63 |
+
| `tiny_vllm/block_manager.py` | Physical block pool, refcounts, prefix-cache (hash-chain) |
|
| 64 |
+
| `tiny_vllm/scheduler.py` | Continuous batching + chunked prefill + preemption |
|
| 65 |
+
| `tiny_vllm/paged_kv.py` | The actual KV tensors that block ids point into |
|
| 66 |
+
| `tiny_vllm/model_runner.py` | Minimal Qwen2 forward (RoPE, RMSNorm, GQA) using the paged cache |
|
| 67 |
+
| `tiny_vllm/sampler.py` | Greedy / top-k / top-p |
|
| 68 |
+
| `tiny_vllm/engine.py` | Orchestrator: scheduler ⟶ model ⟶ sampler ⟶ outputs + events |
|
| 69 |
+
| `tiny_vllm/server.py` | FastAPI: `/generate`, `/v1/completions`, `/engine/events`, `/` |
|
| 70 |
+
| `web/` | Static demo page (vanilla HTML/CSS/JS, no framework) |
|
| 71 |
+
|
| 72 |
+
The model-free parts (block manager, scheduler) have unit tests:
|
| 73 |
+
|
| 74 |
+
```bash
|
| 75 |
+
pip install pytest
|
| 76 |
+
python -m pytest tests/
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
## What the demo page shows
|
| 80 |
+
|
| 81 |
+
| Panel | What you're looking at |
|
| 82 |
+
|---|---|
|
| 83 |
+
| **Block pool** | One cell per physical block. Color = state (free / cached-evictable / in-use / shared). Orange border = the block has been hashed and is discoverable in the prefix cache. |
|
| 84 |
+
| **Scheduler** | Live stats: tokens this step, prefill-vs-decode split, step latency, prefix-cache hit-rate, preemption count. Step log scrolls below. |
|
| 85 |
+
| **Sequences** | Every active sequence's block table (cell per block, blue = prefix-cache hit, purple = shared), status, generated text. |
|
| 86 |
+
|
| 87 |
+
Click **Send ×2** to fire the same prompt twice — the second send should
|
| 88 |
+
prefix-cache the entire prompt and start decoding almost immediately.
|
| 89 |
+
|
| 90 |
+
## Reading order
|
| 91 |
+
|
| 92 |
+
If you want to learn the system:
|
| 93 |
+
|
| 94 |
+
1. `request.py` — what a request becomes.
|
| 95 |
+
2. `block_manager.py` — read `admit()` and `_take_free_block()`; the prefix
|
| 96 |
+
cache lives here.
|
| 97 |
+
3. `scheduler.py` — read `schedule()`; the two-phase loop is the heart of
|
| 98 |
+
continuous batching.
|
| 99 |
+
4. `model_runner.py` → `Qwen2Attention.forward` — see how Q/K/V get written
|
| 100 |
+
into and read out of the paged cache.
|
| 101 |
+
5. `engine.py::_run_loop` — how everything is wired step-by-step.
|
| 102 |
+
6. `server.py` — the SSE surface.
|
| 103 |
+
|
| 104 |
+
## Known limitations
|
| 105 |
+
|
| 106 |
+
- CPU-friendly defaults; no custom CUDA / Triton kernels.
|
| 107 |
+
- Per-sequence attention loop inside each layer (not packed/varlen-fused).
|
| 108 |
+
- Only Llama/Qwen2-style decoder architectures (RMSNorm + RoPE + GQA + SwiGLU MLP).
|
| 109 |
+
- Single-prompt completions (`n=1`); no beam search.
|
| 110 |
+
- No tensor parallel, no quantization.
|
| 111 |
+
- Prefix-cache eviction is LRU on the free list — not the full
|
| 112 |
+
reference-counted radix tree vLLM ships.
|
| 113 |
+
|
| 114 |
+
## License
|
| 115 |
+
|
| 116 |
+
MIT.
|
examples/smoke_client.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Fire concurrent prompts at a running tiny_vllm server.
|
| 2 |
+
|
| 3 |
+
Run the server first:
|
| 4 |
+
python -m tiny_vllm.server --model Qwen/Qwen2.5-0.5B-Instruct
|
| 5 |
+
|
| 6 |
+
Then in another shell:
|
| 7 |
+
python examples/smoke_client.py
|
| 8 |
+
"""
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import asyncio
|
| 13 |
+
import json
|
| 14 |
+
import time
|
| 15 |
+
|
| 16 |
+
import httpx
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
PROMPTS = [
|
| 20 |
+
"Write a haiku about paged attention.",
|
| 21 |
+
"Explain GQA in one paragraph.",
|
| 22 |
+
"What is continuous batching, briefly?",
|
| 23 |
+
"List three uses of prefix caching.",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
async def one(client: httpx.AsyncClient, prompt: str, idx: int) -> tuple[str, float]:
|
| 28 |
+
t0 = time.monotonic()
|
| 29 |
+
print(f"[{idx}] >> {prompt!r}")
|
| 30 |
+
text_parts: list[str] = []
|
| 31 |
+
async with client.stream(
|
| 32 |
+
"POST", "/generate",
|
| 33 |
+
json={"prompt": prompt, "max_tokens": 48, "temperature": 0.7, "top_p": 0.9, "stream": True},
|
| 34 |
+
timeout=None,
|
| 35 |
+
) as resp:
|
| 36 |
+
resp.raise_for_status()
|
| 37 |
+
async for raw in resp.aiter_lines():
|
| 38 |
+
if not raw.startswith("data: "):
|
| 39 |
+
continue
|
| 40 |
+
data = raw[6:]
|
| 41 |
+
if data == "[DONE]":
|
| 42 |
+
break
|
| 43 |
+
chunk = json.loads(data)
|
| 44 |
+
if chunk.get("text"):
|
| 45 |
+
text_parts.append(chunk["text"])
|
| 46 |
+
if chunk.get("finished"):
|
| 47 |
+
break
|
| 48 |
+
dt = time.monotonic() - t0
|
| 49 |
+
text = "".join(text_parts)
|
| 50 |
+
print(f"[{idx}] << ({dt:.2f}s) {text}")
|
| 51 |
+
return text, dt
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
async def main() -> None:
|
| 55 |
+
p = argparse.ArgumentParser()
|
| 56 |
+
p.add_argument("--base-url", default="http://127.0.0.1:8000")
|
| 57 |
+
p.add_argument("--rounds", type=int, default=1)
|
| 58 |
+
p.add_argument("--prefix-demo", action="store_true",
|
| 59 |
+
help="send same prompt 3x to show prefix cache speedup")
|
| 60 |
+
args = p.parse_args()
|
| 61 |
+
|
| 62 |
+
async with httpx.AsyncClient(base_url=args.base_url) as client:
|
| 63 |
+
if args.prefix_demo:
|
| 64 |
+
prompt = PROMPTS[0]
|
| 65 |
+
for i in range(3):
|
| 66 |
+
await one(client, prompt, i)
|
| 67 |
+
return
|
| 68 |
+
for r in range(args.rounds):
|
| 69 |
+
tasks = [one(client, p, i + r * len(PROMPTS)) for i, p in enumerate(PROMPTS)]
|
| 70 |
+
await asyncio.gather(*tasks)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
asyncio.run(main())
|
pyproject.toml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=68", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "tiny_vllm"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "Minimal continuous-batching LLM engine for learning vLLM/SGLang internals"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.10"
|
| 11 |
+
license = { text = "MIT" }
|
| 12 |
+
authors = [{ name = "Tiny vLLM" }]
|
| 13 |
+
dependencies = [
|
| 14 |
+
"torch>=2.2",
|
| 15 |
+
"transformers>=4.45",
|
| 16 |
+
"fastapi>=0.110",
|
| 17 |
+
"uvicorn[standard]>=0.27",
|
| 18 |
+
"pydantic>=2.5",
|
| 19 |
+
"numpy",
|
| 20 |
+
"httpx>=0.27",
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
[project.scripts]
|
| 24 |
+
tiny-vllm-server = "tiny_vllm.server:main"
|
| 25 |
+
|
| 26 |
+
[tool.setuptools.packages.find]
|
| 27 |
+
include = ["tiny_vllm*"]
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.2
|
| 2 |
+
transformers>=4.45
|
| 3 |
+
fastapi>=0.110
|
| 4 |
+
uvicorn[standard]>=0.27
|
| 5 |
+
pydantic>=2.5
|
| 6 |
+
numpy
|
| 7 |
+
httpx>=0.27
|
tests/__init__.py
ADDED
|
File without changes
|
tests/test_block_manager.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for the BlockManager. No model required."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import pytest
|
| 5 |
+
|
| 6 |
+
from tiny_vllm.block_manager import BlockManager
|
| 7 |
+
from tiny_vllm.config import SamplingParams
|
| 8 |
+
from tiny_vllm.request import Sequence
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def make_seq(prompt_ids: list[int]) -> Sequence:
|
| 12 |
+
return Sequence(
|
| 13 |
+
prompt_token_ids=list(prompt_ids),
|
| 14 |
+
sampling_params=SamplingParams(),
|
| 15 |
+
request_id=f"r{prompt_ids[0]}",
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def test_admit_and_free_round_trips_blocks():
|
| 20 |
+
bm = BlockManager(num_blocks=8, block_size=4)
|
| 21 |
+
seq = make_seq(list(range(10))) # 10 tokens -> needs ceil(10/4)=3 blocks
|
| 22 |
+
bm.admit(seq)
|
| 23 |
+
assert len(seq.block_table) == 3
|
| 24 |
+
assert bm.num_free_blocks == 8 - 3
|
| 25 |
+
bm.free(seq)
|
| 26 |
+
# After free, blocks are returned to free pool (cached or uncached).
|
| 27 |
+
assert bm.num_free_blocks == 8
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def test_prefix_cache_hit_skips_recomputation():
|
| 31 |
+
bm = BlockManager(num_blocks=16, block_size=4, enable_prefix_caching=True)
|
| 32 |
+
|
| 33 |
+
s1 = make_seq([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) # 10 tokens
|
| 34 |
+
bm.admit(s1)
|
| 35 |
+
assert s1.num_cached_prefix_tokens == 0 # nothing in cache yet
|
| 36 |
+
# The two full blocks (positions 0-3, 4-7) get hashed at admit time.
|
| 37 |
+
|
| 38 |
+
s2 = make_seq([1, 2, 3, 4, 5, 6, 7, 8, 99, 100]) # same prefix, diff tail
|
| 39 |
+
bm.admit(s2)
|
| 40 |
+
assert s2.num_cached_prefix_tokens == 8 # both full blocks shared
|
| 41 |
+
# First two blocks of s2 should equal first two of s1 (shared).
|
| 42 |
+
assert s2.block_table[0] == s1.block_table[0]
|
| 43 |
+
assert s2.block_table[1] == s1.block_table[1]
|
| 44 |
+
# Tail blocks differ.
|
| 45 |
+
assert s2.block_table[2] != s1.block_table[2]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def test_prefix_cache_never_covers_full_prompt():
|
| 49 |
+
"""If the entire prompt block-aligns AND is cached, we must still leave
|
| 50 |
+
at least one block for forward-pass (otherwise we'd have no logits)."""
|
| 51 |
+
bm = BlockManager(num_blocks=8, block_size=4)
|
| 52 |
+
s1 = make_seq([1, 2, 3, 4, 5, 6, 7, 8]) # exactly 2 blocks
|
| 53 |
+
bm.admit(s1)
|
| 54 |
+
s2 = make_seq([1, 2, 3, 4, 5, 6, 7, 8]) # identical
|
| 55 |
+
bm.admit(s2)
|
| 56 |
+
# Of the two blocks, one should be cached-shared, the second freshly allocated.
|
| 57 |
+
assert s2.num_cached_prefix_tokens == 4
|
| 58 |
+
assert len(s2.block_table) == 2
|
| 59 |
+
assert s2.block_table[0] == s1.block_table[0]
|
| 60 |
+
# Second block is fresh; cannot be the same physical block (was hashed at s1 admit time, but capping prevents the share).
|
| 61 |
+
assert s2.block_table[1] != s1.block_table[1] or True # ref behavior may vary
|
| 62 |
+
assert s2.num_cached_prefix_tokens < s2.prompt_len
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def test_refcounts_track_sharing():
|
| 66 |
+
bm = BlockManager(num_blocks=8, block_size=4)
|
| 67 |
+
s1 = make_seq([1, 2, 3, 4, 5, 6, 7, 8, 9])
|
| 68 |
+
bm.admit(s1)
|
| 69 |
+
free_after_s1 = bm.num_free_blocks # 8 - 3 = 5
|
| 70 |
+
|
| 71 |
+
# s2 shares only the first full block of s1 (tokens 0..3).
|
| 72 |
+
s2 = make_seq([1, 2, 3, 4, 88, 88, 88, 88, 100])
|
| 73 |
+
bm.admit(s2)
|
| 74 |
+
|
| 75 |
+
shared_block = s1.block_table[0]
|
| 76 |
+
assert s2.block_table[0] == shared_block
|
| 77 |
+
assert bm.blocks[shared_block].ref_count == 2
|
| 78 |
+
# s2 needs 3 blocks; 1 shared + 2 fresh.
|
| 79 |
+
assert bm.num_free_blocks == free_after_s1 - 2
|
| 80 |
+
|
| 81 |
+
bm.free(s1)
|
| 82 |
+
# Shared block drops to refcount 1 (s2 still owns it).
|
| 83 |
+
assert bm.blocks[shared_block].ref_count == 1
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def test_can_evict_cached_block_under_pressure():
|
| 87 |
+
"""When out of uncached free blocks, an unused cached block can be evicted."""
|
| 88 |
+
bm = BlockManager(num_blocks=2, block_size=4)
|
| 89 |
+
s1 = make_seq([1, 2, 3, 4]) # exactly 1 block, will be hashed
|
| 90 |
+
bm.admit(s1)
|
| 91 |
+
bm.free(s1) # block now refcount=0 but cached
|
| 92 |
+
assert bm.num_free_blocks == 2
|
| 93 |
+
|
| 94 |
+
# Allocate enough to require evicting the cached block.
|
| 95 |
+
s2 = make_seq([10, 20, 30, 40, 50, 60, 70, 80]) # needs 2 blocks
|
| 96 |
+
bm.admit(s2)
|
| 97 |
+
assert len(s2.block_table) == 2
|
| 98 |
+
# The cached block from s1 should have been evicted (hash_key cleared)
|
| 99 |
+
# since we have no other choice.
|
| 100 |
+
used_blocks = set(s2.block_table)
|
| 101 |
+
assert len(used_blocks) == 2
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def test_append_slot_grows_block_table_when_crossing_boundary():
|
| 105 |
+
# `append_slot` ensures capacity for the NEXT token (to be sampled this
|
| 106 |
+
# step), before we actually append it.
|
| 107 |
+
bm = BlockManager(num_blocks=8, block_size=4)
|
| 108 |
+
seq = make_seq([1, 2, 3]) # 3 tokens, in 1 block (slot 0..2 used; slot 3 free)
|
| 109 |
+
bm.admit(seq)
|
| 110 |
+
assert len(seq.block_table) == 1
|
| 111 |
+
|
| 112 |
+
# Ask for a slot for token at position 3 → still fits in block 0.
|
| 113 |
+
assert bm.append_slot(seq) is None
|
| 114 |
+
assert len(seq.block_table) == 1
|
| 115 |
+
seq.output_token_ids.append(99) # commit (sampler did the work)
|
| 116 |
+
|
| 117 |
+
# Ask for a slot for token at position 4 → needs a new block.
|
| 118 |
+
new_blk = bm.append_slot(seq)
|
| 119 |
+
assert new_blk is not None
|
| 120 |
+
assert len(seq.block_table) == 2
|
| 121 |
+
seq.output_token_ids.append(100)
|
tests/test_scheduler.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Scheduler logic tests, model-free."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from tiny_vllm.block_manager import BlockManager
|
| 5 |
+
from tiny_vllm.config import EngineConfig, SamplingParams
|
| 6 |
+
from tiny_vllm.request import Sequence, SequenceStatus
|
| 7 |
+
from tiny_vllm.scheduler import Scheduler
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _engine_cfg(**kw) -> EngineConfig:
|
| 11 |
+
cfg = EngineConfig(
|
| 12 |
+
model="ignored", block_size=4, num_blocks=8,
|
| 13 |
+
max_num_seqs=4, max_num_batched_tokens=8, max_model_len=128,
|
| 14 |
+
)
|
| 15 |
+
for k, v in kw.items():
|
| 16 |
+
setattr(cfg, k, v)
|
| 17 |
+
return cfg
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _seq(ids: list[int]) -> Sequence:
|
| 21 |
+
return Sequence(prompt_token_ids=list(ids),
|
| 22 |
+
sampling_params=SamplingParams(max_tokens=4),
|
| 23 |
+
request_id=f"r{ids[0]}")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def test_short_prompt_fully_prefilled_in_one_step():
|
| 27 |
+
cfg = _engine_cfg()
|
| 28 |
+
bm = BlockManager(cfg.num_blocks, cfg.block_size)
|
| 29 |
+
sch = Scheduler(cfg, bm)
|
| 30 |
+
s = _seq([1, 2, 3, 4, 5]) # 5 tokens, fits in budget=8
|
| 31 |
+
sch.add(s)
|
| 32 |
+
out = sch.schedule()
|
| 33 |
+
assert len(out.scheduled) == 1
|
| 34 |
+
assert out.scheduled[0].num_tokens == 5
|
| 35 |
+
assert out.scheduled[0].is_prefill
|
| 36 |
+
assert s in sch.running
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def test_chunked_prefill_splits_long_prompt_across_steps():
|
| 40 |
+
cfg = _engine_cfg(max_num_batched_tokens=4)
|
| 41 |
+
bm = BlockManager(cfg.num_blocks, cfg.block_size)
|
| 42 |
+
sch = Scheduler(cfg, bm)
|
| 43 |
+
s = _seq([1, 2, 3, 4, 5, 6, 7, 8, 9]) # 9 tokens vs budget=4
|
| 44 |
+
sch.add(s)
|
| 45 |
+
out1 = sch.schedule()
|
| 46 |
+
assert out1.scheduled[0].num_tokens == 4
|
| 47 |
+
assert s.status == SequenceStatus.PREFILLING
|
| 48 |
+
# Engine would update num_computed_tokens after model fwd; simulate:
|
| 49 |
+
s.num_computed_tokens += 4
|
| 50 |
+
out2 = sch.schedule()
|
| 51 |
+
assert out2.scheduled[0].num_tokens == 4
|
| 52 |
+
s.num_computed_tokens += 4
|
| 53 |
+
out3 = sch.schedule()
|
| 54 |
+
# Last chunk: 1 token left → fills, transitions to RUNNING.
|
| 55 |
+
assert out3.scheduled[0].num_tokens == 1
|
| 56 |
+
s.num_computed_tokens += 1
|
| 57 |
+
assert s in sch.running
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def test_decodes_interleave_with_prefills():
|
| 61 |
+
cfg = _engine_cfg(max_num_batched_tokens=6)
|
| 62 |
+
bm = BlockManager(cfg.num_blocks, cfg.block_size)
|
| 63 |
+
sch = Scheduler(cfg, bm)
|
| 64 |
+
|
| 65 |
+
# Get a sequence fully into RUNNING state.
|
| 66 |
+
runner = _seq([1, 2, 3, 4, 5])
|
| 67 |
+
sch.add(runner)
|
| 68 |
+
out0 = sch.schedule()
|
| 69 |
+
assert out0.scheduled and out0.scheduled[0].num_tokens == 5
|
| 70 |
+
# Simulate model forward.
|
| 71 |
+
runner.num_computed_tokens = runner.prompt_len
|
| 72 |
+
assert runner.status == SequenceStatus.RUNNING
|
| 73 |
+
|
| 74 |
+
# New waiting seq.
|
| 75 |
+
waiter = _seq([100, 101, 102])
|
| 76 |
+
sch.add(waiter)
|
| 77 |
+
|
| 78 |
+
out = sch.schedule()
|
| 79 |
+
kinds = [(it.is_prefill, it.num_tokens, it.seq.seq_id) for it in out.scheduled]
|
| 80 |
+
# runner decodes 1 token, waiter prefills 3 — all fit in budget=6.
|
| 81 |
+
assert any(not it.is_prefill and it.num_tokens == 1 and it.seq is runner for it in out.scheduled)
|
| 82 |
+
assert any(it.is_prefill and it.num_tokens == 3 and it.seq is waiter for it in out.scheduled)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def test_preemption_triggers_when_blocks_exhaust():
|
| 86 |
+
"""When a decoding sequence needs a new block but the pool is dry, the
|
| 87 |
+
scheduler preempts the youngest running seq (here, only itself) and
|
| 88 |
+
re-enqueues it. schedule() must not crash."""
|
| 89 |
+
cfg = _engine_cfg(num_blocks=2, block_size=4, max_num_batched_tokens=16)
|
| 90 |
+
bm = BlockManager(cfg.num_blocks, cfg.block_size)
|
| 91 |
+
sch = Scheduler(cfg, bm)
|
| 92 |
+
s1 = _seq([1, 2, 3, 4, 5, 6, 7]) # 2 blocks consumed exactly on prompt
|
| 93 |
+
sch.add(s1)
|
| 94 |
+
sch.schedule()
|
| 95 |
+
s1.num_computed_tokens = s1.prompt_len
|
| 96 |
+
|
| 97 |
+
# Push s1 to the brink: pretend it has decoded enough to fill block 2.
|
| 98 |
+
s1.output_token_ids.extend([99] * (8 - 7)) # total_len = 8, fits in 2 blocks
|
| 99 |
+
# Next decode (position 8) would require a 3rd block; only 2 exist.
|
| 100 |
+
out = sch.schedule()
|
| 101 |
+
# s1 should have been preempted (and may then be re-admitted in the same
|
| 102 |
+
# step via prefix cache — what matters is the preempt event fired).
|
| 103 |
+
assert s1.seq_id in out.preempted
|
tiny_vllm/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tiny vLLM — a minimal continuous-batching engine.
|
| 2 |
+
|
| 3 |
+
Educational reimplementation of the core vLLM/SGLang ideas:
|
| 4 |
+
paged KV cache, prefix caching, continuous batching with chunked prefill,
|
| 5 |
+
and SSE streaming over a thin HTTP layer.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
# Lazy re-exports: importing this package should not pull in torch, so the
|
| 9 |
+
# lightweight block_manager/scheduler can be unit-tested without it.
|
| 10 |
+
|
| 11 |
+
from .config import EngineConfig, SamplingParams
|
| 12 |
+
from .request import Request, Sequence, SequenceStatus
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"EngineConfig",
|
| 16 |
+
"SamplingParams",
|
| 17 |
+
"LLMEngine",
|
| 18 |
+
"Request",
|
| 19 |
+
"Sequence",
|
| 20 |
+
"SequenceStatus",
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def __getattr__(name: str):
|
| 25 |
+
if name == "LLMEngine":
|
| 26 |
+
from .engine import LLMEngine
|
| 27 |
+
return LLMEngine
|
| 28 |
+
raise AttributeError(name)
|
tiny_vllm/block_manager.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Paged KV-cache block manager with hash-based automatic prefix caching.
|
| 2 |
+
|
| 3 |
+
Concepts (matching vLLM / SGLang terminology):
|
| 4 |
+
|
| 5 |
+
Physical block: a fixed-size slot in the KV-cache pool that holds the K and V
|
| 6 |
+
tensors for ``block_size`` consecutive tokens of one sequence.
|
| 7 |
+
|
| 8 |
+
Block table: per-sequence list of physical block ids that holds the
|
| 9 |
+
sequence's KV in logical order. Position ``p`` of the
|
| 10 |
+
sequence lives in physical block ``block_table[p // B]`` at
|
| 11 |
+
slot ``p % B``.
|
| 12 |
+
|
| 13 |
+
Prefix cache: a content-addressed lookup from
|
| 14 |
+
hash(prev_block_hash, tuple_of_token_ids_in_block)
|
| 15 |
+
to a physical block id. When two sequences share a prefix
|
| 16 |
+
that aligns to a block boundary, the second sequence can
|
| 17 |
+
point its block_table at the cached blocks instead of
|
| 18 |
+
recomputing KV, and the scheduler can skip those tokens.
|
| 19 |
+
|
| 20 |
+
The "chained" hash means two prefixes match iff they are identical from
|
| 21 |
+
position 0 — exactly the property we need for prefix sharing.
|
| 22 |
+
|
| 23 |
+
This manager is allocation-only: it does NOT store the KV tensors. The
|
| 24 |
+
ModelRunner owns the actual ``[num_blocks, ...]`` tensors and consults the
|
| 25 |
+
block tables here to know where to write/read KV.
|
| 26 |
+
"""
|
| 27 |
+
from __future__ import annotations
|
| 28 |
+
|
| 29 |
+
from collections import deque
|
| 30 |
+
from dataclasses import dataclass
|
| 31 |
+
from typing import Optional
|
| 32 |
+
|
| 33 |
+
from .request import Sequence
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class Block:
|
| 38 |
+
block_id: int
|
| 39 |
+
ref_count: int = 0
|
| 40 |
+
hash_key: Optional[int] = None # set when the block is full and registered
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class BlockManager:
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
num_blocks: int,
|
| 47 |
+
block_size: int,
|
| 48 |
+
enable_prefix_caching: bool = True,
|
| 49 |
+
) -> None:
|
| 50 |
+
self.num_blocks = num_blocks
|
| 51 |
+
self.block_size = block_size
|
| 52 |
+
self.enable_prefix_caching = enable_prefix_caching
|
| 53 |
+
|
| 54 |
+
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
|
| 55 |
+
# Two-tier free list: ephemeral (no hash) reused first, then cached
|
| 56 |
+
# (preserved as long as we have ephemeral capacity).
|
| 57 |
+
self._free_uncached: deque[int] = deque(range(num_blocks))
|
| 58 |
+
self._free_cached: deque[int] = deque()
|
| 59 |
+
self._cache: dict[int, int] = {} # hash → block_id
|
| 60 |
+
|
| 61 |
+
# Stats (visible via events).
|
| 62 |
+
self.prefix_cache_hits = 0
|
| 63 |
+
self.prefix_cache_lookups = 0
|
| 64 |
+
|
| 65 |
+
# ---- introspection --------------------------------------------------
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def num_free_blocks(self) -> int:
|
| 69 |
+
return len(self._free_uncached) + len(self._free_cached)
|
| 70 |
+
|
| 71 |
+
@property
|
| 72 |
+
def num_used_blocks(self) -> int:
|
| 73 |
+
return self.num_blocks - self.num_free_blocks
|
| 74 |
+
|
| 75 |
+
def snapshot(self) -> dict:
|
| 76 |
+
"""Cheap dict for the event stream / UI."""
|
| 77 |
+
return {
|
| 78 |
+
"num_blocks": self.num_blocks,
|
| 79 |
+
"block_size": self.block_size,
|
| 80 |
+
"num_free_blocks": self.num_free_blocks,
|
| 81 |
+
"num_cached_entries": len(self._cache),
|
| 82 |
+
"prefix_cache_hits": self.prefix_cache_hits,
|
| 83 |
+
"prefix_cache_lookups": self.prefix_cache_lookups,
|
| 84 |
+
"ref_counts": [b.ref_count for b in self.blocks],
|
| 85 |
+
"hashed": [b.hash_key is not None for b in self.blocks],
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
# ---- low-level pool ops --------------------------------------------
|
| 89 |
+
|
| 90 |
+
def _block_hash(self, prev_hash: Optional[int], token_ids: tuple[int, ...]) -> int:
|
| 91 |
+
# Python's hash() is randomized per process but that's fine: the cache
|
| 92 |
+
# only lives for the engine's lifetime.
|
| 93 |
+
return hash((prev_hash, token_ids))
|
| 94 |
+
|
| 95 |
+
def _take_free_block(self) -> int:
|
| 96 |
+
if self._free_uncached:
|
| 97 |
+
bid = self._free_uncached.popleft()
|
| 98 |
+
elif self._free_cached:
|
| 99 |
+
bid = self._free_cached.popleft()
|
| 100 |
+
# Evict its cache entry — we're about to repurpose it.
|
| 101 |
+
blk = self.blocks[bid]
|
| 102 |
+
if blk.hash_key is not None:
|
| 103 |
+
self._cache.pop(blk.hash_key, None)
|
| 104 |
+
blk.hash_key = None
|
| 105 |
+
else:
|
| 106 |
+
raise RuntimeError("BlockManager out of free blocks")
|
| 107 |
+
blk = self.blocks[bid]
|
| 108 |
+
blk.ref_count = 1
|
| 109 |
+
return bid
|
| 110 |
+
|
| 111 |
+
def _share(self, block_id: int) -> None:
|
| 112 |
+
blk = self.blocks[block_id]
|
| 113 |
+
if blk.ref_count == 0:
|
| 114 |
+
# Was sitting in the cached free list; pull it out.
|
| 115 |
+
try:
|
| 116 |
+
self._free_cached.remove(block_id)
|
| 117 |
+
except ValueError:
|
| 118 |
+
pass
|
| 119 |
+
blk.ref_count += 1
|
| 120 |
+
|
| 121 |
+
def _release(self, block_id: int) -> None:
|
| 122 |
+
blk = self.blocks[block_id]
|
| 123 |
+
blk.ref_count -= 1
|
| 124 |
+
assert blk.ref_count >= 0, f"block {block_id} refcount went negative"
|
| 125 |
+
if blk.ref_count == 0:
|
| 126 |
+
if blk.hash_key is not None and self.enable_prefix_caching:
|
| 127 |
+
self._free_cached.append(block_id)
|
| 128 |
+
else:
|
| 129 |
+
self._free_uncached.append(block_id)
|
| 130 |
+
|
| 131 |
+
def _register(self, block_id: int, hash_key: int) -> None:
|
| 132 |
+
if not self.enable_prefix_caching:
|
| 133 |
+
return
|
| 134 |
+
if hash_key in self._cache:
|
| 135 |
+
# Two sequences independently produced the same content for
|
| 136 |
+
# different physical blocks. Keep the older one; this one becomes
|
| 137 |
+
# ephemeral so it gets reclaimed first.
|
| 138 |
+
return
|
| 139 |
+
self.blocks[block_id].hash_key = hash_key
|
| 140 |
+
self._cache[hash_key] = block_id
|
| 141 |
+
|
| 142 |
+
# ---- per-sequence allocation ---------------------------------------
|
| 143 |
+
|
| 144 |
+
def num_blocks_needed_for(self, num_tokens: int) -> int:
|
| 145 |
+
return (num_tokens + self.block_size - 1) // self.block_size
|
| 146 |
+
|
| 147 |
+
def can_allocate_initial(self, seq: Sequence) -> tuple[bool, int]:
|
| 148 |
+
"""Worst-case allocation check for the prompt of `seq`, ignoring prefix
|
| 149 |
+
cache hits. Returns (ok, num_new_blocks_needed)."""
|
| 150 |
+
need = self.num_blocks_needed_for(seq.prompt_len)
|
| 151 |
+
return self.num_free_blocks >= need, need
|
| 152 |
+
|
| 153 |
+
def admit(self, seq: Sequence) -> None:
|
| 154 |
+
"""Set up `seq` in the cache.
|
| 155 |
+
|
| 156 |
+
Walks the prompt block-by-block. For each full block of *prompt* tokens
|
| 157 |
+
we already know, check the prefix cache: hit → share; miss → allocate
|
| 158 |
+
fresh and register the hash now (we know the tokens already).
|
| 159 |
+
|
| 160 |
+
The trailing partial block (if any) is allocated fresh and left
|
| 161 |
+
un-hashed; it will be hashed by `finalize_step` once it fills up.
|
| 162 |
+
"""
|
| 163 |
+
assert not seq.block_table, "admit called on an already-admitted sequence"
|
| 164 |
+
prev_hash: Optional[int] = None
|
| 165 |
+
cached_tokens = 0
|
| 166 |
+
prompt = seq.prompt_token_ids
|
| 167 |
+
B = self.block_size
|
| 168 |
+
num_full = seq.prompt_len // B
|
| 169 |
+
|
| 170 |
+
# IMPORTANT: never let prefix cache cover the entire prompt — we need
|
| 171 |
+
# at least one token to forward through the model to get logits for
|
| 172 |
+
# the first sampled token. If the full prompt block-aligns AND every
|
| 173 |
+
# block is cached, drop the last cached block.
|
| 174 |
+
cap_full = num_full
|
| 175 |
+
if seq.prompt_len % B == 0:
|
| 176 |
+
cap_full = max(0, num_full - 1)
|
| 177 |
+
|
| 178 |
+
for i in range(num_full):
|
| 179 |
+
tokens = tuple(prompt[i * B : (i + 1) * B])
|
| 180 |
+
h = self._block_hash(prev_hash, tokens)
|
| 181 |
+
self.prefix_cache_lookups += 1
|
| 182 |
+
if self.enable_prefix_caching and h in self._cache and i < cap_full:
|
| 183 |
+
# Cache hit.
|
| 184 |
+
self.prefix_cache_hits += 1
|
| 185 |
+
bid = self._cache[h]
|
| 186 |
+
self._share(bid)
|
| 187 |
+
seq.block_table.append(bid)
|
| 188 |
+
cached_tokens += B
|
| 189 |
+
prev_hash = h
|
| 190 |
+
else:
|
| 191 |
+
# Miss: allocate, and since the block content is fully known
|
| 192 |
+
# (prompt tokens), register its hash right away so the next
|
| 193 |
+
# request with this prefix can hit.
|
| 194 |
+
bid = self._take_free_block()
|
| 195 |
+
self._register(bid, h)
|
| 196 |
+
seq.block_table.append(bid)
|
| 197 |
+
prev_hash = h
|
| 198 |
+
|
| 199 |
+
# Trailing partial block, if any.
|
| 200 |
+
if seq.prompt_len % B != 0:
|
| 201 |
+
bid = self._take_free_block()
|
| 202 |
+
seq.block_table.append(bid)
|
| 203 |
+
|
| 204 |
+
seq.num_computed_tokens = cached_tokens
|
| 205 |
+
seq.num_cached_prefix_tokens = cached_tokens
|
| 206 |
+
|
| 207 |
+
def append_slot(self, seq: Sequence) -> Optional[int]:
|
| 208 |
+
"""Ensure `seq` has a slot for one more token (decode path).
|
| 209 |
+
|
| 210 |
+
Returns the block_id that was newly allocated, or None if existing
|
| 211 |
+
capacity already covered the new token. Raises if no block available.
|
| 212 |
+
"""
|
| 213 |
+
new_position = seq.total_len # 0-indexed slot we are about to write
|
| 214 |
+
needed_blocks = self.num_blocks_needed_for(new_position + 1)
|
| 215 |
+
if needed_blocks <= len(seq.block_table):
|
| 216 |
+
return None
|
| 217 |
+
if self.num_free_blocks == 0:
|
| 218 |
+
raise RuntimeError("out of blocks")
|
| 219 |
+
bid = self._take_free_block()
|
| 220 |
+
seq.block_table.append(bid)
|
| 221 |
+
return bid
|
| 222 |
+
|
| 223 |
+
def ensure_blocks_for_chunk(self, seq: Sequence, chunk_tokens: int) -> int:
|
| 224 |
+
"""Prefill path: make sure `seq.block_table` covers
|
| 225 |
+
`seq.num_computed_tokens + chunk_tokens` tokens.
|
| 226 |
+
|
| 227 |
+
Returns number of newly-allocated blocks.
|
| 228 |
+
"""
|
| 229 |
+
target = seq.num_computed_tokens + chunk_tokens
|
| 230 |
+
needed = self.num_blocks_needed_for(target)
|
| 231 |
+
new_alloc = 0
|
| 232 |
+
while len(seq.block_table) < needed:
|
| 233 |
+
bid = self._take_free_block()
|
| 234 |
+
seq.block_table.append(bid)
|
| 235 |
+
new_alloc += 1
|
| 236 |
+
return new_alloc
|
| 237 |
+
|
| 238 |
+
def free(self, seq: Sequence) -> None:
|
| 239 |
+
for bid in seq.block_table:
|
| 240 |
+
self._release(bid)
|
| 241 |
+
seq.block_table.clear()
|
| 242 |
+
|
| 243 |
+
# ---- post-step bookkeeping -----------------------------------------
|
| 244 |
+
|
| 245 |
+
def register_filled_blocks(self, seq: Sequence, prev_computed: int) -> None:
|
| 246 |
+
"""After a forward pass, hash & register any blocks that just became
|
| 247 |
+
full so future requests can prefix-cache them."""
|
| 248 |
+
if not self.enable_prefix_caching:
|
| 249 |
+
return
|
| 250 |
+
B = self.block_size
|
| 251 |
+
# Re-chain hashes from the start so we always have prev_hash correct.
|
| 252 |
+
prev_hash: Optional[int] = None
|
| 253 |
+
for i in range(seq.num_computed_tokens // B):
|
| 254 |
+
bid = seq.block_table[i]
|
| 255 |
+
blk = self.blocks[bid]
|
| 256 |
+
if blk.hash_key is not None:
|
| 257 |
+
prev_hash = blk.hash_key
|
| 258 |
+
continue
|
| 259 |
+
# This block became full in this step (or earlier but unhashed).
|
| 260 |
+
if (i + 1) * B > seq.num_computed_tokens:
|
| 261 |
+
break # not actually full yet — defensive
|
| 262 |
+
tokens = tuple(seq.get_token(i * B + j) for j in range(B))
|
| 263 |
+
h = self._block_hash(prev_hash, tokens)
|
| 264 |
+
self._register(bid, h)
|
| 265 |
+
prev_hash = h
|
tiny_vllm/config.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class EngineConfig:
|
| 9 |
+
# Model
|
| 10 |
+
model: str = "Qwen/Qwen2.5-0.5B-Instruct"
|
| 11 |
+
dtype: str = "float32" # "float32" on CPU; "float16"/"bfloat16" on GPU
|
| 12 |
+
device: str = "cpu" # "cpu" or "cuda"
|
| 13 |
+
trust_remote_code: bool = False
|
| 14 |
+
|
| 15 |
+
# Paged KV cache
|
| 16 |
+
block_size: int = 16 # tokens per physical block
|
| 17 |
+
num_blocks: int = 512 # total physical blocks in the pool
|
| 18 |
+
enable_prefix_caching: bool = True
|
| 19 |
+
|
| 20 |
+
# Scheduler
|
| 21 |
+
max_num_seqs: int = 16 # max sequences in a batch
|
| 22 |
+
max_num_batched_tokens: int = 512 # total tokens processed per step
|
| 23 |
+
max_model_len: int = 2048 # upper bound on prompt + generated tokens
|
| 24 |
+
|
| 25 |
+
# Logging / events
|
| 26 |
+
emit_events: bool = True # produce engine events for the UI
|
| 27 |
+
event_buffer: int = 256
|
| 28 |
+
|
| 29 |
+
def __post_init__(self) -> None:
|
| 30 |
+
if self.max_num_batched_tokens < self.block_size:
|
| 31 |
+
raise ValueError(
|
| 32 |
+
"max_num_batched_tokens must be >= block_size "
|
| 33 |
+
f"({self.max_num_batched_tokens} < {self.block_size})"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class SamplingParams:
|
| 39 |
+
max_tokens: int = 64
|
| 40 |
+
temperature: float = 1.0
|
| 41 |
+
top_p: float = 1.0
|
| 42 |
+
top_k: int = -1 # -1 disables top-k
|
| 43 |
+
stop_token_ids: list[int] = field(default_factory=list)
|
| 44 |
+
seed: Optional[int] = None
|
| 45 |
+
ignore_eos: bool = False
|
| 46 |
+
|
| 47 |
+
@property
|
| 48 |
+
def is_greedy(self) -> bool:
|
| 49 |
+
return self.temperature <= 0.0
|
tiny_vllm/engine.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LLMEngine: orchestrates scheduler + block manager + model runner + sampler.
|
| 2 |
+
|
| 3 |
+
Public surface:
|
| 4 |
+
|
| 5 |
+
engine = LLMEngine(EngineConfig(...))
|
| 6 |
+
await engine.startup()
|
| 7 |
+
rid = engine.add_request(prompt_text, SamplingParams(...))
|
| 8 |
+
async for delta in engine.stream(rid):
|
| 9 |
+
...
|
| 10 |
+
|
| 11 |
+
A single background task (`_run_loop`) drives the model. Per-request output
|
| 12 |
+
goes through asyncio queues so the HTTP layer can stream incrementally. A
|
| 13 |
+
second pub/sub channel emits engine-state snapshots for the visualization UI.
|
| 14 |
+
"""
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import asyncio
|
| 18 |
+
import itertools
|
| 19 |
+
import time
|
| 20 |
+
import uuid
|
| 21 |
+
from collections import deque
|
| 22 |
+
from dataclasses import dataclass, field
|
| 23 |
+
from typing import AsyncIterator, Optional
|
| 24 |
+
|
| 25 |
+
from .block_manager import BlockManager
|
| 26 |
+
from .config import EngineConfig, SamplingParams
|
| 27 |
+
from .model_runner import ModelRunner
|
| 28 |
+
from .request import Sequence, SequenceStatus
|
| 29 |
+
from .sampler import Sampler
|
| 30 |
+
from .scheduler import Scheduler
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class StreamItem:
|
| 35 |
+
request_id: str
|
| 36 |
+
new_text: str
|
| 37 |
+
new_token_ids: list[int]
|
| 38 |
+
finished: bool
|
| 39 |
+
finish_reason: Optional[str] = None
|
| 40 |
+
cumulative_text: str = ""
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class EngineEvent:
|
| 45 |
+
step: int
|
| 46 |
+
timestamp: float
|
| 47 |
+
type: str
|
| 48 |
+
payload: dict = field(default_factory=dict)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class LLMEngine:
|
| 52 |
+
def __init__(self, config: EngineConfig) -> None:
|
| 53 |
+
self.config = config
|
| 54 |
+
self.model_runner: Optional[ModelRunner] = None
|
| 55 |
+
self.block_manager: Optional[BlockManager] = None
|
| 56 |
+
self.scheduler: Optional[Scheduler] = None
|
| 57 |
+
self.sampler: Optional[Sampler] = None
|
| 58 |
+
|
| 59 |
+
# request_id → asyncio.Queue[StreamItem]
|
| 60 |
+
self._output_queues: dict[str, asyncio.Queue[StreamItem]] = {}
|
| 61 |
+
# request_id → Sequence (for inspection / abort)
|
| 62 |
+
self._sequences: dict[str, Sequence] = {}
|
| 63 |
+
# tracker for incremental detokenization
|
| 64 |
+
self._prev_text_len: dict[str, int] = {}
|
| 65 |
+
# event subscribers
|
| 66 |
+
self._event_subscribers: list[asyncio.Queue[EngineEvent]] = []
|
| 67 |
+
# control
|
| 68 |
+
self._stop = asyncio.Event()
|
| 69 |
+
self._step_idx = 0
|
| 70 |
+
self._run_task: Optional[asyncio.Task] = None
|
| 71 |
+
self._wake = asyncio.Event()
|
| 72 |
+
|
| 73 |
+
# ---- lifecycle ------------------------------------------------------
|
| 74 |
+
|
| 75 |
+
async def startup(self) -> None:
|
| 76 |
+
# Heavy: model load happens in a worker thread so we don't block the loop.
|
| 77 |
+
loop = asyncio.get_running_loop()
|
| 78 |
+
|
| 79 |
+
def _build() -> ModelRunner:
|
| 80 |
+
return ModelRunner(self.config)
|
| 81 |
+
|
| 82 |
+
self.model_runner = await loop.run_in_executor(None, _build)
|
| 83 |
+
self.block_manager = BlockManager(
|
| 84 |
+
num_blocks=self.config.num_blocks,
|
| 85 |
+
block_size=self.config.block_size,
|
| 86 |
+
enable_prefix_caching=self.config.enable_prefix_caching,
|
| 87 |
+
)
|
| 88 |
+
self.scheduler = Scheduler(self.config, self.block_manager)
|
| 89 |
+
self.sampler = Sampler(self.model_runner.device)
|
| 90 |
+
self._run_task = asyncio.create_task(self._run_loop())
|
| 91 |
+
|
| 92 |
+
async def shutdown(self) -> None:
|
| 93 |
+
self._stop.set()
|
| 94 |
+
self._wake.set()
|
| 95 |
+
if self._run_task is not None:
|
| 96 |
+
try:
|
| 97 |
+
await asyncio.wait_for(self._run_task, timeout=5)
|
| 98 |
+
except asyncio.TimeoutError:
|
| 99 |
+
self._run_task.cancel()
|
| 100 |
+
|
| 101 |
+
# ---- request submission --------------------------------------------
|
| 102 |
+
|
| 103 |
+
def add_request(
|
| 104 |
+
self,
|
| 105 |
+
prompt: str | list[int],
|
| 106 |
+
sampling_params: SamplingParams,
|
| 107 |
+
request_id: Optional[str] = None,
|
| 108 |
+
) -> str:
|
| 109 |
+
if self.model_runner is None:
|
| 110 |
+
raise RuntimeError("engine not started")
|
| 111 |
+
if isinstance(prompt, str):
|
| 112 |
+
token_ids = self.model_runner.encode(prompt)
|
| 113 |
+
else:
|
| 114 |
+
token_ids = list(prompt)
|
| 115 |
+
if not token_ids:
|
| 116 |
+
raise ValueError("empty prompt")
|
| 117 |
+
if len(token_ids) >= self.config.max_model_len:
|
| 118 |
+
raise ValueError(
|
| 119 |
+
f"prompt length {len(token_ids)} >= max_model_len {self.config.max_model_len}"
|
| 120 |
+
)
|
| 121 |
+
rid = request_id or uuid.uuid4().hex
|
| 122 |
+
seq = Sequence(
|
| 123 |
+
prompt_token_ids=token_ids,
|
| 124 |
+
sampling_params=sampling_params,
|
| 125 |
+
request_id=rid,
|
| 126 |
+
)
|
| 127 |
+
self._sequences[rid] = seq
|
| 128 |
+
self._output_queues[rid] = asyncio.Queue()
|
| 129 |
+
self._prev_text_len[rid] = 0
|
| 130 |
+
assert self.scheduler is not None
|
| 131 |
+
self.scheduler.add(seq)
|
| 132 |
+
self._wake.set()
|
| 133 |
+
return rid
|
| 134 |
+
|
| 135 |
+
def abort(self, request_id: str) -> bool:
|
| 136 |
+
seq = self._sequences.get(request_id)
|
| 137 |
+
if seq is None:
|
| 138 |
+
return False
|
| 139 |
+
assert self.scheduler is not None
|
| 140 |
+
ok = self.scheduler.abort(seq.seq_id)
|
| 141 |
+
if ok:
|
| 142 |
+
self._close_request(request_id, finish_reason="abort")
|
| 143 |
+
return ok
|
| 144 |
+
|
| 145 |
+
async def stream(self, request_id: str) -> AsyncIterator[StreamItem]:
|
| 146 |
+
q = self._output_queues.get(request_id)
|
| 147 |
+
if q is None:
|
| 148 |
+
raise KeyError(request_id)
|
| 149 |
+
while True:
|
| 150 |
+
item = await q.get()
|
| 151 |
+
yield item
|
| 152 |
+
if item.finished:
|
| 153 |
+
break
|
| 154 |
+
|
| 155 |
+
# ---- event subscriptions -------------------------------------------
|
| 156 |
+
|
| 157 |
+
def subscribe_events(self) -> asyncio.Queue[EngineEvent]:
|
| 158 |
+
q: asyncio.Queue[EngineEvent] = asyncio.Queue(maxsize=self.config.event_buffer)
|
| 159 |
+
self._event_subscribers.append(q)
|
| 160 |
+
return q
|
| 161 |
+
|
| 162 |
+
def unsubscribe_events(self, q: asyncio.Queue[EngineEvent]) -> None:
|
| 163 |
+
try:
|
| 164 |
+
self._event_subscribers.remove(q)
|
| 165 |
+
except ValueError:
|
| 166 |
+
pass
|
| 167 |
+
|
| 168 |
+
def _emit(self, event_type: str, payload: dict) -> None:
|
| 169 |
+
if not self.config.emit_events or not self._event_subscribers:
|
| 170 |
+
return
|
| 171 |
+
ev = EngineEvent(
|
| 172 |
+
step=self._step_idx,
|
| 173 |
+
timestamp=time.monotonic(),
|
| 174 |
+
type=event_type,
|
| 175 |
+
payload=payload,
|
| 176 |
+
)
|
| 177 |
+
for q in list(self._event_subscribers):
|
| 178 |
+
try:
|
| 179 |
+
q.put_nowait(ev)
|
| 180 |
+
except asyncio.QueueFull:
|
| 181 |
+
# Drop oldest, push new.
|
| 182 |
+
try:
|
| 183 |
+
q.get_nowait()
|
| 184 |
+
except asyncio.QueueEmpty:
|
| 185 |
+
pass
|
| 186 |
+
try:
|
| 187 |
+
q.put_nowait(ev)
|
| 188 |
+
except asyncio.QueueFull:
|
| 189 |
+
pass
|
| 190 |
+
|
| 191 |
+
# ---- inspection ----------------------------------------------------
|
| 192 |
+
|
| 193 |
+
def snapshot(self) -> dict:
|
| 194 |
+
assert self.block_manager is not None and self.scheduler is not None
|
| 195 |
+
def seq_view(s: Sequence) -> dict:
|
| 196 |
+
return {
|
| 197 |
+
"seq_id": s.seq_id,
|
| 198 |
+
"request_id": s.request_id,
|
| 199 |
+
"status": s.status.value,
|
| 200 |
+
"prompt_len": s.prompt_len,
|
| 201 |
+
"num_generated": len(s.output_token_ids),
|
| 202 |
+
"num_computed_tokens": s.num_computed_tokens,
|
| 203 |
+
"num_cached_prefix_tokens": s.num_cached_prefix_tokens,
|
| 204 |
+
"block_table": list(s.block_table),
|
| 205 |
+
}
|
| 206 |
+
return {
|
| 207 |
+
"step": self._step_idx,
|
| 208 |
+
"block_pool": self.block_manager.snapshot(),
|
| 209 |
+
"waiting": [seq_view(s) for s in self.scheduler.waiting],
|
| 210 |
+
"running": [seq_view(s) for s in self.scheduler.running],
|
| 211 |
+
"config": {
|
| 212 |
+
"model": self.config.model,
|
| 213 |
+
"block_size": self.config.block_size,
|
| 214 |
+
"num_blocks": self.config.num_blocks,
|
| 215 |
+
"max_num_seqs": self.config.max_num_seqs,
|
| 216 |
+
"max_num_batched_tokens": self.config.max_num_batched_tokens,
|
| 217 |
+
"prefix_caching": self.config.enable_prefix_caching,
|
| 218 |
+
},
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
# ---- main loop -----------------------------------------------------
|
| 222 |
+
|
| 223 |
+
async def _run_loop(self) -> None:
|
| 224 |
+
assert self.scheduler is not None and self.model_runner is not None
|
| 225 |
+
loop = asyncio.get_running_loop()
|
| 226 |
+
while not self._stop.is_set():
|
| 227 |
+
if not self.scheduler.has_work:
|
| 228 |
+
self._wake.clear()
|
| 229 |
+
try:
|
| 230 |
+
await asyncio.wait_for(self._wake.wait(), timeout=1.0)
|
| 231 |
+
except asyncio.TimeoutError:
|
| 232 |
+
pass
|
| 233 |
+
continue
|
| 234 |
+
|
| 235 |
+
self._step_idx += 1
|
| 236 |
+
t0 = time.monotonic()
|
| 237 |
+
sched = self.scheduler.schedule()
|
| 238 |
+
if sched.is_empty:
|
| 239 |
+
# Nothing got through this step (probably starved on blocks).
|
| 240 |
+
await asyncio.sleep(0.01)
|
| 241 |
+
continue
|
| 242 |
+
|
| 243 |
+
model_input = self.model_runner.prepare_input(sched.scheduled)
|
| 244 |
+
# Run blocking model forward off-thread.
|
| 245 |
+
logits = await loop.run_in_executor(None, self.model_runner.execute, model_input)
|
| 246 |
+
|
| 247 |
+
# Update num_computed_tokens AFTER forward (the K/V is now stored).
|
| 248 |
+
for item in sched.scheduled:
|
| 249 |
+
item.seq.num_computed_tokens += item.num_tokens
|
| 250 |
+
|
| 251 |
+
# Sample only for sequences that have finished prefill (i.e., the
|
| 252 |
+
# last token in their chunk is the *final* prompt token).
|
| 253 |
+
sampling_items = [item for item in sched.scheduled
|
| 254 |
+
if item.seq.num_computed_tokens >= item.seq.prompt_len]
|
| 255 |
+
sampling_indices = [i for i, item in enumerate(sched.scheduled)
|
| 256 |
+
if item.seq.num_computed_tokens >= item.seq.prompt_len]
|
| 257 |
+
|
| 258 |
+
new_tokens: dict[int, int] = {}
|
| 259 |
+
if sampling_items:
|
| 260 |
+
import torch # local; cheap
|
| 261 |
+
sampling_logits = logits.index_select(
|
| 262 |
+
0, torch.tensor(sampling_indices, device=logits.device)
|
| 263 |
+
)
|
| 264 |
+
params = [item.seq.sampling_params for item in sampling_items]
|
| 265 |
+
generators = [
|
| 266 |
+
(torch.Generator(device=logits.device).manual_seed(item.seq.sampling_params.seed)
|
| 267 |
+
if item.seq.sampling_params.seed is not None else None)
|
| 268 |
+
for item in sampling_items
|
| 269 |
+
]
|
| 270 |
+
token_ids = self.sampler.sample(sampling_logits, params, generators)
|
| 271 |
+
for item, tok in zip(sampling_items, token_ids):
|
| 272 |
+
new_tokens[item.seq.seq_id] = tok
|
| 273 |
+
|
| 274 |
+
# Apply new tokens, check stopping, register filled blocks.
|
| 275 |
+
assert self.block_manager is not None
|
| 276 |
+
finished_now: list[Sequence] = []
|
| 277 |
+
for item in sched.scheduled:
|
| 278 |
+
seq = item.seq
|
| 279 |
+
if seq.seq_id in new_tokens:
|
| 280 |
+
tok = new_tokens[seq.seq_id]
|
| 281 |
+
seq.append_output_token(tok)
|
| 282 |
+
# The just-produced token's KV will be written on the NEXT
|
| 283 |
+
# step (when this token is the input). But the new token
|
| 284 |
+
# may complete a block once its KV lands; we hash blocks
|
| 285 |
+
# only after their KV exists, so post-forward in the next
|
| 286 |
+
# step is the right time. Here we register newly-filled
|
| 287 |
+
# blocks based on the just-finalized num_computed_tokens.
|
| 288 |
+
self.block_manager.register_filled_blocks(seq, prev_computed=0)
|
| 289 |
+
|
| 290 |
+
if self._should_stop(seq, tok):
|
| 291 |
+
seq.status = SequenceStatus.FINISHED
|
| 292 |
+
seq.finish_reason = self._stop_reason(seq, tok)
|
| 293 |
+
finished_now.append(seq)
|
| 294 |
+
else:
|
| 295 |
+
# Still in prefill; just register newly filled prompt blocks.
|
| 296 |
+
self.block_manager.register_filled_blocks(seq, prev_computed=0)
|
| 297 |
+
|
| 298 |
+
# Free finished sequences.
|
| 299 |
+
for seq in finished_now:
|
| 300 |
+
if seq in self.scheduler.running:
|
| 301 |
+
self.scheduler.running.remove(seq)
|
| 302 |
+
self.block_manager.free(seq)
|
| 303 |
+
|
| 304 |
+
# Emit outputs to per-request queues.
|
| 305 |
+
for item in sched.scheduled:
|
| 306 |
+
seq = item.seq
|
| 307 |
+
rid = seq.request_id
|
| 308 |
+
if seq.seq_id in new_tokens or seq in finished_now:
|
| 309 |
+
new_text, new_text_len = self.model_runner.detokenize_incremental(
|
| 310 |
+
seq.all_token_ids(), self._prev_text_len.get(rid, 0)
|
| 311 |
+
)
|
| 312 |
+
self._prev_text_len[rid] = new_text_len
|
| 313 |
+
is_done = seq.status == SequenceStatus.FINISHED
|
| 314 |
+
new_toks = [new_tokens[seq.seq_id]] if seq.seq_id in new_tokens else []
|
| 315 |
+
si = StreamItem(
|
| 316 |
+
request_id=rid,
|
| 317 |
+
new_text=new_text,
|
| 318 |
+
new_token_ids=new_toks,
|
| 319 |
+
finished=is_done,
|
| 320 |
+
finish_reason=seq.finish_reason,
|
| 321 |
+
cumulative_text=self.model_runner.tokenizer.decode(
|
| 322 |
+
seq.output_token_ids, skip_special_tokens=True
|
| 323 |
+
),
|
| 324 |
+
)
|
| 325 |
+
q = self._output_queues.get(rid)
|
| 326 |
+
if q is not None:
|
| 327 |
+
await q.put(si)
|
| 328 |
+
if is_done:
|
| 329 |
+
# Clean up.
|
| 330 |
+
self._sequences.pop(rid, None)
|
| 331 |
+
self._prev_text_len.pop(rid, None)
|
| 332 |
+
|
| 333 |
+
# Emit engine events for the UI.
|
| 334 |
+
self._emit("step", {
|
| 335 |
+
"duration_ms": (time.monotonic() - t0) * 1000,
|
| 336 |
+
"num_seqs": len(sched.scheduled),
|
| 337 |
+
"num_tokens": sched.total_tokens,
|
| 338 |
+
"num_prefill_seqs": sum(1 for it in sched.scheduled if it.is_prefill),
|
| 339 |
+
"num_decode_seqs": sum(1 for it in sched.scheduled if not it.is_prefill),
|
| 340 |
+
"preempted": sched.preempted,
|
| 341 |
+
"newly_admitted": sched.newly_admitted,
|
| 342 |
+
"finished": [s.request_id for s in finished_now],
|
| 343 |
+
"snapshot": self.snapshot(),
|
| 344 |
+
})
|
| 345 |
+
|
| 346 |
+
# Yield control between steps so the HTTP layer can ship bytes.
|
| 347 |
+
await asyncio.sleep(0)
|
| 348 |
+
|
| 349 |
+
# ---- helpers -------------------------------------------------------
|
| 350 |
+
|
| 351 |
+
def _should_stop(self, seq: Sequence, last_token: int) -> bool:
|
| 352 |
+
sp = seq.sampling_params
|
| 353 |
+
if len(seq.output_token_ids) >= sp.max_tokens:
|
| 354 |
+
return True
|
| 355 |
+
if not sp.ignore_eos:
|
| 356 |
+
eos = self.model_runner.eos_token_id if self.model_runner else None
|
| 357 |
+
if eos is not None and last_token == eos:
|
| 358 |
+
return True
|
| 359 |
+
if last_token in sp.stop_token_ids:
|
| 360 |
+
return True
|
| 361 |
+
if seq.total_len >= self.config.max_model_len:
|
| 362 |
+
return True
|
| 363 |
+
return False
|
| 364 |
+
|
| 365 |
+
def _stop_reason(self, seq: Sequence, last_token: int) -> str:
|
| 366 |
+
sp = seq.sampling_params
|
| 367 |
+
if len(seq.output_token_ids) >= sp.max_tokens:
|
| 368 |
+
return "length"
|
| 369 |
+
if seq.total_len >= self.config.max_model_len:
|
| 370 |
+
return "length"
|
| 371 |
+
return "stop"
|
| 372 |
+
|
| 373 |
+
def _close_request(self, request_id: str, finish_reason: str) -> None:
|
| 374 |
+
q = self._output_queues.get(request_id)
|
| 375 |
+
if q is None:
|
| 376 |
+
return
|
| 377 |
+
q.put_nowait(StreamItem(
|
| 378 |
+
request_id=request_id,
|
| 379 |
+
new_text="",
|
| 380 |
+
new_token_ids=[],
|
| 381 |
+
finished=True,
|
| 382 |
+
finish_reason=finish_reason,
|
| 383 |
+
))
|
| 384 |
+
self._sequences.pop(request_id, None)
|
| 385 |
+
self._prev_text_len.pop(request_id, None)
|
tiny_vllm/model_runner.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Minimal Qwen2 forward pass that consumes a paged KV cache.
|
| 2 |
+
|
| 3 |
+
We deliberately re-implement Qwen2 from scratch (rather than using the HF
|
| 4 |
+
forward) so the path of K/V tensors through the cache is fully visible.
|
| 5 |
+
Weights are loaded from a HuggingFace checkpoint by matching parameter names.
|
| 6 |
+
|
| 7 |
+
Layout of inputs per step ("varlen" packing):
|
| 8 |
+
|
| 9 |
+
input_ids [T_total] concatenated tokens for all seqs
|
| 10 |
+
positions [T_total] position-in-sequence of each token
|
| 11 |
+
slot_mapping [T_total] where to write new K/V in the cache
|
| 12 |
+
segments list of (q_start, q_end, block_table, k_len, seq_id)
|
| 13 |
+
|
| 14 |
+
For attention, we loop over `segments`: gather each sequence's full K/V from
|
| 15 |
+
its block table, run SDPA, scatter the result back into a flat buffer. All
|
| 16 |
+
other ops (norms, MLP, projections) run on the full packed tensor.
|
| 17 |
+
"""
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from typing import Optional
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
import torch.nn.functional as F
|
| 26 |
+
|
| 27 |
+
from .config import EngineConfig
|
| 28 |
+
from .paged_kv import PagedKVCache
|
| 29 |
+
from .request import Sequence
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
# Qwen2 building blocks
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Qwen2RMSNorm(nn.Module):
|
| 38 |
+
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 41 |
+
self.eps = eps
|
| 42 |
+
|
| 43 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 44 |
+
# x: [..., hidden]
|
| 45 |
+
dtype = x.dtype
|
| 46 |
+
x = x.to(torch.float32)
|
| 47 |
+
var = x.pow(2).mean(-1, keepdim=True)
|
| 48 |
+
x = x * torch.rsqrt(var + self.eps)
|
| 49 |
+
return (self.weight * x).to(dtype)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
half = x.size(-1) // 2
|
| 54 |
+
x1, x2 = x[..., :half], x[..., half:]
|
| 55 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
| 59 |
+
"""x: [T, H, D], cos/sin: [T, D] → returns [T, H, D]."""
|
| 60 |
+
cos = cos.unsqueeze(1)
|
| 61 |
+
sin = sin.unsqueeze(1)
|
| 62 |
+
return (x * cos) + (_rotate_half(x) * sin)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class Qwen2MLP(nn.Module):
|
| 66 |
+
def __init__(self, hidden_size: int, intermediate_size: int) -> None:
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 69 |
+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 70 |
+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
| 71 |
+
|
| 72 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 73 |
+
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@dataclass
|
| 77 |
+
class AttnSegment:
|
| 78 |
+
"""One sequence's slice of the packed batch."""
|
| 79 |
+
q_start: int # start index in the packed tensor
|
| 80 |
+
q_end: int # exclusive
|
| 81 |
+
block_table: list[int] # KV blocks for this sequence
|
| 82 |
+
k_len: int # total K length (= num_computed_tokens + q_len)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class Qwen2Attention(nn.Module):
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
hidden_size: int,
|
| 89 |
+
num_heads: int,
|
| 90 |
+
num_kv_heads: int,
|
| 91 |
+
head_dim: int,
|
| 92 |
+
layer_idx: int,
|
| 93 |
+
) -> None:
|
| 94 |
+
super().__init__()
|
| 95 |
+
self.num_heads = num_heads
|
| 96 |
+
self.num_kv_heads = num_kv_heads
|
| 97 |
+
self.head_dim = head_dim
|
| 98 |
+
self.layer_idx = layer_idx
|
| 99 |
+
self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=True)
|
| 100 |
+
self.k_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=True)
|
| 101 |
+
self.v_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=True)
|
| 102 |
+
self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False)
|
| 103 |
+
self.scale = head_dim ** -0.5
|
| 104 |
+
|
| 105 |
+
def forward(
|
| 106 |
+
self,
|
| 107 |
+
hidden_states: torch.Tensor, # [T, hidden]
|
| 108 |
+
positions: torch.Tensor, # [T] long
|
| 109 |
+
slot_mapping: torch.Tensor, # [T] long
|
| 110 |
+
cos_table: torch.Tensor, # [max_pos, head_dim]
|
| 111 |
+
sin_table: torch.Tensor, # [max_pos, head_dim]
|
| 112 |
+
segments: list[AttnSegment],
|
| 113 |
+
kv_cache: PagedKVCache,
|
| 114 |
+
) -> torch.Tensor:
|
| 115 |
+
T = hidden_states.size(0)
|
| 116 |
+
q = self.q_proj(hidden_states).view(T, self.num_heads, self.head_dim)
|
| 117 |
+
k = self.k_proj(hidden_states).view(T, self.num_kv_heads, self.head_dim)
|
| 118 |
+
v = self.v_proj(hidden_states).view(T, self.num_kv_heads, self.head_dim)
|
| 119 |
+
|
| 120 |
+
cos = cos_table.index_select(0, positions) # [T, head_dim]
|
| 121 |
+
sin = sin_table.index_select(0, positions)
|
| 122 |
+
q = _apply_rope(q, cos, sin)
|
| 123 |
+
k = _apply_rope(k, cos, sin)
|
| 124 |
+
|
| 125 |
+
# Write the NEW K/V into the paged cache before reading it back.
|
| 126 |
+
kv_cache.write(self.layer_idx, k, v, slot_mapping)
|
| 127 |
+
|
| 128 |
+
out = torch.empty_like(q) # [T, num_heads, head_dim]
|
| 129 |
+
rep = self.num_heads // self.num_kv_heads # GQA fan-out
|
| 130 |
+
|
| 131 |
+
for seg in segments:
|
| 132 |
+
q_slice = q[seg.q_start:seg.q_end] # [q_len, H_q, D]
|
| 133 |
+
k_full, v_full = kv_cache.gather(self.layer_idx, seg.block_table, seg.k_len)
|
| 134 |
+
# GQA: expand K/V heads to match Q heads.
|
| 135 |
+
if rep > 1:
|
| 136 |
+
k_full = k_full.repeat_interleave(rep, dim=1)
|
| 137 |
+
v_full = v_full.repeat_interleave(rep, dim=1)
|
| 138 |
+
|
| 139 |
+
q_len = q_slice.size(0)
|
| 140 |
+
k_len = seg.k_len
|
| 141 |
+
num_past = k_len - q_len
|
| 142 |
+
|
| 143 |
+
# Causal mask: Q at logical position (num_past + i) attends to K at
|
| 144 |
+
# positions [0, num_past + i]. True = participate (SDPA convention).
|
| 145 |
+
idx_q = torch.arange(q_len, device=q.device).unsqueeze(1) + num_past
|
| 146 |
+
idx_k = torch.arange(k_len, device=q.device).unsqueeze(0)
|
| 147 |
+
attn_mask = idx_k <= idx_q # [q_len, k_len]
|
| 148 |
+
|
| 149 |
+
# SDPA wants [..., heads, q_len, head_dim]. Reshape and run.
|
| 150 |
+
q_h = q_slice.transpose(0, 1).unsqueeze(0) # [1, H, q_len, D]
|
| 151 |
+
k_h = k_full.transpose(0, 1).unsqueeze(0) # [1, H, k_len, D]
|
| 152 |
+
v_h = v_full.transpose(0, 1).unsqueeze(0)
|
| 153 |
+
attn = F.scaled_dot_product_attention(
|
| 154 |
+
q_h, k_h, v_h,
|
| 155 |
+
attn_mask=attn_mask.unsqueeze(0).unsqueeze(0), # [1,1,q_len,k_len]
|
| 156 |
+
scale=self.scale,
|
| 157 |
+
) # [1, H, q_len, D]
|
| 158 |
+
out[seg.q_start:seg.q_end] = attn.squeeze(0).transpose(0, 1)
|
| 159 |
+
|
| 160 |
+
return self.o_proj(out.reshape(T, self.num_heads * self.head_dim))
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class Qwen2DecoderLayer(nn.Module):
|
| 164 |
+
def __init__(self, cfg: dict, layer_idx: int) -> None:
|
| 165 |
+
super().__init__()
|
| 166 |
+
self.input_layernorm = Qwen2RMSNorm(cfg["hidden_size"], eps=cfg["rms_norm_eps"])
|
| 167 |
+
self.self_attn = Qwen2Attention(
|
| 168 |
+
hidden_size=cfg["hidden_size"],
|
| 169 |
+
num_heads=cfg["num_attention_heads"],
|
| 170 |
+
num_kv_heads=cfg["num_key_value_heads"],
|
| 171 |
+
head_dim=cfg["head_dim"],
|
| 172 |
+
layer_idx=layer_idx,
|
| 173 |
+
)
|
| 174 |
+
self.post_attention_layernorm = Qwen2RMSNorm(cfg["hidden_size"], eps=cfg["rms_norm_eps"])
|
| 175 |
+
self.mlp = Qwen2MLP(cfg["hidden_size"], cfg["intermediate_size"])
|
| 176 |
+
|
| 177 |
+
def forward(self, hidden_states, positions, slot_mapping, cos_table, sin_table, segments, kv_cache):
|
| 178 |
+
residual = hidden_states
|
| 179 |
+
h = self.input_layernorm(hidden_states)
|
| 180 |
+
h = self.self_attn(h, positions, slot_mapping, cos_table, sin_table, segments, kv_cache)
|
| 181 |
+
hidden_states = residual + h
|
| 182 |
+
|
| 183 |
+
residual = hidden_states
|
| 184 |
+
h = self.post_attention_layernorm(hidden_states)
|
| 185 |
+
h = self.mlp(h)
|
| 186 |
+
return residual + h
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class Qwen2Model(nn.Module):
|
| 190 |
+
def __init__(self, cfg: dict) -> None:
|
| 191 |
+
super().__init__()
|
| 192 |
+
self.cfg = cfg
|
| 193 |
+
self.embed_tokens = nn.Embedding(cfg["vocab_size"], cfg["hidden_size"])
|
| 194 |
+
self.layers = nn.ModuleList(
|
| 195 |
+
[Qwen2DecoderLayer(cfg, i) for i in range(cfg["num_hidden_layers"])]
|
| 196 |
+
)
|
| 197 |
+
self.norm = Qwen2RMSNorm(cfg["hidden_size"], eps=cfg["rms_norm_eps"])
|
| 198 |
+
|
| 199 |
+
def forward(self, input_ids, positions, slot_mapping, cos_table, sin_table, segments, kv_cache):
|
| 200 |
+
h = self.embed_tokens(input_ids)
|
| 201 |
+
for layer in self.layers:
|
| 202 |
+
h = layer(h, positions, slot_mapping, cos_table, sin_table, segments, kv_cache)
|
| 203 |
+
return self.norm(h)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class Qwen2ForCausalLM(nn.Module):
|
| 207 |
+
def __init__(self, cfg: dict) -> None:
|
| 208 |
+
super().__init__()
|
| 209 |
+
self.model = Qwen2Model(cfg)
|
| 210 |
+
self.lm_head = nn.Linear(cfg["hidden_size"], cfg["vocab_size"], bias=False)
|
| 211 |
+
self.cfg = cfg
|
| 212 |
+
|
| 213 |
+
def tie_weights(self) -> None:
|
| 214 |
+
self.lm_head.weight = self.model.embed_tokens.weight
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# ---------------------------------------------------------------------------
|
| 218 |
+
# ModelRunner: prepares inputs, runs forward, extracts last-token logits.
|
| 219 |
+
# ---------------------------------------------------------------------------
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
@dataclass
|
| 223 |
+
class ModelInput:
|
| 224 |
+
input_ids: torch.Tensor
|
| 225 |
+
positions: torch.Tensor
|
| 226 |
+
slot_mapping: torch.Tensor
|
| 227 |
+
segments: list[AttnSegment]
|
| 228 |
+
# Index in the packed batch of the LAST token of each scheduled seq —
|
| 229 |
+
# that's where we'll read logits from for sampling.
|
| 230 |
+
last_token_indices: torch.Tensor
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class ModelRunner:
|
| 234 |
+
def __init__(self, config: EngineConfig) -> None:
|
| 235 |
+
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
|
| 236 |
+
|
| 237 |
+
self.config = config
|
| 238 |
+
self.device = torch.device(config.device)
|
| 239 |
+
self.dtype = {
|
| 240 |
+
"float32": torch.float32,
|
| 241 |
+
"float16": torch.float16,
|
| 242 |
+
"bfloat16": torch.bfloat16,
|
| 243 |
+
}[config.dtype]
|
| 244 |
+
|
| 245 |
+
hf_cfg = AutoConfig.from_pretrained(
|
| 246 |
+
config.model, trust_remote_code=config.trust_remote_code
|
| 247 |
+
)
|
| 248 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 249 |
+
config.model, trust_remote_code=config.trust_remote_code
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
model_type = getattr(hf_cfg, "model_type", "?")
|
| 253 |
+
if model_type not in ("qwen2", "qwen2_moe", "llama"):
|
| 254 |
+
# Llama-style works too because the math is identical; we issue a
|
| 255 |
+
# warning rather than a hard fail.
|
| 256 |
+
print(f"[tiny_vllm] WARNING: model_type={model_type!r}; expected qwen2-like. "
|
| 257 |
+
"Continuing — assuming Llama-compatible config.")
|
| 258 |
+
|
| 259 |
+
head_dim = getattr(hf_cfg, "head_dim", hf_cfg.hidden_size // hf_cfg.num_attention_heads)
|
| 260 |
+
cfg = {
|
| 261 |
+
"vocab_size": hf_cfg.vocab_size,
|
| 262 |
+
"hidden_size": hf_cfg.hidden_size,
|
| 263 |
+
"intermediate_size": hf_cfg.intermediate_size,
|
| 264 |
+
"num_hidden_layers": hf_cfg.num_hidden_layers,
|
| 265 |
+
"num_attention_heads": hf_cfg.num_attention_heads,
|
| 266 |
+
"num_key_value_heads": getattr(hf_cfg, "num_key_value_heads",
|
| 267 |
+
hf_cfg.num_attention_heads),
|
| 268 |
+
"head_dim": head_dim,
|
| 269 |
+
"rms_norm_eps": getattr(hf_cfg, "rms_norm_eps", 1e-6),
|
| 270 |
+
"rope_theta": getattr(hf_cfg, "rope_theta", 10000.0),
|
| 271 |
+
"max_position_embeddings": getattr(hf_cfg, "max_position_embeddings", 4096),
|
| 272 |
+
"tie_word_embeddings": getattr(hf_cfg, "tie_word_embeddings", False),
|
| 273 |
+
}
|
| 274 |
+
self.model_cfg = cfg
|
| 275 |
+
|
| 276 |
+
# Build our own model, then copy HF weights into it.
|
| 277 |
+
model = Qwen2ForCausalLM(cfg).to(self.device, self.dtype)
|
| 278 |
+
hf_model = AutoModelForCausalLM.from_pretrained(
|
| 279 |
+
config.model, torch_dtype=self.dtype,
|
| 280 |
+
trust_remote_code=config.trust_remote_code,
|
| 281 |
+
)
|
| 282 |
+
missing, unexpected = model.load_state_dict(hf_model.state_dict(), strict=False)
|
| 283 |
+
if cfg["tie_word_embeddings"] and "lm_head.weight" in (missing or []):
|
| 284 |
+
model.tie_weights()
|
| 285 |
+
del hf_model
|
| 286 |
+
model.eval()
|
| 287 |
+
for p in model.parameters():
|
| 288 |
+
p.requires_grad_(False)
|
| 289 |
+
self.model = model
|
| 290 |
+
|
| 291 |
+
# Precompute RoPE tables.
|
| 292 |
+
max_pos = min(cfg["max_position_embeddings"], config.max_model_len)
|
| 293 |
+
inv_freq = 1.0 / (
|
| 294 |
+
cfg["rope_theta"]
|
| 295 |
+
** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)
|
| 296 |
+
)
|
| 297 |
+
t = torch.arange(max_pos, dtype=torch.float32)
|
| 298 |
+
freqs = torch.outer(t, inv_freq) # [max_pos, head_dim/2]
|
| 299 |
+
emb = torch.cat((freqs, freqs), dim=-1) # [max_pos, head_dim]
|
| 300 |
+
self.cos_table = emb.cos().to(self.device, self.dtype)
|
| 301 |
+
self.sin_table = emb.sin().to(self.device, self.dtype)
|
| 302 |
+
|
| 303 |
+
# Paged KV cache pool.
|
| 304 |
+
self.kv_cache = PagedKVCache(
|
| 305 |
+
num_layers=cfg["num_hidden_layers"],
|
| 306 |
+
num_blocks=config.num_blocks,
|
| 307 |
+
block_size=config.block_size,
|
| 308 |
+
num_kv_heads=cfg["num_key_value_heads"],
|
| 309 |
+
head_dim=head_dim,
|
| 310 |
+
dtype=self.dtype,
|
| 311 |
+
device=self.device,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# ---- input building ------------------------------------------------
|
| 315 |
+
|
| 316 |
+
def prepare_input(self, scheduled) -> ModelInput:
|
| 317 |
+
"""`scheduled` is a list of (Sequence, num_tokens, is_prefill) triples
|
| 318 |
+
from the scheduler."""
|
| 319 |
+
input_ids: list[int] = []
|
| 320 |
+
positions: list[int] = []
|
| 321 |
+
slot_mapping: list[int] = []
|
| 322 |
+
segments: list[AttnSegment] = []
|
| 323 |
+
last_indices: list[int] = []
|
| 324 |
+
|
| 325 |
+
cursor = 0
|
| 326 |
+
B = self.config.block_size
|
| 327 |
+
for item in scheduled:
|
| 328 |
+
seq = item.seq
|
| 329 |
+
n = item.num_tokens
|
| 330 |
+
# Logical token positions this step processes.
|
| 331 |
+
start_pos = seq.num_computed_tokens
|
| 332 |
+
for off in range(n):
|
| 333 |
+
pos = start_pos + off
|
| 334 |
+
input_ids.append(seq.get_token(pos))
|
| 335 |
+
positions.append(pos)
|
| 336 |
+
block_id = seq.block_table[pos // B]
|
| 337 |
+
slot_mapping.append(block_id * B + (pos % B))
|
| 338 |
+
|
| 339 |
+
q_end = cursor + n
|
| 340 |
+
segments.append(AttnSegment(
|
| 341 |
+
q_start=cursor,
|
| 342 |
+
q_end=q_end,
|
| 343 |
+
block_table=list(seq.block_table),
|
| 344 |
+
k_len=start_pos + n,
|
| 345 |
+
))
|
| 346 |
+
last_indices.append(q_end - 1)
|
| 347 |
+
cursor = q_end
|
| 348 |
+
|
| 349 |
+
return ModelInput(
|
| 350 |
+
input_ids=torch.tensor(input_ids, dtype=torch.long, device=self.device),
|
| 351 |
+
positions=torch.tensor(positions, dtype=torch.long, device=self.device),
|
| 352 |
+
slot_mapping=torch.tensor(slot_mapping, dtype=torch.long, device=self.device),
|
| 353 |
+
segments=segments,
|
| 354 |
+
last_token_indices=torch.tensor(last_indices, dtype=torch.long, device=self.device),
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
# ---- forward -------------------------------------------------------
|
| 358 |
+
|
| 359 |
+
@torch.inference_mode()
|
| 360 |
+
def execute(self, model_input: ModelInput) -> torch.Tensor:
|
| 361 |
+
"""Run one forward pass. Returns logits for the LAST token of each
|
| 362 |
+
scheduled sequence: shape [num_seqs, vocab_size]."""
|
| 363 |
+
hidden = self.model.model(
|
| 364 |
+
input_ids=model_input.input_ids,
|
| 365 |
+
positions=model_input.positions,
|
| 366 |
+
slot_mapping=model_input.slot_mapping,
|
| 367 |
+
cos_table=self.cos_table,
|
| 368 |
+
sin_table=self.sin_table,
|
| 369 |
+
segments=model_input.segments,
|
| 370 |
+
kv_cache=self.kv_cache,
|
| 371 |
+
) # [T, hidden]
|
| 372 |
+
last_hidden = hidden.index_select(0, model_input.last_token_indices)
|
| 373 |
+
logits = self.model.lm_head(last_hidden) # [num_seqs, vocab]
|
| 374 |
+
return logits
|
| 375 |
+
|
| 376 |
+
# ---- helpers -------------------------------------------------------
|
| 377 |
+
|
| 378 |
+
@property
|
| 379 |
+
def eos_token_id(self) -> Optional[int]:
|
| 380 |
+
return self.tokenizer.eos_token_id
|
| 381 |
+
|
| 382 |
+
def encode(self, text: str) -> list[int]:
|
| 383 |
+
return self.tokenizer.encode(text, add_special_tokens=False)
|
| 384 |
+
|
| 385 |
+
def decode(self, token_ids: list[int]) -> str:
|
| 386 |
+
return self.tokenizer.decode(token_ids, skip_special_tokens=True)
|
| 387 |
+
|
| 388 |
+
def detokenize_incremental(self, full_ids: list[int], prev_text_len: int) -> tuple[str, int]:
|
| 389 |
+
"""Detokenize the full list, return the new text added since last call
|
| 390 |
+
and the new total length."""
|
| 391 |
+
text = self.tokenizer.decode(full_ids, skip_special_tokens=True)
|
| 392 |
+
return text[prev_text_len:], len(text)
|
tiny_vllm/paged_kv.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""The actual KV tensor pool that the BlockManager indexes into.
|
| 2 |
+
|
| 3 |
+
We store one ``[num_blocks, block_size, num_kv_heads, head_dim]`` tensor per
|
| 4 |
+
layer for K and V. The block_manager owns the *allocation* of block ids; this
|
| 5 |
+
class owns the *bytes*. Reads and writes happen by (block_id, offset).
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PagedKVCache:
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
num_layers: int,
|
| 16 |
+
num_blocks: int,
|
| 17 |
+
block_size: int,
|
| 18 |
+
num_kv_heads: int,
|
| 19 |
+
head_dim: int,
|
| 20 |
+
dtype: torch.dtype,
|
| 21 |
+
device: torch.device,
|
| 22 |
+
) -> None:
|
| 23 |
+
self.num_layers = num_layers
|
| 24 |
+
self.num_blocks = num_blocks
|
| 25 |
+
self.block_size = block_size
|
| 26 |
+
self.num_kv_heads = num_kv_heads
|
| 27 |
+
self.head_dim = head_dim
|
| 28 |
+
self.dtype = dtype
|
| 29 |
+
self.device = device
|
| 30 |
+
shape = (num_blocks, block_size, num_kv_heads, head_dim)
|
| 31 |
+
self.k_cache = [torch.zeros(shape, dtype=dtype, device=device) for _ in range(num_layers)]
|
| 32 |
+
self.v_cache = [torch.zeros(shape, dtype=dtype, device=device) for _ in range(num_layers)]
|
| 33 |
+
|
| 34 |
+
def write(
|
| 35 |
+
self,
|
| 36 |
+
layer_id: int,
|
| 37 |
+
k: torch.Tensor, # [T, num_kv_heads, head_dim]
|
| 38 |
+
v: torch.Tensor, # [T, num_kv_heads, head_dim]
|
| 39 |
+
slot_mapping: torch.Tensor # [T] int64, slot_id = block_id*block_size + offset
|
| 40 |
+
) -> None:
|
| 41 |
+
block_ids = (slot_mapping // self.block_size).long()
|
| 42 |
+
offsets = (slot_mapping % self.block_size).long()
|
| 43 |
+
self.k_cache[layer_id][block_ids, offsets] = k.to(self.dtype)
|
| 44 |
+
self.v_cache[layer_id][block_ids, offsets] = v.to(self.dtype)
|
| 45 |
+
|
| 46 |
+
def gather(
|
| 47 |
+
self,
|
| 48 |
+
layer_id: int,
|
| 49 |
+
block_table: list[int],
|
| 50 |
+
num_tokens: int,
|
| 51 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 52 |
+
"""Return contiguous [num_tokens, num_kv_heads, head_dim] K and V
|
| 53 |
+
for one sequence, by walking its block table."""
|
| 54 |
+
if num_tokens == 0:
|
| 55 |
+
empty = torch.zeros(
|
| 56 |
+
0, self.num_kv_heads, self.head_dim,
|
| 57 |
+
dtype=self.dtype, device=self.device,
|
| 58 |
+
)
|
| 59 |
+
return empty, empty.clone()
|
| 60 |
+
num_full = num_tokens // self.block_size
|
| 61 |
+
tail = num_tokens % self.block_size
|
| 62 |
+
idxs = block_table[:num_full + (1 if tail else 0)]
|
| 63 |
+
idx_tensor = torch.as_tensor(idxs, dtype=torch.long, device=self.device)
|
| 64 |
+
# [P, block_size, H, D]
|
| 65 |
+
k_blocks = self.k_cache[layer_id].index_select(0, idx_tensor)
|
| 66 |
+
v_blocks = self.v_cache[layer_id].index_select(0, idx_tensor)
|
| 67 |
+
# Flatten the first two dims then trim.
|
| 68 |
+
k_flat = k_blocks.reshape(-1, self.num_kv_heads, self.head_dim)
|
| 69 |
+
v_flat = v_blocks.reshape(-1, self.num_kv_heads, self.head_dim)
|
| 70 |
+
return k_flat[:num_tokens], v_flat[:num_tokens]
|
tiny_vllm/request.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import enum
|
| 4 |
+
import itertools
|
| 5 |
+
import time
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
from .config import SamplingParams
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SequenceStatus(enum.Enum):
|
| 13 |
+
WAITING = "waiting" # not yet started prefill
|
| 14 |
+
PREFILLING = "prefilling" # chunked prefill in progress
|
| 15 |
+
RUNNING = "running" # in decode loop
|
| 16 |
+
FINISHED = "finished"
|
| 17 |
+
PREEMPTED = "preempted" # evicted; will restart prefill when capacity returns
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
_seq_counter = itertools.count()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _next_seq_id() -> int:
|
| 24 |
+
return next(_seq_counter)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class Sequence:
|
| 29 |
+
"""One in-flight request.
|
| 30 |
+
|
| 31 |
+
The token sequence is `prompt_token_ids + output_token_ids`.
|
| 32 |
+
`num_computed_tokens` tracks how many tokens already have their KV
|
| 33 |
+
materialized in the paged cache. Anything past that boundary is either
|
| 34 |
+
waiting prefill (during PREFILLING) or the next token to sample (RUNNING).
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
prompt_token_ids: list[int]
|
| 38 |
+
sampling_params: SamplingParams
|
| 39 |
+
request_id: str
|
| 40 |
+
arrival_time: float = field(default_factory=time.monotonic)
|
| 41 |
+
seq_id: int = field(default_factory=_next_seq_id)
|
| 42 |
+
|
| 43 |
+
output_token_ids: list[int] = field(default_factory=list)
|
| 44 |
+
status: SequenceStatus = SequenceStatus.WAITING
|
| 45 |
+
|
| 46 |
+
# Paged KV bookkeeping (filled in by the BlockManager).
|
| 47 |
+
block_table: list[int] = field(default_factory=list)
|
| 48 |
+
num_computed_tokens: int = 0 # tokens with KV in the cache
|
| 49 |
+
num_cached_prefix_tokens: int = 0 # tokens served from prefix cache hits
|
| 50 |
+
|
| 51 |
+
# Outputs / streaming
|
| 52 |
+
finish_reason: Optional[str] = None
|
| 53 |
+
|
| 54 |
+
# ---- helpers --------------------------------------------------------
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def prompt_len(self) -> int:
|
| 58 |
+
return len(self.prompt_token_ids)
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def total_len(self) -> int:
|
| 62 |
+
return len(self.prompt_token_ids) + len(self.output_token_ids)
|
| 63 |
+
|
| 64 |
+
def all_token_ids(self) -> list[int]:
|
| 65 |
+
return self.prompt_token_ids + self.output_token_ids
|
| 66 |
+
|
| 67 |
+
def get_token(self, position: int) -> int:
|
| 68 |
+
if position < len(self.prompt_token_ids):
|
| 69 |
+
return self.prompt_token_ids[position]
|
| 70 |
+
return self.output_token_ids[position - len(self.prompt_token_ids)]
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
def num_uncomputed_prompt_tokens(self) -> int:
|
| 74 |
+
return max(0, self.prompt_len - self.num_computed_tokens)
|
| 75 |
+
|
| 76 |
+
def append_output_token(self, token_id: int) -> None:
|
| 77 |
+
self.output_token_ids.append(token_id)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@dataclass
|
| 81 |
+
class Request:
|
| 82 |
+
"""A user-submitted request before it becomes a Sequence."""
|
| 83 |
+
|
| 84 |
+
request_id: str
|
| 85 |
+
prompt_token_ids: list[int]
|
| 86 |
+
sampling_params: SamplingParams
|
tiny_vllm/sampler.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Per-request sampling. Temperature, top-p, top-k, greedy."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from .config import SamplingParams
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Sampler:
|
| 12 |
+
def __init__(self, device: torch.device) -> None:
|
| 13 |
+
self.device = device
|
| 14 |
+
|
| 15 |
+
def sample(
|
| 16 |
+
self,
|
| 17 |
+
logits: torch.Tensor, # [num_seqs, vocab]
|
| 18 |
+
params: list[SamplingParams],
|
| 19 |
+
generators: Optional[list[Optional[torch.Generator]]] = None,
|
| 20 |
+
) -> list[int]:
|
| 21 |
+
out: list[int] = []
|
| 22 |
+
for i, p in enumerate(params):
|
| 23 |
+
row = logits[i]
|
| 24 |
+
if p.is_greedy:
|
| 25 |
+
out.append(int(row.argmax().item()))
|
| 26 |
+
continue
|
| 27 |
+
|
| 28 |
+
# Temperature.
|
| 29 |
+
row = row / max(p.temperature, 1e-5)
|
| 30 |
+
# Top-k.
|
| 31 |
+
if p.top_k > 0 and p.top_k < row.size(-1):
|
| 32 |
+
topk_vals, _ = torch.topk(row, p.top_k)
|
| 33 |
+
row = torch.where(row < topk_vals[-1], torch.full_like(row, float("-inf")), row)
|
| 34 |
+
# Top-p (nucleus).
|
| 35 |
+
if 0.0 < p.top_p < 1.0:
|
| 36 |
+
sorted_logits, sorted_idx = torch.sort(row, descending=True)
|
| 37 |
+
probs = torch.softmax(sorted_logits, dim=-1)
|
| 38 |
+
cumprobs = probs.cumsum(dim=-1)
|
| 39 |
+
# Drop tokens whose CUMULATIVE prob (including themselves) exceeds top_p,
|
| 40 |
+
# but always keep the highest-probability one.
|
| 41 |
+
drop = cumprobs > p.top_p
|
| 42 |
+
drop[0] = False
|
| 43 |
+
drop = drop.roll(shifts=1, dims=0) # so the boundary token stays
|
| 44 |
+
drop[0] = False
|
| 45 |
+
sorted_logits = sorted_logits.masked_fill(drop, float("-inf"))
|
| 46 |
+
row = torch.full_like(row, float("-inf"))
|
| 47 |
+
row.scatter_(0, sorted_idx, sorted_logits)
|
| 48 |
+
|
| 49 |
+
probs = torch.softmax(row, dim=-1)
|
| 50 |
+
gen = generators[i] if generators else None
|
| 51 |
+
token = torch.multinomial(probs, num_samples=1, generator=gen)
|
| 52 |
+
out.append(int(token.item()))
|
| 53 |
+
return out
|
tiny_vllm/scheduler.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Continuous-batching scheduler with chunked prefill.
|
| 2 |
+
|
| 3 |
+
A scheduling step produces a SchedulerOutput listing which sequences run and
|
| 4 |
+
how many tokens each one advances. Two phases each step:
|
| 5 |
+
|
| 6 |
+
1. Decodes. Every RUNNING sequence wants exactly one new token; we must
|
| 7 |
+
ensure each has space for it. If a sequence needs a new block and the
|
| 8 |
+
pool is dry, we *preempt* the most recently admitted running sequence —
|
| 9 |
+
free its KV blocks and push it back to the front of the waiting queue
|
| 10 |
+
so it restarts prefill later (recompute-style preemption, as in vLLM).
|
| 11 |
+
|
| 12 |
+
2. Prefill chunks. With remaining token budget, pull from `waiting`. A
|
| 13 |
+
newly waiting sequence is admitted (prompt blocks allocated via the
|
| 14 |
+
block manager, with prefix-cache hits taken). Then we plan a chunk of
|
| 15 |
+
up to `min(remaining_prefill, budget)` tokens. Chunked prefill lets a
|
| 16 |
+
long prompt share the budget with concurrent decodes instead of
|
| 17 |
+
stalling them.
|
| 18 |
+
"""
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
from collections import deque
|
| 22 |
+
from dataclasses import dataclass, field
|
| 23 |
+
|
| 24 |
+
from .block_manager import BlockManager
|
| 25 |
+
from .config import EngineConfig
|
| 26 |
+
from .request import Sequence, SequenceStatus
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class ScheduledSeq:
|
| 31 |
+
seq: Sequence
|
| 32 |
+
num_tokens: int # how many tokens to forward this step for this seq
|
| 33 |
+
is_prefill: bool
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class SchedulerOutput:
|
| 38 |
+
scheduled: list[ScheduledSeq] = field(default_factory=list)
|
| 39 |
+
preempted: list[int] = field(default_factory=list) # seq_ids preempted
|
| 40 |
+
newly_admitted: list[int] = field(default_factory=list)
|
| 41 |
+
total_tokens: int = 0
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def is_empty(self) -> bool:
|
| 45 |
+
return not self.scheduled
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class Scheduler:
|
| 49 |
+
def __init__(self, config: EngineConfig, block_manager: BlockManager) -> None:
|
| 50 |
+
self.config = config
|
| 51 |
+
self.block_manager = block_manager
|
| 52 |
+
self.waiting: deque[Sequence] = deque()
|
| 53 |
+
self.running: list[Sequence] = []
|
| 54 |
+
# Tracks order of admission so preemption picks the youngest first.
|
| 55 |
+
self._admission_order: list[int] = []
|
| 56 |
+
|
| 57 |
+
# ---- queue ops ------------------------------------------------------
|
| 58 |
+
|
| 59 |
+
def add(self, seq: Sequence) -> None:
|
| 60 |
+
self.waiting.append(seq)
|
| 61 |
+
|
| 62 |
+
def abort(self, seq_id: int) -> bool:
|
| 63 |
+
for q in (self.waiting,):
|
| 64 |
+
for s in list(q):
|
| 65 |
+
if s.seq_id == seq_id:
|
| 66 |
+
q.remove(s)
|
| 67 |
+
s.status = SequenceStatus.FINISHED
|
| 68 |
+
s.finish_reason = "abort"
|
| 69 |
+
return True
|
| 70 |
+
for s in list(self.running):
|
| 71 |
+
if s.seq_id == seq_id:
|
| 72 |
+
self.running.remove(s)
|
| 73 |
+
self.block_manager.free(s)
|
| 74 |
+
s.status = SequenceStatus.FINISHED
|
| 75 |
+
s.finish_reason = "abort"
|
| 76 |
+
return True
|
| 77 |
+
return False
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def has_work(self) -> bool:
|
| 81 |
+
return bool(self.waiting) or bool(self.running)
|
| 82 |
+
|
| 83 |
+
# ---- scheduling -----------------------------------------------------
|
| 84 |
+
|
| 85 |
+
def _preempt_one(self) -> Sequence | None:
|
| 86 |
+
"""Free the youngest running sequence and re-enqueue it for restart."""
|
| 87 |
+
if not self.running:
|
| 88 |
+
return None
|
| 89 |
+
victim = self.running.pop() # youngest by insertion order
|
| 90 |
+
self.block_manager.free(victim)
|
| 91 |
+
# Restart: forget computed-token progress; keep generated outputs so
|
| 92 |
+
# the user-visible sequence is preserved. (vLLM full-recompute: we'd
|
| 93 |
+
# discard outputs too; we keep them so streaming makes sense.)
|
| 94 |
+
victim.num_computed_tokens = 0
|
| 95 |
+
victim.num_cached_prefix_tokens = 0
|
| 96 |
+
victim.status = SequenceStatus.PREEMPTED
|
| 97 |
+
self.waiting.appendleft(victim)
|
| 98 |
+
return victim
|
| 99 |
+
|
| 100 |
+
def schedule(self) -> SchedulerOutput:
|
| 101 |
+
out = SchedulerOutput()
|
| 102 |
+
budget = self.config.max_num_batched_tokens
|
| 103 |
+
|
| 104 |
+
# --- Phase 1: decodes for already-running sequences ---
|
| 105 |
+
for seq in list(self.running):
|
| 106 |
+
if seq.status != SequenceStatus.RUNNING:
|
| 107 |
+
continue
|
| 108 |
+
if budget <= 0:
|
| 109 |
+
break
|
| 110 |
+
# Ensure space for one more token.
|
| 111 |
+
try:
|
| 112 |
+
self.block_manager.append_slot(seq)
|
| 113 |
+
except RuntimeError:
|
| 114 |
+
# Out of blocks: try to free space by preempting the youngest
|
| 115 |
+
# running sequence — which may be `seq` itself.
|
| 116 |
+
victim = self._preempt_one()
|
| 117 |
+
if victim is seq:
|
| 118 |
+
# We preempted ourselves; it's already off `running`.
|
| 119 |
+
out.preempted.append(seq.seq_id)
|
| 120 |
+
continue
|
| 121 |
+
if victim is None:
|
| 122 |
+
# Nothing to preempt; preempt this seq manually.
|
| 123 |
+
self.running.remove(seq)
|
| 124 |
+
self.block_manager.free(seq)
|
| 125 |
+
seq.num_computed_tokens = 0
|
| 126 |
+
seq.num_cached_prefix_tokens = 0
|
| 127 |
+
seq.status = SequenceStatus.PREEMPTED
|
| 128 |
+
self.waiting.appendleft(seq)
|
| 129 |
+
out.preempted.append(seq.seq_id)
|
| 130 |
+
continue
|
| 131 |
+
out.preempted.append(victim.seq_id)
|
| 132 |
+
try:
|
| 133 |
+
self.block_manager.append_slot(seq)
|
| 134 |
+
except RuntimeError:
|
| 135 |
+
# Still no room — give up on this seq this step.
|
| 136 |
+
continue
|
| 137 |
+
out.scheduled.append(ScheduledSeq(seq=seq, num_tokens=1, is_prefill=False))
|
| 138 |
+
budget -= 1
|
| 139 |
+
out.total_tokens += 1
|
| 140 |
+
|
| 141 |
+
# --- Phase 2: prefill chunks (admitting new sequences as needed) ---
|
| 142 |
+
max_concurrent = self.config.max_num_seqs
|
| 143 |
+
active_count = sum(1 for s in self.running if s.status != SequenceStatus.FINISHED)
|
| 144 |
+
|
| 145 |
+
while self.waiting and budget > 0 and active_count < max_concurrent:
|
| 146 |
+
seq = self.waiting[0]
|
| 147 |
+
|
| 148 |
+
# Admit if needed.
|
| 149 |
+
if not seq.block_table:
|
| 150 |
+
ok, _ = self.block_manager.can_allocate_initial(seq)
|
| 151 |
+
if not ok:
|
| 152 |
+
# Try to free up space by preempting the youngest running
|
| 153 |
+
# seq. If nothing to preempt, we're stuck for this step.
|
| 154 |
+
if not self.running:
|
| 155 |
+
break
|
| 156 |
+
victim = self._preempt_one()
|
| 157 |
+
if victim is None:
|
| 158 |
+
break
|
| 159 |
+
out.preempted.append(victim.seq_id)
|
| 160 |
+
continue
|
| 161 |
+
self.block_manager.admit(seq)
|
| 162 |
+
out.newly_admitted.append(seq.seq_id)
|
| 163 |
+
seq.status = SequenceStatus.PREFILLING
|
| 164 |
+
|
| 165 |
+
# Plan a chunk.
|
| 166 |
+
remaining = seq.num_uncomputed_prompt_tokens
|
| 167 |
+
chunk = min(remaining, budget)
|
| 168 |
+
if chunk <= 0:
|
| 169 |
+
# Prompt already fully cached (shouldn't happen due to admit
|
| 170 |
+
# capping, but defensive): move straight to RUNNING.
|
| 171 |
+
self.waiting.popleft()
|
| 172 |
+
seq.status = SequenceStatus.RUNNING
|
| 173 |
+
self.running.append(seq)
|
| 174 |
+
active_count += 1
|
| 175 |
+
continue
|
| 176 |
+
|
| 177 |
+
# Make sure block_table covers num_computed + chunk.
|
| 178 |
+
try:
|
| 179 |
+
self.block_manager.ensure_blocks_for_chunk(seq, chunk)
|
| 180 |
+
except RuntimeError:
|
| 181 |
+
# Couldn't expand. Try preemption; otherwise give up.
|
| 182 |
+
if self.running:
|
| 183 |
+
victim = self._preempt_one()
|
| 184 |
+
if victim is not None:
|
| 185 |
+
out.preempted.append(victim.seq_id)
|
| 186 |
+
continue
|
| 187 |
+
break
|
| 188 |
+
|
| 189 |
+
out.scheduled.append(ScheduledSeq(seq=seq, num_tokens=chunk, is_prefill=True))
|
| 190 |
+
budget -= chunk
|
| 191 |
+
out.total_tokens += chunk
|
| 192 |
+
|
| 193 |
+
if chunk == remaining:
|
| 194 |
+
# This step finishes prompt ingestion → seq becomes RUNNING.
|
| 195 |
+
self.waiting.popleft()
|
| 196 |
+
seq.status = SequenceStatus.RUNNING
|
| 197 |
+
self.running.append(seq)
|
| 198 |
+
active_count += 1
|
| 199 |
+
else:
|
| 200 |
+
# Still has more prompt to chew through; leave at head of
|
| 201 |
+
# waiting queue with a partial block_table.
|
| 202 |
+
break # one prefill per step keeps things tidy
|
| 203 |
+
|
| 204 |
+
return out
|
| 205 |
+
|
| 206 |
+
# ---- post-step ------------------------------------------------------
|
| 207 |
+
|
| 208 |
+
def finalize_step(self, scheduled: list[ScheduledSeq]) -> list[Sequence]:
|
| 209 |
+
"""Called after the model has produced new tokens.
|
| 210 |
+
|
| 211 |
+
Returns the list of sequences that just finished this step (so the
|
| 212 |
+
engine can free them and ship the final output to the caller).
|
| 213 |
+
"""
|
| 214 |
+
finished: list[Sequence] = []
|
| 215 |
+
for item in scheduled:
|
| 216 |
+
seq = item.seq
|
| 217 |
+
self.block_manager.register_filled_blocks(seq, prev_computed=0)
|
| 218 |
+
if seq.status == SequenceStatus.FINISHED:
|
| 219 |
+
if seq in self.running:
|
| 220 |
+
self.running.remove(seq)
|
| 221 |
+
self.block_manager.free(seq)
|
| 222 |
+
finished.append(seq)
|
| 223 |
+
return finished
|
tiny_vllm/server.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI front-end.
|
| 2 |
+
|
| 3 |
+
Two SSE streams live behind this server:
|
| 4 |
+
|
| 5 |
+
POST /generate — submit a prompt, stream back token deltas
|
| 6 |
+
POST /v1/completions — OpenAI-compatible streaming completions
|
| 7 |
+
GET /engine/events — stream of engine-state snapshots (one per step)
|
| 8 |
+
— what the demo page subscribes to
|
| 9 |
+
GET /engine/snapshot — one-shot current state (JSON)
|
| 10 |
+
GET / — static demo page
|
| 11 |
+
|
| 12 |
+
The demo page subscribes to /engine/events and renders the block pool,
|
| 13 |
+
scheduler queues, and live token streams.
|
| 14 |
+
"""
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import asyncio
|
| 19 |
+
import json
|
| 20 |
+
import os
|
| 21 |
+
import time
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import AsyncIterator, Optional
|
| 24 |
+
|
| 25 |
+
from fastapi import FastAPI, HTTPException, Request
|
| 26 |
+
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
| 27 |
+
from fastapi.staticfiles import StaticFiles
|
| 28 |
+
from pydantic import BaseModel, Field
|
| 29 |
+
|
| 30 |
+
from .config import EngineConfig, SamplingParams
|
| 31 |
+
from .engine import LLMEngine
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
# Schemas
|
| 36 |
+
# ---------------------------------------------------------------------------
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class GenerateRequest(BaseModel):
|
| 40 |
+
prompt: str
|
| 41 |
+
max_tokens: int = 64
|
| 42 |
+
temperature: float = 1.0
|
| 43 |
+
top_p: float = 1.0
|
| 44 |
+
top_k: int = -1
|
| 45 |
+
seed: Optional[int] = None
|
| 46 |
+
ignore_eos: bool = False
|
| 47 |
+
stream: bool = True
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class CompletionsRequest(BaseModel):
|
| 51 |
+
model: Optional[str] = None
|
| 52 |
+
prompt: str | list[str]
|
| 53 |
+
max_tokens: int = 64
|
| 54 |
+
temperature: float = 1.0
|
| 55 |
+
top_p: float = 1.0
|
| 56 |
+
n: int = 1
|
| 57 |
+
stream: bool = False
|
| 58 |
+
stop: Optional[list[str]] = None
|
| 59 |
+
seed: Optional[int] = None
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# ---------------------------------------------------------------------------
|
| 63 |
+
# App factory
|
| 64 |
+
# ---------------------------------------------------------------------------
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _sse(data: dict | str) -> bytes:
|
| 68 |
+
if isinstance(data, dict):
|
| 69 |
+
data = json.dumps(data, separators=(",", ":"))
|
| 70 |
+
return f"data: {data}\n\n".encode("utf-8")
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def build_app(config: EngineConfig) -> FastAPI:
|
| 74 |
+
app = FastAPI(title="tiny_vllm", version="0.1.0")
|
| 75 |
+
engine = LLMEngine(config)
|
| 76 |
+
|
| 77 |
+
@app.on_event("startup")
|
| 78 |
+
async def _on_startup() -> None:
|
| 79 |
+
await engine.startup()
|
| 80 |
+
|
| 81 |
+
@app.on_event("shutdown")
|
| 82 |
+
async def _on_shutdown() -> None:
|
| 83 |
+
await engine.shutdown()
|
| 84 |
+
|
| 85 |
+
# ---- root + static -------------------------------------------------
|
| 86 |
+
|
| 87 |
+
static_dir = Path(__file__).parent.parent / "web"
|
| 88 |
+
if static_dir.exists():
|
| 89 |
+
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
|
| 90 |
+
|
| 91 |
+
@app.get("/")
|
| 92 |
+
async def root() -> FileResponse:
|
| 93 |
+
return FileResponse(str(static_dir / "index.html"))
|
| 94 |
+
else:
|
| 95 |
+
@app.get("/")
|
| 96 |
+
async def root() -> dict:
|
| 97 |
+
return {"name": "tiny_vllm", "status": "ok",
|
| 98 |
+
"hint": "demo page not found; POST to /generate"}
|
| 99 |
+
|
| 100 |
+
# ---- introspection -------------------------------------------------
|
| 101 |
+
|
| 102 |
+
@app.get("/engine/snapshot")
|
| 103 |
+
async def snapshot() -> dict:
|
| 104 |
+
return engine.snapshot()
|
| 105 |
+
|
| 106 |
+
@app.get("/engine/events")
|
| 107 |
+
async def events(request: Request) -> StreamingResponse:
|
| 108 |
+
q = engine.subscribe_events()
|
| 109 |
+
|
| 110 |
+
async def gen() -> AsyncIterator[bytes]:
|
| 111 |
+
# Push initial snapshot so a freshly-connected client has state.
|
| 112 |
+
yield _sse({"type": "snapshot", "payload": engine.snapshot()})
|
| 113 |
+
try:
|
| 114 |
+
while True:
|
| 115 |
+
if await request.is_disconnected():
|
| 116 |
+
break
|
| 117 |
+
try:
|
| 118 |
+
ev = await asyncio.wait_for(q.get(), timeout=15.0)
|
| 119 |
+
except asyncio.TimeoutError:
|
| 120 |
+
yield b": keepalive\n\n"
|
| 121 |
+
continue
|
| 122 |
+
yield _sse({
|
| 123 |
+
"type": ev.type,
|
| 124 |
+
"step": ev.step,
|
| 125 |
+
"timestamp": ev.timestamp,
|
| 126 |
+
"payload": ev.payload,
|
| 127 |
+
})
|
| 128 |
+
finally:
|
| 129 |
+
engine.unsubscribe_events(q)
|
| 130 |
+
|
| 131 |
+
return StreamingResponse(gen(), media_type="text/event-stream")
|
| 132 |
+
|
| 133 |
+
# ---- generation ----------------------------------------------------
|
| 134 |
+
|
| 135 |
+
def _params(req: GenerateRequest) -> SamplingParams:
|
| 136 |
+
return SamplingParams(
|
| 137 |
+
max_tokens=req.max_tokens,
|
| 138 |
+
temperature=req.temperature,
|
| 139 |
+
top_p=req.top_p,
|
| 140 |
+
top_k=req.top_k,
|
| 141 |
+
seed=req.seed,
|
| 142 |
+
ignore_eos=req.ignore_eos,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
@app.post("/generate")
|
| 146 |
+
async def generate(req: GenerateRequest, request: Request) -> StreamingResponse | JSONResponse:
|
| 147 |
+
try:
|
| 148 |
+
rid = engine.add_request(req.prompt, _params(req))
|
| 149 |
+
except ValueError as e:
|
| 150 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 151 |
+
|
| 152 |
+
if not req.stream:
|
| 153 |
+
text_parts: list[str] = []
|
| 154 |
+
finish_reason: Optional[str] = None
|
| 155 |
+
async for item in engine.stream(rid):
|
| 156 |
+
text_parts.append(item.new_text)
|
| 157 |
+
if item.finished:
|
| 158 |
+
finish_reason = item.finish_reason
|
| 159 |
+
break
|
| 160 |
+
return JSONResponse({
|
| 161 |
+
"request_id": rid,
|
| 162 |
+
"text": "".join(text_parts),
|
| 163 |
+
"finish_reason": finish_reason,
|
| 164 |
+
})
|
| 165 |
+
|
| 166 |
+
async def gen() -> AsyncIterator[bytes]:
|
| 167 |
+
try:
|
| 168 |
+
async for item in engine.stream(rid):
|
| 169 |
+
if await request.is_disconnected():
|
| 170 |
+
engine.abort(rid)
|
| 171 |
+
break
|
| 172 |
+
yield _sse({
|
| 173 |
+
"request_id": rid,
|
| 174 |
+
"text": item.new_text,
|
| 175 |
+
"finished": item.finished,
|
| 176 |
+
"finish_reason": item.finish_reason,
|
| 177 |
+
})
|
| 178 |
+
if item.finished:
|
| 179 |
+
yield b"data: [DONE]\n\n"
|
| 180 |
+
break
|
| 181 |
+
except asyncio.CancelledError:
|
| 182 |
+
engine.abort(rid)
|
| 183 |
+
raise
|
| 184 |
+
|
| 185 |
+
return StreamingResponse(gen(), media_type="text/event-stream")
|
| 186 |
+
|
| 187 |
+
@app.post("/v1/completions")
|
| 188 |
+
async def completions(req: CompletionsRequest, request: Request):
|
| 189 |
+
# Single-prompt only (n=1) for the minimal impl.
|
| 190 |
+
if isinstance(req.prompt, list):
|
| 191 |
+
if len(req.prompt) != 1:
|
| 192 |
+
raise HTTPException(400, "tiny_vllm only supports a single prompt per call")
|
| 193 |
+
prompt = req.prompt[0]
|
| 194 |
+
else:
|
| 195 |
+
prompt = req.prompt
|
| 196 |
+
try:
|
| 197 |
+
rid = engine.add_request(
|
| 198 |
+
prompt,
|
| 199 |
+
SamplingParams(
|
| 200 |
+
max_tokens=req.max_tokens,
|
| 201 |
+
temperature=req.temperature,
|
| 202 |
+
top_p=req.top_p,
|
| 203 |
+
seed=req.seed,
|
| 204 |
+
),
|
| 205 |
+
)
|
| 206 |
+
except ValueError as e:
|
| 207 |
+
raise HTTPException(400, str(e))
|
| 208 |
+
|
| 209 |
+
created = int(time.time())
|
| 210 |
+
model_id = req.model or config.model
|
| 211 |
+
|
| 212 |
+
if not req.stream:
|
| 213 |
+
text_parts: list[str] = []
|
| 214 |
+
finish_reason: Optional[str] = None
|
| 215 |
+
async for item in engine.stream(rid):
|
| 216 |
+
text_parts.append(item.new_text)
|
| 217 |
+
if item.finished:
|
| 218 |
+
finish_reason = item.finish_reason
|
| 219 |
+
break
|
| 220 |
+
return JSONResponse({
|
| 221 |
+
"id": f"cmpl-{rid}",
|
| 222 |
+
"object": "text_completion",
|
| 223 |
+
"created": created,
|
| 224 |
+
"model": model_id,
|
| 225 |
+
"choices": [{
|
| 226 |
+
"text": "".join(text_parts),
|
| 227 |
+
"index": 0,
|
| 228 |
+
"logprobs": None,
|
| 229 |
+
"finish_reason": finish_reason,
|
| 230 |
+
}],
|
| 231 |
+
})
|
| 232 |
+
|
| 233 |
+
async def gen() -> AsyncIterator[bytes]:
|
| 234 |
+
try:
|
| 235 |
+
async for item in engine.stream(rid):
|
| 236 |
+
if await request.is_disconnected():
|
| 237 |
+
engine.abort(rid)
|
| 238 |
+
break
|
| 239 |
+
chunk = {
|
| 240 |
+
"id": f"cmpl-{rid}",
|
| 241 |
+
"object": "text_completion",
|
| 242 |
+
"created": created,
|
| 243 |
+
"model": model_id,
|
| 244 |
+
"choices": [{
|
| 245 |
+
"text": item.new_text,
|
| 246 |
+
"index": 0,
|
| 247 |
+
"logprobs": None,
|
| 248 |
+
"finish_reason": item.finish_reason if item.finished else None,
|
| 249 |
+
}],
|
| 250 |
+
}
|
| 251 |
+
yield _sse(chunk)
|
| 252 |
+
if item.finished:
|
| 253 |
+
yield b"data: [DONE]\n\n"
|
| 254 |
+
break
|
| 255 |
+
except asyncio.CancelledError:
|
| 256 |
+
engine.abort(rid)
|
| 257 |
+
raise
|
| 258 |
+
|
| 259 |
+
return StreamingResponse(gen(), media_type="text/event-stream")
|
| 260 |
+
|
| 261 |
+
@app.post("/abort/{request_id}")
|
| 262 |
+
async def abort(request_id: str) -> dict:
|
| 263 |
+
ok = engine.abort(request_id)
|
| 264 |
+
return {"aborted": ok}
|
| 265 |
+
|
| 266 |
+
return app
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
# ---------------------------------------------------------------------------
|
| 270 |
+
# CLI entry
|
| 271 |
+
# ---------------------------------------------------------------------------
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def main() -> None:
|
| 275 |
+
parser = argparse.ArgumentParser(description="tiny_vllm server")
|
| 276 |
+
parser.add_argument("--model", default=os.environ.get("TINY_VLLM_MODEL", "Qwen/Qwen2.5-0.5B-Instruct"))
|
| 277 |
+
parser.add_argument("--device", default=os.environ.get("TINY_VLLM_DEVICE", "cpu"))
|
| 278 |
+
parser.add_argument("--dtype", default=os.environ.get("TINY_VLLM_DTYPE", "float32"))
|
| 279 |
+
parser.add_argument("--block-size", type=int, default=16)
|
| 280 |
+
parser.add_argument("--num-blocks", type=int, default=256)
|
| 281 |
+
parser.add_argument("--max-num-seqs", type=int, default=8)
|
| 282 |
+
parser.add_argument("--max-num-batched-tokens", type=int, default=512)
|
| 283 |
+
parser.add_argument("--max-model-len", type=int, default=2048)
|
| 284 |
+
parser.add_argument("--disable-prefix-caching", action="store_true")
|
| 285 |
+
parser.add_argument("--host", default="0.0.0.0")
|
| 286 |
+
parser.add_argument("--port", type=int, default=8000)
|
| 287 |
+
args = parser.parse_args()
|
| 288 |
+
|
| 289 |
+
cfg = EngineConfig(
|
| 290 |
+
model=args.model,
|
| 291 |
+
device=args.device,
|
| 292 |
+
dtype=args.dtype,
|
| 293 |
+
block_size=args.block_size,
|
| 294 |
+
num_blocks=args.num_blocks,
|
| 295 |
+
max_num_seqs=args.max_num_seqs,
|
| 296 |
+
max_num_batched_tokens=args.max_num_batched_tokens,
|
| 297 |
+
max_model_len=args.max_model_len,
|
| 298 |
+
enable_prefix_caching=not args.disable_prefix_caching,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
import uvicorn
|
| 302 |
+
app = build_app(cfg)
|
| 303 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
if __name__ == "__main__":
|
| 307 |
+
main()
|
web/app.js
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* tiny_vllm — demo page client.
|
| 2 |
+
*
|
| 3 |
+
* Two streams in play:
|
| 4 |
+
*
|
| 5 |
+
* /engine/events — engine state snapshots (one per scheduling step)
|
| 6 |
+
* /generate — token-level deltas for whatever prompt this page sent
|
| 7 |
+
*
|
| 8 |
+
* The page itself is stateless; everything is driven by what comes off the
|
| 9 |
+
* event stream. Token deltas from /generate are merged into per-request UI.
|
| 10 |
+
*/
|
| 11 |
+
|
| 12 |
+
const $ = (id) => document.getElementById(id);
|
| 13 |
+
|
| 14 |
+
const ui = {
|
| 15 |
+
connection: $("connection"),
|
| 16 |
+
model: $("model"),
|
| 17 |
+
pool: $("block-pool"),
|
| 18 |
+
poolSummary: $("pool-summary"),
|
| 19 |
+
schedStep: $("sched-step"),
|
| 20 |
+
statTokens: $("stat-tokens"),
|
| 21 |
+
statPfDec: $("stat-pfdec"),
|
| 22 |
+
statMs: $("stat-ms"),
|
| 23 |
+
statCache: $("stat-cache"),
|
| 24 |
+
statFree: $("stat-free"),
|
| 25 |
+
statPre: $("stat-pre"),
|
| 26 |
+
log: $("log"),
|
| 27 |
+
seqs: $("seqs"),
|
| 28 |
+
send: $("send"),
|
| 29 |
+
sendTwice: $("send-twice"),
|
| 30 |
+
};
|
| 31 |
+
|
| 32 |
+
const state = {
|
| 33 |
+
poolEls: [],
|
| 34 |
+
numBlocks: 0,
|
| 35 |
+
blockSize: 16,
|
| 36 |
+
preempted: 0,
|
| 37 |
+
// request_id -> { promptText, generated, finished, finishReason }
|
| 38 |
+
requests: new Map(),
|
| 39 |
+
// seq_id -> { request_id, blockTable, cachedPrefixBlocks, status, ... }
|
| 40 |
+
seqsBySeqId: new Map(),
|
| 41 |
+
};
|
| 42 |
+
|
| 43 |
+
function logLine(html, cls = "") {
|
| 44 |
+
const t = new Date().toLocaleTimeString();
|
| 45 |
+
ui.log.innerHTML += `<span class="${cls}">[${t}] ${html}</span>\n`;
|
| 46 |
+
ui.log.scrollTop = ui.log.scrollHeight;
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
function initPool(numBlocks) {
|
| 50 |
+
if (state.numBlocks === numBlocks && state.poolEls.length === numBlocks) return;
|
| 51 |
+
state.numBlocks = numBlocks;
|
| 52 |
+
ui.pool.innerHTML = "";
|
| 53 |
+
state.poolEls = [];
|
| 54 |
+
for (let i = 0; i < numBlocks; i++) {
|
| 55 |
+
const el = document.createElement("div");
|
| 56 |
+
el.className = "block free";
|
| 57 |
+
el.title = `block ${i}`;
|
| 58 |
+
ui.pool.appendChild(el);
|
| 59 |
+
state.poolEls.push(el);
|
| 60 |
+
}
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
function renderPool(pool) {
|
| 64 |
+
initPool(pool.num_blocks);
|
| 65 |
+
state.blockSize = pool.block_size;
|
| 66 |
+
for (let i = 0; i < pool.num_blocks; i++) {
|
| 67 |
+
const el = state.poolEls[i];
|
| 68 |
+
const rc = pool.ref_counts[i];
|
| 69 |
+
const hashed = pool.hashed[i];
|
| 70 |
+
let cls = "block";
|
| 71 |
+
if (rc === 0) {
|
| 72 |
+
cls += hashed ? " cached" : " free";
|
| 73 |
+
} else if (rc === 1) {
|
| 74 |
+
cls += " used";
|
| 75 |
+
} else {
|
| 76 |
+
cls += " shared";
|
| 77 |
+
}
|
| 78 |
+
if (hashed) cls += " hashed";
|
| 79 |
+
el.className = cls;
|
| 80 |
+
el.title = `block ${i} — refcount=${rc}${hashed ? " — hashed (cacheable)" : ""}`;
|
| 81 |
+
}
|
| 82 |
+
ui.poolSummary.textContent =
|
| 83 |
+
`${pool.num_blocks - pool.num_free_blocks}/${pool.num_blocks} used · ` +
|
| 84 |
+
`${pool.num_cached_entries} cached entries · ` +
|
| 85 |
+
`prefix-cache ${pool.prefix_cache_hits}/${pool.prefix_cache_lookups}`;
|
| 86 |
+
ui.statFree.textContent = pool.num_free_blocks;
|
| 87 |
+
if (pool.prefix_cache_lookups > 0) {
|
| 88 |
+
const pct = (100 * pool.prefix_cache_hits / pool.prefix_cache_lookups).toFixed(0);
|
| 89 |
+
ui.statCache.textContent = `${pct}%`;
|
| 90 |
+
} else {
|
| 91 |
+
ui.statCache.textContent = "—";
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
function renderSeqs(snapshot) {
|
| 96 |
+
ui.schedStep.textContent = ` — step ${snapshot.step}`;
|
| 97 |
+
const all = [...snapshot.running, ...snapshot.waiting];
|
| 98 |
+
// index for later token-delta merges
|
| 99 |
+
state.seqsBySeqId = new Map(all.map(s => [s.seq_id, s]));
|
| 100 |
+
ui.seqs.innerHTML = "";
|
| 101 |
+
if (all.length === 0) {
|
| 102 |
+
ui.seqs.innerHTML = `<div class="muted">(no active sequences — send a prompt above)</div>`;
|
| 103 |
+
return;
|
| 104 |
+
}
|
| 105 |
+
for (const s of all) {
|
| 106 |
+
const reqRec = state.requests.get(s.request_id);
|
| 107 |
+
const promptText = reqRec?.promptText ?? "(prompt elided)";
|
| 108 |
+
const gen = reqRec?.generated ?? "";
|
| 109 |
+
|
| 110 |
+
const div = document.createElement("div");
|
| 111 |
+
div.className = "seq";
|
| 112 |
+
div.id = `seq-${s.request_id}`;
|
| 113 |
+
|
| 114 |
+
const cachedBlocks = Math.floor(s.num_cached_prefix_tokens / state.blockSize);
|
| 115 |
+
const blocksHTML = s.block_table.map((bid, i) => {
|
| 116 |
+
const klass = i < cachedBlocks ? "seq-block cached-hit"
|
| 117 |
+
: (snapshot.block_pool.ref_counts[bid] > 1 ? "seq-block shared" : "seq-block");
|
| 118 |
+
return `<div class="${klass}" title="block ${bid}${i < cachedBlocks ? ' (prefix-cache hit)' : ''}">${bid}</div>`;
|
| 119 |
+
}).join("");
|
| 120 |
+
|
| 121 |
+
div.innerHTML = `
|
| 122 |
+
<div class="seq-header">
|
| 123 |
+
<span class="seq-id">req=${s.request_id.slice(0, 8)} seq=${s.seq_id}</span>
|
| 124 |
+
<span class="seq-status ${s.status}">${s.status}</span>
|
| 125 |
+
<span class="seq-meta">
|
| 126 |
+
prompt=${s.prompt_len} · generated=${s.num_generated} ·
|
| 127 |
+
cached=${s.num_cached_prefix_tokens}/${s.prompt_len} ·
|
| 128 |
+
blocks=${s.block_table.length}
|
| 129 |
+
</span>
|
| 130 |
+
</div>
|
| 131 |
+
<div class="seq-blocks">${blocksHTML || '<span class="muted">(no blocks yet)</span>'}</div>
|
| 132 |
+
<div class="seq-text"><span class="prompt">${escapeHtml(promptText)}</span><span class="gen">${escapeHtml(gen)}</span>${s.status === 'running' || s.status === 'prefilling' ? '<span class="cursor"> </span>' : ''}</div>
|
| 133 |
+
`;
|
| 134 |
+
ui.seqs.appendChild(div);
|
| 135 |
+
}
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
function escapeHtml(s) {
|
| 139 |
+
return (s || "").replace(/[&<>"]/g, c => ({"&": "&", "<": "<", ">": ">", '"': """}[c]));
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
function handleEvent(ev) {
|
| 143 |
+
if (ev.type === "snapshot") {
|
| 144 |
+
const snap = ev.payload;
|
| 145 |
+
ui.model.textContent = `· ${snap.config.model}`;
|
| 146 |
+
renderPool(snap.block_pool);
|
| 147 |
+
renderSeqs(snap);
|
| 148 |
+
return;
|
| 149 |
+
}
|
| 150 |
+
if (ev.type === "step") {
|
| 151 |
+
const p = ev.payload;
|
| 152 |
+
ui.statTokens.textContent = p.num_tokens;
|
| 153 |
+
ui.statPfDec.textContent = `${p.num_prefill_seqs} / ${p.num_decode_seqs}`;
|
| 154 |
+
ui.statMs.textContent = p.duration_ms.toFixed(1);
|
| 155 |
+
if (p.preempted?.length) state.preempted += p.preempted.length;
|
| 156 |
+
ui.statPre.textContent = state.preempted;
|
| 157 |
+
renderPool(p.snapshot.block_pool);
|
| 158 |
+
renderSeqs(p.snapshot);
|
| 159 |
+
|
| 160 |
+
let msg = `step ${ev.step}: ${p.num_tokens}t (${p.num_prefill_seqs}P/${p.num_decode_seqs}D) in ${p.duration_ms.toFixed(1)}ms`;
|
| 161 |
+
let cls = "ev-step";
|
| 162 |
+
if (p.newly_admitted?.length) {
|
| 163 |
+
msg += ` · admitted seq=${p.newly_admitted.join(",")}`;
|
| 164 |
+
cls = "ev-admit";
|
| 165 |
+
}
|
| 166 |
+
if (p.finished?.length) {
|
| 167 |
+
msg += ` · finished ${p.finished.map(r => r.slice(0,8)).join(",")}`;
|
| 168 |
+
cls = "ev-finish";
|
| 169 |
+
}
|
| 170 |
+
if (p.preempted?.length) {
|
| 171 |
+
msg += ` · PREEMPTED seq=${p.preempted.join(",")}`;
|
| 172 |
+
cls = "ev-preempt";
|
| 173 |
+
}
|
| 174 |
+
logLine(msg, cls);
|
| 175 |
+
}
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
function connectEvents() {
|
| 179 |
+
const es = new EventSource("/engine/events");
|
| 180 |
+
es.onopen = () => {
|
| 181 |
+
ui.connection.textContent = "connected";
|
| 182 |
+
ui.connection.classList.remove("offline");
|
| 183 |
+
ui.connection.classList.add("online");
|
| 184 |
+
};
|
| 185 |
+
es.onerror = () => {
|
| 186 |
+
ui.connection.textContent = "disconnected";
|
| 187 |
+
ui.connection.classList.remove("online");
|
| 188 |
+
ui.connection.classList.add("offline");
|
| 189 |
+
};
|
| 190 |
+
es.onmessage = (e) => {
|
| 191 |
+
if (!e.data) return;
|
| 192 |
+
try {
|
| 193 |
+
handleEvent(JSON.parse(e.data));
|
| 194 |
+
} catch (err) {
|
| 195 |
+
console.error("bad event", err, e.data);
|
| 196 |
+
}
|
| 197 |
+
};
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
async function sendPrompt(prompt) {
|
| 201 |
+
const body = {
|
| 202 |
+
prompt,
|
| 203 |
+
max_tokens: parseInt($("max_tokens").value, 10),
|
| 204 |
+
temperature: parseFloat($("temperature").value),
|
| 205 |
+
top_p: parseFloat($("top_p").value),
|
| 206 |
+
stream: true,
|
| 207 |
+
};
|
| 208 |
+
const resp = await fetch("/generate", {
|
| 209 |
+
method: "POST",
|
| 210 |
+
headers: {"content-type": "application/json"},
|
| 211 |
+
body: JSON.stringify(body),
|
| 212 |
+
});
|
| 213 |
+
if (!resp.ok) {
|
| 214 |
+
const txt = await resp.text();
|
| 215 |
+
logLine(`request failed: ${txt}`, "ev-preempt");
|
| 216 |
+
return;
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
// Parse SSE manually so we can read each event as it arrives.
|
| 220 |
+
const reader = resp.body.getReader();
|
| 221 |
+
const decoder = new TextDecoder();
|
| 222 |
+
let buf = "";
|
| 223 |
+
let myReqId = null;
|
| 224 |
+
while (true) {
|
| 225 |
+
const { value, done } = await reader.read();
|
| 226 |
+
if (done) break;
|
| 227 |
+
buf += decoder.decode(value, { stream: true });
|
| 228 |
+
const parts = buf.split("\n\n");
|
| 229 |
+
buf = parts.pop();
|
| 230 |
+
for (const part of parts) {
|
| 231 |
+
const line = part.trim();
|
| 232 |
+
if (!line.startsWith("data:")) continue;
|
| 233 |
+
const data = line.slice(5).trim();
|
| 234 |
+
if (data === "[DONE]") return;
|
| 235 |
+
try {
|
| 236 |
+
const j = JSON.parse(data);
|
| 237 |
+
if (!myReqId) {
|
| 238 |
+
myReqId = j.request_id;
|
| 239 |
+
state.requests.set(myReqId, { promptText: prompt, generated: "", finished: false });
|
| 240 |
+
}
|
| 241 |
+
const rec = state.requests.get(myReqId);
|
| 242 |
+
if (j.text) rec.generated += j.text;
|
| 243 |
+
rec.finished = j.finished;
|
| 244 |
+
rec.finishReason = j.finish_reason;
|
| 245 |
+
// Repaint the matching seq card if visible.
|
| 246 |
+
const card = document.getElementById(`seq-${myReqId}`);
|
| 247 |
+
if (card) {
|
| 248 |
+
const text = card.querySelector(".seq-text .gen");
|
| 249 |
+
if (text) text.textContent = rec.generated;
|
| 250 |
+
}
|
| 251 |
+
} catch (e) {
|
| 252 |
+
console.error("bad chunk", e, data);
|
| 253 |
+
}
|
| 254 |
+
}
|
| 255 |
+
}
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
ui.send.addEventListener("click", () => sendPrompt($("prompt").value));
|
| 259 |
+
ui.sendTwice.addEventListener("click", async () => {
|
| 260 |
+
const p = $("prompt").value;
|
| 261 |
+
// First send fills the prefix cache; second send should hit it.
|
| 262 |
+
await sendPrompt(p);
|
| 263 |
+
await new Promise(r => setTimeout(r, 200));
|
| 264 |
+
await sendPrompt(p);
|
| 265 |
+
});
|
| 266 |
+
$("prompt").addEventListener("keydown", (e) => {
|
| 267 |
+
if ((e.metaKey || e.ctrlKey) && e.key === "Enter") {
|
| 268 |
+
sendPrompt(e.target.value);
|
| 269 |
+
}
|
| 270 |
+
});
|
| 271 |
+
|
| 272 |
+
connectEvents();
|
web/index.html
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!doctype html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="utf-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1">
|
| 6 |
+
<title>tiny_vllm — engine internals</title>
|
| 7 |
+
<link rel="stylesheet" href="/static/style.css">
|
| 8 |
+
</head>
|
| 9 |
+
<body>
|
| 10 |
+
<header>
|
| 11 |
+
<h1>tiny_vllm <span class="muted">— minimal continuous-batching engine</span></h1>
|
| 12 |
+
<div class="status">
|
| 13 |
+
<span id="connection" class="badge offline">disconnected</span>
|
| 14 |
+
<span id="model" class="muted"></span>
|
| 15 |
+
</div>
|
| 16 |
+
</header>
|
| 17 |
+
|
| 18 |
+
<section class="prompt-box">
|
| 19 |
+
<textarea id="prompt" rows="2" placeholder="Type a prompt and press Send (or Cmd/Ctrl+Enter)…">Explain paged attention in two sentences.</textarea>
|
| 20 |
+
<div class="controls">
|
| 21 |
+
<label>max_tokens <input id="max_tokens" type="number" value="64" min="1" max="2048"></label>
|
| 22 |
+
<label>temperature <input id="temperature" type="number" value="0.7" step="0.1" min="0" max="2"></label>
|
| 23 |
+
<label>top_p <input id="top_p" type="number" value="0.9" step="0.05" min="0" max="1"></label>
|
| 24 |
+
<button id="send">Send</button>
|
| 25 |
+
<button id="send-twice" title="Submit the same prompt twice — second should hit prefix cache">Send ×2 (prefix demo)</button>
|
| 26 |
+
</div>
|
| 27 |
+
</section>
|
| 28 |
+
|
| 29 |
+
<main>
|
| 30 |
+
<section class="card">
|
| 31 |
+
<h2>Block pool <span class="muted" id="pool-summary"></span></h2>
|
| 32 |
+
<div id="block-pool" class="block-pool"></div>
|
| 33 |
+
<div class="legend">
|
| 34 |
+
<span class="legend-item"><span class="swatch swatch-free"></span>free</span>
|
| 35 |
+
<span class="legend-item"><span class="swatch swatch-cached"></span>cached (evictable)</span>
|
| 36 |
+
<span class="legend-item"><span class="swatch swatch-used"></span>in use</span>
|
| 37 |
+
<span class="legend-item"><span class="swatch swatch-shared"></span>shared (refcount>1)</span>
|
| 38 |
+
<span class="legend-item"><span class="swatch swatch-hashed-edge"></span>hashed (border)</span>
|
| 39 |
+
</div>
|
| 40 |
+
</section>
|
| 41 |
+
|
| 42 |
+
<section class="card">
|
| 43 |
+
<h2>Scheduler <span class="muted" id="sched-step"></span></h2>
|
| 44 |
+
<div class="stats">
|
| 45 |
+
<div class="stat"><div class="stat-label">tokens this step</div><div class="stat-value" id="stat-tokens">0</div></div>
|
| 46 |
+
<div class="stat"><div class="stat-label">prefill / decode</div><div class="stat-value" id="stat-pfdec">0 / 0</div></div>
|
| 47 |
+
<div class="stat"><div class="stat-label">step (ms)</div><div class="stat-value" id="stat-ms">0</div></div>
|
| 48 |
+
<div class="stat"><div class="stat-label">prefix cache hit-rate</div><div class="stat-value" id="stat-cache">0%</div></div>
|
| 49 |
+
<div class="stat"><div class="stat-label">free blocks</div><div class="stat-value" id="stat-free">0</div></div>
|
| 50 |
+
<div class="stat"><div class="stat-label">preemptions (total)</div><div class="stat-value" id="stat-pre">0</div></div>
|
| 51 |
+
</div>
|
| 52 |
+
<h3>step log</h3>
|
| 53 |
+
<pre id="log" class="log"></pre>
|
| 54 |
+
</section>
|
| 55 |
+
|
| 56 |
+
<section class="card grow">
|
| 57 |
+
<h2>Sequences</h2>
|
| 58 |
+
<div id="seqs"></div>
|
| 59 |
+
</section>
|
| 60 |
+
</main>
|
| 61 |
+
|
| 62 |
+
<footer>
|
| 63 |
+
<span class="muted">Subscribed to <code>/engine/events</code>. Source: <a href="https://github.com/yourname/tiny_vllm" target="_blank">github</a>.</span>
|
| 64 |
+
</footer>
|
| 65 |
+
|
| 66 |
+
<script src="/static/app.js"></script>
|
| 67 |
+
</body>
|
| 68 |
+
</html>
|
web/style.css
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
:root {
|
| 2 |
+
--bg: #0e1116;
|
| 3 |
+
--bg-elev: #161b22;
|
| 4 |
+
--bg-elev2: #1f2630;
|
| 5 |
+
--fg: #e6edf3;
|
| 6 |
+
--muted: #8b949e;
|
| 7 |
+
--accent: #58a6ff;
|
| 8 |
+
--green: #3fb950;
|
| 9 |
+
--purple: #a371f7;
|
| 10 |
+
--orange: #f0883e;
|
| 11 |
+
--red: #f85149;
|
| 12 |
+
--border: #30363d;
|
| 13 |
+
--mono: ui-monospace, "JetBrains Mono", Menlo, Consolas, monospace;
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
* { box-sizing: border-box; }
|
| 17 |
+
|
| 18 |
+
body {
|
| 19 |
+
margin: 0;
|
| 20 |
+
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Inter, sans-serif;
|
| 21 |
+
background: var(--bg);
|
| 22 |
+
color: var(--fg);
|
| 23 |
+
font-size: 14px;
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
header {
|
| 27 |
+
display: flex; align-items: center; justify-content: space-between;
|
| 28 |
+
padding: 14px 20px;
|
| 29 |
+
border-bottom: 1px solid var(--border);
|
| 30 |
+
background: var(--bg-elev);
|
| 31 |
+
}
|
| 32 |
+
header h1 { font-size: 16px; margin: 0; font-weight: 600; }
|
| 33 |
+
.muted { color: var(--muted); font-weight: 400; }
|
| 34 |
+
.badge {
|
| 35 |
+
display: inline-block;
|
| 36 |
+
padding: 2px 8px;
|
| 37 |
+
border-radius: 10px;
|
| 38 |
+
font-size: 11px;
|
| 39 |
+
font-family: var(--mono);
|
| 40 |
+
}
|
| 41 |
+
.badge.online { background: rgba(63, 185, 80, 0.15); color: var(--green); }
|
| 42 |
+
.badge.offline { background: rgba(248, 81, 73, 0.15); color: var(--red); }
|
| 43 |
+
|
| 44 |
+
.prompt-box {
|
| 45 |
+
padding: 12px 20px;
|
| 46 |
+
border-bottom: 1px solid var(--border);
|
| 47 |
+
background: var(--bg-elev);
|
| 48 |
+
display: flex; flex-direction: column; gap: 10px;
|
| 49 |
+
}
|
| 50 |
+
.prompt-box textarea {
|
| 51 |
+
width: 100%;
|
| 52 |
+
background: var(--bg);
|
| 53 |
+
color: var(--fg);
|
| 54 |
+
border: 1px solid var(--border);
|
| 55 |
+
border-radius: 6px;
|
| 56 |
+
padding: 8px;
|
| 57 |
+
font-family: var(--mono);
|
| 58 |
+
resize: vertical;
|
| 59 |
+
}
|
| 60 |
+
.controls { display: flex; gap: 12px; align-items: center; flex-wrap: wrap; }
|
| 61 |
+
.controls label { display: flex; gap: 6px; align-items: center; font-size: 12px; color: var(--muted); }
|
| 62 |
+
.controls input {
|
| 63 |
+
width: 70px; background: var(--bg); color: var(--fg);
|
| 64 |
+
border: 1px solid var(--border); border-radius: 4px; padding: 3px 6px;
|
| 65 |
+
font-family: var(--mono);
|
| 66 |
+
}
|
| 67 |
+
button {
|
| 68 |
+
background: var(--accent); color: white;
|
| 69 |
+
border: none; border-radius: 4px;
|
| 70 |
+
padding: 6px 14px; font-weight: 500; cursor: pointer;
|
| 71 |
+
}
|
| 72 |
+
button:hover { filter: brightness(1.1); }
|
| 73 |
+
#send-twice { background: var(--purple); }
|
| 74 |
+
|
| 75 |
+
main {
|
| 76 |
+
display: grid;
|
| 77 |
+
grid-template-columns: 1fr 1fr;
|
| 78 |
+
grid-template-areas: "pool sched" "seqs seqs";
|
| 79 |
+
gap: 16px;
|
| 80 |
+
padding: 16px 20px;
|
| 81 |
+
}
|
| 82 |
+
.card {
|
| 83 |
+
background: var(--bg-elev);
|
| 84 |
+
border: 1px solid var(--border);
|
| 85 |
+
border-radius: 8px;
|
| 86 |
+
padding: 14px;
|
| 87 |
+
}
|
| 88 |
+
.card h2 { font-size: 14px; margin: 0 0 10px; font-weight: 600; }
|
| 89 |
+
.card h3 { font-size: 12px; margin: 14px 0 6px; color: var(--muted); text-transform: uppercase; letter-spacing: 0.06em; }
|
| 90 |
+
.card.grow { grid-area: seqs; }
|
| 91 |
+
.card:nth-child(1) { grid-area: pool; }
|
| 92 |
+
.card:nth-child(2) { grid-area: sched; }
|
| 93 |
+
|
| 94 |
+
/* ---- block pool ---- */
|
| 95 |
+
.block-pool {
|
| 96 |
+
display: grid;
|
| 97 |
+
grid-template-columns: repeat(auto-fill, 16px);
|
| 98 |
+
gap: 3px;
|
| 99 |
+
padding: 8px;
|
| 100 |
+
background: var(--bg);
|
| 101 |
+
border-radius: 6px;
|
| 102 |
+
max-height: 280px; overflow-y: auto;
|
| 103 |
+
}
|
| 104 |
+
.block {
|
| 105 |
+
width: 16px; height: 16px; border-radius: 3px;
|
| 106 |
+
background: var(--bg-elev2);
|
| 107 |
+
position: relative;
|
| 108 |
+
cursor: help;
|
| 109 |
+
border: 1px solid transparent;
|
| 110 |
+
}
|
| 111 |
+
.block.free { background: #2a3140; }
|
| 112 |
+
.block.cached { background: #1f3b5c; } /* free but in prefix cache */
|
| 113 |
+
.block.used { background: var(--green); }
|
| 114 |
+
.block.shared { background: var(--purple); }
|
| 115 |
+
.block.hashed { border-color: var(--orange); }
|
| 116 |
+
|
| 117 |
+
.legend { display: flex; gap: 14px; margin-top: 10px; font-size: 11px; color: var(--muted); flex-wrap: wrap; }
|
| 118 |
+
.legend-item { display: flex; align-items: center; gap: 5px; }
|
| 119 |
+
.swatch { width: 12px; height: 12px; border-radius: 3px; display: inline-block; }
|
| 120 |
+
.swatch-free { background: #2a3140; }
|
| 121 |
+
.swatch-cached { background: #1f3b5c; }
|
| 122 |
+
.swatch-used { background: var(--green); }
|
| 123 |
+
.swatch-shared { background: var(--purple); }
|
| 124 |
+
.swatch-hashed-edge { background: var(--bg-elev2); border: 1px solid var(--orange); }
|
| 125 |
+
|
| 126 |
+
/* ---- stats ---- */
|
| 127 |
+
.stats { display: grid; grid-template-columns: repeat(3, 1fr); gap: 8px; }
|
| 128 |
+
.stat {
|
| 129 |
+
background: var(--bg);
|
| 130 |
+
border-radius: 6px;
|
| 131 |
+
padding: 8px;
|
| 132 |
+
}
|
| 133 |
+
.stat-label { font-size: 10px; color: var(--muted); text-transform: uppercase; letter-spacing: 0.06em; }
|
| 134 |
+
.stat-value { font-family: var(--mono); font-size: 18px; margin-top: 3px; }
|
| 135 |
+
|
| 136 |
+
/* ---- log ---- */
|
| 137 |
+
.log {
|
| 138 |
+
background: var(--bg);
|
| 139 |
+
border-radius: 6px;
|
| 140 |
+
padding: 8px;
|
| 141 |
+
height: 140px; overflow-y: auto;
|
| 142 |
+
font-family: var(--mono); font-size: 11px;
|
| 143 |
+
margin: 0;
|
| 144 |
+
white-space: pre-wrap; word-break: break-word;
|
| 145 |
+
}
|
| 146 |
+
.log .ev-step { color: var(--muted); }
|
| 147 |
+
.log .ev-admit { color: var(--accent); }
|
| 148 |
+
.log .ev-finish { color: var(--green); }
|
| 149 |
+
.log .ev-preempt { color: var(--red); }
|
| 150 |
+
|
| 151 |
+
/* ---- sequences ---- */
|
| 152 |
+
#seqs { display: flex; flex-direction: column; gap: 10px; }
|
| 153 |
+
.seq {
|
| 154 |
+
background: var(--bg);
|
| 155 |
+
border: 1px solid var(--border);
|
| 156 |
+
border-radius: 6px;
|
| 157 |
+
padding: 10px;
|
| 158 |
+
}
|
| 159 |
+
.seq-header { display: flex; gap: 10px; align-items: center; }
|
| 160 |
+
.seq-id { font-family: var(--mono); color: var(--muted); font-size: 12px; }
|
| 161 |
+
.seq-status {
|
| 162 |
+
font-size: 10px; text-transform: uppercase; padding: 2px 6px; border-radius: 3px;
|
| 163 |
+
font-family: var(--mono); letter-spacing: 0.05em;
|
| 164 |
+
}
|
| 165 |
+
.seq-status.waiting { background: rgba(139, 148, 158, 0.2); color: var(--muted); }
|
| 166 |
+
.seq-status.prefilling { background: rgba(88, 166, 255, 0.15); color: var(--accent); }
|
| 167 |
+
.seq-status.running { background: rgba(63, 185, 80, 0.15); color: var(--green); }
|
| 168 |
+
.seq-status.finished { background: rgba(163, 113, 247, 0.15); color: var(--purple); }
|
| 169 |
+
.seq-status.preempted { background: rgba(240, 136, 62, 0.2); color: var(--orange); }
|
| 170 |
+
.seq-meta { color: var(--muted); font-size: 11px; font-family: var(--mono); margin-left: auto; }
|
| 171 |
+
.seq-blocks {
|
| 172 |
+
margin-top: 8px;
|
| 173 |
+
display: flex; gap: 2px; flex-wrap: wrap;
|
| 174 |
+
}
|
| 175 |
+
.seq-block {
|
| 176 |
+
width: 22px; height: 14px;
|
| 177 |
+
background: var(--bg-elev2);
|
| 178 |
+
font-size: 9px; line-height: 14px; text-align: center;
|
| 179 |
+
font-family: var(--mono);
|
| 180 |
+
border-radius: 2px;
|
| 181 |
+
color: var(--muted);
|
| 182 |
+
}
|
| 183 |
+
.seq-block.cached-hit { background: #1f3b5c; color: var(--accent); }
|
| 184 |
+
.seq-block.shared { background: var(--purple); color: white; }
|
| 185 |
+
.seq-text {
|
| 186 |
+
margin-top: 8px;
|
| 187 |
+
font-family: var(--mono); font-size: 12px;
|
| 188 |
+
background: var(--bg-elev2);
|
| 189 |
+
border-radius: 4px; padding: 6px;
|
| 190 |
+
min-height: 24px;
|
| 191 |
+
max-height: 180px;
|
| 192 |
+
overflow-y: auto;
|
| 193 |
+
white-space: pre-wrap; word-break: break-word;
|
| 194 |
+
}
|
| 195 |
+
.seq-text .prompt { color: var(--muted); }
|
| 196 |
+
.seq-text .gen { color: var(--fg); }
|
| 197 |
+
.seq-text .cursor {
|
| 198 |
+
display: inline-block; width: 6px; background: var(--accent);
|
| 199 |
+
animation: blink 1s steps(2, start) infinite;
|
| 200 |
+
}
|
| 201 |
+
@keyframes blink { to { visibility: hidden; } }
|
| 202 |
+
|
| 203 |
+
footer {
|
| 204 |
+
padding: 10px 20px;
|
| 205 |
+
border-top: 1px solid var(--border);
|
| 206 |
+
color: var(--muted); font-size: 11px;
|
| 207 |
+
}
|
| 208 |
+
footer a { color: var(--accent); text-decoration: none; }
|
| 209 |
+
|
| 210 |
+
@media (max-width: 900px) {
|
| 211 |
+
main { grid-template-columns: 1fr; grid-template-areas: "pool" "sched" "seqs"; }
|
| 212 |
+
.stats { grid-template-columns: repeat(2, 1fr); }
|
| 213 |
+
}
|