HebrewGPT-1B / generate.py
ronnengmail's picture
Upload folder using huggingface_hub
48488d4 verified
#!/usr/bin/env python3
"""
HebrewGPT-1B — Standalone generation script.
This script contains the full model architecture definition and can generate
Hebrew text without depending on the HuggingFace transformers library.
Requirements:
pip install torch sentencepiece
Usage:
python generate.py --prompt "בראשית ברא אלוהים את" --max_tokens 200
python generate.py --prompt "בית המשפט העליון פסק" --temperature 0.8 --top_k 50
"""
import argparse
import math
from dataclasses import dataclass
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import sentencepiece as spm
# ─────────────────────────────────────────────────────────────────────────────
# Model Architecture
# ─────────────────────────────────────────────────────────────────────────────
@dataclass
class ModelConfig:
vocab_size: int = 32000
width: int = 2048
depth: int = 20
n_heads: int = 16
head_dim: int = 128
max_seq_len: int = 2048
dropout: float = 0.0 # Set to 0.0 for inference
rope_theta: float = 10000.0
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
return (x.float() * norm).type_as(x) * self.weight
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_seq_len: int = 2048, theta: float = 10000.0):
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self._build_cache(max_seq_len)
def _build_cache(self, seq_len: int):
t = torch.arange(seq_len, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
self.register_buffer("cos_cached", freqs.cos(), persistent=False)
self.register_buffer("sin_cached", freqs.sin(), persistent=False)
def forward(self, seq_len: int):
if seq_len > self.cos_cached.shape[0]:
self._build_cache(seq_len)
return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
"""Apply RoPE with interleaved pattern (x[..., ::2], x[..., 1::2])."""
x_even = x[..., ::2]
x_odd = x[..., 1::2]
# cos/sin shape: (seq_len, head_dim//2) -> broadcast to (1, seq_len, 1, head_dim//2)
cos = cos.unsqueeze(0).unsqueeze(2) # (1, seq, 1, dim//2)
sin = sin.unsqueeze(0).unsqueeze(2)
out_even = x_even * cos - x_odd * sin
out_odd = x_even * sin + x_odd * cos
# Interleave back
out = torch.stack([out_even, out_odd], dim=-1).flatten(-2)
return out
class SwiGLU(nn.Module):
def __init__(self, width: int, hidden_dim: int, dropout: float = 0.0):
super().__init__()
self.w_gate = nn.Linear(width, hidden_dim, bias=False)
self.w_up = nn.Linear(width, hidden_dim, bias=False)
self.w_down = nn.Linear(hidden_dim, width, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)))
class Attention(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.n_heads = config.n_heads
self.head_dim = config.head_dim
total_dim = config.n_heads * config.head_dim
self.q_proj = nn.Linear(config.width, total_dim, bias=False)
self.k_proj = nn.Linear(config.width, total_dim, bias=False)
self.v_proj = nn.Linear(config.width, total_dim, bias=False)
self.o_proj = nn.Linear(total_dim, config.width, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
mask: torch.Tensor = None) -> torch.Tensor:
B, T, _ = x.shape
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim)
k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim)
v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim)
q = apply_rotary_emb(q, cos, sin)
k = apply_rotary_emb(k, cos, sin)
# (B, n_heads, T, head_dim)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# Scaled dot-product attention
scale = math.sqrt(self.head_dim)
attn = torch.matmul(q, k.transpose(-2, -1)) / scale
if mask is not None:
attn = attn.masked_fill(mask == 0, float("-inf"))
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)
out = torch.matmul(attn, v) # (B, n_heads, T, head_dim)
out = out.transpose(1, 2).contiguous().view(B, T, -1)
return self.o_proj(out)
class TransformerBlock(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
hidden_dim = int(2 * config.width * 4 / 3)
hidden_dim = ((hidden_dim + 63) // 64) * 64 # Round up to multiple of 64
self.ln1 = RMSNorm(config.width)
self.attn = Attention(config)
self.ln2 = RMSNorm(config.width)
self.mlp = SwiGLU(config.width, hidden_dim, config.dropout)
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
mask: torch.Tensor = None) -> torch.Tensor:
x = x + self.attn(self.ln1(x), cos, sin, mask)
x = x + self.mlp(self.ln2(x))
return x
class HebrewGPT(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.tok_emb = nn.Embedding(config.vocab_size, config.width)
self.dropout = nn.Dropout(config.dropout)
self.rotary = RotaryEmbedding(config.head_dim, config.max_seq_len, config.rope_theta)
self.layers = nn.ModuleList([
TransformerBlock(config) for _ in range(config.depth)
])
self.ln_f = RMSNorm(config.width)
self.head = nn.Linear(config.width, config.vocab_size, bias=False)
# Weight tying
self.head.weight = self.tok_emb.weight
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
B, T = input_ids.shape
device = input_ids.device
x = self.dropout(self.tok_emb(input_ids))
cos, sin = self.rotary(T)
cos = cos.to(device)
sin = sin.to(device)
# Causal mask
mask = torch.tril(torch.ones(T, T, device=device)).unsqueeze(0).unsqueeze(0)
for layer in self.layers:
x = layer(x, cos, sin, mask)
x = self.ln_f(x)
logits = self.head(x)
return logits
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 200,
temperature: float = 0.8, top_k: int = 50, top_p: float = 0.9) -> torch.Tensor:
"""Autoregressive generation with top-k and top-p (nucleus) sampling."""
for _ in range(max_new_tokens):
# Crop to max context length
idx_cond = input_ids[:, -self.config.max_seq_len:]
logits = self(idx_cond)
logits = logits[:, -1, :] / temperature
# Top-k filtering
if top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float("-inf")
# Top-p (nucleus) filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
for b in range(logits.shape[0]):
logits[b, sorted_indices[b, sorted_indices_to_remove[b]]] = float("-inf")
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=1)
return input_ids
# ─────────────────────────────────────────────────────────────────────────────
# Main
# ─────────────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="HebrewGPT-1B Text Generation")
parser.add_argument("--model_path", type=str, default="swa_best.pt",
help="Path to model checkpoint (state_dict)")
parser.add_argument("--tokenizer_path", type=str, default="tokenizer.model",
help="Path to SentencePiece tokenizer model")
parser.add_argument("--prompt", type=str, default="בראשית ברא אלוהים את",
help="Hebrew text prompt")
parser.add_argument("--max_tokens", type=int, default=200,
help="Maximum new tokens to generate")
parser.add_argument("--temperature", type=float, default=0.8,
help="Sampling temperature")
parser.add_argument("--top_k", type=int, default=50,
help="Top-k sampling parameter")
parser.add_argument("--top_p", type=float, default=0.9,
help="Top-p (nucleus) sampling parameter")
parser.add_argument("--device", type=str, default=None,
help="Device (cuda/cpu/mps). Auto-detected if not set.")
# Model config overrides (for different model sizes)
parser.add_argument("--width", type=int, default=2048)
parser.add_argument("--depth", type=int, default=20)
parser.add_argument("--n_heads", type=int, default=16)
parser.add_argument("--head_dim", type=int, default=128)
parser.add_argument("--max_seq_len", type=int, default=2048)
args = parser.parse_args()
# Device selection
if args.device:
device = torch.device(args.device)
elif torch.cuda.is_available():
device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"Using device: {device}")
# Load tokenizer
print(f"Loading tokenizer from {args.tokenizer_path}...")
sp = spm.SentencePieceProcessor()
sp.Load(args.tokenizer_path)
# Build model
config = ModelConfig(
vocab_size=32000,
width=args.width,
depth=args.depth,
n_heads=args.n_heads,
head_dim=args.head_dim,
max_seq_len=args.max_seq_len,
dropout=0.0,
)
print(f"Building HebrewGPT model (width={config.width}, depth={config.depth}, "
f"heads={config.n_heads})...")
model = HebrewGPT(config)
# Load weights
print(f"Loading weights from {args.model_path}...")
state_dict = torch.load(args.model_path, map_location="cpu", weights_only=True)
# Handle wrapped checkpoint format (dict with 'model' key)
if isinstance(state_dict, dict) and "model" in state_dict:
state_dict = state_dict["model"]
model.load_state_dict(state_dict)
model.eval().to(device)
param_count = sum(p.numel() for p in model.parameters())
print(f"Model loaded: {param_count:,} parameters")
# Encode prompt
print(f"\nPrompt: {args.prompt}")
input_ids = sp.Encode(args.prompt)
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device)
# Generate
print("Generating...\n")
output_ids = model.generate(
input_tensor,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
)
# Decode and print
generated_text = sp.Decode(output_ids[0].tolist())
print("=" * 60)
print(generated_text)
print("=" * 60)
if __name__ == "__main__":
main()