NovaAI-0.1 / train.py
NovaAI6868's picture
Upload folder using huggingface_hub
cd75f6e verified
# ==========================
# train.py
# ==========================
# Usage:
# python train.py --data_path all.jsonl --spm_model spm.model
# Requirements:
# pip install torch sentencepiece tqdm
import os
import json
import sentencepiece as spm
from argparse import ArgumentParser
from tqdm import tqdm
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
# --------------------------
# Simple Decoder-only Transformer (GPT-like)
# --------------------------
class GPTConfig:
def __init__(self, vocab_size, n_layer=12, n_head=12, n_embd=768, block_size=1024, dropout=0.1):
self.vocab_size = vocab_size
self.n_layer = n_layer
self.n_head = n_head
self.n_embd = n_embd
self.block_size = block_size
self.dropout = dropout
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
self.n_head = config.n_head
self.dropout = nn.Dropout(config.dropout)
def forward(self, x, attn_mask=None):
B, T, C = x.size()
qkv = self.c_attn(x) # (B, T, 3*C)
q, k, v = qkv.split(C, dim=2)
# reshape for multi-head
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1,2) # (B, nh, T, hs)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1,2)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1,2)
att = (q @ k.transpose(-2, -1)) / (C // self.n_head) ** 0.5 # (B, nh, T, T)
# causal mask
mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
att = att.masked_fill(mask == 0, float('-inf'))
att = torch.softmax(att, dim=-1)
att = self.dropout(att)
y = att @ v # (B, nh, T, hs)
y = y.transpose(1,2).contiguous().view(B, T, C)
y = self.c_proj(y)
y = self.dropout(y)
return y
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln2 = nn.LayerNorm(config.n_embd)
self.mlp = nn.Sequential(
nn.Linear(config.n_embd, 4 * config.n_embd),
nn.GELU(),
nn.Linear(4 * config.n_embd, config.n_embd),
nn.Dropout(config.dropout),
)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
self.drop = nn.Dropout(config.dropout)
self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.block_size = config.block_size
# initialize
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
B, T = idx.size()
assert T <= self.block_size
token_embeddings = self.tok_emb(idx) # (B, T, C)
x = token_embeddings + self.pos_emb[:, :T, :]
x = self.drop(x)
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
logits = self.head(x)
loss = None
if targets is not None:
# shift logits and targets for next-token prediction
logits = logits[:, :-1, :].contiguous()
targets = targets[:, 1:].contiguous()
loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
# --------------------------
# Dataset and helpers
# --------------------------
class QADataset(Dataset):
def __init__(self, path, sp_model, block_size=1024):
self.examples = []
self.block_size = block_size
self.sp = sp_model
with open(path, 'r', encoding='utf-8') as f:
for line in f:
obj = json.loads(line)
q = obj.get('question','')
a = obj.get('answer','')
# format: <bos> question <sep> answer <eos>
text = "<s>" + q + "<sep>" + a + "</s>"
ids = self.sp.EncodeAsIds(text)
if len(ids) > 2:
# truncate or pad later
self.examples.append(ids)
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
ids = self.examples[idx]
# pad/truncate to block_size
if len(ids) > self.block_size:
ids = ids[:self.block_size]
else:
ids = ids + [0] * (self.block_size - len(ids))
return torch.tensor(ids, dtype=torch.long)
def collate_fn(batch):
batch = torch.stack(batch, dim=0)
return batch, batch # inputs and targets are same sequence for causal LM
# --------------------------
# Main training loop
# --------------------------
def train(args):
# prepare sentencepiece model (if not exists, train it)
if not os.path.exists(args.spm_model):
print('Training SentencePiece model...')
# create a temporary file with concatenated text
tmp_txt = 'spm_input.txt'
with open(args.data_path, 'r', encoding='utf-8') as fin, open(tmp_txt, 'w', encoding='utf-8') as fout:
for line in fin:
obj = json.loads(line)
text = obj.get('question','') + '\n' + obj.get('answer','') + '\n'
fout.write(text)
spm.SentencePieceTrainer.Train(f'--input={tmp_txt} --model_prefix=spm --vocab_size={args.vocab_size} --model_type=bpe --character_coverage=0.9995')
os.remove(tmp_txt)
sp = spm.SentencePieceProcessor()
sp.Load('spm.model')
else:
sp = spm.SentencePieceProcessor()
sp.Load(args.spm_model)
dataset = QADataset(args.data_path, sp, block_size=args.block_size)
print(f"Loaded {len(dataset)} examples")
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x))
config = GPTConfig(vocab_size=args.vocab_size, n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd, block_size=args.block_size, dropout=args.dropout)
model = GPT(config).to(args.device)
# print parameter count
param_count = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {param_count:,} ({param_count/1e9:.3f} B)")
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
model.train()
for epoch in range(args.epochs):
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{args.epochs}")
for batch_inputs, batch_targets in pbar:
batch_inputs = batch_inputs.to(args.device)
batch_targets = batch_targets.to(args.device)
logits, loss = model(batch_inputs, targets=batch_targets)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
pbar.set_postfix(loss=loss.item())
# save checkpoint each epoch
os.makedirs(args.out_dir, exist_ok=True)
torch.save({'model_state': model.state_dict(), 'sp_model': args.spm_model, 'config': vars(config)}, os.path.join(args.out_dir, f'checkpoint_final.pt'))
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--data_path', type=str, default='all.jsonl')
parser.add_argument('--spm_model', type=str, default='spm.model')
parser.add_argument('--vocab_size', type=int, default=32000)
parser.add_argument('--block_size', type=int, default=1024)
parser.add_argument('--n_layer', type=int, default=3)
parser.add_argument('--n_head', type=int, default=3)
parser.add_argument('--n_embd', type=int, default=768)
parser.add_argument('--batch_size', type=int, default=30)
parser.add_argument('--epochs', type=int, default=300)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--out_dir', type=str, default='checkpoints')
args = parser.parse_args()
train(args)