Rzoro commited on
Commit
ba0c987
Β·
verified Β·
1 Parent(s): 0291a68

Add Erebus-medium checkpoint at step 20000 (20% trained, loss~8.79)

Browse files
Files changed (5) hide show
  1. README.md +91 -0
  2. config.json +9 -0
  3. inference_hf.py +272 -0
  4. model.safetensors +3 -0
  5. tokenizer.json +3 -0
README.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ tags:
6
+ - erebus
7
+ - language-model
8
+ - causal-lm
9
+ - foundation-model
10
+ - pytorch
11
+ pipeline_tag: text-generation
12
+ ---
13
+
14
+ # Erebus-Medium
15
+
16
+ **Erebus-Medium** is a decoder-only causal language model (~454M parameters)
17
+ trained from scratch as part of the [Erebus](https://github.com/m-np/erebus)
18
+ foundation-model project.
19
+
20
+ ## Model architecture
21
+
22
+ | Attribute | Value |
23
+ |----------------|-------|
24
+ | Architecture | Decoder-only Transformer (GPT-style) |
25
+ | Parameters | ~454M |
26
+ | `d_model` | 1024 |
27
+ | `n_heads` | 16 |
28
+ | `n_layers` | 24 |
29
+ | `d_ff` | 4096 |
30
+ | `max_seq_len` | 1024 |
31
+ | Vocabulary | 50,257 (GPT-2 BPE) |
32
+ | Positional enc | RoPE |
33
+ | FFN activation | SwiGLU |
34
+ | Normalisation | RMSNorm (pre-norm) |
35
+ | Training steps | 20,000 |
36
+
37
+ ## Training details
38
+
39
+ - **Dataset**: FineWeb (`sample-10BT`, ~10 B tokens from CommonCrawl)
40
+ - **Tokeniser**: tiktoken `gpt2` encoding (vocab = 50 257)
41
+ - **Optimiser**: AdamW (β₁=0.9, Ξ²β‚‚=0.95, weight decay=0.1)
42
+ - **Schedule**: Cosine decay with linear warm-up
43
+ - **Precision**: bfloat16 mixed precision
44
+
45
+ ## How to use
46
+
47
+ ```python
48
+ import torch
49
+ from huggingface_hub import hf_hub_download
50
+ from safetensors.torch import load_file
51
+
52
+ # Install: pip install huggingface_hub safetensors tiktoken torch
53
+
54
+ # Download model weights
55
+ weights_path = hf_hub_download("Rzoro/erebus-medium", "model.safetensors")
56
+ config_path = hf_hub_download("Rzoro/erebus-medium", "config.json")
57
+
58
+ import json
59
+ with open(config_path) as f:
60
+ cfg_dict = json.load(f)
61
+
62
+ # Build the model (requires erebus repo on your Python path)
63
+ import sys; sys.path.insert(0, "/path/to/erebus")
64
+ from model import ErebusConfig, Erebus
65
+
66
+ config = ErebusConfig(**cfg_dict)
67
+ model = Erebus(config)
68
+ model.load_state_dict(load_file(weights_path))
69
+ model.eval()
70
+
71
+ # Generate text
72
+ import tiktoken
73
+ enc = tiktoken.get_encoding("gpt2")
74
+ prompt = "The foundation of artificial intelligence is"
75
+ input_ids = torch.tensor([enc.encode(prompt)], dtype=torch.long)
76
+ output = model.generate(input_ids, max_new_tokens=100, temperature=0.8)
77
+ print(enc.decode(output[0].tolist()))
78
+ ```
79
+
80
+ ## Fine-tuning
81
+
82
+ Because weights are in standard PyTorch format and the architecture is a
83
+ plain decoder-only transformer, you can fine-tune with:
84
+
85
+ - **Full fine-tuning**: load weights and train as usual (small model fits on one GPU)
86
+ - **LoRA / QLoRA**: apply PEFT adapters for parameter-efficient fine-tuning
87
+ - **Instruction tuning**: format data with a `### Instruction:` / `### Response:` template
88
+
89
+ ## License
90
+
91
+ [MIT](LICENSE)
config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 50257,
3
+ "d_model": 1024,
4
+ "n_heads": 16,
5
+ "n_layers": 24,
6
+ "d_ff": 4096,
7
+ "max_seq_len": 1024,
8
+ "dropout": 0.1
9
+ }
inference_hf.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference_hf.py β€” Self-contained inference script for Erebus models on HuggingFace.
3
+
4
+ This file has zero dependency on the rest of the erebus repo.
5
+ Copy it anywhere and run it as long as you have:
6
+ pip install torch tiktoken huggingface_hub safetensors
7
+
8
+ Usage
9
+ -----
10
+ # From HuggingFace Hub
11
+ python inference_hf.py --hf_repo Rzoro/erebus-small --prompt "The future of AI"
12
+
13
+ # Interactive
14
+ python inference_hf.py --hf_repo Rzoro/erebus-small --interactive
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import argparse
20
+ import json
21
+ import math
22
+ from dataclasses import dataclass
23
+ from typing import Optional
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+
29
+
30
+ # ── Model definition (self-contained copy) ────────────────────────────────────
31
+
32
+ @dataclass
33
+ class ErebusConfig:
34
+ vocab_size: int = 50257
35
+ d_model: int = 768
36
+ n_heads: int = 12
37
+ n_layers: int = 12
38
+ d_ff: int = 3072
39
+ max_seq_len: int = 1024
40
+ dropout: float = 0.0
41
+
42
+
43
+ class RotaryPositionEmbedding(nn.Module):
44
+ def __init__(self, head_dim: int, max_seq_len: int = 4096):
45
+ super().__init__()
46
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
47
+ positions = torch.arange(max_seq_len).float()
48
+ freqs = torch.outer(positions, inv_freq)
49
+ cos = freqs.cos().repeat_interleave(2, dim=-1).unsqueeze(0).unsqueeze(0)
50
+ sin = freqs.sin().repeat_interleave(2, dim=-1).unsqueeze(0).unsqueeze(0)
51
+ self.register_buffer("cos_cached", cos, persistent=False)
52
+ self.register_buffer("sin_cached", sin, persistent=False)
53
+
54
+ @staticmethod
55
+ def _rotate_half(x):
56
+ x1, x2 = x[..., 0::2], x[..., 1::2]
57
+ return torch.stack([-x2, x1], dim=-1).flatten(-2)
58
+
59
+ def forward(self, q, k):
60
+ T = q.size(2)
61
+ cos, sin = self.cos_cached[:, :, :T], self.sin_cached[:, :, :T]
62
+ return q * cos + self._rotate_half(q) * sin, k * cos + self._rotate_half(k) * sin
63
+
64
+
65
+ class MultiHeadAttention(nn.Module):
66
+ def __init__(self, d_model, n_heads, max_seq_len, dropout=0.0):
67
+ super().__init__()
68
+ self.n_heads = n_heads
69
+ self.head_dim = d_model // n_heads
70
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
71
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
72
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
73
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
74
+ self.rope = RotaryPositionEmbedding(self.head_dim, max_seq_len)
75
+ self._flash = hasattr(F, "scaled_dot_product_attention")
76
+
77
+ def forward(self, x):
78
+ B, T, C = x.shape
79
+ def split(t): return t.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
80
+ Q, K, V = split(self.q_proj(x)), split(self.k_proj(x)), split(self.v_proj(x))
81
+ Q, K = self.rope(Q, K)
82
+ if self._flash:
83
+ out = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
84
+ else:
85
+ scale = math.sqrt(self.head_dim)
86
+ scores = (Q @ K.transpose(-2, -1)) / scale
87
+ causal = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool))
88
+ scores = scores.masked_fill(~causal, float("-inf"))
89
+ out = torch.softmax(scores, dim=-1) @ V
90
+ return self.out_proj(out.transpose(1, 2).contiguous().view(B, T, C))
91
+
92
+
93
+ class SwiGLU(nn.Module):
94
+ def __init__(self, d_model, d_ff):
95
+ super().__init__()
96
+ d_ff = (d_ff // 64) * 64
97
+ self.w1 = nn.Linear(d_model, d_ff, bias=False)
98
+ self.w3 = nn.Linear(d_model, d_ff, bias=False)
99
+ self.w2 = nn.Linear(d_ff, d_model, bias=False)
100
+
101
+ def forward(self, x):
102
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
103
+
104
+
105
+ class TransformerBlock(nn.Module):
106
+ def __init__(self, cfg: ErebusConfig):
107
+ super().__init__()
108
+ self.norm1 = nn.RMSNorm(cfg.d_model)
109
+ self.attn = MultiHeadAttention(cfg.d_model, cfg.n_heads, cfg.max_seq_len)
110
+ self.norm2 = nn.RMSNorm(cfg.d_model)
111
+ self.ffn = SwiGLU(cfg.d_model, cfg.d_ff)
112
+
113
+ def forward(self, x):
114
+ x = x + self.attn(self.norm1(x))
115
+ x = x + self.ffn(self.norm2(x))
116
+ return x
117
+
118
+
119
+ class Erebus(nn.Module):
120
+ def __init__(self, cfg: ErebusConfig):
121
+ super().__init__()
122
+ self.cfg = cfg
123
+ self.token_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
124
+ self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
125
+ self.norm = nn.RMSNorm(cfg.d_model)
126
+ self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
127
+ self.lm_head.weight = self.token_emb.weight
128
+
129
+ @torch.no_grad()
130
+ def generate(
131
+ self,
132
+ input_ids: torch.Tensor,
133
+ max_new_tokens: int = 200,
134
+ temperature: float = 0.8,
135
+ top_k: int = 50,
136
+ top_p: float = 0.95,
137
+ repetition_penalty: float = 1.2,
138
+ eos_token_id: Optional[int] = None,
139
+ ) -> torch.Tensor:
140
+ self.eval()
141
+ for _ in range(max_new_tokens):
142
+ ctx = input_ids[:, -self.cfg.max_seq_len:]
143
+ x = self.token_emb(ctx)
144
+ for block in self.blocks:
145
+ x = block(x)
146
+ logits = self.lm_head(self.norm(x))[:, -1, :]
147
+
148
+ if repetition_penalty != 1.0:
149
+ for tok in input_ids[0].unique():
150
+ logits[0, tok] /= repetition_penalty
151
+
152
+ logits = logits / max(temperature, 1e-8)
153
+
154
+ if top_k > 0:
155
+ cutoff, _ = torch.topk(logits, min(top_k, logits.size(-1)))
156
+ logits[logits < cutoff[:, [-1]]] = float("-inf")
157
+
158
+ if top_p < 1.0:
159
+ sorted_logits, sorted_idx = torch.sort(logits, descending=True)
160
+ cum = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
161
+ sorted_logits[cum - F.softmax(sorted_logits, dim=-1) > top_p] = float("-inf")
162
+ logits.scatter_(1, sorted_idx, sorted_logits)
163
+
164
+ next_tok = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
165
+ input_ids = torch.cat([input_ids, next_tok], dim=1)
166
+ if eos_token_id is not None and next_tok.item() == eos_token_id:
167
+ break
168
+ return input_ids
169
+
170
+
171
+ # ── Loading helpers ───────────────────────────────────────────────────────────
172
+
173
+ def load_from_hf(repo_id: str, device: torch.device) -> Erebus:
174
+ from huggingface_hub import hf_hub_download
175
+ from safetensors.torch import load_file
176
+
177
+ print(f"Downloading {repo_id} from HuggingFace Hub …")
178
+ cfg_path = hf_hub_download(repo_id, "config.json")
179
+ weights_path = hf_hub_download(repo_id, "model.safetensors")
180
+
181
+ with open(cfg_path) as f:
182
+ cfg = ErebusConfig(**json.load(f))
183
+
184
+ model = Erebus(cfg)
185
+ model.load_state_dict(load_file(weights_path), strict=False)
186
+ model.eval().to(device)
187
+ n = sum(p.numel() for p in model.parameters())
188
+ print(f"Loaded : {repo_id} ({n/1e6:.1f} M params)\n")
189
+ return model
190
+
191
+
192
+ def load_from_checkpoint(path: str, device: torch.device) -> Erebus:
193
+ ckpt = torch.load(path, map_location="cpu", weights_only=False)
194
+ model = Erebus(ckpt["config"])
195
+ model.load_state_dict(ckpt["model_state_dict"])
196
+ model.eval().to(device)
197
+ n = sum(p.numel() for p in model.parameters())
198
+ print(f"Loaded : {path} ({n/1e6:.1f} M params, step={ckpt.get('step','?')})\n")
199
+ return model
200
+
201
+
202
+ # ── CLI ───────────────────────────────────────────────────────────────────────
203
+
204
+ def parse_args():
205
+ p = argparse.ArgumentParser(description="Erebus inference β€” works with local or HF weights.")
206
+ src = p.add_mutually_exclusive_group(required=True)
207
+ src.add_argument("--hf_repo", help="HuggingFace repo id e.g. Rzoro/erebus-small")
208
+ src.add_argument("--checkpoint", help="Local .pt checkpoint path")
209
+
210
+ inp = p.add_mutually_exclusive_group()
211
+ inp.add_argument("--prompt", default=None)
212
+ inp.add_argument("--interactive", action="store_true")
213
+
214
+ p.add_argument("--max_new_tokens", type=int, default=200)
215
+ p.add_argument("--temperature", type=float, default=0.8)
216
+ p.add_argument("--top_k", type=int, default=50)
217
+ p.add_argument("--top_p", type=float, default=0.95)
218
+ p.add_argument("--repetition_penalty", type=float, default=1.2)
219
+ p.add_argument("--device", default=None)
220
+ return p.parse_args()
221
+
222
+
223
+ def main():
224
+ import tiktoken
225
+ args = parse_args()
226
+ device = torch.device(
227
+ args.device if args.device
228
+ else ("cuda" if torch.cuda.is_available() else "cpu")
229
+ )
230
+ print(f"Device : {device}")
231
+
232
+ model = load_from_hf(args.hf_repo, device) if args.hf_repo \
233
+ else load_from_checkpoint(args.checkpoint, device)
234
+
235
+ enc = tiktoken.get_encoding("gpt2")
236
+
237
+ def run(prompt: str) -> str:
238
+ ids = torch.tensor([enc.encode(prompt)], dtype=torch.long).to(device)
239
+ out = model.generate(
240
+ ids,
241
+ max_new_tokens=args.max_new_tokens,
242
+ temperature=args.temperature,
243
+ top_k=args.top_k,
244
+ top_p=args.top_p,
245
+ repetition_penalty=args.repetition_penalty,
246
+ eos_token_id=enc.eot_token,
247
+ )
248
+ return enc.decode(out[0].tolist())
249
+
250
+ if args.interactive:
251
+ print("═" * 60)
252
+ print("Erebus β€” interactive mode (quit / Ctrl-C to exit)")
253
+ print("═" * 60)
254
+ while True:
255
+ try:
256
+ prompt = input("\nPrompt > ").strip()
257
+ except (EOFError, KeyboardInterrupt):
258
+ print("\nBye!"); break
259
+ if not prompt or prompt.lower() in ("quit", "exit", "q"):
260
+ print("Bye!"); break
261
+ print("\n" + "─" * 60)
262
+ print(run(prompt))
263
+ print("─" * 60)
264
+ else:
265
+ prompt = args.prompt or input("Prompt > ").strip()
266
+ print("\n" + "─" * 60)
267
+ print(run(prompt))
268
+ print("─" * 60)
269
+
270
+
271
+ if __name__ == "__main__":
272
+ main()
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be8c17d1ad353cb9f83b71feae17af52f39628bcabe7cec08218e4eb9e787152
3
+ size 1816688016
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "encoding": "gpt2"
3
+ }