import math import os import time import random from typing import Iterator, List import argparse import concurrent.futures import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import IterableDataset, DataLoader from torch import amp import sentencepiece as spm from tqdm import tqdm import itertools print(f"--- PYTHON EXECUTING THIS FILE: {__file__} ---") try: from torch.utils.tensorboard import SummaryWriter except ImportError: SummaryWriter = None # --- 配置 --- class Config: # --- 模型架构 --- vocab_size = 72000 d_model = 1024 n_heads = 4 n_layers = 8 d_ff = 4096 seq_len = 1024 dropout = 0.1 # --- 训练硬件与精度 --- dtype = torch.float16 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # --- 学习率热重启配置 --- lr_max = 7e-6 warmup_steps = 1000 lr_min = 1e-6 restart_warmup_steps = 200 lr_restart_cycles = [45000, 2000, 3000] # --- 批次与步数 --- batch_size = 5 grad_accum_steps = 6 max_steps = warmup_steps + sum(lr_restart_cycles) + len(lr_restart_cycles) * restart_warmup_steps save_every = 500 out_dir = './checkpoints' # --- 其他训练参数 --- max_grad_norm = 1.0 # --- 高级正则化超参数 --- label_smoothing = 0.05 distill_temp = 1.2 distill_alpha = 0 # --- 分词器类 --- class SPTokenizer: def __init__(self, model_file: str, seq_len: int): self.sp = spm.SentencePieceProcessor(model_file=model_file) self.seq_len = seq_len def encode(self, text: str, pad=True): ids = self.sp.encode(text, out_type=int) if pad: if len(ids) > self.seq_len: ids = ids[:self.seq_len] else: pad_id = self.sp.pad_id() if self.sp.pad_id() != -1 else 0 ids += [pad_id] * (self.seq_len - len(ids)) return ids def decode(self, ids: List[int]): valid_ids = [id for id in ids if id >= 0 and id < self.sp.vocab_size()] return self.sp.decode(valid_ids) # --- 数据集类 --- class WeightedTextLineDataset(IterableDataset): def __init__(self, file_weights: dict[str, int], tokenizer: SPTokenizer, skip_lines: int = 0): super().__init__() self.tokenizer = tokenizer self.skip_lines = skip_lines if not file_weights: raise ValueError("文件权重字典不能为空。") self.filepaths = list(file_weights.keys()) self.weights = list(file_weights.values()) print("加权数据集已初始化。") def _create_line_iterator(self, filepath): """为单个文件创建一个无限循环的行生成器""" while True: try: with open(filepath, 'r', encoding='utf-8', errors='ignore') as f: for line in f: yield line except Exception as e: print(f"警告:读取文件 {filepath} 时出错: {e}") time.sleep(1) # 稍作等待再重试 def __iter__(self) -> Iterator[torch.Tensor]: # 为每个文件创建一个独立的、无限循环的行迭代器 iterators = [self._create_line_iterator(fp) for fp in self.filepaths] # 创建一个总的、加权的行生成器 def weighted_line_generator(): while True: # 根据权重,随机选择一个文件的迭代器 chosen_iterator = random.choices(iterators, weights=self.weights, k=1)[0] try: # 从被选中的迭代器中获取下一行 yield next(chosen_iterator) except StopIteration: # 由于 _create_line_iterator 是无限的,这里理论上不会发生 continue # 在总的加权迭代器上,执行一次性的、高效的跳过 line_it = weighted_line_generator() if self.skip_lines > 0: print(f"数据集:正在快速跳过前 {self.skip_lines} 行(按权重分布)...") line_it = itertools.islice(line_it, self.skip_lines, None) print("跳过完成。") self.skip_lines = 0 # 确保在后续的数据集循环中不再跳过 # 从正确的位置开始,处理并 yield 数据 for line in line_it: line = line.strip() if line: tokens = self.tokenizer.encode(line) yield torch.tensor(tokens, dtype=torch.long) # --- 模型定义 --- class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.eps = eps self.scale = nn.Parameter(torch.ones(dim)) def forward(self, x): norm = x.pow(2).mean(-1, keepdim=True).add(self.eps).sqrt() return x / norm * self.scale class MultiHeadSelfAttention(nn.Module): def __init__(self, d_model, n_heads, seq_len, dropout=0.0): super().__init__() assert d_model % n_heads == 0 self.n_heads = n_heads self.head_dim = d_model // n_heads self.scale = 1.0 / math.sqrt(self.head_dim) self.qkv = nn.Linear(d_model, d_model * 3, bias=False) self.out = nn.Linear(d_model, d_model) self.register_buffer('mask', torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0)) self.dropout = nn.Dropout(dropout) def forward(self, x): B, T, C = x.size() qkv = self.qkv(x) q, k, v = qkv.chunk(3, dim=-1) q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) att = torch.matmul(q, k.transpose(-2, -1)) * self.scale # 注意: 你的旧版本注意力mask逻辑 mask = self.mask[:, :, :T, :T] if self.mask.dim() == 4 else self.mask[:, :T, :T] att = att.masked_fill(mask == 0, float('-inf')) att = F.softmax(att, dim=-1) att = self.dropout(att) out = torch.matmul(att, v) out = out.transpose(1, 2).contiguous().reshape(B, T, C) return self.out(out) class FeedForward(nn.Module): def __init__(self, d_model, d_ff, dropout=0.0): super().__init__() self.net = nn.Sequential( nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) class TransformerBlock(nn.Module): def __init__(self, d_model, n_heads, d_ff, seq_len, dropout=0.0): super().__init__() self.attn = MultiHeadSelfAttention(d_model, n_heads, seq_len, dropout) self.norm1 = RMSNorm(d_model) self.ff = FeedForward(d_model, d_ff, dropout) self.norm2 = RMSNorm(d_model) def forward(self, x): x = x + self.attn(self.norm1(x)) x = x + self.ff(self.norm2(x)) return x class TinyDecoderModel(nn.Module): def __init__(self, cfg: Config): super().__init__() self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) self.pos_emb = nn.Parameter(torch.zeros(cfg.seq_len, cfg.d_model)) self.layers = nn.ModuleList([ TransformerBlock(cfg.d_model, cfg.n_heads, cfg.d_ff, cfg.seq_len, cfg.dropout) for _ in range(cfg.n_layers) ]) self.ln_f = RMSNorm(cfg.d_model) self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) nn.init.normal_(self.pos_emb, mean=0.0, std=0.02) def forward(self, idx): B, T = idx.size() x = self.tok_emb(idx) + self.pos_emb[:T] for l in self.layers: x = l(x) x = self.ln_f(x) logits = self.head(x) return logits # --- 学习率调度器 --- def get_lr_with_cyclical_warmup(step: int, cfg: Config) -> float: if step < cfg.warmup_steps: return cfg.lr_max * (step + 1) / cfg.warmup_steps effective_step = step - cfg.warmup_steps for cycle_len in cfg.lr_restart_cycles: full_cycle_len = cfg.restart_warmup_steps + cycle_len if effective_step < full_cycle_len: step_in_this_cycle = effective_step if step_in_this_cycle < cfg.restart_warmup_steps: return cfg.lr_min + (cfg.lr_max - cfg.lr_min) * (step_in_this_cycle / cfg.restart_warmup_steps) else: step_in_decay = step_in_this_cycle - cfg.restart_warmup_steps decay_ratio = step_in_decay / cycle_len coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) return cfg.lr_min + coeff * (cfg.lr_max - cfg.lr_min) effective_step -= full_cycle_len return cfg.lr_min # --- 检查点函数 --- def save_checkpoint(model, optimizer, scaler, step, out_dir): os.makedirs(out_dir, exist_ok=True) path = os.path.join(out_dir, f'ckpt_step{step}.pt') torch.save({ 'step': step, 'model_state': model.state_dict(), 'opt_state': optimizer.state_dict(), 'scaler_state': scaler.state_dict(), }, path) print(f"\n检查点已保存: {path}") # --- 主训练函数 --- def train(file_weights: dict[str, int], cfg: Config, resume_from_ckpt=None, skip_lines=0): tokenizer = SPTokenizer(model_file=r"D:\a\uyghur-dictionary\mymodel.model", seq_len=cfg.seq_len) writer = SummaryWriter(log_dir=os.path.join('./logs', time.strftime("%Y%m%d-%H%M%S"))) if SummaryWriter else None model = TinyDecoderModel(cfg).to(cfg.device) optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr_max) scaler = amp.GradScaler(enabled=(cfg.dtype == torch.float16)) step = 0 if resume_from_ckpt: print(f"正在从检查点恢复: {resume_from_ckpt}") checkpoint = torch.load(resume_from_ckpt, map_location=cfg.device) model.load_state_dict(checkpoint['model_state']) optimizer.load_state_dict(checkpoint['opt_state']) if 'scaler_state' in checkpoint: scaler.load_state_dict(checkpoint['scaler_state']) step = checkpoint['step'] print(f"恢复成功,将从 step {step + 1} 继续训练。") if step > 0 and skip_lines == 0: lines_per_step = cfg.batch_size * cfg.grad_accum_steps skip_lines = step * lines_per_step dataset = WeightedTextLineDataset(file_weights, tokenizer, skip_lines=skip_lines) dataloader = DataLoader(dataset, batch_size=cfg.batch_size, num_workers=0) data_iter = iter(dataloader) model.train() optimizer.zero_grad(set_to_none=True) start_time = time.time() initial_step = step pbar = tqdm(range(step, cfg.max_steps), desc="Training Steps", initial=step, total=cfg.max_steps) for step in pbar: lr = get_lr_with_cyclical_warmup(step, cfg) for g in optimizer.param_groups: g['lr'] = lr accumulated_loss = 0.0 for i in range(cfg.grad_accum_steps): pbar.set_description(f"Step {step+1}/{cfg.max_steps} [Accum. {i+1}/{cfg.grad_accum_steps}]") try: batch = next(data_iter) except StopIteration: data_iter = iter(dataloader) batch = next(data_iter) batch = batch.to(cfg.device) with amp.autocast(device_type=cfg.device.type.replace(':', ''), dtype=cfg.dtype): logits = model(batch) shift_logits = logits[:, :-1, :].contiguous() shift_labels = batch[:, 1:].contiguous() loss_hard = F.cross_entropy( shift_logits.view(-1, cfg.vocab_size), shift_labels.view(-1), ignore_index=0, label_smoothing=cfg.label_smoothing ) if cfg.distill_temp > 1.0 and cfg.distill_alpha > 0.0: with torch.no_grad(): logits_teacher = shift_logits / cfg.distill_temp q_soft_target = F.softmax(logits_teacher, dim=-1) log_p_cold = F.log_softmax(shift_logits, dim=-1) loss_distill = -torch.sum(q_soft_target * log_p_cold, dim=-1).mean() loss = cfg.distill_alpha * loss_distill + (1.0 - cfg.distill_alpha) * loss_hard else: loss = loss_hard accumulated_loss += loss.item() scaler.scale(loss / cfg.grad_accum_steps).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm) scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) if writer: avg_loss_this_step = accumulated_loss / cfg.grad_accum_steps if math.isfinite(avg_loss_this_step): writer.add_scalar('Loss/Total_Loss', avg_loss_this_step, step) writer.add_scalar('Meta/Learning_Rate', lr, step) writer.flush() pbar.set_postfix(loss=f"{avg_loss_this_step:.4f}", lr=f"{lr:.2e}") if (step + 1) % cfg.save_every == 0: if writer: writer.flush() save_checkpoint(model, optimizer, scaler, step + 1, cfg.out_dir) if writer: writer.flush() save_checkpoint(model, optimizer, scaler, step + 1, cfg.out_dir) if writer: writer.close() if __name__ == "__main__": if torch.cuda.is_available(): torch.cuda.empty_cache() parser = argparse.ArgumentParser(description="使用高级正则化和热重启学习率进行模型训练。") parser.add_argument('--resume_from_ckpt', type=str, default=None, help='指定要恢复训练的检查点文件路径。') parser.add_argument('--skip_lines', type=int, default=0, help='手动指定数据集要跳过的初始行数。') args = parser.parse_args() # --- 【核心】使用字典来定义文件和权重 --- file_weights_map = { } if args.resume_from_ckpt and not os.path.exists(args.resume_from_ckpt): print(f"错误: 检查点文件 '{args.resume_from_ckpt}' 不存在!") exit() cfg = Config() # --- 【核心】将字典和正确的参数名传入 train 函数 --- train(file_weights=file_weights_map, cfg=cfg, resume_from_ckpt=args.resume_from_ckpt, skip_lines=args.skip_lines)