# ========================== # train.py # ========================== # Usage: # python train.py --data_path all.jsonl --spm_model spm.model # Requirements: # pip install torch sentencepiece tqdm import os import json import sentencepiece as spm from argparse import ArgumentParser from tqdm import tqdm import torch from torch import nn from torch.utils.data import Dataset, DataLoader # -------------------------- # Simple Decoder-only Transformer (GPT-like) # -------------------------- class GPTConfig: def __init__(self, vocab_size, n_layer=12, n_head=12, n_embd=768, block_size=1024, dropout=0.1): self.vocab_size = vocab_size self.n_layer = n_layer self.n_head = n_head self.n_embd = n_embd self.block_size = block_size self.dropout = dropout class CausalSelfAttention(nn.Module): def __init__(self, config): super().__init__() assert config.n_embd % config.n_head == 0 self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) self.c_proj = nn.Linear(config.n_embd, config.n_embd) self.n_head = config.n_head self.dropout = nn.Dropout(config.dropout) def forward(self, x, attn_mask=None): B, T, C = x.size() qkv = self.c_attn(x) # (B, T, 3*C) q, k, v = qkv.split(C, dim=2) # reshape for multi-head q = q.view(B, T, self.n_head, C // self.n_head).transpose(1,2) # (B, nh, T, hs) k = k.view(B, T, self.n_head, C // self.n_head).transpose(1,2) v = v.view(B, T, self.n_head, C // self.n_head).transpose(1,2) att = (q @ k.transpose(-2, -1)) / (C // self.n_head) ** 0.5 # (B, nh, T, T) # causal mask mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T) att = att.masked_fill(mask == 0, float('-inf')) att = torch.softmax(att, dim=-1) att = self.dropout(att) y = att @ v # (B, nh, T, hs) y = y.transpose(1,2).contiguous().view(B, T, C) y = self.c_proj(y) y = self.dropout(y) return y class Block(nn.Module): def __init__(self, config): super().__init__() self.ln1 = nn.LayerNorm(config.n_embd) self.attn = CausalSelfAttention(config) self.ln2 = nn.LayerNorm(config.n_embd) self.mlp = nn.Sequential( nn.Linear(config.n_embd, 4 * config.n_embd), nn.GELU(), nn.Linear(4 * config.n_embd, config.n_embd), nn.Dropout(config.dropout), ) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class GPT(nn.Module): def __init__(self, config): super().__init__() self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) self.drop = nn.Dropout(config.dropout) self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)]) self.ln_f = nn.LayerNorm(config.n_embd) self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.block_size = config.block_size # initialize self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, idx, targets=None): B, T = idx.size() assert T <= self.block_size token_embeddings = self.tok_emb(idx) # (B, T, C) x = token_embeddings + self.pos_emb[:, :T, :] x = self.drop(x) for block in self.blocks: x = block(x) x = self.ln_f(x) logits = self.head(x) loss = None if targets is not None: # shift logits and targets for next-token prediction logits = logits[:, :-1, :].contiguous() targets = targets[:, 1:].contiguous() loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) return logits, loss # -------------------------- # Dataset and helpers # -------------------------- class QADataset(Dataset): def __init__(self, path, sp_model, block_size=1024): self.examples = [] self.block_size = block_size self.sp = sp_model with open(path, 'r', encoding='utf-8') as f: for line in f: obj = json.loads(line) q = obj.get('question','') a = obj.get('answer','') # format: question answer text = "" + q + "" + a + "" ids = self.sp.EncodeAsIds(text) if len(ids) > 2: # truncate or pad later self.examples.append(ids) def __len__(self): return len(self.examples) def __getitem__(self, idx): ids = self.examples[idx] # pad/truncate to block_size if len(ids) > self.block_size: ids = ids[:self.block_size] else: ids = ids + [0] * (self.block_size - len(ids)) return torch.tensor(ids, dtype=torch.long) def collate_fn(batch): batch = torch.stack(batch, dim=0) return batch, batch # inputs and targets are same sequence for causal LM # -------------------------- # Main training loop # -------------------------- def train(args): # prepare sentencepiece model (if not exists, train it) if not os.path.exists(args.spm_model): print('Training SentencePiece model...') # create a temporary file with concatenated text tmp_txt = 'spm_input.txt' with open(args.data_path, 'r', encoding='utf-8') as fin, open(tmp_txt, 'w', encoding='utf-8') as fout: for line in fin: obj = json.loads(line) text = obj.get('question','') + '\n' + obj.get('answer','') + '\n' fout.write(text) spm.SentencePieceTrainer.Train(f'--input={tmp_txt} --model_prefix=spm --vocab_size={args.vocab_size} --model_type=bpe --character_coverage=0.9995') os.remove(tmp_txt) sp = spm.SentencePieceProcessor() sp.Load('spm.model') else: sp = spm.SentencePieceProcessor() sp.Load(args.spm_model) dataset = QADataset(args.data_path, sp, block_size=args.block_size) print(f"Loaded {len(dataset)} examples") dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x)) config = GPTConfig(vocab_size=args.vocab_size, n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd, block_size=args.block_size, dropout=args.dropout) model = GPT(config).to(args.device) # print parameter count param_count = sum(p.numel() for p in model.parameters()) print(f"Model parameters: {param_count:,} ({param_count/1e9:.3f} B)") optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) model.train() for epoch in range(args.epochs): pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{args.epochs}") for batch_inputs, batch_targets in pbar: batch_inputs = batch_inputs.to(args.device) batch_targets = batch_targets.to(args.device) logits, loss = model(batch_inputs, targets=batch_targets) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() pbar.set_postfix(loss=loss.item()) # save checkpoint each epoch os.makedirs(args.out_dir, exist_ok=True) torch.save({'model_state': model.state_dict(), 'sp_model': args.spm_model, 'config': vars(config)}, os.path.join(args.out_dir, f'checkpoint_final.pt')) if __name__ == '__main__': parser = ArgumentParser() parser.add_argument('--data_path', type=str, default='all.jsonl') parser.add_argument('--spm_model', type=str, default='spm.model') parser.add_argument('--vocab_size', type=int, default=32000) parser.add_argument('--block_size', type=int, default=1024) parser.add_argument('--n_layer', type=int, default=3) parser.add_argument('--n_head', type=int, default=3) parser.add_argument('--n_embd', type=int, default=768) parser.add_argument('--batch_size', type=int, default=30) parser.add_argument('--epochs', type=int, default=300) parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--dropout', type=float, default=0.1) parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') parser.add_argument('--out_dir', type=str, default='checkpoints') args = parser.parse_args() train(args)