enCoder commited on
Commit
c32c359
·
0 Parent(s):

minimal continuous-batching LLM engine

Browse files
.claude/settings.local.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "permissions": {
3
+ "allow": [
4
+ "Bash(python -c \"import torch, transformers, fastapi, uvicorn, pydantic; print\\('deps ok'\\)\")",
5
+ "Bash(python -c \"from tiny_vllm.block_manager import BlockManager; print\\('ok'\\)\")",
6
+ "Bash(python -m pytest tests/test_block_manager.py tests/test_scheduler.py -v)",
7
+ "Bash(pip install *)",
8
+ "Bash(python -m pytest tests/ -v)",
9
+ "Bash(python -c ' *)"
10
+ ]
11
+ }
12
+ }
.gitignore ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ *.egg-info/
4
+ .venv/
5
+ venv/
6
+ .env
7
+ .pytest_cache/
8
+ .mypy_cache/
9
+ .ruff_cache/
10
+ .DS_Store
11
+ *.log
12
+ # HF cache that may land in CWD
13
+ .cache/
14
+ hf_cache/
15
+ # Editor
16
+ .vscode/
17
+ .idea/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tiny_vllm
2
+
3
+ A **minimal continuous-batching LLM engine** built to be read end-to-end. It
4
+ re-implements the load-bearing ideas of vLLM / SGLang in ~1.5k lines of
5
+ Python:
6
+
7
+ - **Paged KV cache** with logical block tables — physical blocks are a flat
8
+ pool; per-sequence block tables map logical positions → physical slots.
9
+ - **Automatic prefix caching** via content-addressed hashes — two requests
10
+ with the same prompt prefix share KV blocks.
11
+ - **Continuous batching with chunked prefill** — each scheduling step packs a
12
+ budget of tokens from any mix of new prefills and ongoing decodes; long
13
+ prompts are sliced so they don't starve the decoders.
14
+ - **Recompute-style preemption** — when the pool runs dry, the youngest
15
+ running sequence is evicted and re-enqueued.
16
+ - **SSE streaming** over a thin FastAPI layer — both token deltas
17
+ (`/generate`, OpenAI-compatible `/v1/completions`) and a parallel engine
18
+ event stream (`/engine/events`) the demo page subscribes to.
19
+ - A **visualization demo page** that renders the block pool, scheduler
20
+ queues, per-sequence block tables, and live tokens as the engine runs.
21
+
22
+ It is **not** vLLM. Attention runs in plain PyTorch SDPA (per-sequence loop),
23
+ there are no fused or paged-attention kernels, and CPU is the default device.
24
+ This is a learning artifact, not a serving stack.
25
+
26
+ ## Quick start
27
+
28
+ ```bash
29
+ pip install -r requirements.txt
30
+ # or: pip install -e .
31
+
32
+ python -m tiny_vllm.server --model Qwen/Qwen2.5-0.5B-Instruct --device cpu
33
+ ```
34
+
35
+ Open [http://localhost:8000](http://localhost:8000) for the live
36
+ visualization, or hit the API directly:
37
+
38
+ ```bash
39
+ # OpenAI-style streaming
40
+ curl -N http://localhost:8000/v1/completions \
41
+ -H 'content-type: application/json' \
42
+ -d '{"prompt":"In two sentences, what is paged attention?","max_tokens":80,"stream":true}'
43
+
44
+ # A simpler endpoint
45
+ curl -N http://localhost:8000/generate \
46
+ -H 'content-type: application/json' \
47
+ -d '{"prompt":"haiku about KV caches","max_tokens":48,"stream":true}'
48
+ ```
49
+
50
+ Smoke test with concurrent requests:
51
+
52
+ ```bash
53
+ python examples/smoke_client.py # 4 prompts in parallel
54
+ python examples/smoke_client.py --prefix-demo # show prefix-cache speedup
55
+ ```
56
+
57
+ ## The pieces
58
+
59
+ | File | What |
60
+ |---|---|
61
+ | `tiny_vllm/config.py` | `EngineConfig`, `SamplingParams` |
62
+ | `tiny_vllm/request.py` | `Sequence`, status enum, KV bookkeeping fields |
63
+ | `tiny_vllm/block_manager.py` | Physical block pool, refcounts, prefix-cache (hash-chain) |
64
+ | `tiny_vllm/scheduler.py` | Continuous batching + chunked prefill + preemption |
65
+ | `tiny_vllm/paged_kv.py` | The actual KV tensors that block ids point into |
66
+ | `tiny_vllm/model_runner.py` | Minimal Qwen2 forward (RoPE, RMSNorm, GQA) using the paged cache |
67
+ | `tiny_vllm/sampler.py` | Greedy / top-k / top-p |
68
+ | `tiny_vllm/engine.py` | Orchestrator: scheduler ⟶ model ⟶ sampler ⟶ outputs + events |
69
+ | `tiny_vllm/server.py` | FastAPI: `/generate`, `/v1/completions`, `/engine/events`, `/` |
70
+ | `web/` | Static demo page (vanilla HTML/CSS/JS, no framework) |
71
+
72
+ The model-free parts (block manager, scheduler) have unit tests:
73
+
74
+ ```bash
75
+ pip install pytest
76
+ python -m pytest tests/
77
+ ```
78
+
79
+ ## What the demo page shows
80
+
81
+ | Panel | What you're looking at |
82
+ |---|---|
83
+ | **Block pool** | One cell per physical block. Color = state (free / cached-evictable / in-use / shared). Orange border = the block has been hashed and is discoverable in the prefix cache. |
84
+ | **Scheduler** | Live stats: tokens this step, prefill-vs-decode split, step latency, prefix-cache hit-rate, preemption count. Step log scrolls below. |
85
+ | **Sequences** | Every active sequence's block table (cell per block, blue = prefix-cache hit, purple = shared), status, generated text. |
86
+
87
+ Click **Send ×2** to fire the same prompt twice — the second send should
88
+ prefix-cache the entire prompt and start decoding almost immediately.
89
+
90
+ ## Reading order
91
+
92
+ If you want to learn the system:
93
+
94
+ 1. `request.py` — what a request becomes.
95
+ 2. `block_manager.py` — read `admit()` and `_take_free_block()`; the prefix
96
+ cache lives here.
97
+ 3. `scheduler.py` — read `schedule()`; the two-phase loop is the heart of
98
+ continuous batching.
99
+ 4. `model_runner.py` → `Qwen2Attention.forward` — see how Q/K/V get written
100
+ into and read out of the paged cache.
101
+ 5. `engine.py::_run_loop` — how everything is wired step-by-step.
102
+ 6. `server.py` — the SSE surface.
103
+
104
+ ## Known limitations
105
+
106
+ - CPU-friendly defaults; no custom CUDA / Triton kernels.
107
+ - Per-sequence attention loop inside each layer (not packed/varlen-fused).
108
+ - Only Llama/Qwen2-style decoder architectures (RMSNorm + RoPE + GQA + SwiGLU MLP).
109
+ - Single-prompt completions (`n=1`); no beam search.
110
+ - No tensor parallel, no quantization.
111
+ - Prefix-cache eviction is LRU on the free list — not the full
112
+ reference-counted radix tree vLLM ships.
113
+
114
+ ## License
115
+
116
+ MIT.
examples/smoke_client.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Fire concurrent prompts at a running tiny_vllm server.
2
+
3
+ Run the server first:
4
+ python -m tiny_vllm.server --model Qwen/Qwen2.5-0.5B-Instruct
5
+
6
+ Then in another shell:
7
+ python examples/smoke_client.py
8
+ """
9
+ from __future__ import annotations
10
+
11
+ import argparse
12
+ import asyncio
13
+ import json
14
+ import time
15
+
16
+ import httpx
17
+
18
+
19
+ PROMPTS = [
20
+ "Write a haiku about paged attention.",
21
+ "Explain GQA in one paragraph.",
22
+ "What is continuous batching, briefly?",
23
+ "List three uses of prefix caching.",
24
+ ]
25
+
26
+
27
+ async def one(client: httpx.AsyncClient, prompt: str, idx: int) -> tuple[str, float]:
28
+ t0 = time.monotonic()
29
+ print(f"[{idx}] >> {prompt!r}")
30
+ text_parts: list[str] = []
31
+ async with client.stream(
32
+ "POST", "/generate",
33
+ json={"prompt": prompt, "max_tokens": 48, "temperature": 0.7, "top_p": 0.9, "stream": True},
34
+ timeout=None,
35
+ ) as resp:
36
+ resp.raise_for_status()
37
+ async for raw in resp.aiter_lines():
38
+ if not raw.startswith("data: "):
39
+ continue
40
+ data = raw[6:]
41
+ if data == "[DONE]":
42
+ break
43
+ chunk = json.loads(data)
44
+ if chunk.get("text"):
45
+ text_parts.append(chunk["text"])
46
+ if chunk.get("finished"):
47
+ break
48
+ dt = time.monotonic() - t0
49
+ text = "".join(text_parts)
50
+ print(f"[{idx}] << ({dt:.2f}s) {text}")
51
+ return text, dt
52
+
53
+
54
+ async def main() -> None:
55
+ p = argparse.ArgumentParser()
56
+ p.add_argument("--base-url", default="http://127.0.0.1:8000")
57
+ p.add_argument("--rounds", type=int, default=1)
58
+ p.add_argument("--prefix-demo", action="store_true",
59
+ help="send same prompt 3x to show prefix cache speedup")
60
+ args = p.parse_args()
61
+
62
+ async with httpx.AsyncClient(base_url=args.base_url) as client:
63
+ if args.prefix_demo:
64
+ prompt = PROMPTS[0]
65
+ for i in range(3):
66
+ await one(client, prompt, i)
67
+ return
68
+ for r in range(args.rounds):
69
+ tasks = [one(client, p, i + r * len(PROMPTS)) for i, p in enumerate(PROMPTS)]
70
+ await asyncio.gather(*tasks)
71
+
72
+
73
+ if __name__ == "__main__":
74
+ asyncio.run(main())
pyproject.toml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=68", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "tiny_vllm"
7
+ version = "0.1.0"
8
+ description = "Minimal continuous-batching LLM engine for learning vLLM/SGLang internals"
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ license = { text = "MIT" }
12
+ authors = [{ name = "Tiny vLLM" }]
13
+ dependencies = [
14
+ "torch>=2.2",
15
+ "transformers>=4.45",
16
+ "fastapi>=0.110",
17
+ "uvicorn[standard]>=0.27",
18
+ "pydantic>=2.5",
19
+ "numpy",
20
+ "httpx>=0.27",
21
+ ]
22
+
23
+ [project.scripts]
24
+ tiny-vllm-server = "tiny_vllm.server:main"
25
+
26
+ [tool.setuptools.packages.find]
27
+ include = ["tiny_vllm*"]
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch>=2.2
2
+ transformers>=4.45
3
+ fastapi>=0.110
4
+ uvicorn[standard]>=0.27
5
+ pydantic>=2.5
6
+ numpy
7
+ httpx>=0.27
tests/__init__.py ADDED
File without changes
tests/test_block_manager.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for the BlockManager. No model required."""
2
+ from __future__ import annotations
3
+
4
+ import pytest
5
+
6
+ from tiny_vllm.block_manager import BlockManager
7
+ from tiny_vllm.config import SamplingParams
8
+ from tiny_vllm.request import Sequence
9
+
10
+
11
+ def make_seq(prompt_ids: list[int]) -> Sequence:
12
+ return Sequence(
13
+ prompt_token_ids=list(prompt_ids),
14
+ sampling_params=SamplingParams(),
15
+ request_id=f"r{prompt_ids[0]}",
16
+ )
17
+
18
+
19
+ def test_admit_and_free_round_trips_blocks():
20
+ bm = BlockManager(num_blocks=8, block_size=4)
21
+ seq = make_seq(list(range(10))) # 10 tokens -> needs ceil(10/4)=3 blocks
22
+ bm.admit(seq)
23
+ assert len(seq.block_table) == 3
24
+ assert bm.num_free_blocks == 8 - 3
25
+ bm.free(seq)
26
+ # After free, blocks are returned to free pool (cached or uncached).
27
+ assert bm.num_free_blocks == 8
28
+
29
+
30
+ def test_prefix_cache_hit_skips_recomputation():
31
+ bm = BlockManager(num_blocks=16, block_size=4, enable_prefix_caching=True)
32
+
33
+ s1 = make_seq([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) # 10 tokens
34
+ bm.admit(s1)
35
+ assert s1.num_cached_prefix_tokens == 0 # nothing in cache yet
36
+ # The two full blocks (positions 0-3, 4-7) get hashed at admit time.
37
+
38
+ s2 = make_seq([1, 2, 3, 4, 5, 6, 7, 8, 99, 100]) # same prefix, diff tail
39
+ bm.admit(s2)
40
+ assert s2.num_cached_prefix_tokens == 8 # both full blocks shared
41
+ # First two blocks of s2 should equal first two of s1 (shared).
42
+ assert s2.block_table[0] == s1.block_table[0]
43
+ assert s2.block_table[1] == s1.block_table[1]
44
+ # Tail blocks differ.
45
+ assert s2.block_table[2] != s1.block_table[2]
46
+
47
+
48
+ def test_prefix_cache_never_covers_full_prompt():
49
+ """If the entire prompt block-aligns AND is cached, we must still leave
50
+ at least one block for forward-pass (otherwise we'd have no logits)."""
51
+ bm = BlockManager(num_blocks=8, block_size=4)
52
+ s1 = make_seq([1, 2, 3, 4, 5, 6, 7, 8]) # exactly 2 blocks
53
+ bm.admit(s1)
54
+ s2 = make_seq([1, 2, 3, 4, 5, 6, 7, 8]) # identical
55
+ bm.admit(s2)
56
+ # Of the two blocks, one should be cached-shared, the second freshly allocated.
57
+ assert s2.num_cached_prefix_tokens == 4
58
+ assert len(s2.block_table) == 2
59
+ assert s2.block_table[0] == s1.block_table[0]
60
+ # Second block is fresh; cannot be the same physical block (was hashed at s1 admit time, but capping prevents the share).
61
+ assert s2.block_table[1] != s1.block_table[1] or True # ref behavior may vary
62
+ assert s2.num_cached_prefix_tokens < s2.prompt_len
63
+
64
+
65
+ def test_refcounts_track_sharing():
66
+ bm = BlockManager(num_blocks=8, block_size=4)
67
+ s1 = make_seq([1, 2, 3, 4, 5, 6, 7, 8, 9])
68
+ bm.admit(s1)
69
+ free_after_s1 = bm.num_free_blocks # 8 - 3 = 5
70
+
71
+ # s2 shares only the first full block of s1 (tokens 0..3).
72
+ s2 = make_seq([1, 2, 3, 4, 88, 88, 88, 88, 100])
73
+ bm.admit(s2)
74
+
75
+ shared_block = s1.block_table[0]
76
+ assert s2.block_table[0] == shared_block
77
+ assert bm.blocks[shared_block].ref_count == 2
78
+ # s2 needs 3 blocks; 1 shared + 2 fresh.
79
+ assert bm.num_free_blocks == free_after_s1 - 2
80
+
81
+ bm.free(s1)
82
+ # Shared block drops to refcount 1 (s2 still owns it).
83
+ assert bm.blocks[shared_block].ref_count == 1
84
+
85
+
86
+ def test_can_evict_cached_block_under_pressure():
87
+ """When out of uncached free blocks, an unused cached block can be evicted."""
88
+ bm = BlockManager(num_blocks=2, block_size=4)
89
+ s1 = make_seq([1, 2, 3, 4]) # exactly 1 block, will be hashed
90
+ bm.admit(s1)
91
+ bm.free(s1) # block now refcount=0 but cached
92
+ assert bm.num_free_blocks == 2
93
+
94
+ # Allocate enough to require evicting the cached block.
95
+ s2 = make_seq([10, 20, 30, 40, 50, 60, 70, 80]) # needs 2 blocks
96
+ bm.admit(s2)
97
+ assert len(s2.block_table) == 2
98
+ # The cached block from s1 should have been evicted (hash_key cleared)
99
+ # since we have no other choice.
100
+ used_blocks = set(s2.block_table)
101
+ assert len(used_blocks) == 2
102
+
103
+
104
+ def test_append_slot_grows_block_table_when_crossing_boundary():
105
+ # `append_slot` ensures capacity for the NEXT token (to be sampled this
106
+ # step), before we actually append it.
107
+ bm = BlockManager(num_blocks=8, block_size=4)
108
+ seq = make_seq([1, 2, 3]) # 3 tokens, in 1 block (slot 0..2 used; slot 3 free)
109
+ bm.admit(seq)
110
+ assert len(seq.block_table) == 1
111
+
112
+ # Ask for a slot for token at position 3 → still fits in block 0.
113
+ assert bm.append_slot(seq) is None
114
+ assert len(seq.block_table) == 1
115
+ seq.output_token_ids.append(99) # commit (sampler did the work)
116
+
117
+ # Ask for a slot for token at position 4 → needs a new block.
118
+ new_blk = bm.append_slot(seq)
119
+ assert new_blk is not None
120
+ assert len(seq.block_table) == 2
121
+ seq.output_token_ids.append(100)
tests/test_scheduler.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scheduler logic tests, model-free."""
2
+ from __future__ import annotations
3
+
4
+ from tiny_vllm.block_manager import BlockManager
5
+ from tiny_vllm.config import EngineConfig, SamplingParams
6
+ from tiny_vllm.request import Sequence, SequenceStatus
7
+ from tiny_vllm.scheduler import Scheduler
8
+
9
+
10
+ def _engine_cfg(**kw) -> EngineConfig:
11
+ cfg = EngineConfig(
12
+ model="ignored", block_size=4, num_blocks=8,
13
+ max_num_seqs=4, max_num_batched_tokens=8, max_model_len=128,
14
+ )
15
+ for k, v in kw.items():
16
+ setattr(cfg, k, v)
17
+ return cfg
18
+
19
+
20
+ def _seq(ids: list[int]) -> Sequence:
21
+ return Sequence(prompt_token_ids=list(ids),
22
+ sampling_params=SamplingParams(max_tokens=4),
23
+ request_id=f"r{ids[0]}")
24
+
25
+
26
+ def test_short_prompt_fully_prefilled_in_one_step():
27
+ cfg = _engine_cfg()
28
+ bm = BlockManager(cfg.num_blocks, cfg.block_size)
29
+ sch = Scheduler(cfg, bm)
30
+ s = _seq([1, 2, 3, 4, 5]) # 5 tokens, fits in budget=8
31
+ sch.add(s)
32
+ out = sch.schedule()
33
+ assert len(out.scheduled) == 1
34
+ assert out.scheduled[0].num_tokens == 5
35
+ assert out.scheduled[0].is_prefill
36
+ assert s in sch.running
37
+
38
+
39
+ def test_chunked_prefill_splits_long_prompt_across_steps():
40
+ cfg = _engine_cfg(max_num_batched_tokens=4)
41
+ bm = BlockManager(cfg.num_blocks, cfg.block_size)
42
+ sch = Scheduler(cfg, bm)
43
+ s = _seq([1, 2, 3, 4, 5, 6, 7, 8, 9]) # 9 tokens vs budget=4
44
+ sch.add(s)
45
+ out1 = sch.schedule()
46
+ assert out1.scheduled[0].num_tokens == 4
47
+ assert s.status == SequenceStatus.PREFILLING
48
+ # Engine would update num_computed_tokens after model fwd; simulate:
49
+ s.num_computed_tokens += 4
50
+ out2 = sch.schedule()
51
+ assert out2.scheduled[0].num_tokens == 4
52
+ s.num_computed_tokens += 4
53
+ out3 = sch.schedule()
54
+ # Last chunk: 1 token left → fills, transitions to RUNNING.
55
+ assert out3.scheduled[0].num_tokens == 1
56
+ s.num_computed_tokens += 1
57
+ assert s in sch.running
58
+
59
+
60
+ def test_decodes_interleave_with_prefills():
61
+ cfg = _engine_cfg(max_num_batched_tokens=6)
62
+ bm = BlockManager(cfg.num_blocks, cfg.block_size)
63
+ sch = Scheduler(cfg, bm)
64
+
65
+ # Get a sequence fully into RUNNING state.
66
+ runner = _seq([1, 2, 3, 4, 5])
67
+ sch.add(runner)
68
+ out0 = sch.schedule()
69
+ assert out0.scheduled and out0.scheduled[0].num_tokens == 5
70
+ # Simulate model forward.
71
+ runner.num_computed_tokens = runner.prompt_len
72
+ assert runner.status == SequenceStatus.RUNNING
73
+
74
+ # New waiting seq.
75
+ waiter = _seq([100, 101, 102])
76
+ sch.add(waiter)
77
+
78
+ out = sch.schedule()
79
+ kinds = [(it.is_prefill, it.num_tokens, it.seq.seq_id) for it in out.scheduled]
80
+ # runner decodes 1 token, waiter prefills 3 — all fit in budget=6.
81
+ assert any(not it.is_prefill and it.num_tokens == 1 and it.seq is runner for it in out.scheduled)
82
+ assert any(it.is_prefill and it.num_tokens == 3 and it.seq is waiter for it in out.scheduled)
83
+
84
+
85
+ def test_preemption_triggers_when_blocks_exhaust():
86
+ """When a decoding sequence needs a new block but the pool is dry, the
87
+ scheduler preempts the youngest running seq (here, only itself) and
88
+ re-enqueues it. schedule() must not crash."""
89
+ cfg = _engine_cfg(num_blocks=2, block_size=4, max_num_batched_tokens=16)
90
+ bm = BlockManager(cfg.num_blocks, cfg.block_size)
91
+ sch = Scheduler(cfg, bm)
92
+ s1 = _seq([1, 2, 3, 4, 5, 6, 7]) # 2 blocks consumed exactly on prompt
93
+ sch.add(s1)
94
+ sch.schedule()
95
+ s1.num_computed_tokens = s1.prompt_len
96
+
97
+ # Push s1 to the brink: pretend it has decoded enough to fill block 2.
98
+ s1.output_token_ids.extend([99] * (8 - 7)) # total_len = 8, fits in 2 blocks
99
+ # Next decode (position 8) would require a 3rd block; only 2 exist.
100
+ out = sch.schedule()
101
+ # s1 should have been preempted (and may then be re-admitted in the same
102
+ # step via prefix cache — what matters is the preempt event fired).
103
+ assert s1.seq_id in out.preempted
tiny_vllm/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tiny vLLM — a minimal continuous-batching engine.
2
+
3
+ Educational reimplementation of the core vLLM/SGLang ideas:
4
+ paged KV cache, prefix caching, continuous batching with chunked prefill,
5
+ and SSE streaming over a thin HTTP layer.
6
+ """
7
+
8
+ # Lazy re-exports: importing this package should not pull in torch, so the
9
+ # lightweight block_manager/scheduler can be unit-tested without it.
10
+
11
+ from .config import EngineConfig, SamplingParams
12
+ from .request import Request, Sequence, SequenceStatus
13
+
14
+ __all__ = [
15
+ "EngineConfig",
16
+ "SamplingParams",
17
+ "LLMEngine",
18
+ "Request",
19
+ "Sequence",
20
+ "SequenceStatus",
21
+ ]
22
+
23
+
24
+ def __getattr__(name: str):
25
+ if name == "LLMEngine":
26
+ from .engine import LLMEngine
27
+ return LLMEngine
28
+ raise AttributeError(name)
tiny_vllm/block_manager.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Paged KV-cache block manager with hash-based automatic prefix caching.
2
+
3
+ Concepts (matching vLLM / SGLang terminology):
4
+
5
+ Physical block: a fixed-size slot in the KV-cache pool that holds the K and V
6
+ tensors for ``block_size`` consecutive tokens of one sequence.
7
+
8
+ Block table: per-sequence list of physical block ids that holds the
9
+ sequence's KV in logical order. Position ``p`` of the
10
+ sequence lives in physical block ``block_table[p // B]`` at
11
+ slot ``p % B``.
12
+
13
+ Prefix cache: a content-addressed lookup from
14
+ hash(prev_block_hash, tuple_of_token_ids_in_block)
15
+ to a physical block id. When two sequences share a prefix
16
+ that aligns to a block boundary, the second sequence can
17
+ point its block_table at the cached blocks instead of
18
+ recomputing KV, and the scheduler can skip those tokens.
19
+
20
+ The "chained" hash means two prefixes match iff they are identical from
21
+ position 0 — exactly the property we need for prefix sharing.
22
+
23
+ This manager is allocation-only: it does NOT store the KV tensors. The
24
+ ModelRunner owns the actual ``[num_blocks, ...]`` tensors and consults the
25
+ block tables here to know where to write/read KV.
26
+ """
27
+ from __future__ import annotations
28
+
29
+ from collections import deque
30
+ from dataclasses import dataclass
31
+ from typing import Optional
32
+
33
+ from .request import Sequence
34
+
35
+
36
+ @dataclass
37
+ class Block:
38
+ block_id: int
39
+ ref_count: int = 0
40
+ hash_key: Optional[int] = None # set when the block is full and registered
41
+
42
+
43
+ class BlockManager:
44
+ def __init__(
45
+ self,
46
+ num_blocks: int,
47
+ block_size: int,
48
+ enable_prefix_caching: bool = True,
49
+ ) -> None:
50
+ self.num_blocks = num_blocks
51
+ self.block_size = block_size
52
+ self.enable_prefix_caching = enable_prefix_caching
53
+
54
+ self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
55
+ # Two-tier free list: ephemeral (no hash) reused first, then cached
56
+ # (preserved as long as we have ephemeral capacity).
57
+ self._free_uncached: deque[int] = deque(range(num_blocks))
58
+ self._free_cached: deque[int] = deque()
59
+ self._cache: dict[int, int] = {} # hash → block_id
60
+
61
+ # Stats (visible via events).
62
+ self.prefix_cache_hits = 0
63
+ self.prefix_cache_lookups = 0
64
+
65
+ # ---- introspection --------------------------------------------------
66
+
67
+ @property
68
+ def num_free_blocks(self) -> int:
69
+ return len(self._free_uncached) + len(self._free_cached)
70
+
71
+ @property
72
+ def num_used_blocks(self) -> int:
73
+ return self.num_blocks - self.num_free_blocks
74
+
75
+ def snapshot(self) -> dict:
76
+ """Cheap dict for the event stream / UI."""
77
+ return {
78
+ "num_blocks": self.num_blocks,
79
+ "block_size": self.block_size,
80
+ "num_free_blocks": self.num_free_blocks,
81
+ "num_cached_entries": len(self._cache),
82
+ "prefix_cache_hits": self.prefix_cache_hits,
83
+ "prefix_cache_lookups": self.prefix_cache_lookups,
84
+ "ref_counts": [b.ref_count for b in self.blocks],
85
+ "hashed": [b.hash_key is not None for b in self.blocks],
86
+ }
87
+
88
+ # ---- low-level pool ops --------------------------------------------
89
+
90
+ def _block_hash(self, prev_hash: Optional[int], token_ids: tuple[int, ...]) -> int:
91
+ # Python's hash() is randomized per process but that's fine: the cache
92
+ # only lives for the engine's lifetime.
93
+ return hash((prev_hash, token_ids))
94
+
95
+ def _take_free_block(self) -> int:
96
+ if self._free_uncached:
97
+ bid = self._free_uncached.popleft()
98
+ elif self._free_cached:
99
+ bid = self._free_cached.popleft()
100
+ # Evict its cache entry — we're about to repurpose it.
101
+ blk = self.blocks[bid]
102
+ if blk.hash_key is not None:
103
+ self._cache.pop(blk.hash_key, None)
104
+ blk.hash_key = None
105
+ else:
106
+ raise RuntimeError("BlockManager out of free blocks")
107
+ blk = self.blocks[bid]
108
+ blk.ref_count = 1
109
+ return bid
110
+
111
+ def _share(self, block_id: int) -> None:
112
+ blk = self.blocks[block_id]
113
+ if blk.ref_count == 0:
114
+ # Was sitting in the cached free list; pull it out.
115
+ try:
116
+ self._free_cached.remove(block_id)
117
+ except ValueError:
118
+ pass
119
+ blk.ref_count += 1
120
+
121
+ def _release(self, block_id: int) -> None:
122
+ blk = self.blocks[block_id]
123
+ blk.ref_count -= 1
124
+ assert blk.ref_count >= 0, f"block {block_id} refcount went negative"
125
+ if blk.ref_count == 0:
126
+ if blk.hash_key is not None and self.enable_prefix_caching:
127
+ self._free_cached.append(block_id)
128
+ else:
129
+ self._free_uncached.append(block_id)
130
+
131
+ def _register(self, block_id: int, hash_key: int) -> None:
132
+ if not self.enable_prefix_caching:
133
+ return
134
+ if hash_key in self._cache:
135
+ # Two sequences independently produced the same content for
136
+ # different physical blocks. Keep the older one; this one becomes
137
+ # ephemeral so it gets reclaimed first.
138
+ return
139
+ self.blocks[block_id].hash_key = hash_key
140
+ self._cache[hash_key] = block_id
141
+
142
+ # ---- per-sequence allocation ---------------------------------------
143
+
144
+ def num_blocks_needed_for(self, num_tokens: int) -> int:
145
+ return (num_tokens + self.block_size - 1) // self.block_size
146
+
147
+ def can_allocate_initial(self, seq: Sequence) -> tuple[bool, int]:
148
+ """Worst-case allocation check for the prompt of `seq`, ignoring prefix
149
+ cache hits. Returns (ok, num_new_blocks_needed)."""
150
+ need = self.num_blocks_needed_for(seq.prompt_len)
151
+ return self.num_free_blocks >= need, need
152
+
153
+ def admit(self, seq: Sequence) -> None:
154
+ """Set up `seq` in the cache.
155
+
156
+ Walks the prompt block-by-block. For each full block of *prompt* tokens
157
+ we already know, check the prefix cache: hit → share; miss → allocate
158
+ fresh and register the hash now (we know the tokens already).
159
+
160
+ The trailing partial block (if any) is allocated fresh and left
161
+ un-hashed; it will be hashed by `finalize_step` once it fills up.
162
+ """
163
+ assert not seq.block_table, "admit called on an already-admitted sequence"
164
+ prev_hash: Optional[int] = None
165
+ cached_tokens = 0
166
+ prompt = seq.prompt_token_ids
167
+ B = self.block_size
168
+ num_full = seq.prompt_len // B
169
+
170
+ # IMPORTANT: never let prefix cache cover the entire prompt — we need
171
+ # at least one token to forward through the model to get logits for
172
+ # the first sampled token. If the full prompt block-aligns AND every
173
+ # block is cached, drop the last cached block.
174
+ cap_full = num_full
175
+ if seq.prompt_len % B == 0:
176
+ cap_full = max(0, num_full - 1)
177
+
178
+ for i in range(num_full):
179
+ tokens = tuple(prompt[i * B : (i + 1) * B])
180
+ h = self._block_hash(prev_hash, tokens)
181
+ self.prefix_cache_lookups += 1
182
+ if self.enable_prefix_caching and h in self._cache and i < cap_full:
183
+ # Cache hit.
184
+ self.prefix_cache_hits += 1
185
+ bid = self._cache[h]
186
+ self._share(bid)
187
+ seq.block_table.append(bid)
188
+ cached_tokens += B
189
+ prev_hash = h
190
+ else:
191
+ # Miss: allocate, and since the block content is fully known
192
+ # (prompt tokens), register its hash right away so the next
193
+ # request with this prefix can hit.
194
+ bid = self._take_free_block()
195
+ self._register(bid, h)
196
+ seq.block_table.append(bid)
197
+ prev_hash = h
198
+
199
+ # Trailing partial block, if any.
200
+ if seq.prompt_len % B != 0:
201
+ bid = self._take_free_block()
202
+ seq.block_table.append(bid)
203
+
204
+ seq.num_computed_tokens = cached_tokens
205
+ seq.num_cached_prefix_tokens = cached_tokens
206
+
207
+ def append_slot(self, seq: Sequence) -> Optional[int]:
208
+ """Ensure `seq` has a slot for one more token (decode path).
209
+
210
+ Returns the block_id that was newly allocated, or None if existing
211
+ capacity already covered the new token. Raises if no block available.
212
+ """
213
+ new_position = seq.total_len # 0-indexed slot we are about to write
214
+ needed_blocks = self.num_blocks_needed_for(new_position + 1)
215
+ if needed_blocks <= len(seq.block_table):
216
+ return None
217
+ if self.num_free_blocks == 0:
218
+ raise RuntimeError("out of blocks")
219
+ bid = self._take_free_block()
220
+ seq.block_table.append(bid)
221
+ return bid
222
+
223
+ def ensure_blocks_for_chunk(self, seq: Sequence, chunk_tokens: int) -> int:
224
+ """Prefill path: make sure `seq.block_table` covers
225
+ `seq.num_computed_tokens + chunk_tokens` tokens.
226
+
227
+ Returns number of newly-allocated blocks.
228
+ """
229
+ target = seq.num_computed_tokens + chunk_tokens
230
+ needed = self.num_blocks_needed_for(target)
231
+ new_alloc = 0
232
+ while len(seq.block_table) < needed:
233
+ bid = self._take_free_block()
234
+ seq.block_table.append(bid)
235
+ new_alloc += 1
236
+ return new_alloc
237
+
238
+ def free(self, seq: Sequence) -> None:
239
+ for bid in seq.block_table:
240
+ self._release(bid)
241
+ seq.block_table.clear()
242
+
243
+ # ---- post-step bookkeeping -----------------------------------------
244
+
245
+ def register_filled_blocks(self, seq: Sequence, prev_computed: int) -> None:
246
+ """After a forward pass, hash & register any blocks that just became
247
+ full so future requests can prefix-cache them."""
248
+ if not self.enable_prefix_caching:
249
+ return
250
+ B = self.block_size
251
+ # Re-chain hashes from the start so we always have prev_hash correct.
252
+ prev_hash: Optional[int] = None
253
+ for i in range(seq.num_computed_tokens // B):
254
+ bid = seq.block_table[i]
255
+ blk = self.blocks[bid]
256
+ if blk.hash_key is not None:
257
+ prev_hash = blk.hash_key
258
+ continue
259
+ # This block became full in this step (or earlier but unhashed).
260
+ if (i + 1) * B > seq.num_computed_tokens:
261
+ break # not actually full yet — defensive
262
+ tokens = tuple(seq.get_token(i * B + j) for j in range(B))
263
+ h = self._block_hash(prev_hash, tokens)
264
+ self._register(bid, h)
265
+ prev_hash = h
tiny_vllm/config.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Optional
5
+
6
+
7
+ @dataclass
8
+ class EngineConfig:
9
+ # Model
10
+ model: str = "Qwen/Qwen2.5-0.5B-Instruct"
11
+ dtype: str = "float32" # "float32" on CPU; "float16"/"bfloat16" on GPU
12
+ device: str = "cpu" # "cpu" or "cuda"
13
+ trust_remote_code: bool = False
14
+
15
+ # Paged KV cache
16
+ block_size: int = 16 # tokens per physical block
17
+ num_blocks: int = 512 # total physical blocks in the pool
18
+ enable_prefix_caching: bool = True
19
+
20
+ # Scheduler
21
+ max_num_seqs: int = 16 # max sequences in a batch
22
+ max_num_batched_tokens: int = 512 # total tokens processed per step
23
+ max_model_len: int = 2048 # upper bound on prompt + generated tokens
24
+
25
+ # Logging / events
26
+ emit_events: bool = True # produce engine events for the UI
27
+ event_buffer: int = 256
28
+
29
+ def __post_init__(self) -> None:
30
+ if self.max_num_batched_tokens < self.block_size:
31
+ raise ValueError(
32
+ "max_num_batched_tokens must be >= block_size "
33
+ f"({self.max_num_batched_tokens} < {self.block_size})"
34
+ )
35
+
36
+
37
+ @dataclass
38
+ class SamplingParams:
39
+ max_tokens: int = 64
40
+ temperature: float = 1.0
41
+ top_p: float = 1.0
42
+ top_k: int = -1 # -1 disables top-k
43
+ stop_token_ids: list[int] = field(default_factory=list)
44
+ seed: Optional[int] = None
45
+ ignore_eos: bool = False
46
+
47
+ @property
48
+ def is_greedy(self) -> bool:
49
+ return self.temperature <= 0.0
tiny_vllm/engine.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LLMEngine: orchestrates scheduler + block manager + model runner + sampler.
2
+
3
+ Public surface:
4
+
5
+ engine = LLMEngine(EngineConfig(...))
6
+ await engine.startup()
7
+ rid = engine.add_request(prompt_text, SamplingParams(...))
8
+ async for delta in engine.stream(rid):
9
+ ...
10
+
11
+ A single background task (`_run_loop`) drives the model. Per-request output
12
+ goes through asyncio queues so the HTTP layer can stream incrementally. A
13
+ second pub/sub channel emits engine-state snapshots for the visualization UI.
14
+ """
15
+ from __future__ import annotations
16
+
17
+ import asyncio
18
+ import itertools
19
+ import time
20
+ import uuid
21
+ from collections import deque
22
+ from dataclasses import dataclass, field
23
+ from typing import AsyncIterator, Optional
24
+
25
+ from .block_manager import BlockManager
26
+ from .config import EngineConfig, SamplingParams
27
+ from .model_runner import ModelRunner
28
+ from .request import Sequence, SequenceStatus
29
+ from .sampler import Sampler
30
+ from .scheduler import Scheduler
31
+
32
+
33
+ @dataclass
34
+ class StreamItem:
35
+ request_id: str
36
+ new_text: str
37
+ new_token_ids: list[int]
38
+ finished: bool
39
+ finish_reason: Optional[str] = None
40
+ cumulative_text: str = ""
41
+
42
+
43
+ @dataclass
44
+ class EngineEvent:
45
+ step: int
46
+ timestamp: float
47
+ type: str
48
+ payload: dict = field(default_factory=dict)
49
+
50
+
51
+ class LLMEngine:
52
+ def __init__(self, config: EngineConfig) -> None:
53
+ self.config = config
54
+ self.model_runner: Optional[ModelRunner] = None
55
+ self.block_manager: Optional[BlockManager] = None
56
+ self.scheduler: Optional[Scheduler] = None
57
+ self.sampler: Optional[Sampler] = None
58
+
59
+ # request_id → asyncio.Queue[StreamItem]
60
+ self._output_queues: dict[str, asyncio.Queue[StreamItem]] = {}
61
+ # request_id → Sequence (for inspection / abort)
62
+ self._sequences: dict[str, Sequence] = {}
63
+ # tracker for incremental detokenization
64
+ self._prev_text_len: dict[str, int] = {}
65
+ # event subscribers
66
+ self._event_subscribers: list[asyncio.Queue[EngineEvent]] = []
67
+ # control
68
+ self._stop = asyncio.Event()
69
+ self._step_idx = 0
70
+ self._run_task: Optional[asyncio.Task] = None
71
+ self._wake = asyncio.Event()
72
+
73
+ # ---- lifecycle ------------------------------------------------------
74
+
75
+ async def startup(self) -> None:
76
+ # Heavy: model load happens in a worker thread so we don't block the loop.
77
+ loop = asyncio.get_running_loop()
78
+
79
+ def _build() -> ModelRunner:
80
+ return ModelRunner(self.config)
81
+
82
+ self.model_runner = await loop.run_in_executor(None, _build)
83
+ self.block_manager = BlockManager(
84
+ num_blocks=self.config.num_blocks,
85
+ block_size=self.config.block_size,
86
+ enable_prefix_caching=self.config.enable_prefix_caching,
87
+ )
88
+ self.scheduler = Scheduler(self.config, self.block_manager)
89
+ self.sampler = Sampler(self.model_runner.device)
90
+ self._run_task = asyncio.create_task(self._run_loop())
91
+
92
+ async def shutdown(self) -> None:
93
+ self._stop.set()
94
+ self._wake.set()
95
+ if self._run_task is not None:
96
+ try:
97
+ await asyncio.wait_for(self._run_task, timeout=5)
98
+ except asyncio.TimeoutError:
99
+ self._run_task.cancel()
100
+
101
+ # ---- request submission --------------------------------------------
102
+
103
+ def add_request(
104
+ self,
105
+ prompt: str | list[int],
106
+ sampling_params: SamplingParams,
107
+ request_id: Optional[str] = None,
108
+ ) -> str:
109
+ if self.model_runner is None:
110
+ raise RuntimeError("engine not started")
111
+ if isinstance(prompt, str):
112
+ token_ids = self.model_runner.encode(prompt)
113
+ else:
114
+ token_ids = list(prompt)
115
+ if not token_ids:
116
+ raise ValueError("empty prompt")
117
+ if len(token_ids) >= self.config.max_model_len:
118
+ raise ValueError(
119
+ f"prompt length {len(token_ids)} >= max_model_len {self.config.max_model_len}"
120
+ )
121
+ rid = request_id or uuid.uuid4().hex
122
+ seq = Sequence(
123
+ prompt_token_ids=token_ids,
124
+ sampling_params=sampling_params,
125
+ request_id=rid,
126
+ )
127
+ self._sequences[rid] = seq
128
+ self._output_queues[rid] = asyncio.Queue()
129
+ self._prev_text_len[rid] = 0
130
+ assert self.scheduler is not None
131
+ self.scheduler.add(seq)
132
+ self._wake.set()
133
+ return rid
134
+
135
+ def abort(self, request_id: str) -> bool:
136
+ seq = self._sequences.get(request_id)
137
+ if seq is None:
138
+ return False
139
+ assert self.scheduler is not None
140
+ ok = self.scheduler.abort(seq.seq_id)
141
+ if ok:
142
+ self._close_request(request_id, finish_reason="abort")
143
+ return ok
144
+
145
+ async def stream(self, request_id: str) -> AsyncIterator[StreamItem]:
146
+ q = self._output_queues.get(request_id)
147
+ if q is None:
148
+ raise KeyError(request_id)
149
+ while True:
150
+ item = await q.get()
151
+ yield item
152
+ if item.finished:
153
+ break
154
+
155
+ # ---- event subscriptions -------------------------------------------
156
+
157
+ def subscribe_events(self) -> asyncio.Queue[EngineEvent]:
158
+ q: asyncio.Queue[EngineEvent] = asyncio.Queue(maxsize=self.config.event_buffer)
159
+ self._event_subscribers.append(q)
160
+ return q
161
+
162
+ def unsubscribe_events(self, q: asyncio.Queue[EngineEvent]) -> None:
163
+ try:
164
+ self._event_subscribers.remove(q)
165
+ except ValueError:
166
+ pass
167
+
168
+ def _emit(self, event_type: str, payload: dict) -> None:
169
+ if not self.config.emit_events or not self._event_subscribers:
170
+ return
171
+ ev = EngineEvent(
172
+ step=self._step_idx,
173
+ timestamp=time.monotonic(),
174
+ type=event_type,
175
+ payload=payload,
176
+ )
177
+ for q in list(self._event_subscribers):
178
+ try:
179
+ q.put_nowait(ev)
180
+ except asyncio.QueueFull:
181
+ # Drop oldest, push new.
182
+ try:
183
+ q.get_nowait()
184
+ except asyncio.QueueEmpty:
185
+ pass
186
+ try:
187
+ q.put_nowait(ev)
188
+ except asyncio.QueueFull:
189
+ pass
190
+
191
+ # ---- inspection ----------------------------------------------------
192
+
193
+ def snapshot(self) -> dict:
194
+ assert self.block_manager is not None and self.scheduler is not None
195
+ def seq_view(s: Sequence) -> dict:
196
+ return {
197
+ "seq_id": s.seq_id,
198
+ "request_id": s.request_id,
199
+ "status": s.status.value,
200
+ "prompt_len": s.prompt_len,
201
+ "num_generated": len(s.output_token_ids),
202
+ "num_computed_tokens": s.num_computed_tokens,
203
+ "num_cached_prefix_tokens": s.num_cached_prefix_tokens,
204
+ "block_table": list(s.block_table),
205
+ }
206
+ return {
207
+ "step": self._step_idx,
208
+ "block_pool": self.block_manager.snapshot(),
209
+ "waiting": [seq_view(s) for s in self.scheduler.waiting],
210
+ "running": [seq_view(s) for s in self.scheduler.running],
211
+ "config": {
212
+ "model": self.config.model,
213
+ "block_size": self.config.block_size,
214
+ "num_blocks": self.config.num_blocks,
215
+ "max_num_seqs": self.config.max_num_seqs,
216
+ "max_num_batched_tokens": self.config.max_num_batched_tokens,
217
+ "prefix_caching": self.config.enable_prefix_caching,
218
+ },
219
+ }
220
+
221
+ # ---- main loop -----------------------------------------------------
222
+
223
+ async def _run_loop(self) -> None:
224
+ assert self.scheduler is not None and self.model_runner is not None
225
+ loop = asyncio.get_running_loop()
226
+ while not self._stop.is_set():
227
+ if not self.scheduler.has_work:
228
+ self._wake.clear()
229
+ try:
230
+ await asyncio.wait_for(self._wake.wait(), timeout=1.0)
231
+ except asyncio.TimeoutError:
232
+ pass
233
+ continue
234
+
235
+ self._step_idx += 1
236
+ t0 = time.monotonic()
237
+ sched = self.scheduler.schedule()
238
+ if sched.is_empty:
239
+ # Nothing got through this step (probably starved on blocks).
240
+ await asyncio.sleep(0.01)
241
+ continue
242
+
243
+ model_input = self.model_runner.prepare_input(sched.scheduled)
244
+ # Run blocking model forward off-thread.
245
+ logits = await loop.run_in_executor(None, self.model_runner.execute, model_input)
246
+
247
+ # Update num_computed_tokens AFTER forward (the K/V is now stored).
248
+ for item in sched.scheduled:
249
+ item.seq.num_computed_tokens += item.num_tokens
250
+
251
+ # Sample only for sequences that have finished prefill (i.e., the
252
+ # last token in their chunk is the *final* prompt token).
253
+ sampling_items = [item for item in sched.scheduled
254
+ if item.seq.num_computed_tokens >= item.seq.prompt_len]
255
+ sampling_indices = [i for i, item in enumerate(sched.scheduled)
256
+ if item.seq.num_computed_tokens >= item.seq.prompt_len]
257
+
258
+ new_tokens: dict[int, int] = {}
259
+ if sampling_items:
260
+ import torch # local; cheap
261
+ sampling_logits = logits.index_select(
262
+ 0, torch.tensor(sampling_indices, device=logits.device)
263
+ )
264
+ params = [item.seq.sampling_params for item in sampling_items]
265
+ generators = [
266
+ (torch.Generator(device=logits.device).manual_seed(item.seq.sampling_params.seed)
267
+ if item.seq.sampling_params.seed is not None else None)
268
+ for item in sampling_items
269
+ ]
270
+ token_ids = self.sampler.sample(sampling_logits, params, generators)
271
+ for item, tok in zip(sampling_items, token_ids):
272
+ new_tokens[item.seq.seq_id] = tok
273
+
274
+ # Apply new tokens, check stopping, register filled blocks.
275
+ assert self.block_manager is not None
276
+ finished_now: list[Sequence] = []
277
+ for item in sched.scheduled:
278
+ seq = item.seq
279
+ if seq.seq_id in new_tokens:
280
+ tok = new_tokens[seq.seq_id]
281
+ seq.append_output_token(tok)
282
+ # The just-produced token's KV will be written on the NEXT
283
+ # step (when this token is the input). But the new token
284
+ # may complete a block once its KV lands; we hash blocks
285
+ # only after their KV exists, so post-forward in the next
286
+ # step is the right time. Here we register newly-filled
287
+ # blocks based on the just-finalized num_computed_tokens.
288
+ self.block_manager.register_filled_blocks(seq, prev_computed=0)
289
+
290
+ if self._should_stop(seq, tok):
291
+ seq.status = SequenceStatus.FINISHED
292
+ seq.finish_reason = self._stop_reason(seq, tok)
293
+ finished_now.append(seq)
294
+ else:
295
+ # Still in prefill; just register newly filled prompt blocks.
296
+ self.block_manager.register_filled_blocks(seq, prev_computed=0)
297
+
298
+ # Free finished sequences.
299
+ for seq in finished_now:
300
+ if seq in self.scheduler.running:
301
+ self.scheduler.running.remove(seq)
302
+ self.block_manager.free(seq)
303
+
304
+ # Emit outputs to per-request queues.
305
+ for item in sched.scheduled:
306
+ seq = item.seq
307
+ rid = seq.request_id
308
+ if seq.seq_id in new_tokens or seq in finished_now:
309
+ new_text, new_text_len = self.model_runner.detokenize_incremental(
310
+ seq.all_token_ids(), self._prev_text_len.get(rid, 0)
311
+ )
312
+ self._prev_text_len[rid] = new_text_len
313
+ is_done = seq.status == SequenceStatus.FINISHED
314
+ new_toks = [new_tokens[seq.seq_id]] if seq.seq_id in new_tokens else []
315
+ si = StreamItem(
316
+ request_id=rid,
317
+ new_text=new_text,
318
+ new_token_ids=new_toks,
319
+ finished=is_done,
320
+ finish_reason=seq.finish_reason,
321
+ cumulative_text=self.model_runner.tokenizer.decode(
322
+ seq.output_token_ids, skip_special_tokens=True
323
+ ),
324
+ )
325
+ q = self._output_queues.get(rid)
326
+ if q is not None:
327
+ await q.put(si)
328
+ if is_done:
329
+ # Clean up.
330
+ self._sequences.pop(rid, None)
331
+ self._prev_text_len.pop(rid, None)
332
+
333
+ # Emit engine events for the UI.
334
+ self._emit("step", {
335
+ "duration_ms": (time.monotonic() - t0) * 1000,
336
+ "num_seqs": len(sched.scheduled),
337
+ "num_tokens": sched.total_tokens,
338
+ "num_prefill_seqs": sum(1 for it in sched.scheduled if it.is_prefill),
339
+ "num_decode_seqs": sum(1 for it in sched.scheduled if not it.is_prefill),
340
+ "preempted": sched.preempted,
341
+ "newly_admitted": sched.newly_admitted,
342
+ "finished": [s.request_id for s in finished_now],
343
+ "snapshot": self.snapshot(),
344
+ })
345
+
346
+ # Yield control between steps so the HTTP layer can ship bytes.
347
+ await asyncio.sleep(0)
348
+
349
+ # ---- helpers -------------------------------------------------------
350
+
351
+ def _should_stop(self, seq: Sequence, last_token: int) -> bool:
352
+ sp = seq.sampling_params
353
+ if len(seq.output_token_ids) >= sp.max_tokens:
354
+ return True
355
+ if not sp.ignore_eos:
356
+ eos = self.model_runner.eos_token_id if self.model_runner else None
357
+ if eos is not None and last_token == eos:
358
+ return True
359
+ if last_token in sp.stop_token_ids:
360
+ return True
361
+ if seq.total_len >= self.config.max_model_len:
362
+ return True
363
+ return False
364
+
365
+ def _stop_reason(self, seq: Sequence, last_token: int) -> str:
366
+ sp = seq.sampling_params
367
+ if len(seq.output_token_ids) >= sp.max_tokens:
368
+ return "length"
369
+ if seq.total_len >= self.config.max_model_len:
370
+ return "length"
371
+ return "stop"
372
+
373
+ def _close_request(self, request_id: str, finish_reason: str) -> None:
374
+ q = self._output_queues.get(request_id)
375
+ if q is None:
376
+ return
377
+ q.put_nowait(StreamItem(
378
+ request_id=request_id,
379
+ new_text="",
380
+ new_token_ids=[],
381
+ finished=True,
382
+ finish_reason=finish_reason,
383
+ ))
384
+ self._sequences.pop(request_id, None)
385
+ self._prev_text_len.pop(request_id, None)
tiny_vllm/model_runner.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Minimal Qwen2 forward pass that consumes a paged KV cache.
2
+
3
+ We deliberately re-implement Qwen2 from scratch (rather than using the HF
4
+ forward) so the path of K/V tensors through the cache is fully visible.
5
+ Weights are loaded from a HuggingFace checkpoint by matching parameter names.
6
+
7
+ Layout of inputs per step ("varlen" packing):
8
+
9
+ input_ids [T_total] concatenated tokens for all seqs
10
+ positions [T_total] position-in-sequence of each token
11
+ slot_mapping [T_total] where to write new K/V in the cache
12
+ segments list of (q_start, q_end, block_table, k_len, seq_id)
13
+
14
+ For attention, we loop over `segments`: gather each sequence's full K/V from
15
+ its block table, run SDPA, scatter the result back into a flat buffer. All
16
+ other ops (norms, MLP, projections) run on the full packed tensor.
17
+ """
18
+ from __future__ import annotations
19
+
20
+ from dataclasses import dataclass
21
+ from typing import Optional
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+
27
+ from .config import EngineConfig
28
+ from .paged_kv import PagedKVCache
29
+ from .request import Sequence
30
+
31
+
32
+ # ---------------------------------------------------------------------------
33
+ # Qwen2 building blocks
34
+ # ---------------------------------------------------------------------------
35
+
36
+
37
+ class Qwen2RMSNorm(nn.Module):
38
+ def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
39
+ super().__init__()
40
+ self.weight = nn.Parameter(torch.ones(hidden_size))
41
+ self.eps = eps
42
+
43
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
44
+ # x: [..., hidden]
45
+ dtype = x.dtype
46
+ x = x.to(torch.float32)
47
+ var = x.pow(2).mean(-1, keepdim=True)
48
+ x = x * torch.rsqrt(var + self.eps)
49
+ return (self.weight * x).to(dtype)
50
+
51
+
52
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
53
+ half = x.size(-1) // 2
54
+ x1, x2 = x[..., :half], x[..., half:]
55
+ return torch.cat((-x2, x1), dim=-1)
56
+
57
+
58
+ def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
59
+ """x: [T, H, D], cos/sin: [T, D] → returns [T, H, D]."""
60
+ cos = cos.unsqueeze(1)
61
+ sin = sin.unsqueeze(1)
62
+ return (x * cos) + (_rotate_half(x) * sin)
63
+
64
+
65
+ class Qwen2MLP(nn.Module):
66
+ def __init__(self, hidden_size: int, intermediate_size: int) -> None:
67
+ super().__init__()
68
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
69
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
70
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
74
+
75
+
76
+ @dataclass
77
+ class AttnSegment:
78
+ """One sequence's slice of the packed batch."""
79
+ q_start: int # start index in the packed tensor
80
+ q_end: int # exclusive
81
+ block_table: list[int] # KV blocks for this sequence
82
+ k_len: int # total K length (= num_computed_tokens + q_len)
83
+
84
+
85
+ class Qwen2Attention(nn.Module):
86
+ def __init__(
87
+ self,
88
+ hidden_size: int,
89
+ num_heads: int,
90
+ num_kv_heads: int,
91
+ head_dim: int,
92
+ layer_idx: int,
93
+ ) -> None:
94
+ super().__init__()
95
+ self.num_heads = num_heads
96
+ self.num_kv_heads = num_kv_heads
97
+ self.head_dim = head_dim
98
+ self.layer_idx = layer_idx
99
+ self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=True)
100
+ self.k_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=True)
101
+ self.v_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=True)
102
+ self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False)
103
+ self.scale = head_dim ** -0.5
104
+
105
+ def forward(
106
+ self,
107
+ hidden_states: torch.Tensor, # [T, hidden]
108
+ positions: torch.Tensor, # [T] long
109
+ slot_mapping: torch.Tensor, # [T] long
110
+ cos_table: torch.Tensor, # [max_pos, head_dim]
111
+ sin_table: torch.Tensor, # [max_pos, head_dim]
112
+ segments: list[AttnSegment],
113
+ kv_cache: PagedKVCache,
114
+ ) -> torch.Tensor:
115
+ T = hidden_states.size(0)
116
+ q = self.q_proj(hidden_states).view(T, self.num_heads, self.head_dim)
117
+ k = self.k_proj(hidden_states).view(T, self.num_kv_heads, self.head_dim)
118
+ v = self.v_proj(hidden_states).view(T, self.num_kv_heads, self.head_dim)
119
+
120
+ cos = cos_table.index_select(0, positions) # [T, head_dim]
121
+ sin = sin_table.index_select(0, positions)
122
+ q = _apply_rope(q, cos, sin)
123
+ k = _apply_rope(k, cos, sin)
124
+
125
+ # Write the NEW K/V into the paged cache before reading it back.
126
+ kv_cache.write(self.layer_idx, k, v, slot_mapping)
127
+
128
+ out = torch.empty_like(q) # [T, num_heads, head_dim]
129
+ rep = self.num_heads // self.num_kv_heads # GQA fan-out
130
+
131
+ for seg in segments:
132
+ q_slice = q[seg.q_start:seg.q_end] # [q_len, H_q, D]
133
+ k_full, v_full = kv_cache.gather(self.layer_idx, seg.block_table, seg.k_len)
134
+ # GQA: expand K/V heads to match Q heads.
135
+ if rep > 1:
136
+ k_full = k_full.repeat_interleave(rep, dim=1)
137
+ v_full = v_full.repeat_interleave(rep, dim=1)
138
+
139
+ q_len = q_slice.size(0)
140
+ k_len = seg.k_len
141
+ num_past = k_len - q_len
142
+
143
+ # Causal mask: Q at logical position (num_past + i) attends to K at
144
+ # positions [0, num_past + i]. True = participate (SDPA convention).
145
+ idx_q = torch.arange(q_len, device=q.device).unsqueeze(1) + num_past
146
+ idx_k = torch.arange(k_len, device=q.device).unsqueeze(0)
147
+ attn_mask = idx_k <= idx_q # [q_len, k_len]
148
+
149
+ # SDPA wants [..., heads, q_len, head_dim]. Reshape and run.
150
+ q_h = q_slice.transpose(0, 1).unsqueeze(0) # [1, H, q_len, D]
151
+ k_h = k_full.transpose(0, 1).unsqueeze(0) # [1, H, k_len, D]
152
+ v_h = v_full.transpose(0, 1).unsqueeze(0)
153
+ attn = F.scaled_dot_product_attention(
154
+ q_h, k_h, v_h,
155
+ attn_mask=attn_mask.unsqueeze(0).unsqueeze(0), # [1,1,q_len,k_len]
156
+ scale=self.scale,
157
+ ) # [1, H, q_len, D]
158
+ out[seg.q_start:seg.q_end] = attn.squeeze(0).transpose(0, 1)
159
+
160
+ return self.o_proj(out.reshape(T, self.num_heads * self.head_dim))
161
+
162
+
163
+ class Qwen2DecoderLayer(nn.Module):
164
+ def __init__(self, cfg: dict, layer_idx: int) -> None:
165
+ super().__init__()
166
+ self.input_layernorm = Qwen2RMSNorm(cfg["hidden_size"], eps=cfg["rms_norm_eps"])
167
+ self.self_attn = Qwen2Attention(
168
+ hidden_size=cfg["hidden_size"],
169
+ num_heads=cfg["num_attention_heads"],
170
+ num_kv_heads=cfg["num_key_value_heads"],
171
+ head_dim=cfg["head_dim"],
172
+ layer_idx=layer_idx,
173
+ )
174
+ self.post_attention_layernorm = Qwen2RMSNorm(cfg["hidden_size"], eps=cfg["rms_norm_eps"])
175
+ self.mlp = Qwen2MLP(cfg["hidden_size"], cfg["intermediate_size"])
176
+
177
+ def forward(self, hidden_states, positions, slot_mapping, cos_table, sin_table, segments, kv_cache):
178
+ residual = hidden_states
179
+ h = self.input_layernorm(hidden_states)
180
+ h = self.self_attn(h, positions, slot_mapping, cos_table, sin_table, segments, kv_cache)
181
+ hidden_states = residual + h
182
+
183
+ residual = hidden_states
184
+ h = self.post_attention_layernorm(hidden_states)
185
+ h = self.mlp(h)
186
+ return residual + h
187
+
188
+
189
+ class Qwen2Model(nn.Module):
190
+ def __init__(self, cfg: dict) -> None:
191
+ super().__init__()
192
+ self.cfg = cfg
193
+ self.embed_tokens = nn.Embedding(cfg["vocab_size"], cfg["hidden_size"])
194
+ self.layers = nn.ModuleList(
195
+ [Qwen2DecoderLayer(cfg, i) for i in range(cfg["num_hidden_layers"])]
196
+ )
197
+ self.norm = Qwen2RMSNorm(cfg["hidden_size"], eps=cfg["rms_norm_eps"])
198
+
199
+ def forward(self, input_ids, positions, slot_mapping, cos_table, sin_table, segments, kv_cache):
200
+ h = self.embed_tokens(input_ids)
201
+ for layer in self.layers:
202
+ h = layer(h, positions, slot_mapping, cos_table, sin_table, segments, kv_cache)
203
+ return self.norm(h)
204
+
205
+
206
+ class Qwen2ForCausalLM(nn.Module):
207
+ def __init__(self, cfg: dict) -> None:
208
+ super().__init__()
209
+ self.model = Qwen2Model(cfg)
210
+ self.lm_head = nn.Linear(cfg["hidden_size"], cfg["vocab_size"], bias=False)
211
+ self.cfg = cfg
212
+
213
+ def tie_weights(self) -> None:
214
+ self.lm_head.weight = self.model.embed_tokens.weight
215
+
216
+
217
+ # ---------------------------------------------------------------------------
218
+ # ModelRunner: prepares inputs, runs forward, extracts last-token logits.
219
+ # ---------------------------------------------------------------------------
220
+
221
+
222
+ @dataclass
223
+ class ModelInput:
224
+ input_ids: torch.Tensor
225
+ positions: torch.Tensor
226
+ slot_mapping: torch.Tensor
227
+ segments: list[AttnSegment]
228
+ # Index in the packed batch of the LAST token of each scheduled seq —
229
+ # that's where we'll read logits from for sampling.
230
+ last_token_indices: torch.Tensor
231
+
232
+
233
+ class ModelRunner:
234
+ def __init__(self, config: EngineConfig) -> None:
235
+ from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
236
+
237
+ self.config = config
238
+ self.device = torch.device(config.device)
239
+ self.dtype = {
240
+ "float32": torch.float32,
241
+ "float16": torch.float16,
242
+ "bfloat16": torch.bfloat16,
243
+ }[config.dtype]
244
+
245
+ hf_cfg = AutoConfig.from_pretrained(
246
+ config.model, trust_remote_code=config.trust_remote_code
247
+ )
248
+ self.tokenizer = AutoTokenizer.from_pretrained(
249
+ config.model, trust_remote_code=config.trust_remote_code
250
+ )
251
+
252
+ model_type = getattr(hf_cfg, "model_type", "?")
253
+ if model_type not in ("qwen2", "qwen2_moe", "llama"):
254
+ # Llama-style works too because the math is identical; we issue a
255
+ # warning rather than a hard fail.
256
+ print(f"[tiny_vllm] WARNING: model_type={model_type!r}; expected qwen2-like. "
257
+ "Continuing — assuming Llama-compatible config.")
258
+
259
+ head_dim = getattr(hf_cfg, "head_dim", hf_cfg.hidden_size // hf_cfg.num_attention_heads)
260
+ cfg = {
261
+ "vocab_size": hf_cfg.vocab_size,
262
+ "hidden_size": hf_cfg.hidden_size,
263
+ "intermediate_size": hf_cfg.intermediate_size,
264
+ "num_hidden_layers": hf_cfg.num_hidden_layers,
265
+ "num_attention_heads": hf_cfg.num_attention_heads,
266
+ "num_key_value_heads": getattr(hf_cfg, "num_key_value_heads",
267
+ hf_cfg.num_attention_heads),
268
+ "head_dim": head_dim,
269
+ "rms_norm_eps": getattr(hf_cfg, "rms_norm_eps", 1e-6),
270
+ "rope_theta": getattr(hf_cfg, "rope_theta", 10000.0),
271
+ "max_position_embeddings": getattr(hf_cfg, "max_position_embeddings", 4096),
272
+ "tie_word_embeddings": getattr(hf_cfg, "tie_word_embeddings", False),
273
+ }
274
+ self.model_cfg = cfg
275
+
276
+ # Build our own model, then copy HF weights into it.
277
+ model = Qwen2ForCausalLM(cfg).to(self.device, self.dtype)
278
+ hf_model = AutoModelForCausalLM.from_pretrained(
279
+ config.model, torch_dtype=self.dtype,
280
+ trust_remote_code=config.trust_remote_code,
281
+ )
282
+ missing, unexpected = model.load_state_dict(hf_model.state_dict(), strict=False)
283
+ if cfg["tie_word_embeddings"] and "lm_head.weight" in (missing or []):
284
+ model.tie_weights()
285
+ del hf_model
286
+ model.eval()
287
+ for p in model.parameters():
288
+ p.requires_grad_(False)
289
+ self.model = model
290
+
291
+ # Precompute RoPE tables.
292
+ max_pos = min(cfg["max_position_embeddings"], config.max_model_len)
293
+ inv_freq = 1.0 / (
294
+ cfg["rope_theta"]
295
+ ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)
296
+ )
297
+ t = torch.arange(max_pos, dtype=torch.float32)
298
+ freqs = torch.outer(t, inv_freq) # [max_pos, head_dim/2]
299
+ emb = torch.cat((freqs, freqs), dim=-1) # [max_pos, head_dim]
300
+ self.cos_table = emb.cos().to(self.device, self.dtype)
301
+ self.sin_table = emb.sin().to(self.device, self.dtype)
302
+
303
+ # Paged KV cache pool.
304
+ self.kv_cache = PagedKVCache(
305
+ num_layers=cfg["num_hidden_layers"],
306
+ num_blocks=config.num_blocks,
307
+ block_size=config.block_size,
308
+ num_kv_heads=cfg["num_key_value_heads"],
309
+ head_dim=head_dim,
310
+ dtype=self.dtype,
311
+ device=self.device,
312
+ )
313
+
314
+ # ---- input building ------------------------------------------------
315
+
316
+ def prepare_input(self, scheduled) -> ModelInput:
317
+ """`scheduled` is a list of (Sequence, num_tokens, is_prefill) triples
318
+ from the scheduler."""
319
+ input_ids: list[int] = []
320
+ positions: list[int] = []
321
+ slot_mapping: list[int] = []
322
+ segments: list[AttnSegment] = []
323
+ last_indices: list[int] = []
324
+
325
+ cursor = 0
326
+ B = self.config.block_size
327
+ for item in scheduled:
328
+ seq = item.seq
329
+ n = item.num_tokens
330
+ # Logical token positions this step processes.
331
+ start_pos = seq.num_computed_tokens
332
+ for off in range(n):
333
+ pos = start_pos + off
334
+ input_ids.append(seq.get_token(pos))
335
+ positions.append(pos)
336
+ block_id = seq.block_table[pos // B]
337
+ slot_mapping.append(block_id * B + (pos % B))
338
+
339
+ q_end = cursor + n
340
+ segments.append(AttnSegment(
341
+ q_start=cursor,
342
+ q_end=q_end,
343
+ block_table=list(seq.block_table),
344
+ k_len=start_pos + n,
345
+ ))
346
+ last_indices.append(q_end - 1)
347
+ cursor = q_end
348
+
349
+ return ModelInput(
350
+ input_ids=torch.tensor(input_ids, dtype=torch.long, device=self.device),
351
+ positions=torch.tensor(positions, dtype=torch.long, device=self.device),
352
+ slot_mapping=torch.tensor(slot_mapping, dtype=torch.long, device=self.device),
353
+ segments=segments,
354
+ last_token_indices=torch.tensor(last_indices, dtype=torch.long, device=self.device),
355
+ )
356
+
357
+ # ---- forward -------------------------------------------------------
358
+
359
+ @torch.inference_mode()
360
+ def execute(self, model_input: ModelInput) -> torch.Tensor:
361
+ """Run one forward pass. Returns logits for the LAST token of each
362
+ scheduled sequence: shape [num_seqs, vocab_size]."""
363
+ hidden = self.model.model(
364
+ input_ids=model_input.input_ids,
365
+ positions=model_input.positions,
366
+ slot_mapping=model_input.slot_mapping,
367
+ cos_table=self.cos_table,
368
+ sin_table=self.sin_table,
369
+ segments=model_input.segments,
370
+ kv_cache=self.kv_cache,
371
+ ) # [T, hidden]
372
+ last_hidden = hidden.index_select(0, model_input.last_token_indices)
373
+ logits = self.model.lm_head(last_hidden) # [num_seqs, vocab]
374
+ return logits
375
+
376
+ # ---- helpers -------------------------------------------------------
377
+
378
+ @property
379
+ def eos_token_id(self) -> Optional[int]:
380
+ return self.tokenizer.eos_token_id
381
+
382
+ def encode(self, text: str) -> list[int]:
383
+ return self.tokenizer.encode(text, add_special_tokens=False)
384
+
385
+ def decode(self, token_ids: list[int]) -> str:
386
+ return self.tokenizer.decode(token_ids, skip_special_tokens=True)
387
+
388
+ def detokenize_incremental(self, full_ids: list[int], prev_text_len: int) -> tuple[str, int]:
389
+ """Detokenize the full list, return the new text added since last call
390
+ and the new total length."""
391
+ text = self.tokenizer.decode(full_ids, skip_special_tokens=True)
392
+ return text[prev_text_len:], len(text)
tiny_vllm/paged_kv.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """The actual KV tensor pool that the BlockManager indexes into.
2
+
3
+ We store one ``[num_blocks, block_size, num_kv_heads, head_dim]`` tensor per
4
+ layer for K and V. The block_manager owns the *allocation* of block ids; this
5
+ class owns the *bytes*. Reads and writes happen by (block_id, offset).
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import torch
10
+
11
+
12
+ class PagedKVCache:
13
+ def __init__(
14
+ self,
15
+ num_layers: int,
16
+ num_blocks: int,
17
+ block_size: int,
18
+ num_kv_heads: int,
19
+ head_dim: int,
20
+ dtype: torch.dtype,
21
+ device: torch.device,
22
+ ) -> None:
23
+ self.num_layers = num_layers
24
+ self.num_blocks = num_blocks
25
+ self.block_size = block_size
26
+ self.num_kv_heads = num_kv_heads
27
+ self.head_dim = head_dim
28
+ self.dtype = dtype
29
+ self.device = device
30
+ shape = (num_blocks, block_size, num_kv_heads, head_dim)
31
+ self.k_cache = [torch.zeros(shape, dtype=dtype, device=device) for _ in range(num_layers)]
32
+ self.v_cache = [torch.zeros(shape, dtype=dtype, device=device) for _ in range(num_layers)]
33
+
34
+ def write(
35
+ self,
36
+ layer_id: int,
37
+ k: torch.Tensor, # [T, num_kv_heads, head_dim]
38
+ v: torch.Tensor, # [T, num_kv_heads, head_dim]
39
+ slot_mapping: torch.Tensor # [T] int64, slot_id = block_id*block_size + offset
40
+ ) -> None:
41
+ block_ids = (slot_mapping // self.block_size).long()
42
+ offsets = (slot_mapping % self.block_size).long()
43
+ self.k_cache[layer_id][block_ids, offsets] = k.to(self.dtype)
44
+ self.v_cache[layer_id][block_ids, offsets] = v.to(self.dtype)
45
+
46
+ def gather(
47
+ self,
48
+ layer_id: int,
49
+ block_table: list[int],
50
+ num_tokens: int,
51
+ ) -> tuple[torch.Tensor, torch.Tensor]:
52
+ """Return contiguous [num_tokens, num_kv_heads, head_dim] K and V
53
+ for one sequence, by walking its block table."""
54
+ if num_tokens == 0:
55
+ empty = torch.zeros(
56
+ 0, self.num_kv_heads, self.head_dim,
57
+ dtype=self.dtype, device=self.device,
58
+ )
59
+ return empty, empty.clone()
60
+ num_full = num_tokens // self.block_size
61
+ tail = num_tokens % self.block_size
62
+ idxs = block_table[:num_full + (1 if tail else 0)]
63
+ idx_tensor = torch.as_tensor(idxs, dtype=torch.long, device=self.device)
64
+ # [P, block_size, H, D]
65
+ k_blocks = self.k_cache[layer_id].index_select(0, idx_tensor)
66
+ v_blocks = self.v_cache[layer_id].index_select(0, idx_tensor)
67
+ # Flatten the first two dims then trim.
68
+ k_flat = k_blocks.reshape(-1, self.num_kv_heads, self.head_dim)
69
+ v_flat = v_blocks.reshape(-1, self.num_kv_heads, self.head_dim)
70
+ return k_flat[:num_tokens], v_flat[:num_tokens]
tiny_vllm/request.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import enum
4
+ import itertools
5
+ import time
6
+ from dataclasses import dataclass, field
7
+ from typing import Optional
8
+
9
+ from .config import SamplingParams
10
+
11
+
12
+ class SequenceStatus(enum.Enum):
13
+ WAITING = "waiting" # not yet started prefill
14
+ PREFILLING = "prefilling" # chunked prefill in progress
15
+ RUNNING = "running" # in decode loop
16
+ FINISHED = "finished"
17
+ PREEMPTED = "preempted" # evicted; will restart prefill when capacity returns
18
+
19
+
20
+ _seq_counter = itertools.count()
21
+
22
+
23
+ def _next_seq_id() -> int:
24
+ return next(_seq_counter)
25
+
26
+
27
+ @dataclass
28
+ class Sequence:
29
+ """One in-flight request.
30
+
31
+ The token sequence is `prompt_token_ids + output_token_ids`.
32
+ `num_computed_tokens` tracks how many tokens already have their KV
33
+ materialized in the paged cache. Anything past that boundary is either
34
+ waiting prefill (during PREFILLING) or the next token to sample (RUNNING).
35
+ """
36
+
37
+ prompt_token_ids: list[int]
38
+ sampling_params: SamplingParams
39
+ request_id: str
40
+ arrival_time: float = field(default_factory=time.monotonic)
41
+ seq_id: int = field(default_factory=_next_seq_id)
42
+
43
+ output_token_ids: list[int] = field(default_factory=list)
44
+ status: SequenceStatus = SequenceStatus.WAITING
45
+
46
+ # Paged KV bookkeeping (filled in by the BlockManager).
47
+ block_table: list[int] = field(default_factory=list)
48
+ num_computed_tokens: int = 0 # tokens with KV in the cache
49
+ num_cached_prefix_tokens: int = 0 # tokens served from prefix cache hits
50
+
51
+ # Outputs / streaming
52
+ finish_reason: Optional[str] = None
53
+
54
+ # ---- helpers --------------------------------------------------------
55
+
56
+ @property
57
+ def prompt_len(self) -> int:
58
+ return len(self.prompt_token_ids)
59
+
60
+ @property
61
+ def total_len(self) -> int:
62
+ return len(self.prompt_token_ids) + len(self.output_token_ids)
63
+
64
+ def all_token_ids(self) -> list[int]:
65
+ return self.prompt_token_ids + self.output_token_ids
66
+
67
+ def get_token(self, position: int) -> int:
68
+ if position < len(self.prompt_token_ids):
69
+ return self.prompt_token_ids[position]
70
+ return self.output_token_ids[position - len(self.prompt_token_ids)]
71
+
72
+ @property
73
+ def num_uncomputed_prompt_tokens(self) -> int:
74
+ return max(0, self.prompt_len - self.num_computed_tokens)
75
+
76
+ def append_output_token(self, token_id: int) -> None:
77
+ self.output_token_ids.append(token_id)
78
+
79
+
80
+ @dataclass
81
+ class Request:
82
+ """A user-submitted request before it becomes a Sequence."""
83
+
84
+ request_id: str
85
+ prompt_token_ids: list[int]
86
+ sampling_params: SamplingParams
tiny_vllm/sampler.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Per-request sampling. Temperature, top-p, top-k, greedy."""
2
+ from __future__ import annotations
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+
8
+ from .config import SamplingParams
9
+
10
+
11
+ class Sampler:
12
+ def __init__(self, device: torch.device) -> None:
13
+ self.device = device
14
+
15
+ def sample(
16
+ self,
17
+ logits: torch.Tensor, # [num_seqs, vocab]
18
+ params: list[SamplingParams],
19
+ generators: Optional[list[Optional[torch.Generator]]] = None,
20
+ ) -> list[int]:
21
+ out: list[int] = []
22
+ for i, p in enumerate(params):
23
+ row = logits[i]
24
+ if p.is_greedy:
25
+ out.append(int(row.argmax().item()))
26
+ continue
27
+
28
+ # Temperature.
29
+ row = row / max(p.temperature, 1e-5)
30
+ # Top-k.
31
+ if p.top_k > 0 and p.top_k < row.size(-1):
32
+ topk_vals, _ = torch.topk(row, p.top_k)
33
+ row = torch.where(row < topk_vals[-1], torch.full_like(row, float("-inf")), row)
34
+ # Top-p (nucleus).
35
+ if 0.0 < p.top_p < 1.0:
36
+ sorted_logits, sorted_idx = torch.sort(row, descending=True)
37
+ probs = torch.softmax(sorted_logits, dim=-1)
38
+ cumprobs = probs.cumsum(dim=-1)
39
+ # Drop tokens whose CUMULATIVE prob (including themselves) exceeds top_p,
40
+ # but always keep the highest-probability one.
41
+ drop = cumprobs > p.top_p
42
+ drop[0] = False
43
+ drop = drop.roll(shifts=1, dims=0) # so the boundary token stays
44
+ drop[0] = False
45
+ sorted_logits = sorted_logits.masked_fill(drop, float("-inf"))
46
+ row = torch.full_like(row, float("-inf"))
47
+ row.scatter_(0, sorted_idx, sorted_logits)
48
+
49
+ probs = torch.softmax(row, dim=-1)
50
+ gen = generators[i] if generators else None
51
+ token = torch.multinomial(probs, num_samples=1, generator=gen)
52
+ out.append(int(token.item()))
53
+ return out
tiny_vllm/scheduler.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Continuous-batching scheduler with chunked prefill.
2
+
3
+ A scheduling step produces a SchedulerOutput listing which sequences run and
4
+ how many tokens each one advances. Two phases each step:
5
+
6
+ 1. Decodes. Every RUNNING sequence wants exactly one new token; we must
7
+ ensure each has space for it. If a sequence needs a new block and the
8
+ pool is dry, we *preempt* the most recently admitted running sequence —
9
+ free its KV blocks and push it back to the front of the waiting queue
10
+ so it restarts prefill later (recompute-style preemption, as in vLLM).
11
+
12
+ 2. Prefill chunks. With remaining token budget, pull from `waiting`. A
13
+ newly waiting sequence is admitted (prompt blocks allocated via the
14
+ block manager, with prefix-cache hits taken). Then we plan a chunk of
15
+ up to `min(remaining_prefill, budget)` tokens. Chunked prefill lets a
16
+ long prompt share the budget with concurrent decodes instead of
17
+ stalling them.
18
+ """
19
+ from __future__ import annotations
20
+
21
+ from collections import deque
22
+ from dataclasses import dataclass, field
23
+
24
+ from .block_manager import BlockManager
25
+ from .config import EngineConfig
26
+ from .request import Sequence, SequenceStatus
27
+
28
+
29
+ @dataclass
30
+ class ScheduledSeq:
31
+ seq: Sequence
32
+ num_tokens: int # how many tokens to forward this step for this seq
33
+ is_prefill: bool
34
+
35
+
36
+ @dataclass
37
+ class SchedulerOutput:
38
+ scheduled: list[ScheduledSeq] = field(default_factory=list)
39
+ preempted: list[int] = field(default_factory=list) # seq_ids preempted
40
+ newly_admitted: list[int] = field(default_factory=list)
41
+ total_tokens: int = 0
42
+
43
+ @property
44
+ def is_empty(self) -> bool:
45
+ return not self.scheduled
46
+
47
+
48
+ class Scheduler:
49
+ def __init__(self, config: EngineConfig, block_manager: BlockManager) -> None:
50
+ self.config = config
51
+ self.block_manager = block_manager
52
+ self.waiting: deque[Sequence] = deque()
53
+ self.running: list[Sequence] = []
54
+ # Tracks order of admission so preemption picks the youngest first.
55
+ self._admission_order: list[int] = []
56
+
57
+ # ---- queue ops ------------------------------------------------------
58
+
59
+ def add(self, seq: Sequence) -> None:
60
+ self.waiting.append(seq)
61
+
62
+ def abort(self, seq_id: int) -> bool:
63
+ for q in (self.waiting,):
64
+ for s in list(q):
65
+ if s.seq_id == seq_id:
66
+ q.remove(s)
67
+ s.status = SequenceStatus.FINISHED
68
+ s.finish_reason = "abort"
69
+ return True
70
+ for s in list(self.running):
71
+ if s.seq_id == seq_id:
72
+ self.running.remove(s)
73
+ self.block_manager.free(s)
74
+ s.status = SequenceStatus.FINISHED
75
+ s.finish_reason = "abort"
76
+ return True
77
+ return False
78
+
79
+ @property
80
+ def has_work(self) -> bool:
81
+ return bool(self.waiting) or bool(self.running)
82
+
83
+ # ---- scheduling -----------------------------------------------------
84
+
85
+ def _preempt_one(self) -> Sequence | None:
86
+ """Free the youngest running sequence and re-enqueue it for restart."""
87
+ if not self.running:
88
+ return None
89
+ victim = self.running.pop() # youngest by insertion order
90
+ self.block_manager.free(victim)
91
+ # Restart: forget computed-token progress; keep generated outputs so
92
+ # the user-visible sequence is preserved. (vLLM full-recompute: we'd
93
+ # discard outputs too; we keep them so streaming makes sense.)
94
+ victim.num_computed_tokens = 0
95
+ victim.num_cached_prefix_tokens = 0
96
+ victim.status = SequenceStatus.PREEMPTED
97
+ self.waiting.appendleft(victim)
98
+ return victim
99
+
100
+ def schedule(self) -> SchedulerOutput:
101
+ out = SchedulerOutput()
102
+ budget = self.config.max_num_batched_tokens
103
+
104
+ # --- Phase 1: decodes for already-running sequences ---
105
+ for seq in list(self.running):
106
+ if seq.status != SequenceStatus.RUNNING:
107
+ continue
108
+ if budget <= 0:
109
+ break
110
+ # Ensure space for one more token.
111
+ try:
112
+ self.block_manager.append_slot(seq)
113
+ except RuntimeError:
114
+ # Out of blocks: try to free space by preempting the youngest
115
+ # running sequence — which may be `seq` itself.
116
+ victim = self._preempt_one()
117
+ if victim is seq:
118
+ # We preempted ourselves; it's already off `running`.
119
+ out.preempted.append(seq.seq_id)
120
+ continue
121
+ if victim is None:
122
+ # Nothing to preempt; preempt this seq manually.
123
+ self.running.remove(seq)
124
+ self.block_manager.free(seq)
125
+ seq.num_computed_tokens = 0
126
+ seq.num_cached_prefix_tokens = 0
127
+ seq.status = SequenceStatus.PREEMPTED
128
+ self.waiting.appendleft(seq)
129
+ out.preempted.append(seq.seq_id)
130
+ continue
131
+ out.preempted.append(victim.seq_id)
132
+ try:
133
+ self.block_manager.append_slot(seq)
134
+ except RuntimeError:
135
+ # Still no room — give up on this seq this step.
136
+ continue
137
+ out.scheduled.append(ScheduledSeq(seq=seq, num_tokens=1, is_prefill=False))
138
+ budget -= 1
139
+ out.total_tokens += 1
140
+
141
+ # --- Phase 2: prefill chunks (admitting new sequences as needed) ---
142
+ max_concurrent = self.config.max_num_seqs
143
+ active_count = sum(1 for s in self.running if s.status != SequenceStatus.FINISHED)
144
+
145
+ while self.waiting and budget > 0 and active_count < max_concurrent:
146
+ seq = self.waiting[0]
147
+
148
+ # Admit if needed.
149
+ if not seq.block_table:
150
+ ok, _ = self.block_manager.can_allocate_initial(seq)
151
+ if not ok:
152
+ # Try to free up space by preempting the youngest running
153
+ # seq. If nothing to preempt, we're stuck for this step.
154
+ if not self.running:
155
+ break
156
+ victim = self._preempt_one()
157
+ if victim is None:
158
+ break
159
+ out.preempted.append(victim.seq_id)
160
+ continue
161
+ self.block_manager.admit(seq)
162
+ out.newly_admitted.append(seq.seq_id)
163
+ seq.status = SequenceStatus.PREFILLING
164
+
165
+ # Plan a chunk.
166
+ remaining = seq.num_uncomputed_prompt_tokens
167
+ chunk = min(remaining, budget)
168
+ if chunk <= 0:
169
+ # Prompt already fully cached (shouldn't happen due to admit
170
+ # capping, but defensive): move straight to RUNNING.
171
+ self.waiting.popleft()
172
+ seq.status = SequenceStatus.RUNNING
173
+ self.running.append(seq)
174
+ active_count += 1
175
+ continue
176
+
177
+ # Make sure block_table covers num_computed + chunk.
178
+ try:
179
+ self.block_manager.ensure_blocks_for_chunk(seq, chunk)
180
+ except RuntimeError:
181
+ # Couldn't expand. Try preemption; otherwise give up.
182
+ if self.running:
183
+ victim = self._preempt_one()
184
+ if victim is not None:
185
+ out.preempted.append(victim.seq_id)
186
+ continue
187
+ break
188
+
189
+ out.scheduled.append(ScheduledSeq(seq=seq, num_tokens=chunk, is_prefill=True))
190
+ budget -= chunk
191
+ out.total_tokens += chunk
192
+
193
+ if chunk == remaining:
194
+ # This step finishes prompt ingestion → seq becomes RUNNING.
195
+ self.waiting.popleft()
196
+ seq.status = SequenceStatus.RUNNING
197
+ self.running.append(seq)
198
+ active_count += 1
199
+ else:
200
+ # Still has more prompt to chew through; leave at head of
201
+ # waiting queue with a partial block_table.
202
+ break # one prefill per step keeps things tidy
203
+
204
+ return out
205
+
206
+ # ---- post-step ------------------------------------------------------
207
+
208
+ def finalize_step(self, scheduled: list[ScheduledSeq]) -> list[Sequence]:
209
+ """Called after the model has produced new tokens.
210
+
211
+ Returns the list of sequences that just finished this step (so the
212
+ engine can free them and ship the final output to the caller).
213
+ """
214
+ finished: list[Sequence] = []
215
+ for item in scheduled:
216
+ seq = item.seq
217
+ self.block_manager.register_filled_blocks(seq, prev_computed=0)
218
+ if seq.status == SequenceStatus.FINISHED:
219
+ if seq in self.running:
220
+ self.running.remove(seq)
221
+ self.block_manager.free(seq)
222
+ finished.append(seq)
223
+ return finished
tiny_vllm/server.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI front-end.
2
+
3
+ Two SSE streams live behind this server:
4
+
5
+ POST /generate — submit a prompt, stream back token deltas
6
+ POST /v1/completions — OpenAI-compatible streaming completions
7
+ GET /engine/events — stream of engine-state snapshots (one per step)
8
+ — what the demo page subscribes to
9
+ GET /engine/snapshot — one-shot current state (JSON)
10
+ GET / — static demo page
11
+
12
+ The demo page subscribes to /engine/events and renders the block pool,
13
+ scheduler queues, and live token streams.
14
+ """
15
+ from __future__ import annotations
16
+
17
+ import argparse
18
+ import asyncio
19
+ import json
20
+ import os
21
+ import time
22
+ from pathlib import Path
23
+ from typing import AsyncIterator, Optional
24
+
25
+ from fastapi import FastAPI, HTTPException, Request
26
+ from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
27
+ from fastapi.staticfiles import StaticFiles
28
+ from pydantic import BaseModel, Field
29
+
30
+ from .config import EngineConfig, SamplingParams
31
+ from .engine import LLMEngine
32
+
33
+
34
+ # ---------------------------------------------------------------------------
35
+ # Schemas
36
+ # ---------------------------------------------------------------------------
37
+
38
+
39
+ class GenerateRequest(BaseModel):
40
+ prompt: str
41
+ max_tokens: int = 64
42
+ temperature: float = 1.0
43
+ top_p: float = 1.0
44
+ top_k: int = -1
45
+ seed: Optional[int] = None
46
+ ignore_eos: bool = False
47
+ stream: bool = True
48
+
49
+
50
+ class CompletionsRequest(BaseModel):
51
+ model: Optional[str] = None
52
+ prompt: str | list[str]
53
+ max_tokens: int = 64
54
+ temperature: float = 1.0
55
+ top_p: float = 1.0
56
+ n: int = 1
57
+ stream: bool = False
58
+ stop: Optional[list[str]] = None
59
+ seed: Optional[int] = None
60
+
61
+
62
+ # ---------------------------------------------------------------------------
63
+ # App factory
64
+ # ---------------------------------------------------------------------------
65
+
66
+
67
+ def _sse(data: dict | str) -> bytes:
68
+ if isinstance(data, dict):
69
+ data = json.dumps(data, separators=(",", ":"))
70
+ return f"data: {data}\n\n".encode("utf-8")
71
+
72
+
73
+ def build_app(config: EngineConfig) -> FastAPI:
74
+ app = FastAPI(title="tiny_vllm", version="0.1.0")
75
+ engine = LLMEngine(config)
76
+
77
+ @app.on_event("startup")
78
+ async def _on_startup() -> None:
79
+ await engine.startup()
80
+
81
+ @app.on_event("shutdown")
82
+ async def _on_shutdown() -> None:
83
+ await engine.shutdown()
84
+
85
+ # ---- root + static -------------------------------------------------
86
+
87
+ static_dir = Path(__file__).parent.parent / "web"
88
+ if static_dir.exists():
89
+ app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
90
+
91
+ @app.get("/")
92
+ async def root() -> FileResponse:
93
+ return FileResponse(str(static_dir / "index.html"))
94
+ else:
95
+ @app.get("/")
96
+ async def root() -> dict:
97
+ return {"name": "tiny_vllm", "status": "ok",
98
+ "hint": "demo page not found; POST to /generate"}
99
+
100
+ # ---- introspection -------------------------------------------------
101
+
102
+ @app.get("/engine/snapshot")
103
+ async def snapshot() -> dict:
104
+ return engine.snapshot()
105
+
106
+ @app.get("/engine/events")
107
+ async def events(request: Request) -> StreamingResponse:
108
+ q = engine.subscribe_events()
109
+
110
+ async def gen() -> AsyncIterator[bytes]:
111
+ # Push initial snapshot so a freshly-connected client has state.
112
+ yield _sse({"type": "snapshot", "payload": engine.snapshot()})
113
+ try:
114
+ while True:
115
+ if await request.is_disconnected():
116
+ break
117
+ try:
118
+ ev = await asyncio.wait_for(q.get(), timeout=15.0)
119
+ except asyncio.TimeoutError:
120
+ yield b": keepalive\n\n"
121
+ continue
122
+ yield _sse({
123
+ "type": ev.type,
124
+ "step": ev.step,
125
+ "timestamp": ev.timestamp,
126
+ "payload": ev.payload,
127
+ })
128
+ finally:
129
+ engine.unsubscribe_events(q)
130
+
131
+ return StreamingResponse(gen(), media_type="text/event-stream")
132
+
133
+ # ---- generation ----------------------------------------------------
134
+
135
+ def _params(req: GenerateRequest) -> SamplingParams:
136
+ return SamplingParams(
137
+ max_tokens=req.max_tokens,
138
+ temperature=req.temperature,
139
+ top_p=req.top_p,
140
+ top_k=req.top_k,
141
+ seed=req.seed,
142
+ ignore_eos=req.ignore_eos,
143
+ )
144
+
145
+ @app.post("/generate")
146
+ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse | JSONResponse:
147
+ try:
148
+ rid = engine.add_request(req.prompt, _params(req))
149
+ except ValueError as e:
150
+ raise HTTPException(status_code=400, detail=str(e))
151
+
152
+ if not req.stream:
153
+ text_parts: list[str] = []
154
+ finish_reason: Optional[str] = None
155
+ async for item in engine.stream(rid):
156
+ text_parts.append(item.new_text)
157
+ if item.finished:
158
+ finish_reason = item.finish_reason
159
+ break
160
+ return JSONResponse({
161
+ "request_id": rid,
162
+ "text": "".join(text_parts),
163
+ "finish_reason": finish_reason,
164
+ })
165
+
166
+ async def gen() -> AsyncIterator[bytes]:
167
+ try:
168
+ async for item in engine.stream(rid):
169
+ if await request.is_disconnected():
170
+ engine.abort(rid)
171
+ break
172
+ yield _sse({
173
+ "request_id": rid,
174
+ "text": item.new_text,
175
+ "finished": item.finished,
176
+ "finish_reason": item.finish_reason,
177
+ })
178
+ if item.finished:
179
+ yield b"data: [DONE]\n\n"
180
+ break
181
+ except asyncio.CancelledError:
182
+ engine.abort(rid)
183
+ raise
184
+
185
+ return StreamingResponse(gen(), media_type="text/event-stream")
186
+
187
+ @app.post("/v1/completions")
188
+ async def completions(req: CompletionsRequest, request: Request):
189
+ # Single-prompt only (n=1) for the minimal impl.
190
+ if isinstance(req.prompt, list):
191
+ if len(req.prompt) != 1:
192
+ raise HTTPException(400, "tiny_vllm only supports a single prompt per call")
193
+ prompt = req.prompt[0]
194
+ else:
195
+ prompt = req.prompt
196
+ try:
197
+ rid = engine.add_request(
198
+ prompt,
199
+ SamplingParams(
200
+ max_tokens=req.max_tokens,
201
+ temperature=req.temperature,
202
+ top_p=req.top_p,
203
+ seed=req.seed,
204
+ ),
205
+ )
206
+ except ValueError as e:
207
+ raise HTTPException(400, str(e))
208
+
209
+ created = int(time.time())
210
+ model_id = req.model or config.model
211
+
212
+ if not req.stream:
213
+ text_parts: list[str] = []
214
+ finish_reason: Optional[str] = None
215
+ async for item in engine.stream(rid):
216
+ text_parts.append(item.new_text)
217
+ if item.finished:
218
+ finish_reason = item.finish_reason
219
+ break
220
+ return JSONResponse({
221
+ "id": f"cmpl-{rid}",
222
+ "object": "text_completion",
223
+ "created": created,
224
+ "model": model_id,
225
+ "choices": [{
226
+ "text": "".join(text_parts),
227
+ "index": 0,
228
+ "logprobs": None,
229
+ "finish_reason": finish_reason,
230
+ }],
231
+ })
232
+
233
+ async def gen() -> AsyncIterator[bytes]:
234
+ try:
235
+ async for item in engine.stream(rid):
236
+ if await request.is_disconnected():
237
+ engine.abort(rid)
238
+ break
239
+ chunk = {
240
+ "id": f"cmpl-{rid}",
241
+ "object": "text_completion",
242
+ "created": created,
243
+ "model": model_id,
244
+ "choices": [{
245
+ "text": item.new_text,
246
+ "index": 0,
247
+ "logprobs": None,
248
+ "finish_reason": item.finish_reason if item.finished else None,
249
+ }],
250
+ }
251
+ yield _sse(chunk)
252
+ if item.finished:
253
+ yield b"data: [DONE]\n\n"
254
+ break
255
+ except asyncio.CancelledError:
256
+ engine.abort(rid)
257
+ raise
258
+
259
+ return StreamingResponse(gen(), media_type="text/event-stream")
260
+
261
+ @app.post("/abort/{request_id}")
262
+ async def abort(request_id: str) -> dict:
263
+ ok = engine.abort(request_id)
264
+ return {"aborted": ok}
265
+
266
+ return app
267
+
268
+
269
+ # ---------------------------------------------------------------------------
270
+ # CLI entry
271
+ # ---------------------------------------------------------------------------
272
+
273
+
274
+ def main() -> None:
275
+ parser = argparse.ArgumentParser(description="tiny_vllm server")
276
+ parser.add_argument("--model", default=os.environ.get("TINY_VLLM_MODEL", "Qwen/Qwen2.5-0.5B-Instruct"))
277
+ parser.add_argument("--device", default=os.environ.get("TINY_VLLM_DEVICE", "cpu"))
278
+ parser.add_argument("--dtype", default=os.environ.get("TINY_VLLM_DTYPE", "float32"))
279
+ parser.add_argument("--block-size", type=int, default=16)
280
+ parser.add_argument("--num-blocks", type=int, default=256)
281
+ parser.add_argument("--max-num-seqs", type=int, default=8)
282
+ parser.add_argument("--max-num-batched-tokens", type=int, default=512)
283
+ parser.add_argument("--max-model-len", type=int, default=2048)
284
+ parser.add_argument("--disable-prefix-caching", action="store_true")
285
+ parser.add_argument("--host", default="0.0.0.0")
286
+ parser.add_argument("--port", type=int, default=8000)
287
+ args = parser.parse_args()
288
+
289
+ cfg = EngineConfig(
290
+ model=args.model,
291
+ device=args.device,
292
+ dtype=args.dtype,
293
+ block_size=args.block_size,
294
+ num_blocks=args.num_blocks,
295
+ max_num_seqs=args.max_num_seqs,
296
+ max_num_batched_tokens=args.max_num_batched_tokens,
297
+ max_model_len=args.max_model_len,
298
+ enable_prefix_caching=not args.disable_prefix_caching,
299
+ )
300
+
301
+ import uvicorn
302
+ app = build_app(cfg)
303
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
304
+
305
+
306
+ if __name__ == "__main__":
307
+ main()
web/app.js ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* tiny_vllm — demo page client.
2
+ *
3
+ * Two streams in play:
4
+ *
5
+ * /engine/events — engine state snapshots (one per scheduling step)
6
+ * /generate — token-level deltas for whatever prompt this page sent
7
+ *
8
+ * The page itself is stateless; everything is driven by what comes off the
9
+ * event stream. Token deltas from /generate are merged into per-request UI.
10
+ */
11
+
12
+ const $ = (id) => document.getElementById(id);
13
+
14
+ const ui = {
15
+ connection: $("connection"),
16
+ model: $("model"),
17
+ pool: $("block-pool"),
18
+ poolSummary: $("pool-summary"),
19
+ schedStep: $("sched-step"),
20
+ statTokens: $("stat-tokens"),
21
+ statPfDec: $("stat-pfdec"),
22
+ statMs: $("stat-ms"),
23
+ statCache: $("stat-cache"),
24
+ statFree: $("stat-free"),
25
+ statPre: $("stat-pre"),
26
+ log: $("log"),
27
+ seqs: $("seqs"),
28
+ send: $("send"),
29
+ sendTwice: $("send-twice"),
30
+ };
31
+
32
+ const state = {
33
+ poolEls: [],
34
+ numBlocks: 0,
35
+ blockSize: 16,
36
+ preempted: 0,
37
+ // request_id -> { promptText, generated, finished, finishReason }
38
+ requests: new Map(),
39
+ // seq_id -> { request_id, blockTable, cachedPrefixBlocks, status, ... }
40
+ seqsBySeqId: new Map(),
41
+ };
42
+
43
+ function logLine(html, cls = "") {
44
+ const t = new Date().toLocaleTimeString();
45
+ ui.log.innerHTML += `<span class="${cls}">[${t}] ${html}</span>\n`;
46
+ ui.log.scrollTop = ui.log.scrollHeight;
47
+ }
48
+
49
+ function initPool(numBlocks) {
50
+ if (state.numBlocks === numBlocks && state.poolEls.length === numBlocks) return;
51
+ state.numBlocks = numBlocks;
52
+ ui.pool.innerHTML = "";
53
+ state.poolEls = [];
54
+ for (let i = 0; i < numBlocks; i++) {
55
+ const el = document.createElement("div");
56
+ el.className = "block free";
57
+ el.title = `block ${i}`;
58
+ ui.pool.appendChild(el);
59
+ state.poolEls.push(el);
60
+ }
61
+ }
62
+
63
+ function renderPool(pool) {
64
+ initPool(pool.num_blocks);
65
+ state.blockSize = pool.block_size;
66
+ for (let i = 0; i < pool.num_blocks; i++) {
67
+ const el = state.poolEls[i];
68
+ const rc = pool.ref_counts[i];
69
+ const hashed = pool.hashed[i];
70
+ let cls = "block";
71
+ if (rc === 0) {
72
+ cls += hashed ? " cached" : " free";
73
+ } else if (rc === 1) {
74
+ cls += " used";
75
+ } else {
76
+ cls += " shared";
77
+ }
78
+ if (hashed) cls += " hashed";
79
+ el.className = cls;
80
+ el.title = `block ${i} — refcount=${rc}${hashed ? " — hashed (cacheable)" : ""}`;
81
+ }
82
+ ui.poolSummary.textContent =
83
+ `${pool.num_blocks - pool.num_free_blocks}/${pool.num_blocks} used · ` +
84
+ `${pool.num_cached_entries} cached entries · ` +
85
+ `prefix-cache ${pool.prefix_cache_hits}/${pool.prefix_cache_lookups}`;
86
+ ui.statFree.textContent = pool.num_free_blocks;
87
+ if (pool.prefix_cache_lookups > 0) {
88
+ const pct = (100 * pool.prefix_cache_hits / pool.prefix_cache_lookups).toFixed(0);
89
+ ui.statCache.textContent = `${pct}%`;
90
+ } else {
91
+ ui.statCache.textContent = "—";
92
+ }
93
+ }
94
+
95
+ function renderSeqs(snapshot) {
96
+ ui.schedStep.textContent = ` — step ${snapshot.step}`;
97
+ const all = [...snapshot.running, ...snapshot.waiting];
98
+ // index for later token-delta merges
99
+ state.seqsBySeqId = new Map(all.map(s => [s.seq_id, s]));
100
+ ui.seqs.innerHTML = "";
101
+ if (all.length === 0) {
102
+ ui.seqs.innerHTML = `<div class="muted">(no active sequences — send a prompt above)</div>`;
103
+ return;
104
+ }
105
+ for (const s of all) {
106
+ const reqRec = state.requests.get(s.request_id);
107
+ const promptText = reqRec?.promptText ?? "(prompt elided)";
108
+ const gen = reqRec?.generated ?? "";
109
+
110
+ const div = document.createElement("div");
111
+ div.className = "seq";
112
+ div.id = `seq-${s.request_id}`;
113
+
114
+ const cachedBlocks = Math.floor(s.num_cached_prefix_tokens / state.blockSize);
115
+ const blocksHTML = s.block_table.map((bid, i) => {
116
+ const klass = i < cachedBlocks ? "seq-block cached-hit"
117
+ : (snapshot.block_pool.ref_counts[bid] > 1 ? "seq-block shared" : "seq-block");
118
+ return `<div class="${klass}" title="block ${bid}${i < cachedBlocks ? ' (prefix-cache hit)' : ''}">${bid}</div>`;
119
+ }).join("");
120
+
121
+ div.innerHTML = `
122
+ <div class="seq-header">
123
+ <span class="seq-id">req=${s.request_id.slice(0, 8)} seq=${s.seq_id}</span>
124
+ <span class="seq-status ${s.status}">${s.status}</span>
125
+ <span class="seq-meta">
126
+ prompt=${s.prompt_len} · generated=${s.num_generated} ·
127
+ cached=${s.num_cached_prefix_tokens}/${s.prompt_len} ·
128
+ blocks=${s.block_table.length}
129
+ </span>
130
+ </div>
131
+ <div class="seq-blocks">${blocksHTML || '<span class="muted">(no blocks yet)</span>'}</div>
132
+ <div class="seq-text"><span class="prompt">${escapeHtml(promptText)}</span><span class="gen">${escapeHtml(gen)}</span>${s.status === 'running' || s.status === 'prefilling' ? '<span class="cursor">&nbsp;</span>' : ''}</div>
133
+ `;
134
+ ui.seqs.appendChild(div);
135
+ }
136
+ }
137
+
138
+ function escapeHtml(s) {
139
+ return (s || "").replace(/[&<>"]/g, c => ({"&": "&amp;", "<": "&lt;", ">": "&gt;", '"': "&quot;"}[c]));
140
+ }
141
+
142
+ function handleEvent(ev) {
143
+ if (ev.type === "snapshot") {
144
+ const snap = ev.payload;
145
+ ui.model.textContent = `· ${snap.config.model}`;
146
+ renderPool(snap.block_pool);
147
+ renderSeqs(snap);
148
+ return;
149
+ }
150
+ if (ev.type === "step") {
151
+ const p = ev.payload;
152
+ ui.statTokens.textContent = p.num_tokens;
153
+ ui.statPfDec.textContent = `${p.num_prefill_seqs} / ${p.num_decode_seqs}`;
154
+ ui.statMs.textContent = p.duration_ms.toFixed(1);
155
+ if (p.preempted?.length) state.preempted += p.preempted.length;
156
+ ui.statPre.textContent = state.preempted;
157
+ renderPool(p.snapshot.block_pool);
158
+ renderSeqs(p.snapshot);
159
+
160
+ let msg = `step ${ev.step}: ${p.num_tokens}t (${p.num_prefill_seqs}P/${p.num_decode_seqs}D) in ${p.duration_ms.toFixed(1)}ms`;
161
+ let cls = "ev-step";
162
+ if (p.newly_admitted?.length) {
163
+ msg += ` · admitted seq=${p.newly_admitted.join(",")}`;
164
+ cls = "ev-admit";
165
+ }
166
+ if (p.finished?.length) {
167
+ msg += ` · finished ${p.finished.map(r => r.slice(0,8)).join(",")}`;
168
+ cls = "ev-finish";
169
+ }
170
+ if (p.preempted?.length) {
171
+ msg += ` · PREEMPTED seq=${p.preempted.join(",")}`;
172
+ cls = "ev-preempt";
173
+ }
174
+ logLine(msg, cls);
175
+ }
176
+ }
177
+
178
+ function connectEvents() {
179
+ const es = new EventSource("/engine/events");
180
+ es.onopen = () => {
181
+ ui.connection.textContent = "connected";
182
+ ui.connection.classList.remove("offline");
183
+ ui.connection.classList.add("online");
184
+ };
185
+ es.onerror = () => {
186
+ ui.connection.textContent = "disconnected";
187
+ ui.connection.classList.remove("online");
188
+ ui.connection.classList.add("offline");
189
+ };
190
+ es.onmessage = (e) => {
191
+ if (!e.data) return;
192
+ try {
193
+ handleEvent(JSON.parse(e.data));
194
+ } catch (err) {
195
+ console.error("bad event", err, e.data);
196
+ }
197
+ };
198
+ }
199
+
200
+ async function sendPrompt(prompt) {
201
+ const body = {
202
+ prompt,
203
+ max_tokens: parseInt($("max_tokens").value, 10),
204
+ temperature: parseFloat($("temperature").value),
205
+ top_p: parseFloat($("top_p").value),
206
+ stream: true,
207
+ };
208
+ const resp = await fetch("/generate", {
209
+ method: "POST",
210
+ headers: {"content-type": "application/json"},
211
+ body: JSON.stringify(body),
212
+ });
213
+ if (!resp.ok) {
214
+ const txt = await resp.text();
215
+ logLine(`request failed: ${txt}`, "ev-preempt");
216
+ return;
217
+ }
218
+
219
+ // Parse SSE manually so we can read each event as it arrives.
220
+ const reader = resp.body.getReader();
221
+ const decoder = new TextDecoder();
222
+ let buf = "";
223
+ let myReqId = null;
224
+ while (true) {
225
+ const { value, done } = await reader.read();
226
+ if (done) break;
227
+ buf += decoder.decode(value, { stream: true });
228
+ const parts = buf.split("\n\n");
229
+ buf = parts.pop();
230
+ for (const part of parts) {
231
+ const line = part.trim();
232
+ if (!line.startsWith("data:")) continue;
233
+ const data = line.slice(5).trim();
234
+ if (data === "[DONE]") return;
235
+ try {
236
+ const j = JSON.parse(data);
237
+ if (!myReqId) {
238
+ myReqId = j.request_id;
239
+ state.requests.set(myReqId, { promptText: prompt, generated: "", finished: false });
240
+ }
241
+ const rec = state.requests.get(myReqId);
242
+ if (j.text) rec.generated += j.text;
243
+ rec.finished = j.finished;
244
+ rec.finishReason = j.finish_reason;
245
+ // Repaint the matching seq card if visible.
246
+ const card = document.getElementById(`seq-${myReqId}`);
247
+ if (card) {
248
+ const text = card.querySelector(".seq-text .gen");
249
+ if (text) text.textContent = rec.generated;
250
+ }
251
+ } catch (e) {
252
+ console.error("bad chunk", e, data);
253
+ }
254
+ }
255
+ }
256
+ }
257
+
258
+ ui.send.addEventListener("click", () => sendPrompt($("prompt").value));
259
+ ui.sendTwice.addEventListener("click", async () => {
260
+ const p = $("prompt").value;
261
+ // First send fills the prefix cache; second send should hit it.
262
+ await sendPrompt(p);
263
+ await new Promise(r => setTimeout(r, 200));
264
+ await sendPrompt(p);
265
+ });
266
+ $("prompt").addEventListener("keydown", (e) => {
267
+ if ((e.metaKey || e.ctrlKey) && e.key === "Enter") {
268
+ sendPrompt(e.target.value);
269
+ }
270
+ });
271
+
272
+ connectEvents();
web/index.html ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="utf-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1">
6
+ <title>tiny_vllm — engine internals</title>
7
+ <link rel="stylesheet" href="/static/style.css">
8
+ </head>
9
+ <body>
10
+ <header>
11
+ <h1>tiny_vllm <span class="muted">— minimal continuous-batching engine</span></h1>
12
+ <div class="status">
13
+ <span id="connection" class="badge offline">disconnected</span>
14
+ <span id="model" class="muted"></span>
15
+ </div>
16
+ </header>
17
+
18
+ <section class="prompt-box">
19
+ <textarea id="prompt" rows="2" placeholder="Type a prompt and press Send (or Cmd/Ctrl+Enter)…">Explain paged attention in two sentences.</textarea>
20
+ <div class="controls">
21
+ <label>max_tokens <input id="max_tokens" type="number" value="64" min="1" max="2048"></label>
22
+ <label>temperature <input id="temperature" type="number" value="0.7" step="0.1" min="0" max="2"></label>
23
+ <label>top_p <input id="top_p" type="number" value="0.9" step="0.05" min="0" max="1"></label>
24
+ <button id="send">Send</button>
25
+ <button id="send-twice" title="Submit the same prompt twice — second should hit prefix cache">Send ×2 (prefix demo)</button>
26
+ </div>
27
+ </section>
28
+
29
+ <main>
30
+ <section class="card">
31
+ <h2>Block pool <span class="muted" id="pool-summary"></span></h2>
32
+ <div id="block-pool" class="block-pool"></div>
33
+ <div class="legend">
34
+ <span class="legend-item"><span class="swatch swatch-free"></span>free</span>
35
+ <span class="legend-item"><span class="swatch swatch-cached"></span>cached (evictable)</span>
36
+ <span class="legend-item"><span class="swatch swatch-used"></span>in use</span>
37
+ <span class="legend-item"><span class="swatch swatch-shared"></span>shared (refcount&gt;1)</span>
38
+ <span class="legend-item"><span class="swatch swatch-hashed-edge"></span>hashed (border)</span>
39
+ </div>
40
+ </section>
41
+
42
+ <section class="card">
43
+ <h2>Scheduler <span class="muted" id="sched-step"></span></h2>
44
+ <div class="stats">
45
+ <div class="stat"><div class="stat-label">tokens this step</div><div class="stat-value" id="stat-tokens">0</div></div>
46
+ <div class="stat"><div class="stat-label">prefill / decode</div><div class="stat-value" id="stat-pfdec">0 / 0</div></div>
47
+ <div class="stat"><div class="stat-label">step (ms)</div><div class="stat-value" id="stat-ms">0</div></div>
48
+ <div class="stat"><div class="stat-label">prefix cache hit-rate</div><div class="stat-value" id="stat-cache">0%</div></div>
49
+ <div class="stat"><div class="stat-label">free blocks</div><div class="stat-value" id="stat-free">0</div></div>
50
+ <div class="stat"><div class="stat-label">preemptions (total)</div><div class="stat-value" id="stat-pre">0</div></div>
51
+ </div>
52
+ <h3>step log</h3>
53
+ <pre id="log" class="log"></pre>
54
+ </section>
55
+
56
+ <section class="card grow">
57
+ <h2>Sequences</h2>
58
+ <div id="seqs"></div>
59
+ </section>
60
+ </main>
61
+
62
+ <footer>
63
+ <span class="muted">Subscribed to <code>/engine/events</code>. Source: <a href="https://github.com/yourname/tiny_vllm" target="_blank">github</a>.</span>
64
+ </footer>
65
+
66
+ <script src="/static/app.js"></script>
67
+ </body>
68
+ </html>
web/style.css ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ :root {
2
+ --bg: #0e1116;
3
+ --bg-elev: #161b22;
4
+ --bg-elev2: #1f2630;
5
+ --fg: #e6edf3;
6
+ --muted: #8b949e;
7
+ --accent: #58a6ff;
8
+ --green: #3fb950;
9
+ --purple: #a371f7;
10
+ --orange: #f0883e;
11
+ --red: #f85149;
12
+ --border: #30363d;
13
+ --mono: ui-monospace, "JetBrains Mono", Menlo, Consolas, monospace;
14
+ }
15
+
16
+ * { box-sizing: border-box; }
17
+
18
+ body {
19
+ margin: 0;
20
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Inter, sans-serif;
21
+ background: var(--bg);
22
+ color: var(--fg);
23
+ font-size: 14px;
24
+ }
25
+
26
+ header {
27
+ display: flex; align-items: center; justify-content: space-between;
28
+ padding: 14px 20px;
29
+ border-bottom: 1px solid var(--border);
30
+ background: var(--bg-elev);
31
+ }
32
+ header h1 { font-size: 16px; margin: 0; font-weight: 600; }
33
+ .muted { color: var(--muted); font-weight: 400; }
34
+ .badge {
35
+ display: inline-block;
36
+ padding: 2px 8px;
37
+ border-radius: 10px;
38
+ font-size: 11px;
39
+ font-family: var(--mono);
40
+ }
41
+ .badge.online { background: rgba(63, 185, 80, 0.15); color: var(--green); }
42
+ .badge.offline { background: rgba(248, 81, 73, 0.15); color: var(--red); }
43
+
44
+ .prompt-box {
45
+ padding: 12px 20px;
46
+ border-bottom: 1px solid var(--border);
47
+ background: var(--bg-elev);
48
+ display: flex; flex-direction: column; gap: 10px;
49
+ }
50
+ .prompt-box textarea {
51
+ width: 100%;
52
+ background: var(--bg);
53
+ color: var(--fg);
54
+ border: 1px solid var(--border);
55
+ border-radius: 6px;
56
+ padding: 8px;
57
+ font-family: var(--mono);
58
+ resize: vertical;
59
+ }
60
+ .controls { display: flex; gap: 12px; align-items: center; flex-wrap: wrap; }
61
+ .controls label { display: flex; gap: 6px; align-items: center; font-size: 12px; color: var(--muted); }
62
+ .controls input {
63
+ width: 70px; background: var(--bg); color: var(--fg);
64
+ border: 1px solid var(--border); border-radius: 4px; padding: 3px 6px;
65
+ font-family: var(--mono);
66
+ }
67
+ button {
68
+ background: var(--accent); color: white;
69
+ border: none; border-radius: 4px;
70
+ padding: 6px 14px; font-weight: 500; cursor: pointer;
71
+ }
72
+ button:hover { filter: brightness(1.1); }
73
+ #send-twice { background: var(--purple); }
74
+
75
+ main {
76
+ display: grid;
77
+ grid-template-columns: 1fr 1fr;
78
+ grid-template-areas: "pool sched" "seqs seqs";
79
+ gap: 16px;
80
+ padding: 16px 20px;
81
+ }
82
+ .card {
83
+ background: var(--bg-elev);
84
+ border: 1px solid var(--border);
85
+ border-radius: 8px;
86
+ padding: 14px;
87
+ }
88
+ .card h2 { font-size: 14px; margin: 0 0 10px; font-weight: 600; }
89
+ .card h3 { font-size: 12px; margin: 14px 0 6px; color: var(--muted); text-transform: uppercase; letter-spacing: 0.06em; }
90
+ .card.grow { grid-area: seqs; }
91
+ .card:nth-child(1) { grid-area: pool; }
92
+ .card:nth-child(2) { grid-area: sched; }
93
+
94
+ /* ---- block pool ---- */
95
+ .block-pool {
96
+ display: grid;
97
+ grid-template-columns: repeat(auto-fill, 16px);
98
+ gap: 3px;
99
+ padding: 8px;
100
+ background: var(--bg);
101
+ border-radius: 6px;
102
+ max-height: 280px; overflow-y: auto;
103
+ }
104
+ .block {
105
+ width: 16px; height: 16px; border-radius: 3px;
106
+ background: var(--bg-elev2);
107
+ position: relative;
108
+ cursor: help;
109
+ border: 1px solid transparent;
110
+ }
111
+ .block.free { background: #2a3140; }
112
+ .block.cached { background: #1f3b5c; } /* free but in prefix cache */
113
+ .block.used { background: var(--green); }
114
+ .block.shared { background: var(--purple); }
115
+ .block.hashed { border-color: var(--orange); }
116
+
117
+ .legend { display: flex; gap: 14px; margin-top: 10px; font-size: 11px; color: var(--muted); flex-wrap: wrap; }
118
+ .legend-item { display: flex; align-items: center; gap: 5px; }
119
+ .swatch { width: 12px; height: 12px; border-radius: 3px; display: inline-block; }
120
+ .swatch-free { background: #2a3140; }
121
+ .swatch-cached { background: #1f3b5c; }
122
+ .swatch-used { background: var(--green); }
123
+ .swatch-shared { background: var(--purple); }
124
+ .swatch-hashed-edge { background: var(--bg-elev2); border: 1px solid var(--orange); }
125
+
126
+ /* ---- stats ---- */
127
+ .stats { display: grid; grid-template-columns: repeat(3, 1fr); gap: 8px; }
128
+ .stat {
129
+ background: var(--bg);
130
+ border-radius: 6px;
131
+ padding: 8px;
132
+ }
133
+ .stat-label { font-size: 10px; color: var(--muted); text-transform: uppercase; letter-spacing: 0.06em; }
134
+ .stat-value { font-family: var(--mono); font-size: 18px; margin-top: 3px; }
135
+
136
+ /* ---- log ---- */
137
+ .log {
138
+ background: var(--bg);
139
+ border-radius: 6px;
140
+ padding: 8px;
141
+ height: 140px; overflow-y: auto;
142
+ font-family: var(--mono); font-size: 11px;
143
+ margin: 0;
144
+ white-space: pre-wrap; word-break: break-word;
145
+ }
146
+ .log .ev-step { color: var(--muted); }
147
+ .log .ev-admit { color: var(--accent); }
148
+ .log .ev-finish { color: var(--green); }
149
+ .log .ev-preempt { color: var(--red); }
150
+
151
+ /* ---- sequences ---- */
152
+ #seqs { display: flex; flex-direction: column; gap: 10px; }
153
+ .seq {
154
+ background: var(--bg);
155
+ border: 1px solid var(--border);
156
+ border-radius: 6px;
157
+ padding: 10px;
158
+ }
159
+ .seq-header { display: flex; gap: 10px; align-items: center; }
160
+ .seq-id { font-family: var(--mono); color: var(--muted); font-size: 12px; }
161
+ .seq-status {
162
+ font-size: 10px; text-transform: uppercase; padding: 2px 6px; border-radius: 3px;
163
+ font-family: var(--mono); letter-spacing: 0.05em;
164
+ }
165
+ .seq-status.waiting { background: rgba(139, 148, 158, 0.2); color: var(--muted); }
166
+ .seq-status.prefilling { background: rgba(88, 166, 255, 0.15); color: var(--accent); }
167
+ .seq-status.running { background: rgba(63, 185, 80, 0.15); color: var(--green); }
168
+ .seq-status.finished { background: rgba(163, 113, 247, 0.15); color: var(--purple); }
169
+ .seq-status.preempted { background: rgba(240, 136, 62, 0.2); color: var(--orange); }
170
+ .seq-meta { color: var(--muted); font-size: 11px; font-family: var(--mono); margin-left: auto; }
171
+ .seq-blocks {
172
+ margin-top: 8px;
173
+ display: flex; gap: 2px; flex-wrap: wrap;
174
+ }
175
+ .seq-block {
176
+ width: 22px; height: 14px;
177
+ background: var(--bg-elev2);
178
+ font-size: 9px; line-height: 14px; text-align: center;
179
+ font-family: var(--mono);
180
+ border-radius: 2px;
181
+ color: var(--muted);
182
+ }
183
+ .seq-block.cached-hit { background: #1f3b5c; color: var(--accent); }
184
+ .seq-block.shared { background: var(--purple); color: white; }
185
+ .seq-text {
186
+ margin-top: 8px;
187
+ font-family: var(--mono); font-size: 12px;
188
+ background: var(--bg-elev2);
189
+ border-radius: 4px; padding: 6px;
190
+ min-height: 24px;
191
+ max-height: 180px;
192
+ overflow-y: auto;
193
+ white-space: pre-wrap; word-break: break-word;
194
+ }
195
+ .seq-text .prompt { color: var(--muted); }
196
+ .seq-text .gen { color: var(--fg); }
197
+ .seq-text .cursor {
198
+ display: inline-block; width: 6px; background: var(--accent);
199
+ animation: blink 1s steps(2, start) infinite;
200
+ }
201
+ @keyframes blink { to { visibility: hidden; } }
202
+
203
+ footer {
204
+ padding: 10px 20px;
205
+ border-top: 1px solid var(--border);
206
+ color: var(--muted); font-size: 11px;
207
+ }
208
+ footer a { color: var(--accent); text-decoration: none; }
209
+
210
+ @media (max-width: 900px) {
211
+ main { grid-template-columns: 1fr; grid-template-areas: "pool" "sched" "seqs"; }
212
+ .stats { grid-template-columns: repeat(2, 1fr); }
213
+ }