RetentionLabs commited on
Commit
8c2af59
·
verified ·
1 Parent(s): fa58f7a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +105 -105
README.md CHANGED
@@ -1,105 +1,105 @@
1
- ---
2
- license: mit
3
- language:
4
- - en
5
- tags:
6
- - Test-time Training
7
- pipeline_tag: text-generation
8
- base_model:
9
- - Test-Time-Training/ttt-mlp-350m-books-2k
10
- library_name: transformers
11
- ---
12
-
13
- # Learning to (Learn at Test Time): RNNs with Expressive Hidden States
14
-
15
- [**Paper**](https://arxiv.org/abs/2407.04620)
16
- | [**JAX Codebase**](https://github.com/test-time-training/ttt-lm-jax)
17
- | [**Setup**](#environment-setup)
18
- | [**Quick Start**](#quick-start)
19
- | [**Inference Benchmark**](https://github.com/test-time-training/ttt-lm-kernels)
20
-
21
- This is the official PyTorch model implementation of [Learning to (Learn at Test Time): RNNs with Expressive Hidden States](https://arxiv.org/abs/2407.04620).
22
- We **do not recommend training** with this codebase, because it is written in pure PyTorch without any systems optimization, so training will be slow, especially when the per-device batch size is small.
23
-
24
-
25
- For training code, or to replicate results from our paper, please view our [JAX codebase](https://github.com/test-time-training/ttt-lm-jax). For inference kernels, or to replicate speed benchmarks from our paper, please view our [kernel implementations](https://github.com/test-time-training/ttt-lm-kernels).
26
-
27
- ## Abstract
28
-
29
- Self-attention performs well in long context but has quadratic complexity. Existing RNN layers
30
- have linear complexity, but their performance in long context is limited by the expressive power
31
- of their hidden state. We propose a new class of sequence modeling layers with linear complexity
32
- and an expressive hidden state. The key idea is to make the hidden state a machine learning
33
- model itself, and the update rule a step of self-supervised learning.
34
-
35
- Since the hidden state is updated by training even on test sequences, our layers are called **Test-Time Training (TTT) layers**.
36
- We consider two instantiations: TTT-Linear and TTT-MLP, whose hidden state is a linear model
37
- and a two-layer MLP respectively.
38
-
39
- ## Environment Setup
40
-
41
- ```bash
42
- pip install "transformers[torch]"
43
- ```
44
-
45
- ## Quick Start
46
-
47
- Our implementation is based on Huggingface Transformers. You can use the following code to load the model and generate text.
48
-
49
- ### Load with AutoModel
50
-
51
- ```python
52
- import torch
53
- from transformers import AutoTokenizer, AutoModelForCausalLM
54
-
55
-
56
- model_id = "RetentionLabs/TTT-Linear-350M-Base-Books-2k"
57
-
58
- # Initializing a model from remote
59
- tokenizer = AutoTokenizer.from_pretrained(model_id)
60
- model = AutoModelForCausalLM.from_pretrained(
61
- model_id,
62
- trust_remote_code=True,
63
- dtype=torch.bfloat16,
64
- device_map="auto"
65
- )
66
-
67
- # Generate
68
- with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
69
- inputs = tokenizer("The future of AI is", return_tensors="pt").to(model.device)
70
- outputs = model.generate(**inputs, max_new_tokens=100)
71
- print(tokenizer.decode(outputs[0], skip_special_tokens=True))
72
- ```
73
-
74
- ### From scratch
75
-
76
- ```python
77
- from transformers import AutoTokenizer
78
- from modeling_ttt import TTTForCausalLM, TTTConfig, TTT_STANDARD_CONFIGS
79
-
80
- # Initializing a TTT ttt-1b style configuration
81
- # configuration = TTTConfig(**TTT_STANDARD_CONFIGS['1b']) is equivalent to the following
82
- configuration = TTTConfig()
83
-
84
- # Initializing a model from the ttt-1b style configuration
85
- model = TTTForCausalLM(configuration)
86
- model.eval()
87
-
88
- # Accessing the model configuration
89
- configuration = model.config
90
-
91
- # Tokenizer
92
- tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
93
-
94
- # Prefill
95
- input_ids = tokenizer("Greeting from TTT!", return_tensors="pt").input_ids
96
- logits = model(input_ids=input_ids)
97
- print(logits)
98
-
99
- # Decoding
100
- out_ids = model.generate(input_ids=input_ids, max_length=50)
101
- out_str = tokenizer.batch_decode(out_ids, skip_special_tokens=True)
102
- print(out_str)
103
- ```
104
-
105
- **Note: This is a naive implementation of TTT layers for tutorial purposes.** This model can be trained using Huggingface Accelerate, or custom training loops. We have released our faster inference kernel and its speed benchmark [here](https://github.com/test-time-training/ttt-lm-kernels).
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ tags:
6
+ - Test-time Training
7
+ pipeline_tag: text-generation
8
+ base_model:
9
+ - Test-Time-Training/ttt-linear-350m-books-2k
10
+ library_name: transformers
11
+ ---
12
+
13
+ # Learning to (Learn at Test Time): RNNs with Expressive Hidden States
14
+
15
+ [**Paper**](https://arxiv.org/abs/2407.04620)
16
+ | [**JAX Codebase**](https://github.com/test-time-training/ttt-lm-jax)
17
+ | [**Setup**](#environment-setup)
18
+ | [**Quick Start**](#quick-start)
19
+ | [**Inference Benchmark**](https://github.com/test-time-training/ttt-lm-kernels)
20
+
21
+ This is the official PyTorch model implementation of [Learning to (Learn at Test Time): RNNs with Expressive Hidden States](https://arxiv.org/abs/2407.04620).
22
+ We **do not recommend training** with this codebase, because it is written in pure PyTorch without any systems optimization, so training will be slow, especially when the per-device batch size is small.
23
+
24
+
25
+ For training code, or to replicate results from our paper, please view our [JAX codebase](https://github.com/test-time-training/ttt-lm-jax). For inference kernels, or to replicate speed benchmarks from our paper, please view our [kernel implementations](https://github.com/test-time-training/ttt-lm-kernels).
26
+
27
+ ## Abstract
28
+
29
+ Self-attention performs well in long context but has quadratic complexity. Existing RNN layers
30
+ have linear complexity, but their performance in long context is limited by the expressive power
31
+ of their hidden state. We propose a new class of sequence modeling layers with linear complexity
32
+ and an expressive hidden state. The key idea is to make the hidden state a machine learning
33
+ model itself, and the update rule a step of self-supervised learning.
34
+
35
+ Since the hidden state is updated by training even on test sequences, our layers are called **Test-Time Training (TTT) layers**.
36
+ We consider two instantiations: TTT-Linear and TTT-MLP, whose hidden state is a linear model
37
+ and a two-layer MLP respectively.
38
+
39
+ ## Environment Setup
40
+
41
+ ```bash
42
+ pip install "transformers[torch]"
43
+ ```
44
+
45
+ ## Quick Start
46
+
47
+ Our implementation is based on Huggingface Transformers. You can use the following code to load the model and generate text.
48
+
49
+ ### Load with AutoModel
50
+
51
+ ```python
52
+ import torch
53
+ from transformers import AutoTokenizer, AutoModelForCausalLM
54
+
55
+
56
+ model_id = "RetentionLabs/TTT-Linear-350M-Base-Books-2k"
57
+
58
+ # Initializing a model from remote
59
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
60
+ model = AutoModelForCausalLM.from_pretrained(
61
+ model_id,
62
+ trust_remote_code=True,
63
+ dtype=torch.bfloat16,
64
+ device_map="auto"
65
+ )
66
+
67
+ # Generate
68
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
69
+ inputs = tokenizer("The future of AI is", return_tensors="pt").to(model.device)
70
+ outputs = model.generate(**inputs, max_new_tokens=100)
71
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
72
+ ```
73
+
74
+ ### From scratch
75
+
76
+ ```python
77
+ from transformers import AutoTokenizer
78
+ from modeling_ttt import TTTForCausalLM, TTTConfig, TTT_STANDARD_CONFIGS
79
+
80
+ # Initializing a TTT ttt-1b style configuration
81
+ # configuration = TTTConfig(**TTT_STANDARD_CONFIGS['1b']) is equivalent to the following
82
+ configuration = TTTConfig()
83
+
84
+ # Initializing a model from the ttt-1b style configuration
85
+ model = TTTForCausalLM(configuration)
86
+ model.eval()
87
+
88
+ # Accessing the model configuration
89
+ configuration = model.config
90
+
91
+ # Tokenizer
92
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
93
+
94
+ # Prefill
95
+ input_ids = tokenizer("Greeting from TTT!", return_tensors="pt").input_ids
96
+ logits = model(input_ids=input_ids)
97
+ print(logits)
98
+
99
+ # Decoding
100
+ out_ids = model.generate(input_ids=input_ids, max_length=50)
101
+ out_str = tokenizer.batch_decode(out_ids, skip_special_tokens=True)
102
+ print(out_str)
103
+ ```
104
+
105
+ **Note: This is a naive implementation of TTT layers for tutorial purposes.** This model can be trained using Huggingface Accelerate, or custom training loops. We have released our faster inference kernel and its speed benchmark [here](https://github.com/test-time-training/ttt-lm-kernels).