#!/usr/bin/env python3 """ HebrewGPT-1B — Standalone generation script. This script contains the full model architecture definition and can generate Hebrew text without depending on the HuggingFace transformers library. Requirements: pip install torch sentencepiece Usage: python generate.py --prompt "בראשית ברא אלוהים את" --max_tokens 200 python generate.py --prompt "בית המשפט העליון פסק" --temperature 0.8 --top_k 50 """ import argparse import math from dataclasses import dataclass from pathlib import Path import torch import torch.nn as nn import torch.nn.functional as F import sentencepiece as spm # ───────────────────────────────────────────────────────────────────────────── # Model Architecture # ───────────────────────────────────────────────────────────────────────────── @dataclass class ModelConfig: vocab_size: int = 32000 width: int = 2048 depth: int = 20 n_heads: int = 16 head_dim: int = 128 max_seq_len: int = 2048 dropout: float = 0.0 # Set to 0.0 for inference rope_theta: float = 10000.0 class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() return (x.float() * norm).type_as(x) * self.weight class RotaryEmbedding(nn.Module): def __init__(self, dim: int, max_seq_len: int = 2048, theta: float = 10000.0): super().__init__() inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) self._build_cache(max_seq_len) def _build_cache(self, seq_len: int): t = torch.arange(seq_len, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq) self.register_buffer("cos_cached", freqs.cos(), persistent=False) self.register_buffer("sin_cached", freqs.sin(), persistent=False) def forward(self, seq_len: int): if seq_len > self.cos_cached.shape[0]: self._build_cache(seq_len) return self.cos_cached[:seq_len], self.sin_cached[:seq_len] def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: """Apply RoPE with interleaved pattern (x[..., ::2], x[..., 1::2]).""" x_even = x[..., ::2] x_odd = x[..., 1::2] # cos/sin shape: (seq_len, head_dim//2) -> broadcast to (1, seq_len, 1, head_dim//2) cos = cos.unsqueeze(0).unsqueeze(2) # (1, seq, 1, dim//2) sin = sin.unsqueeze(0).unsqueeze(2) out_even = x_even * cos - x_odd * sin out_odd = x_even * sin + x_odd * cos # Interleave back out = torch.stack([out_even, out_odd], dim=-1).flatten(-2) return out class SwiGLU(nn.Module): def __init__(self, width: int, hidden_dim: int, dropout: float = 0.0): super().__init__() self.w_gate = nn.Linear(width, hidden_dim, bias=False) self.w_up = nn.Linear(width, hidden_dim, bias=False) self.w_down = nn.Linear(hidden_dim, width, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))) class Attention(nn.Module): def __init__(self, config: ModelConfig): super().__init__() self.n_heads = config.n_heads self.head_dim = config.head_dim total_dim = config.n_heads * config.head_dim self.q_proj = nn.Linear(config.width, total_dim, bias=False) self.k_proj = nn.Linear(config.width, total_dim, bias=False) self.v_proj = nn.Linear(config.width, total_dim, bias=False) self.o_proj = nn.Linear(total_dim, config.width, bias=False) self.dropout = nn.Dropout(config.dropout) def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: B, T, _ = x.shape q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim) k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim) v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim) q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) # (B, n_heads, T, head_dim) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) # Scaled dot-product attention scale = math.sqrt(self.head_dim) attn = torch.matmul(q, k.transpose(-2, -1)) / scale if mask is not None: attn = attn.masked_fill(mask == 0, float("-inf")) attn = F.softmax(attn, dim=-1) attn = self.dropout(attn) out = torch.matmul(attn, v) # (B, n_heads, T, head_dim) out = out.transpose(1, 2).contiguous().view(B, T, -1) return self.o_proj(out) class TransformerBlock(nn.Module): def __init__(self, config: ModelConfig): super().__init__() hidden_dim = int(2 * config.width * 4 / 3) hidden_dim = ((hidden_dim + 63) // 64) * 64 # Round up to multiple of 64 self.ln1 = RMSNorm(config.width) self.attn = Attention(config) self.ln2 = RMSNorm(config.width) self.mlp = SwiGLU(config.width, hidden_dim, config.dropout) def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: x = x + self.attn(self.ln1(x), cos, sin, mask) x = x + self.mlp(self.ln2(x)) return x class HebrewGPT(nn.Module): def __init__(self, config: ModelConfig): super().__init__() self.config = config self.tok_emb = nn.Embedding(config.vocab_size, config.width) self.dropout = nn.Dropout(config.dropout) self.rotary = RotaryEmbedding(config.head_dim, config.max_seq_len, config.rope_theta) self.layers = nn.ModuleList([ TransformerBlock(config) for _ in range(config.depth) ]) self.ln_f = RMSNorm(config.width) self.head = nn.Linear(config.width, config.vocab_size, bias=False) # Weight tying self.head.weight = self.tok_emb.weight def forward(self, input_ids: torch.Tensor) -> torch.Tensor: B, T = input_ids.shape device = input_ids.device x = self.dropout(self.tok_emb(input_ids)) cos, sin = self.rotary(T) cos = cos.to(device) sin = sin.to(device) # Causal mask mask = torch.tril(torch.ones(T, T, device=device)).unsqueeze(0).unsqueeze(0) for layer in self.layers: x = layer(x, cos, sin, mask) x = self.ln_f(x) logits = self.head(x) return logits @torch.no_grad() def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 200, temperature: float = 0.8, top_k: int = 50, top_p: float = 0.9) -> torch.Tensor: """Autoregressive generation with top-k and top-p (nucleus) sampling.""" for _ in range(max_new_tokens): # Crop to max context length idx_cond = input_ids[:, -self.config.max_seq_len:] logits = self(idx_cond) logits = logits[:, -1, :] / temperature # Top-k filtering if top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = float("-inf") # Top-p (nucleus) filtering if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() sorted_indices_to_remove[:, 0] = False for b in range(logits.shape[0]): logits[b, sorted_indices[b, sorted_indices_to_remove[b]]] = float("-inf") probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) input_ids = torch.cat([input_ids, next_token], dim=1) return input_ids # ───────────────────────────────────────────────────────────────────────────── # Main # ───────────────────────────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser(description="HebrewGPT-1B Text Generation") parser.add_argument("--model_path", type=str, default="swa_best.pt", help="Path to model checkpoint (state_dict)") parser.add_argument("--tokenizer_path", type=str, default="tokenizer.model", help="Path to SentencePiece tokenizer model") parser.add_argument("--prompt", type=str, default="בראשית ברא אלוהים את", help="Hebrew text prompt") parser.add_argument("--max_tokens", type=int, default=200, help="Maximum new tokens to generate") parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature") parser.add_argument("--top_k", type=int, default=50, help="Top-k sampling parameter") parser.add_argument("--top_p", type=float, default=0.9, help="Top-p (nucleus) sampling parameter") parser.add_argument("--device", type=str, default=None, help="Device (cuda/cpu/mps). Auto-detected if not set.") # Model config overrides (for different model sizes) parser.add_argument("--width", type=int, default=2048) parser.add_argument("--depth", type=int, default=20) parser.add_argument("--n_heads", type=int, default=16) parser.add_argument("--head_dim", type=int, default=128) parser.add_argument("--max_seq_len", type=int, default=2048) args = parser.parse_args() # Device selection if args.device: device = torch.device(args.device) elif torch.cuda.is_available(): device = torch.device("cuda") elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") print(f"Using device: {device}") # Load tokenizer print(f"Loading tokenizer from {args.tokenizer_path}...") sp = spm.SentencePieceProcessor() sp.Load(args.tokenizer_path) # Build model config = ModelConfig( vocab_size=32000, width=args.width, depth=args.depth, n_heads=args.n_heads, head_dim=args.head_dim, max_seq_len=args.max_seq_len, dropout=0.0, ) print(f"Building HebrewGPT model (width={config.width}, depth={config.depth}, " f"heads={config.n_heads})...") model = HebrewGPT(config) # Load weights print(f"Loading weights from {args.model_path}...") state_dict = torch.load(args.model_path, map_location="cpu", weights_only=True) # Handle wrapped checkpoint format (dict with 'model' key) if isinstance(state_dict, dict) and "model" in state_dict: state_dict = state_dict["model"] model.load_state_dict(state_dict) model.eval().to(device) param_count = sum(p.numel() for p in model.parameters()) print(f"Model loaded: {param_count:,} parameters") # Encode prompt print(f"\nPrompt: {args.prompt}") input_ids = sp.Encode(args.prompt) input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device) # Generate print("Generating...\n") output_ids = model.generate( input_tensor, max_new_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, ) # Decode and print generated_text = sp.Decode(output_ids[0].tolist()) print("=" * 60) print(generated_text) print("=" * 60) if __name__ == "__main__": main()