geetu040 commited on
Commit
8bc63f3
·
verified ·
1 Parent(s): 4151ca9

Add files using upload-large-folder tool

Browse files
.gitattributes CHANGED
@@ -1,35 +1 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
README.md ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ metrics:
4
+ - mse
5
+ - mae
6
+ - mase
7
+ - wql
8
+ - crps
9
+ pipeline_tag: time-series-forecasting
10
+ datasets:
11
+ - thuml/UTSD
12
+ - Salesforce/lotsa_data
13
+ - Salesforce/GiftEvalPretrain
14
+ - autogluon/chronos_datasets
15
+ tags:
16
+ - time series
17
+ - time-series
18
+ - forecasting
19
+ - foundation models
20
+ - pretrained models
21
+ - time series foundation models
22
+ - quantized
23
+ - 4-bit
24
+ - bitsandbytes
25
+ - unofficial
26
+ library_name: transformers
27
+ base_model:
28
+ - bytedance-research/Timer-S1
29
+ ---
30
+
31
+ # Timer-S1 Quantized 4-bit
32
+
33
+ This repository contains an **unofficial 4-bit BitsAndBytes quantized checkpoint** derived from [`bytedance-research/Timer-S1`](https://huggingface.co/bytedance-research/Timer-S1).
34
+
35
+ Timer-S1 is a time series foundation model for zero-shot forecasting. The original model card describes Timer-S1 as a decoder-only Mixture-of-Experts Transformer with **8.3B** total parameters, **0.75B** activated parameters per token, and a context length of **11,520**. For details about the original model, architecture, training data, benchmark results, and intended use, refer to the upstream model card and the [Timer-S1 technical report](https://arxiv.org/pdf/2603.04791).
36
+
37
+ This upload preserves the upstream Timer-S1 remote-code implementation files and Apache-2.0 license metadata, but stores the model weights as a local 4-bit quantized checkpoint for lower-memory inference.
38
+
39
+ ## Source and Provenance
40
+
41
+ - **Base model**: `bytedance-research/Timer-S1`
42
+ - **Quantization**: BitsAndBytes 4-bit quantization
43
+ - **Status**: unofficial derivative checkpoint
44
+
45
+ No new training or benchmark claims are made for this quantized checkpoint. Numerical outputs may differ slightly from the original bfloat16 checkpoint because the weights are quantized.
46
+
47
+ ## Quantization Details
48
+
49
+ The checkpoint configuration records the following quantization settings:
50
+
51
+ ```json
52
+ {
53
+ "load_in_4bit": true,
54
+ "load_in_8bit": false,
55
+ "quant_method": "bitsandbytes",
56
+ "bnb_4bit_quant_type": "fp4",
57
+ "bnb_4bit_quant_storage": "uint8",
58
+ "bnb_4bit_compute_dtype": "float32",
59
+ "bnb_4bit_use_double_quant": false
60
+ }
61
+ ```
62
+
63
+ The model config also sets `use_cache` to `false`, matching the local quantized checkpoint.
64
+
65
+ ## Quickstart
66
+
67
+ Install the expected runtime dependencies:
68
+
69
+ ```bash
70
+ pip install torch accelerate bitsandbytes "transformers~=4.57.1"
71
+ ```
72
+
73
+ Load the model with Hugging Face Transformers:
74
+
75
+ ```python
76
+ import torch
77
+ from transformers import AutoModelForCausalLM
78
+
79
+ model = AutoModelForCausalLM.from_pretrained(
80
+ "geetu040/Timer-S1-quantized-4bit",
81
+ trust_remote_code=True,
82
+ device_map="auto",
83
+ )
84
+
85
+ batch_size, lookback_length = 1, 2880
86
+ seqs = torch.randn(batch_size, lookback_length).to(model.device)
87
+
88
+ forecast_length = 256
89
+ output = model.generate(seqs, max_new_tokens=forecast_length, revin=True)
90
+
91
+ # Timer-S1 generates forecasts at quantile levels:
92
+ # [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
93
+ print(output.shape) # batch_size x quantile_num(9) x forecast_length
94
+ print(output[0][4]) # median forecast for the first sample
95
+ ```
96
+
97
+ ## Specification
98
+
99
+ - **Architecture**: decoder-only Transformer with MoE
100
+ - **Context length**: up to 11,520
101
+ - **Patch length**: 16
102
+ - **Quantiles**: 0.1 through 0.9
103
+ - **Hidden size**: 1024
104
+ - **Attention heads**: 16
105
+ - **Experts**: 32 total, 2 selected per token
106
+ - **Hidden layers**: 24
107
+ - **Weight format**: `model.safetensors`
108
+ - **Quantization**: BitsAndBytes 4-bit FP4
109
+
110
+ ## License
111
+
112
+ The upstream Timer-S1 model card lists the model under the Apache-2.0 License. This repository preserves that license metadata.
113
+
114
+ ## Citation
115
+
116
+ If you use this quantized checkpoint, cite the original Timer-S1 paper:
117
+
118
+ ```bibtex
119
+ @article{liu2026timer,
120
+ title={Timer-S1: A Billion-Scale Time Series Foundation Model with Serial Scaling},
121
+ author={Liu, Yong and Su, Xingjian and Wang, Shiyu and Zhang, Haoran and Liu, Haixuan and Wang, Yuxuan and Ye, Zhou and Xiang, Yang and Wang, Jianmin and Long, Mingsheng},
122
+ journal={arXiv preprint arXiv:2603.04791},
123
+ year={2026}
124
+ }
125
+ ```
config.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "TimerS1ForPrediction"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_TimerS1.TimerS1Config",
7
+ "AutoModelForCausalLM": "modeling_TimerS1.TimerS1ForPrediction"
8
+ },
9
+ "dropout_rate": 0.1,
10
+ "dtype": "bfloat16",
11
+ "hidden_act": "silu",
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.02,
14
+ "input_token_len": 16,
15
+ "intermediate_size": 4096,
16
+ "max_position_embeddings": 12800,
17
+ "model_type": "Timer-S1",
18
+ "num_attention_heads": 16,
19
+ "num_experts": 32,
20
+ "num_experts_per_token": 2,
21
+ "num_hidden_layers": 24,
22
+ "num_mtp_tokens": 16,
23
+ "output_token_lens": [
24
+ 16
25
+ ],
26
+ "quantiles": [
27
+ 0.1,
28
+ 0.2,
29
+ 0.3,
30
+ 0.4,
31
+ 0.5,
32
+ 0.6,
33
+ 0.7,
34
+ 0.8,
35
+ 0.9
36
+ ],
37
+ "quantization_config": {
38
+ "_load_in_4bit": true,
39
+ "_load_in_8bit": false,
40
+ "bnb_4bit_compute_dtype": "float32",
41
+ "bnb_4bit_quant_storage": "uint8",
42
+ "bnb_4bit_quant_type": "fp4",
43
+ "bnb_4bit_use_double_quant": false,
44
+ "llm_int8_enable_fp32_cpu_offload": false,
45
+ "llm_int8_has_fp16_weight": false,
46
+ "llm_int8_skip_modules": null,
47
+ "llm_int8_threshold": 6.0,
48
+ "load_in_4bit": true,
49
+ "load_in_8bit": false,
50
+ "quant_method": "bitsandbytes"
51
+ },
52
+ "rope_theta": 10000,
53
+ "transformers_version": "4.57.6",
54
+ "use_cache": false
55
+ }
configuration_TimerS1.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http:www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List
16
+ from transformers import PretrainedConfig
17
+
18
+
19
+ class TimerS1Config(PretrainedConfig):
20
+ model_type = "Timer-S1"
21
+ keys_to_ignore_at_inference = ["past_key_values"]
22
+
23
+ def __init__(
24
+ self,
25
+ input_token_len: int = 16,
26
+ hidden_size: int = 1024,
27
+ intermediate_size: int = 4096,
28
+ output_token_lens: List[int] = [16],
29
+ num_hidden_layers: int = 24,
30
+ num_attention_heads: int = 16,
31
+ hidden_act: str = "silu",
32
+ use_cache: bool = True,
33
+ rope_theta: int = 10000,
34
+ dropout_rate: float = 0.1,
35
+ initializer_range: float = 0.02,
36
+ max_position_embeddings: int = 12800,
37
+ quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
38
+ num_experts: int = 32,
39
+ num_experts_per_token: int = 2,
40
+ # MTP configuration
41
+ num_mtp_tokens: int = 16,
42
+ **kwargs,
43
+ ):
44
+ self.input_token_len = input_token_len
45
+ self.hidden_size = hidden_size
46
+ self.intermediate_size = intermediate_size
47
+ self.num_hidden_layers = num_hidden_layers
48
+ self.num_attention_heads = num_attention_heads
49
+ self.hidden_act = hidden_act
50
+ self.output_token_lens = output_token_lens
51
+ self.use_cache = use_cache
52
+ self.rope_theta = rope_theta
53
+ self.dropout_rate = dropout_rate
54
+ self.initializer_range = initializer_range
55
+ self.max_position_embeddings = max_position_embeddings
56
+ self.quantiles = quantiles
57
+ self.num_experts = num_experts
58
+ self.num_experts_per_token = num_experts_per_token
59
+ # MTP configuration
60
+ self.num_mtp_tokens = num_mtp_tokens
61
+ super().__init__(**kwargs)
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.57.6"
4
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:01fe733c374f791d4b00261f5c4bc35690095ba46c9ad87b6751f77f667727f8
3
+ size 4674120678
modeling_TimerS1.py ADDED
@@ -0,0 +1,836 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http:www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Tuple, List, Union
16
+ import math
17
+ from dataclasses import dataclass
18
+
19
+ import torch
20
+ from torch import nn
21
+ import torch.nn.functional as F
22
+ from transformers import PreTrainedModel, Cache, DynamicCache
23
+ from transformers.activations import ACT2FN
24
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
25
+ from transformers.modeling_outputs import MoeModelOutputWithPast, MoeCausalLMOutputWithPast
26
+
27
+ from .configuration_TimerS1 import TimerS1Config
28
+ from .ts_generation_mixin import TSGenerationMixin
29
+
30
+
31
+ @dataclass
32
+ class TimerS1CausalLMOutput(MoeCausalLMOutputWithPast):
33
+ """Extends MoeCausalLMOutputWithPast with hidden_states_for_mtp as a proper dataclass field
34
+ so it is reliably registered in the ModelOutput OrderedDict and accessible via attribute access."""
35
+ hidden_states_for_mtp: Optional[torch.FloatTensor] = None
36
+
37
+ def _get_usable_past_kv_length(cache: Cache, new_seq_length: int, layer_idx: int = 0) -> int:
38
+ """Compute the usable past length for the given cache and upcoming new sequence length.
39
+ This mirrors the previous `get_usable_length(new_seq_length, layer_idx)` behavior that existed in
40
+ Transformers < 4.45, while being compatible with the new Cache API.
41
+ """
42
+ try:
43
+ previous_length = cache.get_seq_length(layer_idx)
44
+ # Dynamic layers return -1, static layers return an int
45
+ max_length = cache.get_max_cache_shape(layer_idx)
46
+ if max_length is not None and max_length != -1 and previous_length + new_seq_length > max_length:
47
+ return max_length - new_seq_length
48
+ return previous_length
49
+ except Exception:
50
+ # Best-effort fallback
51
+ return cache.get_seq_length(layer_idx) if hasattr(cache, "get_seq_length") else 0
52
+
53
+ @dataclass
54
+ class TempMoeModelOutputWithPast(MoeModelOutputWithPast):
55
+ last_hidden_state: torch.FloatTensor = None
56
+ past_key_values: Optional[
57
+ Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor]]]
58
+ ] = None
59
+ use_legacy_cache: Optional[bool] = None
60
+ past_key_values_length: Optional[int] = None
61
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
62
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
63
+ router_logits: Optional[Tuple[torch.FloatTensor]] = None
64
+
65
+ def rotate_half(x):
66
+ x1 = x[..., : x.shape[-1] // 2]
67
+ x2 = x[..., x.shape[-1] // 2:]
68
+ return torch.cat((-x2, x1), dim=-1)
69
+
70
+
71
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
72
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
73
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
74
+ q_embed = (q * cos) + (rotate_half(q) * sin)
75
+ k_embed = (k * cos) + (rotate_half(k) * sin)
76
+ return q_embed, k_embed
77
+
78
+ class RMSNorm(nn.Module):
79
+ def __init__(self, dim: int, eps: float = 1e-6):
80
+ super().__init__()
81
+ self.eps = eps
82
+ self.weight = nn.Parameter(torch.ones(dim))
83
+
84
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
85
+ rms = x.pow(2).mean(dim=-1, keepdim=True).sqrt()
86
+ x_norm = x / (rms + self.eps)
87
+ return x_norm * self.weight
88
+
89
+ class ResidualBlock(nn.Module):
90
+ def __init__(self, config: TimerS1Config) -> None:
91
+ super().__init__()
92
+ self.out_dim = len(config.quantiles) * config.output_token_lens[-1]
93
+ self.dropout = nn.Dropout(config.dropout_rate)
94
+ self.hidden_layer = nn.Linear(config.hidden_size, config.hidden_size)
95
+ self.act = ACT2FN[config.hidden_act]
96
+ self.output_layer = nn.Linear(config.hidden_size, self.out_dim)
97
+ self.residual_layer = nn.Linear(config.hidden_size, self.out_dim)
98
+
99
+ def forward(self, x: torch.Tensor):
100
+ hid = self.act(self.hidden_layer(x))
101
+ out = self.dropout(self.output_layer(hid))
102
+ return out + self.residual_layer(x)
103
+
104
+
105
+ class TimerS1PatchEmbedding(nn.Module):
106
+ def __init__(self, config: TimerS1Config):
107
+ super().__init__()
108
+ self.dropout = nn.Dropout(config.dropout_rate)
109
+ self.hidden_layer = nn.Linear(config.input_token_len * 2, config.intermediate_size)
110
+ self.act = ACT2FN[config.hidden_act]
111
+ self.output_layer = nn.Linear(config.intermediate_size, config.hidden_size)
112
+ self.residual_layer = nn.Linear(config.input_token_len * 2, config.hidden_size)
113
+ self.input_token_len = config.input_token_len
114
+
115
+ def forward(self, x):
116
+ mask = torch.ones_like(x)
117
+ input_length = x.shape[-1]
118
+ padding_length = (self.input_token_len - (input_length % self.input_token_len)) % self.input_token_len
119
+ x = F.pad(x, (padding_length, 0))
120
+ mask = F.pad(mask, (padding_length, 0))
121
+ x = x.unfold(dimension=-1, size=self.input_token_len, step=self.input_token_len)
122
+ mask = mask.unfold(dimension=-1, size=self.input_token_len, step=self.input_token_len)
123
+ x = torch.cat([x, mask], dim=-1)
124
+ hid = self.act(self.hidden_layer(x))
125
+ out = self.dropout(self.output_layer(hid))
126
+ return out + self.residual_layer(x)
127
+
128
+
129
+ class TimerS1RotaryEmbedding(torch.nn.Module):
130
+ def __init__(self, dim, max_position_embeddings=10000, base=10000, device=None):
131
+ super().__init__()
132
+ self.dim = dim
133
+ self.max_position_embeddings = max_position_embeddings
134
+ self.base = base
135
+ inv_freq = 1.0 / (
136
+ self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)
137
+ )
138
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
139
+ self._set_cos_sin_cache(
140
+ seq_len=max_position_embeddings,
141
+ device=self.inv_freq.device,
142
+ dtype=torch.get_default_dtype(),
143
+ )
144
+
145
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
146
+ self.max_seq_len_cached = seq_len
147
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
148
+ freqs = torch.outer(t, self.inv_freq)
149
+ emb = torch.cat((freqs, freqs), dim=-1)
150
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
151
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
152
+
153
+ def forward(self, x, seq_len=None):
154
+ if seq_len > self.max_seq_len_cached:
155
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
156
+ return (
157
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
158
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
159
+ )
160
+
161
+ class TimerS1Attention(nn.Module):
162
+ def __init__(self, config: TimerS1Config, layer_idx: Optional[int] = None):
163
+ super().__init__()
164
+ self.layer_idx = layer_idx
165
+ self.hidden_size = config.hidden_size
166
+ self.num_heads = config.num_attention_heads
167
+ self.head_dim = self.hidden_size // self.num_heads
168
+ self.attention_dropout = config.dropout_rate
169
+
170
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
171
+ self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
172
+ self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
173
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
174
+
175
+ # QK-Norm learnable scales
176
+ self.q_scale = nn.Parameter(torch.ones(self.head_dim))
177
+ self.k_scale = nn.Parameter(torch.ones(self.head_dim))
178
+
179
+ # Attention output gate
180
+ self.gate_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
181
+
182
+ self.rotary_emb = TimerS1RotaryEmbedding(
183
+ self.head_dim,
184
+ max_position_embeddings=config.max_position_embeddings,
185
+ base=config.rope_theta,
186
+ )
187
+
188
+ def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
189
+ eps = 1e-6
190
+ q = q * torch.rsqrt(q.pow(2).mean(dim=-1, keepdim=True) + eps) * self.q_scale.view(1, 1, 1, -1)
191
+ k = k * torch.rsqrt(k.pow(2).mean(dim=-1, keepdim=True) + eps) * self.k_scale.view(1, 1, 1, -1)
192
+ return q, k
193
+
194
+ def forward(
195
+ self,
196
+ hidden_states: torch.Tensor,
197
+ attention_mask: Optional[torch.Tensor] = None,
198
+ position_ids: Optional[torch.LongTensor] = None,
199
+ past_key_value: Optional[Cache] = None,
200
+ output_attentions: bool = False,
201
+ **kwargs,
202
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
203
+ bsz, q_len, _ = hidden_states.size()
204
+
205
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
206
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
207
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
208
+
209
+ kv_seq_len = key_states.shape[-2]
210
+ if past_key_value is not None:
211
+ kv_seq_len += _get_usable_past_kv_length(past_key_value, kv_seq_len, self.layer_idx)
212
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
213
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
214
+
215
+ query_states, key_states = self._apply_qk_norm(query_states, key_states)
216
+
217
+ if past_key_value is not None:
218
+ key_states, value_states = past_key_value.update(
219
+ key_states, value_states, self.layer_idx)
220
+
221
+ attn_output = F.scaled_dot_product_attention(
222
+ query_states,
223
+ key_states,
224
+ value_states,
225
+ attention_mask,
226
+ dropout_p=(self.attention_dropout if self.training else 0.0),
227
+ ) # [bsz, num_heads, q_len, head_dim]
228
+
229
+ gate = torch.sigmoid(self.gate_proj(hidden_states))
230
+ gate = gate.view(bsz, q_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
231
+ attn_output = attn_output * gate
232
+
233
+ attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
234
+ attn_output = self.o_proj(attn_output)
235
+
236
+ attn_weights = None if not output_attentions else attn_output
237
+ return attn_output, attn_weights, past_key_value
238
+
239
+ class TimerS1MLP(nn.Module):
240
+ def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str):
241
+ super().__init__()
242
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
243
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
244
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
245
+ self.act_fn = ACT2FN[hidden_act]
246
+
247
+ def forward(self, hidden_state):
248
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
249
+
250
+ class TimerS1ExpertsLayer(nn.Module):
251
+ def __init__(self, config: TimerS1Config):
252
+ super().__init__()
253
+ self.top_k = config.num_experts_per_token
254
+ self.hidden_size = config.hidden_size
255
+ self.num_experts = config.num_experts
256
+ moe_intermediate_size = config.intermediate_size // self.top_k
257
+
258
+ self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
259
+ self.experts = nn.ModuleList([
260
+ TimerS1MLP(
261
+ hidden_size=config.hidden_size,
262
+ intermediate_size=moe_intermediate_size,
263
+ hidden_act=config.hidden_act,
264
+ )
265
+ for _ in range(self.num_experts)
266
+ ])
267
+
268
+ def forward(self, hidden_states: torch.Tensor):
269
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
270
+ hidden_states = hidden_states.view(-1, hidden_dim)
271
+ router_logits = self.gate(hidden_states)
272
+
273
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
274
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
275
+ routing_weights = routing_weights.to(hidden_states.dtype)
276
+
277
+ final_hidden_states = torch.zeros(
278
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
279
+ )
280
+
281
+ expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
282
+
283
+ for expert_idx in range(self.num_experts):
284
+ expert_layer = self.experts[expert_idx]
285
+ idx, top_x = torch.where(expert_mask[expert_idx])
286
+
287
+ if top_x.numel() == 0:
288
+ continue
289
+
290
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
291
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
292
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
293
+
294
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
295
+ return final_hidden_states
296
+
297
+ class TimerS1DecoderLayer(nn.Module):
298
+ def __init__(self, config: TimerS1Config, layer_idx: int):
299
+ super().__init__()
300
+ self.self_attn = TimerS1Attention(config, layer_idx)
301
+ self.ffn_layer = TimerS1ExpertsLayer(config)
302
+ self.norm1 = RMSNorm(config.hidden_size)
303
+ self.norm2 = RMSNorm(config.hidden_size)
304
+
305
+ def forward(
306
+ self,
307
+ hidden_states: torch.Tensor,
308
+ attention_mask: Optional[torch.Tensor] = None,
309
+ position_ids: Optional[torch.LongTensor] = None,
310
+ past_key_value: Optional[Cache] = None,
311
+ output_attentions: Optional[bool] = False,
312
+ use_cache: Optional[bool] = False,
313
+ **kwargs,
314
+ ) -> Tuple[torch.FloatTensor, Optional[torch.Tensor], Optional[Cache]]:
315
+ residual = hidden_states
316
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
317
+ hidden_states=self.norm1(hidden_states),
318
+ attention_mask=attention_mask,
319
+ position_ids=position_ids,
320
+ past_key_value=past_key_value,
321
+ output_attentions=output_attentions,
322
+ )
323
+ hidden_states = residual + hidden_states
324
+
325
+ residual = hidden_states
326
+ hidden_states = self.ffn_layer(self.norm2(hidden_states))
327
+ hidden_states = residual + hidden_states
328
+
329
+ if not output_attentions:
330
+ self_attn_weights = None
331
+ if not use_cache:
332
+ present_key_value = None
333
+
334
+ return hidden_states, self_attn_weights, present_key_value
335
+
336
+
337
+ class TimerS1PreTrainedModel(PreTrainedModel):
338
+ config_class = TimerS1Config
339
+ base_model_prefix = "model"
340
+ supports_gradient_checkpointing = True
341
+ _no_split_modules = ["TimerS1DecoderLayer"]
342
+ _skip_keys_device_placement = "past_key_values"
343
+ _supports_flash_attn_2 = True
344
+ _supports_sdpa = False
345
+ _supports_cache_class = True
346
+
347
+ def _init_weights(self, module):
348
+ std = self.config.initializer_range
349
+ if isinstance(module, nn.Linear):
350
+ module.weight.data.normal_(mean=0.0, std=std)
351
+ if module.bias is not None:
352
+ module.bias.data.zero_()
353
+ elif isinstance(module, nn.Embedding):
354
+ module.weight.data.normal_(mean=0.0, std=std)
355
+ if module.padding_idx is not None:
356
+ module.weight.data[module.padding_idx].zero_()
357
+
358
+ class TimerS1Model(TimerS1PreTrainedModel):
359
+ def __init__(self, config: TimerS1Config):
360
+ super().__init__(config)
361
+ self.embed_layer = TimerS1PatchEmbedding(config)
362
+ self.layers = nn.ModuleList([
363
+ TimerS1DecoderLayer(config, layer_idx)
364
+ for layer_idx in range(config.num_hidden_layers)
365
+ ])
366
+ self.norm = RMSNorm(config.hidden_size)
367
+ self.gradient_checkpointing = False
368
+
369
+ def forward(
370
+ self,
371
+ input_ids: torch.FloatTensor = None,
372
+ attention_mask: Optional[torch.Tensor] = None,
373
+ position_ids: Optional[torch.LongTensor] = None,
374
+ past_key_values: Optional[
375
+ Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor]]]
376
+ ] = None,
377
+ inputs_embeds: Optional[torch.FloatTensor] = None,
378
+ use_cache: Optional[bool] = None,
379
+ output_attentions: Optional[bool] = None,
380
+ output_hidden_states: Optional[bool] = None,
381
+ return_dict: Optional[bool] = None,
382
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
383
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
384
+ output_hidden_states = (
385
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
386
+ )
387
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
388
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
389
+
390
+ if input_ids is not None and inputs_embeds is not None:
391
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
392
+ elif input_ids is not None:
393
+ batch_size, seq_length = input_ids.shape
394
+ elif inputs_embeds is not None:
395
+ batch_size, seq_length, _ = inputs_embeds.shape
396
+ else:
397
+ raise ValueError("You must specify either input_ids or inputs_embeds")
398
+
399
+ if inputs_embeds is None:
400
+ inputs_embeds = self.embed_layer(input_ids)
401
+ seq_length = inputs_embeds.shape[1]
402
+
403
+ if self.gradient_checkpointing and self.training and use_cache:
404
+ use_cache = False
405
+
406
+ past_key_values_length = 0
407
+ use_legacy_cache = None
408
+ if use_cache:
409
+ use_legacy_cache = not isinstance(past_key_values, Cache)
410
+ if use_legacy_cache:
411
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
412
+ past_key_values_length = _get_usable_past_kv_length(past_key_values, seq_length)
413
+
414
+ if position_ids is None:
415
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
416
+ position_ids = torch.arange(
417
+ past_key_values_length, seq_length + past_key_values_length,
418
+ dtype=torch.long, device=device,
419
+ ).view(-1, seq_length)
420
+ else:
421
+ position_ids = position_ids.view(-1, seq_length).long()
422
+
423
+ attention_mask = _prepare_4d_causal_attention_mask(
424
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length, sliding_window=None,
425
+ )
426
+
427
+ hidden_states = inputs_embeds
428
+
429
+ all_hidden_states = () if output_hidden_states else None
430
+ all_self_attns = () if output_attentions else None
431
+ all_moe_losses = []
432
+
433
+ for decoder_layer in self.layers:
434
+ if output_hidden_states:
435
+ all_hidden_states += (hidden_states,)
436
+
437
+ layer_outputs = decoder_layer(
438
+ hidden_states,
439
+ attention_mask=attention_mask,
440
+ position_ids=position_ids,
441
+ past_key_value=past_key_values,
442
+ output_attentions=output_attentions,
443
+ use_cache=use_cache,
444
+ )
445
+
446
+ hidden_states = layer_outputs[0]
447
+
448
+ if output_attentions:
449
+ all_self_attns += (layer_outputs[1],)
450
+
451
+ hidden_states = self.norm(hidden_states)
452
+ if output_hidden_states:
453
+ all_hidden_states += (hidden_states,)
454
+
455
+ if not return_dict:
456
+ return tuple(
457
+ v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_moe_losses]
458
+ if v is not None
459
+ )
460
+
461
+ return TempMoeModelOutputWithPast(
462
+ last_hidden_state=hidden_states,
463
+ past_key_values=past_key_values,
464
+ hidden_states=all_hidden_states,
465
+ attentions=all_self_attns,
466
+ use_legacy_cache=use_legacy_cache,
467
+ past_key_values_length=past_key_values_length,
468
+ router_logits=all_moe_losses,
469
+ )
470
+
471
+ class TimerS1MTPLayer(nn.Module):
472
+ def __init__(self, config: TimerS1Config, layer_idx: int):
473
+ super().__init__()
474
+ self.hidden_size = config.hidden_size
475
+ self.config = config
476
+ self.layer_idx = layer_idx
477
+ self.norm_hidden = RMSNorm(config.hidden_size)
478
+ self.norm_embeds = RMSNorm(config.hidden_size)
479
+ self.projection_matrix = nn.Linear(2 * self.hidden_size, self.hidden_size, bias=False)
480
+ self.layer = TimerS1DecoderLayer(config, self.layer_idx + self.config.num_hidden_layers)
481
+ self.norm = RMSNorm(config.hidden_size)
482
+ self.gradient_checkpointing = False
483
+
484
+ def forward(
485
+ self,
486
+ hidden_states: torch.FloatTensor = None,
487
+ attention_mask: Optional[torch.Tensor] = None,
488
+ position_ids: Optional[torch.LongTensor] = None,
489
+ past_key_values: Optional[
490
+ Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor]]]
491
+ ] = None,
492
+ use_legacy_cache: Optional[bool] = False,
493
+ past_key_values_length: Optional[int] = 0,
494
+ inputs_embeds: Optional[torch.FloatTensor] = None,
495
+ use_cache: Optional[bool] = None,
496
+ output_attentions: Optional[bool] = None,
497
+ output_hidden_states: Optional[bool] = None,
498
+ return_dict: Optional[bool] = None,
499
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
500
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
501
+ output_hidden_states = (
502
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
503
+ )
504
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
505
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
506
+
507
+ if inputs_embeds is not None:
508
+ batch_size, seq_length, _ = inputs_embeds.shape
509
+ else:
510
+ raise ValueError("You must specify inputs_embeds")
511
+
512
+ if self.gradient_checkpointing and self.training:
513
+ if use_cache:
514
+ use_cache = False
515
+
516
+ if position_ids is None:
517
+ device = inputs_embeds.device
518
+ position_ids = torch.arange(
519
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
520
+ )
521
+ position_ids = position_ids.view(-1, seq_length)
522
+ else:
523
+ position_ids = position_ids.view(-1, seq_length).long()
524
+
525
+ attention_mask = _prepare_4d_causal_attention_mask(
526
+ attention_mask,
527
+ (batch_size, seq_length),
528
+ inputs_embeds,
529
+ past_key_values_length,
530
+ sliding_window=None,
531
+ )
532
+
533
+ hidden_states = self.norm_hidden(hidden_states)
534
+ inputs_embeds = self.norm_embeds(inputs_embeds)
535
+ hidden_states = self.projection_matrix(torch.cat([hidden_states, inputs_embeds], dim=-1))
536
+
537
+ all_hidden_states = () if output_hidden_states else None
538
+ all_self_attns = () if output_attentions else None
539
+ all_moe_losses = []
540
+ next_decoder_cache = None
541
+
542
+ if output_hidden_states:
543
+ all_hidden_states += (hidden_states,)
544
+
545
+ if self.gradient_checkpointing and self.training:
546
+ layer_outputs = self._gradient_checkpointing_func(
547
+ self.layer.__call__,
548
+ hidden_states,
549
+ attention_mask,
550
+ position_ids,
551
+ past_key_values,
552
+ output_attentions,
553
+ use_cache,
554
+ )
555
+ else:
556
+ layer_outputs = self.layer(
557
+ hidden_states,
558
+ attention_mask=attention_mask,
559
+ position_ids=position_ids,
560
+ past_key_value=past_key_values,
561
+ output_attentions=output_attentions,
562
+ use_cache=use_cache,
563
+ )
564
+
565
+ hidden_states = layer_outputs[0]
566
+
567
+ if output_attentions:
568
+ all_self_attns += (layer_outputs[1],)
569
+
570
+ if use_cache:
571
+ next_decoder_cache = layer_outputs[2]
572
+
573
+ hidden_states = self.norm(hidden_states)
574
+
575
+ if output_hidden_states:
576
+ all_hidden_states += (hidden_states,)
577
+
578
+ next_cache = None
579
+ if use_cache:
580
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
581
+
582
+ if not return_dict:
583
+ return tuple(
584
+ v
585
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_moe_losses]
586
+ if v is not None
587
+ )
588
+ return MoeModelOutputWithPast(
589
+ last_hidden_state=hidden_states,
590
+ past_key_values=next_cache,
591
+ hidden_states=all_hidden_states,
592
+ attentions=all_self_attns,
593
+ router_logits=all_moe_losses,
594
+ )
595
+
596
+ class TimerS1ForPrediction(TimerS1PreTrainedModel, TSGenerationMixin):
597
+ def __init__(self, config: TimerS1Config):
598
+ super().__init__(config)
599
+ self.config = config
600
+ self.model = TimerS1Model(self.config)
601
+ self.output_patch_embedding = ResidualBlock(config)
602
+ self.num_quantiles = len(config.quantiles)
603
+ if self.config.num_mtp_tokens > 0:
604
+ self.mtp_modules = nn.ModuleList([
605
+ TimerS1MTPLayer(config, layer_idx)
606
+ for layer_idx in range(self.config.num_mtp_tokens)
607
+ ])
608
+ self.post_init()
609
+
610
+ def set_decoder(self, decoder):
611
+ self.model = decoder
612
+
613
+ def get_decoder(self):
614
+ return self.model
615
+
616
+ def forward(
617
+ self,
618
+ input_ids: torch.FloatTensor = None,
619
+ attention_mask: Optional[torch.Tensor] = None,
620
+ position_ids: Optional[torch.LongTensor] = None,
621
+ past_key_values: Optional[
622
+ Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor]]]
623
+ ] = None,
624
+ inputs_embeds: Optional[torch.FloatTensor] = None,
625
+ full_input_ids: Optional[torch.FloatTensor] = None,
626
+ full_hidden_states: Optional[torch.FloatTensor] = None,
627
+ use_cache: Optional[bool] = None,
628
+ output_attentions: Optional[bool] = None,
629
+ output_hidden_states: Optional[bool] = None,
630
+ return_dict: Optional[bool] = None,
631
+ max_output_length: Optional[int] = None,
632
+ revin: Optional[bool] = False,
633
+ ) -> Union[Tuple, TimerS1CausalLMOutput]:
634
+
635
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
636
+ output_hidden_states = (
637
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
638
+ )
639
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
640
+
641
+ if revin:
642
+ means = input_ids.mean(1, keepdim=True).detach()
643
+ stdev = input_ids.std(dim=1, keepdim=True, unbiased=False).detach()
644
+ stdev = torch.where(stdev > 1e-2, stdev, torch.tensor(1.0, device=input_ids.device))
645
+ input_ids = (input_ids - means) / stdev
646
+ if full_input_ids is not None:
647
+ fi_means = full_input_ids.mean(1, keepdim=True).detach()
648
+ fi_stdev = full_input_ids.std(dim=1, keepdim=True, unbiased=False).detach()
649
+ fi_stdev = torch.where(
650
+ fi_stdev > 1e-2, fi_stdev, torch.tensor(1.0, device=full_input_ids.device)
651
+ )
652
+ full_input_ids = (full_input_ids - fi_means) / fi_stdev
653
+ if inputs_embeds is None and input_ids is not None:
654
+ inputs_embeds = self.model.embed_layer(input_ids)
655
+ # full_inputs_embeds: embeddings for the complete sequence used by MTP layers (no KV cache)
656
+ if full_input_ids is not None:
657
+ full_inputs_embeds = self.model.embed_layer(full_input_ids)
658
+ else:
659
+ full_inputs_embeds = inputs_embeds
660
+
661
+ outputs = self.model(
662
+ input_ids=None,
663
+ attention_mask=attention_mask,
664
+ position_ids=position_ids,
665
+ past_key_values=past_key_values,
666
+ inputs_embeds=inputs_embeds,
667
+ use_cache=use_cache,
668
+ output_attentions=output_attentions,
669
+ output_hidden_states=output_hidden_states,
670
+ return_dict=return_dict,
671
+ )
672
+
673
+ hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
674
+
675
+ # Accumulate full hidden states across generation steps for MTP layers.
676
+ # When KV cache is enabled, hidden_states only covers new tokens, so we need to
677
+ # prepend accumulated past hidden states (full_hidden_states) to restore the full
678
+ # sequence picture needed by MTP layers.
679
+ # When KV cache is disabled, hidden_states already covers the full sequence
680
+ # (same length as full_inputs_embeds), so no accumulation is needed.
681
+ if full_hidden_states is not None and hidden_states.shape[1] < full_inputs_embeds.shape[1]:
682
+ mtp_hidden_states = torch.cat([full_hidden_states.to(hidden_states.device), hidden_states], dim=1)
683
+ else:
684
+ mtp_hidden_states = hidden_states
685
+
686
+ bsz, L, _ = hidden_states.shape
687
+ predictions = None
688
+ loss = None
689
+ if max_output_length is None:
690
+ output_token_len = self.config.output_token_lens[0]
691
+ max_output_length = output_token_len
692
+ else:
693
+ output_token_len = self.config.output_token_lens[0]
694
+ for h in self.config.output_token_lens[1:]:
695
+ if h > max_output_length:
696
+ break
697
+ output_token_len = h
698
+
699
+ predictions = self.output_patch_embedding(hidden_states[:, -1, :]).reshape(
700
+ bsz, self.num_quantiles, self.config.output_token_lens[-1]
701
+ )
702
+
703
+ if self.config.num_mtp_tokens > 0:
704
+ output_patch_len = self.config.output_token_lens[-1]
705
+ full_out_len = output_patch_len + self.config.input_token_len * self.config.num_mtp_tokens
706
+
707
+ target_len = max(0, min(int(max_output_length), int(full_out_len)))
708
+
709
+ out = torch.zeros(bsz, self.num_quantiles, target_len, device=predictions.device)
710
+ base_fill = min(output_patch_len, target_len)
711
+ if base_fill > 0:
712
+ out[:, :, :base_fill] = predictions[:, :, :base_fill]
713
+
714
+ if target_len <= output_patch_len:
715
+ mtp_steps_needed = 0
716
+ else:
717
+ remaining = target_len - output_patch_len
718
+ mtp_steps_needed = min(
719
+ self.config.num_mtp_tokens,
720
+ math.ceil(remaining / self.config.input_token_len),
721
+ )
722
+
723
+ for k, mtp_module in enumerate(self.mtp_modules):
724
+ if k >= mtp_steps_needed:
725
+ break
726
+
727
+ start_pos = (k + 1) * self.config.input_token_len
728
+ if start_pos >= target_len:
729
+ break
730
+
731
+ mtp_full_len = full_inputs_embeds.shape[1]
732
+ mtp_attention_mask = attention_mask[:, -mtp_full_len:] if attention_mask is not None else None
733
+ mtp_outputs = mtp_module(
734
+ hidden_states=mtp_hidden_states,
735
+ inputs_embeds=full_inputs_embeds,
736
+ attention_mask=mtp_attention_mask,
737
+ output_attentions=output_attentions,
738
+ )
739
+ mtp_hidden_states = mtp_outputs[0]
740
+
741
+ mtp_pred = self.output_patch_embedding(mtp_hidden_states)[:, -1, :]
742
+ mtp_pred = mtp_pred.reshape(bsz, self.num_quantiles, output_patch_len)
743
+
744
+ end_pos = min(start_pos + output_patch_len, target_len)
745
+ take = end_pos - start_pos
746
+ if take > 0:
747
+ out[:, :, start_pos:end_pos] = mtp_pred[:, :, :take]
748
+
749
+ predictions = out
750
+
751
+ if max_output_length is not None and predictions.shape[-1] > max_output_length:
752
+ predictions = predictions[:, :, :max_output_length]
753
+ if revin:
754
+ predictions = predictions * stdev + means
755
+ if not return_dict:
756
+ output = (predictions,) + outputs[1:]
757
+ return (loss,) + output if loss is not None else output
758
+
759
+ return TimerS1CausalLMOutput(
760
+ loss=loss,
761
+ logits=predictions,
762
+ past_key_values=outputs.past_key_values,
763
+ hidden_states=outputs.hidden_states,
764
+ attentions=outputs.attentions,
765
+ router_logits=outputs.router_logits,
766
+ # Pass main-model hidden states as a proper field so that
767
+ # _update_model_kwargs_for_generation can reliably accumulate them
768
+ # for the MTP layers across multi-step generation.
769
+ hidden_states_for_mtp=hidden_states,
770
+ )
771
+
772
+ def prepare_inputs_for_generation(
773
+ self,
774
+ input_ids,
775
+ past_key_values=None,
776
+ attention_mask=None,
777
+ inputs_embeds=None,
778
+ revin=False,
779
+ **kwargs,
780
+ ):
781
+ # full_input_ids always holds the complete original sequence for MTP layers
782
+ full_input_ids = input_ids.clone()
783
+ past_length = 0
784
+ if past_key_values is not None:
785
+ if isinstance(past_key_values, Cache):
786
+ cache_length = past_key_values.get_seq_length(0)
787
+ past_length = cache_length
788
+ try:
789
+ max_cache_length = past_key_values.get_max_cache_shape(0)
790
+ if max_cache_length == -1:
791
+ max_cache_length = None
792
+ except Exception:
793
+ max_cache_length = None
794
+ else:
795
+ cache_length = past_length = past_key_values[0][0].shape[2]
796
+ max_cache_length = None
797
+
798
+ # Trim input_ids to only include unprocessed tokens
799
+ if attention_mask is not None and attention_mask.shape[1] > (
800
+ input_ids.shape[1] // self.config.input_token_len
801
+ ):
802
+ input_ids = input_ids[
803
+ :, -(attention_mask.shape[1] - past_length) * self.config.input_token_len:
804
+ ]
805
+ elif past_length < (input_ids.shape[1] // self.config.input_token_len):
806
+ input_ids = input_ids[:, past_length * self.config.input_token_len:]
807
+
808
+ if (
809
+ max_cache_length is not None
810
+ and attention_mask is not None
811
+ and cache_length + (input_ids.shape[1] // self.config.input_token_len) > max_cache_length
812
+ ):
813
+ attention_mask = attention_mask[:, -max_cache_length:]
814
+
815
+ position_ids = kwargs.get("position_ids", None)
816
+ if attention_mask is not None and position_ids is None:
817
+ position_ids = attention_mask.long().cumsum(-1) - 1
818
+ position_ids.masked_fill_(attention_mask == 0, 1)
819
+ if past_length > 0:
820
+ position_ids = position_ids[:, -(input_ids.shape[1] // self.config.input_token_len):]
821
+
822
+ if inputs_embeds is not None and past_key_values is None:
823
+ model_inputs = {"inputs_embeds": inputs_embeds}
824
+ else:
825
+ model_inputs = {"input_ids": input_ids}
826
+
827
+ model_inputs.update({
828
+ "position_ids": position_ids,
829
+ "past_key_values": past_key_values,
830
+ "use_cache": kwargs.get("use_cache"),
831
+ "attention_mask": attention_mask,
832
+ "revin": revin,
833
+ "full_input_ids": full_input_ids,
834
+ "full_hidden_states": kwargs.get("full_hidden_states"),
835
+ })
836
+ return model_inputs
ts_generation_mixin.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http:www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import warnings
16
+ from typing import Any, Dict, List, Optional, Union, Callable
17
+ import torch
18
+ from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList
19
+ from transformers.generation import validate_stopping_criteria, EosTokenCriteria
20
+ from transformers.generation.utils import GenerateNonBeamOutput, GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput, GenerationConfig, GenerateOutput
21
+ from transformers.utils import ModelOutput
22
+
23
+ ALL_CACHE_NAMES = [
24
+ "past_key_values", # default
25
+ "cache_params", # mamba-based models
26
+ "state", # rwkv
27
+ "mems", # xlnet
28
+ "past_buckets_states", # reformer
29
+ ]
30
+
31
+ class TSGenerationMixin(GenerationMixin):
32
+ @torch.no_grad()
33
+ def generate(
34
+ self,
35
+ inputs: Optional[torch.Tensor] = None,
36
+ generation_config: Optional[GenerationConfig] = None,
37
+ logits_processor: Optional[LogitsProcessorList] = None,
38
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
39
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
40
+ synced_gpus: Optional[bool] = None,
41
+ assistant_model: Optional["PreTrainedModel"] = None,
42
+ streamer: Optional["BaseStreamer"] = None,
43
+ negative_prompt_ids: Optional[torch.Tensor] = None,
44
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
45
+ revin: Optional[bool] = True,
46
+ **kwargs,
47
+ ) -> Union[GenerateOutput, torch.LongTensor]:
48
+ if len(inputs.shape) != 2:
49
+ raise ValueError('Input shape must be: [batch_size, seq_len]')
50
+ if revin:
51
+ means = inputs.mean(dim=-1, keepdim=True)
52
+ stdev = inputs.std(dim=-1, keepdim=True, unbiased=False) + 1e-5
53
+ inputs = (inputs - means) / stdev
54
+ outputs = super().generate(
55
+ inputs=inputs,
56
+ generation_config=generation_config,
57
+ logits_processor=logits_processor,
58
+ stopping_criteria=stopping_criteria,
59
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
60
+ synced_gpus=synced_gpus,
61
+ assistant_model=assistant_model,
62
+ streamer=streamer,
63
+ negative_prompt_ids=negative_prompt_ids,
64
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
65
+ **kwargs,
66
+ )
67
+ if revin:
68
+ stdev = stdev.unsqueeze(1)
69
+ means = means.unsqueeze(1)
70
+ outputs = (outputs * stdev) + means
71
+ return outputs
72
+
73
+ def _sample(
74
+ self,
75
+ input_ids: torch.Tensor,
76
+ logits_processor: Optional[LogitsProcessorList] = None,
77
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
78
+ max_length: Optional[int] = None,
79
+ pad_token_id: Optional[int] = None,
80
+ eos_token_id: Optional[Union[int, List[int]]] = None,
81
+ output_attentions: Optional[bool] = None,
82
+ output_hidden_states: Optional[bool] = None,
83
+ output_scores: Optional[bool] = None,
84
+ output_logits: Optional[bool] = None,
85
+ return_dict_in_generate: Optional[bool] = None,
86
+ synced_gpus: bool = False,
87
+ streamer: Optional["BaseStreamer"] = None,
88
+ **model_kwargs,
89
+ ) -> Union[GenerateNonBeamOutput, torch.Tensor]:
90
+ input_ids = input_ids.to(self.device)
91
+ batch_size, cur_len = input_ids.shape
92
+ # init values
93
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
94
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
95
+ if max_length is not None:
96
+ warnings.warn(
97
+ "`max_length` is deprecated in this function, use"
98
+ " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
99
+ UserWarning,
100
+ )
101
+ stopping_criteria = validate_stopping_criteria(
102
+ stopping_criteria, max_length)
103
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
104
+ if eos_token_id is not None:
105
+ stopping_criteria.append(
106
+ EosTokenCriteria(eos_token_id=eos_token_id))
107
+ else:
108
+ # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
109
+ eos_token_id = [
110
+ criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
111
+ ]
112
+ eos_token_id = eos_token_id[0] if eos_token_id else None
113
+ if eos_token_id is None and self.generation_config.eos_token_id is not None:
114
+ eos_token_id = self.generation_config.eos_token_id
115
+ stopping_criteria.append(
116
+ EosTokenCriteria(eos_token_id=eos_token_id))
117
+
118
+ if isinstance(eos_token_id, int):
119
+ eos_token_id = [eos_token_id]
120
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
121
+ output_attentions = (
122
+ output_attentions if output_attentions is not None else self.generation_config.output_attentions
123
+ )
124
+ output_hidden_states = (
125
+ output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
126
+ )
127
+ return_dict_in_generate = (
128
+ return_dict_in_generate
129
+ if return_dict_in_generate is not None
130
+ else self.generation_config.return_dict_in_generate
131
+ )
132
+
133
+ # init attention / hidden states / scores tuples
134
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
135
+ scores = () if (return_dict_in_generate and output_scores) else None
136
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
137
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
138
+ decoder_hidden_states = () if (
139
+ return_dict_in_generate and output_hidden_states) else None
140
+
141
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
142
+ if return_dict_in_generate and self.config.is_encoder_decoder:
143
+ encoder_attentions = model_kwargs["encoder_outputs"].get(
144
+ "attentions") if output_attentions else None
145
+ encoder_hidden_states = (
146
+ model_kwargs["encoder_outputs"].get(
147
+ "hidden_states") if output_hidden_states else None
148
+ )
149
+
150
+ # keep track of which sequences are already finished
151
+ if "inputs_embeds" in model_kwargs:
152
+ cur_len = model_kwargs["inputs_embeds"].shape[1]
153
+ this_peer_finished = False
154
+ unfinished_sequences = torch.ones(
155
+ batch_size, dtype=torch.long, device=input_ids.device)
156
+ model_kwargs["cache_position"] = torch.arange(
157
+ cur_len, device=input_ids.device)
158
+ true_seq_len = (cur_len + self.config.input_token_len - 1) // self.config.input_token_len
159
+ model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, -true_seq_len:]
160
+ max_length = stopping_criteria.max_length
161
+
162
+ generate_results = None
163
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
164
+ # prepare model inputs
165
+ model_inputs = self.prepare_inputs_for_generation(
166
+ input_ids, **model_kwargs)
167
+
168
+ input_length = input_ids.shape[1]
169
+
170
+ # forward pass to get next token
171
+ outputs = self(
172
+ **model_inputs,
173
+ return_dict=True,
174
+ output_attentions=output_attentions,
175
+ output_hidden_states=output_hidden_states,
176
+ max_output_length=max_length - input_length,
177
+ )
178
+
179
+ if synced_gpus and this_peer_finished:
180
+ continue # don't waste resources running the code we don't need
181
+ next_token_logits = outputs.logits
182
+
183
+ # pre-process distribution
184
+ next_tokens_scores = logits_processor(input_ids, next_token_logits)
185
+
186
+ # Store scores, attentions and hidden_states when required
187
+ if return_dict_in_generate:
188
+ if output_scores:
189
+ scores += (next_tokens_scores,)
190
+ if output_logits:
191
+ raw_logits += (next_token_logits,)
192
+ if output_attentions:
193
+ decoder_attentions += (
194
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (
195
+ outputs.attentions,)
196
+ )
197
+ if self.config.is_encoder_decoder:
198
+ cross_attentions += (outputs.cross_attentions,)
199
+
200
+ if output_hidden_states:
201
+ decoder_hidden_states += (
202
+ (outputs.decoder_hidden_states,)
203
+ if self.config.is_encoder_decoder
204
+ else (outputs.hidden_states,)
205
+ )
206
+
207
+ # argmax
208
+ # next_tokens = torch.argmax(next_tokens_scores, dim=-1)
209
+ next_tokens = next_tokens_scores
210
+
211
+ # finished sentences should have their next token be a padding token
212
+ if eos_token_id is not None:
213
+ if pad_token_id is None:
214
+ raise ValueError(
215
+ "If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
216
+ next_tokens = next_tokens * unfinished_sequences + \
217
+ pad_token_id * (1 - unfinished_sequences)
218
+
219
+ # update generated ids, model inputs, and length for next step
220
+ horizon_length = next_tokens.shape[-1] // self.config.input_token_len
221
+
222
+ past_key_values = model_kwargs.get("past_key_values")
223
+ if generate_results is None:
224
+ generate_results = next_tokens
225
+ else:
226
+ generate_results = torch.cat([generate_results, next_tokens], dim=-1)
227
+
228
+ # Use deterministic approach instead of median to avoid CUDA deterministic algorithm issues
229
+ # For flow models, use torch.quantile(p=0.5) which is equivalent to median but deterministic
230
+
231
+ selected_tokens = torch.quantile(next_tokens.float(), q=0.5, dim=1)
232
+ input_ids = torch.cat([input_ids, selected_tokens], dim=-1)
233
+
234
+ if streamer is not None:
235
+ streamer.put(next_tokens.cpu())
236
+ model_kwargs = self._update_model_kwargs_for_generation(
237
+ outputs,
238
+ model_kwargs,
239
+ horizon_length=horizon_length,
240
+ is_encoder_decoder=self.config.is_encoder_decoder,
241
+ )
242
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(
243
+ input_ids, scores)
244
+ this_peer_finished = unfinished_sequences.max() == 0
245
+
246
+ if input_ids.shape[-1] > max_length:
247
+ input_ids = input_ids[:, :max_length]
248
+
249
+ if streamer is not None:
250
+ streamer.end()
251
+
252
+ if return_dict_in_generate:
253
+ if self.config.is_encoder_decoder:
254
+ return GenerateEncoderDecoderOutput(
255
+ sequences=input_ids,
256
+ scores=scores,
257
+ logits=raw_logits,
258
+ encoder_attentions=encoder_attentions,
259
+ encoder_hidden_states=encoder_hidden_states,
260
+ decoder_attentions=decoder_attentions,
261
+ cross_attentions=cross_attentions,
262
+ decoder_hidden_states=decoder_hidden_states,
263
+ past_key_values=model_kwargs.get("past_key_values"),
264
+ )
265
+ else:
266
+ return GenerateDecoderOnlyOutput(
267
+ sequences=input_ids,
268
+ scores=scores,
269
+ logits=raw_logits,
270
+ attentions=decoder_attentions,
271
+ hidden_states=decoder_hidden_states,
272
+ past_key_values=model_kwargs.get("past_key_values"),
273
+ )
274
+ else:
275
+ return generate_results[:, :, :(max_length - cur_len)]
276
+
277
+ def _update_model_kwargs_for_generation(
278
+ self,
279
+ outputs: ModelOutput,
280
+ model_kwargs: Dict[str, Any],
281
+ horizon_length: int = 1,
282
+ is_encoder_decoder: bool = False,
283
+ standardize_cache_format: bool = False,
284
+ ) -> Dict[str, Any]:
285
+ # update past_key_values
286
+ for possible_cache_name in ALL_CACHE_NAMES:
287
+ if possible_cache_name in outputs:
288
+ if possible_cache_name in ("past_buckets_states", "mems"):
289
+ cache_name = "past_key_values"
290
+ else:
291
+ cache_name = possible_cache_name
292
+ model_kwargs[cache_name] = getattr(outputs, possible_cache_name)
293
+ break
294
+
295
+ # update token_type_ids with last value
296
+ if "token_type_ids" in model_kwargs:
297
+ token_type_ids = model_kwargs["token_type_ids"]
298
+ model_kwargs["token_type_ids"] = torch.cat(
299
+ [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
300
+
301
+ if not is_encoder_decoder:
302
+ # update attention mask
303
+ if "attention_mask" in model_kwargs:
304
+ attention_mask = model_kwargs["attention_mask"]
305
+ model_kwargs["attention_mask"] = torch.cat(
306
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], horizon_length))], dim=-1
307
+ )
308
+ else:
309
+ # update decoder attention mask
310
+ if "decoder_attention_mask" in model_kwargs:
311
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
312
+ model_kwargs["decoder_attention_mask"] = torch.cat(
313
+ [decoder_attention_mask, decoder_attention_mask.new_ones(
314
+ (decoder_attention_mask.shape[0], horizon_length))],
315
+ dim=-1,
316
+ )
317
+
318
+ if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
319
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + horizon_length
320
+
321
+ # update full_hidden_states: accumulate hidden states across generation steps for MTP layers
322
+ if hasattr(outputs, "hidden_states_for_mtp") and outputs.hidden_states_for_mtp is not None:
323
+ new_hs = outputs.hidden_states_for_mtp
324
+ if "full_hidden_states" in model_kwargs and model_kwargs["full_hidden_states"] is not None:
325
+ existing = model_kwargs["full_hidden_states"]
326
+ model_kwargs["full_hidden_states"] = torch.cat(
327
+ [existing.to(new_hs.device), new_hs], dim=1
328
+ )
329
+ else:
330
+ model_kwargs["full_hidden_states"] = new_hs
331
+
332
+ return model_kwargs