| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Any, Callable, Optional, Union |
|
|
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
| |
|
|
| from transformers.activations import ACT2FN |
| from transformers.cache_utils import Cache, DynamicCache |
| from transformers.generation import GenerationMixin |
| from transformers.integrations import use_kernel_forward_from_hub |
| from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask |
| from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
| from transformers.modeling_layers import ( |
| GenericForQuestionAnswering, |
| GenericForSequenceClassification, |
| GenericForTokenClassification, |
| GradientCheckpointingLayer, |
| ) |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.processing_utils import Unpack |
| from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple |
| from transformers.utils.deprecation import deprecate_kwarg |
| from transformers.utils.generic import check_model_inputs |
| from transformers.models.qwen3.configuration_qwen3 import Qwen3Config |
| |
| from .summary_context import SummaryBatchContext, build_summary_context, build_summary_sliding_context |
| from summary_attn import summary_attn_func |
|
|
|
|
|
|
| def _parse_config_pattern(val): |
| """Parse a config value that may be an int, list, or Python pattern string like '([4096]*1+[128]*3)*9'.""" |
| if isinstance(val, list): |
| return val |
| if isinstance(val, str): |
| return eval(val) |
| return val |
|
|
|
|
| @use_kernel_forward_from_hub("RMSNorm") |
| class Qwen3RMSNorm(nn.Module): |
| def __init__(self, hidden_size, eps: float = 1e-6) -> None: |
| """ |
| Qwen3RMSNorm is equivalent to T5LayerNorm |
| """ |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| return self.weight * hidden_states.to(input_dtype) |
|
|
| def extra_repr(self): |
| return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" |
|
|
|
|
| class Qwen3RingBufferCache: |
| """ |
| Ring buffer KV cache with summary support. |
| |
| Two strategies based on per-layer sliding_chunk_num: |
| - Large window layers (is_large_window=True): append-only buffer storing only text KV. |
| Summary KV is NOT stored since text tokens attend to all text KV directly. |
| - Small window layers (is_large_window=False): |
| Three buffers: |
| 1. key_cache: [ring(ws) | old_summaries(growing) | chunk_mirror(≤C)] |
| → attention input, steady state is a single contiguous slice |
| 2. new_summary_buf: ring buffer of size scn, stores summaries whose text |
| is still in the window (not needed for attention) |
| 3. chunk_buf: size C, holds current chunk's text KV |
| |
| RoPE position information is baked into KV, so physical order doesn't matter. |
| """ |
|
|
| is_compileable = False |
| _SUMMARY_INIT_CAP = 512 |
| _APPEND_HEADROOM = 1024 |
|
|
| def __init__(self, config: Qwen3Config, sliding_chunk_nums: list[int]): |
| super().__init__() |
| self.summary_chunk_size = getattr(config, "summary_chunk_size", 0) |
| self.summary_token_num = getattr(config, "summary_token_num", 0) |
| self.num_hidden_layers = config.num_hidden_layers |
|
|
| self.sliding_chunk_nums = sliding_chunk_nums |
| large_window_threshold = min(sliding_chunk_nums) * self.summary_chunk_size |
| self.is_large_window = [sv * self.summary_chunk_size > large_window_threshold for sv in sliding_chunk_nums] |
| self.window_sizes = [sv * self.summary_chunk_size for sv in sliding_chunk_nums] |
|
|
| self.key_cache = [None for _ in range(config.num_hidden_layers)] |
| self.value_cache = [None for _ in range(config.num_hidden_layers)] |
|
|
| |
| self._text_len = [0] * config.num_hidden_layers |
| self._capacity = [0] * config.num_hidden_layers |
|
|
| |
| self._window_write_ptr = [0] * config.num_hidden_layers |
| self._n_valid_window = [0] * config.num_hidden_layers |
| self._old_summary_len = [0] * config.num_hidden_layers |
| self._old_summary_cap = [0] * config.num_hidden_layers |
|
|
| |
| self._new_sum_key_buf = [None for _ in range(config.num_hidden_layers)] |
| self._new_sum_value_buf = [None for _ in range(config.num_hidden_layers)] |
| self._new_sum_len = [0] * config.num_hidden_layers |
| self._new_sum_write_ptr = [0] * config.num_hidden_layers |
|
|
| |
| self._chunk_key_buf = [None for _ in range(config.num_hidden_layers)] |
| self._chunk_value_buf = [None for _ in range(config.num_hidden_layers)] |
| self._chunk_buf_len = [0] * config.num_hidden_layers |
|
|
| |
| self.cur_chunk_sizes = [0] * config.num_hidden_layers |
| self.true_tokens = [0] * config.num_hidden_layers |
| self._total_chunks = [0] * config.num_hidden_layers |
| self._reorganized = False |
|
|
| def __len__(self): |
| return self.num_hidden_layers |
|
|
| def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
| """Returns nonzero when cache is populated (used to detect prefill vs decode).""" |
| if layer_idx >= self.num_hidden_layers: |
| return 0 |
| if self.is_large_window[layer_idx]: |
| return self._text_len[layer_idx] |
| else: |
| return (self._n_valid_window[layer_idx] + self._chunk_buf_len[layer_idx] |
| + self._old_summary_len[layer_idx] + self._new_sum_len[layer_idx]) |
|
|
| def get_cur_chunk_size(self, layer_idx: Optional[int] = None) -> int: |
| if layer_idx is None: |
| layer_idx = self.num_hidden_layers - 1 |
| return self.cur_chunk_sizes[layer_idx] |
|
|
| def get_true_token_num(self, layer_idx: Optional[int] = None) -> int: |
| if layer_idx is None: |
| layer_idx = self.num_hidden_layers - 1 |
| return self.true_tokens[layer_idx] |
|
|
| |
|
|
| def update( |
| self, |
| key_states: torch.Tensor, |
| value_states: torch.Tensor, |
| layer_idx: int, |
| cache_kwargs: Optional[dict[str, Any]] = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Append KV during prefill (before reorganize). Returns full KV for prefill attention.""" |
| add_len = key_states.shape[-2] |
| cur_len = self._text_len[layer_idx] |
| new_len = cur_len + add_len |
|
|
| if self.key_cache[layer_idx] is None: |
| cap = new_len + self._APPEND_HEADROOM |
| bsz, heads, _, head_dim = key_states.shape |
| self.key_cache[layer_idx] = torch.empty( |
| bsz, heads, cap, head_dim, dtype=key_states.dtype, device=key_states.device) |
| self.value_cache[layer_idx] = torch.empty( |
| bsz, heads, cap, head_dim, dtype=value_states.dtype, device=value_states.device) |
| self._capacity[layer_idx] = cap |
| elif new_len > self._capacity[layer_idx]: |
| cap = max(new_len + self._APPEND_HEADROOM, self._capacity[layer_idx] * 2) |
| old_k, old_v = self.key_cache[layer_idx], self.value_cache[layer_idx] |
| bsz, heads, _, head_dim = old_k.shape |
| new_k = torch.empty(bsz, heads, cap, head_dim, dtype=old_k.dtype, device=old_k.device) |
| new_v = torch.empty(bsz, heads, cap, head_dim, dtype=old_v.dtype, device=old_v.device) |
| new_k[:, :, :cur_len, :].copy_(old_k[:, :, :cur_len, :]) |
| new_v[:, :, :cur_len, :].copy_(old_v[:, :, :cur_len, :]) |
| self.key_cache[layer_idx] = new_k |
| self.value_cache[layer_idx] = new_v |
| self._capacity[layer_idx] = cap |
|
|
| self.key_cache[layer_idx][:, :, cur_len:new_len, :].copy_(key_states) |
| self.value_cache[layer_idx][:, :, cur_len:new_len, :].copy_(value_states) |
| self._text_len[layer_idx] = new_len |
|
|
| if self.summary_chunk_size > 0: |
| if cache_kwargs and 'summary_mask' in cache_kwargs: |
| text_count = add_len - cache_kwargs['summary_mask'][0].sum().item() |
| else: |
| text_count = add_len |
| self.cur_chunk_sizes[layer_idx] += add_len |
| self.true_tokens[layer_idx] += text_count |
|
|
| return self.key_cache[layer_idx][:, :, :new_len, :], self.value_cache[layer_idx][:, :, :new_len, :] |
|
|
| |
|
|
| def reorganize_after_prefill(self, summary_mask: torch.Tensor): |
| """Reorganize all layers from prefill block layout to ring buffer layout.""" |
| if self._reorganized: |
| return |
| self._reorganized = True |
|
|
| text_mask = ~summary_mask[0] |
|
|
| for layer_idx in range(self.num_hidden_layers): |
| prefill_len = self._text_len[layer_idx] |
| prefill_k = self.key_cache[layer_idx][:, :, :prefill_len, :] |
| prefill_v = self.value_cache[layer_idx][:, :, :prefill_len, :] |
| bsz, heads, _, head_dim = prefill_k.shape |
| device, dtype = prefill_k.device, prefill_k.dtype |
|
|
| text_k = prefill_k[:, :, text_mask, :] |
| text_v = prefill_v[:, :, text_mask, :] |
| n_text = text_k.shape[2] |
|
|
| if self.is_large_window[layer_idx]: |
| |
| cap = n_text + self._APPEND_HEADROOM |
| new_k = torch.empty(bsz, heads, cap, head_dim, dtype=dtype, device=device) |
| new_v = torch.empty(bsz, heads, cap, head_dim, dtype=dtype, device=device) |
| new_k[:, :, :n_text, :].copy_(text_k) |
| new_v[:, :, :n_text, :].copy_(text_v) |
| self.key_cache[layer_idx] = new_k |
| self.value_cache[layer_idx] = new_v |
| self._text_len[layer_idx] = n_text |
| self._capacity[layer_idx] = cap |
| else: |
| |
| all_summary_k = prefill_k[:, :, summary_mask[0], :] |
| all_summary_v = prefill_v[:, :, summary_mask[0], :] |
| n_summary = all_summary_k.shape[2] |
|
|
| C = self.summary_chunk_size |
| ws = self.window_sizes[layer_idx] |
| scn = self.sliding_chunk_nums[layer_idx] |
|
|
| |
| n_complete_chunks = n_text // C |
| n_partial = n_text % C |
| n_complete_text = n_complete_chunks * C |
|
|
| |
| n_window_chunks = min(scn, n_complete_chunks) |
| n_window_text = n_window_chunks * C |
| window_start = n_complete_text - n_window_text |
|
|
| |
| n_old = max(0, n_summary - n_window_chunks) |
| n_new = n_summary - n_old |
|
|
| |
| old_s_cap = max(self._SUMMARY_INIT_CAP, (n_old + 1) * 2) |
| total_cap = ws + old_s_cap + C |
| new_k = torch.empty(bsz, heads, total_cap, head_dim, dtype=dtype, device=device) |
| new_v = torch.empty(bsz, heads, total_cap, head_dim, dtype=dtype, device=device) |
|
|
| if n_window_text > 0: |
| new_k[:, :, :n_window_text, :].copy_(text_k[:, :, window_start:n_complete_text, :]) |
| new_v[:, :, :n_window_text, :].copy_(text_v[:, :, window_start:n_complete_text, :]) |
| self._n_valid_window[layer_idx] = n_window_text |
| self._window_write_ptr[layer_idx] = n_window_text % ws |
|
|
| |
| if n_old > 0: |
| new_k[:, :, ws:ws + n_old, :].copy_(all_summary_k[:, :, :n_old, :]) |
| new_v[:, :, ws:ws + n_old, :].copy_(all_summary_v[:, :, :n_old, :]) |
| self._old_summary_len[layer_idx] = n_old |
| self._old_summary_cap[layer_idx] = old_s_cap |
|
|
| |
| if n_partial > 0: |
| mirror_start = ws + n_old |
| new_k[:, :, mirror_start:mirror_start + n_partial, :].copy_( |
| text_k[:, :, n_complete_text:, :]) |
| new_v[:, :, mirror_start:mirror_start + n_partial, :].copy_( |
| text_v[:, :, n_complete_text:, :]) |
|
|
| self.key_cache[layer_idx] = new_k |
| self.value_cache[layer_idx] = new_v |
| self._capacity[layer_idx] = total_cap |
| self._text_len[layer_idx] = 0 |
|
|
| |
| ns_buf_k = torch.empty(bsz, heads, scn, head_dim, dtype=dtype, device=device) |
| ns_buf_v = torch.empty(bsz, heads, scn, head_dim, dtype=dtype, device=device) |
| if n_new > 0: |
| ns_buf_k[:, :, :n_new, :].copy_(all_summary_k[:, :, n_old:, :]) |
| ns_buf_v[:, :, :n_new, :].copy_(all_summary_v[:, :, n_old:, :]) |
| self._new_sum_key_buf[layer_idx] = ns_buf_k |
| self._new_sum_value_buf[layer_idx] = ns_buf_v |
| self._new_sum_len[layer_idx] = n_new |
| self._new_sum_write_ptr[layer_idx] = n_new % scn |
|
|
| |
| chunk_buf_k = torch.empty(bsz, heads, C, head_dim, dtype=dtype, device=device) |
| chunk_buf_v = torch.empty(bsz, heads, C, head_dim, dtype=dtype, device=device) |
| if n_partial > 0: |
| chunk_buf_k[:, :, :n_partial, :].copy_(text_k[:, :, n_complete_text:, :]) |
| chunk_buf_v[:, :, :n_partial, :].copy_(text_v[:, :, n_complete_text:, :]) |
| self._chunk_key_buf[layer_idx] = chunk_buf_k |
| self._chunk_value_buf[layer_idx] = chunk_buf_v |
| self._chunk_buf_len[layer_idx] = n_partial |
|
|
| block = self.summary_chunk_size + self.summary_token_num |
| for layer_idx in range(self.num_hidden_layers): |
| self.cur_chunk_sizes[layer_idx] = self.cur_chunk_sizes[layer_idx] % block |
| self._total_chunks[layer_idx] = ( |
| self._old_summary_len[layer_idx] + self._new_sum_len[layer_idx] |
| if not self.is_large_window[layer_idx] |
| else (self.true_tokens[layer_idx] // self.summary_chunk_size) |
| ) |
|
|
| |
|
|
| def update_text(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int): |
| """Write a single text token KV during decode.""" |
| if self.is_large_window[layer_idx]: |
| cur = self._text_len[layer_idx] |
| new_len = cur + 1 |
| if new_len > self._capacity[layer_idx]: |
| cap = max(new_len + self._APPEND_HEADROOM, self._capacity[layer_idx] * 2) |
| old_k, old_v = self.key_cache[layer_idx], self.value_cache[layer_idx] |
| bsz, heads, _, head_dim = old_k.shape |
| new_k = torch.empty(bsz, heads, cap, head_dim, dtype=old_k.dtype, device=old_k.device) |
| new_v = torch.empty(bsz, heads, cap, head_dim, dtype=old_v.dtype, device=old_v.device) |
| new_k[:, :, :cur, :].copy_(old_k[:, :, :cur, :]) |
| new_v[:, :, :cur, :].copy_(old_v[:, :, :cur, :]) |
| self.key_cache[layer_idx] = new_k |
| self.value_cache[layer_idx] = new_v |
| self._capacity[layer_idx] = cap |
| self.key_cache[layer_idx][:, :, cur:new_len, :].copy_(key_states) |
| self.value_cache[layer_idx][:, :, cur:new_len, :].copy_(value_states) |
| self._text_len[layer_idx] = new_len |
| else: |
| |
| ws = self.window_sizes[layer_idx] |
| n_old = self._old_summary_len[layer_idx] |
| pos = self._chunk_buf_len[layer_idx] |
| mirror_pos = ws + n_old + pos |
| self.key_cache[layer_idx][:, :, mirror_pos:mirror_pos+1, :].copy_(key_states) |
| self.value_cache[layer_idx][:, :, mirror_pos:mirror_pos+1, :].copy_(value_states) |
| self._chunk_buf_len[layer_idx] = pos + 1 |
|
|
| self.cur_chunk_sizes[layer_idx] += 1 |
| self.true_tokens[layer_idx] += 1 |
|
|
| |
|
|
| def update_summary(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int): |
| """Write summary token KV during decode (chunk boundary). |
| |
| Large window: skip. |
| Small window (order matters — flush mirror before evict to avoid clobbering): |
| 1. Flush mirror region → ring |
| 2. Evict oldest new_summary → old_summary in key_cache (if full) |
| 3. Write new summary → new_summary_buf |
| """ |
| n_summary = key_states.shape[2] |
|
|
| if self.is_large_window[layer_idx]: |
| self.cur_chunk_sizes[layer_idx] += n_summary |
| self._total_chunks[layer_idx] += n_summary |
| return |
|
|
| C = self.summary_chunk_size |
| ws = self.window_sizes[layer_idx] |
| scn = self.sliding_chunk_nums[layer_idx] |
| cbl = self._chunk_buf_len[layer_idx] |
| ptr = self._window_write_ptr[layer_idx] |
| n_old = self._old_summary_len[layer_idx] |
|
|
| |
| mirror_start = ws + n_old |
| if ptr + cbl <= ws: |
| self.key_cache[layer_idx][:, :, ptr:ptr + cbl, :].copy_( |
| self.key_cache[layer_idx][:, :, mirror_start:mirror_start + cbl, :]) |
| self.value_cache[layer_idx][:, :, ptr:ptr + cbl, :].copy_( |
| self.value_cache[layer_idx][:, :, mirror_start:mirror_start + cbl, :]) |
| else: |
| first = ws - ptr |
| self.key_cache[layer_idx][:, :, ptr:ws, :].copy_( |
| self.key_cache[layer_idx][:, :, mirror_start:mirror_start + first, :]) |
| self.value_cache[layer_idx][:, :, ptr:ws, :].copy_( |
| self.value_cache[layer_idx][:, :, mirror_start:mirror_start + first, :]) |
| rest = cbl - first |
| self.key_cache[layer_idx][:, :, :rest, :].copy_( |
| self.key_cache[layer_idx][:, :, mirror_start + first:mirror_start + cbl, :]) |
| self.value_cache[layer_idx][:, :, :rest, :].copy_( |
| self.value_cache[layer_idx][:, :, mirror_start + first:mirror_start + cbl, :]) |
|
|
| self._window_write_ptr[layer_idx] = (ptr + cbl) % ws |
| if self._n_valid_window[layer_idx] < ws: |
| self._n_valid_window[layer_idx] = min(ws, self._n_valid_window[layer_idx] + cbl) |
| self._chunk_buf_len[layer_idx] = 0 |
|
|
| |
| if self._new_sum_len[layer_idx] >= scn: |
| read_ptr = self._new_sum_write_ptr[layer_idx] |
| old_dst = ws + n_old |
|
|
| |
| needed = old_dst + 1 + C |
| if needed > self._capacity[layer_idx]: |
| new_s_cap = max(self._old_summary_cap[layer_idx] * 2, n_old + self._SUMMARY_INIT_CAP) |
| new_total = ws + new_s_cap + C |
| old_k, old_v = self.key_cache[layer_idx], self.value_cache[layer_idx] |
| bsz, heads, _, head_dim = old_k.shape |
| nk = torch.empty(bsz, heads, new_total, head_dim, dtype=old_k.dtype, device=old_k.device) |
| nv = torch.empty(bsz, heads, new_total, head_dim, dtype=old_v.dtype, device=old_v.device) |
| copy_len = ws + n_old |
| nk[:, :, :copy_len, :].copy_(old_k[:, :, :copy_len, :]) |
| nv[:, :, :copy_len, :].copy_(old_v[:, :, :copy_len, :]) |
| self.key_cache[layer_idx] = nk |
| self.value_cache[layer_idx] = nv |
| self._old_summary_cap[layer_idx] = new_s_cap |
| self._capacity[layer_idx] = new_total |
|
|
| self.key_cache[layer_idx][:, :, old_dst:old_dst+1, :].copy_( |
| self._new_sum_key_buf[layer_idx][:, :, read_ptr:read_ptr+1, :]) |
| self.value_cache[layer_idx][:, :, old_dst:old_dst+1, :].copy_( |
| self._new_sum_value_buf[layer_idx][:, :, read_ptr:read_ptr+1, :]) |
| self._old_summary_len[layer_idx] += 1 |
|
|
| |
| w_ptr = self._new_sum_write_ptr[layer_idx] |
| self._new_sum_key_buf[layer_idx][:, :, w_ptr:w_ptr+1, :].copy_(key_states) |
| self._new_sum_value_buf[layer_idx][:, :, w_ptr:w_ptr+1, :].copy_(value_states) |
| self._new_sum_write_ptr[layer_idx] = (w_ptr + 1) % scn |
| if self._new_sum_len[layer_idx] < scn: |
| self._new_sum_len[layer_idx] += 1 |
|
|
| self.cur_chunk_sizes[layer_idx] += n_summary |
| self._total_chunks[layer_idx] += n_summary |
|
|
| |
|
|
| def get_attention_kv(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: |
| """Get full KV for text token attention. |
| |
| Large window: buffer[:text_len] |
| Small window (steady state): key_cache[:ws + n_old + cbl] — single slice, zero cat |
| """ |
| if self.is_large_window[layer_idx]: |
| tl = self._text_len[layer_idx] |
| return (self.key_cache[layer_idx][:, :, :tl, :], |
| self.value_cache[layer_idx][:, :, :tl, :]) |
|
|
| ws = self.window_sizes[layer_idx] |
| nv = self._n_valid_window[layer_idx] |
| cbl = self._chunk_buf_len[layer_idx] |
| n_old = self._old_summary_len[layer_idx] |
|
|
| |
| if nv == ws: |
| end = ws + n_old + cbl |
| return (self.key_cache[layer_idx][:, :, :end, :], |
| self.value_cache[layer_idx][:, :, :end, :]) |
|
|
| |
| parts_k, parts_v = [], [] |
| if nv > 0: |
| parts_k.append(self.key_cache[layer_idx][:, :, :nv, :]) |
| parts_v.append(self.value_cache[layer_idx][:, :, :nv, :]) |
| if cbl > 0: |
| mirror_start = ws + n_old |
| parts_k.append(self.key_cache[layer_idx][:, :, mirror_start:mirror_start + cbl, :]) |
| parts_v.append(self.value_cache[layer_idx][:, :, mirror_start:mirror_start + cbl, :]) |
| if n_old > 0: |
| parts_k.append(self.key_cache[layer_idx][:, :, ws:ws + n_old, :]) |
| parts_v.append(self.value_cache[layer_idx][:, :, ws:ws + n_old, :]) |
| if len(parts_k) == 1: |
| return parts_k[0], parts_v[0] |
| return torch.cat(parts_k, dim=2), torch.cat(parts_v, dim=2) |
|
|
| def get_current_chunk_kv(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: |
| """Get KV of the current chunk's C text tokens for summary token attention.""" |
| C = self.summary_chunk_size |
| if self.is_large_window[layer_idx]: |
| tl = self._text_len[layer_idx] |
| return (self.key_cache[layer_idx][:, :, tl - C:tl, :], |
| self.value_cache[layer_idx][:, :, tl - C:tl, :]) |
| else: |
| ws = self.window_sizes[layer_idx] |
| n_old = self._old_summary_len[layer_idx] |
| cbl = self._chunk_buf_len[layer_idx] |
| mirror_start = ws + n_old |
| return (self.key_cache[layer_idx][:, :, mirror_start:mirror_start + cbl, :], |
| self.value_cache[layer_idx][:, :, mirror_start:mirror_start + cbl, :]) |
|
|
| def reset_chunk_counter(self): |
| """Reset chunk counters after a chunk boundary step completes.""" |
| block = self.summary_chunk_size + self.summary_token_num |
| for layer_idx in range(self.num_hidden_layers): |
| if self.cur_chunk_sizes[layer_idx] >= block: |
| self.cur_chunk_sizes[layer_idx] %= block |
|
|
|
|
| class Qwen3MLP(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = config.intermediate_size |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
| self.act_fn = ACT2FN[config.hidden_act] |
|
|
| def forward(self, x): |
| down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
| return down_proj |
|
|
|
|
| def rotate_half(x): |
| """Rotates half the hidden dims of the input.""" |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| Args: |
| q (`torch.Tensor`): The query tensor. |
| k (`torch.Tensor`): The key tensor. |
| cos (`torch.Tensor`): The cosine part of the rotary embedding. |
| sin (`torch.Tensor`): The sine part of the rotary embedding. |
| position_ids (`torch.Tensor`, *optional*): |
| Deprecated and unused. |
| unsqueeze_dim (`int`, *optional*, defaults to 1): |
| The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
| sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
| that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
| k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
| cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
| the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
| Returns: |
| `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
| """ |
| cos = cos.unsqueeze(unsqueeze_dim) |
| sin = sin.unsqueeze(unsqueeze_dim) |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
|
|
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """ |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
| """ |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
| def _sdpa_attention_forward( |
| module: nn.Module, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| scaling: float, |
| dropout: float = 0.0, |
| **kwargs: Unpack[TransformersKwargs], |
| ): |
| key_states = repeat_kv(key, module.num_key_value_groups) |
| value_states = repeat_kv(value, module.num_key_value_groups) |
| attn_output = F.scaled_dot_product_attention( |
| query, |
| key_states, |
| value_states, |
| attn_mask=None, |
| dropout_p=dropout, |
| is_causal=False, |
| ) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
| return attn_output, None |
| |
|
|
|
|
| class Qwen3Attention(nn.Module): |
| """Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
| def __init__(self, config: Qwen3Config, layer_idx: int): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) |
| self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads |
| self.scaling = self.head_dim**-0.5 |
| self.attention_dropout = config.attention_dropout |
|
|
| self.q_proj = nn.Linear( |
| config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.k_proj = nn.Linear( |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.v_proj = nn.Linear( |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None |
|
|
| @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| attention_mask: Optional[torch.Tensor], |
| past_key_values: Optional[Cache] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs: Unpack[FlashAttentionKwargs], |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
| input_shape = hidden_states.shape[:-1] |
| hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
|
|
| query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) |
| key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) |
| value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
| cos, sin = position_embeddings |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
| if past_key_values is not None: |
| |
| cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
| key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
| attn_output, attn_weights = _sdpa_attention_forward( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| scaling=self.scaling, |
| **kwargs, |
| ) |
|
|
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
| attn_output = self.o_proj(attn_output) |
| return attn_output, attn_weights |
|
|
|
|
| class Qwen3SummaryAttention(Qwen3Attention): |
| """ |
| Summary-aware variant of Qwen3Attention: uses a sliding summary mask. |
| """ |
|
|
| def __init__(self, config: Qwen3Config, layer_idx: int): |
| super().__init__(config, layer_idx) |
| self.summary_chunk_size = getattr(self.config, "summary_chunk_size", 0) |
| self.summary_token_num = getattr(self.config, "summary_token_num", 0) |
|
|
| |
| val = getattr(config, "summary_sliding_chunk_num", 0) or 0 |
| val = _parse_config_pattern(val) |
| if isinstance(val, list): |
| self._sliding_chunk_num = val[layer_idx] |
| else: |
| self._sliding_chunk_num = int(val) |
|
|
| if config.summary_independent_parameters and config.mix_coeff > 0: |
| self.q_proj_summary = nn.Linear( |
| config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.k_proj_summary = nn.Linear( |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.v_proj_summary = nn.Linear( |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
| ) |
|
|
| def _get_sliding_chunk_num(self): |
| return self._sliding_chunk_num |
|
|
| def get_query_key_value_tensors(self, hidden_states): |
| input_shape = hidden_states.shape[:-1] |
| hidden_shape = (*input_shape, -1, self.head_dim) |
| query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) |
| key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) |
| value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
| return query_states, key_states, value_states |
|
|
| def get_query_key_value_tensors_summary(self, hidden_states): |
| input_shape = hidden_states.shape[:-1] |
| hidden_shape = (*input_shape, -1, self.head_dim) |
| query_states = self.q_norm(self.q_proj_summary(hidden_states).view(hidden_shape)).transpose(1, 2) |
| key_states = self.k_norm(self.k_proj_summary(hidden_states).view(hidden_shape)).transpose(1, 2) |
| value_states = self.v_proj_summary(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
| return query_states, key_states, value_states |
|
|
| @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| attention_mask: Optional[torch.Tensor] = None, |
| past_key_values: Optional[Cache] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| summary_ctx: Optional[SummaryBatchContext] = None, |
| **kwargs, |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
| input_shape = hidden_states.shape[:-1] |
| if hidden_states.size(0) != 1: |
| raise ValueError("Summary sliding attention only supports batch size=1.") |
|
|
| |
| if self.config.summary_independent_parameters: |
| if summary_ctx is None: |
| raise ValueError("summary_ctx is required when using summary_independent_parameters.") |
| summary_mask = summary_ctx.summary_mask |
| summary_pos = summary_mask[0] |
| assert (summary_mask == summary_mask[0:1]).all() |
|
|
| if self.config.mix_coeff == 0: |
| |
| query_states, key_states, value_states = self.get_query_key_value_tensors(hidden_states) |
| else: |
| query, key, value = self.get_query_key_value_tensors(hidden_states) |
|
|
| query_states = query.clone() |
| key_states = key.clone() |
| value_states = value.clone() |
|
|
| hs_summary = hidden_states[:, summary_pos, :] |
| if hs_summary.size(1) > 0: |
| base_q_summary = query[:, :, summary_pos, :] |
| base_k_summary = key[:, :, summary_pos, :] |
| base_v_summary = value[:, :, summary_pos, :] |
|
|
| q_s, k_s, v_s = self.get_query_key_value_tensors_summary(hs_summary) |
|
|
| q_s = self.config.mix_coeff * q_s + (1.0 - self.config.mix_coeff) * base_q_summary |
| k_s = self.config.mix_coeff * k_s + (1.0 - self.config.mix_coeff) * base_k_summary |
| v_s = self.config.mix_coeff * v_s + (1.0 - self.config.mix_coeff) * base_v_summary |
|
|
| query_states[:, :, summary_pos, :] = q_s |
| key_states[:, :, summary_pos, :] = k_s |
| value_states[:, :, summary_pos, :] = v_s |
| else: |
| query_states, key_states, value_states = self.get_query_key_value_tensors(hidden_states) |
|
|
| cos, sin = position_embeddings |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
| query_len = query_states.shape[2] |
| is_prefill = past_key_values is None or not past_key_values._reorganized |
|
|
| if is_prefill: |
| |
| if past_key_values is not None: |
| cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
| if summary_ctx is not None: |
| cache_kwargs["summary_mask"] = summary_ctx.summary_mask |
| key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
| with torch.cuda.device(query_states.device): |
| attn_output, attn_weights = summary_attn_func( |
| query_states.transpose(1,2).contiguous(), |
| key_states.transpose(1,2).contiguous(), |
| value_states.transpose(1,2).contiguous(), |
| self.summary_chunk_size, |
| self.summary_token_num, |
| self._get_sliding_chunk_num(), |
| summary_pos=summary_ctx.summary_mask.squeeze() |
| ) |
| elif query_len == 1: |
| |
| past_key_values.update_text(key_states, value_states, self.layer_idx) |
| k_full, v_full = past_key_values.get_attention_kv(self.layer_idx) |
| attn_output, attn_weights = _sdpa_attention_forward( |
| self, |
| query_states, |
| k_full, |
| v_full, |
| None, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| scaling=self.scaling, |
| sliding_window=self.sliding_window, |
| **kwargs, |
| ) |
| else: |
| |
| |
| q_text = query_states[:, :, :1, :] |
| q_summary = query_states[:, :, 1:, :] |
| k_text = key_states[:, :, :1, :] |
| v_text = value_states[:, :, :1, :] |
| k_summary = key_states[:, :, 1:, :] |
| v_summary = value_states[:, :, 1:, :] |
|
|
| |
| past_key_values.update_text(k_text, v_text, self.layer_idx) |
| k_full, v_full = past_key_values.get_attention_kv(self.layer_idx) |
| text_out, _ = _sdpa_attention_forward( |
| self, |
| q_text, |
| k_full, |
| v_full, |
| None, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| scaling=self.scaling, |
| sliding_window=self.sliding_window, |
| **kwargs, |
| ) |
|
|
| |
| |
| |
| k_chunk, v_chunk = past_key_values.get_current_chunk_kv(self.layer_idx) |
| k_chunk_with_self = torch.cat([k_chunk, k_summary], dim=2) |
| v_chunk_with_self = torch.cat([v_chunk, v_summary], dim=2) |
| summary_out, _ = _sdpa_attention_forward( |
| self, |
| q_summary, |
| k_chunk_with_self, |
| v_chunk_with_self, |
| None, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| scaling=self.scaling, |
| sliding_window=self.sliding_window, |
| **kwargs, |
| ) |
|
|
| |
| past_key_values.update_summary(k_summary, v_summary, self.layer_idx) |
|
|
| attn_output = torch.cat([text_out, summary_out], dim=2) |
| attn_weights = None |
|
|
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
| attn_output = self.o_proj(attn_output) |
| return attn_output, attn_weights |
|
|
|
|
| class Qwen3DecoderLayer(GradientCheckpointingLayer): |
| def __init__(self, config: Qwen3Config, layer_idx: int): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
|
|
| |
| if getattr(config, "use_summary_attention", False) is True and config.summary_layer_freq[layer_idx] == 1: |
| self.self_attn = Qwen3SummaryAttention(config=config, layer_idx=layer_idx) |
| elif getattr(config, "use_summary_attention", False) is False and config.summary_layer_freq[layer_idx] == 0: |
| self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx) |
| else: |
| raise ValueError(f'Check config.summary_layer_freq {config.summary_layer_freq} and config.use_summary_attention {config.use_summary_attention}') |
|
|
| self.mlp = Qwen3MLP(config) |
| self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| if getattr(config, "summary_independent_attention_layernorm", False): |
| self.input_layernorm_summary = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.attention_type = config.layer_types[layer_idx] |
|
|
| @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| use_cache: Optional[bool] = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
| summary_ctx: Optional[SummaryBatchContext] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> torch.Tensor: |
| residual = hidden_states |
| if getattr(self.config, "summary_independent_attention_layernorm", False): |
| summary_mask = summary_ctx.summary_mask |
| assert (summary_mask == summary_mask[0:1]).all(), \ |
| "summary_mask must be identical across batch" |
| hidden_states = self.input_layernorm(hidden_states) |
| if summary_mask.any(): |
| hidden_summary = residual[:, summary_mask[0].to(residual.device), :] |
| hidden_summary = self.input_layernorm_summary(hidden_summary) |
| hidden_states[:, summary_mask[0], :] = hidden_summary |
| else: |
| hidden_states = self.input_layernorm(hidden_states) |
| |
| |
| attn_kwargs = { |
| "hidden_states": hidden_states, |
| "attention_mask": attention_mask, |
| "position_ids": position_ids, |
| "past_key_values": past_key_values, |
| "use_cache": use_cache, |
| "cache_position": cache_position, |
| "position_embeddings": position_embeddings, |
| **kwargs, |
| } |
| if isinstance(self.self_attn, Qwen3SummaryAttention): |
| attn_kwargs["summary_ctx"] = summary_ctx |
| |
| hidden_states, _ = self.self_attn(**attn_kwargs) |
| hidden_states = residual + hidden_states |
|
|
| |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
| return hidden_states |
|
|
|
|
| @auto_docstring |
| class Qwen3PreTrainedModel(PreTrainedModel): |
| config: Qwen3Config |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["Qwen3DecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
|
|
| _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": Qwen3DecoderLayer, |
| "attentions": Qwen3Attention, |
| } |
|
|
|
|
| class Qwen3RotaryEmbedding(nn.Module): |
| inv_freq: torch.Tensor |
|
|
| def __init__(self, config: Qwen3Config, device=None): |
| super().__init__() |
| self.max_seq_len_cached = config.max_position_embeddings |
| self.original_max_seq_len = config.max_position_embeddings |
|
|
| self.config = config |
|
|
| self.rope_type = self.config.rope_parameters["rope_type"] |
| rope_init_fn: Callable = self.compute_default_rope_parameters |
| if self.rope_type != "default": |
| rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
| inv_freq, self.attention_scaling = rope_init_fn(self.config, device) |
|
|
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self.original_inv_freq = inv_freq |
|
|
| @staticmethod |
| def compute_default_rope_parameters( |
| config: Optional[Qwen3Config] = None, |
| device: Optional["torch.device"] = None, |
| seq_len: Optional[int] = None, |
| ) -> tuple["torch.Tensor", float]: |
| """ |
| Computes the inverse frequencies according to the original RoPE implementation |
| Args: |
| config ([`~transformers.PreTrainedConfig`]): |
| The model configuration. |
| device (`torch.device`): |
| The device to use for initialization of the inverse frequencies. |
| seq_len (`int`, *optional*): |
| The current sequence length. Unused for this type of RoPE. |
| Returns: |
| Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the |
| post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). |
| """ |
| base = config.rope_parameters["rope_theta"] |
| dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads |
|
|
| attention_factor = 1.0 |
|
|
| |
| inv_freq = 1.0 / ( |
| base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) |
| ) |
| return inv_freq, attention_factor |
|
|
| @torch.no_grad() |
| @dynamic_rope_update |
| def forward(self, x, position_ids): |
| inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) |
| position_ids_expanded = position_ids[:, None, :].float() |
|
|
| device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
| with torch.autocast(device_type=device_type, enabled=False): |
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos = emb.cos() * self.attention_scaling |
| sin = emb.sin() * self.attention_scaling |
|
|
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
| @auto_docstring |
| class Qwen3Model(Qwen3PreTrainedModel): |
| def __init__(self, config: Qwen3Config): |
| super().__init__(config) |
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
| if not getattr(config, "summary_layer_freq", False): |
| if config.use_summary_attention: |
| config.summary_layer_freq = [1]*config.num_hidden_layers |
| else: |
| config.summary_layer_freq = [0]*config.num_hidden_layers |
| Warning(f'Please set config.summary_layer_freq, temp set summary_layer_freq = {config.num_hidden_layers}') |
| else: |
| config.summary_layer_freq = _parse_config_pattern(config.summary_layer_freq) |
|
|
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
| self.layers = nn.ModuleList( |
| [Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
| ) |
| self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.rotary_emb = Qwen3RotaryEmbedding(config=config) |
| self.gradient_checkpointing = False |
| self.has_sliding_layers = "sliding_attention" in self.config.layer_types |
|
|
| |
| _sv = _parse_config_pattern(getattr(config, "summary_sliding_chunk_num", 0) or 0) |
| if isinstance(_sv, list): |
| self._sliding_chunk_nums = [int(v) for v in _sv] |
| else: |
| self._sliding_chunk_nums = [int(_sv)] * config.num_hidden_layers |
|
|
| |
| self.post_init() |
|
|
| def _expand_input_with_summary_tokens(self, input_ids): |
| """Expand input_ids with summary tokens for prefill phase (vectorized). |
| |
| Returns: |
| Tuple of (expanded_input_ids, position_ids, text_only_mask) |
| """ |
| summary_chunk = self.config.summary_chunk_size |
| summary_num = self.config.summary_token_num |
| summary_begin = self.config.summary_token_begin |
|
|
| if summary_chunk == 0 or summary_num == 0: |
| return input_ids, None, None |
|
|
| batch_size, seq_len = input_ids.shape |
| device = input_ids.device |
| dtype = input_ids.dtype |
| block = summary_chunk + summary_num |
|
|
| |
| n_full_chunks = seq_len // summary_chunk |
| remainder = seq_len % summary_chunk |
| has_remainder = remainder > 0 |
|
|
| |
| expanded_len = n_full_chunks * block + (remainder if has_remainder else 0) |
|
|
| |
| expanded_ids = torch.empty((batch_size, expanded_len), dtype=dtype, device=device) |
| text_only_mask = torch.zeros((batch_size, expanded_len), dtype=torch.bool, device=device) |
|
|
| |
| |
| if n_full_chunks > 0: |
| chunk_indices = torch.arange(n_full_chunks, device=device) |
| |
| text_src_offsets = (chunk_indices * summary_chunk).unsqueeze(1) + torch.arange(summary_chunk, device=device).unsqueeze(0) |
| |
| text_dst_offsets = (chunk_indices * block).unsqueeze(1) + torch.arange(summary_chunk, device=device).unsqueeze(0) |
| |
| summary_dst_offsets = (chunk_indices * block + summary_chunk).unsqueeze(1) + torch.arange(summary_num, device=device).unsqueeze(0) |
|
|
| text_src_flat = text_src_offsets.reshape(-1) |
| text_dst_flat = text_dst_offsets.reshape(-1) |
| summary_dst_flat = summary_dst_offsets.reshape(-1) |
|
|
| |
| expanded_ids[:, text_dst_flat] = input_ids[:, text_src_flat] |
| text_only_mask[:, text_dst_flat] = True |
|
|
| |
| summary_ids_val = torch.arange(summary_num, device=device, dtype=dtype) + summary_begin |
| expanded_ids[:, summary_dst_flat] = summary_ids_val.repeat(n_full_chunks).unsqueeze(0).expand(batch_size, -1) |
|
|
| |
| if has_remainder: |
| rem_start_src = n_full_chunks * summary_chunk |
| rem_start_dst = n_full_chunks * block |
| rem_offsets = torch.arange(remainder, device=device) |
| expanded_ids[:, rem_start_dst + rem_offsets] = input_ids[:, rem_start_src + rem_offsets] |
| text_only_mask[:, rem_start_dst + rem_offsets] = True |
|
|
| |
| position_ids = torch.empty((batch_size, expanded_len), dtype=torch.long, device=device) |
|
|
| if n_full_chunks > 0: |
| |
| if self.config.summary_chunk_position_ids_type == 'origin': |
| text_pos = text_src_flat.unsqueeze(0).expand(batch_size, -1) |
| elif self.config.summary_chunk_position_ids_type == 'inner_chunk': |
| inner_pos = torch.arange(summary_chunk, device=device).repeat(n_full_chunks) |
| text_pos = inner_pos.unsqueeze(0).expand(batch_size, -1) |
| else: |
| raise ValueError(f'Check config.summary_chunk_position_ids_type: {self.config.summary_chunk_position_ids_type}') |
| position_ids[:, text_dst_flat] = text_pos |
|
|
| |
| if self.config.summary_token_position_ids_type == 'zeros': |
| position_ids[:, summary_dst_flat] = 0 |
| elif self.config.summary_token_position_ids_type in ('last_chunk_slice_left', 'last_chunk_slice_right'): |
| |
| if self.config.summary_token_position_ids_type == 'last_chunk_slice_left': |
| idx = torch.arange(0, summary_num, device=device, dtype=torch.long) |
| else: |
| idx = torch.arange(1, summary_num + 1, device=device, dtype=torch.long) |
| |
| prev_ends = (chunk_indices * summary_chunk).unsqueeze(1) |
| slice_ends = prev_ends + (idx.unsqueeze(0) * summary_chunk) // summary_num - 1 |
| slice_ends = slice_ends.clamp(min=0) |
| |
| slice_ends = torch.max(slice_ends, prev_ends) |
| position_ids[:, summary_dst_flat] = slice_ends.reshape(-1).unsqueeze(0).expand(batch_size, -1) |
| else: |
| raise ValueError(f'Unknown summary_token_position_ids_type: {self.config.summary_token_position_ids_type}') |
|
|
| |
| if has_remainder: |
| if self.config.summary_chunk_position_ids_type == 'origin': |
| rem_pos = (rem_start_src + rem_offsets).unsqueeze(0).expand(batch_size, -1) |
| elif self.config.summary_chunk_position_ids_type == 'inner_chunk': |
| rem_pos = rem_offsets.unsqueeze(0).expand(batch_size, -1) |
| else: |
| raise ValueError(f'Check config.summary_chunk_position_ids_type: {self.config.summary_chunk_position_ids_type}') |
| position_ids[:, rem_start_dst + rem_offsets] = rem_pos |
|
|
| return expanded_ids, position_ids, text_only_mask |
| |
| def _build_summary_context(self, input_ids, position_ids, is_prefill, use_cache): |
| """Build summary context for attention layers.""" |
| summary_chunk = self.config.summary_chunk_size |
| summary_num = self.config.summary_token_num |
| summary_begin = self.config.summary_token_begin |
|
|
| if summary_chunk > 0 and summary_num > 0: |
| return build_summary_sliding_context( |
| input_ids=input_ids, |
| position_ids=position_ids, |
| summary_token_num=summary_num, |
| summary_token_begin=summary_begin, |
| ) |
| return None |
| |
| def _filter_summary_tokens(self, hidden_states, text_only_mask, use_summary, is_decode): |
| """Filter out summary tokens from output hidden states.""" |
| if text_only_mask is not None: |
| |
| batch_size, _, hidden_size = hidden_states.shape |
| text_length = text_only_mask[0].sum().item() |
| return hidden_states[text_only_mask.to(hidden_states.device)].reshape(batch_size, text_length, hidden_size) |
| elif use_summary and is_decode and hidden_states.size(1) > 1: |
| |
| return hidden_states[:, :1, :] |
| return hidden_states |
|
|
| @check_model_inputs() |
| @auto_docstring |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| summary_ctx: Optional[SummaryBatchContext] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> BaseModelOutputWithPast: |
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
| use_summary = getattr(self.config, "use_summary_attention", False) |
| is_prefill = past_key_values is None or past_key_values.get_seq_length() == 0 |
| |
| |
| text_only_mask = None |
| if use_summary and input_ids is not None and inputs_embeds is None and is_prefill: |
| input_ids, position_ids, text_only_mask = self._expand_input_with_summary_tokens(input_ids) |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| |
| if use_cache and past_key_values is None: |
| if use_summary: |
| past_key_values = Qwen3RingBufferCache( |
| config=self.config, sliding_chunk_nums=self._sliding_chunk_nums) |
| else: |
| past_key_values = DynamicCache(config=self.config) |
|
|
| if cache_position is None: |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| cache_position = torch.arange( |
| past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
| ) |
|
|
| if position_ids is None: |
| position_ids = cache_position.unsqueeze(0) |
|
|
| |
| if use_summary and summary_ctx is None and input_ids is not None: |
| summary_ctx = self._build_summary_context(input_ids, position_ids, is_prefill, use_cache) |
|
|
| causal_mask_mapping = attention_mask |
| if not isinstance(causal_mask_mapping, (dict, list)): |
| if summary_ctx and summary_ctx.enabled: |
| seq_len = inputs_embeds.shape[1] |
| |
| |
| |
| |
| |
| causal_mask_mapping = None |
| else: |
| |
| mask_kwargs = { |
| "config": self.config, |
| "input_embeds": inputs_embeds, |
| "attention_mask": attention_mask, |
| "cache_position": cache_position, |
| "past_key_values": past_key_values, |
| "position_ids": position_ids, |
| } |
| |
| causal_mask_mapping = { |
| "full_attention": create_causal_mask(**mask_kwargs), |
| } |
| |
| if self.has_sliding_layers: |
| causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) |
|
|
| hidden_states = inputs_embeds |
|
|
| |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
| for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): |
| if causal_mask_mapping is None: |
| layer_mask = None |
| elif isinstance(causal_mask_mapping, list): |
| layer_mask = causal_mask_mapping[layer_idx] |
| else: |
| layer_mask = causal_mask_mapping[decoder_layer.attention_type] |
| hidden_states = decoder_layer( |
| hidden_states, |
| attention_mask=layer_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| summary_ctx=summary_ctx, |
| **kwargs, |
| ) |
|
|
| hidden_states = self.norm(hidden_states) |
|
|
| |
| if use_cache and use_summary and past_key_values is not None and is_prefill: |
| if hasattr(past_key_values, 'reorganize_after_prefill') and summary_ctx is not None: |
| past_key_values.reorganize_after_prefill(summary_ctx.summary_mask) |
|
|
| |
| if use_cache and use_summary and past_key_values is not None and not is_prefill: |
| if hasattr(past_key_values, 'reset_chunk_counter'): |
| past_key_values.reset_chunk_counter() |
| |
| |
| hidden_states = self._filter_summary_tokens(hidden_states, text_only_mask, use_summary, |
| past_key_values is not None and past_key_values.get_seq_length() > 0) |
|
|
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=past_key_values if use_cache else None, |
| ) |
|
|
|
|
| @auto_docstring |
| class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin): |
| _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} |
| _tp_plan = {"lm_head": "colwise_rep"} |
| _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = Qwen3Model(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|
| @can_return_tuple |
| @auto_docstring |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| summary_ctx: Optional[SummaryBatchContext] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> CausalLMOutputWithPast: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, Qwen3ForCausalLM |
| |
| >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B") |
| >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") |
| |
| >>> prompt = "Hey, are you conscious? Can you talk to me?" |
| >>> inputs = tokenizer(prompt, return_tensors="pt") |
| |
| >>> # Generate |
| >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
| ```""" |
| outputs: BaseModelOutputWithPast = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| summary_ctx=summary_ctx, |
| **kwargs, |
| ) |
|
|
| hidden_states = outputs.last_hidden_state |
| |
| if isinstance(logits_to_keep, int) and logits_to_keep == 0 and labels is None: |
| |
| logits = self.lm_head(hidden_states[:, -1:, :]) |
| else: |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
| truncate_n = getattr(self.config, "truncate_predict_nums", 151936) |
| if truncate_n > 0: |
| logits = logits[..., :truncate_n] |
|
|
| loss = None |
| if labels is not None: |
| loss = self.loss_function(logits=logits, labels=labels, vocab_size=logits.shape[-1], **kwargs) |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| def _build_summary_attention_mask_for_generation( |
| self, |
| *, |
| input_ids: torch.LongTensor, |
| past_key_values: Optional[Cache], |
| attention_mask: Optional[torch.Tensor], |
| ) -> Optional[torch.Tensor]: |
| """Ring buffer cache handles attention internally — no mask needed for decode.""" |
| if isinstance(past_key_values, Qwen3RingBufferCache): |
| return None |
| return attention_mask |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids: torch.LongTensor, |
| past_key_values: Optional[Cache] = None, |
| attention_mask: Optional[torch.LongTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| **kwargs, |
| ): |
| use_summary = getattr(self.config, "use_summary_attention", False) |
|
|
| |
| if not use_summary: |
| return super().prepare_inputs_for_generation( |
| input_ids=input_ids, |
| past_key_values=past_key_values, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| cache_position=cache_position, |
| position_ids=position_ids, |
| **kwargs, |
| ) |
|
|
| |
| summary_chunk_size = getattr(self.config, "summary_chunk_size", 0) |
| summary_token_num = getattr(self.config, "summary_token_num", 0) |
| summary_token_begin = getattr(self.config, "summary_token_begin", 0) |
| |
| |
| if past_key_values is None or past_key_values.get_seq_length() == 0: |
| if cache_position is None: |
| cache_position = torch.arange(0, input_ids.shape[1], device=input_ids.device) |
| |
| return { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "position_ids": position_ids, |
| "past_key_values": past_key_values, |
| "cache_position": cache_position, |
| "use_cache": kwargs.get("use_cache"), |
| } |
| |
| |
| |
| cur_chunk = past_key_values.get_cur_chunk_size() if hasattr(past_key_values, "get_cur_chunk_size") else 0 |
| true_token_num = past_key_values.get_true_token_num() |
| |
| |
| if input_ids.shape[1] > 1: |
| |
| new_token_count = input_ids.shape[1] - true_token_num |
| assert new_token_count > 0, f'new_token_count={new_token_count} should be greater than 0' |
| input_ids = input_ids[:, -new_token_count:] |
| device = input_ids.device |
| |
| |
| if cur_chunk == summary_chunk_size - 1: |
| |
| batch_size = input_ids.shape[0] |
| summary_ids = ( |
| torch.arange(summary_token_num, device=device, dtype=input_ids.dtype) |
| + summary_token_begin |
| ).unsqueeze(0).repeat(batch_size, 1) |
| |
| |
| input_ids = torch.cat([input_ids, summary_ids], dim=1) |
| |
| |
| if self.config.summary_chunk_position_ids_type == 'origin': |
| text_pos = torch.full((batch_size, 1), past_key_values.get_true_token_num(), device=device, dtype=torch.long) |
| elif self.config.summary_chunk_position_ids_type == 'inner_chunk': |
| text_pos = torch.full((batch_size, 1), cur_chunk, device=device, dtype=torch.long) |
| else: |
| raise ValueError(f'Check config.summary_chunk_position_ids_type: {self.config.summary_chunk_position_ids_type}') |
| |
| if self.config.summary_token_position_ids_type == 'zeros': |
| summary_pos = torch.zeros((batch_size, summary_token_num), device=device, dtype=torch.long) |
| elif self.config.summary_token_position_ids_type == 'last_chunk_slice_left': |
| |
| prev_text_end = past_key_values.get_true_token_num()+1-summary_chunk_size |
| cur_text_end = past_key_values.get_true_token_num()+1 |
| chunk_len = cur_text_end - prev_text_end |
|
|
| idx = torch.arange(0, summary_token_num, device=device, dtype=torch.long,) |
|
|
| |
| slice_ends = prev_text_end + (idx * chunk_len) // summary_token_num - 1 |
| slice_ends = slice_ends.clamp(min=prev_text_end) |
|
|
| summary_pos = slice_ends.to(dtype=torch.long, device=device).unsqueeze(0) |
| elif self.config.summary_token_position_ids_type == 'last_chunk_slice_right': |
| |
| prev_text_end = past_key_values.get_true_token_num()+1-summary_chunk_size |
| cur_text_end = past_key_values.get_true_token_num()+1 |
| chunk_len = cur_text_end - prev_text_end |
|
|
| idx = torch.arange(1, summary_token_num + 1, device=device, dtype=torch.long,) |
|
|
| |
| slice_ends = prev_text_end + (idx * chunk_len) // summary_token_num - 1 |
| slice_ends = slice_ends.clamp(min=prev_text_end) |
|
|
| summary_pos = slice_ends.to(dtype=torch.long, device=device).unsqueeze(0) |
|
|
| else: |
| raise ValueError('') |
|
|
| position_ids = torch.cat([text_pos, summary_pos], dim=1) |
| else: |
| |
| if position_ids is None: |
| batch_size = input_ids.shape[0] |
| if self.config.summary_chunk_position_ids_type == 'origin': |
| position_ids = torch.full((batch_size, input_ids.shape[1]), past_key_values.get_true_token_num(), device=input_ids.device, dtype=torch.long) |
| elif self.config.summary_chunk_position_ids_type == 'inner_chunk': |
| position_ids = torch.full((batch_size, input_ids.shape[1]), cur_chunk, device=input_ids.device, dtype=torch.long) |
| else: |
| raise ValueError(f'Check config.summary_chunk_position_ids_type: {self.config.summary_chunk_position_ids_type}') |
| return { |
| "input_ids": input_ids, |
| "attention_mask": self._build_summary_attention_mask_for_generation( |
| input_ids=input_ids, |
| past_key_values=past_key_values, |
| attention_mask=attention_mask, |
| ), |
| "position_ids": position_ids, |
| "past_key_values": past_key_values, |
| "cache_position": cache_position, |
| "use_cache": kwargs.get("use_cache"), |
| } |
|
|
|
|
| class Qwen3ForSequenceClassification(GenericForSequenceClassification, Qwen3PreTrainedModel): |
| pass |
|
|
|
|
| class Qwen3ForTokenClassification(GenericForTokenClassification, Qwen3PreTrainedModel): |
| pass |
|
|
|
|
| class Qwen3ForQuestionAnswering(GenericForQuestionAnswering, Qwen3PreTrainedModel): |
| base_model_prefix = "transformer" |
|
|
|
|
| __all__ = [ |
| "Qwen3ForCausalLM", |
| "Qwen3ForQuestionAnswering", |
| "Qwen3PreTrainedModel", |
| "Qwen3Model", |
| "Qwen3ForSequenceClassification", |
| "Qwen3ForTokenClassification", |
| "Qwen3RingBufferCache", |
| "Qwen3SummaryAttention", |
| "SummaryBatchContext", |
| "build_summary_context", |
| "build_summary_sliding_context", |
| ] |
|
|