| | import math
|
| | import os
|
| | import time
|
| | import random
|
| | from typing import Iterator, List
|
| | import argparse
|
| | import concurrent.futures
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | from torch.utils.data import IterableDataset, DataLoader
|
| | from torch import amp
|
| | import sentencepiece as spm
|
| | from tqdm import tqdm
|
| | import itertools
|
| |
|
| | print(f"--- PYTHON EXECUTING THIS FILE: {__file__} ---")
|
| |
|
| | try:
|
| | from torch.utils.tensorboard import SummaryWriter
|
| | except ImportError:
|
| | SummaryWriter = None
|
| |
|
| |
|
| | class Config:
|
| |
|
| | vocab_size = 72000
|
| | d_model = 1024
|
| | n_heads = 4
|
| | n_layers = 8
|
| | d_ff = 4096
|
| | seq_len = 1024
|
| | dropout = 0.1
|
| |
|
| | dtype = torch.float16
|
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| |
|
| | lr_max = 7e-6
|
| | warmup_steps = 1000
|
| | lr_min = 1e-6
|
| | restart_warmup_steps = 200
|
| | lr_restart_cycles = [45000, 2000, 3000]
|
| |
|
| | batch_size = 5
|
| | grad_accum_steps = 6
|
| | max_steps = warmup_steps + sum(lr_restart_cycles) + len(lr_restart_cycles) * restart_warmup_steps
|
| | save_every = 500
|
| | out_dir = './checkpoints'
|
| |
|
| | max_grad_norm = 1.0
|
| |
|
| | label_smoothing = 0.05
|
| | distill_temp = 1.2
|
| | distill_alpha = 0
|
| |
|
| |
|
| | class SPTokenizer:
|
| | def __init__(self, model_file: str, seq_len: int):
|
| | self.sp = spm.SentencePieceProcessor(model_file=model_file)
|
| | self.seq_len = seq_len
|
| |
|
| | def encode(self, text: str, pad=True):
|
| | ids = self.sp.encode(text, out_type=int)
|
| | if pad:
|
| | if len(ids) > self.seq_len:
|
| | ids = ids[:self.seq_len]
|
| | else:
|
| | pad_id = self.sp.pad_id() if self.sp.pad_id() != -1 else 0
|
| | ids += [pad_id] * (self.seq_len - len(ids))
|
| | return ids
|
| |
|
| | def decode(self, ids: List[int]):
|
| | valid_ids = [id for id in ids if id >= 0 and id < self.sp.vocab_size()]
|
| | return self.sp.decode(valid_ids)
|
| |
|
| |
|
| | class WeightedTextLineDataset(IterableDataset):
|
| | def __init__(self, file_weights: dict[str, int], tokenizer: SPTokenizer, skip_lines: int = 0):
|
| | super().__init__()
|
| | self.tokenizer = tokenizer
|
| | self.skip_lines = skip_lines
|
| | if not file_weights: raise ValueError("文件权重字典不能为空。")
|
| | self.filepaths = list(file_weights.keys())
|
| | self.weights = list(file_weights.values())
|
| | print("加权数据集已初始化。")
|
| |
|
| | def _create_line_iterator(self, filepath):
|
| | """为单个文件创建一个无限循环的行生成器"""
|
| | while True:
|
| | try:
|
| | with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
| | for line in f: yield line
|
| | except Exception as e:
|
| | print(f"警告:读取文件 {filepath} 时出错: {e}")
|
| | time.sleep(1)
|
| |
|
| | def __iter__(self) -> Iterator[torch.Tensor]:
|
| |
|
| | iterators = [self._create_line_iterator(fp) for fp in self.filepaths]
|
| |
|
| |
|
| | def weighted_line_generator():
|
| | while True:
|
| |
|
| | chosen_iterator = random.choices(iterators, weights=self.weights, k=1)[0]
|
| | try:
|
| |
|
| | yield next(chosen_iterator)
|
| | except StopIteration:
|
| |
|
| | continue
|
| |
|
| |
|
| | line_it = weighted_line_generator()
|
| | if self.skip_lines > 0:
|
| | print(f"数据集:正在快速跳过前 {self.skip_lines} 行(按权重分布)...")
|
| | line_it = itertools.islice(line_it, self.skip_lines, None)
|
| | print("跳过完成。")
|
| | self.skip_lines = 0
|
| |
|
| |
|
| | for line in line_it:
|
| | line = line.strip()
|
| | if line:
|
| | tokens = self.tokenizer.encode(line)
|
| | yield torch.tensor(tokens, dtype=torch.long)
|
| |
|
| |
|
| | class RMSNorm(nn.Module):
|
| | def __init__(self, dim, eps=1e-6):
|
| | super().__init__()
|
| | self.eps = eps
|
| | self.scale = nn.Parameter(torch.ones(dim))
|
| |
|
| | def forward(self, x):
|
| | norm = x.pow(2).mean(-1, keepdim=True).add(self.eps).sqrt()
|
| | return x / norm * self.scale
|
| |
|
| | class MultiHeadSelfAttention(nn.Module):
|
| | def __init__(self, d_model, n_heads, seq_len, dropout=0.0):
|
| | super().__init__()
|
| | assert d_model % n_heads == 0
|
| | self.n_heads = n_heads
|
| | self.head_dim = d_model // n_heads
|
| | self.scale = 1.0 / math.sqrt(self.head_dim)
|
| | self.qkv = nn.Linear(d_model, d_model * 3, bias=False)
|
| | self.out = nn.Linear(d_model, d_model)
|
| | self.register_buffer('mask', torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0))
|
| | self.dropout = nn.Dropout(dropout)
|
| |
|
| | def forward(self, x):
|
| | B, T, C = x.size()
|
| | qkv = self.qkv(x)
|
| | q, k, v = qkv.chunk(3, dim=-1)
|
| | q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| | k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| | v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| | att = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
| |
|
| | mask = self.mask[:, :, :T, :T] if self.mask.dim() == 4 else self.mask[:, :T, :T]
|
| | att = att.masked_fill(mask == 0, float('-inf'))
|
| | att = F.softmax(att, dim=-1)
|
| | att = self.dropout(att)
|
| | out = torch.matmul(att, v)
|
| | out = out.transpose(1, 2).contiguous().reshape(B, T, C)
|
| | return self.out(out)
|
| |
|
| | class FeedForward(nn.Module):
|
| | def __init__(self, d_model, d_ff, dropout=0.0):
|
| | super().__init__()
|
| | self.net = nn.Sequential(
|
| | nn.Linear(d_model, d_ff),
|
| | nn.GELU(),
|
| | nn.Linear(d_ff, d_model),
|
| | nn.Dropout(dropout)
|
| | )
|
| |
|
| | def forward(self, x):
|
| | return self.net(x)
|
| |
|
| | class TransformerBlock(nn.Module):
|
| | def __init__(self, d_model, n_heads, d_ff, seq_len, dropout=0.0):
|
| | super().__init__()
|
| | self.attn = MultiHeadSelfAttention(d_model, n_heads, seq_len, dropout)
|
| | self.norm1 = RMSNorm(d_model)
|
| | self.ff = FeedForward(d_model, d_ff, dropout)
|
| | self.norm2 = RMSNorm(d_model)
|
| |
|
| | def forward(self, x):
|
| | x = x + self.attn(self.norm1(x))
|
| | x = x + self.ff(self.norm2(x))
|
| | return x
|
| |
|
| | class TinyDecoderModel(nn.Module):
|
| | def __init__(self, cfg: Config):
|
| | super().__init__()
|
| | self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
|
| | self.pos_emb = nn.Parameter(torch.zeros(cfg.seq_len, cfg.d_model))
|
| | self.layers = nn.ModuleList([
|
| | TransformerBlock(cfg.d_model, cfg.n_heads, cfg.d_ff, cfg.seq_len, cfg.dropout)
|
| | for _ in range(cfg.n_layers)
|
| | ])
|
| | self.ln_f = RMSNorm(cfg.d_model)
|
| | self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
|
| | nn.init.normal_(self.pos_emb, mean=0.0, std=0.02)
|
| |
|
| | def forward(self, idx):
|
| | B, T = idx.size()
|
| | x = self.tok_emb(idx) + self.pos_emb[:T]
|
| | for l in self.layers:
|
| | x = l(x)
|
| | x = self.ln_f(x)
|
| | logits = self.head(x)
|
| | return logits
|
| |
|
| |
|
| | def get_lr_with_cyclical_warmup(step: int, cfg: Config) -> float:
|
| | if step < cfg.warmup_steps:
|
| | return cfg.lr_max * (step + 1) / cfg.warmup_steps
|
| | effective_step = step - cfg.warmup_steps
|
| | for cycle_len in cfg.lr_restart_cycles:
|
| | full_cycle_len = cfg.restart_warmup_steps + cycle_len
|
| | if effective_step < full_cycle_len:
|
| | step_in_this_cycle = effective_step
|
| | if step_in_this_cycle < cfg.restart_warmup_steps:
|
| | return cfg.lr_min + (cfg.lr_max - cfg.lr_min) * (step_in_this_cycle / cfg.restart_warmup_steps)
|
| | else:
|
| | step_in_decay = step_in_this_cycle - cfg.restart_warmup_steps
|
| | decay_ratio = step_in_decay / cycle_len
|
| | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
|
| | return cfg.lr_min + coeff * (cfg.lr_max - cfg.lr_min)
|
| | effective_step -= full_cycle_len
|
| | return cfg.lr_min
|
| |
|
| |
|
| | def save_checkpoint(model, optimizer, scaler, step, out_dir):
|
| | os.makedirs(out_dir, exist_ok=True)
|
| | path = os.path.join(out_dir, f'ckpt_step{step}.pt')
|
| | torch.save({
|
| | 'step': step,
|
| | 'model_state': model.state_dict(),
|
| | 'opt_state': optimizer.state_dict(),
|
| | 'scaler_state': scaler.state_dict(),
|
| | }, path)
|
| | print(f"\n检查点已保存: {path}")
|
| |
|
| |
|
| | def train(file_weights: dict[str, int], cfg: Config, resume_from_ckpt=None, skip_lines=0):
|
| | tokenizer = SPTokenizer(model_file=r"D:\a\uyghur-dictionary\mymodel.model", seq_len=cfg.seq_len)
|
| | writer = SummaryWriter(log_dir=os.path.join('./logs', time.strftime("%Y%m%d-%H%M%S"))) if SummaryWriter else None
|
| | model = TinyDecoderModel(cfg).to(cfg.device)
|
| | optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr_max)
|
| | scaler = amp.GradScaler(enabled=(cfg.dtype == torch.float16))
|
| | step = 0
|
| |
|
| | if resume_from_ckpt:
|
| | print(f"正在从检查点恢复: {resume_from_ckpt}")
|
| | checkpoint = torch.load(resume_from_ckpt, map_location=cfg.device)
|
| | model.load_state_dict(checkpoint['model_state'])
|
| | optimizer.load_state_dict(checkpoint['opt_state'])
|
| | if 'scaler_state' in checkpoint:
|
| | scaler.load_state_dict(checkpoint['scaler_state'])
|
| | step = checkpoint['step']
|
| | print(f"恢复成功,将从 step {step + 1} 继续训练。")
|
| |
|
| | if step > 0 and skip_lines == 0:
|
| | lines_per_step = cfg.batch_size * cfg.grad_accum_steps
|
| | skip_lines = step * lines_per_step
|
| |
|
| | dataset = WeightedTextLineDataset(file_weights, tokenizer, skip_lines=skip_lines)
|
| | dataloader = DataLoader(dataset, batch_size=cfg.batch_size, num_workers=0)
|
| | data_iter = iter(dataloader)
|
| |
|
| | model.train()
|
| | optimizer.zero_grad(set_to_none=True)
|
| | start_time = time.time()
|
| | initial_step = step
|
| |
|
| | pbar = tqdm(range(step, cfg.max_steps), desc="Training Steps", initial=step, total=cfg.max_steps)
|
| | for step in pbar:
|
| | lr = get_lr_with_cyclical_warmup(step, cfg)
|
| | for g in optimizer.param_groups:
|
| | g['lr'] = lr
|
| |
|
| | accumulated_loss = 0.0
|
| | for i in range(cfg.grad_accum_steps):
|
| | pbar.set_description(f"Step {step+1}/{cfg.max_steps} [Accum. {i+1}/{cfg.grad_accum_steps}]")
|
| | try:
|
| | batch = next(data_iter)
|
| | except StopIteration:
|
| | data_iter = iter(dataloader)
|
| | batch = next(data_iter)
|
| | batch = batch.to(cfg.device)
|
| |
|
| | with amp.autocast(device_type=cfg.device.type.replace(':', ''), dtype=cfg.dtype):
|
| | logits = model(batch)
|
| | shift_logits = logits[:, :-1, :].contiguous()
|
| | shift_labels = batch[:, 1:].contiguous()
|
| |
|
| | loss_hard = F.cross_entropy(
|
| | shift_logits.view(-1, cfg.vocab_size),
|
| | shift_labels.view(-1),
|
| | ignore_index=0,
|
| | label_smoothing=cfg.label_smoothing
|
| | )
|
| |
|
| | if cfg.distill_temp > 1.0 and cfg.distill_alpha > 0.0:
|
| | with torch.no_grad():
|
| | logits_teacher = shift_logits / cfg.distill_temp
|
| | q_soft_target = F.softmax(logits_teacher, dim=-1)
|
| |
|
| | log_p_cold = F.log_softmax(shift_logits, dim=-1)
|
| | loss_distill = -torch.sum(q_soft_target * log_p_cold, dim=-1).mean()
|
| | loss = cfg.distill_alpha * loss_distill + (1.0 - cfg.distill_alpha) * loss_hard
|
| | else:
|
| | loss = loss_hard
|
| |
|
| | accumulated_loss += loss.item()
|
| | scaler.scale(loss / cfg.grad_accum_steps).backward()
|
| |
|
| | scaler.unscale_(optimizer)
|
| | torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)
|
| | scaler.step(optimizer)
|
| | scaler.update()
|
| | optimizer.zero_grad(set_to_none=True)
|
| |
|
| | if writer:
|
| | avg_loss_this_step = accumulated_loss / cfg.grad_accum_steps
|
| | if math.isfinite(avg_loss_this_step):
|
| | writer.add_scalar('Loss/Total_Loss', avg_loss_this_step, step)
|
| | writer.add_scalar('Meta/Learning_Rate', lr, step)
|
| | writer.flush()
|
| |
|
| | pbar.set_postfix(loss=f"{avg_loss_this_step:.4f}", lr=f"{lr:.2e}")
|
| |
|
| | if (step + 1) % cfg.save_every == 0:
|
| | if writer:
|
| | writer.flush()
|
| | save_checkpoint(model, optimizer, scaler, step + 1, cfg.out_dir)
|
| |
|
| | if writer:
|
| | writer.flush()
|
| | save_checkpoint(model, optimizer, scaler, step + 1, cfg.out_dir)
|
| | if writer:
|
| | writer.close()
|
| |
|
| | if __name__ == "__main__":
|
| | if torch.cuda.is_available():
|
| | torch.cuda.empty_cache()
|
| |
|
| | parser = argparse.ArgumentParser(description="使用高级正则化和热重启学习率进行模型训练。")
|
| | parser.add_argument('--resume_from_ckpt', type=str, default=None, help='指定要恢复训练的检查点文件路径。')
|
| | parser.add_argument('--skip_lines', type=int, default=0, help='手动指定数据集要跳过的初始行数。')
|
| | args = parser.parse_args()
|
| |
|
| |
|
| | file_weights_map = {
|
| | }
|
| |
|
| | if args.resume_from_ckpt and not os.path.exists(args.resume_from_ckpt):
|
| | print(f"错误: 检查点文件 '{args.resume_from_ckpt}' 不存在!")
|
| | exit()
|
| |
|
| | cfg = Config()
|
| |
|
| |
|
| | train(file_weights=file_weights_map, cfg=cfg, resume_from_ckpt=args.resume_from_ckpt, skip_lines=args.skip_lines) |