based-1b-50b / README.md
nielsr's picture
nielsr HF Staff
Add missing metadata: library_name, pipeline_tag, and license
9c96577 verified
|
raw
history blame
3.04 kB
metadata
datasets:
  - EleutherAI/pile
language:
  - en
library_name: transformers
pipeline_tag: text-generation
license: mit

Model Card

Citation

Please consider citing this paper if you use our work:

@article{arora2024simple,
  title={Simple linear attention language models balance the recall-throughput tradeoff},
  author={Arora, Simran and Eyuboglu, Sabri and Zhang, Michael and Timalsina, Aman and Alberti, Silas and Zinsley, Dylan and Zou, James and Rudra, Atri and Ré, Christopher},
  journal={arXiv:2402.18668},
  year={2024}
}

This model is a pretrained Based model.

As a quality reference, we include a pretrained Mamba model provided here: https://huggingface.co/hazyresearch/mamba-1b-50b and a pretrained attention (Llama architecture) model provided here: https://huggingface.co/hazyresearch/attn-1b-50bn

All three checkpoints are pretrained on 50Bn tokens of the Pile in the exact same data order using next token prediction.

A WandB report for training is here: https://api.wandb.ai/links/hazy-research/ggo9rst2

Model Sources

The model implementation and training code that produced the model are provided here: https://github.com/HazyResearch/based

Uses

The purpose of this work is to evaluate the language modeling quality of a new efficient architecture, Based.

We include a series of benchmarks that you can use to evaluate quality:

Please reach out to simarora@stanford.edu, eyuboglu@stanford.edu, and mzhang20@stanford.edu with questions.

Use the code below to load the Based checkpoints:

import torch
from transformers import AutoTokenizer
from based.models.gpt import GPTLMHeadModel

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = GPTLMHeadModel.from_pretrained_hf("hazyresearch/based-360m")

The following code will run text generation for a prompt and print out the response.

input = tokenizer.encode("If I take one more step, it will be", return_tensors="pt").to("cuda")
output = model.generate(input, max_length=20)
print(tokenizer.decode(output[0]))

Note. For the checkpoints from other models, you will need to install other dependencies and use slightly different code.

To load the Attention models, use the following code:

import torch
from transformers import AutoTokenizer
from based.models.transformer.gpt import GPTLMHeadModel

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = GPTLMHeadModel.from_pretrained_hf("hazyresearch/attn-360m").to("cuda")

To use the Mamba checkpoints, first run pip install mamba-ssm and then use the following code:

import torch
from transformers import AutoTokenizer
from based.models.mamba import MambaLMHeadModel

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = MambaLMHeadModel.from_pretrained_hf("hazyresearch/mamba-360m").to("cuda")