File size: 16,103 Bytes
c32c359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
"""Minimal Qwen2 forward pass that consumes a paged KV cache.

We deliberately re-implement Qwen2 from scratch (rather than using the HF
forward) so the path of K/V tensors through the cache is fully visible.
Weights are loaded from a HuggingFace checkpoint by matching parameter names.

Layout of inputs per step ("varlen" packing):

  input_ids       [T_total]               concatenated tokens for all seqs
  positions       [T_total]               position-in-sequence of each token
  slot_mapping    [T_total]               where to write new K/V in the cache
  segments        list of (q_start, q_end, block_table, k_len, seq_id)

For attention, we loop over `segments`: gather each sequence's full K/V from
its block table, run SDPA, scatter the result back into a flat buffer.  All
other ops (norms, MLP, projections) run on the full packed tensor.
"""
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from .config import EngineConfig
from .paged_kv import PagedKVCache
from .request import Sequence


# ---------------------------------------------------------------------------
# Qwen2 building blocks
# ---------------------------------------------------------------------------


class Qwen2RMSNorm(nn.Module):
    def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [..., hidden]
        dtype = x.dtype
        x = x.to(torch.float32)
        var = x.pow(2).mean(-1, keepdim=True)
        x = x * torch.rsqrt(var + self.eps)
        return (self.weight * x).to(dtype)


def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    half = x.size(-1) // 2
    x1, x2 = x[..., :half], x[..., half:]
    return torch.cat((-x2, x1), dim=-1)


def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    """x: [T, H, D], cos/sin: [T, D] → returns [T, H, D]."""
    cos = cos.unsqueeze(1)
    sin = sin.unsqueeze(1)
    return (x * cos) + (_rotate_half(x) * sin)


class Qwen2MLP(nn.Module):
    def __init__(self, hidden_size: int, intermediate_size: int) -> None:
        super().__init__()
        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))


@dataclass
class AttnSegment:
    """One sequence's slice of the packed batch."""
    q_start: int           # start index in the packed tensor
    q_end: int             # exclusive
    block_table: list[int] # KV blocks for this sequence
    k_len: int             # total K length (= num_computed_tokens + q_len)


class Qwen2Attention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        head_dim: int,
        layer_idx: int,
    ) -> None:
        super().__init__()
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = head_dim
        self.layer_idx = layer_idx
        self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=True)
        self.k_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=True)
        self.v_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=True)
        self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False)
        self.scale = head_dim ** -0.5

    def forward(
        self,
        hidden_states: torch.Tensor,        # [T, hidden]
        positions: torch.Tensor,            # [T] long
        slot_mapping: torch.Tensor,         # [T] long
        cos_table: torch.Tensor,            # [max_pos, head_dim]
        sin_table: torch.Tensor,            # [max_pos, head_dim]
        segments: list[AttnSegment],
        kv_cache: PagedKVCache,
    ) -> torch.Tensor:
        T = hidden_states.size(0)
        q = self.q_proj(hidden_states).view(T, self.num_heads, self.head_dim)
        k = self.k_proj(hidden_states).view(T, self.num_kv_heads, self.head_dim)
        v = self.v_proj(hidden_states).view(T, self.num_kv_heads, self.head_dim)

        cos = cos_table.index_select(0, positions)  # [T, head_dim]
        sin = sin_table.index_select(0, positions)
        q = _apply_rope(q, cos, sin)
        k = _apply_rope(k, cos, sin)

        # Write the NEW K/V into the paged cache before reading it back.
        kv_cache.write(self.layer_idx, k, v, slot_mapping)

        out = torch.empty_like(q)  # [T, num_heads, head_dim]
        rep = self.num_heads // self.num_kv_heads  # GQA fan-out

        for seg in segments:
            q_slice = q[seg.q_start:seg.q_end]                  # [q_len, H_q, D]
            k_full, v_full = kv_cache.gather(self.layer_idx, seg.block_table, seg.k_len)
            # GQA: expand K/V heads to match Q heads.
            if rep > 1:
                k_full = k_full.repeat_interleave(rep, dim=1)
                v_full = v_full.repeat_interleave(rep, dim=1)

            q_len = q_slice.size(0)
            k_len = seg.k_len
            num_past = k_len - q_len

            # Causal mask: Q at logical position (num_past + i) attends to K at
            # positions [0, num_past + i].  True = participate (SDPA convention).
            idx_q = torch.arange(q_len, device=q.device).unsqueeze(1) + num_past
            idx_k = torch.arange(k_len, device=q.device).unsqueeze(0)
            attn_mask = idx_k <= idx_q  # [q_len, k_len]

            # SDPA wants [..., heads, q_len, head_dim].  Reshape and run.
            q_h = q_slice.transpose(0, 1).unsqueeze(0)   # [1, H, q_len, D]
            k_h = k_full.transpose(0, 1).unsqueeze(0)    # [1, H, k_len, D]
            v_h = v_full.transpose(0, 1).unsqueeze(0)
            attn = F.scaled_dot_product_attention(
                q_h, k_h, v_h,
                attn_mask=attn_mask.unsqueeze(0).unsqueeze(0),  # [1,1,q_len,k_len]
                scale=self.scale,
            )                                            # [1, H, q_len, D]
            out[seg.q_start:seg.q_end] = attn.squeeze(0).transpose(0, 1)

        return self.o_proj(out.reshape(T, self.num_heads * self.head_dim))


