RetentionLabs commited on
Commit
fa58f7a
·
verified ·
1 Parent(s): 063ab49

Upload folder using huggingface_hub

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 test-time-training
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,105 @@
1
  ---
2
  license: mit
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
+ language:
4
+ - en
5
+ tags:
6
+ - Test-time Training
7
+ pipeline_tag: text-generation
8
+ base_model:
9
+ - Test-Time-Training/ttt-mlp-350m-books-2k
10
+ library_name: transformers
11
  ---
12
+
13
+ # Learning to (Learn at Test Time): RNNs with Expressive Hidden States
14
+
15
+ [**Paper**](https://arxiv.org/abs/2407.04620)
16
+ | [**JAX Codebase**](https://github.com/test-time-training/ttt-lm-jax)
17
+ | [**Setup**](#environment-setup)
18
+ | [**Quick Start**](#quick-start)
19
+ | [**Inference Benchmark**](https://github.com/test-time-training/ttt-lm-kernels)
20
+
21
+ This is the official PyTorch model implementation of [Learning to (Learn at Test Time): RNNs with Expressive Hidden States](https://arxiv.org/abs/2407.04620).
22
+ We **do not recommend training** with this codebase, because it is written in pure PyTorch without any systems optimization, so training will be slow, especially when the per-device batch size is small.
23
+
24
+
25
+ For training code, or to replicate results from our paper, please view our [JAX codebase](https://github.com/test-time-training/ttt-lm-jax). For inference kernels, or to replicate speed benchmarks from our paper, please view our [kernel implementations](https://github.com/test-time-training/ttt-lm-kernels).
26
+
27
+ ## Abstract
28
+
29
+ Self-attention performs well in long context but has quadratic complexity. Existing RNN layers
30
+ have linear complexity, but their performance in long context is limited by the expressive power
31
+ of their hidden state. We propose a new class of sequence modeling layers with linear complexity
32
+ and an expressive hidden state. The key idea is to make the hidden state a machine learning
33
+ model itself, and the update rule a step of self-supervised learning.
34
+
35
+ Since the hidden state is updated by training even on test sequences, our layers are called **Test-Time Training (TTT) layers**.
36
+ We consider two instantiations: TTT-Linear and TTT-MLP, whose hidden state is a linear model
37
+ and a two-layer MLP respectively.
38
+
39
+ ## Environment Setup
40
+
41
+ ```bash
42
+ pip install "transformers[torch]"
43
+ ```
44
+
45
+ ## Quick Start
46
+
47
+ Our implementation is based on Huggingface Transformers. You can use the following code to load the model and generate text.
48
+
49
+ ### Load with AutoModel
50
+
51
+ ```python
52
+ import torch
53
+ from transformers import AutoTokenizer, AutoModelForCausalLM
54
+
55
+
56
+ model_id = "RetentionLabs/TTT-Linear-350M-Base-Books-2k"
57
+
58
+ # Initializing a model from remote
59
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
60
+ model = AutoModelForCausalLM.from_pretrained(
61
+ model_id,
62
+ trust_remote_code=True,
63
+ dtype=torch.bfloat16,
64
+ device_map="auto"
65
+ )
66
+
67
+ # Generate
68
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
69
+ inputs = tokenizer("The future of AI is", return_tensors="pt").to(model.device)
70
+ outputs = model.generate(**inputs, max_new_tokens=100)
71
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
72
+ ```
73
+
74
+ ### From scratch
75
+
76
+ ```python
77
+ from transformers import AutoTokenizer
78
+ from modeling_ttt import TTTForCausalLM, TTTConfig, TTT_STANDARD_CONFIGS
79
+
80
+ # Initializing a TTT ttt-1b style configuration
81
+ # configuration = TTTConfig(**TTT_STANDARD_CONFIGS['1b']) is equivalent to the following
82
+ configuration = TTTConfig()
83
+
84
+ # Initializing a model from the ttt-1b style configuration
85
+ model = TTTForCausalLM(configuration)
86
+ model.eval()
87
+
88
+ # Accessing the model configuration
89
+ configuration = model.config
90
+
91
+ # Tokenizer
92
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
93
+
94
+ # Prefill
95
+ input_ids = tokenizer("Greeting from TTT!", return_tensors="pt").input_ids
96
+ logits = model(input_ids=input_ids)
97
+ print(logits)
98
+
99
+ # Decoding
100
+ out_ids = model.generate(input_ids=input_ids, max_length=50)
101
+ out_str = tokenizer.batch_decode(out_ids, skip_special_tokens=True)
102
+ print(out_str)
103
+ ```
104
+
105
+ **Note: This is a naive implementation of TTT layers for tutorial purposes.** This model can be trained using Huggingface Accelerate, or custom training loops. We have released our faster inference kernel and its speed benchmark [here](https://github.com/test-time-training/ttt-lm-kernels).
config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "TTTForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "modeling_ttt.TTTConfig",
7
+ "AutoModel": "modeling_ttt.TTTModel",
8
+ "AutoModelForCausalLM": "modeling_ttt.TTTForCausalLM"
9
+ },
10
+ "bos_token_id": 1,
11
+ "conv_kernel": 4,
12
+ "dtype": "bfloat16",
13
+ "eos_token_id": 2,
14
+ "hidden_act": "silu",
15
+ "hidden_size": 1024,
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": 2736,
18
+ "max_position_embeddings": 2048,
19
+ "mini_batch_size": 16,
20
+ "model_type": "ttt",
21
+ "num_attention_heads": 16,
22
+ "num_hidden_layers": 24,
23
+ "pre_conv": true,
24
+ "pretraining_tp": 1,
25
+ "rms_norm_eps": 1e-06,
26
+ "rope_theta": 10000.0,
27
+ "scan_checkpoint_group_size": 0,
28
+ "share_qk": true,
29
+ "transformers_version": "4.57.6",
30
+ "ttt_base_lr": 1.0,
31
+ "ttt_layer_type": "linear",
32
+ "use_cache": true,
33
+ "use_gate": true,
34
+ "vocab_size": 32000
35
+ }
generation_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 2,
6
+ "transformers_version": "4.57.6",
7
+ "do_sample": true,
8
+ "temperature": 0.7,
9
+ "top_p": 0.9,
10
+ "repetition_penalty": 1.1,
11
+ "max_new_tokens": 512
12
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5502a910469aaabaad74b4242ced63c4d8f3ab4891e08586b384afd2ac1296b4
3
+ size 675438944
modeling_ttt.py ADDED
@@ -0,0 +1,1650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torch.utils.checkpoint
8
+ from torch import nn
9
+ from torch.nn import CrossEntropyLoss
10
+ from torch.utils._pytree import tree_map
11
+
12
+ from transformers import PretrainedConfig
13
+ from transformers.activations import ACT2FN
14
+ from transformers.modeling_outputs import (
15
+ BaseModelOutputWithPast,
16
+ CausalLMOutputWithPast,
17
+ )
18
+ from transformers.generation import GenerationMixin
19
+ from transformers.modeling_utils import PreTrainedModel
20
+ from transformers.utils import ModelOutput, logging
21
+ from transformers.utils.import_utils import is_causal_conv1d_available
22
+
23
+ if is_causal_conv1d_available():
24
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
25
+ else:
26
+ causal_conv1d_update, causal_conv1d_fn = None, None
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+ TTT_STANDARD_CONFIGS = {
32
+ "125m": {
33
+ "hidden_size": 768,
34
+ "intermediate_size": 2048,
35
+ "num_hidden_layers": 12,
36
+ "num_attention_heads": 12,
37
+ },
38
+ "350m": {
39
+ "hidden_size": 1024,
40
+ "intermediate_size": 2736,
41
+ "num_hidden_layers": 24,
42
+ "num_attention_heads": 16,
43
+ },
44
+ "760m": {
45
+ "hidden_size": 1536,
46
+ "intermediate_size": 4096,
47
+ "num_hidden_layers": 24,
48
+ "num_attention_heads": 16,
49
+ },
50
+ "1b": {
51
+ "hidden_size": 2048,
52
+ "intermediate_size": 5504,
53
+ "num_hidden_layers": 24,
54
+ "num_attention_heads": 32,
55
+ },
56
+ }
57
+
58
+
59
+ class TTTConfig(PretrainedConfig):
60
+ r"""
61
+ This is the configuration class to store the configuration of a [`TTTModel`]. It is used to instantiate an TTT
62
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
63
+ defaults will yield a similar configuration to that of the TTT-1B.
64
+
65
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
66
+ documentation from [`PretrainedConfig`] for more information.
67
+
68
+
69
+ Args:
70
+ vocab_size (`int`, *optional*, defaults to 32000):
71
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
72
+ `inputs_ids` passed when calling [`LlamaModel`]
73
+ hidden_size (`int`, *optional*, defaults to 4096):
74
+ Dimension of the hidden representations.
75
+ intermediate_size (`int`, *optional*, defaults to 11008):
76
+ Dimension of the MLP representations.
77
+ num_hidden_layers (`int`, *optional*, defaults to 32):
78
+ Number of hidden layers in the Transformer decoder.
79
+ num_attention_heads (`int`, *optional*, defaults to 32):
80
+ Number of attention heads for each attention layer in the Transformer decoder.
81
+ num_key_value_heads (`int`, *optional*):
82
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
83
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
84
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
85
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
86
+ by meanpooling all the original heads within that group. For more details checkout [this
87
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
88
+ `num_attention_heads`.
89
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
90
+ The non-linear activation function (function or string) in the decoder.
91
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
92
+ The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
93
+ Llama 2 up to 4096, CodeLlama up to 16384.
94
+ initializer_range (`float`, *optional*, defaults to 0.02):
95
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
96
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
97
+ The epsilon used by the rms normalization layers.
98
+ use_cache (`bool`, *optional*, defaults to `True`):
99
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
100
+ relevant if `config.is_decoder=True`.
101
+ pad_token_id (`int`, *optional*):
102
+ Padding token id.
103
+ bos_token_id (`int`, *optional*, defaults to 1):
104
+ Beginning of stream token id.
105
+ eos_token_id (`int`, *optional*, defaults to 2):
106
+ End of stream token id.
107
+ pretraining_tp (`int`, *optional*, defaults to 1):
108
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
109
+ document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to understand more about it. This value is
110
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
111
+ issue](https://github.com/pytorch/pytorch/issues/76232).
112
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
113
+ Whether to tie weight embeddings
114
+ rope_theta (`float`, *optional*, defaults to 10000.0):
115
+ The base period of the RoPE embeddings.
116
+ rope_scaling (`Dict`, *optional*):
117
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
118
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
119
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
120
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
121
+ these scaling strategies behave:
122
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
123
+ experimental feature, subject to breaking API changes in future versions.
124
+ use_gate (`bool`, *optional*, defaults to `False`): whether use gating in Mamba backbone
125
+ share_qk (`bool`, *optional*, defaults to `False`): whether share Q/K projection matrix
126
+ ttt_layer_type (`str`, *optional*, defaults to `"linear"`): ttt block type, "linear" or "mlp", stands for TTT-Linear and TTT-MLP
127
+ ttt_base_lr (`float`, *optional*, defaults to 1.0): base learning rate for TTT learner
128
+ pre_conv (`bool`, *optional*, defaults to `False`): whether use conv before TTT
129
+ conv_kernel (`int`, *optional*, defaults to 4): kernel size of the conv layer
130
+ scan_checkpoint_group_size (`int`, *optional*, defaults to 0):
131
+ gradient checkpoint group size on seq dimension, 0 means no checkpointing.
132
+ In JAX implementation, we set it 4, which means we group 4 mini-batches together in 1 gradient checkpointg to save memory.
133
+
134
+
135
+ ```python
136
+ >>> from . import TTTModel, TTTConfig
137
+
138
+ >>> # Initializing a TTT ttt-1b style configuration
139
+ >>> configuration = TTTConfig()
140
+
141
+ >>> # Initializing a model from the ttt-1b style configuration
142
+ >>> model = TTTModel(configuration)
143
+
144
+ >>> # Accessing the model configuration
145
+ >>> configuration = model.config
146
+ ```"""
147
+
148
+ model_type = "ttt"
149
+
150
+ def __init__(
151
+ self,
152
+ vocab_size=32000,
153
+ hidden_size=2048,
154
+ intermediate_size=5504,
155
+ num_hidden_layers=24,
156
+ num_attention_heads=32,
157
+ hidden_act="silu",
158
+ max_position_embeddings=2048,
159
+ initializer_range=0.02,
160
+ rms_norm_eps=1e-6,
161
+ use_cache=False,
162
+ pad_token_id=None,
163
+ bos_token_id=1,
164
+ eos_token_id=2,
165
+ pretraining_tp=1,
166
+ tie_word_embeddings=True,
167
+ rope_theta=10000.0,
168
+ use_gate=False,
169
+ share_qk=False,
170
+ ttt_layer_type="linear",
171
+ ttt_base_lr=1.0,
172
+ mini_batch_size=16,
173
+ pre_conv=False,
174
+ conv_kernel=4,
175
+ scan_checkpoint_group_size=0,
176
+ **kwargs,
177
+ ):
178
+ self.vocab_size = vocab_size
179
+ self.max_position_embeddings = max_position_embeddings
180
+ self.hidden_size = hidden_size
181
+ self.intermediate_size = intermediate_size
182
+ self.num_hidden_layers = num_hidden_layers
183
+ self.num_attention_heads = num_attention_heads
184
+
185
+ self.hidden_act = hidden_act
186
+ self.initializer_range = initializer_range
187
+ self.rms_norm_eps = rms_norm_eps
188
+ self.pretraining_tp = pretraining_tp
189
+ self.use_cache = use_cache
190
+ self.rope_theta = rope_theta
191
+
192
+ self.use_gate = use_gate
193
+ self.share_qk = share_qk
194
+ self.ttt_layer_type = ttt_layer_type
195
+ self.ttt_base_lr = ttt_base_lr
196
+ self.mini_batch_size = mini_batch_size
197
+
198
+ self.pre_conv = pre_conv
199
+ self.conv_kernel = conv_kernel
200
+ self.scan_checkpoint_group_size = scan_checkpoint_group_size
201
+
202
+ super().__init__(
203
+ pad_token_id=pad_token_id,
204
+ bos_token_id=bos_token_id,
205
+ eos_token_id=eos_token_id,
206
+ tie_word_embeddings=tie_word_embeddings,
207
+ **kwargs,
208
+ )
209
+
210
+
211
+ ########################
212
+ ### Backbone Modules ###
213
+ ########################
214
+
215
+
216
+ def rotate_half(x):
217
+ """Rotates half the hidden dims of the input."""
218
+ x1 = x[..., : x.shape[-1] // 2]
219
+ x2 = x[..., x.shape[-1] // 2 :]
220
+ return torch.cat((-x2, x1), dim=-1)
221
+
222
+
223
+ def permute_qk(q, k):
224
+ # NOTE: EasyLM and transformers use different method to compute rotary emebdding
225
+ # we manually reorder the dim here to match our JAX implementation
226
+ # which may not be optimal for speed
227
+ # reference: https://github.com/young-geng/EasyLM/blob/981a2ed9630f44258a94b6f44dff2b7bd203ae8d/EasyLM/models/llama/convert_hf_to_easylm.py#L33
228
+ bsz, num_head, seq_len, head_dim = q.shape
229
+ q = q.reshape(bsz, num_head, seq_len, head_dim // 2, 2).transpose(3, 4).reshape(bsz, num_head, seq_len, head_dim)
230
+ k = k.reshape(bsz, num_head, seq_len, head_dim // 2, 2).transpose(3, 4).reshape(bsz, num_head, seq_len, head_dim)
231
+
232
+ return q, k
233
+
234
+
235
+ def undo_permute_qk(q, k):
236
+ # NOTE: EasyLM and transformers use different method to compute rotary emebdding
237
+ # we manually undo the reorder the dim here to match our JAX implementation
238
+ # which may not be optimal for speed
239
+ # reference: https://github.com/young-geng/EasyLM/blob/981a2ed9630f44258a94b6f44dff2b7bd203ae8d/EasyLM/models/llama/convert_hf_to_easylm.py#L33
240
+ bsz, num_head, seq_len, head_dim = q.shape
241
+ q = q.reshape(bsz, num_head, seq_len, 2, head_dim // 2).transpose(3, 4).reshape(bsz, num_head, seq_len, head_dim)
242
+ k = k.reshape(bsz, num_head, seq_len, 2, head_dim // 2).transpose(3, 4).reshape(bsz, num_head, seq_len, head_dim)
243
+
244
+ return q, k
245
+
246
+
247
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
248
+ """Applies Rotary Position Embedding to the query and key tensors.
249
+
250
+ Args:
251
+ q (`torch.Tensor`): The query tensor.
252
+ k (`torch.Tensor`): The key tensor.
253
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
254
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
255
+ position_ids (`torch.Tensor`, *optional*):
256
+ Deprecated and unused.
257
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
258
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
259
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
260
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
261
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
262
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
263
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
264
+ Returns:
265
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
266
+ """
267
+ cos = cos.unsqueeze(unsqueeze_dim)
268
+ sin = sin.unsqueeze(unsqueeze_dim)
269
+ q_embed = (q * cos) + (rotate_half(q) * sin)
270
+ k_embed = (k * cos) + (rotate_half(k) * sin)
271
+ return q_embed, k_embed
272
+
273
+
274
+ class RMSNorm(nn.Module):
275
+ def __init__(self, hidden_size, eps=1e-6):
276
+ super().__init__()
277
+ self.weight = nn.Parameter(torch.ones(hidden_size))
278
+ self.variance_epsilon = eps
279
+
280
+ def forward(self, hidden_states):
281
+ input_dtype = hidden_states.dtype
282
+ hidden_states = hidden_states.to(torch.float32)
283
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
284
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
285
+ return self.weight * hidden_states.to(input_dtype)
286
+
287
+
288
+ class SwiGluMLP(nn.Module):
289
+ def __init__(self, config):
290
+ super().__init__()
291
+ self.config = config
292
+ self.hidden_size = config.hidden_size
293
+ self.intermediate_size = config.intermediate_size
294
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
295
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
296
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
297
+ self.act_fn = ACT2FN[config.hidden_act]
298
+
299
+ def forward(self, x):
300
+ if self.config.pretraining_tp > 1:
301
+ slice = self.intermediate_size // self.config.pretraining_tp
302
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
303
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
304
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
305
+
306
+ gate_proj = torch.cat(
307
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)],
308
+ dim=-1,
309
+ )
310
+ up_proj = torch.cat(
311
+ [F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)],
312
+ dim=-1,
313
+ )
314
+
315
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
316
+ down_proj = [
317
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
318
+ ]
319
+ down_proj = sum(down_proj)
320
+ else:
321
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
322
+
323
+ return down_proj
324
+
325
+
326
+ class RotaryEmbedding(nn.Module):
327
+ def __init__(
328
+ self,
329
+ dim,
330
+ max_position_embeddings=16,
331
+ base=10000,
332
+ device=None,
333
+ scaling_factor=1.0,
334
+ ):
335
+ super().__init__()
336
+ self.scaling_factor = scaling_factor
337
+ self.dim = dim
338
+ self.max_position_embeddings = max_position_embeddings
339
+ self.base = base
340
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
341
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
342
+
343
+ @torch.no_grad()
344
+ def forward(self, x, position_ids):
345
+ # x: [bs, num_attention_heads, seq_len, head_size]
346
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
347
+ position_ids_expanded = position_ids[:, None, :].float()
348
+ # Force float32 since bfloat16 loses precision on long contexts
349
+ # See https://github.com/huggingface/transformers/pull/29285
350
+ device_type = x.device.type
351
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
352
+ with torch.autocast(device_type=device_type, enabled=False):
353
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
354
+ emb = torch.cat((freqs, freqs), dim=-1)
355
+ cos = emb.cos()
356
+ sin = emb.sin()
357
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
358
+
359
+
360
+ class Conv(nn.Module):
361
+ def __init__(self, config, layer_idx):
362
+ super().__init__()
363
+ self.config = config
364
+ self.layer_idx = layer_idx
365
+
366
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
367
+ self.conv = nn.Conv1d(
368
+ config.hidden_size,
369
+ config.hidden_size,
370
+ bias=True,
371
+ kernel_size=config.conv_kernel,
372
+ groups=config.hidden_size,
373
+ padding=config.conv_kernel - 1,
374
+ )
375
+
376
+ def __call__(self, hidden_states, cache_params=None):
377
+ seq_len = hidden_states.shape[1]
378
+ hidden_states = self.norm(hidden_states)
379
+ # [B, C, L]
380
+ hidden_states = hidden_states.transpose(1, 2)
381
+
382
+ if causal_conv1d_fn is None:
383
+ if cache_params is not None:
384
+ if cache_params.seqlen_offset > 0:
385
+ conv_state = cache_params.conv_states_dic["pre_conv"][self.layer_idx]
386
+ conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
387
+ conv_state[:, :, -1] = hidden_states[:, :, 0]
388
+ cache_params.conv_states_dic["pre_conv"][self.layer_idx].copy_(conv_state)
389
+ hidden_states = torch.sum(conv_state * self.conv.weight[:, 0, :], dim=-1)
390
+ hidden_states += self.conv.bias
391
+ hidden_states = hidden_states.unsqueeze(-1)
392
+ else:
393
+ conv_state = nn.functional.pad(
394
+ hidden_states,
395
+ (self.config.conv_kernel - hidden_states.shape[-1], 0),
396
+ )
397
+ cache_params.conv_states_dic["pre_conv"][self.layer_idx].copy_(conv_state)
398
+ hidden_states = self.conv(hidden_states)[..., :seq_len]
399
+ else:
400
+ hidden_states = self.conv(hidden_states)[..., :seq_len]
401
+ else:
402
+ conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2))
403
+ if cache_params is not None and cache_params.seqlen_offset > 0:
404
+ hidden_states = causal_conv1d_update(
405
+ hidden_states.squeeze(-1),
406
+ cache_params.conv_states_dic["pre_conv"][self.layer_idx],
407
+ conv_weights,
408
+ self.conv.bias,
409
+ None,
410
+ )
411
+ hidden_states = hidden_states.unsqueeze(-1)
412
+ else:
413
+ if cache_params is not None:
414
+ conv_states = nn.functional.pad(
415
+ hidden_states,
416
+ (self.config.conv_kernel - hidden_states.shape[-1], 0),
417
+ )
418
+ cache_params.conv_states_dic["pre_conv"][self.layer_idx].copy_(conv_states)
419
+ hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv.bias, activation=None)
420
+
421
+ # [B, L, C]
422
+ hidden_states = hidden_states.transpose(1, 2)
423
+
424
+ return hidden_states
425
+
426
+
427
+ #########################
428
+ ### TTT Layer Modules ###
429
+ #########################
430
+
431
+
432
+ def scan(f, init, xs, out, checkpoint_group=0):
433
+ """Minic jax.lax.scan function."""
434
+ carry = init
435
+ if isinstance(xs, dict):
436
+ num_items = len(next(iter(xs.values())))
437
+ else:
438
+ num_items = len(xs[0])
439
+
440
+ def scan_fn(carry, i_start, i_end):
441
+ for i in range(i_start, i_end):
442
+ if isinstance(xs, dict):
443
+ x = {key: tensor[i] for key, tensor in xs.items()}
444
+ else:
445
+ x = [x[i] for x in xs]
446
+ carry, y = f(carry, x)
447
+ out[i] = y
448
+ return carry
449
+
450
+ if checkpoint_group > 0:
451
+ ckpt_every_n = num_items // checkpoint_group
452
+ for k in range(0, num_items, ckpt_every_n):
453
+ carry = torch.utils.checkpoint.checkpoint(
454
+ scan_fn, carry, k, min(k + ckpt_every_n, num_items), use_reentrant=False
455
+ )
456
+ else:
457
+ carry = scan_fn(carry, 0, num_items)
458
+
459
+ return carry, out
460
+
461
+
462
+ def ln_fwd(x, gamma, beta, eps=1e-6):
463
+ "Batch forward for LayerNorm."
464
+
465
+ # Mean and variance computation
466
+ mu = x.mean(dim=-1, keepdim=True)
467
+ var = x.var(dim=-1, keepdim=True, unbiased=False)
468
+
469
+ # Normalization
470
+ std = torch.sqrt(var + eps)
471
+ x_hat = (x - mu) / std
472
+
473
+ # Scale and shift
474
+ y = gamma * x_hat + beta
475
+
476
+ return y
477
+
478
+
479
+ def ln_fused_l2_bwd(x, l2_target, gamma, beta, eps=1e-6):
480
+ "Batch backward for LayerNorm fused with L2 loss."
481
+ D = x.shape[-1]
482
+
483
+ # Mean and variance computation
484
+ mu = x.mean(dim=-1, keepdim=True)
485
+ var = x.var(dim=-1, keepdim=True, unbiased=False)
486
+
487
+ # Normalization
488
+ std = torch.sqrt(var + eps)
489
+ x_hat = (x - mu) / std
490
+
491
+ # Scale and shift
492
+ y = gamma * x_hat + beta
493
+
494
+ grad_output = y - l2_target
495
+ grad_x_hat = grad_output * gamma
496
+ z = (
497
+ (1.0 / D)
498
+ * (
499
+ D * grad_x_hat
500
+ - grad_x_hat.sum(dim=-1, keepdim=True)
501
+ - x_hat * (grad_x_hat * x_hat).sum(dim=-1, keepdim=True)
502
+ )
503
+ / std
504
+ )
505
+
506
+ return z
507
+
508
+
509
+ # Modified from https://github.com/NVIDIA/Megatron-LM/blob/e33c8f78a35765d5aa37475a144da60e8a2349d1/megatron/core/fusions/fused_bias_gelu.py#L26
510
+ def gelu_bwd(x):
511
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
512
+ ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
513
+ return ff
514
+
515
+
516
+ class TTTCache:
517
+ """
518
+ TTTCache is a data structure that holds the last hidden states and gradients for the TTT layer.
519
+
520
+ Arguments:
521
+ model: TTTModel
522
+ batch_size: int
523
+
524
+ Attributes:
525
+ seqlen_offset: int
526
+ mini_batch_size: int
527
+ params_dict: Dict[str, Dict[int, torch.Tensor]] *_states, *_grad -> # layer_idx -> [batch_size, ...]
528
+ conv_states_dic: Dict[str, Dict[int, torch.Tensor]] *_states -> # layer_idx -> [batch_size, ...]
529
+
530
+ """
531
+
532
+ def __init__(self, model, batch_size: int):
533
+ config = model.config
534
+ self.seqlen_offset = 0
535
+ self.mini_batch_size = config.mini_batch_size
536
+
537
+ self.ttt_params_dict = defaultdict(dict)
538
+ if "linear" in config.ttt_layer_type:
539
+ self.ttt_param_names = ["W1", "b1"]
540
+ elif "mlp" in config.ttt_layer_type:
541
+ self.ttt_param_names = ["W1", "b1", "W2", "b2"]
542
+ else:
543
+ raise ValueError(f"TTT Layer Type {config.ttt_layer_type} not supported yet")
544
+
545
+ self.conv_states_dic = defaultdict(dict)
546
+ logger.info(f"Creating cache of size: {batch_size}")
547
+ for layer_idx in range(config.num_hidden_layers):
548
+ for name in self.ttt_param_names:
549
+ weight = getattr(model.layers[layer_idx].seq_modeling_block, name)
550
+ tiled_weight = torch.tile(weight.unsqueeze(0), (batch_size,) + (1,) * weight.dim()).to(model.device)
551
+ self.ttt_params_dict[f"{name}_states"][layer_idx] = tiled_weight
552
+ # for decoding, we need to store the gradients as well
553
+ self.ttt_params_dict[f"{name}_grad"][layer_idx] = torch.zeros_like(tiled_weight)
554
+
555
+ if config.pre_conv:
556
+ self.conv_states_dic["pre_conv"][layer_idx] = torch.zeros(
557
+ batch_size,
558
+ config.hidden_size,
559
+ config.conv_kernel,
560
+ device=model.device,
561
+ )
562
+ if config.share_qk:
563
+ self.conv_states_dic["ttt_conv_q"][layer_idx] = torch.zeros(
564
+ batch_size,
565
+ config.hidden_size,
566
+ config.conv_kernel,
567
+ device=model.device,
568
+ )
569
+ self.conv_states_dic["ttt_conv_k"][layer_idx] = torch.zeros(
570
+ batch_size,
571
+ config.hidden_size,
572
+ config.conv_kernel,
573
+ device=model.device,
574
+ )
575
+
576
+ def update(self, py_tree, layer_idx, seq_len):
577
+ if seq_len % self.mini_batch_size == 0:
578
+ # copy last mini-batch states, clear gradients
579
+ for name in self.ttt_param_names:
580
+ self.ttt_params_dict[f"{name}_states"][layer_idx].copy_(py_tree[f"{name}_states"])
581
+ self.ttt_params_dict[f"{name}_grad"][layer_idx].zero_()
582
+ elif seq_len < self.mini_batch_size:
583
+ if seq_len != 1 and self.seqlen_offset > 0 and self.seqlen_offset % self.mini_batch_size != 0:
584
+ raise ValueError("fractional update not supported yet.")
585
+ if (seq_len + self.seqlen_offset) % self.mini_batch_size == 0:
586
+ # copy last mini-batch states, clear gradients
587
+ for name in self.ttt_param_names:
588
+ self.ttt_params_dict[f"{name}_states"][layer_idx].copy_(py_tree[f"{name}_states"])
589
+ self.ttt_params_dict[f"{name}_grad"][layer_idx].zero_()
590
+ else:
591
+ # copy gradients for the next update
592
+ for name in self.ttt_param_names:
593
+ self.ttt_params_dict[f"{name}_grad"][layer_idx].copy_(py_tree[f"{name}_grad"])
594
+ else:
595
+ raise ValueError(f"seq_len {seq_len} is a partial update not supported yet")
596
+
597
+ def ttt_params_to_dict(self, layer_idx):
598
+ return {name: self.ttt_params_dict[name][layer_idx] for name in self.ttt_params_dict}
599
+
600
+
601
+ class TTTBase(nn.Module):
602
+ def __init__(self, config: TTTConfig, layer_idx: Optional[int] = None):
603
+ super().__init__()
604
+ self.config = config
605
+ self.layer_idx = layer_idx
606
+ if layer_idx is None:
607
+ logger.warning_once(
608
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
609
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
610
+ "when creating this class."
611
+ )
612
+
613
+ self.width = config.hidden_size
614
+ self.hidden_size = config.hidden_size
615
+ self.num_heads = config.num_attention_heads
616
+ self.head_dim = self.width // self.num_heads
617
+ self.mini_batch_size = config.mini_batch_size
618
+
619
+ # token_idx is a scale factor that scale the summation in Eqn. 4
620
+ token_idx = 1.0 / torch.arange(1, self.mini_batch_size + 1)
621
+ self.register_buffer("token_idx", token_idx, persistent=False)
622
+ # make the scale factor learnable
623
+ self.learnable_token_idx = nn.Parameter(torch.zeros((self.mini_batch_size,)))
624
+
625
+ self.share_qk = config.share_qk
626
+ self.conv_kernel = config.conv_kernel
627
+ self._init_qkvo_proj()
628
+ self._init_rope()
629
+ # Learnable eta in Sec. 2.7
630
+ self._init_ttt_lr_gate()
631
+ self._init_ttt_ln()
632
+
633
+ # use gating as in Mamba backbone
634
+ self.use_gate = config.use_gate
635
+ if self.use_gate:
636
+ self.g_proj = nn.Linear(self.width, self.width, bias=False)
637
+
638
+ self.post_norm = nn.LayerNorm(self.width, eps=1e-6)
639
+
640
+ def _init_qkvo_proj(self):
641
+ self.q_proj = nn.Linear(self.width, self.num_heads * self.head_dim, bias=False)
642
+ # we share Q/K projection when using Mamba backbone
643
+ if not self.share_qk:
644
+ self.k_proj = nn.Linear(self.width, self.num_heads * self.head_dim, bias=False)
645
+ self.v_proj = nn.Linear(self.width, self.num_heads * self.head_dim, bias=False)
646
+ self.o_proj = nn.Linear(self.width, self.num_heads * self.head_dim, bias=False)
647
+
648
+ # after share Q/K projection, we use different conv layers for Q and K
649
+ if self.share_qk:
650
+ self.conv_q = nn.Conv1d(
651
+ self.hidden_size,
652
+ self.hidden_size,
653
+ bias=True,
654
+ kernel_size=self.conv_kernel,
655
+ groups=self.hidden_size,
656
+ padding=self.conv_kernel - 1,
657
+ )
658
+ self.conv_k = nn.Conv1d(
659
+ self.hidden_size,
660
+ self.hidden_size,
661
+ bias=True,
662
+ kernel_size=self.conv_kernel,
663
+ groups=self.hidden_size,
664
+ padding=self.conv_kernel - 1,
665
+ )
666
+
667
+ def _init_rope(self):
668
+ self.rope_theta = self.config.rope_theta
669
+ self.rotary_emb = RotaryEmbedding(
670
+ self.head_dim,
671
+ max_position_embeddings=self.mini_batch_size,
672
+ base=self.rope_theta,
673
+ )
674
+
675
+ def _init_ttt_lr_gate(self):
676
+ # [width, 1]
677
+ linear_weight_data = nn.Linear(self.width, 1, bias=True).weight.data
678
+ # prepending head dim -> [num_heads, width, 1]
679
+ self.learnable_ttt_lr_weight = nn.Parameter(
680
+ torch.stack(
681
+ [torch.normal(0, 0.02, size=linear_weight_data.shape) for _ in range(self.num_heads)],
682
+ dim=0,
683
+ )
684
+ )
685
+ linear_bias_data = nn.Linear(self.width, 1, bias=True).bias.data
686
+ # init bias to 0 following original JAX impl.
687
+ # [num_heads, 1]
688
+ self.learnable_ttt_lr_bias = nn.Parameter(
689
+ torch.stack(
690
+ [torch.zeros_like(linear_bias_data) for _ in range(self.num_heads)],
691
+ dim=0,
692
+ )
693
+ )
694
+
695
+ def _init_ttt_ln(self):
696
+ ln_weight_data = nn.LayerNorm(self.head_dim).weight.data
697
+ # prepending head dim -> [num_heads, width]
698
+ self.ttt_norm_weight = nn.Parameter(torch.tile(ln_weight_data.unsqueeze(0), (self.num_heads, 1)))
699
+ ln_bias_data = nn.LayerNorm(self.head_dim).bias.data
700
+ self.ttt_norm_bias = nn.Parameter(torch.tile(ln_bias_data.unsqueeze(0), (self.num_heads, 1)))
701
+
702
+ def get_qkv_projections(self, hidden_states, cache_params: Optional[TTTCache] = None):
703
+ if self.share_qk:
704
+ xq, XV = self.q_proj(hidden_states), self.v_proj(hidden_states)
705
+ seq_len = xq.shape[1]
706
+ xq = xq.transpose(1, 2)
707
+ if causal_conv1d_fn is None:
708
+ if cache_params is not None:
709
+ if cache_params.seqlen_offset > 0:
710
+ conv_q_state = cache_params.conv_states_dic["ttt_conv_q"][self.layer_idx]
711
+ conv_q_state = torch.roll(conv_q_state, shifts=-1, dims=-1)
712
+ conv_q_state[:, :, -1] = xq[:, :, 0]
713
+ cache_params.conv_states_dic["ttt_conv_q"][self.layer_idx].copy_(conv_q_state)
714
+ XQ = torch.sum(conv_q_state * self.conv_q.weight[:, 0, :], dim=-1)
715
+ XQ += self.conv_q.bias
716
+ XQ = XQ.unsqueeze(-1)
717
+
718
+ conv_k_state = cache_params.conv_states_dic["ttt_conv_k"][self.layer_idx]
719
+ conv_k_state = torch.roll(conv_k_state, shifts=-1, dims=-1)
720
+ conv_k_state[:, :, -1] = xq[:, :, 0]
721
+ cache_params.conv_states_dic["ttt_conv_k"][self.layer_idx].copy_(conv_k_state)
722
+ XK = torch.sum(conv_k_state * self.conv_k.weight[:, 0, :], dim=-1)
723
+ XK += self.conv_k.bias
724
+ XK = XK.unsqueeze(-1)
725
+ else:
726
+ conv_q_state = nn.functional.pad(xq, (self.config.conv_kernel - xq.shape[-1], 0))
727
+ cache_params.conv_states_dic["ttt_conv_q"][self.layer_idx].copy_(conv_q_state)
728
+ XQ = self.conv_q(xq)[..., :seq_len]
729
+ conv_k_state = nn.functional.pad(xq, (self.config.conv_kernel - xq.shape[-1], 0))
730
+ cache_params.conv_states_dic["ttt_conv_k"][self.layer_idx].copy_(conv_k_state)
731
+ XK = self.conv_k(xq)[..., :seq_len]
732
+ else:
733
+ XQ = self.conv_q(xq)[..., :seq_len]
734
+ XK = self.conv_k(xq)[..., :seq_len]
735
+ else:
736
+ conv_q_weights = self.conv_q.weight.view(self.conv_q.weight.size(0), self.conv_q.weight.size(2))
737
+ conv_k_weights = self.conv_k.weight.view(self.conv_k.weight.size(0), self.conv_k.weight.size(2))
738
+ if cache_params is not None and cache_params.seqlen_offset > 0:
739
+ XQ = causal_conv1d_update(
740
+ xq.squeeze(-1),
741
+ cache_params.conv_states_dic["ttt_conv_q"][self.layer_idx],
742
+ conv_q_weights,
743
+ self.conv_q.bias,
744
+ None,
745
+ )
746
+ XQ = XQ.unsqueeze(-1)
747
+ XK = causal_conv1d_update(
748
+ xq.squeeze(-1),
749
+ cache_params.conv_states_dic["ttt_conv_k"][self.layer_idx],
750
+ conv_k_weights,
751
+ self.conv_k.bias,
752
+ None,
753
+ )
754
+ XK = XK.unsqueeze(-1)
755
+ else:
756
+ if cache_params is not None:
757
+ conv_q_states = nn.functional.pad(xq, (self.config.conv_kernel - xq.shape[-1], 0))
758
+ cache_params.conv_states_dic["ttt_conv_q"][self.layer_idx].copy_(conv_q_states)
759
+ conv_k_states = nn.functional.pad(xq, (self.config.conv_kernel - xq.shape[-1], 0))
760
+ cache_params.conv_states_dic["ttt_conv_k"][self.layer_idx].copy_(conv_k_states)
761
+ XQ = causal_conv1d_fn(xq, conv_q_weights, self.conv_q.bias, activation=None)
762
+ XK = causal_conv1d_fn(xq, conv_k_weights, self.conv_k.bias, activation=None)
763
+
764
+ XQ = XQ.transpose(1, 2)
765
+ XK = XK.transpose(1, 2)
766
+ else:
767
+ XQ, XK, XV = (
768
+ self.q_proj(hidden_states),
769
+ self.k_proj(hidden_states),
770
+ self.v_proj(hidden_states),
771
+ )
772
+ return XQ, XK, XV
773
+
774
+ def _split_heads(self, hidden_states):
775
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
776
+
777
+ def get_eta(self, X, mini_batch_step_offset, mini_batch_size):
778
+ # [B, num_heads, num_mini_batch, mini_batch_size, 1]
779
+ ttt_lr = torch.einsum("bnkc,hdc->bhnkd", X, self.learnable_ttt_lr_weight) + self.learnable_ttt_lr_bias.reshape(
780
+ 1, -1, 1, 1, 1
781
+ )
782
+ ttt_lr = F.sigmoid(ttt_lr)
783
+
784
+ # [B, num_heads, num_mini_batch, 1, mini_batch_size]
785
+ ttt_lr = ttt_lr.permute(0, 1, 2, 4, 3)
786
+ ttt_lr_eta = self.config.ttt_base_lr * ttt_lr / self.head_dim
787
+
788
+ # [B, L]
789
+ token_idx = self.token_idx + self.learnable_token_idx
790
+ token_idx = token_idx[mini_batch_step_offset : mini_batch_step_offset + mini_batch_size]
791
+
792
+ # token idx should be greast than 0
793
+ token_idx = torch.clamp_min(token_idx, 0.0)
794
+
795
+ # NOTE: token_eta is a scale factor that applies to each token in the mini-batch
796
+ # [B, num_heads, num_mini_batch, mini_batch_size, 1]
797
+ token_eta = torch.broadcast_to(
798
+ token_idx.reshape(1, 1, 1, mini_batch_size, 1),
799
+ (X.shape[0], self.num_heads, X.shape[1], mini_batch_size, 1),
800
+ )
801
+
802
+ return token_eta, ttt_lr_eta
803
+
804
+ def apply_gate(self, hidden_states, ttt_output):
805
+ y = self.g_proj(hidden_states)
806
+ # use 'tanh' approximation for matching JAX impl.
807
+ y = F.gelu(y, approximate="tanh")
808
+ output = y * ttt_output
809
+ return output
810
+
811
+ def get_ttt_inputs(self, inputs, mini_batch_size, cache_params):
812
+ XQ = inputs["XQ"]
813
+ XK = inputs["XK"]
814
+ XV = inputs["XV"]
815
+ X = inputs["X"]
816
+ B, L, C = X.shape
817
+ num_mini_batch = L // mini_batch_size
818
+ # [B ,num_mini_batch, mini_batch_size, C]
819
+ X = X.reshape(B, num_mini_batch, mini_batch_size, self.width)
820
+
821
+ XQ = XQ.reshape(B, self.num_heads, L // mini_batch_size, mini_batch_size, self.head_dim)
822
+ XK = XK.reshape(B, self.num_heads, L // mini_batch_size, mini_batch_size, self.head_dim)
823
+ XV = XV.reshape(B, self.num_heads, L // mini_batch_size, mini_batch_size, self.head_dim)
824
+
825
+ if cache_params is not None:
826
+ mini_batch_step_offset = cache_params.seqlen_offset % self.mini_batch_size
827
+ else:
828
+ mini_batch_step_offset = 0
829
+ token_eta, ttt_lr_eta = self.get_eta(X, mini_batch_step_offset, mini_batch_size)
830
+ eta = token_eta * ttt_lr_eta
831
+ # decouple token_coeff and ilr_coeff for decoding
832
+ inputs = {
833
+ "XQ": XQ,
834
+ "XK": XK,
835
+ "XV": XV,
836
+ "eta": eta,
837
+ "token_eta": token_eta,
838
+ "ttt_lr_eta": ttt_lr_eta,
839
+ }
840
+ return inputs
841
+
842
+ def ttt(
843
+ self,
844
+ inputs,
845
+ mini_batch_size,
846
+ last_mini_batch_params_dict,
847
+ cache_params: Optional[TTTCache] = None,
848
+ ):
849
+ raise NotImplementedError("ttt method must be implemented in TTTBase subclasses.")
850
+
851
+ def forward(
852
+ self,
853
+ hidden_states: torch.Tensor,
854
+ attention_mask: Optional[torch.Tensor] = None,
855
+ position_ids: Optional[torch.LongTensor] = None,
856
+ cache_params: Optional[TTTCache] = None,
857
+ ):
858
+ B, L = hidden_states.shape[:2]
859
+ reminder_len = L % self.mini_batch_size
860
+ num_mini_batch = L // self.mini_batch_size
861
+ last_mini_batch_params_dict = None
862
+
863
+ XQ, XK, XV = self.get_qkv_projections(hidden_states, cache_params=cache_params)
864
+
865
+ # [B, L, C] -> [B, L, num_heads, head_dim] -> [B, num_heads, L, head_dim]
866
+ XQ = XQ.reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2)
867
+ XK = XK.reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2)
868
+ XV = XV.reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2)
869
+
870
+ cos, sin = self.rotary_emb(XV, position_ids % self.mini_batch_size)
871
+
872
+ # permute_qk and undo_permute_qk is just for aligning pytorch with jax pre-training
873
+ XQ, XK = permute_qk(XQ, XK)
874
+ XQ, XK = apply_rotary_pos_emb(XQ, XK, cos, sin)
875
+ XQ, XK = undo_permute_qk(XQ, XK)
876
+
877
+ output_hidden_states = []
878
+ # when input sequence length is not a multiple of mini_batch_size
879
+ # we need to compute them seperately, when computing the reminder,
880
+ # we will need the last_mini_batch_params_dict to continue TTT learning
881
+ if num_mini_batch > 0:
882
+ inputs = {
883
+ "XQ": XQ[:, :, : num_mini_batch * self.mini_batch_size],
884
+ "XK": XK[:, :, : num_mini_batch * self.mini_batch_size],
885
+ "XV": XV[:, :, : num_mini_batch * self.mini_batch_size],
886
+ "X": hidden_states[:, : num_mini_batch * self.mini_batch_size],
887
+ }
888
+ output_mod, last_mini_batch_params_dict = self.ttt(
889
+ self.get_ttt_inputs(inputs, self.mini_batch_size, cache_params),
890
+ mini_batch_size=self.mini_batch_size,
891
+ last_mini_batch_params_dict=last_mini_batch_params_dict,
892
+ cache_params=cache_params,
893
+ )
894
+ output_hidden_states.append(output_mod)
895
+ if reminder_len > 0:
896
+ inputs = {
897
+ "XQ": XQ[:, :, -reminder_len:],
898
+ "XK": XK[:, :, -reminder_len:],
899
+ "XV": XV[:, :, -reminder_len:],
900
+ "X": hidden_states[:, -reminder_len:],
901
+ }
902
+ output_reminder, _ = self.ttt(
903
+ self.get_ttt_inputs(inputs, reminder_len, cache_params),
904
+ mini_batch_size=reminder_len,
905
+ last_mini_batch_params_dict=last_mini_batch_params_dict,
906
+ cache_params=cache_params,
907
+ )
908
+ output_hidden_states.append(output_reminder)
909
+
910
+ output_hidden_states = torch.cat(output_hidden_states, dim=1)
911
+ output_hidden_states = self.post_norm(output_hidden_states)
912
+ if self.use_gate:
913
+ output_hidden_states = self.apply_gate(hidden_states, output_hidden_states)
914
+ output_hidden_states = self.o_proj(output_hidden_states)
915
+
916
+ return output_hidden_states
917
+
918
+
919
+ class TTTLinear(TTTBase):
920
+ def __init__(self, config: TTTConfig, layer_idx: Optional[int] = None):
921
+ super().__init__(config, layer_idx)
922
+ # TTT model initialization for TTT-Linear
923
+ self.W1 = nn.Parameter(torch.normal(0, 0.02, size=(self.num_heads, self.head_dim, self.head_dim)))
924
+ self.b1 = nn.Parameter(torch.zeros(self.num_heads, 1, self.head_dim))
925
+
926
+ def ttt(
927
+ self,
928
+ inputs,
929
+ mini_batch_size,
930
+ last_mini_batch_params_dict,
931
+ cache_params: Optional[TTTCache] = None,
932
+ ):
933
+ if mini_batch_size is None:
934
+ mini_batch_size = self.mini_batch_size
935
+
936
+ # in this case, we are decoding
937
+ if last_mini_batch_params_dict is None and cache_params is not None:
938
+ last_mini_batch_params_dict = cache_params.ttt_params_to_dict(self.layer_idx)
939
+
940
+ # [B, num_heads, num_mini_batch, mini_batch_size, head_dim]
941
+ B = inputs["XV"].shape[0]
942
+ num_mini_batch = inputs["XV"].shape[2]
943
+ L = inputs["XV"].shape[2] * inputs["XV"].shape[3]
944
+ device = inputs["XV"].device
945
+ dtype = inputs["XV"].dtype
946
+
947
+ # NOTE:
948
+ # for prefilling, we will always use dual form for faster computation
949
+ # we need to use primal form if mini_batch_size is not a multiple of self.mini_batch_size
950
+ # since we need store the gradient for the next mini-batch computation
951
+ use_dual_form = cache_params is None or mini_batch_size % self.mini_batch_size == 0
952
+
953
+ def compute_mini_batch(params_dict, inputs):
954
+ # [B, nh, f, f], nh=num_heads, f=head_dim
955
+ W1_init = params_dict["W1_states"]
956
+ # [B, nh, 1, f]
957
+ b1_init = params_dict["b1_states"]
958
+
959
+ # [B,nh,K,f], K=mini_batch_size
960
+ XQ_mini_batch = inputs["XQ"]
961
+ XV_mini_batch = inputs["XV"]
962
+ XK_mini_batch = inputs["XK"]
963
+ # [B, nh, K, 1]
964
+ eta_mini_batch = inputs["eta"]
965
+ token_eta_mini_batch = inputs["token_eta"]
966
+ ttt_lr_eta_mini_batch = inputs["ttt_lr_eta"]
967
+
968
+ X1 = XK_mini_batch
969
+ # [B,nh,K,f] @ [B,nh,f,f] -> [B,nh,K,f]
970
+ Z1 = X1 @ W1_init + b1_init
971
+ reconstruction_target = XV_mini_batch - XK_mini_batch
972
+
973
+ ln_weight = self.ttt_norm_weight.reshape(self.num_heads, 1, self.head_dim)
974
+ ln_bias = self.ttt_norm_bias.reshape(self.num_heads, 1, self.head_dim)
975
+ # [B,nh,K,f]
976
+ grad_l_wrt_Z1 = ln_fused_l2_bwd(Z1, reconstruction_target, ln_weight, ln_bias)
977
+
978
+ if use_dual_form:
979
+ # [B,nh,K,K]
980
+ Attn1 = torch.tril(XQ_mini_batch @ X1.transpose(-2, -1))
981
+ # [B,nh,1,f] - [B,nh,K,K] @ [B,nh,K,f] -> [B,nh,K,f]
982
+ b1_bar = b1_init - torch.tril(eta_mini_batch) @ grad_l_wrt_Z1
983
+ # [B,nh,K,f] @ [B,nh,f,f] - ([B,nh,K,1] * [B,nh,K,K]) @ [B,nh,K,f] + [B,nh,K,f]
984
+ Z1_bar = XQ_mini_batch @ W1_init - (eta_mini_batch * Attn1) @ grad_l_wrt_Z1 + b1_bar
985
+
986
+ last_eta_mini_batch = eta_mini_batch[:, :, -1, :, None]
987
+ # [B,nh,f,f] - [B,nh,f,K] @ [B,nh,K,f]
988
+ W1_last = W1_init - (last_eta_mini_batch * X1).transpose(-1, -2) @ grad_l_wrt_Z1
989
+ # [B,nh,1,f]
990
+ b1_last = b1_init - torch.sum(last_eta_mini_batch * grad_l_wrt_Z1, dim=-2, keepdim=True)
991
+ grad_W1_last = torch.zeros_like(W1_last)
992
+ grad_b1_last = torch.zeros_like(b1_last)
993
+ else:
994
+ ttt_lr_eta_mini_batch = torch.broadcast_to(
995
+ ttt_lr_eta_mini_batch,
996
+ (
997
+ *ttt_lr_eta_mini_batch.shape[:2],
998
+ mini_batch_size,
999
+ mini_batch_size,
1000
+ ),
1001
+ )
1002
+
1003
+ # [B, nh, K, f, f]
1004
+ grad_W1 = torch.einsum("bhki,bhkj->bhkij", X1, grad_l_wrt_Z1)
1005
+ grad_W1 = torch.einsum("bhnk,bhkij->bhnij", torch.tril(ttt_lr_eta_mini_batch), grad_W1)
1006
+ grad_W1 = grad_W1 + params_dict["W1_grad"].unsqueeze(2)
1007
+ # [B, nh, K, f]
1008
+ grad_b1 = torch.einsum("bhnk,bhki->bhni", torch.tril(ttt_lr_eta_mini_batch), grad_l_wrt_Z1)
1009
+ grad_b1 = grad_b1 + params_dict["b1_grad"]
1010
+
1011
+ W1_bar = W1_init.unsqueeze(2) - grad_W1 * token_eta_mini_batch.unsqueeze(-1)
1012
+ b1_bar = b1_init - grad_b1 * token_eta_mini_batch
1013
+
1014
+ # [B, nh, K, 1, f] @ [B, nh, K, f, f]
1015
+ Z1_bar = (XQ_mini_batch.unsqueeze(3) @ W1_bar).squeeze(3) + b1_bar
1016
+
1017
+ W1_last = W1_bar[:, :, -1]
1018
+ b1_last = b1_bar[:, :, -1:]
1019
+ grad_W1_last = grad_W1[:, :, -1]
1020
+ grad_b1_last = grad_b1[:, :, -1:]
1021
+
1022
+ Z1_bar = ln_fwd(Z1_bar, ln_weight, ln_bias)
1023
+
1024
+ XQW_mini_batch = XQ_mini_batch + Z1_bar
1025
+
1026
+ last_param_dict = {
1027
+ "W1_states": W1_last,
1028
+ "b1_states": b1_last,
1029
+ "W1_grad": grad_W1_last,
1030
+ "b1_grad": grad_b1_last,
1031
+ }
1032
+ return last_param_dict, XQW_mini_batch
1033
+
1034
+ if last_mini_batch_params_dict is not None:
1035
+ init_params_dict = last_mini_batch_params_dict
1036
+ else:
1037
+ init_params_dict = {
1038
+ "W1_states": torch.tile(self.W1.unsqueeze(0), dims=(B, 1, 1, 1)),
1039
+ "b1_states": torch.tile(self.b1.unsqueeze(0), dims=(B, 1, 1, 1)),
1040
+ }
1041
+ init_params_dict.update(W1_grad=torch.zeros_like(init_params_dict["W1_states"]))
1042
+ init_params_dict.update(b1_grad=torch.zeros_like(init_params_dict["b1_states"]))
1043
+
1044
+ # [B,num_heads, num_mini_batch, mini_batch_size, f] -> [num_mini_batch, B, num_heads, mini_batch_size, f]
1045
+ inputs = tree_map(lambda x: x.permute(2, 0, 1, 3, 4), inputs)
1046
+
1047
+ # allocate output tensor
1048
+ XQW_batch = torch.empty(
1049
+ (num_mini_batch, B, self.num_heads, mini_batch_size, self.head_dim),
1050
+ device=device,
1051
+ dtype=dtype,
1052
+ )
1053
+ # XQW_batch: [num_mini_batch, B, num_heads, mini_batch_size, head_dim]
1054
+ batch_params_dict, XQW_batch = scan(
1055
+ compute_mini_batch,
1056
+ init_params_dict,
1057
+ inputs,
1058
+ XQW_batch,
1059
+ self.config.scan_checkpoint_group_size if self.training else 0,
1060
+ )
1061
+
1062
+ # [B, num_heads, L, C]
1063
+ if cache_params is not None:
1064
+ cache_params.update(batch_params_dict, self.layer_idx, L)
1065
+
1066
+ # [num_mini_batch, B, num_heads, mini_batch_size, head_dim] -> [B, num_mini_batch, mini_batch_size, num_heads, head_dim]
1067
+ XQW_batch = XQW_batch.permute(1, 0, 3, 2, 4)
1068
+ # [B, L, C]
1069
+ XQW_batch = XQW_batch.reshape(B, L, self.width)
1070
+ return XQW_batch, batch_params_dict
1071
+
1072
+
1073
+ class TTTMLP(TTTBase):
1074
+ def __init__(self, config: TTTConfig, layer_idx: Optional[int] = None):
1075
+ super().__init__(config, layer_idx)
1076
+ # TTT model initialization for TTT-MLP
1077
+ self.W1 = nn.Parameter(torch.normal(0, 0.02, size=(self.num_heads, self.head_dim, 4 * self.head_dim)))
1078
+ self.b1 = nn.Parameter(torch.zeros(self.num_heads, 1, 4 * self.head_dim))
1079
+ self.W2 = nn.Parameter(torch.normal(0, 0.02, size=(self.num_heads, 4 * self.head_dim, self.head_dim)))
1080
+ self.b2 = nn.Parameter(torch.zeros(self.num_heads, 1, self.head_dim))
1081
+
1082
+ def ttt(
1083
+ self,
1084
+ inputs,
1085
+ mini_batch_size,
1086
+ last_mini_batch_params_dict,
1087
+ cache_params: Optional[TTTCache] = None,
1088
+ ):
1089
+ if mini_batch_size is None:
1090
+ mini_batch_size = self.mini_batch_size
1091
+
1092
+ # in this case, we are decoding
1093
+ if last_mini_batch_params_dict is None and cache_params is not None:
1094
+ last_mini_batch_params_dict = cache_params.ttt_params_to_dict(self.layer_idx)
1095
+
1096
+ # [B, num_heads, num_mini_batch, mini_batch_size, head_dim]
1097
+ B = inputs["XV"].shape[0]
1098
+ num_mini_batch = inputs["XV"].shape[2]
1099
+ L = inputs["XV"].shape[2] * inputs["XV"].shape[3]
1100
+ device = inputs["XV"].device
1101
+ dtype = inputs["XV"].dtype
1102
+ # NOTE:
1103
+ # for prefilling, we will always use dual form for faster computation
1104
+ # we need to use primal form if mini_batch_size is not a multiple of self.mini_batch_size
1105
+ # since we need store the gradient for the next mini-batch computation
1106
+ use_dual_form = cache_params is None or mini_batch_size % self.mini_batch_size == 0
1107
+
1108
+ def compute_mini_batch(params_dict, inputs):
1109
+ # [B, nh, f, 4f]
1110
+ W1_init = params_dict["W1_states"]
1111
+ # [B, nh, 1, 4f]
1112
+ b1_init = params_dict["b1_states"]
1113
+ # [B, nh, 4f, f]
1114
+ W2_init = params_dict["W2_states"]
1115
+ # [B, nh, 1, f]
1116
+ b2_init = params_dict["b2_states"]
1117
+
1118
+ # [B,nh,K,f]
1119
+ XQ_mini_batch = inputs["XQ"]
1120
+ XV_mini_batch = inputs["XV"]
1121
+ XK_mini_batch = inputs["XK"]
1122
+ # [B,nh,K,1]
1123
+ eta_mini_batch = inputs["eta"]
1124
+ token_eta_mini_batch = inputs["token_eta"]
1125
+ ttt_lr_eta_mini_batch = inputs["ttt_lr_eta"]
1126
+
1127
+ X1 = XK_mini_batch
1128
+ # [B,nh,K,f] @ [B,nh,f,4f] -> [B,nh,K,4f]
1129
+ Z1 = X1 @ W1_init + b1_init
1130
+ X2 = F.gelu(Z1, approximate="tanh")
1131
+ # [B,nh,K,4f] @ [B,nh,4f,f] -> [B,nh,K,f]
1132
+ Z2 = X2 @ W2_init + b2_init
1133
+ reconstruction_target = XV_mini_batch - XK_mini_batch
1134
+
1135
+ ln_weight = self.ttt_norm_weight.reshape(self.num_heads, 1, self.head_dim)
1136
+ ln_bias = self.ttt_norm_bias.reshape(self.num_heads, 1, self.head_dim)
1137
+ # [B, nh, K, f]
1138
+ grad_l_wrt_Z2 = ln_fused_l2_bwd(Z2, reconstruction_target, ln_weight, ln_bias)
1139
+ # [B, nh, K, 4f]
1140
+ grad_l_wrt_Z1 = grad_l_wrt_Z2 @ W2_init.transpose(-2, -1) * gelu_bwd(Z1)
1141
+
1142
+ if use_dual_form:
1143
+ Attn1 = torch.tril(XQ_mini_batch @ X1.transpose(-2, -1)) # [B,nh,K,K]
1144
+ # [B,nh,1,f] - [B,nh,K,K] @ [B,nh,K,4f] -> [B,nh,K,4f]
1145
+ b1_bar = b1_init - torch.tril(eta_mini_batch) @ grad_l_wrt_Z1
1146
+ # [B,nh,K,f] @ [B,nh,f,4f] - ([B,nh,K,1] * [B,nh,K,K]) @ [B,nh,K,4f] + [B,nh,K,4f]
1147
+ Z1_bar = XQ_mini_batch @ W1_init - (eta_mini_batch * Attn1) @ grad_l_wrt_Z1 + b1_bar
1148
+ X2_bar = F.gelu(Z1_bar, approximate="tanh")
1149
+
1150
+ # [B,nh,K,K]
1151
+ Attn2 = torch.tril(X2_bar @ X2.transpose(-2, -1))
1152
+ # [B,nh,1,f] - [B,nh,K,1] * [B,nh,K,f] -> [B,nh,K,f]
1153
+ b2_bar = b2_init - torch.tril(eta_mini_batch) @ grad_l_wrt_Z2
1154
+ # [B,nh,K,f] @ [1,nh,4f,f] - ([B,nh,K,1] * [B,nh,K,K]) @ [B,nh,K,f] + [B,nh,K,f]
1155
+ Z2_bar = X2_bar @ W2_init - (eta_mini_batch * Attn2) @ grad_l_wrt_Z2 + b2_bar
1156
+
1157
+ last_eta_mini_batch = eta_mini_batch[:, :, -1, :, None]
1158
+ # [B,nh,f,4f] - [B,nh,f,K] @ [B,nh,K,4f]
1159
+ W1_last = W1_init - (last_eta_mini_batch * X1).transpose(-1, -2) @ grad_l_wrt_Z1
1160
+ # [B,nh,1,4f]
1161
+ b1_last = b1_init - torch.sum(last_eta_mini_batch * grad_l_wrt_Z1, dim=-2, keepdim=True)
1162
+ # [B,nh,4f,f] - [B,nh,4f,K] @ [B,nh,K,f]
1163
+ W2_last = W2_init - (last_eta_mini_batch * X2).transpose(-1, -2) @ grad_l_wrt_Z2
1164
+ # [B,nh,1,f]
1165
+ b2_last = b2_init - torch.sum(last_eta_mini_batch * grad_l_wrt_Z2, dim=-2, keepdim=True)
1166
+ grad_W1_last = torch.zeros_like(W1_last)
1167
+ grad_b1_last = torch.zeros_like(b1_last)
1168
+ grad_W2_last = torch.zeros_like(W2_last)
1169
+ grad_b2_last = torch.zeros_like(b2_last)
1170
+
1171
+ else:
1172
+ ttt_lr_eta_mini_batch = torch.broadcast_to(
1173
+ ttt_lr_eta_mini_batch,
1174
+ (
1175
+ *ttt_lr_eta_mini_batch.shape[:2],
1176
+ mini_batch_size,
1177
+ mini_batch_size,
1178
+ ),
1179
+ )
1180
+
1181
+ # [B, nh, K, 4f, f]
1182
+ grad_W2 = torch.einsum("bhki,bhkj->bhkij", X2, grad_l_wrt_Z2)
1183
+ grad_W2 = torch.einsum("bhnk,bhkij->bhnij", torch.tril(ttt_lr_eta_mini_batch), grad_W2)
1184
+ grad_W2 = grad_W2 + params_dict["W2_grad"].unsqueeze(2)
1185
+ # [B, nh, K, f]
1186
+ grad_b2 = torch.einsum("bhnk,bhki->bhni", torch.tril(ttt_lr_eta_mini_batch), grad_l_wrt_Z2)
1187
+ grad_b2 = grad_b2 + params_dict["b2_grad"]
1188
+
1189
+ # [B, nh, K, f, 4f]
1190
+ grad_W1 = torch.einsum("bhki,bhkj->bhkij", X1, grad_l_wrt_Z1)
1191
+ grad_W1 = torch.einsum("bhnk,bhkij->bhnij", torch.tril(ttt_lr_eta_mini_batch), grad_W1)
1192
+ grad_W1 = grad_W1 + params_dict["W1_grad"].unsqueeze(2)
1193
+ # [B, nh, K, 4f]
1194
+ grad_b1 = torch.einsum("bhnk,bhki->bhni", torch.tril(ttt_lr_eta_mini_batch), grad_l_wrt_Z1)
1195
+ grad_b1 = grad_b1 + params_dict["b1_grad"]
1196
+
1197
+ W1_bar = W1_init.unsqueeze(2) - grad_W1 * token_eta_mini_batch.unsqueeze(-1)
1198
+ b1_bar = b1_init - grad_b1 * token_eta_mini_batch
1199
+ W2_bar = W2_init.unsqueeze(2) - grad_W2 * token_eta_mini_batch.unsqueeze(-1)
1200
+ b2_bar = b2_init - grad_b2 * token_eta_mini_batch
1201
+
1202
+ # [B, nh, K, 1, f] @ [B, nh, K, f, 4f] -> [B, nh, K, 4f]
1203
+ Z1_bar = (XQ_mini_batch.unsqueeze(3) @ W1_bar).squeeze(3) + b1_bar
1204
+ X2_bar = F.gelu(Z1_bar, approximate="tanh")
1205
+ Z2_bar = (X2_bar.unsqueeze(3) @ W2_bar).squeeze(3) + b2_bar
1206
+
1207
+ W1_last = W1_bar[:, :, -1]
1208
+ b1_last = b1_bar[:, :, -1:]
1209
+ W2_last = W2_bar[:, :, -1]
1210
+ b2_last = b2_bar[:, :, -1:]
1211
+ grad_W1_last = grad_W1[:, :, -1]
1212
+ grad_b1_last = grad_b1[:, :, -1:]
1213
+ grad_W2_last = grad_W2[:, :, -1]
1214
+ grad_b2_last = grad_b2[:, :, -1:]
1215
+
1216
+ Z2_bar = ln_fwd(Z2_bar, ln_weight, ln_bias)
1217
+
1218
+ XQW_mini_batch = XQ_mini_batch + Z2_bar
1219
+
1220
+ last_param_dict = {
1221
+ "W1_states": W1_last,
1222
+ "b1_states": b1_last,
1223
+ "W2_states": W2_last,
1224
+ "b2_states": b2_last,
1225
+ "W1_grad": grad_W1_last,
1226
+ "b1_grad": grad_b1_last,
1227
+ "W2_grad": grad_W2_last,
1228
+ "b2_grad": grad_b2_last,
1229
+ }
1230
+ return last_param_dict, XQW_mini_batch
1231
+
1232
+ if last_mini_batch_params_dict is not None:
1233
+ init_params_dict = last_mini_batch_params_dict
1234
+ else:
1235
+ init_params_dict = {
1236
+ "W1_states": torch.tile(self.W1.unsqueeze(0), dims=(B, 1, 1, 1)),
1237
+ "b1_states": torch.tile(self.b1.unsqueeze(0), dims=(B, 1, 1, 1)),
1238
+ "W2_states": torch.tile(self.W2.unsqueeze(0), dims=(B, 1, 1, 1)),
1239
+ "b2_states": torch.tile(self.b2.unsqueeze(0), dims=(B, 1, 1, 1)),
1240
+ }
1241
+ init_params_dict.update(W1_grad=torch.zeros_like(init_params_dict["W1_states"]))
1242
+ init_params_dict.update(b1_grad=torch.zeros_like(init_params_dict["b1_states"]))
1243
+ init_params_dict.update(W2_grad=torch.zeros_like(init_params_dict["W2_states"]))
1244
+ init_params_dict.update(b2_grad=torch.zeros_like(init_params_dict["b2_states"]))
1245
+ inputs = tree_map(lambda x: x.permute(2, 0, 1, 3, 4), inputs) # [B,nh,NC,CS,f] -> [NC,B,nh,CS,f]
1246
+ # allocate output tensor
1247
+ XQW_batch = torch.empty(
1248
+ (num_mini_batch, B, self.num_heads, mini_batch_size, self.head_dim),
1249
+ device=device,
1250
+ dtype=dtype,
1251
+ )
1252
+ # XQW_batch: [num_mini_batch, B, num_heads, mini_batch_size, head_dim]
1253
+ batch_params_dict, XQW_batch = scan(
1254
+ compute_mini_batch,
1255
+ init_params_dict,
1256
+ inputs,
1257
+ XQW_batch,
1258
+ self.config.scan_checkpoint_group_size if self.training else 0,
1259
+ )
1260
+
1261
+ # [B, num_heads, L, C]
1262
+ if cache_params is not None:
1263
+ cache_params.update(batch_params_dict, self.layer_idx, L)
1264
+
1265
+ # [num_mini_batch, B, num_heads, mini_batch_size, head_dim] -> [B, num_mini_batch, mini_batch_size, num_heads, head_dim]
1266
+ XQW_batch = XQW_batch.permute(1, 0, 3, 2, 4)
1267
+ # [B, L, C]
1268
+ XQW_batch = XQW_batch.reshape(B, L, self.width)
1269
+ return XQW_batch, batch_params_dict
1270
+
1271
+
1272
+ ################################
1273
+ ### E2E Architecture Modules ###
1274
+ ################################
1275
+
1276
+
1277
+ class Block(nn.Module):
1278
+ def __init__(self, config: TTTConfig, layer_idx: int):
1279
+ super().__init__()
1280
+ self.hidden_size = config.hidden_size
1281
+ self.pre_conv = config.pre_conv
1282
+
1283
+ if config.ttt_layer_type == "linear":
1284
+ ttt_layer = TTTLinear
1285
+ elif config.ttt_layer_type == "mlp":
1286
+ ttt_layer = TTTMLP
1287
+ else:
1288
+ raise ValueError(f"Invalid ttt_layer_type: {config.ttt_layer_type}")
1289
+
1290
+ self.seq_modeling_block = ttt_layer(config=config, layer_idx=layer_idx)
1291
+
1292
+ self.mlp = SwiGluMLP(config)
1293
+ if self.pre_conv:
1294
+ self.conv = Conv(config, layer_idx)
1295
+
1296
+ self.seq_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1297
+ self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1298
+ self.layer_idx = layer_idx
1299
+
1300
+ def forward(
1301
+ self,
1302
+ hidden_states: torch.Tensor,
1303
+ attention_mask: Optional[torch.Tensor] = None,
1304
+ position_ids: Optional[torch.LongTensor] = None,
1305
+ cache_params: Optional[TTTCache] = None,
1306
+ ):
1307
+ if self.pre_conv:
1308
+ residual = hidden_states
1309
+ hidden_states = self.conv(hidden_states, cache_params=cache_params)
1310
+ hidden_states = residual + hidden_states
1311
+
1312
+ residual = hidden_states
1313
+
1314
+ hidden_states = self.seq_norm(hidden_states)
1315
+
1316
+ # TTT Layer
1317
+ hidden_states = self.seq_modeling_block(
1318
+ hidden_states=hidden_states,
1319
+ attention_mask=attention_mask,
1320
+ position_ids=position_ids,
1321
+ cache_params=cache_params,
1322
+ )
1323
+ hidden_states = residual + hidden_states
1324
+
1325
+ # Feed-Forward-Network
1326
+ residual = hidden_states
1327
+ hidden_states = self.ffn_norm(hidden_states)
1328
+ hidden_states = self.mlp(hidden_states)
1329
+ hidden_states = residual + hidden_states
1330
+
1331
+ return hidden_states
1332
+
1333
+
1334
+ class TTTPreTrainedModel(PreTrainedModel):
1335
+ config_class = TTTConfig
1336
+ base_model_prefix = "model"
1337
+ supports_gradient_checkpointing = True
1338
+ _no_split_modules = ["Block"]
1339
+
1340
+ def _init_weights(self, module):
1341
+ std = self.config.initializer_range
1342
+ if isinstance(module, nn.Linear):
1343
+ module.weight.data.normal_(mean=0.0, std=std)
1344
+ if module.bias is not None:
1345
+ module.bias.data.zero_()
1346
+ elif isinstance(module, nn.Embedding):
1347
+ module.weight.data.normal_(mean=0.0, std=std)
1348
+ if module.padding_idx is not None:
1349
+ module.weight.data[module.padding_idx].zero_()
1350
+
1351
+
1352
+ @dataclass
1353
+ class TTTOutput(ModelOutput):
1354
+ """
1355
+ Class for the TTT model outputs.
1356
+
1357
+ Args:
1358
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1359
+ Sequence of hidden-states at the output of the last layer of the model.
1360
+ cache_params (`TTTCache`):
1361
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
1362
+ avoid providing the old `input_ids`.
1363
+ """
1364
+
1365
+ last_hidden_state: Optional[torch.FloatTensor] = None
1366
+ cache_params: Optional[TTTCache] = None
1367
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
1368
+
1369
+
1370
+ @dataclass
1371
+ class TTTCausalLMOutput(ModelOutput):
1372
+ """
1373
+ Base class for causal language model (or autoregressive) outputs.
1374
+
1375
+ Args:
1376
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
1377
+ Language modeling loss (for next-token prediction).
1378
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
1379
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
1380
+ cache_params (`TTTCache`):
1381
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
1382
+ avoid providing the old `input_ids`.
1383
+ """
1384
+
1385
+ loss: Optional[torch.FloatTensor] = None
1386
+ logits: Optional[torch.FloatTensor] = None
1387
+ cache_params: Optional[TTTCache] = None
1388
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
1389
+
1390
+
1391
+ class TTTModel(TTTPreTrainedModel):
1392
+ """
1393
+ Decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Block`]
1394
+
1395
+ Args:
1396
+ config: TTTConfig
1397
+ """
1398
+
1399
+ def __init__(self, config: TTTConfig):
1400
+ super().__init__(config)
1401
+ self.padding_idx = config.pad_token_id
1402
+ self.vocab_size = config.vocab_size
1403
+
1404
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1405
+ self.layers = nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
1406
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1407
+ self.gradient_checkpointing = False
1408
+
1409
+ # Initialize weights and apply final processing
1410
+ self.post_init()
1411
+
1412
+ def get_input_embeddings(self):
1413
+ return self.embed_tokens
1414
+
1415
+ def set_input_embeddings(self, value):
1416
+ self.embed_tokens = value
1417
+
1418
+ def forward(
1419
+ self,
1420
+ input_ids: torch.LongTensor = None,
1421
+ attention_mask: Optional[torch.Tensor] = None,
1422
+ position_ids: Optional[torch.LongTensor] = None,
1423
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1424
+ cache_params: Optional[TTTCache] = None,
1425
+ output_hidden_states: Optional[bool] = None,
1426
+ return_dict: Optional[bool] = None,
1427
+ use_cache: Optional[bool] = None,
1428
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1429
+ output_hidden_states = (
1430
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1431
+ )
1432
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1433
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1434
+
1435
+ if (input_ids is None) ^ (inputs_embeds is not None):
1436
+ raise ValueError(
1437
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
1438
+ )
1439
+
1440
+ if self.gradient_checkpointing and self.training and use_cache:
1441
+ logger.warning_once(
1442
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
1443
+ )
1444
+ use_cache = False
1445
+
1446
+ if inputs_embeds is None:
1447
+ inputs_embeds = self.embed_tokens(input_ids)
1448
+
1449
+ if cache_params is None and use_cache:
1450
+ cache_params = TTTCache(self, inputs_embeds.size(0))
1451
+
1452
+ seqlen_offset = 0
1453
+ if cache_params is not None:
1454
+ seqlen_offset = cache_params.seqlen_offset
1455
+ position_ids = torch.arange(
1456
+ seqlen_offset,
1457
+ seqlen_offset + inputs_embeds.shape[1],
1458
+ dtype=torch.long,
1459
+ device=inputs_embeds.device,
1460
+ ).unsqueeze(0)
1461
+
1462
+ hidden_states = inputs_embeds
1463
+
1464
+ if attention_mask is None:
1465
+ attention_mask = torch.ones_like(input_ids)
1466
+
1467
+ # decoder layers
1468
+ all_hidden_states = () if output_hidden_states else None
1469
+
1470
+ for decoder_layer in self.layers:
1471
+ if self.gradient_checkpointing and self.training:
1472
+ hidden_states = self._gradient_checkpointing_func(
1473
+ decoder_layer.__call__,
1474
+ hidden_states,
1475
+ attention_mask,
1476
+ position_ids,
1477
+ cache_params,
1478
+ )
1479
+ else:
1480
+ hidden_states = decoder_layer(
1481
+ hidden_states,
1482
+ attention_mask=attention_mask,
1483
+ position_ids=position_ids,
1484
+ cache_params=cache_params,
1485
+ )
1486
+
1487
+ if output_hidden_states:
1488
+ all_hidden_states = all_hidden_states + (hidden_states,)
1489
+
1490
+ if use_cache:
1491
+ cache_params.seqlen_offset += inputs_embeds.shape[1]
1492
+
1493
+ hidden_states = self.norm(hidden_states)
1494
+
1495
+ # add hidden states from the last decoder layer
1496
+ if output_hidden_states:
1497
+ all_hidden_states += (hidden_states,)
1498
+
1499
+ if not return_dict:
1500
+ return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
1501
+
1502
+ return TTTOutput(
1503
+ last_hidden_state=hidden_states,
1504
+ cache_params=cache_params if use_cache else None,
1505
+ hidden_states=all_hidden_states,
1506
+ )
1507
+
1508
+
1509
+ class TTTForCausalLM(TTTPreTrainedModel, GenerationMixin):
1510
+ _tied_weights_keys = ["lm_head.weight"]
1511
+
1512
+ def __init__(self, config):
1513
+ super().__init__(config)
1514
+ self.model = TTTModel(config)
1515
+ self.vocab_size = config.vocab_size
1516
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1517
+
1518
+ # Initialize weights and apply final processing
1519
+ self.post_init()
1520
+
1521
+ def get_input_embeddings(self):
1522
+ return self.model.embed_tokens
1523
+
1524
+ def set_input_embeddings(self, value):
1525
+ self.model.embed_tokens = value
1526
+
1527
+ def get_output_embeddings(self):
1528
+ return self.lm_head
1529
+
1530
+ def set_output_embeddings(self, new_embeddings):
1531
+ self.lm_head = new_embeddings
1532
+
1533
+ def set_decoder(self, decoder):
1534
+ self.model = decoder
1535
+
1536
+ def get_decoder(self):
1537
+ return self.model
1538
+
1539
+ def _update_model_kwargs_for_generation(
1540
+ self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs
1541
+ ) -> Dict[str, Any]:
1542
+ model_kwargs["cache_params"] = outputs.get("cache_params", None)
1543
+ # update attention mask
1544
+ if "attention_mask" in model_kwargs:
1545
+ attention_mask = model_kwargs["attention_mask"]
1546
+ model_kwargs["attention_mask"] = torch.cat(
1547
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))],
1548
+ dim=-1,
1549
+ )
1550
+ return model_kwargs
1551
+
1552
+ def prepare_inputs_for_generation(
1553
+ self,
1554
+ input_ids,
1555
+ attention_mask=None,
1556
+ cache_params: Optional[TTTCache] = None,
1557
+ inputs_embeds=None,
1558
+ **kwargs,
1559
+ ):
1560
+ # only last token for inputs_ids if the state is passed along.
1561
+ if cache_params is not None:
1562
+ input_ids = input_ids[:, -1].unsqueeze(-1)
1563
+ attention_mask = attention_mask[:, -1].unsqueeze(-1) if attention_mask is not None else None
1564
+
1565
+ if inputs_embeds is not None and cache_params is None:
1566
+ model_inputs = {"inputs_embeds": inputs_embeds}
1567
+ else:
1568
+ model_inputs = {"input_ids": input_ids}
1569
+
1570
+ model_inputs.update(
1571
+ {
1572
+ "cache_params": cache_params,
1573
+ "use_cache": kwargs.get("use_cache"),
1574
+ "attention_mask": attention_mask,
1575
+ }
1576
+ )
1577
+
1578
+ return model_inputs
1579
+
1580
+ def forward(
1581
+ self,
1582
+ input_ids: torch.LongTensor = None,
1583
+ attention_mask: Optional[torch.Tensor] = None,
1584
+ position_ids: Optional[torch.LongTensor] = None,
1585
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1586
+ cache_params: Optional[TTTCache] = None,
1587
+ labels: Optional[torch.LongTensor] = None,
1588
+ output_hidden_states: Optional[bool] = None,
1589
+ return_dict: Optional[bool] = None,
1590
+ use_cache: Optional[bool] = None,
1591
+ *,
1592
+ output_attentions: Optional[bool] = None,
1593
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1594
+ """
1595
+ Args:
1596
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1597
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1598
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1599
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1600
+ """
1601
+ output_hidden_states = (
1602
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1603
+ )
1604
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1605
+ assert not output_attentions, "output_attentions is not available in TTTForCausalLM"
1606
+
1607
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1608
+ outputs = self.model(
1609
+ input_ids=input_ids,
1610
+ attention_mask=attention_mask,
1611
+ position_ids=position_ids,
1612
+ cache_params=cache_params,
1613
+ inputs_embeds=inputs_embeds,
1614
+ output_hidden_states=output_hidden_states,
1615
+ return_dict=return_dict,
1616
+ use_cache=use_cache,
1617
+ )
1618
+
1619
+ hidden_states = outputs[0]
1620
+ if self.config.pretraining_tp > 1:
1621
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1622
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1623
+ logits = torch.cat(logits, dim=-1)
1624
+ else:
1625
+ logits = self.lm_head(hidden_states)
1626
+ logits = logits.float()
1627
+
1628
+ loss = None
1629
+ if labels is not None:
1630
+ # Shift so that tokens < n predict n
1631
+ shift_logits = logits[..., :-1, :].contiguous()
1632
+ shift_labels = labels[..., 1:].contiguous()
1633
+ # Flatten the tokens
1634
+ loss_fct = CrossEntropyLoss()
1635
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1636
+ shift_labels = shift_labels.view(-1)
1637
+ # Enable model parallelism
1638
+ shift_labels = shift_labels.to(shift_logits.device)
1639
+ loss = loss_fct(shift_logits, shift_labels)
1640
+
1641
+ if not return_dict:
1642
+ output = (logits,) + outputs[1:]
1643
+ return (loss,) + output if loss is not None else output
1644
+
1645
+ return TTTCausalLMOutput(
1646
+ loss=loss,
1647
+ logits=logits,
1648
+ cache_params=outputs.cache_params,
1649
+ hidden_states=outputs.hidden_states,
1650
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "</s>",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": null,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "</s>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ }
30
+ },
31
+ "bos_token": "<s>",
32
+ "clean_up_tokenization_spaces": false,
33
+ "eos_token": "</s>",
34
+ "extra_special_tokens": {},
35
+ "legacy": false,
36
+ "model_max_length": 1000000000000000019884624838656,
37
+ "pad_token": "</s>",
38
+ "padding_side": "right",
39
+ "sp_model_kwargs": {},
40
+ "tokenizer_class": "LlamaTokenizer",
41
+ "unk_token": "<unk>",
42
+ "use_default_system_prompt": false
43
+ }