| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, IterableDataset |
| from torch.amp import autocast, GradScaler |
| from model import TransformerConfig, TransformerLanguageModel |
| from tokenizer import load_tokenizer, SpecialToken |
| import json |
| import random |
| import os |
| from tqdm import tqdm |
|
|
| |
| def prepare_tokenizer(): |
| tok = load_tokenizer("tokenizer.model") |
| |
| new_tokens = { |
| SpecialToken("<|im_start|>"): 50304, |
| SpecialToken("<|im_end|>"): 50305, |
| } |
| tok.special_tokens.update(new_tokens) |
| tok.special_tokens_inv = {v: k for k, v in tok.special_tokens.items()} |
| |
| tok.vocab_size = max(tok.vocab.keys()) + 1 |
| |
| for st, sid in new_tokens.items(): |
| if sid not in tok.vocab: |
| tok.vocab[sid] = f"<{st.name}>".encode("utf-8", errors="replace") |
| tok.save("tokenizer_sft") |
| print(f"Saved tokenizer_sft.model with vocab_size={tok.vocab_size}") |
| return tok |
|
|
|
|
| |
| def expand_model_for_new_tokens(old_ckpt_path, new_vocab_size, config): |
| old_config = TransformerConfig( |
| vocab_size=50304, |
| block_size=1024, |
| n_embed=config.n_embed, |
| n_heads=config.n_heads, |
| n_layers=config.n_layers, |
| dropout=config.dropout, |
| bias=config.bias, |
| ) |
| old_model = TransformerLanguageModel(old_config) |
| old_model.load_state_dict(torch.load(old_ckpt_path, map_location="cpu")) |
|
|
| new_config = TransformerConfig( |
| vocab_size=new_vocab_size, |
| block_size=config.block_size, |
| n_embed=config.n_embed, |
| n_heads=config.n_heads, |
| n_layers=config.n_layers, |
| dropout=config.dropout, |
| bias=config.bias, |
| ) |
| new_model = TransformerLanguageModel(new_config) |
|
|
| new_state = new_model.state_dict() |
| old_state = old_model.state_dict() |
|
|
| for key in new_state: |
| if key in old_state: |
| if new_state[key].shape == old_state[key].shape: |
| new_state[key].copy_(old_state[key]) |
| else: |
| print(f"Expanding {key}: {old_state[key].shape} -> {new_state[key].shape}") |
| if "token_embedding_table" in key: |
| new_state[key][: old_state[key].size(0)].copy_(old_state[key]) |
| elif "lm_head" in key: |
| new_state[key][: old_state[key].size(0)].copy_(old_state[key]) |
| elif "position_embedding_table" in key: |
| |
| new_state[key][: old_state[key].size(0)].copy_(old_state[key]) |
| elif "mask" in key: |
| |
| pass |
| else: |
| print(f"Warning: unexpected shape mismatch for {key}") |
| else: |
| print(f"Key {key} not in old model, initialized randomly.") |
|
|
| new_model.load_state_dict(new_state) |
| return new_model |
|
|
|
|
| |
| class SFTDataset(IterableDataset): |
| def __init__(self, data_file, tokenizer, block_size=2048, mask_prob=0.8): |
| self.tokenizer = tokenizer |
| self.block_size = block_size |
| self.mask_prob = mask_prob |
| self.eos_id = tokenizer.special_tokens[SpecialToken("<|endoftext|>")] |
|
|
| |
| self.samples = [] |
| with open(data_file, "r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| item = json.loads(line) |
| tokens, mask = self._encode_messages(item["messages"]) |
| |
| if len(tokens) > 0 and sum(mask) > 0: |
| self.samples.append((tokens, mask)) |
|
|
| print(f"Loaded {len(self.samples)} valid SFT samples.") |
|
|
| def _encode_messages(self, messages): |
| token_ids = [] |
| loss_mask = [] |
|
|
| for msg in messages: |
| role = msg["role"] |
| content = msg["content"] |
|
|
| prefix = self.tokenizer.encode_all([ |
| SpecialToken("<|im_start|>"), |
| f"{role}\n", |
| ]) |
| content_ids = self.tokenizer.encode(content) |
| suffix = self.tokenizer.encode_all([ |
| SpecialToken("<|im_end|>"), |
| "\n", |
| ]) |
|
|
| msg_tokens = prefix + content_ids + suffix |
| msg_mask = [1 if role == "assistant" else 0] * len(msg_tokens) |
|
|
| token_ids.extend(msg_tokens) |
| loss_mask.extend(msg_mask) |
|
|
| |
| token_ids.append(self.eos_id) |
| loss_mask.append(1) |
|
|
| return token_ids, loss_mask |
|
|
| def __iter__(self): |
| while True: |
| idx = random.randint(0, len(self.samples) - 1) |
| tokens, assistant_mask = self.samples[idx] |
|
|
| |
| max_len = self.block_size + 1 |
| if len(tokens) > max_len: |
| tokens = tokens[:max_len] |
| assistant_mask = assistant_mask[:max_len] |
|
|
| x = tokens[:-1] |
| y = tokens[1:] |
| mask = assistant_mask[:-1] |
|
|
| |
| pad_len = self.block_size - len(x) |
| if pad_len > 0: |
| x = x + [self.eos_id] * pad_len |
| y = y + [self.eos_id] * pad_len |
| mask = mask + [0] * pad_len |
|
|
| |
| if random.random() < self.mask_prob: |
| final_mask = mask |
| else: |
| final_mask = [1] * self.block_size |
|
|
| yield ( |
| torch.tensor(x, dtype=torch.int64), |
| torch.tensor(y, dtype=torch.int64), |
| torch.tensor(final_mask, dtype=torch.float32), |
| ) |
|
|
|
|
| |
| @torch.no_grad() |
| def gen_text(model, tokenizer, text, device="cuda:0", max_new_tokens=200): |
| model.eval() |
| ids = torch.tensor(tokenizer.encode_all([ |
| SpecialToken("<|im_start|>"), |
| "user\n", |
| text, |
| SpecialToken("<|im_end|>"), |
| "\n", |
| SpecialToken("<|im_start|>"), |
| "assistant\n", |
| ]), dtype=torch.int64).to(device).view(1, -1) |
|
|
| output_ids = model.generate(ids, max_new_tokens=max_new_tokens)[0, :] |
| decoded = tokenizer.decode(output_ids.tolist()) |
| model.train() |
| return decoded |
|
|
|
|
| |
| def train(): |
| device = "cuda:0" |
| block_size = 2048 |
| new_vocab_size = 50306 |
| batch_size = 4 |
| gradient_accumulation_steps = 4 |
| learning_rate = 1e-5 |
| max_iters = 2000 |
| save_interval = 200 |
| eval_interval = 50 |
|
|
| |
| if not os.path.exists("tokenizer_sft.model"): |
| tokenizer = prepare_tokenizer() |
| else: |
| tokenizer = load_tokenizer("tokenizer_sft.model") |
| print(f"Loaded tokenizer_sft.model with vocab_size={tokenizer.vocab_size}") |
|
|
| |
| config = TransformerConfig( |
| vocab_size=new_vocab_size, |
| block_size=block_size, |
| n_embed=768, |
| n_heads=12, |
| n_layers=12, |
| dropout=0.0, |
| bias=True, |
| ) |
|
|
| |
| print("Expanding model vocab and loading checkpoint 150000.pt...") |
| model = expand_model_for_new_tokens("checkpoints/new/150000.pt", new_vocab_size, config) |
| model = model.to(device) |
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f"Model loaded. Total parameters: {total_params / 1e6:.2f}M") |
|
|
| |
| dataset = SFTDataset( |
| "data/novels_sft_dataset.jsonl", |
| tokenizer, |
| block_size=block_size, |
| mask_prob=0.8, |
| ) |
| loader = DataLoader(dataset, batch_size=batch_size) |
| data_iter = iter(loader) |
|
|
| |
| optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) |
|
|
| |
| scaler = GradScaler("cuda") |
| autocast_ctx = lambda: autocast("cuda", dtype=torch.float16) |
|
|
| os.makedirs("checkpoints/sft", exist_ok=True) |
|
|
| model.train() |
| pbar = tqdm(total=max_iters, desc="SFT Training") |
| all_loss = 0.0 |
|
|
| for iter_num in range(max_iters + 1): |
| optimizer.zero_grad(set_to_none=True) |
| accum_loss = 0.0 |
|
|
| for _ in range(gradient_accumulation_steps): |
| x, y, mask = next(data_iter) |
| x = x.to(device) |
| y = y.to(device) |
| mask = mask.to(device) |
|
|
| with autocast_ctx(): |
| logits, _ = model(x, device=device) |
| logits = logits.view(-1, config.vocab_size) |
| y_flat = y.view(-1) |
| mask_flat = mask.view(-1) |
|
|
| loss = F.cross_entropy(logits, y_flat, reduction="none") |
| loss = (loss * mask_flat).sum() / (mask_flat.sum() + 1e-8) |
| loss = loss / gradient_accumulation_steps |
|
|
| scaler.scale(loss).backward() |
| accum_loss += loss.item() |
|
|
| |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
| scaler.step(optimizer) |
| scaler.update() |
|
|
| all_loss += accum_loss |
| pbar.update(1) |
| pbar.set_postfix(loss=f"{accum_loss:.4f}") |
|
|
| if iter_num % eval_interval == 0: |
| print(f"\n[Step {iter_num}] Loss: {accum_loss:.4f}") |
| try: |
| decoded = gen_text(model, tokenizer, "写一个恋爱喜剧轻小说,主角是能听到物品心声的高中生。", device=device) |
| |
| text_out = "" |
| for tok in decoded: |
| if isinstance(tok, str): |
| text_out += tok |
| print(f"Sample output: {text_out[:200]}...") |
| except Exception as e: |
| print(f"Generation error: {e}") |
|
|
| if iter_num > 0 and (iter_num % save_interval == 0 or iter_num == max_iters): |
| ckpt_path = f"checkpoints/sft/sft_{iter_num}.pt" |
| torch.save(model.state_dict(), ckpt_path) |
| print(f"\nSaved checkpoint: {ckpt_path}") |
|
|
| pbar.close() |
| final_path = "checkpoints/sft/sft_final.pt" |
| torch.save(model.state_dict(), final_path) |
| print(f"Training complete. Final model saved to {final_path}") |
|
|
|
|
| if __name__ == "__main__": |
| train() |
|
|