|
|
---
|
|
|
license: mit
|
|
|
language:
|
|
|
- en
|
|
|
tags:
|
|
|
- Test-time Training
|
|
|
pipeline_tag: text-generation
|
|
|
base_model:
|
|
|
- Test-Time-Training/ttt-mlp-125m-books-2k
|
|
|
library_name: transformers
|
|
|
---
|
|
|
|
|
|
# Learning to (Learn at Test Time): RNNs with Expressive Hidden States
|
|
|
|
|
|
[**Paper**](https://arxiv.org/abs/2407.04620)
|
|
|
| [**JAX Codebase**](https://github.com/test-time-training/ttt-lm-jax)
|
|
|
| [**Setup**](#environment-setup)
|
|
|
| [**Quick Start**](#quick-start)
|
|
|
| [**Inference Benchmark**](https://github.com/test-time-training/ttt-lm-kernels)
|
|
|
|
|
|
This is the official PyTorch model implementation of [Learning to (Learn at Test Time): RNNs with Expressive Hidden States](https://arxiv.org/abs/2407.04620).
|
|
|
We **do not recommend training** with this codebase, because it is written in pure PyTorch without any systems optimization, so training will be slow, especially when the per-device batch size is small.
|
|
|
|
|
|
|
|
|
For training code, or to replicate results from our paper, please view our [JAX codebase](https://github.com/test-time-training/ttt-lm-jax). For inference kernels, or to replicate speed benchmarks from our paper, please view our [kernel implementations](https://github.com/test-time-training/ttt-lm-kernels).
|
|
|
|
|
|
## Abstract
|
|
|
|
|
|
Self-attention performs well in long context but has quadratic complexity. Existing RNN layers
|
|
|
have linear complexity, but their performance in long context is limited by the expressive power
|
|
|
of their hidden state. We propose a new class of sequence modeling layers with linear complexity
|
|
|
and an expressive hidden state. The key idea is to make the hidden state a machine learning
|
|
|
model itself, and the update rule a step of self-supervised learning.
|
|
|
|
|
|
Since the hidden state is updated by training even on test sequences, our layers are called **Test-Time Training (TTT) layers**.
|
|
|
We consider two instantiations: TTT-Linear and TTT-MLP, whose hidden state is a linear model
|
|
|
and a two-layer MLP respectively.
|
|
|
|
|
|
## Environment Setup
|
|
|
|
|
|
```bash
|
|
|
pip install "transformers[torch]"
|
|
|
```
|
|
|
|
|
|
## Quick Start
|
|
|
|
|
|
Our implementation is based on Huggingface Transformers. You can use the following code to load the model and generate text.
|
|
|
|
|
|
### Load with AutoModel
|
|
|
|
|
|
```python
|
|
|
import torch
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
|
|
|
|
|
|
model_id = "RetentionLabs/TTT-MLP-125M-Base-Books-2k"
|
|
|
|
|
|
# Initializing a model from remote
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
model_id,
|
|
|
trust_remote_code=True,
|
|
|
dtype=torch.bfloat16,
|
|
|
device_map="auto"
|
|
|
)
|
|
|
|
|
|
# Generate
|
|
|
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
|
|
inputs = tokenizer("The future of AI is", return_tensors="pt").to(model.device)
|
|
|
outputs = model.generate(**inputs, max_new_tokens=100)
|
|
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
|
|
```
|
|
|
|
|
|
### From scratch
|
|
|
|
|
|
```python
|
|
|
from transformers import AutoTokenizer
|
|
|
from modeling_ttt import TTTForCausalLM, TTTConfig, TTT_STANDARD_CONFIGS
|
|
|
|
|
|
# Initializing a TTT ttt-1b style configuration
|
|
|
# configuration = TTTConfig(**TTT_STANDARD_CONFIGS['1b']) is equivalent to the following
|
|
|
configuration = TTTConfig()
|
|
|
|
|
|
# Initializing a model from the ttt-1b style configuration
|
|
|
model = TTTForCausalLM(configuration)
|
|
|
model.eval()
|
|
|
|
|
|
# Accessing the model configuration
|
|
|
configuration = model.config
|
|
|
|
|
|
# Tokenizer
|
|
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
|
|
|
|
|
# Prefill
|
|
|
input_ids = tokenizer("Greeting from TTT!", return_tensors="pt").input_ids
|
|
|
logits = model(input_ids=input_ids)
|
|
|
print(logits)
|
|
|
|
|
|
# Decoding
|
|
|
out_ids = model.generate(input_ids=input_ids, max_length=50)
|
|
|
out_str = tokenizer.batch_decode(out_ids, skip_special_tokens=True)
|
|
|
print(out_str)
|
|
|
```
|
|
|
|
|
|
**Note: This is a naive implementation of TTT layers for tutorial purposes.** This model can be trained using Huggingface Accelerate, or custom training loops. We have released our faster inference kernel and its speed benchmark [here](https://github.com/test-time-training/ttt-lm-kernels).
|
|
|
|