File size: 14,875 Bytes
25b1d8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
import math
import os
import time
import random
from typing import Iterator, List
import argparse
import concurrent.futures

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import IterableDataset, DataLoader
from torch import amp
import sentencepiece as spm
from tqdm import tqdm
import itertools

print(f"--- PYTHON EXECUTING THIS FILE: {__file__} ---")

try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    SummaryWriter = None

# --- 配置 ---
class Config:
    # --- 模型架构 ---
    vocab_size = 72000
    d_model = 1024
    n_heads = 4
    n_layers = 8
    d_ff = 4096
    seq_len = 1024
    dropout = 0.1
    # --- 训练硬件与精度 ---
    dtype = torch.float16
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # --- 学习率热重启配置 ---
    lr_max = 7e-6
    warmup_steps = 1000
    lr_min = 1e-6
    restart_warmup_steps = 200
    lr_restart_cycles = [45000, 2000, 3000]
    # --- 批次与步数 ---
    batch_size = 5
    grad_accum_steps = 6
    max_steps = warmup_steps + sum(lr_restart_cycles) + len(lr_restart_cycles) * restart_warmup_steps
    save_every = 500
    out_dir = './checkpoints'
    # --- 其他训练参数 ---
    max_grad_norm = 1.0
    # --- 高级正则化超参数 ---
    label_smoothing = 0.05
    distill_temp = 1.2
    distill_alpha = 0

# --- 分词器类 ---
class SPTokenizer:
    def __init__(self, model_file: str, seq_len: int):
        self.sp = spm.SentencePieceProcessor(model_file=model_file)
        self.seq_len = seq_len

    def encode(self, text: str, pad=True):
        ids = self.sp.encode(text, out_type=int)
        if pad:
            if len(ids) > self.seq_len:
                ids = ids[:self.seq_len]
            else:
                pad_id = self.sp.pad_id() if self.sp.pad_id() != -1 else 0
                ids += [pad_id] * (self.seq_len - len(ids))
        return ids

    def decode(self, ids: List[int]):
        valid_ids = [id for id in ids if id >= 0 and id < self.sp.vocab_size()]
        return self.sp.decode(valid_ids)

# --- 数据集类 ---
class WeightedTextLineDataset(IterableDataset):
    def __init__(self, file_weights: dict[str, int], tokenizer: SPTokenizer, skip_lines: int = 0):
        super().__init__()
        self.tokenizer = tokenizer
        self.skip_lines = skip_lines
        if not file_weights: raise ValueError("文件权重字典不能为空。")
        self.filepaths = list(file_weights.keys())
        self.weights = list(file_weights.values())
        print("加权数据集已初始化。")

    def _create_line_iterator(self, filepath):
        """为单个文件创建一个无限循环的行生成器"""
        while True:
            try:
                with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
                    for line in f: yield line
            except Exception as e: 
                print(f"警告:读取文件 {filepath} 时出错: {e}")
                time.sleep(1) # 稍作等待再重试

    def __iter__(self) -> Iterator[torch.Tensor]:
        # 为每个文件创建一个独立的、无限循环的行迭代器
        iterators = [self._create_line_iterator(fp) for fp in self.filepaths]
        
        # 创建一个总的、加权的行生成器
        def weighted_line_generator():
            while True:
                # 根据权重,随机选择一个文件的迭代器
                chosen_iterator = random.choices(iterators, weights=self.weights, k=1)[0]
                try:
                    # 从被选中的迭代器中获取下一行
                    yield next(chosen_iterator)
                except StopIteration:
                    # 由于 _create_line_iterator 是无限的,这里理论上不会发生
                    continue
        
        # 在总的加权迭代器上,执行一次性的、高效的跳过
        line_it = weighted_line_generator()
        if self.skip_lines > 0:
            print(f"数据集:正在快速跳过前 {self.skip_lines} 行(按权重分布)...")
            line_it = itertools.islice(line_it, self.skip_lines, None)
            print("跳过完成。")
            self.skip_lines = 0 # 确保在后续的数据集循环中不再跳过

        # 从正确的位置开始,处理并 yield 数据
        for line in line_it:
            line = line.strip()
            if line:
                tokens = self.tokenizer.encode(line)
                yield torch.tensor(tokens, dtype=torch.long)

