Apex-1-Instruct-350M / finetune.py
LH-Tech-AI's picture
Update finetune.py
e1845ba verified
raw
history blame
4.4 kB
import os
import time
import math
import torch
from model import GPTConfig, GPT
import numpy as np
out_dir = '/home/user/350m_SmaLLMPro_Final'
init_from = '/home/user/350m_fineweb'
dataset = 'alpaca_cleaned_mixed'
batch_size = 4
gradient_accumulation_steps = 32
block_size = 1024
learning_rate = 2e-5
max_iters = 1500
weight_decay = 0.1
dropout = 0.1
warmup_iters = 0
min_lr = 3e-6
beta1, beta2 = 0.9, 0.95
device = 'cuda'
dtype = 'bfloat16'
compile = True
save_interval = 500
os.makedirs(out_dir, exist_ok=True)
torch.manual_seed(1337)
device_type = 'cuda' if 'cuda' in device else 'cpu'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
data_dir = os.path.join('data', dataset)
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
train_mask = np.memmap(os.path.join(data_dir, 'train_mask.bin'), dtype=np.uint8, mode='r')
def get_batch():
ix = torch.randint(len(train_data) - block_size, (batch_size,))
x = torch.stack([torch.from_numpy((train_data[i:i+block_size]).astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy((train_data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
m = torch.stack([torch.from_numpy((train_mask[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
y[m == 0] = -100
x, y = x.to(device), y.to(device)
return x, y
print(f"๐Ÿ“ฅ Loading Pretraining-Checkpoint from {init_from}...")
ckpt_files = sorted([f for f in os.listdir(init_from) if f.endswith('.pt')])
if not ckpt_files:
raise FileNotFoundError("No checkpoint found in init_from directory!")
ckpt_path = os.path.join(init_from, ckpt_files[-1])
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.to(device)
if compile:
print("๐Ÿš€ Compiling Model...")
model = torch.compile(model)
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
def get_lr(it):
if it < warmup_iters: return learning_rate * it / warmup_iters
if it > max_iters: return min_lr
decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters)
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return min_lr + coeff * (learning_rate - min_lr)
print(f"๐Ÿ› ๏ธ Starting Finetuning...")
model.train()
t0 = time.time()
for iter_num in range(max_iters + 1):
lr = get_lr(iter_num)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
for micro_step in range(gradient_accumulation_steps):
X, Y = get_batch()
with ctx:
logits, loss = model(X, Y)
loss = loss / gradient_accumulation_steps
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if iter_num % 10 == 0:
dt = time.time() - t0
print(f"Iter {iter_num}: Loss {loss.item()*gradient_accumulation_steps:.4f}, Time {dt*1000:.2f}ms, LR {lr:.2e}")
t0 = time.time()
if iter_num > 0 and iter_num % save_interval == 0:
checkpoint_name = f'SmaLLMPro_iter_{iter_num}.pt'
save_path = os.path.join(out_dir, checkpoint_name)
print(f"๐Ÿ’พ Saving checkpoint: {checkpoint_name}")
raw_model = model._orig_mod if compile else model
checkpoint_data = {
'model': raw_model.state_dict(),
'model_args': checkpoint['model_args'],
'iter_num': iter_num,
'lr': lr,
}
torch.save(checkpoint_data, save_path)
print(f"๐Ÿ’พ Finetuning done. Saving SmaLLMPro...")
final_checkpoint = {
'model': model.state_dict() if not compile else model._orig_mod.state_dict(),
'model_args': checkpoint['model_args'],
'config': checkpoint.get('config', {}),
}
torch.save(final_checkpoint, os.path.join(out_dir, 'SmaLLMPro_Final.pt'))
print("โœ… SmaLLMPro saved successfully!")