import torch import torch.nn.functional as F from torch.utils.data import DataLoader, IterableDataset from torch.amp import autocast, GradScaler from model import TransformerConfig, TransformerLanguageModel from tokenizer import load_tokenizer, SpecialToken import json import random import os from tqdm import tqdm # ============== 1. 准备Tokenizer ============== def prepare_tokenizer(): tok = load_tokenizer("tokenizer.model") # 添加新token new_tokens = { SpecialToken("<|im_start|>"): 50304, SpecialToken("<|im_end|>"): 50305, } tok.special_tokens.update(new_tokens) tok.special_tokens_inv = {v: k for k, v in tok.special_tokens.items()} # vocab_size应该是max id + 1 tok.vocab_size = max(tok.vocab.keys()) + 1 # 确保新id在vocab中有占位 for st, sid in new_tokens.items(): if sid not in tok.vocab: tok.vocab[sid] = f"<{st.name}>".encode("utf-8", errors="replace") tok.save("tokenizer_sft") print(f"Saved tokenizer_sft.model with vocab_size={tok.vocab_size}") return tok # ============== 2. 扩展模型词表 ============== def expand_model_for_new_tokens(old_ckpt_path, new_vocab_size, config): old_config = TransformerConfig( vocab_size=50304, block_size=1024, # 旧模型使用1024上下文 n_embed=config.n_embed, n_heads=config.n_heads, n_layers=config.n_layers, dropout=config.dropout, bias=config.bias, ) old_model = TransformerLanguageModel(old_config) old_model.load_state_dict(torch.load(old_ckpt_path, map_location="cpu")) new_config = TransformerConfig( vocab_size=new_vocab_size, block_size=config.block_size, n_embed=config.n_embed, n_heads=config.n_heads, n_layers=config.n_layers, dropout=config.dropout, bias=config.bias, ) new_model = TransformerLanguageModel(new_config) new_state = new_model.state_dict() old_state = old_model.state_dict() for key in new_state: if key in old_state: if new_state[key].shape == old_state[key].shape: new_state[key].copy_(old_state[key]) else: print(f"Expanding {key}: {old_state[key].shape} -> {new_state[key].shape}") if "token_embedding_table" in key: new_state[key][: old_state[key].size(0)].copy_(old_state[key]) elif "lm_head" in key: new_state[key][: old_state[key].size(0)].copy_(old_state[key]) elif "position_embedding_table" in key: # 复制旧的位置编码,新的用随机初始化 new_state[key][: old_state[key].size(0)].copy_(old_state[key]) elif "mask" in key: # mask是buffer,新模型已经初始化为正确大小 pass else: print(f"Warning: unexpected shape mismatch for {key}") else: print(f"Key {key} not in old model, initialized randomly.") new_model.load_state_dict(new_state) return new_model # ============== 3. SFT数据集 ============== class SFTDataset(IterableDataset): def __init__(self, data_file, tokenizer, block_size=2048, mask_prob=0.8): self.tokenizer = tokenizer self.block_size = block_size self.mask_prob = mask_prob # 80%概率只计算assistant loss self.eos_id = tokenizer.special_tokens[SpecialToken("<|endoftext|>")] # 预加载所有数据并编码 self.samples = [] with open(data_file, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue item = json.loads(line) tokens, mask = self._encode_messages(item["messages"]) # 过滤掉没有任何assistant内容的样本 if len(tokens) > 0 and sum(mask) > 0: self.samples.append((tokens, mask)) print(f"Loaded {len(self.samples)} valid SFT samples.") def _encode_messages(self, messages): token_ids = [] loss_mask = [] for msg in messages: role = msg["role"] content = msg["content"] prefix = self.tokenizer.encode_all([ SpecialToken("<|im_start|>"), f"{role}\n", ]) content_ids = self.tokenizer.encode(content) suffix = self.tokenizer.encode_all([ SpecialToken("<|im_end|>"), "\n", ]) msg_tokens = prefix + content_ids + suffix msg_mask = [1 if role == "assistant" else 0] * len(msg_tokens) token_ids.extend(msg_tokens) loss_mask.extend(msg_mask) # 添加eos token_ids.append(self.eos_id) loss_mask.append(1) return token_ids, loss_mask def __iter__(self): while True: idx = random.randint(0, len(self.samples) - 1) tokens, assistant_mask = self.samples[idx] # 截断到 block_size+1(为x,y留出空间) max_len = self.block_size + 1 if len(tokens) > max_len: tokens = tokens[:max_len] assistant_mask = assistant_mask[:max_len] x = tokens[:-1] y = tokens[1:] mask = assistant_mask[:-1] # pad到block_size pad_len = self.block_size - len(x) if pad_len > 0: x = x + [self.eos_id] * pad_len y = y + [self.eos_id] * pad_len mask = mask + [0] * pad_len # 80% / 20% 策略 if random.random() < self.mask_prob: final_mask = mask else: final_mask = [1] * self.block_size yield ( torch.tensor(x, dtype=torch.int64), torch.tensor(y, dtype=torch.int64), torch.tensor(final_mask, dtype=torch.float32), ) # ============== 4. 文本生成测试 ============== @torch.no_grad() def gen_text(model, tokenizer, text, device="cuda:0", max_new_tokens=200): model.eval() ids = torch.tensor(tokenizer.encode_all([ SpecialToken("<|im_start|>"), "user\n", text, SpecialToken("<|im_end|>"), "\n", SpecialToken("<|im_start|>"), "assistant\n", ]), dtype=torch.int64).to(device).view(1, -1) output_ids = model.generate(ids, max_new_tokens=max_new_tokens)[0, :] decoded = tokenizer.decode(output_ids.tolist()) model.train() return decoded # ============== 5. 训练 ============== def train(): device = "cuda:0" block_size = 2048 new_vocab_size = 50306 batch_size = 4 gradient_accumulation_steps = 4 learning_rate = 1e-5 max_iters = 2000 save_interval = 200 eval_interval = 50 # 准备tokenizer if not os.path.exists("tokenizer_sft.model"): tokenizer = prepare_tokenizer() else: tokenizer = load_tokenizer("tokenizer_sft.model") print(f"Loaded tokenizer_sft.model with vocab_size={tokenizer.vocab_size}") # 模型配置 config = TransformerConfig( vocab_size=new_vocab_size, block_size=block_size, n_embed=768, n_heads=12, n_layers=12, dropout=0.0, bias=True, ) # 扩展并加载模型 print("Expanding model vocab and loading checkpoint 150000.pt...") model = expand_model_for_new_tokens("checkpoints/new/150000.pt", new_vocab_size, config) model = model.to(device) total_params = sum(p.numel() for p in model.parameters()) print(f"Model loaded. Total parameters: {total_params / 1e6:.2f}M") # 数据集 dataset = SFTDataset( "data/novels_sft_dataset.jsonl", tokenizer, block_size=block_size, mask_prob=0.8, ) loader = DataLoader(dataset, batch_size=batch_size) data_iter = iter(loader) # 优化器 optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) # AMP混合精度 scaler = GradScaler("cuda") autocast_ctx = lambda: autocast("cuda", dtype=torch.float16) os.makedirs("checkpoints/sft", exist_ok=True) model.train() pbar = tqdm(total=max_iters, desc="SFT Training") all_loss = 0.0 for iter_num in range(max_iters + 1): optimizer.zero_grad(set_to_none=True) accum_loss = 0.0 for _ in range(gradient_accumulation_steps): x, y, mask = next(data_iter) x = x.to(device) y = y.to(device) mask = mask.to(device) with autocast_ctx(): logits, _ = model(x, device=device) logits = logits.view(-1, config.vocab_size) y_flat = y.view(-1) mask_flat = mask.view(-1) loss = F.cross_entropy(logits, y_flat, reduction="none") loss = (loss * mask_flat).sum() / (mask_flat.sum() + 1e-8) loss = loss / gradient_accumulation_steps scaler.scale(loss).backward() accum_loss += loss.item() # 梯度裁剪 scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) scaler.step(optimizer) scaler.update() all_loss += accum_loss pbar.update(1) pbar.set_postfix(loss=f"{accum_loss:.4f}") if iter_num % eval_interval == 0: print(f"\n[Step {iter_num}] Loss: {accum_loss:.4f}") try: decoded = gen_text(model, tokenizer, "写一个恋爱喜剧轻小说,主角是能听到物品心声的高中生。", device=device) # 找到assistant回复部分打印 text_out = "" for tok in decoded: if isinstance(tok, str): text_out += tok print(f"Sample output: {text_out[:200]}...") except Exception as e: print(f"Generation error: {e}") if iter_num > 0 and (iter_num % save_interval == 0 or iter_num == max_iters): ckpt_path = f"checkpoints/sft/sft_{iter_num}.pt" torch.save(model.state_dict(), ckpt_path) print(f"\nSaved checkpoint: {ckpt_path}") pbar.close() final_path = "checkpoints/sft/sft_final.pt" torch.save(model.state_dict(), final_path) print(f"Training complete. Final model saved to {final_path}") if __name__ == "__main__": train()