| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Callable, Optional, Union |
|
|
| import torch |
| from torch import nn |
| from dataclasses import dataclass |
|
|
| from transformers.cache_utils import Cache |
| from transformers.generation import GenerationMixin |
| from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
| from transformers.modeling_outputs import ( |
| BaseModelOutputWithPast, |
| CausalLMOutputWithPast |
| ) |
| from transformers.models.qwen3.modeling_qwen3 import Qwen3PreTrainedModel, Qwen3Model |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.processing_utils import Unpack |
| try: |
| from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging, ModelOutput |
| except ImportError: |
| from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, ModelOutput |
| from .configuration_qwen3_ts import Qwen3TSConfig |
| from typing import Any, Dict |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| |
| class TimeSeriesEmbedding(nn.Module): |
| def __init__(self, config): |
| super(TimeSeriesEmbedding, self).__init__() |
| self.patch_size = config['patch_size'] |
| self.num_layers = config['num_layers'] |
| self.hidden_size = config['hidden_size'] |
| self.num_features = config['num_features'] |
| self.max_sequence_length = config['max_sequence_length'] |
| self.use_position_embedding = config.get('use_position_embedding', False) |
| self.use_position_idx = config.get('use_position_idx', False) |
| self.use_layer_norm = config.get('use_layer_norm', False) |
| self.embedding_dim = config.get('embedding_dim', 16) |
| |
| if self.use_position_embedding: |
| |
| self.position_embedding = nn.Embedding(self.max_sequence_length + 1, self.embedding_dim) |
| self.padding_idx = self.max_sequence_length |
| input_size = 1 * self.patch_size + self.embedding_dim * self.patch_size |
| elif self.use_position_idx: |
| input_size = 2 * self.patch_size |
| else: |
| input_size = 1 * self.patch_size |
| |
| |
| layers = [] |
| for _ in range(self.num_layers - 1): |
| layers.append(nn.Linear(input_size, self.hidden_size)) |
| layers.append(nn.GELU()) |
| input_size = self.hidden_size |
|
|
| layers.append(nn.Linear(input_size, self.hidden_size)) |
| if self.use_layer_norm: |
| layers.append(nn.LayerNorm(self.hidden_size)) |
|
|
| self.mlp = nn.Sequential(*layers) |
|
|
| def forward(self, x: torch.Tensor): |
| batch_size = x.size(0) |
| x = x.reshape(batch_size, -1, self.num_features) |
|
|
| |
| mask = x[:, :, -1].long() |
| valid_lengths = mask.sum(dim=1).long() |
| patch_cnt = (valid_lengths + self.patch_size - 1) // self.patch_size |
|
|
| patches_list = [] |
| |
| all_position_indices = [] |
| patch_info_list = [] |
| |
| for i in range(batch_size): |
| vl = valid_lengths[i].item() |
| pc = patch_cnt[i].item() |
| if pc == 0: |
| continue |
| |
| |
| xi = x[i, :vl, :1] |
| total_padded_length = pc * self.patch_size |
| padding_length = total_padded_length - vl |
| |
| |
| position_indices = torch.arange(vl, device=x.device) |
| |
| if padding_length > 0: |
| |
| last_value = xi[-1:, :] |
| padding = last_value.repeat(padding_length, 1) |
| xi = torch.cat([xi, padding], dim=0) |
| |
| |
| padding_positions = torch.full((padding_length,), self.padding_idx, device=x.device) |
| position_indices = torch.cat([position_indices, padding_positions], dim=0) |
|
|
| |
| xi = xi.reshape(pc, self.patch_size) |
| position_indices = position_indices.reshape(pc, self.patch_size) |
|
|
| if self.use_position_embedding: |
| |
| all_position_indices.append(position_indices) |
| patch_info_list.append({ |
| 'xi': xi, |
| 'pc': pc, |
| 'sample_idx': i |
| }) |
| elif self.use_position_idx: |
| |
| pos_indices = torch.arange(vl, device=x.device).unsqueeze(1) |
| pos_indices = pos_indices / max(1, valid_lengths.max().item() - 1) |
| if padding_length > 0: |
| |
| padding_indices = torch.full((padding_length, 1), -1, device=x.device) |
| pos_indices = torch.cat([pos_indices, padding_indices], dim=0) |
| |
| xi_combined = torch.cat([xi.reshape(-1, 1), pos_indices], dim=1) |
| patch_input = xi_combined.reshape(pc, self.patch_size * 2) |
| patches_list.append(patch_input) |
| else: |
| |
| patch_input = xi |
| patches_list.append(patch_input) |
|
|
| |
| if self.use_position_embedding and all_position_indices: |
| |
| batch_position_indices = torch.cat(all_position_indices, dim=0) |
| |
| batch_pos_emb = self.position_embedding(batch_position_indices) |
| |
| |
| emb_start_idx = 0 |
| for patch_info in patch_info_list: |
| xi = patch_info['xi'] |
| pc = patch_info['pc'] |
| |
| |
| pos_emb = batch_pos_emb[emb_start_idx:emb_start_idx + pc] |
| emb_start_idx += pc |
| |
| |
| xi = xi.unsqueeze(-1) |
| patch_input = torch.cat([ |
| xi.flatten(1), |
| pos_emb.flatten(1) |
| ], dim=1) |
| patches_list.append(patch_input) |
|
|
| |
| if patches_list: |
| x_patches = torch.cat(patches_list, dim=0) |
| x = self.mlp(x_patches) |
| else: |
| |
| x = torch.empty(0, self.hidden_size, device=x.device) |
|
|
| return x, patch_cnt |
|
|
|
|
| @dataclass |
| class Qwen3TSCausalLMOutputWithPast(CausalLMOutputWithPast): |
| """ |
| Output type of Qwen3TSForCausalLM that includes additional fields for timeseries processing. |
| |
| Args: |
| loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
| Language modeling loss (for next-token prediction). |
| logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
| Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
| past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed): |
| Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape |
| `(batch_size, num_heads, sequence_length, embed_size_per_head)`. |
| hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed): |
| Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
| one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
| attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed): |
| Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. |
| attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| The attention mask used in the forward pass, potentially expanded to accommodate timeseries patches. |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. |
| new_token_positions (`torch.LongTensor` of shape `(batch_size, num_new_tokens)`, *optional*): |
| Positions where new tokens (not from timeseries) are located in the expanded sequence. |
| """ |
| attention_mask: Optional[torch.FloatTensor] = None |
| labels: Optional[torch.LongTensor] = None |
| new_token_positions: Optional[torch.LongTensor] = None |
|
|
| try: |
| LossKwargs |
| _BaseKwargs = LossKwargs |
| except NameError: |
| _BaseKwargs = TransformersKwargs |
|
|
|
|
| class KwargsForCausalLM(FlashAttentionKwargs, _BaseKwargs): ... |
|
|
|
|
| class Qwen3TSGenerationMixin(GenerationMixin): |
| """ |
| Generation mixin for Qwen3 models with timeseries support. |
| |
| This mixin handles the special case where timeseries embeddings expand the sequence length |
| during the first forward pass, requiring special attention mask management. |
| """ |
| |
| def prepare_inputs_for_generation( |
| self, |
| input_ids: torch.LongTensor, |
| past_key_values: Optional[Cache] = None, |
| attention_mask: Optional[torch.LongTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| timeseries: Optional[torch.FloatTensor] = None, |
| **kwargs, |
| ) -> Dict[str, Any]: |
| """ |
| Prepare inputs for generation with timeseries support. |
| |
| Timeseries are only processed during the first forward pass. In subsequent |
| generation steps, they are already embedded in the past_key_values. |
| """ |
| |
| has_ts = timeseries is not None and len(timeseries) > 0 |
| |
| |
| if has_ts and past_key_values is not None: |
| |
| if isinstance(past_key_values, Cache): |
| past_length = past_key_values.seen_tokens |
| else: |
| past_length = past_key_values[0][0].shape[2] if past_key_values[0] is not None else 0 |
| |
| |
| if past_length > 0: |
| |
| input_ids = input_ids[:, -1:] |
| |
| timeseries = None |
| has_ts = False |
| |
| |
| model_inputs = super().prepare_inputs_for_generation( |
| input_ids=input_ids, |
| past_key_values=past_key_values, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| cache_position=cache_position, |
| **kwargs |
| ) |
| |
| |
| model_inputs["timeseries"] = timeseries |
| |
| return model_inputs |
| |
| def _update_model_kwargs_for_generation( |
| self, |
| outputs: ModelOutput, |
| model_kwargs: Dict[str, Any], |
| is_encoder_decoder: bool = False, |
| num_new_tokens: int = 1, |
| ) -> Dict[str, Any]: |
| """ |
| Update model keyword arguments for generation, handling attention mask from outputs. |
| |
| This is necessary because timeseries processing can expand the sequence length, |
| and we need to use the expanded attention_mask from the model outputs. |
| """ |
| |
| |
| if hasattr(outputs, "attention_mask") and outputs.attention_mask is not None: |
| model_kwargs["attention_mask"] = outputs.attention_mask |
| |
| |
| model_kwargs = super()._update_model_kwargs_for_generation( |
| outputs=outputs, |
| model_kwargs=model_kwargs, |
| is_encoder_decoder=is_encoder_decoder, |
| num_new_tokens=num_new_tokens, |
| ) |
| |
| return model_kwargs |
| |
| def generate( |
| self, |
| inputs: Optional[torch.Tensor] = None, |
| timeseries: Optional[torch.FloatTensor] = None, |
| generation_config=None, |
| **kwargs, |
| ): |
| """ |
| Generate sequences with timeseries support. |
| |
| Args: |
| inputs: Input token ids |
| timeseries: Optional timeseries data to be processed in the first forward pass |
| generation_config: Generation configuration |
| **kwargs: Additional keyword arguments for generation |
| """ |
| |
| if timeseries is not None: |
| kwargs["timeseries"] = timeseries |
| |
| |
| return super().generate( |
| inputs=inputs, |
| generation_config=generation_config, |
| **kwargs, |
| ) |
| |
| def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]) -> None: |
| """ |
| Validate model kwargs, allowing timeseries as a valid argument. |
| """ |
| |
| timeseries = model_kwargs.pop("timeseries", None) |
| |
| |
| super()._validate_model_kwargs(model_kwargs) |
| |
| |
| if timeseries is not None: |
| model_kwargs["timeseries"] = timeseries |
|
|
| class Qwen3TSPreTrainedModel(Qwen3PreTrainedModel): |
| config_class = Qwen3TSConfig |
|
|
| @auto_docstring |
| class Qwen3TSForCausalLM(Qwen3TSPreTrainedModel, Qwen3TSGenerationMixin): |
| _tied_weights_keys = ["lm_head.weight"] |
| _tp_plan = {"lm_head": "colwise_rep"} |
| _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = Qwen3Model(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.ts_encoder = TimeSeriesEmbedding(config.ts) |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def set_decoder(self, decoder): |
| self.model = decoder |
|
|
| def get_decoder(self): |
| return self.model |
| |
| def _merge_input_ids_with_time_series_features(self, time_series_features, inputs_embeds, input_ids, attention_mask, labels, patch_cnt): |
| batch_size, sequence_length = input_ids.shape |
| _left_padding = torch.any(attention_mask[:, 0] == 0) |
| _right_padding = torch.any(attention_mask[:, -1] == 0) |
| left_padding = False |
| if batch_size > 1: |
| if _left_padding and not _right_padding: |
| left_padding = True |
| elif not _left_padding and _right_padding: |
| left_padding = False |
| elif not _left_padding and not _right_padding: |
| left_padding = False |
| else: |
| raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}") |
| else: |
| if _left_padding and not _right_padding: |
| left_padding = True |
| else: |
| left_padding = False |
|
|
| |
| special_ts_token_mask_start = input_ids == self.config.ts_token_start_index |
| special_ts_token_mask_end = input_ids == self.config.ts_token_end_index |
| special_ts_token_mask = special_ts_token_mask_start | special_ts_token_mask_end |
|
|
| |
| num_special_ts_tokens = torch.sum(special_ts_token_mask_start, dim=-1) |
| total_time_steps, embed_dim = time_series_features.shape |
|
|
| |
| patch_index = 0 |
| num_total_patches = torch.zeros(batch_size, dtype=patch_cnt.dtype, device=patch_cnt.device) |
| special_ts_token_mask_start_nonzero = special_ts_token_mask_start.nonzero() |
| special_ts_token_mask_start_with_size = special_ts_token_mask_start.clone().long() |
|
|
| attn_mask_cnt = attention_mask.sum(dim=-1) |
| for i in range(batch_size): |
| num_ts_in_batch = num_special_ts_tokens[i] |
| num_total_patches[i] = patch_cnt[patch_index : patch_index + num_ts_in_batch].sum() - 2 * num_ts_in_batch |
| for idx in range(patch_index, patch_index + num_ts_in_batch): |
| b_idx, pos = special_ts_token_mask_start_nonzero[idx] |
| special_ts_token_mask_start_with_size[b_idx, pos] *= (patch_cnt[idx].item() - 2) |
| patch_index += num_ts_in_batch |
| attn_mask_cnt[i] += num_total_patches[i].item() |
|
|
| |
| max_embed_dim = sequence_length + num_total_patches.max() |
|
|
| |
| batch_indices, non_ts_indices = torch.where(~special_ts_token_mask) |
| attn_batch_indices, attn_indices = torch.where(attention_mask == 1) |
|
|
| |
| new_token_positions = torch.cumsum((special_ts_token_mask_start_with_size + 1), dim=-1) - 1 |
|
|
| |
| nb_ts_pad = max_embed_dim - 1 - new_token_positions[:, -1] |
| if left_padding: |
| new_token_positions += nb_ts_pad[:, None] |
|
|
| text_to_overwrite = new_token_positions[batch_indices, non_ts_indices] |
|
|
| |
| final_embedding = torch.zeros( |
| batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device |
| ) |
|
|
| final_attention_mask = torch.zeros(batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device) |
| for i in range(attention_mask.size(0)): |
| if left_padding: |
| final_attention_mask[i, max_embed_dim - attn_mask_cnt[i] :] = 1 |
| else: |
| final_attention_mask[i, : attn_mask_cnt[i]] = 1 |
|
|
| final_labels = None |
| if labels is not None: |
| final_labels = torch.full( |
| (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device |
| ) |
|
|
| target_device = inputs_embeds.device |
| batch_indices, non_ts_indices, text_to_overwrite = ( |
| batch_indices.to(target_device), |
| non_ts_indices.to(target_device), |
| text_to_overwrite.to(target_device), |
| ) |
|
|
| |
| final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_ts_indices] |
| if labels is not None: |
| final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_ts_indices] |
|
|
| |
| ts_to_overwrite = torch.full( |
| (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device |
| ) |
| ts_to_overwrite[batch_indices, text_to_overwrite] = False |
|
|
| reversed_cumsum = ts_to_overwrite.flip(dims=[-1]).cumsum(-1).flip(dims=[-1]) - 1 |
| ts_to_overwrite &= reversed_cumsum >= nb_ts_pad[:, None].to(target_device) |
|
|
| |
| if ts_to_overwrite.sum() != time_series_features.shape[:-1].numel(): |
| raise ValueError( |
| f"The input provided to the model are wrong. The number of time series tokens is {torch.sum(special_ts_token_mask_start)} while" |
| f" the number of time series given to the model is {len(patch_cnt)}. This prevents correct indexing and breaks batch generation." |
| ) |
| final_embedding[ts_to_overwrite] = time_series_features.contiguous().reshape(-1, embed_dim).to(target_device) |
| |
| |
|
|
| |
| position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) |
| if position_ids.size(-1) < input_ids.size(-1): |
| position_ids = position_ids[:, -input_ids.size(-1) :] |
|
|
| |
| |
| pad_batch_indices, pad_indices = torch.where(input_ids == self.config.pad_token_id) |
| if len(pad_batch_indices) > 0: |
| indices_to_mask = new_token_positions[pad_batch_indices, pad_indices] |
| final_embedding[pad_batch_indices, indices_to_mask] = 0 |
|
|
| |
| new_token_positions = new_token_positions.masked_fill(attention_mask == 0, -1) |
|
|
| return final_embedding, final_attention_mask, position_ids, final_labels, new_token_positions |
|
|
| @can_return_tuple |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| timeseries: torch.FloatTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| **kwargs: Unpack[KwargsForCausalLM], |
| ) -> Qwen3TSCausalLMOutputWithPast: |
| r""" |
| Args: |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| Indices of input sequence tokens in the vocabulary. |
| attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Mask to avoid performing attention on padding token indices. |
| position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Indices of positions of each input sequence tokens in the position embeddings. |
| past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): |
| Pre-computed hidden-states (key and values in the attention blocks). |
| inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
| Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. |
| use_cache (`bool`, *optional*): |
| If set to `True`, `past_key_values` key value states are returned. |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
| Indices depicting the position of the input sequence tokens in the sequence. |
| timeseries (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size)`, *optional*): |
| Timeseries data to be encoded and merged with text embeddings. |
| |
| Returns: |
| [`Qwen3TSCausalLMOutputWithPast`] or `tuple(torch.FloatTensor)`: |
| The model outputs with potential timeseries-expanded attention mask. |
| """ |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.get_input_embeddings()(input_ids) |
|
|
| if timeseries is not None and timeseries.shape[0] > 0: |
| |
| |
| ts_features, patch_cnt = self.ts_encoder(timeseries) |
| inputs_embeds = inputs_embeds.to(ts_features.dtype) |
| |
| |
| inputs_embeds, attention_mask, position_ids, labels, new_token_positions = self._merge_input_ids_with_time_series_features( |
| ts_features, inputs_embeds, input_ids, attention_mask, labels, patch_cnt |
| ) |
| |
|
|
| |
| |
| |
| outputs: BaseModelOutputWithPast = self.model( |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| hidden_states = outputs.last_hidden_state |
| |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
| loss = None |
| if labels is not None: |
| loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) |
| |
|
|
| return Qwen3TSCausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| attention_mask=attention_mask, |
| labels=labels, |
| new_token_positions=new_token_positions if timeseries is not None and timeseries.shape[0] > 0 else None, |
| ) |
|
|
|
|
| __all__ = [ |
| "Qwen3TSForCausalLM" |
| ] |