| |
|
| | import os
|
| | import json
|
| | import math
|
| | import datetime
|
| | import numpy as np
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | import sentencepiece as spm
|
| | import matplotlib.pyplot as plt
|
| | from model_optimized import MemoryOptimizedBigramLM
|
| |
|
| |
|
| | device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| | batch_size = 8
|
| | num_iter = 10000
|
| | eval_interval = 500
|
| | eval_iters = 200
|
| | d_model = 512
|
| | h = 8
|
| | Nx = 6
|
| | dropout_rate = 0.2
|
| | lr_rate = 1e-4
|
| | max_seq_len = 2048
|
| |
|
| |
|
| | stop_token_weight = 1
|
| | enable_stop_training = True
|
| |
|
| | model_save_dir = "saved_models"
|
| | os.makedirs(model_save_dir, exist_ok=True)
|
| | torch.manual_seed(1337)
|
| |
|
| |
|
| | sp = spm.SentencePieceProcessor()
|
| | sp.load("tokenizer.model")
|
| |
|
| | def encode(s):
|
| | return sp.encode(s, out_type=int)
|
| |
|
| | def decode(tokens):
|
| |
|
| | if isinstance(tokens, torch.Tensor):
|
| | tokens = tokens.tolist()
|
| | text = sp.decode(tokens)
|
| | if "<END>" in text:
|
| | text = text.split("<END>")[0]
|
| | return text.strip()
|
| |
|
| | vocab_size = sp.get_piece_size()
|
| | print(f"词汇表大小: {vocab_size}")
|
| |
|
| |
|
| | end_token_id = sp.piece_to_id("<END>")
|
| | if end_token_id == -1:
|
| | raise ValueError("tokenizer 中没有 <END>,请用 user_defined_symbols 添加并重训 tokenizer")
|
| |
|
| | def detect_pad_id(sp):
|
| | for cand in ["<pad>","<PAD>","<blank>","<PAD>","[PAD]"]:
|
| | pid = sp.piece_to_id(cand)
|
| | if pid != -1:
|
| | return pid
|
| |
|
| | return 1
|
| |
|
| | pad_id = detect_pad_id(sp)
|
| | print(f"<END> id={end_token_id}, pad_id={pad_id}")
|
| |
|
| |
|
| | all_lines = []
|
| | data_path = "data.txt"
|
| | if not os.path.exists(data_path):
|
| | raise FileNotFoundError(f"{data_path} not found")
|
| |
|
| |
|
| | instruction_prompt = "请根据给定的关键词,创作一首诗句。\n"
|
| |
|
| |
|
| | with open(data_path, 'r', encoding='utf-8') as f:
|
| | buf = []
|
| | for raw in f:
|
| | line = raw.rstrip("\n")
|
| | if line.strip() == "":
|
| |
|
| | continue
|
| | buf.append(line.strip())
|
| |
|
| | if line.strip() == "<END>":
|
| |
|
| | sample_text = instruction_prompt + "\n".join(buf)
|
| | tokens = encode(sample_text)
|
| | if not tokens:
|
| | buf = []
|
| | continue
|
| |
|
| | if tokens[-1] != end_token_id:
|
| | tokens = tokens + [end_token_id]
|
| | if len(tokens) <= max_seq_len:
|
| | all_lines.append(tokens)
|
| | buf = []
|
| |
|
| | if buf:
|
| | print("警告:文件末尾有未以 <END> 终止的残余样本(已忽略):", buf[:2])
|
| |
|
| | print(f"解析样本数: {len(all_lines)}")
|
| | if len(all_lines) == 0:
|
| | raise RuntimeError("数据解析后样本数为0,请检查 data.txt 格式(每组必须以单独一行 <END> 结尾))")
|
| |
|
| | split_90perc = int(0.9 * len(all_lines))
|
| | train_lines = all_lines[:split_90perc]
|
| | valid_lines = all_lines[split_90perc:]
|
| | print(f"训练样本数: {len(train_lines)}, 验证样本数: {len(valid_lines)}")
|
| |
|
| |
|
| | def analyze_stop_token_distribution(n=200):
|
| | print("\n分析 <END> 分布:")
|
| | sample_cnt = min(n, len(train_lines))
|
| | end_positions = []
|
| | found = 0
|
| | for tokens in train_lines[:sample_cnt]:
|
| | if end_token_id in tokens:
|
| | pos = tokens.index(end_token_id)
|
| | end_positions.append(pos)
|
| | found += 1
|
| | if found:
|
| | print(f"前 {sample_cnt} 个样本中,包含 <END> 的数量: {found}/{sample_cnt}")
|
| | print(f"<END> 平均位置: {np.mean(end_positions):.1f}")
|
| | else:
|
| | print("未发现 <END>,请检查 tokenizer 或样本是否包含文字 '<END>' 行。")
|
| | analyze_stop_token_distribution()
|
| |
|
| |
|
| | class EnhancedLoss(nn.Module):
|
| | def __init__(self, stop_token_id, stop_weight=2.0, pad_id=1):
|
| | super().__init__()
|
| | self.stop_token_id = stop_token_id
|
| | self.stop_weight = stop_weight
|
| | self.pad_id = pad_id
|
| |
|
| | def forward(self, logits, targets):
|
| | """
|
| | logits: (B, T, V)
|
| | targets: (B, T)
|
| | 返回: total_loss (Tensor), standard_loss (float), stop_loss (float)
|
| | """
|
| | B, T, V = logits.shape
|
| | logits_flat = logits.view(-1, V)
|
| | targets_flat = targets.view(-1)
|
| |
|
| |
|
| | loss_flat = F.cross_entropy(logits_flat, targets_flat, ignore_index=self.pad_id, reduction='none')
|
| | mask = (targets_flat != self.pad_id).float()
|
| |
|
| |
|
| | denom = mask.sum().clamp(min=1.0)
|
| | standard_loss = (loss_flat * mask).sum() / denom
|
| |
|
| |
|
| | stop_mask = (targets_flat == self.stop_token_id)
|
| | stop_loss = torch.tensor(0.0, device=standard_loss.device)
|
| | if enable_stop_training and stop_mask.any():
|
| | stop_loss_vals = loss_flat[stop_mask]
|
| | if stop_loss_vals.numel() > 0:
|
| | stop_loss = stop_loss_vals.mean()
|
| | total_loss = standard_loss + self.stop_weight * stop_loss
|
| | return total_loss, standard_loss.item(), float(stop_loss.item())
|
| |
|
| | return standard_loss, float(standard_loss.item()), 0.0
|
| |
|
| |
|
| | def get_batch(split, batch_size_override=None):
|
| | current_batch_size = batch_size_override if batch_size_override else batch_size
|
| | dataset = train_lines if split == "train" else valid_lines
|
| |
|
| | idxs = np.random.randint(0, len(dataset), current_batch_size)
|
| | batch_lines = [dataset[i] for i in idxs]
|
| |
|
| | x = [torch.tensor(l[:-1], dtype=torch.long) for l in batch_lines]
|
| | y = [torch.tensor(l[1:], dtype=torch.long) for l in batch_lines]
|
| | max_len = max(len(xx) for xx in x)
|
| |
|
| | x = torch.stack([F.pad(xx, (0, max_len - len(xx)), value=pad_id) for xx in x]).to(device)
|
| | y = torch.stack([F.pad(yy, (0, max_len - len(yy)), value=pad_id) for yy in y]).to(device)
|
| | return x, y
|
| |
|
| |
|
| | @torch.no_grad()
|
| | def estimate_loss_and_ppl(model, criterion):
|
| | result = {}
|
| | model.eval()
|
| |
|
| | for split in ['train', 'valid']:
|
| | std_losses = []
|
| | stop_losses = []
|
| | for e in range(eval_iters):
|
| | X, Y = get_batch(split, batch_size_override=4)
|
| | logits, _ = model(X, Y)
|
| | total_loss, std_loss, s_loss = criterion(logits, Y)
|
| | std_losses.append(std_loss)
|
| | stop_losses.append(s_loss)
|
| | del X, Y, logits
|
| | if device == 'cuda':
|
| | torch.cuda.empty_cache()
|
| | avg_std = float(np.mean(std_losses)) if len(std_losses) > 0 else float('nan')
|
| | avg_stop = float(np.mean(stop_losses)) if len(stop_losses) > 0 else 0.0
|
| | ppl = math.exp(avg_std) if not math.isinf(avg_std) else float('inf')
|
| | result[f'{split}_loss'] = avg_std
|
| | result[f'{split}_ppl'] = ppl
|
| | result[f'{split}_stop_loss'] = avg_stop
|
| |
|
| | model.train()
|
| | return result
|
| |
|
| |
|
| | def save_model(model, optimizer, iteration, train_losses, valid_losses, train_ppls, valid_ppls, final=False):
|
| | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| | checkpoint = {
|
| | 'iteration': iteration,
|
| | 'model_state_dict': model.state_dict(),
|
| | 'optimizer_state_dict': optimizer.state_dict() if optimizer is not None else None,
|
| | 'train_losses': train_losses,
|
| | 'valid_losses': valid_losses,
|
| | 'train_ppls': train_ppls,
|
| | 'valid_ppls': valid_ppls,
|
| | 'vocab_size': vocab_size,
|
| | 'd_model': d_model,
|
| | 'h': h,
|
| | 'Nx': Nx,
|
| | 'dropout_rate': dropout_rate,
|
| | 'save_time': timestamp
|
| | }
|
| | if final:
|
| | filename = f"{model_save_dir}/gpt_model_enhanced_stop_{timestamp}.pth"
|
| | else:
|
| | filename = f"{model_save_dir}/gpt_model_checkpoint_enhanced_stop_{timestamp}_iter_{iteration}.pth"
|
| | torch.save(checkpoint, filename)
|
| | print(f"模型已保存到: {filename}")
|
| |
|
| | info_filename = f"{model_save_dir}/training_info_{timestamp}.json"
|
| | with open(info_filename, 'w', encoding='utf-8') as f:
|
| | json.dump({
|
| | 'timestamp': timestamp,
|
| | 'iteration': iteration,
|
| | 'train_losses': train_losses,
|
| | 'valid_losses': valid_losses,
|
| | 'train_ppls': train_ppls,
|
| | 'valid_ppls': valid_ppls
|
| | }, f, indent=2, ensure_ascii=False)
|
| | print(f"训练信息已保存到: {info_filename}")
|
| |
|
| |
|
| |
|
| | def plot_training_curves(train_losses, valid_losses, train_ppls, valid_ppls):
|
| | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
|
| | iterations = list(range(0, len(train_losses) * eval_interval, eval_interval))
|
| | ax1.plot(iterations, train_losses, label='Train Loss', color='blue', linewidth=2)
|
| | ax1.plot(iterations, valid_losses, label='Validation Loss', color='red', linewidth=2)
|
| | ax1.set_xlabel('Iterations')
|
| | ax1.set_ylabel('Loss')
|
| | ax1.set_title('Training Loss')
|
| | ax1.legend()
|
| | ax1.grid(True, alpha=0.3)
|
| | ax2.plot(iterations, train_ppls, label='Train PPL', color='blue', linewidth=2)
|
| | ax2.plot(iterations, valid_ppls, label='Validation PPL', color='red', linewidth=2)
|
| | ax2.set_xlabel('Iterations')
|
| | ax2.set_ylabel('Perplexity (PPL)')
|
| | ax2.set_title('Validation PPL')
|
| | ax2.legend()
|
| | ax2.grid(True, alpha=0.3)
|
| | plt.tight_layout()
|
| | plt.savefig(f'{model_save_dir}/training_curves_{timestamp}.png', dpi=300, bbox_inches='tight')
|
| | plt.show()
|
| |
|
| |
|
| |
|
| | def load_pretrained_model(pretrained_path="saved_models/gpt_model_enhanced_stop_20251003_200243.pth"):
|
| | model = MemoryOptimizedBigramLM(
|
| | vocab_size=vocab_size,
|
| | d_model=d_model,
|
| | max_seq_len=max_seq_len,
|
| | h=h,
|
| | Nx=Nx,
|
| | dropout_rate=dropout_rate
|
| | ).to(device)
|
| |
|
| | if pretrained_path is None or not os.path.exists(pretrained_path):
|
| | print("未指定或未找到预训练权重,will start from scratch.")
|
| | return model, None
|
| |
|
| | try:
|
| |
|
| | checkpoint = torch.load(pretrained_path, map_location=device,weights_only=False)
|
| | state_dict = checkpoint.get('model_state_dict', checkpoint)
|
| |
|
| | filtered_state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
|
| | model.load_state_dict(filtered_state_dict, strict=False)
|
| | print("✅ 成功加载已训练模型权重")
|
| | return model, checkpoint
|
| | except Exception as e:
|
| | print(f"❌ 加载模型失败: {e}")
|
| | print("将从头开始训练...")
|
| | return model, None
|
| |
|
| |
|
| | def test_end_rate(model, prompts, top_k=50, temp=0.9, max_new=300):
|
| | end_id = end_token_id
|
| | pad = pad_id
|
| | cnt = 0
|
| | for p in prompts:
|
| | ids = torch.tensor([encode(p)], dtype=torch.long, device=device)
|
| | gen = model.generate(ids, max_new, temperature=temp, top_k=top_k, eos_token_id=end_id)
|
| | gen_ids = gen[0].tolist()
|
| | if end_id in gen_ids:
|
| | cnt += 1
|
| | print(f"end_rate: {cnt}/{len(prompts)} = {cnt/len(prompts):.3f}")
|
| |
|
| |
|
| | def main(pretrained_path=None):
|
| | model, pretrained_ckpt = load_pretrained_model(pretrained_path)
|
| |
|
| | with torch.no_grad():
|
| | if pad_id < model.token_embedding.weight.shape[0]:
|
| | model.token_embedding.weight[pad_id].zero_()
|
| |
|
| | criterion = EnhancedLoss(end_token_id, stop_token_weight, pad_id=pad_id)
|
| | optimizer = torch.optim.AdamW(model.parameters(), lr=lr_rate)
|
| |
|
| | if pretrained_ckpt and pretrained_ckpt.get('optimizer_state_dict') is not None:
|
| | try:
|
| | optimizer.load_state_dict(pretrained_ckpt['optimizer_state_dict'])
|
| | print("✅ 加载优化器状态")
|
| | except Exception:
|
| | print("⚠️ 无法加载优化器状态(忽略)")
|
| |
|
| | train_losses, valid_losses, train_ppls, valid_ppls = [], [], [], []
|
| |
|
| | print("开始增强停止标记训练...")
|
| | try:
|
| | for iter in range(num_iter):
|
| | if iter % eval_interval == 0:
|
| | if device == 'cuda':
|
| | torch.cuda.empty_cache()
|
| | results = estimate_loss_and_ppl(model, criterion)
|
| | train_losses.append(results['train_loss'])
|
| | valid_losses.append(results['valid_loss'])
|
| | train_ppls.append(results['train_ppl'])
|
| | valid_ppls.append(results['valid_ppl'])
|
| | print(f"step {iter}: train_loss={results['train_loss']:.4f}, valid_loss={results['valid_loss']:.4f}, "
|
| | f"train_ppl={results['train_ppl']:.2f}, valid_ppl={results['valid_ppl']:.2f}")
|
| | if results.get('train_stop_loss', 0.0) > 0:
|
| | print(f" stop_loss={results['train_stop_loss']:.4f}")
|
| |
|
| | xb, yb = get_batch("train")
|
| | logits, _ = model(xb, yb)
|
| | loss, std_loss, stop_loss = criterion(logits, yb)
|
| | loss.backward()
|
| | optimizer.step()
|
| | optimizer.zero_grad(set_to_none=True)
|
| |
|
| | if iter % 100 == 0 and device == 'cuda':
|
| | torch.cuda.empty_cache()
|
| |
|
| | except KeyboardInterrupt:
|
| | print("\n训练中断,保存当前进度...")
|
| | save_model(model, optimizer, iter, train_losses, valid_losses, train_ppls, valid_ppls, final=False)
|
| | except RuntimeError as e:
|
| | print(f"\n运行时错误: {e}")
|
| | save_model(model, optimizer, iter, train_losses, valid_losses, train_ppls, valid_ppls, final=False)
|
| | raise e
|
| |
|
| | plot_training_curves(train_losses, valid_losses, train_ppls, valid_ppls)
|
| | save_model(model, optimizer, num_iter, train_losses, valid_losses, train_ppls, valid_ppls, final=True)
|
| |
|
| |
|
| |
|
| | print("\n测试停止功能(若要更严格请增大 max_new):")
|
| | test_prompts = [
|
| | "请根据给定的关键词,创作一首诗句。\n关键词: 风 雾 寂寞\n诗词:",
|
| | "请根据给定的关键词,创作一首诗句。\n关键词: 信 天涯 晚风\n诗词:",
|
| | "请根据给定的关键词,创作一首诗句。\n关键词: 贴心 改变 自信\n诗词:"
|
| | ]
|
| | test_end_rate(model, test_prompts, top_k=50, temp=0.9, max_new=300)
|
| |
|
| | if __name__ == "__main__":
|
| |
|
| | pretrained_checkpoint_path = "saved_models/gpt_model_enhanced_stop_20251003_200243.pth"
|
| | main(pretrained_checkpoint_path)
|
| |
|