Shourya Bose
commited on
Commit
·
a9073bb
1
Parent(s):
4cc7625
add timefm weights
Browse files- README.md +8 -0
- example.py +33 -0
- model_kwargs.py +6 -0
- models/TimesFM.py +841 -0
- models/__pycache__/Autoformer.cpython-310.pyc +0 -0
- models/__pycache__/LSTM.cpython-310.pyc +0 -0
- models/__pycache__/LSTNet.cpython-310.pyc +0 -0
- weights/TimesFM_L_512_T_48_HET.pth +3 -0
- weights/TimesFM_L_512_T_48_HOM.pth +3 -0
- weights/TimesFM_L_512_T_4_HET.pth +3 -0
- weights/TimesFM_L_512_T_4_HOM.pth +3 -0
- weights/TimesFM_L_512_T_96_HET.pth +3 -0
- weights/TimesFM_L_512_T_96_HOM.pth +3 -0
README.md
CHANGED
|
@@ -14,6 +14,14 @@ When using the companion [dataset](https://huggingface.co/datasets/APPFL/Illinoi
|
|
| 14 |
- All models accept normalized inputs and produce normalized outputs, i.e. set `normalize = True` when generating the datasets.
|
| 15 |
- For Transformer, Autoformer, Informer, and TimesNet set `transformer = True`, while for LSTM, LSTNet, and PatchTST set `transformer = False`.
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
## Credits
|
| 18 |
|
| 19 |
Some model definitions have been adapted from the code provided in the [TSLib Library](https://github.com/thuml/Time-Series-Library).
|
|
|
|
| 14 |
- All models accept normalized inputs and produce normalized outputs, i.e. set `normalize = True` when generating the datasets.
|
| 15 |
- For Transformer, Autoformer, Informer, and TimesNet set `transformer = True`, while for LSTM, LSTNet, and PatchTST set `transformer = False`.
|
| 16 |
|
| 17 |
+
## Packages
|
| 18 |
+
|
| 19 |
+
Executing the code only requires `numpy` and `torch` (PyTorch) packages. You can either have them in your Python base installation, or use a `conda` environment.
|
| 20 |
+
|
| 21 |
+
## Example
|
| 22 |
+
|
| 23 |
+
In order to see how to use the model definitions and load the weights into them, see `example.py`.
|
| 24 |
+
|
| 25 |
## Credits
|
| 26 |
|
| 27 |
Some model definitions have been adapted from the code provided in the [TSLib Library](https://github.com/thuml/Time-Series-Library).
|
example.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
# import models
|
| 5 |
+
from models.LSTM import LSTM
|
| 6 |
+
from models.LSTNet import LSTNet
|
| 7 |
+
from models.Transformer import Transformer
|
| 8 |
+
from models.Autoformer import Autoformer
|
| 9 |
+
from models.Informer import Informer
|
| 10 |
+
from models.PatchTST import PatchTST
|
| 11 |
+
from models.TimesNet import TimesNet
|
| 12 |
+
from models.TimesFM import TimesFM
|
| 13 |
+
|
| 14 |
+
# import keyword args
|
| 15 |
+
from model_kwargs import *
|
| 16 |
+
|
| 17 |
+
# set lookback and lookahead. lookback is fixed to 512, while lookahead can be one among 4, 48, 96
|
| 18 |
+
# heterogeneity can be 'HET' or 'HOM'
|
| 19 |
+
lookback, lookahead, heterogeneity = 512, 48, 'HET'
|
| 20 |
+
|
| 21 |
+
if __name__ == "__main__":
|
| 22 |
+
|
| 23 |
+
models = [LSTM, LSTNet, Transformer, Autoformer, Informer, PatchTST, TimesNet, TimesFM]
|
| 24 |
+
kw_fns = [lstm_kwargs, lstnet_kwargs, transformer_kwargs, autoformer_kwargs, informer_kwargs, patchtst_kwargs, timesnet_kwargs, timesfm_kwargs]
|
| 25 |
+
|
| 26 |
+
# loop over models and their keyword functions
|
| 27 |
+
for model_class, kw_fn in zip(models,kw_fns):
|
| 28 |
+
# load an object of the model class
|
| 29 |
+
model = model_class(**kw_fn(lookback = lookback, lookahead = lookahead))
|
| 30 |
+
# load the weight in the model
|
| 31 |
+
result = model.load_state_dict(torch.load(os.path.join(*[os.getcwd(),'weights',f'{model_class.__name__}_L_{lookback}_T_{lookahead}_{heterogeneity}.pth']),map_location='cpu'))
|
| 32 |
+
# print the outcome
|
| 33 |
+
print(f"Loading weight for model {model_class.__name__}, lookback {lookback}, lookahead {lookahead}, heterogeneity {heterogeneity}, and the result was: {result}.")
|
model_kwargs.py
CHANGED
|
@@ -63,4 +63,10 @@ patchtst_kwargs = lambda lookback,lookahead:{
|
|
| 63 |
'd_model': 32*4,
|
| 64 |
'data_idx': [0,3,4,5,6,7],
|
| 65 |
'time_idx': [1,2]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
}
|
|
|
|
| 63 |
'd_model': 32*4,
|
| 64 |
'data_idx': [0,3,4,5,6,7],
|
| 65 |
'time_idx': [1,2]
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
timesfm_kwargs = lambda lookback, lookahead:{
|
| 69 |
+
'lookback': lookback,
|
| 70 |
+
'lookahead': lookahead,
|
| 71 |
+
'context_len': 512
|
| 72 |
}
|
models/TimesFM.py
ADDED
|
@@ -0,0 +1,841 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Google LLC
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Pytorch version of patched decoder."""
|
| 15 |
+
|
| 16 |
+
import dataclasses
|
| 17 |
+
import math
|
| 18 |
+
from typing import List, Tuple
|
| 19 |
+
import torch
|
| 20 |
+
from torch import nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _create_quantiles() -> list[float]:
|
| 25 |
+
return [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclasses.dataclass
|
| 29 |
+
class TimesFMConfig:
|
| 30 |
+
"""Config for initializing timesfm patched_decoder class."""
|
| 31 |
+
|
| 32 |
+
# The number of blocks in the model.
|
| 33 |
+
num_layers: int = 20
|
| 34 |
+
# The number of attention heads used in the attention layers of the model.
|
| 35 |
+
num_heads: int = 16
|
| 36 |
+
# The number of key-value heads for implementing attention.
|
| 37 |
+
num_kv_heads: int = 16
|
| 38 |
+
# The hidden size of the model.
|
| 39 |
+
hidden_size: int = 1280
|
| 40 |
+
# The dimension of the MLP representations.
|
| 41 |
+
intermediate_size: int = 1280
|
| 42 |
+
# The number of head dimensions.
|
| 43 |
+
head_dim: int = 80
|
| 44 |
+
# The epsilon used by the rms normalization layers.
|
| 45 |
+
rms_norm_eps: float = 1e-6
|
| 46 |
+
# Patch length
|
| 47 |
+
patch_len: int = 32
|
| 48 |
+
# Horizon length
|
| 49 |
+
horizon_len: int = 128
|
| 50 |
+
# quantiles
|
| 51 |
+
quantiles: List[float] = dataclasses.field(default_factory=_create_quantiles)
|
| 52 |
+
# Padding value
|
| 53 |
+
pad_val: float = 1123581321.0
|
| 54 |
+
# Tolerance
|
| 55 |
+
tolerance: float = 1e-6
|
| 56 |
+
# The dtype of the weights.
|
| 57 |
+
dtype: str = "bfloat32"
|
| 58 |
+
# use positional embedding
|
| 59 |
+
use_positional_embedding: bool = True
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _masked_mean_std(
|
| 63 |
+
inputs: torch.Tensor,
|
| 64 |
+
padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 65 |
+
"""Calculates mean and standard deviation of `inputs` across axis 1.
|
| 66 |
+
|
| 67 |
+
It excludes values where `padding` is 1.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
inputs: A PyTorch tensor of shape [b, n, p].
|
| 71 |
+
padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
A tuple containing the mean and standard deviation.
|
| 75 |
+
We return the statistics of the first patch with more than three non-padded
|
| 76 |
+
values.
|
| 77 |
+
"""
|
| 78 |
+
# Selecting the first patch with more than 3 unpadded values.
|
| 79 |
+
pad_sum = torch.sum(1 - padding, dim=2)
|
| 80 |
+
|
| 81 |
+
def _get_patch_index(arr: torch.Tensor):
|
| 82 |
+
indices = torch.argmax((arr >= 3).to(torch.int32), dim=1)
|
| 83 |
+
row_sum = (arr >= 3).to(torch.int32).sum(dim=1)
|
| 84 |
+
return torch.where(row_sum == 0, arr.shape[1] - 1, indices)
|
| 85 |
+
|
| 86 |
+
patch_indices = _get_patch_index(pad_sum)
|
| 87 |
+
bidxs = torch.arange(inputs.shape[0])
|
| 88 |
+
|
| 89 |
+
arr = inputs[bidxs, patch_indices, :]
|
| 90 |
+
pad = padding[bidxs, patch_indices, :]
|
| 91 |
+
|
| 92 |
+
# Create a mask where padding is 0
|
| 93 |
+
mask = 1 - pad
|
| 94 |
+
|
| 95 |
+
# Calculate the number of valid elements
|
| 96 |
+
num_valid_elements = torch.sum(mask, dim=1)
|
| 97 |
+
num_valid_elements = torch.where(
|
| 98 |
+
num_valid_elements == 0,
|
| 99 |
+
torch.tensor(1,
|
| 100 |
+
dtype=num_valid_elements.dtype,
|
| 101 |
+
device=num_valid_elements.device),
|
| 102 |
+
num_valid_elements,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Calculate the masked sum and squared sum
|
| 106 |
+
masked_sum = torch.sum(arr * mask, dim=1)
|
| 107 |
+
masked_squared_sum = torch.sum((arr * mask)**2, dim=1)
|
| 108 |
+
|
| 109 |
+
# Calculate the masked mean and standard deviation
|
| 110 |
+
masked_mean = masked_sum / num_valid_elements
|
| 111 |
+
masked_var = masked_squared_sum / num_valid_elements - masked_mean**2
|
| 112 |
+
masked_var = torch.where(
|
| 113 |
+
masked_var < 0.0,
|
| 114 |
+
torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device),
|
| 115 |
+
masked_var,
|
| 116 |
+
)
|
| 117 |
+
masked_std = torch.sqrt(masked_var)
|
| 118 |
+
|
| 119 |
+
return masked_mean, masked_std
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor:
|
| 123 |
+
"""Shifts rows of seq based on the first 0 in each row of the mask.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
mask: mask tensor of shape [B, N]
|
| 127 |
+
seq: seq tensor of shape [B, N, P]
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
Returns the shifted sequence.
|
| 131 |
+
"""
|
| 132 |
+
batch_size, num_seq, feature_dim = seq.shape
|
| 133 |
+
|
| 134 |
+
new_mask: torch.BoolTensor = mask == 0
|
| 135 |
+
|
| 136 |
+
# Use argmax to find the first True value in each row
|
| 137 |
+
indices = new_mask.to(torch.int32).argmax(dim=1)
|
| 138 |
+
|
| 139 |
+
# Handle rows with all zeros
|
| 140 |
+
indices[~new_mask.any(dim=1)] = -1
|
| 141 |
+
|
| 142 |
+
# Create index ranges for each sequence in the batch
|
| 143 |
+
idx_range = (torch.arange(num_seq).to(
|
| 144 |
+
seq.device).unsqueeze(0).unsqueeze(-1).expand(batch_size, -1,
|
| 145 |
+
feature_dim))
|
| 146 |
+
|
| 147 |
+
# Calculate shifted indices for each element in each sequence
|
| 148 |
+
shifted_idx = (idx_range - indices[:, None, None]) % num_seq
|
| 149 |
+
|
| 150 |
+
# Gather values from seq using shifted indices
|
| 151 |
+
shifted_seq = seq.gather(1, shifted_idx)
|
| 152 |
+
|
| 153 |
+
return shifted_seq
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def get_large_negative_number(dtype: torch.dtype) -> torch.Tensor:
|
| 157 |
+
"""Returns a large negative value for the given dtype."""
|
| 158 |
+
if dtype.is_floating_point:
|
| 159 |
+
dtype_max = torch.finfo(dtype).max
|
| 160 |
+
else:
|
| 161 |
+
dtype_max = torch.iinfo(dtype).max
|
| 162 |
+
return torch.tensor(-0.7 * dtype_max, dtype=dtype)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def apply_mask_to_logits(logits: torch.Tensor,
|
| 166 |
+
mask: torch.Tensor) -> torch.Tensor:
|
| 167 |
+
"""Applies a floating-point mask to a set of logits.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
logits: A torch.Tensor of logit values.
|
| 171 |
+
mask: A torch.Tensor (float32) of mask values with the encoding described
|
| 172 |
+
in the function documentation.
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
Masked logits.
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
min_value = get_large_negative_number(logits.dtype)
|
| 179 |
+
|
| 180 |
+
return torch.where((mask >= min_value * 0.5), logits, min_value)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def convert_paddings_to_mask(
|
| 184 |
+
paddings: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
| 185 |
+
"""Converts binary paddings to a logit mask ready to add to attention matrix.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
paddings: binary torch.Tensor of shape [B, T], with 1 denoting padding
|
| 189 |
+
token.
|
| 190 |
+
dtype: data type of the input.
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
A torch.Tensor of shape [B, 1, 1, T] ready to add to attention logits.
|
| 194 |
+
"""
|
| 195 |
+
attention_mask = paddings.detach().clone()
|
| 196 |
+
attention_mask = attention_mask[:, None, None, :] # Equivalent to jnp.newaxis
|
| 197 |
+
attention_mask *= get_large_negative_number(dtype)
|
| 198 |
+
return attention_mask
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def causal_mask(input_t: torch.Tensor) -> torch.Tensor:
|
| 202 |
+
"""Computes and returns causal mask.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
input_t: A torch.Tensor of shape [B, T, D].
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
An attention_mask torch.Tensor of shape [1, 1, T, T]. Attention mask has
|
| 209 |
+
already been converted to large negative values.
|
| 210 |
+
"""
|
| 211 |
+
assert input_t.dtype.is_floating_point, input_t.dtype
|
| 212 |
+
large_negative_number = get_large_negative_number(input_t.dtype)
|
| 213 |
+
t = input_t.shape[1]
|
| 214 |
+
col_idx = torch.arange(t).unsqueeze(0).repeat(t, 1)
|
| 215 |
+
row_idx = torch.arange(t).unsqueeze(1).repeat(1, t)
|
| 216 |
+
mask = (row_idx < col_idx).to(input_t.dtype) * large_negative_number
|
| 217 |
+
return (mask.unsqueeze(0).unsqueeze(0).to(input_t.device)
|
| 218 |
+
) # Equivalent to jnp.newaxis
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def merge_masks(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
| 222 |
+
"""Merges 2 masks.
|
| 223 |
+
|
| 224 |
+
logscale mask is expected but 0/1 mask is also fine.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
a: torch.Tensor of shape [1|B, 1, 1|T, S].
|
| 228 |
+
b: torch.Tensor of shape [1|B, 1, 1|T, S].
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
torch.Tensor of shape [1|B, 1, 1|T, S].
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
def expand_t(key_mask):
|
| 235 |
+
query_mask = key_mask.transpose(-1, -2) # Equivalent of jnp.transpose
|
| 236 |
+
return torch.minimum(query_mask, key_mask)
|
| 237 |
+
|
| 238 |
+
if a.shape[2] != b.shape[2]:
|
| 239 |
+
if a.shape[2] == 1:
|
| 240 |
+
a = expand_t(a)
|
| 241 |
+
else:
|
| 242 |
+
assert b.shape[2] == 1
|
| 243 |
+
b = expand_t(b)
|
| 244 |
+
|
| 245 |
+
assert a.shape[1:] == b.shape[1:], f"a.shape={a.shape}, b.shape={b.shape}."
|
| 246 |
+
return torch.minimum(a, b) # Element-wise minimum, similar to jnp.minimum
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class ResidualBlock(nn.Module):
|
| 250 |
+
"""TimesFM residual block."""
|
| 251 |
+
|
| 252 |
+
def __init__(
|
| 253 |
+
self,
|
| 254 |
+
input_dims,
|
| 255 |
+
hidden_dims,
|
| 256 |
+
output_dims,
|
| 257 |
+
):
|
| 258 |
+
super(ResidualBlock, self).__init__()
|
| 259 |
+
self.input_dims = input_dims
|
| 260 |
+
self.hidden_dims = hidden_dims
|
| 261 |
+
self.output_dims = output_dims
|
| 262 |
+
|
| 263 |
+
# Hidden Layer
|
| 264 |
+
self.hidden_layer = nn.Sequential(
|
| 265 |
+
nn.Linear(input_dims, hidden_dims),
|
| 266 |
+
nn.SiLU(),
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
# Output Layer
|
| 270 |
+
self.output_layer = nn.Linear(hidden_dims, output_dims)
|
| 271 |
+
# Residual Layer
|
| 272 |
+
self.residual_layer = nn.Linear(input_dims, output_dims)
|
| 273 |
+
|
| 274 |
+
def forward(self, x):
|
| 275 |
+
hidden = self.hidden_layer(x)
|
| 276 |
+
output = self.output_layer(hidden)
|
| 277 |
+
residual = self.residual_layer(x)
|
| 278 |
+
return output + residual
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class RMSNorm(torch.nn.Module):
|
| 282 |
+
"""Pax rms norm in pytorch."""
|
| 283 |
+
|
| 284 |
+
def __init__(
|
| 285 |
+
self,
|
| 286 |
+
dim: int,
|
| 287 |
+
eps: float = 1e-6,
|
| 288 |
+
add_unit_offset: bool = False,
|
| 289 |
+
):
|
| 290 |
+
super().__init__()
|
| 291 |
+
self.eps = eps
|
| 292 |
+
self.add_unit_offset = add_unit_offset
|
| 293 |
+
self.weight = nn.Parameter(torch.zeros(dim))
|
| 294 |
+
|
| 295 |
+
def _norm(self, x):
|
| 296 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 297 |
+
|
| 298 |
+
def forward(self, x):
|
| 299 |
+
output = self._norm(x.float())
|
| 300 |
+
if self.add_unit_offset:
|
| 301 |
+
output = output * (1 + self.weight.float())
|
| 302 |
+
else:
|
| 303 |
+
output = output * self.weight.float()
|
| 304 |
+
return output.type_as(x)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class TransformerMLP(nn.Module):
|
| 308 |
+
"""Pax transformer MLP in pytorch."""
|
| 309 |
+
|
| 310 |
+
def __init__(
|
| 311 |
+
self,
|
| 312 |
+
hidden_size: int,
|
| 313 |
+
intermediate_size: int,
|
| 314 |
+
):
|
| 315 |
+
super().__init__()
|
| 316 |
+
self.gate_proj = nn.Linear(hidden_size, intermediate_size)
|
| 317 |
+
self.down_proj = nn.Linear(intermediate_size, hidden_size)
|
| 318 |
+
self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6)
|
| 319 |
+
|
| 320 |
+
def forward(self, x, paddings=None):
|
| 321 |
+
gate_inp = self.layer_norm(x)
|
| 322 |
+
gate = self.gate_proj(gate_inp)
|
| 323 |
+
gate = F.relu(gate)
|
| 324 |
+
outputs = self.down_proj(gate)
|
| 325 |
+
if paddings is not None:
|
| 326 |
+
outputs = outputs * (1.0 - paddings[:, :, None])
|
| 327 |
+
return outputs + x
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class TimesFMAttention(nn.Module):
|
| 331 |
+
"""Implements the attention used in TimesFM."""
|
| 332 |
+
|
| 333 |
+
def __init__(
|
| 334 |
+
self,
|
| 335 |
+
hidden_size: int,
|
| 336 |
+
num_heads: int,
|
| 337 |
+
num_kv_heads: int,
|
| 338 |
+
head_dim: int,
|
| 339 |
+
):
|
| 340 |
+
super().__init__()
|
| 341 |
+
|
| 342 |
+
self.num_heads = num_heads
|
| 343 |
+
self.num_kv_heads = num_kv_heads
|
| 344 |
+
|
| 345 |
+
assert self.num_heads % self.num_kv_heads == 0
|
| 346 |
+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
| 347 |
+
|
| 348 |
+
self.hidden_size = hidden_size
|
| 349 |
+
self.head_dim = head_dim
|
| 350 |
+
|
| 351 |
+
self.q_size = self.num_heads * self.head_dim
|
| 352 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
| 353 |
+
self.scaling = nn.Parameter(
|
| 354 |
+
torch.empty((self.head_dim,), dtype=torch.float32),)
|
| 355 |
+
|
| 356 |
+
self.qkv_proj = nn.Linear(
|
| 357 |
+
self.hidden_size,
|
| 358 |
+
(self.num_heads + 2 * self.num_kv_heads) * self.head_dim,
|
| 359 |
+
)
|
| 360 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size)
|
| 361 |
+
|
| 362 |
+
def _per_dim_scaling(self, query: torch.Tensor) -> torch.Tensor:
|
| 363 |
+
# [batch_size, n_local_heads, input_len, head_dim]
|
| 364 |
+
r_softplus_0 = 1.442695041
|
| 365 |
+
softplus_func = torch.nn.Softplus()
|
| 366 |
+
scale = r_softplus_0 / math.sqrt(self.head_dim)
|
| 367 |
+
scale = scale * softplus_func(self.scaling)
|
| 368 |
+
return query * scale[None, None, None, :]
|
| 369 |
+
|
| 370 |
+
def forward(
|
| 371 |
+
self,
|
| 372 |
+
hidden_states: torch.Tensor,
|
| 373 |
+
mask: torch.Tensor,
|
| 374 |
+
kv_write_indices: torch.Tensor | None = None,
|
| 375 |
+
kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None,
|
| 376 |
+
) -> torch.Tensor:
|
| 377 |
+
hidden_states_shape = hidden_states.shape
|
| 378 |
+
assert len(hidden_states_shape) == 3
|
| 379 |
+
|
| 380 |
+
batch_size, input_len, _ = hidden_states_shape
|
| 381 |
+
|
| 382 |
+
qkv = self.qkv_proj(hidden_states)
|
| 383 |
+
xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
| 384 |
+
|
| 385 |
+
xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)
|
| 386 |
+
xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
|
| 387 |
+
xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)
|
| 388 |
+
xq = self._per_dim_scaling(xq)
|
| 389 |
+
|
| 390 |
+
# Write new kv cache.
|
| 391 |
+
# [batch_size, input_len, n_local_kv_heads, head_dim]
|
| 392 |
+
if kv_cache is not None and kv_write_indices is not None:
|
| 393 |
+
k_cache, v_cache = kv_cache
|
| 394 |
+
k_cache.index_copy_(1, kv_write_indices, xk)
|
| 395 |
+
v_cache.index_copy_(1, kv_write_indices, xv)
|
| 396 |
+
|
| 397 |
+
key = k_cache
|
| 398 |
+
value = v_cache
|
| 399 |
+
else:
|
| 400 |
+
key = xk
|
| 401 |
+
value = xv
|
| 402 |
+
if self.num_kv_heads != self.num_heads:
|
| 403 |
+
# [batch_size, max_seq_len, n_local_heads, head_dim]
|
| 404 |
+
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2)
|
| 405 |
+
value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2)
|
| 406 |
+
|
| 407 |
+
# [batch_size, n_local_heads, input_len, head_dim]
|
| 408 |
+
q = xq.transpose(1, 2)
|
| 409 |
+
# [batch_size, n_local_heads, max_seq_len, head_dim]
|
| 410 |
+
k = key.transpose(1, 2)
|
| 411 |
+
v = value.transpose(1, 2)
|
| 412 |
+
|
| 413 |
+
# [batch_size, n_local_heads, input_len, max_seq_len]
|
| 414 |
+
scores = torch.matmul(q, k.transpose(2, 3))
|
| 415 |
+
scores = scores + mask
|
| 416 |
+
scores = F.softmax(scores.float(), dim=-1).type_as(q)
|
| 417 |
+
|
| 418 |
+
# [batch_size, n_local_heads, input_len, head_dim]
|
| 419 |
+
output = torch.matmul(scores, v)
|
| 420 |
+
# return scores, output.transpose(1, 2).contiguous()
|
| 421 |
+
|
| 422 |
+
# [batch_size, input_len, hidden_dim]
|
| 423 |
+
output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1)
|
| 424 |
+
output = self.o_proj(output)
|
| 425 |
+
return scores, output
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
class TimesFMDecoderLayer(nn.Module):
|
| 429 |
+
"""Transformer layer."""
|
| 430 |
+
|
| 431 |
+
def __init__(
|
| 432 |
+
self,
|
| 433 |
+
hidden_size: int,
|
| 434 |
+
intermediate_size: int,
|
| 435 |
+
num_heads: int,
|
| 436 |
+
num_kv_heads: int,
|
| 437 |
+
head_dim: int,
|
| 438 |
+
rms_norm_eps: float = 1e-6,
|
| 439 |
+
):
|
| 440 |
+
super().__init__()
|
| 441 |
+
self.self_attn = TimesFMAttention(
|
| 442 |
+
hidden_size=hidden_size,
|
| 443 |
+
num_heads=num_heads,
|
| 444 |
+
num_kv_heads=num_kv_heads,
|
| 445 |
+
head_dim=head_dim,
|
| 446 |
+
)
|
| 447 |
+
self.mlp = TransformerMLP(
|
| 448 |
+
hidden_size=hidden_size,
|
| 449 |
+
intermediate_size=intermediate_size,
|
| 450 |
+
)
|
| 451 |
+
self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
|
| 452 |
+
|
| 453 |
+
def forward(
|
| 454 |
+
self,
|
| 455 |
+
hidden_states: torch.Tensor,
|
| 456 |
+
mask: torch.Tensor,
|
| 457 |
+
paddings: torch.Tensor,
|
| 458 |
+
kv_write_indices: torch.Tensor | None = None,
|
| 459 |
+
kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None,
|
| 460 |
+
) -> torch.Tensor:
|
| 461 |
+
# Self Attention
|
| 462 |
+
residual = hidden_states
|
| 463 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 464 |
+
scores, hidden_states = self.self_attn(
|
| 465 |
+
hidden_states=hidden_states,
|
| 466 |
+
mask=mask,
|
| 467 |
+
kv_write_indices=kv_write_indices,
|
| 468 |
+
kv_cache=kv_cache,
|
| 469 |
+
)
|
| 470 |
+
hidden_states = residual + hidden_states
|
| 471 |
+
|
| 472 |
+
# MLP
|
| 473 |
+
hidden_states = self.mlp(hidden_states, paddings=paddings)
|
| 474 |
+
|
| 475 |
+
return scores, hidden_states
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
class StackedDecoder(nn.Module):
|
| 479 |
+
"""Stacked transformer layer."""
|
| 480 |
+
|
| 481 |
+
def __init__(
|
| 482 |
+
self,
|
| 483 |
+
hidden_size: int,
|
| 484 |
+
intermediate_size: int,
|
| 485 |
+
num_heads: int,
|
| 486 |
+
num_kv_heads: int,
|
| 487 |
+
head_dim: int,
|
| 488 |
+
num_layers: int,
|
| 489 |
+
rms_norm_eps: float = 1e-6,
|
| 490 |
+
):
|
| 491 |
+
super().__init__()
|
| 492 |
+
|
| 493 |
+
self.layers = nn.ModuleList()
|
| 494 |
+
for _ in range(num_layers):
|
| 495 |
+
self.layers.append(
|
| 496 |
+
TimesFMDecoderLayer(
|
| 497 |
+
hidden_size=hidden_size,
|
| 498 |
+
intermediate_size=intermediate_size,
|
| 499 |
+
num_heads=num_heads,
|
| 500 |
+
num_kv_heads=num_kv_heads,
|
| 501 |
+
head_dim=head_dim,
|
| 502 |
+
rms_norm_eps=rms_norm_eps,
|
| 503 |
+
))
|
| 504 |
+
|
| 505 |
+
def forward(
|
| 506 |
+
self,
|
| 507 |
+
hidden_states: torch.Tensor,
|
| 508 |
+
paddings: torch.Tensor,
|
| 509 |
+
kv_write_indices: torch.Tensor | None = None,
|
| 510 |
+
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] | None = None,
|
| 511 |
+
) -> torch.Tensor:
|
| 512 |
+
padding_mask = convert_paddings_to_mask(paddings, hidden_states.dtype)
|
| 513 |
+
atten_mask = causal_mask(hidden_states)
|
| 514 |
+
mask = merge_masks(padding_mask, atten_mask)
|
| 515 |
+
for i in range(len(self.layers)):
|
| 516 |
+
layer = self.layers[i]
|
| 517 |
+
kv_cache = kv_caches[i] if kv_caches is not None else None
|
| 518 |
+
_, hidden_states = layer(
|
| 519 |
+
hidden_states=hidden_states,
|
| 520 |
+
mask=mask,
|
| 521 |
+
paddings=paddings,
|
| 522 |
+
kv_write_indices=kv_write_indices,
|
| 523 |
+
kv_cache=kv_cache,
|
| 524 |
+
)
|
| 525 |
+
return hidden_states
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
class PositionalEmbedding(torch.nn.Module):
|
| 529 |
+
"""Generates position embedding for a given 1-d sequence.
|
| 530 |
+
|
| 531 |
+
Attributes:
|
| 532 |
+
min_timescale: Start of the geometric index. Determines the periodicity of
|
| 533 |
+
the added signal.
|
| 534 |
+
max_timescale: End of the geometric index. Determines the frequency of the
|
| 535 |
+
added signal.
|
| 536 |
+
embedding_dims: Dimension of the embedding to be generated.
|
| 537 |
+
"""
|
| 538 |
+
|
| 539 |
+
def __init__(
|
| 540 |
+
self,
|
| 541 |
+
embedding_dims: int,
|
| 542 |
+
min_timescale: int = 1,
|
| 543 |
+
max_timescale: int = 10_000,
|
| 544 |
+
) -> None:
|
| 545 |
+
super().__init__()
|
| 546 |
+
self.min_timescale = min_timescale
|
| 547 |
+
self.max_timescale = max_timescale
|
| 548 |
+
self.embedding_dims = embedding_dims
|
| 549 |
+
|
| 550 |
+
def forward(self, seq_length=None, position=None):
|
| 551 |
+
"""Generates a Tensor of sinusoids with different frequencies.
|
| 552 |
+
|
| 553 |
+
Args:
|
| 554 |
+
seq_length: an optional Python int defining the output sequence length.
|
| 555 |
+
if the `position` argument is specified.
|
| 556 |
+
position: [B, seq_length], optional position for each token in the
|
| 557 |
+
sequence, only required when the sequence is packed.
|
| 558 |
+
|
| 559 |
+
Returns:
|
| 560 |
+
[B, seqlen, D] if `position` is specified, else [1, seqlen, D]
|
| 561 |
+
"""
|
| 562 |
+
if position is None:
|
| 563 |
+
assert seq_length is not None
|
| 564 |
+
# [1, seqlen]
|
| 565 |
+
position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0)
|
| 566 |
+
else:
|
| 567 |
+
assert position.ndim == 2, position.shape
|
| 568 |
+
|
| 569 |
+
num_timescales = self.embedding_dims // 2
|
| 570 |
+
log_timescale_increment = math.log(
|
| 571 |
+
float(self.max_timescale) / float(self.min_timescale)) / max(
|
| 572 |
+
num_timescales - 1, 1)
|
| 573 |
+
inv_timescales = self.min_timescale * torch.exp(
|
| 574 |
+
torch.arange(num_timescales, dtype=torch.float32) *
|
| 575 |
+
-log_timescale_increment)
|
| 576 |
+
scaled_time = position.unsqueeze(2) * inv_timescales.unsqueeze(0).unsqueeze(
|
| 577 |
+
0)
|
| 578 |
+
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
|
| 579 |
+
# Padding to ensure correct embedding dimension
|
| 580 |
+
signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2))
|
| 581 |
+
return signal
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
class PatchedTimeSeriesDecoder(nn.Module):
|
| 585 |
+
"""Patched time-series decoder."""
|
| 586 |
+
|
| 587 |
+
def __init__(self, config: TimesFMConfig):
|
| 588 |
+
super().__init__()
|
| 589 |
+
self.config = config
|
| 590 |
+
self.input_ff_layer = ResidualBlock(
|
| 591 |
+
input_dims=2 * config.patch_len,
|
| 592 |
+
output_dims=config.hidden_size,
|
| 593 |
+
hidden_dims=config.intermediate_size,
|
| 594 |
+
)
|
| 595 |
+
self.freq_emb = nn.Embedding(num_embeddings=3,
|
| 596 |
+
embedding_dim=config.hidden_size)
|
| 597 |
+
self.horizon_ff_layer = ResidualBlock(
|
| 598 |
+
input_dims=config.hidden_size,
|
| 599 |
+
output_dims=config.horizon_len * (1 + len(config.quantiles)),
|
| 600 |
+
hidden_dims=config.intermediate_size,
|
| 601 |
+
)
|
| 602 |
+
self.stacked_transformer = StackedDecoder(
|
| 603 |
+
hidden_size=self.config.hidden_size,
|
| 604 |
+
intermediate_size=self.config.intermediate_size,
|
| 605 |
+
num_heads=self.config.num_heads,
|
| 606 |
+
num_kv_heads=self.config.num_kv_heads,
|
| 607 |
+
head_dim=self.config.head_dim,
|
| 608 |
+
num_layers=self.config.num_layers,
|
| 609 |
+
rms_norm_eps=self.config.rms_norm_eps,
|
| 610 |
+
)
|
| 611 |
+
if self.config.use_positional_embedding:
|
| 612 |
+
self.position_emb = PositionalEmbedding(self.config.hidden_size)
|
| 613 |
+
|
| 614 |
+
def _forward_transform(
|
| 615 |
+
self, inputs: torch.Tensor, patched_pads: torch.Tensor
|
| 616 |
+
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
| 617 |
+
"""Input is of shape [B, N, P]."""
|
| 618 |
+
mu, sigma = _masked_mean_std(inputs, patched_pads)
|
| 619 |
+
sigma = torch.where(
|
| 620 |
+
sigma < self.config.tolerance,
|
| 621 |
+
torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device),
|
| 622 |
+
sigma,
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
# Normalize each patch
|
| 626 |
+
outputs = (inputs - mu[:, None, None]) / sigma[:, None, None]
|
| 627 |
+
outputs = torch.where(
|
| 628 |
+
torch.abs(inputs - self.config.pad_val) < self.config.tolerance,
|
| 629 |
+
torch.tensor(self.config.pad_val,
|
| 630 |
+
dtype=outputs.dtype,
|
| 631 |
+
device=outputs.device),
|
| 632 |
+
outputs,
|
| 633 |
+
)
|
| 634 |
+
return outputs, (mu, sigma)
|
| 635 |
+
|
| 636 |
+
def _reverse_transform(
|
| 637 |
+
self, outputs: torch.Tensor, stats: tuple[torch.Tensor,
|
| 638 |
+
torch.Tensor]) -> torch.Tensor:
|
| 639 |
+
"""Output is of shape [B, N, P, Q]."""
|
| 640 |
+
mu, sigma = stats
|
| 641 |
+
return outputs * sigma[:, None, None, None] + mu[:, None, None, None]
|
| 642 |
+
|
| 643 |
+
def _preprocess_input(
|
| 644 |
+
self,
|
| 645 |
+
input_ts: torch.Tensor,
|
| 646 |
+
input_padding: torch.Tensor,
|
| 647 |
+
) -> tuple[
|
| 648 |
+
torch.Tensor,
|
| 649 |
+
torch.Tensor,
|
| 650 |
+
tuple[torch.Tensor, torch.Tensor] | None,
|
| 651 |
+
torch.Tensor,
|
| 652 |
+
]:
|
| 653 |
+
"""Preprocess input for stacked transformer."""
|
| 654 |
+
|
| 655 |
+
# Reshape into patches (using view for efficiency)
|
| 656 |
+
bsize = input_ts.shape[0]
|
| 657 |
+
patched_inputs = input_ts.view(bsize, -1, self.config.patch_len)
|
| 658 |
+
patched_pads = input_padding.view(bsize, -1, self.config.patch_len)
|
| 659 |
+
|
| 660 |
+
patched_inputs = torch.where(
|
| 661 |
+
torch.abs(patched_pads - 1.0) < self.config.tolerance,
|
| 662 |
+
torch.tensor(0.0,
|
| 663 |
+
dtype=patched_inputs.dtype,
|
| 664 |
+
device=patched_inputs.device),
|
| 665 |
+
patched_inputs,
|
| 666 |
+
)
|
| 667 |
+
patched_pads = torch.where(
|
| 668 |
+
torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance,
|
| 669 |
+
torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device),
|
| 670 |
+
patched_pads,
|
| 671 |
+
)
|
| 672 |
+
patched_inputs, stats = self._forward_transform(patched_inputs,
|
| 673 |
+
patched_pads)
|
| 674 |
+
|
| 675 |
+
# B x N x D
|
| 676 |
+
patched_inputs = patched_inputs * (1.0 - patched_pads)
|
| 677 |
+
concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1)
|
| 678 |
+
model_input = self.input_ff_layer(concat_inputs)
|
| 679 |
+
|
| 680 |
+
# A patch should not be padded even if there is at least one zero.
|
| 681 |
+
patched_padding = torch.min(patched_pads,
|
| 682 |
+
dim=-1)[0] # Get the values from the min result
|
| 683 |
+
if self.config.use_positional_embedding:
|
| 684 |
+
pos_emb = self.position_emb(model_input.shape[1]).to(model_input.device)
|
| 685 |
+
pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0)
|
| 686 |
+
pos_emb = _shift_padded_seq(patched_padding, pos_emb)
|
| 687 |
+
model_input += pos_emb
|
| 688 |
+
|
| 689 |
+
return model_input, patched_padding, stats, patched_inputs
|
| 690 |
+
|
| 691 |
+
def _postprocess_output(
|
| 692 |
+
self,
|
| 693 |
+
model_output: torch.Tensor,
|
| 694 |
+
num_outputs: int,
|
| 695 |
+
stats: tuple[torch.Tensor, torch.Tensor],
|
| 696 |
+
) -> torch.Tensor:
|
| 697 |
+
"""Postprocess output of stacked transformer."""
|
| 698 |
+
|
| 699 |
+
# B x N x (H.Q)
|
| 700 |
+
output_ts = self.horizon_ff_layer(model_output)
|
| 701 |
+
|
| 702 |
+
# Reshape using view
|
| 703 |
+
b, n, _ = output_ts.shape
|
| 704 |
+
output_ts = output_ts.view(b, n, self.config.horizon_len, num_outputs)
|
| 705 |
+
|
| 706 |
+
return self._reverse_transform(output_ts, stats)
|
| 707 |
+
|
| 708 |
+
def forward(
|
| 709 |
+
self,
|
| 710 |
+
input_ts: torch.Tensor,
|
| 711 |
+
input_padding: torch.LongTensor,
|
| 712 |
+
freq: torch.Tensor,
|
| 713 |
+
) -> torch.Tensor:
|
| 714 |
+
num_outputs = len(self.config.quantiles) + 1
|
| 715 |
+
model_input, patched_padding, stats, _ = self._preprocess_input(
|
| 716 |
+
input_ts=input_ts,
|
| 717 |
+
input_padding=input_padding,
|
| 718 |
+
)
|
| 719 |
+
f_emb = self.freq_emb(freq) # B x 1 x D
|
| 720 |
+
model_input += f_emb
|
| 721 |
+
model_output = self.stacked_transformer(model_input, patched_padding)
|
| 722 |
+
|
| 723 |
+
output_ts = self._postprocess_output(model_output, num_outputs, stats)
|
| 724 |
+
return output_ts
|
| 725 |
+
|
| 726 |
+
def decode(
|
| 727 |
+
self,
|
| 728 |
+
input_ts: torch.Tensor,
|
| 729 |
+
paddings: torch.Tensor,
|
| 730 |
+
freq: torch.LongTensor,
|
| 731 |
+
horizon_len: int,
|
| 732 |
+
output_patch_len: int | None = None,
|
| 733 |
+
max_len: int = 512,
|
| 734 |
+
return_forecast_on_context: bool = False,
|
| 735 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 736 |
+
"""Auto-regressive decoding without caching.
|
| 737 |
+
|
| 738 |
+
Args:
|
| 739 |
+
input_ts: input time-series and paddings. Time-series shape B x C.
|
| 740 |
+
paddings: padding shape B x (C + H) where H is the prediction length.
|
| 741 |
+
freq: frequency shape B x 1
|
| 742 |
+
horizon_len: prediction length.
|
| 743 |
+
output_patch_len: output length to be fetched from one step of
|
| 744 |
+
auto-regressive decoding.
|
| 745 |
+
max_len: maximum training context length.
|
| 746 |
+
return_forecast_on_context: whether to return the model forecast on the
|
| 747 |
+
context except the first input patch.
|
| 748 |
+
|
| 749 |
+
Returns:
|
| 750 |
+
Tuple of two forecasting results:
|
| 751 |
+
- Point (mean) output predictions as a tensor with shape B x H'.
|
| 752 |
+
- Full predictions (mean and quantiles) as a tensor with shape
|
| 753 |
+
B x H' x (1 + # quantiles).
|
| 754 |
+
In particular, if return_forecast_on_context is True, H' is H plus
|
| 755 |
+
the forecastable context length, i.e. context_len - (first) patch_len.
|
| 756 |
+
"""
|
| 757 |
+
final_out = input_ts
|
| 758 |
+
context_len = final_out.shape[1]
|
| 759 |
+
full_outputs = []
|
| 760 |
+
if paddings.shape[1] != final_out.shape[1] + horizon_len:
|
| 761 |
+
raise ValueError(
|
| 762 |
+
"Length of paddings must match length of input + horizon_len:"
|
| 763 |
+
f" {paddings.shape[1]} != {final_out.shape[1]} + {horizon_len}")
|
| 764 |
+
if output_patch_len is None:
|
| 765 |
+
output_patch_len = self.config.horizon_len
|
| 766 |
+
num_decode_patches = (horizon_len + output_patch_len -
|
| 767 |
+
1) // output_patch_len
|
| 768 |
+
for step_index in range(num_decode_patches):
|
| 769 |
+
current_padding = paddings[:, 0:final_out.shape[1]]
|
| 770 |
+
input_ts = final_out[:, -max_len:]
|
| 771 |
+
input_padding = current_padding[:, -max_len:]
|
| 772 |
+
fprop_outputs = self(input_ts, input_padding, freq)
|
| 773 |
+
if return_forecast_on_context and step_index == 0:
|
| 774 |
+
# For the first decodings step, collect the model forecast on the
|
| 775 |
+
# context except the unavailable first input batch forecast.
|
| 776 |
+
new_full_ts = fprop_outputs[:, :-1, :self.config.patch_len, :]
|
| 777 |
+
new_full_ts = fprop_outputs.view(new_full_ts.size(0), -1,
|
| 778 |
+
new_full_ts.size(3))
|
| 779 |
+
|
| 780 |
+
full_outputs.append(new_full_ts)
|
| 781 |
+
|
| 782 |
+
# (full batch, last patch, output_patch_len, index of mean forecast = 0)
|
| 783 |
+
new_ts = fprop_outputs[:, -1, :output_patch_len, 0]
|
| 784 |
+
new_full_ts = fprop_outputs[:, -1, :output_patch_len, :]
|
| 785 |
+
# (full batch, last patch, output_patch_len, all output indices)
|
| 786 |
+
full_outputs.append(new_full_ts)
|
| 787 |
+
final_out = torch.concatenate([final_out, new_ts], axis=-1)
|
| 788 |
+
|
| 789 |
+
if return_forecast_on_context:
|
| 790 |
+
# `full_outputs` indexing starts at after the first input patch.
|
| 791 |
+
full_outputs = torch.concatenate(
|
| 792 |
+
full_outputs,
|
| 793 |
+
axis=1)[:, :(context_len - self.config.patch_len + horizon_len), :]
|
| 794 |
+
else:
|
| 795 |
+
# `full_outputs` indexing starts at the forecast horizon.
|
| 796 |
+
full_outputs = torch.concatenate(full_outputs, axis=1)[:,
|
| 797 |
+
0:horizon_len, :]
|
| 798 |
+
|
| 799 |
+
return (full_outputs[:, :, 0], full_outputs)
|
| 800 |
+
|
| 801 |
+
class TimesFM(nn.Module):
|
| 802 |
+
|
| 803 |
+
def __init__(self, lookback: int = 512, lookahead: int = 96, context_len: int = 512):
|
| 804 |
+
|
| 805 |
+
super(TimesFM, self).__init__()
|
| 806 |
+
|
| 807 |
+
self.timesfm = PatchedTimeSeriesDecoder(TimesFMConfig())
|
| 808 |
+
self.lookback, self.lookahead = lookback, lookahead
|
| 809 |
+
self.context_len = context_len
|
| 810 |
+
|
| 811 |
+
def load_state_dict(self, state_dict, *args, **kwargs):
|
| 812 |
+
|
| 813 |
+
return self.timesfm.load_state_dict(state_dict, *args, **kwargs)
|
| 814 |
+
|
| 815 |
+
def state_dict(self, *args, **kwargs):
|
| 816 |
+
|
| 817 |
+
return self.timesfm.state_dict(*args, **kwargs)
|
| 818 |
+
|
| 819 |
+
def pad_tensor(self, x):
|
| 820 |
+
|
| 821 |
+
B, L = x.shape
|
| 822 |
+
device = x.device
|
| 823 |
+
dtype = x.dtype
|
| 824 |
+
|
| 825 |
+
if L < self.context_len:
|
| 826 |
+
padded_input = torch.zeros((B, self.context_len), device=device, dtype=dtype)
|
| 827 |
+
padded_input[:, -L:] = x
|
| 828 |
+
padding = torch.ones((B, self.context_len), device=device, dtype=dtype)
|
| 829 |
+
padding[:, -L:] = 0
|
| 830 |
+
else:
|
| 831 |
+
padded_input = x[:, -self.context_len:]
|
| 832 |
+
padding = torch.zeros((B, self.context_len), device=device, dtype=dtype)
|
| 833 |
+
|
| 834 |
+
freq = torch.zeros((B, 1), device=device, dtype=torch.long)
|
| 835 |
+
|
| 836 |
+
return padded_input, torch.cat((padding,torch.zeros((B,self.lookahead),device=device,dtype=dtype)),dim=-1), freq
|
| 837 |
+
|
| 838 |
+
def forward(self, x):
|
| 839 |
+
|
| 840 |
+
padded_inp, padding, freq = self.pad_tensor(x)
|
| 841 |
+
return self.timesfm.decode(padded_inp,padding,freq,self.lookahead)[0] # ignoring quantiles
|
models/__pycache__/Autoformer.cpython-310.pyc
CHANGED
|
Binary files a/models/__pycache__/Autoformer.cpython-310.pyc and b/models/__pycache__/Autoformer.cpython-310.pyc differ
|
|
|
models/__pycache__/LSTM.cpython-310.pyc
CHANGED
|
Binary files a/models/__pycache__/LSTM.cpython-310.pyc and b/models/__pycache__/LSTM.cpython-310.pyc differ
|
|
|
models/__pycache__/LSTNet.cpython-310.pyc
CHANGED
|
Binary files a/models/__pycache__/LSTNet.cpython-310.pyc and b/models/__pycache__/LSTNet.cpython-310.pyc differ
|
|
|
weights/TimesFM_L_512_T_48_HET.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5dd216286c5493e6aaa9aa0f08ccf6e645423e83733e6e2c6be78920f5266cc4
|
| 3 |
+
size 814365703
|
weights/TimesFM_L_512_T_48_HOM.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b1eb69eeaa672c28212c5fe410d4b7d87c41a0868b8874f33308ab932f01ac89
|
| 3 |
+
size 814365703
|
weights/TimesFM_L_512_T_4_HET.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:826f2c6d2f01218f55579997cd057257f6d5817b3856fb9ffd6e70d13c5d8e2a
|
| 3 |
+
size 814365382
|
weights/TimesFM_L_512_T_4_HOM.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d710dc8d8012d226a63d4a983743b48dfeb21d10c1d2bc674b86ec6472b4a060
|
| 3 |
+
size 814365382
|
weights/TimesFM_L_512_T_96_HET.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3987782b50e4e6119cd9d35df3815bb2895ec010100919862c35620d9459767d
|
| 3 |
+
size 814365703
|
weights/TimesFM_L_512_T_96_HOM.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:49f862e58bc92993cf966b06facaddc79a7e8875d8a525561b5ae3fc3b67a1fc
|
| 3 |
+
size 814365703
|