"""Qwen3 with scaled sequence length via embedding replication. Extends Qwen3Model/Qwen3ForCausalLM with scale_seq_times additional embedding tables. During forward, the original token sequence of length L is expanded to (1 + scale_seq_times) * L via interleaved multi-stream embedding, then processed by the standard Qwen3 transformer body. Architecture overview (n = 1 + scale_seq_times): - n Embedding tables: E_0 (original), E_1, ..., E_{n-1} (new) - Interleaved layout: [E_0(t1), E_1(t1), ..., E_0(t2), E_1(t2), ...] - RoPE positions: 0, 1, 2, ..., n*L - 1 (continuous) - Standard causal attention over all n*L positions - Contraction: only the last stream's hidden_state per token goes through lm_head (the stream with the richest context), matching v4dev behavior. See: Scale_SeqLen_via_Embedding_Replication.md """ from typing import Optional, Tuple, Union import torch from torch import nn from transformers import Qwen3ForCausalLM, Qwen3Model from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, can_return_tuple from .configuration_qwen3_scale_seq import Qwen3ScaleSeqConfig class Qwen3ScaleSeqModel(Qwen3Model): """Qwen3Model extended with multi-stream embedding for sequence scaling.""" config_class = Qwen3ScaleSeqConfig def __init__(self, config: Qwen3ScaleSeqConfig): super().__init__(config) self.scale_seq_times = getattr(config, "scale_seq_times", 0) if self.scale_seq_times > 0: self.scale_seq_embed_tokens_list = nn.ModuleList( [ nn.Embedding( config.vocab_size, config.hidden_size, self.padding_idx, ) for _ in range(self.scale_seq_times) ] ) self.post_init() def _expand_scale_seq( self, input_ids: torch.LongTensor, hidden_states: torch.FloatTensor, ) -> torch.FloatTensor: """Expand hidden_states from (B, T, D) to (B, T * scale, D). Layout per original token i: [main_emb_i, scale_seq_1_emb_i, ..., scale_seq_N_emb_i] Args: input_ids: (batch, seq_len) original token ids. hidden_states: (batch, seq_len, hidden) main embedding output. Returns: Expanded tensor of shape (batch, seq_len * scale, hidden). """ device = hidden_states.device B, T, D = hidden_states.shape # (B, T, D) -> (B, T, 1, D) parts = [hidden_states.unsqueeze(2)] for s in range(self.scale_seq_times): emb_module = self.scale_seq_embed_tokens_list[s] hs_s = emb_module(input_ids.to(emb_module.weight.device)).to(device) parts.append(hs_s.unsqueeze(2)) # (B, T, 1, D) # (B, T, scale, D) -> (B, T * scale, D) expanded = torch.cat(parts, dim=2) return expanded.reshape(B, T * (self.scale_seq_times + 1), D) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values=None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: if ( self.scale_seq_times > 0 and input_ids is not None and inputs_embeds is None ): scale = self.scale_seq_times + 1 # Compute main embedding, then expand with scale_seq streams inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self._expand_scale_seq(input_ids, inputs_embeds) B = inputs_embeds.shape[0] T_expanded = inputs_embeds.shape[1] # Recompute cache_position and position_ids in expanded space past_seen_tokens = ( past_key_values.get_seq_length() if past_key_values is not None else 0 ) cache_position = torch.arange( past_seen_tokens, past_seen_tokens + T_expanded, device=inputs_embeds.device, ) position_ids = cache_position.unsqueeze(0).expand(B, -1) # Expand attention_mask to match expanded sequence length if attention_mask is not None: attention_mask = attention_mask.repeat_interleave(scale, dim=1) input_ids = None # avoid double embedding lookup in super().forward() return super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, **kwargs, ) class Qwen3ScaleSeqForCausalLM(Qwen3ForCausalLM): """Qwen3ForCausalLM with multi-stream embedding for sequence scaling. Contraction: after the transformer body produces (B, T*scale, D), select only the last stream per token (the one with richest context) before applying lm_head, producing (B, T, vocab_size). """ config_class = Qwen3ScaleSeqConfig _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: Qwen3ScaleSeqConfig): super().__init__(config) # Replace the inner model with our scaled version self.model = Qwen3ScaleSeqModel(config) self.post_init() @can_return_tuple def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values=None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> CausalLMOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, **kwargs, ) hidden_states = outputs[0] # ---- scale_seq contraction ---- # Contract expanded hidden_states (B, T*scale, D) back to logical # token space (B, T, D) by selecting the last stream per token group # (the stream with the richest context), matching v4dev behavior. if self.model.scale_seq_times > 0: scale = self.model.scale_seq_times + 1 hidden_states = hidden_states[:, scale - 1::scale, :] slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values if use_cache else None, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) __all__ = ["Qwen3ScaleSeqModel", "Qwen3ScaleSeqForCausalLM"]