| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import json |
| | import sentencepiece as spm |
| | from argparse import ArgumentParser |
| | from tqdm import tqdm |
| |
|
| | import torch |
| | from torch import nn |
| | from torch.utils.data import Dataset, DataLoader |
| |
|
| | |
| | |
| | |
| |
|
| | class GPTConfig: |
| | def __init__(self, vocab_size, n_layer=12, n_head=12, n_embd=768, block_size=1024, dropout=0.1): |
| | self.vocab_size = vocab_size |
| | self.n_layer = n_layer |
| | self.n_head = n_head |
| | self.n_embd = n_embd |
| | self.block_size = block_size |
| | self.dropout = dropout |
| |
|
| |
|
| | class CausalSelfAttention(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | assert config.n_embd % config.n_head == 0 |
| | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) |
| | self.c_proj = nn.Linear(config.n_embd, config.n_embd) |
| | self.n_head = config.n_head |
| | self.dropout = nn.Dropout(config.dropout) |
| |
|
| | def forward(self, x, attn_mask=None): |
| | B, T, C = x.size() |
| | qkv = self.c_attn(x) |
| | q, k, v = qkv.split(C, dim=2) |
| | |
| | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1,2) |
| | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1,2) |
| | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1,2) |
| |
|
| | att = (q @ k.transpose(-2, -1)) / (C // self.n_head) ** 0.5 |
| | |
| | mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T) |
| | att = att.masked_fill(mask == 0, float('-inf')) |
| | att = torch.softmax(att, dim=-1) |
| | att = self.dropout(att) |
| |
|
| | y = att @ v |
| | y = y.transpose(1,2).contiguous().view(B, T, C) |
| | y = self.c_proj(y) |
| | y = self.dropout(y) |
| | return y |
| |
|
| |
|
| | class Block(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.ln1 = nn.LayerNorm(config.n_embd) |
| | self.attn = CausalSelfAttention(config) |
| | self.ln2 = nn.LayerNorm(config.n_embd) |
| | self.mlp = nn.Sequential( |
| | nn.Linear(config.n_embd, 4 * config.n_embd), |
| | nn.GELU(), |
| | nn.Linear(4 * config.n_embd, config.n_embd), |
| | nn.Dropout(config.dropout), |
| | ) |
| |
|
| | def forward(self, x): |
| | x = x + self.attn(self.ln1(x)) |
| | x = x + self.mlp(self.ln2(x)) |
| | return x |
| |
|
| |
|
| | class GPT(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) |
| | self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) |
| | self.drop = nn.Dropout(config.dropout) |
| | self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)]) |
| | self.ln_f = nn.LayerNorm(config.n_embd) |
| | self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
| | self.block_size = config.block_size |
| |
|
| | |
| | self.apply(self._init_weights) |
| |
|
| | def _init_weights(self, module): |
| | if isinstance(module, nn.Linear): |
| | nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| | if module.bias is not None: |
| | nn.init.zeros_(module.bias) |
| | elif isinstance(module, nn.Embedding): |
| | nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| |
|
| | def forward(self, idx, targets=None): |
| | B, T = idx.size() |
| | assert T <= self.block_size |
| | token_embeddings = self.tok_emb(idx) |
| | x = token_embeddings + self.pos_emb[:, :T, :] |
| | x = self.drop(x) |
| | for block in self.blocks: |
| | x = block(x) |
| | x = self.ln_f(x) |
| | logits = self.head(x) |
| |
|
| | loss = None |
| | if targets is not None: |
| | |
| | logits = logits[:, :-1, :].contiguous() |
| | targets = targets[:, 1:].contiguous() |
| | loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) |
| | return logits, loss |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class QADataset(Dataset): |
| | def __init__(self, path, sp_model, block_size=1024): |
| | self.examples = [] |
| | self.block_size = block_size |
| | self.sp = sp_model |
| | with open(path, 'r', encoding='utf-8') as f: |
| | for line in f: |
| | obj = json.loads(line) |
| | q = obj.get('question','') |
| | a = obj.get('answer','') |
| | |
| | text = "<s>" + q + "<sep>" + a + "</s>" |
| | ids = self.sp.EncodeAsIds(text) |
| | if len(ids) > 2: |
| | |
| | self.examples.append(ids) |
| |
|
| | def __len__(self): |
| | return len(self.examples) |
| |
|
| | def __getitem__(self, idx): |
| | ids = self.examples[idx] |
| | |
| | if len(ids) > self.block_size: |
| | ids = ids[:self.block_size] |
| | else: |
| | ids = ids + [0] * (self.block_size - len(ids)) |
| | return torch.tensor(ids, dtype=torch.long) |
| |
|
| |
|
| | def collate_fn(batch): |
| | batch = torch.stack(batch, dim=0) |
| | return batch, batch |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def train(args): |
| | |
| | if not os.path.exists(args.spm_model): |
| | print('Training SentencePiece model...') |
| | |
| | tmp_txt = 'spm_input.txt' |
| | with open(args.data_path, 'r', encoding='utf-8') as fin, open(tmp_txt, 'w', encoding='utf-8') as fout: |
| | for line in fin: |
| | obj = json.loads(line) |
| | text = obj.get('question','') + '\n' + obj.get('answer','') + '\n' |
| | fout.write(text) |
| | spm.SentencePieceTrainer.Train(f'--input={tmp_txt} --model_prefix=spm --vocab_size={args.vocab_size} --model_type=bpe --character_coverage=0.9995') |
| | os.remove(tmp_txt) |
| | sp = spm.SentencePieceProcessor() |
| | sp.Load('spm.model') |
| | else: |
| | sp = spm.SentencePieceProcessor() |
| | sp.Load(args.spm_model) |
| |
|
| | dataset = QADataset(args.data_path, sp, block_size=args.block_size) |
| | print(f"Loaded {len(dataset)} examples") |
| | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x)) |
| |
|
| | config = GPTConfig(vocab_size=args.vocab_size, n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd, block_size=args.block_size, dropout=args.dropout) |
| | model = GPT(config).to(args.device) |
| |
|
| | |
| | param_count = sum(p.numel() for p in model.parameters()) |
| | print(f"Model parameters: {param_count:,} ({param_count/1e9:.3f} B)") |
| |
|
| | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) |
| |
|
| | model.train() |
| | for epoch in range(args.epochs): |
| | pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{args.epochs}") |
| | for batch_inputs, batch_targets in pbar: |
| | batch_inputs = batch_inputs.to(args.device) |
| | batch_targets = batch_targets.to(args.device) |
| | logits, loss = model(batch_inputs, targets=batch_targets) |
| | optimizer.zero_grad() |
| | loss.backward() |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| | optimizer.step() |
| | pbar.set_postfix(loss=loss.item()) |
| |
|
| | |
| | os.makedirs(args.out_dir, exist_ok=True) |
| | torch.save({'model_state': model.state_dict(), 'sp_model': args.spm_model, 'config': vars(config)}, os.path.join(args.out_dir, f'checkpoint_final.pt')) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | parser = ArgumentParser() |
| | parser.add_argument('--data_path', type=str, default='all.jsonl') |
| | parser.add_argument('--spm_model', type=str, default='spm.model') |
| | parser.add_argument('--vocab_size', type=int, default=32000) |
| | parser.add_argument('--block_size', type=int, default=1024) |
| | parser.add_argument('--n_layer', type=int, default=3) |
| | parser.add_argument('--n_head', type=int, default=3) |
| | parser.add_argument('--n_embd', type=int, default=768) |
| | parser.add_argument('--batch_size', type=int, default=30) |
| | parser.add_argument('--epochs', type=int, default=300) |
| | parser.add_argument('--lr', type=float, default=3e-4) |
| | parser.add_argument('--dropout', type=float, default=0.1) |
| | parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') |
| | parser.add_argument('--out_dir', type=str, default='checkpoints') |
| | args = parser.parse_args() |
| | train(args) |
| |
|
| |
|