Instructions to use DavidSeyserHF/rex1-base with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use DavidSeyserHF/rex1-base with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="DavidSeyserHF/rex1-base", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("DavidSeyserHF/rex1-base", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use DavidSeyserHF/rex1-base with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "DavidSeyserHF/rex1-base" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "DavidSeyserHF/rex1-base", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/DavidSeyserHF/rex1-base
- SGLang
How to use DavidSeyserHF/rex1-base with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "DavidSeyserHF/rex1-base" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "DavidSeyserHF/rex1-base", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "DavidSeyserHF/rex1-base" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "DavidSeyserHF/rex1-base", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use DavidSeyserHF/rex1-base with Docker Model Runner:
docker model run hf.co/DavidSeyserHF/rex1-base
Update rex1-base: mixed-2 checkpoint step 710000
Browse files- README.md +3 -3
- export_metadata.json +2 -2
- inference.py +7 -0
- model.py +39 -0
- model.safetensors +1 -1
- training_config.yaml +139 -24
README.md
CHANGED
|
@@ -7,7 +7,7 @@ tags:
|
|
| 7 |
- rex
|
| 8 |
---
|
| 9 |
|
| 10 |
-
# REX1
|
| 11 |
|
| 12 |
REX is a recursive decoder-only Transformer language model. This repository uses custom
|
| 13 |
Transformers code, so load it with `trust_remote_code=True`.
|
|
@@ -21,12 +21,12 @@ tokenizer = AutoTokenizer.from_pretrained(".")
|
|
| 21 |
|
| 22 |
## Checkpoint
|
| 23 |
|
| 24 |
-
Exported from `runs/rex-300m/
|
| 25 |
|
| 26 |
## Training Notes
|
| 27 |
|
| 28 |
- Tokenizer: `gpt2`
|
| 29 |
- Context length: `1024`
|
| 30 |
-
- Training output dir: `runs/rex-300m`
|
| 31 |
|
| 32 |
This is a base language model checkpoint and is not instruction-aligned unless noted.
|
|
|
|
| 7 |
- rex
|
| 8 |
---
|
| 9 |
|
| 10 |
+
# REX1 300M mixed-2 step 710000
|
| 11 |
|
| 12 |
REX is a recursive decoder-only Transformer language model. This repository uses custom
|
| 13 |
Transformers code, so load it with `trust_remote_code=True`.
|
|
|
|
| 21 |
|
| 22 |
## Checkpoint
|
| 23 |
|
| 24 |
+
Exported from `runs/rex-300m-mixed-2/ckpt_step710000.pt`.
|
| 25 |
|
| 26 |
## Training Notes
|
| 27 |
|
| 28 |
- Tokenizer: `gpt2`
|
| 29 |
- Context length: `1024`
|
| 30 |
+
- Training output dir: `runs/rex-300m-mixed-2`
|
| 31 |
|
| 32 |
This is a base language model checkpoint and is not instruction-aligned unless noted.
|
export_metadata.json
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
{
|
| 2 |
-
"checkpoint": "runs/rex-300m/
|
| 3 |
-
"step":
|
| 4 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"checkpoint": "runs/rex-300m-mixed-2/ckpt_step710000.pt",
|
| 3 |
+
"step": 710000
|
| 4 |
}
|
inference.py
CHANGED
|
@@ -45,6 +45,12 @@ def build_parser() -> argparse.ArgumentParser:
|
|
| 45 |
parser.add_argument("--max-new-tokens", type=int, default=100, help="Number of tokens to generate")
|
| 46 |
parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature; 0 means greedy")
|
| 47 |
parser.add_argument("--top-k", type=int, default=50, help="Limit sampling to top-k tokens; <=0 disables")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
return parser
|
| 49 |
|
| 50 |
|
|
@@ -73,6 +79,7 @@ def main() -> None:
|
|
| 73 |
max_new_tokens=args.max_new_tokens,
|
| 74 |
temperature=args.temperature,
|
| 75 |
top_k=top_k,
|
|
|
|
| 76 |
)
|
| 77 |
|
| 78 |
print(tokenizer.decode(output_ids[0].tolist(), skip_special_tokens=True))
|
|
|
|
| 45 |
parser.add_argument("--max-new-tokens", type=int, default=100, help="Number of tokens to generate")
|
| 46 |
parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature; 0 means greedy")
|
| 47 |
parser.add_argument("--top-k", type=int, default=50, help="Limit sampling to top-k tokens; <=0 disables")
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--no-repeat-ngram-size",
|
| 50 |
+
type=int,
|
| 51 |
+
default=0,
|
| 52 |
+
help="Prevent repeated n-grams of this size; 0 disables",
|
| 53 |
+
)
|
| 54 |
return parser
|
| 55 |
|
| 56 |
|
|
|
|
| 79 |
max_new_tokens=args.max_new_tokens,
|
| 80 |
temperature=args.temperature,
|
| 81 |
top_k=top_k,
|
| 82 |
+
no_repeat_ngram_size=args.no_repeat_ngram_size,
|
| 83 |
)
|
| 84 |
|
| 85 |
print(tokenizer.decode(output_ids[0].tolist(), skip_special_tokens=True))
|
model.py
CHANGED
|
@@ -211,11 +211,15 @@ class RexForCausalLM(nn.Module):
|
|
| 211 |
max_new_tokens: int,
|
| 212 |
temperature: float = 1.0,
|
| 213 |
top_k: int | None = None,
|
|
|
|
| 214 |
) -> torch.Tensor:
|
| 215 |
self.eval()
|
|
|
|
|
|
|
| 216 |
for _ in range(max_new_tokens):
|
| 217 |
context = input_ids[:, -self.cfg.max_seq_len :]
|
| 218 |
logits = self(context)["logits"][:, -1, :]
|
|
|
|
| 219 |
if temperature < 0:
|
| 220 |
raise ValueError("temperature must be >= 0")
|
| 221 |
if temperature == 0:
|
|
@@ -231,6 +235,41 @@ class RexForCausalLM(nn.Module):
|
|
| 231 |
input_ids = torch.cat([input_ids, next_token], dim=1)
|
| 232 |
return input_ids
|
| 233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
def parameter_count(self, trainable_only: bool = False) -> int:
|
| 235 |
params = self.parameters()
|
| 236 |
if trainable_only:
|
|
|
|
| 211 |
max_new_tokens: int,
|
| 212 |
temperature: float = 1.0,
|
| 213 |
top_k: int | None = None,
|
| 214 |
+
no_repeat_ngram_size: int = 0,
|
| 215 |
) -> torch.Tensor:
|
| 216 |
self.eval()
|
| 217 |
+
if no_repeat_ngram_size < 0:
|
| 218 |
+
raise ValueError("no_repeat_ngram_size must be >= 0")
|
| 219 |
for _ in range(max_new_tokens):
|
| 220 |
context = input_ids[:, -self.cfg.max_seq_len :]
|
| 221 |
logits = self(context)["logits"][:, -1, :]
|
| 222 |
+
logits = self._apply_no_repeat_ngram(logits, input_ids, no_repeat_ngram_size)
|
| 223 |
if temperature < 0:
|
| 224 |
raise ValueError("temperature must be >= 0")
|
| 225 |
if temperature == 0:
|
|
|
|
| 235 |
input_ids = torch.cat([input_ids, next_token], dim=1)
|
| 236 |
return input_ids
|
| 237 |
|
| 238 |
+
@staticmethod
|
| 239 |
+
def _apply_no_repeat_ngram(
|
| 240 |
+
logits: torch.Tensor,
|
| 241 |
+
input_ids: torch.Tensor,
|
| 242 |
+
no_repeat_ngram_size: int,
|
| 243 |
+
) -> torch.Tensor:
|
| 244 |
+
if no_repeat_ngram_size <= 0:
|
| 245 |
+
return logits
|
| 246 |
+
|
| 247 |
+
logits = logits.clone()
|
| 248 |
+
for batch_idx in range(input_ids.size(0)):
|
| 249 |
+
banned_tokens = RexForCausalLM._get_banned_ngram_tokens(
|
| 250 |
+
input_ids[batch_idx].tolist(),
|
| 251 |
+
no_repeat_ngram_size,
|
| 252 |
+
)
|
| 253 |
+
if banned_tokens:
|
| 254 |
+
logits[batch_idx, banned_tokens] = float("-inf")
|
| 255 |
+
return logits
|
| 256 |
+
|
| 257 |
+
@staticmethod
|
| 258 |
+
def _get_banned_ngram_tokens(tokens: list[int], ngram_size: int) -> list[int]:
|
| 259 |
+
if ngram_size == 1:
|
| 260 |
+
return list(set(tokens))
|
| 261 |
+
if len(tokens) < ngram_size - 1:
|
| 262 |
+
return []
|
| 263 |
+
|
| 264 |
+
prefix_to_next: dict[tuple[int, ...], set[int]] = {}
|
| 265 |
+
for i in range(len(tokens) - ngram_size + 1):
|
| 266 |
+
ngram = tokens[i : i + ngram_size]
|
| 267 |
+
prefix = tuple(ngram[:-1])
|
| 268 |
+
prefix_to_next.setdefault(prefix, set()).add(ngram[-1])
|
| 269 |
+
|
| 270 |
+
current_prefix = tuple(tokens[-(ngram_size - 1) :])
|
| 271 |
+
return list(prefix_to_next.get(current_prefix, set()))
|
| 272 |
+
|
| 273 |
def parameter_count(self, trainable_only: bool = False) -> int:
|
| 274 |
params = self.parameters()
|
| 275 |
if trainable_only:
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1196009344
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d5ed9638191454c97255326a3f5aed401e13c60f55f142286c247306d9698770
|
| 3 |
size 1196009344
|
training_config.yaml
CHANGED
|
@@ -15,32 +15,145 @@ data:
|
|
| 15 |
tokenizer_name: gpt2
|
| 16 |
block_size: 1024
|
| 17 |
stride: 1024
|
| 18 |
-
train_bin: data/train.bin
|
| 19 |
-
val_bin: data/val.bin
|
| 20 |
num_workers: 2
|
| 21 |
download:
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
train:
|
| 34 |
seed: 1337
|
| 35 |
device: auto
|
| 36 |
dtype: bfloat16
|
| 37 |
-
out_dir: runs/rex-300m
|
| 38 |
batch_size: 8
|
| 39 |
gradient_accumulation_steps: 1
|
| 40 |
-
epochs:
|
| 41 |
max_steps: null
|
| 42 |
-
learning_rate:
|
| 43 |
-
min_lr:
|
| 44 |
warmup_steps: 1000
|
| 45 |
weight_decay: 0.1
|
| 46 |
betas:
|
|
@@ -49,22 +162,24 @@ train:
|
|
| 49 |
eps: 1.0e-08
|
| 50 |
grad_clip: 1.0
|
| 51 |
compile: true
|
| 52 |
-
resume:
|
| 53 |
log_every: 10
|
| 54 |
-
eval_every:
|
| 55 |
-
eval_batches:
|
| 56 |
-
save_every:
|
| 57 |
wandb:
|
| 58 |
enabled: true
|
| 59 |
project: rex
|
| 60 |
entity: null
|
| 61 |
-
name: rex1
|
| 62 |
group: pretrain
|
| 63 |
tags:
|
| 64 |
- recursive-transformer
|
| 65 |
- 300m
|
| 66 |
-
-
|
| 67 |
-
|
|
|
|
|
|
|
| 68 |
mode: online
|
| 69 |
watch: false
|
| 70 |
watch_log: gradients
|
|
|
|
| 15 |
tokenizer_name: gpt2
|
| 16 |
block_size: 1024
|
| 17 |
stride: 1024
|
| 18 |
+
train_bin: data/mixed-2/train.bin
|
| 19 |
+
val_bin: data/mixed-2/val.bin
|
| 20 |
num_workers: 2
|
| 21 |
download:
|
| 22 |
+
sources:
|
| 23 |
+
- name: fineweb_edu
|
| 24 |
+
dataset_name: HuggingFaceFW/fineweb-edu
|
| 25 |
+
dataset_config: sample-10BT
|
| 26 |
+
text_column: text
|
| 27 |
+
train_split: train
|
| 28 |
+
split_strategy: head
|
| 29 |
+
streaming: true
|
| 30 |
+
max_train_docs: 400000
|
| 31 |
+
max_val_docs: 10000
|
| 32 |
+
- name: cosmopedia_web
|
| 33 |
+
dataset_name: HuggingFaceTB/cosmopedia
|
| 34 |
+
dataset_config: web_samples_v2
|
| 35 |
+
text_column: text
|
| 36 |
+
train_split: train
|
| 37 |
+
split_strategy: head
|
| 38 |
+
streaming: true
|
| 39 |
+
max_train_docs: 50000
|
| 40 |
+
max_val_docs: 5000
|
| 41 |
+
- name: cosmopedia_khanacademy
|
| 42 |
+
dataset_name: HuggingFaceTB/cosmopedia
|
| 43 |
+
dataset_config: khanacademy
|
| 44 |
+
text_column: text
|
| 45 |
+
train_split: train
|
| 46 |
+
split_strategy: head
|
| 47 |
+
streaming: true
|
| 48 |
+
max_train_docs: 50000
|
| 49 |
+
max_val_docs: 5000
|
| 50 |
+
- name: cosmopedia_openstax
|
| 51 |
+
dataset_name: HuggingFaceTB/cosmopedia
|
| 52 |
+
dataset_config: openstax
|
| 53 |
+
text_column: text
|
| 54 |
+
train_split: train
|
| 55 |
+
split_strategy: head
|
| 56 |
+
streaming: true
|
| 57 |
+
max_train_docs: 50000
|
| 58 |
+
max_val_docs: 5000
|
| 59 |
+
- name: cosmopedia_auto_math
|
| 60 |
+
dataset_name: HuggingFaceTB/cosmopedia
|
| 61 |
+
dataset_config: auto_math_text
|
| 62 |
+
text_column: text
|
| 63 |
+
train_split: train
|
| 64 |
+
split_strategy: head
|
| 65 |
+
streaming: true
|
| 66 |
+
max_train_docs: 50000
|
| 67 |
+
max_val_docs: 5000
|
| 68 |
+
- name: cosmopedia_stanford
|
| 69 |
+
dataset_name: HuggingFaceTB/cosmopedia
|
| 70 |
+
dataset_config: stanford
|
| 71 |
+
text_column: text
|
| 72 |
+
train_split: train
|
| 73 |
+
split_strategy: head
|
| 74 |
+
streaming: true
|
| 75 |
+
max_train_docs: 50000
|
| 76 |
+
max_val_docs: 5000
|
| 77 |
+
- name: cosmopedia_wikihow
|
| 78 |
+
dataset_name: HuggingFaceTB/cosmopedia
|
| 79 |
+
dataset_config: wikihow
|
| 80 |
+
text_column: text
|
| 81 |
+
train_split: train
|
| 82 |
+
split_strategy: head
|
| 83 |
+
streaming: true
|
| 84 |
+
max_train_docs: 40000
|
| 85 |
+
max_val_docs: 4000
|
| 86 |
+
- name: wikipedia_en
|
| 87 |
+
dataset_name: wikimedia/wikipedia
|
| 88 |
+
dataset_config: 20231101.en
|
| 89 |
+
text_column: text
|
| 90 |
+
train_split: train
|
| 91 |
+
split_strategy: head
|
| 92 |
+
streaming: true
|
| 93 |
+
max_train_docs: 80000
|
| 94 |
+
max_val_docs: 5000
|
| 95 |
+
- name: open_web_math
|
| 96 |
+
dataset_name: open-web-math/open-web-math
|
| 97 |
+
text_column: text
|
| 98 |
+
train_split: train
|
| 99 |
+
split_strategy: head
|
| 100 |
+
streaming: true
|
| 101 |
+
max_train_docs: 75000
|
| 102 |
+
max_val_docs: 5000
|
| 103 |
+
- name: codeparrot_clean
|
| 104 |
+
dataset_name: codeparrot/codeparrot-clean
|
| 105 |
+
text_column: content
|
| 106 |
+
train_split: train
|
| 107 |
+
split_strategy: head
|
| 108 |
+
streaming: true
|
| 109 |
+
max_train_docs: 34000
|
| 110 |
+
max_val_docs: 2500
|
| 111 |
+
- name: tinystories
|
| 112 |
+
dataset_name: roneneldan/TinyStories
|
| 113 |
+
text_column: text
|
| 114 |
+
train_split: train
|
| 115 |
+
val_split: validation
|
| 116 |
+
split_strategy: head
|
| 117 |
+
streaming: true
|
| 118 |
+
max_train_docs: 200000
|
| 119 |
+
max_val_docs: 5000
|
| 120 |
+
- name: wikitext103
|
| 121 |
+
dataset_name: wikitext
|
| 122 |
+
dataset_config: wikitext-103-raw-v1
|
| 123 |
+
text_column: text
|
| 124 |
+
train_split: train
|
| 125 |
+
val_split: validation
|
| 126 |
+
split_strategy: head
|
| 127 |
+
streaming: false
|
| 128 |
+
max_train_docs: 100000
|
| 129 |
+
max_val_docs: 5000
|
| 130 |
+
- name: arxiv_abstracts
|
| 131 |
+
dataset_name: nick007x/arxiv-papers
|
| 132 |
+
text_column:
|
| 133 |
+
- title
|
| 134 |
+
- subjects
|
| 135 |
+
- abstract
|
| 136 |
+
text_template: 'Title: {title}
|
| 137 |
+
|
| 138 |
+
Subjects: {subjects}
|
| 139 |
+
|
| 140 |
+
Abstract: {abstract}'
|
| 141 |
+
train_split: train
|
| 142 |
+
split_strategy: head
|
| 143 |
+
streaming: true
|
| 144 |
+
max_train_docs: 30000
|
| 145 |
+
max_val_docs: 5000
|
| 146 |
train:
|
| 147 |
seed: 1337
|
| 148 |
device: auto
|
| 149 |
dtype: bfloat16
|
| 150 |
+
out_dir: runs/rex-300m-mixed-2
|
| 151 |
batch_size: 8
|
| 152 |
gradient_accumulation_steps: 1
|
| 153 |
+
epochs: 10
|
| 154 |
max_steps: null
|
| 155 |
+
learning_rate: 5.0e-05
|
| 156 |
+
min_lr: 5.0e-06
|
| 157 |
warmup_steps: 1000
|
| 158 |
weight_decay: 0.1
|
| 159 |
betas:
|
|
|
|
| 162 |
eps: 1.0e-08
|
| 163 |
grad_clip: 1.0
|
| 164 |
compile: true
|
| 165 |
+
resume: runs/rex-300m-mixed-continue/ckpt_step690000.pt
|
| 166 |
log_every: 10
|
| 167 |
+
eval_every: 5000
|
| 168 |
+
eval_batches: 100
|
| 169 |
+
save_every: 10000
|
| 170 |
wandb:
|
| 171 |
enabled: true
|
| 172 |
project: rex
|
| 173 |
entity: null
|
| 174 |
+
name: rex1-mixed-2
|
| 175 |
group: pretrain
|
| 176 |
tags:
|
| 177 |
- recursive-transformer
|
| 178 |
- 300m
|
| 179 |
+
- mixed-2
|
| 180 |
+
- benchmark-mix
|
| 181 |
+
notes: "v2 corpus \u2014 more FineWeb-Edu + Wikipedia, less code/math. Continue\
|
| 182 |
+
\ from mixed-continue step 690k."
|
| 183 |
mode: online
|
| 184 |
watch: false
|
| 185 |
watch_log: gradients
|