KSA-4B-base / modeling_qwen3.py
OpenOneRec's picture
Upload folder using huggingface_hub
6e72c5c verified
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/qwen3/modular_qwen3.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_qwen3.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Union
import torch
from torch import nn
import torch.nn.functional as F
#from flash_attn import flash_attn_func
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation import GenerationMixin
from transformers.integrations import use_kernel_forward_from_hub
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_layers import (
GenericForQuestionAnswering,
GenericForSequenceClassification,
GenericForTokenClassification,
GradientCheckpointingLayer,
)
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import check_model_inputs
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
# InfinityLM imports for summary attention
from .summary_context import SummaryBatchContext, build_summary_context, build_summary_sliding_context
from summary_attn import summary_attn_func
def _parse_config_pattern(val):
"""Parse a config value that may be an int, list, or Python pattern string like '([4096]*1+[128]*3)*9'."""
if isinstance(val, list):
return val
if isinstance(val, str):
return eval(val)
return val
@use_kernel_forward_from_hub("RMSNorm")
class Qwen3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
"""
Qwen3RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class Qwen3RingBufferCache:
"""
Ring buffer KV cache with summary support.
Two strategies based on per-layer sliding_chunk_num:
- Large window layers (is_large_window=True): append-only buffer storing only text KV.
Summary KV is NOT stored since text tokens attend to all text KV directly.
- Small window layers (is_large_window=False):
Three buffers:
1. key_cache: [ring(ws) | old_summaries(growing) | chunk_mirror(≤C)]
→ attention input, steady state is a single contiguous slice
2. new_summary_buf: ring buffer of size scn, stores summaries whose text
is still in the window (not needed for attention)
3. chunk_buf: size C, holds current chunk's text KV
RoPE position information is baked into KV, so physical order doesn't matter.
"""
is_compileable = False
_SUMMARY_INIT_CAP = 512
_APPEND_HEADROOM = 1024
def __init__(self, config: Qwen3Config, sliding_chunk_nums: list[int]):
super().__init__()
self.summary_chunk_size = getattr(config, "summary_chunk_size", 0)
self.summary_token_num = getattr(config, "summary_token_num", 0)
self.num_hidden_layers = config.num_hidden_layers
self.sliding_chunk_nums = sliding_chunk_nums
large_window_threshold = min(sliding_chunk_nums) * self.summary_chunk_size
self.is_large_window = [sv * self.summary_chunk_size > large_window_threshold for sv in sliding_chunk_nums]
self.window_sizes = [sv * self.summary_chunk_size for sv in sliding_chunk_nums]
self.key_cache = [None for _ in range(config.num_hidden_layers)]
self.value_cache = [None for _ in range(config.num_hidden_layers)]
# Large window: append-only
self._text_len = [0] * config.num_hidden_layers
self._capacity = [0] * config.num_hidden_layers
# Small window: ring buffer + summary
self._window_write_ptr = [0] * config.num_hidden_layers
self._n_valid_window = [0] * config.num_hidden_layers
self._old_summary_len = [0] * config.num_hidden_layers # old summaries in key_cache
self._old_summary_cap = [0] * config.num_hidden_layers
# New summary ring buffer (small window only): summaries whose text is still in window
self._new_sum_key_buf = [None for _ in range(config.num_hidden_layers)]
self._new_sum_value_buf = [None for _ in range(config.num_hidden_layers)]
self._new_sum_len = [0] * config.num_hidden_layers # how many filled (≤ scn)
self._new_sum_write_ptr = [0] * config.num_hidden_layers # ring write pointer
# Current chunk buffer (small window only): holds partial chunk text KV
self._chunk_key_buf = [None for _ in range(config.num_hidden_layers)]
self._chunk_value_buf = [None for _ in range(config.num_hidden_layers)]
self._chunk_buf_len = [0] * config.num_hidden_layers
# Common
self.cur_chunk_sizes = [0] * config.num_hidden_layers
self.true_tokens = [0] * config.num_hidden_layers
self._total_chunks = [0] * config.num_hidden_layers # completed chunks count
self._reorganized = False
def __len__(self):
return self.num_hidden_layers
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns nonzero when cache is populated (used to detect prefill vs decode)."""
if layer_idx >= self.num_hidden_layers:
return 0
if self.is_large_window[layer_idx]:
return self._text_len[layer_idx]
else:
return (self._n_valid_window[layer_idx] + self._chunk_buf_len[layer_idx]
+ self._old_summary_len[layer_idx] + self._new_sum_len[layer_idx])
def get_cur_chunk_size(self, layer_idx: Optional[int] = None) -> int:
if layer_idx is None:
layer_idx = self.num_hidden_layers - 1
return self.cur_chunk_sizes[layer_idx]
def get_true_token_num(self, layer_idx: Optional[int] = None) -> int:
if layer_idx is None:
layer_idx = self.num_hidden_layers - 1
return self.true_tokens[layer_idx]
# ── Prefill: standard append (before reorganize) ──
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Append KV during prefill (before reorganize). Returns full KV for prefill attention."""
add_len = key_states.shape[-2]
cur_len = self._text_len[layer_idx]
new_len = cur_len + add_len
if self.key_cache[layer_idx] is None:
cap = new_len + self._APPEND_HEADROOM
bsz, heads, _, head_dim = key_states.shape
self.key_cache[layer_idx] = torch.empty(
bsz, heads, cap, head_dim, dtype=key_states.dtype, device=key_states.device)
self.value_cache[layer_idx] = torch.empty(
bsz, heads, cap, head_dim, dtype=value_states.dtype, device=value_states.device)
self._capacity[layer_idx] = cap
elif new_len > self._capacity[layer_idx]:
cap = max(new_len + self._APPEND_HEADROOM, self._capacity[layer_idx] * 2)
old_k, old_v = self.key_cache[layer_idx], self.value_cache[layer_idx]
bsz, heads, _, head_dim = old_k.shape
new_k = torch.empty(bsz, heads, cap, head_dim, dtype=old_k.dtype, device=old_k.device)
new_v = torch.empty(bsz, heads, cap, head_dim, dtype=old_v.dtype, device=old_v.device)
new_k[:, :, :cur_len, :].copy_(old_k[:, :, :cur_len, :])
new_v[:, :, :cur_len, :].copy_(old_v[:, :, :cur_len, :])
self.key_cache[layer_idx] = new_k
self.value_cache[layer_idx] = new_v
self._capacity[layer_idx] = cap
self.key_cache[layer_idx][:, :, cur_len:new_len, :].copy_(key_states)
self.value_cache[layer_idx][:, :, cur_len:new_len, :].copy_(value_states)
self._text_len[layer_idx] = new_len
if self.summary_chunk_size > 0:
if cache_kwargs and 'summary_mask' in cache_kwargs:
text_count = add_len - cache_kwargs['summary_mask'][0].sum().item()
else:
text_count = add_len
self.cur_chunk_sizes[layer_idx] += add_len
self.true_tokens[layer_idx] += text_count
return self.key_cache[layer_idx][:, :, :new_len, :], self.value_cache[layer_idx][:, :, :new_len, :]
# ── Reorganize after prefill ──
def reorganize_after_prefill(self, summary_mask: torch.Tensor):
"""Reorganize all layers from prefill block layout to ring buffer layout."""
if self._reorganized:
return
self._reorganized = True
text_mask = ~summary_mask[0]
for layer_idx in range(self.num_hidden_layers):
prefill_len = self._text_len[layer_idx]
prefill_k = self.key_cache[layer_idx][:, :, :prefill_len, :]
prefill_v = self.value_cache[layer_idx][:, :, :prefill_len, :]
bsz, heads, _, head_dim = prefill_k.shape
device, dtype = prefill_k.device, prefill_k.dtype
text_k = prefill_k[:, :, text_mask, :]
text_v = prefill_v[:, :, text_mask, :]
n_text = text_k.shape[2]
if self.is_large_window[layer_idx]:
# Large window: keep only text KV
cap = n_text + self._APPEND_HEADROOM
new_k = torch.empty(bsz, heads, cap, head_dim, dtype=dtype, device=device)
new_v = torch.empty(bsz, heads, cap, head_dim, dtype=dtype, device=device)
new_k[:, :, :n_text, :].copy_(text_k)
new_v[:, :, :n_text, :].copy_(text_v)
self.key_cache[layer_idx] = new_k
self.value_cache[layer_idx] = new_v
self._text_len[layer_idx] = n_text
self._capacity[layer_idx] = cap
else:
# Small window: split summaries into old (evicted) and new (in window)
all_summary_k = prefill_k[:, :, summary_mask[0], :]
all_summary_v = prefill_v[:, :, summary_mask[0], :]
n_summary = all_summary_k.shape[2]
C = self.summary_chunk_size
ws = self.window_sizes[layer_idx]
scn = self.sliding_chunk_nums[layer_idx]
# Split text into complete chunks + partial remainder
n_complete_chunks = n_text // C
n_partial = n_text % C
n_complete_text = n_complete_chunks * C
# Window: last scn complete chunks (or all if fewer)
n_window_chunks = min(scn, n_complete_chunks)
n_window_text = n_window_chunks * C
window_start = n_complete_text - n_window_text
# Split summaries: old (text evicted from ring) vs new (text in ring)
n_old = max(0, n_summary - n_window_chunks)
n_new = n_summary - n_old
# key_cache: [ring(ws) | old_summaries | chunk_mirror(≤C)]
old_s_cap = max(self._SUMMARY_INIT_CAP, (n_old + 1) * 2)
total_cap = ws + old_s_cap + C
new_k = torch.empty(bsz, heads, total_cap, head_dim, dtype=dtype, device=device)
new_v = torch.empty(bsz, heads, total_cap, head_dim, dtype=dtype, device=device)
if n_window_text > 0:
new_k[:, :, :n_window_text, :].copy_(text_k[:, :, window_start:n_complete_text, :])
new_v[:, :, :n_window_text, :].copy_(text_v[:, :, window_start:n_complete_text, :])
self._n_valid_window[layer_idx] = n_window_text
self._window_write_ptr[layer_idx] = n_window_text % ws
# Old summaries go into key_cache after ring
if n_old > 0:
new_k[:, :, ws:ws + n_old, :].copy_(all_summary_k[:, :, :n_old, :])
new_v[:, :, ws:ws + n_old, :].copy_(all_summary_v[:, :, :n_old, :])
self._old_summary_len[layer_idx] = n_old
self._old_summary_cap[layer_idx] = old_s_cap
# Mirror partial chunk into key_cache after old_summaries
if n_partial > 0:
mirror_start = ws + n_old
new_k[:, :, mirror_start:mirror_start + n_partial, :].copy_(
text_k[:, :, n_complete_text:, :])
new_v[:, :, mirror_start:mirror_start + n_partial, :].copy_(
text_v[:, :, n_complete_text:, :])
self.key_cache[layer_idx] = new_k
self.value_cache[layer_idx] = new_v
self._capacity[layer_idx] = total_cap
self._text_len[layer_idx] = 0
# New summary ring buffer
ns_buf_k = torch.empty(bsz, heads, scn, head_dim, dtype=dtype, device=device)
ns_buf_v = torch.empty(bsz, heads, scn, head_dim, dtype=dtype, device=device)
if n_new > 0:
ns_buf_k[:, :, :n_new, :].copy_(all_summary_k[:, :, n_old:, :])
ns_buf_v[:, :, :n_new, :].copy_(all_summary_v[:, :, n_old:, :])
self._new_sum_key_buf[layer_idx] = ns_buf_k
self._new_sum_value_buf[layer_idx] = ns_buf_v
self._new_sum_len[layer_idx] = n_new
self._new_sum_write_ptr[layer_idx] = n_new % scn
# Chunk buffer for partial remainder
chunk_buf_k = torch.empty(bsz, heads, C, head_dim, dtype=dtype, device=device)
chunk_buf_v = torch.empty(bsz, heads, C, head_dim, dtype=dtype, device=device)
if n_partial > 0:
chunk_buf_k[:, :, :n_partial, :].copy_(text_k[:, :, n_complete_text:, :])
chunk_buf_v[:, :, :n_partial, :].copy_(text_v[:, :, n_complete_text:, :])
self._chunk_key_buf[layer_idx] = chunk_buf_k
self._chunk_value_buf[layer_idx] = chunk_buf_v
self._chunk_buf_len[layer_idx] = n_partial
block = self.summary_chunk_size + self.summary_token_num
for layer_idx in range(self.num_hidden_layers):
self.cur_chunk_sizes[layer_idx] = self.cur_chunk_sizes[layer_idx] % block
self._total_chunks[layer_idx] = (
self._old_summary_len[layer_idx] + self._new_sum_len[layer_idx]
if not self.is_large_window[layer_idx]
else (self.true_tokens[layer_idx] // self.summary_chunk_size)
)
# ── Decode: text token update ──
def update_text(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int):
"""Write a single text token KV during decode."""
if self.is_large_window[layer_idx]:
cur = self._text_len[layer_idx]
new_len = cur + 1
if new_len > self._capacity[layer_idx]:
cap = max(new_len + self._APPEND_HEADROOM, self._capacity[layer_idx] * 2)
old_k, old_v = self.key_cache[layer_idx], self.value_cache[layer_idx]
bsz, heads, _, head_dim = old_k.shape
new_k = torch.empty(bsz, heads, cap, head_dim, dtype=old_k.dtype, device=old_k.device)
new_v = torch.empty(bsz, heads, cap, head_dim, dtype=old_v.dtype, device=old_v.device)
new_k[:, :, :cur, :].copy_(old_k[:, :, :cur, :])
new_v[:, :, :cur, :].copy_(old_v[:, :, :cur, :])
self.key_cache[layer_idx] = new_k
self.value_cache[layer_idx] = new_v
self._capacity[layer_idx] = cap
self.key_cache[layer_idx][:, :, cur:new_len, :].copy_(key_states)
self.value_cache[layer_idx][:, :, cur:new_len, :].copy_(value_states)
self._text_len[layer_idx] = new_len
else:
# Write only to key_cache mirror region (chunk_buf eliminated)
ws = self.window_sizes[layer_idx]
n_old = self._old_summary_len[layer_idx]
pos = self._chunk_buf_len[layer_idx]
mirror_pos = ws + n_old + pos
self.key_cache[layer_idx][:, :, mirror_pos:mirror_pos+1, :].copy_(key_states)
self.value_cache[layer_idx][:, :, mirror_pos:mirror_pos+1, :].copy_(value_states)
self._chunk_buf_len[layer_idx] = pos + 1
self.cur_chunk_sizes[layer_idx] += 1
self.true_tokens[layer_idx] += 1
# ── Decode: summary token update ──
def update_summary(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int):
"""Write summary token KV during decode (chunk boundary).
Large window: skip.
Small window (order matters — flush mirror before evict to avoid clobbering):
1. Flush mirror region → ring
2. Evict oldest new_summary → old_summary in key_cache (if full)
3. Write new summary → new_summary_buf
"""
n_summary = key_states.shape[2]
if self.is_large_window[layer_idx]:
self.cur_chunk_sizes[layer_idx] += n_summary
self._total_chunks[layer_idx] += n_summary
return
C = self.summary_chunk_size
ws = self.window_sizes[layer_idx]
scn = self.sliding_chunk_nums[layer_idx]
cbl = self._chunk_buf_len[layer_idx]
ptr = self._window_write_ptr[layer_idx]
n_old = self._old_summary_len[layer_idx]
# Step 1: Flush mirror region → ring (must happen before evict touches mirror[0])
mirror_start = ws + n_old
if ptr + cbl <= ws:
self.key_cache[layer_idx][:, :, ptr:ptr + cbl, :].copy_(
self.key_cache[layer_idx][:, :, mirror_start:mirror_start + cbl, :])
self.value_cache[layer_idx][:, :, ptr:ptr + cbl, :].copy_(
self.value_cache[layer_idx][:, :, mirror_start:mirror_start + cbl, :])
else:
first = ws - ptr
self.key_cache[layer_idx][:, :, ptr:ws, :].copy_(
self.key_cache[layer_idx][:, :, mirror_start:mirror_start + first, :])
self.value_cache[layer_idx][:, :, ptr:ws, :].copy_(
self.value_cache[layer_idx][:, :, mirror_start:mirror_start + first, :])
rest = cbl - first
self.key_cache[layer_idx][:, :, :rest, :].copy_(
self.key_cache[layer_idx][:, :, mirror_start + first:mirror_start + cbl, :])
self.value_cache[layer_idx][:, :, :rest, :].copy_(
self.value_cache[layer_idx][:, :, mirror_start + first:mirror_start + cbl, :])
self._window_write_ptr[layer_idx] = (ptr + cbl) % ws
if self._n_valid_window[layer_idx] < ws:
self._n_valid_window[layer_idx] = min(ws, self._n_valid_window[layer_idx] + cbl)
self._chunk_buf_len[layer_idx] = 0
# Step 2: Evict oldest new_summary → old_summary (now safe — mirror already flushed)
if self._new_sum_len[layer_idx] >= scn:
read_ptr = self._new_sum_write_ptr[layer_idx]
old_dst = ws + n_old # == mirror_start, but mirror data is already in ring
# Check capacity for old_summary growth
needed = old_dst + 1 + C # +1 for new old_sum, +C for future chunk mirror
if needed > self._capacity[layer_idx]:
new_s_cap = max(self._old_summary_cap[layer_idx] * 2, n_old + self._SUMMARY_INIT_CAP)
new_total = ws + new_s_cap + C
old_k, old_v = self.key_cache[layer_idx], self.value_cache[layer_idx]
bsz, heads, _, head_dim = old_k.shape
nk = torch.empty(bsz, heads, new_total, head_dim, dtype=old_k.dtype, device=old_k.device)
nv = torch.empty(bsz, heads, new_total, head_dim, dtype=old_v.dtype, device=old_v.device)
copy_len = ws + n_old
nk[:, :, :copy_len, :].copy_(old_k[:, :, :copy_len, :])
nv[:, :, :copy_len, :].copy_(old_v[:, :, :copy_len, :])
self.key_cache[layer_idx] = nk
self.value_cache[layer_idx] = nv
self._old_summary_cap[layer_idx] = new_s_cap
self._capacity[layer_idx] = new_total
self.key_cache[layer_idx][:, :, old_dst:old_dst+1, :].copy_(
self._new_sum_key_buf[layer_idx][:, :, read_ptr:read_ptr+1, :])
self.value_cache[layer_idx][:, :, old_dst:old_dst+1, :].copy_(
self._new_sum_value_buf[layer_idx][:, :, read_ptr:read_ptr+1, :])
self._old_summary_len[layer_idx] += 1
# Step 3: Write new summary to new_summary_buf (overwrite oldest slot)
w_ptr = self._new_sum_write_ptr[layer_idx]
self._new_sum_key_buf[layer_idx][:, :, w_ptr:w_ptr+1, :].copy_(key_states)
self._new_sum_value_buf[layer_idx][:, :, w_ptr:w_ptr+1, :].copy_(value_states)
self._new_sum_write_ptr[layer_idx] = (w_ptr + 1) % scn
if self._new_sum_len[layer_idx] < scn:
self._new_sum_len[layer_idx] += 1
self.cur_chunk_sizes[layer_idx] += n_summary
self._total_chunks[layer_idx] += n_summary
# ── Decode: get KV for attention ──
def get_attention_kv(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
"""Get full KV for text token attention.
Large window: buffer[:text_len]
Small window (steady state): key_cache[:ws + n_old + cbl] — single slice, zero cat
"""
if self.is_large_window[layer_idx]:
tl = self._text_len[layer_idx]
return (self.key_cache[layer_idx][:, :, :tl, :],
self.value_cache[layer_idx][:, :, :tl, :])
ws = self.window_sizes[layer_idx]
nv = self._n_valid_window[layer_idx]
cbl = self._chunk_buf_len[layer_idx]
n_old = self._old_summary_len[layer_idx]
# Steady state: ring full → [ring(ws) | old_sums(n_old) | chunk_mirror(cbl)] contiguous
if nv == ws:
end = ws + n_old + cbl
return (self.key_cache[layer_idx][:, :, :end, :],
self.value_cache[layer_idx][:, :, :end, :])
# Warmup: ring not full, [nv:ws] is gap → cat
parts_k, parts_v = [], []
if nv > 0:
parts_k.append(self.key_cache[layer_idx][:, :, :nv, :])
parts_v.append(self.value_cache[layer_idx][:, :, :nv, :])
if cbl > 0:
mirror_start = ws + n_old
parts_k.append(self.key_cache[layer_idx][:, :, mirror_start:mirror_start + cbl, :])
parts_v.append(self.value_cache[layer_idx][:, :, mirror_start:mirror_start + cbl, :])
if n_old > 0:
parts_k.append(self.key_cache[layer_idx][:, :, ws:ws + n_old, :])
parts_v.append(self.value_cache[layer_idx][:, :, ws:ws + n_old, :])
if len(parts_k) == 1:
return parts_k[0], parts_v[0]
return torch.cat(parts_k, dim=2), torch.cat(parts_v, dim=2)
def get_current_chunk_kv(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
"""Get KV of the current chunk's C text tokens for summary token attention."""
C = self.summary_chunk_size
if self.is_large_window[layer_idx]:
tl = self._text_len[layer_idx]
return (self.key_cache[layer_idx][:, :, tl - C:tl, :],
self.value_cache[layer_idx][:, :, tl - C:tl, :])
else:
ws = self.window_sizes[layer_idx]
n_old = self._old_summary_len[layer_idx]
cbl = self._chunk_buf_len[layer_idx]
mirror_start = ws + n_old
return (self.key_cache[layer_idx][:, :, mirror_start:mirror_start + cbl, :],
self.value_cache[layer_idx][:, :, mirror_start:mirror_start + cbl, :])
def reset_chunk_counter(self):
"""Reset chunk counters after a chunk boundary step completes."""
block = self.summary_chunk_size + self.summary_token_num
for layer_idx in range(self.num_hidden_layers):
if self.cur_chunk_sizes[layer_idx] >= block:
self.cur_chunk_sizes[layer_idx] %= block
class Qwen3MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def _sdpa_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_output = F.scaled_dot_product_attention(
query,
key_states,
value_states,
attn_mask=None,
dropout_p=dropout,
is_causal=False,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None
class Qwen3Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Qwen3Config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
attn_output, attn_weights = _sdpa_attention_forward(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class Qwen3SummaryAttention(Qwen3Attention):
"""
Summary-aware variant of Qwen3Attention: uses a sliding summary mask.
"""
def __init__(self, config: Qwen3Config, layer_idx: int):
super().__init__(config, layer_idx)
self.summary_chunk_size = getattr(self.config, "summary_chunk_size", 0)
self.summary_token_num = getattr(self.config, "summary_token_num", 0)
# Cache sliding_chunk_num to avoid eval() on every forward call
val = getattr(config, "summary_sliding_chunk_num", 0) or 0
val = _parse_config_pattern(val)
if isinstance(val, list):
self._sliding_chunk_num = val[layer_idx]
else:
self._sliding_chunk_num = int(val)
if config.summary_independent_parameters and config.mix_coeff > 0:
self.q_proj_summary = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj_summary = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj_summary = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
def _get_sliding_chunk_num(self):
return self._sliding_chunk_num
def get_query_key_value_tensors(self, hidden_states):
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
return query_states, key_states, value_states
def get_query_key_value_tensors_summary(self, hidden_states):
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_norm(self.q_proj_summary(hidden_states).view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj_summary(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj_summary(hidden_states).view(hidden_shape).transpose(1, 2)
return query_states, key_states, value_states
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
summary_ctx: Optional[SummaryBatchContext] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1]
if hidden_states.size(0) != 1:
raise ValueError("Summary sliding attention only supports batch size=1.")
# Compute q/k/v for the full sequence once.
if self.config.summary_independent_parameters:
if summary_ctx is None:
raise ValueError("summary_ctx is required when using summary_independent_parameters.")
summary_mask = summary_ctx.summary_mask
summary_pos = summary_mask[0]
assert (summary_mask == summary_mask[0:1]).all()
if self.config.mix_coeff == 0:
# When mix_coeff=0, summary projections have no effect — skip clone + extra linear
query_states, key_states, value_states = self.get_query_key_value_tensors(hidden_states)
else:
query, key, value = self.get_query_key_value_tensors(hidden_states)
query_states = query.clone()
key_states = key.clone()
value_states = value.clone()
hs_summary = hidden_states[:, summary_pos, :]
if hs_summary.size(1) > 0:
base_q_summary = query[:, :, summary_pos, :]
base_k_summary = key[:, :, summary_pos, :]
base_v_summary = value[:, :, summary_pos, :]
q_s, k_s, v_s = self.get_query_key_value_tensors_summary(hs_summary)
q_s = self.config.mix_coeff * q_s + (1.0 - self.config.mix_coeff) * base_q_summary
k_s = self.config.mix_coeff * k_s + (1.0 - self.config.mix_coeff) * base_k_summary
v_s = self.config.mix_coeff * v_s + (1.0 - self.config.mix_coeff) * base_v_summary
query_states[:, :, summary_pos, :] = q_s
key_states[:, :, summary_pos, :] = k_s
value_states[:, :, summary_pos, :] = v_s
else:
query_states, key_states, value_states = self.get_query_key_value_tensors(hidden_states)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
query_len = query_states.shape[2]
is_prefill = past_key_values is None or not past_key_values._reorganized
if is_prefill:
# Prefill: use standard append and summary_attn_func
if past_key_values is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
if summary_ctx is not None:
cache_kwargs["summary_mask"] = summary_ctx.summary_mask
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
with torch.cuda.device(query_states.device):
attn_output, attn_weights = summary_attn_func(
query_states.transpose(1,2).contiguous(),
key_states.transpose(1,2).contiguous(),
value_states.transpose(1,2).contiguous(),
self.summary_chunk_size,
self.summary_token_num,
self._get_sliding_chunk_num(),
summary_pos=summary_ctx.summary_mask.squeeze()
)
elif query_len == 1:
# Single text token decode: write to cache, attend to full buffer
past_key_values.update_text(key_states, value_states, self.layer_idx)
k_full, v_full = past_key_values.get_attention_kv(self.layer_idx)
attn_output, attn_weights = _sdpa_attention_forward(
self,
query_states,
k_full,
v_full,
None,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window,
**kwargs,
)
else:
# Chunk boundary: query = [text_token, summary_token(s)]
# Split into text (first token) and summary (remaining tokens)
q_text = query_states[:, :, :1, :]
q_summary = query_states[:, :, 1:, :]
k_text = key_states[:, :, :1, :]
v_text = value_states[:, :, :1, :]
k_summary = key_states[:, :, 1:, :]
v_summary = value_states[:, :, 1:, :]
# 1. Write text token to cache, get full KV, run text attention
past_key_values.update_text(k_text, v_text, self.layer_idx)
k_full, v_full = past_key_values.get_attention_kv(self.layer_idx)
text_out, _ = _sdpa_attention_forward(
self,
q_text,
k_full,
v_full,
None,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window,
**kwargs,
)
# 2. Summary attention: attend to current chunk's C text tokens + own KV (self-attention)
# The original model includes the summary token's own KV in its attention
# (causal within summary positions). With S=1, this is just self-attention.
k_chunk, v_chunk = past_key_values.get_current_chunk_kv(self.layer_idx)
k_chunk_with_self = torch.cat([k_chunk, k_summary], dim=2)
v_chunk_with_self = torch.cat([v_chunk, v_summary], dim=2)
summary_out, _ = _sdpa_attention_forward(
self,
q_summary,
k_chunk_with_self,
v_chunk_with_self,
None,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window,
**kwargs,
)
# 3. Write summary KV to cache
past_key_values.update_summary(k_summary, v_summary, self.layer_idx)
attn_output = torch.cat([text_out, summary_out], dim=2)
attn_weights = None
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class Qwen3DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Qwen3Config, layer_idx: int):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
# Use SummaryAttention if enabled in config
if getattr(config, "use_summary_attention", False) is True and config.summary_layer_freq[layer_idx] == 1:
self.self_attn = Qwen3SummaryAttention(config=config, layer_idx=layer_idx)
elif getattr(config, "use_summary_attention", False) is False and config.summary_layer_freq[layer_idx] == 0:
self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx)
else:
raise ValueError(f'Check config.summary_layer_freq {config.summary_layer_freq} and config.use_summary_attention {config.use_summary_attention}')
self.mlp = Qwen3MLP(config)
self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if getattr(config, "summary_independent_attention_layernorm", False):
self.input_layernorm_summary = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attention_type = config.layer_types[layer_idx]
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
summary_ctx: Optional[SummaryBatchContext] = None,
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
residual = hidden_states
if getattr(self.config, "summary_independent_attention_layernorm", False):
summary_mask = summary_ctx.summary_mask
assert (summary_mask == summary_mask[0:1]).all(), \
"summary_mask must be identical across batch"
hidden_states = self.input_layernorm(hidden_states)
if summary_mask.any():
hidden_summary = residual[:, summary_mask[0].to(residual.device), :]
hidden_summary = self.input_layernorm_summary(hidden_summary)
hidden_states[:, summary_mask[0], :] = hidden_summary
else:
hidden_states = self.input_layernorm(hidden_states)
# Self Attention - pass summary_ctx if using summary attention
attn_kwargs = {
"hidden_states": hidden_states,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"cache_position": cache_position,
"position_embeddings": position_embeddings,
**kwargs,
}
if isinstance(self.self_attn, Qwen3SummaryAttention):
attn_kwargs["summary_ctx"] = summary_ctx
hidden_states, _ = self.self_attn(**attn_kwargs)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
@auto_docstring
class Qwen3PreTrainedModel(PreTrainedModel):
config: Qwen3Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen3DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Qwen3DecoderLayer,
"attentions": Qwen3Attention,
}
class Qwen3RotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, config: Qwen3Config, device=None):
super().__init__()
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_type = self.config.rope_parameters["rope_type"]
rope_init_fn: Callable = self.compute_default_rope_parameters
if self.rope_type != "default":
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = inv_freq
@staticmethod
def compute_default_rope_parameters(
config: Optional[Qwen3Config] = None,
device: Optional["torch.device"] = None,
seq_len: Optional[int] = None,
) -> tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
Args:
config ([`~transformers.PreTrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
base = config.rope_parameters["rope_theta"]
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
attention_factor = 1.0 # Unused in this type of RoPE
# Compute the inverse frequencies
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
)
return inv_freq, attention_factor
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
@auto_docstring
class Qwen3Model(Qwen3PreTrainedModel):
def __init__(self, config: Qwen3Config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
if not getattr(config, "summary_layer_freq", False):
if config.use_summary_attention:
config.summary_layer_freq = [1]*config.num_hidden_layers
else:
config.summary_layer_freq = [0]*config.num_hidden_layers
Warning(f'Please set config.summary_layer_freq, temp set summary_layer_freq = {config.num_hidden_layers}')
else:
config.summary_layer_freq = _parse_config_pattern(config.summary_layer_freq)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen3RotaryEmbedding(config=config)
self.gradient_checkpointing = False
self.has_sliding_layers = "sliding_attention" in self.config.layer_types
# Cache per-layer sliding_chunk_nums for KV cache eviction
_sv = _parse_config_pattern(getattr(config, "summary_sliding_chunk_num", 0) or 0)
if isinstance(_sv, list):
self._sliding_chunk_nums = [int(v) for v in _sv]
else:
self._sliding_chunk_nums = [int(_sv)] * config.num_hidden_layers
# Initialize weights and apply final processing
self.post_init()
def _expand_input_with_summary_tokens(self, input_ids):
"""Expand input_ids with summary tokens for prefill phase (vectorized).
Returns:
Tuple of (expanded_input_ids, position_ids, text_only_mask)
"""
summary_chunk = self.config.summary_chunk_size
summary_num = self.config.summary_token_num
summary_begin = self.config.summary_token_begin
if summary_chunk == 0 or summary_num == 0:
return input_ids, None, None
batch_size, seq_len = input_ids.shape
device = input_ids.device
dtype = input_ids.dtype
block = summary_chunk + summary_num
# Number of full chunks and remainder
n_full_chunks = seq_len // summary_chunk
remainder = seq_len % summary_chunk
has_remainder = remainder > 0
# Total expanded length: full_chunks * block + remainder
expanded_len = n_full_chunks * block + (remainder if has_remainder else 0)
# --- Build expanded_input_ids ---
expanded_ids = torch.empty((batch_size, expanded_len), dtype=dtype, device=device)
text_only_mask = torch.zeros((batch_size, expanded_len), dtype=torch.bool, device=device)
# Compute text positions: for chunk i, text goes to [i*block, i*block+summary_chunk)
# Summary positions: [i*block+summary_chunk, (i+1)*block)
if n_full_chunks > 0:
chunk_indices = torch.arange(n_full_chunks, device=device)
# Text source positions in original input_ids
text_src_offsets = (chunk_indices * summary_chunk).unsqueeze(1) + torch.arange(summary_chunk, device=device).unsqueeze(0) # [n_full_chunks, summary_chunk]
# Text dest positions in expanded
text_dst_offsets = (chunk_indices * block).unsqueeze(1) + torch.arange(summary_chunk, device=device).unsqueeze(0) # [n_full_chunks, summary_chunk]
# Summary dest positions
summary_dst_offsets = (chunk_indices * block + summary_chunk).unsqueeze(1) + torch.arange(summary_num, device=device).unsqueeze(0) # [n_full_chunks, summary_num]
text_src_flat = text_src_offsets.reshape(-1)
text_dst_flat = text_dst_offsets.reshape(-1)
summary_dst_flat = summary_dst_offsets.reshape(-1)
# Copy text tokens
expanded_ids[:, text_dst_flat] = input_ids[:, text_src_flat]
text_only_mask[:, text_dst_flat] = True
# Fill summary tokens
summary_ids_val = torch.arange(summary_num, device=device, dtype=dtype) + summary_begin
expanded_ids[:, summary_dst_flat] = summary_ids_val.repeat(n_full_chunks).unsqueeze(0).expand(batch_size, -1)
# Handle remainder (last partial chunk, no summary tokens)
if has_remainder:
rem_start_src = n_full_chunks * summary_chunk
rem_start_dst = n_full_chunks * block
rem_offsets = torch.arange(remainder, device=device)
expanded_ids[:, rem_start_dst + rem_offsets] = input_ids[:, rem_start_src + rem_offsets]
text_only_mask[:, rem_start_dst + rem_offsets] = True
# --- Build position_ids ---
position_ids = torch.empty((batch_size, expanded_len), dtype=torch.long, device=device)
if n_full_chunks > 0:
# Text position IDs
if self.config.summary_chunk_position_ids_type == 'origin':
text_pos = text_src_flat.unsqueeze(0).expand(batch_size, -1)
elif self.config.summary_chunk_position_ids_type == 'inner_chunk':
inner_pos = torch.arange(summary_chunk, device=device).repeat(n_full_chunks)
text_pos = inner_pos.unsqueeze(0).expand(batch_size, -1)
else:
raise ValueError(f'Check config.summary_chunk_position_ids_type: {self.config.summary_chunk_position_ids_type}')
position_ids[:, text_dst_flat] = text_pos
# Summary position IDs
if self.config.summary_token_position_ids_type == 'zeros':
position_ids[:, summary_dst_flat] = 0
elif self.config.summary_token_position_ids_type in ('last_chunk_slice_left', 'last_chunk_slice_right'):
# Vectorized slice_ends computation for all chunks at once
if self.config.summary_token_position_ids_type == 'last_chunk_slice_left':
idx = torch.arange(0, summary_num, device=device, dtype=torch.long)
else:
idx = torch.arange(1, summary_num + 1, device=device, dtype=torch.long)
# For each chunk i: prev_text_end = i * summary_chunk
prev_ends = (chunk_indices * summary_chunk).unsqueeze(1) # [n_full_chunks, 1]
slice_ends = prev_ends + (idx.unsqueeze(0) * summary_chunk) // summary_num - 1 # [n_full_chunks, summary_num]
slice_ends = slice_ends.clamp(min=0)
# Clamp per-chunk: min is prev_text_end for that chunk
slice_ends = torch.max(slice_ends, prev_ends)
position_ids[:, summary_dst_flat] = slice_ends.reshape(-1).unsqueeze(0).expand(batch_size, -1)
else:
raise ValueError(f'Unknown summary_token_position_ids_type: {self.config.summary_token_position_ids_type}')
# Remainder position IDs
if has_remainder:
if self.config.summary_chunk_position_ids_type == 'origin':
rem_pos = (rem_start_src + rem_offsets).unsqueeze(0).expand(batch_size, -1)
elif self.config.summary_chunk_position_ids_type == 'inner_chunk':
rem_pos = rem_offsets.unsqueeze(0).expand(batch_size, -1)
else:
raise ValueError(f'Check config.summary_chunk_position_ids_type: {self.config.summary_chunk_position_ids_type}')
position_ids[:, rem_start_dst + rem_offsets] = rem_pos
return expanded_ids, position_ids, text_only_mask
def _build_summary_context(self, input_ids, position_ids, is_prefill, use_cache):
"""Build summary context for attention layers."""
summary_chunk = self.config.summary_chunk_size
summary_num = self.config.summary_token_num
summary_begin = self.config.summary_token_begin
if summary_chunk > 0 and summary_num > 0:
return build_summary_sliding_context(
input_ids=input_ids,
position_ids=position_ids,
summary_token_num=summary_num,
summary_token_begin=summary_begin,
)
return None
def _filter_summary_tokens(self, hidden_states, text_only_mask, use_summary, is_decode):
"""Filter out summary tokens from output hidden states."""
if text_only_mask is not None:
# Prefill: vectorized filtering using boolean mask
batch_size, _, hidden_size = hidden_states.shape
text_length = text_only_mask[0].sum().item()
return hidden_states[text_only_mask.to(hidden_states.device)].reshape(batch_size, text_length, hidden_size)
elif use_summary and is_decode and hidden_states.size(1) > 1:
# Decode: if we have multiple tokens, only return the first (text token)
return hidden_states[:, :1, :]
return hidden_states
@check_model_inputs()
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
summary_ctx: Optional[SummaryBatchContext] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
use_summary = getattr(self.config, "use_summary_attention", False)
is_prefill = past_key_values is None or past_key_values.get_seq_length() == 0
# Prefill phase with summary attention: expand input_ids with summary tokens
text_only_mask = None
if use_summary and input_ids is not None and inputs_embeds is None and is_prefill:
input_ids, position_ids, text_only_mask = self._expand_input_with_summary_tokens(input_ids)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# Initialize cache
if use_cache and past_key_values is None:
if use_summary:
past_key_values = Qwen3RingBufferCache(
config=self.config, sliding_chunk_nums=self._sliding_chunk_nums)
else:
past_key_values = DynamicCache(config=self.config)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# Build summary context if needed
if use_summary and summary_ctx is None and input_ids is not None:
summary_ctx = self._build_summary_context(input_ids, position_ids, is_prefill, use_cache)
causal_mask_mapping = attention_mask
if not isinstance(causal_mask_mapping, (dict, list)):
if summary_ctx and summary_ctx.enabled:
seq_len = inputs_embeds.shape[1]
# During prefill, Qwen3SummaryAttention uses summary_attn_func
# which does not need a dense mask. Skip expensive mask construction.
# During decode, prepare_inputs_for_generation already computed
# per-layer keep_indices and passed them as attention_mask (list).
# If we reach here with a non-list, it means no mask is needed.
causal_mask_mapping = None
else:
# Prepare mask arguments
mask_kwargs = {
"config": self.config,
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
# Create the masks - disable causal mask when summary context is enabled
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
}
# The sliding window alternating layers are not always activated depending on the config
if self.has_sliding_layers:
causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
if causal_mask_mapping is None:
layer_mask = None
elif isinstance(causal_mask_mapping, list):
layer_mask = causal_mask_mapping[layer_idx]
else:
layer_mask = causal_mask_mapping[decoder_layer.attention_type]
hidden_states = decoder_layer(
hidden_states,
attention_mask=layer_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
summary_ctx=summary_ctx,
**kwargs,
)
hidden_states = self.norm(hidden_states)
# After prefill: reorganize cache to ring buffer layout
if use_cache and use_summary and past_key_values is not None and is_prefill:
if hasattr(past_key_values, 'reorganize_after_prefill') and summary_ctx is not None:
past_key_values.reorganize_after_prefill(summary_ctx.summary_mask)
# After chunk boundary decode: reset chunk counters
if use_cache and use_summary and past_key_values is not None and not is_prefill:
if hasattr(past_key_values, 'reset_chunk_counter'):
past_key_values.reset_chunk_counter()
# Filter out summary tokens from output
hidden_states = self._filter_summary_tokens(hidden_states, text_only_mask, use_summary,
past_key_values is not None and past_key_values.get_seq_length() > 0)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
)
@auto_docstring
class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin):
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config):
super().__init__(config)
self.model = Qwen3Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
summary_ctx: Optional[SummaryBatchContext] = None,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python
>>> from transformers import AutoTokenizer, Qwen3ForCausalLM
>>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
summary_ctx=summary_ctx,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
if isinstance(logits_to_keep, int) and logits_to_keep == 0 and labels is None:
# Inference: only need last token's logits to avoid OOM from [seq_len, vocab_size]
logits = self.lm_head(hidden_states[:, -1:, :])
else:
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
truncate_n = getattr(self.config, "truncate_predict_nums", 151936)
if truncate_n > 0:
logits = logits[..., :truncate_n]
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=logits.shape[-1], **kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def _build_summary_attention_mask_for_generation(
self,
*,
input_ids: torch.LongTensor,
past_key_values: Optional[Cache],
attention_mask: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
"""Ring buffer cache handles attention internally — no mask needed for decode."""
if isinstance(past_key_values, Qwen3RingBufferCache):
return None
return attention_mask
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[Cache] = None,
attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
):
use_summary = getattr(self.config, "use_summary_attention", False)
# If not using summary attention, use standard behavior
if not use_summary:
return super().prepare_inputs_for_generation(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
**kwargs,
)
# For summary attention: handle cache-based input slicing
summary_chunk_size = getattr(self.config, "summary_chunk_size", 0)
summary_token_num = getattr(self.config, "summary_token_num", 0)
summary_token_begin = getattr(self.config, "summary_token_begin", 0)
# Prefill phase: pass full sequence, forward() will handle summary token insertion
if past_key_values is None or past_key_values.get_seq_length() == 0:
if cache_position is None:
cache_position = torch.arange(0, input_ids.shape[1], device=input_ids.device)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past_key_values": past_key_values,
"cache_position": cache_position,
"use_cache": kwargs.get("use_cache"),
}
# Decode phase: only pass new tokens not in cache
# Get current chunk size (number of text tokens in current chunk)
cur_chunk = past_key_values.get_cur_chunk_size() if hasattr(past_key_values, "get_cur_chunk_size") else 0
true_token_num = past_key_values.get_true_token_num()
# Only take the new tokens that haven't been processed
if input_ids.shape[1] > 1:
# Slice to get only new tokens
new_token_count = input_ids.shape[1] - true_token_num
assert new_token_count > 0, f'new_token_count={new_token_count} should be greater than 0'
input_ids = input_ids[:, -new_token_count:]
device = input_ids.device
# Check if we need to insert summary tokens
# If cur_chunk >= summary_chunk_size, we need to generate summary tokens
if cur_chunk == summary_chunk_size - 1:
# Insert summary tokens
batch_size = input_ids.shape[0]
summary_ids = (
torch.arange(summary_token_num, device=device, dtype=input_ids.dtype)
+ summary_token_begin
).unsqueeze(0).repeat(batch_size, 1)
# Concatenate: [text_token, summary_tokens]
input_ids = torch.cat([input_ids, summary_ids], dim=1)
# Position IDs: text token uses cur_chunk, summary tokens use 0
if self.config.summary_chunk_position_ids_type == 'origin':
text_pos = torch.full((batch_size, 1), past_key_values.get_true_token_num(), device=device, dtype=torch.long)
elif self.config.summary_chunk_position_ids_type == 'inner_chunk':
text_pos = torch.full((batch_size, 1), cur_chunk, device=device, dtype=torch.long)
else:
raise ValueError(f'Check config.summary_chunk_position_ids_type: {self.config.summary_chunk_position_ids_type}')
if self.config.summary_token_position_ids_type == 'zeros':
summary_pos = torch.zeros((batch_size, summary_token_num), device=device, dtype=torch.long)
elif self.config.summary_token_position_ids_type == 'last_chunk_slice_left':
# 等分成 summary_num 份,每个 summary token 取对应 slice 的末尾
prev_text_end = past_key_values.get_true_token_num()+1-summary_chunk_size
cur_text_end = past_key_values.get_true_token_num()+1
chunk_len = cur_text_end - prev_text_end
idx = torch.arange(0, summary_token_num, device=device, dtype=torch.long,)
# 每一份的末尾(全局 position)
slice_ends = prev_text_end + (idx * chunk_len) // summary_token_num - 1
slice_ends = slice_ends.clamp(min=prev_text_end)
summary_pos = slice_ends.to(dtype=torch.long, device=device).unsqueeze(0)
elif self.config.summary_token_position_ids_type == 'last_chunk_slice_right':
# 等分成 summary_num 份,每个 summary token 取对应 slice 的末尾
prev_text_end = past_key_values.get_true_token_num()+1-summary_chunk_size
cur_text_end = past_key_values.get_true_token_num()+1
chunk_len = cur_text_end - prev_text_end
idx = torch.arange(1, summary_token_num + 1, device=device, dtype=torch.long,)
# 每一份的末尾(全局 position)
slice_ends = prev_text_end + (idx * chunk_len) // summary_token_num - 1
slice_ends = slice_ends.clamp(min=prev_text_end)
summary_pos = slice_ends.to(dtype=torch.long, device=device).unsqueeze(0)
else:
raise ValueError('')
position_ids = torch.cat([text_pos, summary_pos], dim=1)
else:
# Normal decode: just the new text token with position = cur_chunk
if position_ids is None:
batch_size = input_ids.shape[0]
if self.config.summary_chunk_position_ids_type == 'origin':
position_ids = torch.full((batch_size, input_ids.shape[1]), past_key_values.get_true_token_num(), device=input_ids.device, dtype=torch.long)
elif self.config.summary_chunk_position_ids_type == 'inner_chunk':
position_ids = torch.full((batch_size, input_ids.shape[1]), cur_chunk, device=input_ids.device, dtype=torch.long)
else:
raise ValueError(f'Check config.summary_chunk_position_ids_type: {self.config.summary_chunk_position_ids_type}')
return {
"input_ids": input_ids,
"attention_mask": self._build_summary_attention_mask_for_generation(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
),
"position_ids": position_ids,
"past_key_values": past_key_values,
"cache_position": cache_position,
"use_cache": kwargs.get("use_cache"),
}
class Qwen3ForSequenceClassification(GenericForSequenceClassification, Qwen3PreTrainedModel):
pass
class Qwen3ForTokenClassification(GenericForTokenClassification, Qwen3PreTrainedModel):
pass
class Qwen3ForQuestionAnswering(GenericForQuestionAnswering, Qwen3PreTrainedModel):
base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
__all__ = [
"Qwen3ForCausalLM",
"Qwen3ForQuestionAnswering",
"Qwen3PreTrainedModel",
"Qwen3Model",
"Qwen3ForSequenceClassification",
"Qwen3ForTokenClassification",
"Qwen3RingBufferCache",
"Qwen3SummaryAttention",
"SummaryBatchContext",
"build_summary_context",
"build_summary_sliding_context",
]