|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" |
|
|
|
|
|
import math |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import Optional |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from accelerate import Accelerator |
|
|
from tqdm import tqdm |
|
|
import sentencepiece as spm |
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
print(f"[Info] Torch version: {torch.__version__}") |
|
|
print(f"[Info] CUDA available: {torch.cuda.is_available()}") |
|
|
if torch.cuda.is_available(): |
|
|
print(f"[Info] CUDA version: {torch.version.cuda}") |
|
|
|
|
|
from flash_attn import flash_attn_func |
|
|
FLASH_ATTENTION_2 = True |
|
|
print("[OK] Flash Attention 2 enabled") |
|
|
except Exception: |
|
|
FLASH_ATTENTION_2 = False |
|
|
print("[WARN] Flash Attention 2 not available – using PyTorch SDPA") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelArgs: |
|
|
dim: int = 768 |
|
|
n_layers: int = 12 |
|
|
n_heads: int = 12 |
|
|
n_kv_heads: int = 4 |
|
|
vocab_size: int = 32000 |
|
|
multiple_of: int = 256 |
|
|
ffn_dim_multiplier: float = 3.0 |
|
|
norm_eps: float = 1e-5 |
|
|
max_seq_len: int = 1024 |
|
|
|
|
|
|
|
|
SAVE_EVERY_STEPS = 100_000 |
|
|
TOKENIZER_MODEL_PATH = "tokenizer.model" |
|
|
TRAIN_BIN = "dataset.bin" |
|
|
VALID_BIN = "valid.bin" |
|
|
CHECKPOINT_DIR = "checkpoints" |
|
|
os.makedirs(CHECKPOINT_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
def __init__(self, dim, eps=1e-6): |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
|
|
|
|
def forward(self, x): |
|
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight |
|
|
|
|
|
|
|
|
def precompute_freqs_cis(dim, seq_len, theta=10000.0): |
|
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) / dim)) |
|
|
t = torch.arange(seq_len) |
|
|
freqs = torch.outer(t, freqs) |
|
|
return freqs.cos(), freqs.sin() |
|
|
|
|
|
|
|
|
def apply_rotary_emb(x, cos, sin): |
|
|
x1, x2 = x[..., 0::2], x[..., 1::2] |
|
|
cos = cos.unsqueeze(0).unsqueeze(2) |
|
|
sin = sin.unsqueeze(0).unsqueeze(2) |
|
|
out = torch.empty_like(x) |
|
|
out[..., 0::2] = x1 * cos - x2 * sin |
|
|
out[..., 1::2] = x1 * sin + x2 * cos |
|
|
return out |
|
|
|
|
|
|
|
|
class Attention(nn.Module): |
|
|
def __init__(self, args): |
|
|
super().__init__() |
|
|
self.n_heads = args.n_heads |
|
|
self.head_dim = args.dim // args.n_heads |
|
|
self.n_kv_heads = args.n_kv_heads |
|
|
self.repeat_kv = args.n_heads // args.n_kv_heads |
|
|
|
|
|
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) |
|
|
self.wk = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False) |
|
|
self.wv = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False) |
|
|
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) |
|
|
|
|
|
def forward(self, x, cos, sin): |
|
|
B, T, _ = x.shape |
|
|
|
|
|
q = self.wq(x).view(B, T, self.n_heads, self.head_dim) |
|
|
k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim) |
|
|
v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim) |
|
|
|
|
|
k = k.repeat_interleave(self.repeat_kv, dim=2) |
|
|
v = v.repeat_interleave(self.repeat_kv, dim=2) |
|
|
|
|
|
q = apply_rotary_emb(q, cos, sin) |
|
|
k = apply_rotary_emb(k, cos, sin) |
|
|
|
|
|
q = q.transpose(1, 2) |
|
|
k = k.transpose(1, 2) |
|
|
v = v.transpose(1, 2) |
|
|
|
|
|
if FLASH_ATTENTION_2: |
|
|
out = flash_attn_func(q, k, v, causal=True) |
|
|
else: |
|
|
out = F.scaled_dot_product_attention(q, k, v, is_causal=True) |
|
|
|
|
|
out = out.transpose(1, 2).contiguous().view(B, T, -1) |
|
|
return self.wo(out) |
|
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
def __init__(self, dim, multiple_of, mult): |
|
|
super().__init__() |
|
|
hidden = multiple_of * ((int(dim * mult) + multiple_of - 1) // multiple_of) |
|
|
self.w1 = nn.Linear(dim, hidden, bias=False) |
|
|
self.w2 = nn.Linear(hidden, dim, bias=False) |
|
|
self.w3 = nn.Linear(dim, hidden, bias=False) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
|
def __init__(self, args): |
|
|
super().__init__() |
|
|
self.attn = Attention(args) |
|
|
self.ffn = FeedForward(args.dim, args.multiple_of, args.ffn_dim_multiplier) |
|
|
self.attn_norm = RMSNorm(args.dim, args.norm_eps) |
|
|
self.ffn_norm = RMSNorm(args.dim, args.norm_eps) |
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def forward(self, x, cos, sin): |
|
|
x = x + self.attn(self.attn_norm(x), cos, sin) |
|
|
|
|
|
if self.training and self.gradient_checkpointing: |
|
|
x = x + torch.utils.checkpoint.checkpoint( |
|
|
self._ffn, x, use_reentrant=False |
|
|
) |
|
|
else: |
|
|
x = x + self.ffn(self.ffn_norm(x)) |
|
|
return x |
|
|
|
|
|
def _ffn(self, x): |
|
|
return self.ffn(self.ffn_norm(x)) |
|
|
|
|
|
|
|
|
class Transformer(nn.Module): |
|
|
def __init__(self, args): |
|
|
super().__init__() |
|
|
self.tok_emb = nn.Embedding(args.vocab_size, args.dim) |
|
|
self.layers = nn.ModuleList([TransformerBlock(args) for _ in range(args.n_layers)]) |
|
|
self.norm = RMSNorm(args.dim, args.norm_eps) |
|
|
self.out = nn.Linear(args.dim, args.vocab_size, bias=False) |
|
|
|
|
|
cos, sin = precompute_freqs_cis(args.dim // args.n_heads, args.max_seq_len * 2) |
|
|
self.register_buffer("cos_cached", cos, persistent=False) |
|
|
self.register_buffer("sin_cached", sin, persistent=False) |
|
|
|
|
|
self.apply(self._init) |
|
|
|
|
|
def gradient_checkpointing_enable(self): |
|
|
for layer in self.layers: |
|
|
layer.gradient_checkpointing = True |
|
|
print("[OK] Gradient checkpointing enabled") |
|
|
|
|
|
def _init(self, m): |
|
|
if isinstance(m, (nn.Linear, nn.Embedding)): |
|
|
nn.init.normal_(m.weight, std=0.02) |
|
|
|
|
|
def forward(self, tokens): |
|
|
B, T = tokens.shape |
|
|
h = self.tok_emb(tokens) |
|
|
cos = self.cos_cached[:T] |
|
|
sin = self.sin_cached[:T] |
|
|
|
|
|
for layer in self.layers: |
|
|
h = layer(h, cos, sin) |
|
|
|
|
|
h = self.norm(h) |
|
|
return self.out(h) |
|
|
|
|
|
def get_num_params(self): |
|
|
return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MemmapDataset(Dataset): |
|
|
def __init__(self, path: str, max_seq_len: int, stride: Optional[int] = None): |
|
|
self.tokens = np.memmap(path, dtype=np.int32, mode="r") |
|
|
self.max_seq_len = max_seq_len |
|
|
self.stride = stride or max_seq_len // 2 |
|
|
|
|
|
max_start = len(self.tokens) - (max_seq_len + 1) |
|
|
if max_start <= 0: |
|
|
raise ValueError("Dataset too small for the given max_seq_len") |
|
|
|
|
|
self.starts = list(range(0, max_start, self.stride)) |
|
|
if self.starts[-1] != max_start: |
|
|
self.starts.append(max_start) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.starts) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
i = self.starts[idx] |
|
|
seq = torch.from_numpy( |
|
|
self.tokens[i:i + self.max_seq_len + 1].copy() |
|
|
).long() |
|
|
return seq[:-1], seq[1:] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_text(model, tokenizer, prompts, |
|
|
max_new_tokens=128, temperature=0.8, top_p=0.95, eos_id=1): |
|
|
model.eval() |
|
|
device = next(model.parameters()).device |
|
|
results = {} |
|
|
|
|
|
for prompt in prompts: |
|
|
ids = tokenizer.encode(prompt) |
|
|
x = torch.tensor([ids], device=device) |
|
|
|
|
|
for _ in range(max_new_tokens): |
|
|
logits = model(x)[0, -1] / temperature |
|
|
sorted_logits, sorted_idx = torch.sort(logits, descending=True) |
|
|
probs = torch.softmax(sorted_logits, dim=0) |
|
|
|
|
|
cum_probs = probs.cumsum(dim=0) |
|
|
mask = cum_probs > top_p |
|
|
mask[1:] = mask[:-1].clone() |
|
|
mask[0] = False |
|
|
|
|
|
logits[sorted_idx[mask]] = -float("inf") |
|
|
probs = torch.softmax(logits, dim=0) |
|
|
|
|
|
next_tok = torch.multinomial(probs, 1) |
|
|
x = torch.cat([x, next_tok.unsqueeze(0)], dim=1) |
|
|
|
|
|
if next_tok.item() == eos_id: |
|
|
break |
|
|
|
|
|
results[prompt] = tokenizer.decode(x[0].tolist()) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train( |
|
|
model, |
|
|
train_ds, |
|
|
valid_ds, |
|
|
tokenizer, |
|
|
args, |
|
|
batch_size=1, |
|
|
grad_accum=8, |
|
|
epochs=1, |
|
|
lr=1e-5, |
|
|
warmup_steps=500, |
|
|
): |
|
|
accelerator = Accelerator( |
|
|
mixed_precision="bf16" if torch.cuda.is_bf16_supported() else "fp16", |
|
|
gradient_accumulation_steps=grad_accum, |
|
|
) |
|
|
|
|
|
model.gradient_checkpointing_enable() |
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_ds, |
|
|
batch_size=batch_size, |
|
|
shuffle=True, |
|
|
num_workers=2, |
|
|
pin_memory=True, |
|
|
) |
|
|
|
|
|
valid_loader = DataLoader( |
|
|
valid_ds, |
|
|
batch_size=batch_size, |
|
|
shuffle=False, |
|
|
num_workers=2, |
|
|
pin_memory=True, |
|
|
) |
|
|
|
|
|
optimizer = torch.optim.AdamW( |
|
|
model.parameters(), |
|
|
lr=lr, |
|
|
betas=(0.9, 0.95), |
|
|
weight_decay=0.01, |
|
|
) |
|
|
|
|
|
total_steps = math.ceil(len(train_loader) / grad_accum) * epochs |
|
|
|
|
|
def lr_lambda(step): |
|
|
if step < warmup_steps: |
|
|
return step / max(1, warmup_steps) |
|
|
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) |
|
|
return 0.5 * (1.0 + math.cos(math.pi * progress)) |
|
|
|
|
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
|
|
|
model, optimizer, train_loader, valid_loader, scheduler = accelerator.prepare( |
|
|
model, optimizer, train_loader, valid_loader, scheduler |
|
|
) |
|
|
|
|
|
if accelerator.is_main_process: |
|
|
eff_bs = batch_size * grad_accum * accelerator.num_processes |
|
|
print(f"Model params: {model.get_num_params():,}") |
|
|
print(f"Effective batch size: {eff_bs}") |
|
|
print(f"Total optimizer steps: {total_steps}") |
|
|
print(f"Flash Attention: {FLASH_ATTENTION_2}") |
|
|
print("-" * 60) |
|
|
|
|
|
global_step = 0 |
|
|
best_val = float("inf") |
|
|
|
|
|
for epoch in range(epochs): |
|
|
model.train() |
|
|
running_loss = 0.0 |
|
|
|
|
|
pbar = tqdm( |
|
|
train_loader, |
|
|
disable=not accelerator.is_local_main_process, |
|
|
desc=f"Epoch {epoch+1}/{epochs}", |
|
|
) |
|
|
|
|
|
for step, (x, y) in enumerate(pbar): |
|
|
with accelerator.accumulate(model): |
|
|
logits = model(x) |
|
|
loss = F.cross_entropy( |
|
|
logits.view(-1, logits.size(-1)), |
|
|
y.view(-1), |
|
|
ignore_index=tokenizer.pad_id(), |
|
|
) |
|
|
accelerator.backward(loss) |
|
|
|
|
|
if accelerator.sync_gradients: |
|
|
accelerator.clip_grad_norm_(model.parameters(), 1.0) |
|
|
optimizer.step() |
|
|
scheduler.step() |
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
|
global_step += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if accelerator.is_main_process and global_step % SAVE_EVERY_STEPS == 0: |
|
|
ckpt_path = f"{CHECKPOINT_DIR}/step_{global_step}.pt" |
|
|
checkpoint = { |
|
|
"step": global_step, |
|
|
"model_state_dict": accelerator.unwrap_model(model).state_dict(), |
|
|
"optimizer_state_dict": optimizer.state_dict(), |
|
|
"scheduler_state_dict": scheduler.state_dict(), |
|
|
"model_args": args, |
|
|
} |
|
|
torch.save(checkpoint, ckpt_path) |
|
|
print(f"[Checkpoint] Saved complete checkpoint at step {global_step}") |
|
|
|
|
|
prompts = [ |
|
|
"Once upon a time", |
|
|
"In a distant future", |
|
|
"First step to build a rocket", |
|
|
"Capital city of France", |
|
|
"Artificial intelligence will", |
|
|
] |
|
|
|
|
|
samples = generate_text( |
|
|
accelerator.unwrap_model(model), |
|
|
tokenizer, |
|
|
prompts, |
|
|
max_new_tokens=100, |
|
|
temperature=0.8, |
|
|
top_p=0.95, |
|
|
) |
|
|
|
|
|
print(f"[Sample generation @ step {global_step}]") |
|
|
for prompt, text in samples.items(): |
|
|
print(f"Prompt: {prompt}") |
|
|
print(f"Generated: {text}") |
|
|
print("-" * 50) |
|
|
|
|
|
running_loss += loss.item() |
|
|
pbar.set_postfix( |
|
|
loss=f"{running_loss/(step+1):.4f}", |
|
|
lr=f"{scheduler.get_last_lr()[0]:.2e}", |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.eval() |
|
|
val_loss = 0.0 |
|
|
with torch.no_grad(): |
|
|
for x, y in valid_loader: |
|
|
logits = model(x) |
|
|
loss = F.cross_entropy( |
|
|
logits.view(-1, logits.size(-1)), |
|
|
y.view(-1), |
|
|
ignore_index=tokenizer.pad_id(), |
|
|
) |
|
|
val_loss += loss.item() |
|
|
|
|
|
val_loss /= len(valid_loader) |
|
|
|
|
|
accelerator.print( |
|
|
f"[Epoch {epoch+1}] Train Loss: {running_loss/len(train_loader):.6f} | " |
|
|
f"Val Loss: {val_loss:.6f}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
|
prompts = [ |
|
|
"Once upon a time", |
|
|
"In a distant future", |
|
|
"First step to build a rocket", |
|
|
"Capital city of France", |
|
|
"Artificial intelligence will", |
|
|
] |
|
|
|
|
|
samples = generate_text( |
|
|
accelerator.unwrap_model(model), |
|
|
tokenizer, |
|
|
prompts, |
|
|
max_new_tokens=100, |
|
|
temperature=0.8, |
|
|
top_p=0.95, |
|
|
) |
|
|
|
|
|
print("[Sample generation]") |
|
|
for prompt, text in samples.items(): |
|
|
print(f"Prompt: {prompt}") |
|
|
print(f"Generated: {text}") |
|
|
print("-" * 50) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
|
checkpoint = { |
|
|
"step": global_step, |
|
|
"model_state_dict": accelerator.unwrap_model(model).state_dict(), |
|
|
"optimizer_state_dict": optimizer.state_dict(), |
|
|
"scheduler_state_dict": scheduler.state_dict(), |
|
|
"model_args": args, |
|
|
} |
|
|
torch.save(checkpoint, f"{CHECKPOINT_DIR}/final_model.pt") |
|
|
print("Training complete.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = ModelArgs() |
|
|
|
|
|
tokenizer = spm.SentencePieceProcessor(model_file=TOKENIZER_MODEL_PATH) |
|
|
args.vocab_size = tokenizer.vocab_size() |
|
|
|
|
|
train_ds = MemmapDataset(TRAIN_BIN, args.max_seq_len) |
|
|
valid_ds = MemmapDataset(VALID_BIN, args.max_seq_len) |
|
|
|
|
|
model = Transformer(args) |
|
|
|
|
|
''' |
|
|
RESUME_FROM = "checkpoints/step_200000.pt" |
|
|
|
|
|
if os.path.exists(RESUME_FROM): |
|
|
print(f"[Resume] Loading checkpoint from {RESUME_FROM}") |
|
|
checkpoint = torch.load(RESUME_FROM, map_location="cpu") |
|
|
|
|
|
# Support both old format (direct state_dict) and new format (checkpoint dict) |
|
|
if "model_state_dict" in checkpoint: |
|
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
print(f"[Resume] Loaded model from step {checkpoint.get('step', 'unknown')}") |
|
|
else: |
|
|
# Old format: checkpoint is directly the model state_dict |
|
|
model.load_state_dict(checkpoint) |
|
|
print(f"[Resume] Loaded model (old format)") |
|
|
''' |
|
|
|
|
|
train( |
|
|
model, |
|
|
train_ds, |
|
|
valid_ds, |
|
|
tokenizer, |
|
|
args, |
|
|
batch_size=1, |
|
|
grad_accum=8, |
|
|
epochs=1, |
|
|
lr=1e-5, |
|
|
) |
|
|
|