| import os |
| import time |
| import math |
| import torch |
| from model import GPTConfig, GPT |
|
|
| import numpy as np |
|
|
| out_dir = '/home/user/350m_SmaLLMPro_Final' |
| init_from = '/home/user/350m_fineweb' |
| dataset = 'alpaca_cleaned_mixed' |
|
|
| batch_size = 4 |
| gradient_accumulation_steps = 32 |
| block_size = 1024 |
| learning_rate = 2e-5 |
| max_iters = 1500 |
| weight_decay = 0.1 |
| dropout = 0.1 |
| warmup_iters = 0 |
| min_lr = 3e-6 |
| beta1, beta2 = 0.9, 0.95 |
| device = 'cuda' |
| dtype = 'bfloat16' |
| compile = True |
| save_interval = 500 |
|
|
| os.makedirs(out_dir, exist_ok=True) |
| torch.manual_seed(1337) |
| device_type = 'cuda' if 'cuda' in device else 'cpu' |
| ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] |
| ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) |
|
|
| data_dir = os.path.join('data', dataset) |
| train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') |
| train_mask = np.memmap(os.path.join(data_dir, 'train_mask.bin'), dtype=np.uint8, mode='r') |
|
|
| def get_batch(): |
| ix = torch.randint(len(train_data) - block_size, (batch_size,)) |
| x = torch.stack([torch.from_numpy((train_data[i:i+block_size]).astype(np.int64)) for i in ix]) |
| y = torch.stack([torch.from_numpy((train_data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) |
| m = torch.stack([torch.from_numpy((train_mask[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) |
| |
| y[m == 0] = -100 |
| |
| x, y = x.to(device), y.to(device) |
| return x, y |
|
|
| print(f"๐ฅ Loading Pretraining-Checkpoint from {init_from}...") |
| ckpt_files = sorted([f for f in os.listdir(init_from) if f.endswith('.pt')]) |
| if not ckpt_files: |
| raise FileNotFoundError("No checkpoint found in init_from directory!") |
|
|
| ckpt_path = os.path.join(init_from, ckpt_files[-1]) |
| checkpoint = torch.load(ckpt_path, map_location=device) |
| gptconf = GPTConfig(**checkpoint['model_args']) |
| model = GPT(gptconf) |
| state_dict = checkpoint['model'] |
|
|
| unwanted_prefix = '_orig_mod.' |
| for k,v in list(state_dict.items()): |
| if k.startswith(unwanted_prefix): |
| state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
|
|
| model.load_state_dict(state_dict) |
| model.to(device) |
|
|
| if compile: |
| print("๐ Compiling Model...") |
| model = torch.compile(model) |
|
|
| optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type) |
| scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) |
|
|
| def get_lr(it): |
| if it < warmup_iters: return learning_rate * it / warmup_iters |
| if it > max_iters: return min_lr |
| decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters) |
| coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) |
| return min_lr + coeff * (learning_rate - min_lr) |
|
|
| print(f"๐ ๏ธ Starting Finetuning...") |
| model.train() |
| t0 = time.time() |
|
|
| for iter_num in range(max_iters + 1): |
| lr = get_lr(iter_num) |
| for param_group in optimizer.param_groups: |
| param_group['lr'] = lr |
|
|
| for micro_step in range(gradient_accumulation_steps): |
| X, Y = get_batch() |
| with ctx: |
| logits, loss = model(X, Y) |
| loss = loss / gradient_accumulation_steps |
| scaler.scale(loss).backward() |
|
|
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| scaler.step(optimizer) |
| scaler.update() |
| optimizer.zero_grad(set_to_none=True) |
|
|
| if iter_num % 10 == 0: |
| dt = time.time() - t0 |
| print(f"Iter {iter_num}: Loss {loss.item()*gradient_accumulation_steps:.4f}, Time {dt*1000:.2f}ms, LR {lr:.2e}") |
| t0 = time.time() |
|
|
| if iter_num > 0 and iter_num % save_interval == 0: |
| checkpoint_name = f'SmaLLMPro_iter_{iter_num}.pt' |
| save_path = os.path.join(out_dir, checkpoint_name) |
| print(f"๐พ Saving checkpoint: {checkpoint_name}") |
| raw_model = model._orig_mod if compile else model |
| checkpoint_data = { |
| 'model': raw_model.state_dict(), |
| 'model_args': checkpoint['model_args'], |
| 'iter_num': iter_num, |
| 'lr': lr, |
| } |
| torch.save(checkpoint_data, save_path) |
|
|
| print(f"๐พ Finetuning done. Saving SmaLLMPro...") |
| final_checkpoint = { |
| 'model': model.state_dict() if not compile else model._orig_mod.state_dict(), |
| 'model_args': checkpoint['model_args'], |
| 'config': checkpoint.get('config', {}), |
| } |
| torch.save(final_checkpoint, os.path.join(out_dir, 'SmaLLMPro_Final.pt')) |
| print("โ
SmaLLMPro saved successfully!") |