kemuray6300a's picture
Upload 3 files
25b1d8f verified
import math
import os
import time
import random
from typing import Iterator, List
import argparse
import concurrent.futures
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import IterableDataset, DataLoader
from torch import amp
import sentencepiece as spm
from tqdm import tqdm
import itertools
print(f"--- PYTHON EXECUTING THIS FILE: {__file__} ---")
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
SummaryWriter = None
# --- 配置 ---
class Config:
# --- 模型架构 ---
vocab_size = 72000
d_model = 1024
n_heads = 4
n_layers = 8
d_ff = 4096
seq_len = 1024
dropout = 0.1
# --- 训练硬件与精度 ---
dtype = torch.float16
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# --- 学习率热重启配置 ---
lr_max = 7e-6
warmup_steps = 1000
lr_min = 1e-6
restart_warmup_steps = 200
lr_restart_cycles = [45000, 2000, 3000]
# --- 批次与步数 ---
batch_size = 5
grad_accum_steps = 6
max_steps = warmup_steps + sum(lr_restart_cycles) + len(lr_restart_cycles) * restart_warmup_steps
save_every = 500
out_dir = './checkpoints'
# --- 其他训练参数 ---
max_grad_norm = 1.0
# --- 高级正则化超参数 ---
label_smoothing = 0.05
distill_temp = 1.2
distill_alpha = 0
# --- 分词器类 ---
class SPTokenizer:
def __init__(self, model_file: str, seq_len: int):
self.sp = spm.SentencePieceProcessor(model_file=model_file)
self.seq_len = seq_len
def encode(self, text: str, pad=True):
ids = self.sp.encode(text, out_type=int)
if pad:
if len(ids) > self.seq_len:
ids = ids[:self.seq_len]
else:
pad_id = self.sp.pad_id() if self.sp.pad_id() != -1 else 0
ids += [pad_id] * (self.seq_len - len(ids))
return ids
def decode(self, ids: List[int]):
valid_ids = [id for id in ids if id >= 0 and id < self.sp.vocab_size()]
return self.sp.decode(valid_ids)
# --- 数据集类 ---
class WeightedTextLineDataset(IterableDataset):
def __init__(self, file_weights: dict[str, int], tokenizer: SPTokenizer, skip_lines: int = 0):
super().__init__()
self.tokenizer = tokenizer
self.skip_lines = skip_lines
if not file_weights: raise ValueError("文件权重字典不能为空。")
self.filepaths = list(file_weights.keys())
self.weights = list(file_weights.values())
print("加权数据集已初始化。")
def _create_line_iterator(self, filepath):
"""为单个文件创建一个无限循环的行生成器"""
while True:
try:
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
for line in f: yield line
except Exception as e:
print(f"警告:读取文件 {filepath} 时出错: {e}")
time.sleep(1) # 稍作等待再重试
def __iter__(self) -> Iterator[torch.Tensor]:
# 为每个文件创建一个独立的、无限循环的行迭代器
iterators = [self._create_line_iterator(fp) for fp in self.filepaths]
# 创建一个总的、加权的行生成器
def weighted_line_generator():
while True:
# 根据权重,随机选择一个文件的迭代器
chosen_iterator = random.choices(iterators, weights=self.weights, k=1)[0]
try:
# 从被选中的迭代器中获取下一行
yield next(chosen_iterator)
except StopIteration:
# 由于 _create_line_iterator 是无限的,这里理论上不会发生
continue
# 在总的加权迭代器上,执行一次性的、高效的跳过
line_it = weighted_line_generator()
if self.skip_lines > 0:
print(f"数据集:正在快速跳过前 {self.skip_lines} 行(按权重分布)...")
line_it = itertools.islice(line_it, self.skip_lines, None)
print("跳过完成。")
self.skip_lines = 0 # 确保在后续的数据集循环中不再跳过
# 从正确的位置开始,处理并 yield 数据
for line in line_it:
line = line.strip()
if line:
tokens = self.tokenizer.encode(line)
yield torch.tensor(tokens, dtype=torch.long)
# --- 模型定义 ---
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x):
norm = x.pow(2).mean(-1, keepdim=True).add(self.eps).sqrt()
return x / norm * self.scale
class MultiHeadSelfAttention(nn.Module):
def __init__(self, d_model, n_heads, seq_len, dropout=0.0):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.scale = 1.0 / math.sqrt(self.head_dim)
self.qkv = nn.Linear(d_model, d_model * 3, bias=False)
self.out = nn.Linear(d_model, d_model)
self.register_buffer('mask', torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B, T, C = x.size()
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
att = torch.matmul(q, k.transpose(-2, -1)) * self.scale
# 注意: 你的旧版本注意力mask逻辑
mask = self.mask[:, :, :T, :T] if self.mask.dim() == 4 else self.mask[:, :T, :T]
att = att.masked_fill(mask == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.dropout(att)
out = torch.matmul(att, v)
out = out.transpose(1, 2).contiguous().reshape(B, T, C)
return self.out(out)
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, seq_len, dropout=0.0):
super().__init__()
self.attn = MultiHeadSelfAttention(d_model, n_heads, seq_len, dropout)
self.norm1 = RMSNorm(d_model)
self.ff = FeedForward(d_model, d_ff, dropout)
self.norm2 = RMSNorm(d_model)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.ff(self.norm2(x))
return x
class TinyDecoderModel(nn.Module):
def __init__(self, cfg: Config):
super().__init__()
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
self.pos_emb = nn.Parameter(torch.zeros(cfg.seq_len, cfg.d_model))
self.layers = nn.ModuleList([
TransformerBlock(cfg.d_model, cfg.n_heads, cfg.d_ff, cfg.seq_len, cfg.dropout)
for _ in range(cfg.n_layers)
])
self.ln_f = RMSNorm(cfg.d_model)
self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
nn.init.normal_(self.pos_emb, mean=0.0, std=0.02)
def forward(self, idx):
B, T = idx.size()
x = self.tok_emb(idx) + self.pos_emb[:T]
for l in self.layers:
x = l(x)
x = self.ln_f(x)
logits = self.head(x)
return logits
# --- 学习率调度器 ---
def get_lr_with_cyclical_warmup(step: int, cfg: Config) -> float:
if step < cfg.warmup_steps:
return cfg.lr_max * (step + 1) / cfg.warmup_steps
effective_step = step - cfg.warmup_steps
for cycle_len in cfg.lr_restart_cycles:
full_cycle_len = cfg.restart_warmup_steps + cycle_len
if effective_step < full_cycle_len:
step_in_this_cycle = effective_step
if step_in_this_cycle < cfg.restart_warmup_steps:
return cfg.lr_min + (cfg.lr_max - cfg.lr_min) * (step_in_this_cycle / cfg.restart_warmup_steps)
else:
step_in_decay = step_in_this_cycle - cfg.restart_warmup_steps
decay_ratio = step_in_decay / cycle_len
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return cfg.lr_min + coeff * (cfg.lr_max - cfg.lr_min)
effective_step -= full_cycle_len
return cfg.lr_min
# --- 检查点函数 ---
def save_checkpoint(model, optimizer, scaler, step, out_dir):
os.makedirs(out_dir, exist_ok=True)
path = os.path.join(out_dir, f'ckpt_step{step}.pt')
torch.save({
'step': step,
'model_state': model.state_dict(),
'opt_state': optimizer.state_dict(),
'scaler_state': scaler.state_dict(),
}, path)
print(f"\n检查点已保存: {path}")
# --- 主训练函数 ---
def train(file_weights: dict[str, int], cfg: Config, resume_from_ckpt=None, skip_lines=0):
tokenizer = SPTokenizer(model_file=r"D:\a\uyghur-dictionary\mymodel.model", seq_len=cfg.seq_len)
writer = SummaryWriter(log_dir=os.path.join('./logs', time.strftime("%Y%m%d-%H%M%S"))) if SummaryWriter else None
model = TinyDecoderModel(cfg).to(cfg.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr_max)
scaler = amp.GradScaler(enabled=(cfg.dtype == torch.float16))
step = 0
if resume_from_ckpt:
print(f"正在从检查点恢复: {resume_from_ckpt}")
checkpoint = torch.load(resume_from_ckpt, map_location=cfg.device)
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['opt_state'])
if 'scaler_state' in checkpoint:
scaler.load_state_dict(checkpoint['scaler_state'])
step = checkpoint['step']
print(f"恢复成功,将从 step {step + 1} 继续训练。")
if step > 0 and skip_lines == 0:
lines_per_step = cfg.batch_size * cfg.grad_accum_steps
skip_lines = step * lines_per_step
dataset = WeightedTextLineDataset(file_weights, tokenizer, skip_lines=skip_lines)
dataloader = DataLoader(dataset, batch_size=cfg.batch_size, num_workers=0)
data_iter = iter(dataloader)
model.train()
optimizer.zero_grad(set_to_none=True)
start_time = time.time()
initial_step = step
pbar = tqdm(range(step, cfg.max_steps), desc="Training Steps", initial=step, total=cfg.max_steps)
for step in pbar:
lr = get_lr_with_cyclical_warmup(step, cfg)
for g in optimizer.param_groups:
g['lr'] = lr
accumulated_loss = 0.0
for i in range(cfg.grad_accum_steps):
pbar.set_description(f"Step {step+1}/{cfg.max_steps} [Accum. {i+1}/{cfg.grad_accum_steps}]")
try:
batch = next(data_iter)
except StopIteration:
data_iter = iter(dataloader)
batch = next(data_iter)
batch = batch.to(cfg.device)
with amp.autocast(device_type=cfg.device.type.replace(':', ''), dtype=cfg.dtype):
logits = model(batch)
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = batch[:, 1:].contiguous()
loss_hard = F.cross_entropy(
shift_logits.view(-1, cfg.vocab_size),
shift_labels.view(-1),
ignore_index=0,
label_smoothing=cfg.label_smoothing
)
if cfg.distill_temp > 1.0 and cfg.distill_alpha > 0.0:
with torch.no_grad():
logits_teacher = shift_logits / cfg.distill_temp
q_soft_target = F.softmax(logits_teacher, dim=-1)
log_p_cold = F.log_softmax(shift_logits, dim=-1)
loss_distill = -torch.sum(q_soft_target * log_p_cold, dim=-1).mean()
loss = cfg.distill_alpha * loss_distill + (1.0 - cfg.distill_alpha) * loss_hard
else:
loss = loss_hard
accumulated_loss += loss.item()
scaler.scale(loss / cfg.grad_accum_steps).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if writer:
avg_loss_this_step = accumulated_loss / cfg.grad_accum_steps
if math.isfinite(avg_loss_this_step):
writer.add_scalar('Loss/Total_Loss', avg_loss_this_step, step)
writer.add_scalar('Meta/Learning_Rate', lr, step)
writer.flush()
pbar.set_postfix(loss=f"{avg_loss_this_step:.4f}", lr=f"{lr:.2e}")
if (step + 1) % cfg.save_every == 0:
if writer:
writer.flush()
save_checkpoint(model, optimizer, scaler, step + 1, cfg.out_dir)
if writer:
writer.flush()
save_checkpoint(model, optimizer, scaler, step + 1, cfg.out_dir)
if writer:
writer.close()
if __name__ == "__main__":
if torch.cuda.is_available():
torch.cuda.empty_cache()
parser = argparse.ArgumentParser(description="使用高级正则化和热重启学习率进行模型训练。")
parser.add_argument('--resume_from_ckpt', type=str, default=None, help='指定要恢复训练的检查点文件路径。')
parser.add_argument('--skip_lines', type=int, default=0, help='手动指定数据集要跳过的初始行数。')
args = parser.parse_args()
# --- 【核心】使用字典来定义文件和权重 ---
file_weights_map = {
}
if args.resume_from_ckpt and not os.path.exists(args.resume_from_ckpt):
print(f"错误: 检查点文件 '{args.resume_from_ckpt}' 不存在!")
exit()
cfg = Config()
# --- 【核心】将字典和正确的参数名传入 train 函数 ---
train(file_weights=file_weights_map, cfg=cfg, resume_from_ckpt=args.resume_from_ckpt, skip_lines=args.skip_lines)