# --- 模型定义 ---
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        norm = x.pow(2).mean(-1, keepdim=True).add(self.eps).sqrt()
        return x / norm * self.scale

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads, seq_len, dropout=0.0):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.scale = 1.0 / math.sqrt(self.head_dim)
        self.qkv = nn.Linear(d_model, d_model * 3, bias=False)
        self.out = nn.Linear(d_model, d_model)
        self.register_buffer('mask', torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)
        q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        att = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        # 注意: 你的旧版本注意力mask逻辑
        mask = self.mask[:, :, :T, :T] if self.mask.dim() == 4 else self.mask[:, :T, :T]
        att = att.masked_fill(mask == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.dropout(att)
        out = torch.matmul(att, v)
        out = out.transpose(1, 2).contiguous().reshape(B, T, C)
        return self.out(out)

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, seq_len, dropout=0.0):
        super().__init__()
        self.attn = MultiHeadSelfAttention(d_model, n_heads, seq_len, dropout)
        self.norm1 = RMSNorm(d_model)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.norm2 = RMSNorm(d_model)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ff(self.norm2(x))
        return x

class TinyDecoderModel(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
        self.pos_emb = nn.Parameter(torch.zeros(cfg.seq_len, cfg.d_model))
        self.layers = nn.ModuleList([
            TransformerBlock(cfg.d_model, cfg.n_heads, cfg.d_ff, cfg.seq_len, cfg.dropout)
            for _ in range(cfg.n_layers)
        ])
        self.ln_f = RMSNorm(cfg.d_model)
        self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
        nn.init.normal_(self.pos_emb, mean=0.0, std=0.02)

    def forward(self, idx):
        B, T = idx.size()
        x = self.tok_emb(idx) + self.pos_emb[:T]
        for l in self.layers:
            x = l(x)
        x = self.ln_f(x)
        logits = self.head(x)
        return logits

# --- 学习率调度器 ---
def get_lr_with_cyclical_warmup(step: int, cfg: Config) -> float:
    if step < cfg.warmup_steps:
        return cfg.lr_max * (step + 1) / cfg.warmup_steps
    effective_step = step - cfg.warmup_steps
    for cycle_len in cfg.lr_restart_cycles:
        full_cycle_len = cfg.restart_warmup_steps + cycle_len
        if effective_step < full_cycle_len:
            step_in_this_cycle = effective_step
            if step_in_this_cycle < cfg.restart_warmup_steps:
                return cfg.lr_min + (cfg.lr_max - cfg.lr_min) * (step_in_this_cycle / cfg.restart_warmup_steps)
            else:
                step_in_decay = step_in_this_cycle - cfg.restart_warmup_steps
                decay_ratio = step_in_decay / cycle_len
                coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
                return cfg.lr_min + coeff * (cfg.lr_max - cfg.lr_min)
        effective_step -= full_cycle_len
    return cfg.lr_min

# --- 检查点函数 ---
def save_checkpoint(model, optimizer, scaler, step, out_dir):
    os.makedirs(out_dir, exist_ok=True)
    path = os.path.join(out_dir, f'ckpt_step{step}.pt')
    torch.save({
        'step': step,
        'model_state': model.state_dict(),
        'opt_state': optimizer.state_dict(),
        'scaler_state': scaler.state_dict(),
    }, path)
    print(f"\n检查点已保存: {path}")

# --- 主训练函数 ---
def train(file_weights: dict[str, int], cfg: Config, resume_from_ckpt=None, skip_lines=0):
    tokenizer = SPTokenizer(model_file=r"D:\a\uyghur-dictionary\mymodel.model", seq_len=cfg.seq_len)
    writer = SummaryWriter(log_dir=os.path.join('./logs', time.strftime("%Y%m%d-%H%M%S"))) if SummaryWriter else None
    model = TinyDecoderModel(cfg).to(cfg.device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr_max)
    scaler = amp.GradScaler(enabled=(cfg.dtype == torch.float16))
    step = 0

    if resume_from_ckpt:
        print(f"正在从检查点恢复: {resume_from_ckpt}")
        checkpoint = torch.load(resume_from_ckpt, map_location=cfg.device)
        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['opt_state'])
        if 'scaler_state' in checkpoint:
            scaler.load_state_dict(checkpoint['scaler_state'])
        step = checkpoint['step']
        print(f"恢复成功,将从 step {step + 1} 继续训练。")

    if step > 0 and skip_lines == 0:
        lines_per_step = cfg.batch_size * cfg.grad_accum_steps
        skip_lines = step * lines_per_step

    dataset = WeightedTextLineDataset(file_weights, tokenizer, skip_lines=skip_lines)
    dataloader = DataLoader(dataset, batch_size=cfg.batch_size, num_workers=0)
    data_iter = iter(dataloader)

    model.train()
    optimizer.zero_grad(set_to_none=True)
    start_time = time.time()
    initial_step = step

    pbar = tqdm(range(step, cfg.max_steps), desc="Training Steps", initial=step, total=cfg.max_steps)
    for step in pbar:
        lr = get_lr_with_cyclical_warmup(step, cfg)
        for g in optimizer.param_groups:
            g['lr'] = lr
        
        accumulated_loss = 0.0
        for i in range(cfg.grad_accum_steps):
            pbar.set_description(f"Step {step+1}/{cfg.max_steps} [Accum. {i+1}/{cfg.grad_accum_steps}]")
            try:
                batch = next(data_iter)
            except StopIteration:
                data_iter = iter(dataloader)
                batch = next(data_iter)
            batch = batch.to(cfg.device)

            with amp.autocast(device_type=cfg.device.type.replace(':', ''), dtype=cfg.dtype):
                logits = model(batch)
                shift_logits = logits[:, :-1, :].contiguous()
                shift_labels = batch[:, 1:].contiguous()

                loss_hard = F.cross_entropy(
                    shift_logits.view(-1, cfg.vocab_size),
                    shift_labels.view(-1),
                    ignore_index=0,
                    label_smoothing=cfg.label_smoothing
                )

                if cfg.distill_temp > 1.0 and cfg.distill_alpha > 0.0:
                    with torch.no_grad():
                        logits_teacher = shift_logits / cfg.distill_temp
                        q_soft_target = F.softmax(logits_teacher, dim=-1)
                    
                    log_p_cold = F.log_softmax(shift_logits, dim=-1)
                    loss_distill = -torch.sum(q_soft_target * log_p_cold, dim=-1).mean()
                    loss = cfg.distill_alpha * loss_distill + (1.0 - cfg.distill_alpha) * loss_hard
                else:
                    loss = loss_hard
            
            accumulated_loss += loss.item()
            scaler.scale(loss / cfg.grad_accum_steps).backward()

        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)

        if writer:
            avg_loss_this_step = accumulated_loss / cfg.grad_accum_steps
            if math.isfinite(avg_loss_this_step):
                writer.add_scalar('Loss/Total_Loss', avg_loss_this_step, step)
                writer.add_scalar('Meta/Learning_Rate', lr, step)
                writer.flush()
        
        pbar.set_postfix(loss=f"{avg_loss_this_step:.4f}", lr=f"{lr:.2e}")

        if (step + 1) % cfg.save_every == 0:
            if writer:
                writer.flush()
            save_checkpoint(model, optimizer, scaler, step + 1, cfg.out_dir)

    if writer:
        writer.flush()
    save_checkpoint(model, optimizer, scaler, step + 1, cfg.out_dir)
    if writer:
        writer.close()

if __name__ == "__main__":
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    parser = argparse.ArgumentParser(description="使用高级正则化和热重启学习率进行模型训练。")
    parser.add_argument('--resume_from_ckpt', type=str, default=None, help='指定要恢复训练的检查点文件路径。')
    parser.add_argument('--skip_lines', type=int, default=0, help='手动指定数据集要跳过的初始行数。')
    args = parser.parse_args()
    
    # --- 【核心】使用字典来定义文件和权重 ---
    file_weights_map = {
    }
    
    if args.resume_from_ckpt and not os.path.exists(args.resume_from_ckpt):
        print(f"错误: 检查点文件 '{args.resume_from_ckpt}' 不存在!")
        exit()
        
    cfg = Config()
    
    # --- 【核心】将字典和正确的参数名传入 train 函数 ---
    train(file_weights=file_weights_map, cfg=cfg, resume_from_ckpt=args.resume_from_ckpt, skip_lines=args.skip_lines)