Faizack commited on
Commit
360970f
·
0 Parent(s):

Initial Kronos-small custom deployment

Browse files
Files changed (9) hide show
  1. .gitattributes +1 -0
  2. .gitignore +4 -0
  3. README.md +120 -0
  4. config.json +13 -0
  5. inference.py +82 -0
  6. kronos.py +835 -0
  7. model.safetensors +3 -0
  8. module.py +581 -0
  9. requirements.txt +6 -0
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ .env*
4
+ .DS_Store
README.md ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## `faizack/kronos-small-custom`
2
+
3
+ Custom Hugging Face model repo for deploying the **Kronos-small** time-series forecasting model as an Inference Endpoint.
4
+
5
+ This folder is structured so you can **zip it or `git init` it and push directly** to Hugging Face under your account [`faizack`](https://huggingface.co/faizack).
6
+
7
+ ### 1. Files expected in this repo
8
+
9
+ You will need the following files in the root of the Hugging Face repo:
10
+
11
+ - `config.json` – copied or downloaded from `NeoQuasar/Kronos-small`
12
+ - `model.safetensors` – weights from `NeoQuasar/Kronos-small`
13
+ - `model.py` – Kronos model definition (from the official GitHub repo)
14
+ - `tokenizer.py` – Kronos tokenizer implementation
15
+ - `predictor.py` – `KronosPredictor` wrapper
16
+ - `inference.py` – entrypoint used by Hugging Face Inference Endpoints (already provided here)
17
+ - `requirements.txt` – Python dependencies (already provided here)
18
+
19
+ This folder currently includes:
20
+
21
+ - `README.md` (this file)
22
+ - `inference.py`
23
+ - `requirements.txt`
24
+ - `.env.example`
25
+ - `.gitattributes` (Git LFS for safetensors)
26
+ - `.gitignore`
27
+
28
+ You still need to add:
29
+
30
+ - `config.json`
31
+ - `model.safetensors`
32
+ - `model.py`
33
+ - `tokenizer.py`
34
+ - `predictor.py`
35
+
36
+ ### 2. How to prepare and push to Hugging Face
37
+
38
+ From this folder:
39
+
40
+ ```bash
41
+ cd kronos-small-custom
42
+
43
+ # (optional) initialize git
44
+ git init
45
+ git lfs install
46
+
47
+ # Log in to Hugging Face
48
+ huggingface-cli login
49
+
50
+ # Create the remote repo under your account
51
+ huggingface-cli repo create faizack/kronos-small-custom --type model
52
+
53
+ # Add HF remote
54
+ git remote add origin https://huggingface.co/faizack/kronos-small-custom
55
+ ```
56
+
57
+ Now copy in the Kronos implementation and weights:
58
+
59
+ 1. From the official Kronos GitHub repo, copy:
60
+ - `model.py`
61
+ - `tokenizer.py`
62
+ - `predictor.py`
63
+ 2. From `NeoQuasar/Kronos-small` on Hugging Face, download:
64
+ - `config.json`
65
+ - `model.safetensors`
66
+
67
+ Then commit and push:
68
+
69
+ ```bash
70
+ git add .
71
+ git commit -m "Initial Kronos-small custom deployment"
72
+ git push -u origin main
73
+ ```
74
+
75
+ ### 3. Inference contract
76
+
77
+ `inference.py` exposes a `predict(request)` function that Hugging Face Inference Endpoints will call.
78
+
79
+ Expected JSON body:
80
+
81
+ ```json
82
+ {
83
+ "inputs": {
84
+ "df": [
85
+ {"open": 1.0, "high": 1.1, "low": 0.9, "close": 1.05},
86
+ {"open": 1.05, "high": 1.12, "low": 1.0, "close": 1.08}
87
+ ],
88
+ "x_timestamp": ["2024-01-01T00:00:00Z", "2024-01-01T01:00:00Z"],
89
+ "y_timestamp": ["2024-01-01T02:00:00Z", "2024-01-01T03:00:00Z"],
90
+ "pred_len": 2,
91
+ "T": 1.0,
92
+ "top_p": 0.9,
93
+ "sample_count": 1
94
+ }
95
+ }
96
+ ```
97
+
98
+ Response structure:
99
+
100
+ ```json
101
+ {
102
+ "predictions": [
103
+ {
104
+ "open": ...,
105
+ "high": ...,
106
+ "low": ...,
107
+ "close": ...
108
+ },
109
+ {
110
+ "open": ...,
111
+ "high": ...,
112
+ "low": ...,
113
+ "close": ...
114
+ }
115
+ ]
116
+ }
117
+ ```
118
+
119
+ You can adapt this contract as needed, as long as `predict` returns JSON-serializable data.
120
+
config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_dropout_p": 0.1,
3
+ "d_model": 512,
4
+ "ff_dim": 1024,
5
+ "ffn_dropout_p": 0.25,
6
+ "learn_te": true,
7
+ "n_heads": 8,
8
+ "n_layers": 8,
9
+ "resid_dropout_p": 0.25,
10
+ "s1_bits": 10,
11
+ "s2_bits": 10,
12
+ "token_dropout_p": 0.1
13
+ }
inference.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict
3
+
4
+ import pandas as pd
5
+ import torch
6
+
7
+ from kronos import Kronos, KronosTokenizer, KronosPredictor # type: ignore
8
+
9
+
10
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+
13
+ def _load_components(model_dir: str = "."):
14
+ """
15
+ Load tokenizer, model, and predictor from a local directory.
16
+
17
+ This is called once at module import time on HF Inference Endpoints.
18
+ """
19
+ tokenizer = KronosTokenizer.from_pretrained(model_dir)
20
+ model = Kronos.from_pretrained(model_dir).to(DEVICE)
21
+
22
+ max_context = int(os.getenv("KRONOS_MAX_CONTEXT", "512"))
23
+
24
+ predictor = KronosPredictor(
25
+ model=model,
26
+ tokenizer=tokenizer,
27
+ device=DEVICE,
28
+ max_context=max_context,
29
+ )
30
+
31
+ return tokenizer, model, predictor
32
+
33
+
34
+ TOKENIZER, MODEL, PREDICTOR = _load_components(".")
35
+
36
+
37
+ def predict(request: Dict[str, Any]) -> Dict[str, Any]:
38
+ """
39
+ Entry point for Hugging Face Inference Endpoints.
40
+
41
+ Expected input format:
42
+
43
+ {
44
+ "inputs": {
45
+ "df": [
46
+ {"open": ..., "high": ..., "low": ..., "close": ...},
47
+ ...
48
+ ],
49
+ "x_timestamp": [...], # list of ISO8601 strings or timestamps
50
+ "y_timestamp": [...], # list of ISO8601 strings or timestamps
51
+ "pred_len": 120,
52
+ "T": 1.0, # optional
53
+ "top_p": 0.9, # optional
54
+ "sample_count": 1 # optional
55
+ }
56
+ }
57
+ """
58
+ inputs = request.get("inputs", request)
59
+
60
+ df = pd.DataFrame(inputs["df"])
61
+ x_timestamp = pd.to_datetime(inputs["x_timestamp"])
62
+ y_timestamp = pd.to_datetime(inputs["y_timestamp"])
63
+
64
+ pred_len = int(inputs["pred_len"])
65
+ T = float(inputs.get("T", 1.0))
66
+ top_p = float(inputs.get("top_p", 0.9))
67
+ sample_count = int(inputs.get("sample_count", 1))
68
+
69
+ result_df = PREDICTOR.predict(
70
+ df=df,
71
+ x_timestamp=x_timestamp,
72
+ y_timestamp=y_timestamp,
73
+ pred_len=pred_len,
74
+ T=T,
75
+ top_p=top_p,
76
+ sample_count=sample_count,
77
+ )
78
+
79
+ # Return a plain dict for JSON serialization
80
+ return {
81
+ "predictions": result_df.to_dict(orient="records"),
82
+ }
kronos.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from huggingface_hub import PyTorchModelHubMixin
4
+ from tqdm import trange
5
+ from module import *
6
+
7
+
8
+ class KronosTokenizer(nn.Module, PyTorchModelHubMixin):
9
+ """
10
+ KronosTokenizer module for tokenizing input data using a hybrid quantization approach.
11
+
12
+ This tokenizer utilizes a combination of encoder and decoder Transformer blocks
13
+ along with the Binary Spherical Quantization (BSQuantizer) to compress and decompress input data.
14
+
15
+ Args:
16
+ d_in (int): Input dimension.
17
+ d_model (int): Model dimension.
18
+ n_heads (int): Number of attention heads.
19
+ ff_dim (int): Feed-forward dimension.
20
+ n_enc_layers (int): Number of encoder layers.
21
+ n_dec_layers (int): Number of decoder layers.
22
+ ffn_dropout_p (float): Dropout probability for feed-forward networks.
23
+ attn_dropout_p (float): Dropout probability for attention mechanisms.
24
+ resid_dropout_p (float): Dropout probability for residual connections.
25
+ s1_bits (int): Number of bits for the pre token in BSQuantizer.
26
+ s2_bits (int): Number of bits for the post token in BSQuantizer.
27
+ beta (float): Beta parameter for BSQuantizer.
28
+ gamma0 (float): Gamma0 parameter for BSQuantizer.
29
+ gamma (float): Gamma parameter for BSQuantizer.
30
+ zeta (float): Zeta parameter for BSQuantizer.
31
+ group_size (int): Group size parameter for BSQuantizer.
32
+
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ d_in,
38
+ d_model,
39
+ n_heads,
40
+ ff_dim,
41
+ n_enc_layers,
42
+ n_dec_layers,
43
+ ffn_dropout_p,
44
+ attn_dropout_p,
45
+ resid_dropout_p,
46
+ s1_bits,
47
+ s2_bits,
48
+ beta,
49
+ gamma0,
50
+ gamma,
51
+ zeta,
52
+ group_size,
53
+ ):
54
+
55
+ super().__init__()
56
+ self.d_in = d_in
57
+ self.d_model = d_model
58
+ self.n_heads = n_heads
59
+ self.ff_dim = ff_dim
60
+ self.enc_layers = n_enc_layers
61
+ self.dec_layers = n_dec_layers
62
+ self.ffn_dropout_p = ffn_dropout_p
63
+ self.attn_dropout_p = attn_dropout_p
64
+ self.resid_dropout_p = resid_dropout_p
65
+
66
+ self.s1_bits = s1_bits
67
+ self.s2_bits = s2_bits
68
+ self.codebook_dim = (
69
+ s1_bits + s2_bits
70
+ ) # Total dimension of the codebook after quantization
71
+ self.embed = nn.Linear(self.d_in, self.d_model)
72
+ self.head = nn.Linear(self.d_model, self.d_in)
73
+
74
+ # Encoder Transformer Blocks
75
+ self.encoder = nn.ModuleList(
76
+ [
77
+ TransformerBlock(
78
+ self.d_model,
79
+ self.n_heads,
80
+ self.ff_dim,
81
+ self.ffn_dropout_p,
82
+ self.attn_dropout_p,
83
+ self.resid_dropout_p,
84
+ )
85
+ for _ in range(self.enc_layers - 1)
86
+ ]
87
+ )
88
+ # Decoder Transformer Blocks
89
+ self.decoder = nn.ModuleList(
90
+ [
91
+ TransformerBlock(
92
+ self.d_model,
93
+ self.n_heads,
94
+ self.ff_dim,
95
+ self.ffn_dropout_p,
96
+ self.attn_dropout_p,
97
+ self.resid_dropout_p,
98
+ )
99
+ for _ in range(self.dec_layers - 1)
100
+ ]
101
+ )
102
+ self.quant_embed = nn.Linear(
103
+ in_features=self.d_model, out_features=self.codebook_dim
104
+ ) # Linear layer before quantization
105
+ self.post_quant_embed_pre = nn.Linear(
106
+ in_features=self.s1_bits, out_features=self.d_model
107
+ ) # Linear layer after quantization (pre part - s1 bits)
108
+ self.post_quant_embed = nn.Linear(
109
+ in_features=self.codebook_dim, out_features=self.d_model
110
+ ) # Linear layer after quantization (full codebook)
111
+ self.tokenizer = BSQuantizer(
112
+ self.s1_bits, self.s2_bits, beta, gamma0, gamma, zeta, group_size
113
+ ) # BSQuantizer module
114
+
115
+ def forward(self, x):
116
+ """
117
+ Forward pass of the KronosTokenizer.
118
+
119
+ Args:
120
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_in).
121
+
122
+ Returns:
123
+ tuple: A tuple containing:
124
+ - tuple: (z_pre, z) - Reconstructed outputs from decoder with s1_bits and full codebook respectively,
125
+ both of shape (batch_size, seq_len, d_in).
126
+ - torch.Tensor: bsq_loss - Loss from the BSQuantizer.
127
+ - torch.Tensor: quantized - Quantized representation from BSQuantizer.
128
+ - torch.Tensor: z_indices - Indices from the BSQuantizer.
129
+ """
130
+ z = self.embed(x)
131
+
132
+ for layer in self.encoder:
133
+ z = layer(z)
134
+
135
+ z = self.quant_embed(z) # (B, T, codebook)
136
+
137
+ bsq_loss, quantized, z_indices = self.tokenizer(z)
138
+
139
+ quantized_pre = quantized[
140
+ :, :, : self.s1_bits
141
+ ] # Extract the first part of quantized representation (s1_bits)
142
+ z_pre = self.post_quant_embed_pre(quantized_pre)
143
+
144
+ z = self.post_quant_embed(quantized)
145
+
146
+ # Decoder layers (for pre part - s1 bits)
147
+ for layer in self.decoder:
148
+ z_pre = layer(z_pre)
149
+ z_pre = self.head(z_pre)
150
+
151
+ # Decoder layers (for full codebook)
152
+ for layer in self.decoder:
153
+ z = layer(z)
154
+ z = self.head(z)
155
+
156
+ return (z_pre, z), bsq_loss, quantized, z_indices
157
+
158
+ def indices_to_bits(self, x, half=False):
159
+ """
160
+ Converts indices to bit representations and scales them.
161
+
162
+ Args:
163
+ x (torch.Tensor): Indices tensor.
164
+ half (bool, optional): Whether to process only half of the codebook dimension. Defaults to False.
165
+
166
+ Returns:
167
+ torch.Tensor: Bit representation tensor.
168
+ """
169
+ if half:
170
+ x1 = x[0] # Assuming x is a tuple of indices if half is True
171
+ x2 = x[1]
172
+ mask = 2 ** torch.arange(
173
+ self.codebook_dim // 2, device=x1.device, dtype=torch.long
174
+ ) # Create a mask for bit extraction
175
+ x1 = (x1.unsqueeze(-1) & mask) != 0 # Extract bits for the first half
176
+ x2 = (x2.unsqueeze(-1) & mask) != 0 # Extract bits for the second half
177
+ x = torch.cat([x1, x2], dim=-1) # Concatenate the bit representations
178
+ else:
179
+ mask = 2 ** torch.arange(
180
+ self.codebook_dim, device=x.device, dtype=torch.long
181
+ ) # Create a mask for bit extraction
182
+ x = (x.unsqueeze(-1) & mask) != 0 # Extract bits
183
+
184
+ x = x.float() * 2 - 1 # Convert boolean to bipolar (-1, 1)
185
+ q_scale = 1.0 / (self.codebook_dim**0.5) # Scaling factor
186
+ x = x * q_scale
187
+ return x
188
+
189
+ def encode(self, x, half=False):
190
+ """
191
+ Encodes the input data into quantized indices.
192
+
193
+ Args:
194
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_in).
195
+ half (bool, optional): Whether to use half quantization in BSQuantizer. Defaults to False.
196
+
197
+ Returns:
198
+ torch.Tensor: Quantized indices from BSQuantizer.
199
+ """
200
+ z = self.embed(x)
201
+ for layer in self.encoder:
202
+ z = layer(z)
203
+ z = self.quant_embed(z)
204
+
205
+ bsq_loss, quantized, z_indices = self.tokenizer(z, half)
206
+ return z_indices
207
+
208
+ def decode(self, x, half=False):
209
+ """
210
+ Decodes quantized indices back to the input data space.
211
+
212
+ Args:
213
+ x (torch.Tensor): Quantized indices tensor.
214
+ half (bool, optional): Whether the indices were generated with half quantization. Defaults to False.
215
+
216
+ Returns:
217
+ torch.Tensor: Reconstructed output tensor of shape (batch_size, seq_len, d_in).
218
+ """
219
+ quantized = self.indices_to_bits(x, half)
220
+ z = self.post_quant_embed(quantized)
221
+ for layer in self.decoder:
222
+ z = layer(z)
223
+ z = self.head(z)
224
+ return z
225
+
226
+
227
+ class Kronos(nn.Module, PyTorchModelHubMixin):
228
+ """
229
+ Kronos Model.
230
+
231
+ Args:
232
+ s1_bits (int): Number of bits for pre tokens.
233
+ s2_bits (int): Number of bits for post tokens.
234
+ n_layers (int): Number of Transformer blocks.
235
+ d_model (int): Dimension of the model's embeddings and hidden states.
236
+ n_heads (int): Number of attention heads in the MultiheadAttention layers.
237
+ ff_dim (int): Dimension of the feedforward network in the Transformer blocks.
238
+ ffn_dropout_p (float): Dropout probability for the feedforward network.
239
+ attn_dropout_p (float): Dropout probability for the attention layers.
240
+ resid_dropout_p (float): Dropout probability for residual connections.
241
+ token_dropout_p (float): Dropout probability for token embeddings.
242
+ learn_te (bool): Whether to use learnable temporal embeddings.
243
+ """
244
+
245
+ def __init__(
246
+ self,
247
+ s1_bits,
248
+ s2_bits,
249
+ n_layers,
250
+ d_model,
251
+ n_heads,
252
+ ff_dim,
253
+ ffn_dropout_p,
254
+ attn_dropout_p,
255
+ resid_dropout_p,
256
+ token_dropout_p,
257
+ learn_te,
258
+ ):
259
+ super().__init__()
260
+ self.s1_bits = s1_bits
261
+ self.s2_bits = s2_bits
262
+ self.n_layers = n_layers
263
+ self.d_model = d_model
264
+ self.n_heads = n_heads
265
+ self.learn_te = learn_te
266
+ self.ff_dim = ff_dim
267
+ self.ffn_dropout_p = ffn_dropout_p
268
+ self.attn_dropout_p = attn_dropout_p
269
+ self.resid_dropout_p = resid_dropout_p
270
+ self.token_dropout_p = token_dropout_p
271
+
272
+ self.s1_vocab_size = 2**self.s1_bits
273
+ self.token_drop = nn.Dropout(self.token_dropout_p)
274
+ self.embedding = HierarchicalEmbedding(self.s1_bits, self.s2_bits, self.d_model)
275
+ self.time_emb = TemporalEmbedding(self.d_model, self.learn_te)
276
+ self.transformer = nn.ModuleList(
277
+ [
278
+ TransformerBlock(
279
+ self.d_model,
280
+ self.n_heads,
281
+ self.ff_dim,
282
+ self.ffn_dropout_p,
283
+ self.attn_dropout_p,
284
+ self.resid_dropout_p,
285
+ )
286
+ for _ in range(self.n_layers)
287
+ ]
288
+ )
289
+ self.norm = RMSNorm(self.d_model)
290
+ self.dep_layer = DependencyAwareLayer(self.d_model)
291
+ self.head = DualHead(self.s1_bits, self.s2_bits, self.d_model)
292
+ self.apply(self._init_weights)
293
+
294
+ def _init_weights(self, module):
295
+
296
+ if isinstance(module, nn.Linear):
297
+ nn.init.xavier_normal_(module.weight)
298
+ if module.bias is not None:
299
+ nn.init.zeros_(module.bias)
300
+ elif isinstance(module, nn.Embedding):
301
+ nn.init.normal_(module.weight, mean=0, std=self.embedding.d_model**-0.5)
302
+ elif isinstance(module, nn.LayerNorm):
303
+ nn.init.ones_(module.weight)
304
+ nn.init.zeros_(module.bias)
305
+ elif isinstance(module, RMSNorm):
306
+ nn.init.ones_(module.weight)
307
+
308
+ def forward(
309
+ self,
310
+ s1_ids,
311
+ s2_ids,
312
+ stamp=None,
313
+ padding_mask=None,
314
+ use_teacher_forcing=False,
315
+ s1_targets=None,
316
+ ):
317
+ """
318
+ Args:
319
+ s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len]
320
+ s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len]
321
+ stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None.
322
+ padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None.
323
+ use_teacher_forcing (bool, optional): Whether to use teacher forcing for s1 decoding. Defaults to False.
324
+ s1_targets (torch.Tensor, optional): Target s1 token IDs for teacher forcing. Shape: [batch_size, seq_len]. Defaults to None.
325
+
326
+ Returns:
327
+ Tuple[torch.Tensor, torch.Tensor]:
328
+ - s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size]
329
+ - s2_logits: Logits for s2 token predictions, conditioned on s1. Shape: [batch_size, seq_len, s2_vocab_size]
330
+ """
331
+ x = self.embedding([s1_ids, s2_ids])
332
+ if stamp is not None:
333
+ time_embedding = self.time_emb(stamp)
334
+ x = x + time_embedding
335
+ x = self.token_drop(x)
336
+
337
+ for layer in self.transformer:
338
+ x = layer(x, key_padding_mask=padding_mask)
339
+
340
+ x = self.norm(x)
341
+
342
+ s1_logits = self.head(x)
343
+
344
+ if use_teacher_forcing:
345
+ sibling_embed = self.embedding.emb_s1(s1_targets)
346
+ else:
347
+ s1_probs = F.softmax(s1_logits.detach(), dim=-1)
348
+ sample_s1_ids = torch.multinomial(
349
+ s1_probs.view(-1, self.s1_vocab_size), 1
350
+ ).view(s1_ids.shape)
351
+ sibling_embed = self.embedding.emb_s1(sample_s1_ids)
352
+
353
+ x2 = self.dep_layer(
354
+ x, sibling_embed, key_padding_mask=padding_mask
355
+ ) # Dependency Aware Layer: Condition on s1 embeddings
356
+ s2_logits = self.head.cond_forward(x2)
357
+ return s1_logits, s2_logits
358
+
359
+ def decode_s1(self, s1_ids, s2_ids, stamp=None, padding_mask=None):
360
+ """
361
+ Decodes only the s1 tokens.
362
+
363
+ This method performs a forward pass to predict only s1 tokens. It returns the s1 logits
364
+ and the context representation from the Transformer, which can be used for subsequent s2 decoding.
365
+
366
+ Args:
367
+ s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len]
368
+ s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len]
369
+ stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None.
370
+ padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None.
371
+
372
+ Returns:
373
+ Tuple[torch.Tensor, torch.Tensor]:
374
+ - s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size]
375
+ - context: Context representation from the Transformer. Shape: [batch_size, seq_len, d_model]
376
+ """
377
+ x = self.embedding([s1_ids, s2_ids])
378
+ if stamp is not None:
379
+ time_embedding = self.time_emb(stamp)
380
+ x = x + time_embedding
381
+ x = self.token_drop(x)
382
+
383
+ for layer in self.transformer:
384
+ x = layer(x, key_padding_mask=padding_mask)
385
+
386
+ x = self.norm(x)
387
+
388
+ s1_logits = self.head(x)
389
+ return s1_logits, x
390
+
391
+ def decode_s2(self, context, s1_ids, padding_mask=None):
392
+ """
393
+ Decodes the s2 tokens, conditioned on the context and s1 tokens.
394
+
395
+ This method decodes s2 tokens based on a pre-computed context representation (typically from `decode_s1`)
396
+ and the s1 token IDs. It uses the dependency-aware layer and the conditional s2 head to predict s2 tokens.
397
+
398
+ Args:
399
+ context (torch.Tensor): Context representation from the transformer (output of decode_s1).
400
+ Shape: [batch_size, seq_len, d_model]
401
+ s1_ids (torch.torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len]
402
+ padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None.
403
+
404
+ Returns:
405
+ torch.Tensor: s2 logits. Shape: [batch_size, seq_len, s2_vocab_size]
406
+ """
407
+ sibling_embed = self.embedding.emb_s1(s1_ids)
408
+ x2 = self.dep_layer(context, sibling_embed, key_padding_mask=padding_mask)
409
+ return self.head.cond_forward(x2)
410
+
411
+
412
+ def top_k_top_p_filtering(
413
+ logits,
414
+ top_k: int = 0,
415
+ top_p: float = 1.0,
416
+ filter_value: float = -float("Inf"),
417
+ min_tokens_to_keep: int = 1,
418
+ ):
419
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
420
+ Args:
421
+ logits: logits distribution shape (batch size, vocabulary size)
422
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
423
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
424
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
425
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
426
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
427
+ """
428
+ if top_k > 0:
429
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
430
+ # Remove all tokens with a probability less than the last token of the top-k
431
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
432
+ logits[indices_to_remove] = filter_value
433
+ return logits
434
+
435
+ if top_p < 1.0:
436
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
437
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
438
+
439
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
440
+ sorted_indices_to_remove = cumulative_probs > top_p
441
+ if min_tokens_to_keep > 1:
442
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
443
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
444
+ # Shift the indices to the right to keep also the first token above the threshold
445
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
446
+ sorted_indices_to_remove[..., 0] = 0
447
+
448
+ # scatter sorted tensors to original indexing
449
+ indices_to_remove = sorted_indices_to_remove.scatter(
450
+ 1, sorted_indices, sorted_indices_to_remove
451
+ )
452
+ logits[indices_to_remove] = filter_value
453
+ return logits
454
+
455
+
456
+ def sample_from_logits(
457
+ logits, temperature=1.0, top_k=None, top_p=None, sample_logits=True
458
+ ):
459
+ logits = logits / temperature
460
+ if top_k is not None or top_p is not None:
461
+ if top_k > 0 or top_p < 1.0:
462
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
463
+
464
+ probs = F.softmax(logits, dim=-1)
465
+
466
+ if not sample_logits:
467
+ _, x = top_k(probs, k=1, dim=-1)
468
+ else:
469
+ x = torch.multinomial(probs, num_samples=1)
470
+
471
+ return x
472
+
473
+
474
+ def auto_regressive_inference(
475
+ tokenizer,
476
+ model,
477
+ x,
478
+ x_stamp,
479
+ y_stamp,
480
+ max_context,
481
+ pred_len,
482
+ clip=5,
483
+ T=1.0,
484
+ top_k=0,
485
+ top_p=0.99,
486
+ sample_count=5,
487
+ verbose=False,
488
+ ):
489
+ with torch.no_grad():
490
+ batch_size = x.size(0)
491
+ initial_seq_len = x.size(1)
492
+ x = torch.clip(x, -clip, clip)
493
+
494
+ device = x.device
495
+ x = (
496
+ x.unsqueeze(1)
497
+ .repeat(1, sample_count, 1, 1)
498
+ .reshape(-1, x.size(1), x.size(2))
499
+ .to(device)
500
+ )
501
+ x_stamp = (
502
+ x_stamp.unsqueeze(1)
503
+ .repeat(1, sample_count, 1, 1)
504
+ .reshape(-1, x_stamp.size(1), x_stamp.size(2))
505
+ .to(device)
506
+ )
507
+ y_stamp = (
508
+ y_stamp.unsqueeze(1)
509
+ .repeat(1, sample_count, 1, 1)
510
+ .reshape(-1, y_stamp.size(1), y_stamp.size(2))
511
+ .to(device)
512
+ )
513
+
514
+ x_token = tokenizer.encode(x, half=True)
515
+
516
+ def get_dynamic_stamp(x_stamp, y_stamp, current_seq_len, pred_step):
517
+
518
+ if current_seq_len <= max_context - pred_step:
519
+ return torch.cat([x_stamp, y_stamp[:, :pred_step, :]], dim=1)
520
+ else:
521
+ start_idx = max_context - pred_step
522
+ return torch.cat(
523
+ [x_stamp[:, -start_idx:, :], y_stamp[:, :pred_step, :]], dim=1
524
+ )
525
+
526
+ if verbose:
527
+ ran = trange
528
+ else:
529
+ ran = range
530
+ for i in ran(pred_len):
531
+ current_seq_len = initial_seq_len + i
532
+
533
+ if current_seq_len <= max_context:
534
+ input_tokens = x_token
535
+ else:
536
+ input_tokens = [t[:, -max_context:].contiguous() for t in x_token]
537
+
538
+ current_stamp = get_dynamic_stamp(x_stamp, y_stamp, current_seq_len, i)
539
+
540
+ s1_logits, context = model.decode_s1(
541
+ input_tokens[0], input_tokens[1], current_stamp
542
+ )
543
+ s1_logits = s1_logits[:, -1, :]
544
+ sample_pre = sample_from_logits(
545
+ s1_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True
546
+ )
547
+
548
+ s2_logits = model.decode_s2(context, sample_pre)
549
+ s2_logits = s2_logits[:, -1, :]
550
+ sample_post = sample_from_logits(
551
+ s2_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True
552
+ )
553
+
554
+ x_token[0] = torch.cat([x_token[0], sample_pre], dim=1)
555
+ x_token[1] = torch.cat([x_token[1], sample_post], dim=1)
556
+
557
+ torch.cuda.empty_cache()
558
+
559
+ input_tokens = [t[:, -max_context:].contiguous() for t in x_token]
560
+ z = tokenizer.decode(input_tokens, half=True)
561
+ z = z.reshape(batch_size, sample_count, z.size(1), z.size(2))
562
+ preds = z.cpu().numpy()
563
+ preds = np.mean(preds, axis=1)
564
+
565
+ return preds
566
+
567
+
568
+ def calc_time_stamps(x_timestamp):
569
+ time_df = pd.DataFrame()
570
+ time_df["minute"] = x_timestamp.dt.minute
571
+ time_df["hour"] = x_timestamp.dt.hour
572
+ time_df["weekday"] = x_timestamp.dt.weekday
573
+ time_df["day"] = x_timestamp.dt.day
574
+ time_df["month"] = x_timestamp.dt.month
575
+ return time_df
576
+
577
+
578
+ class KronosPredictor:
579
+
580
+ def __init__(self, model, tokenizer, device="cuda:0", max_context=512, clip=5):
581
+ self.tokenizer = tokenizer
582
+ self.model = model
583
+ self.max_context = max_context
584
+ self.clip = clip
585
+ self.price_cols = ["open", "high", "low", "close"]
586
+ self.vol_col = "volume"
587
+ self.amt_vol = "amount"
588
+ self.time_cols = ["minute", "hour", "weekday", "day", "month"]
589
+ self.device = device
590
+
591
+ self.tokenizer = self.tokenizer.to(self.device)
592
+ self.model = self.model.to(self.device)
593
+
594
+ def generate(
595
+ self, x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose
596
+ ):
597
+
598
+ x_tensor = torch.from_numpy(np.array(x).astype(np.float32)).to(self.device)
599
+ x_stamp_tensor = torch.from_numpy(np.array(x_stamp).astype(np.float32)).to(
600
+ self.device
601
+ )
602
+ y_stamp_tensor = torch.from_numpy(np.array(y_stamp).astype(np.float32)).to(
603
+ self.device
604
+ )
605
+
606
+ preds = auto_regressive_inference(
607
+ self.tokenizer,
608
+ self.model,
609
+ x_tensor,
610
+ x_stamp_tensor,
611
+ y_stamp_tensor,
612
+ self.max_context,
613
+ pred_len,
614
+ self.clip,
615
+ T,
616
+ top_k,
617
+ top_p,
618
+ sample_count,
619
+ verbose,
620
+ )
621
+ preds = preds[:, -pred_len:, :]
622
+ return preds
623
+
624
+ def predict(
625
+ self,
626
+ df,
627
+ x_timestamp,
628
+ y_timestamp,
629
+ pred_len,
630
+ T=1.0,
631
+ top_k=0,
632
+ top_p=0.9,
633
+ sample_count=1,
634
+ verbose=True,
635
+ ):
636
+
637
+ if not isinstance(df, pd.DataFrame):
638
+ raise ValueError("Input must be a pandas DataFrame.")
639
+
640
+ if not all(col in df.columns for col in self.price_cols):
641
+ raise ValueError(f"Price columns {self.price_cols} not found in DataFrame.")
642
+
643
+ df = df.copy()
644
+ if self.vol_col not in df.columns:
645
+ df[self.vol_col] = 0.0 # Fill missing volume with zeros
646
+ df[self.amt_vol] = 0.0 # Fill missing amount with zeros
647
+ if self.amt_vol not in df.columns and self.vol_col in df.columns:
648
+ df[self.amt_vol] = df[self.vol_col] * df[self.price_cols].mean(axis=1)
649
+
650
+ if df[self.price_cols + [self.vol_col, self.amt_vol]].isnull().values.any():
651
+ raise ValueError(
652
+ "Input DataFrame contains NaN values in price or volume columns."
653
+ )
654
+
655
+ x_time_df = calc_time_stamps(x_timestamp)
656
+ y_time_df = calc_time_stamps(y_timestamp)
657
+
658
+ x = df[self.price_cols + [self.vol_col, self.amt_vol]].values.astype(np.float32)
659
+ x_stamp = x_time_df.values.astype(np.float32)
660
+ y_stamp = y_time_df.values.astype(np.float32)
661
+
662
+ x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
663
+
664
+ x = (x - x_mean) / (x_std + 1e-5)
665
+ x = np.clip(x, -self.clip, self.clip)
666
+
667
+ x = x[np.newaxis, :]
668
+ x_stamp = x_stamp[np.newaxis, :]
669
+ y_stamp = y_stamp[np.newaxis, :]
670
+
671
+ preds = self.generate(
672
+ x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose
673
+ )
674
+
675
+ preds = preds.squeeze(0)
676
+ preds = preds * (x_std + 1e-5) + x_mean
677
+
678
+ pred_df = pd.DataFrame(
679
+ preds,
680
+ columns=self.price_cols + [self.vol_col, self.amt_vol],
681
+ index=y_timestamp,
682
+ )
683
+ return pred_df
684
+
685
+ def predict_batch(
686
+ self,
687
+ df_list,
688
+ x_timestamp_list,
689
+ y_timestamp_list,
690
+ pred_len,
691
+ T=1.0,
692
+ top_k=0,
693
+ top_p=0.9,
694
+ sample_count=1,
695
+ verbose=True,
696
+ ):
697
+ """
698
+ Perform parallel (batch) prediction on multiple time series. All series must have the same historical length and prediction length (pred_len).
699
+
700
+ Args:
701
+ df_list (List[pd.DataFrame]): List of input DataFrames, each containing price columns and optional volume/amount columns.
702
+ x_timestamp_list (List[pd.DatetimeIndex or Series]): List of timestamps corresponding to historical data, length should match the number of rows in each DataFrame.
703
+ y_timestamp_list (List[pd.DatetimeIndex or Series]): List of future prediction timestamps, length should equal pred_len.
704
+ pred_len (int): Number of prediction steps.
705
+ T (float): Sampling temperature.
706
+ top_k (int): Top-k filtering threshold.
707
+ top_p (float): Top-p (nucleus sampling) threshold.
708
+ sample_count (int): Number of parallel samples per series, automatically averaged internally.
709
+ verbose (bool): Whether to display autoregressive progress.
710
+
711
+ Returns:
712
+ List[pd.DataFrame]: List of prediction results in the same order as input, each DataFrame contains
713
+ `open, high, low, close, volume, amount` columns, indexed by corresponding `y_timestamp`.
714
+ """
715
+ # Basic validation
716
+ if (
717
+ not isinstance(df_list, (list, tuple))
718
+ or not isinstance(x_timestamp_list, (list, tuple))
719
+ or not isinstance(y_timestamp_list, (list, tuple))
720
+ ):
721
+ raise ValueError(
722
+ "df_list, x_timestamp_list, y_timestamp_list must be list or tuple types."
723
+ )
724
+ if not (len(df_list) == len(x_timestamp_list) == len(y_timestamp_list)):
725
+ raise ValueError(
726
+ "df_list, x_timestamp_list, y_timestamp_list must have consistent lengths."
727
+ )
728
+
729
+ num_series = len(df_list)
730
+
731
+ x_list = []
732
+ x_stamp_list = []
733
+ y_stamp_list = []
734
+ means = []
735
+ stds = []
736
+ seq_lens = []
737
+ y_lens = []
738
+
739
+ for i in range(num_series):
740
+ df = df_list[i]
741
+ if not isinstance(df, pd.DataFrame):
742
+ raise ValueError(f"Input at index {i} is not a pandas DataFrame.")
743
+ if not all(col in df.columns for col in self.price_cols):
744
+ raise ValueError(
745
+ f"DataFrame at index {i} is missing price columns {self.price_cols}."
746
+ )
747
+
748
+ df = df.copy()
749
+ if self.vol_col not in df.columns:
750
+ df[self.vol_col] = 0.0
751
+ df[self.amt_vol] = 0.0
752
+ if self.amt_vol not in df.columns and self.vol_col in df.columns:
753
+ df[self.amt_vol] = df[self.vol_col] * df[self.price_cols].mean(axis=1)
754
+
755
+ if df[self.price_cols + [self.vol_col, self.amt_vol]].isnull().values.any():
756
+ raise ValueError(
757
+ f"DataFrame at index {i} contains NaN values in price or volume columns."
758
+ )
759
+
760
+ x_timestamp = x_timestamp_list[i]
761
+ y_timestamp = y_timestamp_list[i]
762
+
763
+ x_time_df = calc_time_stamps(x_timestamp)
764
+ y_time_df = calc_time_stamps(y_timestamp)
765
+
766
+ x = df[self.price_cols + [self.vol_col, self.amt_vol]].values.astype(
767
+ np.float32
768
+ )
769
+ x_stamp = x_time_df.values.astype(np.float32)
770
+ y_stamp = y_time_df.values.astype(np.float32)
771
+
772
+ if x.shape[0] != x_stamp.shape[0]:
773
+ raise ValueError(
774
+ f"Inconsistent lengths at index {i}: x has {x.shape[0]} vs x_stamp has {x_stamp.shape[0]}."
775
+ )
776
+ if y_stamp.shape[0] != pred_len:
777
+ raise ValueError(
778
+ f"y_timestamp length at index {i} should equal pred_len={pred_len}, got {y_stamp.shape[0]}."
779
+ )
780
+
781
+ x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
782
+ x_norm = (x - x_mean) / (x_std + 1e-5)
783
+ x_norm = np.clip(x_norm, -self.clip, self.clip)
784
+
785
+ x_list.append(x_norm)
786
+ x_stamp_list.append(x_stamp)
787
+ y_stamp_list.append(y_stamp)
788
+ means.append(x_mean)
789
+ stds.append(x_std)
790
+
791
+ seq_lens.append(x_norm.shape[0])
792
+ y_lens.append(y_stamp.shape[0])
793
+
794
+ # Require all series to have consistent historical and prediction lengths for batch processing
795
+ if len(set(seq_lens)) != 1:
796
+ raise ValueError(
797
+ f"Parallel prediction requires all series to have consistent historical lengths, got: {seq_lens}"
798
+ )
799
+ if len(set(y_lens)) != 1:
800
+ raise ValueError(
801
+ f"Parallel prediction requires all series to have consistent prediction lengths, got: {y_lens}"
802
+ )
803
+
804
+ x_batch = np.stack(x_list, axis=0).astype(np.float32) # (B, seq_len, feat)
805
+ x_stamp_batch = np.stack(x_stamp_list, axis=0).astype(
806
+ np.float32
807
+ ) # (B, seq_len, time_feat)
808
+ y_stamp_batch = np.stack(y_stamp_list, axis=0).astype(
809
+ np.float32
810
+ ) # (B, pred_len, time_feat)
811
+
812
+ preds = self.generate(
813
+ x_batch,
814
+ x_stamp_batch,
815
+ y_stamp_batch,
816
+ pred_len,
817
+ T,
818
+ top_k,
819
+ top_p,
820
+ sample_count,
821
+ verbose,
822
+ )
823
+ # preds: (B, pred_len, feat)
824
+
825
+ pred_dfs = []
826
+ for i in range(num_series):
827
+ preds_i = preds[i] * (stds[i] + 1e-5) + means[i]
828
+ pred_df = pd.DataFrame(
829
+ preds_i,
830
+ columns=self.price_cols + [self.vol_col, self.amt_vol],
831
+ index=y_timestamp_list[i],
832
+ )
833
+ pred_dfs.append(pred_df)
834
+
835
+ return pred_dfs
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b082dfcbd8e8c142a725c8bbb99781802f38fec81210e13479effb32b3c3e020
3
+ size 98980656
module.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ from einops import rearrange, reduce
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.autograd import Function
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class DifferentiableEntropyFunction(Function):
11
+ @staticmethod
12
+ def forward(ctx, zq, basis, K, eps):
13
+ zb = (zq + 1) / 2
14
+ zi = ((zb * basis).sum(-1)).to(torch.int64)
15
+ cnt = torch.scatter_reduce(torch.zeros(2 ** K, device=zq.device, dtype=zq.dtype),
16
+ 0,
17
+ zi.flatten(),
18
+ torch.ones_like(zi.flatten()).to(zq.dtype),
19
+ 'sum')
20
+ prob = (cnt + eps) / (cnt + eps).sum()
21
+ H = -(prob * torch.log(prob)).sum()
22
+ ctx.save_for_backward(zq, zi, prob)
23
+ ctx.K = K
24
+ return H
25
+
26
+ @staticmethod
27
+ def backward(ctx, grad_output):
28
+ zq, zi, prob = ctx.saved_tensors
29
+ grad_array = -grad_output * (torch.log(prob) + 1) / zi.numel() / ctx.K
30
+ reord_grad = grad_array[zi.flatten()].reshape(zi.shape)
31
+ grad_input = reord_grad.unsqueeze(-1) * zq
32
+ return grad_input, None, None, None, None
33
+
34
+
35
+ def codebook_entropy(zq, basis, K, eps=1e-4):
36
+ return DifferentiableEntropyFunction.apply(zq, basis, K, eps)
37
+
38
+
39
+ class BinarySphericalQuantizer(nn.Module):
40
+ def __init__(self, embed_dim, beta, gamma0, gamma, zeta,
41
+ input_format='bchw',
42
+ soft_entropy=True, group_size=9,
43
+ persample_entropy_compute='analytical',
44
+ cb_entropy_compute='group',
45
+ l2_norm=True,
46
+ inv_temperature=1):
47
+ """
48
+ Paper link: https://arxiv.org/pdf/2406.07548.pdf
49
+ Here we use the official implementation of the BinarySphericalQuantizer.
50
+ """
51
+ super().__init__()
52
+ self.embed_dim = embed_dim
53
+ self.beta = beta # loss weight for commit loss
54
+ self.gamma0 = gamma0 # loss weight for entropy penalty
55
+ self.gamma = gamma # loss weight for entropy penalty
56
+ self.zeta = zeta # loss weight for entire entropy penalty
57
+ self.input_format = input_format
58
+ assert self.embed_dim % group_size == 0, "embed_dim must be divisible by group_size"
59
+ self.num_groups = self.embed_dim // group_size
60
+ self.group_size = group_size
61
+ assert persample_entropy_compute in ['group', 'analytical'], "persample_entropy_compute must be either 'group' or 'analytical'"
62
+ assert cb_entropy_compute in ['group', 'nce'], "cb_entropy_compute must be either 'group' or 'nce'"
63
+ self.persample_entropy_compute = persample_entropy_compute
64
+ self.cb_entropy_compute = cb_entropy_compute
65
+ self.l2_norm = l2_norm
66
+ self.inv_temperature = inv_temperature
67
+
68
+ self.register_buffer('basis', 2 ** torch.arange(embed_dim - 1, -1, -1))
69
+ self.register_buffer('group_basis', 2 ** torch.arange(group_size - 1, -1, -1))
70
+
71
+ self.num_dimensions = 2 ** embed_dim
72
+ self.bits_per_index = embed_dim
73
+
74
+ # we only need to keep the codebook portion up to the group size
75
+ # because we approximate the H loss with this subcode
76
+ group_codes = torch.arange(2 ** self.group_size)
77
+ group_codebook = self.indexes_to_codes(group_codes).float()[:, -group_size:]
78
+ self.register_buffer('group_codebook', group_codebook, persistent=False)
79
+
80
+ self.soft_entropy = soft_entropy # soft_entropy: Sec 3.2 of https://arxiv.org/pdf/1911.05894.pdf
81
+
82
+ def quantize(self, z):
83
+ assert z.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {z.shape[-1]}"
84
+
85
+ zhat = torch.where(z > 0,
86
+ torch.tensor(1, dtype=z.dtype, device=z.device),
87
+ torch.tensor(-1, dtype=z.dtype, device=z.device))
88
+ return z + (zhat - z).detach()
89
+
90
+ def forward(self, z):
91
+ # if self.input_format == 'bchw':
92
+ # z = rearrange(z, 'b c h w -> b h w c')
93
+ zq = self.quantize(z)
94
+
95
+ indices = self.codes_to_indexes(zq.detach())
96
+ group_indices = self.codes_to_group_indexes(zq.detach())
97
+ if not self.training:
98
+ used_codes = torch.unique(indices, return_counts=False)
99
+ else:
100
+ used_codes = None
101
+
102
+ q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
103
+
104
+ if self.soft_entropy:
105
+ persample_entropy, cb_entropy, avg_prob = self.soft_entropy_loss(z)
106
+ entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy
107
+ else:
108
+ zb_by_sample = ((zq + 1) / 2).reshape(z.shape[0], -1, z.shape[-1]).to(torch.float32)
109
+ persample_entropy = self.get_hard_per_sample_entropy(zb_by_sample)
110
+ cb_entropy = codebook_entropy(zq, self.basis, self.embed_dim)
111
+ entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy
112
+
113
+ zq = zq * q_scale
114
+
115
+ # commit loss
116
+ commit_loss = self.beta * torch.mean(((zq.detach() - z) ** 2).sum(dim=-1))
117
+
118
+ # if self.input_format == 'bchw':
119
+ # zq = rearrange(zq, 'b h w c -> b c h w')
120
+
121
+ return (
122
+ zq,
123
+ commit_loss + self.zeta * entropy_penalty / self.inv_temperature,
124
+ {"H": cb_entropy, "used_codes": used_codes, "indices": indices, "group_indices": group_indices,
125
+ "avg_prob": avg_prob}
126
+ )
127
+
128
+ def soft_entropy_loss(self, z):
129
+ # if we divide the code in subgroups of size group_size, the codebook will be of size 2 ** group_size
130
+ # the sub-code is the last group_size bits of the full code
131
+ group_code_book = self.group_codebook / (self.embed_dim ** 0.5 if self.l2_norm else 1)
132
+ divided_z = rearrange(z, '... (g c) -> ... g c', c=self.group_size)
133
+
134
+ # we calculate the distance between the divided_z and the codebook for each subgroup
135
+ distance = - 2 * torch.einsum('... g c, d c ->... g d', divided_z, group_code_book)
136
+ prob = (-distance * self.inv_temperature).softmax(dim=-1)
137
+ if self.persample_entropy_compute == 'analytical':
138
+ if self.l2_norm:
139
+ p = torch.sigmoid(-4 * z / (self.embed_dim ** 0.5) * self.inv_temperature)
140
+ else:
141
+ p = torch.sigmoid(-4 * z * self.inv_temperature)
142
+ prob = torch.stack([p, 1 - p], dim=-1)
143
+ per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
144
+ else:
145
+ per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
146
+
147
+ # macro average of the probability of each subgroup
148
+ avg_prob = reduce(prob, '... g d ->g d', 'mean')
149
+ codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False)
150
+
151
+ # the approximation of the entropy is the sum of the entropy of each subgroup
152
+ return per_sample_entropy, codebook_entropy.sum(), avg_prob
153
+
154
+ def get_hard_per_sample_entropy(self, zb_by_sample):
155
+ probs_per_dim = zb_by_sample.sum(1) / zb_by_sample.shape[1]
156
+ persample_entropy = - probs_per_dim * torch.log(probs_per_dim + 1e-8) - (1 - probs_per_dim) * torch.log(1 - probs_per_dim + 1e-8)
157
+ persample_entropy = persample_entropy.sum(-1)
158
+ return persample_entropy.mean()
159
+
160
+ def codes_to_indexes(self, zhat):
161
+ """Converts a `code` to an index in the codebook.
162
+ Args:
163
+ zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1}
164
+ """
165
+ assert zhat.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {zhat.shape[-1]}"
166
+ return ((zhat + 1) / 2 * self.basis).sum(axis=-1).to(torch.int64)
167
+
168
+ def codes_to_group_indexes(self, zhat):
169
+ """Converts a `code` to a list of indexes (in groups) in the codebook.
170
+ Args:
171
+ zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1}
172
+ """
173
+ zhat_in_group = rearrange(zhat, 'b ... (g c) -> b ... g c', c=self.group_size)
174
+ return ((zhat_in_group + 1) / 2 * self.group_basis).sum(axis=-1).to(torch.int64)
175
+
176
+ def indexes_to_codes(self, indices):
177
+ """Inverse of `indexes_to_codes`."""
178
+ indices = indices.unsqueeze(-1)
179
+ codes_non_centered = torch.remainder(
180
+ torch.floor_divide(indices, self.basis), 2
181
+ )
182
+ return codes_non_centered * 2 - 1
183
+
184
+ def group_indexes_to_codes(self, group_indices):
185
+ """Inverse of `group_indexes_to_codes`."""
186
+ group_indices = group_indices.unsqueeze(-1)
187
+ codes_non_centered = torch.remainder(
188
+ torch.floor_divide(group_indices, self.group_basis), 2
189
+ )
190
+ codes_non_centered = rearrange(codes_non_centered, 'b ... g c -> b ... (g c)')
191
+ return codes_non_centered * 2 - 1
192
+
193
+ def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True):
194
+ if normalize:
195
+ probs = (count + eps) / (count + eps).sum(dim=dim, keepdim=True)
196
+ else:
197
+ probs = count
198
+ H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim)
199
+ return H
200
+
201
+ def get_group_codebook_entry(self, group_indices):
202
+ z_q = self.group_indexes_to_codes(group_indices)
203
+ q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
204
+ z_q = z_q * q_scale
205
+ if self.input_format == 'bchw':
206
+ h, w = int(z_q.shape[1] ** 0.5)
207
+ assert h * w == z_q.shape[1], 'Invalid sequence length'
208
+ z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h)
209
+ return z_q
210
+
211
+ def get_codebook_entry(self, indices):
212
+ z_q = self.indexes_to_codes(indices)
213
+ q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
214
+ z_q = z_q * q_scale
215
+ if self.input_format == 'bchw':
216
+ h, w = int(z_q.shape[1] ** 0.5)
217
+ assert h * w == z_q.shape[1], 'Invalid sequence length'
218
+ z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h)
219
+ return z_q
220
+
221
+
222
+ class BSQuantizer(nn.Module):
223
+
224
+ def __init__(self, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size):
225
+ super().__init__()
226
+ self.codebook_dim = s1_bits + s2_bits
227
+ self.s1_bits = s1_bits
228
+ self.s2_bits = s2_bits
229
+ self.bsq = BinarySphericalQuantizer(self.codebook_dim, beta, gamma0, gamma, zeta, group_size=group_size)
230
+
231
+ def bits_to_indices(self, bits):
232
+ bits = (bits >= 0).to(torch.long)
233
+ indices = 2 ** torch.arange(
234
+ 0,
235
+ bits.shape[-1],
236
+ 1,
237
+ dtype=torch.long,
238
+ device=bits.device,
239
+ )
240
+ return (bits * indices).sum(-1)
241
+
242
+ def forward(self, z, half=False):
243
+ z = F.normalize(z, dim=-1)
244
+ quantized, bsq_loss, metrics = self.bsq(z)
245
+ if half:
246
+ q_pre = quantized[:, :, :self.s1_bits]
247
+ q_post = quantized[:, :, self.s1_bits:]
248
+ z_indices = [self.bits_to_indices(q_pre), self.bits_to_indices(q_post)]
249
+ else:
250
+ z_indices = self.bits_to_indices(quantized)
251
+ return bsq_loss, quantized, z_indices
252
+
253
+
254
+ class RMSNorm(torch.nn.Module):
255
+ def __init__(self, dim: int, eps: float = 1e-5):
256
+ super().__init__()
257
+ self.eps = eps
258
+ self.weight = nn.Parameter(torch.ones(dim))
259
+
260
+ def _norm(self, x):
261
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
262
+
263
+ def forward(self, x):
264
+ output = self._norm(x.float()).type_as(x)
265
+ return output * self.weight
266
+
267
+
268
+ class FeedForward(nn.Module):
269
+ def __init__(self, d_model, ff_dim, ffn_dropout_p=0.0):
270
+ super().__init__()
271
+
272
+ self.w1 = nn.Linear(d_model, ff_dim, bias=False)
273
+ self.w3 = nn.Linear(d_model, ff_dim, bias=False)
274
+ self.w2 = nn.Linear(ff_dim, d_model, bias=False)
275
+ self.ffn_dropout = nn.Dropout(ffn_dropout_p)
276
+
277
+ def forward(self, x):
278
+ return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
279
+
280
+
281
+ class RotaryPositionalEmbedding(nn.Module):
282
+ def __init__(self, dim):
283
+ super().__init__()
284
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
285
+ self.register_buffer("inv_freq", inv_freq)
286
+ self.seq_len_cached = None
287
+ self.cos_cached = None
288
+ self.sin_cached = None
289
+
290
+ def _update_cos_sin_cache(self, x, seq_len):
291
+ if seq_len != self.seq_len_cached:
292
+ self.seq_len_cached = seq_len
293
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
294
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
295
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
296
+ self.cos_cached = emb.cos()[None, None, :, :]
297
+ self.sin_cached = emb.sin()[None, None, :, :]
298
+ return self.cos_cached, self.sin_cached
299
+
300
+ def forward(self, q, k):
301
+ cos, sin = self._update_cos_sin_cache(q, q.shape[-2])
302
+ return (
303
+ (q * cos) + (self._rotate_half(q) * sin),
304
+ (k * cos) + (self._rotate_half(k) * sin),
305
+ )
306
+
307
+ def _rotate_half(self, x):
308
+ x1, x2 = x.chunk(2, dim=-1)
309
+ return torch.cat((-x2, x1), dim=-1)
310
+
311
+
312
+ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, training=True) -> torch.Tensor:
313
+ L, S = query.size(-2), key.size(-2)
314
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
315
+ attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device)
316
+
317
+ if is_causal:
318
+ assert attn_mask is None
319
+ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).to(query.device)
320
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
321
+ attn_bias.to(query.dtype)
322
+
323
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
324
+ attn_weight += attn_bias
325
+
326
+ if attn_mask is not None:
327
+ attn_mask_bias = torch.zeros_like(attn_weight)
328
+ if attn_mask.dtype == torch.bool:
329
+ attn_mask_bias.masked_fill_(attn_mask, float("-inf"))
330
+ else:
331
+ attn_mask_bias += attn_mask
332
+ attn_weight += attn_mask_bias
333
+
334
+ attn_weight = torch.softmax(attn_weight, dim=-1)
335
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=training)
336
+ return attn_weight @ value
337
+
338
+
339
+ class MultiHeadAttentionWithRoPE(nn.Module):
340
+ def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout_p=0.0):
341
+ super().__init__()
342
+ self.d_model = d_model
343
+ self.n_heads = n_heads
344
+ self.head_dim = d_model // n_heads
345
+
346
+ self.q_proj = nn.Linear(d_model, d_model)
347
+ self.k_proj = nn.Linear(d_model, d_model)
348
+ self.v_proj = nn.Linear(d_model, d_model)
349
+ self.out_proj = nn.Linear(d_model, d_model)
350
+ self.rotary = RotaryPositionalEmbedding(self.head_dim)
351
+ self.attn_dropout_p = attn_dropout_p
352
+ self.resid_dropout = nn.Dropout(resid_dropout_p)
353
+
354
+ def forward(self, x, key_padding_mask=None):
355
+ batch_size, seq_len, _ = x.shape
356
+
357
+ q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
358
+ k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
359
+ v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
360
+
361
+ q, k = self.rotary(q, k)
362
+
363
+ if key_padding_mask is not None:
364
+ attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) # [batch, 1, 1, seq_len]
365
+ attn_mask = attn_mask.expand(-1, self.n_heads, seq_len, -1) # [batch, n_heads, q_len, k_len]
366
+ else:
367
+ attn_mask = None
368
+
369
+ attn_output = scaled_dot_product_attention(
370
+ q, k, v,
371
+ attn_mask=attn_mask,
372
+ dropout_p=self.attn_dropout_p,
373
+ is_causal=True,
374
+ training=self.training
375
+ )
376
+
377
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
378
+ return self.resid_dropout(self.out_proj(attn_output))
379
+
380
+
381
+ class MultiHeadCrossAttentionWithRoPE(nn.Module):
382
+ def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout=0.0):
383
+ super().__init__()
384
+ self.d_model = d_model
385
+ self.n_heads = n_heads
386
+ self.head_dim = d_model // n_heads
387
+
388
+ self.q_proj = nn.Linear(d_model, d_model)
389
+ self.k_proj = nn.Linear(d_model, d_model)
390
+ self.v_proj = nn.Linear(d_model, d_model)
391
+ self.out_proj = nn.Linear(d_model, d_model)
392
+ self.rotary = RotaryPositionalEmbedding(self.head_dim)
393
+ self.attn_dropout_p = attn_dropout_p
394
+ self.resid_dropout = nn.Dropout(resid_dropout)
395
+
396
+ def forward(self, query, key, value, key_padding_mask=None):
397
+ batch_size, q_len, _ = query.shape
398
+ _, seq_len, _ = key.shape
399
+
400
+ q = self.q_proj(query).view(batch_size, q_len, self.n_heads, self.head_dim).transpose(1, 2)
401
+ k = self.k_proj(key).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
402
+ v = self.v_proj(value).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
403
+
404
+ q, k = self.rotary(q, k)
405
+
406
+ if key_padding_mask is not None:
407
+ attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2)
408
+ attn_mask = attn_mask.expand(-1, self.n_heads, q_len, -1)
409
+ else:
410
+ attn_mask = None
411
+
412
+ is_causal_flag = self.training
413
+
414
+ attn_output = scaled_dot_product_attention(
415
+ q, k, v,
416
+ attn_mask=attn_mask,
417
+ dropout_p=self.attn_dropout_p,
418
+ is_causal=is_causal_flag,
419
+ training=self.training
420
+ )
421
+
422
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, q_len, self.d_model)
423
+ return self.resid_dropout(self.out_proj(attn_output))
424
+
425
+
426
+ class HierarchicalEmbedding(nn.Module):
427
+ def __init__(self, s1_bits, s2_bits, d_model=256):
428
+ super().__init__()
429
+ self.s1_bits = s1_bits
430
+ self.s2_bits = s2_bits
431
+
432
+ vocab_s1 = 2 ** s1_bits
433
+ vocab_s2 = 2 ** s2_bits
434
+
435
+ self.emb_s1 = nn.Embedding(vocab_s1, d_model)
436
+ self.emb_s2 = nn.Embedding(vocab_s2, d_model)
437
+ self.d_model = d_model
438
+ self.fusion_proj = nn.Linear(d_model * 2, d_model)
439
+
440
+ nn.init.normal_(self.emb_s1.weight, mean=0, std=d_model ** -0.5)
441
+ nn.init.normal_(self.emb_s2.weight, mean=0, std=d_model ** -0.5)
442
+
443
+ def forward(self, token_ids):
444
+ """Inputs:
445
+ token_ids: [batch_size, seq_len] token ID
446
+ Output: [batch_size, seq_len, d_model]
447
+ """
448
+ if isinstance(token_ids, tuple) or isinstance(token_ids, list):
449
+ s1_ids, s2_ids = token_ids
450
+ else:
451
+ s1_ids, s2_ids = self.split_token(token_ids, self.s2_bits)
452
+ s1_emb = self.emb_s1(s1_ids) * math.sqrt(self.d_model)
453
+ s2_emb = self.emb_s2(s2_ids) * math.sqrt(self.d_model)
454
+ return self.fusion_proj(torch.cat([s1_emb, s2_emb], dim=-1))
455
+
456
+
457
+ class DependencyAwareLayer(nn.Module):
458
+ def __init__(self, d_model, n_heads=4, attn_dropout_p=0.0, resid_dropout=0.0):
459
+ super().__init__()
460
+ self.cross_attn = MultiHeadCrossAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout)
461
+ self.norm = RMSNorm(d_model)
462
+
463
+ def forward(self, hidden_states, sibling_embed, key_padding_mask=None):
464
+ """hidden_states: [batch, seq_len, d_model]
465
+ sibling_embed: Embedding from another subtoken
466
+ """
467
+ attn_out = self.cross_attn(
468
+ query=sibling_embed,
469
+ key=hidden_states,
470
+ value=hidden_states,
471
+ key_padding_mask=key_padding_mask
472
+ )
473
+ return self.norm(hidden_states + attn_out)
474
+
475
+
476
+ class TransformerBlock(nn.Module):
477
+ def __init__(self, d_model, n_heads, ff_dim=1024, ffn_dropout_p=0.0, attn_dropout_p=0.0, resid_dropout_p=0.0):
478
+ super().__init__()
479
+ self.norm1 = RMSNorm(d_model)
480
+ self.self_attn = MultiHeadAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout_p)
481
+ self.norm2 = RMSNorm(d_model)
482
+ self.ffn = FeedForward(d_model, ff_dim, ffn_dropout_p)
483
+
484
+ def forward(self, x, key_padding_mask=None):
485
+ residual = x
486
+ x = self.norm1(x)
487
+ attn_out = self.self_attn(x, key_padding_mask=key_padding_mask)
488
+ x = residual + attn_out
489
+
490
+ residual = x
491
+ x = self.norm2(x)
492
+ ffn_out = self.ffn(x)
493
+ x = residual + ffn_out
494
+ return x
495
+
496
+
497
+ class DualHead(nn.Module):
498
+ def __init__(self, s1_bits, s2_bits, d_model):
499
+ super().__init__()
500
+ self.vocab_s1 = 2 ** s1_bits
501
+ self.vocab_s2 = 2 ** s2_bits
502
+ self.proj_s1 = nn.Linear(d_model, self.vocab_s1)
503
+ self.proj_s2 = nn.Linear(d_model, self.vocab_s2)
504
+
505
+ def compute_loss(self, s1_logits, s2_logits, s1_targets, s2_targets, padding_mask=None):
506
+ if padding_mask is not None:
507
+ valid_mask = (padding_mask == 0)
508
+ s1_logits = s1_logits[valid_mask]
509
+ s2_logits = s2_logits[valid_mask]
510
+ s1_targets = s1_targets[valid_mask]
511
+ s2_targets = s2_targets[valid_mask]
512
+ ce_s1 = F.cross_entropy(s1_logits, s1_targets)
513
+ ce_s2 = F.cross_entropy(s2_logits, s2_targets)
514
+ else:
515
+ ce_s1 = F.cross_entropy(s1_logits.reshape(-1, self.vocab_s1), s1_targets.reshape(-1))
516
+ ce_s2 = F.cross_entropy(s2_logits.reshape(-1, self.vocab_s2), s2_targets.reshape(-1))
517
+ ce_loss = (ce_s1 + ce_s2) / 2
518
+ return ce_loss, ce_s1, ce_s2
519
+
520
+ def forward(self, x):
521
+ return self.proj_s1(x)
522
+
523
+ def cond_forward(self, x2):
524
+ return self.proj_s2(x2)
525
+
526
+
527
+ class FixedEmbedding(nn.Module):
528
+ def __init__(self, c_in, d_model):
529
+ super(FixedEmbedding, self).__init__()
530
+
531
+ w = torch.zeros(c_in, d_model).float()
532
+ w.require_grad = False
533
+
534
+ position = torch.arange(0, c_in).float().unsqueeze(1)
535
+ div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
536
+
537
+ w[:, 0::2] = torch.sin(position * div_term)
538
+ w[:, 1::2] = torch.cos(position * div_term)
539
+
540
+ self.emb = nn.Embedding(c_in, d_model)
541
+ self.emb.weight = nn.Parameter(w, requires_grad=False)
542
+
543
+ def forward(self, x):
544
+ return self.emb(x).detach()
545
+
546
+
547
+ class TemporalEmbedding(nn.Module):
548
+ def __init__(self, d_model, learn_pe):
549
+ super(TemporalEmbedding, self).__init__()
550
+
551
+ minute_size = 60
552
+ hour_size = 24
553
+ weekday_size = 7
554
+ day_size = 32
555
+ month_size = 13
556
+
557
+ Embed = FixedEmbedding if not learn_pe else nn.Embedding
558
+ self.minute_embed = Embed(minute_size, d_model)
559
+ self.hour_embed = Embed(hour_size, d_model)
560
+ self.weekday_embed = Embed(weekday_size, d_model)
561
+ self.day_embed = Embed(day_size, d_model)
562
+ self.month_embed = Embed(month_size, d_model)
563
+
564
+ def forward(self, x):
565
+ x = x.long()
566
+
567
+ minute_x = self.minute_embed(x[:, :, 0])
568
+ hour_x = self.hour_embed(x[:, :, 1])
569
+ weekday_x = self.weekday_embed(x[:, :, 2])
570
+ day_x = self.day_embed(x[:, :, 3])
571
+ month_x = self.month_embed(x[:, :, 4])
572
+
573
+ return hour_x + weekday_x + day_x + month_x + minute_x
574
+
575
+
576
+
577
+
578
+
579
+
580
+
581
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ pandas
3
+ numpy
4
+ safetensors
5
+ huggingface_hub
6
+