tiny-vllm / tiny_vllm /model_runner.py
enCoder's picture
minimal continuous-batching LLM engine
c32c359
"""Minimal Qwen2 forward pass that consumes a paged KV cache.
We deliberately re-implement Qwen2 from scratch (rather than using the HF
forward) so the path of K/V tensors through the cache is fully visible.
Weights are loaded from a HuggingFace checkpoint by matching parameter names.
Layout of inputs per step ("varlen" packing):
input_ids [T_total] concatenated tokens for all seqs
positions [T_total] position-in-sequence of each token
slot_mapping [T_total] where to write new K/V in the cache
segments list of (q_start, q_end, block_table, k_len, seq_id)
For attention, we loop over `segments`: gather each sequence's full K/V from
its block table, run SDPA, scatter the result back into a flat buffer. All
other ops (norms, MLP, projections) run on the full packed tensor.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from .config import EngineConfig
from .paged_kv import PagedKVCache
from .request import Sequence
# ---------------------------------------------------------------------------
# Qwen2 building blocks
# ---------------------------------------------------------------------------
class Qwen2RMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [..., hidden]
dtype = x.dtype
x = x.to(torch.float32)
var = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(var + self.eps)
return (self.weight * x).to(dtype)
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
half = x.size(-1) // 2
x1, x2 = x[..., :half], x[..., half:]
return torch.cat((-x2, x1), dim=-1)
def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
"""x: [T, H, D], cos/sin: [T, D] → returns [T, H, D]."""
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
return (x * cos) + (_rotate_half(x) * sin)
class Qwen2MLP(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int) -> None:
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
@dataclass
class AttnSegment:
"""One sequence's slice of the packed batch."""
q_start: int # start index in the packed tensor
q_end: int # exclusive
block_table: list[int] # KV blocks for this sequence
k_len: int # total K length (= num_computed_tokens + q_len)
class Qwen2Attention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
layer_idx: int,
) -> None:
super().__init__()
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.layer_idx = layer_idx
self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=True)
self.k_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=True)
self.v_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=True)
self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False)
self.scale = head_dim ** -0.5
def forward(
self,
hidden_states: torch.Tensor, # [T, hidden]
positions: torch.Tensor, # [T] long
slot_mapping: torch.Tensor, # [T] long
cos_table: torch.Tensor, # [max_pos, head_dim]
sin_table: torch.Tensor, # [max_pos, head_dim]
segments: list[AttnSegment],
kv_cache: PagedKVCache,
) -> torch.Tensor:
T = hidden_states.size(0)
q = self.q_proj(hidden_states).view(T, self.num_heads, self.head_dim)
k = self.k_proj(hidden_states).view(T, self.num_kv_heads, self.head_dim)
v = self.v_proj(hidden_states).view(T, self.num_kv_heads, self.head_dim)
cos = cos_table.index_select(0, positions) # [T, head_dim]
sin = sin_table.index_select(0, positions)
q = _apply_rope(q, cos, sin)
k = _apply_rope(k, cos, sin)
# Write the NEW K/V into the paged cache before reading it back.
kv_cache.write(self.layer_idx, k, v, slot_mapping)
out = torch.empty_like(q) # [T, num_heads, head_dim]
rep = self.num_heads // self.num_kv_heads # GQA fan-out
for seg in segments:
q_slice = q[seg.q_start:seg.q_end] # [q_len, H_q, D]
k_full, v_full = kv_cache.gather(self.layer_idx, seg.block_table, seg.k_len)
# GQA: expand K/V heads to match Q heads.
if rep > 1:
k_full = k_full.repeat_interleave(rep, dim=1)
v_full = v_full.repeat_interleave(rep, dim=1)
q_len = q_slice.size(0)
k_len = seg.k_len
num_past = k_len - q_len
# Causal mask: Q at logical position (num_past + i) attends to K at
# positions [0, num_past + i]. True = participate (SDPA convention).
idx_q = torch.arange(q_len, device=q.device).unsqueeze(1) + num_past
idx_k = torch.arange(k_len, device=q.device).unsqueeze(0)
attn_mask = idx_k <= idx_q # [q_len, k_len]
# SDPA wants [..., heads, q_len, head_dim]. Reshape and run.
q_h = q_slice.transpose(0, 1).unsqueeze(0) # [1, H, q_len, D]
k_h = k_full.transpose(0, 1).unsqueeze(0) # [1, H, k_len, D]
v_h = v_full.transpose(0, 1).unsqueeze(0)
attn = F.scaled_dot_product_attention(
q_h, k_h, v_h,
attn_mask=attn_mask.unsqueeze(0).unsqueeze(0), # [1,1,q_len,k_len]
scale=self.scale,
) # [1, H, q_len, D]
out[seg.q_start:seg.q_end] = attn.squeeze(0).transpose(0, 1)
return self.o_proj(out.reshape(T, self.num_heads * self.head_dim))
class Qwen2DecoderLayer(nn.Module):
def __init__(self, cfg: dict, layer_idx: int) -> None:
super().__init__()
self.input_layernorm = Qwen2RMSNorm(cfg["hidden_size"], eps=cfg["rms_norm_eps"])
self.self_attn = Qwen2Attention(
hidden_size=cfg["hidden_size"],
num_heads=cfg["num_attention_heads"],
num_kv_heads=cfg["num_key_value_heads"],
head_dim=cfg["head_dim"],
layer_idx=layer_idx,
)
self.post_attention_layernorm = Qwen2RMSNorm(cfg["hidden_size"], eps=cfg["rms_norm_eps"])
self.mlp = Qwen2MLP(cfg["hidden_size"], cfg["intermediate_size"])
def forward(self, hidden_states, positions, slot_mapping, cos_table, sin_table, segments, kv_cache):
residual = hidden_states
h = self.input_layernorm(hidden_states)
h = self.self_attn(h, positions, slot_mapping, cos_table, sin_table, segments, kv_cache)
hidden_states = residual + h
residual = hidden_states
h = self.post_attention_layernorm(hidden_states)
h = self.mlp(h)
return residual + h
class Qwen2Model(nn.Module):
def __init__(self, cfg: dict) -> None:
super().__init__()
self.cfg = cfg
self.embed_tokens = nn.Embedding(cfg["vocab_size"], cfg["hidden_size"])
self.layers = nn.ModuleList(
[Qwen2DecoderLayer(cfg, i) for i in range(cfg["num_hidden_layers"])]
)
self.norm = Qwen2RMSNorm(cfg["hidden_size"], eps=cfg["rms_norm_eps"])
def forward(self, input_ids, positions, slot_mapping, cos_table, sin_table, segments, kv_cache):
h = self.embed_tokens(input_ids)
for layer in self.layers:
h = layer(h, positions, slot_mapping, cos_table, sin_table, segments, kv_cache)
return self.norm(h)
class Qwen2ForCausalLM(nn.Module):
def __init__(self, cfg: dict) -> None:
super().__init__()
self.model = Qwen2Model(cfg)
self.lm_head = nn.Linear(cfg["hidden_size"], cfg["vocab_size"], bias=False)
self.cfg = cfg
def tie_weights(self) -> None:
self.lm_head.weight = self.model.embed_tokens.weight
# ---------------------------------------------------------------------------
# ModelRunner: prepares inputs, runs forward, extracts last-token logits.
# ---------------------------------------------------------------------------
@dataclass
class ModelInput:
input_ids: torch.Tensor
positions: torch.Tensor
slot_mapping: torch.Tensor
segments: list[AttnSegment]
# Index in the packed batch of the LAST token of each scheduled seq —
# that's where we'll read logits from for sampling.
last_token_indices: torch.Tensor
class ModelRunner:
def __init__(self, config: EngineConfig) -> None:
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
self.config = config
self.device = torch.device(config.device)
self.dtype = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}[config.dtype]
hf_cfg = AutoConfig.from_pretrained(
config.model, trust_remote_code=config.trust_remote_code
)
self.tokenizer = AutoTokenizer.from_pretrained(
config.model, trust_remote_code=config.trust_remote_code
)
model_type = getattr(hf_cfg, "model_type", "?")
if model_type not in ("qwen2", "qwen2_moe", "llama"):
# Llama-style works too because the math is identical; we issue a
# warning rather than a hard fail.
print(f"[tiny_vllm] WARNING: model_type={model_type!r}; expected qwen2-like. "
"Continuing — assuming Llama-compatible config.")
head_dim = getattr(hf_cfg, "head_dim", hf_cfg.hidden_size // hf_cfg.num_attention_heads)
cfg = {
"vocab_size": hf_cfg.vocab_size,
"hidden_size": hf_cfg.hidden_size,
"intermediate_size": hf_cfg.intermediate_size,
"num_hidden_layers": hf_cfg.num_hidden_layers,
"num_attention_heads": hf_cfg.num_attention_heads,
"num_key_value_heads": getattr(hf_cfg, "num_key_value_heads",
hf_cfg.num_attention_heads),
"head_dim": head_dim,
"rms_norm_eps": getattr(hf_cfg, "rms_norm_eps", 1e-6),
"rope_theta": getattr(hf_cfg, "rope_theta", 10000.0),
"max_position_embeddings": getattr(hf_cfg, "max_position_embeddings", 4096),
"tie_word_embeddings": getattr(hf_cfg, "tie_word_embeddings", False),
}
self.model_cfg = cfg
# Build our own model, then copy HF weights into it.
model = Qwen2ForCausalLM(cfg).to(self.device, self.dtype)
hf_model = AutoModelForCausalLM.from_pretrained(
config.model, torch_dtype=self.dtype,
trust_remote_code=config.trust_remote_code,
)
missing, unexpected = model.load_state_dict(hf_model.state_dict(), strict=False)
if cfg["tie_word_embeddings"] and "lm_head.weight" in (missing or []):
model.tie_weights()
del hf_model
model.eval()
for p in model.parameters():
p.requires_grad_(False)
self.model = model
# Precompute RoPE tables.
max_pos = min(cfg["max_position_embeddings"], config.max_model_len)
inv_freq = 1.0 / (
cfg["rope_theta"]
** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)
)
t = torch.arange(max_pos, dtype=torch.float32)
freqs = torch.outer(t, inv_freq) # [max_pos, head_dim/2]
emb = torch.cat((freqs, freqs), dim=-1) # [max_pos, head_dim]
self.cos_table = emb.cos().to(self.device, self.dtype)
self.sin_table = emb.sin().to(self.device, self.dtype)
# Paged KV cache pool.
self.kv_cache = PagedKVCache(
num_layers=cfg["num_hidden_layers"],
num_blocks=config.num_blocks,
block_size=config.block_size,
num_kv_heads=cfg["num_key_value_heads"],
head_dim=head_dim,
dtype=self.dtype,
device=self.device,
)
# ---- input building ------------------------------------------------
def prepare_input(self, scheduled) -> ModelInput:
"""`scheduled` is a list of (Sequence, num_tokens, is_prefill) triples
from the scheduler."""
input_ids: list[int] = []
positions: list[int] = []
slot_mapping: list[int] = []
segments: list[AttnSegment] = []
last_indices: list[int] = []
cursor = 0
B = self.config.block_size
for item in scheduled:
seq = item.seq
n = item.num_tokens
# Logical token positions this step processes.
start_pos = seq.num_computed_tokens
for off in range(n):
pos = start_pos + off
input_ids.append(seq.get_token(pos))
positions.append(pos)
block_id = seq.block_table[pos // B]
slot_mapping.append(block_id * B + (pos % B))
q_end = cursor + n
segments.append(AttnSegment(
q_start=cursor,
q_end=q_end,
block_table=list(seq.block_table),
k_len=start_pos + n,
))
last_indices.append(q_end - 1)
cursor = q_end
return ModelInput(
input_ids=torch.tensor(input_ids, dtype=torch.long, device=self.device),
positions=torch.tensor(positions, dtype=torch.long, device=self.device),
slot_mapping=torch.tensor(slot_mapping, dtype=torch.long, device=self.device),
segments=segments,
last_token_indices=torch.tensor(last_indices, dtype=torch.long, device=self.device),
)
# ---- forward -------------------------------------------------------
@torch.inference_mode()
def execute(self, model_input: ModelInput) -> torch.Tensor:
"""Run one forward pass. Returns logits for the LAST token of each
scheduled sequence: shape [num_seqs, vocab_size]."""
hidden = self.model.model(
input_ids=model_input.input_ids,
positions=model_input.positions,
slot_mapping=model_input.slot_mapping,
cos_table=self.cos_table,
sin_table=self.sin_table,
segments=model_input.segments,
kv_cache=self.kv_cache,
) # [T, hidden]
last_hidden = hidden.index_select(0, model_input.last_token_indices)
logits = self.model.lm_head(last_hidden) # [num_seqs, vocab]
return logits
# ---- helpers -------------------------------------------------------
@property
def eos_token_id(self) -> Optional[int]:
return self.tokenizer.eos_token_id
def encode(self, text: str) -> list[int]:
return self.tokenizer.encode(text, add_special_tokens=False)
def decode(self, token_ids: list[int]) -> str:
return self.tokenizer.decode(token_ids, skip_special_tokens=True)
def detokenize_incremental(self, full_ids: list[int], prev_text_len: int) -> tuple[str, int]:
"""Detokenize the full list, return the new text added since last call
and the new total length."""
text = self.tokenizer.decode(full_ids, skip_special_tokens=True)
return text[prev_text_len:], len(text)