| | """Qwen3 with scaled sequence length via embedding replication. |
| | |
| | Extends Qwen3Model/Qwen3ForCausalLM with scale_seq_times additional |
| | embedding tables. During forward, the original token sequence of length L |
| | is expanded to (1 + scale_seq_times) * L via interleaved multi-stream |
| | embedding, then processed by the standard Qwen3 transformer body. |
| | |
| | Architecture overview (n = 1 + scale_seq_times): |
| | - n Embedding tables: E_0 (original), E_1, ..., E_{n-1} (new) |
| | - Interleaved layout: [E_0(t1), E_1(t1), ..., E_0(t2), E_1(t2), ...] |
| | - RoPE positions: 0, 1, 2, ..., n*L - 1 (continuous) |
| | - Standard causal attention over all n*L positions |
| | - Contraction: only the last stream's hidden_state per token goes through |
| | lm_head (the stream with the richest context), matching v4dev behavior. |
| | |
| | See: Scale_SeqLen_via_Embedding_Replication.md |
| | """ |
| |
|
| | from typing import Optional, Tuple, Union |
| |
|
| | import torch |
| | from torch import nn |
| | from transformers import Qwen3ForCausalLM, Qwen3Model |
| | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| | from transformers.processing_utils import Unpack |
| | from transformers.utils import TransformersKwargs, can_return_tuple |
| |
|
| | from .configuration_qwen3_scale_seq import Qwen3ScaleSeqConfig |
| |
|
| |
|
| | class Qwen3ScaleSeqModel(Qwen3Model): |
| | """Qwen3Model extended with multi-stream embedding for sequence scaling.""" |
| |
|
| | config_class = Qwen3ScaleSeqConfig |
| |
|
| | def __init__(self, config: Qwen3ScaleSeqConfig): |
| | super().__init__(config) |
| | self.scale_seq_times = getattr(config, "scale_seq_times", 0) |
| |
|
| | if self.scale_seq_times > 0: |
| | self.scale_seq_embed_tokens_list = nn.ModuleList( |
| | [ |
| | nn.Embedding( |
| | config.vocab_size, |
| | config.hidden_size, |
| | self.padding_idx, |
| | ) |
| | for _ in range(self.scale_seq_times) |
| | ] |
| | ) |
| |
|
| | self.post_init() |
| |
|
| | def _expand_scale_seq( |
| | self, |
| | input_ids: torch.LongTensor, |
| | hidden_states: torch.FloatTensor, |
| | ) -> torch.FloatTensor: |
| | """Expand hidden_states from (B, T, D) to (B, T * scale, D). |
| | |
| | Layout per original token i: |
| | [main_emb_i, scale_seq_1_emb_i, ..., scale_seq_N_emb_i] |
| | |
| | Args: |
| | input_ids: (batch, seq_len) original token ids. |
| | hidden_states: (batch, seq_len, hidden) main embedding output. |
| | |
| | Returns: |
| | Expanded tensor of shape (batch, seq_len * scale, hidden). |
| | """ |
| | device = hidden_states.device |
| | B, T, D = hidden_states.shape |
| |
|
| | |
| | parts = [hidden_states.unsqueeze(2)] |
| |
|
| | for s in range(self.scale_seq_times): |
| | emb_module = self.scale_seq_embed_tokens_list[s] |
| | hs_s = emb_module(input_ids.to(emb_module.weight.device)).to(device) |
| | parts.append(hs_s.unsqueeze(2)) |
| |
|
| | |
| | expanded = torch.cat(parts, dim=2) |
| | return expanded.reshape(B, T * (self.scale_seq_times + 1), D) |
| |
|
| | def forward( |
| | self, |
| | input_ids: Optional[torch.LongTensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values=None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | **kwargs, |
| | ) -> Union[Tuple, BaseModelOutputWithPast]: |
| | if ( |
| | self.scale_seq_times > 0 |
| | and input_ids is not None |
| | and inputs_embeds is None |
| | ): |
| | scale = self.scale_seq_times + 1 |
| |
|
| | |
| | inputs_embeds = self.embed_tokens(input_ids) |
| | inputs_embeds = self._expand_scale_seq(input_ids, inputs_embeds) |
| |
|
| | B = inputs_embeds.shape[0] |
| | T_expanded = inputs_embeds.shape[1] |
| |
|
| | |
| | past_seen_tokens = ( |
| | past_key_values.get_seq_length() |
| | if past_key_values is not None else 0 |
| | ) |
| | cache_position = torch.arange( |
| | past_seen_tokens, past_seen_tokens + T_expanded, |
| | device=inputs_embeds.device, |
| | ) |
| | position_ids = cache_position.unsqueeze(0).expand(B, -1) |
| |
|
| | |
| | if attention_mask is not None: |
| | attention_mask = attention_mask.repeat_interleave(scale, dim=1) |
| |
|
| | input_ids = None |
| |
|
| | return super().forward( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | cache_position=cache_position, |
| | **kwargs, |
| | ) |
| |
|
| |
|
| | class Qwen3ScaleSeqForCausalLM(Qwen3ForCausalLM): |
| | """Qwen3ForCausalLM with multi-stream embedding for sequence scaling. |
| | |
| | Contraction: after the transformer body produces (B, T*scale, D), |
| | select only the last stream per token (the one with richest context) |
| | before applying lm_head, producing (B, T, vocab_size). |
| | """ |
| |
|
| | config_class = Qwen3ScaleSeqConfig |
| | _tied_weights_keys = ["lm_head.weight"] |
| |
|
| | def __init__(self, config: Qwen3ScaleSeqConfig): |
| | super().__init__(config) |
| | |
| | self.model = Qwen3ScaleSeqModel(config) |
| | self.post_init() |
| |
|
| | @can_return_tuple |
| | def forward( |
| | self, |
| | input_ids: Optional[torch.LongTensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values=None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | logits_to_keep: Union[int, torch.Tensor] = 0, |
| | **kwargs, |
| | ) -> CausalLMOutputWithPast: |
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | outputs = self.model( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | cache_position=cache_position, |
| | **kwargs, |
| | ) |
| |
|
| | hidden_states = outputs[0] |
| |
|
| | |
| | |
| | |
| | |
| | if self.model.scale_seq_times > 0: |
| | scale = self.model.scale_seq_times + 1 |
| | hidden_states = hidden_states[:, scale - 1::scale, :] |
| |
|
| | slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| | logits = self.lm_head(hidden_states[:, slice_indices, :]) |
| |
|
| | loss = None |
| | if labels is not None: |
| | loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) |
| |
|
| | return CausalLMOutputWithPast( |
| | loss=loss, |
| | logits=logits, |
| | past_key_values=outputs.past_key_values if use_cache else None, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |
| |
|
| |
|
| | __all__ = ["Qwen3ScaleSeqModel", "Qwen3ScaleSeqForCausalLM"] |
| |
|