--- license: mit language: - en tags: - Test-time Training pipeline_tag: text-generation base_model: - Test-Time-Training/ttt-linear-125m-books-2k library_name: transformers --- # Learning to (Learn at Test Time): RNNs with Expressive Hidden States [**Paper**](https://arxiv.org/abs/2407.04620) | [**JAX Codebase**](https://github.com/test-time-training/ttt-lm-jax) | [**Setup**](#environment-setup) | [**Quick Start**](#quick-start) | [**Inference Benchmark**](https://github.com/test-time-training/ttt-lm-kernels) This is the official PyTorch model implementation of [Learning to (Learn at Test Time): RNNs with Expressive Hidden States](https://arxiv.org/abs/2407.04620). We **do not recommend training** with this codebase, because it is written in pure PyTorch without any systems optimization, so training will be slow, especially when the per-device batch size is small. For training code, or to replicate results from our paper, please view our [JAX codebase](https://github.com/test-time-training/ttt-lm-jax). For inference kernels, or to replicate speed benchmarks from our paper, please view our [kernel implementations](https://github.com/test-time-training/ttt-lm-kernels). ## Abstract Self-attention performs well in long context but has quadratic complexity. Existing RNN layers have linear complexity, but their performance in long context is limited by the expressive power of their hidden state. We propose a new class of sequence modeling layers with linear complexity and an expressive hidden state. The key idea is to make the hidden state a machine learning model itself, and the update rule a step of self-supervised learning. Since the hidden state is updated by training even on test sequences, our layers are called **Test-Time Training (TTT) layers**. We consider two instantiations: TTT-Linear and TTT-MLP, whose hidden state is a linear model and a two-layer MLP respectively. ## Environment Setup ```bash pip install "transformers[torch]" ``` ## Quick Start Our implementation is based on Huggingface Transformers. You can use the following code to load the model and generate text. ### Load with AutoModel ```python import torch from transformers import AutoTokenizer, AutoModelForCausalLM model_id = "RetentionLabs/TTT-Linear-125M-Base-Books-2k" # Initializing a model from remote tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto" ) # Generate with torch.autocast(device_type="cuda", dtype=torch.bfloat16): inputs = tokenizer("The future of AI is", return_tensors="pt").to(model.device) outputs = model.generate(**inputs, max_new_tokens=100) print(tokenizer.decode(outputs[0], skip_special_tokens=True)) ``` ### From scratch ```python from transformers import AutoTokenizer from modeling_ttt import TTTForCausalLM, TTTConfig, TTT_STANDARD_CONFIGS # Initializing a TTT ttt-1b style configuration # configuration = TTTConfig(**TTT_STANDARD_CONFIGS['1b']) is equivalent to the following configuration = TTTConfig() # Initializing a model from the ttt-1b style configuration model = TTTForCausalLM(configuration) model.eval() # Accessing the model configuration configuration = model.config # Tokenizer tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") # Prefill input_ids = tokenizer("Greeting from TTT!", return_tensors="pt").input_ids logits = model(input_ids=input_ids) print(logits) # Decoding out_ids = model.generate(input_ids=input_ids, max_length=50) out_str = tokenizer.batch_decode(out_ids, skip_special_tokens=True) print(out_str) ``` **Note: This is a naive implementation of TTT layers for tutorial purposes.** This model can be trained using Huggingface Accelerate, or custom training loops. We have released our faster inference kernel and its speed benchmark [here](https://github.com/test-time-training/ttt-lm-kernels).