| |
| """ |
| HebrewGPT-1B — Standalone generation script. |
| |
| This script contains the full model architecture definition and can generate |
| Hebrew text without depending on the HuggingFace transformers library. |
| |
| Requirements: |
| pip install torch sentencepiece |
| |
| Usage: |
| python generate.py --prompt "בראשית ברא אלוהים את" --max_tokens 200 |
| python generate.py --prompt "בית המשפט העליון פסק" --temperature 0.8 --top_k 50 |
| """ |
|
|
| import argparse |
| import math |
| from dataclasses import dataclass |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import sentencepiece as spm |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class ModelConfig: |
| vocab_size: int = 32000 |
| width: int = 2048 |
| depth: int = 20 |
| n_heads: int = 16 |
| head_dim: int = 128 |
| max_seq_len: int = 2048 |
| dropout: float = 0.0 |
| rope_theta: float = 10000.0 |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(dim)) |
| self.eps = eps |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() |
| return (x.float() * norm).type_as(x) * self.weight |
|
|
|
|
| class RotaryEmbedding(nn.Module): |
| def __init__(self, dim: int, max_seq_len: int = 2048, theta: float = 10000.0): |
| super().__init__() |
| inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq) |
| self._build_cache(max_seq_len) |
|
|
| def _build_cache(self, seq_len: int): |
| t = torch.arange(seq_len, dtype=self.inv_freq.dtype) |
| freqs = torch.outer(t, self.inv_freq) |
| self.register_buffer("cos_cached", freqs.cos(), persistent=False) |
| self.register_buffer("sin_cached", freqs.sin(), persistent=False) |
|
|
| def forward(self, seq_len: int): |
| if seq_len > self.cos_cached.shape[0]: |
| self._build_cache(seq_len) |
| return self.cos_cached[:seq_len], self.sin_cached[:seq_len] |
|
|
|
|
| def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| """Apply RoPE with interleaved pattern (x[..., ::2], x[..., 1::2]).""" |
| x_even = x[..., ::2] |
| x_odd = x[..., 1::2] |
|
|
| |
| cos = cos.unsqueeze(0).unsqueeze(2) |
| sin = sin.unsqueeze(0).unsqueeze(2) |
|
|
| out_even = x_even * cos - x_odd * sin |
| out_odd = x_even * sin + x_odd * cos |
|
|
| |
| out = torch.stack([out_even, out_odd], dim=-1).flatten(-2) |
| return out |
|
|
|
|
| class SwiGLU(nn.Module): |
| def __init__(self, width: int, hidden_dim: int, dropout: float = 0.0): |
| super().__init__() |
| self.w_gate = nn.Linear(width, hidden_dim, bias=False) |
| self.w_up = nn.Linear(width, hidden_dim, bias=False) |
| self.w_down = nn.Linear(hidden_dim, width, bias=False) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))) |
|
|
|
|
| class Attention(nn.Module): |
| def __init__(self, config: ModelConfig): |
| super().__init__() |
| self.n_heads = config.n_heads |
| self.head_dim = config.head_dim |
| total_dim = config.n_heads * config.head_dim |
|
|
| self.q_proj = nn.Linear(config.width, total_dim, bias=False) |
| self.k_proj = nn.Linear(config.width, total_dim, bias=False) |
| self.v_proj = nn.Linear(config.width, total_dim, bias=False) |
| self.o_proj = nn.Linear(total_dim, config.width, bias=False) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, |
| mask: torch.Tensor = None) -> torch.Tensor: |
| B, T, _ = x.shape |
|
|
| q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim) |
| k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim) |
| v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim) |
|
|
| q = apply_rotary_emb(q, cos, sin) |
| k = apply_rotary_emb(k, cos, sin) |
|
|
| |
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| v = v.transpose(1, 2) |
|
|
| |
| scale = math.sqrt(self.head_dim) |
| attn = torch.matmul(q, k.transpose(-2, -1)) / scale |
|
|
| if mask is not None: |
| attn = attn.masked_fill(mask == 0, float("-inf")) |
|
|
| attn = F.softmax(attn, dim=-1) |
| attn = self.dropout(attn) |
|
|
| out = torch.matmul(attn, v) |
| out = out.transpose(1, 2).contiguous().view(B, T, -1) |
| return self.o_proj(out) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, config: ModelConfig): |
| super().__init__() |
| hidden_dim = int(2 * config.width * 4 / 3) |
| hidden_dim = ((hidden_dim + 63) // 64) * 64 |
|
|
| self.ln1 = RMSNorm(config.width) |
| self.attn = Attention(config) |
| self.ln2 = RMSNorm(config.width) |
| self.mlp = SwiGLU(config.width, hidden_dim, config.dropout) |
|
|
| def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, |
| mask: torch.Tensor = None) -> torch.Tensor: |
| x = x + self.attn(self.ln1(x), cos, sin, mask) |
| x = x + self.mlp(self.ln2(x)) |
| return x |
|
|
|
|
| class HebrewGPT(nn.Module): |
| def __init__(self, config: ModelConfig): |
| super().__init__() |
| self.config = config |
|
|
| self.tok_emb = nn.Embedding(config.vocab_size, config.width) |
| self.dropout = nn.Dropout(config.dropout) |
| self.rotary = RotaryEmbedding(config.head_dim, config.max_seq_len, config.rope_theta) |
|
|
| self.layers = nn.ModuleList([ |
| TransformerBlock(config) for _ in range(config.depth) |
| ]) |
|
|
| self.ln_f = RMSNorm(config.width) |
| self.head = nn.Linear(config.width, config.vocab_size, bias=False) |
|
|
| |
| self.head.weight = self.tok_emb.weight |
|
|
| def forward(self, input_ids: torch.Tensor) -> torch.Tensor: |
| B, T = input_ids.shape |
| device = input_ids.device |
|
|
| x = self.dropout(self.tok_emb(input_ids)) |
| cos, sin = self.rotary(T) |
| cos = cos.to(device) |
| sin = sin.to(device) |
|
|
| |
| mask = torch.tril(torch.ones(T, T, device=device)).unsqueeze(0).unsqueeze(0) |
|
|
| for layer in self.layers: |
| x = layer(x, cos, sin, mask) |
|
|
| x = self.ln_f(x) |
| logits = self.head(x) |
| return logits |
|
|
| @torch.no_grad() |
| def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 200, |
| temperature: float = 0.8, top_k: int = 50, top_p: float = 0.9) -> torch.Tensor: |
| """Autoregressive generation with top-k and top-p (nucleus) sampling.""" |
| for _ in range(max_new_tokens): |
| |
| idx_cond = input_ids[:, -self.config.max_seq_len:] |
| logits = self(idx_cond) |
| logits = logits[:, -1, :] / temperature |
|
|
| |
| if top_k > 0: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = float("-inf") |
|
|
| |
| if top_p < 1.0: |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| sorted_indices_to_remove = cumulative_probs > top_p |
| sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() |
| sorted_indices_to_remove[:, 0] = False |
| for b in range(logits.shape[0]): |
| logits[b, sorted_indices[b, sorted_indices_to_remove[b]]] = float("-inf") |
|
|
| probs = F.softmax(logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| input_ids = torch.cat([input_ids, next_token], dim=1) |
|
|
| return input_ids |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="HebrewGPT-1B Text Generation") |
| parser.add_argument("--model_path", type=str, default="swa_best.pt", |
| help="Path to model checkpoint (state_dict)") |
| parser.add_argument("--tokenizer_path", type=str, default="tokenizer.model", |
| help="Path to SentencePiece tokenizer model") |
| parser.add_argument("--prompt", type=str, default="בראשית ברא אלוהים את", |
| help="Hebrew text prompt") |
| parser.add_argument("--max_tokens", type=int, default=200, |
| help="Maximum new tokens to generate") |
| parser.add_argument("--temperature", type=float, default=0.8, |
| help="Sampling temperature") |
| parser.add_argument("--top_k", type=int, default=50, |
| help="Top-k sampling parameter") |
| parser.add_argument("--top_p", type=float, default=0.9, |
| help="Top-p (nucleus) sampling parameter") |
| parser.add_argument("--device", type=str, default=None, |
| help="Device (cuda/cpu/mps). Auto-detected if not set.") |
| |
| parser.add_argument("--width", type=int, default=2048) |
| parser.add_argument("--depth", type=int, default=20) |
| parser.add_argument("--n_heads", type=int, default=16) |
| parser.add_argument("--head_dim", type=int, default=128) |
| parser.add_argument("--max_seq_len", type=int, default=2048) |
| args = parser.parse_args() |
|
|
| |
| if args.device: |
| device = torch.device(args.device) |
| elif torch.cuda.is_available(): |
| device = torch.device("cuda") |
| elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
| device = torch.device("mps") |
| else: |
| device = torch.device("cpu") |
|
|
| print(f"Using device: {device}") |
|
|
| |
| print(f"Loading tokenizer from {args.tokenizer_path}...") |
| sp = spm.SentencePieceProcessor() |
| sp.Load(args.tokenizer_path) |
|
|
| |
| config = ModelConfig( |
| vocab_size=32000, |
| width=args.width, |
| depth=args.depth, |
| n_heads=args.n_heads, |
| head_dim=args.head_dim, |
| max_seq_len=args.max_seq_len, |
| dropout=0.0, |
| ) |
| print(f"Building HebrewGPT model (width={config.width}, depth={config.depth}, " |
| f"heads={config.n_heads})...") |
| model = HebrewGPT(config) |
|
|
| |
| print(f"Loading weights from {args.model_path}...") |
| state_dict = torch.load(args.model_path, map_location="cpu", weights_only=True) |
| |
| if isinstance(state_dict, dict) and "model" in state_dict: |
| state_dict = state_dict["model"] |
| model.load_state_dict(state_dict) |
| model.eval().to(device) |
|
|
| param_count = sum(p.numel() for p in model.parameters()) |
| print(f"Model loaded: {param_count:,} parameters") |
|
|
| |
| print(f"\nPrompt: {args.prompt}") |
| input_ids = sp.Encode(args.prompt) |
| input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device) |
|
|
| |
| print("Generating...\n") |
| output_ids = model.generate( |
| input_tensor, |
| max_new_tokens=args.max_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| top_p=args.top_p, |
| ) |
|
|
| |
| generated_text = sp.Decode(output_ids[0].tolist()) |
| print("=" * 60) |
| print(generated_text) |
| print("=" * 60) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|