ronnengmail commited on
Commit
48488d4
·
verified ·
1 Parent(s): f6623a9

Upload folder using huggingface_hub

Browse files
Files changed (8) hide show
  1. README.md +243 -0
  2. best.pt +3 -0
  3. config.json +20 -0
  4. generate.py +330 -0
  5. special_tokens_map.json +6 -0
  6. swa_best.pt +3 -0
  7. tokenizer.model +3 -0
  8. tokenizer_config.json +11 -0
README.md ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - he
4
+ license: apache-2.0
5
+ tags:
6
+ - hebrew
7
+ - gpt
8
+ - causal-lm
9
+ - hebrew-nlp
10
+ - muon-optimizer
11
+ - sentencepiece
12
+ - rope
13
+ - swiglu
14
+ datasets:
15
+ - hebrew-wikipedia
16
+ - HeNLP/HeDC4
17
+ library_name: transformers
18
+ pipeline_tag: text-generation
19
+ model-index:
20
+ - name: HebrewGPT-1B
21
+ results:
22
+ - task:
23
+ type: text-generation
24
+ name: Language Modeling
25
+ metrics:
26
+ - name: Perplexity
27
+ type: perplexity
28
+ value: 29.75
29
+ - name: Top-1 Accuracy
30
+ type: accuracy
31
+ value: 38.4
32
+ - name: Top-5 Accuracy
33
+ type: accuracy
34
+ value: 56.1
35
+ ---
36
+
37
+ # HebrewGPT-1B 🇮🇱
38
+
39
+ **HebrewGPT-1B** is a 1.08 billion parameter autoregressive language model trained from scratch on 2.48 billion tokens of Hebrew text. It is the first open-source, Hebrew-native GPT model of this scale, featuring a custom architecture with SwiGLU activations, RoPE positional encoding, and RMSNorm — trained with the Muon optimizer combined with Lookahead and Stochastic Weight Averaging (SWA).
40
+
41
+ This model was developed as part of an autonomous AI research project exploring whether an AI agent could independently conduct meaningful ML research. The full paper and methodology are available at the links below.
42
+
43
+ - 📄 **Paper**: [Hebrew Language Model Research via Agentic AI](https://d11k83yu06biio.cloudfront.net/paper/hebrew-autoresearch.html)
44
+ - 💻 **GitHub**: [AgenticResearcher](https://github.com/fatherRonnen/AgenticResearcher)
45
+ - 🔬 **Ablation model**: [HebrewGPT-1B-AdamW](https://huggingface.co/Slasky/HebrewGPT-1B-AdamW) (AdamW baseline)
46
+ - 🧪 **Smaller model**: [HebrewGPT-296M](https://huggingface.co/Slasky/HebrewGPT-296M) (296M parameter variant)
47
+
48
+ ## Model Description
49
+
50
+ | Parameter | Value |
51
+ |---|---|
52
+ | Parameters | 1.08B |
53
+ | Hidden size (WIDTH) | 2048 |
54
+ | Layers (DEPTH) | 20 |
55
+ | Attention heads | 16 |
56
+ | Head dimension | 128 |
57
+ | MLP type | SwiGLU (intermediate_size=5504) |
58
+ | Positional encoding | RoPE (interleaved, θ=10000) |
59
+ | Normalization | RMSNorm |
60
+ | Vocabulary | 32,000 (Hebrew-native SentencePiece BPE) |
61
+ | Context length | 2,048 tokens |
62
+ | Weight tying | Yes (embedding ↔ output head) |
63
+ | Precision | bfloat16 |
64
+
65
+ ### Architecture Details
66
+
67
+ HebrewGPT uses a decoder-only transformer with several modern design choices:
68
+
69
+ - **SwiGLU MLP**: Gate and up projections with SiLU activation, hidden dim = `int(2 × width × 4/3)` rounded up to multiple of 64 = 5504
70
+ - **RoPE**: Rotary Position Embeddings with interleaved pattern (`x[..., ::2]`, `x[..., 1::2]`)
71
+ - **RMSNorm**: Pre-norm architecture with RMSNorm before attention and MLP
72
+ - **Weight tying**: Output projection shares weights with token embeddings
73
+
74
+ ## Training Details
75
+
76
+ ### Optimizer
77
+ - **Muon** optimizer + **Lookahead** (k=5, α=0.6) + **Stochastic Weight Averaging (SWA)**
78
+ - 4 cosine annealing cycles with warm restarts
79
+ - Dropout: 0.1
80
+
81
+ ### Data
82
+ 2.48 billion tokens from 12 Hebrew datasets:
83
+
84
+ | Dataset | Proportion |
85
+ |---|---|
86
+ | Ben Yehuda Project (literature) | 23% |
87
+ | Supreme Court rulings | 22% |
88
+ | C4 (Hebrew subset) | 20% |
89
+ | CC100 (Hebrew) | 19% |
90
+ | Hebrew Wikipedia | 12% |
91
+ | Task-specific data | 4% |
92
+
93
+ ### Hardware & Cost
94
+ - **Hardware**: 8× NVIDIA H100 80GB GPUs
95
+ - **Training time**: ~8 hours
96
+ - **Steps**: ~18,672
97
+
98
+ ## Evaluation Results
99
+
100
+ ### Overall Metrics
101
+
102
+ | Metric | Value |
103
+ |---|---|
104
+ | Validation BPB (SWA) | 25.89 |
105
+ | Perplexity | 29.75 |
106
+ | Top-1 Token Accuracy | 38.4% |
107
+ | Top-5 Token Accuracy | 56.1% |
108
+ | Top-10 Token Accuracy | 63.6% |
109
+
110
+ ### Domain-Specific Perplexity
111
+
112
+ | Domain | Perplexity |
113
+ |---|---|
114
+ | Legal | 5.93 |
115
+ | Wikipedia | 11.50 |
116
+ | News | 24.81 |
117
+ | Conversational | 29.79 |
118
+ | Literature | 31.42 |
119
+
120
+ ### Comparison with Other Hebrew Models
121
+
122
+ | Model | Top-1 Accuracy | Top-5 Accuracy |
123
+ |---|---|---|
124
+ | **HebrewGPT-1B (this model)** | **38.4%** | **56.1%** |
125
+ | HebrewGPT-296M | 39.6% | 68.4% |
126
+ | AlephBERT | ~35% | — |
127
+ | HeBERT | ~33% | — |
128
+
129
+ *Note: AlephBERT and HeBERT are encoder models (BERT-based) and not directly comparable for generation tasks. Token prediction accuracy is provided for reference on Hebrew language understanding capability.*
130
+
131
+ ### Optimizer Ablation
132
+
133
+ Training with AdamW instead of Muon (all else equal) yields val_bpb=28.09 — a **12.3% degradation**, demonstrating the significant advantage of Muon at the 1B scale. See [HebrewGPT-1B-AdamW](https://huggingface.co/Slasky/HebrewGPT-1B-AdamW) for details.
134
+
135
+ ## Usage
136
+
137
+ > ⚠️ **Custom Architecture**: This model uses a custom architecture that is not a standard HuggingFace `transformers` model. You must use the provided model class definition or reference the [GitHub repository](https://github.com/fatherRonnen/AgenticResearcher).
138
+
139
+ ### Quick Start
140
+
141
+ ```python
142
+ import torch
143
+ import sentencepiece as spm
144
+
145
+ # Load tokenizer
146
+ sp = spm.SentencePieceProcessor()
147
+ sp.Load("tokenizer.model")
148
+
149
+ # Load model (see generate.py for full model class definition)
150
+ from generate import HebrewGPT, ModelConfig
151
+
152
+ config = ModelConfig(
153
+ vocab_size=32000,
154
+ width=2048,
155
+ depth=20,
156
+ n_heads=16,
157
+ head_dim=128,
158
+ max_seq_len=2048,
159
+ dropout=0.0, # No dropout at inference
160
+ )
161
+ model = HebrewGPT(config)
162
+
163
+ # Load weights
164
+ state_dict = torch.load("swa_best.pt", map_location="cpu")
165
+ model.load_state_dict(state_dict)
166
+ model.eval().to("cuda" if torch.cuda.is_available() else "cpu")
167
+
168
+ # Generate
169
+ prompt = "בראשית ברא אלוהים את"
170
+ input_ids = sp.Encode(prompt)
171
+ input_tensor = torch.tensor([input_ids], device=model.tok_emb.weight.device)
172
+
173
+ with torch.no_grad():
174
+ for _ in range(100):
175
+ logits = model(input_tensor)
176
+ next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
177
+ input_tensor = torch.cat([input_tensor, next_token], dim=1)
178
+ if input_tensor.shape[1] > 2048:
179
+ break
180
+
181
+ generated = sp.Decode(input_tensor[0].tolist())
182
+ print(generated)
183
+ ```
184
+
185
+ ### Full Example
186
+
187
+ See [`generate.py`](generate.py) in this repository for a complete standalone script with the full model architecture definition and generation utilities.
188
+
189
+ ## Hebrew Generation Examples
190
+
191
+ <div dir="rtl">
192
+
193
+ **Prompt**: בראשית ברא אלוהים את
194
+
195
+ **Generated**: בראשית ברא אלוהים את השמים ואת הארץ. והארץ היתה תוהו ובוהו וחושך על פני תהום...
196
+
197
+ ---
198
+
199
+ **Prompt**: בית המשפט העליון פסק כי
200
+
201
+ **Generated**: בית המשפט העליון פסק כי יש לקבל את הערעור ולהחזיר את התיק לדיון מחדש בפני בית המשפט המחוזי...
202
+
203
+ ---
204
+
205
+ **Prompt**: הטכנולוגיה המודרנית משנה את
206
+
207
+ **Generated**: הטכנולוגיה המודרנית משנה את האופן שבו אנו חיים, עובדים ומתקשרים זה עם זה...
208
+
209
+ </div>
210
+
211
+ *Note: Generated examples are illustrative. Actual outputs depend on sampling parameters.*
212
+
213
+ ## Limitations
214
+
215
+ - **Hebrew-only**: The model was trained exclusively on Hebrew text. It has limited ability to handle other languages.
216
+ - **No instruction tuning**: This is a base language model. It has not been fine-tuned for chat, instruction following, or safety alignment.
217
+ - **Context length**: Limited to 2,048 tokens.
218
+ - **Training data biases**: The model reflects biases present in its training data, which includes legal documents, literature, and web text.
219
+ - **Custom architecture**: Requires the provided model class to load — not compatible with standard `AutoModelForCausalLM`.
220
+ - **No safety filtering**: The model may generate inappropriate, biased, or factually incorrect content.
221
+
222
+ ## Citation
223
+
224
+ ```bibtex
225
+ @article{slasky2025hebrewgpt,
226
+ title={Hebrew Language Model Research via Agentic AI: Training HebrewGPT from Scratch},
227
+ author={Slasky, Ronnen},
228
+ year={2025},
229
+ url={https://d11k83yu06biio.cloudfront.net/paper/hebrew-autoresearch.html}
230
+ }
231
+ ```
232
+
233
+ ## Acknowledgments
234
+
235
+ - **Loki** — AI research assistant (Claude/Anthropic on OpenClaw) who assisted throughout the research process
236
+ - **Andrej Karpathy** — For the autoresearch framework and inspiration
237
+ - The Hebrew NLP community for open datasets
238
+
239
+ ## Contact
240
+
241
+ - **Author**: Ronnen Slasky
242
+ - **Email**: ronnen@slasky.com
243
+ - **GitHub**: [fatherRonnen/AgenticResearcher](https://github.com/fatherRonnen/AgenticResearcher)
best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be999b76db166cfcbfac73c568c7f717a3147ebb16afa6748c470341f2dabb53
3
+ size 8903540921
config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": ["HebrewGPT"],
3
+ "model_type": "hebrew-gpt",
4
+ "vocab_size": 32000,
5
+ "hidden_size": 2048,
6
+ "num_hidden_layers": 20,
7
+ "num_attention_heads": 16,
8
+ "head_dim": 128,
9
+ "intermediate_size": 5504,
10
+ "max_position_embeddings": 2048,
11
+ "dropout": 0.1,
12
+ "activation": "silu",
13
+ "norm_type": "rmsnorm",
14
+ "rope_theta": 10000.0,
15
+ "tie_word_embeddings": true,
16
+ "torch_dtype": "bfloat16",
17
+ "auto_map": {
18
+ "AutoModel": "generate.HebrewGPT"
19
+ }
20
+ }
generate.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ HebrewGPT-1B — Standalone generation script.
4
+
5
+ This script contains the full model architecture definition and can generate
6
+ Hebrew text without depending on the HuggingFace transformers library.
7
+
8
+ Requirements:
9
+ pip install torch sentencepiece
10
+
11
+ Usage:
12
+ python generate.py --prompt "בראשית ברא אלוהים את" --max_tokens 200
13
+ python generate.py --prompt "בית המשפט העליון פסק" --temperature 0.8 --top_k 50
14
+ """
15
+
16
+ import argparse
17
+ import math
18
+ from dataclasses import dataclass
19
+ from pathlib import Path
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ import sentencepiece as spm
25
+
26
+
27
+ # ─────────────────────────────────────────────────────────────────────────────
28
+ # Model Architecture
29
+ # ─────────────────────────────────────────────────────────────────────────────
30
+
31
+ @dataclass
32
+ class ModelConfig:
33
+ vocab_size: int = 32000
34
+ width: int = 2048
35
+ depth: int = 20
36
+ n_heads: int = 16
37
+ head_dim: int = 128
38
+ max_seq_len: int = 2048
39
+ dropout: float = 0.0 # Set to 0.0 for inference
40
+ rope_theta: float = 10000.0
41
+
42
+
43
+ class RMSNorm(nn.Module):
44
+ def __init__(self, dim: int, eps: float = 1e-6):
45
+ super().__init__()
46
+ self.weight = nn.Parameter(torch.ones(dim))
47
+ self.eps = eps
48
+
49
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
50
+ norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
51
+ return (x.float() * norm).type_as(x) * self.weight
52
+
53
+
54
+ class RotaryEmbedding(nn.Module):
55
+ def __init__(self, dim: int, max_seq_len: int = 2048, theta: float = 10000.0):
56
+ super().__init__()
57
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
58
+ self.register_buffer("inv_freq", inv_freq)
59
+ self._build_cache(max_seq_len)
60
+
61
+ def _build_cache(self, seq_len: int):
62
+ t = torch.arange(seq_len, dtype=self.inv_freq.dtype)
63
+ freqs = torch.outer(t, self.inv_freq)
64
+ self.register_buffer("cos_cached", freqs.cos(), persistent=False)
65
+ self.register_buffer("sin_cached", freqs.sin(), persistent=False)
66
+
67
+ def forward(self, seq_len: int):
68
+ if seq_len > self.cos_cached.shape[0]:
69
+ self._build_cache(seq_len)
70
+ return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
71
+
72
+
73
+ def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
74
+ """Apply RoPE with interleaved pattern (x[..., ::2], x[..., 1::2])."""
75
+ x_even = x[..., ::2]
76
+ x_odd = x[..., 1::2]
77
+
78
+ # cos/sin shape: (seq_len, head_dim//2) -> broadcast to (1, seq_len, 1, head_dim//2)
79
+ cos = cos.unsqueeze(0).unsqueeze(2) # (1, seq, 1, dim//2)
80
+ sin = sin.unsqueeze(0).unsqueeze(2)
81
+
82
+ out_even = x_even * cos - x_odd * sin
83
+ out_odd = x_even * sin + x_odd * cos
84
+
85
+ # Interleave back
86
+ out = torch.stack([out_even, out_odd], dim=-1).flatten(-2)
87
+ return out
88
+
89
+
90
+ class SwiGLU(nn.Module):
91
+ def __init__(self, width: int, hidden_dim: int, dropout: float = 0.0):
92
+ super().__init__()
93
+ self.w_gate = nn.Linear(width, hidden_dim, bias=False)
94
+ self.w_up = nn.Linear(width, hidden_dim, bias=False)
95
+ self.w_down = nn.Linear(hidden_dim, width, bias=False)
96
+ self.dropout = nn.Dropout(dropout)
97
+
98
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
99
+ return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)))
100
+
101
+
102
+ class Attention(nn.Module):
103
+ def __init__(self, config: ModelConfig):
104
+ super().__init__()
105
+ self.n_heads = config.n_heads
106
+ self.head_dim = config.head_dim
107
+ total_dim = config.n_heads * config.head_dim
108
+
109
+ self.q_proj = nn.Linear(config.width, total_dim, bias=False)
110
+ self.k_proj = nn.Linear(config.width, total_dim, bias=False)
111
+ self.v_proj = nn.Linear(config.width, total_dim, bias=False)
112
+ self.o_proj = nn.Linear(total_dim, config.width, bias=False)
113
+ self.dropout = nn.Dropout(config.dropout)
114
+
115
+ def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
116
+ mask: torch.Tensor = None) -> torch.Tensor:
117
+ B, T, _ = x.shape
118
+
119
+ q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim)
120
+ k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim)
121
+ v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim)
122
+
123
+ q = apply_rotary_emb(q, cos, sin)
124
+ k = apply_rotary_emb(k, cos, sin)
125
+
126
+ # (B, n_heads, T, head_dim)
127
+ q = q.transpose(1, 2)
128
+ k = k.transpose(1, 2)
129
+ v = v.transpose(1, 2)
130
+
131
+ # Scaled dot-product attention
132
+ scale = math.sqrt(self.head_dim)
133
+ attn = torch.matmul(q, k.transpose(-2, -1)) / scale
134
+
135
+ if mask is not None:
136
+ attn = attn.masked_fill(mask == 0, float("-inf"))
137
+
138
+ attn = F.softmax(attn, dim=-1)
139
+ attn = self.dropout(attn)
140
+
141
+ out = torch.matmul(attn, v) # (B, n_heads, T, head_dim)
142
+ out = out.transpose(1, 2).contiguous().view(B, T, -1)
143
+ return self.o_proj(out)
144
+
145
+
146
+ class TransformerBlock(nn.Module):
147
+ def __init__(self, config: ModelConfig):
148
+ super().__init__()
149
+ hidden_dim = int(2 * config.width * 4 / 3)
150
+ hidden_dim = ((hidden_dim + 63) // 64) * 64 # Round up to multiple of 64
151
+
152
+ self.ln1 = RMSNorm(config.width)
153
+ self.attn = Attention(config)
154
+ self.ln2 = RMSNorm(config.width)
155
+ self.mlp = SwiGLU(config.width, hidden_dim, config.dropout)
156
+
157
+ def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
158
+ mask: torch.Tensor = None) -> torch.Tensor:
159
+ x = x + self.attn(self.ln1(x), cos, sin, mask)
160
+ x = x + self.mlp(self.ln2(x))
161
+ return x
162
+
163
+
164
+ class HebrewGPT(nn.Module):
165
+ def __init__(self, config: ModelConfig):
166
+ super().__init__()
167
+ self.config = config
168
+
169
+ self.tok_emb = nn.Embedding(config.vocab_size, config.width)
170
+ self.dropout = nn.Dropout(config.dropout)
171
+ self.rotary = RotaryEmbedding(config.head_dim, config.max_seq_len, config.rope_theta)
172
+
173
+ self.layers = nn.ModuleList([
174
+ TransformerBlock(config) for _ in range(config.depth)
175
+ ])
176
+
177
+ self.ln_f = RMSNorm(config.width)
178
+ self.head = nn.Linear(config.width, config.vocab_size, bias=False)
179
+
180
+ # Weight tying
181
+ self.head.weight = self.tok_emb.weight
182
+
183
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
184
+ B, T = input_ids.shape
185
+ device = input_ids.device
186
+
187
+ x = self.dropout(self.tok_emb(input_ids))
188
+ cos, sin = self.rotary(T)
189
+ cos = cos.to(device)
190
+ sin = sin.to(device)
191
+
192
+ # Causal mask
193
+ mask = torch.tril(torch.ones(T, T, device=device)).unsqueeze(0).unsqueeze(0)
194
+
195
+ for layer in self.layers:
196
+ x = layer(x, cos, sin, mask)
197
+
198
+ x = self.ln_f(x)
199
+ logits = self.head(x)
200
+ return logits
201
+
202
+ @torch.no_grad()
203
+ def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 200,
204
+ temperature: float = 0.8, top_k: int = 50, top_p: float = 0.9) -> torch.Tensor:
205
+ """Autoregressive generation with top-k and top-p (nucleus) sampling."""
206
+ for _ in range(max_new_tokens):
207
+ # Crop to max context length
208
+ idx_cond = input_ids[:, -self.config.max_seq_len:]
209
+ logits = self(idx_cond)
210
+ logits = logits[:, -1, :] / temperature
211
+
212
+ # Top-k filtering
213
+ if top_k > 0:
214
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
215
+ logits[logits < v[:, [-1]]] = float("-inf")
216
+
217
+ # Top-p (nucleus) filtering
218
+ if top_p < 1.0:
219
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
220
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
221
+ sorted_indices_to_remove = cumulative_probs > top_p
222
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
223
+ sorted_indices_to_remove[:, 0] = False
224
+ for b in range(logits.shape[0]):
225
+ logits[b, sorted_indices[b, sorted_indices_to_remove[b]]] = float("-inf")
226
+
227
+ probs = F.softmax(logits, dim=-1)
228
+ next_token = torch.multinomial(probs, num_samples=1)
229
+ input_ids = torch.cat([input_ids, next_token], dim=1)
230
+
231
+ return input_ids
232
+
233
+
234
+ # ─────────────────────────────────────────────────────────────────────────────
235
+ # Main
236
+ # ─────────────────────────────────────────────────────────────────────────────
237
+
238
+ def main():
239
+ parser = argparse.ArgumentParser(description="HebrewGPT-1B Text Generation")
240
+ parser.add_argument("--model_path", type=str, default="swa_best.pt",
241
+ help="Path to model checkpoint (state_dict)")
242
+ parser.add_argument("--tokenizer_path", type=str, default="tokenizer.model",
243
+ help="Path to SentencePiece tokenizer model")
244
+ parser.add_argument("--prompt", type=str, default="בראשית ברא אלוהים את",
245
+ help="Hebrew text prompt")
246
+ parser.add_argument("--max_tokens", type=int, default=200,
247
+ help="Maximum new tokens to generate")
248
+ parser.add_argument("--temperature", type=float, default=0.8,
249
+ help="Sampling temperature")
250
+ parser.add_argument("--top_k", type=int, default=50,
251
+ help="Top-k sampling parameter")
252
+ parser.add_argument("--top_p", type=float, default=0.9,
253
+ help="Top-p (nucleus) sampling parameter")
254
+ parser.add_argument("--device", type=str, default=None,
255
+ help="Device (cuda/cpu/mps). Auto-detected if not set.")
256
+ # Model config overrides (for different model sizes)
257
+ parser.add_argument("--width", type=int, default=2048)
258
+ parser.add_argument("--depth", type=int, default=20)
259
+ parser.add_argument("--n_heads", type=int, default=16)
260
+ parser.add_argument("--head_dim", type=int, default=128)
261
+ parser.add_argument("--max_seq_len", type=int, default=2048)
262
+ args = parser.parse_args()
263
+
264
+ # Device selection
265
+ if args.device:
266
+ device = torch.device(args.device)
267
+ elif torch.cuda.is_available():
268
+ device = torch.device("cuda")
269
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
270
+ device = torch.device("mps")
271
+ else:
272
+ device = torch.device("cpu")
273
+
274
+ print(f"Using device: {device}")
275
+
276
+ # Load tokenizer
277
+ print(f"Loading tokenizer from {args.tokenizer_path}...")
278
+ sp = spm.SentencePieceProcessor()
279
+ sp.Load(args.tokenizer_path)
280
+
281
+ # Build model
282
+ config = ModelConfig(
283
+ vocab_size=32000,
284
+ width=args.width,
285
+ depth=args.depth,
286
+ n_heads=args.n_heads,
287
+ head_dim=args.head_dim,
288
+ max_seq_len=args.max_seq_len,
289
+ dropout=0.0,
290
+ )
291
+ print(f"Building HebrewGPT model (width={config.width}, depth={config.depth}, "
292
+ f"heads={config.n_heads})...")
293
+ model = HebrewGPT(config)
294
+
295
+ # Load weights
296
+ print(f"Loading weights from {args.model_path}...")
297
+ state_dict = torch.load(args.model_path, map_location="cpu", weights_only=True)
298
+ # Handle wrapped checkpoint format (dict with 'model' key)
299
+ if isinstance(state_dict, dict) and "model" in state_dict:
300
+ state_dict = state_dict["model"]
301
+ model.load_state_dict(state_dict)
302
+ model.eval().to(device)
303
+
304
+ param_count = sum(p.numel() for p in model.parameters())
305
+ print(f"Model loaded: {param_count:,} parameters")
306
+
307
+ # Encode prompt
308
+ print(f"\nPrompt: {args.prompt}")
309
+ input_ids = sp.Encode(args.prompt)
310
+ input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device)
311
+
312
+ # Generate
313
+ print("Generating...\n")
314
+ output_ids = model.generate(
315
+ input_tensor,
316
+ max_new_tokens=args.max_tokens,
317
+ temperature=args.temperature,
318
+ top_k=args.top_k,
319
+ top_p=args.top_p,
320
+ )
321
+
322
+ # Decode and print
323
+ generated_text = sp.Decode(output_ids[0].tolist())
324
+ print("=" * 60)
325
+ print(generated_text)
326
+ print("=" * 60)
327
+
328
+
329
+ if __name__ == "__main__":
330
+ main()
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "eos_token": "</s>",
4
+ "unk_token": "<unk>",
5
+ "pad_token": "<pad>"
6
+ }
swa_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91a7d39e2372f5492eed72d2413dcc53022b9664ce61002e086001a4a59b1311
3
+ size 4331021947
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ecfbf40eb7e4bf8fcc7d857e1110153319bd9ffd0cc575e8b79afa1b0bd68a28
3
+ size 825144
tokenizer_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "sentencepiece",
3
+ "sentencepiece_model_file": "tokenizer.model",
4
+ "vocab_size": 32000,
5
+ "bos_token": "<s>",
6
+ "eos_token": "</s>",
7
+ "unk_token": "<unk>",
8
+ "pad_token": "<pad>",
9
+ "model_max_length": 2048,
10
+ "clean_up_tokenization_spaces": false
11
+ }