class Qwen2DecoderLayer(nn.Module):
    def __init__(self, cfg: dict, layer_idx: int) -> None:
        super().__init__()
        self.input_layernorm = Qwen2RMSNorm(cfg["hidden_size"], eps=cfg["rms_norm_eps"])
        self.self_attn = Qwen2Attention(
            hidden_size=cfg["hidden_size"],
            num_heads=cfg["num_attention_heads"],
            num_kv_heads=cfg["num_key_value_heads"],
            head_dim=cfg["head_dim"],
            layer_idx=layer_idx,
        )
        self.post_attention_layernorm = Qwen2RMSNorm(cfg["hidden_size"], eps=cfg["rms_norm_eps"])
        self.mlp = Qwen2MLP(cfg["hidden_size"], cfg["intermediate_size"])

    def forward(self, hidden_states, positions, slot_mapping, cos_table, sin_table, segments, kv_cache):
        residual = hidden_states
        h = self.input_layernorm(hidden_states)
        h = self.self_attn(h, positions, slot_mapping, cos_table, sin_table, segments, kv_cache)
        hidden_states = residual + h

        residual = hidden_states
        h = self.post_attention_layernorm(hidden_states)
        h = self.mlp(h)
        return residual + h


class Qwen2Model(nn.Module):
    def __init__(self, cfg: dict) -> None:
        super().__init__()
        self.cfg = cfg
        self.embed_tokens = nn.Embedding(cfg["vocab_size"], cfg["hidden_size"])
        self.layers = nn.ModuleList(
            [Qwen2DecoderLayer(cfg, i) for i in range(cfg["num_hidden_layers"])]
        )
        self.norm = Qwen2RMSNorm(cfg["hidden_size"], eps=cfg["rms_norm_eps"])

    def forward(self, input_ids, positions, slot_mapping, cos_table, sin_table, segments, kv_cache):
        h = self.embed_tokens(input_ids)
        for layer in self.layers:
            h = layer(h, positions, slot_mapping, cos_table, sin_table, segments, kv_cache)
        return self.norm(h)


class Qwen2ForCausalLM(nn.Module):
    def __init__(self, cfg: dict) -> None:
        super().__init__()
        self.model = Qwen2Model(cfg)
        self.lm_head = nn.Linear(cfg["hidden_size"], cfg["vocab_size"], bias=False)
        self.cfg = cfg

    def tie_weights(self) -> None:
        self.lm_head.weight = self.model.embed_tokens.weight


# ---------------------------------------------------------------------------
# ModelRunner: prepares inputs, runs forward, extracts last-token logits.
# ---------------------------------------------------------------------------


@dataclass
class ModelInput:
    input_ids: torch.Tensor
    positions: torch.Tensor
    slot_mapping: torch.Tensor
    segments: list[AttnSegment]
    # Index in the packed batch of the LAST token of each scheduled seq —
    # that's where we'll read logits from for sampling.
    last_token_indices: torch.Tensor


