datasets:
- EleutherAI/pile
language:
- en
library_name: transformers
pipeline_tag: text-generation
license: mit
Model Card
Citation
Please consider citing this paper if you use our work:
@article{arora2024simple,
title={Simple linear attention language models balance the recall-throughput tradeoff},
author={Arora, Simran and Eyuboglu, Sabri and Zhang, Michael and Timalsina, Aman and Alberti, Silas and Zinsley, Dylan and Zou, James and Rudra, Atri and Ré, Christopher},
journal={arXiv:2402.18668},
year={2024}
}
This model is a pretrained Based model.
As a quality reference, we include a pretrained Mamba model provided here: https://huggingface.co/hazyresearch/mamba-1b-50b and a pretrained attention (Llama architecture) model provided here: https://huggingface.co/hazyresearch/attn-1b-50bn
All three checkpoints are pretrained on 50Bn tokens of the Pile in the exact same data order using next token prediction.
A WandB report for training is here: https://api.wandb.ai/links/hazy-research/ggo9rst2
Model Sources
The model implementation and training code that produced the model are provided here: https://github.com/HazyResearch/based
Uses
The purpose of this work is to evaluate the language modeling quality of a new efficient architecture, Based.
We include a series of benchmarks that you can use to evaluate quality:
- FDA: https://huggingface.co/datasets/hazyresearch/based-fda
- SWDE: https://huggingface.co/datasets/hazyresearch/based-swde
- SQUAD: https://huggingface.co/datasets/hazyresearch/based-squad
Please reach out to simarora@stanford.edu, eyuboglu@stanford.edu, and mzhang20@stanford.edu with questions.
Use the code below to load the Based checkpoints:
import torch
from transformers import AutoTokenizer
from based.models.gpt import GPTLMHeadModel
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = GPTLMHeadModel.from_pretrained_hf("hazyresearch/based-360m")
The following code will run text generation for a prompt and print out the response.
input = tokenizer.encode("If I take one more step, it will be", return_tensors="pt").to("cuda")
output = model.generate(input, max_length=20)
print(tokenizer.decode(output[0]))
Note. For the checkpoints from other models, you will need to install other dependencies and use slightly different code.
To load the Attention models, use the following code:
import torch
from transformers import AutoTokenizer
from based.models.transformer.gpt import GPTLMHeadModel
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = GPTLMHeadModel.from_pretrained_hf("hazyresearch/attn-360m").to("cuda")
To use the Mamba checkpoints, first run pip install mamba-ssm and then use the following code:
import torch
from transformers import AutoTokenizer
from based.models.mamba import MambaLMHeadModel
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = MambaLMHeadModel.from_pretrained_hf("hazyresearch/mamba-360m").to("cuda")