guxinhao commited on
Commit
d041ad0
·
verified ·
1 Parent(s): bdc72f3

Upload 10 files

Browse files
README.md CHANGED
@@ -1,3 +1,128 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ metrics:
4
+ - mse
5
+ - mae
6
+ - mase
7
+ - wql
8
+ - crps
9
+ pipeline_tag: time-series-forecasting
10
+ datasets:
11
+ - thuml/UTSD
12
+ - Salesforce/lotsa_data
13
+ - Salesforce/GiftEvalPretrain
14
+ - autogluon/chronos_datasets
15
+ tags:
16
+ - time series
17
+ - time-series
18
+ - forecasting
19
+ - foundation models
20
+ - pretrained models
21
+ - time series foundation models
22
+ library_name: transformers
23
+ ---
24
+
25
+ # Timer-S1
26
+
27
+ Timer-S1 is a time series foundation model with **8.3B** total parameters, **0.75B** activated parameters per token, and a context length of **11,520**.
28
+
29
+ The model supports **zero-shot forecasting** (predicting without dataset-specific training) at different quantile levels.
30
+
31
+ For more details, please refer to our [technical report](https://arxiv.org/pdf/2603.04791).
32
+
33
+ ![image](https://cdn-uploads.huggingface.co/production/uploads/64fbe24a2d20ced4e91de38a/7Udz1nO2V1Nk0pw5cW4gG.png)
34
+
35
+ **Architecture**: Timer-S1 is a decoder-only Mixture-of-Experts (MoE) Transformer. For time series forecasting (a sequential problem where each step depends on previous ones), we propose **TimeSTP**, enabling multi-step prediction with cost-effective **serial computations**.
36
+ ![image](https://cdn-uploads.huggingface.co/production/uploads/64fbe24a2d20ced4e91de38a/1XsUZDPw8DJebZwH-Ievd.png)
37
+
38
+ **Performance**: Timer-S1 achieves state-of-the-art results on [GIFT-Eval](https://huggingface.co/spaces/Salesforce/GIFT-Eval). The model excels particularly at **medium-term** and **long-term** forecasting tasks.
39
+
40
+ ![image](https://cdn-uploads.huggingface.co/production/uploads/64fbe24a2d20ced4e91de38a/XDOekWBIGBoc8nTDI-WBI.png)
41
+
42
+ ![image](https://cdn-uploads.huggingface.co/production/uploads/64fbe24a2d20ced4e91de38a/r7eGVKBIRI8h7lMre4-lP.png)
43
+
44
+ **Post Training**: Timer-S1 undergoes post-training, including continued pre-training (**CPT**) and long-context extension (**LCE**), which improves short-term and long-context performance.
45
+
46
+ ![image](https://cdn-uploads.huggingface.co/production/uploads/69ce7cea1430d60211285e20/9KqUVPPkA6DMr_EnhpD_O.png)
47
+
48
+
49
+ ## Quickstart
50
+
51
+ ```
52
+ pip install torch accelerate transformers~=4.57.1
53
+ ```
54
+
55
+ ```python
56
+ import torch
57
+ from transformers import AutoModelForCausalLM
58
+
59
+ # load pretrain model
60
+ # supports different lookback/forecast lengths
61
+ model = AutoModelForCausalLM.from_pretrained(
62
+ 'bytedance-research/Timer-S1',
63
+ trust_remote_code=True,
64
+ device_map="auto"
65
+ )
66
+
67
+ # use local model
68
+ # model = AutoModelForCausalLM.from_pretrained(
69
+ # 'path_to_timer_s1',
70
+ # trust_remote_code=True,
71
+ # device_map="auto"
72
+ # )
73
+
74
+ # prepare input
75
+ batch_size, lookback_length = 64, 11520
76
+ seqs = torch.randn(batch_size, lookback_length).to(model.device)
77
+
78
+ # Note that Timer-S1 generates predictions at fixed quantile levels
79
+ forecast_length = 256
80
+
81
+ output = model.generate(seqs, max_new_tokens=forecast_length, revin=True)
82
+
83
+ # produce quantile forecasts in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
84
+ print(output.shape) # batch_size x quantile_num(9) x forecast_length
85
+
86
+ # produce the median forecast of the first sample
87
+ print(output[0][4])
88
+ ```
89
+
90
+
91
+ This model support inference using either CPU or GPU. To load this model on GPU, we recommend a GPU with **at least 40GB VRAM** (e.g., A100 40GB/80GB, or H100).
92
+
93
+ > **Encounter out-of-memory at runtime?** Try the following options:
94
+ > ```python
95
+ > # Option 1: reduce batch size or context length
96
+ > batch_size, lookback_length = 1, 2880
97
+ >
98
+ > # Option 2: disable KV cache at runtime (or edit it in config.json for a permanent change)
99
+ > model.config.use_cache = False # there is no efficiency impact for cases where the prediction horizon does not exceed 256.
100
+ > ```
101
+
102
+ ## Specification
103
+
104
+ * **Architecture**: decoder-only Transformer with MoE
105
+ * **Context Length**: up to 11,520
106
+ * **ReNorm**: default=True
107
+ * **KV Cache**: default=True
108
+ * **Patch Length**: 16
109
+ * **Total Parameters**: 8.3B
110
+ * **Activated Parameters**: 0.75B
111
+ * **Number of Layers**: 40
112
+
113
+
114
+ ## License Agreement
115
+
116
+ This model is licensed under the Apache-2.0 License.
117
+
118
+ ## Citation
119
+
120
+ If you find Timer-S1 helpful for your research, please cite our paper:
121
+ ```
122
+ @article{liu2026timer,
123
+ title={Timer-S1: A Billion-Scale Time Series Foundation Model with Serial Scaling},
124
+ author={Liu, Yong and Su, Xingjian and Wang, Shiyu and Zhang, Haoran and Liu, Haixuan and Wang, Yuxuan and Ye, Zhou and Xiang, Yang and Wang, Jianmin and Long, Mingsheng},
125
+ journal={arXiv preprint arXiv:2603.04791},
126
+ year={2026}
127
+ }
128
+ ```
config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Timer-S1"
4
+ ],
5
+ "dropout_rate": 0.1,
6
+ "hidden_act": "silu",
7
+ "hidden_size": 1024,
8
+ "initializer_range": 0.02,
9
+ "input_token_len": 16,
10
+ "intermediate_size": 4096,
11
+ "max_position_embeddings": 12800,
12
+ "model_type": "Timer-S1",
13
+ "auto_map": {
14
+ "AutoConfig": "configuration_TimerS1.TimerS1Config",
15
+ "AutoModelForCausalLM": "modeling_TimerS1.TimerS1ForPrediction"
16
+ },
17
+ "num_attention_heads": 16,
18
+ "num_experts": 32,
19
+ "num_experts_per_token": 2,
20
+ "num_hidden_layers": 24,
21
+ "num_mtp_tokens": 16,
22
+ "output_token_lens": [
23
+ 16
24
+ ],
25
+ "quantiles": [
26
+ 0.1,
27
+ 0.2,
28
+ 0.3,
29
+ 0.4,
30
+ 0.5,
31
+ 0.6,
32
+ 0.7,
33
+ 0.8,
34
+ 0.9
35
+ ],
36
+ "rope_theta": 10000,
37
+ "torch_dtype": "bfloat16",
38
+ "use_cache": true
39
+ }
configuration_TimerS1.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http:www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List
16
+ from transformers import PretrainedConfig
17
+
18
+
19
+ class TimerS1Config(PretrainedConfig):
20
+ model_type = "Timer-S1"
21
+ keys_to_ignore_at_inference = ["past_key_values"]
22
+
23
+ def __init__(
24
+ self,
25
+ input_token_len: int = 16,
26
+ hidden_size: int = 1024,
27
+ intermediate_size: int = 4096,
28
+ output_token_lens: List[int] = [16],
29
+ num_hidden_layers: int = 24,
30
+ num_attention_heads: int = 16,
31
+ hidden_act: str = "silu",
32
+ use_cache: bool = True,
33
+ rope_theta: int = 10000,
34
+ dropout_rate: float = 0.1,
35
+ initializer_range: float = 0.02,
36
+ max_position_embeddings: int = 12800,
37
+ quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
38
+ num_experts: int = 32,
39
+ num_experts_per_token: int = 2,
40
+ # MTP configuration
41
+ num_mtp_tokens: int = 16,
42
+ **kwargs,
43
+ ):
44
+ self.input_token_len = input_token_len
45
+ self.hidden_size = hidden_size
46
+ self.intermediate_size = intermediate_size
47
+ self.num_hidden_layers = num_hidden_layers
48
+ self.num_attention_heads = num_attention_heads
49
+ self.hidden_act = hidden_act
50
+ self.output_token_lens = output_token_lens
51
+ self.use_cache = use_cache
52
+ self.rope_theta = rope_theta
53
+ self.dropout_rate = dropout_rate
54
+ self.initializer_range = initializer_range
55
+ self.max_position_embeddings = max_position_embeddings
56
+ self.quantiles = quantiles
57
+ self.num_experts = num_experts
58
+ self.num_experts_per_token = num_experts_per_token
59
+ # MTP configuration
60
+ self.num_mtp_tokens = num_mtp_tokens
61
+ super().__init__(**kwargs)
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8735f5cca42b847670cfc1a03118f092d3cd49150787a0254e863b05e892022c
3
+ size 4999034896
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d44c1ccffbce8263b98c5f9f5ac8027078d7166ffb13b7a55f1934f6b0c4c370
3
+ size 4999220272
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6859deeac01a834733852c2d97f47e653137a13d37cffdc7e33633a50ec3dda7
3
+ size 4996606248
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce442e869597b171db455ad31868722dbff0a1bd88ed77ccb21c5a4438a2edd3
3
+ size 1613025584
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_TimerS1.py ADDED
@@ -0,0 +1,837 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http:www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Tuple, List, Union
16
+ import math
17
+ from dataclasses import dataclass
18
+
19
+ import torch
20
+ from torch import nn
21
+ import torch.nn.functional as F
22
+ from transformers import PreTrainedModel, Cache, DynamicCache
23
+ from transformers.activations import ACT2FN
24
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
25
+ from transformers.modeling_outputs import MoeModelOutputWithPast, MoeCausalLMOutputWithPast
26
+
27
+ from .configuration_TimerS1 import TimerS1Config
28
+ from .ts_generation_mixin import TSGenerationMixin
29
+
30
+
31
+ @dataclass
32
+ class TimerS1CausalLMOutput(MoeCausalLMOutputWithPast):
33
+ """Extends MoeCausalLMOutputWithPast with hidden_states_for_mtp as a proper dataclass field
34
+ so it is reliably registered in the ModelOutput OrderedDict and accessible via attribute access."""
35
+ hidden_states_for_mtp: Optional[torch.FloatTensor] = None
36
+
37
+ def _get_usable_past_kv_length(cache: Cache, new_seq_length: int, layer_idx: int = 0) -> int:
38
+ """Compute the usable past length for the given cache and upcoming new sequence length.
39
+
40
+ This mirrors the previous `get_usable_length(new_seq_length, layer_idx)` behavior that existed in
41
+ Transformers < 4.45, while being compatible with the new Cache API.
42
+ """
43
+ try:
44
+ previous_length = cache.get_seq_length(layer_idx)
45
+ # Dynamic layers return -1, static layers return an int
46
+ max_length = cache.get_max_cache_shape(layer_idx)
47
+ if max_length is not None and max_length != -1 and previous_length + new_seq_length > max_length:
48
+ return max_length - new_seq_length
49
+ return previous_length
50
+ except Exception:
51
+ # Best-effort fallback
52
+ return cache.get_seq_length(layer_idx) if hasattr(cache, "get_seq_length") else 0
53
+
54
+ @dataclass
55
+ class TempMoeModelOutputWithPast(MoeModelOutputWithPast):
56
+ last_hidden_state: torch.FloatTensor = None
57
+ past_key_values: Optional[
58
+ Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor]]]
59
+ ] = None
60
+ use_legacy_cache: Optional[bool] = None
61
+ past_key_values_length: Optional[int] = None
62
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
63
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
64
+ router_logits: Optional[Tuple[torch.FloatTensor]] = None
65
+
66
+ def rotate_half(x):
67
+ x1 = x[..., : x.shape[-1] // 2]
68
+ x2 = x[..., x.shape[-1] // 2:]
69
+ return torch.cat((-x2, x1), dim=-1)
70
+
71
+
72
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
73
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
74
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
75
+ q_embed = (q * cos) + (rotate_half(q) * sin)
76
+ k_embed = (k * cos) + (rotate_half(k) * sin)
77
+ return q_embed, k_embed
78
+
79
+ class RMSNorm(nn.Module):
80
+ def __init__(self, dim: int, eps: float = 1e-6):
81
+ super().__init__()
82
+ self.eps = eps
83
+ self.weight = nn.Parameter(torch.ones(dim))
84
+
85
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
86
+ rms = x.pow(2).mean(dim=-1, keepdim=True).sqrt()
87
+ x_norm = x / (rms + self.eps)
88
+ return x_norm * self.weight
89
+
90
+ class ResidualBlock(nn.Module):
91
+ def __init__(self, config: TimerS1Config) -> None:
92
+ super().__init__()
93
+ self.out_dim = len(config.quantiles) * config.output_token_lens[-1]
94
+ self.dropout = nn.Dropout(config.dropout_rate)
95
+ self.hidden_layer = nn.Linear(config.hidden_size, config.hidden_size)
96
+ self.act = ACT2FN[config.hidden_act]
97
+ self.output_layer = nn.Linear(config.hidden_size, self.out_dim)
98
+ self.residual_layer = nn.Linear(config.hidden_size, self.out_dim)
99
+
100
+ def forward(self, x: torch.Tensor):
101
+ hid = self.act(self.hidden_layer(x))
102
+ out = self.dropout(self.output_layer(hid))
103
+ return out + self.residual_layer(x)
104
+
105
+
106
+ class TimerS1PatchEmbedding(nn.Module):
107
+ def __init__(self, config: TimerS1Config):
108
+ super().__init__()
109
+ self.dropout = nn.Dropout(config.dropout_rate)
110
+ self.hidden_layer = nn.Linear(config.input_token_len * 2, config.intermediate_size)
111
+ self.act = ACT2FN[config.hidden_act]
112
+ self.output_layer = nn.Linear(config.intermediate_size, config.hidden_size)
113
+ self.residual_layer = nn.Linear(config.input_token_len * 2, config.hidden_size)
114
+ self.input_token_len = config.input_token_len
115
+
116
+ def forward(self, x):
117
+ mask = torch.ones_like(x)
118
+ input_length = x.shape[-1]
119
+ padding_length = (self.input_token_len - (input_length % self.input_token_len)) % self.input_token_len
120
+ x = F.pad(x, (padding_length, 0))
121
+ mask = F.pad(mask, (padding_length, 0))
122
+ x = x.unfold(dimension=-1, size=self.input_token_len, step=self.input_token_len)
123
+ mask = mask.unfold(dimension=-1, size=self.input_token_len, step=self.input_token_len)
124
+ x = torch.cat([x, mask], dim=-1)
125
+ hid = self.act(self.hidden_layer(x))
126
+ out = self.dropout(self.output_layer(hid))
127
+ return out + self.residual_layer(x)
128
+
129
+
130
+ class TimerS1RotaryEmbedding(torch.nn.Module):
131
+ def __init__(self, dim, max_position_embeddings=10000, base=10000, device=None):
132
+ super().__init__()
133
+ self.dim = dim
134
+ self.max_position_embeddings = max_position_embeddings
135
+ self.base = base
136
+ inv_freq = 1.0 / (
137
+ self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)
138
+ )
139
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
140
+ self._set_cos_sin_cache(
141
+ seq_len=max_position_embeddings,
142
+ device=self.inv_freq.device,
143
+ dtype=torch.get_default_dtype(),
144
+ )
145
+
146
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
147
+ self.max_seq_len_cached = seq_len
148
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
149
+ freqs = torch.outer(t, self.inv_freq)
150
+ emb = torch.cat((freqs, freqs), dim=-1)
151
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
152
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
153
+
154
+ def forward(self, x, seq_len=None):
155
+ if seq_len > self.max_seq_len_cached:
156
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
157
+ return (
158
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
159
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
160
+ )
161
+
162
+ class TimerS1Attention(nn.Module):
163
+ def __init__(self, config: TimerS1Config, layer_idx: Optional[int] = None):
164
+ super().__init__()
165
+ self.layer_idx = layer_idx
166
+ self.hidden_size = config.hidden_size
167
+ self.num_heads = config.num_attention_heads
168
+ self.head_dim = self.hidden_size // self.num_heads
169
+ self.attention_dropout = config.dropout_rate
170
+
171
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
172
+ self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
173
+ self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
174
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
175
+
176
+ # QK-Norm learnable scales
177
+ self.q_scale = nn.Parameter(torch.ones(self.head_dim))
178
+ self.k_scale = nn.Parameter(torch.ones(self.head_dim))
179
+
180
+ # Attention output gate
181
+ self.gate_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
182
+
183
+ self.rotary_emb = TimerS1RotaryEmbedding(
184
+ self.head_dim,
185
+ max_position_embeddings=config.max_position_embeddings,
186
+ base=config.rope_theta,
187
+ )
188
+
189
+ def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
190
+ eps = 1e-6
191
+ q = q * torch.rsqrt(q.pow(2).mean(dim=-1, keepdim=True) + eps) * self.q_scale.view(1, 1, 1, -1)
192
+ k = k * torch.rsqrt(k.pow(2).mean(dim=-1, keepdim=True) + eps) * self.k_scale.view(1, 1, 1, -1)
193
+ return q, k
194
+
195
+ def forward(
196
+ self,
197
+ hidden_states: torch.Tensor,
198
+ attention_mask: Optional[torch.Tensor] = None,
199
+ position_ids: Optional[torch.LongTensor] = None,
200
+ past_key_value: Optional[Cache] = None,
201
+ output_attentions: bool = False,
202
+ **kwargs,
203
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
204
+ bsz, q_len, _ = hidden_states.size()
205
+
206
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
207
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
208
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
209
+
210
+ kv_seq_len = key_states.shape[-2]
211
+ if past_key_value is not None:
212
+ kv_seq_len += _get_usable_past_kv_length(past_key_value, kv_seq_len, self.layer_idx)
213
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
214
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
215
+
216
+ query_states, key_states = self._apply_qk_norm(query_states, key_states)
217
+
218
+ if past_key_value is not None:
219
+ key_states, value_states = past_key_value.update(
220
+ key_states, value_states, self.layer_idx)
221
+
222
+ attn_output = F.scaled_dot_product_attention(
223
+ query_states,
224
+ key_states,
225
+ value_states,
226
+ attention_mask,
227
+ dropout_p=(self.attention_dropout if self.training else 0.0),
228
+ ) # [bsz, num_heads, q_len, head_dim]
229
+
230
+ gate = torch.sigmoid(self.gate_proj(hidden_states))
231
+ gate = gate.view(bsz, q_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
232
+ attn_output = attn_output * gate
233
+
234
+ attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
235
+ attn_output = self.o_proj(attn_output)
236
+
237
+ attn_weights = None if not output_attentions else attn_output
238
+ return attn_output, attn_weights, past_key_value
239
+
240
+ class TimerS1MLP(nn.Module):
241
+ def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str):
242
+ super().__init__()
243
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
244
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
245
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
246
+ self.act_fn = ACT2FN[hidden_act]
247
+
248
+ def forward(self, hidden_state):
249
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
250
+
251
+ class TimerS1ExpertsLayer(nn.Module):
252
+ def __init__(self, config: TimerS1Config):
253
+ super().__init__()
254
+ self.top_k = config.num_experts_per_token
255
+ self.hidden_size = config.hidden_size
256
+ self.num_experts = config.num_experts
257
+ moe_intermediate_size = config.intermediate_size // self.top_k
258
+
259
+ self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
260
+ self.experts = nn.ModuleList([
261
+ TimerS1MLP(
262
+ hidden_size=config.hidden_size,
263
+ intermediate_size=moe_intermediate_size,
264
+ hidden_act=config.hidden_act,
265
+ )
266
+ for _ in range(self.num_experts)
267
+ ])
268
+
269
+ def forward(self, hidden_states: torch.Tensor):
270
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
271
+ hidden_states = hidden_states.view(-1, hidden_dim)
272
+ router_logits = self.gate(hidden_states)
273
+
274
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
275
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
276
+ routing_weights = routing_weights.to(hidden_states.dtype)
277
+
278
+ final_hidden_states = torch.zeros(
279
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
280
+ )
281
+
282
+ expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
283
+
284
+ for expert_idx in range(self.num_experts):
285
+ expert_layer = self.experts[expert_idx]
286
+ idx, top_x = torch.where(expert_mask[expert_idx])
287
+
288
+ if top_x.numel() == 0:
289
+ continue
290
+
291
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
292
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
293
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
294
+
295
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
296
+ return final_hidden_states
297
+
298
+ class TimerS1DecoderLayer(nn.Module):
299
+ def __init__(self, config: TimerS1Config, layer_idx: int):
300
+ super().__init__()
301
+ self.self_attn = TimerS1Attention(config, layer_idx)
302
+ self.ffn_layer = TimerS1ExpertsLayer(config)
303
+ self.norm1 = RMSNorm(config.hidden_size)
304
+ self.norm2 = RMSNorm(config.hidden_size)
305
+
306
+ def forward(
307
+ self,
308
+ hidden_states: torch.Tensor,
309
+ attention_mask: Optional[torch.Tensor] = None,
310
+ position_ids: Optional[torch.LongTensor] = None,
311
+ past_key_value: Optional[Cache] = None,
312
+ output_attentions: Optional[bool] = False,
313
+ use_cache: Optional[bool] = False,
314
+ **kwargs,
315
+ ) -> Tuple[torch.FloatTensor, Optional[torch.Tensor], Optional[Cache]]:
316
+ residual = hidden_states
317
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
318
+ hidden_states=self.norm1(hidden_states),
319
+ attention_mask=attention_mask,
320
+ position_ids=position_ids,
321
+ past_key_value=past_key_value,
322
+ output_attentions=output_attentions,
323
+ )
324
+ hidden_states = residual + hidden_states
325
+
326
+ residual = hidden_states
327
+ hidden_states = self.ffn_layer(self.norm2(hidden_states))
328
+ hidden_states = residual + hidden_states
329
+
330
+ if not output_attentions:
331
+ self_attn_weights = None
332
+ if not use_cache:
333
+ present_key_value = None
334
+
335
+ return hidden_states, self_attn_weights, present_key_value
336
+
337
+
338
+ class TimerS1PreTrainedModel(PreTrainedModel):
339
+ config_class = TimerS1Config
340
+ base_model_prefix = "model"
341
+ supports_gradient_checkpointing = True
342
+ _no_split_modules = ["TimerS1DecoderLayer"]
343
+ _skip_keys_device_placement = "past_key_values"
344
+ _supports_flash_attn_2 = True
345
+ _supports_sdpa = False
346
+ _supports_cache_class = True
347
+
348
+ def _init_weights(self, module):
349
+ std = self.config.initializer_range
350
+ if isinstance(module, nn.Linear):
351
+ module.weight.data.normal_(mean=0.0, std=std)
352
+ if module.bias is not None:
353
+ module.bias.data.zero_()
354
+ elif isinstance(module, nn.Embedding):
355
+ module.weight.data.normal_(mean=0.0, std=std)
356
+ if module.padding_idx is not None:
357
+ module.weight.data[module.padding_idx].zero_()
358
+
359
+ class TimerS1Model(TimerS1PreTrainedModel):
360
+ def __init__(self, config: TimerS1Config):
361
+ super().__init__(config)
362
+ self.embed_layer = TimerS1PatchEmbedding(config)
363
+ self.layers = nn.ModuleList([
364
+ TimerS1DecoderLayer(config, layer_idx)
365
+ for layer_idx in range(config.num_hidden_layers)
366
+ ])
367
+ self.norm = RMSNorm(config.hidden_size)
368
+ self.gradient_checkpointing = False
369
+
370
+ def forward(
371
+ self,
372
+ input_ids: torch.FloatTensor = None,
373
+ attention_mask: Optional[torch.Tensor] = None,
374
+ position_ids: Optional[torch.LongTensor] = None,
375
+ past_key_values: Optional[
376
+ Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor]]]
377
+ ] = None,
378
+ inputs_embeds: Optional[torch.FloatTensor] = None,
379
+ use_cache: Optional[bool] = None,
380
+ output_attentions: Optional[bool] = None,
381
+ output_hidden_states: Optional[bool] = None,
382
+ return_dict: Optional[bool] = None,
383
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
384
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
385
+ output_hidden_states = (
386
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
387
+ )
388
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
389
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
390
+
391
+ if input_ids is not None and inputs_embeds is not None:
392
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
393
+ elif input_ids is not None:
394
+ batch_size, seq_length = input_ids.shape
395
+ elif inputs_embeds is not None:
396
+ batch_size, seq_length, _ = inputs_embeds.shape
397
+ else:
398
+ raise ValueError("You must specify either input_ids or inputs_embeds")
399
+
400
+ if inputs_embeds is None:
401
+ inputs_embeds = self.embed_layer(input_ids)
402
+ seq_length = inputs_embeds.shape[1]
403
+
404
+ if self.gradient_checkpointing and self.training and use_cache:
405
+ use_cache = False
406
+
407
+ past_key_values_length = 0
408
+ use_legacy_cache = None
409
+ if use_cache:
410
+ use_legacy_cache = not isinstance(past_key_values, Cache)
411
+ if use_legacy_cache:
412
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
413
+ past_key_values_length = _get_usable_past_kv_length(past_key_values, seq_length)
414
+
415
+ if position_ids is None:
416
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
417
+ position_ids = torch.arange(
418
+ past_key_values_length, seq_length + past_key_values_length,
419
+ dtype=torch.long, device=device,
420
+ ).view(-1, seq_length)
421
+ else:
422
+ position_ids = position_ids.view(-1, seq_length).long()
423
+
424
+ attention_mask = _prepare_4d_causal_attention_mask(
425
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length, sliding_window=None,
426
+ )
427
+
428
+ hidden_states = inputs_embeds
429
+
430
+ all_hidden_states = () if output_hidden_states else None
431
+ all_self_attns = () if output_attentions else None
432
+ all_moe_losses = []
433
+
434
+ for decoder_layer in self.layers:
435
+ if output_hidden_states:
436
+ all_hidden_states += (hidden_states,)
437
+
438
+ layer_outputs = decoder_layer(
439
+ hidden_states,
440
+ attention_mask=attention_mask,
441
+ position_ids=position_ids,
442
+ past_key_value=past_key_values,
443
+ output_attentions=output_attentions,
444
+ use_cache=use_cache,
445
+ )
446
+
447
+ hidden_states = layer_outputs[0]
448
+
449
+ if output_attentions:
450
+ all_self_attns += (layer_outputs[1],)
451
+
452
+ hidden_states = self.norm(hidden_states)
453
+ if output_hidden_states:
454
+ all_hidden_states += (hidden_states,)
455
+
456
+ if not return_dict:
457
+ return tuple(
458
+ v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_moe_losses]
459
+ if v is not None
460
+ )
461
+
462
+ return TempMoeModelOutputWithPast(
463
+ last_hidden_state=hidden_states,
464
+ past_key_values=past_key_values,
465
+ hidden_states=all_hidden_states,
466
+ attentions=all_self_attns,
467
+ use_legacy_cache=use_legacy_cache,
468
+ past_key_values_length=past_key_values_length,
469
+ router_logits=all_moe_losses,
470
+ )
471
+
472
+ class TimerS1MTPLayer(nn.Module):
473
+ def __init__(self, config: TimerS1Config, layer_idx: int):
474
+ super().__init__()
475
+ self.hidden_size = config.hidden_size
476
+ self.config = config
477
+ self.layer_idx = layer_idx
478
+ self.norm_hidden = RMSNorm(config.hidden_size)
479
+ self.norm_embeds = RMSNorm(config.hidden_size)
480
+ self.projection_matrix = nn.Linear(2 * self.hidden_size, self.hidden_size, bias=False)
481
+ self.layer = TimerS1DecoderLayer(config, self.layer_idx + self.config.num_hidden_layers)
482
+ self.norm = RMSNorm(config.hidden_size)
483
+ self.gradient_checkpointing = False
484
+
485
+ def forward(
486
+ self,
487
+ hidden_states: torch.FloatTensor = None,
488
+ attention_mask: Optional[torch.Tensor] = None,
489
+ position_ids: Optional[torch.LongTensor] = None,
490
+ past_key_values: Optional[
491
+ Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor]]]
492
+ ] = None,
493
+ use_legacy_cache: Optional[bool] = False,
494
+ past_key_values_length: Optional[int] = 0,
495
+ inputs_embeds: Optional[torch.FloatTensor] = None,
496
+ use_cache: Optional[bool] = None,
497
+ output_attentions: Optional[bool] = None,
498
+ output_hidden_states: Optional[bool] = None,
499
+ return_dict: Optional[bool] = None,
500
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
501
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
502
+ output_hidden_states = (
503
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
504
+ )
505
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
506
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
507
+
508
+ if inputs_embeds is not None:
509
+ batch_size, seq_length, _ = inputs_embeds.shape
510
+ else:
511
+ raise ValueError("You must specify inputs_embeds")
512
+
513
+ if self.gradient_checkpointing and self.training:
514
+ if use_cache:
515
+ use_cache = False
516
+
517
+ if position_ids is None:
518
+ device = inputs_embeds.device
519
+ position_ids = torch.arange(
520
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
521
+ )
522
+ position_ids = position_ids.view(-1, seq_length)
523
+ else:
524
+ position_ids = position_ids.view(-1, seq_length).long()
525
+
526
+ attention_mask = _prepare_4d_causal_attention_mask(
527
+ attention_mask,
528
+ (batch_size, seq_length),
529
+ inputs_embeds,
530
+ past_key_values_length,
531
+ sliding_window=None,
532
+ )
533
+
534
+ hidden_states = self.norm_hidden(hidden_states)
535
+ inputs_embeds = self.norm_embeds(inputs_embeds)
536
+ hidden_states = self.projection_matrix(torch.cat([hidden_states, inputs_embeds], dim=-1))
537
+
538
+ all_hidden_states = () if output_hidden_states else None
539
+ all_self_attns = () if output_attentions else None
540
+ all_moe_losses = []
541
+ next_decoder_cache = None
542
+
543
+ if output_hidden_states:
544
+ all_hidden_states += (hidden_states,)
545
+
546
+ if self.gradient_checkpointing and self.training:
547
+ layer_outputs = self._gradient_checkpointing_func(
548
+ self.layer.__call__,
549
+ hidden_states,
550
+ attention_mask,
551
+ position_ids,
552
+ past_key_values,
553
+ output_attentions,
554
+ use_cache,
555
+ )
556
+ else:
557
+ layer_outputs = self.layer(
558
+ hidden_states,
559
+ attention_mask=attention_mask,
560
+ position_ids=position_ids,
561
+ past_key_value=past_key_values,
562
+ output_attentions=output_attentions,
563
+ use_cache=use_cache,
564
+ )
565
+
566
+ hidden_states = layer_outputs[0]
567
+
568
+ if output_attentions:
569
+ all_self_attns += (layer_outputs[1],)
570
+
571
+ if use_cache:
572
+ next_decoder_cache = layer_outputs[2]
573
+
574
+ hidden_states = self.norm(hidden_states)
575
+
576
+ if output_hidden_states:
577
+ all_hidden_states += (hidden_states,)
578
+
579
+ next_cache = None
580
+ if use_cache:
581
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
582
+
583
+ if not return_dict:
584
+ return tuple(
585
+ v
586
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_moe_losses]
587
+ if v is not None
588
+ )
589
+ return MoeModelOutputWithPast(
590
+ last_hidden_state=hidden_states,
591
+ past_key_values=next_cache,
592
+ hidden_states=all_hidden_states,
593
+ attentions=all_self_attns,
594
+ router_logits=all_moe_losses,
595
+ )
596
+
597
+ class TimerS1ForPrediction(TimerS1PreTrainedModel, TSGenerationMixin):
598
+ def __init__(self, config: TimerS1Config):
599
+ super().__init__(config)
600
+ self.config = config
601
+ self.model = TimerS1Model(self.config)
602
+ self.output_patch_embedding = ResidualBlock(config)
603
+ self.num_quantiles = len(config.quantiles)
604
+ if self.config.num_mtp_tokens > 0:
605
+ self.mtp_modules = nn.ModuleList([
606
+ TimerS1MTPLayer(config, layer_idx)
607
+ for layer_idx in range(self.config.num_mtp_tokens)
608
+ ])
609
+ self.post_init()
610
+
611
+ def set_decoder(self, decoder):
612
+ self.model = decoder
613
+
614
+ def get_decoder(self):
615
+ return self.model
616
+
617
+ def forward(
618
+ self,
619
+ input_ids: torch.FloatTensor = None,
620
+ attention_mask: Optional[torch.Tensor] = None,
621
+ position_ids: Optional[torch.LongTensor] = None,
622
+ past_key_values: Optional[
623
+ Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor]]]
624
+ ] = None,
625
+ inputs_embeds: Optional[torch.FloatTensor] = None,
626
+ full_input_ids: Optional[torch.FloatTensor] = None,
627
+ full_hidden_states: Optional[torch.FloatTensor] = None,
628
+ use_cache: Optional[bool] = None,
629
+ output_attentions: Optional[bool] = None,
630
+ output_hidden_states: Optional[bool] = None,
631
+ return_dict: Optional[bool] = None,
632
+ max_output_length: Optional[int] = None,
633
+ revin: Optional[bool] = False,
634
+ ) -> Union[Tuple, TimerS1CausalLMOutput]:
635
+
636
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
637
+ output_hidden_states = (
638
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
639
+ )
640
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
641
+
642
+ if revin:
643
+ means = input_ids.mean(1, keepdim=True).detach()
644
+ stdev = input_ids.std(dim=1, keepdim=True, unbiased=False).detach()
645
+ stdev = torch.where(stdev > 1e-2, stdev, torch.tensor(1.0, device=input_ids.device))
646
+ input_ids = (input_ids - means) / stdev
647
+ if full_input_ids is not None:
648
+ fi_means = full_input_ids.mean(1, keepdim=True).detach()
649
+ fi_stdev = full_input_ids.std(dim=1, keepdim=True, unbiased=False).detach()
650
+ fi_stdev = torch.where(
651
+ fi_stdev > 1e-2, fi_stdev, torch.tensor(1.0, device=full_input_ids.device)
652
+ )
653
+ full_input_ids = (full_input_ids - fi_means) / fi_stdev
654
+ if inputs_embeds is None and input_ids is not None:
655
+ inputs_embeds = self.model.embed_layer(input_ids)
656
+ # full_inputs_embeds: embeddings for the complete sequence used by MTP layers (no KV cache)
657
+ if full_input_ids is not None:
658
+ full_inputs_embeds = self.model.embed_layer(full_input_ids)
659
+ else:
660
+ full_inputs_embeds = inputs_embeds
661
+
662
+ outputs = self.model(
663
+ input_ids=None,
664
+ attention_mask=attention_mask,
665
+ position_ids=position_ids,
666
+ past_key_values=past_key_values,
667
+ inputs_embeds=inputs_embeds,
668
+ use_cache=use_cache,
669
+ output_attentions=output_attentions,
670
+ output_hidden_states=output_hidden_states,
671
+ return_dict=return_dict,
672
+ )
673
+
674
+ hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
675
+
676
+ # Accumulate full hidden states across generation steps for MTP layers.
677
+ # When KV cache is enabled, hidden_states only covers new tokens, so we need to
678
+ # prepend accumulated past hidden states (full_hidden_states) to restore the full
679
+ # sequence picture needed by MTP layers.
680
+ # When KV cache is disabled, hidden_states already covers the full sequence
681
+ # (same length as full_inputs_embeds), so no accumulation is needed.
682
+ if full_hidden_states is not None and hidden_states.shape[1] < full_inputs_embeds.shape[1]:
683
+ mtp_hidden_states = torch.cat([full_hidden_states.to(hidden_states.device), hidden_states], dim=1)
684
+ else:
685
+ mtp_hidden_states = hidden_states
686
+
687
+ bsz, L, _ = hidden_states.shape
688
+ predictions = None
689
+ loss = None
690
+ if max_output_length is None:
691
+ output_token_len = self.config.output_token_lens[0]
692
+ max_output_length = output_token_len
693
+ else:
694
+ output_token_len = self.config.output_token_lens[0]
695
+ for h in self.config.output_token_lens[1:]:
696
+ if h > max_output_length:
697
+ break
698
+ output_token_len = h
699
+
700
+ predictions = self.output_patch_embedding(hidden_states[:, -1, :]).reshape(
701
+ bsz, self.num_quantiles, self.config.output_token_lens[-1]
702
+ )
703
+
704
+ if self.config.num_mtp_tokens > 0:
705
+ output_patch_len = self.config.output_token_lens[-1]
706
+ full_out_len = output_patch_len + self.config.input_token_len * self.config.num_mtp_tokens
707
+
708
+ target_len = max(0, min(int(max_output_length), int(full_out_len)))
709
+
710
+ out = torch.zeros(bsz, self.num_quantiles, target_len, device=predictions.device)
711
+ base_fill = min(output_patch_len, target_len)
712
+ if base_fill > 0:
713
+ out[:, :, :base_fill] = predictions[:, :, :base_fill]
714
+
715
+ if target_len <= output_patch_len:
716
+ mtp_steps_needed = 0
717
+ else:
718
+ remaining = target_len - output_patch_len
719
+ mtp_steps_needed = min(
720
+ self.config.num_mtp_tokens,
721
+ math.ceil(remaining / self.config.input_token_len),
722
+ )
723
+
724
+ for k, mtp_module in enumerate(self.mtp_modules):
725
+ if k >= mtp_steps_needed:
726
+ break
727
+
728
+ start_pos = (k + 1) * self.config.input_token_len
729
+ if start_pos >= target_len:
730
+ break
731
+
732
+ mtp_full_len = full_inputs_embeds.shape[1]
733
+ mtp_attention_mask = attention_mask[:, -mtp_full_len:] if attention_mask is not None else None
734
+ mtp_outputs = mtp_module(
735
+ hidden_states=mtp_hidden_states,
736
+ inputs_embeds=full_inputs_embeds,
737
+ attention_mask=mtp_attention_mask,
738
+ output_attentions=output_attentions,
739
+ )
740
+ mtp_hidden_states = mtp_outputs[0]
741
+
742
+ mtp_pred = self.output_patch_embedding(mtp_hidden_states)[:, -1, :]
743
+ mtp_pred = mtp_pred.reshape(bsz, self.num_quantiles, output_patch_len)
744
+
745
+ end_pos = min(start_pos + output_patch_len, target_len)
746
+ take = end_pos - start_pos
747
+ if take > 0:
748
+ out[:, :, start_pos:end_pos] = mtp_pred[:, :, :take]
749
+
750
+ predictions = out
751
+
752
+ if max_output_length is not None and predictions.shape[-1] > max_output_length:
753
+ predictions = predictions[:, :, :max_output_length]
754
+ if revin:
755
+ predictions = predictions * stdev + means
756
+ if not return_dict:
757
+ output = (predictions,) + outputs[1:]
758
+ return (loss,) + output if loss is not None else output
759
+
760
+ return TimerS1CausalLMOutput(
761
+ loss=loss,
762
+ logits=predictions,
763
+ past_key_values=outputs.past_key_values,
764
+ hidden_states=outputs.hidden_states,
765
+ attentions=outputs.attentions,
766
+ router_logits=outputs.router_logits,
767
+ # Pass main-model hidden states as a proper field so that
768
+ # _update_model_kwargs_for_generation can reliably accumulate them
769
+ # for the MTP layers across multi-step generation.
770
+ hidden_states_for_mtp=hidden_states,
771
+ )
772
+
773
+ def prepare_inputs_for_generation(
774
+ self,
775
+ input_ids,
776
+ past_key_values=None,
777
+ attention_mask=None,
778
+ inputs_embeds=None,
779
+ revin=False,
780
+ **kwargs,
781
+ ):
782
+ # full_input_ids always holds the complete original sequence for MTP layers
783
+ full_input_ids = input_ids.clone()
784
+ past_length = 0
785
+ if past_key_values is not None:
786
+ if isinstance(past_key_values, Cache):
787
+ cache_length = past_key_values.get_seq_length(0)
788
+ past_length = cache_length
789
+ try:
790
+ max_cache_length = past_key_values.get_max_cache_shape(0)
791
+ if max_cache_length == -1:
792
+ max_cache_length = None
793
+ except Exception:
794
+ max_cache_length = None
795
+ else:
796
+ cache_length = past_length = past_key_values[0][0].shape[2]
797
+ max_cache_length = None
798
+
799
+ # Trim input_ids to only include unprocessed tokens
800
+ if attention_mask is not None and attention_mask.shape[1] > (
801
+ input_ids.shape[1] // self.config.input_token_len
802
+ ):
803
+ input_ids = input_ids[
804
+ :, -(attention_mask.shape[1] - past_length) * self.config.input_token_len:
805
+ ]
806
+ elif past_length < (input_ids.shape[1] // self.config.input_token_len):
807
+ input_ids = input_ids[:, past_length * self.config.input_token_len:]
808
+
809
+ if (
810
+ max_cache_length is not None
811
+ and attention_mask is not None
812
+ and cache_length + (input_ids.shape[1] // self.config.input_token_len) > max_cache_length
813
+ ):
814
+ attention_mask = attention_mask[:, -max_cache_length:]
815
+
816
+ position_ids = kwargs.get("position_ids", None)
817
+ if attention_mask is not None and position_ids is None:
818
+ position_ids = attention_mask.long().cumsum(-1) - 1
819
+ position_ids.masked_fill_(attention_mask == 0, 1)
820
+ if past_length > 0:
821
+ position_ids = position_ids[:, -(input_ids.shape[1] // self.config.input_token_len):]
822
+
823
+ if inputs_embeds is not None and past_key_values is None:
824
+ model_inputs = {"inputs_embeds": inputs_embeds}
825
+ else:
826
+ model_inputs = {"input_ids": input_ids}
827
+
828
+ model_inputs.update({
829
+ "position_ids": position_ids,
830
+ "past_key_values": past_key_values,
831
+ "use_cache": kwargs.get("use_cache"),
832
+ "attention_mask": attention_mask,
833
+ "revin": revin,
834
+ "full_input_ids": full_input_ids,
835
+ "full_hidden_states": kwargs.get("full_hidden_states"),
836
+ })
837
+ return model_inputs
ts_generation_mixin.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http:www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import warnings
16
+ from typing import Any, Dict, List, Optional, Union, Callable
17
+ import torch
18
+ from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList
19
+ from transformers.generation import validate_stopping_criteria, EosTokenCriteria
20
+ from transformers.generation.utils import GenerateNonBeamOutput, GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput, GenerationConfig, GenerateOutput
21
+ from transformers.utils import ModelOutput
22
+
23
+ ALL_CACHE_NAMES = [
24
+ "past_key_values", # default
25
+ "cache_params", # mamba-based models
26
+ "state", # rwkv
27
+ "mems", # xlnet
28
+ "past_buckets_states", # reformer
29
+ ]
30
+
31
+ class TSGenerationMixin(GenerationMixin):
32
+ @torch.no_grad()
33
+ def generate(
34
+ self,
35
+ inputs: Optional[torch.Tensor] = None,
36
+ generation_config: Optional[GenerationConfig] = None,
37
+ logits_processor: Optional[LogitsProcessorList] = None,
38
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
39
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
40
+ synced_gpus: Optional[bool] = None,
41
+ assistant_model: Optional["PreTrainedModel"] = None,
42
+ streamer: Optional["BaseStreamer"] = None,
43
+ negative_prompt_ids: Optional[torch.Tensor] = None,
44
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
45
+ revin: Optional[bool] = True,
46
+ **kwargs,
47
+ ) -> Union[GenerateOutput, torch.LongTensor]:
48
+ if len(inputs.shape) != 2:
49
+ raise ValueError('Input shape must be: [batch_size, seq_len]')
50
+ if revin:
51
+ means = inputs.mean(dim=-1, keepdim=True)
52
+ stdev = inputs.std(dim=-1, keepdim=True, unbiased=False) + 1e-5
53
+ inputs = (inputs - means) / stdev
54
+ outputs = super().generate(
55
+ inputs=inputs,
56
+ generation_config=generation_config,
57
+ logits_processor=logits_processor,
58
+ stopping_criteria=stopping_criteria,
59
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
60
+ synced_gpus=synced_gpus,
61
+ assistant_model=assistant_model,
62
+ streamer=streamer,
63
+ negative_prompt_ids=negative_prompt_ids,
64
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
65
+ **kwargs,
66
+ )
67
+ if revin:
68
+ stdev = stdev.unsqueeze(1)
69
+ means = means.unsqueeze(1)
70
+ outputs = (outputs * stdev) + means
71
+ return outputs
72
+
73
+ def _sample(
74
+ self,
75
+ input_ids: torch.Tensor,
76
+ logits_processor: Optional[LogitsProcessorList] = None,
77
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
78
+ max_length: Optional[int] = None,
79
+ pad_token_id: Optional[int] = None,
80
+ eos_token_id: Optional[Union[int, List[int]]] = None,
81
+ output_attentions: Optional[bool] = None,
82
+ output_hidden_states: Optional[bool] = None,
83
+ output_scores: Optional[bool] = None,
84
+ output_logits: Optional[bool] = None,
85
+ return_dict_in_generate: Optional[bool] = None,
86
+ synced_gpus: bool = False,
87
+ streamer: Optional["BaseStreamer"] = None,
88
+ **model_kwargs,
89
+ ) -> Union[GenerateNonBeamOutput, torch.Tensor]:
90
+ input_ids = input_ids.to(self.device)
91
+ batch_size, cur_len = input_ids.shape
92
+ # init values
93
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
94
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
95
+ if max_length is not None:
96
+ warnings.warn(
97
+ "`max_length` is deprecated in this function, use"
98
+ " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
99
+ UserWarning,
100
+ )
101
+ stopping_criteria = validate_stopping_criteria(
102
+ stopping_criteria, max_length)
103
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
104
+ if eos_token_id is not None:
105
+ stopping_criteria.append(
106
+ EosTokenCriteria(eos_token_id=eos_token_id))
107
+ else:
108
+ # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
109
+ eos_token_id = [
110
+ criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
111
+ ]
112
+ eos_token_id = eos_token_id[0] if eos_token_id else None
113
+ if eos_token_id is None and self.generation_config.eos_token_id is not None:
114
+ eos_token_id = self.generation_config.eos_token_id
115
+ stopping_criteria.append(
116
+ EosTokenCriteria(eos_token_id=eos_token_id))
117
+
118
+ if isinstance(eos_token_id, int):
119
+ eos_token_id = [eos_token_id]
120
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
121
+ output_attentions = (
122
+ output_attentions if output_attentions is not None else self.generation_config.output_attentions
123
+ )
124
+ output_hidden_states = (
125
+ output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
126
+ )
127
+ return_dict_in_generate = (
128
+ return_dict_in_generate
129
+ if return_dict_in_generate is not None
130
+ else self.generation_config.return_dict_in_generate
131
+ )
132
+
133
+ # init attention / hidden states / scores tuples
134
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
135
+ scores = () if (return_dict_in_generate and output_scores) else None
136
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
137
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
138
+ decoder_hidden_states = () if (
139
+ return_dict_in_generate and output_hidden_states) else None
140
+
141
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
142
+ if return_dict_in_generate and self.config.is_encoder_decoder:
143
+ encoder_attentions = model_kwargs["encoder_outputs"].get(
144
+ "attentions") if output_attentions else None
145
+ encoder_hidden_states = (
146
+ model_kwargs["encoder_outputs"].get(
147
+ "hidden_states") if output_hidden_states else None
148
+ )
149
+
150
+ # keep track of which sequences are already finished
151
+ if "inputs_embeds" in model_kwargs:
152
+ cur_len = model_kwargs["inputs_embeds"].shape[1]
153
+ this_peer_finished = False
154
+ unfinished_sequences = torch.ones(
155
+ batch_size, dtype=torch.long, device=input_ids.device)
156
+ model_kwargs["cache_position"] = torch.arange(
157
+ cur_len, device=input_ids.device)
158
+ true_seq_len = (cur_len + self.config.input_token_len - 1) // self.config.input_token_len
159
+ model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, -true_seq_len:]
160
+ max_length = stopping_criteria.max_length
161
+
162
+ generate_results = None
163
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
164
+ # prepare model inputs
165
+ model_inputs = self.prepare_inputs_for_generation(
166
+ input_ids, **model_kwargs)
167
+
168
+ input_length = input_ids.shape[1]
169
+
170
+ # forward pass to get next token
171
+ outputs = self(
172
+ **model_inputs,
173
+ return_dict=True,
174
+ output_attentions=output_attentions,
175
+ output_hidden_states=output_hidden_states,
176
+ max_output_length=max_length - input_length,
177
+ )
178
+
179
+ if synced_gpus and this_peer_finished:
180
+ continue # don't waste resources running the code we don't need
181
+ next_token_logits = outputs.logits
182
+
183
+ # pre-process distribution
184
+ next_tokens_scores = logits_processor(input_ids, next_token_logits)
185
+
186
+ # Store scores, attentions and hidden_states when required
187
+ if return_dict_in_generate:
188
+ if output_scores:
189
+ scores += (next_tokens_scores,)
190
+ if output_logits:
191
+ raw_logits += (next_token_logits,)
192
+ if output_attentions:
193
+ decoder_attentions += (
194
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (
195
+ outputs.attentions,)
196
+ )
197
+ if self.config.is_encoder_decoder:
198
+ cross_attentions += (outputs.cross_attentions,)
199
+
200
+ if output_hidden_states:
201
+ decoder_hidden_states += (
202
+ (outputs.decoder_hidden_states,)
203
+ if self.config.is_encoder_decoder
204
+ else (outputs.hidden_states,)
205
+ )
206
+
207
+ # argmax
208
+ # next_tokens = torch.argmax(next_tokens_scores, dim=-1)
209
+ next_tokens = next_tokens_scores
210
+
211
+ # finished sentences should have their next token be a padding token
212
+ if eos_token_id is not None:
213
+ if pad_token_id is None:
214
+ raise ValueError(
215
+ "If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
216
+ next_tokens = next_tokens * unfinished_sequences + \
217
+ pad_token_id * (1 - unfinished_sequences)
218
+
219
+ # update generated ids, model inputs, and length for next step
220
+ horizon_length = next_tokens.shape[-1] // self.config.input_token_len
221
+
222
+ past_key_values = model_kwargs.get("past_key_values")
223
+ if generate_results is None:
224
+ generate_results = next_tokens
225
+ else:
226
+ generate_results = torch.cat([generate_results, next_tokens], dim=-1)
227
+
228
+ # Use deterministic approach instead of median to avoid CUDA deterministic algorithm issues
229
+ # For flow models, use torch.quantile(p=0.5) which is equivalent to median but deterministic
230
+
231
+ selected_tokens = torch.quantile(next_tokens.float(), q=0.5, dim=1)
232
+ input_ids = torch.cat([input_ids, selected_tokens], dim=-1)
233
+
234
+ if streamer is not None:
235
+ streamer.put(next_tokens.cpu())
236
+ model_kwargs = self._update_model_kwargs_for_generation(
237
+ outputs,
238
+ model_kwargs,
239
+ horizon_length=horizon_length,
240
+ is_encoder_decoder=self.config.is_encoder_decoder,
241
+ )
242
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(
243
+ input_ids, scores)
244
+ this_peer_finished = unfinished_sequences.max() == 0
245
+
246
+ if input_ids.shape[-1] > max_length:
247
+ input_ids = input_ids[:, :max_length]
248
+
249
+ if streamer is not None:
250
+ streamer.end()
251
+
252
+ if return_dict_in_generate:
253
+ if self.config.is_encoder_decoder:
254
+ return GenerateEncoderDecoderOutput(
255
+ sequences=input_ids,
256
+ scores=scores,
257
+ logits=raw_logits,
258
+ encoder_attentions=encoder_attentions,
259
+ encoder_hidden_states=encoder_hidden_states,
260
+ decoder_attentions=decoder_attentions,
261
+ cross_attentions=cross_attentions,
262
+ decoder_hidden_states=decoder_hidden_states,
263
+ past_key_values=model_kwargs.get("past_key_values"),
264
+ )
265
+ else:
266
+ return GenerateDecoderOnlyOutput(
267
+ sequences=input_ids,
268
+ scores=scores,
269
+ logits=raw_logits,
270
+ attentions=decoder_attentions,
271
+ hidden_states=decoder_hidden_states,
272
+ past_key_values=model_kwargs.get("past_key_values"),
273
+ )
274
+ else:
275
+ return generate_results[:, :, :(max_length - cur_len)]
276
+
277
+ def _update_model_kwargs_for_generation(
278
+ self,
279
+ outputs: ModelOutput,
280
+ model_kwargs: Dict[str, Any],
281
+ horizon_length: int = 1,
282
+ is_encoder_decoder: bool = False,
283
+ standardize_cache_format: bool = False,
284
+ ) -> Dict[str, Any]:
285
+ # update past_key_values
286
+ for possible_cache_name in ALL_CACHE_NAMES:
287
+ if possible_cache_name in outputs:
288
+ if possible_cache_name in ("past_buckets_states", "mems"):
289
+ cache_name = "past_key_values"
290
+ else:
291
+ cache_name = possible_cache_name
292
+ model_kwargs[cache_name] = getattr(outputs, possible_cache_name)
293
+ break
294
+
295
+ # update token_type_ids with last value
296
+ if "token_type_ids" in model_kwargs:
297
+ token_type_ids = model_kwargs["token_type_ids"]
298
+ model_kwargs["token_type_ids"] = torch.cat(
299
+ [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
300
+
301
+ if not is_encoder_decoder:
302
+ # update attention mask
303
+ if "attention_mask" in model_kwargs:
304
+ attention_mask = model_kwargs["attention_mask"]
305
+ model_kwargs["attention_mask"] = torch.cat(
306
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], horizon_length))], dim=-1
307
+ )
308
+ else:
309
+ # update decoder attention mask
310
+ if "decoder_attention_mask" in model_kwargs:
311
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
312
+ model_kwargs["decoder_attention_mask"] = torch.cat(
313
+ [decoder_attention_mask, decoder_attention_mask.new_ones(
314
+ (decoder_attention_mask.shape[0], horizon_length))],
315
+ dim=-1,
316
+ )
317
+
318
+ if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
319
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + horizon_length
320
+
321
+ # update full_hidden_states: accumulate hidden states across generation steps for MTP layers
322
+ if hasattr(outputs, "hidden_states_for_mtp") and outputs.hidden_states_for_mtp is not None:
323
+ new_hs = outputs.hidden_states_for_mtp
324
+ if "full_hidden_states" in model_kwargs and model_kwargs["full_hidden_states"] is not None:
325
+ existing = model_kwargs["full_hidden_states"]
326
+ model_kwargs["full_hidden_states"] = torch.cat(
327
+ [existing.to(new_hs.device), new_hs], dim=1
328
+ )
329
+ else:
330
+ model_kwargs["full_hidden_states"] = new_hs
331
+
332
+ return model_kwargs