# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/qwen3/modular_qwen3.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_qwen3.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Callable, Optional, Union import torch from torch import nn import torch.nn.functional as F #from flash_attn import flash_attn_func from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.generation import GenerationMixin from transformers.integrations import use_kernel_forward_from_hub from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_layers import ( GenericForQuestionAnswering, GenericForSequenceClassification, GenericForTokenClassification, GradientCheckpointingLayer, ) from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import PreTrainedModel from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.generic import check_model_inputs from transformers.models.qwen3.configuration_qwen3 import Qwen3Config # InfinityLM imports for summary attention from .summary_context import SummaryBatchContext, build_summary_context, build_summary_sliding_context from summary_attn import summary_attn_func def _parse_config_pattern(val): """Parse a config value that may be an int, list, or Python pattern string like '([4096]*1+[128]*3)*9'.""" if isinstance(val, list): return val if isinstance(val, str): return eval(val) return val @use_kernel_forward_from_hub("RMSNorm") class Qwen3RMSNorm(nn.Module): def __init__(self, hidden_size, eps: float = 1e-6) -> None: """ Qwen3RMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class Qwen3RingBufferCache: """ Ring buffer KV cache with summary support. Two strategies based on per-layer sliding_chunk_num: - Large window layers (is_large_window=True): append-only buffer storing only text KV. Summary KV is NOT stored since text tokens attend to all text KV directly. - Small window layers (is_large_window=False): Three buffers: 1. key_cache: [ring(ws) | old_summaries(growing) | chunk_mirror(≀C)] β†’ attention input, steady state is a single contiguous slice 2. new_summary_buf: ring buffer of size scn, stores summaries whose text is still in the window (not needed for attention) 3. chunk_buf: size C, holds current chunk's text KV RoPE position information is baked into KV, so physical order doesn't matter. """ is_compileable = False _SUMMARY_INIT_CAP = 512 _APPEND_HEADROOM = 1024 def __init__(self, config: Qwen3Config, sliding_chunk_nums: list[int]): super().__init__() self.summary_chunk_size = getattr(config, "summary_chunk_size", 0) self.summary_token_num = getattr(config, "summary_token_num", 0) self.num_hidden_layers = config.num_hidden_layers self.sliding_chunk_nums = sliding_chunk_nums large_window_threshold = min(sliding_chunk_nums) * self.summary_chunk_size self.is_large_window = [sv * self.summary_chunk_size > large_window_threshold for sv in sliding_chunk_nums] self.window_sizes = [sv * self.summary_chunk_size for sv in sliding_chunk_nums] self.key_cache = [None for _ in range(config.num_hidden_layers)] self.value_cache = [None for _ in range(config.num_hidden_layers)] # Large window: append-only self._text_len = [0] * config.num_hidden_layers self._capacity = [0] * config.num_hidden_layers # Small window: ring buffer + summary self._window_write_ptr = [0] * config.num_hidden_layers self._n_valid_window = [0] * config.num_hidden_layers self._old_summary_len = [0] * config.num_hidden_layers # old summaries in key_cache self._old_summary_cap = [0] * config.num_hidden_layers # New summary ring buffer (small window only): summaries whose text is still in window self._new_sum_key_buf = [None for _ in range(config.num_hidden_layers)] self._new_sum_value_buf = [None for _ in range(config.num_hidden_layers)] self._new_sum_len = [0] * config.num_hidden_layers # how many filled (≀ scn) self._new_sum_write_ptr = [0] * config.num_hidden_layers # ring write pointer # Current chunk buffer (small window only): holds partial chunk text KV self._chunk_key_buf = [None for _ in range(config.num_hidden_layers)] self._chunk_value_buf = [None for _ in range(config.num_hidden_layers)] self._chunk_buf_len = [0] * config.num_hidden_layers # Common self.cur_chunk_sizes = [0] * config.num_hidden_layers self.true_tokens = [0] * config.num_hidden_layers self._total_chunks = [0] * config.num_hidden_layers # completed chunks count self._reorganized = False def __len__(self): return self.num_hidden_layers def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns nonzero when cache is populated (used to detect prefill vs decode).""" if layer_idx >= self.num_hidden_layers: return 0 if self.is_large_window[layer_idx]: return self._text_len[layer_idx] else: return (self._n_valid_window[layer_idx] + self._chunk_buf_len[layer_idx] + self._old_summary_len[layer_idx] + self._new_sum_len[layer_idx]) def get_cur_chunk_size(self, layer_idx: Optional[int] = None) -> int: if layer_idx is None: layer_idx = self.num_hidden_layers - 1 return self.cur_chunk_sizes[layer_idx] def get_true_token_num(self, layer_idx: Optional[int] = None) -> int: if layer_idx is None: layer_idx = self.num_hidden_layers - 1 return self.true_tokens[layer_idx] # ── Prefill: standard append (before reorganize) ── def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Append KV during prefill (before reorganize). Returns full KV for prefill attention.""" add_len = key_states.shape[-2] cur_len = self._text_len[layer_idx] new_len = cur_len + add_len if self.key_cache[layer_idx] is None: cap = new_len + self._APPEND_HEADROOM bsz, heads, _, head_dim = key_states.shape self.key_cache[layer_idx] = torch.empty( bsz, heads, cap, head_dim, dtype=key_states.dtype, device=key_states.device) self.value_cache[layer_idx] = torch.empty( bsz, heads, cap, head_dim, dtype=value_states.dtype, device=value_states.device) self._capacity[layer_idx] = cap elif new_len > self._capacity[layer_idx]: cap = max(new_len + self._APPEND_HEADROOM, self._capacity[layer_idx] * 2) old_k, old_v = self.key_cache[layer_idx], self.value_cache[layer_idx] bsz, heads, _, head_dim = old_k.shape new_k = torch.empty(bsz, heads, cap, head_dim, dtype=old_k.dtype, device=old_k.device) new_v = torch.empty(bsz, heads, cap, head_dim, dtype=old_v.dtype, device=old_v.device) new_k[:, :, :cur_len, :].copy_(old_k[:, :, :cur_len, :]) new_v[:, :, :cur_len, :].copy_(old_v[:, :, :cur_len, :]) self.key_cache[layer_idx] = new_k self.value_cache[layer_idx] = new_v self._capacity[layer_idx] = cap self.key_cache[layer_idx][:, :, cur_len:new_len, :].copy_(key_states) self.value_cache[layer_idx][:, :, cur_len:new_len, :].copy_(value_states) self._text_len[layer_idx] = new_len if self.summary_chunk_size > 0: if cache_kwargs and 'summary_mask' in cache_kwargs: text_count = add_len - cache_kwargs['summary_mask'][0].sum().item() else: text_count = add_len self.cur_chunk_sizes[layer_idx] += add_len self.true_tokens[layer_idx] += text_count return self.key_cache[layer_idx][:, :, :new_len, :], self.value_cache[layer_idx][:, :, :new_len, :] # ── Reorganize after prefill ── def reorganize_after_prefill(self, summary_mask: torch.Tensor): """Reorganize all layers from prefill block layout to ring buffer layout.""" if self._reorganized: return self._reorganized = True text_mask = ~summary_mask[0] for layer_idx in range(self.num_hidden_layers): prefill_len = self._text_len[layer_idx] prefill_k = self.key_cache[layer_idx][:, :, :prefill_len, :] prefill_v = self.value_cache[layer_idx][:, :, :prefill_len, :] bsz, heads, _, head_dim = prefill_k.shape device, dtype = prefill_k.device, prefill_k.dtype text_k = prefill_k[:, :, text_mask, :] text_v = prefill_v[:, :, text_mask, :] n_text = text_k.shape[2] if self.is_large_window[layer_idx]: # Large window: keep only text KV cap = n_text + self._APPEND_HEADROOM new_k = torch.empty(bsz, heads, cap, head_dim, dtype=dtype, device=device) new_v = torch.empty(bsz, heads, cap, head_dim, dtype=dtype, device=device) new_k[:, :, :n_text, :].copy_(text_k) new_v[:, :, :n_text, :].copy_(text_v) self.key_cache[layer_idx] = new_k self.value_cache[layer_idx] = new_v self._text_len[layer_idx] = n_text self._capacity[layer_idx] = cap else: # Small window: split summaries into old (evicted) and new (in window) all_summary_k = prefill_k[:, :, summary_mask[0], :] all_summary_v = prefill_v[:, :, summary_mask[0], :] n_summary = all_summary_k.shape[2] C = self.summary_chunk_size ws = self.window_sizes[layer_idx] scn = self.sliding_chunk_nums[layer_idx] # Split text into complete chunks + partial remainder n_complete_chunks = n_text // C n_partial = n_text % C n_complete_text = n_complete_chunks * C # Window: last scn complete chunks (or all if fewer) n_window_chunks = min(scn, n_complete_chunks) n_window_text = n_window_chunks * C window_start = n_complete_text - n_window_text # Split summaries: old (text evicted from ring) vs new (text in ring) n_old = max(0, n_summary - n_window_chunks) n_new = n_summary - n_old # key_cache: [ring(ws) | old_summaries | chunk_mirror(≀C)] old_s_cap = max(self._SUMMARY_INIT_CAP, (n_old + 1) * 2) total_cap = ws + old_s_cap + C new_k = torch.empty(bsz, heads, total_cap, head_dim, dtype=dtype, device=device) new_v = torch.empty(bsz, heads, total_cap, head_dim, dtype=dtype, device=device) if n_window_text > 0: new_k[:, :, :n_window_text, :].copy_(text_k[:, :, window_start:n_complete_text, :]) new_v[:, :, :n_window_text, :].copy_(text_v[:, :, window_start:n_complete_text, :]) self._n_valid_window[layer_idx] = n_window_text self._window_write_ptr[layer_idx] = n_window_text % ws # Old summaries go into key_cache after ring if n_old > 0: new_k[:, :, ws:ws + n_old, :].copy_(all_summary_k[:, :, :n_old, :]) new_v[:, :, ws:ws + n_old, :].copy_(all_summary_v[:, :, :n_old, :]) self._old_summary_len[layer_idx] = n_old self._old_summary_cap[layer_idx] = old_s_cap # Mirror partial chunk into key_cache after old_summaries if n_partial > 0: mirror_start = ws + n_old new_k[:, :, mirror_start:mirror_start + n_partial, :].copy_( text_k[:, :, n_complete_text:, :]) new_v[:, :, mirror_start:mirror_start + n_partial, :].copy_( text_v[:, :, n_complete_text:, :]) self.key_cache[layer_idx] = new_k self.value_cache[layer_idx] = new_v self._capacity[layer_idx] = total_cap self._text_len[layer_idx] = 0 # New summary ring buffer ns_buf_k = torch.empty(bsz, heads, scn, head_dim, dtype=dtype, device=device) ns_buf_v = torch.empty(bsz, heads, scn, head_dim, dtype=dtype, device=device) if n_new > 0: ns_buf_k[:, :, :n_new, :].copy_(all_summary_k[:, :, n_old:, :]) ns_buf_v[:, :, :n_new, :].copy_(all_summary_v[:, :, n_old:, :]) self._new_sum_key_buf[layer_idx] = ns_buf_k self._new_sum_value_buf[layer_idx] = ns_buf_v self._new_sum_len[layer_idx] = n_new self._new_sum_write_ptr[layer_idx] = n_new % scn # Chunk buffer for partial remainder chunk_buf_k = torch.empty(bsz, heads, C, head_dim, dtype=dtype, device=device) chunk_buf_v = torch.empty(bsz, heads, C, head_dim, dtype=dtype, device=device) if n_partial > 0: chunk_buf_k[:, :, :n_partial, :].copy_(text_k[:, :, n_complete_text:, :]) chunk_buf_v[:, :, :n_partial, :].copy_(text_v[:, :, n_complete_text:, :]) self._chunk_key_buf[layer_idx] = chunk_buf_k self._chunk_value_buf[layer_idx] = chunk_buf_v self._chunk_buf_len[layer_idx] = n_partial block = self.summary_chunk_size + self.summary_token_num for layer_idx in range(self.num_hidden_layers): self.cur_chunk_sizes[layer_idx] = self.cur_chunk_sizes[layer_idx] % block self._total_chunks[layer_idx] = ( self._old_summary_len[layer_idx] + self._new_sum_len[layer_idx] if not self.is_large_window[layer_idx] else (self.true_tokens[layer_idx] // self.summary_chunk_size) ) # ── Decode: text token update ── def update_text(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int): """Write a single text token KV during decode.""" if self.is_large_window[layer_idx]: cur = self._text_len[layer_idx] new_len = cur + 1 if new_len > self._capacity[layer_idx]: cap = max(new_len + self._APPEND_HEADROOM, self._capacity[layer_idx] * 2) old_k, old_v = self.key_cache[layer_idx], self.value_cache[layer_idx] bsz, heads, _, head_dim = old_k.shape new_k = torch.empty(bsz, heads, cap, head_dim, dtype=old_k.dtype, device=old_k.device) new_v = torch.empty(bsz, heads, cap, head_dim, dtype=old_v.dtype, device=old_v.device) new_k[:, :, :cur, :].copy_(old_k[:, :, :cur, :]) new_v[:, :, :cur, :].copy_(old_v[:, :, :cur, :]) self.key_cache[layer_idx] = new_k self.value_cache[layer_idx] = new_v self._capacity[layer_idx] = cap self.key_cache[layer_idx][:, :, cur:new_len, :].copy_(key_states) self.value_cache[layer_idx][:, :, cur:new_len, :].copy_(value_states) self._text_len[layer_idx] = new_len else: # Write only to key_cache mirror region (chunk_buf eliminated) ws = self.window_sizes[layer_idx] n_old = self._old_summary_len[layer_idx] pos = self._chunk_buf_len[layer_idx] mirror_pos = ws + n_old + pos self.key_cache[layer_idx][:, :, mirror_pos:mirror_pos+1, :].copy_(key_states) self.value_cache[layer_idx][:, :, mirror_pos:mirror_pos+1, :].copy_(value_states) self._chunk_buf_len[layer_idx] = pos + 1 self.cur_chunk_sizes[layer_idx] += 1 self.true_tokens[layer_idx] += 1 # ── Decode: summary token update ── def update_summary(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int): """Write summary token KV during decode (chunk boundary). Large window: skip. Small window (order matters β€” flush mirror before evict to avoid clobbering): 1. Flush mirror region β†’ ring 2. Evict oldest new_summary β†’ old_summary in key_cache (if full) 3. Write new summary β†’ new_summary_buf """ n_summary = key_states.shape[2] if self.is_large_window[layer_idx]: self.cur_chunk_sizes[layer_idx] += n_summary self._total_chunks[layer_idx] += n_summary return C = self.summary_chunk_size ws = self.window_sizes[layer_idx] scn = self.sliding_chunk_nums[layer_idx] cbl = self._chunk_buf_len[layer_idx] ptr = self._window_write_ptr[layer_idx] n_old = self._old_summary_len[layer_idx] # Step 1: Flush mirror region β†’ ring (must happen before evict touches mirror[0]) mirror_start = ws + n_old if ptr + cbl <= ws: self.key_cache[layer_idx][:, :, ptr:ptr + cbl, :].copy_( self.key_cache[layer_idx][:, :, mirror_start:mirror_start + cbl, :]) self.value_cache[layer_idx][:, :, ptr:ptr + cbl, :].copy_( self.value_cache[layer_idx][:, :, mirror_start:mirror_start + cbl, :]) else: first = ws - ptr self.key_cache[layer_idx][:, :, ptr:ws, :].copy_( self.key_cache[layer_idx][:, :, mirror_start:mirror_start + first, :]) self.value_cache[layer_idx][:, :, ptr:ws, :].copy_( self.value_cache[layer_idx][:, :, mirror_start:mirror_start + first, :]) rest = cbl - first self.key_cache[layer_idx][:, :, :rest, :].copy_( self.key_cache[layer_idx][:, :, mirror_start + first:mirror_start + cbl, :]) self.value_cache[layer_idx][:, :, :rest, :].copy_( self.value_cache[layer_idx][:, :, mirror_start + first:mirror_start + cbl, :]) self._window_write_ptr[layer_idx] = (ptr + cbl) % ws if self._n_valid_window[layer_idx] < ws: self._n_valid_window[layer_idx] = min(ws, self._n_valid_window[layer_idx] + cbl) self._chunk_buf_len[layer_idx] = 0 # Step 2: Evict oldest new_summary β†’ old_summary (now safe β€” mirror already flushed) if self._new_sum_len[layer_idx] >= scn: read_ptr = self._new_sum_write_ptr[layer_idx] old_dst = ws + n_old # == mirror_start, but mirror data is already in ring # Check capacity for old_summary growth needed = old_dst + 1 + C # +1 for new old_sum, +C for future chunk mirror if needed > self._capacity[layer_idx]: new_s_cap = max(self._old_summary_cap[layer_idx] * 2, n_old + self._SUMMARY_INIT_CAP) new_total = ws + new_s_cap + C old_k, old_v = self.key_cache[layer_idx], self.value_cache[layer_idx] bsz, heads, _, head_dim = old_k.shape nk = torch.empty(bsz, heads, new_total, head_dim, dtype=old_k.dtype, device=old_k.device) nv = torch.empty(bsz, heads, new_total, head_dim, dtype=old_v.dtype, device=old_v.device) copy_len = ws + n_old nk[:, :, :copy_len, :].copy_(old_k[:, :, :copy_len, :]) nv[:, :, :copy_len, :].copy_(old_v[:, :, :copy_len, :]) self.key_cache[layer_idx] = nk self.value_cache[layer_idx] = nv self._old_summary_cap[layer_idx] = new_s_cap self._capacity[layer_idx] = new_total self.key_cache[layer_idx][:, :, old_dst:old_dst+1, :].copy_( self._new_sum_key_buf[layer_idx][:, :, read_ptr:read_ptr+1, :]) self.value_cache[layer_idx][:, :, old_dst:old_dst+1, :].copy_( self._new_sum_value_buf[layer_idx][:, :, read_ptr:read_ptr+1, :]) self._old_summary_len[layer_idx] += 1 # Step 3: Write new summary to new_summary_buf (overwrite oldest slot) w_ptr = self._new_sum_write_ptr[layer_idx] self._new_sum_key_buf[layer_idx][:, :, w_ptr:w_ptr+1, :].copy_(key_states) self._new_sum_value_buf[layer_idx][:, :, w_ptr:w_ptr+1, :].copy_(value_states) self._new_sum_write_ptr[layer_idx] = (w_ptr + 1) % scn if self._new_sum_len[layer_idx] < scn: self._new_sum_len[layer_idx] += 1 self.cur_chunk_sizes[layer_idx] += n_summary self._total_chunks[layer_idx] += n_summary # ── Decode: get KV for attention ── def get_attention_kv(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: """Get full KV for text token attention. Large window: buffer[:text_len] Small window (steady state): key_cache[:ws + n_old + cbl] β€” single slice, zero cat """ if self.is_large_window[layer_idx]: tl = self._text_len[layer_idx] return (self.key_cache[layer_idx][:, :, :tl, :], self.value_cache[layer_idx][:, :, :tl, :]) ws = self.window_sizes[layer_idx] nv = self._n_valid_window[layer_idx] cbl = self._chunk_buf_len[layer_idx] n_old = self._old_summary_len[layer_idx] # Steady state: ring full β†’ [ring(ws) | old_sums(n_old) | chunk_mirror(cbl)] contiguous if nv == ws: end = ws + n_old + cbl return (self.key_cache[layer_idx][:, :, :end, :], self.value_cache[layer_idx][:, :, :end, :]) # Warmup: ring not full, [nv:ws] is gap β†’ cat parts_k, parts_v = [], [] if nv > 0: parts_k.append(self.key_cache[layer_idx][:, :, :nv, :]) parts_v.append(self.value_cache[layer_idx][:, :, :nv, :]) if cbl > 0: mirror_start = ws + n_old parts_k.append(self.key_cache[layer_idx][:, :, mirror_start:mirror_start + cbl, :]) parts_v.append(self.value_cache[layer_idx][:, :, mirror_start:mirror_start + cbl, :]) if n_old > 0: parts_k.append(self.key_cache[layer_idx][:, :, ws:ws + n_old, :]) parts_v.append(self.value_cache[layer_idx][:, :, ws:ws + n_old, :]) if len(parts_k) == 1: return parts_k[0], parts_v[0] return torch.cat(parts_k, dim=2), torch.cat(parts_v, dim=2) def get_current_chunk_kv(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: """Get KV of the current chunk's C text tokens for summary token attention.""" C = self.summary_chunk_size if self.is_large_window[layer_idx]: tl = self._text_len[layer_idx] return (self.key_cache[layer_idx][:, :, tl - C:tl, :], self.value_cache[layer_idx][:, :, tl - C:tl, :]) else: ws = self.window_sizes[layer_idx] n_old = self._old_summary_len[layer_idx] cbl = self._chunk_buf_len[layer_idx] mirror_start = ws + n_old return (self.key_cache[layer_idx][:, :, mirror_start:mirror_start + cbl, :], self.value_cache[layer_idx][:, :, mirror_start:mirror_start + cbl, :]) def reset_chunk_counter(self): """Reset chunk counters after a chunk boundary step completes.""" block = self.summary_chunk_size + self.summary_token_num for layer_idx in range(self.num_hidden_layers): if self.cur_chunk_sizes[layer_idx] >= block: self.cur_chunk_sizes[layer_idx] %= block class Qwen3MLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def _sdpa_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_output = F.scaled_dot_product_attention( query, key_states, value_states, attn_mask=None, dropout_p=dropout, is_causal=False, ) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, None class Qwen3Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Qwen3Config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attn_output, attn_weights = _sdpa_attention_forward( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class Qwen3SummaryAttention(Qwen3Attention): """ Summary-aware variant of Qwen3Attention: uses a sliding summary mask. """ def __init__(self, config: Qwen3Config, layer_idx: int): super().__init__(config, layer_idx) self.summary_chunk_size = getattr(self.config, "summary_chunk_size", 0) self.summary_token_num = getattr(self.config, "summary_token_num", 0) # Cache sliding_chunk_num to avoid eval() on every forward call val = getattr(config, "summary_sliding_chunk_num", 0) or 0 val = _parse_config_pattern(val) if isinstance(val, list): self._sliding_chunk_num = val[layer_idx] else: self._sliding_chunk_num = int(val) if config.summary_independent_parameters and config.mix_coeff > 0: self.q_proj_summary = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) self.k_proj_summary = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.v_proj_summary = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) def _get_sliding_chunk_num(self): return self._sliding_chunk_num def get_query_key_value_tensors(self, hidden_states): input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) return query_states, key_states, value_states def get_query_key_value_tensors_summary(self, hidden_states): input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_norm(self.q_proj_summary(hidden_states).view(hidden_shape)).transpose(1, 2) key_states = self.k_norm(self.k_proj_summary(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj_summary(hidden_states).view(hidden_shape).transpose(1, 2) return query_states, key_states, value_states @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, summary_ctx: Optional[SummaryBatchContext] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: input_shape = hidden_states.shape[:-1] if hidden_states.size(0) != 1: raise ValueError("Summary sliding attention only supports batch size=1.") # Compute q/k/v for the full sequence once. if self.config.summary_independent_parameters: if summary_ctx is None: raise ValueError("summary_ctx is required when using summary_independent_parameters.") summary_mask = summary_ctx.summary_mask summary_pos = summary_mask[0] assert (summary_mask == summary_mask[0:1]).all() if self.config.mix_coeff == 0: # When mix_coeff=0, summary projections have no effect β€” skip clone + extra linear query_states, key_states, value_states = self.get_query_key_value_tensors(hidden_states) else: query, key, value = self.get_query_key_value_tensors(hidden_states) query_states = query.clone() key_states = key.clone() value_states = value.clone() hs_summary = hidden_states[:, summary_pos, :] if hs_summary.size(1) > 0: base_q_summary = query[:, :, summary_pos, :] base_k_summary = key[:, :, summary_pos, :] base_v_summary = value[:, :, summary_pos, :] q_s, k_s, v_s = self.get_query_key_value_tensors_summary(hs_summary) q_s = self.config.mix_coeff * q_s + (1.0 - self.config.mix_coeff) * base_q_summary k_s = self.config.mix_coeff * k_s + (1.0 - self.config.mix_coeff) * base_k_summary v_s = self.config.mix_coeff * v_s + (1.0 - self.config.mix_coeff) * base_v_summary query_states[:, :, summary_pos, :] = q_s key_states[:, :, summary_pos, :] = k_s value_states[:, :, summary_pos, :] = v_s else: query_states, key_states, value_states = self.get_query_key_value_tensors(hidden_states) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_len = query_states.shape[2] is_prefill = past_key_values is None or not past_key_values._reorganized if is_prefill: # Prefill: use standard append and summary_attn_func if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} if summary_ctx is not None: cache_kwargs["summary_mask"] = summary_ctx.summary_mask key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) with torch.cuda.device(query_states.device): attn_output, attn_weights = summary_attn_func( query_states.transpose(1,2).contiguous(), key_states.transpose(1,2).contiguous(), value_states.transpose(1,2).contiguous(), self.summary_chunk_size, self.summary_token_num, self._get_sliding_chunk_num(), summary_pos=summary_ctx.summary_mask.squeeze() ) elif query_len == 1: # Single text token decode: write to cache, attend to full buffer past_key_values.update_text(key_states, value_states, self.layer_idx) k_full, v_full = past_key_values.get_attention_kv(self.layer_idx) attn_output, attn_weights = _sdpa_attention_forward( self, query_states, k_full, v_full, None, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=self.sliding_window, **kwargs, ) else: # Chunk boundary: query = [text_token, summary_token(s)] # Split into text (first token) and summary (remaining tokens) q_text = query_states[:, :, :1, :] q_summary = query_states[:, :, 1:, :] k_text = key_states[:, :, :1, :] v_text = value_states[:, :, :1, :] k_summary = key_states[:, :, 1:, :] v_summary = value_states[:, :, 1:, :] # 1. Write text token to cache, get full KV, run text attention past_key_values.update_text(k_text, v_text, self.layer_idx) k_full, v_full = past_key_values.get_attention_kv(self.layer_idx) text_out, _ = _sdpa_attention_forward( self, q_text, k_full, v_full, None, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=self.sliding_window, **kwargs, ) # 2. Summary attention: attend to current chunk's C text tokens + own KV (self-attention) # The original model includes the summary token's own KV in its attention # (causal within summary positions). With S=1, this is just self-attention. k_chunk, v_chunk = past_key_values.get_current_chunk_kv(self.layer_idx) k_chunk_with_self = torch.cat([k_chunk, k_summary], dim=2) v_chunk_with_self = torch.cat([v_chunk, v_summary], dim=2) summary_out, _ = _sdpa_attention_forward( self, q_summary, k_chunk_with_self, v_chunk_with_self, None, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=self.sliding_window, **kwargs, ) # 3. Write summary KV to cache past_key_values.update_summary(k_summary, v_summary, self.layer_idx) attn_output = torch.cat([text_out, summary_out], dim=2) attn_weights = None attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class Qwen3DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Qwen3Config, layer_idx: int): super().__init__() self.config = config self.hidden_size = config.hidden_size # Use SummaryAttention if enabled in config if getattr(config, "use_summary_attention", False) is True and config.summary_layer_freq[layer_idx] == 1: self.self_attn = Qwen3SummaryAttention(config=config, layer_idx=layer_idx) elif getattr(config, "use_summary_attention", False) is False and config.summary_layer_freq[layer_idx] == 0: self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx) else: raise ValueError(f'Check config.summary_layer_freq {config.summary_layer_freq} and config.use_summary_attention {config.use_summary_attention}') self.mlp = Qwen3MLP(config) self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if getattr(config, "summary_independent_attention_layernorm", False): self.input_layernorm_summary = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attention_type = config.layer_types[layer_idx] @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC summary_ctx: Optional[SummaryBatchContext] = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: residual = hidden_states if getattr(self.config, "summary_independent_attention_layernorm", False): summary_mask = summary_ctx.summary_mask assert (summary_mask == summary_mask[0:1]).all(), \ "summary_mask must be identical across batch" hidden_states = self.input_layernorm(hidden_states) if summary_mask.any(): hidden_summary = residual[:, summary_mask[0].to(residual.device), :] hidden_summary = self.input_layernorm_summary(hidden_summary) hidden_states[:, summary_mask[0], :] = hidden_summary else: hidden_states = self.input_layernorm(hidden_states) # Self Attention - pass summary_ctx if using summary attention attn_kwargs = { "hidden_states": hidden_states, "attention_mask": attention_mask, "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": use_cache, "cache_position": cache_position, "position_embeddings": position_embeddings, **kwargs, } if isinstance(self.self_attn, Qwen3SummaryAttention): attn_kwargs["summary_ctx"] = summary_ctx hidden_states, _ = self.self_attn(**attn_kwargs) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states @auto_docstring class Qwen3PreTrainedModel(PreTrainedModel): config: Qwen3Config base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Qwen3DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": Qwen3DecoderLayer, "attentions": Qwen3Attention, } class Qwen3RotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` def __init__(self, config: Qwen3Config, device=None): super().__init__() self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_type = self.config.rope_parameters["rope_type"] rope_init_fn: Callable = self.compute_default_rope_parameters if self.rope_type != "default": rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = inv_freq @staticmethod def compute_default_rope_parameters( config: Optional[Qwen3Config] = None, device: Optional["torch.device"] = None, seq_len: Optional[int] = None, ) -> tuple["torch.Tensor", float]: """ Computes the inverse frequencies according to the original RoPE implementation Args: config ([`~transformers.PreTrainedConfig`]): The model configuration. device (`torch.device`): The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ base = config.rope_parameters["rope_theta"] dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads attention_factor = 1.0 # Unused in this type of RoPE # Compute the inverse frequencies inv_freq = 1.0 / ( base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) ) return inv_freq, attention_factor @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @auto_docstring class Qwen3Model(Qwen3PreTrainedModel): def __init__(self, config: Qwen3Config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size if not getattr(config, "summary_layer_freq", False): if config.use_summary_attention: config.summary_layer_freq = [1]*config.num_hidden_layers else: config.summary_layer_freq = [0]*config.num_hidden_layers Warning(f'Please set config.summary_layer_freq, temp set summary_layer_freq = {config.num_hidden_layers}') else: config.summary_layer_freq = _parse_config_pattern(config.summary_layer_freq) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen3RotaryEmbedding(config=config) self.gradient_checkpointing = False self.has_sliding_layers = "sliding_attention" in self.config.layer_types # Cache per-layer sliding_chunk_nums for KV cache eviction _sv = _parse_config_pattern(getattr(config, "summary_sliding_chunk_num", 0) or 0) if isinstance(_sv, list): self._sliding_chunk_nums = [int(v) for v in _sv] else: self._sliding_chunk_nums = [int(_sv)] * config.num_hidden_layers # Initialize weights and apply final processing self.post_init() def _expand_input_with_summary_tokens(self, input_ids): """Expand input_ids with summary tokens for prefill phase (vectorized). Returns: Tuple of (expanded_input_ids, position_ids, text_only_mask) """ summary_chunk = self.config.summary_chunk_size summary_num = self.config.summary_token_num summary_begin = self.config.summary_token_begin if summary_chunk == 0 or summary_num == 0: return input_ids, None, None batch_size, seq_len = input_ids.shape device = input_ids.device dtype = input_ids.dtype block = summary_chunk + summary_num # Number of full chunks and remainder n_full_chunks = seq_len // summary_chunk remainder = seq_len % summary_chunk has_remainder = remainder > 0 # Total expanded length: full_chunks * block + remainder expanded_len = n_full_chunks * block + (remainder if has_remainder else 0) # --- Build expanded_input_ids --- expanded_ids = torch.empty((batch_size, expanded_len), dtype=dtype, device=device) text_only_mask = torch.zeros((batch_size, expanded_len), dtype=torch.bool, device=device) # Compute text positions: for chunk i, text goes to [i*block, i*block+summary_chunk) # Summary positions: [i*block+summary_chunk, (i+1)*block) if n_full_chunks > 0: chunk_indices = torch.arange(n_full_chunks, device=device) # Text source positions in original input_ids text_src_offsets = (chunk_indices * summary_chunk).unsqueeze(1) + torch.arange(summary_chunk, device=device).unsqueeze(0) # [n_full_chunks, summary_chunk] # Text dest positions in expanded text_dst_offsets = (chunk_indices * block).unsqueeze(1) + torch.arange(summary_chunk, device=device).unsqueeze(0) # [n_full_chunks, summary_chunk] # Summary dest positions summary_dst_offsets = (chunk_indices * block + summary_chunk).unsqueeze(1) + torch.arange(summary_num, device=device).unsqueeze(0) # [n_full_chunks, summary_num] text_src_flat = text_src_offsets.reshape(-1) text_dst_flat = text_dst_offsets.reshape(-1) summary_dst_flat = summary_dst_offsets.reshape(-1) # Copy text tokens expanded_ids[:, text_dst_flat] = input_ids[:, text_src_flat] text_only_mask[:, text_dst_flat] = True # Fill summary tokens summary_ids_val = torch.arange(summary_num, device=device, dtype=dtype) + summary_begin expanded_ids[:, summary_dst_flat] = summary_ids_val.repeat(n_full_chunks).unsqueeze(0).expand(batch_size, -1) # Handle remainder (last partial chunk, no summary tokens) if has_remainder: rem_start_src = n_full_chunks * summary_chunk rem_start_dst = n_full_chunks * block rem_offsets = torch.arange(remainder, device=device) expanded_ids[:, rem_start_dst + rem_offsets] = input_ids[:, rem_start_src + rem_offsets] text_only_mask[:, rem_start_dst + rem_offsets] = True # --- Build position_ids --- position_ids = torch.empty((batch_size, expanded_len), dtype=torch.long, device=device) if n_full_chunks > 0: # Text position IDs if self.config.summary_chunk_position_ids_type == 'origin': text_pos = text_src_flat.unsqueeze(0).expand(batch_size, -1) elif self.config.summary_chunk_position_ids_type == 'inner_chunk': inner_pos = torch.arange(summary_chunk, device=device).repeat(n_full_chunks) text_pos = inner_pos.unsqueeze(0).expand(batch_size, -1) else: raise ValueError(f'Check config.summary_chunk_position_ids_type: {self.config.summary_chunk_position_ids_type}') position_ids[:, text_dst_flat] = text_pos # Summary position IDs if self.config.summary_token_position_ids_type == 'zeros': position_ids[:, summary_dst_flat] = 0 elif self.config.summary_token_position_ids_type in ('last_chunk_slice_left', 'last_chunk_slice_right'): # Vectorized slice_ends computation for all chunks at once if self.config.summary_token_position_ids_type == 'last_chunk_slice_left': idx = torch.arange(0, summary_num, device=device, dtype=torch.long) else: idx = torch.arange(1, summary_num + 1, device=device, dtype=torch.long) # For each chunk i: prev_text_end = i * summary_chunk prev_ends = (chunk_indices * summary_chunk).unsqueeze(1) # [n_full_chunks, 1] slice_ends = prev_ends + (idx.unsqueeze(0) * summary_chunk) // summary_num - 1 # [n_full_chunks, summary_num] slice_ends = slice_ends.clamp(min=0) # Clamp per-chunk: min is prev_text_end for that chunk slice_ends = torch.max(slice_ends, prev_ends) position_ids[:, summary_dst_flat] = slice_ends.reshape(-1).unsqueeze(0).expand(batch_size, -1) else: raise ValueError(f'Unknown summary_token_position_ids_type: {self.config.summary_token_position_ids_type}') # Remainder position IDs if has_remainder: if self.config.summary_chunk_position_ids_type == 'origin': rem_pos = (rem_start_src + rem_offsets).unsqueeze(0).expand(batch_size, -1) elif self.config.summary_chunk_position_ids_type == 'inner_chunk': rem_pos = rem_offsets.unsqueeze(0).expand(batch_size, -1) else: raise ValueError(f'Check config.summary_chunk_position_ids_type: {self.config.summary_chunk_position_ids_type}') position_ids[:, rem_start_dst + rem_offsets] = rem_pos return expanded_ids, position_ids, text_only_mask def _build_summary_context(self, input_ids, position_ids, is_prefill, use_cache): """Build summary context for attention layers.""" summary_chunk = self.config.summary_chunk_size summary_num = self.config.summary_token_num summary_begin = self.config.summary_token_begin if summary_chunk > 0 and summary_num > 0: return build_summary_sliding_context( input_ids=input_ids, position_ids=position_ids, summary_token_num=summary_num, summary_token_begin=summary_begin, ) return None def _filter_summary_tokens(self, hidden_states, text_only_mask, use_summary, is_decode): """Filter out summary tokens from output hidden states.""" if text_only_mask is not None: # Prefill: vectorized filtering using boolean mask batch_size, _, hidden_size = hidden_states.shape text_length = text_only_mask[0].sum().item() return hidden_states[text_only_mask.to(hidden_states.device)].reshape(batch_size, text_length, hidden_size) elif use_summary and is_decode and hidden_states.size(1) > 1: # Decode: if we have multiple tokens, only return the first (text token) return hidden_states[:, :1, :] return hidden_states @check_model_inputs() @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, summary_ctx: Optional[SummaryBatchContext] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") use_summary = getattr(self.config, "use_summary_attention", False) is_prefill = past_key_values is None or past_key_values.get_seq_length() == 0 # Prefill phase with summary attention: expand input_ids with summary tokens text_only_mask = None if use_summary and input_ids is not None and inputs_embeds is None and is_prefill: input_ids, position_ids, text_only_mask = self._expand_input_with_summary_tokens(input_ids) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) # Initialize cache if use_cache and past_key_values is None: if use_summary: past_key_values = Qwen3RingBufferCache( config=self.config, sliding_chunk_nums=self._sliding_chunk_nums) else: past_key_values = DynamicCache(config=self.config) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) # Build summary context if needed if use_summary and summary_ctx is None and input_ids is not None: summary_ctx = self._build_summary_context(input_ids, position_ids, is_prefill, use_cache) causal_mask_mapping = attention_mask if not isinstance(causal_mask_mapping, (dict, list)): if summary_ctx and summary_ctx.enabled: seq_len = inputs_embeds.shape[1] # During prefill, Qwen3SummaryAttention uses summary_attn_func # which does not need a dense mask. Skip expensive mask construction. # During decode, prepare_inputs_for_generation already computed # per-layer keep_indices and passed them as attention_mask (list). # If we reach here with a non-list, it means no mask is needed. causal_mask_mapping = None else: # Prepare mask arguments mask_kwargs = { "config": self.config, "input_embeds": inputs_embeds, "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, "position_ids": position_ids, } # Create the masks - disable causal mask when summary context is enabled causal_mask_mapping = { "full_attention": create_causal_mask(**mask_kwargs), } # The sliding window alternating layers are not always activated depending on the config if self.has_sliding_layers: causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): if causal_mask_mapping is None: layer_mask = None elif isinstance(causal_mask_mapping, list): layer_mask = causal_mask_mapping[layer_idx] else: layer_mask = causal_mask_mapping[decoder_layer.attention_type] hidden_states = decoder_layer( hidden_states, attention_mask=layer_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, summary_ctx=summary_ctx, **kwargs, ) hidden_states = self.norm(hidden_states) # After prefill: reorganize cache to ring buffer layout if use_cache and use_summary and past_key_values is not None and is_prefill: if hasattr(past_key_values, 'reorganize_after_prefill') and summary_ctx is not None: past_key_values.reorganize_after_prefill(summary_ctx.summary_mask) # After chunk boundary decode: reset chunk counters if use_cache and use_summary and past_key_values is not None and not is_prefill: if hasattr(past_key_values, 'reset_chunk_counter'): past_key_values.reset_chunk_counter() # Filter out summary tokens from output hidden_states = self._filter_summary_tokens(hidden_states, text_only_mask, use_summary, past_key_values is not None and past_key_values.get_seq_length() > 0) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, ) @auto_docstring class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) self.model = Qwen3Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, summary_ctx: Optional[SummaryBatchContext] = None, **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Example: ```python >>> from transformers import AutoTokenizer, Qwen3ForCausalLM >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B") >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, summary_ctx=summary_ctx, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss if isinstance(logits_to_keep, int) and logits_to_keep == 0 and labels is None: # Inference: only need last token's logits to avoid OOM from [seq_len, vocab_size] logits = self.lm_head(hidden_states[:, -1:, :]) else: slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) truncate_n = getattr(self.config, "truncate_predict_nums", 151936) if truncate_n > 0: logits = logits[..., :truncate_n] loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=logits.shape[-1], **kwargs) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def _build_summary_attention_mask_for_generation( self, *, input_ids: torch.LongTensor, past_key_values: Optional[Cache], attention_mask: Optional[torch.Tensor], ) -> Optional[torch.Tensor]: """Ring buffer cache handles attention internally β€” no mask needed for decode.""" if isinstance(past_key_values, Qwen3RingBufferCache): return None return attention_mask def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, cache_position: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, **kwargs, ): use_summary = getattr(self.config, "use_summary_attention", False) # If not using summary attention, use standard behavior if not use_summary: return super().prepare_inputs_for_generation( input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, cache_position=cache_position, position_ids=position_ids, **kwargs, ) # For summary attention: handle cache-based input slicing summary_chunk_size = getattr(self.config, "summary_chunk_size", 0) summary_token_num = getattr(self.config, "summary_token_num", 0) summary_token_begin = getattr(self.config, "summary_token_begin", 0) # Prefill phase: pass full sequence, forward() will handle summary token insertion if past_key_values is None or past_key_values.get_seq_length() == 0: if cache_position is None: cache_position = torch.arange(0, input_ids.shape[1], device=input_ids.device) return { "input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids, "past_key_values": past_key_values, "cache_position": cache_position, "use_cache": kwargs.get("use_cache"), } # Decode phase: only pass new tokens not in cache # Get current chunk size (number of text tokens in current chunk) cur_chunk = past_key_values.get_cur_chunk_size() if hasattr(past_key_values, "get_cur_chunk_size") else 0 true_token_num = past_key_values.get_true_token_num() # Only take the new tokens that haven't been processed if input_ids.shape[1] > 1: # Slice to get only new tokens new_token_count = input_ids.shape[1] - true_token_num assert new_token_count > 0, f'new_token_count={new_token_count} should be greater than 0' input_ids = input_ids[:, -new_token_count:] device = input_ids.device # Check if we need to insert summary tokens # If cur_chunk >= summary_chunk_size, we need to generate summary tokens if cur_chunk == summary_chunk_size - 1: # Insert summary tokens batch_size = input_ids.shape[0] summary_ids = ( torch.arange(summary_token_num, device=device, dtype=input_ids.dtype) + summary_token_begin ).unsqueeze(0).repeat(batch_size, 1) # Concatenate: [text_token, summary_tokens] input_ids = torch.cat([input_ids, summary_ids], dim=1) # Position IDs: text token uses cur_chunk, summary tokens use 0 if self.config.summary_chunk_position_ids_type == 'origin': text_pos = torch.full((batch_size, 1), past_key_values.get_true_token_num(), device=device, dtype=torch.long) elif self.config.summary_chunk_position_ids_type == 'inner_chunk': text_pos = torch.full((batch_size, 1), cur_chunk, device=device, dtype=torch.long) else: raise ValueError(f'Check config.summary_chunk_position_ids_type: {self.config.summary_chunk_position_ids_type}') if self.config.summary_token_position_ids_type == 'zeros': summary_pos = torch.zeros((batch_size, summary_token_num), device=device, dtype=torch.long) elif self.config.summary_token_position_ids_type == 'last_chunk_slice_left': # η­‰εˆ†ζˆ summary_num 份,每δΈͺ summary token 取对应 slice ηš„ζœ«ε°Ύ prev_text_end = past_key_values.get_true_token_num()+1-summary_chunk_size cur_text_end = past_key_values.get_true_token_num()+1 chunk_len = cur_text_end - prev_text_end idx = torch.arange(0, summary_token_num, device=device, dtype=torch.long,) # ζ―δΈ€δ»½ηš„ζœ«ε°ΎοΌˆε…¨ε±€ positionοΌ‰ slice_ends = prev_text_end + (idx * chunk_len) // summary_token_num - 1 slice_ends = slice_ends.clamp(min=prev_text_end) summary_pos = slice_ends.to(dtype=torch.long, device=device).unsqueeze(0) elif self.config.summary_token_position_ids_type == 'last_chunk_slice_right': # η­‰εˆ†ζˆ summary_num 份,每δΈͺ summary token 取对应 slice ηš„ζœ«ε°Ύ prev_text_end = past_key_values.get_true_token_num()+1-summary_chunk_size cur_text_end = past_key_values.get_true_token_num()+1 chunk_len = cur_text_end - prev_text_end idx = torch.arange(1, summary_token_num + 1, device=device, dtype=torch.long,) # ζ―δΈ€δ»½ηš„ζœ«ε°ΎοΌˆε…¨ε±€ positionοΌ‰ slice_ends = prev_text_end + (idx * chunk_len) // summary_token_num - 1 slice_ends = slice_ends.clamp(min=prev_text_end) summary_pos = slice_ends.to(dtype=torch.long, device=device).unsqueeze(0) else: raise ValueError('') position_ids = torch.cat([text_pos, summary_pos], dim=1) else: # Normal decode: just the new text token with position = cur_chunk if position_ids is None: batch_size = input_ids.shape[0] if self.config.summary_chunk_position_ids_type == 'origin': position_ids = torch.full((batch_size, input_ids.shape[1]), past_key_values.get_true_token_num(), device=input_ids.device, dtype=torch.long) elif self.config.summary_chunk_position_ids_type == 'inner_chunk': position_ids = torch.full((batch_size, input_ids.shape[1]), cur_chunk, device=input_ids.device, dtype=torch.long) else: raise ValueError(f'Check config.summary_chunk_position_ids_type: {self.config.summary_chunk_position_ids_type}') return { "input_ids": input_ids, "attention_mask": self._build_summary_attention_mask_for_generation( input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, ), "position_ids": position_ids, "past_key_values": past_key_values, "cache_position": cache_position, "use_cache": kwargs.get("use_cache"), } class Qwen3ForSequenceClassification(GenericForSequenceClassification, Qwen3PreTrainedModel): pass class Qwen3ForTokenClassification(GenericForTokenClassification, Qwen3PreTrainedModel): pass class Qwen3ForQuestionAnswering(GenericForQuestionAnswering, Qwen3PreTrainedModel): base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model` __all__ = [ "Qwen3ForCausalLM", "Qwen3ForQuestionAnswering", "Qwen3PreTrainedModel", "Qwen3Model", "Qwen3ForSequenceClassification", "Qwen3ForTokenClassification", "Qwen3RingBufferCache", "Qwen3SummaryAttention", "SummaryBatchContext", "build_summary_context", "build_summary_sliding_context", ]