File size: 8,746 Bytes
0d62c3c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 | """Qwen3 with scaled sequence length via embedding replication.
Extends Qwen3Model/Qwen3ForCausalLM with scale_seq_times additional
embedding tables. During forward, the original token sequence of length L
is expanded to (1 + scale_seq_times) * L via interleaved multi-stream
embedding, then processed by the standard Qwen3 transformer body.
Architecture overview (n = 1 + scale_seq_times):
- n Embedding tables: E_0 (original), E_1, ..., E_{n-1} (new)
- Interleaved layout: [E_0(t1), E_1(t1), ..., E_0(t2), E_1(t2), ...]
- RoPE positions: 0, 1, 2, ..., n*L - 1 (continuous)
- Standard causal attention over all n*L positions
- Contraction: only the last stream's hidden_state per token goes through
lm_head (the stream with the richest context), matching v4dev behavior.
See: Scale_SeqLen_via_Embedding_Replication.md
"""
from typing import Optional, Tuple, Union
import torch
from torch import nn
from transformers import Qwen3ForCausalLM, Qwen3Model
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, can_return_tuple
from .configuration_qwen3_scale_seq import Qwen3ScaleSeqConfig
class Qwen3ScaleSeqModel(Qwen3Model):
"""Qwen3Model extended with multi-stream embedding for sequence scaling."""
config_class = Qwen3ScaleSeqConfig
def __init__(self, config: Qwen3ScaleSeqConfig):
super().__init__(config)
self.scale_seq_times = getattr(config, "scale_seq_times", 0)
if self.scale_seq_times > 0:
self.scale_seq_embed_tokens_list = nn.ModuleList(
[
nn.Embedding(
config.vocab_size,
config.hidden_size,
self.padding_idx,
)
for _ in range(self.scale_seq_times)
]
)
self.post_init()
def _expand_scale_seq(
self,
input_ids: torch.LongTensor,
hidden_states: torch.FloatTensor,
) -> torch.FloatTensor:
"""Expand hidden_states from (B, T, D) to (B, T * scale, D).
Layout per original token i:
[main_emb_i, scale_seq_1_emb_i, ..., scale_seq_N_emb_i]
Args:
input_ids: (batch, seq_len) original token ids.
hidden_states: (batch, seq_len, hidden) main embedding output.
Returns:
Expanded tensor of shape (batch, seq_len * scale, hidden).
"""
device = hidden_states.device
B, T, D = hidden_states.shape
# (B, T, D) -> (B, T, 1, D)
parts = [hidden_states.unsqueeze(2)]
for s in range(self.scale_seq_times):
emb_module = self.scale_seq_embed_tokens_list[s]
hs_s = emb_module(input_ids.to(emb_module.weight.device)).to(device)
parts.append(hs_s.unsqueeze(2)) # (B, T, 1, D)
# (B, T, scale, D) -> (B, T * scale, D)
expanded = torch.cat(parts, dim=2)
return expanded.reshape(B, T * (self.scale_seq_times + 1), D)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values=None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
if (
self.scale_seq_times > 0
and input_ids is not None
and inputs_embeds is None
):
scale = self.scale_seq_times + 1
# Compute main embedding, then expand with scale_seq streams
inputs_embeds = self.embed_tokens(input_ids)
inputs_embeds = self._expand_scale_seq(input_ids, inputs_embeds)
B = inputs_embeds.shape[0]
T_expanded = inputs_embeds.shape[1]
# Recompute cache_position and position_ids in expanded space
past_seen_tokens = (
past_key_values.get_seq_length()
if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + T_expanded,
device=inputs_embeds.device,
)
position_ids = cache_position.unsqueeze(0).expand(B, -1)
# Expand attention_mask to match expanded sequence length
if attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(scale, dim=1)
input_ids = None # avoid double embedding lookup in super().forward()
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
class Qwen3ScaleSeqForCausalLM(Qwen3ForCausalLM):
"""Qwen3ForCausalLM with multi-stream embedding for sequence scaling.
Contraction: after the transformer body produces (B, T*scale, D),
select only the last stream per token (the one with richest context)
before applying lm_head, producing (B, T, vocab_size).
"""
config_class = Qwen3ScaleSeqConfig
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: Qwen3ScaleSeqConfig):
super().__init__(config)
# Replace the inner model with our scaled version
self.model = Qwen3ScaleSeqModel(config)
self.post_init()
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values=None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> CausalLMOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# ---- scale_seq contraction ----
# Contract expanded hidden_states (B, T*scale, D) back to logical
# token space (B, T, D) by selecting the last stream per token group
# (the stream with the richest context), matching v4dev behavior.
if self.model.scale_seq_times > 0:
scale = self.model.scale_seq_times + 1
hidden_states = hidden_states[:, scale - 1::scale, :]
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values if use_cache else None,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = ["Qwen3ScaleSeqModel", "Qwen3ScaleSeqForCausalLM"]
|