class ModelRunner:
    def __init__(self, config: EngineConfig) -> None:
        from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM

        self.config = config
        self.device = torch.device(config.device)
        self.dtype = {
            "float32": torch.float32,
            "float16": torch.float16,
            "bfloat16": torch.bfloat16,
        }[config.dtype]

        hf_cfg = AutoConfig.from_pretrained(
            config.model, trust_remote_code=config.trust_remote_code
        )
        self.tokenizer = AutoTokenizer.from_pretrained(
            config.model, trust_remote_code=config.trust_remote_code
        )

        model_type = getattr(hf_cfg, "model_type", "?")
        if model_type not in ("qwen2", "qwen2_moe", "llama"):
            # Llama-style works too because the math is identical; we issue a
            # warning rather than a hard fail.
            print(f"[tiny_vllm] WARNING: model_type={model_type!r}; expected qwen2-like. "
                  "Continuing — assuming Llama-compatible config.")

        head_dim = getattr(hf_cfg, "head_dim", hf_cfg.hidden_size // hf_cfg.num_attention_heads)
        cfg = {
            "vocab_size": hf_cfg.vocab_size,
            "hidden_size": hf_cfg.hidden_size,
            "intermediate_size": hf_cfg.intermediate_size,
            "num_hidden_layers": hf_cfg.num_hidden_layers,
            "num_attention_heads": hf_cfg.num_attention_heads,
            "num_key_value_heads": getattr(hf_cfg, "num_key_value_heads",
                                            hf_cfg.num_attention_heads),
            "head_dim": head_dim,
            "rms_norm_eps": getattr(hf_cfg, "rms_norm_eps", 1e-6),
            "rope_theta": getattr(hf_cfg, "rope_theta", 10000.0),
            "max_position_embeddings": getattr(hf_cfg, "max_position_embeddings", 4096),
            "tie_word_embeddings": getattr(hf_cfg, "tie_word_embeddings", False),
        }
        self.model_cfg = cfg

        # Build our own model, then copy HF weights into it.
        model = Qwen2ForCausalLM(cfg).to(self.device, self.dtype)
        hf_model = AutoModelForCausalLM.from_pretrained(
            config.model, torch_dtype=self.dtype,
            trust_remote_code=config.trust_remote_code,
        )
        missing, unexpected = model.load_state_dict(hf_model.state_dict(), strict=False)
        if cfg["tie_word_embeddings"] and "lm_head.weight" in (missing or []):
            model.tie_weights()
        del hf_model
        model.eval()
        for p in model.parameters():
            p.requires_grad_(False)
        self.model = model

        # Precompute RoPE tables.
        max_pos = min(cfg["max_position_embeddings"], config.max_model_len)
        inv_freq = 1.0 / (
            cfg["rope_theta"]
            ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)
        )
        t = torch.arange(max_pos, dtype=torch.float32)
        freqs = torch.outer(t, inv_freq)        # [max_pos, head_dim/2]
        emb = torch.cat((freqs, freqs), dim=-1)  # [max_pos, head_dim]
        self.cos_table = emb.cos().to(self.device, self.dtype)
        self.sin_table = emb.sin().to(self.device, self.dtype)

        # Paged KV cache pool.
        self.kv_cache = PagedKVCache(
            num_layers=cfg["num_hidden_layers"],
            num_blocks=config.num_blocks,
            block_size=config.block_size,
            num_kv_heads=cfg["num_key_value_heads"],
            head_dim=head_dim,
            dtype=self.dtype,
            device=self.device,
        )

    # ---- input building ------------------------------------------------

    def prepare_input(self, scheduled) -> ModelInput:
        """`scheduled` is a list of (Sequence, num_tokens, is_prefill) triples
        from the scheduler."""
        input_ids: list[int] = []
        positions: list[int] = []
        slot_mapping: list[int] = []
        segments: list[AttnSegment] = []
        last_indices: list[int] = []

        cursor = 0
        B = self.config.block_size
        for item in scheduled:
            seq = item.seq
            n = item.num_tokens
            # Logical token positions this step processes.
            start_pos = seq.num_computed_tokens
            for off in range(n):
                pos = start_pos + off
                input_ids.append(seq.get_token(pos))
                positions.append(pos)
                block_id = seq.block_table[pos // B]
                slot_mapping.append(block_id * B + (pos % B))

            q_end = cursor + n
            segments.append(AttnSegment(
                q_start=cursor,
                q_end=q_end,
                block_table=list(seq.block_table),
                k_len=start_pos + n,
            ))
            last_indices.append(q_end - 1)
            cursor = q_end

        return ModelInput(
            input_ids=torch.tensor(input_ids, dtype=torch.long, device=self.device),
            positions=torch.tensor(positions, dtype=torch.long, device=self.device),
            slot_mapping=torch.tensor(slot_mapping, dtype=torch.long, device=self.device),
            segments=segments,
            last_token_indices=torch.tensor(last_indices, dtype=torch.long, device=self.device),
        )

    # ---- forward -------------------------------------------------------

    @torch.inference_mode()
    def execute(self, model_input: ModelInput) -> torch.Tensor:
        """Run one forward pass.  Returns logits for the LAST token of each
        scheduled sequence: shape [num_seqs, vocab_size]."""
        hidden = self.model.model(
            input_ids=model_input.input_ids,
            positions=model_input.positions,
            slot_mapping=model_input.slot_mapping,
            cos_table=self.cos_table,
            sin_table=self.sin_table,
            segments=model_input.segments,
            kv_cache=self.kv_cache,
        )                                                # [T, hidden]
        last_hidden = hidden.index_select(0, model_input.last_token_indices)
        logits = self.model.lm_head(last_hidden)         # [num_seqs, vocab]
        return logits

    # ---- helpers -------------------------------------------------------

    @property
    def eos_token_id(self) -> Optional[int]:
        return self.tokenizer.eos_token_id

    def encode(self, text: str) -> list[int]:
        return self.tokenizer.encode(text, add_special_tokens=False)

    def decode(self, token_ids: list[int]) -> str:
        return self.tokenizer.decode(token_ids, skip_special_tokens=True)

    def detokenize_incremental(self, full_ids: list[int], prev_text_len: int) -> tuple[str, int]:
        """Detokenize the full list, return the new text added since last call
        and the new total length."""
        text = self.tokenizer.decode(full_ids, skip_special_tokens=True)
        return text[prev_text_len:], len(text)