|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import sentencepiece as spm |
|
|
import numpy as np |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CHECKPOINTS = { |
|
|
"pretrain_900k": "checkpoints/step_900000.pt", |
|
|
"continual_100k": "checkpoints/step_100000.pt", |
|
|
"continual_200k": "checkpoints/step_200000.pt", |
|
|
"continual_300k": "checkpoints/step_300000.pt", |
|
|
"continual_400k": "checkpoints/step_400000.pt", |
|
|
"continual_500k": "checkpoints/step_500000.pt", |
|
|
} |
|
|
|
|
|
TOKENIZER_MODEL_PATH = "tokenizer.model" |
|
|
VALID_BIN = "valid.bin" |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
BATCH_SIZE = 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
|
|
|
|
@dataclass |
|
|
class ModelArgs: |
|
|
dim: int = 768 |
|
|
n_layers: int = 12 |
|
|
n_heads: int = 12 |
|
|
n_kv_heads: int = 4 |
|
|
vocab_size: int = 32000 |
|
|
multiple_of: int = 256 |
|
|
ffn_dim_multiplier: float = 3.0 |
|
|
norm_eps: float = 1e-5 |
|
|
max_seq_len: int = 1024 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MemmapDataset(Dataset): |
|
|
def __init__(self, path: str, max_seq_len: int, stride=None): |
|
|
self.tokens = np.memmap(path, dtype=np.int32, mode="r") |
|
|
self.max_seq_len = max_seq_len |
|
|
self.stride = stride or max_seq_len // 2 |
|
|
|
|
|
max_start = len(self.tokens) - (max_seq_len + 1) |
|
|
if max_start <= 0: |
|
|
raise ValueError("Dataset too small") |
|
|
|
|
|
self.starts = list(range(0, max_start, self.stride)) |
|
|
if self.starts[-1] != max_start: |
|
|
self.starts.append(max_start) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.starts) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
i = self.starts[idx] |
|
|
seq = torch.from_numpy( |
|
|
self.tokens[i:i + self.max_seq_len + 1].copy() |
|
|
).long() |
|
|
return seq[:-1], seq[1:] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
def __init__(self, dim, eps=1e-6): |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
|
|
|
|
def forward(self, x): |
|
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight |
|
|
|
|
|
def precompute_freqs_cis(dim, seq_len, theta=10000.0): |
|
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) / dim)) |
|
|
t = torch.arange(seq_len) |
|
|
freqs = torch.outer(t, freqs) |
|
|
return freqs.cos(), freqs.sin() |
|
|
|
|
|
def apply_rotary_emb(x, cos, sin): |
|
|
x1, x2 = x[..., 0::2], x[..., 1::2] |
|
|
cos = cos.unsqueeze(0).unsqueeze(2) |
|
|
sin = sin.unsqueeze(0).unsqueeze(2) |
|
|
out = torch.empty_like(x) |
|
|
out[..., 0::2] = x1 * cos - x2 * sin |
|
|
out[..., 1::2] = x1 * sin + x2 * cos |
|
|
return out |
|
|
|
|
|
class Attention(nn.Module): |
|
|
def __init__(self, args): |
|
|
super().__init__() |
|
|
self.n_heads = args.n_heads |
|
|
self.head_dim = args.dim // args.n_heads |
|
|
self.n_kv_heads = args.n_kv_heads |
|
|
self.repeat_kv = args.n_heads // args.n_kv_heads |
|
|
|
|
|
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) |
|
|
self.wk = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False) |
|
|
self.wv = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False) |
|
|
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) |
|
|
|
|
|
def forward(self, x, cos, sin): |
|
|
B, T, _ = x.shape |
|
|
q = self.wq(x).view(B, T, self.n_heads, self.head_dim) |
|
|
k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim) |
|
|
v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim) |
|
|
k = k.repeat_interleave(self.repeat_kv, dim=2) |
|
|
v = v.repeat_interleave(self.repeat_kv, dim=2) |
|
|
q = apply_rotary_emb(q, cos, sin) |
|
|
k = apply_rotary_emb(k, cos, sin) |
|
|
q = q.transpose(1,2) |
|
|
k = k.transpose(1,2) |
|
|
v = v.transpose(1,2) |
|
|
out = F.scaled_dot_product_attention(q, k, v, is_causal=True) |
|
|
out = out.transpose(1,2).contiguous().view(B, T, -1) |
|
|
return self.wo(out) |
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
def __init__(self, dim, multiple_of, mult): |
|
|
super().__init__() |
|
|
hidden = multiple_of * ((int(dim * mult) + multiple_of -1)//multiple_of) |
|
|
self.w1 = nn.Linear(dim, hidden, bias=False) |
|
|
self.w2 = nn.Linear(hidden, dim, bias=False) |
|
|
self.w3 = nn.Linear(dim, hidden, bias=False) |
|
|
def forward(self,x): |
|
|
return self.w2(F.silu(self.w1(x))*self.w3(x)) |
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
|
def __init__(self, args): |
|
|
super().__init__() |
|
|
self.attn = Attention(args) |
|
|
self.ffn = FeedForward(args.dim, args.multiple_of, args.ffn_dim_multiplier) |
|
|
self.attn_norm = RMSNorm(args.dim, args.norm_eps) |
|
|
self.ffn_norm = RMSNorm(args.dim, args.norm_eps) |
|
|
def forward(self, x, cos, sin): |
|
|
x = x + self.attn(self.attn_norm(x), cos, sin) |
|
|
x = x + self.ffn(self.ffn_norm(x)) |
|
|
return x |
|
|
|
|
|
class Transformer(nn.Module): |
|
|
def __init__(self, args): |
|
|
super().__init__() |
|
|
self.tok_emb = nn.Embedding(args.vocab_size, args.dim) |
|
|
self.layers = nn.ModuleList([TransformerBlock(args) for _ in range(args.n_layers)]) |
|
|
self.norm = RMSNorm(args.dim, args.norm_eps) |
|
|
self.out = nn.Linear(args.dim, args.vocab_size, bias=False) |
|
|
cos, sin = precompute_freqs_cis(args.dim//args.n_heads, args.max_seq_len*2) |
|
|
self.register_buffer("cos_cached", cos, persistent=False) |
|
|
self.register_buffer("sin_cached", sin, persistent=False) |
|
|
def forward(self, tokens): |
|
|
B, T = tokens.shape |
|
|
h = self.tok_emb(tokens) |
|
|
cos = self.cos_cached[:T] |
|
|
sin = self.sin_cached[:T] |
|
|
for layer in self.layers: |
|
|
h = layer(h, cos, sin) |
|
|
h = self.norm(h) |
|
|
return self.out(h) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_checkpoint(path, valid_loader, tokenizer, args): |
|
|
ckpt = torch.load(path, map_location="cpu", weights_only=False) |
|
|
|
|
|
|
|
|
if isinstance(ckpt, dict) and "model_state_dict" in ckpt: |
|
|
state_dict = ckpt["model_state_dict"] |
|
|
else: |
|
|
state_dict = ckpt |
|
|
|
|
|
model = Transformer(args) |
|
|
model.load_state_dict(state_dict) |
|
|
model.to(DEVICE) |
|
|
model.eval() |
|
|
|
|
|
total_loss = 0.0 |
|
|
total_tokens = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for x, y in valid_loader: |
|
|
x = x.to(DEVICE) |
|
|
y = y.to(DEVICE) |
|
|
|
|
|
logits = model(x) |
|
|
loss = F.cross_entropy( |
|
|
logits.view(-1, logits.size(-1)), |
|
|
y.view(-1), |
|
|
ignore_index=tokenizer.pad_id(), |
|
|
reduction="sum", |
|
|
) |
|
|
|
|
|
total_loss += loss.item() |
|
|
total_tokens += (y != tokenizer.pad_id()).sum().item() |
|
|
|
|
|
return total_loss / total_tokens |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
args = ModelArgs() |
|
|
tokenizer = spm.SentencePieceProcessor(model_file=TOKENIZER_MODEL_PATH) |
|
|
args.vocab_size = tokenizer.vocab_size() |
|
|
|
|
|
|
|
|
valid_ds = MemmapDataset(VALID_BIN, args.max_seq_len) |
|
|
valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True) |
|
|
|
|
|
print("="*70) |
|
|
print("Checkpoint comparison (validation)") |
|
|
print("="*70) |
|
|
|
|
|
results = {} |
|
|
for name, path in CHECKPOINTS.items(): |
|
|
print(f"[Eval] {name}") |
|
|
loss = evaluate_checkpoint(path, valid_loader, tokenizer, args) |
|
|
ppl = math.exp(loss) |
|
|
results[name] = (loss, ppl) |
|
|
print(f" Val loss: {loss:.6f}") |
|
|
print(f" Perplexity: {ppl:.2f}") |
|
|
print("-"*50) |
|
|
|
|
|
print("\nSummary:") |
|
|
for name, (loss, ppl) in results.items(): |
|
|
print(f"{name:20s} | loss {loss:.6f} | ppl {ppl:.2f}") |
|
|
print("="*70) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|