Instructions to use DavidSeyserHF/rex1-base with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use DavidSeyserHF/rex1-base with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="DavidSeyserHF/rex1-base", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("DavidSeyserHF/rex1-base", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use DavidSeyserHF/rex1-base with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "DavidSeyserHF/rex1-base" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "DavidSeyserHF/rex1-base", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/DavidSeyserHF/rex1-base
- SGLang
How to use DavidSeyserHF/rex1-base with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "DavidSeyserHF/rex1-base" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "DavidSeyserHF/rex1-base", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "DavidSeyserHF/rex1-base" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "DavidSeyserHF/rex1-base", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use DavidSeyserHF/rex1-base with Docker Model Runner:
docker model run hf.co/DavidSeyserHF/rex1-base
| """REX: a recursive decoder-only Transformer language model.""" | |
| from __future__ import annotations | |
| import json | |
| from dataclasses import asdict, dataclass | |
| from pathlib import Path | |
| from typing import Any | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class RexConfig: | |
| vocab_size: int = 50_257 | |
| max_seq_len: int = 2048 | |
| d_model: int = 1536 | |
| n_heads: int = 16 | |
| n_layers: int = 8 | |
| recurrence_steps: int = 2 | |
| ffn_dim: int = 3968 | |
| dropout: float = 0.0 | |
| norm_eps: float = 1e-5 | |
| tie_embeddings: bool = True | |
| use_step_embeddings: bool = True | |
| initializer_range: float = 0.02 | |
| def from_dict(cls, data: dict[str, Any]) -> "RexConfig": | |
| fields = {name for name in cls.__dataclass_fields__} | |
| return cls(**{k: v for k, v in data.items() if k in fields}) | |
| def to_dict(self) -> dict[str, Any]: | |
| return asdict(self) | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-5): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| dtype = x.dtype | |
| x = x.float() | |
| x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) | |
| return (self.weight * x).to(dtype) | |
| class RotaryEmbedding(nn.Module): | |
| def __init__(self, dim: int, max_seq_len: int, base: float = 10_000.0): | |
| super().__init__() | |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) | |
| positions = torch.arange(max_seq_len, dtype=torch.float) | |
| freqs = torch.outer(positions, inv_freq) | |
| self.register_buffer("cos", freqs.cos(), persistent=False) | |
| self.register_buffer("sin", freqs.sin(), persistent=False) | |
| def forward(self, seq_len: int) -> tuple[torch.Tensor, torch.Tensor]: | |
| return self.cos[:seq_len], self.sin[:seq_len] | |
| def _rotate_half(x: torch.Tensor) -> torch.Tensor: | |
| x1 = x[..., ::2] | |
| x2 = x[..., 1::2] | |
| return torch.stack((-x2, x1), dim=-1).flatten(-2) | |
| def apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: | |
| cos = torch.repeat_interleave(cos, 2, dim=-1)[None, None, :, :] | |
| sin = torch.repeat_interleave(sin, 2, dim=-1)[None, None, :, :] | |
| return (x * cos) + (_rotate_half(x) * sin) | |
| def _safe_torch_load(path: str | Path, map_location: str | torch.device | None) -> Any: | |
| try: | |
| return torch.load(path, map_location=map_location, weights_only=True) | |
| except TypeError: | |
| return torch.load(path, map_location=map_location) | |
| class CausalSelfAttention(nn.Module): | |
| def __init__(self, cfg: RexConfig): | |
| super().__init__() | |
| if cfg.d_model % cfg.n_heads != 0: | |
| raise ValueError("d_model must be divisible by n_heads") | |
| self.n_heads = cfg.n_heads | |
| self.head_dim = cfg.d_model // cfg.n_heads | |
| if self.head_dim % 2 != 0: | |
| raise ValueError("attention head_dim must be even for rotary embeddings") | |
| self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False) | |
| self.out = nn.Linear(cfg.d_model, cfg.d_model, bias=False) | |
| self.dropout = cfg.dropout | |
| self.rotary = RotaryEmbedding(self.head_dim, cfg.max_seq_len) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| bsz, seq_len, width = x.shape | |
| qkv = self.qkv(x) | |
| q, k, v = qkv.chunk(3, dim=-1) | |
| q = q.view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) | |
| k = k.view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) | |
| v = v.view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) | |
| cos, sin = self.rotary(seq_len) | |
| q = apply_rotary(q, cos.to(q.device), sin.to(q.device)) | |
| k = apply_rotary(k, cos.to(k.device), sin.to(k.device)) | |
| y = F.scaled_dot_product_attention( | |
| q, | |
| k, | |
| v, | |
| dropout_p=self.dropout if self.training else 0.0, | |
| is_causal=True, | |
| ) | |
| y = y.transpose(1, 2).contiguous().view(bsz, seq_len, width) | |
| return self.out(y) | |
| class SwiGLU(nn.Module): | |
| def __init__(self, cfg: RexConfig): | |
| super().__init__() | |
| self.w1 = nn.Linear(cfg.d_model, cfg.ffn_dim, bias=False) | |
| self.w2 = nn.Linear(cfg.ffn_dim, cfg.d_model, bias=False) | |
| self.w3 = nn.Linear(cfg.d_model, cfg.ffn_dim, bias=False) | |
| self.dropout = nn.Dropout(cfg.dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) | |
| class RexBlock(nn.Module): | |
| def __init__(self, cfg: RexConfig): | |
| super().__init__() | |
| self.attn_norm = RMSNorm(cfg.d_model, cfg.norm_eps) | |
| self.attn = CausalSelfAttention(cfg) | |
| self.ffn_norm = RMSNorm(cfg.d_model, cfg.norm_eps) | |
| self.ffn = SwiGLU(cfg) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x + self.attn(self.attn_norm(x)) | |
| x = x + self.ffn(self.ffn_norm(x)) | |
| return x | |
| class RexForCausalLM(nn.Module): | |
| """Decoder-only LM with a stack of blocks reused across recursive passes.""" | |
| def __init__(self, cfg: RexConfig): | |
| super().__init__() | |
| if cfg.recurrence_steps < 1: | |
| raise ValueError("recurrence_steps must be >= 1") | |
| self.cfg = cfg | |
| self.token_embedding = nn.Embedding(cfg.vocab_size, cfg.d_model) | |
| self.drop = nn.Dropout(cfg.dropout) | |
| self.blocks = nn.ModuleList([RexBlock(cfg) for _ in range(cfg.n_layers)]) | |
| self.final_norm = RMSNorm(cfg.d_model, cfg.norm_eps) | |
| self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) | |
| if cfg.tie_embeddings: | |
| self.lm_head.weight = self.token_embedding.weight | |
| if cfg.use_step_embeddings: | |
| self.step_embedding = nn.Parameter(torch.zeros(cfg.recurrence_steps, cfg.d_model)) | |
| else: | |
| self.register_parameter("step_embedding", None) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, module: nn.Module) -> None: | |
| if isinstance(module, nn.Linear): | |
| nn.init.normal_(module.weight, mean=0.0, std=self.cfg.initializer_range) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| nn.init.normal_(module.weight, mean=0.0, std=self.cfg.initializer_range) | |
| def encode(self, input_ids: torch.Tensor, normalize: bool = True) -> torch.Tensor: | |
| """Return contextual token representations for downstream tasks.""" | |
| if input_ids.ndim != 2: | |
| raise ValueError("input_ids must have shape [batch, seq]") | |
| if input_ids.size(1) > self.cfg.max_seq_len: | |
| raise ValueError(f"sequence length exceeds max_seq_len={self.cfg.max_seq_len}") | |
| x = self.drop(self.token_embedding(input_ids)) | |
| for step in range(self.cfg.recurrence_steps): | |
| if self.step_embedding is not None: | |
| x = x + self.step_embedding[step].view(1, 1, -1) | |
| for block in self.blocks: | |
| x = block(x) | |
| if normalize: | |
| x = self.final_norm(x) | |
| return x | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| labels: torch.Tensor | None = None, | |
| ) -> dict[str, torch.Tensor | None]: | |
| hidden_states = self.encode(input_ids, normalize=True) | |
| logits = self.lm_head(hidden_states) | |
| loss = None | |
| if labels is not None: | |
| loss = F.cross_entropy( | |
| logits[:, :-1].contiguous().view(-1, logits.size(-1)), | |
| labels[:, 1:].contiguous().view(-1), | |
| ignore_index=-100, | |
| ) | |
| return {"logits": logits, "loss": loss} | |
| def generate( | |
| self, | |
| input_ids: torch.Tensor, | |
| max_new_tokens: int, | |
| temperature: float = 1.0, | |
| top_k: int | None = None, | |
| no_repeat_ngram_size: int = 0, | |
| ) -> torch.Tensor: | |
| self.eval() | |
| if no_repeat_ngram_size < 0: | |
| raise ValueError("no_repeat_ngram_size must be >= 0") | |
| for _ in range(max_new_tokens): | |
| context = input_ids[:, -self.cfg.max_seq_len :] | |
| logits = self(context)["logits"][:, -1, :] | |
| logits = self._apply_no_repeat_ngram(logits, input_ids, no_repeat_ngram_size) | |
| if temperature < 0: | |
| raise ValueError("temperature must be >= 0") | |
| if temperature == 0: | |
| next_token = torch.argmax(logits, dim=-1, keepdim=True) | |
| input_ids = torch.cat([input_ids, next_token], dim=1) | |
| continue | |
| logits = logits / temperature | |
| if top_k is not None: | |
| values, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| logits = logits.masked_fill(logits < values[:, [-1]], float("-inf")) | |
| probs = F.softmax(logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| input_ids = torch.cat([input_ids, next_token], dim=1) | |
| return input_ids | |
| def _apply_no_repeat_ngram( | |
| logits: torch.Tensor, | |
| input_ids: torch.Tensor, | |
| no_repeat_ngram_size: int, | |
| ) -> torch.Tensor: | |
| if no_repeat_ngram_size <= 0: | |
| return logits | |
| logits = logits.clone() | |
| for batch_idx in range(input_ids.size(0)): | |
| banned_tokens = RexForCausalLM._get_banned_ngram_tokens( | |
| input_ids[batch_idx].tolist(), | |
| no_repeat_ngram_size, | |
| ) | |
| if banned_tokens: | |
| logits[batch_idx, banned_tokens] = float("-inf") | |
| return logits | |
| def _get_banned_ngram_tokens(tokens: list[int], ngram_size: int) -> list[int]: | |
| if ngram_size == 1: | |
| return list(set(tokens)) | |
| if len(tokens) < ngram_size - 1: | |
| return [] | |
| prefix_to_next: dict[tuple[int, ...], set[int]] = {} | |
| for i in range(len(tokens) - ngram_size + 1): | |
| ngram = tokens[i : i + ngram_size] | |
| prefix = tuple(ngram[:-1]) | |
| prefix_to_next.setdefault(prefix, set()).add(ngram[-1]) | |
| current_prefix = tuple(tokens[-(ngram_size - 1) :]) | |
| return list(prefix_to_next.get(current_prefix, set())) | |
| def parameter_count(self, trainable_only: bool = False) -> int: | |
| params = self.parameters() | |
| if trainable_only: | |
| params = (p for p in params if p.requires_grad) | |
| return sum(p.numel() for p in params) | |
| def save_pretrained(self, save_directory: str | Path) -> None: | |
| """Save model weights and config in a lightweight HF-style folder.""" | |
| save_path = Path(save_directory) | |
| save_path.mkdir(parents=True, exist_ok=True) | |
| with open(save_path / "config.json", "w", encoding="utf-8") as f: | |
| json.dump(self.cfg.to_dict(), f, indent=2) | |
| f.write("\n") | |
| torch.save(self.state_dict(), save_path / "pytorch_model.bin") | |
| def from_pretrained( | |
| cls, | |
| load_directory: str | Path, | |
| map_location: str | torch.device | None = "cpu", | |
| strict: bool = True, | |
| ) -> "RexForCausalLM": | |
| """Load a model saved by ``save_pretrained``.""" | |
| load_path = Path(load_directory) | |
| with open(load_path / "config.json", "r", encoding="utf-8") as f: | |
| cfg = RexConfig.from_dict(json.load(f)) | |
| model = cls(cfg) | |
| state_dict = _safe_torch_load(load_path / "pytorch_model.bin", map_location) | |
| model.load_state_dict(state_dict, strict=strict) | |
| return model | |
| def from_checkpoint( | |
| cls, | |
| checkpoint_path: str | Path, | |
| map_location: str | torch.device | None = "cpu", | |
| strict: bool = True, | |
| ) -> "RexForCausalLM": | |
| """Load from a training checkpoint produced by ``train.py``.""" | |
| checkpoint = _safe_torch_load(checkpoint_path, map_location) | |
| cfg_data = checkpoint.get("model_config") | |
| if cfg_data is None: | |
| cfg_data = checkpoint.get("config", {}).get("model") | |
| if cfg_data is None: | |
| raise ValueError("checkpoint does not contain model_config or config.model") | |
| model = cls(RexConfig.from_dict(cfg_data)) | |
| state_dict = checkpoint.get("model", checkpoint) | |
| model.load_state_dict(state_dict, strict=strict) | |
| return model | |
| def build_model(config: dict[str, Any] | RexConfig | None = None) -> RexForCausalLM: | |
| if config is None: | |
| cfg = RexConfig() | |
| elif isinstance(config, RexConfig): | |
| cfg = config | |
| else: | |
| cfg = RexConfig.from_dict(config) | |
| return RexForCausalLM(cfg) | |